It’s 2023. Is PyTorch’s FSDP the best choice for training large models?

OpenMMLab
14 min readAug 1, 2023

The wave of large model training initiated by ChatGPT has made many eager to try their hand at training large models. When looking for training baselines, you’ve surely noticed that the codebase for training large models tends to use frameworks like DeepSpeed (MMEngine v0.8.0 also supports it, allowing one-click switching for convenience!) or ColossalAI (MMEngine will support it in the next version!), with scant regard for PyTorch’s native FSDP (FullyShardedDataParallel). But why is this? Is FSDP not memory-efficient enough? Is it too slow for training? Or is it simply inconvenient to use? Read on, and I’m sure you’ll gain some insights.

Background of FSDP

FSDP’s implementation was inspired by FairScale. When developing large features, PyTorch typically creates a new library to provide some experimental support and collect user feedback, such as FairScale, Dynamo (the cornerstone of PyTorch 2.0), and torchdistx. Once the feature becomes more mature, it may be incorporated into PyTorch. Compared to the brief introduction of FSDP in PyTorch’s official tutorial, FairScale has done a much better job. Before we start the introduction, here is an introduction by FairScale, and it’s worth considering: do you really need FSDP? (This is also true for other large-scale training frameworks)

Introduction to the ZeRO Series

Having seen the above figure, you’ll notice that FairScale defines FSDP as ZeRO3. Considering that some may not be familiar with the ZeRO series of large model optimization strategies, let me give a brief introduction:

During model training, memory usage can be largely divided into three parts: activation values, model weights, gradients, and optimizer states. For vision models, activation values take up most of the memory, so mixed-precision training can significantly reduce memory usage (fp16). However, for large language models or multimodal models, optimizing the memory usage of the latter three becomes more important.

Taking PyTorch as an example, when you use DistributedDataParallel, it allocates memory for model parameters, gradients, and optimizer states in each process and synchronously updates this data during training. Although this approach can speed up training through data parallelism, its memory allocation strategy is evidently poor. Since the parameters in each process are the same, why should each process save the complete set of parameters? Thus, ZeRO advocates that each process should only save a part of the parameters, gathering them into all processes when needed. ZeRO has three stages of optimization strategies:

ZeRO1: Sharding only the optimizer state

ZeRO2: Sharding both the optimizer state and gradients

ZeRO3: Sharding optimizer state, gradients, and model parameters

Take a model with 7.5B (φ) parameters as an example, let’s briefly calculate the memory usage of model parameters, gradients, and optimizer states:

fp32 training:

The model parameter size is φ, the gradient size is also φ, and in the case of using Adam, the optimizer state is 2φ. If it’s standard fp32 training, then the actual memory used is (1 + 1 + 2)φ * 4: 16φ bytes (4 is the memory size occupied by fp32 data).

fp16 training:

If mixed-precision training is enabled, to ensure the precision of parameter updates, the optimizer state needs to remain in fp32, and an additional copy of fp32 model parameters needs to be stored. Therefore, memory usage is 2φ(model parameters) + 2φ(model gradients) + 8φ(optimizer state) + 4φ(copy of fp32 model parameters stored in optimizer in the DeepSpeed implementation): 16φ bytes.

From this perspective, it’s clear why the memory usage of a 7.5B model can be as high as 120B, and why the ZeRO series is so effective.

FSDP — ZeRO3?

Returning to the main topic, FairScale says that FSDP is equivalent to ZeRO3’s optimization. Let’s understand this through a simple example (in this example, the optimizer is SGD because PyTorch’s Adam has been heavily optimized and its actual memory usage is much higher than theoretical). Before the official test, let’s look at the tests of single device fp32 training, single device fp16 training, and DDP fp16 training:

Single device fp16 + fp32

