5.2 nn.Transformer

使用 PyTorch 的 nn.Transformer 接口,实现德语-英语翻译。

创建日期: 2025-04-09

在讲完 注意力机制Transformer 论文 "Attention is All You Need" 之后,接下来将想法付诸实践,最有效的办法就是复现论文中德语-英语翻译。我们使用 PyTorch 提供的 nn.Transformer 接口,通过参数传递,快速实现 Transformer 模型,代码在 torch_impl.py 文件中,为后续手写 Transformer 实现打下基础。

5.2.1 模型概述

Transformer 架构的左边是编码器,右边是解码器,左右两边同时使用注意力机制,如下图所示:

这里不会细讲这个模型,感兴趣的读者可以查看之前的教程 Transformer 论文 。主要介绍 nn.Transformer 接口,全部参数如下:

def __init__(
    self,
    d_model: int = 512,
    nhead: int = 8,
    num_encoder_layers: int = 6,
    num_decoder_layers: int = 6,
    dim_feedforward: int = 2048,
    dropout: float = 0.1,
    activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
    custom_encoder: Optional[Any] = None,
    custom_decoder: Optional[Any] = None,
    layer_norm_eps: float = 1e-5,
    batch_first: bool = False,
    norm_first: bool = False,
    bias: bool = True,
    device=None,
    dtype=None,
) -> None:

对其中重要的参数做简单的解释:

  1. d_model:编码器/解码器输入中预期特征的数量(默认值为 512),可以理解为单词嵌入的维度。

  2. nhead:多头注意力模型中的头数量(默认值为 8)。

  3. num_encoder_layers:编码器中子层的数量(默认值为 6)。

  4. num_decoder_layers:解码器中子层的数量(默认值为 6)。

  5. dim_feedforward:前馈网络模型的维度(默认值为 2048)。

  6. dropout:随机失活值(默认值为 0.1)。

  7. activation:编码器/解码器中间层的激活函数,可以是字符串(例如 "relu" 或 "gelu")或一元可调用函数。默认值是 relu 。

  8. custom_encoder:自定义编码器(默认值为 None)。

  9. custom_decoder:自定义解码器(默认值为 None)。

  10. batch_first:如果为 "True",则输入和输出张量将以 (batch, seq, feature) 的形式提供。默认值为 False (seq, batch, feature)。

  11. norm_first:如果为 "True",则编码器和解码器层将在其他注意力机制和前馈操作之前执行 LayerNorm,否则在操作之后执行。默认值:False(之后)。

  12. bias:如果设置为 "False",则 "Linear" 层和 "LayerNorm" 层将不会学习加性偏差。默认值为 True。

在 Transformer 论文中,作者使用 8 台 P100 GPU,450 万句子对,到达当时最先进的水平需要训练 12 个小时。我们基本上也使用相同的参数,大约 3 万句子对,在单台 4070 GPU 上训练,40 个轮回需要大约半个小时,详细训练过程见 第 4 小节

5.2.2 德-英数据集

我们的任务是将德语翻译成英语,因此需要有英语-德语数据集。国内有个比较大的问题就是网络下载环境,这里每个人碰到的情况会不一样。另外一个问题就是 Python 库的兼容性,原本打算使用的是 spaCy,但是 pip install spacy 告知编译阶段报错(一个可行的办法是使用 Python 虚拟环境)。这里使用的是 nltk 代替,主要作用是 Tokenize ,将句子分成独立的单词。

数据集来自于 HuggingFacedatasets 仓库。下载方式如下:

from datasets import load_dataset
dataset = load_dataset('bentrevett/multi30k')
print(dataset)
DatasetDict({
                train: Dataset({
                    features: ['en', 'de'],
                    num_rows: 29000
                })
                validation: Dataset({
                    features: ['en', 'de'],
                    num_rows: 1014
                })
                test: Dataset({
                    features: ['en', 'de'],
                    num_rows: 1000
                })
            })

数据集已经包含 trainvalidationtest 三个部分,查看 train 部分第一个句子:

