renate.utils.hf_utils module#

class renate.utils.hf_utils.DataCollatorWithPaddingForWildTime(tokenizer, padding=True, max_length=None, pad_to_multiple_of=None, return_tensors='pt')[source]#

Bases: DataCollatorWithPadding

A data collator class that can handle wild time data (non-standard) batches.

This adds to the transformer library’s DataCollatorWithPadding. That data collator expects data in a standard HF format. Wild time data format is slightly different: We get a tuple of BatchEncoding (dict) and a class label. When being read from a buffer, an additional metadata attribute is present. These cases are not handled by the orig data collator. The code here only separates the input data into format original collator can handle and undoes the data packing: see parts after super().__call__.

tokenizer: PreTrainedTokenizerBase#