renate.evaluation.evaluator module#
- class renate.evaluation.evaluator.Evaluator(model, batch_size, transform=None, target_transform=None, logged_metrics=None)[source]#
Bases:
LightningModule,ABCA general Evaluator module for collection of quantitative metrics on the test dataset.
This is an abstract interface which can be called with respect to a PyTorch Lightning
Trainer. and itstest()function. It collects quantitative observations with respect to a single dataset. The metrics that are being collected are defined in thecreate_metricsfunction.- Parameters:
model¶ (
RenateModule) – ARenateModuleto be evaluated.batch_size¶ (
int) – The batch size to be used when creating the test data loader.transform¶ (
Optional[Callable]) – The transformation applied for evaluation.target_transform¶ (
Optional[Callable]) – The target transformation applied for evaluation.logged_metrics¶ (
Optional[Dict[str,Metric]]) – Metrics logged additional to the default ones.
- on_model_test_start(test_dataset, test_collate_fn=None, task_id=None)[source]#
Called before a model test starts.
- Return type:
DataLoader
- test_step(batch, batch_idx)[source]#
PyTorch Lightning function to perform the test step.
- Return type:
None
- class renate.evaluation.evaluator.ClassificationEvaluator(model, batch_size, transform=None, target_transform=None, logged_metrics=None)[source]#
Bases:
EvaluatorA classification Evaluator module for collection of quantitative metrics on the test dataset.
- renate.evaluation.evaluator.evaluate(model, test_dataset, test_collate_fn=None, task_id='default_task', batch_size=32, transform=None, target_transform=None, logged_metrics=None, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32')[source]#
Evaluate the model on the test dataset or a set of test datasets corresponding to distinct tasks.
If the
test_datasetare specified as a list of datasets, it is assumed to be ordered. Similarly, in a case thetask_idare specified as a list, it is assumed to be ordered. A task ID list can be used to set specific model part to be used, for example, an output head with some specific test dataset in the input sequence.- Parameters:
model¶ (
RenateModule) – ARenateModuleto be evaluated.test_dataset¶ (
Union[List[Dataset],Dataset]) – The test dataset(s) to be evaluated.test_collate_fn¶ (
Optional[Callable]) – collate_fn used in the DataLoader.task_id¶ (
Union[List[str],str]) – The task id(s) of the test dataset(s).batch_size¶ (
int) – The batch size to be used when creating the test data loader.transform¶ (
Optional[Callable]) – The transformation applied for evaluation.target_transform¶ (
Optional[Callable]) – The target transformation applied for evaluation.logged_metrics¶ (
Optional[Dict[str,Metric]]) – Metrics logged additional to the default ones.logger¶ (
Logger) – Logger used by PyTorch Lightning to log intermediate results.accelerator¶ (
Literal['auto','cpu','gpu','tpu']) – Accelerator used by PyTorch Lightning to train the model.devices¶ (
Optional[int]) – Devices used by PyTorch Lightning to train the model. If the devices flag is not defined, it will assume devices to be “auto” and fetch theauto_device_countfrom theaccelerator.strategy¶ (
str) – Name of the distributed training strategy to use. More detailsprecision¶ (
str) – Type of bit precision to use. More details
- Return type:
Dict[str,List[float]]