我一直在寻找替代方法来保存PyTorch训练模型。 到目前为止,我已经找到了两个备选方案。
- torch.save()保存一个模型和torch.load()加载一个模型。
- model.state_dict() ,以节省训练模型和model.load_state_dict()来加载保存的模型。
我也碰到过这种讨论 ,其中方法2时建议在方法1。
我的问题是,为什么第二种方法是首选? 难道仅仅是因为torch.nn模块具有这两种功能,我们鼓励使用它们?
我一直在寻找替代方法来保存PyTorch训练模型。 到目前为止,我已经找到了两个备选方案。
我也碰到过这种讨论 ,其中方法2时建议在方法1。
我的问题是,为什么第二种方法是首选? 难道仅仅是因为torch.nn模块具有这两种功能,我们鼓励使用它们?
我发现这个页面上他们的GitHub库,我就贴上这里的内容。
有序列化和恢复模型的两种主要方法。
第一个(推荐)保存并只加载模型参数:
torch.save(the_model.state_dict(), PATH)
再后来:
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
第二保存和加载整个模型:
torch.save(the_model, PATH)
再后来:
the_model = torch.load(PATH)
然而,在这种情况下,序列化的数据绑定到特定的类和使用的准确的目录结构,所以在其他项目中使用时,它可以通过各种方式突破,或在一些严重的refactors。
这取决于你想要做什么。
案例1:保存自己使用它的推理模型 :您保存模型,您还原它,然后你改变模型来评估模式。 这样做是因为你平时有BatchNorm
和Dropout
层,默认情况下是在建设训练模式:
torch.save(model.state_dict(), filepath)
#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()
案例#2:保存模型后恢复训练 :如果你需要保持训练,你是要保存模型,您需要保存的不仅仅是更多的钱。 您还需要保存的优化,时代,成绩等,你会做这样的状态:
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
...
}
torch.save(state, filepath)
恢复训练,你会做喜欢的东西: state = torch.load(filepath)
,然后还原每个单独的对象的状态,这样的事情:
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
既然你恢复训练, 不叫model.eval()
一旦你恢复加载时的状态。
案例3:型号被他人使用,没有进入到你的代码 :在Tensorflow您可以创建一个.pb
定义两者的架构和模型的权重文件。 这是非常方便,使用特别是当Tensorflow serve
。 等效的方式做到这一点Pytorch将是:
torch.save(model, filepath)
# Then later:
model = torch.load(filepath)
这样,仍然不防弹,自pytorch仍处于一个很大的变化,我不会推荐它。
该泡菜 Python库实现了序列化和反序列化Python对象二进制协议。
当您import torch
(或当您使用PyTorch)将import pickle
的你,你不需要调用pickle.dump()
和pickle.load()
直接,这是方法来保存和加载对象。
事实上, torch.save()
和torch.load()
将包裹pickle.dump()
和pickle.load()
为您服务。
一个state_dict
对方的回答值得提及的只是几个音符。
什么state_dict
我们有内部PyTorch? 实际上有两个state_dict
秒。
该PyTorch模型torch.nn.Module
有model.parameters()
调用来获得可以学习的参数(W和B)。 这些参数可学习,一次任意设定,会我们学习更新随着时间的推移。 可学习的参数是第一state_dict
。
第二state_dict
是优化状态字典。 优化器也是该模型的一部分。 你还记得,优化用于改善我们可以学习的参数。 但优化state_dict
是固定的。 在没有什么学习。
由于state_dict
对象是Python字典,它们可以方便地保存,更新,修改和恢复,加上模块化的大量工作PyTorch模型和优化。
让我们创建一个超级简单的模型来解释这一点:
import torch
import torch.optim as optim
model = torch.nn.Linear(5, 2)
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Model weight:")
print(model.weight)
print("Model bias:")
print(model.bias)
print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
此代码将输出以下内容:
Model's state_dict:
weight torch.Size([2, 5])
bias torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328, 0.1360, 0.1553, -0.1838, -0.0316],
[ 0.0479, 0.1760, 0.1712, 0.2244, 0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]
请注意,这是一个最小模型。 你可以尝试添加顺序栈
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.Conv2d(A, B, C)
torch.nn.Linear(H, D_out),
)
注意有可学习的参数(卷积层,线性层等),注册缓冲器(batchnorm层),只有层在模型中的条目state_dict
。
非可学习的东西,属于优化对象state_dict
,其中包含有关优化器的状态信息,以及所使用的超参数。
这个故事的其余部分是相同的; 在推断阶段(这是一个阶段,当我们使用训练后的模型)预测; 我们根据我们了解到的参数预测。 所以对于推论,我们只需要保存参数model.state_dict()
torch.save(model.state_dict(), filepath)
而且使用后model.load_state_dict(torch.load(文件路径))model.eval()
注意:不要忘记最后一行model.eval()
这是加载模型后至关重要。
也不要试图挽救torch.save(model.parameters(), filepath)
。 所述model.parameters()
仅仅是生成器对象。
在另一边, torch.save(model, filepath)
保存模型对象本身,而是牢记模型没有优化的state_dict
。 检查由@Jadiel德阿马斯其他出色答卷保存优化的状态字典。