<dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><small id="yhprb"></small><dfn id="yhprb"></dfn><small id="yhprb"><delect id="yhprb"></delect></small><small id="yhprb"></small><small id="yhprb"></small> <delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"></dfn><dfn id="yhprb"></dfn><s id="yhprb"><noframes id="yhprb"><small id="yhprb"><dfn id="yhprb"></dfn></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><small id="yhprb"></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn> <small id="yhprb"></small><delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn>
"); //-->

博客專(zhuān)欄

EEPW首頁(yè) > 博客 > 征程 6EM 常見(jiàn) QConfig 配置解讀與示例

征程 6EM 常見(jiàn) QConfig 配置解讀與示例

發(fā)布人:地平線(xiàn)開(kāi)發(fā)者 時(shí)間:2025-06-01 來(lái)源:工程師 發(fā)布文章

一、引言

在工具鏈用戶(hù)手冊《量化感知訓練(QAT)-開(kāi)發(fā)指南-QConfig 詳解》章節專(zhuān)門(mén)介紹了在 J6EM 上 qconfig 是怎么回事,從經(jīng)歷看,大家可能會(huì )存在看了依舊不懂,或懂了不知道怎么配置的情況,特別是一些 OE 包中示例沒(méi)有的配置,例如固定某節點(diǎn) scale、配置 linear weight int16 等操作。

qconfig 控制了模型所有節點(diǎn)的量化類(lèi)型,例如是采用 int8 還是 int16 量化,是固定校準階段的 scale 去 qat 還是不固定 scale 去 qat。

提供的模板可分為三類(lèi):基礎模板、敏感度模板、自定義模板。本文將常見(jiàn)配置通過(guò)示例方式進(jìn)行呈現。

二、基礎模板

基礎模板中 calibration / qat / qat_fixed_act_scale 區別在于使用的 observer 類(lèi)型和 scale 更新邏輯,分別用于校準,不固定 activation scaleqat 訓練,固定 activation scale qat 訓練。

default 模板 ( default_calibration_qconfig_setter / default_qat_qconfig_setter / default_qat_fixed_act_qconfig_setter ) 會(huì )做三件事:

首先,將可以設置的高精度輸出都設置上,對于不支持高精度的輸出將給出提示;

然后,從 grid sample 算子的 grid 輸入向前搜索,直到出現第一個(gè) gemm 類(lèi)算子或者 QuantStub,將中間的所有算子都設置為 int16。根據經(jīng)驗這里的 grid 一般表達范圍較寬,int8 有較大可能不滿(mǎn)足精度需求;

最后,將其余算子設置為 int8。

int16 模板 ( qat_8bit_weight_16bit_act_qconfig_setter / qat_8bit_weight_16bit_fixed_act_qconfig_setter / calibration_8bit_weight_16bit_act_qconfig_setter ) 會(huì )做兩件事:

首先,將可以設置的高精度輸出都設置上,對于不支持高精度的輸出將給出提示;

其次,將其余算子設置為 int16。

from horizon_plugin_pytorch.quantization.qconfig_template import (
   default_calibration_qconfig_setter,
   default_qat_qconfig_setter,
   default_qat_fixed_act_qconfig_setter,
   qat_8bit_weight_16bit_act_qconfig_setter,
   qat_8bit_weight_16bit_fixed_act_qconfig_setter,
   calibration_8bit_weight_16bit_act_qconfig_setter,
)
qat_or_calib_model = prepare(
   float_model,
   example_inputs=example_inputs,  # 用來(lái)感知圖結構
   qconfig_setter=(

       default_qat_qconfig_setter,    # 根據需要配置setter模板
   ),
)

三、敏感度模板

敏感度模板有三個(gè):

sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter
sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter
sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter

三者的區別和基礎模板中三者的區別類(lèi)似,也是分別用于校準,不固定 activation scale qat 訓練,固定 activation scale qat 訓練。

敏感度模板的第一個(gè)輸入是精度 debug 工具產(chǎn)生的敏感度結果,第二個(gè)參數可以指定 ratio 或 topk,敏感度模板會(huì )根據配置,將量化敏感度最高的 topk 個(gè)算子設置為 int16。搭配固定模板,可以實(shí)現混合精度調優(yōu)。

