Transformer架构的核心:自注意力机制

1. 自注意力机制的提出背景

在Transformer架构出现之前,循环神经网络(RNN)及其变体(LSTM、GRU)是处理序列数据的主流模型。然而,RNN存在以下局限性:

  • 顺序处理:RNN必须按顺序处理序列数据,无法并行计算,导致训练效率低下
  • 长期依赖:虽然LSTM和GRU缓解了长期依赖问题,但对于长序列仍然存在挑战
  • 计算复杂度:RNN的时间复杂度为O(n²),其中n是序列长度

自注意力机制(Self-Attention)的提出解决了这些问题,它允许模型:

  • 直接计算序列中任意两个位置之间的依赖关系,无论距离远近
  • 并行处理整个序列,大幅提高计算效率
  • 捕捉全局上下文信息,更好地理解序列语义

2. 自注意力机制的基本原理

自注意力机制的核心思想是:对于序列中的每个元素,计算它与序列中所有元素的关联程度(注意力权重),然后根据这些权重对所有元素进行加权求和,得到该元素的表示。

2.1 自注意力的计算过程

  1. 输入嵌入:将序列中的每个元素转换为向量表示
  2. 计算查询、键、值:通过线性变换生成三个向量
  3. 计算注意力分数:使用查询和键计算注意力分数
  4. 注意力分数归一化:使用softmax函数归一化注意力分数
  5. 加权求和:根据注意力分数对值向量进行加权求和

2.2 数学公式

对于序列中的第i个元素,自注意力的计算过程可表示为:

Query_i = W_Q · X_i
Key_i = W_K · X_i
Value_i = W_V · X_i

Score_{i,j} = Query_i · Key_j / √d_k
AttentionWeight_{i,j} = softmax(Score_{i,j})
Output_i = Σ_j (AttentionWeight_{i,j} · Value_j)

其中:

  • X_i 是第i个输入元素的嵌入向量
  • W_Q, W_K, W_V 是可学习的权重矩阵
  • d_k 是键向量的维度,用于缩放注意力分数
  • Score_{i,j} 表示第i个元素对第j个元素的注意力分数
  • AttentionWeight_{i,j} 是归一化后的注意力权重
  • Output_i 是第i个元素的自注意力输出

3. 多头自注意力机制

多头自注意力(Multi-Head Attention)是对基本自注意力机制的扩展,它通过多个"头"并行计算自注意力,然后将结果拼接起来。

3.1 多头自注意力的优势

  • 捕获不同类型的依赖关系:不同的头可以学习到不同类型的特征和依赖关系
  • 增加模型容量:多头机制相当于在不同的子空间中并行计算注意力,增加了模型的表达能力
  • 提高注意力的多样性:避免了单一注意力机制可能带来的信息瓶颈

3.2 多头自注意力的计算过程

  1. 线性变换:对每个输入元素进行多次线性变换,生成多组查询、键、值向量
  2. 并行计算:每组向量独立计算自注意力
  3. 结果拼接:将多个头的输出拼接起来
  4. 最终线性变换:通过线性变换将拼接后的结果映射到目标维度

3.3 数学公式

# 第h个头的计算
Query_i^h = W_Q^h · X_i
Key_i^h = W_K^h · X_i
Value_i^h = W_V^h · X_i

Score_{i,j}^h = Query_i^h · Key_j^h / √d_k
AttentionWeight_{i,j}^h = softmax(Score_{i,j}^h)
HeadOutput_i^h = Σ_j (AttentionWeight_{i,j}^h · Value_j^h)

# 拼接所有头的输出
MultiHeadOutput_i = Concat(HeadOutput_i^1, HeadOutput_i^2, ..., HeadOutput_i^H)

# 最终线性变换
FinalOutput_i = W_O · MultiHeadOutput_i

其中:

  • H 是头的数量
  • W_Q^h, W_K^h, W_V^h 是第h个头的权重矩阵
  • W_O 是最终的输出权重矩阵

4. 自注意力机制在Transformer中的应用

在Transformer架构中,自注意力机制被广泛应用于编码器和解码器中:

4.1 编码器中的自注意力

编码器中的自注意力层允许每个位置关注输入序列中的所有位置,捕获全局依赖关系。这使得编码器能够更好地理解整个输入序列的语义信息。

