irpas技术客

PyTorch(Python)训练MNIST模型移动端IOS上使用Swift实时数字识别_亚图跨际_pytorch ios

网络投稿 6354

识别手写数字是计算机视觉的基石问题,可以通过神经网络来解决。在此,我不会重复有关模型构建和训练的细节。

本文中,我的目的是将经过训练的模型移植到移动环境中。我使用 pytorch 构建模型,因为我想尝试一下 torchscript。对于 ios 应用程序,我使用 swift 和 swiftUI。

使用 PyTorch 进行手写数字识别 数据集

我们将使用流行的 MNIST 数据库。 它是 70000 个手写数字的集合,分为 60000 个和 10000 个图像的训练集和测试集。

在开始之前,我们需要进行所有必要的导入。

import numpy as np import torch import torchvision import matplotlib.pyplot as plt from time import time from torchvision import datasets, transforms from torch import nn, optim

首先,让我们定义要对数据执行哪些转换。 换句话说,您可以将其视为对图像执行的某种自定义编辑,以便所有图像具有相同的尺寸和属性。 我们使用 torchvision.transforms 来实现。

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), ])

现在我们终于下载了数据集,将它们打乱并转换它们中的每一个。r然后,将它们加载到 DataLoader,它结合了数据集和采样器,并在数据集上提供单进程或多进程迭代器。

trainset = datasets.MNIST('PATH_TO_STORE_TRAINSET', download=True, train=True, transform=transform) valset = datasets.MNIST('PATH_TO_STORE_TESTSET', download=True, train=False, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True)

其中,batch_size是我们想要一次读取的图像数量。

接下来,我们将对我们的图像和张量进行一些探索性数据分析。让我们检查图像和标签的形状。

dataiter = iter(trainloader) images, labels = dataiter.next() print(images.shape) print(labels.shape) 构建神经网络

我们将构建以下网络,如您所见,它包含一个输入层(第一层),一个由十个神经元(或单元,圆圈)组成的输出层和两个隐藏层。

调整权重

神经网络通过对可用数据进行多次迭代来学习。术语学习是指调整网络的权重以最小化损失。让我们想象一下它是如何工作的。

训练 | 测试评估 | 保存模型 本文建模

我选择的模型有两个卷积层和两个全连接层。它使用 LogSoftmax 作为输出层激活。

模型输入

该模型是在训练数据上训练的。 我想让它处理来自相机流的图像。 这称为生产数据。 该应用程序必须预处理生产数据以匹配训练数据的形状和语义。 否则它给出的结果将是次优的。

可视化 输出 TorchScript

Torchscript 是迈向非 Python 环境的第一步。简而言之,torchscript 为我们提供了一个可以在 c++ 中使用的模型。模型训练完成后,我们可以将其转换为 torchscript 模块。

进入ios环境

我们训练了一个模型,将像素数据(图像)转换为数字预测。我们可以从 c++ 调用它。我们如何在 ios 中运行它?这就是我们将在本节中发现的内容。

运行推理 let optionalImg = UIImage(named: "three") guard let inputImg = optionalImg else { return "An helpful error message" } 源代码

详情参阅 - 亚图跨际


1.本站遵循行业规范,任何转载的稿件都会明确标注作者和来源;2.本站的原创文章,会注明原创字样,如未注明都非原创,如有侵权请联系删除!;3.作者投稿可能会经我们编辑修改或补充;4.本站不提供任何储存功能只提供收集或者投稿人的网盘链接。

标签: #pytorch #iOS #我使用 #构建模型因为我想尝试一下