Spaces:
Sleeping
Sleeping
import torch | |
from transformers import DataCollatorForTokenClassification | |
from transformers.data.data_collator import pad_without_fast_tokenizer_warning | |
class ExtendedEmbeddingsDataCollatorForTokenClassification(DataCollatorForTokenClassification): | |
""" | |
A data collator for token classification tasks with extended embeddings. | |
This data collator extends the functionality of the `DataCollatorForTokenClassification` class | |
by adding support for additional features such as `per`, `org`, and `loc`. | |
Part of the code copied from: transformers.data.data_collator.DataCollatorForTokenClassification | |
""" | |
def torch_call(self, features): | |
label_name = "label" if "label" in features[0].keys() else "labels" | |
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None | |
per = [feature["per"] for feature in features] if "per" in features[0].keys() else None | |
org = [feature["org"] for feature in features] if "org" in features[0].keys() else None | |
loc = [feature["loc"] for feature in features] if "loc" in features[0].keys() else None | |
no_labels_features = [{k: v for k, v in feature.items() if k not in [label_name, "per", "org", "loc"]} for feature in features] | |
batch = pad_without_fast_tokenizer_warning( | |
self.tokenizer, | |
no_labels_features, | |
padding=self.padding, | |
max_length=self.max_length, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
return_tensors="pt", | |
) | |
if labels is None: | |
return batch | |
sequence_length = batch["input_ids"].shape[1] | |
padding_side = self.tokenizer.padding_side | |
def to_list(tensor_or_iterable): | |
if isinstance(tensor_or_iterable, torch.Tensor): | |
return tensor_or_iterable.tolist() | |
return list(tensor_or_iterable) | |
if padding_side == "right": | |
batch[label_name] = [ | |
to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels | |
] | |
batch["per"] = [ | |
to_list(p) + [0] * (sequence_length - len(p)) for p in per | |
] | |
batch["org"] = [ | |
to_list(o) + [0] * (sequence_length - len(o)) for o in org | |
] | |
batch["loc"] = [ | |
to_list(l) + [0] * (sequence_length - len(l)) for l in loc | |
] | |
else: | |
batch[label_name] = [ | |
[self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels | |
] | |
batch["per"] = [ | |
[0] * (sequence_length - len(p)) + self.to_list(p) for p in per | |
] | |
batch["org"] = [ | |
[0] * (sequence_length - len(o)) + self.to_list(o) for o in org | |
] | |
batch["loc"] = [ | |
[0] * (sequence_length - len(l)) + self.to_list(l) for l in loc | |
] | |
batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64) | |
batch["per"] = torch.tensor(batch["per"], dtype=torch.int64) | |
batch["org"] = torch.tensor(batch["org"], dtype=torch.int64) | |
batch["loc"] = torch.tensor(batch["loc"], dtype=torch.int64) | |
return batch | |