在线网站模板,wordpress响应瀑布主题,百度商城,四川网站设计0 开始之前
作者#xff1a;hadiii#xff0c;北京大学 电子信息硕士在读
本文从Llama 3报告出发#xff0c;基本整理一些现代LLM的技术。基本#xff0c;是说对一些具体细节不会过于详尽#xff0c;而是希望得到一篇相对全面#xff0c;包括预训练#xff0c;后训练hadiii北京大学 电子信息硕士在读
本文从Llama 3报告出发基本整理一些现代LLM的技术。基本是说对一些具体细节不会过于详尽而是希望得到一篇相对全面包括预训练后训练推理又能介绍清楚一些具体技术例如RMDPOKV CacheGQAPagedAttentionData Parallelism等等的索引向文章。由于东西比较多且无法详尽细节所以推荐大家二次整理为自己的笔记。
本文的主要参考是Llama Team的The Llama 3 Herd of Models报告原文以及沐神回归B站新出的论文精读系列。同时也包括一些知乎的优秀文章。
1 Intro Illustration of the overall architecture and training of Llama 3 Overview of the Llama 3 Herd of models.
1.1 现代基础模型训练的主要阶段
a预训练阶段pre-training stage算法相对直接一般是用大量的数据去做下一个词的预测next-word prediction。
b后训练阶段post-training stage算法比较丰富包括SFTRLHFDPO等等。任务上看包括让模型做一些指令跟随的任务instruction following将模型偏好对齐到人类喜好上align with human preferences或者提高模型在特定任务的能力例如codemathroleplay等等。
从过去的模型看基本上可以认为GPT123都是在做pre-training而InstructGPT和RLHF则是在做post-training。以上是较为笼统的介绍。
1.2 现代基础模型训练的关键 MetaWe believe there are three key levers in the development of high-quality foundation models: data, scale, and managing complexity. meta认为现代基础模型训练的关键是data, scale, and managing complexity。
a关于data Llama系列有堆数据的传统相较于Llama 2 1.8T的预训练语料Llama 3的预训练语料堆到了15T的multilingual tokens。 沐神15个T可能是目前在公有的网络上面能够抓到的文本数据的一个大概的上限这个上限的意思是指与其再找一些增量的数据不如去调整现有的数据的质量。 b关于scaleLlama 3.1提供了8B70B405B三个规模。每个规模的性能差异可参考下面的benchmark。
c关于managing complexity复杂度管理说白了即Llama 3的算法相对简单。Llama 3选择了一个标准的稠密Transformer模型架构只进行了少量调整而没有选择MOE。后训练方面Llama 3采用了SFT、RS和DPO即一套相对简单的过程而不是更复杂的RLHF算法因为后者往往稳定性较差且更难以扩展。这些都属于design choice。23章会详细介绍相关技术。
1.3 benchmark表现
Llama 3各规格模型的benchmark表现如下。简要介绍其中的MMLU和IFEval。 Performance of finetuned Llama 3 models on key benchmark evaluations.
aMMLU系列 类似于各种考试里面的选择题只是主要考察模型的知识面背答案。
Question: Glucose is transported into the muscle cell:Choices:
A. via protein transporters called GLUT4.
B. only in the presence of insulin.
C. via hexokinase.
D. via monocarbylic acid transporters.Correct answer: A
原版MMLU是比较老的benchmark存在大家overfit的可能性。MMLU-Pro相对更新一些可以看到在MMLU-Pro上8B70B405B的差距相当大说明参数规模和内化到权重中的知识量还是非常相关的。
bIFEval IF即Instruction Following考察模型对指令的理解和遵循能力。原文见IFEval Dataset | Papers With Code[1]。 IFEval 示例
在IFEVAL上8B和70B的差距还是很明显的80.4/87.5而70B和405B的差距已经不明显了87.5/88.6。说明参数规模到达一定程度后再想通过扩大规模来提IF能力可能会逐渐不显著。
c剩下的benchmark则偏垂直一些分别包含了CodeMathReasoningTool useLong contextMultilingual可参见报告原文。 补充上述评估集既然都有overfit和leaking的风险那还有没有其他的benchmark呢当然比如LiveBench这种monthly更新的benchmarkLiveBench[2]。不过天底下是没有完美的benchmark的尤其是对于具体业务而言。 总体上看8B和70B在各方面差距都还是比较明显但70B和405B在以上的评估集中则差异相对小一些。405B的推理和训练都比较慢一般情况下70B算是复杂应用的首选。如果特别复杂再考虑405B毕竟性价比还是会差一些。值得一提的是Llama 3.1 70B在IFEval上接近Claude3.5 sonnet的水准。
2 Pre-Training MetaLanguage model pre-training involves: (1) the curation and filtering of a large-scale training corpus, (2) the development of a model architecture and corresponding scaling laws for determining model size, (3) the development of techniques for efficient pre-training at large scale, and (4) the development of a pre-training recipe. We present each of these components separately below. 上文比较笼统地说明了Pre-Training的要点。
2.1 Pre-Training Data • Web Data Curation
预训练数据处理的要点包括de-duplication methods and data cleaning mechanisms即去重和清洗如果做得不好质量会很差。具体报告中的Web Data Curation章节提到了以下内容
aPII and safety filtering报告提到预训练数据中移除了包含PIIpersonally identifiable information关于人的身份信息隐私信息和成人内容的域名。但具体是什么一个标准来锚定该数据是否属于PII和成人内容未给出示例一类的说明所以大概率是混了一些进去的。
bText extraction and cleaning由于web data是raw HTML content所以Llama构建了一个parser来解析各类文档。有趣的观点是报告认为Markdown对模型的性能有害因此删除了所有Markdown marker。但挪掉之后具体怎么做的未加说明。
cDe-duplicationLlama使用了三个级别的去重URLdocument, and line level。具体来说URL去重即保留每个URL对应页面的最新版本。document级别则在整个数据集上采用了global MinHash来去除近似重复的文档。line level的具体做法则是按照每30M的documents进行搜索去除其中出现超过6次的文本行。
dHeuristic filtering启发式的过滤。包括n-gram的过滤如果n比较长重复较多则把该行去掉典型的例子是logging文本。也包括危险词的过滤如果一个网页的dirty word太多则去掉。报告还提到使用了基于token-distribution Kullback-Leibler divergenceKL散度的方法来过滤过于奇葩的数据。即如果一个文档和其他文档算KL的距离差太远的话就把该文档标记为奇怪的文档去掉。
KL散度的概念比较常用是用于衡量两个概率分布之间的差异程度。定义为 eModel-based quality filtering基于模型的分类。比如fasttext和基于Llama 2训练的Roberta-based classifiers分类包括分高质量or低质量也可以是打领域tag等等。
fCode and reasoning data and Multilingual data也是一些特定数据的抽取pipeline花钱花人力做的一些工作。 • 数据混合Data Mix
数据配比确实相当重要且是实验性较强的工作炼丹烧钱烧时间出成果。报告中提到了Knowledge classification和scaling law的一些实验。
aKnowledge classification. 即使用一个分类器划分数据的类别例如客观知识类娱乐八卦类成人内容类......娱乐八卦类的数据对模型就不太好分类后就可以让这类数据少来一些。
**bScaling laws for data mix. **即多做不同配比的实验看指标变化。稍详细一点说是在不同的小模型上做不同的配比实验然后用来预测更大scale的最优配比。
总结最后的预训练数据大概是50%的general knowledge25%的mathematical and reasoning数据17%的code数据8%的多语言数据。 • 退火数据Annealing Data
报告发现在少量高质量的code和math的数据上做一下学习率的退火能够提升预训练模型的benchmark performance。这很符合直觉即考前多背一下题目考的会更好一些。
具体来说是在大量通用数据的训练完成后用一小撮高质量的特定领域数据继续训练同时将学习率慢慢降低。Llama 3在预训练的最后40M token采取了将LR线性退火到0的方法同时配合数据配比调整。最后8B模型在GSM8k和MATH验证集上提升不错但对405B的模型提升却可以忽略不计说明该参数规模的模型也许不需要specific in-domain的训练样本来提升性能。
同时报告提到可以使用退火来评估domain-specific的小数据集的质量比做Scaling Law的相关实验效率更高。
2.2 Model Architecture
总体上看Llama 3相较于2做了以下改动GQA面向一个sequence内部的不同文档的attention mask128K tokens的词表RoPE的调整。 • 基本推理过程 - KV Cache - GQA
Llama 3使用标准的Dense Transformer架构性能的提高主要来自于数据质量和多样性的改进以及训练规模的增加很喜欢说一些实话。当然和Llama 2相比还算有一些改变
例如上述提到的Grouped Query AttentionGQA用于加速推理节省解码的内存。对于70B及以上的模型几乎是必须用的技术。GQA涉及到KV CacheKV Cache涉及到基本的推理过程因此从推理开始写。
a基本推理过程 LLM推理过程
1、输入的Text根据词表被切分成n个token/token idsn个token ids被映射为n个embedding向量即1个embedding矩阵
2、embedding矩阵通过L个transformer block内部有各种注意力计算和FFN层在最后一层输出一个与输入形状相同的embedding矩阵
3、输出的n个embedding再过一个线性层lm_head该线性层的输出形状和词表大小一致。线性层输出再接一个softmax就得到了next token的概率分
4、随后再根据解码策略采样即可。Next token被算出来后加入输入的token序列长度为n1继续计算第n2个token这就是自回归。
bKV Cache
由于在计算第n1个token时L个Transformer block的中间结果是可以被保存下来的所以也许可以复用它们。我们把第 层第 个token的输出记为 。不难发现需要计算第n2个token时有很大一部分中间结果和计算n1时相同。可表示为
输入token序列: 与输入 token 序列为 的中间结果 一致所以我们利用缓存来可以减少大量的计算。 因此LLM推理过程分为Prefill和Decode两个阶段Prefill阶段会对Prompt中所有的token做并行计算得到Prompt中所有Tokens的KV Cache以及计算得到首Token。Prompt Tokens计算得到的KV Cache会保存下来留给Decode阶段复用
Decode阶段是一个自回归过程每解码一个新的Token都需要用到所有之前计算得到的KV Cache来计算当前query token的Attention。因此当输出长度越来越大或者context很长时KV Cache将会占用大量的显存。
本段内容以及下图引用自[KV Cache优化] MQA/GQA/YOCO/CLA/MLKV笔记: 层内和层间KV Cache共享[3]。 所以现在也存在prefix caching的概念简单地说就是把特定前缀的KV Cache缓存起来保留备用。对于指令复杂prompt较长的任务或者多轮对话场景非常有效。vllm已经可以很方便地开启prefix caching对长输入短输出的固定任务优化较好。KV Cache有大量的方向可以做是LLM推理优化的核心之一。
cGQAGrouped Query Attention
GQA是从模型层面降低KV Cache大小的手段之一。聊GQA之前的惯例是聊MHA和MQA。
MHA即Multi Head Attention多头注意力Transformer原文的attention形式。如下图所示MHA中每个Query向量都会对应一个KeyValue其输出会把每个注意力头的输出拼接起来。因此也会存较多的KV Cache。
MQA即Multi Query Attention。如下图所示MQA的思路比较直接就是让每个注意力头共用一个KV很显然相较于MHAKV Cache的占用直接减少到了1/head_num。不过由于结构的修改和Attention部分的参数量降低模型效果也必然受到影响。MQA似乎还是有些暴力。
因此出现了平衡的版本即GQAGrouped Query Attention。和图中一致即将Queries进行分组每组对应一个KV用一种折中的方法实现了减少计算量和KV Cache大小。 • RoPE旋转位置编码
首先应该聊聊经典的正弦编码。上文在LM的一次推理过程中提到token会映射为embedding向量在经典transformer的结构中这个embedding向量是词嵌入向量实体的孤立语义和位置编码实体间的关联语义的叠加。如何表征token的位置则是位置编码研究的问题。
《动手学深度学习PyTorch版》全要点笔记[4]经典transformer架构的位置编码是正弦编码。 正弦编码存在一些可能的问题比如对相对位置的表示较弱。RoPE则尝试在解决这些问题。
2.3 Scaling Laws
最初的形式
简单来说就是可以用小模型的一些实验结果来预测更大模型的结果。Scaling Law由OpenAI提出有两个大家熟知的结论
1、对于Decoder-only的LM计算量 模型参数量 数据大小 三者满足 。其中 的单位是Flops 是token数
2、模型的最终性能主要与 相关与模型的具体结构高矮胖瘦相关性不高。 -** Llama报告的内容**
之前的Scaling Law的预测方法主要是从next-token prediction loss训练时的validation loss出发的但这个loss和具体的任务表现不一定是绝对相关的。因为next-token prediction loss并不和具体任务表现例如数学绝对挂钩。所以Llama 3在做Scaling Law的实验时做了一个two-stage的方法
step1预测模型在具体下游任务上的NLL loss这个NLL loss还是和computeFLOPs挂钩成函数关系
step2利用Scaling Law将step1中的loss和具体的task accuracy关联起来。例如1.4的NLL loss对应0.25的accuracy1.2的误差对应0.95的accuracy所以这个规律和具体也可以解耦得到对于一个具体benchmark的Scaling Law曲线xy轴分别为loss和accuracy。
具体可见下图。ARC Challenge benchmark是一个做推理的多选题任务集。发现Scaling Law的预测还是挺准的。不过要注意不同任务的benchmark曲线可能也长得不一样。 2.4 Training Recipe
Llama 3的预训练策略主要由三步构成分别为(1) initial pre-training, (2) long-context pre-training, and (3) annealing.
Initial Pre-Training
主要是一些细节。简单翻译下。我们使用 AdamW 对 Llama 3 405B 进行预训练peak learning rate 为 linear warm up为 8000 步以及cosine learning rate预计在 1,200,000 步中衰减到 。为了提高训练稳定性我们在训练初期使用了较小的批次大小并随后增加了批次大小以提高效率。具体来说我们使用的initial batch size为4M 的tokens长度为 4096 的序列在训练了 252M tokens后后将这些值加倍8M sequences of 8,192 tokens。在训练了2.87 T token后再次将加倍到 16M。我们发现这种训练配方非常稳定我们观察到的损失峰值loss spikes很少并且不需要进行干预来纠正模型训练的偏差。
同时也做了一些data mix的调整。比如多拿非英语数据数学数据更多的最新网络数据等等。
Long Context Pre-Training
简单翻译下。在预训练的最后阶段我们对 long sequences 进行训练以支持最多 128K tokens 的 context窗口。我们之前没有对 long sequences 进行训练因为在 self-attention layers 中的计算量随 sequence length 呈平方增长。我们逐步增加支持的 context length进行 pre-training直到模型成功适应了增加的 context length。
我们通过以下两点评估成功的适应性(1) 模型在 short-context evaluations 中的表现是否完全恢复具体来说可能就是MMLU这些评测集(2) 模型是否能完美解决长度达到该值的 needle in a haystack 任务大海捞针任务。
在 Llama 3 405B 的 pre-training 中我们逐步在六个阶段增加了 context length从最初的 8K context窗口开始最终达到 128K context窗口。这个 long-context pre-training 阶段使用了大约 0.8T tokens。
Annealing
见2.1 Pre-Training Data同退火数据Annealing Data一节的内容。
3 Post-Training
下图很清晰地概括了Llama 3的后训练思路要素包括RMSFTRSDPO。本章会一一介绍。后训练是业内绝大多数NLPer做的事情。 Illustration of the overall post-training approach for Llama 3.
Llama 3后训练策略的backbone是一个Reward Model和一个Language Model。首先利用人类标注的偏好数据在pre-trained checkpoint之上训练一个RM。然后对pre-trained checkpoint做SFT之后用DPO做对齐作为本轮的最佳模型进入下轮迭代参与Rejection Sampling过程。
注意到训练是迭代式的即有多轮方法相同的训练。具体来说Llama 3进行了6轮的循环。在每个周期中收集新的偏好标注和 SFT 数据并从最新的模型中采样合成数据。
3.1 Reward Model 红框部分是RM的训练路径
首先应该简介一下Reward ModelRM。Reward Model是一种通过”偏好排序数据“A B C D训练得到的模型能够给一段文本一个偏好性例如安全性拟人性或者某种综合性的偏好的分数。这个分数是一个标量体现了人类的某种偏好。
而且A B可能不仅是A B也可能是远好于稍好于这个其实也能在损失函数里体现出来margin loss即Llama 2论文中 的部分 Preference Data构建
Llama详细讲解了Preference Data的构建过程。大概是这样几个step
step 1. 使用不同的数据配比和训练策略训练出多个for annotation的模型。部署多个不同的模型针对一个具体的user prompt采样出两个来自不同模型的response。
step 2. 标注同学会按照“好多少”的标准对response对进行打分包括四个等级significantly better, better, slightly better, or marginally better。
step 3. 偏好标注好后鼓励标注同学去“edit”chosen response即他们上一步已经选择了更好的那个答案改的更好。既可以直接修改chosen response本身也可以修改prompt来refine这些数据。
所以最后有一部分偏好数据是有三个ranked response的即edited chosen rejected。最后得到了这样的数据构成。 训练
训练和Llama 2类似。但是Llama 3反而在损失函数中去掉了margin loss即上文的 因为观察到在数据规模扩大后margin的改进效果逐渐减弱不如简化。
3.2 SFT
SFT大概是大多数同学接触LLM训练的首选。SFT即使用标准的交叉熵损失standard cross entropy loss同时mask prompt部分的loss训练target tokens的过程。
SFT Data构建
SFT数据有很多个来源Rejection Sampling的数据针对特定能力的合成数据少量的人工标注数据。
Rejection Sampling
Rejection Sampling的过程就是固定模型和prompt让LM采样出K个不同的答案根据RM的K个不同的分数选出最优答案。然后将该最优答案作为SFT数据做迭代式的训练。其中模型一般是前一轮训练中表现最好的checkpointK则可以调整一般是10-30。采样也有很多细节涉及到preference pair构造比如rejected可能不能无脑选最差的这些需要实验。
为了提高拒绝采样的效率Llama 3采用了PagedAttention。在 PagedAttention 中内存浪费只会发生在序列的最后一个块中可以很好地提升吞吐量。PagedAttention的内存共享也是很好的优化在Rejection Sampling中多个response是由同一个prompt生成的。在这种情况下prompt 的计算和内存可以在输出序列中共享。这里做一些简单介绍。
PagedAttention think of blocks as pages, tokens as bytes and requests as processes。 PagedAttention也是主流推理加速框架vLLM之选。大家应该都学过OS课了解虚拟内存内存分页管理内存碎片的概念。PagedAttention也是受到OS的启发认为KV Cache 没有必要存储在连续的内存中而是像操作系统一样把块的概念引入为“page”byte的概念引入为“token”进程的概念引入为“request”。
2.2节中我们提到由于在计算第n1个token时L个Transformer block的中间结果是可以被保存下来的所以也许可以复用它们。这被称作KV Cache。
但是KV Cache非常大需要一块连续内存来存储。并且我们在接收到sequence之前并不知道需要预留多少连续内存所以只能预先分配一个最大可能长度的cache导致了很多浪费这被称为“内部碎片”。而由于我们给多个sequence分配了内存所以剩下的内存不足以分配给新的sequence这一部分内存实际上也没用了所以也造成了浪费这被称为“外部碎片”。
PagedAttention 允许在非连续的内存空间中存储连续的 key 和 value 。具体来说 它将每个序列的 KV cache 划分为块每个块包含固定数量 token 的键和值。因此对于1个sequence最多会有1个page是有内存碎片的。由于按块分配外部碎片则彻底没有了。这和OS中的分页存储解决的问题一致。 回到SFT Data最后得到了这样的数据构成。 训练细节上Llama 3对405B进行微调时学习率为10⁻⁵训练步数在8.5K到9K之间。
3.3 Rejection Sampling
见 3.2 SFT 中的Rejection Sampling。
3.4 Direct Preference Optimization DPO在SFT之后进行目的是对齐人类的偏好。DPO是RLHF的简化目的是跳过复杂的RM训练等过程RLHF是先用标注的偏好数据去训练RM然后再指导RL的过程而DPO则这把上述两个步骤的loss融合到一起。
因此DPO的训练数据也是人类偏好数据格式类似于chosen-rejected对。DPO的损失如下 # DPO的数据格式
{ prompt: ,chosen: ,rejected:
}
DPO训练细节
在训练过程中Llama 3主要使用最新一批的偏好数据这些数据是通过前几轮对齐中表现最好的模型收集的需要用到RM。好处是这些数据更好地符合每轮正在优化的Policy Model的分布。所以这种DPO也是Iterative的属于on-policy。
a第一个细节是由于DPO损失函数的特点chosen response和rejected response中如果出现了一些共同的token则会导致相互冲突的学习目标因为模型需要同时增加和减少这些token的生成概率。所以Llama 3 Mask了formatting tokens 的 loss实验发现这些token如果算loss可能会导致tail repetition和突然生成终止的token。
b第二个细节是Llama 3给chosen sequence加上了一个negative log-likelihoodNLL loss从NLL loss和标准交叉熵损失的差别上看可以简单把NLL loss理解为SFT loss 加上NLL loss的好处是防止chosen response的log probability下降。坏处是chosen response如果本身不够好加这个SFT loss可能也不太好需要具体问题具体分析。
3.5 Data Processing and Quality Control
数据质量始终是最关键的。由于Llama 3的大部分训练数据是模型生成的因此需要仔细进行清洗和质量控制。这和绝大多数垂直业务模型也一致。
数据清洗Data cleaning
首先数据中往往存在一些不理想的模式Llama 3就有过度使用表情符号或感叹号的问题。一些非常经典的AI味语风也需要注意例如“过于喜欢滑跪”的语气问题遇事不决就“对不起”或“我道歉”这种样本应该不能在数据集中太多。
数据修剪Data pruning
Llama 3还应用了一些基于模型的技术来去除低质量的训练样本来提升模型整体性能
1、主题分类Topic classification首先对一个小模型如Llama 3 8B进行微调使其成为topic classifier例如专门用一大堆分类文本的任务数据去SFT一下。然后对所有训练数据进行分类将其分类为粗粒度类别如“数学推理”和细粒度类别如“几何和三角学”。
2、质量评分Quality scoring使用Reward model和基于Llama的信号为每个样本的质量打分。对于基于RM的评分我们将得分处于RM评分前四分之一的数据视为高质量数据。对于基于Llama的评分就是在Llama 3设计了一些打分的prompt一般英语数据使用三个维度的评分准确性、指令遵循性和语气/表达coding数据则使用两个维度的评分错误识别和用户意图并将获得最高分的样本视为高质量数据。
最后发现RM评分和Llama评分的分歧率较高但发现结合这两种机制能在meta内部测试集中取得最佳的召回率。最终选择被RM OR Llama 3分类模型标记为高质量的样本。
3、难度评分Difficulty scoring由于还希望优先处理对模型来说更复杂的样本因此报告提到两种难度评估方法对数据进行评分Instag和基于Llama的评分。对于Instag我们提示Llama 3 70B对SFT提示进行意图标注意图越多复杂性越高。基于Llama的思路和Quality scoring相似给了Llama 3一些prompt基于三个维度去打分。
4、语义去重Semantic deduplication最后进行语义去重。Llama 3首先使用RoBERTa对完整对话进行聚类然后在每个聚类内按质量分数 × 难度分数对其进行排序。接着遍历所有排序的样本进行贪婪选择仅保留与当前聚类中已见样本的余弦相似度小于阈值的样本。
4 Inference
首先请参考2.2 Model Architecture中关于基本推理过程KV CacheGQA部分的内容同时请参考3.2 SFT中关于PagedAttention的介绍。
4.1 Parallelism
ParallelismLLM分布式训练推理的一部分包括Data Parallelism和Model Parallelism本节做一些介绍。同样涉及到OS的一些概念。
Data Parallelism
Data Parallelism数据并行在每个设备上独立接收到不同的输入数据批次可称mini-batch并执行前向传播以计算该批次上的损失。在反向传播过程中每个设备会计算梯度并与所有其他设备交换这些梯度。然后使用这些梯度的平均值来更新每个设备上的模型权重确保在下一次训练步骤开始时所有设备都具有相同的模型权重。
好处是加快了batch的训练速度并且能够放下更大batch size的数据。坏处是每张卡也都使用了完整的模型权重得保证单卡能装得下。 Data Parallelism
Model Parallelism
Model Parallelism。模型并行包括Tensor Parallelism和Pipeline Parallelism。Model Parallelism解决的是单张卡放不下一个完整模型权重的问题每张显卡只放部分参数。一般来说会按照层进行划分参数按层划分一般叫Pipeline Parallelism。如果模型的一层如果都装不下了同一个模型层内拆分开训练是Tensor Parallelism。
好处是能放下更大的权重了坏处是后面层的卡需要等待前面层的计算结果所以GPU会有空闲状态。反向传播时也一样前面层的卡要等后面层的卡。 Llama 3中的Pipeline Parallelism
使用BF16数值表示模型参数时Llama 3 405B模型无法在一台配备8个Nvidia H100 GPU的单机内完全加载到GPU内存中。为了解决这一问题Llama 3 team使用两台机器node上的16个GPU并行进行BF16精度的模型推理。
在每个node内部利用NVLink的high bandwidth来启用tensor parallelism。而在node之间连接的带宽较低延迟较高因此采用pipeline parallelismGpipe。
在使用pipeline parallelism进行训练时bubble是一个主要的效率问题详见论文Gpipe。然而在推理过程中这并不是一个问题因为推理不涉及反向传递。因此Llama 3使用micro-batching来提高推理的吞吐量throughput。
Gpipe
在前向传播过程中GPipe 首先将每个大小为 N 的mini-batch划分为 M 个相等的micro-batch并将它们通过 K 个GPU进行流水线处理。在反向传播过程中每个micro-batch的梯度是基于前向传播时使用的相同模型参数计算的。在每个mini-batch结束时所有 M 个micro-batch的梯度会被累积并应用于所有GPU以更新模型参数。 micro-batching效果
报告在key-value cache pre-fill stage和decoding stage两个阶段见 2.2 Model Architecture 的讲解都评估了micro-batches的效果。在4096个输入 tokens和256 个输出 tokens的情况下报告发现在相同的local batch size下micro-batches提高了推理的吞吐量如下图所示。
这些改进归因于micro-batches在这两个阶段中实现了并发执行。由于micro-batches带来了额外的同步点synchronization points导致延迟增加但总体而言micro-batches仍然带来了更好的吞吐量-延迟平衡throughput-latency trade-off。
4.2 Quantization
Quantization量化也是当前热门的话题核心手段是通过降低模型参数的精度来减少GPU占用并减少计算量。和PagedAttention类似同样可以从OS中找到很多相关的东西。一些常见的精度表示如下 INT8 量化
INT 8量化相对简单。如图所示的是absmax的INT 8量化输入是一个FP16的向量。假设用 absmax 对向量[1.2, -0.5, -4.3, 1.2, -3.1, 0.8, 2.4, 5.4]进行量化。首先需要计算该向量的最大绝对值在本例中为5.4。Int8 的范围为[-127, 127]因此我们将127除以5.4得到缩放因子scaling factor23.5。最后将原始向量乘以缩放因子得到最终的量化向量[28, -12, -101, 28, -73, 19, 56, 127]。
要恢复原向量可以将 int8 量化值除以缩放因子但由于上面的过程是“四舍五入”的我们将丢失一些精度。 FP8 量化
Llama 3利用H100 GPU的原生FP8支持来执行低精度推理。为了启用低精度推理Llama 3对模型内部的大多数矩阵乘法应用FP8量化。实现细节见下面的两篇参考文章。特别是对模型中前馈网络层的大多数参数和激活值进行量化这些部分约占推理计算时间的50%。其中还有一些细节
Llama 3没有对模型的自注意力层中的参数进行量化。也没有在第一个和最后一个Transformer层中执行量化。并且采用了按行量化的方式对参数和激活矩阵的每一行计算缩放因子Scaling Factor。如下图所示。 量化结果
量化结果主要是两个方面一个是好处即efficiency的提升一个是坏处即accuracy的下降。
对于efficiencyLlama 3针对于4,096 input tokens and 256 output tokens做了定量实验在prefill阶段2.2 Model Architecture 中有详细介绍使用FP8推理可将吞吐量提高多达50%4k-9k在decode阶段也能更好地trade off throughput-latency。
对于accuracy在标准benchmark上即使不做上文所说的细节FP8推理的表现也与BF16推理相当。但是当Scaling Factor没有上限时模型有时会生成错误的响应所以benchmark无法正确和充分地反映FP8量化的影响。于是Llama 3使用FP8和BF16生成了100,000个响应选择用奖励模型的分布来分析。从下图可以看到FP8的得分几乎没有影响RM的得分分布。 Throughput-latency trade-off in FP8 inference with Llama 3 405B Reward score distribution for Llama 3 405B using BF16 and FP8 inference.
5 写在最后
最近平时工作可以说是把脑子想“干”了所以花大概三个周末完成了这篇接近2w字的文章。写完感觉有很多不足但还是随便找个时间发了吧。其一是本来是打算从Llama 3这种优质开源模型和报告出发进行一些知识上的梳理结果行文时几乎保留了论文原来的结构导致前一个知识点到下一个知识点不够丝滑
其二是由于水平不够和“综合性”考量的限制所以对很多需要深入的知识没有详尽。后面几个周末也许还会持续迭代一下本文主要是继续细化技术点。所以也恳请诸位指出错误或不足尽情提出需要补充内容的部分。
引用链接
[1] IFEval Dataset | Papers With Code: https://paperswithcode.com/dataset/ifeval[2] LiveBench: https://livebench.ai/[3] [KV Cache优化] MQA/GQA/YOCO/CLA/MLKV笔记: 层内和层间KV Cache共享: https://zhuanlan.zhihu.com/p/697311739[4] 《动手学深度学习PyTorch版》全要点笔记: https://zhuanlan.zhihu.com/p/664880302