PyTorch 最佳实践:模型保存和加载
PyTorch模型保存和加载有两种方法,官方最佳实践指南推荐其中一种,但似乎效果没啥区别。最近做模型量化,遇到一个意外的错误,才理解了最佳实践背后的原理,以及不遵循它可能会遇到什么问题。
作者:Lernapparat 编译:McGL
我们研究了一些最佳实践,同时尝试阐明其背后的基本原理。
你是中级 PyTorch 程序员吗?你是否遵循官方文档的最佳实践指南?你对哪些应该坚持,哪些可以放弃而不会搞出问题有自己的经验和看法吗?
我承认有时候很难遵循最佳实践,因为他们反对的方法似乎也能工作,而我并不完全理解他们的基本原理。这是发生在我身上的一件小事。
一个我做量化 (Quantization)的故事
在Raspberry Pi 上搭建 PyTorch 之后,我一直期待着用它做一些有趣的项目。当然,我找到了一个模型,我想在Pi上适配并跑起来。我很快就让它跑起来了,但是它没有我想象的那么快。所以我开始着手量化它。
量化使得任何操作都是有状态的 / 暂时的(stateful / temporarily)
如果你把 PyTorch 计算看作是一组由操作链接起来的值(张量),量化包括对每个操作进行量化,并形成一个意见(opinion),即通过一个仿射变换对量化元素类型进行整数范围近似,张量值输出的范围应该是多少。如果这听起来很复杂,不要担心,重点是现在每个操作都需要与“一个意见”相关联,或者更准确的说,是一个观察者,记录模型的一些典型应用中所看到的最小值和最大值。但是现在这意味着在量化期间,所有操作都是有状态的。更准确的说,在准备量化和进行量化之前,它们都是有状态的。
我经常提到这一点,我主张不要声明一次激活函数,然后多次重用。这是因为在使用函数的计算中的各个点上,观察者通常会看到不同的值,所以现在它们的工作方式不同了。
这种新的有状态特性也适用于简单的事情,比如张量相加,通常表示为 a + b。为此, PyTorch 提供了 torch.nn.quantized.FloatFunctional模块。这是一个常见的 Module ,但是做了修改,在计算中不使用 forward ,而是有几种方法对应基本的操作,如我们这里的.add
所以我使用了残差(residual)模块,它看起来大概像这样(注意它是如何分开独立声明激活的,这是一件好事!):
class ResBlock(torch.nn.Module):
def __init__(self, ...):
self.conv1 = ...
self.act1 = ...
self.conv2 = ...
self.act2 = ...
def forward(self, x):
return self.act2(x + self.conv2(self.act1(self.conv1(x))))
我还添加了 self.add = torch.nn.quantized.FloatFunctional() 到 __init__ 并把 x + ... 替换为 self.add.add(x, ...)。搞定!
根据准备好的模型,我可以添加量化本身,依据PyTorch 教程执行很简单。在评估脚本的最后,模型全部加载、设置为 eval 等之后,我添加了以下内容并重新启动了正在使用的 notebook kernel,然后运行了所有这些。
#config
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
torch.backends.quantized.engine = 'qnnpack'
# wrap in quantization of inputs / de-quantization of output)
model = torch.quantization.QuantWrapper(model)
# insert observers
torch.quantization.prepare(model, inplace=True)
因此稍后(在运行模型以获得观察结果之后) ,我会调用
torch.quantization.convert(model, inplace=True)
来得到一个模型。很简单!
一个意外的错误
现在我只需要运行几个批次的输入。
preds = model(inp)
但是发生了什么呢?
ModuleAttributeError: 'ResBlock' object has no attribute 'add'
糟糕!
出什么问题了? 是不是我在 ResBlock 中有拼写错误?
在 Jupyter中你可以非常容易地使用 ?? model.resblock1来检查。但是这没问题,没有拼写错误。
这就是 PyTorch 最佳实践的用武之地。
序列化(Serialization)最佳实践
PyTorch 官方文档有个关于序列化的说明,其中包含一个最佳实践部分。它这样开头
序列化和还原模型主要有两种方法。第一个(推荐)是只保存和加载模型参数:
然后展示了如何用 state_dict() 和 load_state_dict() 方法来运作. 第二种方法是保存和加载模型。
该说明提供了优先只使用序列化参数的理由如下:
然而,在[保存模型的情况]下,序列化的数据绑定到特定的类和所使用的确切目录结构,因此在其他项目中使用时,或在一些重度的重构之后,它可能会以各种方式中断。
事实证明,这是一个相当轻描淡写的说法,甚至在我们非常温和的修改中——几乎算不上重大的修改——也遇到了它所提到的问题。
什么出了问题?
为了找到问题的核心,我们必须思考 Python 中的对象是什么。在一个粗略的过度简化中,它完全由其 __dict__属性定义, 该属性包含所有("data")成员,其__class__ 属性指向它的类型( 例如,对于 Module 实例,是Module, 而对于 Module 本身 (一个类) ,是 type) 。当我们调用一个方法时,它通常不在 __dict__ 中(其实也可以,但改动会比较复杂)。但是 Python 会自动查询 __class__ 来寻找方法 (或者其他在 __dict__中找不到的东西)。
当反序列化模型时(我使用的模型的作者没有遵循最佳实践建议) ,Python 将通过查找 __class__ 的类型并将其与反序列化__dict__组合来构造一个对象。但是它(正确地)没有做的是调用 __init__ 来设置类(它不应该这样做,尤其是担心在 __init__ 和序列化之间可能已经修改了内容,或者它可能有我们不希望的副作用)。这意味着,当我们调用模块时,我们使用了新的forward 但是得到了原作者的__init__ 准备的__dict__ 和后续的训练,而没有我们修改过的 __init__ 添加的新属性add。
所以简而言之,这就是为什么在 Python 中序列化 PyTorch 模块或通常意义上的对象是危险的: 你很容易就会得到数据属性和代码不同步的结果。
保持兼容性
这里有一个显而易见的问题——也可以说是一个缺点——那就是除了状态字典(state dict)之外,我们还需要跟踪 setup 的配置。但是如果你愿意的话,你可以轻松地序列化所有参数以及状态字典——只需将它们粘贴到一个联合字典中。
但是不序列化模块本身还有其他优点:
显而易见的是,我们可以使用状态字典。可以无需模块加载状态字典,如果我们改变了一些重要的东西,可以检查和修改状态字典。
不太明显的是,实现者或用户还可以自定义模块处理状态字典。这有两个方面:
- 对于用户来说,有钩子(hooks)。好吧,它们不是非常官方,但是有_register_load_state_dict_pre_hook ,你可以用它来注册钩子,在更新模型之前处理状态字典,还有_register_state_dict_hook 来注册钩子,这些钩子在状态字典被收集之后和从 state_dict()返回之前被调用。
- 更重要的是,实现者可以覆写 _load_from_state_dict 。当类具有属性 _version时,这将在状态字典中保存为 version 元数据(metadata). 有了这个,你可以添加来自旧状态字典的转换。BatchNorm提供了一个怎么做到这点的例子,大致看起来像这样:
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if (version is None or version < 2) and self.have_new_thing:
new_key = prefix + 'new_thing_param'
if new_key not in state_dict:
state_dict[new_key] = ... # some default here
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
因此在这里我们检查版本是否是旧的,并且需要一个新的key,先添加它,然后再交给超类 (通常是 torch.nn.Module)常规处理。
总结
当保存整个模型而不是按照最佳实践只保存参数时,我们已经看到了什么出错了的非常详细的描述。 我个人的看法是,保存模型的陷阱是相当大的,很容易掉坑里,所以我们真的应该注意只保存模型参数,而不是 Module 类。
希望你喜欢这个深入 PyTorch 最佳实践的小插曲。
原文:http://lernapparat.de/pytorch-best-practices/
推荐阅读
-
[姿势估计] 实践记录:使用 Dlib 和 mediapipe 进行人脸姿势估计 - 本文重点介绍方法 2):方法 1:基于深度学习的方法:。 基于深度学习的方法:基于深度学习的方法利用深度学习模型,如卷积神经网络(CNN)或递归神经网络(RNN),直接从人脸图像中学习姿势估计。这些方法能够学习更复杂的特征表征,并在大规模数据集上取得优异的性能。方法二:基于二维校准信息估计三维姿态信息(计算机视觉 PnP 问题)。 特征点定位:人脸姿态估计的第一步是通过特征点定位来检测和定位人脸的关键点,如眼睛、鼻子和嘴巴。这些关键点提供了人脸的局部结构信息,可用于后续的姿势估计。 旋转表示:常见的旋转表示方法包括欧拉角和旋转矩阵。欧拉角通过三个旋转角度(通常是俯仰、偏航和滚动)描述头部的旋转姿态。旋转矩阵是一个 3x3 矩阵,表示头部从一个坐标系到另一个坐标系的变换。 三维模型重建:根据特征点的定位结果,三维人脸模型可用于姿势估计。通过将人脸的二维图像映射到三维模型上,可以估算出人脸的旋转和平移信息。这就需要建立人脸的三维模型,然后通过优化方法将模型与特征点对齐,从而获得姿势估计结果。 特征点定位 特征点定位是用于检测人脸关键部位的五官基础部分,还有其他更多的特征点表示方法,大家可以参考我上一篇文章中介绍的特征点检测方案实践:人脸校正二次定位操作来解决人脸校正的问题,客户在检测关键点的代码上略有修改,坐标转换部分客户见上图 def get_face_info(image). img_copy = image.copy image.flags.writeable = False image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) results = face_detection.process(image) # 在图像上绘制人脸检测注释。 image.flags.writeable = True image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) box_info, facial = None, None if results.detections: for detection in results. for detection in results.detections: mp_drawing.Drawing.detection = 无 mp_drawing.draw_detection(image, detection) 面部 = detection.location_data.relative_keypoints 返回面部 在上述代码中,返回的数据是五官(6 个关键点的坐标),这是用 mediapipe 库实现的,下面我们可以尝试用另一个库:dlib 来实现。 使用 dlib 使用 Dlib 库在 Python 中实现人脸关键点检测的步骤如下: 确保已安装 Dlib 库,可使用以下命令: pip install dlib 导入必要的库: 加载 Dlib 的人脸检测器和关键点检测器模型: 读取图像并将其灰度化: 使用人脸检测器检测图像中的人脸: 对检测到的人脸进行遍历,并使用关键点检测器检测人脸关键点: 显示绘制了关键点的图像: 以下代码将参数 landmarks_part 添加到要返回的关键点坐标中。
-
PyTorch 模型静态量化、保存、加载 int8 量化模型
-
PyTorch 最佳实践:模型保存和加载
-
静态量化、保存并在 PyTorch 模型训练完成后加载 int8 量化模型
-
Zero One Everything Yi-34B-Chat 微调模型和量化版本开源!Magic Hitch 社区最佳实践教程!
-
PyTorch: 如何保存、加载和绘制训练过程中的损失图?