Source code for renate.utils.deepspeed
# Copyright 2020 The PyTorch Lightning team and Microsoft Corporation. All rights reserved.
# Modifications: Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved
# SPDX-License-Identifier: Apache-2.0
import os
import pickle as pkl
from pathlib import Path
from typing import Any, Dict, Optional, Union
import torch
from deepspeed.utils.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint,
get_model_state_file,
get_optim_files,
)
CPU_DEVICE = torch.device("cpu")
[docs]
def ds_checkpoint_dir(checkpoint_dir: Union[str, Path], tag: Optional[str] = None) -> str:
if tag is None:
latest_path = os.path.join(checkpoint_dir, "latest")
if os.path.isfile(latest_path):
with open(latest_path) as fd:
tag = fd.read().strip()
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
directory = os.path.join(checkpoint_dir, tag)
if not os.path.isdir(directory):
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
return directory
# modified
[docs]
def search_key(state: Dict[str, Any], substring: str) -> str:
"""This function looks for a substring in keys of dict and returns the full key that
is the first match."""
for k in state.keys():
if substring in k:
return k
# Modified script from
# https://github.com/Lightning-AI/lightning/blob/1.8.6/src/pytorch_lightning/utilities/deepspeed.py
# which is modified from
# https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/utils/zero_to_fp32.py
[docs]
def convert_zero_checkpoint_to_fp32_state_dict(
checkpoint_dir: Union[str, Path], tag: Optional[str] = None
) -> Dict[str, Any]:
"""Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file
that can be loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training
without DeepSpeed. Additionally the script has been modified to ensure we keep the
lightning state inside the state dict for being able to run
``LightningModule.load_from_checkpoint('...')``.
Modification to this version include the explicit handling of the _extra_state
element of state dict. Deepspeed's and Lightning get-fp-32... functions only collate
trainable parameters.
Args:
checkpoint_dir: path to the desired checkpoint folder.
(one that contains the tag-folder, like ``global_step14``)
tag: checkpoint tag used as a unique identifier for checkpoint. If not provided will
attempt to load tag in the file named ``latest`` in the checkpoint folder,
e.g., ``global_step14``.
"""
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
# additional logic to ensure we keep the lightning state dict as well from rank 0.
deepspeed_states = [
"module",
"optimizer",
"lr_scheduler",
"csr_tensor_module_names",
"skipped_steps",
"global_steps",
"dp_world_size",
"mp_world_size",
]
checkpoint_dir = ds_checkpoint_dir(checkpoint_dir)
optim_files = get_optim_files(checkpoint_dir)
optim_state = torch.load(optim_files[0], map_location=CPU_DEVICE)
zero_stage = optim_state["optimizer_state_dict"]["zero_stage"]
model_file = get_model_state_file(checkpoint_dir, zero_stage)
client_state = torch.load(model_file, map_location=CPU_DEVICE)
# Assign extra_state by searching for which key it is
extra_key = search_key(client_state["module"], "extra_state")
extra_state = client_state["module"][extra_key]
state_dict[extra_key] = extra_state
# End of modifications
client_state = {
key: value for key, value in client_state.items() if key not in deepspeed_states
}
# State dict keys will include reference to wrapper _LightningModuleWrapperBase
# Delete `module` prefix before saving.
state_dict = {k.partition("module.")[2]: state_dict[k] for k in state_dict.keys()}
client_state["state_dict"] = state_dict
return client_state
[docs]
def convert_to_tensor(obj):
"""This function converts a pickleable object to a torch tensor. This is only to
aid saving with Deepspeed."""
return torch.as_tensor(list(pkl.dumps(obj)))
[docs]
def recover_object_from_tensor(tensor):
"""This function converts a tensor to a byte stream that is passed through pickle
to recover the underlying object. For usage with Deepspeed"""
return pkl.loads(bytes(tensor.tolist()))