位置编码/绝对位置编码/相对位置编码/Rope原理+公式详细推导及代码实现

文章目录

    • 1. 位置编码概述
      • 1.1 为什么需要位置编码?
    • 2. 绝对位置编码 (Absolute Position Encoding)
      • 2.1 原理
      • 2.2 数学公式
      • 2.3 代码实现
      • 2.4 代码与公式的对应关系
      • 2.5 特性与优势
      • 2.6 可学习的绝对位置编码
    • 3. 相对位置编码 (Relative Position Encoding)
      • 3.1 原理
      • 3.2 数学公式
      • 3.3 Shaw et al. (2018) 相对位置编码
      • 3.4 代码与公式的对应关系
      • 3.5 特性与优势
      • 3.6 带相对位置的注意力计算
    • 4. RoPE (Rotary Position Embedding)
      • 4.1 原理
      • 4.2 数学公式
      • 4.3 公式的详细推导
        • 1. 旋转向量点积的展开
        • 2. 合并第二项和第三项
        • 3. 合并第一项和第四项
        • 4. 验证相对位置依赖
      • 4.5 代码实现
      • 4.6 代码与公式的对应关系
      • 4.7 特性与优势
      • 4.8 带RoPE的注意力机制
    • 5. 各种位置编码对比
      • 5.1 特点对比
      • 5.2 性能测试代码
    • 6. 总结

1. 位置编码概述

位置编码是Transformer架构中的关键组件,用于为序列中的每个位置提供位置信息。由于自注意力机制本身是位置无关的,需要额外的位置信息来理解序列中元素的顺序。

1.1 为什么需要位置编码?

输入序列: "我  爱  中   国"[1] [2] [3] [4]没有位置编码:自注意力机制无法区分词语的顺序
有位置编码:每个位置都有唯一的位置标识

2. 绝对位置编码 (Absolute Position Encoding)

2.1 原理

绝对位置编码为序列中的每个位置分配一个唯一的编码向量。最经典的是Transformer论文中的正弦余弦位置编码。

2.2 数学公式

正弦位置编码的数学公式如下:

对于位置 pospospos 和维度 iii

  • iii 为偶数时:
    PE(pos,2i)=sin⁡(pos100002i/dmodel)PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)PE(pos,2i)=sin(100002i/dmodelpos)

  • iii 为奇数时:
    PE(pos,2i+1)=cos⁡(pos100002i/dmodel)PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)PE(pos,2i+1)=cos(100002i/dmodelpos)

其中:

  • pospospos 是序列中的位置(范围:000max_len−1max\_len-1max_len1
  • iii 是编码向量中的维度索引(范围:000dmodel−1d_{\text{model}}-1dmodel1
  • dmodeld_{\text{model}}dmodel 是模型的嵌入维度

2.3 代码实现

import torch
import torch.nn as nn
import math
import numpy as np
import matplotlib.pyplot as pltclass SinusoidalPositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super().__init__()self.d_model = d_model# 创建位置编码矩阵pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len).unsqueeze(1).float()# 计算除法项div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))# 应用正弦和余弦函数pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)# 注册为缓冲区(不参与梯度更新)self.register_buffer('pe', pe.unsqueeze(0))def forward(self, x):# x shape: (batch_size, seq_len, d_model)seq_len = x.size(1)return x + self.pe[:, :seq_len]# 使用示例
d_model = 512
max_len = 100
pos_encoding = SinusoidalPositionalEncoding(d_model, max_len)# 模拟输入
batch_size, seq_len = 2, 20
x = torch.randn(batch_size, seq_len, d_model)
output = pos_encoding(x)
print(f"输入形状: {x.shape}")
print(f"输出形状: {output.shape}")

2.4 代码与公式的对应关系

  1. 创建位置编码矩阵

    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len).unsqueeze(1).float()
    

    这里生成了一个形状为 (max_len, d_model) 的零矩阵,并准备好位置索引向量。

  2. 计算分母项

    div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
    

    这对应公式中的分母部分 100002i/dmodel10000^{2i/d_{\text{model}}}100002i/dmodel,通过指数和对数运算转换为:
    exp⁡(−log⁡(10000)⋅2idmodel)\exp\left(-\frac{\log(10000) \cdot 2i}{d_{\text{model}}}\right)exp(dmodellog(10000)2i)

  3. 应用正弦和余弦函数

    pe[:, 0::2] = torch.sin(position * div_term)  # 偶数维度使用正弦
    pe[:, 1::2] = torch.cos(position * div_term)  # 奇数维度使用余弦
    

    这直接对应公式中的正弦和余弦部分,分别应用于偶数和奇数维度。

  4. 注册为缓冲区

    self.register_buffer('pe', pe.unsqueeze(0))
    

    将位置编码注册为模型的缓冲区(不参与训练),并添加批次维度。

  5. 前向传播

    def forward(self, x):seq_len = x.size(1)return x + self.pe[:, :seq_len]
    

    将位置编码加到输入张量上,只取与输入序列长度匹配的部分。

2.5 特性与优势

  1. 相对位置表示:正弦位置编码能够表达相对位置关系,因为对于任意固定偏移量 kkkPEpos+kPE_{pos+k}PEpos+k 可以表示为 PEposPE_{pos}PEpos 的线性函数。

  2. 泛化能力:可以推广到比训练期间见过的更长的序列长度。

  3. 计算高效:无需学习参数,在推理时直接生成位置编码。

  4. 梯度稳定性:由于使用固定函数生成,不会影响模型训练的梯度流动。

2.6 可学习的绝对位置编码

class LearnablePositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000):super().__init__()self.pos_embedding = nn.Embedding(max_len, d_model)self.max_len = max_lendef forward(self, x):batch_size, seq_len, _ = x.shapepositions = torch.arange(seq_len, device=x.device).unsqueeze(0).repeat(batch_size, 1)pos_encodings = self.pos_embedding(positions)return x + pos_encodings# 使用示例
learnable_pos = LearnablePositionalEncoding(d_model, max_len)
output_learnable = learnable_pos(x)
print(f"可学习位置编码输出形状: {output_learnable.shape}")

