5.3 nn.Transformer

Use the torch.nn.Transformer interface for the German-to-English translation task.

Created Date: 2025-07-02

In this tutorial, we will use the torch.nn.Transformer interface to quickly translate German to English and complete the framework of the translation project. Then we will use the handwritten multi-head attention, encoder, decoder, and forward propagation layers to replace the torch.nn.Transformer interface.

5.3.1 Dataset of Multi30k

For ease of use, we'll use the WMT16 Machine Translation dataset. This dataset contains the "multi30k" dataset, which is the "task 1" dataset from here .

Each example consists of an "en" and a "de" feature. "en" is an English sentence, and "de" is the German translation of the English sentence. The Multi30k dataset has 3 splits: train, validation, and test. Use datasets from HuggingFace:

nltk.download('punkt')
nltk.download('punkt_tab')

dataset = datasets.load_dataset('bentrevett/multi30k')
print(dataset)
# {'en': 'Two young, White males are outside near many bushes.',
# 'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.'}
print(dataset['train'][0])
DatasetDict({
    train: Dataset({
        features: ['en', 'de'],
        num_rows: 29000
    })
    validation: Dataset({
        features: ['en', 'de'],
        num_rows: 1014
    })
    test: Dataset({
        features: ['en', 'de'],
        num_rows: 1000
    })
})
{'en': 'Two young, White males are outside near many bushes.', 'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.'}

File language_trans_torch.py uses the TorchText library to handle the complexities of text processing, such as tokenization, vocabulary building, and batching. Let's break down the code step-by-step:

1. Data Loading and Initial Preparation

After we get dataset object from Hugging Face's datasets, extracts parallel sentences (German to English in this case) from the 'train', 'validation', and 'test' splits of the dataset.

2. Defining Field Objects

Field objects in TorchText define how raw text data should be processed. Let's look at the arguments:

  • tokenize=word_tokenize specifies the tokenization function, it from NLTK , split a sentence into individual words or subword units. Tokenization is the process of breaking down a sequence of characters into pieces, called tokens.

  • init_token='<sos>' adds a special "start of sequence" token at the beginning of each sentence.

  • eos_token='<eos>' adds a special "end of sequence" token at the end of each sentence.

  • pad_token='<pad>' specifies a token to use for padding. When batching sentences, they often have different lengths. Padding ensures all sentences in a batch have the same length by adding <pad> tokens to shorter ones.

  • lower=True converts all text to lowercase. This helps reduce vocabulary size and makes the model less sensitive to case variations.

  • batch_first=True when data is batched, this determines the dimension order. If True, the batch dimension comes first (e.g., [batch_size, sequence_length]).

3. Creating Example Objects

Example objects represent a single instance of data, associating raw text with their respective Field definitions. Here, each German-English sentence pair is converted into an Example.

4. Building Dataset Objects

Dataset objects in TorchText are containers for Example objects. They organize the data into a format suitable for use with DataLoader (which would be the next step for iterating through batches).

5. Building Vocabularies

This is a crucial step. build_vocab creates a vocabulary for each Field. A vocabulary is a mapping from unique words (tokens) to numerical indices.

train_data = [(example['de'], example['en']) for example in dataset['train']]
valid_data = [(example['de'], example['en']) for example in dataset['validation']]
test_data = [(example['de'], example['en']) for example in dataset['test']]

SRC = Field(tokenize=word_tokenize, init_token='', eos_token='', pad_token='', lower=True, batch_first=True)
TRG = Field(tokenize=word_tokenize, init_token='', eos_token='', pad_token='', lower=True, batch_first=True)

train_examples = [Example.fromlist([src, trg], fields=[('src', SRC), ('trg', TRG)]) for src, trg in train_data]
valid_examples = [Example.fromlist([src, trg], fields=[('src', SRC), ('trg', TRG)]) for src, trg in valid_data]
test_examples = [Example.fromlist([src, trg], fields=[('src', SRC), ('trg', TRG)]) for src, trg in test_data]

