renate.utils.module module#

renate.utils.module.evaluate_and_record_results(results, model, data_module, transform=None, target_transform=None, task_id='default_task', logged_metrics=None, metric_postfix='', batch_size=32, accelerator='auto', devices=None, strategy='ddp', precision='32')[source]#

A helper function that performs the evaluation on test data and records quantitative metrics in a dictionary.

Parameters:
  • results (Dict[str, List[List[float]]]) – The results dictionary to which the results should be saved.

  • model (RenateModule) – A RenateModule to be evaluated.

  • data_module (Union[Scenario, RenateDataModule]) – A Scenario or RenateDataModule from which the test data is queried.

  • transform (Optional[Callable]) – The transformation applied for evaluation.

  • target_transform (Optional[Callable]) – The target transformation applied for evaluation.

  • task_id (str) – The task ID for which the evaluation should be performed.

  • logged_metrics (Optional[Dict[str, Metric]]) – Metrics logged additional to the default ones.

  • metric_postfix (str) – The postfix for the metric names.

  • batch_size (int) – A batch size for the test loader.

  • accelerator (Literal['auto', 'cpu', 'gpu', 'tpu']) – Accelerator used by PyTorch Lightning to train the model.

  • devices (Optional[int]) – Devices used by PyTorch Lightning to train the model. If the devices flag is not defined, it will assume devices to be “auto” and fetch the auto_device_count from the accelerator.

  • strategy (str) – Name of the distributed training strategy to use. More details

  • precision (str) – Type of bit precision to use. More details

Return type:

Dict[str, List[List[float]]]

renate.utils.module.get_model(config_module, **kwargs)[source]#

Creates and returns a model instance.

Return type:

RenateModule

renate.utils.module.get_data_module(config_module, **kwargs)[source]#

Creates and returns a data module instance.

Return type:

RenateDataModule

renate.utils.module.get_loss_fn(config_module, convert, **kwargs)[source]#

Creates and returns the loss function from config

Return type:

Module

renate.utils.module.get_optimizer(config_module, **kwargs)[source]#

Creates partial optimizer object from config.

Return type:

Optional[Callable[[List[Parameter]], Optimizer]]

renate.utils.module.get_learning_rate_scheduler(config_module, **kwargs)[source]#

Creates partial learning rate scheduler object from config.

Return type:

Optional[Tuple[Callable[[Optimizer], _LRScheduler], Literal['epoch', 'step']]]

renate.utils.module.get_metrics(config_module, **kwargs)[source]#

Creates and returns a dictionary of metrics.

Return type:

Optional[Dict[str, Metric]]

renate.utils.module.get_and_prepare_data_module(config_module, **kwargs)[source]#

Prepares data.

Return type:

RenateDataModule

renate.utils.module.get_and_setup_data_module(config_module, prepare_data, **kwargs)[source]#

Creates data module and possibly calls the prepare_data function needed for setup

Return type:

RenateDataModule

renate.utils.module.import_module(module_name, location)[source]#

Imports Python module from file location.

Return type:

ModuleType