AWS上实现超大规模模型训练的近线性扩展
当前最先进的语言模型具有数十亿参数。要在可控时间内训练这些模型,需要将工作负载分配到大型计算集群上。理想情况下,训练时间应随着集群规模扩大而线性减少。但由于节点间协调工作所需的通信会抵消并行化带来的收益,实现线性扩展非常困难。
我们近期优化了微软DeepSpeed分布式训练库的通信效率,在最多64个GPU上显著提升了性能。但当规模从数十个GPU扩展到数百个GPU时,在公有云环境中通信开销再次成为效率瓶颈。
在即将于2023年VLDB会议上发表的论文中,我们提出了一种名为MiCS(最小化通信规模)的方法,可在云环境中实现数百个GPU的高效模型训练。与DeepSpeed和FairScale等现有框架将模型状态划分到所有GPU不同,MiCS会创建模型状态的多个副本,并将每个副本划分到GPU子集中。
实验结果显示,在不同规模的BERT模型上,使用p3dn.24xlarge实例集群评估时,MiCS在吞吐量和扩展效率方面都有显著提升。该方法能实现近线性扩展(如下图矩形框所示),相比DeepSpeed-v0.5.6内置的ZeRO优化器,吞吐量最高提升2.82倍。
规模感知的模型分区
MiCS将集群中的GPU划分为多个"分区组",每个组持有完整的模型状态副本。这种方法将频繁的通信操作(如参数收集)限制在固定数量的GPU内,有效控制了通信开销随集群规模增长的问题。
分层通信策略
当单个模型状态副本的内存需求超过单节点GPU总内存时,MiCS采用分层通信策略减少节点间通信参与者的数量。例如在双节点四GPU场景下,通信量因子从3/4降至1/2。
两跳梯度同步
MiCS通过将梯度同步开销分摊到多个微步中,实现了高效的两跳梯度同步机制。这使得在p4de.24xlarge机器上训练1750亿参数模型时,每个GPU能达到169万亿次浮点运算(理论峰值的54.2%)。
当集群规模从128GPU扩展到512GPU时,MiCS实现了99.4%的弱扩展效率,而DeepSpeed ZeRO第三阶段仅达到72%。我们正在将MiCS开源,相信它将大幅降低在Amazon EC2平台上训练大模型的时间和成本。
致谢:Yida Wang, Justin Chiu, Roshan Makhijani, RJ, Stephen Rawls, Xin Jin
更多精彩内容 请关注我的个人公众号 公众号(办公AI智能小助手)
公众号二维码