Source code for renate.benchmark.datasets.base
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from abc import ABC
from pathlib import Path
from typing import Optional, Union
from renate import defaults
from renate.data.data_module import RenateDataModule
[docs]
class DataIncrementalDataModule(RenateDataModule, ABC):
"""Base class for all :py:class:`~renate.data.data_module.RenateDataModule` compatible with
:py:class:`~renate.benchmark.scenarios.DataIncrementalScenario`.
Defines the API required by the :py:class:`~renate.benchmark.scenarios.DataIncrementalScenario`.
All classes extending this class must load the datasets corresponding to the value in
``data_id`` whenever ``setup()`` is called.
Args:
data_path: the path to the folder containing the dataset files.
data_id: Time slice to be loaded.
src_bucket: the name of the s3 bucket. If not provided, downloads the data from original
source.
src_object_name: the folder path in the s3 bucket.
val_size: Fraction of the training data to be used for validation.
seed: Seed used to fix random number generation.
"""
def __init__(
self,
data_path: Union[Path, str],
data_id: Union[int, str],
src_bucket: Optional[str] = None,
src_object_name: Optional[str] = None,
val_size: float = defaults.VALIDATION_SIZE,
seed: int = defaults.SEED,
):
super().__init__(
data_path=data_path,
src_bucket=src_bucket,
src_object_name=src_object_name,
val_size=val_size,
seed=seed,
)
self.data_id = data_id