文章预览
作者:小莲子 原 文 地址: https://zhuanlan.zhihu.com/p/690588915 对于 LLM, 加速推理 并 降低显存 ,是两个至关重要的问题。本文将从 Key-Value Cache 出发,介绍两种相关的模型结构改进。分别是 ChatGLM 系列使用的 Multi-Query Attention(MQA) 和 LLama 系列使用的 Grouped-Query Attention(GQA) 。希望读完之后,能够明白“为何要做”和“如何去做”。 LLM 普遍采用 decoder-only 的结构,自回归地逐个 token 进行生成。在最基础的设置下,我们传给模型的输入顺序为(假设只生成 5 个词): prompt -> x0 -> x1 -> x2 -> x3 -> x4 -> x5 用 transformers 库的 pytorh 代码表示为: input_ids = tokenizer(prompt, return_tensors = "pt" )[ "input_ids" ] . to( "cuda" ) for _ in range ( 5 ): next_logits = model(input_ids)[ "logits" ][:, - 1 :] # 取最大概率token为当前时间步输出. next_token_id = torch . argmax(next_logits,dim =- 1
………………………………