来源:学长的的yuque ··· 我只是加以整理为笔记

Important

以一个具体的玩具配置,把一次 LLM 推理从 token 进入到第一个生成 token,再到 Step 2 的 KV Cache 复用,按维度一步步推一遍。所有形状都精确到张量级别。

0. 玩具配置(贯穿全文)

符号含义数值
V词表大小6000
d_eToken embedding 维度4096
d注意力层隐藏维度5200
h多头注意力头数8
d_k每个头的维度 d / h650
d_ffFFN 中间维度20800
d_last最后一层 hidden state 维度3200
Bbatch size1
L_pprompt token 数7

Prompt:「你好,请问一下」→ BPE 切分得到 7 个 token。


1. 输入端:从文本到第一层注意力的输入

  1. BPE 分词:「你好,请问一下」→ 7 个 token id。

  2. Embedding 查表:过 nn.Embedding(6000, 4096),得到

  3. 维度适配 Linear:因为第一层注意力期望的隐藏维度是 5200,所以接一个 把 embedding 投到 5200。

    • 维度变化:
  4. 加上 batch 维(B=1)后,进入第一层注意力的张量形状为

Important

真正的 LLM 一般会让 embedding 维度 = 隐藏维度(直接复用),这里人为拆开是为了把「embedding 维度」和「注意力隐藏维度」两个概念讲清楚。


2. Prefill Phase(预填充阶段,Step 1)

Step 1 模型只能看到 prompt 的 7 个 token,目标是:一次性并行算完这 7 个 token 的所有层,把每层的 K/V 写入 KV Cache,并算出最后一个 token 的 logits 用来采样

2.1 第一步:输入向量

第一层注意力的输入:

2.2 第二步:计算 Q、K、V

三个独立线性变换

2.3 第三步:多头拆分

多头拆分本质上是「把同一个大 5200 维向量切成 8 个 650 维的小子空间,让每个子空间独立做一场注意力」。这一步没有任何乘加运算,全是张量的 内存布局重排view / transpose)。

以 Q 为例(K、V 同理),起点是

  1. Reshape:拆最后一维

    • 操作:Q.view(7, 8, 650),把最后那个 5200 维拆成 (h, d_k) = (8, 650)

    • 形状变化:

    • 内存上数字所在位置一个都没动,只是「怎么看它」变了:原本「第 i 个 token 的第 j 维」现在读为「第 i 个 token 的第 k 个头的第 m 维」(其中 )。

    • 隐含约定: 训练出来后,沿输出维度的第 0–649 个分量学到了「head 0 的 query 表示」,650–1299 个是「head 1 的 query」……8 个头的参数其实是默默被下标切出来的

  2. Transpose:把 head 维提到前面

    • 操作:Q.transpose(0, 1),交换「seq_len 维」和「head 维」。

    • 形状变化:

    • 代价:这一步会让张量变「non-contiguous」(内存不连续),很多实现会接一个 .contiguous() 重新拷贝一份,换取后续 matmul 的高性能。

  3. 为什么要 transpose?——是语义对齐,不是性能优化

    • 关键事实:torch.matmul / @定义 就是「只在最后两维上做矩阵乘,前面所有维独立循环」。这不是技巧,是这个 op 的契约。

    • 所以「你想让哪两维参与矩阵乘」→「那两维必须物理上位于张量的最后两维」。reshape 后形状是 ,最后两维是 。不 transpose 直接乘,被乘的就是这个 (8, 650) 子矩阵——语义上压根不是「每个 head 内部 token-token 相似度」,而是「head 之间 × d_k 之间」一个完全错位的运算

    • transpose 把 head 挪走后,最后两维变成 ,这才是「每个 head 内部、token 之间」的那个子矩阵。

    • 后续 自动并行 8 个 是「位置对了」之后的 附送福利。B、h 两维被 matmul 当独立循环维看——PyTorch 文档里也叫它们 batch dim,但语义上和你设的样本 batch B 不是一回事。

    • 附记:用 torch.einsum('blhd,bshd->bhls', Q, K) 这种明式指定轴的 API,压根不需要 transpose。所以 transpose 不是数学上必须的,只是 matmul 这种「靠轴位置」的 API 必须靠它适配。

  4. 加上 batch 维后的真实形状

    • 实际实现里 B 维一直在,完整流程是:

Important

一句话拽清楚:reshape 是「思想上拆 head」;transpose 是「把『我要参与矩阵乘的那两维』挪到最后两维,调适 matmul 的语义契约」。数字本身一个都没变。

2.4 第四步:KV Cache 的核心操作(Step 1 特有)

Step 1 时 KV Cache 是 完全空 的。一次性并行处理 7 个 token,得到完整的 K、V(),直接 写入并分配显存 中的 KV Cache:

  • Cache_K 形状:(含 batch 维)

  • Cache_V 形状:

