of a series about distributed AI across multiple GPUs:
Introduction
In the previous post, we saw how Distributed Data Parallelism (DDP) speeds up training by splitting batches across GPUs. DDP solves the throughput problem, but it introduces a new challenge: memory redundancy.
In vanilla DDP, every GPU holds a complete copy of the model parameters, gradients, and optimizer states. For large models like GPT-3 (175B parameters), this redundancy becomes a big waste of precious VRAM.
ZeRO (Zero Redundancy Optimizer) solves this. There are three levels:
- ZeRO-1 partitions only optimizer states
- ZeRO-2 partitions optimizer states + gradients
- ZeRO-3 partitions optimizer states + gradients + model parameters
ZeRO isn’t a parallelism technique because all GPUs still run the same forward and backward passes. It’s a memory optimization strategy that eliminates redundancy across GPUs, letting you train larger models on the same hardware.
The Memory Problem in DDP
Let’s break down what actually consumes memory during training. For a model with parameters:
- Model Parameters: values (the weights of your neural network)
- Gradients: values (one gradient per parameter)
- Optimizer States (Adam): values (first moment and second moment for each parameter)
- Activations: Intermediate outputs stored during forward pass for use in backward pass
The first three scale with model size and are redundant across GPUs in DDP. Activations scale with batch size, sequence length, and # neurons, and are unique per GPU since each GPU processes different data. ZeRO doesn’t touch activation memory.
Let’s calculate the memory usage for a 7B-parameter model using Adam and FP32:
- Parameters: 7 billion * 4 bytes = 28 GB
- Gradients: 7 billion * 4 bytes = 28 GB
- Optimizer states: 7 billion * 2 * 4 bytes = 56 GB
- Memory per GPU in DDP: 112 GB
Activations add significant memory on top of this, but since they’re unique per GPU, ZeRO can’t partition them. Techniques like activation checkpointing can help, it discards some activations and then recomputes them as needed during the backward pass. But that’s outside the scope of this article.
Let’s understand how ZeRO works by implementing it from the ground up, starting with ZeRO-1 and working our way to ZeRO-3.
ZeRO-1: Optimizer State Partitioning
In ZeRO-1, only the optimizer states are partitioned. Each GPU:
- Still holds the full model parameters and gradients
- Stores only 1/N of the optimizer states (N = number of GPUs)
- Updates only the corresponding 1/N of the parameters
This is the sequence actions taken during training:
- Forward pass: each GPU processes its own micro-batch
- Backward pass: compute gradients
all-reducegradients: every GPU gets the all gradients- Optimizer step: Each GPU updates its parameter partition
all-gatherparameters: sync the updated model across GPUs

Here’s a simplified implementation:
import torch
import torch.distributed as dist
class ZeRO_1:
def __init__(self, model, optimizer_cls):
self.model = model
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.param_shards = list() # each rank holds only its shard of the optimizer states
self.param_metadata = list() # metadata to reconstruct shards
for param in self.model.parameters():
original_shape = param.data.shape
flat = param.data.view(-1)
numel = flat.numel()
remainder = numel % self.world_size
pad_size = (self.world_size - remainder) % self.world_size
padded_numel = numel + pad_size
shard_size = padded_numel // self.world_size
shard_start = self.rank * shard_size
shard_end = shard_start + shard_size
self.param_metadata.append(
{
"original_shape": original_shape,
"numel": numel,
"padded_numel": padded_numel,
"shard_size": shard_size,
"shard_start": shard_start,
"shard_end": shard_end,
}
)
if pad_size > 0:
flat_padded = torch.cat([flat, flat.new_zeros(pad_size)])
else:
flat_padded = flat
shard = flat_padded[shard_start:shard_end].clone()
shard.requires_grad_(True)
self.param_shards.append(shard)
self.optimizer = optimizer_cls(self.param_shards)
def training_step(self, inputs, targets, loss_fn):
output = self.model(inputs) # forward
loss = loss_fn(output, targets) # compute loss
loss.backward() # backward
self._sync_gradients() # all-reduce gradients across GPUs
self.optimizer.step() # update local shard of parameters
self._sync_params() # all gather model params
# clear gradients for the next step
for param in self.model.parameters():
param.grad = None
def _sync_gradients(self):
for idx, param in enumerate(self.model.parameters()):
meta = self.param_metadata[idx]
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
param.grad /= self.world_size
self.param_shards[idx].grad = param.grad.view(-1)[meta["shard_start"]:meta["shard_end"]]
def _sync_params(self):
for idx, param in enumerate(self.model.parameters()):
meta = self.param_metadata[idx]
full_flat = torch.empty(meta["padded_numel"], device=param.device, dtype=param.dtype)
dist.all_gather_into_tensor(
output_tensor=full_flat,
input_tensor=self.param_shards[idx].data,
)
reconstructed = full_flat[:meta["numel"]].view(meta["original_shape"])
param.data.copy_(reconstructed)
Notice that the all-reduce syncs all gradients, but each GPU only uses the gradients for its own parameter partition, it’s overcommunicating. ZeRO-2 fixes this by sharding the gradients too.
In practice, you’d never use ZeRO-1 as ZeRO-2 gives you better memory savings at essentially the same cost. But it’s still worth going over it for learning purposes.
Memory with ZeRO-1, 7B model, 8 GPUs:
- Parameters: 28 GB (fully replicated)
- Gradients: 28 GB (fully replicated)
- Optimizer states: 56 GB / 8 = 7 GB
- Total per GPU: 63 GB (down from GB)
ZeRO-2: Gradient Partitioning
ZeRO-2 partitions both optimizer states and gradients. Since each GPU only updates a partition of parameters, it only needs the corresponding gradients.
ZeRO-1 uses all-reduce, which gives every GPU all the gradients. ZeRO-2 replaces this with reduce-scatter, each GPU receives only the gradients it actually needs. This saves both memory and communication bandwidth.
Training steps:
- Forward pass: each GPU processes its own micro-batch
- Backward pass: compute gradients
reduce-scattergradients: each GPU gets only its partition- Optimizer step: Each GPU updates its parameter partition
all-gatherparameters: sync the updated model across GPUs