train_dataset = Dataset(examples=train_examples, fields=[('src', SRC), ('trg', TRG)])
valid_dataset = Dataset(examples=valid_examples, fields=[('src', SRC), ('trg', TRG)])
test_dataset = Dataset(examples=test_examples, fields=[('src', SRC), ('trg', TRG)])

SRC.build_vocab(train_dataset, min_freq=2)
TRG.build_vocab(train_dataset, min_freq=2)

print('Source vocabulary size: ' + str(len(SRC.vocab)))
print('Target vocabulary size: ' + str(len(TRG.vocab)))

print([word for word, _ in list(SRC.vocab.stoi.items())[:10]])
print([word for word, _ in list(TRG.vocab.stoi.items())[:10]])

We print useful information about the created vocabularies, prints the first 10 word-to-index mappings from the vocabulary. stoi stands for "string to index." This gives a peek into what words are included in the vocabulary and their assigned numerical IDs.

{'en': 'Two young, White males are outside near many bushes.', 'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.'}    
Source vocabulary size: 7861
Target vocabulary size: 5920
['', '', '', '', '.', 'ein', 'einem', 'in', 'eine', ',']
['', '', '', '', 'a', '.', 'in', 'the', 'on', 'man']

BucketIterator efficiently groups similar-length sentences into batches for training, validation, and testing. The loop demonstrates accessing these batched source (src) and target (trg) tensors and printing their shapes.

BATCH_SIZE = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_dataset, valid_dataset, test_dataset),
    batch_size=BATCH_SIZE,
    device=device,
    sort_within_batch=True,
    sort_key=lambda x: len(x.src),
)

for batch in train_iterator:
    src = batch.src
    trg = batch.trg
    
    print(src.shape)
    print(trg.shape)
    break
torch.Size([32, 19])
torch.Size([32, 26])

5.3.2 nn.Transformer API

This section builds a Transformer model using PyTorch's nn.Transformer for sequence-to-sequence tasks. It defines embeddings, encoder-decoder layers, and output projection. Training employs teacher forcing, loss masking, and gradient clipping. Finally, inference generates translations by feeding target sequences and decoding predictions from model outputs.

5.3.2.1 Make Transformer Model

This TransformerModel class defines a standard sequence-to-sequence Transformer architecture using PyTorch's nn.Transformer module. It initializes with:

  • src_emb and trg_emb: Embedding layers to convert input source and target token IDs into dense vector representations (emb_size).

  • transformer: The core nn.Transformer module, consisting of stacked encoder and decoder layers. Key parameters like d_model (embedding dimension), nhead (attention heads), nlayers (number of layers), dim_feedforward (hidden dimension of feedforward networks), and dropout are configurable.

  • fc_out: A final linear layer to project the Transformer's output back to the size of the target vocabulary (trg_vocab_size), producing logits for word prediction.

The forward method takes embedded source and target sequences, passes them through the transformer, and then applies the final linear layer to get the output logits. This model is commonly used for tasks like neural machine translation.

class TransformerModel(nn.Module):
    def __init__(self, src_vocab_size, trg_vocab_size,
                 emb_size=256, nhead=8, nhid=1024, nlayers=6, dropout=0.1):
        super(TransformerModel, self).__init__()

        self.src_emb = nn.Embedding(src_vocab_size, emb_size)
        self.trg_emb = nn.Embedding(trg_vocab_size, emb_size)
        
        self.transformer = nn.Transformer(
            d_model=emb_size,
            nhead=nhead,
            num_encoder_layers=nlayers,
            num_decoder_layers=nlayers,
            dim_feedforward=nhid,
            batch_first=True,
            dropout=dropout,
        )

        self.fc_out = nn.Linear(emb_size, trg_vocab_size)
    
    def forward(self, src, trg):
        src_emb = self.src_emb(src)
        trg_emb = self.trg_emb(trg)
        
        output = self.transformer(src_emb, trg_emb)
        
        return self.fc_out(output)

The self.fc_out = nn.Linear(emb_size, trg_vocab_size) layer in a Transformer's decoder is the final output layer.

It convert the internal representation (of emb_size dimensions) of the generated token from the Transformer decoder into a vector whose size matches the total number of unique words/tokens in the target language's vocabulary (trg_vocab_size).

