renate.updaters.avalanche.plugins module#

class renate.updaters.avalanche.plugins.RenateFileSystemCheckpointStorage(directory)[source]#

Bases: FileSystemCheckpointStorage

checkpoint_exists(checkpoint_name)[source]#

Checks if a checkpoint exists

Parameters:

checkpoint_name (str) – The name of the checkpoint to check.

Return type:

bool

Returns:

True if it exists.

store_checkpoint(checkpoint_name, checkpoint_writer)[source]#

Stores a checkpoint.

This method expects a checkpoint name and a callable.

The callable must accept a file-like object as input. The file-like object is created by the checkpoint storage (this object) and it will accept binary write operations to store the byte representation of the checkpoint.

Parameters:
  • checkpoint_name (str) – The name of the checkpoint.

  • checkpoint_writer (Callable[[IO[bytes]], None]) – A callable that accepts a writable file-like object. The callable must write the checkpoint to the provided file object.

Returns:

None. It will raise an exception if the checkpoint cannot be loaded, depending on the specific implementation.

class renate.updaters.avalanche.plugins.RenateCheckpointPlugin(storage, map_location=None)[source]#

Bases: CheckpointPlugin

load_checkpoint_if_exists()[source]#

Loads the latest checkpoint if it exists.

This will load the strategy (including the model weights, all the plugins, metrics, and loggers), load and set the state of the global random number generators (torch, torch cuda, numpy, and Python’s random), and the number of training experiences so far.

The loaded checkpoint refers to the last successful evaluation.

Parameters:

update_checkpoint_plugin – Defaults to True, which means that the CheckpointPlugin in the un-pickled strategy will be replaced with self (this plugin instance).

Returns:

The loaded strategy and the number experiences so far (this number can also be interpreted as the index of the next training experience).