The implementation is very similar to ZeRO-1, but the gradient synchronization step uses reduce-scatter instead of all-reduce:
But wait, if every GPU computes all gradients during backprop, how does this actually save VRAM? Here’s how:
- As the parameter gradients are computed layer by layer, they’re immediately
reduce-scatteredand the local copy is freed (our simplified implementation doesn’t perform this). - During backprop, you only need the gradient of the next neuron activation to compute the current param’s gradient, i.e., you don’t need the entire gradient graph.
- That way you can free up the memory for gradients as you’re moving backwards, keeping only the assigned partition for each GPU.
Memory with ZeRO-2, 7B model, 8 GPUs:
- Parameters: 28 GB (fully replicated)
- Gradients: 28 GB / 8 = 3.5 GB
- Optimizer states: 56 GB / 8 = 7 GB
- Total per GPU: 38.5 GB (down from 112 GB)
ZeRO-3: Parameter Partitioning
ZeRO-3 partitions optimizer states, gradients, and parameters. Each GPU stores only 1/N of the entire model state.
During forward and backward passes, each layer needs its full parameters, but each GPU only stores a fraction. So we all-gather parameters just-in-time, use them, then discard immediately after.
Training steps:
- Forward pass:
- All-gather the layer’s parameters from all GPUs
- Run the layer’s forward pass using previous layer’s activations as input
- Discard the gathered parameters (keep only the local partition)
- Repeat these steps until all layers are done
- Backward pass (per layer, in reverse):
- All-gather the layer’s parameters again
- Compute gradients for current layer using activation gradients from next layer
- Reduce-scatter the gradients (each GPU keeps its shard)
- Discard the gathered parameters (keep only the local partition)
- Repeat these steps until all layers are done
- Each GPU runs an optimizer step on its partition
- No final all-gather needed since parameters are gathered layer-by-layer during the forward pass

