irpas技术客

入门学习MNIST手写数字识别_Zkaisen_minst手写识别

网络投稿 6573

一、MNIST数据集 1.MNIST数据集简介

MNIST数据集是一个公开的数据集,相当于深度学习的hello world,用来检验一个模型/库/框架是否有效的一个评价指标。

MNIST数据集是由0?9手写数字图片和数字标签所组成的,由60000个训练样本和10000个测试样本组成,每个样本都是一张28 * 28像素的灰度手写数字图片。MNIST 数据集来自美国国家标准与技术研究所,整个训练集由250个不同人的手写数字组成,其中50%来自美国高中学生,50%来自人口普查的工作人员。

2.MNIST数据集包含四部分 ?train-images-idx3-ubyte.gz:训练集图片(9912422字节),55000张训练集,5000张验证集train-labels-idx1-ubyte.gz:训练集图片对应的标签(28881字节)t10k-images-idx3-ubyte .gz:测试集图片(1648877字节),10000张图片?t10k-labels-idx1-ubyte.gz:测试集图片对应的标签(4542字节)

MNIST数据集是集成的API无需手动下载,可以通过torch里面的API直接获取。

官方文档:https://pytorch.org/vision/stable/datasets.html

参数:

root:指的是下载的目录?train:如果设置成True的话表示取训练集,如果要取测试集就设置成Falsedownload:如果设置成True,会先判断是否下载过,如果未下载过,就会下载文件;如果已经下载过了,设置成True和False都一样,不会重新下载。transform:是对图片进行预处理的一些操作,可以将一个PIL 图片翻译成张量或其他内容? ? ? ? ? ? from torchvision.datasets import MNIST#获取MNIST的数据集 mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=my_transforms) #root表示下载路径,训练模式取训练集,是否下载:是

注:当我们后期要训练自己的数据集的时候需要将MNIST 数据集换成我们自己的数据集

print(len(mnist_train))#len表示求它的长度,训练集有10000张 print(mnist_train[0])#getitem表示通过索引把图像取出来 运行结果: 60000 (<PIL.Image.Image image mode=L size=28x28 at 0x2175A412320>, 5) 解释说明: 运行结果中60000是训练集的长度,即训练集有60000张图片 第二行返回结果有两部分,一部分是PIL图像,另一部分是图片的标签 标签为5,就表示这个数字是5.

怎样能看一下这张图片是什么呢?

首先需要安装并导入matlab库

import matplotlib.pyplot as plt#安装matlab库并导入matplotlib.pyplot from torchvision.datasets import MNIST#获取MNIST的数据集 mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=None) #root表示下载路径,训练模式取训练集,是否下载:是 print(len(mnist_train))#len表示求它的长度,训练集有10000张 print(mnist_train[0][0])#getitem表示通过索引把图像取出来 image = mnist_train[0][0]#取出具体的一张图片 plt.imshow(image) plt.show()#把图片展示出来 print(mnist_train[0][1])#把图片的标签打印出来 运行结果: 60000 <PIL.Image.Image image mode=L size=28x28 at 0x1DAF3D3BB00> 5

图片展示??

遇到OMP报错的话,在代码中添加下面两行代码即可

import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"#前两行代码解决一个OMP报错

?可以参考下面链接:

OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.OMP: Hint_fencecat的博客-CSDN博客OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.OMP: Hinthttps://blog.csdn.net/fencecat/article/details/122887204?spm=1001.2014.3001.5502

有时即使模型再好,识别率也达不到100%;因为有些数字写的实在太飘逸了,标签也是随心所欲😂

二、数据加载

MNIST数据集继承了torch.utils.data.Dataset

需要自己实现__len__和__getitem__两个方法:

__len__实现获取数据集长度的操作__getitem__实现获取第几个对象的操作,通过索引的方式把图片取出来。

torch已封装好的加载器

前边已经得到MNIST数据集的实例化对象,接下来就可以进行数据的加载,加载器功能较多,如果自己实现的话会比较复杂,我们可以借助torch已经封装好的加载器来处理

官方文档https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

from torchvision.datasets import MNIST from torch.utils.data import DataLoader#导入数据加载器 mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=None) dataloader = DataLoader(mnist_train, batch_size=2, shuffle=True) #实例化一个类,传入把训练集,batch_size设为1,shuffle设为true打乱 print(dataloader)#打印一下 运行结果: <torch.utils.data.dataloader.DataLoader object at 0x00000297C76711D0>

