1-bit大模型還能再突破!新一代BitNet架構(gòu)啟用4位激活值
新智元報(bào)道編輯:alan【新智元導(dǎo)讀】近日,BitNet系列的原班人馬推出了新一代架構(gòu):BitNet a4.8,為1 bit大模型啟用了4位激活值,支持3 bit KV cache,效率再突破。量化到1 bit的LLM還能再突破?這次,他們對(duì)激活值下手了!近日,BitNet系列的原班人馬推出了新一代架構(gòu):BitNet a4.8,為1 bit大模型啟用了4位激活值:論文地址:https://arxiv.org/pdf/2411.04965眾所周知,激活值量化通常是比較難辦的。本次的BitNet a4.8采用混合量化和稀疏化策略,來(lái)減輕異常通道引入的量化誤差。簡(jiǎn)單來(lái)說(shuō)就是,對(duì)注意力層和FFN層的輸入采用4位量化,同時(shí)用8位整數(shù)稀疏化中間狀態(tài)。大量實(shí)驗(yàn)表明,BitNet a4.8在相同的訓(xùn)練成本下,實(shí)現(xiàn)了與前代BitNet b1.58相當(dāng)?shù)男阅埽瑫r(shí)因?yàn)榭梢猿缘?位(INT4/FP4)內(nèi)核的計(jì)算紅利,實(shí)現(xiàn)了更快的推理速度。BitNet a4.8僅激活55%的參數(shù),并支持3 bit KV cache,進(jìn)一步提升了大規(guī)模LLM部署和推理的效率。BitNet a4.8模型架構(gòu)模型的整體架構(gòu)如圖1所示,BitNet a4.8采用了與BitNet b1.58相同的布局。作者使用BitLinear替換注意力(MHA)和前饋網(wǎng)絡(luò)(FFN)中的線(xiàn)性投影,以從頭開(kāi)始學(xué)習(xí)1.58 bit權(quán)重。對(duì)于激活值,采用混合量化和稀疏化策略來(lái)減輕異常值維度引入的誤差。圖2說(shuō)明了模型大小為7B的BitNet b1.58中,每個(gè)模塊輸入的分布。注意力層和FFN層的輸入通常類(lèi)似高斯分布,而在FFN下采樣之前的激活值和注意力中的輸出投影中,發(fā)現(xiàn)了很多異常值通道和大量接近零的條目(全精度LLM也有類(lèi)似觀察結(jié)果)。如圖3所示,直接將低位量化應(yīng)用于這些中間狀態(tài)會(huì)引入很大的量化誤差。因此,作者使用Q-Sparse的稀疏化方法,將這些中間狀態(tài)保持在8位(同時(shí)消除了計(jì)算瓶頸)。對(duì)于自注意層的輸出投影,使用sparsify-then-quantize函數(shù):兩個(gè)Q分別表示權(quán)重W和激活X的量化函數(shù),M是掩碼,根據(jù)激活X的絕對(duì)值取topK,⊙是元素乘法。具體來(lái)說(shuō),權(quán)重量化和激活值量化函數(shù)可以表述為:對(duì)于FFN,這里采用squared ReLU和門(mén)控線(xiàn)性單元(GLU)來(lái)進(jìn)一步提高激活的稀疏性:根據(jù)初步實(shí)驗(yàn)的結(jié)果,使用squared ReLU時(shí),下采樣輸入的稀疏性超過(guò)了80%,且對(duì)性能的影響最小。此外,作者還觀察到gate + squared ReLU的輸出也表現(xiàn)出高激活稀疏性(7B模型為67.5%)。通過(guò)首先計(jì)算gate projection,然后僅在非零通道上執(zhí)行up projection,可以進(jìn)一步減少推理的計(jì)算量。相比之下,attention和FFN的輸入中包含的異常值特征要少得多,可以使用absmean函數(shù)將激活值量化為4位整數(shù):模型訓(xùn)練初始化BitNet a4.8使用BitNet b1.58的權(quán)重開(kāi)始訓(xùn)練,分為W1.58A8與W1.58A4兩階段。第一階段使用8位激活和GLU + squared ReLU訓(xùn)練模型;第二階段采用上面介紹過(guò)的混合量化和稀疏化。BitNet a4.8只需少量訓(xùn)練,即可快速適應(yīng)4bit位寬和稀疏激活,同時(shí)性能損失可以忽略不計(jì)。梯度近似作者使用直通估計(jì)器(STE)對(duì)BitNet a4.8進(jìn)行梯度逼近,使用混合精度訓(xùn)練來(lái)更新參數(shù)。這里直接繞過(guò)了不可微函數(shù),包括反向傳播過(guò)程中的量化函數(shù)和topK稀疏函數(shù)。對(duì)于混合精度訓(xùn)練,保持全精度latent weight來(lái)累積參數(shù)更新。模型量化浮點(diǎn)量化提供了比基于整數(shù)的量化更寬的動(dòng)態(tài)范圍,這對(duì)于處理激活值的長(zhǎng)尾分布至關(guān)重要。研究人員將FFN下采樣層的輸入保留為8位整數(shù),其他激活值使用MinMax量化器量化為FP4:公式中E和M分別表示指數(shù)和尾數(shù)部分的位寬。這里采用E2M1格式,因?yàn)樗膭?dòng)態(tài)范圍更大。實(shí)驗(yàn)本文將BitNet a4.8、BitNet b1.58,以及各種參數(shù)量大小的FP16精度LLaMA進(jìn)行了比較。其中的1.58 bit模型,遵循BitNet b1.58的訓(xùn)練方案,采用了兩階段權(quán)重衰減和學(xué)習(xí)率調(diào)度。所有模型都使用RedPajama數(shù)據(jù)集中的100B token進(jìn)行訓(xùn)練,以確保公平比較。對(duì)于BitNet a4.8,作者首先使用95B token來(lái)訓(xùn)練8位激活值的模型。然后重用優(yōu)化器狀態(tài),并使用5B token進(jìn)行混合量化和稀疏化的訓(xùn)練。實(shí)驗(yàn)將topK設(shè)置為50%(attention的輸出投影位置)。作者使用lm-evaluation-harness工具包,評(píng)估模型在一系列語(yǔ)言任務(wù)上的zero-shot準(zhǔn)確性,包括ARC-Easy(ARCe)、ARCChallenge(ARCc)、Hellaswag(HS)、Winogrande(WGe)和PIQA(PQ)。另外還測(cè)試了在C4數(shù)據(jù)集(測(cè)試集)上的困惑度。主要結(jié)果表1總結(jié)了BitNet a4.8、BitNet b1.58和FP16 LLaMA的詳細(xì)測(cè)試結(jié)果。全精度(FP16)LLaMA和BitNet b1.58之間的性能差距,隨著模型大小的增長(zhǎng)而縮小。對(duì)于7B模型,BitNet b1.58在語(yǔ)言模型困惑度和任務(wù)的平均準(zhǔn)確性方面與LLaMA相當(dāng)。此外,相比于BitNet b1.58,BitNet a4.8的平均精度幾乎沒(méi)有損失。表2展示了各種大小的BitNet a4.8、BitNet b1.58 和 FP16 LLaMA中每個(gè)模塊的詳細(xì)稀疏性(使用C4驗(yàn)證集上的非嵌入?yún)?shù)計(jì)算)。值得注意的是,BitNet a4.8的稀疏性明顯高于BitNet b1.58和LLaMA。比如在7B模型中,BitNet a4.8的整體稀疏性達(dá)到了44.5%,只有3.4B的活躍參數(shù)。down projection層的輸入顯示出特別高的稀疏性,且中間狀態(tài)分布以零為中心。此外,gate projection的輸出非常稀疏,導(dǎo)致了up projection的高稀疏性(因?yàn)橹恍枰趶腉ate中選擇非零通道來(lái)執(zhí)行投影)。具體來(lái)說(shuō),對(duì)于7B BitNet a4.8,Gate和up projection的稀疏率分別為67.5%和12.0%。表3顯示了BitNet a4.8在3B和7B模型大小下,low-bit attention的詳細(xì)情況。模型使用4位KV或QKV頭,精度損失可忽略不計(jì),同時(shí)KV cache可以量化為3位整數(shù)。low-bit attention對(duì)于高效的長(zhǎng)序列建模至關(guān)重要,它減少了KV cache的內(nèi)存占用和IO,并加速了注意力計(jì)算。在本文的實(shí)驗(yàn)中,作者采用RoPE后量化。使用absmax函數(shù)將QKV頭直接量化為無(wú)符號(hào)整數(shù),無(wú)需任何校準(zhǔn)數(shù)據(jù)集。對(duì)于3 bit KV量化,研究人員將bos token的頭保留為4 bit,因?yàn)樗嗟漠惓V堤卣鳌O趯?shí)驗(yàn)圖4顯示了700M BitNet a4.8的訓(xùn)練損耗曲線(xiàn),比較了使用完整的INT4/FP4量化,以及本文的混合量化和稀疏化。完整的INT4量化會(huì)導(dǎo)致發(fā)散,而混合架構(gòu)在訓(xùn)練困惑度方面明顯優(yōu)于完整的FP4架構(gòu)。使用RedPajama數(shù)據(jù)集中25B token,來(lái)進(jìn)行模型的第一階段訓(xùn)練,采用absmean和MinMax量化器分別進(jìn)行完整的INT4和FP4量化。對(duì)于完整的INT4量化,由于其輸入具有更大的異常值,這里設(shè)置β = 2*mean(|X|)。接下來(lái)為1.3B BitNet a4.8的down projection層輸入,設(shè)置不同的量化或激活函數(shù)。所有模型都使用RedPajama數(shù)據(jù)集中的50B token進(jìn)行第一階段訓(xùn)練。為了確保公平比較,其他激活值都保留在8位。圖5顯示了這些模型的訓(xùn)練損失曲線(xiàn)。Squared ReLU的訓(xùn)練困惑度比Swish略好,同時(shí)實(shí)現(xiàn)了更高的稀疏性。此外,對(duì)down projection的輸入應(yīng)用FP4量化會(huì)導(dǎo)致性能顯著下降,而將INT4激活與STE一起使用會(huì)導(dǎo)致發(fā)散。參考資料:https://arxiv.org/abs/2411.04965https://venturebeat.com/ai/how-microsofts-next-gen-bitnet-architecture-is-turbocharging-llm-efficiency/