5.3 nn.Transformer
Use the torch.nn.Transformer interface for the German-to-English translation task.
Created Date: 2025-07-02
In this tutorial, we will use the torch.nn.Transformer interface to quickly translate German to English and complete the framework of the translation project. Then we will use the handwritten multi-head attention, encoder, decoder, and forward propagation layers to replace the torch.nn.Transformer
interface.
5.3.1 Dataset of Multi30k
For ease of use, we'll use the WMT16 Machine Translation dataset. This dataset contains the "multi30k" dataset, which is the "task 1" dataset from here .
Each example consists of an "en" and a "de" feature. "en" is an English sentence, and "de" is the German translation of the English sentence. The Multi30k dataset has 3 splits: train, validation, and test. Use datasets
from HuggingFace:
nltk.download('punkt')
nltk.download('punkt_tab')
dataset = datasets.load_dataset('bentrevett/multi30k')
print(dataset)
# {'en': 'Two young, White males are outside near many bushes.',
# 'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.'}
print(dataset['train'][0])
DatasetDict({ train: Dataset({ features: ['en', 'de'], num_rows: 29000 }) validation: Dataset({ features: ['en', 'de'], num_rows: 1014 }) test: Dataset({ features: ['en', 'de'], num_rows: 1000 }) }) {'en': 'Two young, White males are outside near many bushes.', 'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.'}
File language_trans_torch.py uses the TorchText library to handle the complexities of text processing, such as tokenization, vocabulary building, and batching. Let's break down the code step-by-step:
1. Data Loading and Initial Preparation
After we get dataset
object from Hugging Face's datasets, extracts parallel sentences (German to English in this case) from the 'train', 'validation', and 'test' splits of the dataset.
2. Defining Field
Objects
Field
objects in TorchText define how raw text data should be processed. Let's look at the arguments:
tokenize=word_tokenize
specifies the tokenization function, it from NLTK , split a sentence into individual words or subword units. Tokenization is the process of breaking down a sequence of characters into pieces, called tokens.init_token='<sos>'
adds a special "start of sequence" token at the beginning of each sentence.eos_token='<eos>'
adds a special "end of sequence" token at the end of each sentence.pad_token='<pad>'
specifies a token to use for padding. When batching sentences, they often have different lengths. Padding ensures all sentences in a batch have the same length by adding<pad>
tokens to shorter ones.lower=True
converts all text to lowercase. This helps reduce vocabulary size and makes the model less sensitive to case variations.batch_first=True
when data is batched, this determines the dimension order. If True, the batch dimension comes first (e.g.,[batch_size, sequence_length]
).
3. Creating Example
Objects
Example objects represent a single instance of data, associating raw text with their respective Field definitions. Here, each German-English sentence pair is converted into an Example.
4. Building Dataset
Objects
Dataset
objects in TorchText are containers for Example objects. They organize the data into a format suitable for use with DataLoader (which would be the next step for iterating through batches).
5. Building Vocabularies
This is a crucial step. build_vocab
creates a vocabulary for each Field. A vocabulary is a mapping from unique words (tokens) to numerical indices.
train_data = [(example['de'], example['en']) for example in dataset['train']]
valid_data = [(example['de'], example['en']) for example in dataset['validation']]
test_data = [(example['de'], example['en']) for example in dataset['test']]
SRC = Field(tokenize=word_tokenize, init_token='', eos_token='', pad_token='', lower=True, batch_first=True)
TRG = Field(tokenize=word_tokenize, init_token='', eos_token='', pad_token='', lower=True, batch_first=True)
train_examples = [Example.fromlist([src, trg], fields=[('src', SRC), ('trg', TRG)]) for src, trg in train_data]
valid_examples = [Example.fromlist([src, trg], fields=[('src', SRC), ('trg', TRG)]) for src, trg in valid_data]
test_examples = [Example.fromlist([src, trg], fields=[('src', SRC), ('trg', TRG)]) for src, trg in test_data]
train_dataset = Dataset(examples=train_examples, fields=[('src', SRC), ('trg', TRG)])
valid_dataset = Dataset(examples=valid_examples, fields=[('src', SRC), ('trg', TRG)])
test_dataset = Dataset(examples=test_examples, fields=[('src', SRC), ('trg', TRG)])
SRC.build_vocab(train_dataset, min_freq=2)
TRG.build_vocab(train_dataset, min_freq=2)
print('Source vocabulary size: ' + str(len(SRC.vocab)))
print('Target vocabulary size: ' + str(len(TRG.vocab)))
print([word for word, _ in list(SRC.vocab.stoi.items())[:10]])
print([word for word, _ in list(TRG.vocab.stoi.items())[:10]])
We print useful information about the created vocabularies, prints the first 10 word-to-index mappings from the vocabulary. stoi stands for "string to index." This gives a peek into what words are included in the vocabulary and their assigned numerical IDs.
{'en': 'Two young, White males are outside near many bushes.', 'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.'} Source vocabulary size: 7861 Target vocabulary size: 5920 ['', ' ', ' ', ' ', '.', 'ein', 'einem', 'in', 'eine', ','] [' ', ' ', ' ', ' ', 'a', '.', 'in', 'the', 'on', 'man']
BucketIterator
efficiently groups similar-length sentences into batches for training, validation, and testing. The loop demonstrates accessing these batched source (src) and target (trg) tensors and printing their shapes.
BATCH_SIZE = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
(train_dataset, valid_dataset, test_dataset),
batch_size=BATCH_SIZE,
device=device,
sort_within_batch=True,
sort_key=lambda x: len(x.src),
)
for batch in train_iterator:
src = batch.src
trg = batch.trg
print(src.shape)
print(trg.shape)
break
torch.Size([32, 19]) torch.Size([32, 26])
5.3.2 nn.Transformer API
This section builds a Transformer model using PyTorch's nn.Transformer for sequence-to-sequence tasks. It defines embeddings, encoder-decoder layers, and output projection. Training employs teacher forcing, loss masking, and gradient clipping. Finally, inference generates translations by feeding target sequences and decoding predictions from model outputs.
5.3.2.1 Make Transformer Model
This TransformerModel class defines a standard sequence-to-sequence Transformer architecture using PyTorch's nn.Transformer module. It initializes with:
src_emb and trg_emb: Embedding layers to convert input source and target token IDs into dense vector representations (emb_size).
transformer: The core nn.Transformer module, consisting of stacked encoder and decoder layers. Key parameters like d_model (embedding dimension), nhead (attention heads), nlayers (number of layers), dim_feedforward (hidden dimension of feedforward networks), and dropout are configurable.
fc_out: A final linear layer to project the Transformer's output back to the size of the target vocabulary (trg_vocab_size), producing logits for word prediction.
The forward method takes embedded source and target sequences, passes them through the transformer, and then applies the final linear layer to get the output logits. This model is commonly used for tasks like neural machine translation.
class TransformerModel(nn.Module):
def __init__(self, src_vocab_size, trg_vocab_size,
emb_size=256, nhead=8, nhid=1024, nlayers=6, dropout=0.1):
super(TransformerModel, self).__init__()
self.src_emb = nn.Embedding(src_vocab_size, emb_size)
self.trg_emb = nn.Embedding(trg_vocab_size, emb_size)
self.transformer = nn.Transformer(
d_model=emb_size,
nhead=nhead,
num_encoder_layers=nlayers,
num_decoder_layers=nlayers,
dim_feedforward=nhid,
batch_first=True,
dropout=dropout,
)
self.fc_out = nn.Linear(emb_size, trg_vocab_size)
def forward(self, src, trg):
src_emb = self.src_emb(src)
trg_emb = self.trg_emb(trg)
output = self.transformer(src_emb, trg_emb)
return self.fc_out(output)
The self.fc_out = nn.Linear(emb_size, trg_vocab_size)
layer in a Transformer's decoder is the final output layer.
It convert the internal representation (of emb_size dimensions) of the generated token from the Transformer decoder into a vector whose size matches the total number of unique words/tokens in the target language's vocabulary (trg_vocab_size).
Essentially, it's the layer that translates the model's learned representation of "what word comes next" into concrete predictions over all possible words.
model = TransformerModel(src_vocab_size, trg_vocab_size).to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
criterion = nn.CrossEntropyLoss(ignore_index=TRG.vocab.stoi[TRG.pad_token])
criterion.to(device)
EPOCHS = 40
CLIP = 1.0 # Gradient clipping
5.3.2.2 Training nn.Transformer
def train(model, iterator, optimizer, criterion, clip):
model.train()
epoch_loss = 0
for i, batch in enumerate(iterator):
src = batch.src
trg = batch.trg
assert src.size(0) == trg.size(0), f"src batch size {src.size(0)} does not match trg batch size {trg.size(0)}"
optimizer.zero_grad()
# Exclude the last token from target sequence (shifted target sequence)
output = model(src, trg[:, :-1])
output_dim = output.shape[-1]
# Flatten output and trg to calculate loss
output = output.view(-1, output_dim) # (batch_size * seq_len, trg_vocab_size)
trg = trg[:, 1:].contiguous().view(-1) # Exclude first token from target sequence
loss = criterion(output, trg)
loss.backward()
# Gradient clipping to avoid exploding gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(iterator)
def evaluate(model, iterator, criterion):
model.eval()
epoch_loss = 0
with torch.no_grad():
for i, batch in enumerate(iterator):
src = batch.src
trg = batch.trg
output = model(src, trg[:, :-1])
output_dim = output.shape[-1]
output = output.view(-1, output_dim)
trg = trg[:, 1:].contiguous().view(-1)
loss = criterion(output, trg)
epoch_loss += loss.item()
return epoch_loss / len(iterator)
for epoch in range(EPOCHS):
train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
valid_loss = evaluate(model, valid_iterator, criterion)
print(f"Epoch {epoch+1} | Train Loss: {train_loss:.3f} | Validation Loss: {valid_loss:.3f}")
Epoch 1 | Train Loss: 4.128 | Validation Loss: 3.180 Epoch 2 | Train Loss: 3.065 | Validation Loss: 2.667 Epoch 3 | Train Loss: 2.637 | Validation Loss: 2.352 Epoch 4 | Train Loss: 2.333 | Validation Loss: 2.126 Epoch 5 | Train Loss: 2.097 | Validation Loss: 1.961 Epoch 6 | Train Loss: 1.902 | Validation Loss: 1.835 Epoch 7 | Train Loss: 1.739 | Validation Loss: 1.736 Epoch 8 | Train Loss: 1.597 | Validation Loss: 1.651 Epoch 9 | Train Loss: 1.473 | Validation Loss: 1.584 Epoch 10 | Train Loss: 1.362 | Validation Loss: 1.525 Epoch 11 | Train Loss: 1.263 | Validation Loss: 1.481 Epoch 12 | Train Loss: 1.173 | Validation Loss: 1.445 Epoch 13 | Train Loss: 1.090 | Validation Loss: 1.425 Epoch 14 | Train Loss: 1.017 | Validation Loss: 1.392 Epoch 15 | Train Loss: 0.950 | Validation Loss: 1.388 Epoch 16 | Train Loss: 0.889 | Validation Loss: 1.376 Epoch 17 | Train Loss: 0.833 | Validation Loss: 1.358 Epoch 18 | Train Loss: 0.782 | Validation Loss: 1.363 Epoch 19 | Train Loss: 0.737 | Validation Loss: 1.359 Epoch 20 | Train Loss: 0.694 | Validation Loss: 1.360 Epoch 21 | Train Loss: 0.657 | Validation Loss: 1.369 Epoch 22 | Train Loss: 0.625 | Validation Loss: 1.371 Epoch 23 | Train Loss: 0.593 | Validation Loss: 1.393 Epoch 24 | Train Loss: 0.568 | Validation Loss: 1.411 Epoch 25 | Train Loss: 0.542 | Validation Loss: 1.398 Epoch 26 | Train Loss: 0.520 | Validation Loss: 1.406 Epoch 27 | Train Loss: 0.498 | Validation Loss: 1.405 Epoch 28 | Train Loss: 0.481 | Validation Loss: 1.429 Epoch 29 | Train Loss: 0.463 | Validation Loss: 1.439 Epoch 30 | Train Loss: 0.449 | Validation Loss: 1.447 Epoch 31 | Train Loss: 0.431 | Validation Loss: 1.455 Epoch 32 | Train Loss: 0.420 | Validation Loss: 1.474 Epoch 33 | Train Loss: 0.406 | Validation Loss: 1.484 Epoch 34 | Train Loss: 0.396 | Validation Loss: 1.499 Epoch 35 | Train Loss: 0.385 | Validation Loss: 1.508 Epoch 36 | Train Loss: 0.374 | Validation Loss: 1.524 Epoch 37 | Train Loss: 0.366 | Validation Loss: 1.540 Epoch 38 | Train Loss: 0.356 | Validation Loss: 1.560 Epoch 39 | Train Loss: 0.350 | Validation Loss: 1.573 Epoch 40 | Train Loss: 0.343 | Validation Loss: 1.563
5.3.2.3 Translation Inference
def translate_one_batch(model, iterator, SRC, TRG):
model.eval()
translations = []
with torch.no_grad():
batch = next(iter(iterator))
src = batch.src
trg = batch.trg
output = model(src, trg[:, :-1])
output = output.argmax(dim=-1)
for i in range(src.size(0)):
src_tokens = [SRC.vocab.itos[idx] for idx in src[i]]
hyp_tokens = []
for idx in output[i]:
if idx == TRG.vocab.stoi[TRG.eos_token]:
break
if idx != TRG.vocab.stoi[TRG.pad_token]:
hyp_tokens.append(TRG.vocab.itos[idx])
translations.append({
'src': ' '.join(src_tokens),
'hyp': ' '.join(hyp_tokens),
'trg': ' '.join([TRG.vocab.itos[idx] for idx in trg[i][1:].cpu().numpy() if idx != TRG.vocab.stoi[TRG.pad_token]])
})
return translations
def print_translations(translations):
for translation in translations:
print(f"Source: {translation['src']}")
print(f"Prediction: {translation['hyp']}")
print(f"Reference: {translation['trg']}")
print("-" * 50)
translations = translate_one_batch(model, test_iterator, SRC, TRG)
print_translations(translations)
Source: <sos> ein mann arbeitet an einem <unk> . <eos> Prediction: a man is working a man stand . Reference: a man is working a hotdog stand . <eos> -------------------------------------------------- Source: <sos> ein junges mädchen schwimmt in einem pool <eos> Prediction: a young girl swimming in a young Reference: a young girl swimming in a pool <eos> -------------------------------------------------- Source: <sos> zwei männer in schwarz in einer stadt <eos> Prediction: two men wearing black in a city Reference: two men wearing black in a city <eos> -------------------------------------------------- Source: <sos> eine große menschenmenge füllt eine straße . <eos> Prediction: a large group of people fill a large . Reference: a large group of people fill a street . <eos> -------------------------------------------------- Source: <sos> ein mann schneidet äste von bäumen . <eos> Prediction: a man cutting of . trees . Reference: a man cutting branches of trees . <eos> -------------------------------------------------- Source: <sos> drei leute sitzen in einer höhle . <eos> Prediction: three people sit in a cave . Reference: three people sit in a cave . <eos> -------------------------------------------------- Source: <sos> ein typ arbeitet an einem gebäude . <eos> Prediction: a building works on a building . Reference: a guy works on a building . <eos> -------------------------------------------------- Source: <sos> leute reparieren das dach eines hauses . <eos> Prediction: people are fixing the roof of a house . Reference: people are fixing the roof of a house . <eos> -------------------------------------------------- Source: <sos> zwei jungen vor einem <unk> . <eos> <pad> Prediction: two boys in front of a machine machine . Reference: two boys in front of a soda machine . <eos> -------------------------------------------------- Source: <sos> ein typ küsst einen anderen typ <eos> <pad> Prediction: a guy also a guy Reference: a guy give a kiss to a guy also <eos> -------------------------------------------------- Source: <sos> zwei jungen spielen gegeneinander fußball . <eos> <pad> Prediction: two boys play soccer against each other . Reference: two boys play soccer against each other . <eos> -------------------------------------------------- Source: <sos> kinder kämpfen um den ballbesitz . <eos> <pad> Prediction: kids compete to the of of the soccer ball . Reference: kids compete to <unk> possession of the soccer ball . <eos> -------------------------------------------------- Source: <sos> ein wandgemälde auf einem gebäude . <eos> <pad> Prediction: a mural on the side of a mural . Reference: a mural on the side of a building . <eos> -------------------------------------------------- Source: <sos> ärzte bei einer art operation . <eos> <pad> Prediction: doctors performing some type of surgery . Reference: doctors performing some type of surgery . <eos> -------------------------------------------------- Source: <sos> feuerwehrmänner kommen aus einer u-bahnstation . <eos< <pad> Prediction: firemen <unk> station a subway station . Reference: firemen <unk> from a subway station . <eos> -------------------------------------------------- Source: <sos> arbeiter diskutieren neben den schienen . <eos> <pad> Prediction: construction workers having a construction by the tracks . Reference: construction workers having a discussion by the tracks . <eos> -------------------------------------------------- Source: <sos> ein cowboy <unk> seinen arm . <eos> <pad> Prediction: a cowboy with his . arm up his cowboy . Reference: a cowboy wrapping up his arm with a bandage . <eos> -------------------------------------------------- Source: <sos> ein mann verwendet <unk> geräte . <eos> <pad> Prediction: a man is using electronic equipment . Reference: a man is using electronic equipment . <eos> -------------------------------------------------- Source: <sos> zwei fußballmannschaften auf dem feld . <eos< <pad> Prediction: two soccer teams are on the soccer . Reference: two soccer teams are on the field . <eos> -------------------------------------------------- Source: <sos> ein mann an seinem hochzeitstag . <eos> <pad> Prediction: a man on his wedding day . Reference: a man on his wedding day . <eos> -------------------------------------------------- Source: <sos> ein hellbrauner hund läuft bergauf . <eos> <pad> Prediction: a light brown dog is running up . Reference: a light brown dog is running up . <eos> -------------------------------------------------- Source: <sos> hunde laufen auf einer hunderennbahn . <eos> <pad> Prediction: dogs run at a racetrack run . Reference: dogs run at a dog racetrack . <eos> -------------------------------------------------- Source: <sos> ein am strand <unk> auto . <eos> <pad> Prediction: a car parked at the beach . Reference: a car parked at the beach . <eos> -------------------------------------------------- Source: <sos> leute sitzen in einem zug . <eos> <pad> Prediction: people sit inside a train . Reference: people sit inside a train . <eos> -------------------------------------------------- Source: <sos> ein kind planscht im wasser . <eos> <pad> Prediction: a child is splashing in the water Reference: a child is splashing in the water <eos> -------------------------------------------------- Source: <sos> leute bewundern ein kunstwerk . <eos> <pad> <pad> Prediction: a are admiring a row . people of Reference: people are admiring a work of art . <eos> -------------------------------------------------- Source: <sos> junge frau klettert auf felswand <eos> <pad> <pad> Prediction: young woman rock face face Reference: young woman climbing rock face <eos> -------------------------------------------------- Source: <sos> ein rockkonzert findet statt . <eos> <pad> <pad> Prediction: a concert concert is taking place . Reference: a rock concert is taking place . <eos> -------------------------------------------------- Source: <sos> eine frau spielt volleyball . <eos> <pad> <pad> Prediction: a woman is playing volleyball . Reference: a woman is playing volleyball . <eos> -------------------------------------------------- Source: <sos> ein <unk> inspiziert etwas . <eos> <pad> <pad> Prediction: an army officer is inspecting something . Reference: an army officer is inspecting something . <eos> -------------------------------------------------- Source: <sos> zwei männer mit mützen . <eos> <pad> <pad> Prediction: two men wearing hats . Reference: two men wearing hats . <eos> -------------------------------------------------- Source: <sos> drei männer gehen bergauf . <eos> <pad> <pad> Prediction: three men are walking up walking . Reference: three men are walking up hill . <eos> --------------------------------------------------
5.3.3 Transformer from Scratch
File translate_scratch.py
Epoch 1 | Train Loss: 4.187 | Validation Loss: 3.352 Epoch 2 | Train Loss: 3.199 | Validation Loss: 2.923 Epoch 3 | Train Loss: 2.832 | Validation Loss: 2.656 Epoch 4 | Train Loss: 2.580 | Validation Loss: 2.496 Epoch 5 | Train Loss: 2.384 | Validation Loss: 2.370 Epoch 6 | Train Loss: 2.220 | Validation Loss: 2.275 Epoch 7 | Train Loss: 2.078 | Validation Loss: 2.198 Epoch 8 | Train Loss: 1.950 | Validation Loss: 2.127 Epoch 9 | Train Loss: 1.837 | Validation Loss: 2.084 Epoch 10 | Train Loss: 1.733 | Validation Loss: 2.058 Epoch 11 | Train Loss: 1.636 | Validation Loss: 2.031 Epoch 12 | Train Loss: 1.548 | Validation Loss: 2.011 Epoch 13 | Train Loss: 1.461 | Validation Loss: 2.007 Epoch 14 | Train Loss: 1.382 | Validation Loss: 2.000 Epoch 15 | Train Loss: 1.307 | Validation Loss: 1.990 Epoch 16 | Train Loss: 1.234 | Validation Loss: 2.007 Epoch 17 | Train Loss: 1.163 | Validation Loss: 2.011 Epoch 18 | Train Loss: 1.099 | Validation Loss: 2.031 Epoch 19 | Train Loss: 1.037 | Validation Loss: 2.054 Epoch 20 | Train Loss: 0.978 | Validation Loss: 2.060 Epoch 21 | Train Loss: 0.922 | Validation Loss: 2.080 Epoch 22 | Train Loss: 0.869 | Validation Loss: 2.095 Epoch 23 | Train Loss: 0.822 | Validation Loss: 2.120 Epoch 24 | Train Loss: 0.775 | Validation Loss: 2.149 Epoch 25 | Train Loss: 0.733 | Validation Loss: 2.182 Epoch 26 | Train Loss: 0.691 | Validation Loss: 2.199 Epoch 27 | Train Loss: 0.654 | Validation Loss: 2.236 Epoch 28 | Train Loss: 0.617 | Validation Loss: 2.272 Epoch 29 | Train Loss: 0.584 | Validation Loss: 2.298 Epoch 30 | Train Loss: 0.553 | Validation Loss: 2.339 Epoch 31 | Train Loss: 0.524 | Validation Loss: 2.372 Epoch 32 | Train Loss: 0.497 | Validation Loss: 2.410 Epoch 33 | Train Loss: 0.472 | Validation Loss: 2.444 Epoch 34 | Train Loss: 0.450 | Validation Loss: 2.478 Epoch 35 | Train Loss: 0.428 | Validation Loss: 2.494 Epoch 36 | Train Loss: 0.407 | Validation Loss: 2.545 Epoch 37 | Train Loss: 0.387 | Validation Loss: 2.569 Epoch 38 | Train Loss: 0.370 | Validation Loss: 2.603 Epoch 39 | Train Loss: 0.353 | Validation Loss: 2.652 Epoch 40 | Train Loss: 0.338 | Validation Loss: 2.678