Decoder-Only整体结构
我们以模型Llama-3.1-8B-Instruct为例,打印其结构如下(后面会慢慢解析每一部分,莫慌):
LlamaForCausalLM((model): LlamaModel((embed_tokens): VocabParallelEmbedding(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)(layers): ModuleList((0-31): 32 x LlamaDecoderLayer((self_attn): LlamaAttention((qkv_proj): QKVParallelLinear(in_features=4096, output_features=6144, bias=False, tp_size=1, gather_output=False)(o_proj): RowParallelLinear(input_features=4096, output_features=4096, bias=False, tp_size=1, reduce_results=True)(rotary_emb): Llama3RotaryEmbedding(head_size=128, rotary_dim=128, max_position_embeddings=131072, base=500000.0, is_neox_style=True)(attn): RadixAttention())(mlp): LlamaMLP((gate_up_proj): MergedColumnParallelLinear(in_features=4096, output_features=28672, bias=False, tp_size=1, gather_output=False)(down_proj): RowParallelLinear(input_features=14336, output_features=4096, bias=False, tp_size=1, reduce_results=True)(act_fn): SiluAndMul())(input_layernorm): RMSNorm()(post_attention_layernorm): RMSNorm()))(norm): RMSNorm())(lm_head): ParallelLMHead(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)(logits_processor): LogitsProcessor()(pooler): Pooler()
)
 
Decoder-Only处理流程
我们以Llama-3.1-8B-Instruct模型为例,结合一个具体的聊天对话场景,详细说明Decoder-Only模型的处理流程,从用户输入到最终输出回答。整个过程会逐步拆解,并标注每个步骤的输入输出形状(假设batch_size=1,seq_len=10,hidden_dim=4096,词表大小=128000)。
1. 用户输入与聊天模板处理
场景:用户问:“如何做西红柿炒鸡蛋?”
 模型需求:需要根据历史对话和当前问题生成回答。
聊天模板处理
- 输入文本text:原始用户输入(如“如何做西红柿炒鸡蛋?”)
 - 模板化prompt:模型需要将输入包装成特定格式的prompt,例如:
[系统指令]:你是一个烹饪助手,请回答以下问题。 [用户]:如何做西红柿炒鸡蛋? [助手]: - 作用:模板化prompt让模型明确任务目标(如回答问题),并模拟对话上下文。
 
输入输出形状:
- 输入文本长度:假设为10个字符(实际长度取决于具体输入)。
 - 模板化后的prompt长度:假设为30个字符(包含系统指令、用户问题和占位符)。
 
2. Tokenizer处理:从prompt到input_ids
步骤:
- Tokenization:将模板化prompt拆分为模型能理解的Token(如“西红柿”→“西红柿”,“炒”→“炒”)。
 - 映射到input_ids:每个Token被映射为对应的ID(例如,“西红柿”→1234,“炒”→5678)。
 
示例:
 假设模板化Prompt被拆分为10个Token,其input_ids为:
[101, 1234, 5678, 8901, 2345, 6789, 102, 3456, 7890, 102]
 
(其中101和102是特殊标记,如<BOS>和<EOS>,表示开始和结束)
输入输出形状:
input_ids的形状为(batch_size, seq_len)→(1, 10)attention_mask(可选)的形状为(1, 10),标记哪些位置是有效Token(1)或填充(0)。
3. 嵌入层:input_ids → hidden_states
步骤:
- Token Embedding:将input_ids映射为高维向量(如4096维)。
 - Positional Encoding:添加位置信息,让模型知道每个Token在序列中的位置。
 
示例:
- input_ids 
[101, 1234, 5678, ...]→ 隐藏状态hidden_states的形状为(1, 10, 4096)。 - 每个Token对应的向量包含其语义和位置信息(例如,“西红柿”对应的食物相关特征,以及它在句子中的位置)。
 
输入输出形状:
hidden_states的形状为(batch_size, seq_len, hidden_dim)→(1, 10, 4096)
4. Decoder Block处理:逐层计算
核心流程:
-  
Masked Self-Attention(带掩码的自注意力):
- 每个Token只能看到自己及之前的Token(防止“偷看”未来内容)。
 - 例如,在生成“西红柿炒鸡蛋”时,模型会先处理“西红柿”,再处理“炒”,确保生成逻辑连贯。
 
 -  
