renate.updaters.learner_components.reinitialization module#

class renate.updaters.learner_components.reinitialization.ReinitializationComponent(weight=0, sample_new_memory_batch=False)[source]#

Bases: Component

Resets the model using each layer’s built-in reinitialization logic.

See also renate.utils.torch_utils.reinitialize_model_parameters.

on_train_start(model)[source]#

Updates the model parameters.

Parameters:

model (RenateModule) – The model used for training.

Return type:

None

class renate.updaters.learner_components.reinitialization.ShrinkAndPerturbReinitializationComponent(shrink_factor, sigma)[source]#

Bases: Component

Shrinking and Perturbation reinitialization through scaling the weights and adding random noise.

Ash, J., & Adams, R. P. (2020). On warm-starting neural network training. Advances in Neural Information Processing Systems, 33, 3884-3894.

Parameters:
  • shrink_factor (float) – A scaling coefficient applied to shrink the weights.

  • sigma (float) – variance of the random Gaussian noise added to the weights.

on_train_start(model)[source]#

Shrink and perturb the model’s weights.

Return type:

None