Essentially, it's the layer that translates the model's learned representation of "what word comes next" into concrete predictions over all possible words.

model = TransformerModel(src_vocab_size, trg_vocab_size).to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
criterion = nn.CrossEntropyLoss(ignore_index=TRG.vocab.stoi[TRG.pad_token])

criterion.to(device)

EPOCHS = 40
CLIP = 1.0  # Gradient clipping

5.3.2.2 Training nn.Transformer

def train(model, iterator, optimizer, criterion, clip):
    model.train()
    
    epoch_loss = 0
    for i, batch in enumerate(iterator):
        src = batch.src
        trg = batch.trg

        assert src.size(0) == trg.size(0), f"src batch size {src.size(0)} does not match trg batch size {trg.size(0)}"
        
        optimizer.zero_grad()
        
        # Exclude the last token from target sequence (shifted target sequence)
        output = model(src, trg[:, :-1])
        
        output_dim = output.shape[-1]
        
        # Flatten output and trg to calculate loss
        output = output.view(-1, output_dim)  # (batch_size * seq_len, trg_vocab_size)
        trg = trg[:, 1:].contiguous().view(-1)  # Exclude first token from target sequence
        
        loss = criterion(output, trg)
        loss.backward()
        
        # Gradient clipping to avoid exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()
    
    return epoch_loss / len(iterator)
def evaluate(model, iterator, criterion):
    model.eval() 
    epoch_loss = 0
    
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            src = batch.src
            trg = batch.trg
            
            output = model(src, trg[:, :-1])
            
            output_dim = output.shape[-1]
            
            output = output.view(-1, output_dim)
            trg = trg[:, 1:].contiguous().view(-1)
            
            loss = criterion(output, trg)
            epoch_loss += loss.item()
    
    return epoch_loss / len(iterator)
for epoch in range(EPOCHS):
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iterator, criterion)
    print(f"Epoch {epoch+1} | Train Loss: {train_loss:.3f} | Validation Loss: {valid_loss:.3f}")
Epoch 1 | Train Loss: 4.128 | Validation Loss: 3.180
Epoch 2 | Train Loss: 3.065 | Validation Loss: 2.667
Epoch 3 | Train Loss: 2.637 | Validation Loss: 2.352
Epoch 4 | Train Loss: 2.333 | Validation Loss: 2.126
Epoch 5 | Train Loss: 2.097 | Validation Loss: 1.961
Epoch 6 | Train Loss: 1.902 | Validation Loss: 1.835
Epoch 7 | Train Loss: 1.739 | Validation Loss: 1.736
Epoch 8 | Train Loss: 1.597 | Validation Loss: 1.651
Epoch 9 | Train Loss: 1.473 | Validation Loss: 1.584
Epoch 10 | Train Loss: 1.362 | Validation Loss: 1.525
Epoch 11 | Train Loss: 1.263 | Validation Loss: 1.481
Epoch 12 | Train Loss: 1.173 | Validation Loss: 1.445
Epoch 13 | Train Loss: 1.090 | Validation Loss: 1.425
Epoch 14 | Train Loss: 1.017 | Validation Loss: 1.392
Epoch 15 | Train Loss: 0.950 | Validation Loss: 1.388
Epoch 16 | Train Loss: 0.889 | Validation Loss: 1.376
Epoch 17 | Train Loss: 0.833 | Validation Loss: 1.358
Epoch 18 | Train Loss: 0.782 | Validation Loss: 1.363
Epoch 19 | Train Loss: 0.737 | Validation Loss: 1.359
Epoch 20 | Train Loss: 0.694 | Validation Loss: 1.360
Epoch 21 | Train Loss: 0.657 | Validation Loss: 1.369
Epoch 22 | Train Loss: 0.625 | Validation Loss: 1.371
Epoch 23 | Train Loss: 0.593 | Validation Loss: 1.393
Epoch 24 | Train Loss: 0.568 | Validation Loss: 1.411
Epoch 25 | Train Loss: 0.542 | Validation Loss: 1.398
Epoch 26 | Train Loss: 0.520 | Validation Loss: 1.406
Epoch 27 | Train Loss: 0.498 | Validation Loss: 1.405
Epoch 28 | Train Loss: 0.481 | Validation Loss: 1.429
Epoch 29 | Train Loss: 0.463 | Validation Loss: 1.439
Epoch 30 | Train Loss: 0.449 | Validation Loss: 1.447
Epoch 31 | Train Loss: 0.431 | Validation Loss: 1.455
Epoch 32 | Train Loss: 0.420 | Validation Loss: 1.474
Epoch 33 | Train Loss: 0.406 | Validation Loss: 1.484
Epoch 34 | Train Loss: 0.396 | Validation Loss: 1.499
Epoch 35 | Train Loss: 0.385 | Validation Loss: 1.508
Epoch 36 | Train Loss: 0.374 | Validation Loss: 1.524
Epoch 37 | Train Loss: 0.366 | Validation Loss: 1.540
Epoch 38 | Train Loss: 0.356 | Validation Loss: 1.560
Epoch 39 | Train Loss: 0.350 | Validation Loss: 1.573
Epoch 40 | Train Loss: 0.343 | Validation Loss: 1.563