若模型有多個(gè)輸出,每個(gè)輸出都會(huì )產(chǎn)生一個(gè)敏感度表,您可以設置多個(gè)敏感度模版。示例如下:

from horizon_plugin_pytorch.quantization.qconfig_template import (
   default_calibration_qconfig_setter,
   sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter,
   sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,
   sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter,
)

# 這兩個(gè)pt文件是通過(guò)debug工具得到的
table1 = torch.load("output_0-0_L1_sensitive_ops.pt")
table2 = torch.load("output_0-1_L1_sensitive_ops.pt")

calibration_model = prepare(
   float_model,
   example_inputs=example_input,
   qconfig_setter=(
       sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter(table1, ratio=0.2),
       sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter(table2, ratio=0.2),
       default_calibration_qconfig_setter,
   ),
)

四、自定義模板

自定義模板為 ModuleNameQconfigSetter,需要傳入模塊名和對應自定義的 qconfig,一般用于設置 fixed scale、配置 linear weight int16 等特殊需求,可以和固定模板,敏感度模板搭配使用。示例如下:

from horizon_plugin_pytorch.quantization.qconfig_template import (
   calibration_8bit_weight_16bit_act_qconfig_setter,
   ModuleNameQconfigSetter,
)
from horizon_plugin_pytorch.quantization.qconfig import (
   get_qconfig,
   MSEObserver,
   MinMaxObserver,
   FixedScaleObserver,
   QConfig,
)
from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize

# 手動(dòng)設置某個(gè)算子的輸出scale
op_name_output_fix_scale_qconfig = QConfig(
   output=FakeQuantize.with_args(
       observer=FixedScaleObserver,
       dtype=qint16,
       scale=0.0625,
   )
)

# 設置某個(gè)算子weight與輸出activation的量化類(lèi)型
# 校準時(shí)用MSEObserver,qat時(shí)用MinMaxObserver
# 沒(méi)有weight的算子,配置了weight_dtype也不會(huì )起作用
calib_weight_act_both_int16_qconfig = get_qconfig(
   observer=MSEObserver,
   weight_dtype=qint16,
   out_dtype=qint16,
)

calib_weight_act_both_int8_qconfig = get_qconfig(
   observer=MSEObserver,
   weight_dtype=qint8,
   out_dtype=qint8,
)

qat_weight_act_both_int16_qconfig = get_qconfig(
   observer=MinMaxObserver,
   weight_dtype=qint16,
   out_dtype=qint16,
   fix_scale=True,    # 是否固定scale
)

放在一塊簡(jiǎn)單示例如下:

from horizon_plugin_pytorch.quantization.qconfig_template import (
   default_qat_qconfig_setter,
   sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,
   ModuleNameQconfigSetter,
)

table = torch.load("output_0-0_dataindex_1_sensitive_ops.pt")

# 自動(dòng)替換生成的算子只能通過(guò) ModuleNameQconfigSetter 配置自定義 qconfig。
module_name_to_qconfig = {
   "_generated_add_0": op_name_output_fix_scale_qconfig ,
}

qat_model = prepare(
   float_model,
   example_inputs=example_input,
   qconfig_setter=(
       ModuleNameQconfigSetter(module_name_to_qconfig),
       sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter(table, ratio=0.2),
       default_qat_qconfig_setter,
   ),
)

五、可運行的示例

將網(wǎng)絡(luò )中 linear2 的 weight 配置為 int16 量化、輸入配置為 int8 量化、輸出配置為 int16 量化,其他算子激活使用 int16 量化,weight 使用 int8 量化。

import torch
from horizon_plugin_pytorch import set_march, March
set_march(March.NASH_M)
from horizon_plugin_pytorch.quantization import prepare, set_fake_quantize, FakeQuantState
from horizon_plugin_pytorch.quantization import QuantStub
from horizon_plugin_pytorch.quantization.hbdk4 import export
from horizon_plugin_pytorch.quantization.qconfig_template import calibration_8bit_weight_16bit_act_qconfig_setter, ModuleNameQconfigSetter
from horizon_plugin_pytorch.quantization.qconfig import get_qconfig, MSEObserver, MinMaxObserver
from horizon_plugin_pytorch.dtype import qint8, qint16
from torch.quantization import DeQuantStub
import torch.nn as nn


