AIxiv专栏是人工智能站发布学术、技术内容的栏目。过去数年,人工智能站AIxiv专栏接收报道了2000多篇内容,覆盖全球各大高校与企业的顶级实验室,有效促进了学术交流与传播。如果您有优秀的工作想要分享,欢迎投稿或者联系报道。投稿邮箱:liyazhou@jiqizhixin.com;zhaoyunfeng@jiqizhixin.com
项目核心开发者 Haosheng Zou 本科毕业于清华大学电子系,博士毕业于清华大学计算机系朱军教授组,目前在 360 智脑从事长文本和强化学习等后训练工作。开发者 Xiaowei Lv 目前在人民大学信息学院研二在读。Fenrui Xiao、Junchen Liu、Qi An 和 Xiaodong Sun 等在开发测试中亦有贡献。
大模型长序列的处理能力已越来越重要,像复杂长文本任务、多帧视频理解任务、以及 OpenAI 近期发布的 o1、o3 系列模型的高计算量模式,需要处理的输入 + 输出总 token 数从几万量级上升到了几百万量级。面对模型日益增长的长序列需求,在预训练(Pre-Training)和后训练(Post-Training)阶段,所用的平台框架都需要支持更长序列数据的训练。不同于预训练阶段基于 Megatron-LM 定制开发的常见选择,后训练阶段因后训练算法的多样性(比如仅 DPO 就有几十个变种)和训练需求的灵活性,至今没有一个框架同时在并行策略、后训练算法、GPU 显存优化和简单易用这 4 个方面上全部做到兼容并包。
在所有开源的后训练框架中,LLaMA-Factory 是用户最多的框架之一(GitHub star 数已 37k 多),保持长期迭代更新,支持丰富的模型和后训练算法,有各种 GPU 显存优化技巧和简单易用的方式。然而,LLaMA-Factory 在长序列后训练上支持仍有所欠缺,尚不支持长序列的关键技术 —— 序列并行。
项目主页:https://github.com/Qihoo360/360-LLaMA-Factory
最近,360 智脑基于 LLaMA-Factory 开源了 360-LLaMA-Factory,加入了序列并行功能,一行代码即可支持任意长序列的后训练(Post-Training)—— 仅需额外指定序列并行一个参数:
sequence_parallel_size: 16
按需增加序列并行的 GPU 卡数,即可在任意长度的序列上 SFT 或 DPO。
360-LLaMA-Factory 的实现经过了严格的正确性验证,已在主仓 Pull Request 中审核过。正式合并进 LLaMA-Factory 主仓之前,可先使用 360-LLaMA-Factory。
360 智脑早在 2023 年就开始了长文本大模型的研发,到目前为止已经成功应用于开源并更新了两个版本的 360Zhinao-7B-Chat-360k 模型,以及近日发布的长思维链推理模型 360gpt2-o1。在 360-LLaMA-Factory 中,我们将 360 智脑内部长序列后训练能力系统性地整合进了 LLaMA-Factory 中,用户仅需额外添加一行代码,即可进行理论上任意长度的长序列后训练(增加序列并行的 GPU 卡数即可):
sequence_parallel_size: 16
在原先使用 LLaMA-Factory 的基础上,只需额外增加一个参数
通过这种方式,360-LLaMA-Factory 将 LLaMA-Factory 的序列并行也做到了简单易用和兼容并包,和 LLaMA-Factory 的其他功能完全兼容。
粗粒度地测试 8 卡 80G 的全参数后训练(不考虑除了 zero3-offload 和 gradient checkpointing 外的任何优化技巧),360-LLaMA-Factory 至少可以训到 SFT 210k (7B) / 128k (72B) 和 DPO 84k (7B) / 46k (72B)。若加上注掉 logits = logits.float () 和 DPO 预计算等技巧,2 卡序列并行即可解决诸多常见的训练需求。360-LLaMA-Factory 让序列并行也真正成为了简单好用、效果也好的后训练工具。
作为开源社区的一份子,360-LLaMA-Factory 离不开 LLaMA-Factory、ring-flash-attention 和 EasyContext 等开源项目的开创性工作,我们的底层开发部分依赖了这些工作,但也有我们自己在具体实现方式上的不同和见解。我们相信我们的代码实现已做到尽可能好的模块化和尽可能少的原始代码修改,且严格检查过正确性,因此也已向 LLaMA-Factory 主仓提交了 Pull Request,初步审核通过。我们乐于同开源社区共建完善这项工作。
随着大模型训练数据长度的增长,预训练和后训练平台框架都需要支持长序列数据训练。
预训练阶段,英伟达的 Megatron-LM 凭借丰富高效的并行策略与出色的 GPU 显存优化,成为主流框架,基于它的定制开发往往是最通用的解法, Megatron-LM 本身已实现了序列并行(Megatron-LM 称之为 context parallelism,其他工作一般称为 sequence parallelism)。
后训练阶段情况相对复杂。后训练算法多样,如 DPO 就有诸多变种,且训练需求灵活多变,不同场景对算法、资源、并行性等要求各异。因此,至今没有一个框架能在并行策略、后训练算法、GPU 显存优化和易用性这四个关键方面做到近乎完美的兼容。虽有框架在部分方面表现尚可,但总体仍存在短板,这也限制了模型在长序列数据后训练上的进一步发展。
长序列后训练面临的关键瓶颈是:序列长度增加时,激活显存会大幅上升。虽然有 unsloth、liger kernel、LoRA 等多种降低显存占用的技巧,但均未从根本上解决序列长度增加的本质问题,其效果存在明确上限。
序列并行(sequence parallelism)被认为是解决长序列训练问题的通解,它通过把一条长序列切分到不同的显卡上进行计算,从而避免了每张显卡处理过长的序列,从根本上解决了 “每张显卡处理的序列长度增加” 的问题。然而,序列并行的实现难度较大,需要在切分后的序列之间进行通信计算 attention,需要侵入修改原始的 attention 函数。在开源的 Megatron-LM 中,序列并行也是所有并行策略中最后才添加的,LLaMA-Factory 之前还没有支持序列并行。
我们调研了其他一些支持序列并行的开源框架,有些实现上有错或小 bug、导致支持的后训练算法不全;有些更新维护不及时、训练较新的模型不方便、显示进度条等易用性不足。有的与 LLaMA-Factory 相比继承依赖更少,支持功能较少但更干净、更适合定制开发,有不同的使用场景。此外,各家的序列并行具体实现也不尽相同。详见下面的表 1 和 GitHub README,有未调研到的也请包涵并联系 360-LLaMA-Factory。
360-LLaMA-Factory 系统性地为 LLaMA-Factory 增加了序列并行的支持。以下将简要介绍 360-LLaMA-Factory 框架中的模块化修改和执行流程。
3.1 360-LLaMA-Factory 的框架和模块化封装
360-LLaMA-Factory 将序列并行的代码做到了尽可能好的模块化和尽可能少的原始代码修改。
我们认为序列并行本质上应认为是对模型的修改,因此在 model_args 中增加了参数并抽象为 apply_sequence_parallel 修改模型的函数。
# src/llamafactory/model/loader.py
sequence_parallel_group = apply_sequence_parallel(model_args) # 序列并行monkey patch,改动attention计算
model.sequence_parallel_group = sequence_parallel_group # 维护模型的序列并行组,不开则为None
相应地,数据处理部分也要相应地修改,我们将 zigzag ring attention 所需的数据处理抽象成了一个 decorator,装饰原来的数据处理函数。背后,这会将先 shuffle、packing、预处理好的数据进一步做好序列并行的准备:先将每行 pad 或截断到指定的训练长度,再按 zigzag 切分并按顺序写入数据集,最后在训练时用 SequentialSampler 读取训练数据。
# src/llamafactory/data/loader.py
@sequence_parallel_decorator
loss 计算则需要在 Trainer 中做序列并行组内的 reduce 汇总和计算。
# src/llamafactory/train/sft/trainer.py
dist.all_reduce(loss, op=dist.ReduceOp.SUM, group=sp_group)
dist.all_reduce(label_num, op=dist.ReduceOp.SUM, group=sp_group)
# src/llamafactory/train/dpo/trainer.py
dist.all_reduce(policy_chosen_logps, op=dist.ReduceOp.SUM, group=sp_group)
dist.all_reduce(policy_rejected_logps, op=dist.ReduceOp.SUM, group=sp_group)
dist.all_reduce(reference_chosen_logps, op=dist.ReduceOp.SUM, group=sp_group)
dist.all_reduce(reference_rejected_logps, op=dist.ReduceOp.SUM, group=sp_group)
3.2 360-LLaMA-Factory 的 SFT 和 DPOTrainer
除了统一的模块化抽象,序列并行也需要对 360-LLaMA-Factory 的 Trainer 稍做定制化的修改,以适配各底层库。针对最普遍的后训练需求 SFT 和 DPO(及其变种),我们对 360-LLaMA-Factory 中的 SFT 和 DPOTrainer 做了尽可能少且清晰的修改。
其中,dummy_forward 是因为我们发现基于目前的底层序列并行实现,在第一次 forward 时 DPO loss 不等于 log (sigmoid (0)),但学习率设为 0 时之后的 DPO loss 全都等于。因此,训练最开始时先做且仅做一次假前传,不对正式训练循环造成任何影响。
从 SFT 和 DPO 的序列并行对比图中,可以清晰地看出 360-LLaMA-Factory 序列并行带来的改动。
图 3:360-LLaMA-Factory SFT 序列并行对比
图 4:360-LLaMA-Factory DPO 序列并行对比
内部 360-LLaMA-Factory 的早期版本已训练了开源的 360Zhinao2-7B-Chat-360k。
为验证本次开源的 360-LLaMA-Factory 的正确性,我们用总量为 30 条的小数据集,验证了序列并行开与不开的对比情况下,训练曲线的差别,以此来确保 360-LLaMA-Factory 所有实现的正确性。从下图可见,序列并行对训练曲线的影响几乎可以忽略不计,DPO 稍有一定数值误差,但我们也仔细检查了该误差与 DeepSpeed Ulysses 的误差范围一致,很可能部分是并行计算本身的随机性导致的,亦可参考 ring-flash-attention 的详细说明。
图 5:360-LLaMA-Factory SFT 和 DPO 序列并行开关对比
为便于对比效果,我们基于第三方全尺寸开源模型粗粒度压测了最大训练长度,如下表 2、表 3 所示,可见 8 卡 80G 的序列并行上限已可满足几十至几百 k 超长序列的需求:
360 智脑开源了 360-LLaMA-Factory,支持了序列并行,仅需额外 1 个参数控制。基于 LLaMA-Factory 和 ring-flash-attention 开发,360-LLaMA-Factory 的实现模块化、效果正确且在长序列上有效。
欢迎开发者们使用和开发。在本仓库(https://github.com/Qihoo360/360-LLaMA-Factory)下提交序列并行相关的 issue 或 PR 即可。
也欢迎研究者们,尤其是依赖长序列大模型的研究者们,在研究中使用我们的代码,可以这样引用我们的工作:
@software{360-llama-factory,
author = {Haosheng Zou, Xiaowei Lv, Shousheng Jia and Xiangzheng Zhang},
title = {360-LLaMA-Factory},
url = {https://github.com/Qihoo360/360-LLaMA-Factory},
建议同时引用 LLaMA-Factory 和 ring-flash-attention 相关工作。