Source code for renate.updaters.learner_components.reinitialization
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import torch
from renate.models.renate_module import RenateModule
from renate.updaters.learner_components.component import Component
from renate.utils.pytorch import reinitialize_model_parameters
[docs]
class ReinitializationComponent(Component):
"""Resets the model using each layer's built-in reinitialization logic.
See also `renate.utils.torch_utils.reinitialize_model_parameters`.
"""
[docs]
def on_train_start(self, model: RenateModule) -> None:
reinitialize_model_parameters(model)
[docs]
class ShrinkAndPerturbReinitializationComponent(Component):
"""Shrinking and Perturbation reinitialization through scaling the
weights and adding random noise.
Ash, J., & Adams, R. P. (2020). On warm-starting neural network training.
Advances in Neural Information Processing Systems, 33, 3884-3894.
Args:
shrink_factor: A scaling coefficient applied to shrink the weights.
sigma: variance of the random Gaussian noise added to the weights.
"""
def __init__(self, shrink_factor: float, sigma: float) -> None:
self._shrink_factor = shrink_factor
self._sigma = sigma
super().__init__()
def _verify_attributes(self) -> None:
"""Verify if attributes have valid values."""
super()._verify_attributes()
assert self._shrink_factor > 0.0, "Shrink factor must be positive."
assert self._sigma >= 0, "Sigma must be non-negative."
[docs]
@torch.no_grad()
def on_train_start(self, model: RenateModule) -> None:
"""Shrink and perturb the model's weights."""
for p in model.parameters():
if self._shrink_factor != 1.0:
p.mul_(self._shrink_factor)
if self._sigma != 0.0:
p.add_(self._sigma * torch.randn(p.size(), device=p.device))