renate.evaluation.evaluator module#

class renate.evaluation.evaluator.Evaluator(model, batch_size, transform=None, target_transform=None, logged_metrics=None)[source]#

Bases: LightningModule, ABC

A general Evaluator module for collection of quantitative metrics on the test dataset.

This is an abstract interface which can be called with respect to a PyTorch Lightning Trainer. and its test() function. It collects quantitative observations with respect to a single dataset. The metrics that are being collected are defined in the create_metrics function.

  • model (RenateModule) – A RenateModule to be evaluated.

  • batch_size (int) – The batch size to be used when creating the test data loader.

  • transform (Optional[Callable]) – The transformation applied for evaluation.

  • target_transform (Optional[Callable]) – The target transformation applied for evaluation.

  • logged_metrics (Optional[Dict[str, Metric]]) – Metrics logged additional to the default ones.

on_model_test_start(test_dataset, test_collate_fn=None, task_id=None)[source]#

Called before a model test starts.

Return type:


test_step(batch, batch_idx)[source]#

PyTorch Lightning function to perform the test step.

Return type:


abstract forward(x, task_id=None)[source]#

Forward pass of the model.

Task ID can be used to specify, for example, the output head to perform the evaluation with a specific data Chunk ID. Here, the task_id is used only to compute the test metrics.

Return type:



PyTorch Lightning function to perform at the end of test loop.

Logs the metrics and resets the metric collection.

Return type:


class renate.evaluation.evaluator.ClassificationEvaluator(model, batch_size, transform=None, target_transform=None, logged_metrics=None)[source]#

Bases: Evaluator

A classification Evaluator module for collection of quantitative metrics on the test dataset.

forward(x, task_id=None)[source]#

Forward pass of the model.

Task ID can be used to specify, for example, the output head to perform the evaluation with a specific data Chunk ID. Here, the task_id is used only to compute the test metrics.

Return type:


renate.evaluation.evaluator.evaluate(model, test_dataset, test_collate_fn=None, task_id='default_task', batch_size=32, transform=None, target_transform=None, logged_metrics=None, logger=<pytorch_lightning.loggers.tensorboard.TensorBoardLogger object>, accelerator='auto', devices=None, strategy='ddp', precision='32')[source]#

Evaluate the model on the test dataset or a set of test datasets corresponding to distinct tasks.

If the test_dataset are specified as a list of datasets, it is assumed to be ordered. Similarly, in a case the task_id are specified as a list, it is assumed to be ordered. A task ID list can be used to set specific model part to be used, for example, an output head with some specific test dataset in the input sequence.

  • model (RenateModule) – A RenateModule to be evaluated.

  • test_dataset (Union[List[Dataset], Dataset]) – The test dataset(s) to be evaluated.

  • test_collate_fn (Optional[Callable]) – collate_fn used in the DataLoader.

  • task_id (Union[List[str], str]) – The task id(s) of the test dataset(s).

  • batch_size (int) – The batch size to be used when creating the test data loader.

  • transform (Optional[Callable]) – The transformation applied for evaluation.

  • target_transform (Optional[Callable]) – The target transformation applied for evaluation.

  • logged_metrics (Optional[Dict[str, Metric]]) – Metrics logged additional to the default ones.

  • logger (Logger) – Logger used by PyTorch Lightning to log intermediate results.

  • accelerator (Literal['auto', 'cpu', 'gpu', 'tpu']) – Accelerator used by PyTorch Lightning to train the model.

  • devices (Optional[int]) – Devices used by PyTorch Lightning to train the model. If the devices flag is not defined, it will assume devices to be “auto” and fetch the auto_device_count from the accelerator.

  • strategy (str) – Name of the distributed training strategy to use. More details

  • precision (str) – Type of bit precision to use. More details

Return type:

Dict[str, List[float]]