前言
最近有同学在面试中被问到了 vLLM 的 PagedAttention,这篇文章带大家了解其核心原理。
1、数据内存布局
(1)KV cache 整体内存布局
将所有可用于 kv cache 的内存按块进行预先分配,每一块的大小为 block_size * num_kv_heads * head_size,即一个 block 存放 block_size 个 token。
假如:
- num_kv_heads = 2
- head_size=64
- block_size=16
kv cache 的数据类型是 FP32。
block 的数量,由可用于 kv_cache 的内存大小决定的。
(2)单个 block 内存布局
整体布局以相同维度连续存储的方式。
key cache:
- 将每个 block 分为 num_kv_heads 份,num_kv_heads = 2 即分为两份。
- 将 head_size 按 16 字节(单个线程操作的最大 size)大小给拆分, 比如 head_size=64、FP32 类型,即每 4 个维度进行拆分,即 16 份。
- 按 16 字节大小的维度为单位,依次按 token 进行存储。
如图是 key cache 中的一个 block 的内存布局,如果多个 head ,依次往后延伸。
value cache
与 key cache 相比,依然采用相同维度连续存储的方式,只是不需要再按 16 字节拆分。
如图是 value cache 中的一个 block 的内存布局,如果多个 head,依次往后延伸。
2、CUDA 线程模型
要了解 paged attention 算子,该线程模型非常重要。
如上图,假设num_kv_heads = 8:
- 线程模型中的每一行(Y 方向)处理 query 的一个 seq 序列。
- X 方向的数量为 kv cahe 的头的数量,每个线程块处理一个头维度
- 每个线程块为 128 个线程
- 虽然每个线程块处理一个头维度,但并不是每个线程块处理一个头,而是线程块中的一个线程束处理一个头,比如,线程块 1 中的线程束 1 处理 block1 中的 head1,线程束 2 处理 block2 的 head1,依次轮推。
3、加载 query 到共享内存
通过上面的线程模型,可知线程块的每个头由一个线程束处理,而每个头里面 block_size 的数量即为 block 中 token 的数量。
理想情况下,是一个线程处理一个 token,因为根据 Q@K 是一个标量,它需要先乘,再归约求和两个环节(虽然两个环节可以合并处理, 但是逻辑上是分为两个环节)。
但是每个线程束 warp_size 是 32 个线程,而 block_size=16,并不满足 1:1 的关系。
所以,我们需要将 warp_size/block_size = 2,即每个 token 由 2 个线程处理,thread_group_size = 2。
为了达到理想的内存访问:访问合并&对齐访问,对于 key cache 的访问如下。
因此,在加载 query 的时候,以下面的这种方式,假如 query 的一个头向量为(k1,k2,k3,k4 … k64), 其内存布局如下:
注意,query 的加载是加载到静态共享内存中。
举个例子:在第一轮循环的时候:
- 线程 1 读取 k1、k2 和 token1 的 d1、d2,计算注意力分数
- 线程 2 读取 k3、k4 和 token1 的 d3、d4,计算注意力分数
- 线程 3 读取 k1、k2 和 token2 的 d1、d2,计算注意力分数
- 线程 4 读取 k3、k4 和 token2 的 d3、d4,计算注意力分数
如此, 我们即可实现 key cache 的连续读取。
4、加载 key 到寄存器
对于 key cache 并不是直接读取内存中的数据进行计算,而是先加载到寄存器中,再进行计算。
由上面加载 query 的时候,我们可以看到,是两个线程负责一个 token 的计算,所以,每个线程加载自身的一对应数据到自己线程的寄存器中即可,已经做好了内存访问连续和对齐。
以线程 1 计算 token1 为例子,它加载到寄存器中的数据如下:
(d1,d2,d5,d6,d9,d10 ... d61,d62)
线程 2 加载的数据如下:
(d3,d4,d7,d8,d11,d12 ... d63,d64)
5、计算Q@K写入共享内存logits中
计算 Q@K 的过程, 对于线程内,进行边计算边求和,当线程内计算完成,在线程组之间,使用 __shfl_xor_sync 进行归约。
最终的结果写入到动态共享内存 logits 中,其中 logits 的 key 为全局 token_index,值为注意力权重。
同时记录一个 qk_max,用于防止 softmax 计算过程中出现数据溢出的情况,qk_max 的值也是通过线程间 __shfl_sync 得到最大值。
有个小技巧,在线程束内归约使用 __shfl 相关函数,线程数之间规约完成后,需要在 block 块的线程束之间规约。
可以将线程束之间规约的结果写入到共享内存中,然后用一个线程束取对共享内存中的数据进行规约,这样就可以使用线程束内之间的函数 __shfl 了。
6、计算softmax
如下:
- 每个线程读取 logits 中自己线程所计算的 token 注意力权重,并计算 __expf(logits[i] - qk_max);而后求和得到 sum。
- 通过 __shfl 相关函数将线程束内的 sum 进行归约。
- 跨线程束之间的规约通过写入共享内存,再使用一个线程束进行归约,最后得到 exp_sum。
- 计算 softmax 的系数部分 inv_sum = __fdividef(1.f, exp_sum + 1e-6f)
- 每个线程读取 logits 中自己线程所计算的 token 注意力权重,使用 logits[i] *= inv_sum;计算并写回 logits
7、计算logits@V到寄存器
logits 中按 token_index 存储每个 token 的注意力权重,根据 value cache 中的内存布局。
我们的目标是:每一个线程以单次最大可操作读内存(16 字节)value 中的数据,假如我们 value cache 中的数据是 FP32,那么我们每次可以读取 4 个数据。
如图展示在一个 value cache 头内,线程与 v cache 之间的对应关系。
以线程 1 计算关系为例子:
(1)线程 1 一次性读取 token0_d1,token1_d1,token2_d2,token3_d3,然后从 logits 中得到 4 个 token 的注意力权重值 token0_weight、token1_weight、token2_weight、token3_weight
(2)计算:ret1 = token0_d1 * token0_weight + token1_d1 * token1_weight + token2_d2 * token2_weight + token3_d3 * token3_weight
(3)将第 2 步计算得到的值写入到寄存器中 accs[0] = ret1
(4)同样可以计算线程 1 在第 9 个维度上的结果 accs[1] = token0_d9 * token0_weight + token1_d9 * token1_weight + token2_d9 * token2_weight + token3_d9 * token3_weight
(5)对于 value cache 不同块之间,每个线程计算的结果进行累加,比如对于第 2 块 cache,当线程 1 计算第 1 个维度的时候,计算的结果,可以直接与 accs[0] 进行累加。
最后在将所有结果进行归约处理后,得到最终的 attention score。
8、总结
如下:
- 对于 query 是直接加载进入到静态共享内存中计算
- 对于 key cache 中的数据,按 token 维度进行计算,并且是先加载入寄存器中,再进行计算
- Q@K 的结果写入到共享内存中,并不占用全局内存(这也是为什么在线推理不适用 flash attention 的原因)
- Value cache 在计算的时候,直接读取全局内存中的数据进行计算,并没有做任何加载
- 线程内所有连续内存的访问都使用了向量化访存,极大的利用带宽并减少指令数量
- 每个线程与 Value cache 计算后的注意力分数,使用线程的寄存器进行存储
- 所有归约的过程,都使用线程束内 __shfl 相关操作完成,跨线程束使用共享内存中转
- Paged attention 也巧妙的数据结构和线程模型,充分的利用了 gpu 资源,达到高效的计算同时实现了灵活的存储
9、Q&A
1.对于 key cache,进行相同维度连续存储的时候,为什么要按 16 字节划分?
答: 进行划分是为了在 Q@K 的时候实现内存的连续访问,以达到访问合并的目的,以 16 字节的原因是:Q@K 的时候,如果每一个线程处理一个 token,那么一个线程单次可操作的内存是 16 字节,所以这个值不能大于 16 字节,否则很有可能出现内存访问不连续。
2.对于每个序列 kv cache 长度不一致,可能会导致的计算耗被最长的 kv cache 给拖慢?
答: 这也就是 v2 版本解决的问题,主要是通过分区的方式优化。
Paged attention 的模拟测试数据:
数据特征:
query 的数据依次增长,query_head1:[0,1,2…63],query_head2:[64,65…127]
Key cache 和 value cache 的数据也一样 key_cache:token1_head1:[0,1,2…63],token1_head2:[64,65…127],token2_head1:[128,129,130…191],token2_head2:[192,193…255]
该数据在调试 paged attention 的过程中方便识别计算中的数据存取关系:
num_seqs = 1 # 序列的数量 batch_sizenum_heads = 2num_kv_heads = 2head_size = 64block_size = 16 # 表示每个block存放的token总数num_blocks = 2max_seq_len = 32device = torch.device("cuda")def create_test_inputs(): query = torch.arange(0, num_seqs * num_heads * head_size, dtype=torch.float32, device=device).reshape(num_seqs, num_heads, head_size) value_cache = torch.arange(0, num_blocks * num_kv_heads * head_size * block_size, dtype=torch.float32, device=device).reshape(num_blocks, block_size, num_kv_heads, head_size).permute(0,2,3,1).contiguous() key_cache = torch.arange(0, num_blocks * num_kv_heads * head_size * block_size, dtype=torch.float32, device=device).reshape(num_blocks, block_size, num_kv_heads, head_size // 4, 4).permute(0,2,3,1,4).contiguous() max_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = torch.zeros((num_seqs, max_blocks_per_seq), dtype=torch.int32, device=device) # 只用第一个block seq_lens = torch.tensor([block_size, block_size], dtype=torch.int32, device=device) # 使用全部 16 token return query, key_cache, value_cache, block_tables, seq_lens