论文阅读:speculative decoding

Fast Inference from Transformers via Speculative Decoding

论文地址:https://arxiv.org/pdf/2211.17192

speculative sampling

为了从分布 p ( x ) p(x) p(x) 中采样,我们实际上是从分布 q ( x ) q(x) q(x) 中采样 x x x,如果 q ( x ) ≤ p ( x ) q(x) \leq p(x) q(x)p(x),则保留该样本;如果 q ( x ) > p ( x ) q(x) > p(x) q(x)>p(x),则以概率 1 − p ( x ) q ( x ) 1 - \frac{p(x)}{q(x)} 1q(x)p(x) 拒绝该样本,并重新从调整后的分布 p ′ ( x ) = norm ( max ⁡ ( 0 , p ( x ) − q ( x ) ) ) p'(x) = \text{norm}(\max(0, p(x)-q(x))) p(x)=norm(max(0,p(x)q(x))) 中采样。对于任何分布 p ( x ) p(x) p(x) q ( x ) q(x) q(x),以及以此方式采样的 x x x,确实有 x ∼ p ( x ) x \sim p(x) xp(x)

给定通过在条件前缀上运行 M q M_q Mq 获得的分布 q ( x ) q(x) q(x),我们可以采样一个标记 x 1 ∼ q ( x ) x_1 \sim q(x) x1q(x)。然后,我们通过在前缀上运行 M p M_p Mp 来计算分布 p ( x ) p(x) p(x),同时并行地推测性地计算下一个标记 x 2 x_2 x2 的分布,即在前缀上追加 x 1 x_1 x1 后运行 M p M_p Mp。一旦两项计算都完成,我们就按上述方式处理:如果 x 1 x_1 x1 被拒绝,我们丢弃 x 2 x_2 x2 的计算,并从调整后的分布中重新采样 x 1 x_1 x1;如果 x 1 x_1 x1 被接受,我们就保留两个标记。算法 1 将这一想法推广为一次采样 1 到 γ + 1 \gamma + 1 γ+1 个标记。
运行算法

分析

有几个证明需要注意一下:

单次算法期望能生成的token
  1. 单次算法期望能生成的token数量服从几何分布,但是求和项是有限制的,这里推导下​

  2. ​接受率β的定义​
    设目标模型分布为 p(x),草稿模型分布为 q(x)。草稿模型生成的单个token被目标模型接受的概率为:

β = ∑ x min ⁡ ( q ( x ) , p ( x ) ) \beta = \sum_x \min\left(q(x), p(x)\right) β=xmin(q(x),p(x))

  1. ​拒绝率α的定义​

α = 1 − β = 1 − ∑ x min ⁡ ( p ( x ) , q ( x ) ) x \alpha = 1 - \beta = 1 - \sum_x \min(p(x), q(x)) x α=1β=1xmin(p(x),q(x))x

  • 假设每个token的接受事件独立且同分布(i.i.d.),草稿模型一次生成 K 个token:

  • ​首次拒绝发生在位置 r​ 的概率为:

    P ( r ) = ( 1 − β ) β r − 1 ( 1 ≤ r ≤ K ) P(r) = (1-\beta) \beta^{r-1} \quad (1 \leq r \leq K) P(r)=(1β)βr1(1rK)

    所有token均被接受​​ 的概率为: β K \beta^K βK

  • 综上期望能生成的token数量为:

    γ = ∑ r = 1 K r ⋅ P ( r ) ⏟ 拒绝前生成的token + K ⋅ β K ⏟ 全接受时生成K个token \gamma = \underbrace{\sum_{r=1}^K r \cdot P(r)}_{\text{拒绝前生成的token}} + \underbrace{K \cdot \beta^K}_{\text{全接受时生成K个token}} γ=拒绝前生成的token r=1KrP(r)+全接受时生成Ktoken KβK

代入 P ( r ) P(r) P(r) 后展开:

γ = ∑ r = 1 K r ⋅ ( 1 − β ) β r − 1 + K β K \gamma = \sum_{r=1}^K r \cdot (1-\beta) \beta^{r-1} + K \beta^K γ=r=1Kr(1β)βr1+KβK

  1. 几何级数求和​

几何级数求和公式为:

∑ r = 1 K r β r − 1 \sum_{r=1}^K r \beta^{r-1} r=1Krβr1 求和处理:

  • ​令 S = ∑ r = 1 K β r − 1 S = \sum_{r=1}^K \beta^{r-1} S=r=1Kβr1​:

