Inside FSDP: A Look at the Flat-Parameter Design Jan 27, 2025 Fully Sharded Data Parallel (FSDP) brings DeepSpeed’s Zero Redundancy Optimizer (ZeRO) into PyTorch’s ecosystem to facilitate large model training. To understand the design philosophy behind FSDP’s flat-parameter abstraction, we’ll step through a minimal implementation specifically designed for nanoGPT, building our intuition for the underlying mechanics. The full code is available here. FSDP in a Nutshell Unlike traditional data parallel training where each GPU maintains a complete model copy, FSDP divides a model into discrete units and shards each unit’s parameters, gradients, and optimizer states across workers. During the forward pass, FSDP reconstructs each unit on-demand by collecting parameter shards from participating ranks through an AllGather operation, then immediately after computing the unit’s activations discards the gathered parameters. This choreography repeats during the backward pass, computing gradients instead of activations. After backpropagation through a unit is complete, FSDP frees the non-local parameters (as in the forward pass) and distributes gradients across ranks through a ReduceScatter operation. In this way, we only materialize a single unit at any given time, significantly curbing peak memory consumption. Implementation walkthrough Classically in PyTorch, instantiating a model loads all parameters onto the GPU at once - which becomes impossible for large models. For context, training in half-precision with PyTorch AMP requires 18 bytes per parameter. Even with an H100’s 80GB of VRAM, we can at most store models with ~4.4B parameters - and this limit doesn’t account for the additional memory needed for activations during training. Creating the model on a meta device instead lets us define the architecture and parameter metadata without allocating any GPU memory: device_context = torch.device('meta') if self.is_distributed and self.args.deferred_init else torch.device(self.device) with device_context: model = GPT(GPTConfig(vocab_size=50304)) We then wrap our model with a constructor that mirrors PyTorch’s minimally invasive API design: model = CustomFSDP(model, process_group=self.dp_group, param_init_fn=model._init_weights) The constructor’s _wrap_blocks() subroutine recursively traverses the model and wraps each transformer block as its own FSDP module - imitating transformer_auto_wrap_policy from torch.distributed.fsdp.wrap. Partitioning the model into these self-contained units enables us to incrementally materialize it, starting with the transformer blocks at the top of the recursive call stack and ending with the root unit (i.e., the outermost one comprising the text embedding layer’s and language modelling head’s shared weight, the position embedding layer’s weight, and the final layernorm’s bias and gain) at the bottom. class CustomFSDP(nn.Module): def __init__(self, module, process_group, param_init_fn): super().__init__() self.process_group = process_group self.world_size = dist.get_world_size(self.process_group) self.rank = dist.get_rank(self.process_group) self._fsdp_wrapped_module = module self.param_numels = [] self.param_shapes = [] self.param_names = [] self.flat_param = None self.local_shard = None self._wrap_blocks(self.module, param_init_fn) self._record_param_metadata() self._create_and_shard_flat_param(param_init_fn) self.register_full_backward_pre_hook(lambda m, go: self._pre_backward()) self._register_post_backward_hooks() def _wrap_blocks(self, module, param_init_fn): for name, child in module.named_children(): if isinstance(child, Block): fsdp_unit = CustomFSDP(child, self.process_group, param_init_fn) setattr(module, name, fsdp_unit) else: self._wrap_blocks(child, param_init_fn) At the core of each FSDP unit is a single ‘flat parameter’, stored in the .flat_param attribute of our implementation. It is a 1D tensor created by concatenating the constituent model parameters after flattening them. Both the flat parameter and its gradient own the underlying storage of the original model parameters and their gradients. Individual parameters are implemented as views: windows into specific sections of the flat parameter. When you access a parameter’s data, you’re actually looking at a reshaped slice of the underlying .data tensor of .flat_param. Similarly, each parameter’s gradient is a reshaped view into the corresponding section of the .grad tensor of .flat_param. This setup maintains the illusion of independent parameters from the model’s perspective. Toy example of an FSDP unit with 3 parameters (A, B, C) in a 2-way data parallel setup. Returning to our deferred initialization procedure, we find that our parameters still reside on the meta device. However, PyTorch doesn’t support direct conversion of meta parameters to CUDA parameters via .to() since meta parameters, by design, store no actual data to transfer. We can work around this by assigning our parameter variables to freshly constructed nn.Parameter objects temporarily holding CUDA-allocated empty tensors, which we then immediately redirect to view the appropriate sections of .flat_param. A prerequisite for this reconstruction is recording crucial metadata such as each parameter’s shape and total number of elements, as well as identifying any parameters that are shared across different parts of the model. With our parameters now properly structured on device, we apply the initialization function to populate them with the correct starting values. def _record_param_metadata(self): for n, p in self.module.named_parameters(): if '_fsdp_wrapped_module' not in n: self.param_numels.append(p.numel()) self.param_shapes.append(p.shape) self.param_names.append(n) def _materialize_params(self): def _replace_param(param_path, new_param): *module_path, leaf = param_path.split('.') submodule = reduce(getattr, module_path, self.module) setattr(submodule, leaf, new_param) for name in self.param_names: _replace_param(name, nn.Parameter(torch.empty(0, device='cuda'))) def _create_and_shard_flat_param(self, param_init_fn): total_numel = sum(self.param_numels) padded_size = math.ceil(total_numel / self.world_size) * self.world_size shard_size = padded_size // self.world_size self.flat_param = torch.zeros(padded_size, device='cuda') self.local_shard = torch.zeros(shard_size, device='cuda') self.local_shard.grad = torch.zeros_like(self.local_shard) devices = {self.module.get_parameter(name).device for name in self.param_names} assert len(devices) == 1, "All parameters must be on the same device" is_materialized = (devices.pop() != torch.device('meta')) if is_materialized: offset = 0 for name, numel in zip(self.param_names, self.param_numels): self.flat_param[offset:offset+numel] = self.module.get_parameter(name).data.view(-1) offset += numel else: self._materialize_params() self._update_module_params() def _apply_param_init_fn(root_module, param_init_fn): queue = deque([root_module]) while queue: module = queue.popleft() if not isinstance(module, CustomFSDP): param_init_fn(module) for child in module.children(): if not isinstance(child, CustomFSDP): queue.append(child) _apply_param_init_fn(self.module, param_init_fn) start_idx = self.rank * shard_size end_idx = start_idx + shard_size self.local_shard.data.add_(self.flat_param[start_idx: end_idx]) self._shard() It’s worth noting that parameters are initialized in a breadth-first fashion, matching exactly the way it’s done natively in FSDP. By preserving this sequence, we ensure the random number generator produces identical samples, resulting in consistent parameter values across implementations. This consistency is a helpful sanity check. Because floating-point arithmetic is non-associative, achieving exact reproducibility is usually impossible to guarantee. However, by initializing parameters in the same order, we can directly compare loss values across training runs. We’re now ready to shard this unit before moving on to initialize the next one. The flat parameter’s total size is padded to ensure even division across workers. Each worker creates a dummy tensor whose length is the padded length divided by the world size, stores it in .local_shard and copies over its designated portion of the flat parameter. The sharding process involves carefully eliminating all references to the large gathered flat tensor to ensure proper memory deallocation. First, we redirect .flat_param’s data pointer to .local_shard’s data tensor. Then, we must update each model parameter’s data pointer to the appropriate section of .local_shard, as these parameters still hold references to the original flat tensor. def _shard(self, include_grads=False): self.flat_param.data = self.local_shard.data if include_grads: self.flat_param.grad = self.local_shard.grad self._update_module_params(include_grads=include_grads) def _update_module_params(self, include_grads=False): is_sharded = self.flat_param.data_ptr() == self.local_shard.data_ptr() local_shard_size = self.local_shard.numel() offset = 0 - local_shard_size * self.rank if is_sharded else 0 for name, numel, shape in zip(self.param_names, self.param_numels, self.param_shapes): data_tensor, grad_tensor = self._retrieve_data_and_grad_tensors(offset, numel, shape, is_sharded, local_shard_size, include_grads) parameter = self.module.get_parameter(name) parameter.data = data_tensor if include_grads: parameter.grad = grad_tensor offset += numel def _retrieve_data_and_grad_tensors(self, offset, numel, shape, is_sharded, local_shard_size, include_grads): if is_sharded: # Handle cases where parameter lies outside this shard if offset + numel < 0 or offset >= local_shard_size: return torch.empty(0, device='cuda'), None # Get slice of parameter from local shard start = max(offset, 0) end = min(offset + numel, local_shard_size) data_tensor = self.local_shard[start:end] grad_tensor = self.local_shard.grad[start:end] if include_grads else None else: # Get slice from full flattened parameter data_tensor = self.flat_param[offset:offset+numel].view(shape) grad_tensor = self.flat_param.grad[offset:offset+numel].view(shape) if include_grads else None return data_tensor, grad_tensor Assigning views of the local shard requires careful handling of parameter boundaries. While each rank maintains references to all model parameters, it only holds the actual values for the portion of the flat parameter copied into its local shard. For each parameter, we calculate its position relative to the local shard using the offset calculation at the beginning of _update_module_params(). The actual tensor indexing and view retrieval is handled in _retrieve_data_and_grad_tensors(), which addresses three possible scenarios for each parameter: Parameters entirely outside the shard - either ending before the shard begins (offset + numel < 0) or starting after it ends (offset >= local_shard_size) - get assigned an empty tensor. Parameters partially contained in the shard require extracting just the overlapping portion. The start and end indices are clamped using max and min operations - max clamps negative indices to 0, while min clamps indices greater than the shard length. Parameters fully contained in the shard (those that have a range [offset, offset + numel] that falls entirely within [0, local_shard_size]) are handled by the same logic and assigned views containing all their values. During training, we call forward() on our top-level FSDP module that wraps the entire GPT model. Since our implementation inherits from nn.Module and each transformer block has also been wrapped in its own FSDP instance, we can override the forward method to intercept calls to the underlying model. def forward(self, *args, **kwargs): self._gather() output = self._fsdp_wrapped_module(*args, **kwargs) self._shard() return output The distributed operations are coordinated as follows: first _gather() reconstructs the full flat parameter tensor on each rank via torch.distributed.all_gather_into_tensor(), then the wrapped module performs its forward pass exactly as it would in a non-distributed setting, and finally _shard() frees the memory used by the gathered parameters. def _gather(self, include_grads=False): full_tensor = torch.zeros(self.local_shard.numel() * self.world_size, device=self.local_shard.device) dist.all_gather_into_tensor(full_tensor, self.local_shard, group=self.process_group) self.flat_param.data = full_tensor if include_grads: full_grads_tensor = torch.zeros_like(self.flat_param) self.flat_param.grad = full_grads_tensor self._update_module_params(include_grads=include_grads) The backward pass follows the same gather-compute-shard sequence but we orchestrate it through a system of hooks rather than explicit calls. We register a pre-backward hook on the wrapped module that fires just before backpropagation through the FSDP unit begins. When triggered, this hook calls _pre_backward() to gather parameters with include_grads=True, constructing a gradient tensor that will own the storage for individual parameter gradients, and initializes .grad_counter that will help determine when to shard the unit. To track the backward pass through the unit, we attach a post-accumulation hook to each parameter that increments this counter. Once it equals the number of parameters in the unit - indicating all gradients have accumulated into the flat parameter’s gradient - we trigger a ReduceScatter to aggregate gradients across ranks into our local shard. This synchronization point marks when it’s safe to shard the unit. def _pre_backward(self): if all(self._fsdp_wrapped_module.get_parameter(name).grad is None for name in self.param_names): self.local_shard.grad = torch.zeros_like(self.local_shard) self.grad_counter = 0 self._gather(include_grads=True) def _post_backward(self): self.grad_counter += 1 if self.grad_counter == len(list(self.param_names)): grad_shards = list(self.flat_param.grad.chunk(self.world_size)) buffer = torch.empty(self.local_shard.shape, device='cuda') dist.reduce_scatter(buffer, grad_shards, op=dist.ReduceOp.AVG, group=self.process_group) self.local_shard.grad.add_(buffer) self._shard(include_grads=True) In most training setups, optimizer.zero_grad() is called before each forward pass to clear the gradients of model parameters for the next step. However, this operation is not aware of CustomFSDP’s local shard gradient buffer. We thus add a check to the _pre_backward() hook to detect when parameters owned by an FSDP unit have had their gradients reset to None and consequently zero out the local shard’s gradient. Advantages of the Flat-Parameter Design Consolidating parameters and their gradients into flat contiguous containers addresses two key constraints: efficient communication and minimal memory usage. As we’ve seen, FSDP relies heavily on AllGather and ReduceScatter. The flat-parameter abstraction optimizes these communications in two key ways: Reduced Fixed Overhead: Instead of paying the latency cost for each individual parameter/gradient, we pay it once per flat parameter. This is significant for these collective operations particularly, as they’re implemented in a ring fashion, requiring all participating GPUs to communicate with each other. The latency grows linearly with cluster size, making this optimization increasingly important as we scale to larger deployments. Better Bandwidth Utilization: By sending larger, consolidated payloads, we make better use of the available network bandwidth. The communication cost becomes dominated by actual data transfer, creating an amortization effect where the fixed latency cost gets spread over more useful work. The memory savings stem from the flat parameter’s contiguous layout perfectly matching NCCL’s AllGather API requirements. Since individual parameters are implemented as views into this flat tensor rather than separate allocations, we can directly use the AllGather’s output buffer as parameter storage, avoiding extra memory copies. The only additional buffer required is for gradient accumulation, since performing a ReduceScatter directly into the local shard’s gradient would overwrite previous values.