14.3 Tensor Parallel

Large Scale Transformer model training with Tensor Parallel (TP)

Created Date: 2025-07-07

This tutorial demonstrates how to train a large Transformer-like model across hundreds to thousands of GPUs using Tensor Parallel and Fully Sharded Data Parallel.

14.3.1 How Tensor Parallel works?

Tensor Parallel (TP) was originally proposed in the Megatron-LM paper, and it is an efficient model parallelism technique to train large scale Transformer models. Sequence Parallel (SP) we mention in this tutorial is a variant of Tensor Parallel that shards on the sequence dimension for nn.LayerNorm or RMSNorm to further save activation memory during training. As the model becomes larger, the activation memory becomes the bottleneck, so in Tensor Parallel training it usually applies Sequence Parallel to LayerNorm or RMSNorm layers.

14.3.2 When and Why you should apply Tensor Parallel

The PyTorch Fully Sharded Data Parallel (FSDP) already has the capability to scale model training to a specific number of GPUs. However, when it comes to further scale the model training in terms of model size and GPU quantity, many additional challenges arise that may require combining Tensor Parallel with FSDP:

  1. As the world size (number of GPUs) is becoming excessively large (exceeding 128/256 GPUs), the FSDP collectives (such as allgather) are being dominated by ring latency. By implementing TP/SP on top of FSDP, the FSDP world size could be reduced by 8 by applying FSDP to be inter-host only, consequently decreasing the latency costs by the same amount.

  2. Hit data parallelism limit where you can not raise the global batch size to be above the number of GPUs due to both convergence and GPU memory limitations, Tensor/Sequence Parallel is the only known way to “ballpark” the global batch size and continue scaling with more GPUs. This means both model size and number of GPUs could continue to scale.

  3. For certain types of models, when local batch size becomes smaller, TP/SP can yield matrix multiplication shapes that are more optimized for floating point operations (FLOPS).

So, when pre-training, how easy is it to hit those limits? As of now, pre-training a Large Language Model (LLM) with billions or trillions of tokens could take months, even when using thousands of GPUs.

  • It will always hit limitation 1 when training LLM on a large scale. For example, Llama 2 70B trained with 2k GPUs for 35 days, multi-dimensional parallelisms are needed at 2k scale.

  • When the Transformer model becomes larger (such as Llama2 70B), it will also quickly hit the limitation 2. One could not use FSDP alone with even local batch_size=1 due to memory and convergence constraints. For example, Llama 2 global batch size is 1K, so data parallelism alone can not be used at 2K GPUs.

14.3.3 Apply Sequence Parallel to LayerNorm/RMSNorm layers

Sequence Parallel works on top of the Tensor Parallel illustrated above. Compared with basic Tensor Parallel, which only shards tensors within the Attention modules and FeedForward modules and keep their module inputs and outputs (namely activations in the forward pass and gradients in the backward pass) replicated, Sequence Parallel keeps them sharded on the sequence dimension.

In a typical TransformerBlock, the forward function combines norm layers (LayerNorm or RMSNorm), an attention layer, a feed forward layer, and residual connections. For example:

14.3.4 Apply Loss Parallel

Loss Parallel is a related technique to save memory and communication when the loss function is computed, as model outputs are usually very large. In Loss Parallel, when the model outputs are sharded on the (often huge) vocabulary dimension, the cross-entropy loss can be computed efficiently, without gathering all the model outputs to every single GPU. This not only significantly reduces the memory consumption, but also improves training speed by reducing communication overhead and doing sharded computation in parallel. The picture below briefly illustrates how Loss Parallel avoids gathering all model outputs to every GPU by doing sharded computation.

14.3.5 Combine Tensor Parallel with Fully Sharded Data Parallel together

Now that we have shown how to apply Tensor/Sequence Parallel to the model, let us also take a look at how Tensor Parallel and Fully Sharded Data Parallel could work together. Since Tensor Parallelism incurs communications that block the computation, we want to make sure it runs within a fast communication channel, such as NVLink. In practice, we usually apply Tensor Parallel within each host, and apply Fully Sharded Data Parallel across the hosts.

14.3.6 Conclusion

This tutorial demonstrates how to train a large Transformer-like model across hundreds to thousands of GPUs using Tensor Parallel in combination with Fully Sharded Data Parallel. It explains how to apply Tensor Parallel to different parts of the model, with no code changes to the model itself. Tensor Parallel is a efficient model parallelism technique for large scale training.

To see the complete end-to-end code example explained in this tutorial, please refer to the Tensor Parallel examples in the pytorch/examples repository.