S = 1 + β + β 2 + ⋯ + β K − 1 = 1 − β K 1 − β S = 1 + \beta + \beta^2 + \cdots + \beta^{K-1} = \frac{1-\beta^K}{1-\beta} S=1+β+β2++βK1=1β1βK

  • ​对 S S S 求导​​:

∑ r = 1 K r β r − 1 = d d β ( ∑ r = 0 K β r ) = d d β ( 1 − β K + 1 1 − β ) = 1 − ( K + 1 ) β K + K β K + 1 ( 1 − β ) 2 \sum_{r=1}^K r \beta^{r-1} = \frac{d}{d\beta} \left( \sum_{r=0}^K \beta^r \right) = \frac{d}{d\beta} \left( \frac{1-\beta^{K+1}}{1-\beta} \right) = \frac{1 - (K+1)\beta^K + K\beta^{K+1}}{(1-\beta)^2} r=1Krβr1=dβd(r=0Kβr)=dβd(1β1βK+1)=(1β)21(K+1)βK+KβK+1

  • ​代入γ表达式​​:

γ = ( 1 − β ) ⋅ 1 − ( K + 1 ) β K + K β K + 1 ( 1 − β ) 2 + K β K = 1 − ( K + 1 ) β K + K β K + 1 1 − β + K β K \gamma = (1-\beta) \cdot \frac{1 - (K+1)\beta^K + K\beta^{K+1}}{(1-\beta)^2} + K\beta^K = \frac{1 - (K+1)\beta^K + K\beta^{K+1}}{1-\beta} + K\beta^K γ=(1β)(1β)21(K+1)βK+KβK+1+KβK=1β1(K+1)βK+KβK+1+KβK

  • 化简​​:

γ = 1 − β K 1 − β \gamma = \frac{1 - \beta^K}{1-\beta} γ=1β1βK

​物理意义​​:

  • K → ∞ K \to \infty K时, γ → 1 1 − β = 1 α \gamma \to \frac{1}{1-\beta} = \frac{1}{\alpha} γ1β1=α1(理想无限长草稿)。
  • 例如 β \beta β = 0.8` 时, γ max = 5 \gamma_{\text{max}} = 5 γmax=5,即平均每次生成5个token。

得证

Walltime的时间优化

​定理 3.8​​:算法 1 在总运行时间上的预期改进因子为
‘ 1 − α γ + 1 ( 1 − α ) ( γ c + 1 ) ‘ `\frac{1 - \alpha^{\gamma + 1}}{(1 - \alpha)(\gamma c + 1)}` (1α)(γc+1)1αγ+1

​证明​​:
记运行目标模型 M p M_p Mp​单步​​的成本为 T T T
算法 1 的​​单次运行成本​​为 T c γ + T Tc\gamma + T Tcγ+T(其中 c γ T c\gamma T cγT用于运行近似模型 M q M_q Mq γ \gamma γ 次, T T T 用于运行 M p M_p Mp 一次)。
根据单次算法期望能生成的token算法推导,单次运行​​平均生成 token 数量​​为 1 − α γ + 1 1 − α \dfrac{1 - \alpha^{\gamma + 1}}{1 - \alpha} 1α1αγ+1
因此,使用算法 1 生成单个 token 的​​总体预期成本​​为:
( c γ + 1 ) ( 1 − α ) 1 − α γ + 1 T ‘ \frac{(c\gamma + 1)(1 - \alpha)}{1 - \alpha^{\gamma + 1}}T` 1αγ+1(cγ+1)(1α)T
由于标准解码算法生成单个 token 的成本为 T
比较可得上述改进因子。∎
(注:符号 “∎” 表示证明结束)


关键术语说明:

