從8-Bit到4-Bit。
原標題:4比特量化三倍加速不掉點!清華即插即用的SageAttention迎來升級
文章來源:機器之心
內容字數:6979字
清華大學陳鍵飛團隊提出SageAttention2:4-Bit即插即用注意力機制,實現3-4.5倍推理加速
本文總結了清華大學陳鍵飛團隊最新提出的SageAttention2論文要點。該工作在之前的SageAttention基礎上,進一步將注意力機制中的Q、K矩陣量化到INT4,實現了更高速的推理速度,并在多個大型模型上保持了端到端的精度。
1. 背景與挑戰
隨著大型模型序列長度的增加,注意力機制(Attention)的計算開銷成為瓶頸。雖然線性層的低比特量化已較為成熟,但注意力模塊大多仍使用高精度(FP16或FP32)運算。SageAttention率先將Attention中的QKT量化到INT8,取得了顯著加速效果。然而,INT8的矩陣乘法速度仍不及INT4,且FP16的乘法累加器加速僅在特定顯卡上有效。因此,將注意力機制量化到INT4成為進一步提升效率的關鍵,但也面臨巨大挑戰:直接量化到INT4會導致精度嚴重下降。
2. SageAttention2的技術方案
為了解決INT4量化帶來的精度損失問題,SageAttention2提出了以下技術方案:
對Q和K進行平滑處理: 在對K進行平滑處理的基礎上,新增對Q進行平滑處理(Q – mean(Q)),并補償到最終結果中,有效降低了量化誤差。
Per-thread量化: 將Q、K矩陣的量化粒度細化到per-thread級別,提高了4-Bit QKT乘法的精度,且不增加額外開銷。
FP32寄存器累加FP8 PV乘法結果: 使用FP32寄存器累加FlashAttention分塊粒度的PV的FP22乘法結果,避免了累積誤差。
使用E4M3格式的FP8: 實驗表明,E4M3格式的FP8精度最高,接近FP16。
可選的V矩陣平滑: 對V矩陣進行平滑處理,進一步提升PV矩陣乘法的精度。
3. 實驗結果
SageAttention2在速度和精度上都取得了顯著提升:
速度: 相比FlashAttention2和xformers分別實現了3倍和4.5倍的推理加速,在不同顯卡上均有不同程度的加速。
精度: 在視頻、圖像、文本生成等多種大型模型上保持了端到端的精度,例如在CogvideoX-1.5-5B模型上實現了1.8倍的端到端加速,且視頻質量無損。
4. 總結
SageAttention2通過一系列創新技術,成功地將注意力機制量化到INT4,實現了顯著的推理加速,并在多個大型模型上保持了精度。其即插即用的特性也方便了在實際應用中的部署,為大型模型的效率提升提供了有力支持。該工作已開源,方便開發者使用。
聯系作者
文章來源:機器之心
作者微信:
作者簡介:專業的人工智能媒體和產業服務平臺