广告位联系
返回顶部
分享到

pytorch GPU和CPU模型相互加载方式

python 来源:互联网 作者:佚名 发布时间:2024-09-09 21:50:37 人浏览
摘要

1 pytorch保存模型的两种方式 1.1 直接保存模型并读取 1 2 3 4 5 6 7 # 创建你的模型实例对象: model model = net() ## 保存模型 torch.save(model, model_name.pth) ## 读取模型 model = torch.load(model_name.pth) 1.2 只保存模

1 pytorch保存模型的两种方式

1.1 直接保存模型并读取

1

2

3

4

5

6

7

# 创建你的模型实例对象: model

model = net()

## 保存模型

torch.save(model, 'model_name.pth')

 

## 读取模型

model = torch.load('model_name.pth')

1.2 只保存模型中的参数并读取

1

2

3

4

5

6

7

## 保存模型

torch.save({'model': model.state_dict()}, 'model_name.pth')

 

## 读取模型

model = net()

state_dict = torch.load('model_name.pth')

model.load_state_dict(state_dict['model'])

  • 第一种方法可以直接保存模型,加载模型的时候直接把读取的模型给一个参数就行。
  • 第二种方法则只是保存参数,在读取模型参数前要先定义一个模型(模型必须与原模型相同的构造),然后对这个模型导入参数。虽然麻烦,但是可以同时保存多个模型的参数,而第一种方法则不能,而且第一种方法有时不能保证模型的相同性(你读取的模型并不是你想要的)。

如何保存模型决定了如何读取模型,一般来选择第二种来保存和读取。

2 GPU / CPU模型相互加载

2.1 单个CPU和单个GPU模型加载

pytorch 允许把在GPU上训练的模型加载到CPU上,也允许把在CPU上训练的模型加载到GPU上。

加载模型参数的时候,在GPU和CPU训练的模型是不一样的,这两种模型是不能混为一谈的,下面分情况进行操作说明。

情况一:CPU -> CPU, GPU -> GPU

  • GPU训练的模型,在GPU上使用;
  • CPU训练的模型,在CPU上使用,

这种情况下我们都只用直接用下面的语句即可:

1

torch.load('model_dict.pth')

情况二:GPU -> CPG/GPU

GPU训练的模型,不知道放在CPU还是GPU运行,两种情况都要考虑

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

import torch

from torchvision import models

 

# 加载预训练的GPU模型权重文件

weights_path = 'model_gpu.pth'

 

# 定义一个与原模型结构相同的新模型

model = models.resnet50()

 

# 检查是否有可用的CUDA设备

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

 

# 将权重映射到相应的设备内存并加载到模型中

weights = torch.load(weights_path, map_location=device)

model.load_state_dict(weights)

 

# 设置为评估模式

model.eval()

 

print("Model is successfully loaded and can be used on a", device.type, "!")

情况三:CPU -> CPG/GPU

模型是在CPU上训练的,但不确定要在CPU还是GPU上运行时,两种情况都要考虑

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

import torch

from torchvision import models

 

# 加载预训练的CPU模型权重文件

weights_path = 'model_cpu.pth'

 

# 定义一个与原模型结构相同的新模型

model = models.resnet50()

 

# 检查是否有可用的CUDA设备

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

 

# 将权重映射到相应的设备内存并加载到模型中

if device.type == 'cuda':

    model.to(device)

    weights = torch.load(weights_path, map_location=device)

else:

    weights = torch.load(weights_path, map_location='cpu')

 

model.load_state_dict(weights)

 

# 设置为评估模式

model.eval()

 

print("Model is successfully loaded and can be used on a", device.type, "!")


版权声明 : 本文内容来源于互联网或用户自行发布贡献,该文观点仅代表原作者本人。本站仅提供信息存储空间服务和不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权, 违法违规的内容, 请发送邮件至2530232025#qq.cn(#换@)举报,一经查实,本站将立刻删除。
原文链接 :
相关文章
  • Python使用切片移动元素位置的代码
    一.基本介绍 1.切片基础 在 Python 中,切片是指从序列类型(如列表、字符串、元组等)中提取子序列的过程。切片的基本语法如下: 1 seq
  • Python使用FastApi发送Post请求的步骤
    一.基本介绍 FastAPI 是一个现代、快速(高性能)的 Web 框架,用于构建 API,它基于 Python 3.6 及以上版本。在 FastAPI 中发送 POST 请求,通常是
  • pytorch GPU和CPU模型相互加载方式
    1 pytorch保存模型的两种方式 1.1 直接保存模型并读取 1 2 3 4 5 6 7 # 创建你的模型实例对象: model model = net() ## 保存模型 torch.save(model, model_name
  • pytorch模型保存方式介绍
    pytorch模型保存 保存模型主要分为两类: 保存整个模型 只保存模型参数 1.保存加载整个模型(不推荐) 保存整个网络模型,网络结构+权重
  • Python虚拟环境virtualenv安装的详细教程保姆级(Wi

    Python虚拟环境virtualenv安装的详细教程保姆级(Wi
    虚拟环境安装 工作中我们经常会根据不同的项目切换不同的python环境,如果仅仅是在本地就安装一个python环境,项目移植也要重新配置环境
  • python中eval的用法介绍

    python中eval的用法介绍
    python中eval的用法 eval(字符串) 能够以Python表达式的方式解析并执行字符串,并将返回结果输出。 eval()函数将去掉字符串的两个引号,将其解
  • 使用python生成定制化词云的代码

    使用python生成定制化词云的代码
    数据可视化已成为我们理解复杂信息的关键工具。词云,作为一种流行的数据可视化形式,能够将大量文本数据中的关键词以视觉化的方式
  • 通过Python实现在Word中添加和删除书签的操作

    通过Python实现在Word中添加和删除书签的操作
    本文中用到的方法需要用到Spire.Doc for Python库。可以直接通过pip进行安装: pip install Spire.Doc Python 在指定段落添加书签 加载Word文档; 获取指
  • 使用Python在PDF文档中创建动作

    使用Python在PDF文档中创建动作
    PDF格式因其跨平台兼容性和丰富的功能集而成为许多行业中的首选文件格式。其中,PDF中的动作(Action) 功能尤为突出,它允许开发者嵌入
  • Python解决ModuleNotFoundError: No module named 'PIL'的问题
    一、分析问题背景 ModuleNotFoundError: No module named PIL是一个常见的Python错误,通常出现在使用Pillow库时。Pillow是Python中用于图像处理的一个库,
  • 本站所有内容来源于互联网或用户自行发布,本站仅提供信息存储空间服务,不拥有版权,不承担法律责任。如有侵犯您的权益,请您联系站长处理!
  • Copyright © 2017-2022 F11.CN All Rights Reserved. F11站长开发者网 版权所有 | 苏ICP备2022031554号-1 | 51LA统计