# 定義網(wǎng)絡(luò )結構
class SmallModel(nn.Module):
   def __init__(self):
       super(SmallModel, self).__init__()
       # 第一個(gè) Linear: 輸入 [2, 100, 256] -> 輸出 [2, 100, 256]
       self.linear1 = nn.Linear(256, 256)
       self.layernorm = nn.LayerNorm(256)  # 對最后一維進(jìn)行歸一化
       self.relu = nn.ReLU()
       # 第二個(gè) Linear: 輸入 [2, 100, 256] -> 輸出 [2, 100, 60]
       self.linear2 = nn.Linear(256, 60)
       # 第三個(gè) Linear: 輸入 [2, 100, 60] -> 輸出 [2, 100, 60]
       self.linear3 = nn.Linear(60, 60)
       self.quant = QuantStub()
       self.dequant = DeQuantStub()

   def forward(self, x):
       x = self.quant(x)
       # 第一個(gè) Linear
       x = self.linear1(x)  # [2, 100, 256]
       x = self.layernorm(x)  # [2, 100, 256]
       x = self.relu(x)  # [2, 100, 256]
       # 第二個(gè) Linear
       x = self.linear2(x)  # [2, 100, 60]
       # 第三個(gè) Linear
       x = self.linear3(x)
       x = self.dequant(x)
       return x

example_input = torch.randn(2, 100, 256)
model = SmallModel()

# 前向傳播
output = model(example_input)
print("輸出形狀:", output.shape)

# A global march indicating the target hardware version must be setted before prepare qat.
set_march(March.NASH_M)

calib_weight_act_both_int16_qconfig = get_qconfig(
   observer=MSEObserver,
   weight_dtype=qint16,
   out_dtype=qint16,
)

# layernorm沒(méi)有weight,配置了weight_dtype也不會(huì )起作用
calib_weight_act_both_int8_qconfig = get_qconfig(
   observer=MSEObserver,
   weight_dtype=qint8,
   out_dtype=qint8,
)

qat_weight_act_both_int16_qconfig = get_qconfig(
   observer=MinMaxObserver,
   weight_dtype=qint16,
   out_dtype=qint16,
   fix_scale=True,
)
# 節點(diǎn)名稱(chēng),可以從model_check_result.txt中獲取,也可以從敏感度文件中獲取
module_name_to_qconfig = {
   "layernorm": calib_weight_act_both_int8_qconfig,
   "linear2": calib_weight_act_both_int16_qconfig,  
}

calib_model = prepare(model.eval(), example_input,
                     qconfig_setter=(
                         ModuleNameQconfigSetter(module_name_to_qconfig),
                         calibration_8bit_weight_16bit_act_qconfig_setter,
                         ),
                     )

calib_model.eval()
set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
calib_model(example_input)

calib_model.eval()                            
set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
calib_out = calib_model(example_input)

qat_bc = export(calib_model, example_input)

配置 add 單算子輸入和輸出均使用固定 scale

import torch
from horizon_plugin_pytorch import set_march, March
set_march(March.NASH_E)
from horizon_plugin_pytorch.quantization import prepare, set_fake_quantize, FakeQuantState
from horizon_plugin_pytorch.quantization import QuantStub
from horizon_plugin_pytorch.quantization.hbdk4 import export
from horizon_plugin_pytorch.quantization.qconfig_template import calibration_8bit_weight_16bit_act_qconfig_setter, ModuleNameQconfigSetter
from horizon_plugin_pytorch.quantization.qconfig import get_qconfig, MSEObserver, MinMaxObserver, FixedScaleObserver, QConfig
from horizon_plugin_pytorch.quantization.fake_quantize import FakeQuantize
from horizon_plugin_pytorch.dtype import qint8, qint16
from torch.quantization import DeQuantStub
import torch.nn as nn


class AddNet(nn.Module):
   def __init__(self):
       super(AddNet, self).__init__()
       self.quant_x = QuantStub()
       self.quant_y = QuantStub()
       self.dequant = DeQuantStub()

   def forward(self, x, y):
       x = self.quant_x(x)
       y = self.quant_y(y)
       z = torch.add(x, y)
       z = self.dequant(z)
       return z