3. 相对位置编码 (Relative Position Encoding)

3.1 原理

相对位置编码关注的是位置之间的相对关系,而不是绝对位置。这种方法在处理长序列时表现更好。

3.2 数学公式

相对位置编码的核心思想是在注意力计算中引入相对位置信息。对于两个位置 iiijjj,其相对位置为 k=i−jk = i - jk=ij,编码公式主要体现在注意力得分的计算中:

  1. 标准自注意力公式(无相对位置):
    Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dkQKT)V

  2. 加入相对位置编码的注意力公式
    Attention(Q,K,V)=softmax((Q+Rq)(K+Rk)Tdk)(V+Rv)\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{(Q + R_q)(K + R_k)^T}{\sqrt{d_k}}\right)(V + R_v)Attention(Q,K,V)=softmax(dk(Q+Rq)(K+Rk)T)(V+Rv)

    其中:

    • RqR_qRqRkR_kRkRvR_vRv 分别是查询(Query)、键(Key)、值(Value)的相对位置编码矩阵;
    • RkR_kRkRvR_vRv 通常由相对位置索引 kkk 映射得到,即 Rk=Ek(k)R_k = E_k(k)Rk=Ek(k)Rv=Ev(k)R_v = E_v(k)Rv=Ev(k),其中 EkE_kEkEvE_vEv 是可学习的嵌入矩阵。

3.3 Shaw et al. (2018) 相对位置编码

class RelativePositionEncoding(nn.Module):def __init__(self, d_model, max_relative_position=50):super().__init__()self.d_model = d_modelself.max_relative_position = max_relative_position# 相对位置嵌入vocab_size = 2 * max_relative_position + 1self.relative_position_k = nn.Embedding(vocab_size, d_model)self.relative_position_v = nn.Embedding(vocab_size, d_model)def get_relative_positions(self, seq_len):"""生成相对位置矩阵"""range_vec = torch.arange(seq_len)range_mat = range_vec.unsqueeze(0).repeat(seq_len, 1)distance_mat = range_mat - range_mat.transpose(0, 1)# 裁剪到最大相对位置distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)# 转换为正数索引final_mat = distance_mat_clipped + self.max_relative_positionreturn final_matdef forward(self, query, key, value):seq_len = query.size(1)relative_positions = self.get_relative_positions(seq_len)# 获取相对位置编码relative_position_k_emb = self.relative_position_k(relative_positions)relative_position_v_emb = self.relative_position_v(relative_positions)return relative_position_k_emb, relative_position_v_emb# 使用示例
rel_pos_encoding = RelativePositionEncoding(d_model)
q = torch.randn(batch_size, seq_len, d_model)
k = torch.randn(batch_size, seq_len, d_model)
v = torch.randn(batch_size, seq_len, d_model)rel_k, rel_v = rel_pos_encoding(q, k, v)
print(f"相对位置编码K形状: {rel_k.shape}")
print(f"相对位置编码V形状: {rel_v.shape}")

