Source code for renate.utils.config_spaces

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, Union

import numpy as np
from syne_tune.config_space import Domain, choice, loguniform, uniform


def _get_range(start, end, step):
    return [np.round(i, 8) for i in np.arange(start, end + 1e-8, step)]


_learner_config_space = {
    "optimizer": "SGD",
    "momentum": choice([0.0, 0.9, 0.99]),
    "weight_decay": loguniform(1e-6, 1e-2),
    "learning_rate": loguniform(0.001, 0.5),
    "batch_size": 64,
    "max_epochs": 50,
}
_replay_config_space = {
    **_learner_config_space,
    **{
        "batch_memory_frac": 0.5,
        "memory_size": 1000,
    },
}
_er_config_space = {
    **_replay_config_space,
    **{"updater": "ER", "alpha": 0.5, "loss_normalization": 0, "loss_weight": 0.5},
}
_der_config_space = {
    **_replay_config_space,
    **{
        "updater": "DER",
        "alpha": uniform(0.0, 1.0),
        "beta": uniform(0.0, 1.0),
        "loss_normalization": 0,
        "loss_weight": 1.0,
    },
}
_super_er_config_space = {
    **_replay_config_space,
    **{
        "updater": "Super-ER",
        "der_alpha": uniform(0.0, 1.0),
        "der_beta": uniform(0.0, 1.0),
        "sp_shrink_factor": choice(_get_range(0.3, 1, 0.1)),
        "sp_sigma": choice([0.0, 1e-5, 1e-4, 1e-3, 1e-2]),
        "cls_alpha": choice(_get_range(0, 1, 0.05)),
        "cls_stable_model_update_weight": 0.999,
        "cls_plastic_model_update_weight": 0.999,
        "cls_stable_model_update_probability": uniform(0.01, 0.49),
        "cls_plastic_model_update_probability": uniform(0.5, 1.0),
        "loss_normalization": 1,
        "ema_memory_update_gamma": uniform(0.95, 1.0),
    },
}
_repeated_distill_config_space = _replay_config_space
_offline_er_config_space = {
    **_replay_config_space,
    **{"loss_weight_new_data": choice([0.25, 0.5, 0.75])},
}


[docs] def config_space(updater: str) -> Dict[str, Union[Domain, str, int, float]]: """Returns the default configuration space for the updater.""" config_spaces = { "ER": _er_config_space, "DER": _der_config_space, "SUPER-ER": _super_er_config_space, "RD": _repeated_distill_config_space, "OfflineER": _offline_er_config_space, } return config_spaces[updater.upper()]