print(dataset['train'][0])
{'en': 'Two young, White males are outside near many bushes.',
                'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.'}

5.2.3 预处理

预处理包含两部分,一是将句子进行拆分,转换为一个词汇表,并给首尾加上开始和结束标记,二是将转换后的文本使用 PyTorch 的 API 进行分批处理。

5.2.3.1 构建词汇表

NLTK (Natural Language Toolkit) 是用于构建处理人类语言数据的 Python 程序,可用于分类、标记化、词干提取等文本处理任务。

使用 nltk.download('punkt') 初始化 nltk 库标记部分:

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\15207\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\15207\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!

将标记器传递到 Field 结构体中,设置开始标记,结束标记,以及空标记,构建源语言和目标语言的单词库(单词至少要出现 2 次):

train_data = [(example['de'], example['en']) for example in dataset['train']]
valid_data = [(example['de'], example['en']) for example in dataset['validation']]
test_data = [(example['de'], example['en']) for example in dataset['test']]

SRC = Field(tokenize=word_tokenize, init_token='', eos_token='', pad_token='', lower=True, batch_first=True)
TRG = Field(tokenize=word_tokenize, init_token='', eos_token='', pad_token='', lower=True, batch_first=True)

# 创建 `Example` 列表
train_examples = [Example.fromlist([src, trg], fields=[('src', SRC), ('trg', TRG)]) for src, trg in train_data]
valid_examples = [Example.fromlist([src, trg], fields=[('src', SRC), ('trg', TRG)]) for src, trg in valid_data]
test_examples = [Example.fromlist([src, trg], fields=[('src', SRC), ('trg', TRG)]) for src, trg in test_data]

# 转换为 `Dataset`
train_dataset = Dataset(examples=train_examples, fields=[('src', SRC), ('trg', TRG)])
valid_dataset = Dataset(examples=valid_examples, fields=[('src', SRC), ('trg', TRG)])
test_dataset = Dataset(examples=test_examples, fields=[('src', SRC), ('trg', TRG)])

# 构建词汇表
SRC.build_vocab(train_dataset, min_freq=2)
TRG.build_vocab(train_dataset, min_freq=2)

print('Source vocabulary size: ' + str(len(SRC.vocab)))
print('Target vocabulary size: ' + str(len(TRG.vocab)))

源数据有 7861 个标记,目标数据有 5920 个标记:

Source vocabulary size: 7861
Target vocabulary size: 5920

分别打印前 10 个标记:

print([word for word, _ in list(SRC.vocab.stoi.items())[:10]])
print([word for word, _ in list(TRG.vocab.stoi.items())[:10]])
['', '', '', '', '.', 'ein', 'einem', 'in', 'eine', ',']
['', '', '', '', 'a', '.', 'in', 'the', 'on', 'man']

5.2.3.2 数据迭代器

使用 BucketIterator 对标记后的数据集对进行分批处理,批次大小设置为 32 。判断是否 GPU 可用,如果可用,就将数据转移到 GPU 上。

src_vocab_size = len(SRC.vocab)
trg_vocab_size = len(TRG.vocab)

BATCH_SIZE = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_dataset, valid_dataset, test_dataset),
    batch_size=BATCH_SIZE,
    device=device,
    sort_within_batch=True,
    sort_key=lambda x: len(x.src),
)

for batch in train_iterator:
    src = batch.src
    trg = batch.trg
    
    print(src.shape)
    print(trg.shape)
    break

打印训练数据集的第一个批次,当前批次源数据的标记最大长度为 17 ,目标数据的标记最大长度为 21 :

torch.Size([32, 17])
torch.Size([32, 21])

做好上述准备后,就可以开始训练啦!

5.2.4 训练与验证

训练过程和之前学习的模型类似,定义模型的架构,分别定义 trainevaluate 函数。

5.2.4.1 架构定义

模型架构都封装在 nn.Transformer 结构体中,使用 nn.Embedding 函数将进行词嵌入,最后使用 nn.Linear 将模型的输出映射到目标词汇表:

