Source code for renate.utils.misc
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import time
from typing import Dict, Optional, Set, Tuple, Union
from pytorch_lightning import Callback
import torch
from renate.utils.pytorch import complementary_indices
[docs]
def int_or_str(x: str) -> Union[str, int]:
"""Function to cast to int or str.
This is used to tackle precision which can be int (16, 32) or str (bf16)
"""
try:
return int(x)
except ValueError:
return x
[docs]
def maybe_populate_mask_and_ignore_logits(
use_masking: bool,
class_mask: Optional[torch.Tensor],
classes_in_current_task: Optional[Set[int]],
logits: torch.Tensor,
):
"""Snippet to compute which logits to ignore after computing the class mask if required."""
if use_masking:
if class_mask is None:
# Now is the time to repopulate the class_mask
class_mask = torch.tensor(
complementary_indices(logits.size(1), classes_in_current_task),
device=logits.device,
dtype=torch.long,
)
# fill the logits with -inf
logits.index_fill_(1, class_mask.to(logits.device), -float("inf"))
return logits, class_mask
[docs]
class AdditionalTrainingMetrics(Callback):
def __init__(self) -> None:
self._training_start_time = None
self._curr_epoch_end_time = None
[docs]
def on_train_start(self) -> None:
self._training_start_time = time.time()
[docs]
def on_train_epoch_end(self) -> None:
self._curr_epoch_end_time = time.time()
def __call__(self, model: torch.nn.Module) -> Dict[str, Union[float, int]]:
if all([self._training_start_time, self._curr_epoch_end_time]):
total_training_time = self._curr_epoch_end_time - self._training_start_time
else:
total_training_time = 0.0
# maximum amount of memory used in training. This might
# not be the best choice, but the most convenient.
peak_memory_usage = (
torch.cuda.memory_stats()["allocated_bytes.all.peak"]
if torch.cuda.is_available()
else 0
)
trainable_params, total_params = self.parameters_count(model)
return dict(
total_training_time=total_training_time,
peak_memory_usage=peak_memory_usage,
trainable_params=trainable_params,
total_params=total_params,
)
[docs]
def parameters_count(self, model: torch.nn.Module) -> Tuple[int, int]:
trainable_params, total_params = 0, 0
for param in model.parameters():
num_params = param.numel()
total_params += num_params
if param.requires_grad:
trainable_params += num_params
return trainable_params, total_params