renate.memory.buffer module#
- class renate.memory.buffer.DataBuffer(max_size=None, seed=0, transform=None, target_transform=None)[source]#
Bases:
Dataset
A memory buffer storing data points.
The buffer functions as a torch dataset, i.e., it implements
__len__
and__getitem__
. Pytorch data loaders can be used to sample from or iterate over the buffer.Data can be added to the buffer via
buffer.update(dataset, metadata)
.dataset
is a pytorch dataset expected to return an arbitrary nestedtuple
/dict
structure containingtorch.Tensor`s of _fixed_ size and data type. `metadata
is a dictionary mapping strings to tensors for associated metadata. The logic to decide which data points remain in the buffer is implemented by different subclasses.Extracting an element from the buffer will return a nested tuple of the form
data_point, metadata = buffer[i]
, wheredata_point
is the raw data point andmetadata
is a dictionary containing associated metadata as well as fieldidx
containing the index of the data point in the buffer. Additional fields of metadata might be added by some buffering methods, e.g., instance weights in coreset methods.In addition to passing metadata to the
update
method, one also access and replace the metadata in the buffer via theget_metadata
andset_metadata
methods.Note that, in order to apply transformations, the buffer assumes that the data points are tuples of the form
(x, y)
. We applytransform
toinputs
andtarget_transform
toy
. Ensure that the transforms accept the correct type, e.g., ifx
is a dictionary,transform
needs to operate on a dictionary.
- class renate.memory.buffer.ReservoirBuffer(max_size=None, seed=0, transform=None, target_transform=None)[source]#
Bases:
DataBuffer
A buffer implementing reservoir sampling.
- class renate.memory.buffer.SlidingWindowBuffer(max_size=None, seed=0, transform=None, target_transform=None)[source]#
Bases:
DataBuffer
A buffer implementing a sliding window.
- class renate.memory.buffer.GreedyClassBalancingBuffer(max_size=None, seed=0, transform=None, target_transform=None)[source]#
Bases:
DataBuffer
A buffer implementing a greedy class-balancing approach.
- class renate.memory.buffer.InfiniteBuffer(transform=None, target_transform=None)[source]#
Bases:
DataBuffer
A data buffer that stores _all_ incoming data.