renate.memory.buffer module#
- class renate.memory.buffer.DataBuffer(max_size=None, seed=0, transform=None, target_transform=None)[source]#
Bases:
DatasetA memory buffer storing data points.
The buffer functions as a torch dataset, i.e., it implements
__len__and__getitem__. Pytorch data loaders can be used to sample from or iterate over the buffer.Data can be added to the buffer via
buffer.update(dataset, metadata).datasetis a pytorch dataset expected to return an arbitrary nestedtuple/dictstructure containingtorch.Tensor`s of _fixed_ size and data type. `metadatais a dictionary mapping strings to tensors for associated metadata. The logic to decide which data points remain in the buffer is implemented by different subclasses.Extracting an element from the buffer will return a nested tuple of the form
data_point, metadata = buffer[i], wheredata_pointis the raw data point andmetadatais a dictionary containing associated metadata as well as fieldidxcontaining the index of the data point in the buffer. Additional fields of metadata might be added by some buffering methods, e.g., instance weights in coreset methods.In addition to passing metadata to the
updatemethod, one also access and replace the metadata in the buffer via theget_metadataandset_metadatamethods.Note that, in order to apply transformations, the buffer assumes that the data points are tuples of the form
(x, y). We applytransformtoinputsandtarget_transformtoy. Ensure that the transforms accept the correct type, e.g., ifxis a dictionary,transformneeds to operate on a dictionary.
- class renate.memory.buffer.ReservoirBuffer(max_size=None, seed=0, transform=None, target_transform=None)[source]#
Bases:
DataBufferA buffer implementing reservoir sampling.
- class renate.memory.buffer.SlidingWindowBuffer(max_size=None, seed=0, transform=None, target_transform=None)[source]#
Bases:
DataBufferA buffer implementing a sliding window.
- class renate.memory.buffer.GreedyClassBalancingBuffer(max_size=None, seed=0, transform=None, target_transform=None)[source]#
Bases:
DataBufferA buffer implementing a greedy class-balancing approach.
- class renate.memory.buffer.InfiniteBuffer(transform=None, target_transform=None)[source]#
Bases:
DataBufferA data buffer that stores _all_ incoming data.