Source code for renate.utils.optimizer

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from functools import partial
from typing import Callable, List

import torch
from torch.nn import Parameter
from torch.optim import Optimizer

import renate.defaults as defaults


[docs] def create_partial_optimizer( optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, lr: float = defaults.LEARNING_RATE, momentum: float = defaults.MOMENTUM, weight_decay: float = defaults.WEIGHT_DECAY, ) -> Callable[[List[Parameter]], Optimizer]: """Creates a partial optimizer object. Args: optimizer: The name of the optimizer to be used. Options: `Adam` or `SGD`. lr: Learning rate to be used. momentum: Value for the momentum hyperparameter (if relevant). weight_decay: Value for the weight_decay hyperparameter (if relevant). """ if optimizer == "SGD": return partial(torch.optim.SGD, lr=lr, momentum=momentum, weight_decay=weight_decay) elif optimizer == "Adam": return partial(torch.optim.Adam, lr=lr, weight_decay=weight_decay) else: raise ValueError(f"Unknown optimizer: {optimizer}.")