PyTorch Lightning¶
We support Pytorch Lightning by NnScalerStrategy and NnScalerPrecision. You can use nnscaler strategy in pytorch lightning like this:
compute_config=ComputeConfig(...)
policy = ...
trainer = Trainer(
...,
strategy=NnScalerStrategy(
compute_config=compute_config, pas_policy=..., gen_savedir=...,
...
),
plugins=[NnScalerPrecision(precision, ...)],
...
)
trainer.fit(...)
Model¶
Dummy input¶
We need a dummy input to trace the forward function. You can specify it in two ways:
Add
dummy_forward_argsproperty to your model class, which should be a dictionary of forward inputs.You can also add
dummy_forward_args_fn, which will be used to convert the sample (loaded from train dataloader) to forward inputs.
Rewritten members¶
We will rewrite two functions:
forwardfunction: As we explained before, theforwardfunction will be replaced with a distributed version.logfunction: We will rewrite thelogfunction to force thesync_dist_groupto be set properly whensync_dist=True.
We will also set all trainable modules to None to reduce memory usage.
To make sure the model can be used with nnscaler strategy, you should follow these rules:
All trainable parameters should only be used in forward function. If it is used outside forward, it should be in torch.no_grad context. Otherwise, as we don’t create reduce-op outside forward, its gradient will be incorrect.
Train/Validate/Test should use exactly the same graph.
All functions replying on the trainable modules should be rewritten with forward function. After our conversion, all those modules will be None.
Strategy¶
The constructor argument of NnScalerStrategy is the combination of Strategy’s constructor and nnscaler.parallize function. You can refer to the documentation of Strategy and nnscaler.parallize for more details.
One special argument is state_dict_type, which specify the format in which the state of the model and optimizers gets saved into the checkpoint.
"sharded": Each rank saves its shard of weights and optimizer states to a file. The checkpoint is a folder with as many files as the local world size."deduped": Each rank saves its deduped shard of weights and optimizer states to a file. The checkpoint is a folder with as many files as the local world size.
Precision¶
It has exactly the same constructor arguments as Precision’s constructor.
Currently we support 32-true, 16-true, bf16-true, 16-mixed, bf16-mixed.
You can specify a grad scaler when you use 16-true.
Checkpoint¶
If this is the first time you run the model, and you have a pretrained model, you must load the pretrained model before you pass it to the Trainer constructor. The tracing process will use the pretrained model weights to trace the forward function.
Just like other pytorch lightning strategy,
you can resume from a checkpoint by specifying the ckpt_path argument in the Trainer.fit function.
Please note when the parallel plan is changed (i.e you re-trace the model with different configurations),
the checkpoints become incompatible, and can’t be loaded any more.
You must firstly merge the checkpoints to a merged checkpoint with NnScalerStrategy.merge_checkpoint and then load the merged checkpoint as a regular checkpoint.
def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str) -> None:
where checkpoint_files is a list of checkpoint files to merge, and output_file is the output file path.
Limitations¶
Currently, nnScaler only supports:
single parameter group.
single optimizer.
single learning rate scheduler.