常用数据集汇总,神经网络训练的动力源泉。目前还在整理中,我们会一步一步完善该板块。
MNIST(Modified National Institute of Standards and Technology)是一个广泛使用的手写数字图像数据集,由美国国家标准与技术研究院(NIST)收集并修改而成,常被用于机器学习领域的入门级图像分类任务。包含 0 到 9 的灰度手写数字图像,每张图像大小为 28×28 像素(共 784 个像素点),像素值范围 0~255(黑白),单通道。
类型 | 下载链接 | 文件大小 |
压缩包 | mnist.zip | 22.2MB |
压缩包中共包含四个文件:
# train-images-idx3-ubyte.gz
# train-labels-idx1-ubyte.gz
# t10k-images-idx3-ubyte.gz
# t10k-labels-idx1-ubyte.gz
其中用于训练的字符有 60,000 张,测试样本 10,000 张,它们构成了理解图像分类和评估模型性能的基石数据。以下是使用 Python 读取数据的代码:
import os
import gzip
import struct
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
class MNIST(Dataset):
def __init__(self, root, train=True):
self.train = train
prefix = 'train' if train else 't10k'
images_path = os.path.join(root, prefix + '-images-idx3-ubyte.gz')
labels_path = os.path.join(root, prefix + '-labels-idx1-ubyte.gz')
if not os.path.exists(labels_path):
raise FileNotFoundError(f"MNIST labels not found at {labels_path}")
with gzip.open(labels_path, 'rb') as lbpath:
_, n = struct.unpack('>II', lbpath.read(8))
self.labels = np.frombuffer(lbpath.read(), dtype=np.uint8)
if not os.path.exists(images_path):
raise FileNotFoundError(f"MNIST images not found at {images_path}")
with gzip.open(images_path, 'rb') as imgpath:
_, n, rows, cols = struct.unpack('>IIII', imgpath.read(16))
self.images = np.frombuffer(imgpath.read(), dtype=np.uint8).reshape(n, rows, cols)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img = self.images[idx]
img = img.astype(np.float32)
img = torch.from_numpy(img).unsqueeze(0) # 增加通道维度 (1, 28, 28)
img = img.float() / 255.0 # 归一化到 [0, 1]
img = (img - 0.5) / 0.5 # 归一化到 [-1, 1]
# 可选的数据增强选项
# if self.train:
# img = transforms.RandomRotation(10)(img) # 示例:随机旋转10度
label = self.labels[idx]
label = torch.tensor(label, dtype=torch.long) # 标签转为 torch.long 类型(适用于 PyTorch 的交叉熵损失)
return img, label
如果只需要取得某一个数字,可以使用如下方式:
import os
import gzip
import struct
import numpy as np
import torch
from torch.utils.data import Dataset
class MNIST3(Dataset):
def __init__(self, root, train=True):
self.train = train
prefix = 'train' if train else 't10k'
images_path = os.path.join(root, prefix + '-images-idx3-ubyte.gz')
if not os.path.exists(images_path):
raise FileNotFoundError(f"MNIST images not found at {images_path}")
with gzip.open(images_path, 'rb') as imgpath:
_, n, rows, cols = struct.unpack('>IIII', imgpath.read(16))
images = np.frombuffer(imgpath.read(), dtype=np.uint8).reshape(n, rows, cols)
self.images = images[labels == 3] # 仅保留数字3
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img = self.images[idx]
img = img.astype(np.float32)
img = torch.from_numpy(img).unsqueeze(0) # 增加通道维度 (1, 28, 28)
img = img.float() / 255.0 # 归一化到 [0, 1]
img = (img - 0.5) / 0.5 # 归一化到 [-1, 1]
return img
这样的数据集可用于训练 GAN 对某个特定的数字进行概率分布学习。