深入研究 GRPO,發現了意外收獲。
原標題:DeepSeek用的GRPO占用大量內存?有人給出了些方法
文章來源:機器之心
內容字數:8253字
RTX 3080 移動版可訓練的大模型及GRPO訓練技巧
本文總結了使用RTX 3080移動版顯卡(16GB顯存)進行大型語言模型強化學習訓練的經驗,重點介紹了群組相對策略優化(GRPO)方法及其內存優化策略。
可訓練模型大小及方法選擇
作者使用GRPO方法,在RTX 3080移動版上進行訓練,發現模型大小和訓練方式對顯存需求影響很大。實驗在參數量從5億到140億不等的模型上進行,比較了全參數微調和參數高效微調(PEFT,使用LoRA)。全參數微調比PEFT需要更多內存。在H100上進行的實驗顯示,全參數微調所需的VRAM超過80GB。
GRPO的高內存需求原因
GRPO的高內存需求源于其內部涉及多個模型(策略模型、參考模型和獎勵模型),每個查詢都會產生多個輸出,導致內存占用迅速增加。即使獎勵模型非參數化,內存需求依然很高。
內存優化策略
為了降低內存占用,作者使用了兩種技術:8位優化器(例如8-bit AdamW)和梯度檢查點。8位優化器能更高效地存儲優化器跟蹤數據,而梯度檢查點則通過在訓練過程中拍攝快照來減少內存使用,雖然會降低訓練速度(約20-30%),但能顯著減少內存占用。
代碼示例及參數設置
作者提供了使用Hugging Face的trl庫進行GRPO訓練的代碼示例,該代碼簡潔易懂,適合小型模型(如meta-llama/Llama-3.2-1B-Instruct)和數據集(如openai/GSM8K)。文中詳細說明了各個參數(如`num_generations`、`batch_size`、`gradient_accumulation_steps`、`num_completions`、`max_prompt_length`、`max_completion_length`)對VRAM使用量的影響,并建議在內存瓶頸修復前使用`num_generations=4`。
VRAM使用量估算
作者給出了VRAM使用量的粗略估算方法,考慮了模型參數、梯度、優化器狀態等因素,并指出PEFT可以減少梯度的顯存占用。
實驗結果及結論
作者使用10億參數的Llama 3.2模型進行了完整訓練,結果顯示GRPO顯著提升了模型準確率(從19%提升到40.5%),展示了其強大潛力。
總而言之,本文為GPU資源有限的開發者提供了寶貴的GRPO訓練經驗,并通過內存優化策略和參數調整,幫助開發者在有限的硬件條件下訓練更大的模型。
聯系作者
文章來源:機器之心
作者微信:
作者簡介:專業的人工智能媒體和產業服務平臺