# 創(chuàng )建模型
model = AddNet()

# 生成兩個(gè)相同形狀的輸入張量
torch.manual_seed(42)
x = torch.randn(1, 1, 2, 6)
y = torch.randn(1, 2, 2, 6)
example_input = (x,y)

# 前向傳播
output = model(example_input[0], example_input[1])
print("float輸出數據:", output)
print("輸入形狀:", example_input[0].shape)
print("輸出形狀:", output.shape)

# A global march indicating the target hardware version must be setted before prepare qat.
set_march(March.NASH_E)

add_input_fix_scale_qconfig = QConfig(
   output=FakeQuantize.with_args(
       observer=FixedScaleObserver,
       dtype=qint16,
       scale=0.03125,
   )
)
add_output_fix_scale_qconfig = QConfig(
   output=FakeQuantize.with_args(
       observer=FixedScaleObserver,
       dtype=qint16,
       scale=0.0625,
   )
)

# 節點(diǎn)名稱(chēng),可以從model_check_result.txt中獲取,也可以從敏感度文件中獲取
module_name_to_qconfig = {
   "quant_x": add_input_fix_scale_qconfig,

   "quant_y": add_input_fix_scale_qconfig,

   "_generated_add_0": add_output_fix_scale_qconfig,
}

calib_model = prepare(model.eval(), example_input,
                     qconfig_setter=(
                         ModuleNameQconfigSetter(module_name_to_qconfig),
                         calibration_8bit_weight_16bit_act_qconfig_setter,
                         ),
                     )

calib_model.eval()
set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
calib_model(example_input[0], example_input[1])

calib_model.eval()                            
set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
calib_out = calib_model(example_input[0], example_input[1])
print("calib輸出數據:", calib_out)

qat_bc = export(calib_model, example_input)

六、凍結部分網(wǎng)絡(luò )結構 qat 的配置

補充常見(jiàn)凍結網(wǎng)絡(luò )結構,去進(jìn)行 qat 的做法

from horizon_plugin_pytorch.quantization import (
   QuantStub,
   prepare,
   set_fake_quantize,
   FakeQuantState,
)
#prepare QAT模型
qat_model = prepare(
   model,
   example_inputs=xxx,
   qconfig_setter=(
       xxx,
   )
)
#加載calib權重
qat_model.load_state_dict(torch.load("calib-checkpoint.ckpt"))
#QAT訓練
qat_model.train()
#固定backbone部分的權重,requires_grad不影響drop bn的行為,需要與eval聯(lián)合用
for param in qat_model.backbone.parameters():
   param.requires_grad = False
#固定backbone部分的scale,eval只影響drop bn的行為,如果發(fā)生了backward仍然會(huì )改變權重,需要與requires_grad聯(lián)合使用
qat_model.backbone.eval()
set_fake_quantize(qat_model.backbone, FakeQuantState.VALIDATION)
#配置head的FakeQuant為QAT狀態(tài)
set_fake_quantize(qat_model.head, FakeQuantState.QAT)


*博客內容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀(guān)點(diǎn),如有侵權請聯(lián)系工作人員刪除。




相關(guān)推薦

技術(shù)專(zhuān)區

關(guān)閉
国产精品自在自线亚洲|国产精品无圣光一区二区|国产日产欧洲无码视频|久久久一本精品99久久K精品66|欧美人与动牲交片免费播放
<dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><small id="yhprb"></small><dfn id="yhprb"></dfn><small id="yhprb"><delect id="yhprb"></delect></small><small id="yhprb"></small><small id="yhprb"></small> <delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"></dfn><dfn id="yhprb"></dfn><s id="yhprb"><noframes id="yhprb"><small id="yhprb"><dfn id="yhprb"></dfn></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><small id="yhprb"></small><dfn id="yhprb"><delect id="yhprb"></delect></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn> <small id="yhprb"></small><delect id="yhprb"><strike id="yhprb"></strike></delect><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn><dfn id="yhprb"><s id="yhprb"><strike id="yhprb"></strike></s></dfn><dfn id="yhprb"><s id="yhprb"></s></dfn>