5.3 nn.Transformer

Detailed explanation of the torch.nn.Transformer interface.

Created Date: 2025-07-02

In this tutorial, we will provide a detailed introduction to the torch.nn.Transformer interface through two examples, helping you fully master its usage.

The first example is a generative model from the official PyTorch examples word_language_model . The second example is an English-to-German translation model, which is a very common example in text processing.

5.3.1 Data Preprocess

The Dictionary class is designed to map words to unique numerical IDs, a common practice in Natural Language Processing (NLP) to convert text into a format machine learning models can understand.

class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []

    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]

    def __len__(self):
        return len(self.idx2word)

5.3.2 World Language Model

class GPTModel(torch.nn.Transformer):
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.1):
        super(GPTModel, self).__init__(
            d_model=ninp, nhead=nhead, dim_feedforward=nhid, num_encoder_layers=nlayers
        )
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(ninp, dropout)

        self.input_emb = torch.nn.Embedding(ntoken, ninp)
        self.ninp = ninp
        self.decoder = torch.nn.Linear(ninp, ntoken)

        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        torch.nn.init.uniform_(self.input_emb.weight, -initrange, initrange)
        torch.nn.init.zeros_(self.decoder.bias)
        torch.nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def forward(self, src, has_mask=True):
        if has_mask:
            device = src.device
            if self.src_mask is None or self.src_mask.size(0) != len(src):
                mask = torch.log(torch.tril(torch.ones(len(src), len(src)))).to(device)
                self.src_mask = mask
        else:
            self.src_mask = None

        src = self.input_emb(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output = self.encoder(src, mask=self.src_mask)
        output = self.decoder(output)
        return torch.nn.functional.log_softmax(output, dim=-1)

5.3.3 torch.nn.Transformer

5.3.3.1 __init__ Function

class torch.nn.Transformer(d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6,
    dim_feedforward=2048, dropout=0.1, activation=<function relu>,
    custom_encoder=None, custom_decoder=None, layer_norm_eps=1e-05,
    batch_first=False, norm_first=False, bias=True, device=None, dtype=None)

This Transformer layer implements the original Transformer architecture described in the Attention Is All You Need paper. The intent of this layer is as a reference implementation for foundational understanding and thus it contains only limited features relative to newer Transformer architectures.

  • d_model (int) – the number of expected features in the encoder/decoder inputs (default=512).

  • nhead (int) – the number of heads in the multiheadattention models (default=8).

  • num_encoder_layers (int) – the number of sub-encoder-layers in the encoder (default=6).

  • num_decoder_layers (int) – the number of sub-decoder-layers in the decoder (default=6).

  • dim_feedforward (int) – the dimension of the feedforward network model (default=2048).

  • dropout (float) – the dropout value (default=0.1).

  • activation (Union[str, Callable[[Tensor], Tensor]]) – the activation function of encoder/decoder intermediate layer, can be a string (“relu” or “gelu”) or a unary callable. Default: relu.

  • custom_encoder (Optional[Any]) – custom encoder (default=None).

  • custom_decoder (Optional[Any]) – custom decoder (default=None).

  • layer_norm_eps (float) – the eps value in layer normalization components (default=1e-5).

  • batch_first (bool) – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False (seq, batch, feature).

  • norm_first (bool) – if True, encoder and decoder layers will perform LayerNorms before other attention and feedforward operations, otherwise after. Default: False (after).

  • bias (bool) – If set to False, Linear and LayerNorm layers will not learn an additive bias. Default: True.

5.3.3.2 forward Function

forward(src, tgt,
    src_mask=None, tgt_mask=None, memory_mask=None,
    src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None,
    src_is_causal=None, tgt_is_causal=None, memory_is_causal=False)

Take in and process masked source/target sequences.

  • src (Tensor) – the sequence to the encoder (required).

  • tgt (Tensor) – the sequence to the decoder (required).

  • src_mask (Optional[Tensor]) – the additive mask for the src sequence (optional).

  • tgt_mask (Optional[Tensor]) – the additive mask for the tgt sequence (optional).

  • memory_mask (Optional[Tensor]) – the additive mask for the encoder output (optional).

  • src_key_padding_mask (Optional[Tensor]) – the Tensor mask for src keys per batch (optional).

  • tgt_key_padding_mask (Optional[Tensor]) – the Tensor mask for tgt keys per batch (optional).

  • memory_key_padding_mask (Optional[Tensor]) – the Tensor mask for memory keys per batch (optional).

  • src_is_causal (Optional[bool]) – If specified, applies a causal mask as src_mask. Default: None; try to detect a causal mask. Warning: src_is_causal provides a hint that src_mask is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility.

  • tgt_is_causal (Optional[bool]) – If specified, applies a causal mask as tgt_mask. Default: None; try to detect a causal mask. Warning: tgt_is_causal provides a hint that tgt_mask is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility.

  • memory_is_causal (bool) – If specified, applies a causal mask as memory_mask. Default: False. Warning: memory_is_causal provides a hint that memory_mask is the causal mask. Providing incorrect hints can result in incorrect execution, including forward and backward compatibility.

5.3.3.3 generate_square_subsequent_mask Function

Generate a square causal mask for the sequence.

5.3.4 English-German Translation

5.3.5 Detail