class Layer(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Sequential(
*(nn.Linear(10000, 10000) for _ in range(10))
)

def forward(self, x):
return self.linear(x)

def test_fp32():
model = Layer().cuda()
optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)
data = torch.ones(10000).cuda()
for i in range(10):
optimizer.zero_grad()
output = model(data)
loss = output.sum()
loss.backward()
optimizer.step()
memory = max_memory_allocated()
print(f'step memory allocate: {memory / 1e9:.3f}G')

def test_fp16():
torch.cuda.init()
model = Layer().cuda()
optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)
data = torch.ones(10000).cuda()
for _ in range(10):
with autocast(device_type='cuda'):
optimizer.zero_grad()
output = model(data)
loss = output.sum()
loss.backward()
optimizer.step()
memory = max_memory_allocated()
print(f'memory allocated: {memory / 1e9:.3f}G')

After running the code, we find that the memory usage is as follows:

fp32: 12.035G

fp16: 14.035G

What? Does amp use an additional 2G of memory? How is this calculated? This comes down to the implementation of amp. PyTorch’s amp doesn’t change the type of model weights, so they’re still stored in fp32,but it converts the fp32 weights to fp16 before and after the forward/backward for the whitelisted operators to calculate the fp16 activation and gradients. The fp16 gradients are further converted to fp32 to ensure the precision of parameter updates. But if both the weights and gradients remain in fp32 and the optimizer state is unchanged, why is an additional 2G used? The reason is that the fp16 weights during forward and backward operations are cached, which is implemented in the amp’s C++ code. The cached fp16 gradients are the source of the extra 2G.

To save this part of the parameters, you need to pass cache_enabled=False to autocast.

def test_fp16():
torch.cuda.init()
model = Layer().cuda()
optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)
data = torch.ones(10000).cuda()
for _ in range(10):
with autocast(device_type='cuda', cache_enabled=False):
optimizer.zero_grad()
output = model(data)
loss = output.sum()
loss.backward()
optimizer.step()
memory = max_memory_allocated()
print(f'memory allocated: {memory / 1e9:.3f}G')

As a result, the memory consumption is 12.235G, which is basically consistent with fp32 and meets expectations.

DDP Training

DDP just creates and updates the model in each process, so memory usage should still be around 12G, right?

def _test_ddp_fp16():
rank = dist.get_rank()
model = DistributedDataParallel(Layer().cuda())
optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)
data = torch.ones(10000).cuda()
for _ in range(10):
with autocast(device_type='cuda', cache_enabled=False):
optimizer.zero_grad()
output = model(data)
loss = output.sum()
loss.backward()
optimizer.step()
memory = max_memory_allocated()
if rank == 0:
print(f'memory allocated: {memory / 1e9:.3f}G')

However, the result is:

16.036G

The principle is simple. DDP requires a bucket for gradient computation and gradient synchronization, and the bucket retains a copy of the gradient, so it consumes about 4G more memory.

FSDP Training

When using FSDP, we need to configure the auto_wrap_policy parameter to choose the model sharding strategy, otherwise the memory optimization can only reach the level of ZeRO-stage1. The configuration of auto_wrap_policy and its corresponding principle will be explained in detail in the following sections.

from torch.distributed.fsdp.wrap import _module_wrap_policy

def _test_fsdp_fp16():
rank = dist.get_rank()
fsdp_model = FullyShardedDataParallel(
module=Layer(), device_id=rank,
auto_wrap_policy=partial(
_module_wrap_policy,
module_classes=nn.Linear))
optimizer = SGD(fsdp_model.parameters(), lr=0.1, momentum=0.9)
data = torch.ones(10000).cuda()
for _ in range(10):
optimizer.zero_grad()
output = fsdp_model(data)
loss = output.sum()
loss.backward()
optimizer.step()
memory = max_memory_allocated()
if rank == 0:
print(f'step memory allocate: {memory / 1e9:.3f}G')
torch.cuda.reset_max_memory_allocated()

The result is 1.524G, which is basically equivalent to the memory optimization effect of ZeRO3.

