文章目录
- 1. 位置编码概述
- 1.1 为什么需要位置编码?
- 2. 绝对位置编码 (Absolute Position Encoding)
- 2.1 原理
- 2.2 数学公式
- 2.3 代码实现
- 2.4 代码与公式的对应关系
- 2.5 特性与优势
- 2.6 可学习的绝对位置编码
- 3. 相对位置编码 (Relative Position Encoding)
- 3.1 原理
- 3.2 数学公式
- 3.3 Shaw et al. (2018) 相对位置编码
- 3.4 代码与公式的对应关系
- 3.5 特性与优势
- 3.6 带相对位置的注意力计算
- 4. RoPE (Rotary Position Embedding)
- 4.1 原理
- 4.2 数学公式
- 4.3 公式的详细推导
- 1. 旋转向量点积的展开
- 2. 合并第二项和第三项
- 3. 合并第一项和第四项
- 4. 验证相对位置依赖
- 4.5 代码实现
- 4.6 代码与公式的对应关系
- 4.7 特性与优势
- 4.8 带RoPE的注意力机制
- 5. 各种位置编码对比
- 5.1 特点对比
- 5.2 性能测试代码
- 6. 总结
1. 位置编码概述
位置编码是Transformer架构中的关键组件,用于为序列中的每个位置提供位置信息。由于自注意力机制本身是位置无关的,需要额外的位置信息来理解序列中元素的顺序。
1.1 为什么需要位置编码?
输入序列: "我 爱 中 国"[1] [2] [3] [4]没有位置编码:自注意力机制无法区分词语的顺序
有位置编码:每个位置都有唯一的位置标识
2. 绝对位置编码 (Absolute Position Encoding)
2.1 原理
绝对位置编码为序列中的每个位置分配一个唯一的编码向量。最经典的是Transformer论文中的正弦余弦位置编码。
2.2 数学公式
正弦位置编码的数学公式如下:
对于位置 pospospos 和维度 iii:
-
当 iii 为偶数时:
PE(pos,2i)=sin(pos100002i/dmodel)PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)PE(pos,2i)=sin(100002i/dmodelpos) -
当 iii 为奇数时:
PE(pos,2i+1)=cos(pos100002i/dmodel)PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)PE(pos,2i+1)=cos(100002i/dmodelpos)
其中:
- pospospos 是序列中的位置(范围:000 到 max_len−1max\_len-1max_len−1)
- iii 是编码向量中的维度索引(范围:000 到 dmodel−1d_{\text{model}}-1dmodel−1)
- dmodeld_{\text{model}}dmodel 是模型的嵌入维度
2.3 代码实现
import torch
import torch.nn as nn
import math
import numpy as np
import matplotlib.pyplot as pltclass SinusoidalPositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super().__init__()self.d_model = d_model# 创建位置编码矩阵pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len).unsqueeze(1).float()# 计算除法项div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))# 应用正弦和余弦函数pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)# 注册为缓冲区(不参与梯度更新)self.register_buffer('pe', pe.unsqueeze(0))def forward(self, x):# x shape: (batch_size, seq_len, d_model)seq_len = x.size(1)return x + self.pe[:, :seq_len]# 使用示例
d_model = 512
max_len = 100
pos_encoding = SinusoidalPositionalEncoding(d_model, max_len)# 模拟输入
batch_size, seq_len = 2, 20
x = torch.randn(batch_size, seq_len, d_model)
output = pos_encoding(x)
print(f"输入形状: {x.shape}")
print(f"输出形状: {output.shape}")
2.4 代码与公式的对应关系
-
创建位置编码矩阵:
pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1).float()
这里生成了一个形状为
(max_len, d_model)
的零矩阵,并准备好位置索引向量。 -
计算分母项:
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
这对应公式中的分母部分 100002i/dmodel10000^{2i/d_{\text{model}}}100002i/dmodel,通过指数和对数运算转换为:
exp(−log(10000)⋅2idmodel)\exp\left(-\frac{\log(10000) \cdot 2i}{d_{\text{model}}}\right)exp(−dmodellog(10000)⋅2i) -
应用正弦和余弦函数:
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维度使用正弦 pe[:, 1::2] = torch.cos(position * div_term) # 奇数维度使用余弦
这直接对应公式中的正弦和余弦部分,分别应用于偶数和奇数维度。
-
注册为缓冲区:
self.register_buffer('pe', pe.unsqueeze(0))
将位置编码注册为模型的缓冲区(不参与训练),并添加批次维度。
-
前向传播:
def forward(self, x):seq_len = x.size(1)return x + self.pe[:, :seq_len]
将位置编码加到输入张量上,只取与输入序列长度匹配的部分。
2.5 特性与优势
-
相对位置表示:正弦位置编码能够表达相对位置关系,因为对于任意固定偏移量 kkk,PEpos+kPE_{pos+k}PEpos+k 可以表示为 PEposPE_{pos}PEpos 的线性函数。
-
泛化能力:可以推广到比训练期间见过的更长的序列长度。
-
计算高效:无需学习参数,在推理时直接生成位置编码。
-
梯度稳定性:由于使用固定函数生成,不会影响模型训练的梯度流动。
2.6 可学习的绝对位置编码
class LearnablePositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super().__init__()self.pos_embedding = nn.Embedding(max_len, d_model)self.max_len = max_lendef forward(self, x):batch_size, seq_len, _ = x.shapepositions = torch.arange(seq_len, device=x.device).unsqueeze(0).repeat(batch_size, 1)pos_encodings = self.pos_embedding(positions)return x + pos_encodings# 使用示例
learnable_pos = LearnablePositionalEncoding(d_model, max_len)
output_learnable = learnable_pos(x)
print(f"可学习位置编码输出形状: {output_learnable.shape}")
3. 相对位置编码 (Relative Position Encoding)
3.1 原理
相对位置编码关注的是位置之间的相对关系,而不是绝对位置。这种方法在处理长序列时表现更好。
3.2 数学公式
相对位置编码的核心思想是在注意力计算中引入相对位置信息。对于两个位置 iii 和 jjj,其相对位置为 k=i−jk = i - jk=i−j,编码公式主要体现在注意力得分的计算中:
-
标准自注意力公式(无相对位置):
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dkQKT)V -
加入相对位置编码的注意力公式:
Attention(Q,K,V)=softmax((Q+Rq)(K+Rk)Tdk)(V+Rv)\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{(Q + R_q)(K + R_k)^T}{\sqrt{d_k}}\right)(V + R_v)Attention(Q,K,V)=softmax(dk(Q+Rq)(K+Rk)T)(V+Rv)其中:
- RqR_qRq、RkR_kRk、RvR_vRv 分别是查询(Query)、键(Key)、值(Value)的相对位置编码矩阵;
- RkR_kRk 和 RvR_vRv 通常由相对位置索引 kkk 映射得到,即 Rk=Ek(k)R_k = E_k(k)Rk=Ek(k) 和 Rv=Ev(k)R_v = E_v(k)Rv=Ev(k),其中 EkE_kEk 和 EvE_vEv 是可学习的嵌入矩阵。
3.3 Shaw et al. (2018) 相对位置编码
class RelativePositionEncoding(nn.Module):def __init__(self, d_model, max_relative_position=50):super().__init__()self.d_model = d_modelself.max_relative_position = max_relative_position# 相对位置嵌入vocab_size = 2 * max_relative_position + 1self.relative_position_k = nn.Embedding(vocab_size, d_model)self.relative_position_v = nn.Embedding(vocab_size, d_model)def get_relative_positions(self, seq_len):"""生成相对位置矩阵"""range_vec = torch.arange(seq_len)range_mat = range_vec.unsqueeze(0).repeat(seq_len, 1)distance_mat = range_mat - range_mat.transpose(0, 1)# 裁剪到最大相对位置distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)# 转换为正数索引final_mat = distance_mat_clipped + self.max_relative_positionreturn final_matdef forward(self, query, key, value):seq_len = query.size(1)relative_positions = self.get_relative_positions(seq_len)# 获取相对位置编码relative_position_k_emb = self.relative_position_k(relative_positions)relative_position_v_emb = self.relative_position_v(relative_positions)return relative_position_k_emb, relative_position_v_emb# 使用示例
rel_pos_encoding = RelativePositionEncoding(d_model)
q = torch.randn(batch_size, seq_len, d_model)
k = torch.randn(batch_size, seq_len, d_model)
v = torch.randn(batch_size, seq_len, d_model)rel_k, rel_v = rel_pos_encoding(q, k, v)
print(f"相对位置编码K形状: {rel_k.shape}")
print(f"相对位置编码V形状: {rel_v.shape}")
3.4 代码与公式的对应关系
-
初始化嵌入层:
vocab_size = 2 * max_relative_position + 1 self.relative_position_k = nn.Embedding(vocab_size, d_model) self.relative_position_v = nn.Embedding(vocab_size, d_model)
vocab_size
对应所有可能的相对位置范围(从-max_relative_position
到+max_relative_position
);relative_position_k
和relative_position_v
分别对应公式中的 EkE_kEk 和 EvE_vEv,用于将相对位置索引映射为嵌入向量。
-
生成相对位置矩阵:
def get_relative_positions(self, seq_len):range_vec = torch.arange(seq_len)range_mat = range_vec.unsqueeze(0).repeat(seq_len, 1)distance_mat = range_mat - range_mat.transpose(0, 1)distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)final_mat = distance_mat_clipped + self.max_relative_positionreturn final_mat
- 生成的
distance_mat
是所有位置对 (i,j)(i,j)(i,j) 的相对距离矩阵(即 k=i−jk = i - jk=i−j); clamp
操作将相对距离限制在预设范围内,避免过远的位置影响;final_mat
将相对距离转换为非负索引(通过加上max_relative_position
),便于嵌入层查找。
- 生成的
-
获取相对位置编码:
def forward(self, query, key, value):seq_len = query.size(1)relative_positions = self.get_relative_positions(seq_len)relative_position_k_emb = self.relative_position_k(relative_positions)relative_position_v_emb = self.relative_position_v(relative_positions)return relative_position_k_emb, relative_position_v_emb
relative_position_k_emb
和relative_position_v_emb
分别对应公式中的 RkR_kRk 和 RvR_vRv;- 它们的形状均为
(seq_len, seq_len, d_model)
,表示任意两个位置之间的相对位置编码。
3.5 特性与优势
-
捕捉相对位置关系
相比绝对位置编码(如Sinusoidal PE),相对位置编码直接建模token对之间的距离,更适合捕捉序列中的结构信息(如语法依赖关系)。 -
参数高效
只需存储有限范围内的相对位置嵌入(通常为2*max_relative_position+1
个向量),而不是为每个绝对位置存储一个向量。 -
泛化能力
对于长度超过训练时所见的序列,仍能通过相对位置编码处理,而绝对位置编码可能超出预定义范围。 -
灵活应用
可选择性地应用于注意力机制的不同组件(如仅应用于Key,或同时应用于Key和Value),根据任务需求调整。 -
提升长序列性能
在长文本任务(如文档摘要、长对话生成)中,相对位置编码能更好地捕捉远距离依赖关系。
3.6 带相对位置的注意力计算
class RelativeMultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads, max_relative_position=50):super().__init__()self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_headsself.w_q = nn.Linear(d_model, d_model)self.w_k = nn.Linear(d_model, d_model)self.w_v = nn.Linear(d_model, d_model)self.w_o = nn.Linear(d_model, d_model)self.rel_pos_encoding = RelativePositionEncoding(self.d_k, max_relative_position)def forward(self, query, key, value, mask=None):batch_size, seq_len, _ = query.shape# 线性变换Q = self.w_q(query).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)K = self.w_k(key).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)V = self.w_v(value).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)# 计算注意力分数attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)# 添加相对位置编码rel_k, rel_v = self.rel_pos_encoding(query, key, value)rel_k = rel_k.unsqueeze(0).unsqueeze(0).repeat(batch_size, self.num_heads, 1, 1, 1)# 相对位置注意力rel_attention = torch.matmul(Q.unsqueeze(-2), rel_k.transpose(-2, -1)).squeeze(-2)attention_scores = attention_scores + rel_attention# 应用掩码if mask is not None:attention_scores.masked_fill_(mask == 0, -1e9)# Softmaxattention_weights = torch.softmax(attention_scores, dim=-1)# 应用注意力权重context = torch.matmul(attention_weights, V)# 重新整形并通过输出层context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)output = self.w_o(context)return output, attention_weights
4. RoPE (Rotary Position Embedding)
4.1 原理
RoPE(Rotary Positional Encoding)是一种基于旋转机制的位置编码方法,通过旋转向量空间来隐式表达token间的相对位置关系。它在Transformer模型中取得了显著效果,尤其是在长序列建模和语言理解任务中。
4.2 数学公式
RoPE的核心思想是通过旋转操作将位置信息直接融入到向量表示中。对于位置 mmm 处的向量 qmq_mqm 和位置 nnn 处的向量 knk_nkn,RoPE的计算公式如下:
-
旋转操作:
对于位置 mmm 和维度 ddd,将向量 qmq_mqm 旋转 θm\theta_mθm 角度:
RoPE(qm,m)d=qm⋅cos(mθd)+RotateHalf(qm)⋅sin(mθd)\text{RoPE}(q_m, m)_d = q_m \cdot \cos(m\theta_d) + \text{RotateHalf}(q_m) \cdot \sin(m\theta_d) RoPE(qm,m)d=qm⋅cos(mθd)+RotateHalf(qm)⋅sin(mθd)
其中:- θd=110000ddmodel\theta_d = \frac{1}{10000^{\frac{d}{d_{\text{model}}}}}θd=10000dmodeld1 是频率参数;
- RotateHalf(x)\text{RotateHalf}(x)RotateHalf(x) 表示将向量 xxx 的前半部分与后半部分交换符号后拼接,即 [xd/2+1,xd/2+2,...,xd,−x1,−x2,...,−xd/2][x_{d/2+1}, x_{d/2+2}, ..., x_d, -x_1, -x_2, ..., -x_{d/2}][xd/2+1,xd/2+2,...,xd,−x1,−x2,...,−xd/2]。
-
在注意力机制中的应用:
RoPE通过以下方式改变注意力得分计算:Attention(qm,kn)=RoPE(qm,m)⋅RoPE(kn,n)\begin{aligned} \text{Attention}(q_m, k_n) &= \text{RoPE}(q_m, m) \cdot \text{RoPE}(k_n, n) \end{aligned} Attention(qm,kn)=RoPE(qm,m)⋅RoPE(kn,n)
4.3 公式的详细推导
1. 旋转向量点积的展开
RoPE(qm,m)⋅RoPE(kn,n)=[qmcos(mθ)+RotateHalf(qm)sin(mθ)]⋅[kncos(nθ)+RotateHalf(kn)sin(nθ)]=(qmcos(mθ))⋅(kncos(nθ))+(qmcos(mθ))⋅(RotateHalf(kn)sin(nθ))+(RotateHalf(qm)sin(mθ))⋅(kncos(nθ))+(RotateHalf(qm)sin(mθ))⋅(RotateHalf(kn)sin(nθ))\begin{aligned} &\text{RoPE}(q_m, m) \cdot \text{RoPE}(k_n, n) \\ =& \left[ q_m \cos(m\theta) + \text{RotateHalf}(q_m) \sin(m\theta) \right] \cdot \left[ k_n \cos(n\theta) + \text{RotateHalf}(k_n) \sin(n\theta) \right] \\ =& (q_m \cos(m\theta)) \cdot (k_n \cos(n\theta)) \\ &+ (q_m \cos(m\theta)) \cdot (\text{RotateHalf}(k_n) \sin(n\theta)) \\ &+ (\text{RotateHalf}(q_m) \sin(m\theta)) \cdot (k_n \cos(n\theta)) \\ &+ (\text{RotateHalf}(q_m) \sin(m\theta)) \cdot (\text{RotateHalf}(k_n) \sin(n\theta)) \end{aligned} ==RoPE(qm,m)⋅RoPE(kn,n)[qmcos(mθ)+RotateHalf(qm)sin(mθ)]⋅[kncos(nθ)+RotateHalf(kn)sin(nθ)](qmcos(mθ))⋅(kncos(nθ))+(qmcos(mθ))⋅(RotateHalf(kn)sin(nθ))+(RotateHalf(qm)sin(mθ))⋅(kncos(nθ))+(RotateHalf(qm)sin(mθ))⋅(RotateHalf(kn)sin(nθ))
2. 合并第二项和第三项
根据 RotateHalf\text{RotateHalf}RotateHalf 的正交性:
q⋅RotateHalf(k)=−RotateHalf(q)⋅kq \cdot \text{RotateHalf}(k) = -\text{RotateHalf}(q) \cdot k q⋅RotateHalf(k)=−RotateHalf(q)⋅k
以及,三角函数的角度差公式
sin(a−b)=sinacosb−cosasinb\sin(a-b) = \sin a \cos b - \cos a \sin b sin(a−b)=sinacosb−cosasinb
将上述展开式中的第二和第三项重写:
第二项=qm⋅RotateHalf(kn)⋅cos(mθ)sin(nθ)=−RotateHalf(qm)⋅kn⋅cos(mθ)sin(nθ)\begin{aligned} \text{第二项} &= q_m \cdot \text{RotateHalf}(k_n) \cdot \cos(m\theta)\sin(n\theta) \\ &= -\text{RotateHalf}(q_m) \cdot k_n \cdot \cos(m\theta)\sin(n\theta) \end{aligned} 第二项=qm⋅RotateHalf(kn)⋅cos(mθ)sin(nθ)=−RotateHalf(qm)⋅kn⋅cos(mθ)sin(nθ)
第三项=RotateHalf(qm)⋅kn⋅sin(mθ)cos(nθ)\begin{aligned} \text{第三项} &= \text{RotateHalf}(q_m) \cdot k_n \cdot \sin(m\theta)\cos(n\theta) \end{aligned} 第三项=RotateHalf(qm)⋅kn⋅sin(mθ)cos(nθ)
将第二和第三项合并:
第二项+第三项=RotateHalf(qm)⋅kn⋅[sin(mθ)cos(nθ)−cos(mθ)sin(nθ)]=RotateHalf(qm)⋅kn⋅sin((m−n)θ)\begin{aligned} \text{第二项} + \text{第三项} &= \text{RotateHalf}(q_m) \cdot k_n \cdot \left[ \sin(m\theta)\cos(n\theta) - \cos(m\theta)\sin(n\theta) \right] \\ &= \text{RotateHalf}(q_m) \cdot k_n \cdot \sin\left((m-n)\theta\right) \end{aligned} 第二项+第三项=RotateHalf(qm)⋅kn⋅[sin(mθ)cos(nθ)−cos(mθ)sin(nθ)]=RotateHalf(qm)⋅kn⋅sin((m−n)θ)
3. 合并第一项和第四项
第一项=qm⋅kn⋅cos(mθ)cos(nθ)第四项=RotateHalf(qm)⋅RotateHalf(kn)⋅sin(mθ)sin(nθ)\begin{aligned} \text{第一项} &= q_m \cdot k_n \cdot \cos(m\theta)\cos(n\theta) \\ \text{第四项} &= \text{RotateHalf}(q_m) \cdot \text{RotateHalf}(k_n) \cdot \sin(m\theta)\sin(n\theta) \end{aligned} 第一项第四项=qm⋅kn⋅cos(mθ)cos(nθ)=RotateHalf(qm)⋅RotateHalf(kn)⋅sin(mθ)sin(nθ)
根据 RotateHalf\text{RotateHalf}RotateHalf 的旋转后点积不变性:
RotateHalf(q)⋅RotateHalf(k)=q⋅k\text{RotateHalf}(q) \cdot \text{RotateHalf}(k) = q \cdot k RotateHalf(q)⋅RotateHalf(k)=q⋅k
以及,三角函数的角度差公式
cos(a−b)=cosacosb+sinasinbcos(a-b) = \cos a \cos b + \sin a \sin bcos(a−b)=cosacosb+sinasinb
第四项=RotateHalf(qm)⋅RotateHalf(kn)⋅sin(mθ)sin(nθ)=qm⋅kn⋅sin(mθ)sin(nθ)\begin{aligned} \text{第四项} &= \text{RotateHalf}(q_m) \cdot \text{RotateHalf}(k_n) \cdot \sin(m\theta)\sin(n\theta)\\ &= q_m \cdot k_n \cdot \sin(m\theta)\sin(n\theta) \end{aligned} 第四项=RotateHalf(qm)⋅RotateHalf(kn)⋅sin(mθ)sin(nθ)=qm⋅kn⋅sin(mθ)sin(nθ)
所以,可以将这两项合并为:
第一项+第四项=qm⋅kn⋅cos((m−n)θ)\begin{aligned} \text{第一项} + \text{第四项} &= q_m \cdot k_n \cdot \cos\left((m-n)\theta\right) \end{aligned} 第一项+第四项=qm⋅kn⋅cos((m−n)θ)
最终,ROPE上述表达式可简化为:
Attention(qm,kn)=RoPE(qm,m)⋅RoPE(kn,n)=qm⋅kn⋅cos((m−n)θ)+RotateHalf(qm)⋅kn⋅sin((m−n)θ)\begin{aligned} \text{Attention}(q_m, k_n) &= \text{RoPE}(q_m, m) \cdot \text{RoPE}(k_n, n)\\ &= q_m \cdot k_n \cdot \cos\left((m-n)\theta\right) + \text{RotateHalf}(q_m) \cdot k_n \cdot \sin\left((m-n)\theta\right) \end{aligned} Attention(qm,kn)=RoPE(qm,m)⋅RoPE(kn,n)=qm⋅kn⋅cos((m−n)θ)+RotateHalf(qm)⋅kn⋅sin((m−n)θ)
4. 验证相对位置依赖
最终表达式中的所有三角函数项均包含 (m−n)θ(m-n)\theta(m−n)θ,即只依赖于位置差 (m-n),而非单独的 (m) 或 (n)。这表明:
- 相对位置信息被隐式编码在注意力得分中
- 当 (m-n) 固定时,无论 (m) 和 (n) 的绝对位置如何变化,注意力得分保持不变
- 模型能够通过这种机制学习到序列中的相对距离关系
4.5 代码实现
class RoPEPositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000, base=10000):super().__init__()self.d_model = d_modelself.max_len = max_lenself.base = base# 预计算频率inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))self.register_buffer('inv_freq', inv_freq)# 预计算位置编码self._build_cache(max_len)def _build_cache(self, max_len):positions = torch.arange(max_len).float()angles = torch.outer(positions, self.inv_freq)# 计算sin和cossin_angles = torch.sin(angles)cos_angles = torch.cos(angles)# 存储缓存self.register_buffer('sin_cached', sin_angles)self.register_buffer('cos_cached', cos_angles)def rotate_half(self, x):"""旋转向量的一半维度"""x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]return torch.cat([-x2, x1], dim=-1)def forward(self, x, seq_len=None):if seq_len is None:seq_len = x.shape[-2]# 获取sin和cos值sin = self.sin_cached[:seq_len, :].unsqueeze(0)cos = self.cos_cached[:seq_len, :].unsqueeze(0)# 扩展维度以匹配输入if x.dim() == 4: # (batch, heads, seq_len, dim)sin = sin.unsqueeze(1)cos = cos.unsqueeze(1)# 重复sin和cos以匹配完整维度sin = sin.repeat_interleave(2, dim=-1)cos = cos.repeat_interleave(2, dim=-1)# 应用旋转return x * cos + self.rotate_half(x) * sin# 使用示例
rope = RoPEPositionalEncoding(d_model)
x_rope = torch.randn(batch_size, seq_len, d_model)
output_rope = rope(x_rope)
print(f"RoPE输出形状: {output_rope.shape}")
4.6 代码与公式的对应关系
-
预计算频率参数:
inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model)) self.register_buffer('inv_freq', inv_freq)
这对应公式中的 θd=110000d/dmodel\theta_d = \frac{1}{10000^{d/d_{\text{model}}}}θd=10000d/dmodel1,用于生成不同维度的旋转频率。
-
预计算位置角度的sin和cos值:
def _build_cache(self, max_len):positions = torch.arange(max_len).float()angles = torch.outer(positions, self.inv_freq)sin_angles = torch.sin(angles)cos_angles = torch.cos(angles)self.register_buffer('sin_cached', sin_angles)self.register_buffer('cos_cached', cos_angles)
angles
矩阵对应 mθdm\theta_dmθd,即位置 mmm 在维度 ddd 上的旋转角度;sin_cached
和cos_cached
分别存储 sin(mθd)\sin(m\theta_d)sin(mθd) 和 cos(mθd)\cos(m\theta_d)cos(mθd),避免重复计算。
-
向量旋转操作:
def rotate_half(self, x):x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]return torch.cat([-x2, x1], dim=-1)
实现了 RotateHalf(x)\text{RotateHalf}(x)RotateHalf(x) 操作,将向量后半部分取负后与前半部分拼接。
-
应用旋转位置编码:
def forward(self, x, seq_len=None):# 获取对应位置的sin和cos值sin = self.sin_cached[:seq_len, :].unsqueeze(0)cos = self.cos_cached[:seq_len, :].unsqueeze(0)# 扩展维度以匹配输入if x.dim() == 4: # (batch, heads, seq_len, dim)sin = sin.unsqueeze(1)cos = cos.unsqueeze(1)# 重复以匹配完整维度sin = sin.repeat_interleave(2, dim=-1)cos = cos.repeat_interleave(2, dim=-1)# 应用旋转:x * cos + rotate_half(x) * sinreturn x * cos + self.rotate_half(x) * sin
这直接对应RoPE的核心公式:
RoPE(x,m)=x⋅cos(mθ)+RotateHalf(x)⋅sin(mθ)\text{RoPE}(x, m) = x \cdot \cos(m\theta) + \text{RotateHalf}(x) \cdot \sin(m\theta)RoPE(x,m)=x⋅cos(mθ)+RotateHalf(x)⋅sin(mθ)
4.7 特性与优势
-
隐式相对位置编码
RoPE通过旋转操作隐式地将相对位置信息融入注意力计算,使得模型能够更好地捕捉序列中的相对距离关系,优于传统的绝对位置编码。 -
旋转不变性
RoPE保证了位置编码的旋转不变性,即对于任意向量 xxx 和位置偏移 kkk,有:
RoPE(x,m)⋅RoPE(y,m+k)=RoPE(x,0)⋅RoPE(y,k)\text{RoPE}(x, m) \cdot \text{RoPE}(y, m+k) = \text{RoPE}(x, 0) \cdot \text{RoPE}(y, k)RoPE(x,m)⋅RoPE(y,m+k)=RoPE(x,0)⋅RoPE(y,k)
这使得模型在不同位置上具有一致的表示能力。 -
无需额外参数
RoPE不需要像可学习位置编码那样引入大量额外参数,只需预计算 sin\sinsin 和 cos\coscos 值,计算效率高。 -
长序列建模能力
实验表明,RoPE在长序列任务(如长文本生成、文档级NLP)中表现优于Sinusoidal PE和绝对位置编码,能够更有效地捕捉远距离依赖关系。 -
兼容性强
可以直接应用于现有的Transformer架构,无需修改模型的整体结构,易于集成到各种NLP系统中。
4.8 带RoPE的注意力机制
class RoPEMultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads, max_len=5000):super().__init__()self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_headsself.w_q = nn.Linear(d_model, d_model)self.w_k = nn.Linear(d_model, d_model)self.w_v = nn.Linear(d_model, d_model)self.w_o = nn.Linear(d_model, d_model)self.rope = RoPEPositionalEncoding(self.d_k, max_len)def forward(self, query, key, value, mask=None):batch_size, seq_len, _ = query.shape# 线性变换Q = self.w_q(query).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)K = self.w_k(key).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)V = self.w_v(value).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)# 应用RoPEQ = self.rope(Q)K = self.rope(K)# 计算注意力attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:attention_scores.masked_fill_(mask == 0, -1e9)attention_weights = torch.softmax(attention_scores, dim=-1)context = torch.matmul(attention_weights, V)# 重新整形context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)output = self.w_o(context)return output, attention_weights
5. 各种位置编码对比
5.1 特点对比
编码类型 | 优点 | 缺点 | 适用场景 |
---|---|---|---|
绝对位置编码 | 简单直观,计算效率高 | 对长序列泛化能力差 | 固定长度序列 |
相对位置编码 | 更好的泛化能力 | 计算复杂度高 | 需要处理可变长度序列 |
RoPE | 完美的长度外推能力 | 实现相对复杂 | 长序列,语言模型 |
5.2 性能测试代码
import timedef benchmark_position_encodings():batch_size, seq_len, d_model = 32, 512, 512num_heads = 8# 创建模型models = {'Sinusoidal': SinusoidalPositionalEncoding(d_model),'Learnable': LearnablePositionalEncoding(d_model),'RoPE': RoPEPositionalEncoding(d_model)}x = torch.randn(batch_size, seq_len, d_model)# 基准测试for name, model in models.items():start_time = time.time()for _ in range(100):with torch.no_grad():output = model(x)end_time = time.time()print(f"{name}: {(end_time - start_time) * 1000:.2f}ms")# 运行基准测试
benchmark_position_encodings()
6. 总结
位置编码是Transformer架构中的关键组件,不同类型的位置编码各有特点:
- 绝对位置编码:简单高效,适用于固定长度序列
- 相对位置编码:关注位置关系,泛化能力更强
- RoPE:通过旋转矩阵优雅地处理位置信息,支持长度外推
选择合适的位置编码方式需要根据具体应用场景和性能需求来决定。现代大语言模型(如GPT、LLaMA等)普遍采用RoPE,因为它在处理长序列时表现出色。