前馈网络(FFN):
- 对每个Token的隐藏状态进行非线性变换,增强表达能力。
 
 
示例:
- 假设模型有32层Decoder Block,每层都会更新 
hidden_states。 - 最终的 
hidden_states保留了完整的上下文信息(如“西红柿炒鸡蛋”的步骤描述)。 
输入输出形状:
- 每层Decoder Block的输入输出形状不变,仍为 
(1, 10, 4096) 
5. LM Head:从hidden_states到下一个词
步骤:
- 线性层:将最后一个Token的隐藏状态(形状为 
(1, 10, 4096))映射到词表维度(128000)。- 例如,对最后一个位置(
seq_len=9)的隐藏状态取值:hidden_states[:, 9, :]→ 形状(1, 4096)。 
 - 例如,对最后一个位置(
 - Softmax:将输出转换为概率分布(每个词的概率)。
 
示例:
- 假设模型预测下一个词是“步骤一”,其ID为9876,则概率分布中9876的值最高。
 
输入输出形状:
- 线性层输出形状:
(1, 128000) - 概率分布形状:
(1, 128000) 
6. 采样策略:从概率分布到下一个词
方法:
- Top-k采样:从概率最高的前k个词(如k=50)中随机选一个。
 - Greedy Search:直接选概率最高的词(如“步骤一”)。
 
示例:
- 模型选择“步骤一”作为下一个词,并将其ID(9876)添加到 
input_ids中。 - 新的 
input_ids变为:[101, 1234, 5678, ..., 9876](长度+1)。 
输入输出形状:
- 新的 
input_ids形状为(1, 11) 
7. 迭代生成:重复步骤3-6直到完成
流程:
- 将新的 
input_ids和hidden_states送回Decoder Block。 - 重复计算,逐步生成完整回答(如“步骤一:热锅凉油…”)。
 - 直到生成终止标记(如
<EOS>)或达到最大长度(如2048 Token)。 
示例:
- 生成完整回答后,
input_ids的长度可能变为200(假设生成190个新Token)。 - 最终的 
input_ids包含原始Prompt和生成的回答。 
8. Tokenizer反向处理:从input_ids到用户文本
步骤:
- 将生成的 
input_ids(含prompt和回答)截取回答部分(去掉prompt)。 - 使用Tokenizer将 
input_ids转换回自然语言文本(如“步骤一:热锅凉油…”)。 
输入输出形状:
- 截取后的 
input_ids形状为(1, 190) - 最终输出文本长度取决于生成内容(如“步骤一:热锅凉油…”)
 
总结流程图
用户输入 → 模板化Prompt → Tokenizer → input_ids (1,10)  → 嵌入层 → hidden_states (1,10,4096)  → Decoder Block ×32 → hidden_states (1,10,4096)  → LM Head → 概率分布 (1,128000)  → 采样 → 新input_ids (1,11)  → 重复生成 → input_ids (1,200)  → Tokenizer反向 → 用户文本
 
LlamaForCausalLM结构分析
以模型Llama-3.1-8B-Instruct为例,将一部分子结构信息折叠起来,将显示如下:
LlamaForCausalLM((model): LlamaModel((embed_tokens): VocabParallelEmbedding(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)(layers): ModuleList((0-31): 32 x LlamaDecoderLayer(...))(norm): RMSNorm())(lm_head): ParallelLMHead(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)(logits_processor): LogitsProcessor()(pooler): Pooler()
)
 
可以看到LlamaForCausalLM主要由几个关键部分组成:model, lm_head, logits_processor和pooler。这几个组件作用各不相同,我们现在来介绍一下他们。
1. model:核心解码器结构
 
(1) embed_tokens:词嵌入层
 
- 作用:将输入的Token ID(如“西红柿”→ID=1234)映射为4096维的向量,表示Token的语义和位置信息。
 - 技术细节: 
- 使用VocabParallelEmbedding(并行词嵌入,仅需了解,无需深入),支持分布式训练。
 - 词表大小为128256,覆盖多语言和特殊符号(如
<BOS>、<EOS>)。 
 - 输入输出形状: 
- 输入:
(batch_size, seq_len)→(1, 10)(假设输入10个Token) - 输出:
(batch_size, seq_len, hidden_dim)→(1, 10, 4096) 
 - 输入:
 
(2) layers:32层Decoder Block
 
- 核心结构: 
- 多头注意力(MHA):通过Grouped-Query Attention (GQA) 提高推理效率(Llama 3.1新增)。 
- 查询(Q)、键(K)、值(V)的维度:
d_model=4096,num_heads=32,head_dim=128。 - GQA机制:将K/V头数减少为
num_key_value_heads=8,降低计算开销。 
 - 查询(Q)、键(K)、值(V)的维度:
 - 前馈网络(MLP):使用SwiGLU激活函数(Sigmoid + Gated Linear Unit),替代传统ReLU。 
- 输入:
4096维 → 中间层:11008维 → 输出:4096维。 
 - 输入:
 - 归一化:每层使用RMSNorm(均方根归一化),稳定训练并加速收敛。
 
 - 多头注意力(MHA):通过Grouped-Query Attention (GQA) 提高推理效率(Llama 3.1新增)。 
 - 输入输出形状: 
- 每层输入/输出:
(1, 10, 4096)(与输入形状一致) 
 - 每层输入/输出:
 
(3) norm:最终归一化层
 
- 作用:对32层Decoder Block的输出进行最后一次归一化,确保数值稳定性。
 - 技术细节: 
- 使用RMSNorm,无需计算均值,直接对向量的模长标准化。
 - 公式:
hidden_states = hidden_states / sqrt(variance + ε),其中ε=1e-6。 
 
2. lm_head:语言模型头部
 
- 作用:将最终的隐藏状态(
hidden_dim=4096)映射为词表大小(vocab_size=128256)的概率分布,预测下一个词。 - 技术细节: 
- 使用ParallelLMHead(并行线性层),加速大规模词表的计算。
 - 参数量:
4096 × 128256 ≈ 5.16B(占模型总参数量的约76%)。 
 - 输入输出形状: 
- 输入:
(1, 4096)(取最后一个位置的隐藏状态) - 输出:
(1, 128256)(每个词的概率值) 
 - 输入:
 
3. logits_processor:概率分布处理器
 
- 作用:对
lm_head输出的概率分布进行后处理,控制生成策略。 - 常用功能: 
- 温度调节(Temperature):降低温度(
<1)使输出更确定,升高温度(>1)增加多样性。 - Top-k/Top-p采样:从概率最高的
k个词或累积概率达p的词中随机选择,平衡质量和多样性。 - 重复惩罚(Repetition Penalty):抑制重复生成相同词(如避免“西红柿西红柿”)。
 
 - 温度调节(Temperature):降低温度(
 - 输入输出形状: 
- 输入:
(1, 128256)(原始概率分布) - 输出:
(1, 128256)(处理后的概率分布) 
 - 输入:
 
4. pooler:池化层
 
- 作用:将整个序列的隐藏状态压缩为固定长度的向量,用于下游任务(如分类、相似度计算)。
 - 技术细节: 
- 默认取第一个Token(如
<BOS>)的隐藏状态作为全局表示。 - 或使用平均池化/最大池化,但Llama 3.1通常直接取
<BOS>。 
 - 默认取第一个Token(如
 - 输入输出形状: 
- 输入:
(1, 10, 4096)(全序列隐藏状态) - 输出:
(1, 4096)(固定长度的全局向量) 
 - 输入:
 
总结:组件协同工作流程
- 输入处理:用户输入文本 → 模板化Prompt → 
embed_tokens→(1, 10, 4096) - 特征提取:32层Decoder Block → 
hidden_states→(1, 10, 4096) - 归一化:
norm→ 稳定输出 - 生成预测: 
lm_head→(1, 128256)概率分布logits_processor→ 调整概率分布- 采样生成下一个词 → 更新 
input_ids 
 - 迭代生成:重复步骤1-4,直到生成终止标记(
<EOS>)或达到最大长度。 - 任务适配:
pooler提取全局向量 → 用于分类、相似度等任务。 
model:像一个厨师,逐步处理食材(Token)并调整火候(注意力机制)。lm_head:厨师的“味觉”,决定下一步该加什么调料(预测下一个词)。logits_processor:厨房的“规则制定者”,确保菜谱不重复且口味可控。pooler:食客的“总结笔记”,用一句话概括整道菜的风味(全局语义)。