class TransformerModel(nn.Module):
    def __init__(self, src_vocab_size, trg_vocab_size,
                    emb_size=256, nhead=8, nhid=1024, nlayers=6, dropout=0.1):
        super(TransformerModel, self).__init__()

        self.src_emb = nn.Embedding(src_vocab_size, emb_size)
        self.trg_emb = nn.Embedding(trg_vocab_size, emb_size)
        
        self.transformer = nn.Transformer(
            d_model=emb_size,
            nhead=nhead,
            num_encoder_layers=nlayers,
            num_decoder_layers=nlayers,
            dim_feedforward=nhid,
            batch_first=True,
            dropout=dropout,
        )

        self.fc_out = nn.Linear(emb_size, trg_vocab_size)
    
    def forward(self, src, trg):
        src_emb = self.src_emb(src)
        trg_emb = self.trg_emb(trg)
        
        output = self.transformer(src_emb, trg_emb)
        
        return self.fc_out(output)

5.2.4.2 训练与验证

分别定义优化器和损失函数:

optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
criterion = nn.CrossEntropyLoss(ignore_index=TRG.vocab.stoi[TRG.pad_token])

训练函数:

# 调整训练轮数和梯度裁剪
EPOCHS = 40
CLIP = 1.0  # Gradient clipping

def train(model, iterator, optimizer, criterion, clip):
    model.train()
    
    epoch_loss = 0
    for i, batch in enumerate(iterator):
        src = batch.src  # 源句子
        trg = batch.trg  # 目标句子
        
        # 确保 batch_size 一致
        assert src.size(0) == trg.size(0)

        optimizer.zero_grad()
        
        # Exclude the last token from target sequence (shifted target sequence)
        output = model(src, trg[:, :-1])  # tgt[:, :-1] 排除目标序列的最后一个 token
        
        output_dim = output.shape[-1]
        
        # Flatten output and trg to calculate loss
        output = output.view(-1, output_dim)  # (batch_size * seq_len, trg_vocab_size)
        trg = trg[:, 1:].contiguous().view(-1)  # Exclude first token from target sequence
        
        loss = criterion(output, trg)
        loss.backward()
        
        # Gradient clipping to avoid exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()

    return epoch_loss / len(iterator)

验证函数:

def evaluate(model, iterator, criterion):
    model.eval()  # 设置为评估模式
    epoch_loss = 0
    
    with torch.no_grad():  # 评估时不计算梯度
        for i, batch in enumerate(iterator):
            src = batch.src  # 源句子
            trg = batch.trg  # 目标句子
            
            # 目标句子去掉最后一个 token
            output = model(src, trg[:, :-1])
            
            output_dim = output.shape[-1]
            
            # Flatten output 和 trg 用于计算 loss
            output = output.view(-1, output_dim)
            trg = trg[:, 1:].contiguous().view(-1)
            
            loss = criterion(output, trg)  # 计算损失
            epoch_loss += loss.item()
    
    return epoch_loss / len(iterator)

设置 40 个训练轮回:

for epoch in range(EPOCHS):
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iterator, criterion)
    print(f"Epoch {epoch+1} | Train Loss: {train_loss:.3f} | Validation Loss: {valid_loss:.3f}")
