Source code for renate.updaters.model_updater

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import abc
import logging
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Type

import torch
import torchmetrics
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers.logger import Logger
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from syne_tune import Reporter
from torch.nn import Parameter
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import Dataset

from renate import defaults
from renate.utils.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
from renate.utils.distributed_strategies import create_strategy
from renate.utils.file import unlink_file_or_folder
from renate.utils.misc import AdditionalTrainingMetrics, int_or_str
from .learner import Learner, ReplayLearner
from ..models import RenateModule


logging_logger = logging.getLogger(__name__)


[docs] class SyneTuneCallback(Callback): """Callback to report metrics to Syne Tune. Args: val_enabled: Whether validation was enabled in the Learner. """ def __init__(self, val_enabled: bool): super().__init__() self._report = Reporter() self._val_enabled = val_enabled self._additional_metrics = AdditionalTrainingMetrics() @rank_zero_only def _log(self, trainer: Trainer, pl_module: LightningModule) -> None: """Report the current epoch's results to Syne Tune. If validation was run `_val_enabled` is True, the results are reported at the end of the validation epoch. Otherwise, they are reported at the end of the training epoch. """ training = pl_module.training if trainer.sanity_checking or (training and self._val_enabled): return to_report = {k: v.item() for k, v in trainer.logged_metrics.items()} to_report.update(self._additional_metrics(pl_module)) self._report( **to_report, step=trainer.current_epoch, epoch=trainer.current_epoch + 1, )
[docs] def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: self._additional_metrics.on_train_epoch_end() self._log(trainer=trainer, pl_module=pl_module)
[docs] def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: self._log(trainer=trainer, pl_module=pl_module)
[docs] def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: self._additional_metrics.on_train_start()
[docs] class RenateModelCheckpoint(ModelCheckpoint): """Callback to save Renate state after each epoch. Args: model: Model to be saved when creating a checkpoint. output_state_folder: Checkpoint folder location. val_enabled: Whether validation was enabled in the Learner. Forwarded to `SyneTuneCallback`. metric: Monitored metric to decide when to write a new checkpoint. If no metric is provided or validation is not enabled, the latest model will be stored. mode: `min` or `max`. Whether to minimize or maximize the monitored `metric`. use_syne_tune_callback: Whether to use `SyneTuneCallback`. """ def __init__( self, model: RenateModule, output_state_folder: str, val_enabled: bool, metric: Optional[str] = None, mode: defaults.SUPPORTED_TUNING_MODE_TYPE = "min", use_syne_tune_callback: bool = True, ) -> None: every_n_epochs = 1 save_last = False if metric is None or not val_enabled: every_n_epochs = 0 save_last = True learner_checkpoint_filename = Path(defaults.learner_state_file("")).stem super().__init__( dirpath=output_state_folder, filename=learner_checkpoint_filename, every_n_epochs=every_n_epochs, monitor=metric, mode=mode, save_last=save_last, save_weights_only=True, ) self._model = model self._output_state_folder = output_state_folder self.CHECKPOINT_NAME_LAST = learner_checkpoint_filename # Delete old checkpoint if exists unlink_file_or_folder(Path(defaults.learner_state_file(self._output_state_folder))) # FIXME: Hack to make sure Syne Tune is called after checkpointing. # Details: https://github.com/Lightning-AI/lightning/issues/15026 # If fixed, remove on_train_epoch_end, on_validation_epoch_end, val_enabled, remove line # below, and add in ModelUpdaterSyneTune callback. if use_syne_tune_callback: self._syne_tune_callback = SyneTuneCallback(val_enabled) else: self._syne_tune_callback = None
[docs] def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: super().on_train_epoch_end(trainer=trainer, pl_module=pl_module) if self._syne_tune_callback is not None: self._syne_tune_callback.on_train_epoch_end(trainer=trainer, pl_module=pl_module)
[docs] def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: super().on_validation_epoch_end(trainer=trainer, pl_module=pl_module) if self._syne_tune_callback is not None: self._syne_tune_callback.on_validation_epoch_end(trainer=trainer, pl_module=pl_module)
def _load_best_checkpoint_and_save(self, trainer: Trainer, pl_module: LightningModule) -> None: # Reload best state. learner_state_path = Path(defaults.learner_state_file(self._output_state_folder)) if learner_state_path.exists(): # There are three obvious steps that are handled by lightning if # we use the load_from_checkpoint mechanism. Here we do those manually. # See for reference # https://github.com/Lightning-AI/lightning/blob/1.8.6/src/pytorch_lightning/core/saving.py#L225 # 1. Load the state_dict from the checkpoint file. # 2. Call the on_load_checkpoint (which is a callback) # 3. Load the state_dict into the model. Note the strategy.load_model_state_dict call. loaded_state = trainer.strategy.load_checkpoint(learner_state_path) pl_module.on_load_checkpoint(loaded_state) trainer.strategy.load_model_state_dict(loaded_state) # Finalize model update. pl_module.on_model_update_end() # Save permanently. if trainer.is_global_zero: # Save the buffer only on rank zero. pl_module.save(self._output_state_folder) # Overwrite checkpoint. self._save_checkpoint(trainer, str(learner_state_path))
[docs] def teardown(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: """Implements the separation of learner and model at the end of training. There are two cases two handle. 1. If deepspeed is being used: The learner_state_path (which the checkpointing func) uses is a directory and not a file. This directory has sharded state_dicts (of model and optimizers), depending on which deepspeed stage is used. There are three steps here a. combine all the shards into one big state dict. b. The learner_state_path is a dir (learner.ckpt/). This needs to be deleted first. c. Write the combined state_dict as the learner.ckpt file as a single file. d. Extract the state_dict element from the learner and save that as the model.ckpt. 2. If not deepspeed (say DDP or single device): The steps are much simpler. a. Load the learner.ckpt and extract the state_dict element. b. | Sanitize the extracted state_dict. Learner has the model in a _model attribute. | So strip the first "_model." from the keys of the state_dict. c. Save the sanitized model to model.ckpt. Case 2 is needs to be done even for Case 1 (step d). So teardown is a recursive call in Case 1 which automatically goes to Case 2 as learner.ckpt is file now. """ if trainer.is_global_zero and (stage == "fit"): learner_state_path = Path(defaults.learner_state_file(self._output_state_folder)) if learner_state_path.exists() and learner_state_path.is_dir(): # Deepspeed zero saves everything as folders. combined_state_dict = convert_zero_checkpoint_to_fp32_state_dict(learner_state_path) unlink_file_or_folder(learner_state_path) torch.save(combined_state_dict, learner_state_path) self.teardown(trainer, pl_module, stage) elif learner_state_path.exists() and learner_state_path.is_file(): # This a normal file. We strip the model of any wrappers and save that. learner_state = torch.load(learner_state_path) out_sd = { k.replace("_model.", "", 1): v for k, v in learner_state["state_dict"].items() } # Replace only 1 instance because we have to load it into RenateModule. torch.save(out_sd, defaults.model_file(self.dirpath)) # Remove model from learner checkpoint learner_state["state_dict"] = {} torch.save(learner_state, learner_state_path)
[docs] def on_exception( self, trainer: Trainer, pl_module: LightningModule, exception: BaseException ) -> None: super().on_exception(trainer, pl_module, exception) self._load_best_checkpoint_and_save(trainer, pl_module)
[docs] def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: super().on_fit_end(trainer, pl_module) self._load_best_checkpoint_and_save(trainer, pl_module)
[docs] class ModelUpdater(abc.ABC): """Updates a learner using the data provided. Args: model: The potentially pretrained model to be updated with new data. learner_class: Class of the learner to be used for model update. learner_kwargs: Arguments either used for creating a new learner (no previous state available) or replace current arguments of the learner. input_state_folder: Folder used by Renate to store files for current state. output_state_folder: Folder used by Renate to store files for next state. max_epochs: The maximum number of epochs used to train the model. For comparability between methods, epochs are interpreted as "finetuning-equivalent". That is, one epoch is defined as `len(current_task_dataset) / batch_size` update steps. train_transform: The transformation applied during training. train_target_transform: The target transformation applied during testing. test_transform: The transformation at test time. test_target_transform: The target transformation at test time. buffer_transform: Augmentations applied to the input data coming from the memory. Not all updaters require this. If required but not passed, `transform` will be used. buffer_target_transform: Transformations applied to the target. Not all updaters require this. If required but not passed, `target_transform` will be used. metric: Monitored metric to decide when to write a new checkpoint or early-stop the optimization. If no metric is provided, the latest model will be stored. mode: `min` or `max`. Whether to minimize or maximize the monitored `metric`. logged_metrics: Metrics logged additional to the default ones. early_stopping_enabled: Enables the early stopping of the optimization. logger: Logger used by PyTorch Lightning to log intermediate results. accelerator: Accelerator used by PyTorch Lightning to train the model. devices: Devices used by PyTorch Lightning to train the model. If the devices flag is not defined, it will assume devices to be "auto" and fetch the `auto_device_count` from the `accelerator`. deterministic_trainer: When set to True makes the output of the training deterministic. The value is passed to the trainer as described `here <https://pytorch-lightning.readthedocs.io/en/stable/common\ /trainer.html#reproducibility>`_. gradient_clip_val: Gradient clipping value used in PyTorch Lightning. Defaults to not clipping by using a value of None. gradient_clip_algorithm: Method to clip gradients (norm or value) used in PyTorch Lightning. """ def __init__( self, model: RenateModule, loss_fn: torch.nn.Module, optimizer: Callable[[List[Parameter]], Optimizer], learner_class: Type[Learner], learner_kwargs: Optional[Dict[str, Any]] = None, input_state_folder: Optional[str] = None, output_state_folder: Optional[str] = None, max_epochs: int = defaults.MAX_EPOCHS, learning_rate_scheduler: Optional[Optional[Callable[[Optimizer], _LRScheduler]]] = None, learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 train_transform: Optional[Callable] = None, train_target_transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, test_target_transform: Optional[Callable] = None, buffer_transform: Optional[Callable] = None, buffer_target_transform: Optional[Callable] = None, metric: Optional[str] = None, mode: defaults.SUPPORTED_TUNING_MODE_TYPE = "min", logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None, early_stopping_enabled: bool = False, logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS), accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, strategy: Optional[str] = defaults.DISTRIBUTED_STRATEGY, precision: str = defaults.PRECISION, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ): self._learner_kwargs = learner_kwargs or {} self._learner_kwargs["loss_fn"] = loss_fn self._learner_kwargs["optimizer"] = optimizer self._learner_kwargs["mask_unused_classes"] = mask_unused_classes if learning_rate_scheduler is not None: self._learner_kwargs["learning_rate_scheduler"] = learning_rate_scheduler self._learner_kwargs["learning_rate_scheduler_interval"] = ( learning_rate_scheduler_interval ) self._model = model self._learner_state_file: Optional[str] = None if input_state_folder is not None: self._learner_state_file = defaults.learner_state_file(input_state_folder) else: logging_logger.info( "No location for current updater state provided. Updating will start from scratch." ) if output_state_folder is None: logging_logger.info( "No location for next updater state provided. No state will be stored." ) elif metric is None: logging_logger.info( "Metric or mode is not provided. Checkpoint is saved only after training." ) if metric is None and early_stopping_enabled: warnings.warn( "Early stopping is enabled but no metric is specified. Early stopping will be " "ignored." ) early_stopping_enabled = False self._input_state_folder = input_state_folder self._output_state_folder = output_state_folder self._metric = metric self._mode = mode self._logged_metrics = logged_metrics self._early_stopping_enabled = early_stopping_enabled self._train_transform = train_transform self._train_target_transform = train_target_transform self._test_transform = test_transform self._test_target_transform = test_target_transform self._buffer_transform = buffer_transform or train_transform self._buffer_target_transform = buffer_target_transform or train_target_transform self._transforms_kwargs = { "train_transform": self._train_transform, "train_target_transform": self._train_target_transform, "test_transform": self._test_transform, "test_target_transform": self._test_target_transform, } if issubclass(learner_class, ReplayLearner): self._transforms_kwargs["buffer_transform"] = self._buffer_transform self._transforms_kwargs["buffer_target_transform"] = self._buffer_target_transform self._max_epochs = max_epochs if accelerator not in defaults.SUPPORTED_ACCELERATORS: raise ValueError( f"Accelerator {accelerator} not supported. " f"Supported accelerators are {defaults.SUPPORTED_ACCELERATORS}." ) self._accelerator = accelerator self._devices = devices self._strategy = strategy self._precision = int_or_str(precision) self._learner = self._load_learner(learner_class, self._learner_kwargs) assert self._learner.is_logged_metric(metric), f"Target metric `{metric}` is not logged." self._logger = logger self._num_epochs_trained = 0 self._deterministic_trainer = deterministic_trainer self._gradient_clip_algorithm = gradient_clip_algorithm self._gradient_clip_val = gradient_clip_val
[docs] @abc.abstractmethod def update( self, train_dataset: Dataset, val_dataset: Optional[Dataset] = None, train_dataset_collate_fn: Optional[Callable] = None, val_dataset_collate_fn: Optional[Callable] = None, task_id: Optional[str] = None, ) -> None: """Updates the model using the data passed as input. Args: train_dataset: The training data. val_dataset: The validation data. train_dataset_collate_fn: collate_fn used to merge a list of samples to form a mini-batch of Tensors for the training data. val_dataset_collate_fn: collate_fn used to merge a list of samples to form a mini-batch of Tensors for the validation data. task_id: The task id. """
def _load_learner( self, learner_class: Type[Learner], learner_kwargs: Dict[str, Any], ) -> Learner: if self._learner_state_file is None or not Path(self._learner_state_file).is_file(): logging_logger.warning("No updater state available. Updating from scratch.") return learner_class( model=self._model, **learner_kwargs, logged_metrics=self._logged_metrics, **self._transforms_kwargs, ) learner = learner_class.load_from_checkpoint( self._learner_state_file, model=self._model, logged_metrics=self._logged_metrics, strict=False, **self._transforms_kwargs, **learner_kwargs, ) learner.load(self._input_state_folder) return learner def _fit_learner( self, learner: Learner, use_syne_tune_callback: bool = True, ) -> None: callbacks: List[Callback] = [] if use_syne_tune_callback: callbacks.append(SyneTuneCallback(learner.val_enabled)) if self._output_state_folder is not None: model_checkpoint_callback = RenateModelCheckpoint( model=self._model, output_state_folder=self._output_state_folder, metric=self._metric, mode=self._mode, val_enabled=learner.val_enabled, use_syne_tune_callback=use_syne_tune_callback, ) callbacks = [model_checkpoint_callback] # FIXME: insert at 0 as soon as PTL is fixed. if self._early_stopping_enabled: if learner.val_enabled: callbacks.insert(0, EarlyStopping(monitor=self._metric, mode=self._mode)) else: warnings.warn( "Early stopping is currently not supported without a validation set. It will " "be ignored." ) strategy = create_strategy(self._devices, self._strategy) # Finetuning-equivalent epochs. num_batches = len(learner._train_dataset) // learner._batch_size num_batches += min(len(learner._train_dataset) % learner._batch_size, 1) trainer = Trainer( accelerator=self._accelerator, devices=self._devices, max_epochs=self._max_epochs, limit_train_batches=num_batches, callbacks=callbacks, logger=self._logger, enable_progress_bar=False, deterministic=self._deterministic_trainer, strategy=strategy, precision=self._precision, gradient_clip_val=self._gradient_clip_val, gradient_clip_algorithm=self._gradient_clip_algorithm, ) trainer.fit(learner) self._num_epochs_trained = trainer.current_epoch
[docs] class SingleTrainingLoopUpdater(ModelUpdater): """Simple ModelUpdater which requires a single learner only to update the model."""
[docs] def update( self, train_dataset: Dataset, val_dataset: Optional[Dataset] = None, train_dataset_collate_fn: Optional[Callable] = None, val_dataset_collate_fn: Optional[Callable] = None, task_id: Optional[str] = None, ) -> RenateModule: """Updates the model using the data passed as input. Args: train_dataset: The training data. val_dataset: The validation data. train_dataset_collate_fn: collate_fn used to merge a list of samples to form a mini-batch of Tensors for the training data. val_dataset_collate_fn: collate_fn used to merge a list of samples to form a mini-batch of Tensors for the validation data. task_id: The task id. """ self._learner.on_model_update_start( train_dataset=train_dataset, val_dataset=val_dataset, train_dataset_collate_fn=train_dataset_collate_fn, val_dataset_collate_fn=val_dataset_collate_fn, task_id=task_id, ) self._fit_learner(self._learner) return self._model