renate.utils.module module#
- renate.utils.module.evaluate_and_record_results(results, model, data_module, transform=None, target_transform=None, task_id='default_task', logged_metrics=None, metric_postfix='', batch_size=32, accelerator='auto', devices=None, strategy='ddp', precision='32')[source]#
A helper function that performs the evaluation on test data and records quantitative metrics in a dictionary.
- Parameters:
results¶ (
Dict[str,List[List[float]]]) – The results dictionary to which the results should be saved.model¶ (
RenateModule) – A RenateModule to be evaluated.data_module¶ (
Union[Scenario,RenateDataModule]) – A Scenario or RenateDataModule from which the test data is queried.transform¶ (
Optional[Callable]) – The transformation applied for evaluation.target_transform¶ (
Optional[Callable]) – The target transformation applied for evaluation.task_id¶ (
str) – The task ID for which the evaluation should be performed.logged_metrics¶ (
Optional[Dict[str,Metric]]) – Metrics logged additional to the default ones.metric_postfix¶ (
str) – The postfix for the metric names.batch_size¶ (
int) – A batch size for the test loader.accelerator¶ (
Literal['auto','cpu','gpu','tpu']) – Accelerator used by PyTorch Lightning to train the model.devices¶ (
Optional[int]) – Devices used by PyTorch Lightning to train the model. If the devices flag is not defined, it will assume devices to be “auto” and fetch theauto_device_countfrom theaccelerator.strategy¶ (
str) – Name of the distributed training strategy to use. More detailsprecision¶ (
str) – Type of bit precision to use. More details
- Return type:
Dict[str,List[List[float]]]
- renate.utils.module.get_model(config_module, **kwargs)[source]#
Creates and returns a model instance.
- Return type:
- renate.utils.module.get_data_module(config_module, **kwargs)[source]#
Creates and returns a data module instance.
- Return type:
- renate.utils.module.get_loss_fn(config_module, convert, **kwargs)[source]#
Creates and returns the loss function from config
- Return type:
Module
- renate.utils.module.get_optimizer(config_module, **kwargs)[source]#
Creates partial optimizer object from config.
- Return type:
Optional[Callable[[List[Parameter]],Optimizer]]
- renate.utils.module.get_learning_rate_scheduler(config_module, **kwargs)[source]#
Creates partial learning rate scheduler object from config.
- Return type:
Optional[Tuple[Callable[[Optimizer],_LRScheduler],Literal['epoch','step']]]
- renate.utils.module.get_metrics(config_module, **kwargs)[source]#
Creates and returns a dictionary of metrics.
- Return type:
Optional[Dict[str,Metric]]
- renate.utils.module.get_and_prepare_data_module(config_module, **kwargs)[source]#
Prepares data.
- Return type: