14.2 Fully Sharded Data Parallel
Large Scale Transformer model training with Tensor Parallel (TP)
Created Date: 2025-07-07
PyTorch FSDP2 provides a fully sharded data parallelism (FSDP) implementation targeting performant eager-mode while using per-parameter sharding for improved usability.
14.2.1 How FSDP2 works
In DistributedDataParallel (DDP) training, each rank owns a model replica and processes a batch of data, finally it uses all-reduce to sync gradients across ranks.
Comparing with DDP, FSDP reduces GPU memory footprint by sharding model parameters, gradients, and optimizer states. It makes it feasible to train models that cannot fit on a single GPU. As shown below in the picture:
Outside of forward and backward computation, parameters are fully sharded;
Before forward and backward, sharded parameters are all-gathered into unsharded parameters;
Inside backward, local unsharded gradients are reduce-scatterred into sharded gradients;
Optimizer updates sharded parameters with sharded gradients, resulting in sharded optimizer states.
FSDP can be considered a decomposition of DDP’s all-reduce into reduce-scatter and all-gather operations.