4.2 解码器中的自注意力

解码器中的自注意力层有两种类型:

  1. 掩码自注意力:只允许当前位置关注之前的位置,确保生成过程的自回归性
  2. 编码器-解码器注意力:允许解码器关注编码器的输出,捕获输入和输出之间的依赖关系

4.3 位置编码

由于自注意力机制本身不包含位置信息,Transformer使用位置编码(Positional Encoding)来为输入序列添加位置信息:

PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

其中:

  • pos 是位置索引
  • i 是维度索引
  • d_model 是模型维度

5. PyTorch实现自注意力机制

5.1 基本自注意力实现

import torch
import torch.nn as nn
import math

class SelfAttention(nn.Module):
    def __init__(self, d_model):
        super(SelfAttention, self).__init__()
        self.d_model = d_model
        
        # 线性变换层
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        
        # 输出线性变换
        self.out_linear = nn.Linear(d_model, d_model)
    
    def forward(self, x, mask=None):
        # x: [batch_size, seq_len, d_model]
        batch_size, seq_len, _ = x.size()
        
        # 计算查询、键、值
        q = self.q_linear(x)
        k = self.k_linear(x)
        v = self.v_linear(x)
        
        # 计算注意力分数
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_model)
        
        # 应用掩码
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # 计算注意力权重
        attention_weights = torch.softmax(scores, dim=-1)
        
        # 加权求和
        output = torch.matmul(attention_weights, v)
        
        # 输出线性变换
        output = self.out_linear(output)
        
        return output, attention_weights

5.2 多头自注意力实现

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # 线性变换层
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        
        # 输出线性变换
        self.out_linear = nn.Linear(d_model, d_model)
    
    def split_heads(self, x):
        # x: [batch_size, seq_len, d_model]
        batch_size, seq_len, _ = x.size()
        # 分割成多个头: [batch_size, num_heads, seq_len, d_k]
        return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
    
    def forward(self, x, mask=None):
        # x: [batch_size, seq_len, d_model]
        batch_size, seq_len, _ = x.size()
        
        # 计算查询、键、值并分割成多个头
        q = self.split_heads(self.q_linear(x))
        k = self.split_heads(self.k_linear(x))
        v = self.split_heads(self.v_linear(x))
        
        # 计算注意力分数
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # 应用掩码
        if mask is not None:
            # 扩展掩码维度以匹配注意力分数
            mask = mask.unsqueeze(1).unsqueeze(2)
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # 计算注意力权重
        attention_weights = torch.softmax(scores, dim=-1)
        
        # 加权求和
        output = torch.matmul(attention_weights, v)
        
        # 拼接多个头的输出
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        # 输出线性变换
        output = self.out_linear(output)
        
        return output, attention_weights

6. 自注意力机制的优势分析

6.1 并行计算能力

自注意力机制的最大优势是并行计算能力。与RNN的顺序计算不同,自注意力机制可以同时计算序列中所有位置的注意力,大幅提高了计算效率。

计算复杂度对比

  • RNN: O(n²·d)
  • 自注意力: O(n²·d),但可以并行计算
  • 其中n是序列长度,d是隐藏层维度

6.2 长距离依赖捕获

自注意力机制可以直接计算序列中任意两个位置之间的依赖关系,不受距离限制,因此能够更好地捕获长距离依赖。

6.3 可解释性

自注意力机制的注意力权重可以可视化,提供了模型决策的可解释性。通过分析注意力权重,我们可以了解模型在处理序列时关注了哪些位置。

7. 实用案例分析

7.1 机器翻译中的自注意力

在机器翻译任务中,自注意力机制允许模型在生成目标语言单词时,关注源语言中相关的单词,无论它们在序列中的位置如何。

案例:英中翻译

输入:"The cat sat on the mat"

自注意力机制会在生成"猫"时关注"cat",在生成"垫子"时关注"mat",即使它们在序列中的位置不同。

7.2 文本分类中的自注意力

在文本分类任务中,自注意力机制可以帮助模型识别文本中的关键信息,提高分类准确性。

代码示例:情感分析

import torch
import torch.nn as nn
import torch.optim as optim

class SentimentClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, d_model, num_heads, num_classes):
        super(SentimentClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.position_encoding = self.create_position_encoding(100, embedding_dim)
        self.attention = MultiHeadAttention(d_model, num_heads)
        self.fc = nn.Linear(d_model, num_classes)
    
    def create_position_encoding(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)
    
    def forward(self, x):
        # x: [batch_size, seq_len]
        emb = self.embedding(x)
        # 添加位置编码
        seq_len = x.size(1)
        emb = emb + self.position_encoding[:, :seq_len, :]
        # 自注意力
        attn_output, _ = self.attention(emb)
        # 平均池化
        pooled = torch.mean(attn_output, dim=1)
        # 分类
        output = self.fc(pooled)
        return output

# 模型训练
vocab_size = 10000
embedding_dim = 512
d_model = 512
num_heads = 8
num_classes = 2

model = SentimentClassifier(vocab_size, embedding_dim, d_model, num_heads, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# 训练过程
# ...

7.3 注意力权重可视化

通过可视化注意力权重,我们可以直观地了解模型在处理序列时的关注焦点。

可视化示例

import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention(attention_weights, tokens):
    # attention_weights: [num_heads, seq_len, seq_len]
    # tokens: 输入序列的标记
    
    num_heads = attention_weights.size(0)
    seq_len = attention_weights.size(1)
    
    plt.figure(figsize=(15, 10))
    
    for i in range(num_heads):
        plt.subplot(2, 4, i+1)
        sns.heatmap(attention_weights[i].detach().numpy(), 
                   xticklabels=tokens, 
                   yticklabels=tokens, 
                   cmap='viridis')
        plt.title(f'Head {i+1}')
        plt.xticks(rotation=45)
        plt.yticks(rotation=0)
    
    plt.tight_layout()
    plt.show()

# 使用示例
# tokens = ['The', 'cat', 'sat', 'on', 'the', 'mat']
# _, attention_weights = model(...)  # 获取注意力权重
# visualize_attention(attention_weights[0], tokens)  # 可视化第一个样本的注意力权重

8. 自注意力机制的局限性与改进

8.1 计算复杂度

自注意力机制的时间复杂度和空间复杂度均为O(n²),其中n是序列长度。对于非常长的序列(如长文档),这会导致计算资源消耗过大。

8.2 改进方法

  • 稀疏自注意力:只计算与当前位置相关的部分位置的注意力,如局部自注意力、带状自注意力等
  • 线性自注意力:使用核函数或其他方法将复杂度降为O(n),如Performer、Linformer等
  • 层次化自注意力:通过层次化结构处理长序列,如Longformer、BigBird等

9. 总结与展望

自注意力机制是Transformer架构的核心创新,它通过直接计算序列中不同位置之间的依赖关系,解决了RNN的顺序计算和长期依赖问题。多头自注意力机制进一步增强了模型的表达能力,使Transformer在各种序列建模任务中取得了优异的性能。

9.1 核心要点回顾

  • 自注意力机制允许模型直接计算序列中任意两个位置之间的依赖关系
  • 多头自注意力通过多个"头"并行计算注意力,捕获不同类型的依赖关系
  • 自注意力机制具有并行计算能力,大幅提高了训练效率
  • 自注意力机制能够更好地捕获长距离依赖,适用于各种序列建模任务

9.2 未来发展方向

  • 更高效的自注意力变体,以处理更长的序列
  • 自注意力与其他机制的结合,如卷积、循环结构等
  • 自注意力在更多领域的应用,如计算机视觉、语音处理等
  • 自注意力机制的理论分析和解释,以更好地理解其工作原理

10. 课后练习

  1. 实现一个基本的自注意力机制,并测试其在简单序列任务上的性能。

  2. 比较不同头数的多头自注意力在情感分析任务上的性能差异。

  3. 尝试实现一种稀疏自注意力变体,如局部自注意力,并与标准自注意力进行比较。

  4. 可视化不同任务中自注意力权重的分布,分析模型关注的重点。

  5. 探索自注意力机制在其他领域的应用,如图像分类、语音识别等。

« 上一篇 注意力机制的基本思想 下一篇 » Transformer的整体架构与优势