迭代DataLoader类:

# 加载器 from torchvision.datasets import MNIST from torch.utils.data import DataLoader#导入数据加载器 mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=None) dataloader = DataLoader(mnist_train, batch_size=2, shuffle=True) #实例化一个类,传入把训练集,batch_size设为1,shuffle设为true打乱 # print(dataloader)#打印一下 for i in dataloader: print(i) 报错: TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'> #batch必须包含张量,numpy数组,数字,字典或列表,不支持PIL图像

怎样才能迭代将PIL图像打印呢?需要引入图像处理

三、transforms图像处理 1.导入transforms方法,并将MNIST数据集的transfrom改为transforms.ToTensor() #图片处理 #导入transforms方法,并将MNIST数据集中transform改为transforms.ToTensor() from torchvision import transforms#导入transforms方法 from torchvision.datasets import MNIST from torch.utils.data import DataLoader#导入数据加载器 mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=transforms.ToTensor()) dataloader = DataLoader(mnist_train, batch_size=2, shuffle=True) #实例化一个类,传入把训练集,batch_size设为1,shuffle设为true打乱 # print(dataloader)#打印一下 for i in dataloader: print(i) 运行结果:将PIL图像转换成了张量形式 [tensor([[[[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]]], [[[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]]]]), tensor([8, 0])] 2.集合transforms.Compose(transforms)可以将transforms组合起来使用 #图片处理 #导入transforms方法,并将MNIST数据集中transform改为transforms.ToTensor() from torchvision import transforms#导入transforms方法 from torchvision.datasets import MNIST my_transforms = transforms.Compose( [transforms.PILToTensor()]) from torch.utils.data import DataLoader#导入数据加载器 mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=transforms.ToTensor()) dataloader = DataLoader(mnist_train, batch_size=1, shuffle=True)#实例化一个类,传入把训练集,batch_size设为1,shuffle设为true打乱 for i in dataloader: print(i) exit()#打印一次后退出 运行结果: [tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1608, 0.5961, 0.9137, 0.5961, 0.5961, 0.2000, 0.0392, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7961, 0.9922, 0.9882, 0.9922, 0.9882, 0.9922, 0.6745, 0.1608, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4000, 0.9961, 0.9922, 0.4000, 0.2392, 0.6392, 0.9529, 0.9176, 0.2000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4000, 0.9922, 0.9882, 0.0000, 0.0000, 0.0000, 0.3176, 0.9922, 0.9098, 0.1608, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4000, 0.9961, 0.9922, 0.0000, 0.0000, 0.0000, 0.0000, 0.5176, 0.9922, 0.6392, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0784, 0.8353, 0.9882, 0.1608, 0.0000, 0.0000, 0.1608, 0.5176, 0.9882, 0.8745, 0.0784, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3608, 0.9922, 0.8392, 0.2000, 0.4431, 0.9137, 0.7961, 0.7961, 0.3216, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2000, 0.9882, 0.9922, 0.9882, 0.9922, 0.8314, 0.0784, 0.0784, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7961, 0.9961, 0.9922, 0.5569, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6353, 0.9922, 0.9882, 0.0784, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5176, 0.9922, 0.9961, 0.9922, 0.2431, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1608, 0.9922, 0.9882, 0.9922, 0.9882, 0.4000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6392, 0.9961, 0.6745, 0.5961, 0.9922, 0.4000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0824, 0.8745, 0.8353, 0.0392, 0.2784, 0.9882, 0.7176, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2039, 0.9922, 0.7961, 0.0000, 0.1608, 0.9529, 0.9961, 0.1961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5176, 0.9882, 0.7961, 0.0000, 0.0000, 0.7961, 0.9922, 0.1961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6000, 0.9922, 0.8000, 0.0000, 0.1216, 0.9137, 1.0000, 0.1961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5961, 0.9882, 0.7961, 0.0000, 0.6784, 0.9882, 0.6745, 0.0392, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3608, 0.9922, 1.0000, 0.9922, 1.0000, 0.9922, 0.1608, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0392, 0.5137, 0.8353, 0.9882, 0.9137, 0.2745, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]]), tensor([8])] Process finished with exit code 0 打印时将images与labels分开 # for i in dataloader: # print(i) # exit()#打印一次后退出 # for (images, labels) in dataloader: # print(images, labels) #for i in dataloader: # print(i[0], i[1]) ?3.transfroms方法