Analysing the memory usage here is to help you look at memory optimization rationally when switching from DDP to FSDP.

FSDP Sharding Strategy

In the previous section, we mentioned that we need to specify the model sharding strategy through the auto_wrap_policy. So how does this parameter work? And why can the optimization only reaches ZeRO-stage1 without configuring this parameter?

Similar to DistributedDataParallel, FSDP also uses a model wrapper: FullyShardedDataParallel to implement the logic of parameter slicing. The Wrapped module will become the root fsdp module, and the root fsdp module will recursively wrap the submodule into a child fsdp module according to the user-defined auto_wrap_policy when building:

Take the officially implemented _module_wrap_policy as an example, where the key parameter module_classes is used to indicate which type of submodule should be wrapped into a child fsdp module.

def _module_wrap_policy(
module: nn.Module,
recurse: bool,
nonwrapped_numel: int,
module_classes: Set[Type[nn.Module]],
) -> bool:
"""
This auto wrap policy wraps every module that is an instance of any type in
``module_classes`` as its own FSDP instance. The root module given by
``module`` is always wrapped as an FSDP instance regardless. Since the
wrapping proceeds bottom up, each FSDP instance manages the parameters in
its subtree excluding any already managed by a child FSDP instance.

Args:
module (nn.Module): Current module being considered.
recurse (bool): If ``False``, then this function must decide whether
``module`` should be wrapped as an FSDP instance or not. If
``True``, then the function is still recursing down the module
tree as a part of the DFS.
nonwrapped_numel (int): Parameter numel not yet wrapped.
module_classes (Set[Type[nn.Module]]): Set of module classes that are
wrapped as FSDP instances.

Returns:
``True`` if ``recurse=True``, and whether ``module`` should be wrapped
if ``recurse=False``.
"""
if recurse:
return True # always recurse
if inspect.isclass(module_classes):
module_classes = (module_classes, )
return isinstance(module, tuple(module_classes))

In the previous section, we specified it as nn.Linear, which means each nn.Linear will be wrapped into a child fsdp module.

All fsdp modules will trigger parameter unsharding (all gather) and sharding during the forward process.

  1. The forward of the root fsdp module will gather the parameters of different processes in the pre-forward stage and register some pre-backward-hook and post-backward-hook. Then it releases parameters that do not belong to the current rank in the post-forward stage. The pre-backward-hook will gather the parameters again before executing backward, and the post-backward-hook is responsible for implementing gradient reduce-scatter, that is, gradient synchronization + gradient distribution.

It should be noted that the fsdp-module forward will not further gather the parameters of the child fsdp module.

Compared with the child fsdp module, the forward of the root fsdp module will also do some additional work such as cuda stream initialization, which is not further discussed here.

2. The forward of the child fsdp module

The main logic is basically the same as the root fsdp module

It can be seen that each time the fsdp module only gathers part of the parameters, which is in line with our expectations. So what if we don’t set auto_wrap_policy? That is, there are no child fsdp modules.

During the forward stage of the root fsdp module, it will directly gather all the parameters, which means that it is impossible to achieve the memory saving through parameter slicing in ZeRO-stage3. However, the slicing of gradients and optimizer states in ZeRO1 and ZeRO2 can still be achieved. The reason is that the post-backward-hook is still registered during the forward stage, so the logic of gradient reduce-scatter will still work. When building the Optimizer, the parameters of the root fsdp module are passed in, so the optimizer will directly update the sliced parameters and record the state of the sliced parameters, so the optimization of the sliced state of the optimizer is also effective.

auto_wrap_policy needs to follow a certain interface specification, that is, accept the following parameters:

  • module: the module accessed when recursively traversing the submodule
  • recurse: Whether to further recursively wrap the submodule of child fsdp module submodule to child fsdp module
  • nonwrapped_numel: The meaning of this parameter is the parameter quantity of the current module that does not need to be sliced. What are the parameters that do not need to be sliced? Generally speaking, it includes two parts, namely the already sliced parameters and the parameters that the user specifies to be ignored (ignored_params). Based on this parameter, a size-based wrap policy can be implemented, such as the officially implemented size_based_auto_wrap_policy.

FSDP gives users the right to configure the auto_wrap_policy parameter, which has indeed improved its flexibility, but it has also invisibly increased the learning cost of FSDP.For example, what effect will auto_wrap_policy have, what is the meaning of its several input parameters. Users may feel puzzled when they get start with FSDP.

However, if the cost of using FSDP is limited to this, I believe everyone is still willing to learn and use it. However, some implicit conventions and some strange errors are really discouraging.

The painful lessons learned from experimenting with FSDP

Risks of Replacing Submodules

In the previous section, we mentioned that FSDP replaces submodules with child FSDP modules after wrapping. You might wonder, what will happen if the parent module accesses some attributes or methods of the submodule? Will an AttributeError be raised ?

def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to the wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self._fsdp_wrapped_module, name)

This way, it looks for undefined attributes in the submodule. However, this still poses risks.

  1. If the attribute you access happens to have the same name as an attribute in the FSDP, you might access the wrong attribute.
  2. If you directly access the submodule’s parameter and perform some operations on it. Since parameters are gathered during the forward stage, what you get directly at this point is a sharded parameter, which will probably throw an error.
  3. If you happen not to directly call the __call__ method of the child fsdp module, for example in this situation:
class Layer(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.processor = nn.Linear(1, 1)
self.linear1 = nn.Linear(1, 1)
self.linear2 = nn.Linear(1, 1)

def forward(self, x):
return self.linear1(x) + self.linear2(x)

class ToyModel(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.linear = nn.Linear(1, 1)
self.layer = Layer() # 会被 auto wrap policy 指定为 child fsdp module

def forward(self, x):
y = self.linear(self.layer.processor(x))
return self.layer(y)

Suppose Layer is wrapped as an fsdp module and self.layer.processor is directly called by ToyModel.forward, an error will be raised since the Layer.forward has not been called an the parameters of processor still remain sharded.

Or in this case:

class A:
...
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample]) -> dict:
feats = self.extract_feat(inputs)
return self.head.loss(feats, data_samples)

class B:
...
def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample], **kwargs) -> dict:
cls_score = self(feats) # 没有走 FSDP 的 forward
losses = self._get_loss(cls_score, data_samples, **kwargs)
return losses

class B is a submodule head of class A, and A will call self.head.loss. If class B is wrapper as a child fsdp module, the sharded tensor will no be gathered when calling self.head.loss, then a corresponding error will be raised.

Optimizer with Multiple Parameter Groups

PyTorch’s optimizer supports setting different learning rates, momentum, and other hyperparameters for different parameters in the model. The setup process looks something like this:

param_groups = []
for module in model.modules():
if isinstance(module, nn.BatchNorm2d):
param_groups.append({'param': module.weight, lr=0.01})
param_groups.append({'param': module.bias, lr=0.1})
elif:

optimizer = SGD(param_groups, lr=0.1)

However, the problem is, prior to PyTorch 2.0, once the root fsdp module and child fsdp module are built, it deletes the original parameters, such as bn.weights, bn.bias, and converts all the unsliced parameters under the fsdp module into a large flatten parameter.

For example, in the previous chapter’s example, if no auto_wrap_policy is specified, only the outermost root fsdp module will be retained. Then, all the parameters of the linear layers will be reconstructed into a large flatten parameter, placed under the root_fsdp_module:

        rank = dist.get_rank()
fsdp_model = FullyShardedDataParallel(
module=Layer(), device_id=rank,
# auto_wrap_policy=partial(
# _module_wrap_policy,
# module_classes=nn.Linear),
)
print(list(fsdp_model.parameters()))

At this point, each rank will only print out one parameter:

