pytorch模型保存 保存模型主要分为两类: 保存整个模型 只保存模型参数 1.保存加载整个模型(不推荐) 保存整个网络模型,网络结构+权重参数 1 torch.save(model,net.pth) 加载整个网络模型(可能比
pytorch模型保存保存模型主要分为两类:
1.保存加载整个模型(不推荐)保存整个网络模型,网络结构+权重参数
加载整个网络模型(可能比较耗时)
2.只保存加载模型参数(推荐)保存模型的权重参数(速度快,占内存少)
load 模型参数 因为我们只保存了 模型的参数,所以需要先定义一个网络对象,然后再加载模型参数。
#将模型参数加载到新模型中,torch.load返回的是一个OrderedDict,说明.state_dict()只是把所有模型的参数都已OrderedDict的形式存下来。
Note:保存模型进行推理测试时,只需保存训练好的模型的权重参数,即推荐第二种方法。
如果哪一天我们需要重新写这个网络的,比如使用new_model,如果直接load会出现unexpected key. 但是加上strict=False可以很容易地加载预训练的参数(注意检查key是否匹配),直接忽略不匹配的key,对于匹配的key则进行正常的赋值。 |
2019-06-18
2019-07-04
2021-05-23
2021-05-27
2021-05-27