这就是「冻结的记忆」,为 Step 2、Step 3… 的流式生成提供历史上下文,避免重复计算这 7 个 token。

2.5 第五步:注意力分数 + 因果掩码

  1. 点积 + 缩放

    • 每个头都得到一张 的 token 互注意力矩阵。

  2. Causal Mask:把上三角(对角线以上)设为 ,确保第 i 个 token 只能看 ≤ i 的 token。

  3. Softmax:在最后一维做归一化, 处变 0,得到 的权重矩阵。

2.6 第六步:注意力加权 + 输出投影

  • 加权求和:

  • 多头拼接:transpose + reshape

  • 输出投影 :得到

2.7 第七步:第一次残差

2.8 第八步:第一次 LayerNorm

Important

LayerNorm 是 对每个 token 独立做 的:在 5200 这个隐藏维度上算均值和方差并标准化,再乘以可学习的 、加 。7 个 token 之间互不干扰。

2.9 第九步:FFN(升维 → 激活 → 降维)

FFN 中 token 之间不通讯,每个 token 独自走两层 MLP。

  1. 升维

  2. 非线性激活,逐元素操作,维度不变。

  3. 降维

2.10 第十步:第二次残差 + 第二次 LayerNorm

这就是第一层 Transformer Block 的最终输出,作为下一层的输入,逐层堆叠直到最后一层。


3. 输出端:从最后一层 hidden 到第一个 token

  1. 经过所有 Transformer Block 后,最后一层输出

  2. 生成时只需要最后一个位置(位置 7)的 hidden:取

  3. LM Head:过 ,得到 logits

  4. Softmax:转成 6000 维概率分布。

  5. 采样(greedy / top-k / top-p / temperature):得到 Step 1 的输出 token,比如假设是 token id = 21。

至此,Prefill 完成:KV Cache 装着 7 个 prompt token 的 K/V,模型吐出第一个回答 token。


4. Decoding Phase(解码阶段,Step 2)

Step 2 输入不再是 7 个 token,而是 仅 1 个:上一步采样出来的 token id = 21。

Important

没有 KV Cache 的话每步都要重新算前面所有 token;有了 KV Cache,Step 2 的计算量出现 断崖式下降

4.1 第一步:输入向量的极简转换

  1. Token id = 21 过 Embedding:

  2. 过维度适配 Linear:

4.2 第二步:只算这 1 个 token 的 Q、K、V

仍然乘以

  • 拆头后变形为

4.3 第三步:KV Cache 的核心操作(拼接与读取)

这是 Decoding 的灵魂。 从显存中唤醒 Layer 1 的 KV Cache:

  • 历史 Cache:Cache_K, Cache_V (7 个 prompt token)

  • Concat 写入:把当前 token 的 K、V()在 sequence 维度上拼接进去

  • 更新后:Cache_K, Cache_V (7 prompt + 1 生成)

4.4 第四步:非对称的注意力分数计算

Step 1 是 矩阵;Step 2 是 向量 × 矩阵

  1. 点积

    • 含义:当前 token 对前面 7 个 prompt + 自己这 8 个位置的注意力分数。

    • 注意:Q 只有 1 个,且本就在序列末端,不需要 Causal Mask

  2. Softmax:在最后一维归一化。

  3. 加权求和

当前这 1 个 token 成功吸收了 8 个位置的全部上下文。

4.5 第五步:第一层收尾

  1. 多头拼接transpose + reshape

  2. 残差 + LayerNorm,再 LN

  3. FFN

    • 由于只有 1 个 token,这里的矩阵乘法量比 Step 1 小 7 倍
  4. 再一次残差 + LayerNorm:完成 Layer 1,输出


5. Prefill vs Decoding:核心对比一张表

对比项Prefill (Step 1)Decoding (Step 2+)
输入 token 数L_p = 7(整个 prompt 并行)1(上一步采样出的 token)
Q 形状(每层每头)
注意力矩阵形状
是否需要 Causal Mask是(下三角)否(Q 本就在末端)
KV Cache 操作从空 → 写入 7 个 token 的 K/V读出历史 + concat 1 个新 K/V
FFN 计算量7 × (5200 → 20800 → 5200)1 × (5200 → 20800 → 5200)
瓶颈类型Compute-bound(算力受限)Memory-bound(显存带宽受限)

Important

一句话总结:Prefill 是「一次性把 prompt 嚼完,吐出 KV Cache + 第一个 token」;Decoding 是「每步只啃 1 个 token,但要把整张 KV Cache 当上下文来读」。两者算力 / 显存特征完全不同,这也是为什么生产部署里要把 Prefill 和 Decode 分开调度(PD 分离)。