5.3.2.3 Translation Inference

def translate_one_batch(model, iterator, SRC, TRG):
    model.eval()
    translations = []
    
    with torch.no_grad():
        batch = next(iter(iterator))  
        src = batch.src
        trg = batch.trg
        
        output = model(src, trg[:, :-1])
        output = output.argmax(dim=-1)
        
        for i in range(src.size(0)):
            src_tokens = [SRC.vocab.itos[idx] for idx in src[i]]
            hyp_tokens = []
            for idx in output[i]:
                if idx == TRG.vocab.stoi[TRG.eos_token]:
                    break
                if idx != TRG.vocab.stoi[TRG.pad_token]:
                    hyp_tokens.append(TRG.vocab.itos[idx])

            translations.append({
                'src': ' '.join(src_tokens),
                'hyp': ' '.join(hyp_tokens),
                'trg': ' '.join([TRG.vocab.itos[idx] for idx in trg[i][1:].cpu().numpy() if idx != TRG.vocab.stoi[TRG.pad_token]])
            })
    
    return translations

def print_translations(translations):
    for translation in translations:
        print(f"Source: {translation['src']}")
        print(f"Prediction: {translation['hyp']}")
        print(f"Reference: {translation['trg']}")
        print("-" * 50)

translations = translate_one_batch(model, test_iterator, SRC, TRG)
print_translations(translations)
Source: <sos> ein mann arbeitet an einem <unk> . <eos>
Prediction: a man is working a man stand .
Reference: a man is working a hotdog stand . <eos>
--------------------------------------------------
Source: <sos> ein junges mädchen schwimmt in einem pool <eos>
Prediction: a young girl swimming in a young
Reference: a young girl swimming in a pool <eos>
--------------------------------------------------
Source: <sos> zwei männer in schwarz in einer stadt <eos>
Prediction: two men wearing black in a city
Reference: two men wearing black in a city <eos>
--------------------------------------------------
Source: <sos> eine große menschenmenge füllt eine straße . <eos>
Prediction: a large group of people fill a large .
Reference: a large group of people fill a street . <eos>
--------------------------------------------------
Source: <sos> ein mann schneidet äste von bäumen . <eos>
Prediction: a man cutting of . trees .
Reference: a man cutting branches of trees . <eos>
--------------------------------------------------
Source: <sos> drei leute sitzen in einer höhle . <eos>
Prediction: three people sit in a cave .
Reference: three people sit in a cave . <eos>
--------------------------------------------------
Source: <sos> ein typ arbeitet an einem gebäude . <eos>
Prediction: a building works on a building .
Reference: a guy works on a building . <eos>
--------------------------------------------------
Source: <sos> leute reparieren das dach eines hauses . <eos>
Prediction: people are fixing the roof of a house .
Reference: people are fixing the roof of a house . <eos>
--------------------------------------------------
Source: <sos> zwei jungen vor einem <unk> . <eos> <pad>
Prediction: two boys in front of a machine machine .
Reference: two boys in front of a soda machine . <eos>
--------------------------------------------------
Source: <sos> ein typ küsst einen anderen typ <eos> <pad>
Prediction: a guy also a guy
Reference: a guy give a kiss to a guy also <eos>
--------------------------------------------------
Source: <sos> zwei jungen spielen gegeneinander fußball . <eos> <pad>
Prediction: two boys play soccer against each other .
Reference: two boys play soccer against each other . <eos>
--------------------------------------------------
Source: <sos> kinder kämpfen um den ballbesitz . <eos> <pad>
Prediction: kids compete to the of of the soccer ball .
Reference: kids compete to <unk> possession of the soccer ball . <eos>
--------------------------------------------------
Source: <sos> ein wandgemälde auf einem gebäude . <eos> <pad>
Prediction: a mural on the side of a mural .
Reference: a mural on the side of a building . <eos>
--------------------------------------------------
Source: <sos> ärzte bei einer art operation . <eos> <pad>
Prediction: doctors performing some type of surgery .
Reference: doctors performing some type of surgery . <eos>
--------------------------------------------------
Source: <sos> feuerwehrmänner kommen aus einer u-bahnstation . <eos< <pad>
Prediction: firemen <unk> station a subway station .
Reference: firemen <unk> from a subway station . <eos>
--------------------------------------------------
Source: <sos> arbeiter diskutieren neben den schienen . <eos> <pad>
Prediction: construction workers having a construction by the tracks .
Reference: construction workers having a discussion by the tracks . <eos>
--------------------------------------------------
Source: <sos> ein cowboy <unk> seinen arm . <eos> <pad>
Prediction: a cowboy with his . arm up his cowboy .
Reference: a cowboy wrapping up his arm with a bandage . <eos>
--------------------------------------------------
Source: <sos> ein mann verwendet <unk> geräte . <eos> <pad>
Prediction: a man is using electronic equipment .
Reference: a man is using electronic equipment . <eos>
--------------------------------------------------
Source: <sos> zwei fußballmannschaften auf dem feld . <eos< <pad>
Prediction: two soccer teams are on the soccer .
Reference: two soccer teams are on the field . <eos>
--------------------------------------------------
Source: <sos> ein mann an seinem hochzeitstag . <eos> <pad>
Prediction: a man on his wedding day .
Reference: a man on his wedding day . <eos>
--------------------------------------------------
Source: <sos> ein hellbrauner hund läuft bergauf . <eos> <pad>
Prediction: a light brown dog is running up .
Reference: a light brown dog is running up . <eos>
--------------------------------------------------
Source: <sos> hunde laufen auf einer hunderennbahn . <eos> <pad>
Prediction: dogs run at a racetrack run .
Reference: dogs run at a dog racetrack . <eos>
--------------------------------------------------
Source: <sos> ein am strand <unk> auto . <eos> <pad>
Prediction: a car parked at the beach .
Reference: a car parked at the beach . <eos>
--------------------------------------------------
Source: <sos> leute sitzen in einem zug . <eos> <pad>
Prediction: people sit inside a train .
Reference: people sit inside a train . <eos>
--------------------------------------------------
Source: <sos> ein kind planscht im wasser . <eos> <pad>
Prediction: a child is splashing in the water
Reference: a child is splashing in the water <eos>
--------------------------------------------------
Source: <sos> leute bewundern ein kunstwerk . <eos> <pad> <pad>
Prediction: a are admiring a row . people of
Reference: people are admiring a work of art . <eos>
--------------------------------------------------
Source: <sos> junge frau klettert auf felswand <eos> <pad> <pad>
Prediction: young woman rock face face
Reference: young woman climbing rock face <eos>
--------------------------------------------------
Source: <sos> ein rockkonzert findet statt . <eos> <pad> <pad>
Prediction: a concert concert is taking place .
Reference: a rock concert is taking place . <eos>
--------------------------------------------------
Source: <sos> eine frau spielt volleyball . <eos> <pad> <pad>
Prediction: a woman is playing volleyball .
Reference: a woman is playing volleyball . <eos>
--------------------------------------------------
Source: <sos> ein <unk> inspiziert etwas . <eos> <pad> <pad>
Prediction: an army officer is inspecting something .
Reference: an army officer is inspecting something . <eos>
--------------------------------------------------
Source: <sos> zwei männer mit mützen . <eos> <pad> <pad>
Prediction: two men wearing hats .
Reference: two men wearing hats . <eos>
--------------------------------------------------
Source: <sos> drei männer gehen bergauf . <eos> <pad> <pad>
Prediction: three men are walking up walking .
Reference: three men are walking up hill . <eos>
--------------------------------------------------

