闸北网站推广公司,个人博客网站备案,织梦手机网站模板删除,网页的设计与制作目录 前言1. 从MHA、MQA、GQA到MLA1.1 MHA1.2 瓶颈1.3 MQA1.4 GQA1.5 MLA1.5.1 Part 11.5.2 Part 21.5.3 Part 3 结语参考 前言 学习 DeepSeek 中的 MLA 模块#xff0c;究极缝合怪#xff0c;东抄抄西抄抄#xff0c;主要 copy 自苏神的文章#xff0c;仅供自己参考#… 目录 前言1. 从MHA、MQA、GQA到MLA1.1 MHA1.2 瓶颈1.3 MQA1.4 GQA1.5 MLA1.5.1 Part 11.5.2 Part 21.5.3 Part 3 结语参考 前言 学习 DeepSeek 中的 MLA 模块究极缝合怪东抄抄西抄抄主要 copy 自苏神的文章仅供自己参考 refer1缓存与效果的极限拉扯从MHA、MQA、GQA到MLA refer2: 博客分享从MHA、MQA、GQA到MLA 1. 从MHA、MQA、GQA到MLA 以下内容均来自于苏神的文章缓存与效果的极限拉扯从MHA、MQA、GQA到MLA 1.1 MHA
MHAMulti-Head Attention也就是多头注意力是开山之作 《Attention is all you need》 所提出的一种 Attention 的形式
在数学上多头注意力 MHA 等价于多个独立的单头注意力的拼接假设输入的行向量序列为 x 1 , x 2 , ⋯ , x l \bm{x}_1,\bm{x}_2,\cdots,\bm{x}_l x1,x2,⋯,xl其中 x i ∈ R d \bm{x}_i \in \mathbb{R}^d xi∈Rd那么 MHA 可以形式地记为 o t [ o t ( 1 ) , o t ( 2 ) , ⋯ , o t ( h ) ] o t ( s ) A t t e n t i o n ( q t ( s ) , k ≤ t ( s ) , v ≤ t ( s ) ) ≜ ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) v i ( s ) ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) q i ( s ) x i W q ( s ) ∈ R d k , W q ( s ) ∈ R d × d k k i ( s ) x i W k ( s ) ∈ R d k , W k ( s ) ∈ R d × d k v i ( s ) x i W v ( s ) ∈ R d v , W v ( s ) ∈ R d × d v \bm{o_{t}}\left[\bm{o_{t}^{(1)}},\bm{o_{t}^{(2)}},\cdots,\bm{o_{t}^{(h)}}\right] \\ \bm{o}_{t}^{(s)}\bm{Attention}\left(\bm{q}_{t}^{(s)},\bm{k}_{\leq t}^{(s)}, \bm{v}_{\leq t}^{(s)}\right)\triangleq\frac{\sum_{i\leq t}\exp\Bigl(\bm{q}_ {t}^{(s)}\bm{k}_{i}^{(s)}{}^{\top}\Bigr)\bm{v}_{i}^{(s)}}{\sum_{i\leq t}\exp \Bigl(\bm{q}_{t}^{(s)}\bm{k}_{i}^{(s)}{}^{\top}\Bigr)} \\ \begin{array}{l} \bm{q_{i}^{(s)}x_{i}W_{q}^{(s)}\in\mathbb{R}^{d_{k}}, \quad W_{q}^{(s)}\in\mathbb{R}^{d\times d_{k}}}\\ \bm{k_{i}^{(s)}x_{i}W_{k}^{(s)}\in\mathbb{R}^{d_{k}},\quad W_{k}^{(s)}\in \mathbb{R}^{d\times d_{k}}}\\ \bm{v_{i}^{(s)}x_{i}W_{v}^{(s)}\in\mathbb{R}^{d_{v}},\quad W_{v}^{(s)}\in \mathbb{R}^{d\times d_{v}}} \end{array} ot[ot(1),ot(2),⋯,ot(h)]ot(s)Attention(qt(s),k≤t(s),v≤t(s))≜∑i≤texp(qt(s)ki(s)⊤)∑i≤texp(qt(s)ki(s)⊤)vi(s)qi(s)xiWq(s)∈Rdk,Wq(s)∈Rd×dkki(s)xiWk(s)∈Rdk,Wk(s)∈Rd×dkvi(s)xiWv(s)∈Rdv,Wv(s)∈Rd×dv
简单起见这里省略了 Attention 矩阵的缩放因子 1 d k \frac{1}{\sqrt{d_k}} dk 1
实践上常见的设置是 d k d v d / h d_kd_vd/h dkdvd/h例如对于 LLaMA2-7B 有 d 4096 , h 32 , d k d v 128 {d4096,h32,d_{k}d_{v}128} d4096,h32,dkdv128LLaMa2-70B 则是 d 8192 , h 64 , d k d v 128 {d8192,h64,d_{k}d_{v}128} d8192,h64,dkdv128
这里只考虑主流自回归 LLM 所用的 Causal Attention在 token by token 递归生成时新预测出来的第 t 1 t1 t1 个 token并不会影响到已经算好的 k ≤ t ( s ) , v ≤ t ( s ) {k_{\leq t}^{(s)},v_{\leq t}^{(s)}} k≤t(s),v≤t(s)因此这部分结果我们可以缓存下来供后续生成调用避免不必要的重复计算这就是所谓的 KV Cache
关于 KV Cache 大家感兴趣的可以看看KV Cache的原理与实现
后面的 MQA、GQA、MLA 都是围绕“如何减少 KV Cache 同时尽可能地保证效果”这个主题发展而来的产物 上图展示了标准 MHA 下的 KV Cache 是多大它和注意力头数、序列长度等相关此时 KV Cache 的大小是 2 ∗ s e q _ l e n ∗ n u m _ h e a d ∗ h e a d _ d i m 2*seq\_len * num\_head * head\_dim 2∗seq_len∗num_head∗head_dim
Note该图片来自于 https://github.com/preacher-1/MLA_tutorial
代码实现如下
import math
import torch
import torch.nn as nn# Multi-Head Attention
class MHA(nn.Module):def __init__(self, d_model, num_heads):super().__init__()self.d_model d_modelself.num_heads num_headsassert d_model % num_heads 0self.head_dim d_model // num_headsself.q_linear nn.ModuleList([nn.Linear(d_model, self.head_dim, biasFalse) for _ in range(num_heads)])self.k_linear nn.ModuleList([nn.Linear(d_model, self.head_dim, biasFalse) for _ in range(num_heads)])self.v_linear nn.ModuleList([nn.Linear(d_model, self.head_dim, biasFalse) for _ in range(num_heads)])self.out_linear nn.Linear(d_model, d_model, biasFalse) def forward(self, x):bsz, seq_len, _ x.shapeoutputs []# Parallelfor i in range(self.num_heads):q self.q_linear[i](x) # (bsz, seq_len, head_dim)k self.k_linear[i](x) # (bsz, seq_len, head_dim)v self.v_linear[i](x) # (bsz, seq_len, head_dim)# RoPE# TODO: Implement RoPE# Attentionattention torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (bsz, seq_len, seq_len)# Casual maskmask torch.triu(torch.ones(seq_len, seq_len), diagonal1).bool()mask mask.unsqueeze(0).to(x.device) # (1, seq_len, seq_len)attention attention.masked_fill(mask, float(-inf))attention torch.softmax(attention, dim-1)# Outputoutput torch.matmul(attention, v) # (bsz, seq_len, seq_len)outputs.append(output)# Linear projectionoutput torch.cat(outputs, dim-1) # (bsz, seq_len, d_model)output self.out_linear(output)return output# Another implement for Multi-Head Attention
class MHA2(nn.Module):def __init__(self, d_model, num_heads):super().__init__()self.d_model d_modelself.num_heads num_headsassert d_model % num_heads 0self.head_dim d_model // num_headsself.q_linear nn.Linear(d_model, num_heads * self.head_dim, biasFalse)self.k_linear nn.Linear(d_model, num_heads * self.head_dim, biasFalse)self.v_linear nn.Linear(d_model, num_heads * self.head_dim, biasFalse)self.out_linear nn.Linear(d_model, d_model, biasFalse)def forward(self, x):bsz, seq_len, _ x.shapeq self.q_linear(x) # (bsz, seq_len, num_heads * head_dim)k self.k_linear(x) # (bsz, seq_len, num_heads * head_dim)v self.v_linear(x) # (bsz, seq_len, num_heads * head_dim)# matmul 只能在最后两个维度相乘, 需要对 NxD 的矩阵相乘, 做 1,2 维度的交换# (bsz, num_heads, seq_len, head_dim)q q.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)k k.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)v v.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)# RoPE# TODO: Implement RoPE# Attentionattention torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (bsz, num_heads, seq_len, seq_len)# Casual maskmask torch.triu(torch.ones(seq_len, seq_len), diagonal1).bool()mask mask.unsqueeze(0).to(x.device) # (1, seq_len, seq_len)attention attention.masked_fill(mask, float(-inf))attention torch.softmax(attention, dim-1)# Outputoutput torch.matmul(attention, v) # (bsz, num_heads, seq_len, head_dim)output output.transpose(1, 2) # (bsz, seq_len, num_heads, head_dim) output output.contiguous().view(bsz, seq_len, -1) # (bsz, seq_len, d_model)output self.out_linear(output)return output# Example usage
torch.manual_seed(10)
d_model 512
num_heads 8
mha MHA2(d_model, num_heads)
x torch.randn(10, 20, d_model) # (bsz, seq_len, d_mdeol)
output mha(x)
print(output.shape) # (10, 20, 512)Note代码参考自https://github.com/preacher-1/MLA_tutorial
此外 MHA 还有另外一种实现一次性将 Q , K , V Q,K,V Q,K,V 投影具体代码可以参考 https://github.com/karpathy/minGPT/tree/master/mingpt
1.2 瓶颈
为什么降低 KV Cache 的大小如此重要呢
众所周知一般情况下 LLM 的推理都是在 GPU 上进行的而单张 GPU 的显存是有限的一部分我们要用来存放模型的参数和前向计算的激活值这部分依赖于模型的体量选定模型后它就是个常数另外一部分我们要用来存放模型的 KV Cache这部分不仅依赖于模型的体量还依赖于模型的输入长度也就是在推理过程中是动态增长的当 Context 长度足够长时它的大小就会占主导地位可能超过一张卡甚至一台机8张卡的总显存量
在 GPU 上部署模型的原则是能一张卡部署的就不要跨多张卡能一台机部署的就不要跨多台机。这是因为“卡内通信带宽 卡间通信带宽 机间通信带宽”由于“木桶效应”模型部署时跨的设备越多受设备间通信带宽的的“拖累”就越大事实上即便是单卡 H100 内 SRAM 与 HBM 的带宽已经达到了 3TB/s但对于 Short Context 来说这个速度依然还是推理的瓶颈更不用说更慢的卡间、机间通信了
所以减少 KV Cache 的目的就是要实现在更少的设备上推理更长的 Context或者在相同的 Context 长度下让推理的 batch size 更大从而实现更快的推理速度或者更大的吞吐量。当然最终的目的都是为了实现更低的推理成本
1.3 MQA
MQAMulti-Query Attention是减少 KV Cache 的一次非常朴素的尝试首次提出自 《Fast Transformer Decoding: One Write-Head is All You Need》
MQA 的思路很简单直接让所有 Attention Head 共享同一个 K、V用公式来说就是取消 MHA 所有的 k , v \bm{k},\bm{v} k,v 的上标 ( s ) ^{(s)} (s) o t [ o t ( 1 ) , o t ( 2 ) , ⋯ , o t ( h ) ] o t ( s ) A t t e n t i o n ( q t ( s ) , k ≤ t ( s ) , v ≤ t ( s ) ) ≜ ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) v i ( s ) ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) q i ( s ) x i W q ( s ) ∈ R d k , W q ( s ) ∈ R d × d k k i ( s ) x i W k ( s ) ∈ R d k , W k ( s ) ∈ R d × d k v i ( s ) x i W v ( s ) ∈ R d v , W v ( s ) ∈ R d × d v \bm{o_{t}}\left[\bm{o_{t}^{(1)}},\bm{o_{t}^{(2)}},\cdots,\bm{o_{t}^{(h)}}\right] \\ \bm{o}_{t}^{(s)}\bm{Attention}\left(\bm{q}_{t}^{(s)},\bm{k}_{\leq t}^{\color{red}{\bcancel{(s)}}}, \bm{v}_{\leq t}^{\color{red}{\bcancel{(s)}}}\right)\triangleq\frac{\sum_{i\leq t}\exp\Bigl(\bm{q}_ {t}^{(s)}\bm{k}_{i}^{\color{red}{\bcancel{(s)}}}{}^{\top}\Bigr)\bm{v}_{i}^{\color{red}{\bcancel{(s)}}}}{\sum_{i\leq t}\exp \Bigl(\bm{q}_{t}^{(s)}\bm{k}_{i}^{\color{red}{\bcancel{(s)}}}{}^{\top}\Bigr)} \\ \begin{array}{l} \bm{q_{i}^{(s)}x_{i}W_{q}^{(s)}\in\mathbb{R}^{d_{k}}, \quad W_{q}^{(s)}\in\mathbb{R}^{d\times d_{k}}}\\ \bm{k_{i}^{\color{red}{\bcancel{(s)}}}x_{i}W_{k}^{\color{red}{\bcancel{(s)}}}\in\mathbb{R}^{d_{k}},\quad W_{k}^{\color{red}{\bcancel{(s)}}}\in \mathbb{R}^{d\times d_{k}}}\\ \bm{v_{i}^{\color{red}{\bcancel{(s)}}}x_{i}W_{v}^{\color{red}{\bcancel{(s)}}}\in\mathbb{R}^{d_{v}},\quad W_{v}^{\color{red}{\bcancel{(s)}}}\in \mathbb{R}^{d\times d_{v}}} \end{array} ot[ot(1),ot(2),⋯,ot(h)]ot(s)Attention(qt(s),k≤t(s) ,v≤t(s) )≜∑i≤texp(qt(s)ki(s) ⊤)∑i≤texp(qt(s)ki(s) ⊤)vi(s) qi(s)xiWq(s)∈Rdk,Wq(s)∈Rd×dkki(s) xiWk(s) ∈Rdk,Wk(s) ∈Rd×dkvi(s) xiWv(s) ∈Rdv,Wv(s) ∈Rd×dv
使用 MQA 的模型包括 PaLM、StarCoder、Gemini 等。很明显MQA 直接将 KV Cache 减少到了原来的 1 / h 1/h 1/h这是非常可观的单从节省显存角度看已经是天花板了
效果方面目前看来大部分任务的损失都比较有限且 MQA 的支持者相信这部分损失可以通过进一步训练来弥补回。此外注意到 MQA 由于共享了 K、V将会导致 Attention 的参数量减少了将近一半而为了模型总参数量的不变通常会相应地增大 FFN/GLU 的规模这也能弥补一部分效果损失 Note该图片来自于 https://github.com/preacher-1/MLA_tutorial
上图展示了标准 MQA 下的 KV Cache 是多大和标准的 MHA 相比 Q Q Q 保持不变但所有头的 K , V K,V K,V 共享此时 KV Cache 的大小是 2 ∗ s e q _ l e n ∗ 1 ∗ h e a d _ d i m 2*seq\_len * 1 * head\_dim 2∗seq_len∗1∗head_dim
代码实现如下
import math
import torch
import torch.nn as nn# Multi-Query Attention
class MQA(nn.Module):def __init__(self, d_model, num_heads):super().__init__()self.d_model d_modelself.num_heads num_headsassert d_model % num_heads 0self.head_dim d_model // num_headsself.q_linear nn.Linear(d_model, d_model)self.k_linear nn.Linear(d_model, self.head_dim)self.v_linear nn.Linear(d_model, self.head_dim)self.out_linear nn.Linear(d_model, d_model)def forward(self, x):bsz, _, _ x.shape# Linear projections, all heads share the same K, Vq self.q_linear(x) # (bsz, seq_len, d_model)k self.k_linear(x) # (bsz, seq_len, head_dim)v self.v_linear(x) # (bsz, seq_len, head_dim)# Reshape for multi-head attentionq q.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)# (bsz, num_heads, seq_len, head_dim)k torch.unsqueeze(k, 1) # (bsz, 1, seq_len, head_dim)v torch.unsqueeze(v, 1) # (bsz, 1, seq_len, head_dim)# Attentionattention torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (bsz, num_heads, seq_len, seq_len)attention torch.softmax(attention, dim-1)# Outputoutput torch.matmul(attention, v) # (bsz, num_heads, seq_len, head_dim)output output.transpose(1, 2) # (bsz, seq_len, num_heads, head_dim) output output.contiguous().view(bsz, -1, d_model)# Linear projectionoutput self.out_linear(output)return output# Example usage
torch.manual_seed(10)
d_model 512
num_heads 8
mqa MQA(d_model, num_heads)
x torch.randn(10, 20, d_model) # (bsz, seq_len, d_mdeol)
output mqa(x)
print(output.shape) # (10, 20, 512)Note代码参考自https://github.com/preacher-1/MLA_tutorial
这个代码和 MHA 实现类似不同的是由于 MQA 所有头共享同一个 K , V K,V K,V因此这里的 W k , W v W_k,W_v Wk,Wv 投影矩阵的维度是 head_dim 而不再是 num_heads * head_dim在 forward 时通过广播机制将 W k , W v W_k,W_v Wk,Wv 共享到其他头即可
1.4 GQA
然而也有人担心 MQA 对 KV Cache 的压缩太严重以至于会影响模型的学习效率以及最终结果。为此一个 MHA 和 MQA 之间的过渡版本 GQAGrouped-Query Attention应运而生出自论文 《GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints》是 23 年的工作
GQA 的思想也很朴素它就是将所有 Head 分为 g g g 个组 g g g 可以整除 h h h每组共享同一对 K、V用数学公式表示为 o t [ o t ( 1 ) , o t ( 2 ) , ⋯ , o t ( h ) ] o t ( s ) A t t e n t i o n ( q t ( s ) , k ≤ t ( ⌈ s g / h ⌉ ) , v ≤ t ( ⌈ s g / h ⌉ ) ) ≜ ∑ i ≤ t exp ( q t ( s ) k i ( ⌈ s g / h ⌉ ) ⊤ ) v i ( ⌈ s g / h ⌉ ) ∑ i ≤ t exp ( q t ( s ) k i ( ⌈ s g / h ⌉ ) ⊤ ) q i ( s ) x i W q ( s ) ∈ R d k , W q ( s ) ∈ R d × d k k i ( ⌈ s g / h ⌉ ) x i W k ( ⌈ s g / h ⌉ ) ∈ R d k , W k ( ⌈ s g / h ⌉ ) ∈ R d × d k v i ( ⌈ s g / h ⌉ ) x i W v ( ⌈ s g / h ⌉ ) ∈ R d v , W v ( ⌈ s g / h ⌉ ) ∈ R d × d v \bm{o_{t}}\left[\bm{o_{t}^{(1)}},\bm{o_{t}^{(2)}},\cdots,\bm{o_{t}^{(h)}}\right] \\ \bm{o}_{t}^{(s)}\bm{Attention}\left(\bm{q}_{t}^{(s)},\bm{k}_{\leq t}^{\color{red}{(\lceil sg/h \rceil)}}, \bm{v}_{\leq t}^{\color{red}{(\lceil sg/h \rceil)}}\right)\triangleq\frac{\sum_{i\leq t}\exp\Bigl(\bm{q}_ {t}^{(s)}\bm{k}_{i}^{\color{red}{(\lceil sg/h \rceil)}}{}^{\top}\Bigr)\bm{v}_{i}^{\color{red}{(\lceil sg/h \rceil)}}}{\sum_{i\leq t}\exp \Bigl(\bm{q}_{t}^{(s)}\bm{k}_{i}^{\color{red}{(\lceil sg/h \rceil)}}{}^{\top}\Bigr)} \\ \bm{q_{i}^{(s)}x_{i}W_{q}^{(s)}\in\mathbb{R}^{d_{k}}, \quad W_{q}^{(s)}\in\mathbb{R}^{d\times d_{k}}}\\ \bm{k_{i}^{\color{red}{(\lceil sg/h \rceil)}}x_{i}W_{k}^{\color{red}{(\lceil sg/h \rceil)}}\in\mathbb{R}^{d_{k}},\quad W_{k}^{\color{red}{(\lceil sg/h \rceil)}}\in \mathbb{R}^{d\times d_{k}}}\\ \bm{v_{i}^{\color{red}{(\lceil sg/h \rceil)}}x_{i}W_{v}^{\color{red}{(\lceil sg/h \rceil)}}\in\mathbb{R}^{d_{v}},\quad W_{v}^{\color{red}{(\lceil sg/h \rceil)}}\in \mathbb{R}^{d\times d_{v}}} ot[ot(1),ot(2),⋯,ot(h)]ot(s)Attention(qt(s),k≤t(⌈sg/h⌉),v≤t(⌈sg/h⌉))≜∑i≤texp(qt(s)ki(⌈sg/h⌉)⊤)∑i≤texp(qt(s)ki(⌈sg/h⌉)⊤)vi(⌈sg/h⌉)qi(s)xiWq(s)∈Rdk,Wq(s)∈Rd×dkki(⌈sg/h⌉)xiWk(⌈sg/h⌉)∈Rdk,Wk(⌈sg/h⌉)∈Rd×dkvi(⌈sg/h⌉)xiWv(⌈sg/h⌉)∈Rdv,Wv(⌈sg/h⌉)∈Rd×dv
其中 ⌈ ⋅ ⌉ \lceil\cdot\rceil ⌈⋅⌉ 是上取整符号
GQA 提供了从 MHA 到 MQA 的自然过渡当 g h gh gh 时就是 MHA当 g 1 g1 g1 时就是 MQA当 1 g h 1gh 1gh 时它只将 KV Cache 压缩到 g / h g/h g/h压缩率不如 MQA但同时也提供了更大的自由度效果上更有保证。GQA 最知名的使用者大概是 Meta 开源的 LLAMA2-70B以及 LLAMA3 全系列此外使用 GQA 的模型还有 TigerBot、DeepSeek-V1、StarCoder2、Yi、ChatGLM2、ChatGLM3 等相比使用 MQA 的模型更多 Note该图片来自于 https://github.com/preacher-1/MLA_tutorial
上图展示了标准 GQA 下的 KV Cache 是多大GQA 是 MHA 和 MQA 的一种折中它将 K , V K,V K,V 分成 group 组每组共享同一个 K , V K,V K,V此时 KV Cache 的大小是 2 ∗ s e q _ l e n ∗ n _ g r o u p s ∗ h e a d _ d i m 2*seq\_len * n\_groups * head\_dim 2∗seq_len∗n_groups∗head_dim
代码实现如下
import math
import torch
import torch.nn as nn# Grouped-Query Attention
class GQA(torch.nn.Module):def __init__(self, d_model, num_heads, num_groups):super().__init__()self.d_model d_modelself.num_heads num_headsself.num_groups num_groupsself.group_heads num_heads // num_groupsself.head_dim d_model // num_headsself.W_q nn.Linear(d_model, d_model)self.W_k nn.Linear(d_model, self.head_dim * num_groups)self.W_v nn.Linear(d_model, self.head_dim * num_groups)self.out_linear nn.Linear(d_model, d_model)def forward(self, x):bsz, seq_len, _ x.shape# Linear projections, each group share the same K, Vq self.W_q(x) # (bsz, seq_len, d_model)k self.W_k(x) # (bsz, seq_len, head_dim * num_groups)v self.W_v(x) # (bsz, seq_len, head_dim * num_groups)# Reshape for multi-head attention# (bsz, num_groups, gropus_head, seq_len, head_dim)q q.view(bsz, seq_len, self.num_groups, self.group_heads, self.head_dim).permute(0, 2, 3, 1, 4)k k.view(bsz, seq_len, self.num_groups, self.head_dim).transpose(1, 2) # (bsz, num_groups, seq_len, head_dim)v v.view(bsz, seq_len, self.num_groups, self.head_dim).transpose(1, 2) # (bsz, num_groups, seq_len, head_dim)k torch.unsqueeze(k, 2) # (bsz, num_groups, 1, seq_len, head_dim)v torch.unsqueeze(v, 2) # (bsz, num_groups, 1, seq_len, head_dim)# Attentionattention torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) # (bsz, num_gropus, gropus_head, seq_len, seq_len)attention torch.softmax(attention, dim-1)# Outputoutput torch.matmul(attention, v) # (bsz, num_groups, gropus_head, seq_len, head_dim)output output.permute(0, 3, 1, 2, 4).contiguous().view(bsz, -1, self.d_model)# Linear projectionoutput self.out_linear(output)return output# Example usage
torch.manual_seed(10)
d_model 512
num_heads 8
num_groups 4
gqa GQA(d_model, num_heads, num_groups)
x torch.randn(32, 10, d_model) # (bsz, seq_len, d_mdeol)
output gqa(x)
print(output.shape) # (32, 10, 512)Note代码参考自https://github.com/preacher-1/MLA_tutorial
1.5 MLA
有了 MHA、MQA、GQA 的铺垫我们理解起 MLAMulti-head Latent Attention就相对容易一些了。DeepSeek-V2 的技术报告里是从低秩投影类似于 LoRA的角度引入 MLA 的但苏神认为低秩投影这个角度并不贴近本质MLA 的本质是低秩投影之后的工作
1.5.1 Part 1
GQA 在投影之后做了什么呢首先它将向量对半分两份分别作为 K、V然后每一份又均分为 g g g 份每一份复制 h / g h/g h/g 次以此来“凑”够 h h h 个 Attention Head 所需要的 K、V
我们知道分割、复制都是简单的线性变换所以 MLA 的第一个想法是将这些简单的线性变换换成一般的线性变换以增强模型的能力 o t [ o t ( 1 ) , o t ( 2 ) , ⋯ , o t ( h ) ] o t ( s ) A t t e n t i o n ( q t ( s ) , k ≤ t ( s ) , v ≤ t ( s ) ) ≜ ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) v i ( s ) ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) q i ( s ) x i W q ( s ) ∈ R d k , W q ( s ) ∈ R d × d k k i ( s ) c i W k ( s ) ∈ R d k , W k ( s ) ∈ R d c × d k v i ( s ) c i W v ( s ) ∈ R d v , W v ( s ) ∈ R d c × d v c i x i W c ∈ R d c , W c ∈ R d × d c \bm{o_{t}}\left[\bm{o_{t}^{(1)}},\bm{o_{t}^{(2)}},\cdots,\bm{o_{t}^{(h)}}\right] \\ \bm{o}_{t}^{(s)}\bm{Attention}\left(\bm{q}_{t}^{(s)},\bm{k}_{\leq t}^{(s)}, \bm{v}_{\leq t}^{(s)}\right)\triangleq\frac{\sum_{i\leq t}\exp\Bigl(\bm{q}_ {t}^{(s)}\bm{k}_{i}^{(s)}{}^{\top}\Bigr)\bm{v}_{i}^{(s)}}{\sum_{i\leq t}\exp \Bigl(\bm{q}_{t}^{(s)}\bm{k}_{i}^{(s)}{}^{\top}\Bigr)} \\ \begin{array}{l} \bm{q_{i}^{(s)}x_{i}W_{q}^{(s)}\in\mathbb{R}^{d_{k}}, \quad W_{q}^{(s)}\in\mathbb{R}^{d\times d_{k}}}\\ \bm{k_{i}^{(s)}c_{i}W_{k}^{(s)}\in\mathbb{R}^{d_{k}},\quad W_{k}^{(s)}\in \mathbb{R}^{d_c\times d_{k}}}\\ \bm{v_{i}^{(s)}c_{i}W_{v}^{(s)}\in\mathbb{R}^{d_{v}},\quad W_{v}^{(s)}\in \mathbb{R}^{d_c\times d_{v}}} \end{array}\\ \bm{c_{i}x_{i}W_{c}\in\mathbb{R}^{d_{c}},\quad W_{c}\in \mathbb{R}^{d\times d_{c}}} ot[ot(1),ot(2),⋯,ot(h)]ot(s)Attention(qt(s),k≤t(s),v≤t(s))≜∑i≤texp(qt(s)ki(s)⊤)∑i≤texp(qt(s)ki(s)⊤)vi(s)qi(s)xiWq(s)∈Rdk,Wq(s)∈Rd×dkki(s)ciWk(s)∈Rdk,Wk(s)∈Rdc×dkvi(s)ciWv(s)∈Rdv,Wv(s)∈Rdc×dvcixiWc∈Rdc,Wc∈Rd×dc 这里博主有些困惑那看了 博客分享从MHA、MQA、GQA到MLA 视频之后大概理解了 MLA 想要做的事情前面我们提到 GQA 的实现中有一些分割、复制的变换MLA 出于增强 GQA 性能的目的想要将这些简单的线性变换加上一些可学习的参数让其变成一般的线性变换 以下分析内容来自于博客分享从MHA、MQA、GQA到MLA MQA 和 GQA 的“升维”投影矩阵 原始的 GQA 先将输入 x i x_i xi 分别压缩到 g d k gd_k gdk 和 g d v g d_v gdv 维再复制 g g g 份得到可以直接和 q i q_i qi 相乘的 k i k_i ki 和 v i v_i vi。将 GQA 的投影矩阵记为 W c \boldsymbol{W}_c Wc那么有 c i x i W c ∈ R g ( d k d v ) , W c ∈ R d × g ( d k d v ) c i [ k i ( 1 ) , ⋯ , k i ( g ) , v i ( 1 ) , ⋯ , v i ( g ) ] [ c k i , c v i ] W s p l i t k [ I g d k 0 g d k ] W s p l i t v [ 0 g d v I g d v ] c k i c i W s p l i t k ∈ R g d k , c v i c i W s p l i t v ∈ R g d v \boldsymbol{c}_i \boldsymbol{x}_i \boldsymbol{W}_c \in \mathbb{R}^{g(d_kd_v)},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times g(d_kd_v)}\\\boldsymbol{c}_i [\boldsymbol{k}_i^{(1)}, \cdots, \boldsymbol{k}_i^{(g)}, \boldsymbol{v}_i^{(1)}, \cdots, \boldsymbol{v}_i^{(g)}] [\boldsymbol{ck}_i, \boldsymbol{cv}_i]\\ \boldsymbol{W}_{split}^k \begin{bmatrix}\boldsymbol{I}_{gd_k} \\\boldsymbol{0}_{gd_k}\end{bmatrix} \quad\boldsymbol{W}_{split}^v \begin{bmatrix}\boldsymbol{0}_{gd_v} \\\boldsymbol{I}_{gd_v}\end{bmatrix} \\\boldsymbol{ck}_i \boldsymbol{c}_i \boldsymbol{W}_{split}^k \in \mathbb{R}^{gd_k},\quad \boldsymbol{cv}_i \boldsymbol{c}_i \boldsymbol{W}_{split}^v \in \mathbb{R}^{gd_v}\\ cixiWc∈Rg(dkdv),Wc∈Rd×g(dkdv)ci[ki(1),⋯,ki(g),vi(1),⋯,vi(g)][cki,cvi]Wsplitk[Igdk0gdk]Wsplitv[0gdvIgdv]ckiciWsplitk∈Rgdk,cviciWsplitv∈Rgdv 这里 W s p l i t k , W s p l i t v \boldsymbol{W}_{split}^k,\boldsymbol{W}_{split}^v Wsplitk,Wsplitv 实现了形式上的分割操作得到 c k i , c v i \boldsymbol{ck}_i, \boldsymbol{cv}_i cki,cvi。下面我们将构造“复制”操作的投影矩阵 W k ∈ R g d k × h d k [ I d k I d k ⋯ I d k 0 d k 0 d k ⋯ 0 d k ⋯ 0 d k 0 d k ⋯ 0 d k 0 d k 0 d k ⋯ 0 d k I d k I d k ⋯ I d k ⋯ 0 d k 0 d k ⋯ 0 d k ⋮ ⋮ ⋱ ⋮ ⋮ ⋮ ⋱ ⋮ ⋱ ⋮ ⋮ ⋱ ⋮ 0 d k 0 d k ⋯ 0 d k 0 d k 0 d k ⋯ 0 d k ⋯ I d k I d k ⋯ I d k ] \boldsymbol{W}_k \in \mathbb{R}^{g d_k\times h d_k} \begin{bmatrix} \boldsymbol{I}_{d_k} \boldsymbol{I}_{d_k} \cdots \boldsymbol{I}_{d_k} \boldsymbol{0}_{d_k } \boldsymbol{0}_{d_k } \cdots \boldsymbol{0}_{d_k } \cdots \boldsymbol{0}_{d_k } \boldsymbol{0}_{d_k } \cdots \boldsymbol{0}_{d_k } \\ \boldsymbol{0}_{d_k } \boldsymbol{0}_{d_k } \cdots \boldsymbol{0}_{d_k } \boldsymbol{I}_{d_k} \boldsymbol{I}_{d_k} \cdots \boldsymbol{I}_{d_k} \cdots \boldsymbol{0}_{d_k } \boldsymbol{0}_{d_k } \cdots \boldsymbol{0}_{d_k } \\ \vdots \vdots \ddots \vdots \vdots \vdots \ddots \vdots \ddots \vdots \vdots \ddots \vdots \\ \boldsymbol{0}_{d_k } \boldsymbol{0}_{d_k } \cdots \boldsymbol{0}_{d_k } \boldsymbol{0}_{d_k } \boldsymbol{0}_{d_k } \cdots \boldsymbol{0}_{d_k } \cdots \boldsymbol{I}_{d_k} \boldsymbol{I}_{d_k} \cdots \boldsymbol{I}_{d_k} \end{bmatrix} Wk∈Rgdk×hdk Idk0dk⋮0dkIdk0dk⋮0dk⋯⋯⋱⋯Idk0dk⋮0dk0dkIdk⋮0dk0dkIdk⋮0dk⋯⋯⋱⋯0dkIdk⋮0dk⋯⋯⋱⋯0dk0dk⋮Idk0dk0dk⋮Idk⋯⋯⋱⋯0dk0dk⋮Idk 其中每行 I d k \boldsymbol{I}_{d_k} Idk 重复 h / g h/g h/g 遍代表从 groups 到 heads 的“放缩”倍数或“复制”次数一共有 g g g 行对应原来的 d i ( s ) d_i^{(s)} di(s)共 g g g个。 W v ∈ R g d v × h d v \boldsymbol{W}_v \in \mathbb{R}^{g d_v\times h d_v} Wv∈Rgdv×hdv 的形式与 W k \boldsymbol{W}_k Wk 相同故不赘述。将前者左乘 c k i \boldsymbol{ck}_i cki则有 c k i W k [ k i ( 1 ) , ⋯ , k i ( g ) ] ⋅ [ I d k I d k ⋯ I d k 0 d k 0 d k ⋯ 0 d k ⋯ 0 d k 0 d k ⋯ 0 d k 0 d k 0 d k ⋯ 0 d k I d k I d k ⋯ I d k ⋯ 0 d k 0 d k ⋯ 0 d k ⋮ ⋮ ⋱ ⋮ ⋮ ⋮ ⋱ ⋮ ⋱ ⋮ ⋮ ⋱ ⋮ 0 d k 0 d k ⋯ 0 d k 0 d k 0 d k ⋯ 0 d k ⋯ I d k I d k ⋯ I d k ] [ k i ( 1 ) I d k , k i ( 1 ) I d k , ⋯ , k i ( 1 ) I d k , k i ( 2 ) I d k , k i ( 2 ) I d k , ⋯ , k i ( 2 ) I d k , ⋯ , k i ( g ) I d k , k i ( g ) I d k , ⋯ , k i ( g ) I d k ] [ k i ( 1 ) , ⋯ , k i ( 1 ) , k i ( 2 ) , ⋯ , k i ( 2 ) , ⋯ , k i ( g ) , ⋯ , k i ( g ) ] ∈ R h d k \begin{aligned}\boldsymbol{ck}_i \boldsymbol{W}_k [\boldsymbol{k}_i^{(1)}, \cdots, \boldsymbol{k}_i^{(g)}] \cdot \begin{bmatrix} \boldsymbol{I}_{d_k} \boldsymbol{I}_{d_k} \cdots \boldsymbol{I}_{d_k} \boldsymbol{0}_{d_k } \boldsymbol{0}_{d_k } \cdots \boldsymbol{0}_{d_k } \cdots \boldsymbol{0}_{d_k } \boldsymbol{0}_{d_k } \cdots \boldsymbol{0}_{d_k } \\ \boldsymbol{0}_{d_k } \boldsymbol{0}_{d_k } \cdots \boldsymbol{0}_{d_k } \boldsymbol{I}_{d_k} \boldsymbol{I}_{d_k} \cdots \boldsymbol{I}_{d_k} \cdots \boldsymbol{0}_{d_k } \boldsymbol{0}_{d_k } \cdots \boldsymbol{0}_{d_k } \\ \vdots \vdots \ddots \vdots \vdots \vdots \ddots \vdots \ddots \vdots \vdots \ddots \vdots \\ \boldsymbol{0}_{d_k } \boldsymbol{0}_{d_k } \cdots \boldsymbol{0}_{d_k } \boldsymbol{0}_{d_k } \boldsymbol{0}_{d_k } \cdots \boldsymbol{0}_{d_k } \cdots \boldsymbol{I}_{d_k} \boldsymbol{I}_{d_k} \cdots \boldsymbol{I}_{d_k} \end{bmatrix}\\[\boldsymbol{k}_i^{(1)} \boldsymbol{I}_{d_k}, \boldsymbol{k}_i^{(1)} \boldsymbol{I}_{d_k}, \cdots,\boldsymbol{k}_i^{(1)} \boldsymbol{I}_{d_k},\boldsymbol{k}_i^{(2)} \boldsymbol{I}_{d_k},\boldsymbol{k}_i^{(2)} \boldsymbol{I}_{d_k},\cdots,\boldsymbol{k}_i^{(2)} \boldsymbol{I}_{d_k},\cdots,\boldsymbol{k}_i^{(g)} \boldsymbol{I}_{d_k},\boldsymbol{k}_i^{(g)} \boldsymbol{I}_{d_k},\cdots,\boldsymbol{k}_i^{(g)} \boldsymbol{I}_{d_k}]\\ [\boldsymbol{k}_i^{(1)}, \cdots, \boldsymbol{k}_i^{(1)},\boldsymbol{k}_i^{(2)}, \cdots, \boldsymbol{k}_i^{(2)},\cdots,\boldsymbol{k}_i^{(g)}, \cdots, \boldsymbol{k}_i^{(g)}] \in \mathbb{R}^{h d_k} \end{aligned} ckiWk[ki(1),⋯,ki(g)]⋅ Idk0dk⋮0dkIdk0dk⋮0dk⋯⋯⋱⋯Idk0dk⋮0dk0dkIdk⋮0dk0dkIdk⋮0dk⋯⋯⋱⋯0dkIdk⋮0dk⋯⋯⋱⋯0dk0dk⋮Idk0dk0dk⋮Idk⋯⋯⋱⋯0dk0dk⋮Idk [ki(1)Idk,ki(1)Idk,⋯,ki(1)Idk,ki(2)Idk,ki(2)Idk,⋯,ki(2)Idk,⋯,ki(g)Idk,ki(g)Idk,⋯,ki(g)Idk][ki(1),⋯,ki(1),ki(2),⋯,ki(2),⋯,ki(g),⋯,ki(g)]∈Rhdk 于是我们就得到了维度为 h d k h d_k hdk 的 k i \boldsymbol{k}_i ki其中每个 k i ( s ) \boldsymbol{k}_i^{(s)} ki(s) 都被复制了 h / g h/g h/g 次实现了“复制”操作。同理 c v i W v \boldsymbol{cv}_i \boldsymbol{W}_v cviWv 得到 v i \boldsymbol{v}_i vi维度为 h d v h d_v hdv。 这里我们讨论的都是单个 token 的行向量而对于实际输入序列其最后两个维度为 (seq_len, d)同样可以直接替换上面的单一向量。 在上面所构造的所有矩阵中最重要的是 W k \boldsymbol{W}_k Wk 和 W v \boldsymbol{W}_v Wv可以看出两者都是由若干单位矩阵组成的稀疏矩阵是 GQA 的分割、复制操作的矩阵形式描述那么正如苏神文章中所述我们可以将让这两个矩阵变成可学习的参数比如在“复制”过程中给每个头一个不同的权重这样理论上可以增强 GQA 的能力。 这就是 MLA 的思想它可以看作是 GQA 的一种改进在压缩到 c \boldsymbol{c} c 维之后又用一个上投影矩阵来恢复到更高的维度在 DeepSeek-V2 的技术报告中是先利用下投影矩阵 W D K V W^{DKV} WDKV 将隐藏层输入 h t \mathbf{h}_t ht 投影得到 c t K V \mathbf{c}_t^{KV} ctKV然后再用两个上投影矩阵 W U K , W U V W^{UK},W^{UV} WUK,WUV 将 c t K V \mathbf{c}_t^{KV} ctKV 还原得到 k t C , v t C \mathbf{k}_t^C,\mathbf{v}_t^C ktC,vtC 然而理论上这样是能增加模型能力但别忘了 GQA 的主要目的是减少 KV Cache出于节省计算和通信成本的考虑我们一般缓存的是投影后的 k i , v i \bm{k_{i}},\bm{v_{i}} ki,vi 而不是投影前的 c i \bm{c_{i}} ci 或 x i \bm{x_{i}} xi而 MLA 的这个做法通过不同的投影矩阵再次让所有的 K、V Head 都变得各不相同那么 KV Cache 的大小就恢复成跟 MHA 一样大了违背了 GQA 的初衷。
对此MLA 发现我们可以结合 Dot-Attention 的具体形式通过一个简单但不失巧妙的恒等变换来规避这个问题。首先在训练阶段还是照常进行此时优化空间不大然后在推理阶段我们利用 q t ( s ) k i ( s ) ⊤ ( x t W q ( s ) ) ( c i W k ( s ) ) ⊤ x t ( W q ( s ) W k ( s ) ⊤ ) c i ⊤ \bm{q_{t}^{(s)}}\bm{k_{i}^{(s)\top}}\left(\bm{x_{t}}\bm{W_{q}^{(s)}}\right) \left(\bm{c_{i}}\bm{W_{k}^{(s)}}\right)\bm{{}^{\top}}\bm{x_{t}}\left(\bm{W_{q }^{(s)}}\bm{W_{k}^{(s)\top}}\right)\bm{c_{i}^{\top}} qt(s)ki(s)⊤(xtWq(s))(ciWk(s))⊤xt(Wq(s)Wk(s)⊤)ci⊤
这意味着推理阶段我们可以将 W q ( s ) W k ( s ) ⊤ \bm{W_{q }^{(s)}}\bm{W_{k}^{(s)\top}} Wq(s)Wk(s)⊤ 合并起来作为 Q 的投影矩阵那么 c i \bm{c_i} ci 则取代了原本的 k i \bm{k_{i}} ki
同理在 o t \bm{o_{t}} ot 后面我们还有一个投影矩阵于是 v i ( s ) c i W v ( s ) \bm{v_{i}^{(s)}c_{i}W_{v}^{(s)}} vi(s)ciWv(s) 的 W v ( s ) \bm{W_{v}^{(s)}} Wv(s) 也可以吸收到后面的投影矩阵中去于是等效地 v i \bm{v_{i}} vi 也可以用 c i \bm{c_i} ci 代替也就是说此时 KV Cache 只需要存下所有的 c i \bm{c_i} ci 就行而不至于存下所有的 k i ( s ) \bm{k_{i}^{(s)}} ki(s)、 v i ( s ) \bm{v_{i}^{(s)}} vi(s)。注意到 c i \bm{c_i} ci 跟 ( s ) ^{(s)} (s) 无关也就是说所有的头共享的即 MLA 在推理阶段它可以恒等变换为一个 MQA
再次强调我们的主题一直都是减少 KV Cache那到目前为止MLA 做到了什么呢答案是通过不同的投影矩阵来增强 GQA 的能力并且推理时可以保持同样大小的 KV Cache。那么反过来如果我们只需要跟 GQA 相近的能力那么是不是就可以再次减少 KV Cache 了换言之 d c d_c dc 没必要取 g ( d k d v ) g(d_kd_v) g(dkdv)而是取更小的值DeepSeek-V2 取了 512从而进一步压缩 KV Cache这就是 MLA 的核心思想
1.5.2 Part 2
一切似乎都很完美但到目前为止的 MLA 有一个难以绕开的缺陷—不兼容 RoPE旋转位置编码
关于 RoPE 大家感兴趣的可以看看RoPE旋转位置编码原理浅析
前面我们说了MLA 之所以能保持跟 GQA 一样大小的 KV Cache其关键一步是将 W q ( s ) W k ( s ) ⊤ \bm{W_{q }^{(s)}}\bm{W_{k}^{(s)\top}} Wq(s)Wk(s)⊤ 合并成一个跟位置无关的矩阵作为 Q 的投影矩阵但如果加了 RoPE 的话这一步就无法实现了。这是因为 RoPE 是一个跟位置相关的、 d k × d k d_k\times d_k dk×dk 的分块对角矩阵 R m \bm{\mathcal{R}_{m}} Rm满足 R m R n ⊤ R m − n \bm{\mathcal{R}_{m}}\bm{\mathcal{R}_{n}}^{\top}\bm{\mathcal{R}_{m-n}} RmRn⊤Rm−nMLA 加入 RoPE 之后会让 W q ( s ) W k ( s ) ⊤ \bm{W_{q }^{(s)}}\bm{W_{k}^{(s)\top}} Wq(s)Wk(s)⊤ 之间多插入了一项 R t − i \bm{\mathcal{R}_{t-i}} Rt−i q i ( s ) x i W q ( s ) R i , k i ( s ) c i W k ( s ) R i q t ( s ) k i ( s ) ⊤ ( x t W q ( s ) R i ) ( c i W k ( s ) R i ) ⊤ x t ( W q ( s ) R t − i W k ( s ) ⊤ ) c i ⊤ \bm{q_{i}^{(s)}}\bm{x_{i}W_{q}^{(s)}{\color{red}{\mathcal{R}_{i}}}}\quad,\quad\bm{k_{i}^{(s) }}\bm{c_{i}W_{k}^{(s)}{\color{red}{\mathcal{R}_{i}}}} \\ \bm{q}_{t}^{(s)}\bm{k}_{i}^{(s)}{}^{\top}\left(\bm{x}_{t}\bm{W}_{q}^{(s)} {\color{red}{\bm{\mathcal{R}_{i}}}}\right)\left(\bm{c}_{i}\bm{W}_{k}^{(s)} {\color{red}{\bm{\mathcal{R}_{i}}}}\right) {}^{\top}\bm{x}_{t}\left(\bm{W}_{q}^{(s)} {\color{red}{\bm{\mathcal{R}_{t-i}}}}\bm{W}_{k}^{(s)}{ }^{\top}\right)\bm{c}_{i}^{\top} qi(s)xiWq(s)Ri,ki(s)ciWk(s)Riqt(s)ki(s)⊤(xtWq(s)Ri)(ciWk(s)Ri)⊤xt(Wq(s)Rt−iWk(s)⊤)ci⊤
这里的 W q ( s ) R t − i W k ( s ) ⊤ \bm{W}_{q}^{(s)} {\color{red}{\bm{\mathcal{R}_{t-i}}}}\bm{W}_{k}^{(s)}{ }^{\top} Wq(s)Rt−iWk(s)⊤ 就无法合并为一个固定的投影矩阵了跟位置差 t − i t-i t−i 相关从而 MLA 的想法无法结合 RoPE 实现
最后发布的 MLA 通过将每个 Attention Head 的 Q、K 新增 d r d_r dr 个维度用来添加 RoPE其中 K 新增的维度每个 Head 共享 o t [ o t ( 1 ) , o t ( 2 ) , ⋯ , o t ( h ) ] o t ( s ) A t t e n t i o n ( q t ( s ) , k ≤ t ( s ) , v ≤ t ( s ) ) ≜ ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) v i ( s ) ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) q i ( s ) [ x i W q c ( s ) , x i W q r ( s ) R i ] ∈ R d k d r , W q c ( s ) ∈ R d × d k , W q r ( s ) ∈ R d × d r k i ( s ) [ c i W k c ( s ) , x i W k r ( s ) R i ] ∈ R d k d r , W k c ( s ) ∈ R d c × d k , W k r ( s ) ∈ R d × d r v i ( s ) c i W v ( s ) ∈ R d v , W v ( s ) ∈ R d c × d v c i x i W c ∈ R d c , W c ∈ R d × d c \bm{o_{t}}\left[\bm{o_{t}^{(1)}},\bm{o_{t}^{(2)}},\cdots,\bm{o_{t}^{(h)}}\right] \\ \bm{o}_{t}^{(s)}\bm{Attention}\left(\bm{q}_{t}^{(s)},\bm{k}_{\leq t}^{(s)}, \bm{v}_{\leq t}^{(s)}\right)\triangleq\frac{\sum_{i\leq t}\exp\Bigl(\bm{q}_ {t}^{(s)}\bm{k}_{i}^{(s)}{}^{\top}\Bigr)\bm{v}_{i}^{(s)}}{\sum_{i\leq t}\exp \Bigl(\bm{q}_{t}^{(s)}\bm{k}_{i}^{(s)}{}^{\top}\Bigr)} \\ \bm{q}_{i}^{(s)}\left[\bm{x_{i}}\bm{W}_{qc}^{(s)}\bm{,x_{i}}\bm{W}_{qr}^{(s)} {\color{red}\bm{\mathcal{R}}_{i}}\right]\in\mathbb{R}^{d_{k}d_{r}}\bm{,}\quad\bm{W}_{qc}^ {(s)}\in\mathbb{R}^{d\times d_{k}}\bm{,W}_{qr}^{(s)}\in\mathbb{R}^{d\times d _{r}} \\ \bm{k}_{i}^{(s)}\left[\bm{c_{i}}\bm{W}_{kc}^{(s)}\bm{,x_{i}}\bm{W}_{kr}^{\color{red}\bcancel{(s)}} {\color{red}\bm{\mathcal{R}}_{i}}\right]\in\mathbb{R}^{d_{k}d_{r}}\bm{,}\quad\bm{W}_{kc}^ {(s)}\in\mathbb{R}^{d_c\times d_{k}}\bm{,W}_{kr}^{\color{red}\bcancel{(s)}}\in\mathbb{R}^{d\times d _{r}} \\ \bm{v_{i}^{(s)}c_{i}W_{v}^{(s)}\in\mathbb{R}^{d_{v}},\quad W_{v}^{(s)}\in \mathbb{R}^{d_c\times d_{v}}} \\ \bm{c_{i}x_{i}W_{c}\in\mathbb{R}^{d_{c}},\quad W_{c}\in \mathbb{R}^{d\times d_{c}}} ot[ot(1),ot(2),⋯,ot(h)]ot(s)Attention(qt(s),k≤t(s),v≤t(s))≜∑i≤texp(qt(s)ki(s)⊤)∑i≤texp(qt(s)ki(s)⊤)vi(s)qi(s)[xiWqc(s),xiWqr(s)Ri]∈Rdkdr,Wqc(s)∈Rd×dk,Wqr(s)∈Rd×drki(s)[ciWkc(s),xiWkr(s) Ri]∈Rdkdr,Wkc(s)∈Rdc×dk,Wkr(s) ∈Rd×drvi(s)ciWv(s)∈Rdv,Wv(s)∈Rdc×dvcixiWc∈Rdc,Wc∈Rd×dc
这样一来没有 RoPE 的维度就可以重复 “Part 1” 的操作在推理时 KV Cache 只需要存 c i \bm{c_i} ci新增的带 RoPE 的维度就可以用来补充位置信息并且由于所有 Head 共享所以也就只有在 K Cache 这里增加了 d r d_r dr 个维度原论文取了 d r d k / 2 64 d_rd_k/264 drdk/264相比原本的 d c 512 d_c512 dc512增加的幅度不大
1.5.3 Part 3
最后有一个细节就是 MLA 的最终版本还将 Q 的输入也改为了低秩投影形式这与减少 KV Cache 无关主要是为了减少训练期间参数量和相应的梯度所占的显存 o t [ o t ( 1 ) , o t ( 2 ) , ⋯ , o t ( h ) ] o t ( s ) A t t e n t i o n ( q t ( s ) , k ≤ t ( s ) , v ≤ t ( s ) ) ≜ ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) v i ( s ) ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) q i ( s ) [ c i ′ W q c ( s ) , c i ′ W q r ( s ) R i ] ∈ R d k d r , W q c ( s ) ∈ R d c ′ , W q r ( s ) ∈ R d c ′ × d r k i ( s ) [ c i W k c ( s ) , x i W k r ( s ) R i ] ∈ R d k d r , W k c ( s ) ∈ R d c , W k r ( s ) ∈ R d × d r v i ( s ) c i W v ( s ) ∈ R d v , W v ( s ) ∈ R d c × d v c i ′ x i W c ′ ∈ R d c ′ , W c ′ ∈ R d × d c ′ c i x i W c ∈ R d c , W c ∈ R d × d c \begin{gathered} \boldsymbol{o}_t \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\right] \\[10pt] \boldsymbol{o}_t^{(s)} Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{(s)} ,\boldsymbol{v}_{\leq t}^{(s)}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)\boldsymbol{v}_i^{(s)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)} \\[15pt] \boldsymbol{q}_i^{(s)} \left[\boldsymbol{c}_i\boldsymbol{W}_{qc}^{(s)}, \boldsymbol{c}_i\boldsymbol{W}_{qr}^{(s)}{\color{red}{\boldsymbol{\mathcal{R}}_i}}\right]\in\mathbb{R}^{d_k d_r},\quad \boldsymbol{W}_{qc}^{(s)}\in\mathbb{R}^{d_c},\boldsymbol{W}_{qr}^{(s)}\in\mathbb{R}^{d_c\times d_r}\\ \boldsymbol{k}_i^{(s)} \left[\boldsymbol{c}_i\boldsymbol{W}_{kc}^{(s)}, \boldsymbol{x}_i\boldsymbol{W}_{kr}^{\color{red}{\smash{\bcancel{(s)}}}}{\color{red}{\boldsymbol{\mathcal{R}}_i}}\right]\in\mathbb{R}^{d_kd_r},\quad \boldsymbol{W}_{kc}^{(s)}\in\mathbb{R}^{d_c}, \boldsymbol{W}_{kr}^{\color{red}{\smash{\bcancel{(s)}}}}\in\mathbb{R}^{d\times d_r} \\ \boldsymbol{v}_i^{(s)} \boldsymbol{c}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_c\times d_v} \\[10pt] \boldsymbol{c}_i \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times d_c} \\ \boldsymbol{c}_i \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times d_c} \\ \end{gathered} ot[ot(1),ot(2),⋯,ot(h)]ot(s)Attention(qt(s),k≤t(s),v≤t(s))≜∑i≤texp(qt(s)ki(s)⊤)∑i≤texp(qt(s)ki(s)⊤)vi(s)qi(s)[ci′Wqc(s),ci′Wqr(s)Ri]∈Rdkdr,Wqc(s)∈Rdc′,Wqr(s)∈Rdc′×drki(s)[ciWkc(s),xiWkr(s) Ri]∈Rdkdr,Wkc(s)∈Rdc,Wkr(s) ∈Rd×drvi(s)ciWv(s)∈Rdv,Wv(s)∈Rdc×dvci′xiWc′∈Rdc′,Wc′∈Rd×dc′cixiWc∈Rdc,Wc∈Rd×dc
注意 k i ( s ) \boldsymbol{k}_i^{(s)} ki(s) 中的第二项带 RoPE 的部分其输入还是 x i \boldsymbol{x}_i xi 而不是 c i \boldsymbol{c}_i ci这里保持了原论文的设置不是笔误 d c ′ d_c dc′ 原论文的取值是 1536跟 d c 512 d_c512 dc512 不同。
同时我们把带 RoPE 的 MHA 放在下面方便大家对比 o t [ o t ( 1 ) , o t ( 2 ) , ⋯ , o t ( h ) ] o t ( s ) A t t e n t i o n ( q t ( s ) , k ≤ t ( s ) , v ≤ t ( s ) ) ≜ ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) v i ( s ) ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) q i ( s ) x i W q ( s ) R i ∈ R d k , W q ( s ) ∈ R d × d k k i ( s ) x i W k ( s ) R i ∈ R d k , W k ( s ) ∈ R d × d k v i ( s ) x i W v ( s ) ∈ R d v , W v ( s ) ∈ R d × d v \begin{gathered} \boldsymbol{o}_t \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\right] \\[10pt] \boldsymbol{o}_t^{(s)} Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{(s)} ,\boldsymbol{v}_{\leq t}^{(s)}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)\boldsymbol{v}_i^{(s)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)} \\[15pt] \boldsymbol{q}_i^{(s)}\boldsymbol{x}_i\boldsymbol{W}_q^{(s)}{\color{red}\mathcal{R}_i}\in\mathbb{R}^{d_k},\quad\boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d\times d_k} \\\boldsymbol{k}_i^{(s)}\boldsymbol{x}_i\boldsymbol{W}_k^{(s)}{\color{red}\mathcal{R}_i}\in\mathbb{R}^{d_k},\quad\boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d\times d_k}\\ \boldsymbol{v}_i^{(s)}\boldsymbol{x}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d\times d_v} \end{gathered} ot[ot(1),ot(2),⋯,ot(h)]ot(s)Attention(qt(s),k≤t(s),v≤t(s))≜∑i≤texp(qt(s)ki(s)⊤)∑i≤texp(qt(s)ki(s)⊤)vi(s)qi(s)xiWq(s)Ri∈Rdk,Wq(s)∈Rd×dkki(s)xiWk(s)Ri∈Rdk,Wk(s)∈Rd×dkvi(s)xiWv(s)∈Rdv,Wv(s)∈Rd×dv
可以发现其实在训练阶段除了多了一步低秩投影以及只在部分维度加 RoPE 外MLA 与 Q、K 的 Head Size 由 d k d_k dk 换成 d k d r d_kd_r dkdr 的 MHA 基本无异
推理阶段的 MLA 则改为 o t [ o t ( 1 ) W v ( 1 ) , o t ( 2 ) W v ( 2 ) , ⋯ , o t ( h ) W v ( h ) ] o t ( s ) A t t e n t i o n ( q t ( s ) , k ≤ t ( s ) , c ≤ t ) ≜ ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) c i ∑ i ≤ t exp ( q t ( s ) k i ( s ) ⊤ ) q i ( s ) [ c i ′ W q c ( s ) W k c ( s ) ⊤ , c i ′ W q r ( s ) R i ] ∈ R d c d r k i ( s ) [ c i , x i W k r ( s ) R i ] ∈ R d c d r W q c ( s ) ∈ R d c ′ × d k , W k c ( s ) ∈ R d c × d k , W q r ( s ) ∈ R d c ′ × d r , W k r ( s ) ∈ R d × d r c i ′ x i W c ′ ∈ R d c ′ , W c ′ ∈ R d × d c ′ c i x i W c ∈ R d c , W c ∈ R d × d c \begin{gathered} \boldsymbol{o}_t \left[\boldsymbol{o}_t^{(1)}\boldsymbol{W}_v^{(1)}, \boldsymbol{o}_t^{(2)}\boldsymbol{W}_v^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\boldsymbol{W}_v^{(h)}\right] \\[10pt] \boldsymbol{o}_t^{(s)} Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{\color{red}{\smash{\bcancel{(s)}}}} ,\boldsymbol{c}_{\leq t}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{\color{red}{\smash{\bcancel{(s)}}}}{}^{\top}\right)\boldsymbol{c}_i}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{\color{red}{\smash{\bcancel{(s)}}}}{}^{\top}\right)} \\[15pt] \boldsymbol{q}_i^{(s)} \left[\boldsymbol{c}_i\boldsymbol{W}_{qc}^{(s)}\boldsymbol{W}_{kc}^{(s)}{}^{\top}, \boldsymbol{c}_i\boldsymbol{W}_{qr}^{(s)}{\color{red}{\boldsymbol{\mathcal{R}}_i}}\right]\in\mathbb{R}^{d_c d_r}\\ \boldsymbol{k}_i^{\color{red}{\smash{\bcancel{(s)}}}} \left[\boldsymbol{c}_i, \boldsymbol{x}_i\boldsymbol{W}_{kr}^{\color{red}{\smash{\bcancel{(s)}}}}{\color{red}{\boldsymbol{\mathcal{R}}_i}}\right]\in\mathbb{R}^{d_cd_r}\\ \boldsymbol{W}_{qc}^{(s)}\in\mathbb{R}^{d_c\times d_k},\boldsymbol{W}_{kc}^{(s)}\in\mathbb{R}^{d_c\times d_k},\boldsymbol{W}_{qr}^{(s)}\in\mathbb{R}^{d_c\times d_r},\boldsymbol{W}_{kr}^{\color{#ccc}{\smash{\bcancel{(s)}}}}\in\mathbb{R}^{d\times d_r} \\[10pt] \boldsymbol{c}_i \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times d_c} \\ \boldsymbol{c}_i \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times d_c} \\ \end{gathered} ot[ot(1)Wv(1),ot(2)Wv(2),⋯,ot(h)Wv(h)]ot(s)Attention(qt(s),k≤t(s) ,c≤t)≜∑i≤texp(qt(s)ki(s) ⊤)∑i≤texp(qt(s)ki(s) ⊤)ciqi(s)[ci′Wqc(s)Wkc(s)⊤,ci′Wqr(s)Ri]∈Rdcdrki(s) [ci,xiWkr(s) Ri]∈RdcdrWqc(s)∈Rdc′×dk,Wkc(s)∈Rdc×dk,Wqr(s)∈Rdc′×dr,Wkr(s) ∈Rd×drci′xiWc′∈Rdc′,Wc′∈Rd×dc′cixiWc∈Rdc,Wc∈Rd×dc
此时 Q、K 的 Head Size 变成了 d c d r d_c d_r dcdrV 的 Head Size 则变成了 d c d_c dc按照原论文的设置这是 d k d_k dk、 d v d_v dv 的 4 倍。所以实际上 MLA 在推理阶段做的这个转换虽然能有效减少KV Cache但其推理的计算量是增加的。
那为什么还能提高推理效率呢这又回到“瓶颈”一节所讨论的问题了我们可以将LLM的推理分两部分第一个 Token 的生成Prefill和后续每个 Token 的生成GenerationPrefill 阶段涉及到对输入所有 Token 的并行计算然后把对应的 KV Cache 存下来这部分对于计算、带宽和显存都是瓶颈MLA 虽然增大了计算量但 KV Cache 的减少也降低了显存和带宽的压力大家半斤八两但是 Generation 阶段由于每步只计算一个 Token实际上它更多的是带宽瓶颈和显存瓶颈因此 MLA 的引入理论上能明显提高 Generation 的速度。
还有一个细节充分体现了这个特性。一般的 LLM 架构参数满足 h d hd hd即 num_heads * head_size hidden_size但 DeepSeek-V2 不一样它 d k 128 , d 5120 d_k128,d5120 dk128,d5120但 h 128 h128 h128是一般设置的 3 倍这是因为 MLA 的 KV Cache 大小跟 h h h 无关增大 h h h 只会增加计算量和提升模型能力但不会增加 KV Cache所以不会带来速度瓶颈。
由于篇幅原因CSDN 对正文字数有限制真服了MLA 的代码实现我们放在下篇文章
结语 MLA 可以看作是 GQA 的优化通过投影矩阵的方式替换 GQA 中的分割、复制等线性变换操作并引入了一个恒等变换在推理阶段通过矩阵吸收来进一步压缩 KV Cache同时采用了一种混合方法通过新增维度来兼容 RoPE 旋转位置编码总的来说MLA 算得上一种非常实用的注意力变体 大家可以多看看苏神的文章来加深理解 参考
缓存与效果的极限拉扯从MHA、MQA、GQA到MLA博客分享从MHA、MQA、GQA到MLAhttps://github.com/preacher-1/MLA_tutorialhttps://github.com/deepseek-ai/DeepSeek-V3