🌟 Gemma 4:Drafter 详解
Google Gemma 团队刚发的文章,介绍了Gemma 4 为提升推理速度而引入的 drafter 草稿模型机制。该技术不再完全依赖大型目标模型一个 token 一个 token 地生成内容,而是让一个更小更快的草稿模型先提前预测多个 token,再由目标模型并行验证这些预测,从而显著减少推理时的计算开销。下面是文章翻译。原文在Google Gemma的推上。
------------------------
为了提升 Gemma 4 模型的推理速度,Gemma 4 主系列模型之外还发布了一组新的自回归 “drafter” 模型。这里的主模型被称为 “target model”,也就是目标模型;而 drafter 草稿模型可以在目标模型处理一个 token 的时间里,提前预测多个 token。这种技术也被称为推测解码,即 speculative decoding。
在 drafter 预测出多个草稿 token 之后,目标模型只需要验证这些被建议的草稿 token。验证过程可以并行完成,因此能显著加快推理速度。它减少了目标模型为了生成每个 token 所需要执行的前向传播次数。由于 drafter 会生成一串 token 供目标模型验证,所以我们也把它称为 Multi-Token Prediction,MTP,多 token 预测头。
Gemma 4 系列发布的草稿模型体积较小,并引入了若干增强设计,以提高草稿 token 的质量,并进一步加速推理。例如,它会利用目标模型的激活值和 KV cache 来获得更好的预测结果。
这些增强带来了显著的解码加速,同时仍然能保证相近的生成质量。因此,这些 checkpoint 非常适合低延迟和端侧应用场景。
[ 图1 ]
这里面有很多内容需要拆解。下面我们依次讲解推测解码、MTP,以及 drafter 的设计。
🌟 什么是推测解码?
Gemma 4 模型以自回归方式生成文本,也就是一次生成一个 token。无论某个 token 是容易预测还是难以预测,每生成一个 token 大致都需要相同的计算量。因此,当某些 token 很容易预测时,这个过程就可能显得不必要地缓慢。
假设一个较大的模型正在生成文本,并且已经生成了 “Actions speak”。熟悉英语谚语的人会知道,这句话常见的完整表达是 “Actions speak louder than words.”,意思是“事实胜于雄辩”。由于这句话非常常见,小模型很可能生成与大模型完全相同的后续内容,也就是 “louder than words”。在这种情况下,让大模型一个 token 一个 token 地预测 “louder than words” 就会浪费时间和计算资源。
通过推测解码,我们可以使用一个更小的草稿模型提前预测多个 token。草稿模型接收同样的输入 “Actions speak”,然后仍然以自回归方式预测若干 token,比如四个 token。由于草稿模型的规模只有大模型的一小部分,因此它生成这些草稿 token 的速度会快得多。
[图2]
🌟 什么是多 token 预测?
不过,草稿模型生成的 token 并不一定正确,否则我们直接使用这个小模型就可以了。因此,这些 token 会被交给目标模型并行验证。由于目标模型可以在一次前向传播中完成验证,它就不需要为每个 token 单独执行一次前向传播。
drafter 就是我们所说的 Multi-Token Prediction Head,MTP Head,多 token 预测头。目标模型的每一次前向传播都会执行常规的 next-token prediction,也就是下一个 token 预测,并产生中间隐藏状态。drafter,也就是 MTP Head,会使用这些隐藏状态,并执行若干次自回归前向传播,从而生成多个 token。
因此,目标模型的一次前向传播不再只得到一个 token,而是能得到多个 token:其中一个来自目标模型自身的下一个 token 预测,另外多个则来自 drafter,也就是 MTP Head。
如果目标模型认可草稿模型给出的建议,那么所有草稿 token 都会被接受。原本需要目标模型逐个生成的四个 token,现在由小模型在很短时间内生成,目标模型只需要花费相当于生成一个 token 的时间来验证它们。并且,如果所有草稿 token 都被接受,目标模型本身还会额外生成一个新的 token。
如果目标模型只不同意部分草稿 token,它会接受前面那些它认可的 token,直到遇到第一个不认可的 token。随后,目标模型会用自己生成的 token 替换被拒绝的那个 token。
[图3]
这个过程实际上非常快,因为模型可以一次性验证所有草稿 token 的质量,而不是逐个验证。由于草稿模型很小,它预测单个 token 所需的时间远低于目标模型。这意味着,目标模型几乎可以在自己生成一个 token 的时间内,完成对多个 token 的验证。
需要注意的是,草稿模型和大多数语言模型一样,仍然是顺序生成这些 token 的;只是由于它的规模小,所以生成速度快得多。
目标模型认为足够好的 token 都会被选中。它拒绝的第一个 token,以及之后的所有 token,都会被丢弃。不过,由于目标模型已经完成了一次前向传播,它仍然可以执行一次下一个 token 预测。因此,即使像 “pens” 这样的 token 被拒绝,目标模型也会给出一个替代 token。
[图4]
因此,目标模型最终可能接受任意数量的草稿 token。从整体流程来看,这一点非常有意思:草稿模型以自回归方式逐步处理并一个接一个地生成 token,而目标模型则可以并行验证所有草稿 token。目标模型本身仍然是自回归模型,但它不再需要逐个生成这些草稿 token,而是可以一次性完成验证。
[图5]
🌟 Gemma 4 中的 MTP
Gemma 4 系列发布的草稿模型与 dense Gemma 4 模型非常相似,但要小得多。事实上,Gemma 4 E2B 对应的草稿模型只有大约 7600 万个参数、4 层结构,并且输入 embedding 的维度更小,只有 256,而不是 1536。
[图6]
可以注意到,decoder 本身与 dense Gemma 4 模型的 decoder 类似。不过,在 decoder 之前和之后还有很多额外设计。
草稿模型加入了多种专门用于提升效率和进一步加速推理的增强设计。同样,它也使用了一些有趣的技术来提升草稿 token 的质量,并降低 drafter 的延迟。毕竟,我们希望草稿 token 尽可能准确,同时生成速度也尽可能快。
这些变化可以概括如下:
⭕️ 目标模型激活值:草稿模型会使用目标模型最后一层的激活值,将其与 token embedding 拼接,然后下投影到 drafter 模型的维度。
⭕️ KV Cache 共享:草稿模型不会构建自己的 KV cache,而是通过 cross-attention 使用目标模型的 KV cache。
⭕️ 高效 Embedder:LM Head 使用一种稀疏解码技术,识别最可能预测到的 token cluster,也就是 token 簇。这个设计只用于 E2B 和 E4B。
下面我们更详细地看看每一项设计。
🌟 目标模型激活值
为了提高草稿模型生成 token 的质量,目标模型,例如 E2B 的最后激活值会被输入到草稿模型中。假设使用 E2B 模型,这些激活值会与草稿模型的 token embedding 拼接在一起,两者各有 1536 个数值。
拼接后的 embedding 很大,因此出于效率考虑,它们会被下投影到只有 256 个数值。这本质上是对大型目标模型已经处理过的状态,以及草稿模型新初始化的 embedding 进行压缩。既然目标模型之前已经算出了这些激活值,为什么要浪费掉呢?
[图7]
这些激活值只会在草稿模型第一轮处理时提供给它。请记住,在 drafter 完成初始一轮之后,它可能会通过自回归方式把生成的 token 再输入回自身,从而继续生成多个 token。因此,在第二轮中,草稿模型会使用自己在第一轮生成的激活值,因为这时已经生成了一个新的 token。
由于小型草稿模型的中间激活值只有 256 个数值,它们会被上投影,以匹配其输入 embedding 表的维度,也就是 1536。需要注意的是,为了进一步提高效率,输入 embedding 表在目标模型和草稿模型之间是共享的。
[图8]
随后,在第三轮中,草稿模型会使用第二轮生成的激活值,以此类推。
🌟 KV Cache 共享
KV cache 会占用相当多的空间,因为它包含序列中每个 token 在每一层中的 key 和 value 表示。虽然 Gemma 4 已经采取了很多措施来减少这部分开销,例如在 global attention 层中设置 K=V,但草稿模型又更进一步。
草稿模型不需要处理完整 prompt 并构建自己的 KV cache,而是通过 cross-attention 使用目标模型已经计算好的 KV cache。对于局部注意力层,草稿模型会直接复用目标模型最后计算出的 local KV cache。由于任意 Gemma 4 模型的最后一层始终是 global 层,因此该 global KV cache 也会被复用于草稿模型的 global attention 层。
[图9]
和前面一样,既然目标模型已经完成了大部分繁重计算,把这些结果丢弃就是一种浪费。
🌟 高效 Embedder
最后,还有一个高效 embedder,用来减少从 Language Model Head,也就是 LM Head,生成草稿 token 所需的计算量。
在传统模型中,LM Head 会把 decoder 生成的隐藏状态转换为 logits,也就是 token 概率。这个过程通常是通过将隐藏状态与一个巨大的权重矩阵相乘来完成的。这个权重矩阵与 embedding 层使用的是同一个矩阵,也就是 lookup table。对于这样一个小模型来说,这个过程可能会相当昂贵。
[图10]
你可能会问:我们真的需要为全部 262,144 个 token 都计算 logits 吗?毕竟其中大多数 token 很可能都不相关。
在 Gemma 4 E2B 和 E4B 模型的 drafter 中,这个过程通过一种经典技术变得更高效,也就是 clustering,聚类。所有 token embedding 会被单独聚类,划分为多个大组,每组包含语义相近的 token。然后,模型会为每个 cluster 找到一个 embedding 表示。
[图11]
随后,这个新的 lookup table 会用于 LM Head,并与隐藏状态相乘。这样得到的不是 token logits,而是 cluster logits。也就是说,cluster logits 会告诉我们哪些 cluster 最有可能包含下一个 token。
模型会选择最可能包含正确 token 的 cluster。然后,只对这些 cluster 内部的 token 计算最终的 token logits。
[图12]
这种计算要轻量得多,同时仍然可以告诉我们:下一个 token 几乎肯定在所选 cluster 中。这个高效 embedder 让草稿模型的速度进一步提升。
#AI创造营##How I AI#