3.4 代码与公式的对应关系

  1. 初始化嵌入层

    vocab_size = 2 * max_relative_position + 1
    self.relative_position_k = nn.Embedding(vocab_size, d_model)
    self.relative_position_v = nn.Embedding(vocab_size, d_model)
    
    • vocab_size 对应所有可能的相对位置范围(从 -max_relative_position+max_relative_position);
    • relative_position_krelative_position_v 分别对应公式中的 EkE_kEkEvE_vEv,用于将相对位置索引映射为嵌入向量。
  2. 生成相对位置矩阵

    def get_relative_positions(self, seq_len):range_vec = torch.arange(seq_len)range_mat = range_vec.unsqueeze(0).repeat(seq_len, 1)distance_mat = range_mat - range_mat.transpose(0, 1)distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)final_mat = distance_mat_clipped + self.max_relative_positionreturn final_mat
    
    • 生成的 distance_mat 是所有位置对 (i,j)(i,j)(i,j) 的相对距离矩阵(即 k=i−jk = i - jk=ij);
    • clamp 操作将相对距离限制在预设范围内,避免过远的位置影响;
    • final_mat 将相对距离转换为非负索引(通过加上 max_relative_position),便于嵌入层查找。
  3. 获取相对位置编码

    def forward(self, query, key, value):seq_len = query.size(1)relative_positions = self.get_relative_positions(seq_len)relative_position_k_emb = self.relative_position_k(relative_positions)relative_position_v_emb = self.relative_position_v(relative_positions)return relative_position_k_emb, relative_position_v_emb
    
    • relative_position_k_embrelative_position_v_emb 分别对应公式中的 RkR_kRkRvR_vRv
    • 它们的形状均为 (seq_len, seq_len, d_model),表示任意两个位置之间的相对位置编码。

3.5 特性与优势

  1. 捕捉相对位置关系
    相比绝对位置编码(如Sinusoidal PE),相对位置编码直接建模token对之间的距离,更适合捕捉序列中的结构信息(如语法依赖关系)。

  2. 参数高效
    只需存储有限范围内的相对位置嵌入(通常为 2*max_relative_position+1 个向量),而不是为每个绝对位置存储一个向量。

  3. 泛化能力
    对于长度超过训练时所见的序列,仍能通过相对位置编码处理,而绝对位置编码可能超出预定义范围。

  4. 灵活应用
    可选择性地应用于注意力机制的不同组件(如仅应用于Key,或同时应用于Key和Value),根据任务需求调整。

  5. 提升长序列性能
    在长文本任务(如文档摘要、长对话生成)中,相对位置编码能更好地捕捉远距离依赖关系。

3.6 带相对位置的注意力计算

class RelativeMultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads, max_relative_position=50):super().__init__()self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_headsself.w_q = nn.Linear(d_model, d_model)self.w_k = nn.Linear(d_model, d_model)self.w_v = nn.Linear(d_model, d_model)self.w_o = nn.Linear(d_model, d_model)self.rel_pos_encoding = RelativePositionEncoding(self.d_k, max_relative_position)def forward(self, query, key, value, mask=None):batch_size, seq_len, _ = query.shape# 线性变换Q = self.w_q(query).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)K = self.w_k(key).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)V = self.w_v(value).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)# 计算注意力分数attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)# 添加相对位置编码rel_k, rel_v = self.rel_pos_encoding(query, key, value)rel_k = rel_k.unsqueeze(0).unsqueeze(0).repeat(batch_size, self.num_heads, 1, 1, 1)# 相对位置注意力rel_attention = torch.matmul(Q.unsqueeze(-2), rel_k.transpose(-2, -1)).squeeze(-2)attention_scores = attention_scores + rel_attention# 应用掩码if mask is not None:attention_scores.masked_fill_(mask == 0, -1e9)# Softmaxattention_weights = torch.softmax(attention_scores, dim=-1)# 应用注意力权重context = torch.matmul(attention_weights, V)# 重新整形并通过输出层context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)output = self.w_o(context)return output, attention_weights

4. RoPE (Rotary Position Embedding)

4.1 原理

RoPE(Rotary Positional Encoding)是一种基于旋转机制的位置编码方法,通过旋转向量空间来隐式表达token间的相对位置关系。它在Transformer模型中取得了显著效果,尤其是在长序列建模和语言理解任务中。

4.2 数学公式

RoPE的核心思想是通过旋转操作将位置信息直接融入到向量表示中。对于位置 mmm 处的向量 qmq_mqm 和位置 nnn 处的向量 knk_nkn,RoPE的计算公式如下:

  1. 旋转操作
    对于位置 mmm 和维度 ddd,将向量 qmq_mqm 旋转 θm\theta_mθm 角度:
    RoPE(qm,m)d=qm⋅cos⁡(mθd)+RotateHalf(qm)⋅sin⁡(mθd)\text{RoPE}(q_m, m)_d = q_m \cdot \cos(m\theta_d) + \text{RotateHalf}(q_m) \cdot \sin(m\theta_d) RoPE(qm,m)d=qmcos(mθd)+RotateHalf(qm)sin(mθd)
    其中:

    • θd=110000ddmodel\theta_d = \frac{1}{10000^{\frac{d}{d_{\text{model}}}}}θd=10000dmodeld1 是频率参数;
    • RotateHalf(x)\text{RotateHalf}(x)RotateHalf(x) 表示将向量 xxx 的前半部分与后半部分交换符号后拼接,即 [xd/2+1,xd/2+2,...,xd,−x1,−x2,...,−xd/2][x_{d/2+1}, x_{d/2+2}, ..., x_d, -x_1, -x_2, ..., -x_{d/2}][xd/2+1,xd/2+2,...,xd,x1,x2,...,xd/2]
  2. 在注意力机制中的应用
    RoPE通过以下方式改变注意力得分计算:

    Attention(qm,kn)=RoPE(qm,m)⋅RoPE(kn,n)\begin{aligned} \text{Attention}(q_m, k_n) &= \text{RoPE}(q_m, m) \cdot \text{RoPE}(k_n, n) \end{aligned} Attention(qm,kn)=RoPE(qm,m)RoPE(kn,n)