[Parameter containing:
Parameter(FlatParameter([-4.6519e-05, -6.2861e-03, 3.9519e-03, ..., -3.2763e-03,
7.1111e-04, -8.2136e-03], device='cuda:3', requires_grad=True))]

Therefore, before PyTorch 2.0, once FSDP was used, it was difficult to set different learning rates for each parameter because multiple parameters would be merged into one after fsdp wrap. The subsequent gradient shard and parameter updates are also based on the flatten tensor.

Since parameter updates are also based on the flatten tensor, FSDP requires consistent dtype and requires_grad attributes of each parameter under the same fsdp module, otherwise,Parameters cannot compose a large flatten tensor.

PyTorch 2.0 added a use_orig_params parameter to FSDP. When this parameter is turned on, FSDP will not delete the original parameters during the wrap process. The memory of the original parameters point to some area of the flatten params.

This is a great update. Without introducing additional GPU memory consumption, users can still access the original parameters and set different optimizer hyperparameters for them. With the introduction of this parameter, in theory, the restriction on the uniformity of the requires_grad attribute of all parameters under the same FSDP module should also be lifted. Unfortunately, PyTorch 2.0 did not adjust this part of the logic, but this issue has been fixed on the main branch, and it is believed that the upcoming PyTorch 2.1 will be able to solve this pain point.

Stability of FSDP Interface

Although as early as PyTorch 1.11, FSDP was already a beta feature, to this day, the FSDP module is still in a state of rapid iteration. In February 2023, the developers of FSDP initiated a discussion, introducing some design concepts and internal restructuring.

In addition, the external interface of FSDP is updated relatively quickly. When you open the API documentation of PyTorch FSDP, you will find that many interfaces are marked as deprecated. However, overall, the new interface is indeed much easier to use and more flexible than the old one. The integration of FSDP by MMEngine is also based on the new interface.

Conclusion

  1. FSDP, in terms of memory savings, is indeed equivalent to ZeRO3, but it should be noted that when mixed-precision training (autocast) is enabled, cache_enabled needs to be set to False.
  2. FSDP has a higher learning curve in terms of ease of use. Users need to understand the logic of FSDP wrapping modules, the role of auto_wrap_policy, and some limitations. Unexpected errors are prone to happen if users do not have an overall understanding about FSDP . The error messages and the actual cause of the error may not be highly related, making it difficult to debug.
  3. PyTorch 2.0 has greatly improved the usability of FSDP through the use_ori_params parameter, but the restriction on the uniformity of the requires_grad attribute still exists. To solve this problem, you can wait for the PyTorch 2.1 update and specify use_orig_params=True. But if you want to solve this temporarily, you need to make some changes to auto_wrap_policy . Since this is based on the internal agreement of FSDP, it may not be very stable, so I won't go into details here.

In general, FSDP leaves something to be desired in terms of ease of use, but in terms of flexibility, it gives users more room for operation. However, with the continuous iteration of PyTorch, FSDP is expected to become as easy to use as DDP. MMEngine will also closely follow the updates of FSDP, aiming to lower the entry threshold while maintaining flexibility and summarizing a set of simple, easy-to-configure best practices.

If you’re interested, feel free to encourage more updates. If there’s an opportunity, we can further discuss the design philosophy of FSDP, the construction logic of flatten params, the rules for parameter slicing, and the parallel methods for gradient computation and synchronization in FSDP. Let’s exchange ideas on how to tackle the errors thrown by FSDP (hopefully, with fewer rounds of debugging after PyTorch updates).

What? You also want to get a comprehensive analysis of DeepSpeed, ColossalAI, and FSDP? MMEngine also supports DeepSpeed from version v0.8.0, and we will bring an introduction to DeepSpeed next time. Please pay more attention to MMEngine and give it a star. We believe that in the near future, you will be able to switch freely between FSDP, DeepSpeed, and ColossalAI with just a few lines of code and experience the pros and cons of various training frameworks yourself.

--

--