文章预览
来源:投稿 作者:175 编辑:学姐 unset unset 引言 unset unset 今天介绍LLAMA2模型引入的关于注意力的改进——分组查询注意力(Grouped-query attention,GQA)1。 Transformer中的多头注意力在解码阶段来说是一个性能瓶颈。多查询注意力2通过共享单个key和value头,同时不减少query头来提升性能。多查询注意力可能导致质量下降和训练不稳定,因此常用的是分组查询注意力。 然后我们结合上篇文章3探讨的旋转位置编码,将选择位置编码应用到分组查询注意力上。 unset unset 多头注意力 unset unset 我们先回顾以下原始多头注意力的实现。 import torch from torch import nn, Tensor import math from dataclasses import dataclass @dataclass class ModelArgs: hidden_size: int = 512 num_heads: int = 8 attention_dropout: float = 0.1 class MultiHeadAttention(nn.Module): def __init__(self, args: Mode
………………………………