import torch from transformers import DataCollatorForTokenClassification from transformers.data.data_collator import pad_without_fast_tokenizer_warning class ExtendedEmbeddingsDataCollatorForTokenClassification(DataCollatorForTokenClassification): """ A data collator for token classification tasks with extended embeddings. This data collator extends the functionality of the `DataCollatorForTokenClassification` class by adding support for additional features such as `per`, `org`, and `loc`. Part of the code copied from: transformers.data.data_collator.DataCollatorForTokenClassification """ def torch_call(self, features): label_name = "label" if "label" in features[0].keys() else "labels" labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None per = [feature["per"] for feature in features] if "per" in features[0].keys() else None org = [feature["org"] for feature in features] if "org" in features[0].keys() else None loc = [feature["loc"] for feature in features] if "loc" in features[0].keys() else None no_labels_features = [{k: v for k, v in feature.items() if k not in [label_name, "per", "org", "loc"]} for feature in features] batch = pad_without_fast_tokenizer_warning( self.tokenizer, no_labels_features, padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors="pt", ) if labels is None: return batch sequence_length = batch["input_ids"].shape[1] padding_side = self.tokenizer.padding_side def to_list(tensor_or_iterable): if isinstance(tensor_or_iterable, torch.Tensor): return tensor_or_iterable.tolist() return list(tensor_or_iterable) if padding_side == "right": batch[label_name] = [ to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels ] batch["per"] = [ to_list(p) + [0] * (sequence_length - len(p)) for p in per ] batch["org"] = [ to_list(o) + [0] * (sequence_length - len(o)) for o in org ] batch["loc"] = [ to_list(l) + [0] * (sequence_length - len(l)) for l in loc ] else: batch[label_name] = [ [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels ] batch["per"] = [ [0] * (sequence_length - len(p)) + self.to_list(p) for p in per ] batch["org"] = [ [0] * (sequence_length - len(o)) + self.to_list(o) for o in org ] batch["loc"] = [ [0] * (sequence_length - len(l)) + self.to_list(l) for l in loc ] batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64) batch["per"] = torch.tensor(batch["per"], dtype=torch.int64) batch["org"] = torch.tensor(batch["org"], dtype=torch.int64) batch["loc"] = torch.tensor(batch["loc"], dtype=torch.int64) return batch