项目实践】MNIST 手写数字识别 (上)
前言
本文将介绍如何在 PyTorch 中构建一个简单的卷积神经网络,并训练它使用 MNIST 数据集识别手写数字,这将可以被看做是图像识别的 “Hello, World!”;
MNIST 包含 70,000 张手写数字图像:60,000 张用于训练,10,000 张用于测试。这些图像是灰度的,28x28 像素,居中以减少预处理并更快地开始。
配置环境
在本文中,我们将使用 PyTorch 训练卷积神经网络来识别 MNIST 的手写数字。 PyTorch 是一个非常流行的深度学习框架,如 Tensorflow、CNTK 和 Caffe2。但与这些其他框架不同,PyTorch 具有动态执行图,这意味着计算图是动态创建的。
import torch
import torchvision
这里关于 PyTorch 的环境搭建就不再赘述了;
PyTorch 的官方文档链接:PyTorch documentation,在这里不仅有 API的说明还有一些经典的实例可供参考,中文文档点这!
准备数据集
完成环境导入之后,我们可以继续准备我们将使用的数据。
但在此之前,我们将定义我们将用于实验的超参数。在这里,epoch
的数量定义了我们将在整个训练数据集上循环多少次,而 learning_rate
和 momentum
是我们稍后将使用的优化器的超参数。
n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10
random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)
对于可重复的实验,我们必须为任何使用随机数生成的值设置随机种子:numpy
和 random
;
而且,由于 cuDNN 使用非确定性算法,可以通过设置 torch.backends.cudnn.enabled = False
禁用该算法。
现在我们还需要数据集 DataLoaders,这就是 TorchVision 发挥作用的地方。它让我们以方便的方式使用加载 MNIST 数据集。下面用于 Normalize()
转换的值 0.1307 和 0.3081 是 MNIST 数据集的全局平均值和标准差,我们将在此处将它们作为给定值。
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(path, train=True, download=False,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(path, train=False, download=False,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size_test, shuffle=True)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!
TIP: 如果你可以接受等待时间的话,可以改动 download=True
,不然的话,就自己先下载,然后在设置路径;
PyTorch 的 DataLoader
包含一些有趣的选项,而不是数据集和批量大小。例如,我们可以使用 num_workers > 1
来使用子进程异步加载数据或使用固定 RAM(via pin_memory)来加速 RAM 到 GPU 的传输。
使用数据集
接下来使用一下 test_loader
:
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
所以一个测试数据批次是一个形状张量:这意味着我们有 1000 个 28x28 像素的灰度示例(即没有 rgb 通道,因此只有一个)。可以使用 matplotlib 绘制其中的一些:
import matplotlib.pyplot as plt
fig = plt.figure()
for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
plt.title("Ground Truth: {}".format(example_targets[i]))
plt.xticks([])
plt.yticks([])
后记
当你完成上述工作后,且一切正常,那么你的准备工作就完成了!接下来,就是要构建一个简单的卷积神经网络,并训练它使用 MNIST 数据集识别手写数字;
推荐阅读
-
【摩尔线程+Colossal-AI强强联手】MusaBert登上CLUE榜单TOP10:技术细节揭秘 - 技术实力:摩尔线程凭借"软硬兼备"的技术底蕴,让MusaBert得以从底层优化到顶层。其内置多功能GPU配备AI加速和并行计算模块,提供了全面的AI与科学计算支持,为AI推理和低资源条件下的大模型训练等场景带来了高效、经济且环保的算力。 - 算法层面亮点:依托Colossal-AI AI大模型开发系统,MusaBert在训练过程中展现出了卓越的并行性能与易用性,特别在预处理阶段对DataLoader进行了优化,适应低资源环境高效处理海量数据。同时,通过精细的建模优化、领域内数据增强以及Adan优化器等手段,挖掘和展示了预训练语言模型出色的语义理解潜力。基于MusaBert,摩尔线程自主研发的MusaSim通过对比学习方法微调,结合百万对标注数据,MusaSim在多个任务如语义相似度、意图识别和情绪分析中均表现出色。 - 数据资源丰富:MusaBert除了自家高质量语义相似数据外,还融合了悟道开源200GB数据、CLUE社区80GB数据,以及浪潮公司提供的1TB高质量数据,保证模型即便在较小规模下仍具备良好性能。 当前,MusaBert已成功应用于摩尔线程的智能客服与数字人项目,并广泛服务于语义相似度、情绪识别、阅读理解与声韵识别等领域。为了降低大模型开发和应用难度,MusaBert及其相关高质量模型代码已在Colossal-AI仓库开源,可快速训练优质中文BERT模型。同时,通过摩尔线程与潞晨科技的深度合作,仅需一张多功能GPU单卡便能高效训练MusaBert或更大规模的GPT2模型,显著降低预训练成本,进一步推动双方在低资源大模型训练领域的共享目标。 MusaBert荣登CLUE榜单TOP10,象征着摩尔线程与潞晨科技联合研发团队在中文预训练研究领域的领先地位。展望未来,双方将携手探索更大规模的自然语言模型研究,充分运用上游数据资源,产出更为强大的模型并开源。持续强化在摩尔线程多功能GPU上的大模型训练能力,特别是在消费级显卡等低资源环境下,致力于降低使用大模型训练的门槛与成本,推动人工智能更加普惠。而潞晨科技作为重要合作伙伴,将继续发挥关键作用。
-
10个简单好玩的人工智能项目实操指南(含Python源码) - 例子5:手写数字识别
-
PyTorch - (7) MNIST 手写数字识别示例
-
mnist 手写数字识别 - 深度学习入门项目(tensorflow+keras+序列模型)
-
使用 MNIST 的 CNN 手写数字识别实用代码和见解
-
项目实践】MNIST 手写数字识别 (上)
-
[完整项目] 基于 Mnist 的手写数字识别 - Pytorch 版
-
深度学习项目在行动--手写数字识别项目
-
MNIST 手写数字识别代码(KNN 手写数字识别)
-
MNIST 手写数字识别代码详细备注版 [零基础入门使用]。