Here’s a simplified implementation:
class ZeRO_3(ZeRO_2):
"""
ZeRO-3: Shard optimizer states (stage 1) + gradients (stage 2) + model parameters (stage 3).
At rest, each rank holds only param_shards[idx] — a 1/world_size slice
of each parameter. Full parameters are materialised temporarily during
the forward and backward passes via all_gather, then immediately freed.
"""
def __init__(self, model, optimizer_cls):
self.model = model
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.param_metadata = []
shard_list = []
self._param_to_idx = {}
for idx, param in enumerate(self.model.parameters()):
original_shape = param.data.shape
flat = param.data.view(-1)
numel = flat.numel()
remainder = numel % self.world_size
pad_size = (self.world_size - remainder) % self.world_size
padded_numel = numel + pad_size
shard_size = padded_numel // self.world_size
shard_start = self.rank * shard_size
shard_end = shard_start + shard_size
self.param_metadata.append(
{
"original_shape": original_shape,
"numel": numel,
"padded_numel": padded_numel,
"shard_size": shard_size,
"shard_start": shard_start,
"shard_end": shard_end,
}
)
if pad_size > 0:
flat_padded = torch.cat([flat, flat.new_zeros(pad_size)])
else:
flat_padded = flat
shard = flat_padded[shard_start:shard_end].clone()
shard_list.append(shard)
# Replace the full tensor with only this rank's shard.
# The model's param.data now points to a tiny slice; the full
# weight will be reconstructed on demand during forward/backward.
param.data = shard.detach()
self._param_to_idx[param] = idx
self.param_shards = [s.requires_grad_(True) for s in shard_list]
self.optimizer = optimizer_cls(self.param_shards)
self._register_hooks()
def _gather_param(self, idx, device, dtype):
"""All-gather the full parameter tensor for parameter `idx`."""
meta = self.param_metadata[idx]
full_flat = torch.empty(meta["padded_numel"], device=device, dtype=dtype)
dist.all_gather_into_tensor(
output_tensor=full_flat,
input_tensor=self.param_shards[idx].data,
)
return full_flat[: meta["numel"]].view(meta["original_shape"])
def _gather_module_params(self, module):
"""Gather full params for every parameter that belongs to this module only (not children)."""
for param in module.parameters(recurse=False):
idx = self._param_to_idx[param]
param.data = self._gather_param(idx, param.device, param.dtype)
def _reshard_module_params(self, module):
"""Reshard params back to local shard for every direct param of this module."""
for param in module.parameters(recurse=False):
idx = self._param_to_idx[param]
param.data = self.param_shards[idx].data
def _register_hooks(self):
self._hooks = []
for module in self.model.modules():
# Skip container modules that have no direct parameters
if not list(module.parameters(recurse=False)):
continue
# Forward: gather -> run -> reshard
h1 = module.register_forward_pre_hook(
lambda mod, _inputs: self._gather_module_params(mod)
)
h2 = module.register_forward_hook(
lambda mod, _inputs, _output: self._reshard_module_params(mod)
)
# Backward: gather before grad computation → reshard after
h3 = module.register_full_backward_pre_hook(
lambda mod, _grad_output: self._gather_module_params(mod)
)
h4 = module.register_full_backward_hook(
lambda mod, _grad_input, _grad_output: self._reshard_module_params(mod)
)
self._hooks.extend([h1, h2, h3, h4])
def training_step(self, inputs, targets, loss_fn):
# Hooks handle all gather/reshard around each module automatically
output = self.model(inputs)
loss = loss_fn(output, targets)
loss.backward()
self._sync_gradients()
# Each rank updates only its local shard
self.optimizer.step()
for param in self.model.parameters():
param.grad = None
Each layer’s parameters are gathered right before they’re needed and freed immediately after. This keeps peak memory minimal at the cost of more communication. In practice, implementations overlap the all-gather for layer N+1 with the forward of layer N to hide latency.
Memory with ZeRO-3, 7B model, 8 GPUs:
- Parameters: 28 GB / 8 = 3.5 GB
- Gradients: 28 GB / 8 = 3.5 GB
- Optimizer states: 56 GB / 8 = 7 GB
- Total per GPU: 14 GB (down from 112 GB)
That’s an 8x reduction in memory usage, which is exactly what we’d expect from partitioning across 8 GPUs.
Using ZeRO in PyTorch
PyTorch ships with two implementations of ZeRO-3: FSDP1 (older, less optimized) and FSDP2 (newer, recommended). Always use FSDP2.
FSDP (Fully Sharded Data Parallel) handles parameter gathering, gradient scattering, communication overlap, and memory management automatically:
from torch.distributed.fsdp import fully_shard
model = Transformer()
for layer in model.layers:
fully_shard(layer)
fully_shard(model)
You have to apply fully_shard layer-by-layer and then wrap the whole model.
Conclusion
ZeRO is exchanging memory for communication, so it’s not a free lunch. In general it’s not worth it for smaller models (e.g. BERT) but it’s a game changer for larger models.
Congratulations on making it to the end! In this post, you learned about:
- The memory redundancy problem in standard DDP
- How ZeRO partitions optimizer states, gradients, and parameters across GPUs
- The three stages of ZeRO and their memory/communication trade-offs
- How to use ZeRO-3 via PyTorch’s FSDP
In the next article, we’ll explore Tensor Parallelism, a model parallelism technique that speeds up a layer computation by distributing work across GPUs.