PTQ(Post Training Quantization)源码阅读一

PTQ(Post Training Quantization)源码阅读一

最近在做模型量化相关工作,就研究下PTQ的原理和代码实现。PTQ原理部分已经有很多文章讲的都很好,有时间的话后面自己总结一篇原理篇。本文主要从PTQ代码实现来阐述。

讲解代码前我们先看下PTQ的使用:

# load model
model = load_model(model_path)
model.eval()

# register quant_handle_hook in forward_post_hooks
ptq = PTQ()
model = ptq.quantize(model)

# calibration
for key, input in reader:
   model(input)

# compute quant params
ptq.ptq._convert(model)

# save quant model
jit.save(model, quant_model_path)

我们先看下如何收集 activation 量化信息。

ImperativePTQ

class ImperativePTQ(object):
    """
    Static post training quantization.
    """

    def __init__(self, quant_config=ptq_config.default_ptq_config):
        """
        Constructor.

        Args:
            quant_config(PTQConfig): the config of post training quantization.
                The config has weight_quantizer and activation_quantizer.
                In default, the weight_quantizer is PerChannelAbsmaxQuantizer
                and the activation_quantizer is KLQuantizer.
        """
        super().__init__()

        assert isinstance(quant_config, ptq_config.PTQConfig)

        self._quant_config = quant_config

ImperativePTQ是PTQ的实现类。输出参数为quant_config,主要指明 weight/activation 的量化方法。默认的 activation_quantizer 使用 KLQuantizer, weight_quntizer 使用 PerChannelAbsmaxQuantizer.

class PTQConfig(object):
    """
    The PTQ config shows how to quantize the inputs and outputs.
    """

    def __init__(self, activation_quantizer, weight_quantizer):
        """
        Constructor.

        Args:
            activation_quantizer(BaseQuantizer): The activation quantizer.
                It should be the instance of BaseQuantizer.
            weight_quantizer(BaseQuantizer): The weight quantizer.
                It should be the instance of BaseQuantizer.
        """
        super().__init__()
        assert isinstance(activation_quantizer, tuple(SUPPORT_ACT_QUANTIZERS))
        assert isinstance(weight_quantizer, tuple(SUPPORT_WT_QUANTIZERS))

        self.in_act_quantizer = copy.deepcopy(activation_quantizer)
        self.out_act_quantizer = copy.deepcopy(activation_quantizer)
        self.wt_quantizer = copy.deepcopy(weight_quantizer)

        self.quant_hook_handle = None

        # In order to wrap simulated layers, use in_act_quantizer
        # to calculate the input thresholds for conv2d, linear and etc.
        self.enable_in_act_quantizer = False


default_ptq_config = PTQConfig(KLQuantizer(), PerChannelAbsmaxQuantizer())

其中 quant_hook_handle 是 Layer 的 foward post hook的 handle。

enable_in_act_quantizer 是否使用 in_act_quantizer 计算输入激活的量化参数。

_is_skip_layer 和 _is_quant_layer

模型一般是一层一层堆叠起来的,框架提供的nn.Conv2d, nn.Linear层一般作为基础层来搭建模型网络。量化时我们需要知道哪些层需要量化,哪些层不需要量化。可以通过_is_skip_layer_is_quant_layer两个静态类方法获得。

     @staticmethod
     def _is_skip_layer(layer):
        return hasattr(layer, "skip_quant") and layer.skip_quant == True

     @staticmethod
     def _is_quant_layer(layer):
        return hasattr(layer, "_quant_config")

is_leaf_layer

def is_leaf_layer(layer):
    """
    Whether the layer is leaf layer.
    """
    return isinstance(layer, paddle.nn.Layer) and len(layer.sublayers()) == 0

layer 的 sublayers 空时为叶子节点。

quantize

     def quantize(self, model, inplace=False, fuse=False, fuse_list=None):
        """
        Add quant config and hook to the target layer.

        Args:
            model(paddle.nn.Layer): The model to be quantized.
            inplace(bool): Whether apply quantization to the input model.
                           Default: False.
            fuse(bool): Whether to fuse layers.
                        Default: False.
            fuse_list(list): The layers' names to be fused. For example,
                "fuse_list = [["conv1", "bn1"], ["conv2", "bn2"]]".
                A TypeError would be raised if "fuse" was set as
                True but "fuse_list" was None.
                Default: None.
        Return
            quantized_model(paddle.nn.Layer): The quantized model.
        """
        assert isinstance(
            model, paddle.nn.Layer
        ), "The model must be the instance of paddle.nn.Layer."
        if not inplace:
            model = copy.deepcopy(model)
        if fuse:
            model.eval()
            model = fuse_utils.fuse_layers(model, fuse_list)

我们看下模型量化的入口, model 是模型实例,inplace 指明是否在原图上操作,fusefuse_list用户指定是否对模型做fuse操作。该接口最终返经过处理(用于收集模型各层 activation 的信息)后的模型。

        for name, layer in model.named_sublayers():
            if (
                PTQRegistry.is_supported_layer(layer)
                and utils.is_leaf_layer(layer)
                and not self._is_skip_layer(layer)
            ):
                # Add quant config
                quant_config = copy.deepcopy(self._quant_config)
                if PTQRegistry.is_simulated_quant_layer(layer):
                    ## quant activation
                    quant_config.enable_in_act_quantizer = True
                layer._quant_config = quant_config

                # register hook
                hook = ptq_hooks.quant_forward_post_hook
                quant_hook_handle = layer.register_forward_post_hook(hook)
                quant_config.quant_hook_handle = quant_hook_handle
                layer._forward_post_hooks.move_to_end(
                    quant_hook_handle._hook_id, last=False
                )

        return model

首先遍历各层,判断该层:

PTQRegistry 是一个字典,后续再看下其实现。

如果满足上述条件,则对该层添加量化处理:

我们看下quant_forward_post_hook的实现:

def quant_forward_post_hook(layer, inputs, outputs):
    """
    The forward_post_hook for PTQ.
    """
    assert hasattr(
        layer, '_quant_config'
    ), "The layer should have _quant_config attr"

    qc = layer._quant_config
    if qc.enable_in_act_quantizer:
        qc.in_act_quantizer.sample_data(layer, inputs)
    qc.out_act_quantizer.sample_data(layer, (outputs,))

在 forward 完成后,通过 qc.out_act_quantizer 收集 outputs 的 activation 数据。

根据 qc.enable_in_act_quantizer 的配置确定是否收集 inputs 的 activation 数据。

我们知道,只有 PTQRegistry.is_simulated_quant_layer(layer) 真(目前只有 nn.Conv2D/nn.Linaer 时为真)的时候 qc.enable_in_act_quantizer 为真。

KLQuantizerPerChannelAbsmaxQuantizer的实现我们后面再讨论。

至此,处理完各层后返回 model 对象。后续使用校准数据过 model,收集 activation 分布。

展开阅读全文

页面更新:2024-04-27

标签:遍历   源码   模型   叶子   原理   参数   代码   操作   方法   数据   信息

1 2 3 4 5

上滑加载更多 ↓
推荐阅读:
友情链接:
更多:

本站资料均由网友自行发布提供,仅用于学习交流。如有版权问题,请与我联系,QQ:4156828  

© CopyRight 2020-2024 All Rights Reserved. Powered By 71396.com 闽ICP备11008920号-4
闽公网安备35020302034903号

Top