广告位联系
返回顶部
分享到

Pytorch写数字识别LeNet模型的介绍

python 来源:互联网 作者:秩名 发布时间:2022-01-20 21:19:11 人浏览
摘要

LeNet网络 LeNet网络过卷积层时候保持分辨率不变,过池化层时候分辨率变小。实现如下 from PIL import Image import cv2 import matplotlib.pyplot as plt import torchvision from torchvision import transforms import

LeNet网络

在这里插入图片描述

LeNet网络过卷积层时候保持分辨率不变,过池化层时候分辨率变小。实现如下

from PIL import Image
import cv2
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import numpy as np
import tqdm as tqdm

class LeNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.sequential = nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),
                                        nn.AvgPool2d(kernel_size=2,stride=2),
                                        nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
                                        nn.AvgPool2d(kernel_size=2,stride=2),
                                        nn.Flatten(),
                                        nn.Linear(16*25,120),nn.Sigmoid(),
                                        nn.Linear(120,84),nn.Sigmoid(),
                                        nn.Linear(84,10))
        
    
    def forward(self,x):
        return self.sequential(x)

class MLP(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.sequential = nn.Sequential(nn.Flatten(),
                          nn.Linear(28*28,120),nn.Sigmoid(),
                          nn.Linear(120,84),nn.Sigmoid(),
                          nn.Linear(84,10))
        
    
    def forward(self,x):
        return self.sequential(x)

epochs = 15
batch = 32
lr=0.9
loss = nn.CrossEntropyLoss()
model = LeNet()
optimizer = torch.optim.SGD(model.parameters(),lr)
device = torch.device('cuda')
root = r"./"
trans_compose  = transforms.Compose([transforms.ToTensor(),
                    ])
train_data = torchvision.datasets.MNIST(root,train=True,transform=trans_compose,download=True)
test_data = torchvision.datasets.MNIST(root,train=False,transform=trans_compose,download=True)
train_loader = DataLoader(train_data,batch_size=batch,shuffle=True)
test_loader = DataLoader(test_data,batch_size=batch,shuffle=False)

model.to(device)
loss.to(device)
# model.apply(init_weights)
for epoch in range(epochs):
  train_loss = 0
  test_loss = 0
  correct_train = 0
  correct_test = 0
  for index,(x,y) in enumerate(train_loader):
    x = x.to(device)
    y = y.to(device)
    predict = model(x)
    L = loss(predict,y)
    optimizer.zero_grad()
    L.backward()
    optimizer.step()
    train_loss = train_loss + L
    correct_train += (predict.argmax(dim=1)==y).sum()
  acc_train = correct_train/(batch*len(train_loader))
  with torch.no_grad():
    for index,(x,y) in enumerate(test_loader):
      [x,y] = [x.to(device),y.to(device)]
      predict = model(x)
      L1 = loss(predict,y)
      test_loss = test_loss + L1
      correct_test += (predict.argmax(dim=1)==y).sum()
    acc_test = correct_test/(batch*len(test_loader))
  print(f'epoch:{epoch},train_loss:{train_loss/batch},test_loss:{test_loss/batch},acc_train:{acc_train},acc_test:{acc_test}')

训练结果

epoch:12,train_loss:2.235553741455078,test_loss:0.3947642743587494,acc_train:0.9879833459854126,acc_test:0.9851238131523132
epoch:13,train_loss:2.028963804244995,test_loss:0.3220392167568207,acc_train:0.9891499876976013,acc_test:0.9875199794769287
epoch:14,train_loss:1.8020273447036743,test_loss:0.34837451577186584,acc_train:0.9901833534240723,acc_test:0.98702073097229 

训练结果

找了一张图片,将其分割成只含一个数字的图片进行测试

在这里插入图片描述

images_np = cv2.imread("/content/R-C.png",cv2.IMREAD_GRAYSCALE)
h,w = images_np.shape
images_np = np.array(255*torch.ones(h,w))-images_np#图片反色
images = Image.fromarray(images_np)
plt.figure(1)
plt.imshow(images)
test_images = []
for i in range(10):
  for j in range(16):
    test_images.append(images_np[h//10*i:h//10+h//10*i,w//16*j:w//16*j+w//16])
sample = test_images[77]
sample_tensor = torch.tensor(sample).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device)
sample_tensor = torch.nn.functional.interpolate(sample_tensor,(28,28))
predict = model(sample_tensor)
output = predict.argmax()
print(output)
plt.figure(2)
plt.imshow(np.array(sample_tensor.squeeze().to('cpu')))

在这里插入图片描述

此时预测结果为4,预测正确。从这段代码中可以看到有一个反色的步骤,若不反色,结果会受到影响,如下图所示,预测为0,错误。

模型用于输入的图片是单通道的黑白图片,这里由于可视化出现了黄色,但实际上是黑白色,反色操作说明了数据的预处理十分的重要,很多数据如果是不清理过是无法直接用于推理的。

在这里插入图片描述

将所有用来泛化性测试的图片进行准确率测试:

correct = 0
i = 0
cnt = 1
for sample in test_images:
  sample_tensor = torch.tensor(sample).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device)
  sample_tensor = torch.nn.functional.interpolate(sample_tensor,(28,28))
  predict = model(sample_tensor)
  output = predict.argmax()
  if(output==i):
    correct+=1
  if(cnt%16==0):
    i+=1
  cnt+=1
acc_g = correct/len(test_images)
print(f'acc_g:{acc_g}')

 如果不反色,acc_g=0.15

acc_g:0.50625


版权声明 : 本文内容来源于互联网或用户自行发布贡献,该文观点仅代表原作者本人。本站仅提供信息存储空间服务和不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权, 违法违规的内容, 请发送邮件至2530232025#qq.cn(#换@)举报,一经查实,本站将立刻删除。
原文链接 : https://blog.csdn.net/weixin_44823313/article/details/122581741
相关文章
  • Python Django教程之实现新闻应用程序

    Python Django教程之实现新闻应用程序
    Django是一个用Python编写的高级框架,它允许我们创建服务器端Web应用程序。在本文中,我们将了解如何使用Django创建新闻应用程序。 我们将
  • 书写Python代码的一种更优雅方式(推荐!)

    书写Python代码的一种更优雅方式(推荐!)
    一些比较熟悉pandas的读者朋友应该经常会使用query()、eval()、pipe()、assign()等pandas的常用方法,书写可读性很高的「链式」数据分析处理代码
  • Python灰度变换中伽马变换分析实现

    Python灰度变换中伽马变换分析实现
    1. 介绍 伽马变换主要目的是对比度拉伸,将图像灰度较低的部分进行修正 伽马变换针对的是对单个像素点的变换,也就是点对点的映射 形
  • 使用OpenCV实现迷宫解密的全过程

    使用OpenCV实现迷宫解密的全过程
    一、你能自己走出迷宫吗? 如下图所示,可以看到是一张较为复杂的迷宫图,相信也有人尝试过自己一点一点的找出口,但我们肉眼来解谜
  • Python中的数据精度问题的介绍

    Python中的数据精度问题的介绍
    一、python运算时精度问题 1.运行时精度问题 在Python中(其他语言中也存在这个问题,这是计算机采用二进制导致的),有时候由于二进制和
  • Python随机值生成的常用方法

    Python随机值生成的常用方法
    一、随机整数 1.包含上下限:[a, b] 1 2 3 4 import random #1、随机整数:包含上下限:[a, b] for i in range(10): print(random.randint(0,5),end= | ) 查看运行结
  • Python字典高级用法深入分析讲解
    一、 collections 中 defaultdict 的使用 1.字典的键映射多个值 将下面的列表转成字典 l = [(a,2),(b,3),(a,1),(b,4),(a,3),(a,1),(b,3)] 一个字典就是一个键对
  • Python浅析多态与鸭子类型使用实例
    什么多态:同一事物有多种形态 为何要有多态=》多态会带来什么样的特性,多态性 多态性指的是可以在不考虑对象具体类型的情况下而直
  • Python字典高级用法深入分析介绍
    一、 collections 中 defaultdict 的使用 1.字典的键映射多个值 将下面的列表转成字典 l = [(a,2),(b,3),(a,1),(b,4),(a,3),(a,1),(b,3)] 一个字典就是一个键对
  • Python淘宝或京东等秒杀抢购脚本实现(秒杀脚本

    Python淘宝或京东等秒杀抢购脚本实现(秒杀脚本
    我们的目标是秒杀淘宝或京东等的订单,这里面有几个关键点,首先需要登录淘宝或京东,其次你需要准备好订单,最后要在指定时间快速
  • 本站所有内容来源于互联网或用户自行发布,本站仅提供信息存储空间服务,不拥有版权,不承担法律责任。如有侵犯您的权益,请您联系站长处理!
  • Copyright © 2017-2022 F11.CN All Rights Reserved. F11站长开发者网 版权所有 | 苏ICP备2022031554号-1 | 51LA统计