renate.utils.module module#
- renate.utils.module.evaluate_and_record_results(results, model, data_module, transform=None, target_transform=None, task_id='default_task', logged_metrics=None, metric_postfix='', batch_size=32, accelerator='auto', devices=None, strategy='ddp', precision='32')[source]#
A helper function that performs the evaluation on test data and records quantitative metrics in a dictionary.
- Parameters:
results¶ (
Dict
[str
,List
[List
[float
]]]) – The results dictionary to which the results should be saved.model¶ (
RenateModule
) – A RenateModule to be evaluated.data_module¶ (
Union
[Scenario
,RenateDataModule
]) – A Scenario or RenateDataModule from which the test data is queried.transform¶ (
Optional
[Callable
]) – The transformation applied for evaluation.target_transform¶ (
Optional
[Callable
]) – The target transformation applied for evaluation.task_id¶ (
str
) – The task ID for which the evaluation should be performed.logged_metrics¶ (
Optional
[Dict
[str
,Metric
]]) – Metrics logged additional to the default ones.metric_postfix¶ (
str
) – The postfix for the metric names.batch_size¶ (
int
) – A batch size for the test loader.accelerator¶ (
Literal
['auto'
,'cpu'
,'gpu'
,'tpu'
]) – Accelerator used by PyTorch Lightning to train the model.devices¶ (
Optional
[int
]) – Devices used by PyTorch Lightning to train the model. If the devices flag is not defined, it will assume devices to be “auto” and fetch theauto_device_count
from theaccelerator
.strategy¶ (
str
) – Name of the distributed training strategy to use. More detailsprecision¶ (
str
) – Type of bit precision to use. More details
- Return type:
Dict
[str
,List
[List
[float
]]]
- renate.utils.module.get_model(config_module, **kwargs)[source]#
Creates and returns a model instance.
- Return type:
- renate.utils.module.get_data_module(config_module, **kwargs)[source]#
Creates and returns a data module instance.
- Return type:
- renate.utils.module.get_loss_fn(config_module, convert, **kwargs)[source]#
Creates and returns the loss function from config
- Return type:
Module
- renate.utils.module.get_optimizer(config_module, **kwargs)[source]#
Creates partial optimizer object from config.
- Return type:
Optional
[Callable
[[List
[Parameter
]],Optimizer
]]
- renate.utils.module.get_learning_rate_scheduler(config_module, **kwargs)[source]#
Creates partial learning rate scheduler object from config.
- Return type:
Optional
[Tuple
[Callable
[[Optimizer
],_LRScheduler
],Literal
['epoch'
,'step'
]]]
- renate.utils.module.get_metrics(config_module, **kwargs)[source]#
Creates and returns a dictionary of metrics.
- Return type:
Optional
[Dict
[str
,Metric
]]
- renate.utils.module.get_and_prepare_data_module(config_module, **kwargs)[source]#
Prepares data.
- Return type: