# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import os
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torchmetrics
from pytorch_lightning.loggers.logger import Logger
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch.nn import Parameter
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset
from renate import defaults
from renate.memory import InfiniteBuffer
from renate.models import RenateModule
from renate.types import NestedTensors
from renate.updaters.learner import Learner
from renate.updaters.model_updater import SingleTrainingLoopUpdater
[docs]
class JointLearner(Learner):
"""A Learner that implements the Joint strategy.
This is a simple strategy that trains the model on all previously observed data.
Each time a new chunk of data is observed, the model is reinitialized and retrained on
all the previously observed data and the new chunk of data. The buffer holding the previous
data is updated with the new chunk of data.
"""
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._memory_buffer = InfiniteBuffer(
transform=self._train_transform,
target_transform=self._train_target_transform,
)
[docs]
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
super().on_save_checkpoint(checkpoint=checkpoint)
checkpoint["memory_buffer"] = self._memory_buffer.state_dict()
[docs]
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
super().on_load_checkpoint(checkpoint)
self._memory_buffer.load_state_dict(checkpoint["memory_buffer"])
[docs]
def save(self, output_state_dir: str) -> None:
super().save(output_state_dir)
buffer_dir = os.path.join(output_state_dir, "memory_buffer")
os.makedirs(buffer_dir, exist_ok=True)
self._memory_buffer.save(buffer_dir)
[docs]
def load(self, input_state_dir: str) -> None:
super().load(input_state_dir)
self._memory_buffer.load(os.path.join(input_state_dir, "memory_buffer"))
[docs]
def on_model_update_start(
self,
train_dataset: Dataset,
val_dataset: Dataset,
train_dataset_collate_fn: Optional[Callable] = None,
val_dataset_collate_fn: Optional[Callable] = None,
task_id: Optional[str] = None,
) -> None:
"""Called before a model update starts."""
super().on_model_update_start(
train_dataset=train_dataset,
val_dataset=val_dataset,
train_dataset_collate_fn=train_dataset_collate_fn,
val_dataset_collate_fn=val_dataset_collate_fn,
task_id=task_id,
)
self._memory_buffer.update(train_dataset)
[docs]
def train_dataloader(self) -> DataLoader:
return DataLoader(
self._memory_buffer,
batch_size=self._batch_size,
shuffle=True,
generator=self._rng,
pin_memory=True,
collate_fn=self._train_collate_fn,
)
[docs]
def training_step(
self,
batch: Tuple[Tuple[NestedTensors, torch.Tensor], Dict[str, torch.Tensor]],
batch_idx: int,
) -> STEP_OUTPUT:
"""PyTorch Lightning function to return the training loss."""
batch, _ = batch
return super().training_step(batch=batch, batch_idx=batch_idx)
[docs]
class JointModelUpdater(SingleTrainingLoopUpdater):
def __init__(
self,
model: RenateModule,
loss_fn: torch.nn.Module,
optimizer: Callable[[List[Parameter]], Optimizer],
learning_rate_scheduler: Optional[partial] = None,
learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501
batch_size: int = defaults.BATCH_SIZE,
input_state_folder: Optional[str] = None,
output_state_folder: Optional[str] = None,
max_epochs: int = defaults.MAX_EPOCHS,
train_transform: Optional[Callable] = None,
train_target_transform: Optional[Callable] = None,
test_transform: Optional[Callable] = None,
test_target_transform: Optional[Callable] = None,
metric: Optional[str] = None,
mode: defaults.SUPPORTED_TUNING_MODE_TYPE = "min",
logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None,
early_stopping_enabled: bool = False,
logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS),
accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR,
devices: Optional[int] = None,
strategy: str = defaults.DISTRIBUTED_STRATEGY,
precision: str = defaults.PRECISION,
seed: int = defaults.SEED,
deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER,
gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL,
gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM,
mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES,
):
learner_kwargs = {
"batch_size": batch_size,
"seed": seed,
}
super().__init__(
model,
loss_fn=loss_fn,
optimizer=optimizer,
learner_class=JointLearner,
learner_kwargs=learner_kwargs,
input_state_folder=input_state_folder,
output_state_folder=output_state_folder,
max_epochs=max_epochs,
learning_rate_scheduler=learning_rate_scheduler,
learning_rate_scheduler_interval=learning_rate_scheduler_interval,
train_transform=train_transform,
train_target_transform=train_target_transform,
test_transform=test_transform,
test_target_transform=test_target_transform,
metric=metric,
mode=mode,
logged_metrics=logged_metrics,
early_stopping_enabled=early_stopping_enabled,
logger=logger,
accelerator=accelerator,
devices=devices,
strategy=strategy,
precision=precision,
deterministic_trainer=deterministic_trainer,
gradient_clip_algorithm=gradient_clip_algorithm,
gradient_clip_val=gradient_clip_val,
mask_unused_classes=mask_unused_classes,
)