广告位联系
返回顶部
分享到

Pytorch中实现只导入部分模型参数

python 来源:互联网搜集 作者:秩名 发布时间:2020-01-02 22:24:38 人浏览
摘要

我们在做迁移学习,或者在分割,检测等任务想使用预训练好的模型,同时又有自己修改之后的结构,使得模型文件保存的参数,有一部分是不需要的(dont expected)。我们搭建的网络对保存文件来说,有一部分参数也是没有的(missed)。如果依旧使用torch.load(

我们在做迁移学习,或者在分割,检测等任务想使用预训练好的模型,同时又有自己修改之后的结构,使得模型文件保存的参数,有一部分是不需要的(don't expected)。我们搭建的网络对保存文件来说,有一部分参数也是没有的(missed)。如果依旧使用torch.load(model.state_dict())的办法,就会出现 xxx expected,xxx missed类似的错误。那么在这种情况下,该如何导入模型呢?

好在Pytorch中的模型参数使用字典保存的,键是参数的名称,值是参数的具体数值。我们使用model.state_dict()获得这个字典,之后就能利用参数名称来实现导入。

请看下面的一个例子。

我们先搭建一个小小的网络。
 

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
class Net(Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1 = nn.Conv2d(3,32,3,1)
    self.conv2 = nn.Conv2d(32,3,3,1)
    self.w = nn.Parameter(t.randn(3,10))
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
  
    out = F.avg_pool2d(out,(out.shape[2],out.shape[3]))
    out = F.linear(out,weight=self.w)
    return out

然后我们保存这个网络的初始值。

model = Net()
t.save(model.state_dict(),'xxx.pth')

现在我们将Net修改一下,多加几个卷积层,但并不加入到forward中,仅仅出于少些几行的目的。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
  
  
class Net(Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 3, 3, 1)
    self.conv3 = nn.Conv2d(3,64,3,1)
    self.conv4 = nn.Conv2d(64,32,3,1)
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
  
    self.w = nn.Parameter(t.randn(3, 10))
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
  
    out = F.avg_pool2d(out, (out.shape[2], out.shape[3]))
    out = F.linear(out, weight=self.w)
    return out

我们现在试着导入之前保存的模型参数。

path = 'xxx.pth'
model = Net()
model.load_state_dict(t.load(path))
  
'''
RuntimeError: Error(s) in loading state_dict for Net:
 Missing key(s) in state_dict: "conv3.weight", "conv3.bias", "conv4.weight", "conv4.bias". 
'''

出现了没有在模型文件中找到error中的关键字的错误。

现在我们这样导入模型
 

path = 'xxx.pth'
model = Net()
save_model = t.load(path)
model_dict = model.state_dict()
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
model_dict.update(state_dict)
model.load_state_dict(model_dict)

看看上面的代码,很容易弄明白。其中model_dict.update的作用是更新代码中搭建的模型参数字典。为啥更新我其实并不清楚,但这一步骤是必须的,否则还会报错。

为了弄清楚为什么要更新model_dict,我们不妨分别输出state_dict和model_dict的关键值看一看。
 

for k in state_dict.keys():
  print(k)
  
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
'''
for k in model_dict.keys():
  print(k)
  
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
conv4.weight
conv4.bias
'''

这个结果也是预料之中的,所以我猜测,update之后,model_dict和state_dict中具有相同键的值已经同步了。updata的目的就是使model_dict带有state_dict中都具有的那一部分参数的值,对于model_dict中有的,但是save_dict中没有的参数,值不改变,参数仍然使用初始值。

版权声明 : 本文内容来源于互联网或用户自行发布贡献,该文观点仅代表原作者本人。本站仅提供信息存储空间服务和不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权, 违法违规的内容, 请发送邮件至2530232025#qq.cn(#换@)举报,一经查实,本站将立刻删除。
原文链接 : https://blog.csdn.net/qq_34914551/article/details/87871134
相关文章
  • Python Django教程之实现新闻应用程序

    Python Django教程之实现新闻应用程序
    Django是一个用Python编写的高级框架,它允许我们创建服务器端Web应用程序。在本文中,我们将了解如何使用Django创建新闻应用程序。 我们将
  • 书写Python代码的一种更优雅方式(推荐!)

    书写Python代码的一种更优雅方式(推荐!)
    一些比较熟悉pandas的读者朋友应该经常会使用query()、eval()、pipe()、assign()等pandas的常用方法,书写可读性很高的「链式」数据分析处理代码
  • Python灰度变换中伽马变换分析实现

    Python灰度变换中伽马变换分析实现
    1. 介绍 伽马变换主要目的是对比度拉伸,将图像灰度较低的部分进行修正 伽马变换针对的是对单个像素点的变换,也就是点对点的映射 形
  • 使用OpenCV实现迷宫解密的全过程

    使用OpenCV实现迷宫解密的全过程
    一、你能自己走出迷宫吗? 如下图所示,可以看到是一张较为复杂的迷宫图,相信也有人尝试过自己一点一点的找出口,但我们肉眼来解谜
  • Python中的数据精度问题的介绍

    Python中的数据精度问题的介绍
    一、python运算时精度问题 1.运行时精度问题 在Python中(其他语言中也存在这个问题,这是计算机采用二进制导致的),有时候由于二进制和
  • Python随机值生成的常用方法

    Python随机值生成的常用方法
    一、随机整数 1.包含上下限:[a, b] 1 2 3 4 import random #1、随机整数:包含上下限:[a, b] for i in range(10): print(random.randint(0,5),end= | ) 查看运行结
  • Python字典高级用法深入分析讲解
    一、 collections 中 defaultdict 的使用 1.字典的键映射多个值 将下面的列表转成字典 l = [(a,2),(b,3),(a,1),(b,4),(a,3),(a,1),(b,3)] 一个字典就是一个键对
  • Python浅析多态与鸭子类型使用实例
    什么多态:同一事物有多种形态 为何要有多态=》多态会带来什么样的特性,多态性 多态性指的是可以在不考虑对象具体类型的情况下而直
  • Python字典高级用法深入分析介绍
    一、 collections 中 defaultdict 的使用 1.字典的键映射多个值 将下面的列表转成字典 l = [(a,2),(b,3),(a,1),(b,4),(a,3),(a,1),(b,3)] 一个字典就是一个键对
  • Python淘宝或京东等秒杀抢购脚本实现(秒杀脚本

    Python淘宝或京东等秒杀抢购脚本实现(秒杀脚本
    我们的目标是秒杀淘宝或京东等的订单,这里面有几个关键点,首先需要登录淘宝或京东,其次你需要准备好订单,最后要在指定时间快速
  • 本站所有内容来源于互联网或用户自行发布,本站仅提供信息存储空间服务,不拥有版权,不承担法律责任。如有侵犯您的权益,请您联系站长处理!
  • Copyright © 2017-2022 F11.CN All Rights Reserved. F11站长开发者网 版权所有 | 苏ICP备2022031554号-1 | 51LA统计