数据集

常用数据集汇总,神经网络训练的动力源泉。目前还在整理中,我们会一步一步完善该板块。

一、MNIST 数据集

MNIST(Modified National Institute of Standards and Technology)是一个广泛使用的手写数字图像数据集,由美国国家标准与技术研究院(NIST)收集并修改而成,常被用于机器学习领域的入门级图像分类任务。包含 0 到 9 的灰度手写数字图像,每张图像大小为 28×28 像素(共 784 个像素点),像素值范围 0~255(黑白),单通道。

类型下载链接文件大小
压缩包mnist.zip22.2MB

压缩包中共包含四个文件:

# train-images-idx3-ubyte.gz
# train-labels-idx1-ubyte.gz
# t10k-images-idx3-ubyte.gz
# t10k-labels-idx1-ubyte.gz

其中用于训练的字符有 60,000 张,测试样本 10,000 张,它们构成了理解图像分类和评估模型性能的基石数据。以下是使用 Python 读取数据的代码:

import os
import gzip
import struct
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
  
class MNIST(Dataset):
    def __init__(self, root, train=True):
        self.train = train

        prefix = 'train' if train else 't10k'
        images_path = os.path.join(root, prefix + '-images-idx3-ubyte.gz')
        labels_path = os.path.join(root, prefix + '-labels-idx1-ubyte.gz')

        if not os.path.exists(labels_path):
            raise FileNotFoundError(f"MNIST labels not found at {labels_path}")

        with gzip.open(labels_path, 'rb') as lbpath:
            _, n = struct.unpack('>II', lbpath.read(8))
            self.labels = np.frombuffer(lbpath.read(), dtype=np.uint8)

        if not os.path.exists(images_path):
            raise FileNotFoundError(f"MNIST images not found at {images_path}")

        with gzip.open(images_path, 'rb') as imgpath:
            _, n, rows, cols = struct.unpack('>IIII', imgpath.read(16))
            self.images = np.frombuffer(imgpath.read(), dtype=np.uint8).reshape(n, rows, cols)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        img = img.astype(np.float32)
        img = torch.from_numpy(img).unsqueeze(0) # 增加通道维度 (1, 28, 28)
        img = img.float() / 255.0 # 归一化到 [0, 1]
        img = (img - 0.5) / 0.5 # 归一化到 [-1, 1]

        # 可选的数据增强选项
        # if self.train:
        #     img = transforms.RandomRotation(10)(img)  # 示例:随机旋转10度

        label = self.labels[idx]
        label = torch.tensor(label, dtype=torch.long) # 标签转为 torch.long 类型(适用于 PyTorch 的交叉熵损失)

        return img, label

如果只需要取得某一个数字,可以使用如下方式:

import os
import gzip
import struct
import numpy as np
import torch
from torch.utils.data import Dataset

class MNIST3(Dataset):
    def __init__(self, root, train=True):
        self.train = train

        prefix = 'train' if train else 't10k'
        images_path = os.path.join(root, prefix + '-images-idx3-ubyte.gz')

        if not os.path.exists(images_path):
            raise FileNotFoundError(f"MNIST images not found at {images_path}")

        with gzip.open(images_path, 'rb') as imgpath:
            _, n, rows, cols = struct.unpack('>IIII', imgpath.read(16))
            images = np.frombuffer(imgpath.read(), dtype=np.uint8).reshape(n, rows, cols)

        self.images = images[labels == 3]  # 仅保留数字3

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        img = img.astype(np.float32)
        img = torch.from_numpy(img).unsqueeze(0) # 增加通道维度 (1, 28, 28)
        img = img.float() / 255.0 # 归一化到 [0, 1]
        img = (img - 0.5) / 0.5 # 归一化到 [-1, 1]
        return img

这样的数据集可用于训练 GAN 对某个特定的数字进行概率分布学习。