5.3 nn.Transformer
Detailed explanation of the torch.nn.Transformer interface.
Created Date: 2025-07-02
In this tutorial, we will provide a detailed introduction to the torch.nn.Transformer interface through two examples, helping you fully master its usage.
The first example is a generative model from the official PyTorch examples word_language_model . The second example is an English-to-German translation model, which is a very common example in text processing.
5.3.1 Data Preprocess
The Dictionary class is designed to map words to unique numerical IDs, a common practice in Natural Language Processing (NLP) to convert text into a format machine learning models can understand.
class Dictionary(object):
def __init__(self):
self.word2idx = {}
self.idx2word = []
def add_word(self, word):
if word not in self.word2idx:
self.idx2word.append(word)
self.word2idx[word] = len(self.idx2word) - 1
return self.word2idx[word]
def __len__(self):
return len(self.idx2word)
5.3.2 World Language Model
class GPTModel(torch.nn.Transformer):
def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.1):
super(GPTModel, self).__init__(
d_model=ninp, nhead=nhead, dim_feedforward=nhid, num_encoder_layers=nlayers
)
self.src_mask = None
self.pos_encoder = PositionalEncoding(ninp, dropout)
self.input_emb = torch.nn.Embedding(ntoken, ninp)
self.ninp = ninp
self.decoder = torch.nn.Linear(ninp, ntoken)
self.init_weights()
def init_weights(self):
initrange = 0.1
torch.nn.init.uniform_(self.input_emb.weight, -initrange, initrange)
torch.nn.init.zeros_(self.decoder.bias)
torch.nn.init.uniform_(self.decoder.weight, -initrange, initrange)
def forward(self, src, has_mask=True):
if has_mask:
device = src.device
if self.src_mask is None or self.src_mask.size(0) != len(src):
mask = torch.log(torch.tril(torch.ones(len(src), len(src)))).to(device)
self.src_mask = mask
else:
self.src_mask = None
src = self.input_emb(src) * math.sqrt(self.ninp)
src = self.pos_encoder(src)
output = self.encoder(src, mask=self.src_mask)
output = self.decoder(output)
return torch.nn.functional.log_softmax(output, dim=-1)
5.3.3 torch.nn.Transformer
5.3.3.1 __init__ Function
class torch.nn.Transformer(d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6,
dim_feedforward=2048, dropout=0.1, activation=<function relu>,
custom_encoder=None, custom_decoder=None, layer_norm_eps=1e-05,
batch_first=False, norm_first=False, bias=True, device=None, dtype=None)
This Transformer layer implements the original Transformer architecture described in the Attention Is All You Need paper. The intent of this layer is as a reference implementation for foundational understanding and thus it contains only limited features relative to newer Transformer architectures.
d_model (int) – the number of expected features in the encoder/decoder inputs (default=512).
nhead (int) – the number of heads in the multiheadattention models (default=8).
num_encoder_layers (int) – the number of sub-encoder-layers in the encoder (default=6).
num_decoder_layers (int) – the number of sub-decoder-layers in the decoder (default=6).
dim_feedforward (int) – the dimension of the feedforward network model (default=2048).
dropout (float) – the dropout value (default=0.1).
activation (Union[str, Callable[[Tensor], Tensor]]) – the activation function of encoder/decoder intermediate layer, can be a string (“relu” or “gelu”) or a unary callable. Default: relu.
custom_encoder (Optional[Any]) – custom encoder (default=None).
custom_decoder (Optional[Any]) – custom decoder (default=None).
layer_norm_eps (float) – the eps value in layer normalization components (default=1e-5).
batch_first (bool) – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False (seq, batch, feature).
norm_first (bool) – if True, encoder and decoder layers will perform LayerNorms before other attention and feedforward operations, otherwise after. Default: False (after).
bias (bool) – If set to False, Linear and LayerNorm layers will not learn an additive bias. Default: True.
5.3.3.2 forward Function
forward(src, tgt,
src_mask=None, tgt_mask=None, memory_mask=None,
src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None,
src_is_causal=None, tgt_is_causal=None, memory_is_causal=False)
Take in and process masked source/target sequences.
src (Tensor) – the sequence to the encoder (required).
tgt (Tensor) – the sequence to the decoder (required).
src_mask (Optional[Tensor]) – the additive mask for the src sequence (optional).
tgt_mask (Optional[Tensor]) – the additive mask for the tgt sequence (optional).
memory_mask (Optional[Tensor]) – the additive mask for the encoder output (optional).
src_key_padding_mask (Optional[Tensor]) – the Tensor mask for src keys per batch (optional).
tgt_key_padding_mask (Optional[Tensor]) – the Tensor mask for tgt keys per batch (optional).
memory_key_padding_mask (Optional[Tensor]) – the Tensor mask for memory keys per batch (optional).
src_is_causal (Optional[bool]) – If specified, applies a causal mask as src_mask. Default: None; try to detect a causal mask. Warning: src_is_causal provides a hint that src_mask is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility.
tgt_is_causal (Optional[bool]) – If specified, applies a causal mask as tgt_mask. Default: None; try to detect a causal mask. Warning: tgt_is_causal provides a hint that tgt_mask is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility.
memory_is_causal (bool) – If specified, applies a causal mask as memory_mask. Default: False. Warning: memory_is_causal provides a hint that memory_mask is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility.
5.3.3.3 generate_square_subsequent_mask Function
Generate a square causal mask for the sequence.