NerRoB-czech / extended_embeddings /extended_embeddings_data_collator.py
AlzbetaStrompova
minor changes
75a65be
import torch
from transformers import DataCollatorForTokenClassification
from transformers.data.data_collator import pad_without_fast_tokenizer_warning
class ExtendedEmbeddingsDataCollatorForTokenClassification(DataCollatorForTokenClassification):
"""
A data collator for token classification tasks with extended embeddings.
This data collator extends the functionality of the `DataCollatorForTokenClassification` class
by adding support for additional features such as `per`, `org`, and `loc`.
Part of the code copied from: transformers.data.data_collator.DataCollatorForTokenClassification
"""
def torch_call(self, features):
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
per = [feature["per"] for feature in features] if "per" in features[0].keys() else None
org = [feature["org"] for feature in features] if "org" in features[0].keys() else None
loc = [feature["loc"] for feature in features] if "loc" in features[0].keys() else None
no_labels_features = [{k: v for k, v in feature.items() if k not in [label_name, "per", "org", "loc"]} for feature in features]
batch = pad_without_fast_tokenizer_warning(
self.tokenizer,
no_labels_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
if labels is None:
return batch
sequence_length = batch["input_ids"].shape[1]
padding_side = self.tokenizer.padding_side
def to_list(tensor_or_iterable):
if isinstance(tensor_or_iterable, torch.Tensor):
return tensor_or_iterable.tolist()
return list(tensor_or_iterable)
if padding_side == "right":
batch[label_name] = [
to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
]
batch["per"] = [
to_list(p) + [0] * (sequence_length - len(p)) for p in per
]
batch["org"] = [
to_list(o) + [0] * (sequence_length - len(o)) for o in org
]
batch["loc"] = [
to_list(l) + [0] * (sequence_length - len(l)) for l in loc
]
else:
batch[label_name] = [
[self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
]
batch["per"] = [
[0] * (sequence_length - len(p)) + self.to_list(p) for p in per
]
batch["org"] = [
[0] * (sequence_length - len(o)) + self.to_list(o) for o in org
]
batch["loc"] = [
[0] * (sequence_length - len(l)) + self.to_list(l) for l in loc
]
batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
batch["per"] = torch.tensor(batch["per"], dtype=torch.int64)
batch["org"] = torch.tensor(batch["org"], dtype=torch.int64)
batch["loc"] = torch.tensor(batch["loc"], dtype=torch.int64)
return batch