隨著深度學(xué)習(xí)的發(fā)展,其應(yīng)用場(chǎng)景也越發(fā)的廣泛與多樣。這些多樣化的場(chǎng)景往往會(huì)對(duì)實(shí)際的部署提出更加“定制化”的限制。例如,自動(dòng)駕駛汽車對(duì)人體識(shí)別的精度要求肯定比圖像識(shí)別動(dòng)物分類的精度要求更加嚴(yán)苛,因?yàn)槎叩膽?yīng)用場(chǎng)景和錯(cuò)誤預(yù)測(cè)帶來(lái)的后果截然不同。這些“定制化”帶來(lái)的差異,對(duì)于實(shí)際部署的模型在精度、速度、空間占用上有更具體的要求。在很多場(chǎng)景中由于部署的設(shè)備算力不強(qiáng)、內(nèi)存較小,導(dǎo)致對(duì)于模型的速度和空間占用具有嚴(yán)格要求,而經(jīng)過(guò)量化的模型具有速度快、空間占用小的特性,恰恰能滿足這種需求。
因此量化模型被廣泛使用在推理側(cè),量化也成為了一個(gè)重要且非?;钴S的研究領(lǐng)域。近期,MegEngine 開(kāi)源了 4 bits 的量化的相關(guān)內(nèi)容,通過(guò) MegEngine 4 bits 量化實(shí)現(xiàn)的 ResNet-50 模型在 ImageNet 數(shù)據(jù)集上的精度表現(xiàn)與 8 bits 量化模型相差無(wú)幾,并且速度是 TensorRT-v7 8 bits ResNet-50 模型的推理速度的 1.3 倍。這次實(shí)踐為 MegEngine 積累了 4 bits 量化的相關(guān)經(jīng)驗(yàn)。同時(shí),MegEngine 決定將 4 bits 量化的相關(guān)代碼開(kāi)源,為大家提供可參考的完整方案,推動(dòng)在更低比特推理領(lǐng)域的探索與發(fā)展。
背景
深度學(xué)習(xí)領(lǐng)域的模型量化是將輸入從連續(xù)或其他較大的值集約束到離散集的過(guò)程。量化具有以下兩點(diǎn)優(yōu)勢(shì):
在存儲(chǔ)空間上,相較于 FLOAT 的 32 bits 的大小,量化值占用的空間更小。
在性能上,各類計(jì)算設(shè)備對(duì)量化值的計(jì)算能力要高于 FLOAT 的計(jì)算能力。
本文中提到的 n bits 量化,就是將 FP32 的數(shù)據(jù)約束到 n bits 表示的整型數(shù)據(jù)的過(guò)程。量化依據(jù)數(shù)據(jù)的映射特征可以分為線性量化和非線性量化,MegEngine 中采用的是線性量化,使用的量化公式和反量化公式如下:
其中,Q 是量化方法,r 是真實(shí)獲取的輸入 FLOAT 值,S 是 FLOAT 類型的縮放因子,Z 是 INT 類型“零點(diǎn)”。
圖1 4 bits 非對(duì)稱線性量化
圖2 4 bits 對(duì)稱線性量化
如圖 1 所示,MegEngine 用數(shù)據(jù)類型 UINT4 表示 4 bits 的非對(duì)稱線性量化,量化值的取值范圍為[0,15];當(dāng) Z 取 0 時(shí)即為對(duì)稱線性量化,此時(shí) 4bits 量化值的取值范圍為[-8, 7],在 MegEngine 中用數(shù)據(jù)類型 INT4 表示,如圖 2 所示。
目前 8 bits 量化模型在一些場(chǎng)景下被業(yè)界廣泛運(yùn)用,我們想去了解 4 bits 量化模型的落地的可能性。這要解決兩個(gè)問(wèn)題:一方面,4 bits 量化模型的精度要如何保證;另一方面,4 bits 量化模型的速度能提升多少。要解答這兩個(gè)問(wèn)題,需要算法研究員和工程開(kāi)發(fā)人員的通力協(xié)作進(jìn)行驗(yàn)證。整件事情投入高,收益不明確。我們想找到開(kāi)源代碼,快速?gòu)脑韺用鎸?duì)這兩個(gè)問(wèn)題有個(gè)判斷,但經(jīng)過(guò)調(diào)研發(fā)現(xiàn)目前并沒(méi)有 4 bits 量化相關(guān)開(kāi)源內(nèi)容可供研究參考。所以,MegEngine 決定開(kāi)發(fā) 4 bits 量化并解答這兩方面的問(wèn)題。
緩解精度下降
保證 4 bits 量化模型的精度是重中之重,如果模型精度無(wú)法滿足需求,則 4 bits 量化的開(kāi)發(fā)將毫無(wú)意義。為了避免精度的大幅下降,MegEngine 采取的舉措是輸入和輸出采用非對(duì)稱量化 UINT4,weights 采用對(duì)稱量化 INT4,bias 采用 FP32。接下來(lái),從計(jì)算公式的推演上,來(lái)看這樣設(shè)計(jì)的合理性:
FP32 原始計(jì)算一次卷積輸出結(jié)果的公式:
結(jié)合公式 [1]、[2] 推導(dǎo)的 4 bits 量化的公式:
優(yōu)化之后的公式:
在上述公式中,ZI、ZW 是否等于 0,表明輸入/輸出和 weights 采用 INT4 還是 UINT4。并且在該公式中,除了Q(Ii)的值需要推理時(shí)確定,其余值均可在推理前獲得。所以,依據(jù)數(shù)據(jù)的計(jì)算特性,將這個(gè)公式分為了三個(gè)部分,分別用三種顏色表示:
黑色表示無(wú)論輸入/輸出以及 weights 數(shù)據(jù)類型如何選擇,一定有的計(jì)算量。因?yàn)闊o(wú)法避免,所以不用考慮這部分的數(shù)據(jù)特性。
藍(lán)色表示可以在推理前計(jì)算好的數(shù)據(jù)。
紅色表示必須在推理時(shí)才能計(jì)算的數(shù)據(jù)。
推理前可以計(jì)算好的這部分?jǐn)?shù)據(jù)可以提前計(jì)算并融合進(jìn) bias 中加入后續(xù)計(jì)算,所以 bias 必須用 FP32 數(shù)據(jù)類型表示,否則精度會(huì)大大降低。
至于輸入/輸出以及 weights 的數(shù)據(jù)類型選擇,結(jié)合上述公式可以推導(dǎo)得出:
全用 INT4 時(shí),即ZI、ZW 均等于 0, 計(jì)算量最小,只有黑色部分公式。
輸入/輸出用 UINT4,weights 用 INT4,即ZI 不等于 0,ZW 等于 0 時(shí),會(huì)增加藍(lán)色公式部分的計(jì)算量,但是這個(gè)部分是可以提前運(yùn)算好的,對(duì)整體計(jì)算時(shí)間影響不大。
weights 用 UINT4,即ZW 不等于 0 時(shí), 會(huì)增加紅色公式部分的計(jì)算量,會(huì)對(duì)整體的計(jì)算時(shí)間帶來(lái)較大影響。
由于 ResNet-50 模型 conv_relu 算子中的 relu 操作,輸入/輸出層的數(shù)據(jù)比較符合非對(duì)稱的特性,采用非對(duì)稱量化能更好地保留數(shù)據(jù)信息減少精度損失,所以輸入/輸出應(yīng)該選擇 UINT4,排除了上面三種方案中的第一種。第三種方案計(jì)算量會(huì)大很多,但是對(duì)精度的收益并不明顯。所以,最終選擇輸入和輸出采用非對(duì)稱量化 UINT4,weights 采用對(duì)稱量化INT4的方案。
緩解精度下降
提升模型性能并非一個(gè)簡(jiǎn)單的“因?yàn)橛?jì)算設(shè)備的 4 bits 算力大于 8 bits 算力,所以易知......”的推導(dǎo),計(jì)算設(shè)備 4 bits 算力大于 8 bits 算力是已知的,但是需要一些方法將這部分的算力“兌現(xiàn)”,算力需要合適的算子釋放出來(lái),其次,4 bits 量化所追求的也并非在某個(gè)算子的性能上超過(guò) 8 bits 量化,而是在模型層次超越 8 bits 量化??紤]到ResNet-50 模型以及卷積算子非常具有代表性,我們最終決定用 ResNet-50 模型作為基準(zhǔn)測(cè)試模型。經(jīng)過(guò)對(duì)模型的分析,發(fā)現(xiàn) ResNet-50 模型的性能瓶頸主要集中在兩個(gè)方面:
小算子比如 relu、add 較多,這些細(xì)瑣算子帶來(lái)的啟動(dòng)以及帶寬上的開(kāi)銷較大。
conv 計(jì)算非常多,占用了全圖 80% 以上的運(yùn)算時(shí)間。
為解決這兩方面的瓶頸,MegEngine 做了以下兩個(gè)方面的優(yōu)化工作:圖層次的算子融合以及算子層次的優(yōu)化。
算子融合優(yōu)化
MegEngine 通過(guò)對(duì)計(jì)算圖進(jìn)行掃描匹配,并將匹配到的圖結(jié)構(gòu)替換為優(yōu)化后的圖結(jié)構(gòu)。ResNet-50 模型所用的兩種 pass 轉(zhuǎn)換如下圖所示:
圖3 兩種Pass優(yōu)化方法
圖 3 中的大方塊表示圖中各種算子,小方塊表示這些算子的讀/寫(xiě)數(shù)據(jù)操作以及啟動(dòng)開(kāi)銷。從圖中可以看到經(jīng)過(guò)算子融合的優(yōu)化可以有效減少算子的讀/寫(xiě)數(shù)據(jù)的操作以及啟動(dòng)開(kāi)銷。
將這兩個(gè) pass 應(yīng)用于原始的 ResNet-50 的結(jié)構(gòu),就可以得到優(yōu)化后的圖。
圖4 Pass優(yōu)化在ResNet-50模型中的應(yīng)用
從圖 4 可以看到,通過(guò)對(duì) ResNet-50 模型的網(wǎng)絡(luò)結(jié)構(gòu)的優(yōu)化,add 和 relu 這些計(jì)算強(qiáng)度較小的算子已經(jīng)被 conv 這種計(jì)算強(qiáng)度大的算子所吸收,減少了小算子帶來(lái)的啟動(dòng)以及讀寫(xiě)上的開(kāi)銷。
conv 算子優(yōu)化
經(jīng)過(guò)算子融合優(yōu)化后,可以看到 ResNet-50 模型調(diào)用的算子主要是各種 conv fuse 的算子,如 Conv_Relu、Conv_Add_Relu,這些算子的主體部分都是 conv,所以主要的優(yōu)化也都落實(shí)在了 conv 算子優(yōu)化上。
conv 采用 implicit gemm 算法并通過(guò) mma 指令調(diào)度 tensor core 進(jìn)行計(jì)算加速。顧名思義,implicit gemm就是將 conv 運(yùn)算轉(zhuǎn)換為矩陣乘的一種算法,是對(duì) img2col 的算法的改進(jìn),傳統(tǒng)的 img2col 算法如下:
圖5 img2col示意圖
從圖 5 中可以看到,img2col 是將輸入 shape 為(N,IC,IH,IW),卷積核 shape 為(OC,IC,FH,FW)的卷積運(yùn)算變?yōu)?shape 分別為(OC,ICFHFW)和(ICFHFW,NOHOW)的兩個(gè)矩陣的乘法運(yùn)算。implict geem 的整體運(yùn)算邏輯與 img2col 相同,其區(qū)別在于 img2col 會(huì)“顯式”地完成圖 6 中數(shù)據(jù)的卷積排布到矩陣排布的轉(zhuǎn)換,需要額外開(kāi)辟一塊矩陣大小的空間用以存儲(chǔ)轉(zhuǎn)換后的矩陣,implict gemm 的轉(zhuǎn)換則是“隱式”的,沒(méi)有這部分空間開(kāi)銷,在 implicit gemm 算法中并沒(méi)有開(kāi)辟額外的空間存儲(chǔ)卷積核矩陣(OCxICFHFW)和輸入矩陣(ICFHFWxNOHOW),而是在分塊后,每個(gè) block 會(huì)按照上圖中的對(duì)應(yīng)邏輯,在 global memory 到 shared memory 的加載過(guò)程中完成從數(shù)據(jù)的原始卷積排布到 block 所需的矩陣分塊排布的轉(zhuǎn)換。
針對(duì) 4 bits 的 implict gemm 的優(yōu)化主要參照 cutlass 的優(yōu)化方案,并在此基礎(chǔ)上加入了 output 重排的優(yōu)化。由于篇幅問(wèn)題,本節(jié)僅講解 output 重排的優(yōu)化,想要了解更多技術(shù)細(xì)節(jié),建議參考閱讀之前的文章以及開(kāi)源代碼。
先分析 output 目前的排布情況,implict geem 的計(jì)算最終都落實(shí)在了 mma 指令上,而 mma 指令輸出的排布與 warp 中 32 個(gè)線程的關(guān)系如下:
圖6 mma輸出排布示意圖
如圖 6 中所示,在一次 mma 指令運(yùn)算中,一個(gè) warp 的 32 個(gè)線程負(fù)責(zé) 64 個(gè)運(yùn)算結(jié)果,且這些結(jié)果都存儲(chǔ)在寄存器上。每個(gè)線程負(fù)責(zé) 8x8 的結(jié)果矩陣同一行內(nèi)連續(xù)的兩個(gè)運(yùn)算結(jié)果,每四個(gè)線程負(fù)責(zé)同一行的 8 個(gè)運(yùn)算結(jié)果。
結(jié)合 implict geem 的結(jié)果矩陣 OCxNOHOW(由 OCxICFHFW 和 ICFHFWxNOHOW 乘積得到),在MegEngine 4 bits 量化的卷積算子設(shè)計(jì)中,一個(gè) warp 的 32 個(gè)線程和輸出的排布關(guān)系如下:
圖7 warp輸出排布示意圖
一個(gè) warp 負(fù)責(zé) 64x64 大小的輸出矩陣,該矩陣由 8x8 個(gè) mma 的 8x8 輸出矩陣組成,輸出和線程的排布關(guān)系如圖所示,黃色部分表示線程 0 所擁有的數(shù)據(jù)。圖 7 中的所有數(shù)據(jù)都在寄存器上,算子的最后一步操作,也就是將這些數(shù)據(jù)寫(xiě)回到 global memory 上并按照 NCHW64 的方式進(jìn)行排布。
一眼看上去,這些數(shù)據(jù)的排布都是間隔開(kāi)的,雖然橫坐標(biāo)上的數(shù)據(jù)連續(xù),但對(duì)于寫(xiě)回到 global memory 并按照 NCHW64 排布而言,并沒(méi)有什么幫助。直接的寫(xiě)回方式是將這些寄存器上的數(shù)據(jù)進(jìn)行壓縮,先將 8 個(gè)32 bits的數(shù)據(jù)轉(zhuǎn)換為 8 個(gè)4 bits 的數(shù)據(jù),再將這 8 個(gè) 4 bits 的數(shù)據(jù)放到一個(gè) 32 bits 大小的空間,然后寫(xiě)回到 global memory,這種處理方式將面臨幾個(gè)問(wèn)題:
每個(gè)線程中的數(shù)據(jù)都不連續(xù),增大了數(shù)據(jù)處理難度,這些額外的處理計(jì)算可能會(huì)導(dǎo)致性能下降。
需要在縱向的 8 個(gè)線程間交換數(shù)據(jù),會(huì)有同步的開(kāi)銷。
這無(wú)疑是一個(gè)開(kāi)銷比較大的處理方式,為了解決寫(xiě)回?cái)?shù)據(jù)帶來(lái)的性能問(wèn)題,MegEngine 采用了以下處理方式:
注意到 NCHW64 的排布方式,每 64 個(gè) OC 是連續(xù)的,嘗試將矩陣旋轉(zhuǎn)一下,想象這是一個(gè) NOHOWxOC 的矩陣,那么 T0、T1、T2、T3 四個(gè)線程所負(fù)責(zé)的數(shù)據(jù)在 OC 維度上是連續(xù)的,它們對(duì)于的 OC 維度分別是
T0{0,1; 8,9;16,17;24,25;32,33;40,41;48,49;56,57}、
T1{2,3;10,11;18,19;26,27;34,35;42,43;50,51;58,59}......
可以看到,現(xiàn)在是四個(gè)線程負(fù)責(zé) 64 個(gè)連續(xù)的輸出,那么只要這四個(gè)線程交換數(shù)據(jù)再壓縮、寫(xiě)回即可。相比于之前 8 個(gè)線程間數(shù)據(jù)交換和寫(xiě)回,現(xiàn)在的處理方式更加簡(jiǎn)單,內(nèi)部偏移計(jì)算與同步開(kāi)銷會(huì)更少。所以實(shí)現(xiàn)output轉(zhuǎn)置是一種切實(shí)可行的優(yōu)化方法。這也體現(xiàn)了 NCHW64 的排布方式使得 4 bits 類型的數(shù)據(jù)在傳輸過(guò)程能被連續(xù)訪存,充分利用硬件資源的特點(diǎn)。
但是線程間交換數(shù)據(jù)的開(kāi)銷在output轉(zhuǎn)置處理中依然沒(méi)有被徹底解決。如果可以得到
T0{0,1;2,3;4,5;6,7;8,9;10,11;12,13;14,15}、
T1{16,17;18,19;20,21;22,23;24,25;26,27;28,29;30,31}......
這樣的輸出OC 維度和線程對(duì)應(yīng)關(guān)系。那么就只需要在線程內(nèi)部進(jìn)行數(shù)據(jù)打包和寫(xiě)回,并且 16 個(gè)4 bits 的數(shù)據(jù)正好占用 2 個(gè)32 bits 大小的空間,非常規(guī)整。要實(shí)現(xiàn)這個(gè)效果也是非常簡(jiǎn)單的:對(duì)于 AxB=C 的矩陣乘法,要實(shí)現(xiàn) C 矩陣的列順序變換,只需要對(duì) B 矩陣進(jìn)行對(duì)應(yīng)的列順序變換即可,如下圖所示:
圖8 矩陣乘積的列變換
從圖 可以看出,將乘積矩陣 AxB=C 中的 B 矩陣的第1列和第5列進(jìn)行對(duì)調(diào),結(jié)果矩陣 C 對(duì)應(yīng)的列的運(yùn)算結(jié)果也會(huì)發(fā)生同步的對(duì)調(diào)。利用這一特點(diǎn),可以在 conv 算子運(yùn)算前,將 weights 的列進(jìn)行重排序,使得最終輸出OC 維度在對(duì)應(yīng)的相同線程中保持連續(xù),T0{0,1;2,3;4,5;6,7;8,9;10,11;12,13;14,15}...
所以總結(jié)一下 output 重排的策略,其實(shí)就兩點(diǎn):
將 OCxICFHFW 和 ICFHFWxNOHOW 的矩陣乘,變?yōu)?NOHOWxICFHFW 和 ICFHFWxOC 的矩陣乘,實(shí)現(xiàn)output 結(jié)果的轉(zhuǎn)置,確保在 OC 維度上的數(shù)據(jù)連續(xù),配合 NCHW64 的排布方式,便于將數(shù)據(jù)從寄存器上寫(xiě)回到 global memory 上。
通過(guò)對(duì) ICFHFWxOC 矩陣的 OC 進(jìn)行重新排序,實(shí)現(xiàn) output 矩陣 NOHOWxOC 的 OC 維度和線程的對(duì)應(yīng)關(guān)系更加合理,確保線程內(nèi)部的數(shù)據(jù)連續(xù)性,避免線程間數(shù)據(jù)交換的開(kāi)銷。
總結(jié) & 展望
本次開(kāi)源提供了和 TensorRT(TRT) ResNet-50 8 bits 量化模型在 ImageNet 數(shù)據(jù)集上速度以及精度對(duì)比結(jié)果:
圖9 速度對(duì)比
圖10 精度對(duì)比
通過(guò)在 ResNet50 上的測(cè)試可以看到,MegEngine 的 INT4 方案可以比 fp32 推理速度提升 5.65 倍至多,相比于現(xiàn)在業(yè)內(nèi)較為常用的 INT8 方案也仍然可以提升 1.3 倍的速度。在速度大幅提升的同時(shí),uint4*int4 的方案盡可能的保證了精度,精度下降能夠控制在 top1 -0.3% 左右。
在速度和精度兩方面的努力,讓 INT4 的方案能夠在實(shí)際的業(yè)務(wù)場(chǎng)景中帶來(lái)顯著的優(yōu)勢(shì),而不只是停留在論文上。
審核編輯 :李倩
-
開(kāi)源
+關(guān)注
關(guān)注
3文章
3407瀏覽量
42713 -
數(shù)據(jù)集
+關(guān)注
關(guān)注
4文章
1209瀏覽量
24834 -
量化
+關(guān)注
關(guān)注
0文章
34瀏覽量
2346
原文標(biāo)題:提速還能不掉點(diǎn)!深度解析 MegEngine 4 bits 量化開(kāi)源實(shí)現(xiàn)
文章出處:【微信號(hào):AI前線,微信公眾號(hào):AI前線】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
深度解析deepseek開(kāi)源是什么意思
蘋(píng)果開(kāi)源Swift Build,強(qiáng)化開(kāi)發(fā)者生態(tài)建設(shè)
玻璃通孔(TGV)技術(shù)深度解析
黃鶴開(kāi)源社區(qū)正式發(fā)布
深度解析研華全棧式AI產(chǎn)品布局
汽車智能化開(kāi)源創(chuàng)新論壇精選:OS是未來(lái)新興汽車產(chǎn)業(yè)生態(tài)構(gòu)建的核心
![汽車智能<b class='flag-5'>化開(kāi)源</b>創(chuàng)新論壇精選:OS是未來(lái)新興汽車產(chǎn)業(yè)生態(tài)構(gòu)建的核心](https://file1.elecfans.com/web2/M00/C4/8A/wKgZomX0EhWACv8DAAAUet8ikhs451.png)
論壇介紹 | RT-Thread出席汽車智能化開(kāi)源創(chuàng)新論壇
![論壇介紹 | RT-Thread出席汽車智能<b class='flag-5'>化開(kāi)源</b>創(chuàng)新論壇](https://file1.elecfans.com/web2/M00/C4/8A/wKgZomX0EhWACv8DAAAUet8ikhs451.png)
深度神經(jīng)網(wǎng)絡(luò)模型量化的基本方法
深度學(xué)習(xí)模型量化方法
![<b class='flag-5'>深度</b>學(xué)習(xí)模型<b class='flag-5'>量化</b>方法](https://file1.elecfans.com/web2/M00/FC/79/wKgZomaUkiyAe9zcAAAKurIXg1w692.png)
深度神經(jīng)網(wǎng)絡(luò)(DNN)架構(gòu)解析與優(yōu)化策略
I2S Master bits_per_sample != bits_per_chan情況下工作不正常是怎么回事?
存內(nèi)計(jì)算技術(shù)工具鏈——量化篇
![存內(nèi)計(jì)算技術(shù)工具鏈——<b class='flag-5'>量化</b>篇](https://file1.elecfans.com/web2/M00/E5/D3/wKgZomZEmdiAIdAqABnBLeia9y8361.png)
深度解析深度學(xué)習(xí)下的語(yǔ)義SLAM
![<b class='flag-5'>深度</b><b class='flag-5'>解析</b><b class='flag-5'>深度</b>學(xué)習(xí)下的語(yǔ)義SLAM](https://file1.elecfans.com/web2/M00/D6/82/wKgZomYnfe-ARm_pAAAcYiwkMFk951.png)
評(píng)論