The Annotated Transformer 详解
1. 项目简介
The Annotated Transformer是由Harvard NLP团队开发的一个项目,旨在通过详细的代码注释和解释,帮助研究者和开发者理解Transformer模型的工作原理。该项目基于Vaswani等人在2017年发表的论文《Attention Is All You Need》,提供了一个清晰、可理解的Transformer实现。
1.1 主要功能
- 详细的代码注释:每一行代码都有详细的解释
- 完整的Transformer实现:包含编码器-解码器架构
- 注意力机制的可视化:帮助理解注意力权重的计算过程
- 教学友好:适合作为学习Transformer的教程
- 可运行的代码:可以直接运行和实验
1.2 应用场景
- 学习Transformer模型的内部工作原理
- 研究和实验注意力机制
- 作为自然语言处理课程的教学材料
- 基于Transformer构建自定义模型
- 理解现代NLP模型的基础架构
2. 核心概念
2.1 Transformer架构
Transformer是一种基于自注意力机制的序列到序列模型,主要由编码器和解码器组成:
- 编码器:将输入序列编码为上下文表示
- 解码器:根据编码器的输出和已生成的序列,生成目标序列
- 自注意力机制:允许模型关注输入序列的不同位置
- 位置编码:为模型提供序列的位置信息
- 前馈网络:对注意力输出进行进一步处理
2.2 注意力机制
注意力机制是Transformer的核心,它允许模型在处理每个位置时,关注输入序列的相关部分:
- 缩放点积注意力:计算查询和键的点积,然后缩放并应用softmax
- 多头注意力:使用多个注意力头,捕捉不同类型的关系
- 掩码注意力:在解码器中使用,防止关注未来的位置
2.3 位置编码
由于Transformer不包含循环或卷积结构,需要通过位置编码为模型提供序列的位置信息:
- 正弦位置编码:使用不同频率的正弦和余弦函数生成位置编码
- 相对位置编码:编码位置之间的相对距离
3. 代码分析
3.1 模型架构
The Annotated Transformer实现了完整的Transformer架构,包括:
class EncoderDecoder(nn.Module):
"""标准的编码器-解码器架构"""
def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
super(EncoderDecoder, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.src_embed = src_embed # 源语言嵌入
self.tgt_embed = tgt_embed # 目标语言嵌入
self.generator = generator # 生成器
def forward(self, src, tgt, src_mask, tgt_mask):
"""前向传播"""
return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)
def encode(self, src, src_mask):
"""编码源语言序列"""
return self.encoder(self.src_embed(src), src_mask)
def decode(self, memory, src_mask, tgt, tgt_mask):
"""解码目标语言序列"""
return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)3.2 注意力机制实现
def attention(query, key, value, mask=None, dropout=None):
"""缩放点积注意力"""
d_k = query.size(-1)
# 计算注意力分数
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = F.softmax(scores, dim = -1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn
class MultiHeadedAttention(nn.Module):
"""多头注意力"""
def __init__(self, h, d_model, dropout=0.1):
super(MultiHeadedAttention, self).__init__()
assert d_model % h == 0
# 每个头的维度
self.d_k = d_model // h
self.h = h
# 线性层
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None
self.dropout = nn.Dropout(p=dropout)
def forward(self, query, key, value, mask=None):
if mask is not None:
# 所有头使用相同的掩码
mask = mask.unsqueeze(1)
nbatches = query.size(0)
# 1) 线性投影到h个头
query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
for l, x in zip(self.linears, (query, key, value))]
# 2) 应用注意力到所有投影后的向量
x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
# 3) 拼接结果并应用最终的线性层
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
return self.linears[-1](x)3.3 位置编码实现
class PositionalEncoding(nn.Module):
"""位置编码"""
def __init__(self, d_model, dropout, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
# 计算位置编码
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)3.4 编码器实现
class Encoder(nn.Module):
"""Transformer编码器"""
def __init__(self, layer, N):
super(Encoder, self).__init__()
# 克隆N个编码器层
self.layers = clones(layer, N)
# 层归一化
self.norm = LayerNorm(layer.size)
def forward(self, x, mask):
"""前向传播"""
for layer in self.layers:
x = layer(x, mask)
return self.norm(x)
class EncoderLayer(nn.Module):
"""单个编码器层"""
def __init__(self, size, self_attn, feed_forward, dropout):
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
# 子层连接
self.sublayer = clones(SublayerConnection(size, dropout), 2)
self.size = size
def forward(self, x, mask):
"""前向传播"""
# 自注意力子层
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
# 前馈网络子层
return self.sublayer[1](x, self.feed_forward)3.5 解码器实现
class Decoder(nn.Module):
"""Transformer解码器"""
def __init__(self, layer, N):
super(Decoder, self).__init__()
# 克隆N个解码器层
self.layers = clones(layer, N)
# 层归一化
self.norm = LayerNorm(layer.size)
def forward(self, x, memory, src_mask, tgt_mask):
"""前向传播"""
for layer in self.layers:
x = layer(x, memory, src_mask, tgt_mask)
return self.norm(x)
class DecoderLayer(nn.Module):
"""单个解码器层"""
def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
super(DecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
# 子层连接
self.sublayer = clones(SublayerConnection(size, dropout), 3)
def forward(self, x, memory, src_mask, tgt_mask):
"""前向传播"""
m = memory
# 掩码自注意力子层
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
# 编码器-解码器注意力子层
x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
# 前馈网络子层
return self.sublayer[2](x, self.feed_forward)4. 实践应用
4.1 机器翻译
使用The Annotated Transformer实现机器翻译:
# 加载数据
from torchtext.datasets import Multi30k
train_iter, valid_iter, test_iter = Multi30k(split=('train', 'valid', 'test'))
# 构建词汇表
SRC = Field(tokenize="spacy", tokenizer_language="de", init_token="<sos>", eos_token="<eos>", lower=True)
TRG = Field(tokenize="spacy", tokenizer_language="en", init_token="<sos>", eos_token="<eos>", lower=True)
SRC.build_vocab(train_iter, min_freq=2)
TRG.build_vocab(train_iter, min_freq=2)
# 初始化模型
d_model = 512
heads = 8
N = 6
d_ff = 2048
dropout = 0.1
# 创建模型实例
model = make_model(len(SRC.vocab), len(TRG.vocab), N=N, d_model=d_model, d_ff=d_ff, h=heads, dropout=dropout)
# 训练模型
train_model(model, train_iter, valid_iter, test_iter, SRC, TRG, epochs=10)
# 翻译示例
def translate(sentence, model, SRC, TRG):
model.eval()
tokens = [token.text.lower() for token in spacy_de(sentence)]
tokens = [SRC.init_token] + tokens + [SRC.eos_token]
src_indexes = [SRC.vocab.stoi[token] for token in tokens]
src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)
src_mask = (src_tensor != SRC.vocab.stoi[SRC.pad_token]).unsqueeze(-2)
memory = model.encode(src_tensor, src_mask)
tgt_indexes = [TRG.vocab.stoi[TRG.init_token]]
for i in range(50):
tgt_tensor = torch.LongTensor(tgt_indexes).unsqueeze(0).to(device)
tgt_mask = (tgt_tensor != TRG.vocab.stoi[TRG.pad_token]).unsqueeze(-2)
tgt_mask = tgt_mask & subsequent_mask(tgt_tensor.size(-1)).type_as(tgt_mask.data)
output = model.decode(memory, src_mask, tgt_tensor, tgt_mask)
prob = model.generator(output[:, -1])
_, next_word = torch.max(prob, dim=1)
next_word = next_word.item()
tgt_indexes.append(next_word)
if next_word == TRG.vocab.stoi[TRG.eos_token]:
break
tgt_tokens = [TRG.vocab.itos[i] for i in tgt_indexes]
return tgt_tokens[1:-1]
# 测试翻译
sentence = "ein mann in einem blauen hemd steht auf einem berg"
translation = translate(sentence, model, SRC, TRG)
print("源语言:", sentence)
print("目标语言:", ' '.join(translation))4.2 文本摘要
使用The Annotated Transformer实现文本摘要:
# 加载数据
from torchtext.datasets import CNNDM
train_iter, valid_iter, test_iter = CNNDM(split=('train', 'valid', 'test'))
# 构建词汇表
SRC = Field(tokenize="spacy", init_token="<sos>", eos_token="<eos>", lower=True)
TRG = Field(tokenize="spacy", init_token="<sos>", eos_token="<eos>", lower=True)
SRC.build_vocab(train_iter, min_freq=2)
TRG.build_vocab(train_iter, min_freq=2)
# 初始化模型
d_model = 512
heads = 8
N = 6
d_ff = 2048
dropout = 0.1
# 创建模型实例
model = make_model(len(SRC.vocab), len(TRG.vocab), N=N, d_model=d_model, d_ff=d_ff, h=heads, dropout=dropout)
# 训练模型
train_model(model, train_iter, valid_iter, test_iter, SRC, TRG, epochs=5)
# 生成摘要
def summarize(text, model, SRC, TRG):
model.eval()
tokens = [token.text.lower() for token in spacy_en(text)]
tokens = [SRC.init_token] + tokens + [SRC.eos_token]
src_indexes = [SRC.vocab.stoi[token] for token in tokens]
src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)
src_mask = (src_tensor != SRC.vocab.stoi[SRC.pad_token]).unsqueeze(-2)
memory = model.encode(src_tensor, src_mask)
tgt_indexes = [TRG.vocab.stoi[TRG.init_token]]
for i in range(100):
tgt_tensor = torch.LongTensor(tgt_indexes).unsqueeze(0).to(device)
tgt_mask = (tgt_tensor != TRG.vocab.stoi[TRG.pad_token]).unsqueeze(-2)
tgt_mask = tgt_mask & subsequent_mask(tgt_tensor.size(-1)).type_as(tgt_mask.data)
output = model.decode(memory, src_mask, tgt_tensor, tgt_mask)
prob = model.generator(output[:, -1])
_, next_word = torch.max(prob, dim=1)
next_word = next_word.item()
tgt_indexes.append(next_word)
if next_word == TRG.vocab.stoi[TRG.eos_token]:
break
tgt_tokens = [TRG.vocab.itos[i] for i in tgt_indexes]
return tgt_tokens[1:-1]
# 测试摘要
text = "The Annotated Transformer is a detailed implementation of the Transformer model. It provides clear code with extensive comments to help understand how the model works. The project is maintained by the Harvard NLP group and is widely used for educational purposes."
summary = summarize(text, model, SRC, TRG)
print("原文:", text)
print("摘要:", ' '.join(summary))4.3 注意力可视化
可视化Transformer的注意力权重:
import matplotlib.pyplot as plt
import seaborn as sns
def visualize_attention(src, tgt, model, SRC, TRG):
model.eval()
# 预处理输入
src_tokens = [token.text.lower() for token in spacy_de(src)]
src_tokens = [SRC.init_token] + src_tokens + [SRC.eos_token]
src_indexes = [SRC.vocab.stoi[token] for token in src_tokens]
src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)
src_mask = (src_tensor != SRC.vocab.stoi[SRC.pad_token]).unsqueeze(-2)
# 预处理目标
tgt_tokens = [token.text.lower() for token in spacy_en(tgt)]
tgt_tokens = [TRG.init_token] + tgt_tokens + [TRG.eos_token]
tgt_indexes = [TRG.vocab.stoi[token] for token in tgt_tokens]
tgt_tensor = torch.LongTensor(tgt_indexes).unsqueeze(0).to(device)
tgt_mask = (tgt_tensor != TRG.vocab.stoi[TRG.pad_token]).unsqueeze(-2)
tgt_mask = tgt_mask & subsequent_mask(tgt_tensor.size(-1)).type_as(tgt_mask.data)
# 前向传播
memory = model.encode(src_tensor, src_mask)
output = model.decode(memory, src_mask, tgt_tensor, tgt_mask)
# 获取注意力权重
attn_weights = model.decoder.layers[0].src_attn.attn.detach().cpu().numpy()[0]
# 可视化
plt.figure(figsize=(10, 8))
sns.heatmap(attn_weights, xticklabels=src_tokens, yticklabels=tgt_tokens, cmap="viridis")
plt.title("Attention Weights")
plt.xlabel("Source")
plt.ylabel("Target")
plt.show()
# 测试注意力可视化
src = "ein mann in einem blauen hemd"
tgt = "a man in a blue shirt"
visualize_attention(src, tgt, model, SRC, TRG)5. 扩展与改进
5.1 模型扩展
基于The Annotated Transformer,可以实现以下扩展:
- BERT:添加掩码语言模型目标,实现双向编码器表示
- GPT:修改解码器,实现自回归语言模型
- T5:使用统一的文本到文本框架
- BART:结合双向编码器和自回归解码器
5.2 性能优化
提高Transformer模型的性能:
- 混合精度训练:使用FP16减少内存使用
- 梯度累积:模拟更大的批量大小
- 学习率调度:使用预热和衰减策略
- 批量归一化:加速训练收敛
- 模型并行:在多个GPU上并行训练
5.3 应用扩展
将Transformer应用到其他领域:
- 计算机视觉:ViT(Vision Transformer)
- 语音处理:Whisper
- 多模态学习:CLIP
- 推荐系统:基于Transformer的推荐模型
- 时间序列预测:用于预测任务
6. 总结与展望
The Annotated Transformer是一个优秀的教学资源,通过详细的代码注释和实现,帮助开发者和研究者理解Transformer模型的内部工作原理。它的主要优势包括:
- 详细的代码注释:每一行代码都有清晰的解释
- 完整的实现:包含Transformer的所有核心组件
- 教学友好:适合作为学习Transformer的教程
- 可扩展性:可以基于此实现各种Transformer变体
未来,Transformer模型有望在以下方面继续发展:
- 更高效的架构设计:减少计算复杂度和内存使用
- 更好的小样本学习能力:提高模型的泛化能力
- 更强的多模态理解:整合文本、图像、音频等多种模态
- 更广泛的应用场景:扩展到更多领域和任务
- 更易于部署:优化模型大小和推理速度
通过学习The Annotated Transformer,开发者可以深入理解Transformer的工作原理,为构建更强大的NLP模型打下基础。Transformer的出现标志着自然语言处理领域的重要突破,为未来的AI研究和应用开辟了新的方向。