4.3 公式的详细推导

1. 旋转向量点积的展开

RoPE(qm,m)⋅RoPE(kn,n)=[qmcos⁡(mθ)+RotateHalf(qm)sin⁡(mθ)]⋅[kncos⁡(nθ)+RotateHalf(kn)sin⁡(nθ)]=(qmcos⁡(mθ))⋅(kncos⁡(nθ))+(qmcos⁡(mθ))⋅(RotateHalf(kn)sin⁡(nθ))+(RotateHalf(qm)sin⁡(mθ))⋅(kncos⁡(nθ))+(RotateHalf(qm)sin⁡(mθ))⋅(RotateHalf(kn)sin⁡(nθ))\begin{aligned} &\text{RoPE}(q_m, m) \cdot \text{RoPE}(k_n, n) \\ =& \left[ q_m \cos(m\theta) + \text{RotateHalf}(q_m) \sin(m\theta) \right] \cdot \left[ k_n \cos(n\theta) + \text{RotateHalf}(k_n) \sin(n\theta) \right] \\ =& (q_m \cos(m\theta)) \cdot (k_n \cos(n\theta)) \\ &+ (q_m \cos(m\theta)) \cdot (\text{RotateHalf}(k_n) \sin(n\theta)) \\ &+ (\text{RotateHalf}(q_m) \sin(m\theta)) \cdot (k_n \cos(n\theta)) \\ &+ (\text{RotateHalf}(q_m) \sin(m\theta)) \cdot (\text{RotateHalf}(k_n) \sin(n\theta)) \end{aligned} ==RoPE(qm,m)RoPE(kn,n)[qmcos(mθ)+RotateHalf(qm)sin(mθ)][kncos(nθ)+RotateHalf(kn)sin(nθ)](qmcos(mθ))(kncos(nθ))+(qmcos(mθ))(RotateHalf(kn)sin(nθ))+(RotateHalf(qm)sin(mθ))(kncos(nθ))+(RotateHalf(qm)sin(mθ))(RotateHalf(kn)sin(nθ))

2. 合并第二项和第三项

根据 RotateHalf\text{RotateHalf}RotateHalf 的正交性:
q⋅RotateHalf(k)=−RotateHalf(q)⋅kq \cdot \text{RotateHalf}(k) = -\text{RotateHalf}(q) \cdot k qRotateHalf(k)=RotateHalf(q)k
以及,三角函数的角度差公式
sin⁡(a−b)=sin⁡acos⁡b−cos⁡asin⁡b\sin(a-b) = \sin a \cos b - \cos a \sin b sin(ab)=sinacosbcosasinb

将上述展开式中的第二和第三项重写:
第二项=qm⋅RotateHalf(kn)⋅cos⁡(mθ)sin⁡(nθ)=−RotateHalf(qm)⋅kn⋅cos⁡(mθ)sin⁡(nθ)\begin{aligned} \text{第二项} &= q_m \cdot \text{RotateHalf}(k_n) \cdot \cos(m\theta)\sin(n\theta) \\ &= -\text{RotateHalf}(q_m) \cdot k_n \cdot \cos(m\theta)\sin(n\theta) \end{aligned} 第二项=qmRotateHalf(kn)cos(mθ)sin(nθ)=RotateHalf(qm)kncos(mθ)sin(nθ)

第三项=RotateHalf(qm)⋅kn⋅sin⁡(mθ)cos⁡(nθ)\begin{aligned} \text{第三项} &= \text{RotateHalf}(q_m) \cdot k_n \cdot \sin(m\theta)\cos(n\theta) \end{aligned} 第三项=RotateHalf(qm)knsin(mθ)cos(nθ)

