# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import abc
from typing import Callable, Dict, List, Optional, Union
import torch
import torchmetrics
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.loggers.logger import Logger
from torch.utils.data import DataLoader, Dataset
from renate import defaults
from renate.data.datasets import _TransformedDataset
from renate.models import RenateModule
from renate.utils.distributed_strategies import create_strategy
from renate.utils.misc import int_or_str
[docs]
class Evaluator(LightningModule, abc.ABC):
"""A general Evaluator module for collection of quantitative metrics on the test dataset.
This is an abstract interface which can be called with respect to a PyTorch Lightning `Trainer`.
and its `.test()` function. It collects quantitative observations with respect to a single
dataset. The metrics that are being collected are defined in the `create_metrics` function.
Args:
model: A `RenateModule` to be evaluated.
batch_size: The batch size to be used when creating the test data loader.
transform: The transformation applied for evaluation.
target_transform: The target transformation applied for evaluation.
logged_metrics: Metrics logged additional to the default ones.
"""
def __init__(
self,
model: RenateModule,
batch_size: int,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None,
) -> None:
super().__init__()
self._model = model
self._model.deregister_hooks()
self._batch_size = batch_size
self._transform = transform
self._target_transform = target_transform
self._metric_collection = torchmetrics.MetricCollection(logged_metrics)
[docs]
def on_model_test_start(
self,
test_dataset: Dataset,
test_collate_fn: Optional[Callable] = None,
task_id: Optional[str] = None,
) -> DataLoader:
"""Called before a model test starts."""
test_dataset = _TransformedDataset(
test_dataset,
transform=self._transform,
target_transform=self._target_transform,
)
self._task_id = task_id
return DataLoader(
test_dataset,
batch_size=self._batch_size,
shuffle=False,
pin_memory=True,
collate_fn=test_collate_fn,
)
[docs]
def test_step(self, batch: List[torch.Tensor], batch_idx: int) -> None:
"""PyTorch Lightning function to perform the test step."""
x, y = batch
outputs = self(x)
self._metric_collection(outputs, y)
[docs]
@abc.abstractmethod
def forward(self, x, task_id: Optional[str] = None) -> torch.Tensor:
"""Forward pass of the model.
Task ID can be used to specify, for example, the output head to perform the evaluation with
a specific data Chunk ID. Here, the `task_id` is used only to compute the test metrics.
"""
pass
[docs]
def on_test_epoch_end(self) -> None:
"""PyTorch Lightning function to perform at the end of test loop.
Logs the metrics and resets the metric collection.
"""
self.log_dict(self._metric_collection.compute(), on_step=False, on_epoch=True)
self._metric_collection.reset()
[docs]
class ClassificationEvaluator(Evaluator):
"""A classification Evaluator module for collection of quantitative metrics on the test
dataset.
"""
[docs]
def forward(self, x, task_id: Optional[str] = None) -> torch.Tensor:
"""Forward pass of the model.
Task ID can be used to specify, for example, the output head to perform the evaluation with
a specific data Chunk ID. Here, the `task_id` is used only to compute the test metrics.
"""
if task_id is None:
task_id = self._task_id
return self._model.get_logits(x, task_id=task_id)
[docs]
def evaluate(
model: RenateModule,
test_dataset: Union[List[Dataset], Dataset],
test_collate_fn: Optional[Callable] = None,
task_id: Union[List[str], str] = defaults.TASK_ID,
batch_size: int = defaults.BATCH_SIZE,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None,
logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS),
accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR,
devices: Optional[int] = None,
strategy: str = defaults.DISTRIBUTED_STRATEGY,
precision: str = defaults.PRECISION,
) -> Dict[str, List[float]]:
"""Evaluate the model on the test dataset or a set of test datasets corresponding to distinct
tasks.
If the `test_dataset` are specified as a list of datasets, it is assumed to be ordered.
Similarly, in a case the `task_id` are specified as a list, it is assumed to be ordered. A task
ID list can be used to set specific model part to be used, for example, an output head with some
specific test dataset in the input sequence.
Args:
model: A `RenateModule` to be evaluated.
test_dataset: The test dataset(s) to be evaluated.
test_collate_fn: collate_fn used in the DataLoader.
task_id: The task id(s) of the test dataset(s).
batch_size: The batch size to be used when creating the test data loader.
transform: The transformation applied for evaluation.
target_transform: The target transformation applied for evaluation.
logged_metrics: Metrics logged additional to the default ones.
logger: Logger used by PyTorch Lightning to log intermediate results.
accelerator: Accelerator used by PyTorch Lightning to train the model.
devices: Devices used by PyTorch Lightning to train the model. If the devices flag is not
defined, it will assume devices to be "auto" and fetch the `auto_device_count` from the
`accelerator`.
strategy: Name of the distributed training strategy to use.
`More details <https://lightning.ai/docs/pytorch/stable/extensions/strategy.html>`__
precision: Type of bit precision to use.
`More details <https://lightning.ai/docs/pytorch/stable/common/precision_basic.html>`__
"""
if isinstance(test_dataset, Dataset):
test_dataset = [test_dataset]
if isinstance(task_id, str):
task_id = [task_id] * len(test_dataset)
assert len(task_id) == len(test_dataset)
evaluator = ClassificationEvaluator(
model=model,
batch_size=batch_size,
transform=transform,
target_transform=target_transform,
logged_metrics=logged_metrics,
)
trainer = Trainer(
accelerator=accelerator,
devices=devices,
logger=logger,
enable_checkpointing=False,
enable_progress_bar=False,
strategy=create_strategy(devices, strategy),
precision=int_or_str(precision),
)
results = {}
for i in range(len(test_dataset)):
test_loader = evaluator.on_model_test_start(test_dataset[i], test_collate_fn, task_id[i])
trainer.test(
evaluator,
test_loader,
)
for metric_name, value in trainer.logged_metrics.items():
if metric_name not in results:
results[metric_name] = []
results[metric_name].append(value.item())
return results