# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import math
import os
from pathlib import Path
from typing import Any, Optional, Tuple, Union
from warnings import warn
import torch
from renate.types import NestedTensors
[docs]
def mmap_tensor(
filename: str, size: Union[int, Tuple[int, ...]], dtype: torch.dtype
) -> torch.Tensor:
"""Creates or accesses a memory-mapped tensor."""
t = torch.from_file(
filename,
shared=True,
size=math.prod(size) if isinstance(size, tuple) else size,
dtype=dtype,
device="cpu",
)
return t.view(size)
[docs]
class Storage(torch.utils.data.Dataset):
"""An abstract class for permanent storage of datasets."""
def __init__(self, directory: str) -> None:
self._directory = directory
def __len__(self) -> int:
return self._length
def __getitem__(self, idx: int) -> Any:
raise NotImplementedError()
[docs]
def dump_dataset(self, ds: torch.utils.data.Dataset) -> None:
raise NotImplementedError()
[docs]
def load_dataset(self, directory: Union[str, Path]):
raise NotImplementedError()
[docs]
class MemoryMappedTensorStorage(Storage):
"""A class implementing permanent storage of nested tensor datasets.
This implements storage for `length` data points consisting of nested tensors of fixed types
and shapes. `Storage` implements `__len__` and `__getitem__` and therefore can be used as a
torch `Dataset`. To populate the storage, it also implements `dump_dataset`. It does _not_ keep
track which slots have or have not been populated.
`Storage` is given a path to a directory, where it creates (or accesses, if they already exist)
memory-mapped tensor files.
Args:
directory: Path to a directory.
data_point: Prototypical datapoint from which to infer shapes/dtypes.
length: Number of items to be stored.
"""
def __init__(self, directory: str) -> None:
warn(
f"""{self.__class__.__name__} will be deprecated very soon. Use FileTensorStorage
instead. {self.__class__.__name__} is currently not fully functional, as some of the
necessary parts of the interface have been modified and simplified. """,
DeprecationWarning,
stacklevel=2,
)
super().__init__(directory)
self._storage: Optional[NestedTensors] = None
@staticmethod
def _create_mmap_tensors(path: str, data_point: NestedTensors, length: int) -> NestedTensors:
if isinstance(data_point, torch.Tensor):
os.makedirs(os.path.dirname(path), exist_ok=True)
filename = f"{path}.pt"
return mmap_tensor(filename, size=(length, *data_point.size()), dtype=data_point.dtype)
elif isinstance(data_point, tuple):
return tuple(
MemoryMappedTensorStorage._create_mmap_tensors(
os.path.join(path, f"{i}.pt"), data_point[i], length
)
for i in range(len(data_point))
)
elif isinstance(data_point, dict):
return {
key: MemoryMappedTensorStorage._create_mmap_tensors(
os.path.join(path, f"{key}.pt"), data_point[key], length
)
for key in data_point
}
else:
raise TypeError(f"Expected nested tuple/dict of tensors, found {type(data_point)}.")
@staticmethod
def _get(storage: NestedTensors, idx: int) -> NestedTensors:
if isinstance(storage, torch.Tensor):
return storage[idx]
elif isinstance(storage, tuple):
return tuple(MemoryMappedTensorStorage._get(t, idx) for t in storage)
elif isinstance(storage, dict):
return {key: MemoryMappedTensorStorage._get(t, idx) for key, t in storage.items()}
else:
raise TypeError(f"Expected nested tuple/dict of tensors, found {type(storage)}.")
def __getitem__(self, idx: int) -> NestedTensors:
"""Read the item stored at index `idx`."""
return self._get(self._storage, idx)
@staticmethod
def _set(storage: NestedTensors, idx: int, data_point: NestedTensors) -> None:
if isinstance(storage, torch.Tensor):
assert isinstance(data_point, torch.Tensor)
assert data_point.dtype is storage.dtype
storage[idx] = data_point
elif isinstance(storage, tuple):
assert isinstance(data_point, tuple)
assert len(data_point) == len(storage)
for i in range(len(storage)):
MemoryMappedTensorStorage._set(storage[i], idx, data_point[i])
elif isinstance(storage, dict):
assert isinstance(data_point, dict)
assert set(data_point.keys()) == set(storage.keys())
for key in storage:
MemoryMappedTensorStorage._set(storage[key], idx, data_point[key])
else:
raise TypeError(f"Expected nested tuple/dict of tensors, found {type(storage)}.")
[docs]
def dump_dataset(self, ds):
self._length = len(ds)
self._storage = self._create_mmap_tensors(self._directory, ds[0], self._length)
for idx in range(len(self)):
self._set(self._storage, idx, ds[idx])
[docs]
class FileTensorStorage(Storage):
"""A class implementing permanent storage of nested tensor datasets to disk as pickle files.
This implements storage for `length` data points consisting of nested tensors of fixed types
and shapes. `Storage` implements `__len__` and `__getitem__` and therefore can be used as a
torch `Dataset`. To populate the storage, it also implements `dump_dataset`. It does _not_ keep
track which slots have or have not been populated.
`Storage` is given a path to a directory, where it creates (or accesses, if they already exist)
pickle files one for each point in the dataset.
Args:
directory: Path to a directory.
"""
def __init__(self, directory: str) -> None:
super().__init__(directory)
[docs]
def dump_dataset(self, ds: torch.utils.data.Dataset) -> None:
for i in range(len(ds)):
torch.save(ds[i], self._compose_file_path_from_index(i))
def __getitem__(self, idx: int) -> Any:
if not hasattr(self, "_length"):
self.load_dataset(None)
return torch.load(self._compose_file_path_from_index(idx))
[docs]
def load_dataset(self, directory: Union[str, Path]):
self._length = len([x for x in os.listdir(self._directory) if x.endswith(".pt")])
def _compose_file_path_from_index(self, idx: int) -> str:
return os.path.join(self._directory, f"{idx}.pt")