PTQ(Post Training Quantization)源码阅读二

PTQ(Post Training Quantization)源码阅读二

上文提到了 PTQRegistry 这个类,主要功能是作为dict来存储 nn.Layer -> LayerInfo 的映射。我们看下这个类的实现。

LayerInfo

class LayerInfo(object):
    """
    Store the argnames of the inputs and outputs.
    """

    def __init__(self, layer, input_names: List[TEXT], weight_names: List[TEXT], output_names: List[TEXT]):
        super().__init__()
        self.layer = layer
        self.input_names = input_names
        self.weight_names = weight_names
        self.output_names = output_names

主要存储 nn.Layer,及其对应的输入、权重和输出名。

全局参数 PTQ_LAYERS_INFO, QUANT_LAYERS_INFOSIMULATED_LAYERS 汇总目前支持量化的层的 LayerInfo 如下:


PTQ_LAYERS_INFO = [
    LayerInfo(paddle.nn.Conv2D, ['Input'], ['Filter'], ['Output']),
    LayerInfo(paddle.nn.Linear, ['X'], ['Y'], ['Out']),
    LayerInfo(paddle.nn.BatchNorm2D, ['X'], [], ['Y']),
    LayerInfo(paddle.nn.AdaptiveMaxPool2D, ['X'], [], ['Out']),
    LayerInfo(paddle.nn.AdaptiveAvgPool2D, ['X'], [], ['Out']),
    LayerInfo(paddle.nn.AvgPool2D, ['X'], [], ['Out']),
    LayerInfo(paddle.nn.MaxPool2D, ['X'], [], ['Out']),
    LayerInfo(paddle.nn.ReLU, ['X'], [], ['Out']),
    LayerInfo(paddle.nn.ReLU6, ['X'], [], ['Out']),
    LayerInfo(paddle.nn.Hardswish, ['X'], [], ['Out']),
    LayerInfo(paddle.nn.Swish, ['X'], [], ['Out']),
    LayerInfo(paddle.nn.Sigmoid, ['X'], [], ['Out']),
    LayerInfo(paddle.nn.Softmax, ['X'], [], ['Out']),
    LayerInfo(paddle.nn.Tanh, ['X'], [], ['Out']),
    LayerInfo(paddle.nn.quant.add, ['X', 'Y'], [], ['Out']),
]

QUANT_LAYERS_INFO = [
    LayerInfo(
        paddle.nn.quant.quant_layers.QuantizedConv2D,
        ['Input'],
        ['Filter'],
        ['Output'],
    ),
    LayerInfo(
        paddle.nn.quant.quant_layers.QuantizedLinear, ['X'], ['Y'], ['Out']
    ),
]

SIMULATED_LAYERS = [paddle.nn.Conv2D, paddle.nn.Linear]

PTQ_LAYERS_INFO中存储目前支持量化的层和对应的输入、输出、权重名字。

QUANT_LAYERS_INFO 是量化新实现支持的 LayerInfo。这种实现方式等价 torch 基于 nn.QuantModule 的实现。

SIMULATED_LAYERS 存储的针对 input/weight 量化的层。模拟量化层会采集层的 input 的分布。weight 分布不需要采集。

模拟量化这里应该指的是 Fake Quantization.

PTQRegistry

PTQRegistry 用于对上面三个全局变量查询访问使用.

class PTQRegistry(object):
    """
    Register the supported layers for PTQ and provide layers info.
    """

    supported_layers_map = {}
    registered_layers_map = {}
    is_inited = False

    def __init__(self):
        super().__init__()

    @classmethod
    def _init(cls):
        if not cls.is_inited:
            for layer_info in PTQ_LAYERS_INFO:
                cls.supported_layers_map[layer_info.layer] = layer_info

            all_layers_info = PTQ_LAYERS_INFO + QUANT_LAYERS_INFO
            for layer_info in all_layers_info:
                cls.registered_layers_map[layer_info.layer] = layer_info
        cls.is_inited = True

cls.supported_layers_map 存储 PTQ_LAYERS_INFO 的内容。

cls.registered_layers_map 存储 PTQ_LAYERS_INFO + QUANT_LAYERS_INFO 的内容。

注意,这里的 key 是 nn.Layer 子类。

四个查询接口如下,不做过多介绍了。:

 @classmethod
    def is_supported_layer(cls, layer):
        """
        Analyze whether the layer supports quantization.
        Args:
            layer(Layer): The input layer can be a python class or an instance.
        Returns:
            flag(bool): Whther the layer is supported.
        """
        cls._init()
        return layer in cls.supported_layers_map or isinstance(
            layer, tuple(cls.supported_layers_map.keys())
        )

    @classmethod
    def is_registered_layer(cls, layer):
        """
        Analyze whether the layer is register layer_info.
        Args:
            layer(Layer): The input layer can be a python class or an instance.
        Returns:
            flag(bool): Wether the layer is register layer_info.
        """
        cls._init()
        return layer in cls.registered_layers_map or isinstance(
            layer, tuple(cls.registered_layers_map.keys())
        )

    @classmethod
    def is_simulated_quant_layer(cls, layer):
        """
        Analyze whether the layer is simulated quant layer.
        Args:
            layer(Layer): The input layer can be a python class or an instance.
        Returns:
            flag(bool): Whther the layer is supported.
        """
        return layer in SIMULATED_LAYERS or isinstance(
            layer, tuple(SIMULATED_LAYERS)
        )

    @classmethod
    def layer_info(cls, layer):
        """
        Get the infomation for the layer.
        Args:
            layer(Layer): The input layer can be a python class or an instance.
        Returns:
            layer_info(LayerInfo): The layer info of the input layer.
        """
        assert cls.is_registered_layer(
            layer
        ), "The input layer is not register."

        for layer_key, layer_info in cls.registered_layers_map.items():
            if layer == layer_key or isinstance(layer, layer_key):
                return layer_info

参考文献


展开阅读全文

页面更新:2024-03-31

标签:源码   子类   上文   权重   参考文献   全局   精度   接口   名字   内容

1 2 3 4 5

上滑加载更多 ↓
推荐阅读:
友情链接:
更多:

本站资料均由网友自行发布提供,仅用于学习交流。如有版权问题,请与我联系,QQ:4156828  

© CopyRight 2020-2024 All Rights Reserved. Powered By 71396.com 闽ICP备11008920号-4
闽公网安备35020302034903号

Top