官方文档:torchvision.transforms — Torchvision 0.11.0 documentationhttps://pytorch.org/vision/stable/transforms.html

(1) transfroms简介 transfroms是一种常用的图像转换方法,他们可以通过Compose方法组合到一起,这样可以实现许多个transfroms对图像进行处理。transfroms方法提供图像的精细化处理,例如在分割任务的情况下? ,你必须建立一个更复杂的转换管道,这时transfroms方法是很有用的。很多转换器既接受PIL图像,也接受tensor图像。一张tensor图像是形状为(C,?H,?W)的张量,这里C表示通道数,H和W 是图像的高和宽。1batch 的tensor图像是一个形状为(B,?C,?H,?W)?的张量,这里B表示在batch上有多少张图片。transfroms方法处理过后,会把通道移到最前边。比如MNIST h*w*c为:28*28*1,tensor处理完,通道数会提前,并且做了轴交换,变为了c*h*w为:1*28*28,为什么要这样设计呢?据说是做矩阵加减乘除以及卷积等运算是需要用cuda和cudnn的函数的,而这些接口都设成chw格式了。

a. 轴交换

transfroms方法处理过后,如果我们需要把图片转回PIL,需要进行一次轴交换;因为无法处理一个28通道数的图片。

#轴交换之前打印一下图片的形状 from torchvision import transforms#导入transforms方法 from torchvision.datasets import MNIST my_transforms = transforms.Compose( [transforms.PILToTensor()] ) from torch.utils.data import DataLoader mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform= transforms.ToTensor()) dataloader = DataLoader(mnist_train, batch_size=1, shuffle=True) for (images, labels) in dataloader: print(images.shape) exit()#打印一次后退出 #运行结果 torch.Size([1, 1, 28, 28]) #说明: 这里第一个“1”表1 batch_size,即一次加载一张图片 第二个“1”表示通道数,后边两个“28”分别表示图片的高和宽 #使用make_grid方法将两张图片融合 from torchvision.utils import make_grid##即使一张图片我们也要将它融合一下,使用make_grid方法 from torchvision import transforms#导入transforms方法 from torchvision.datasets import MNIST my_transforms = transforms.Compose( [transforms.PILToTensor()] )#将多个transforms组合在一起,还可以加入标准化等图像处理 from torch.utils.data import DataLoader#导入数据加载器 mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=transforms.ToTensor()) dataloader = DataLoader(mnist_train, batch_size=1, shuffle=True) #实例化一个类,传入把训练集,batch_size设为1,shuffle设为true打乱 print(dataloader)#打印一下 for (images, labels) in dataloader: print(make_grid(images).shape) exit()#打印一次后退出 #运行结果: <torch.utils.data.dataloader.DataLoader object at 0x00000289D3634438> torch.Size([3, 28, 28]) Process finished with exit code 0 #结果图像的形状变成了3*28*28 #如果将上述代码中dataloader = DataLoader(mnist_train, batch_size=1, shuffle=True)换成 dataloader = DataLoader(mnist_train, batch_size=2, shuffle=True) #运行结果变为torch.Size([3, 32, 62]),相当于把两张图片融合了

b. 使用轴交换边回去

#轴交换 import os os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"#前两行代码解决一个OMP报错 from torchvision.utils import make_grid##即使一张图片我们也要将它融合一下,使用make_grid方法 from torchvision import transforms#导入transforms方法 from torchvision.datasets import MNIST import matplotlib.pyplot as plt#安装matlab库并导入matplotlib.pyplot my_transforms = transforms.Compose( [transforms.PILToTensor()] )#将多个transforms组合在一起,还可以加入标准化等图像处理 from torch.utils.data import DataLoader#导入数据加载器 mnist_train = MNIST(root="/MNIST_data", train=True, download=True, transform=transforms.ToTensor()) dataloader = DataLoader(mnist_train, batch_size=2, shuffle=True) #实例化一个类,传入把训练集,batch_size设为1,shuffle设为true打乱 print(dataloader)#打印一下 for (images, labels) in dataloader: image = make_grid(images).permute(1, 2, 0).numpy() #permute(1, 2, 0)实际上就是把通道数移到后边的过程,忘记的话回看第二节课的视频 #轴交换了之后转换成numpy的数组,之后就可以做加载了 plt.imshow(image) plt.show() exit() #运行结果 <torch.utils.data.dataloader.DataLoader object at 0x0000027DCB7A4710> Process finished with exit code 0

