renate.evaluation.evaluator module#
- class renate.evaluation.evaluator.Evaluator(model, batch_size, transform=None, target_transform=None, logged_metrics=None)[source]#
Bases:
LightningModule
,ABC
A general Evaluator module for collection of quantitative metrics on the test dataset.
This is an abstract interface which can be called with respect to a PyTorch Lightning
Trainer
. and itstest()
function. It collects quantitative observations with respect to a single dataset. The metrics that are being collected are defined in thecreate_metrics
function.- Parameters:
model¶ (
RenateModule
) – ARenateModule
to be evaluated.batch_size¶ (
int
) – The batch size to be used when creating the test data loader.transform¶ (
Optional
[Callable
]) – The transformation applied for evaluation.target_transform¶ (
Optional
[Callable
]) – The target transformation applied for evaluation.logged_metrics¶ (
Optional
[Dict
[str
,Metric
]]) – Metrics logged additional to the default ones.
- on_model_test_start(test_dataset, test_collate_fn=None, task_id=None)[source]#
Called before a model test starts.
- Return type:
DataLoader
- test_step(batch, batch_idx)[source]#
PyTorch Lightning function to perform the test step.
- Return type:
None
- class renate.evaluation.evaluator.ClassificationEvaluator(model, batch_size, transform=None, target_transform=None, logged_metrics=None)[source]#
Bases:
Evaluator
A classification Evaluator module for collection of quantitative metrics on the test dataset.
- renate.evaluation.evaluator.evaluate(model, test_dataset, test_collate_fn=None, task_id='default_task', batch_size=32, transform=None, target_transform=None, logged_metrics=None, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32')[source]#
Evaluate the model on the test dataset or a set of test datasets corresponding to distinct tasks.
If the
test_dataset
are specified as a list of datasets, it is assumed to be ordered. Similarly, in a case thetask_id
are specified as a list, it is assumed to be ordered. A task ID list can be used to set specific model part to be used, for example, an output head with some specific test dataset in the input sequence.- Parameters:
model¶ (
RenateModule
) – ARenateModule
to be evaluated.test_dataset¶ (
Union
[List
[Dataset
],Dataset
]) – The test dataset(s) to be evaluated.test_collate_fn¶ (
Optional
[Callable
]) – collate_fn used in the DataLoader.task_id¶ (
Union
[List
[str
],str
]) – The task id(s) of the test dataset(s).batch_size¶ (
int
) – The batch size to be used when creating the test data loader.transform¶ (
Optional
[Callable
]) – The transformation applied for evaluation.target_transform¶ (
Optional
[Callable
]) – The target transformation applied for evaluation.logged_metrics¶ (
Optional
[Dict
[str
,Metric
]]) – Metrics logged additional to the default ones.logger¶ (
Logger
) – Logger used by PyTorch Lightning to log intermediate results.accelerator¶ (
Literal
['auto'
,'cpu'
,'gpu'
,'tpu'
]) – Accelerator used by PyTorch Lightning to train the model.devices¶ (
Optional
[int
]) – Devices used by PyTorch Lightning to train the model. If the devices flag is not defined, it will assume devices to be “auto” and fetch theauto_device_count
from theaccelerator
.strategy¶ (
str
) – Name of the distributed training strategy to use. More detailsprecision¶ (
str
) – Type of bit precision to use. More details
- Return type:
Dict
[str
,List
[float
]]