PTQ(Post Training Quantization)源码阅读二
上文提到了 PTQRegistry 这个类,主要功能是作为dict来存储 nn.Layer -> LayerInfo 的映射。我们看下这个类的实现。
class LayerInfo(object):
"""
Store the argnames of the inputs and outputs.
"""
def __init__(self, layer, input_names: List[TEXT], weight_names: List[TEXT], output_names: List[TEXT]):
super().__init__()
self.layer = layer
self.input_names = input_names
self.weight_names = weight_names
self.output_names = output_names
主要存储 nn.Layer,及其对应的输入、权重和输出名。
全局参数 PTQ_LAYERS_INFO, QUANT_LAYERS_INFO 和 SIMULATED_LAYERS 汇总目前支持量化的层的 LayerInfo 如下:
PTQ_LAYERS_INFO = [
LayerInfo(paddle.nn.Conv2D, ['Input'], ['Filter'], ['Output']),
LayerInfo(paddle.nn.Linear, ['X'], ['Y'], ['Out']),
LayerInfo(paddle.nn.BatchNorm2D, ['X'], [], ['Y']),
LayerInfo(paddle.nn.AdaptiveMaxPool2D, ['X'], [], ['Out']),
LayerInfo(paddle.nn.AdaptiveAvgPool2D, ['X'], [], ['Out']),
LayerInfo(paddle.nn.AvgPool2D, ['X'], [], ['Out']),
LayerInfo(paddle.nn.MaxPool2D, ['X'], [], ['Out']),
LayerInfo(paddle.nn.ReLU, ['X'], [], ['Out']),
LayerInfo(paddle.nn.ReLU6, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Hardswish, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Swish, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Sigmoid, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Softmax, ['X'], [], ['Out']),
LayerInfo(paddle.nn.Tanh, ['X'], [], ['Out']),
LayerInfo(paddle.nn.quant.add, ['X', 'Y'], [], ['Out']),
]
QUANT_LAYERS_INFO = [
LayerInfo(
paddle.nn.quant.quant_layers.QuantizedConv2D,
['Input'],
['Filter'],
['Output'],
),
LayerInfo(
paddle.nn.quant.quant_layers.QuantizedLinear, ['X'], ['Y'], ['Out']
),
]
SIMULATED_LAYERS = [paddle.nn.Conv2D, paddle.nn.Linear]
PTQ_LAYERS_INFO中存储目前支持量化的层和对应的输入、输出、权重名字。
QUANT_LAYERS_INFO 是量化新实现支持的 LayerInfo。这种实现方式等价 torch 基于 nn.QuantModule 的实现。
SIMULATED_LAYERS 存储的针对 input/weight 量化的层。模拟量化层会采集层的 input 的分布。weight 分布不需要采集。
模拟量化这里应该指的是 Fake Quantization.
PTQRegistry 用于对上面三个全局变量查询访问使用.
class PTQRegistry(object):
"""
Register the supported layers for PTQ and provide layers info.
"""
supported_layers_map = {}
registered_layers_map = {}
is_inited = False
def __init__(self):
super().__init__()
@classmethod
def _init(cls):
if not cls.is_inited:
for layer_info in PTQ_LAYERS_INFO:
cls.supported_layers_map[layer_info.layer] = layer_info
all_layers_info = PTQ_LAYERS_INFO + QUANT_LAYERS_INFO
for layer_info in all_layers_info:
cls.registered_layers_map[layer_info.layer] = layer_info
cls.is_inited = True
cls.supported_layers_map 存储 PTQ_LAYERS_INFO 的内容。
cls.registered_layers_map 存储 PTQ_LAYERS_INFO + QUANT_LAYERS_INFO 的内容。
注意,这里的 key 是 nn.Layer 子类。
四个查询接口如下,不做过多介绍了。:
@classmethod
def is_supported_layer(cls, layer):
"""
Analyze whether the layer supports quantization.
Args:
layer(Layer): The input layer can be a python class or an instance.
Returns:
flag(bool): Whther the layer is supported.
"""
cls._init()
return layer in cls.supported_layers_map or isinstance(
layer, tuple(cls.supported_layers_map.keys())
)
@classmethod
def is_registered_layer(cls, layer):
"""
Analyze whether the layer is register layer_info.
Args:
layer(Layer): The input layer can be a python class or an instance.
Returns:
flag(bool): Wether the layer is register layer_info.
"""
cls._init()
return layer in cls.registered_layers_map or isinstance(
layer, tuple(cls.registered_layers_map.keys())
)
@classmethod
def is_simulated_quant_layer(cls, layer):
"""
Analyze whether the layer is simulated quant layer.
Args:
layer(Layer): The input layer can be a python class or an instance.
Returns:
flag(bool): Whther the layer is supported.
"""
return layer in SIMULATED_LAYERS or isinstance(
layer, tuple(SIMULATED_LAYERS)
)
@classmethod
def layer_info(cls, layer):
"""
Get the infomation for the layer.
Args:
layer(Layer): The input layer can be a python class or an instance.
Returns:
layer_info(LayerInfo): The layer info of the input layer.
"""
assert cls.is_registered_layer(
layer
), "The input layer is not register."
for layer_key, layer_info in cls.registered_layers_map.items():
if layer == layer_key or isinstance(layer, layer_key):
return layer_info
页面更新:2024-03-31
本站资料均由网友自行发布提供,仅用于学习交流。如有版权问题,请与我联系,QQ:4156828
© CopyRight 2020-2024 All Rights Reserved. Powered By 71396.com 闽ICP备11008920号-4
闽公网安备35020302034903号