Source code for renate.updaters.avalanche.plugins

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from typing import Callable, Dict, IO, Optional, Union

import torch
from avalanche.training.plugins.checkpoint import CheckpointPlugin, FileSystemCheckpointStorage

from renate import defaults


[docs] class RenateFileSystemCheckpointStorage(FileSystemCheckpointStorage): def _make_checkpoint_dir(self, checkpoint_name: str) -> Path: return self.directory def _make_checkpoint_file_path(self, checkpoint_name: str) -> Path: return Path(defaults.learner_state_file(str(self._make_checkpoint_dir(checkpoint_name))))
[docs] def checkpoint_exists(self, checkpoint_name: str) -> bool: return self._make_checkpoint_file_path(checkpoint_name).exists()
[docs] def store_checkpoint( self, checkpoint_name: str, checkpoint_writer: Callable[[IO[bytes]], None] ): checkpoint_file = self._make_checkpoint_file_path(checkpoint_name) if checkpoint_file.exists(): checkpoint_file.unlink() super().store_checkpoint( checkpoint_name=checkpoint_name, checkpoint_writer=checkpoint_writer )
[docs] class RenateCheckpointPlugin(CheckpointPlugin): def __init__( self, storage: RenateFileSystemCheckpointStorage, map_location: Optional[Union[str, torch.device, Dict[str, str]]] = None, ): super().__init__(storage=storage, map_location=map_location)
[docs] def load_checkpoint_if_exists(self): if not self.storage.checkpoint_exists(defaults.LEARNER_CHECKPOINT_NAME): return None, 0 loaded_checkpoint = self.storage.load_checkpoint( defaults.LEARNER_CHECKPOINT_NAME, self.load_checkpoint ) return loaded_checkpoint["strategy"], loaded_checkpoint["exp_counter"]