英文术语中文翻译符号含义
walltime总运行时间-算法从启动到结束的时钟时间
expected improvement factor预期改进因子-优化后时间开销的缩减比例
cost per step单步成本 T T T目标模型 M p M_p Mp 推理一个 token 的时间
approximation model近似模型 M q M_q Mq快速但低精度的草稿模型
tokens标记(Token)-模型生成的基本文本单位
rejection rate拒绝率 α \alpha α草稿模型 M q M_q Mq 的 token 被目标模型 M p M_p Mp 拒绝的概率
γ \gamma γ生成长度 γ \gamma γ草稿模型单次运行的 token 生成数
cost ratio成本比 c c c M q M_q Mq M p M_p Mp 的单步时间比值( 0 < c < 1 0 < c < 1 0<c<1

公式解析:

  1. ​改进因子​
    1 − α γ + 1 ( 1 − α ) ( γ c + 1 ) \frac{1 - \alpha^{\gamma + 1}}{(1 - \alpha)(\gamma c + 1)} (1α)(γc+1)1αγ+1
  • ​分子​ 1 − α γ + 1 1 - \alpha^{\gamma+1} 1αγ+1:草稿模型连续生成 \gamma 个 token 均未被拒绝的概率补偿
  • ​分母​ ( 1 − α ) (1-\alpha) (1α):单 token 接受率, γ c + 1 \gamma c + 1 γc+1:草稿+验证的总时间成本

该值 ​​>1​​ 时表示加速,值越大加速效果越显著

  1. ​单 token 成本公式​
    ( c γ + 1 ) ( 1 − α ) 1 − α γ + 1 T \frac{(c\gamma+1)(1-\alpha)}{1-\alpha^{\gamma+1}}T 1αγ+1(cγ+1)(1α)T
  • ​分子​ ( c γ + 1 ) ( 1 − α ) T (c\gamma+1)(1-\alpha)T (cγ+1)(1α)T:草稿生成+验证的实际计算量
  • ​分母​ 1 − α γ + 1 1-\alpha^{\gamma+1} 1αγ+1:有效 token 产出的概率加权
操作数计算

操作数的计算量也是类似的,直接贴结论了

( 1 − α ) ( γ c ^ + γ + 1 ) 1 − α γ + 1 \frac{(1-\alpha)(\gamma \hat{c}+\gamma+1)}{1-\alpha^{\gamma+1}} 1αγ+1(1α)(γc^+γ+1)

采样和原分布的等价性证明

参考https://arxiv.org/pdf/2302.01318
其中需要一步代换证明下面两个公式等价:

原始公式

第一个公式:
= 1 − ∑ x ′ min ⁡ ( p ( x ′ ) , q ( x ′ ) ) =1-\sum_{x^{\prime}}\min\left(p\left(x^{\prime}\right),q\left(x^{\prime}\right)\right) =1xmin(p(x),q(x))

第二个公式:
= ∑ x ′ max ⁡ ( 0 , q ( x ′ ) − p ( x ′ ) ) =\sum_{x^{\prime}}\max\left(0,q\left(x^{\prime}\right)-p\left(x^{\prime}\right)\right) =xmax(0,q(x)p(x))

推导步骤

步骤 1: 应用 min 函数的恒等式

对于任何两个实数 a a a b b b,都存在以下恒等关系:
min ⁡ ( a , b ) = a − max ⁡ ( 0 , a − b ) \min(a,b) = a - \max(0, a - b) min(a,b)=amax(0,ab)

b = p ( x ′ ) b = p(x') b=p(x) a = q ( x ′ ) a = q(x') a=q(x),得到:
min ⁡ ( p ( x ′ ) , q ( x ′ ) ) = q ( x ′ ) − max ⁡ ( 0 , q ( x ′ ) − p ( x ′ ) ) \min(p(x'),q(x')) = q(x') - \max(0, q(x') - p(x')) min(p(x),q(x))=q(x)max(0,q(x)p(x))

步骤 2: 代入第一个公式

将恒等式代入原始公式:
1 − ∑ x ′ min ⁡ ( p ( x ′ ) , q ( x ′ ) ) = 1 − ∑ x ′ [ q ( x ′ ) − max ⁡ ( 0 , q ( x ′ ) − p ( x ′ ) ) ] \begin{aligned} &1 - \sum_{x^{\prime}} \min(p(x'),q(x')) \\ &= 1 - \sum_{x^{\prime}} \left[ q(x') - \max(0, q(x') - p(x')) \right] \end{aligned} 1xmin(p(x),q(x))=1x[q(x)max(0,q(x)p(x))]

步骤 3: 拆分求和运算

将求和符号分配到表达式内部:
= 1 − [ ∑ x ′ p ( x ′ ) − ∑ x ′ max ⁡ ( 0 , p ( x ′ ) − q ( x ′ ) ) ] = 1 - \left[ \sum_{x^{\prime}} p(x') - \sum_{x^{\prime}} \max(0, p(x') - q(x')) \right] =1[xp(x)xmax(0,p(x)q(x))]
= 1 − ∑ x ′ q ( x ′ ) + ∑ x ′ max ⁡ ( 0 , q ( x ′ ) − p ( x ′ ) ) = 1 - \sum_{x^{\prime}} q(x') + \sum_{x^{\prime}} \max(0, q(x') - p(x')) =1xq(x)+xmax(0,q(x)p(x))

步骤 4: 应用概率分布性质

因为 p p p q q q 都是概率分布函数,满足:
∑ x ′ p ( x ′ ) = 1 和 ∑ x ′ q ( x ′ ) = 1 \sum_{x^{\prime}} p(x') = 1 \quad \text{和} \quad \sum_{x^{\prime}} q(x') = 1 xp(x)=1xq(x)=1

代入表达式:
= 1 − 1 + ∑ x ′ max ⁡ ( 0 , q ( x ′ ) − p ( x ′ ) ) = 1 - 1 + \sum_{x^{\prime}} \max(0, q(x') - p(x')) =11+xmax(0,q(x)p(x))
= ∑ x ′ max ⁡ ( 0 , q ( x ′ ) − p ( x ′ ) ) = \sum_{x^{\prime}} \max(0, q(x') - p(x')) =xmax(0,q(x)p(x))

得证

Reference

https://arxiv.org/pdf/2211.17192

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.tpcf.cn/diannao/86819.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java操作word里的表格

依赖&#xff1a; <dependency><groupId>com.techCoLtd</groupId><artifactId>aspose-words-16.4.0-jdk16</artifactId><classifier>jdk16</classifier> </dependency>/*** 删除表格及表格的行* throws Exception*/ private s…

单链表经典算法题之分割链表

给定一个头结点和一个值x&#xff0c;是链表中所有小于x的值都在x前面 typedef struct ListNode ListNode; struct ListNode* partition(struct ListNode* head, int x) { //思路一&#xff1a;在原链表上进行修改 //思路二&#xff1a;创建新链表&#xff0c;使用哨兵位&…

Modbus TCP转DeviceNet网关连接ABB变频器配置案例

某工厂需要将支持Modbus TCP协议的上位机控制系统&#xff08;如PLC或SCADA&#xff09;与支持DeviceNet协议的变频器&#xff08;如ABB ACS880、施耐德ATV320等&#xff09;进行通信。为实现协议转换&#xff0c;采用开疆智能Modbus TCP转DeviceNet网关KJ-DVCZ-MTCPS作为中间设…

【力扣 简单 C++】206. 反转链表

目录 题目 解法一&#xff1a;迭代 解法二&#xff1a;递归 题目 待添加 解法一&#xff1a;迭代 class Solution { private:ListNode* reverse(ListNode* head){ListNode* newHead {};while (head){ListNode* nextNode {head->next};head->next newHead;newHead …

计算机视觉之三维重建(深入浅出SfM与SLAM核心算法)—— 1. 摄像机几何

文章目录 1. 针孔相机1.1. 针孔成像1.2. 光圈对成像的影响 2. 透视投影相机2.1. 透镜成像2.2. 失焦2.3. 径向畸变2.4. 透视投影的性质 3. 世界坐标系到像素坐标系的变换4. 其它相机模型4.1. 弱透视投影摄像机4.2. 正交投影摄像机4.3. 各种摄像机模型的应用场合 课程视频链接&am…

第十三节:第七部分:Stream流的中间方法、Stream流的终结方法

Stream流常见的中间方法 Stream流常见的终结方法 代码 学生类&#xff08;代码一与代码二共涉及到的类&#xff09; package com.itheima.day28_Stream;import java.util.Objects;public class Student implements Comparable<Student> {private String name;private i…

深入理解 Go 中的字节序(Endianness)检测代码

深入理解 Go 中的字节序&#xff08;大小端&#xff09;检测代码 在计算机系统中&#xff0c;字节序&#xff08;Endianness&#xff09; 是指多字节数据类型&#xff08;如 int16、int32 等&#xff09;在内存中的存储顺序。Go 语言标准库提供了对大端&#xff08;Big-endian&…

JAVA:RabbitMQ 消息持久化机制的技术指南

🐇 1、简述 在使用 RabbitMQ 构建可靠消息系统时,消息丢失是必须避免的问题。为此,RabbitMQ 提供了消息持久化机制(Message Durability),可以保障在 Broker 异常宕机后数据不会丢失。 本篇博客将从原理出发,结合 Spring Boot 实战讲解如何正确实现 RabbitMQ 消息持久…

tabs页签嵌套表格,切换表格保存数据不变并回勾

需求&#xff1a;点击左边的tab页签&#xff0c;请求右侧表格数据&#xff1b;如果返回的接口数据存在taskuser字段并不为null&#xff0c;那么按照这个字段去回勾数据。如果存在数据&#xff0c;但与后面所勾选的数据项不同&#xff0c;按照后面勾选的为主。 <el-tabs tab-…

Java Kafka消费者

基础 Java Kafka消费者主要通过以下核心类实现&#xff1a; KafkaConsumer&#xff1a;消费者的核心类&#xff0c;用于创建消费者对象进行数据消费1ConsumerConfig&#xff1a;获取各种配置参数&#xff0c;如果不配置就使用默认值1ConsumerRecord&#xff1a;每条数据都要封…

Git操作问题及解决方案-记录5

Git操作问题及解决方案 问题一&#xff1a;本地更改与远程更新冲突 问题描述 当本地文件有未提交的更改&#xff0c;同时远程仓库也有更新时&#xff0c;执行git pull会导致冲突。 $ git pull origin main error: Your local changes to the following files would be overw…

一[3]、ubuntu18.04环境 利用 yolov8 训练开源列车数据集,并实现列车轨道检测

一、开源车载数据集地址 (7 封私信) 轨道交通数据集-OSDaR23: Open Sensor Data for Rail 2023 - 知乎 二、参考资料 https://zhuanlan.zhihu.com/p/692608487 YOLOv8训练自己的数据集-CSDN博客 https://download.csdn.net/blog/column/12710137/140991739

C语言数据结构笔记5:Keil 编译器优化行为_malloc指针内存分配问题

记录俩个keil5 STM32 的c语言编程中 &#xff0c;编译器优化行为 和 指针内存分配问题。 目录 关闭Keil 编译器优化行为&#xff1a; malloc指针内存分配问题 多层嵌套的结构体&#xff1a; 用指针取值&#xff1a; 发现问题&#xff1a; 解决问题&#xff1a; 示例代码 关闭Ke…

每日八股文6.12

每日八股-6.12 计算机网络1.当我们在浏览器中输入一个 URL 并按下回车后&#xff0c;到页面最终显示出来&#xff0c;这中间都发生了哪些关键步骤&#xff1f;2.请简述一下 JWT&#xff08;JSON Web Tokens&#xff09;的原理和校验机制3.DNS 是如何进行域名解析的&#xff1f;…

什么是云计算的边缘原生应用?

关于作者&#xff1a;John Bradshaw阿卡迈公司欧洲、中东和非洲地区云计算技术与战略总监 当谈及云计算时&#xff0c;人们往往会联想到那些坐落于国际大都会核心地带的大型数据中心集群&#xff0c;这些设施作为数字时代的重要枢纽&#xff0c;承载着海量数据处理任务。尽管这…

Linux常用命令速查与面试高频命令总结

&#x1f427; Linux常用命令速查与面试高频命令总结 本文旨在帮助初学者快速掌握 Linux 的常用命令&#xff0c;同时为即将参加技术面试的朋友们提供一份高频命令清单和实用技巧。 &#x1f530; 一、基础命令&#xff1a;熟练使用命令行从这里开始 这些是你在 Linux 中最常用…

基础测试工具使用经验

背景 vtune&#xff0c;perf, nsight system等基础测试工具&#xff0c;都是用过的&#xff0c;但是没有记录&#xff0c;都逐渐忘了。所以写这篇博客总结记录一下&#xff0c;只要以后发现新的用法&#xff0c;就记得来编辑补充一下 perf 比较基础的用法&#xff1a; 先改这…

浅谈DaemonSet

1. DaemonSet 概述 ‌定义‌&#xff1a;DaemonSet 确保 Kubernetes 集群的每个节点上运行一个 Pod 实例。‌特性‌&#xff1a; 每个节点上只有一个 Pod 实例。新节点加入集群时&#xff0c;会自动在新节点上创建 Pod。旧节点被删除时&#xff0c;其上的 Pod 会被回收。 2.…

计算机系统(6)

◆指令寻址方式&#xff1a; 顺序寻址方式&#xff1a;执行一段程序时&#xff0c;是一条指令接着一条指令的顺序执行。 跳跃寻址方式:下一条指令的地址码不是由程序计数器给出&#xff0c;而是由本条指令直接给出。程序跳跃后&#xff0c;按新的指令地址开始顺序执行。因此&…

基于服务器使用 apt 安装、配置 Nginx

&#x1f9fe; 一、查看可安装的 Nginx 版本 首先&#xff0c;你可以运行以下命令查看可用版本&#xff1a; apt-cache madison nginx-core输出示例&#xff1a; nginx-core | 1.18.0-6ubuntu14.6 | http://archive.ubuntu.com/ubuntu focal-updates/main amd64 Packages ng…