Source code for renate.utils.distributed_strategies

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import warnings
from typing import Optional

from pytorch_lightning.strategies import Strategy, StrategyRegistry

_SUPPORTED_STRATEGIES = [
    "ddp_find_unused_parameters_false",
    "ddp",
    "deepspeed",
    "deepspeed_stage_1",
    "deepspeed_stage_2",
    "deepspeed_stage_2_offload",
    "deepspeed_stage_3",
    "deepspeed_stage_3_offload",
    "deepspeed_stage_3_offload_nvme",
]
_UNSUPPORTED_STRATEGIES = [
    x for x in StrategyRegistry.available_strategies() if x not in _SUPPORTED_STRATEGIES
]


[docs] def create_strategy(devices: int = 1, strategy_name: Optional["str"] = None) -> Strategy: """Function returns a strategy object based on the number of devices queried and name of strategy""" devices = devices or 1 if strategy_name in _UNSUPPORTED_STRATEGIES: raise ValueError( f"Current strategy: {strategy_name} is unsupported. Choose deepspeed variants or ddp." ) if devices < 0: raise ValueError("Number of devices has to be at least 0.") elif devices == 1: # If one GPU, use standard training. Enabled by passing strategy=None # to pl.Trainer if strategy_name is not None: warnings.warn(f"With devices=1, strategy is ignored. But got {strategy_name}.") return None elif strategy_name in ["none", "None", None]: # Nothing is specified and devices > 1. Fall back to DDP return StrategyRegistry.get("ddp") elif "deepspeed" in strategy_name: strategy = StrategyRegistry.get(strategy_name) # TODO: This should be changed to instantiating Deepspeed and settting it in # the constructor. This works for nowbecause forcing PyTorch optimizer flag isn't used # anywhere by Deepspeed. strategy.config["zero_force_ds_cpu_optimizer"] = False return strategy else: # Something else happened. Fall back to whatever is happening. return StrategyRegistry.get(strategy_name)