将第二和第三项合并:
第二项+第三项=RotateHalf(qm)⋅kn⋅[sin⁡(mθ)cos⁡(nθ)−cos⁡(mθ)sin⁡(nθ)]=RotateHalf(qm)⋅kn⋅sin⁡((m−n)θ)\begin{aligned} \text{第二项} + \text{第三项} &= \text{RotateHalf}(q_m) \cdot k_n \cdot \left[ \sin(m\theta)\cos(n\theta) - \cos(m\theta)\sin(n\theta) \right] \\ &= \text{RotateHalf}(q_m) \cdot k_n \cdot \sin\left((m-n)\theta\right) \end{aligned} 第二项+第三项=RotateHalf(qm)kn[sin(mθ)cos(nθ)cos(mθ)sin(nθ)]=RotateHalf(qm)knsin((mn)θ)

3. 合并第一项和第四项

第一项=qm⋅kn⋅cos⁡(mθ)cos⁡(nθ)第四项=RotateHalf(qm)⋅RotateHalf(kn)⋅sin⁡(mθ)sin⁡(nθ)\begin{aligned} \text{第一项} &= q_m \cdot k_n \cdot \cos(m\theta)\cos(n\theta) \\ \text{第四项} &= \text{RotateHalf}(q_m) \cdot \text{RotateHalf}(k_n) \cdot \sin(m\theta)\sin(n\theta) \end{aligned} 第一项第四项=qmkncos(mθ)cos(nθ)=RotateHalf(qm)RotateHalf(kn)sin(mθ)sin(nθ)

根据 RotateHalf\text{RotateHalf}RotateHalf 的旋转后点积不变性:
RotateHalf(q)⋅RotateHalf(k)=q⋅k\text{RotateHalf}(q) \cdot \text{RotateHalf}(k) = q \cdot k RotateHalf(q)RotateHalf(k)=qk
以及,三角函数的角度差公式
cos(a−b)=cos⁡acos⁡b+sin⁡asin⁡bcos(a-b) = \cos a \cos b + \sin a \sin bcos(ab)=cosacosb+sinasinb
第四项=RotateHalf(qm)⋅RotateHalf(kn)⋅sin⁡(mθ)sin⁡(nθ)=qm⋅kn⋅sin⁡(mθ)sin⁡(nθ)\begin{aligned} \text{第四项} &= \text{RotateHalf}(q_m) \cdot \text{RotateHalf}(k_n) \cdot \sin(m\theta)\sin(n\theta)\\ &= q_m \cdot k_n \cdot \sin(m\theta)\sin(n\theta) \end{aligned} 第四项=RotateHalf(qm)RotateHalf(kn)sin(mθ)sin(nθ)=qmknsin(mθ)sin(nθ)
所以,可以将这两项合并为:
第一项+第四项=qm⋅kn⋅cos⁡((m−n)θ)\begin{aligned} \text{第一项} + \text{第四项} &= q_m \cdot k_n \cdot \cos\left((m-n)\theta\right) \end{aligned} 第一项+第四项=qmkncos((mn)θ)

最终,ROPE上述表达式可简化为:
Attention(qm,kn)=RoPE(qm,m)⋅RoPE(kn,n)=qm⋅kn⋅cos⁡((m−n)θ)+RotateHalf(qm)⋅kn⋅sin⁡((m−n)θ)\begin{aligned} \text{Attention}(q_m, k_n) &= \text{RoPE}(q_m, m) \cdot \text{RoPE}(k_n, n)\\ &= q_m \cdot k_n \cdot \cos\left((m-n)\theta\right) + \text{RotateHalf}(q_m) \cdot k_n \cdot \sin\left((m-n)\theta\right) \end{aligned} Attention(qm,kn)=RoPE(qm,m)RoPE(kn,n)=qmkncos((mn)θ)+RotateHalf(qm)knsin((mn)θ)

4. 验证相对位置依赖

最终表达式中的所有三角函数项均包含 (m−n)θ(m-n)\theta(mn)θ,即只依赖于位置差 (m-n),而非单独的 (m) 或 (n)。这表明:

  • 相对位置信息被隐式编码在注意力得分中
  • 当 (m-n) 固定时,无论 (m) 和 (n) 的绝对位置如何变化,注意力得分保持不变
  • 模型能够通过这种机制学习到序列中的相对距离关系

4.5 代码实现

