AI课堂第19讲DL深度学习PyTo

模型训练的目的是为了生成合适的模型文件,该文件可以部署到不同设备上进行使用。

模型文件一般是在每次迭代训练完成后保存到电脑硬盘,在训练完成后,我们根据生成的损失曲线图选择合适的模型文件作为实际部署模型发布。

torch提供了save和load函数用于数据和文件的读写,可以存储变量、Tensor列表、Tensor字典、模型文件等。

PyTorch模型保存有许多后缀(.pt,.pth,.pkl等),其保存内容没有区别,仅仅是后缀不同。

1.读写Tensor数据

使用save函数和load函数分别存储和读取Tensor。save使用python的pickle实用程序将对象进行序列化,然后将序列化的对象保存到disk,使用save可以保存各种对象,包括模型、张量和字典等。而load使用pickleunpickle工具将pickle的对象文件反序列化为内存。

代码演示:

2.模型参数访问

读写模型参数之前我们应先知道参数是存在模型哪个位置,以及如何访问模型参数?

上节课我们讲了模型参数可以通过对net.named_parameters()循环遍历或者访问net.parameters()获取对应参数,比较麻烦,一般用于优化器的初始化。而模型的保存常用model.state_dict()来实现,因为state_dict()方法的返回值是一个python的字典对象,将每一层与它的对应参数建立映射关系,读写比较方便。

在网络模型中,只有具有可学习参数的层(卷积层、线性层等)才有state_dict中的条目。优化器(optim)也有一个state_dict,其中包含关于优化器状态以及所使用的超参数的信息。

3.读写模型

PyTorch中保存和加载训练模型有两种常见的方法:

(a)仅保存和加载模型参数(state_dict)(官方推荐)

仅保存模型参数:

torch.save(model.state_dict(),my_model.pth)

加载:model=Model_Net(**)#这里需要重新创建模型网络结构

model.load_state_dict(torch.load(my_model.pth))#这里根据模型结构,调用存储的模型参数

(b)保存和加载整个模型保存整个模型(包括参数和结构):

torch.save(model,my_model.pth)

加载:

model=torch.load(my_model.pth)

如果想要保存模型参数、训练采用的优化器参数、模型保存路径等信息,可将这些信息组合起来构成一个字典,然后将字典保存起来:

4.模型文件互相加载:CPU-GPU

模型的加载还和设备硬件有关,我们实际应用中常会在GPU上训练数据,也会有在CPU训练的情况(模型小、无GPU时)。

训练方式的不同(CPU训练或GPU训练),会影响模型的加载,torch.load函数在加载模型时会自动选择加载设备。如果想要在GPU和CPU之间相互加载的话,就需要指定对应的参数。

总结

模型的保存是我们训练的成果,将模型文件应用于部署测试是我们训练成果的展示,尤其是我们在GPU上训练的模型部署到CPU上运行可以极大的降低发布成本。




转载请注明:http://www.aierlanlan.com/tzrz/2671.html

  • 上一篇文章:
  •   
  • 下一篇文章: