renate.utils.deepspeed module#

renate.utils.deepspeed.ds_checkpoint_dir(checkpoint_dir, tag=None)[source]#
Return type:

str

renate.utils.deepspeed.search_key(state, substring)[source]#

This function looks for a substring in keys of dict and returns the full key that is the first match.

Return type:

str

renate.utils.deepspeed.convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, tag=None)[source]#

Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict file that can be loaded with torch.load(file) + load_state_dict() and used for training without DeepSpeed. Additionally the script has been modified to ensure we keep the lightning state inside the state dict for being able to run LightningModule.load_from_checkpoint('...'). Modification to this version include the explicit handling of the _extra_state element of state dict. Deepspeed’s and Lightning get-fp-32… functions only collate trainable parameters.

Parameters:
  • checkpoint_dir (Union[str, Path]) – path to the desired checkpoint folder. (one that contains the tag-folder, like global_step14)

  • tag (Optional[str]) – checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named latest in the checkpoint folder, e.g., global_step14.

Return type:

Dict[str, Any]

renate.utils.deepspeed.convert_to_tensor(obj)[source]#

This function converts a pickleable object to a torch tensor. This is only to aid saving with Deepspeed.

renate.utils.deepspeed.recover_object_from_tensor(tensor)[source]#

This function converts a tensor to a byte stream that is passed through pickle to recover the underlying object. For usage with Deepspeed