5.3.3 Transformer from Scratch

File translate_scratch.py

Epoch 1 | Train Loss: 4.187 | Validation Loss: 3.352
Epoch 2 | Train Loss: 3.199 | Validation Loss: 2.923
Epoch 3 | Train Loss: 2.832 | Validation Loss: 2.656
Epoch 4 | Train Loss: 2.580 | Validation Loss: 2.496
Epoch 5 | Train Loss: 2.384 | Validation Loss: 2.370
Epoch 6 | Train Loss: 2.220 | Validation Loss: 2.275
Epoch 7 | Train Loss: 2.078 | Validation Loss: 2.198
Epoch 8 | Train Loss: 1.950 | Validation Loss: 2.127
Epoch 9 | Train Loss: 1.837 | Validation Loss: 2.084
Epoch 10 | Train Loss: 1.733 | Validation Loss: 2.058
Epoch 11 | Train Loss: 1.636 | Validation Loss: 2.031
Epoch 12 | Train Loss: 1.548 | Validation Loss: 2.011
Epoch 13 | Train Loss: 1.461 | Validation Loss: 2.007
Epoch 14 | Train Loss: 1.382 | Validation Loss: 2.000
Epoch 15 | Train Loss: 1.307 | Validation Loss: 1.990
Epoch 16 | Train Loss: 1.234 | Validation Loss: 2.007
Epoch 17 | Train Loss: 1.163 | Validation Loss: 2.011
Epoch 18 | Train Loss: 1.099 | Validation Loss: 2.031
Epoch 19 | Train Loss: 1.037 | Validation Loss: 2.054
Epoch 20 | Train Loss: 0.978 | Validation Loss: 2.060
Epoch 21 | Train Loss: 0.922 | Validation Loss: 2.080
Epoch 22 | Train Loss: 0.869 | Validation Loss: 2.095
Epoch 23 | Train Loss: 0.822 | Validation Loss: 2.120
Epoch 24 | Train Loss: 0.775 | Validation Loss: 2.149
Epoch 25 | Train Loss: 0.733 | Validation Loss: 2.182
Epoch 26 | Train Loss: 0.691 | Validation Loss: 2.199
Epoch 27 | Train Loss: 0.654 | Validation Loss: 2.236
Epoch 28 | Train Loss: 0.617 | Validation Loss: 2.272
Epoch 29 | Train Loss: 0.584 | Validation Loss: 2.298
Epoch 30 | Train Loss: 0.553 | Validation Loss: 2.339
Epoch 31 | Train Loss: 0.524 | Validation Loss: 2.372
Epoch 32 | Train Loss: 0.497 | Validation Loss: 2.410
Epoch 33 | Train Loss: 0.472 | Validation Loss: 2.444
Epoch 34 | Train Loss: 0.450 | Validation Loss: 2.478
Epoch 35 | Train Loss: 0.428 | Validation Loss: 2.494
Epoch 36 | Train Loss: 0.407 | Validation Loss: 2.545
Epoch 37 | Train Loss: 0.387 | Validation Loss: 2.569
Epoch 38 | Train Loss: 0.370 | Validation Loss: 2.603
Epoch 39 | Train Loss: 0.353 | Validation Loss: 2.652
Epoch 40 | Train Loss: 0.338 | Validation Loss: 2.678

5.3.4 Optimization

5.3.5 Conclusion