class RoPEPositionalEncoding(nn.Module):def __init__(self, d_model, max_len=5000, base=10000):super().__init__()self.d_model = d_modelself.max_len = max_lenself.base = base# 预计算频率inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))self.register_buffer('inv_freq', inv_freq)# 预计算位置编码self._build_cache(max_len)def _build_cache(self, max_len):positions = torch.arange(max_len).float()angles = torch.outer(positions, self.inv_freq)# 计算sin和cossin_angles = torch.sin(angles)cos_angles = torch.cos(angles)# 存储缓存self.register_buffer('sin_cached', sin_angles)self.register_buffer('cos_cached', cos_angles)def rotate_half(self, x):"""旋转向量的一半维度"""x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]return torch.cat([-x2, x1], dim=-1)def forward(self, x, seq_len=None):if seq_len is None:seq_len = x.shape[-2]# 获取sin和cos值sin = self.sin_cached[:seq_len, :].unsqueeze(0)cos = self.cos_cached[:seq_len, :].unsqueeze(0)# 扩展维度以匹配输入if x.dim() == 4:  # (batch, heads, seq_len, dim)sin = sin.unsqueeze(1)cos = cos.unsqueeze(1)# 重复sin和cos以匹配完整维度sin = sin.repeat_interleave(2, dim=-1)cos = cos.repeat_interleave(2, dim=-1)# 应用旋转return x * cos + self.rotate_half(x) * sin# 使用示例
rope = RoPEPositionalEncoding(d_model)
x_rope = torch.randn(batch_size, seq_len, d_model)
output_rope = rope(x_rope)
print(f"RoPE输出形状: {output_rope.shape}")

4.6 代码与公式的对应关系

  1. 预计算频率参数

    inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))
    self.register_buffer('inv_freq', inv_freq)
    

    这对应公式中的 θd=110000d/dmodel\theta_d = \frac{1}{10000^{d/d_{\text{model}}}}θd=10000d/dmodel1,用于生成不同维度的旋转频率。

  2. 预计算位置角度的sin和cos值

    def _build_cache(self, max_len):positions = torch.arange(max_len).float()angles = torch.outer(positions, self.inv_freq)sin_angles = torch.sin(angles)cos_angles = torch.cos(angles)self.register_buffer('sin_cached', sin_angles)self.register_buffer('cos_cached', cos_angles)
    
    • angles 矩阵对应 mθdm\theta_dmθd,即位置 mmm 在维度 ddd 上的旋转角度;
    • sin_cachedcos_cached 分别存储 sin⁡(mθd)\sin(m\theta_d)sin(mθd)cos⁡(mθd)\cos(m\theta_d)cos(mθd),避免重复计算。
  3. 向量旋转操作

    def rotate_half(self, x):x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]return torch.cat([-x2, x1], dim=-1)
    

    实现了 RotateHalf(x)\text{RotateHalf}(x)RotateHalf(x) 操作,将向量后半部分取负后与前半部分拼接。

  4. 应用旋转位置编码

    def forward(self, x, seq_len=None):# 获取对应位置的sin和cos值sin = self.sin_cached[:seq_len, :].unsqueeze(0)cos = self.cos_cached[:seq_len, :].unsqueeze(0)# 扩展维度以匹配输入if x.dim() == 4:  # (batch, heads, seq_len, dim)sin = sin.unsqueeze(1)cos = cos.unsqueeze(1)# 重复以匹配完整维度sin = sin.repeat_interleave(2, dim=-1)cos = cos.repeat_interleave(2, dim=-1)# 应用旋转:x * cos + rotate_half(x) * sinreturn x * cos + self.rotate_half(x) * sin
    

    这直接对应RoPE的核心公式:
    RoPE(x,m)=x⋅cos⁡(mθ)+RotateHalf(x)⋅sin⁡(mθ)\text{RoPE}(x, m) = x \cdot \cos(m\theta) + \text{RotateHalf}(x) \cdot \sin(m\theta)RoPE(x,m)=xcos(mθ)+RotateHalf(x)sin(mθ)

4.7 特性与优势

  1. 隐式相对位置编码
    RoPE通过旋转操作隐式地将相对位置信息融入注意力计算,使得模型能够更好地捕捉序列中的相对距离关系,优于传统的绝对位置编码。

  2. 旋转不变性
    RoPE保证了位置编码的旋转不变性,即对于任意向量 xxx 和位置偏移 kkk,有:
    RoPE(x,m)⋅RoPE(y,m+k)=RoPE(x,0)⋅RoPE(y,k)\text{RoPE}(x, m) \cdot \text{RoPE}(y, m+k) = \text{RoPE}(x, 0) \cdot \text{RoPE}(y, k)RoPE(x,m)RoPE(y,m+k)=RoPE(x,0)RoPE(y,k)
    这使得模型在不同位置上具有一致的表示能力。

  3. 无需额外参数
    RoPE不需要像可学习位置编码那样引入大量额外参数,只需预计算 sin⁡\sinsincos⁡\coscos 值,计算效率高。

  4. 长序列建模能力
    实验表明,RoPE在长序列任务(如长文本生成、文档级NLP)中表现优于Sinusoidal PE和绝对位置编码,能够更有效地捕捉远距离依赖关系。

  5. 兼容性强
    可以直接应用于现有的Transformer架构,无需修改模型的整体结构,易于集成到各种NLP系统中。

4.8 带RoPE的注意力机制

class RoPEMultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads, max_len=5000):super().__init__()self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_headsself.w_q = nn.Linear(d_model, d_model)self.w_k = nn.Linear(d_model, d_model)self.w_v = nn.Linear(d_model, d_model)self.w_o = nn.Linear(d_model, d_model)self.rope = RoPEPositionalEncoding(self.d_k, max_len)def forward(self, query, key, value, mask=None):batch_size, seq_len, _ = query.shape# 线性变换Q = self.w_q(query).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)K = self.w_k(key).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)V = self.w_v(value).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)# 应用RoPEQ = self.rope(Q)K = self.rope(K)# 计算注意力attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:attention_scores.masked_fill_(mask == 0, -1e9)attention_weights = torch.softmax(attention_scores, dim=-1)context = torch.matmul(attention_weights, V)# 重新整形context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)output = self.w_o(context)return output, attention_weights

5. 各种位置编码对比

5.1 特点对比

编码类型优点缺点适用场景
绝对位置编码简单直观,计算效率高对长序列泛化能力差固定长度序列
相对位置编码更好的泛化能力计算复杂度高需要处理可变长度序列
RoPE完美的长度外推能力实现相对复杂长序列,语言模型

5.2 性能测试代码

import timedef benchmark_position_encodings():batch_size, seq_len, d_model = 32, 512, 512num_heads = 8# 创建模型models = {'Sinusoidal': SinusoidalPositionalEncoding(d_model),'Learnable': LearnablePositionalEncoding(d_model),'RoPE': RoPEPositionalEncoding(d_model)}x = torch.randn(batch_size, seq_len, d_model)# 基准测试for name, model in models.items():start_time = time.time()for _ in range(100):with torch.no_grad():output = model(x)end_time = time.time()print(f"{name}: {(end_time - start_time) * 1000:.2f}ms")# 运行基准测试
benchmark_position_encodings()

6. 总结

位置编码是Transformer架构中的关键组件,不同类型的位置编码各有特点:

  1. 绝对位置编码:简单高效,适用于固定长度序列
  2. 相对位置编码:关注位置关系,泛化能力更强
  3. RoPE:通过旋转矩阵优雅地处理位置信息,支持长度外推

选择合适的位置编码方式需要根据具体应用场景和性能需求来决定。现代大语言模型(如GPT、LLaMA等)普遍采用RoPE,因为它在处理长序列时表现出色。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.tpcf.cn/bicheng/88981.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

网络安全初级第一次作业

一,docker搭建和挂载vpm 1.安装 Docker apt-get install docker.io docker-compose 2.创建文件 mkdir /etc/docker.service.d vim /etc/docker.service.d/http-proxy.conf 3.改写文件配置 [Service] Environment"HTTP_PROXYhttp://192.168.10.103:7890…

交换类排序的C语言实现

交换类排序包括冒泡排序和快速排序两种。冒泡排序基本介绍冒泡排序是通过重复比较相邻元素并交换位置实现排序。其核心思想是每一轮遍历将未排序序列中的最大(或最小)元素"浮动"到正确位置,类似气泡上升。基本过程是从序列起始位置…

嵌入式 Linux开发环境构建之Source Insight 的安装和使用

目录 一、Source Insight 的安装 二、Source Insight 使用 一、Source Insight 的安装 这个软件是代码编辑和查看软件,打开开发板光盘软件,然后右键选择以管理员身份运行这个安装包。在弹出来的安装向导里面点击 next ,如下图所示。这里选择…

【字节跳动】数据挖掘面试题0016:解释AUC的定义,它解决了什么问题,优缺点是什么,并说出工业界如何计算AUC。

文章大纲 AUC(Area Under the Curve)详解一、定义:AUC是什么?二、解决了什么问题?三、优缺点分析四、工业界大规模计算AUC的方法1. 标准计算(小数据)2. 工业级大规模计算方案3.工业界最佳实践4.工业界方案选型建议总结:AUC的本质AUC(Area Under the Curve)详解 一、…

Python后端项目之:我为什么使用pdm+uv

在试用了一段时间的uv和pdm之后,上个月(2025.06)开始,逐步把用了几年的poetry替换成了pdmuv(pipx install pdm uv && pdm config use_uv true) ## 为什么poetry -> pdm: 1. 通过ssh连接到服务器并使用poetry shell激活虚拟环境之…

鸿蒙Next开发,配置Navigation的Route

