renate.utils.misc module#

renate.utils.misc.int_or_str(x)[source]#

Function to cast to int or str.

This is used to tackle precision which can be int (16, 32) or str (bf16)

Return type:

Union[str, int]

renate.utils.misc.maybe_populate_mask_and_ignore_logits(use_masking, class_mask, classes_in_current_task, logits)[source]#

Snippet to compute which logits to ignore after computing the class mask if required.

class renate.utils.misc.AdditionalTrainingMetrics[source]#

Bases: Callback

on_train_start()[source]#

Called when the train begins.

Return type:

None

on_train_epoch_end()[source]#

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, either: :rtype: None

  1. Implement training_epoch_end in the LightningModule and access outputs via the module OR

  2. Cache data across train batch hooks inside the callback implementation to post-process in this hook.

parameters_count(model)[source]#
Return type:

Tuple[int, int]