来源:学长的的yuque ··· 我只是加以整理为笔记
Important
以一个具体的玩具配置,把一次 LLM 推理从 token 进入到第一个生成 token,再到 Step 2 的 KV Cache 复用,按维度一步步推一遍。所有形状都精确到张量级别。
0. 玩具配置(贯穿全文)
| 符号 | 含义 | 数值 |
|---|---|---|
| V | 词表大小 | 6000 |
| d_e | Token embedding 维度 | 4096 |
| d | 注意力层隐藏维度 | 5200 |
| h | 多头注意力头数 | 8 |
| d_k | 每个头的维度 d / h | 650 |
| d_ff | FFN 中间维度 | 20800 |
| d_last | 最后一层 hidden state 维度 | 3200 |
| B | batch size | 1 |
| L_p | prompt token 数 | 7 |
Prompt:「你好,请问一下」→ BPE 切分得到 7 个 token。
1. 输入端:从文本到第一层注意力的输入
-
BPE 分词:「你好,请问一下」→ 7 个 token id。
-
Embedding 查表:过
nn.Embedding(6000, 4096),得到 。 -
维度适配 Linear:因为第一层注意力期望的隐藏维度是 5200,所以接一个 把 embedding 投到 5200。
- 维度变化:
-
加上 batch 维(B=1)后,进入第一层注意力的张量形状为 。
Important
真正的 LLM 一般会让 embedding 维度 = 隐藏维度(直接复用),这里人为拆开是为了把「embedding 维度」和「注意力隐藏维度」两个概念讲清楚。
2. Prefill Phase(预填充阶段,Step 1)
Step 1 模型只能看到 prompt 的 7 个 token,目标是:一次性并行算完这 7 个 token 的所有层,把每层的 K/V 写入 KV Cache,并算出最后一个 token 的 logits 用来采样。
2.1 第一步:输入向量
第一层注意力的输入:。
2.2 第二步:计算 Q、K、V
三个独立线性变换 :
2.3 第三步:多头拆分
多头拆分本质上是「把同一个大 5200 维向量切成 8 个 650 维的小子空间,让每个子空间独立做一场注意力」。这一步没有任何乘加运算,全是张量的 内存布局重排(view / transpose)。
以 Q 为例(K、V 同理),起点是 :
-
Reshape:拆最后一维
-
操作:
Q.view(7, 8, 650),把最后那个 5200 维拆成(h, d_k) = (8, 650)。 -
形状变化:
-
内存上数字所在位置一个都没动,只是「怎么看它」变了:原本「第 i 个 token 的第 j 维」现在读为「第 i 个 token 的第 k 个头的第 m 维」(其中 )。
-
隐含约定: 训练出来后,沿输出维度的第 0–649 个分量学到了「head 0 的 query 表示」,650–1299 个是「head 1 的 query」……8 个头的参数其实是默默被下标切出来的。
-
-
Transpose:把 head 维提到前面
-
操作:
Q.transpose(0, 1),交换「seq_len 维」和「head 维」。 -
形状变化:
-
代价:这一步会让张量变「non-contiguous」(内存不连续),很多实现会接一个
.contiguous()重新拷贝一份,换取后续 matmul 的高性能。
-
-
为什么要 transpose?——是语义对齐,不是性能优化
-
关键事实:
torch.matmul/@的 定义 就是「只在最后两维上做矩阵乘,前面所有维独立循环」。这不是技巧,是这个 op 的契约。 -
所以「你想让哪两维参与矩阵乘」→「那两维必须物理上位于张量的最后两维」。reshape 后形状是 ,最后两维是 。不 transpose 直接乘,被乘的就是这个
(8, 650)子矩阵——语义上压根不是「每个 head 内部 token-token 相似度」,而是「head 之间 × d_k 之间」一个完全错位的运算。 -
transpose 把 head 挪走后,最后两维变成 ,这才是「每个 head 内部、token 之间」的那个子矩阵。
-
后续 自动并行 8 个 是「位置对了」之后的 附送福利。B、h 两维被 matmul 当独立循环维看——PyTorch 文档里也叫它们 batch dim,但语义上和你设的样本 batch B 不是一回事。
-
附记:用
torch.einsum('blhd,bshd->bhls', Q, K)这种明式指定轴的 API,压根不需要 transpose。所以 transpose 不是数学上必须的,只是 matmul 这种「靠轴位置」的 API 必须靠它适配。
-
-
加上 batch 维后的真实形状
- 实际实现里 B 维一直在,完整流程是:。
Important
一句话拽清楚:reshape 是「思想上拆 head」;transpose 是「把『我要参与矩阵乘的那两维』挪到最后两维,调适 matmul 的语义契约」。数字本身一个都没变。
2.4 第四步:KV Cache 的核心操作(Step 1 特有)
Step 1 时 KV Cache 是 完全空 的。一次性并行处理 7 个 token,得到完整的 K、V(),直接 写入并分配显存 中的 KV Cache:
-
Cache_K形状:(含 batch 维) -
Cache_V形状:
这就是「冻结的记忆」,为 Step 2、Step 3… 的流式生成提供历史上下文,避免重复计算这 7 个 token。
2.5 第五步:注意力分数 + 因果掩码
-
点积 + 缩放:
-
-
每个头都得到一张 的 token 互注意力矩阵。
-
-
Causal Mask:把上三角(对角线以上)设为 ,确保第 i 个 token 只能看 ≤ i 的 token。
-
Softmax:在最后一维做归一化, 处变 0,得到 的权重矩阵。
2.6 第六步:注意力加权 + 输出投影
-
加权求和:
-
多头拼接:
transpose+reshape回 -
输出投影 :得到
2.7 第七步:第一次残差
2.8 第八步:第一次 LayerNorm
Important
LayerNorm 是 对每个 token 独立做 的:在 5200 这个隐藏维度上算均值和方差并标准化,再乘以可学习的 、加 。7 个 token 之间互不干扰。
2.9 第九步:FFN(升维 → 激活 → 降维)
FFN 中 token 之间不通讯,每个 token 独自走两层 MLP。
-
升维:
-
非线性激活:,逐元素操作,维度不变。
-
降维:
2.10 第十步:第二次残差 + 第二次 LayerNorm
这就是第一层 Transformer Block 的最终输出,作为下一层的输入,逐层堆叠直到最后一层。
3. 输出端:从最后一层 hidden 到第一个 token
-
经过所有 Transformer Block 后,最后一层输出 。
-
生成时只需要最后一个位置(位置 7)的 hidden:取 。
-
LM Head:过 ,得到 logits 。
-
Softmax:转成 6000 维概率分布。
-
采样(greedy / top-k / top-p / temperature):得到 Step 1 的输出 token,比如假设是 token id = 21。
至此,Prefill 完成:KV Cache 装着 7 个 prompt token 的 K/V,模型吐出第一个回答 token。
4. Decoding Phase(解码阶段,Step 2)
Step 2 输入不再是 7 个 token,而是 仅 1 个:上一步采样出来的 token id = 21。
Important
没有 KV Cache 的话每步都要重新算前面所有 token;有了 KV Cache,Step 2 的计算量出现 断崖式下降。
4.1 第一步:输入向量的极简转换
-
Token id = 21 过 Embedding:
-
过维度适配 Linear:
4.2 第二步:只算这 1 个 token 的 Q、K、V
仍然乘以 :
-
-
拆头后变形为
4.3 第三步:KV Cache 的核心操作(拼接与读取)
这是 Decoding 的灵魂。 从显存中唤醒 Layer 1 的 KV Cache:
-
历史 Cache:
Cache_K,Cache_V(7 个 prompt token) -
Concat 写入:把当前 token 的 K、V()在 sequence 维度上拼接进去
-
更新后:
Cache_K,Cache_V(7 prompt + 1 生成)
4.4 第四步:非对称的注意力分数计算
Step 1 是 矩阵;Step 2 是 向量 × 矩阵。
-
点积:
-
-
含义:当前 token 对前面 7 个 prompt + 自己这 8 个位置的注意力分数。
-
注意:Q 只有 1 个,且本就在序列末端,不需要 Causal Mask。
-
-
Softmax:在最后一维归一化。
-
加权求和:
当前这 1 个 token 成功吸收了 8 个位置的全部上下文。
4.5 第五步:第一层收尾
-
多头拼接:
transpose+reshape回 -
残差 + LayerNorm:,再 LN
-
FFN:
- 由于只有 1 个 token,这里的矩阵乘法量比 Step 1 小 7 倍
-
再一次残差 + LayerNorm:完成 Layer 1,输出
5. Prefill vs Decoding:核心对比一张表
| 对比项 | Prefill (Step 1) | Decoding (Step 2+) |
|---|---|---|
| 输入 token 数 | L_p = 7(整个 prompt 并行) | 1(上一步采样出的 token) |
| Q 形状(每层每头) | ||
| 注意力矩阵形状 | ||
| 是否需要 Causal Mask | 是(下三角) | 否(Q 本就在末端) |
| KV Cache 操作 | 从空 → 写入 7 个 token 的 K/V | 读出历史 + concat 1 个新 K/V |
| FFN 计算量 | 7 × (5200 → 20800 → 5200) | 1 × (5200 → 20800 → 5200) |
| 瓶颈类型 | Compute-bound(算力受限) | Memory-bound(显存带宽受限) |
Important
一句话总结:Prefill 是「一次性把 prompt 嚼完,吐出 KV Cache + 第一个 token」;Decoding 是「每步只啃 1 个 token,但要把整张 KV Cache 当上下文来读」。两者算力 / 显存特征完全不同,这也是为什么生产部署里要把 Prefill 和 Decode 分开调度(PD 分离)。