Epoch 1 | Train Loss: 4.126 | Validation Loss: 3.174
Epoch 2 | Train Loss: 3.048 | Validation Loss: 2.654
Epoch 3 | Train Loss: 2.622 | Validation Loss: 2.346
Epoch 4 | Train Loss: 2.326 | Validation Loss: 2.142
Epoch 5 | Train Loss: 2.096 | Validation Loss: 1.967
Epoch 6 | Train Loss: 1.905 | Validation Loss: 1.845
Epoch 7 | Train Loss: 1.743 | Validation Loss: 1.724
Epoch 8 | Train Loss: 1.601 | Validation Loss: 1.648
Epoch 10 | Train Loss: 1.367 | Validation Loss: 1.518
Epoch 12 | Train Loss: 1.176 | Validation Loss: 1.447
Epoch 14 | Train Loss: 1.021 | Validation Loss: 1.386
Epoch 16 | Train Loss: 0.893 | Validation Loss: 1.359
Epoch 18 | Train Loss: 0.787 | Validation Loss: 1.355
Epoch 20 | Train Loss: 0.699 | Validation Loss: 1.352
Epoch 22 | Train Loss: 0.629 | Validation Loss: 1.358
Epoch 24 | Train Loss: 0.573 | Validation Loss: 1.384
Epoch 26 | Train Loss: 0.525 | Validation Loss: 1.402
Epoch 28 | Train Loss: 0.485 | Validation Loss: 1.439
Epoch 30 | Train Loss: 0.451 | Validation Loss: 1.452
Epoch 32 | Train Loss: 0.424 | Validation Loss: 1.488
Epoch 34 | Train Loss: 0.399 | Validation Loss: 1.506
Epoch 36 | Train Loss: 0.377 | Validation Loss: 1.525
Epoch 38 | Train Loss: 0.360 | Validation Loss: 1.552
Epoch 40 | Train Loss: 0.346 | Validation Loss: 1.574

5.2.4.3 待优化

观察每一轮的训练损失和验证损失,验证损失到达 1.4 之后,就趋于稳定,而训练损失还在不断降低,说明模型发生比较严重的过拟合现象。

推断引起该问题的主要原因是数据量太小,可以采用更大的数据集进行训练,但是这会极大地增加训练时间和训练难度,感兴趣的读者可以结合自己的配置进行处理。

5.2.5 翻译推理

翻译推理函数,这里取测试集的第一批数据进行推理,当输出遇到 "sos" 时结束就去掉句子后续的预测值:

def translate_one_batch(model, iterator, SRC, TRG):
    model.eval()  # 切换到评估模式
    translations = []  # 存储翻译结果
    
    with torch.no_grad():
        # 获取第一个 batch
        batch = next(iter(iterator))  
        src = batch.src  # 源句子
        trg = batch.trg  # 目标句子
        
        # 使用模型生成翻译
        output = model(src, trg[:, :-1])
        output = output.argmax(dim=-1)  # 选择最大概率的预测词
        
        # 将预测结果转为词汇表中的词
        for i in range(src.size(0)):  # 遍历每一个样本
            src_tokens = [SRC.vocab.itos[idx] for idx in src[i]]  # 源句子
            # hyp_tokens = [TRG.vocab.itos[idx] for idx in output[i] if idx != TRG.vocab.stoi[TRG.pad_token]]  # 预测的翻译
            hyp_tokens = []
            for idx in output[i]:  # 遍历每个生成的词
                if idx == TRG.vocab.stoi[TRG.eos_token]:  # 如果生成的是 EOS token,停止生成
                    break
                if idx != TRG.vocab.stoi[TRG.pad_token]:  # 如果不是 pad token,加入到翻译中
                    hyp_tokens.append(TRG.vocab.itos[idx])

            # 保存翻译结果
            translations.append({
                'src': ' '.join(src_tokens),  # 源句子
                'hyp': ' '.join(hyp_tokens),  # 生成的翻译
                'trg': ' '.join([TRG.vocab.itos[idx] for idx in trg[i][1:].cpu().numpy() if idx != TRG.vocab.stoi[TRG.pad_token]])  # 目标句子
            })
    
    return translations

下面是预测的部分结果,可以看到有些句子还是和参考一样的,有些句子则发生一些错误:

--------------------------------------------------
Source:  drei leute sitzen in einer höhle . 
Prediction: three people sit in a cave .
Reference: three people sit in a cave . 
--------------------------------------------------
Source:  ein typ arbeitet an einem gebäude . 
Prediction: a building works on a building .
Reference: a guy works on a building . 
--------------------------------------------------
Source:  leute reparieren das dach eines hauses . 
Prediction: people are fixing the roof of a house .
Reference: people are fixing the roof of a house . 
--------------------------------------------------
Source:  zwei jungen vor einem  .  
Prediction: two boys in front of a machine machine .
Reference: two boys in front of a soda machine . 
--------------------------------------------------

以上就是全部内容,还有很多值得研究的地方!比如文本生成是如何处理的。