renate.utils.deepspeed module#
- renate.utils.deepspeed.search_key(state, substring)[source]#
This function looks for a substring in keys of dict and returns the full key that is the first match.
- Return type:
str
- renate.utils.deepspeed.convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, tag=None)[source]#
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated
state_dict
file that can be loaded withtorch.load(file)
+load_state_dict()
and used for training without DeepSpeed. Additionally the script has been modified to ensure we keep the lightning state inside the state dict for being able to runLightningModule.load_from_checkpoint('...')
. Modification to this version include the explicit handling of the _extra_state element of state dict. Deepspeed’s and Lightning get-fp-32… functions only collate trainable parameters.- Parameters:
checkpoint_dir¶ (
Union
[str
,Path
]) – path to the desired checkpoint folder. (one that contains the tag-folder, likeglobal_step14
)tag¶ (
Optional
[str
]) – checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file namedlatest
in the checkpoint folder, e.g.,global_step14
.
- Return type:
Dict
[str
,Any
]