renate.memory package#

class renate.memory.DataBuffer(max_size=None, seed=0, transform=None, target_transform=None)[source]#

Bases: Dataset

A memory buffer storing data points.

The buffer functions as a torch dataset, i.e., it implements __len__ and __getitem__. Pytorch data loaders can be used to sample from or iterate over the buffer.

Data can be added to the buffer via buffer.update(dataset, metadata). dataset is a pytorch dataset expected to return an arbitrary nested tuple/dict structure containing torch.Tensor`s of _fixed_ size and data type. `metadata is a dictionary mapping strings to tensors for associated metadata. The logic to decide which data points remain in the buffer is implemented by different subclasses.

Extracting an element from the buffer will return a nested tuple of the form data_point, metadata = buffer[i], where data_point is the raw data point and metadata is a dictionary containing associated metadata as well as field idx containing the index of the data point in the buffer. Additional fields of metadata might be added by some buffering methods, e.g., instance weights in coreset methods.

In addition to passing metadata to the update method, one also access and replace the metadata in the buffer via the get_metadata and set_metadata methods.

Note that, in order to apply transformations, the buffer assumes that the data points are tuples of the form (x, y). We apply transform to inputs and target_transform to y. Ensure that the transforms accept the correct type, e.g., if x is a dictionary, transform needs to operate on a dictionary.

update(dataset, metadata=None)[source]#

Updates the buffer with a new dataset.

Return type:

None

get_metadata(key)[source]#
Return type:

Dict[str, Tensor]

set_metadata(key, values)[source]#
Return type:

None

state_dict()[source]#
Return type:

Dict

load_state_dict(state_dict)[source]#
Return type:

None

save(target_dir)[source]#
Return type:

None

load(source_dir)[source]#
Return type:

None

class renate.memory.GreedyClassBalancingBuffer(max_size=None, seed=0, transform=None, target_transform=None)[source]#

Bases: DataBuffer

A buffer implementing a greedy class-balancing approach.

state_dict()[source]#
Return type:

Dict

load_state_dict(state_dict)[source]#
Return type:

None

class renate.memory.InfiniteBuffer(transform=None, target_transform=None)[source]#

Bases: DataBuffer

A data buffer that stores _all_ incoming data.

class renate.memory.ReservoirBuffer(max_size=None, seed=0, transform=None, target_transform=None)[source]#

Bases: DataBuffer

A buffer implementing reservoir sampling.

class renate.memory.SlidingWindowBuffer(max_size=None, seed=0, transform=None, target_transform=None)[source]#

Bases: DataBuffer

A buffer implementing a sliding window.

Submodules#