图片展示:

添加代码:print(labels)可以将标签打印出来

这个操作一般只有调试时才会用,正常运算不需要把tensor图像转换成PIL图像再看一下

(2)进阶了解transfroms方法

?参考文档:PyTorch 学习笔记:transforms的二十二个方法(transforms用法非常详细)_liangbaqiang的博客-CSDN博客_transforms.scale

四、模型和优化器 1.简介

模型四深度学习的关键内容,是深度学习的核心。

深度神经网络的种类主要有:

传统神经网络CNN卷积网络CNN循环神经网络(递归神经网络)RNN

目前比较流行的深度神经模型几乎都是卷积和循环两种模型的延伸。

2.全连接层:torch.nn.Linear (1)简介

官方文档:https://pytorch.org/docs/stable/nn.html

对于MNIST数据集这种简单的,且样本数量足够多的项目,一个全连接层就能达到不错的效果。

后期会对这些模型的“层”进行组合实现。有卷积层、池化层、标准化层等等。

?全连接层指的是层中的每个节点都会连接它下一层的所有节点,它是模仿人脑神经结构来构建的。最左边是输入的是图像,实际上就是图像的像素点,全连接层每层之间都是线性关系。

比如:假设输入为、、……,那么与输入层直接相连的中间层就是这样计算来的

,同理可以计算出第二层的、……,同样中间各层之间都有一个权重,下一层的输出都是由上一层的每个输入乘以相应的权重累加得出的。最终得到的是两个输出结果,这是一个二分类的问题。输出几个值几分类问题。

(2)全连接层的实现 #全连接层 #首先我们要新建一个类,这个类要继承nn.Module class MnistModel(nn.Module): def __init__(self):#继承__init__方法 super(MnistModel, self).__init__() self.fc2 = nn.Linear(1*28*28, 10)#最初传入的图片的像素点是1*28*28的,最后我们要收敛成10个结果 #如果先收敛成100个,然后在写一个全连接层 # self.fc2 = nn.Linear(1 * 28 * 28, 100) # self.fc2 = nn.Linear(100, 10) #激活函数,激励函数,通过数学手段将线性计算过程进行优化,使其加速。最常用的线性激活函数Relu self.relu = nn.ReLU() def forward(self, image):#继承前向传播的方法 image_viwed = image.view(-1, 1*28*28)#此处需要拍平 out = self.fc2(image_viwed) fcl_out = self.relu(out)#激活函数对应一下 return out 3.优化器

#优化器官方文档:https://pytorch.org/docs/stable/optim.html

(1)简介 优化器的作用就是寻求模型最优解,优化器有梯度下降,动量优化,自适应优化等,梯度下降是最原始的,也是最基础的。梯度下降算法,载入数据集,计算所有梯度,然后执行决策。依据是损失函数,通过损失进行每一步计算,梯度下降算法分为:标准梯度下降法、批量梯度下降法和随机梯度下降法。 (2)优化器实现? from torch import optim#导入优化器 #需要把实例化的模型传进去 model = MnistModel() optim.Adam(model.parameters(), lr=1e-4)#这是一种自适应的优化器,不需要调参 #lr表示学习率,1e-4表示10的4次方 #优化器官方文档:https://pytorch.org/docs/stable/optim.html 4.损失函数 (1)简介 损失函数,设计一个损失函数的计算方法,让他统一一个损失值,算出一个结论,进而判断下次模型要朝着那个方向去优化权重,最终损失函数的选择取决于最终的结果和标签之间的关系。每一种损失函数都对应着一种数学模型计算,目的就是把模型训练结果与标签之间建立起关系,在梯度下降优化器中,让损失不断减小的方向就是训练的方向 #损失函数的实现 (2)损失函数实现 LOST = nn.CTCLoss()#调用nn的损失函数,实例化 LOST(MODEL_RESULT, LABELS)#把模型的结果和标签传进去,得到一个数字就是损失值,就是优化器朝哪个方向去做的一个依据


1.本站遵循行业规范,任何转载的稿件都会明确标注作者和来源;2.本站的原创文章,会注明原创字样,如未注明都非原创,如有侵权请联系删除!;3.作者投稿可能会经我们编辑修改或补充;4.本站不提供任何储存功能只提供收集或者投稿人的网盘链接。

标签: #minst手写识别 #28像素的灰度手写数字图片 #MNIST