1. 通过router_map.json配置文件进行 创建页面配置router_map.json {"routerMap": [{"name": "StateExamplePage","pageSourceFile": "src/main/ets/pages/state/StateExamplePage.ets","buildFunction": "P…

在 GitHub 上创建私有仓库

一、在 GitHub 上创建私有仓库打开 GitHub官网 并登录。点击右上角的 “” → 选择 “New repository”。填写以下内容: Repository name:仓库名称,例如 my-private-repo。Description:可选,仓库描述。Visibility&…

量产技巧之RK3588 Android12默认移除导航栏状态栏​

本文介绍使用源码编译默认去掉导航栏/状态栏方法,以触觉智能EVB3588开发板演示,Android12系统,搭载了瑞芯微RK3588芯片,该开发板是核心板加底板设计,音视频接口、通信接口等各类接口一应俱全,可帮助企业提高产品开发效…

Conda 安装与配置详解及常见问题解决

《Conda 安装与配置详解及常见问题解决》 安装 Conda 有两种主流方式,分别是安装 Miniconda(轻量级)和 Anaconda(包含常用数据科学包)。下面为你详细介绍安装步骤和注意要点。 一、安装 Miniconda(推荐&a…

Linux ——lastb定时备份清理

lastb 命令显示的是系统中 /var/log/btmp 文件中的SSH 登录失败记录。你可以像处理 wtmp 那样,对 btmp 文件进行备份与清理。✅ 一、备份 lastb 数据cp /var/log/btmp /var/log/btmp.backup.$(date %F)会保存为如 /var/log/btmp.backup.2025-07-14✅ 二、清空 lastb…

自定义类型 - 联合体与枚举(百度笔试题算法优化)

目录一、联合体1.1 联合体类型的声明1.2 联合体的特点1.3 相同成员的结构体和联合体对比1.4 联合体大小的计算1.5 联合练习二、枚举类型2.1 枚举类型的声明2.2 枚举类型的优点总结一、联合体 1.1 联合体类型的声明 像结构体一样,联合体也是由一个或者多个成员构成…

FS820R08A6P2LB——英飞凌高性能IGBT模块,驱动高效能源未来!

产品概述FS820R08A6P2LB 是英飞凌(Infineon)推出的一款高性能、高可靠性IGBT功率模块,采用先进的EconoDUAL™ 3封装,专为大功率工业应用设计。该模块集成了IGBT(绝缘栅双极型晶体管)和二极管,适…

python学智能算法(十八)|SVM基础概念-向量点积

引言 前序学习进程中,已经对向量的基础定义有所了解,已经知晓了向量的值和方向向量的定义,学习链接如下: 向量的值和方向 在此基础上,本文进一步学习向量点积。 向量点积 向量点积运算规则,我们在中学阶…

【windows办公小助手】比文档编辑器更好用的Notepad++轻量编辑器

Notepad 中文版软件下载:这个路径总是显示有百度无法下载,不推荐 更新:推荐下载路径 https://github.com/notepad-plus-plus/notepad-plus-plus/releases 参考博主:Notepad的安装与使用

2025年7月12日全国青少年信息素养大赛图形化(Scratch)编程小学高年级组复赛真题+答案解析

2025年7月12日全国青少年信息素养大赛图形化(Scratch)编程小学高年级组复赛真题+答案解析 选择题 题目一 运行如图所示的程序,舞台上一共会出现多少只小猫呢?( ) A. 5 B. 6 C. 7 D. 8 正确答案: B 答案解析: 程序中“当绿旗被点击”后,角色先移到指定位置,然后“重…

对于独热编码余弦相似度结果为0和词向量解决了词之间相似性问题的理解

文章目录深入理解简单案例结论词向量(Word Embedding)简介词向量如何解决相似性问题?简单案例:基于上下文的词向量训练总结对于独热表示的向量,如果采用余弦相似度计算向量间的相似度,可以明显的发现任意两…

数据结构·数状数组(BIT)

树状数组(Binary Index Tree) 英文名:使用二进制下标的树结构 理解:这个树实际上用数组来存,二进制下标就是将正常的下标拆为二进制来看。 求x的最低位1的函数lowbit(x) 假设x的二进制表示为x ...10000,…

uniapp video视频全屏播放后退出,页面字体变大,样式混乱问题

uniapp官方的说法是因为页面使用rpx,但是全屏和退出全屏自动计算屏幕尺寸不支持rpx,建议使用px。但是因为uniapp端的开发都是使用rpx作为屏幕尺寸计算参数,不可能因为video全屏播放功能就整个全部修改,工作量大,耗时耗…

重复频率较高的广告为何一直在被使用?

在日常生活中,重复评率较高的洗脑广告我们时常能够碰到。广告的本质是信息传递,而重复频率较高的广告往往可以通过洗脑式的传播方式来提升传播效率。下面就让我们一同来了解下,为何这类广告一直受到企业的青睐。一、语义凝练高频率广告的内容…

内容管理系统指南:企业内容运营的核心引擎

内容管理看似简单,实际上随着内容量的激增,管理难度也逐步提升。尤其是在面对大量页面、图文、视频资料等数字内容时,没有专业工具的支持,效率与准确性都会受到挑战。此时,内容管理系统(CMS)应运…