14.2 Fully Sharded Data Parallel

Large Scale Transformer model training with Tensor Parallel (TP)

Created Date: 2025-07-07

PyTorch FSDP2 provides a fully sharded data parallelism (FSDP) implementation targeting performant eager-mode while using per-parameter sharding for improved usability.

14.2.1 How FSDP2 works

In DistributedDataParallel (DDP) training, each rank owns a model replica and processes a batch of data, finally it uses all-reduce to sync gradients across ranks.

Comparing with DDP, FSDP reduces GPU memory footprint by sharding model parameters, gradients, and optimizer states. It makes it feasible to train models that cannot fit on a single GPU. As shown below in the picture:

  • Outside of forward and backward computation, parameters are fully sharded;

  • Before forward and backward, sharded parameters are all-gathered into unsharded parameters;

  • Inside backward, local unsharded gradients are reduce-scatterred into sharded gradients;

  • Optimizer updates sharded parameters with sharded gradients, resulting in sharded optimizer states.

FSDP can be considered a decomposition of DDP’s all-reduce into reduce-scatter and all-gather operations.

14.2.2 How to use FSDP2

14.2.2.1 Model Initialization

14.2.2.2 Forward/Backward with Prefetching

14.2.2.3 Enabling Mixed Precision

14.2.2.4 Gradient Clipping and Optimizer with DTensor

14.2.2.5 State Dicts with DTensor APIs

14.2.2.6 State Dict with DCP APIs