Spaces:
Sleeping
Sleeping
from typing import OrderedDict | |
import torch | |
from torch.utils import data | |
from torch.utils.data import Dataset | |
from datasets.arrow_dataset import Dataset as HFDataset | |
from datasets.load import load_dataset, load_metric | |
from transformers import AutoTokenizer, DataCollatorForTokenClassification, AutoConfig | |
import numpy as np | |
class SRLDataset(Dataset): | |
def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None: | |
super().__init__() | |
raw_datasets = load_dataset(f'tasks/srl/datasets/{data_args.dataset_name}.py') | |
self.tokenizer = tokenizer | |
if training_args.do_train: | |
column_names = raw_datasets["train"].column_names | |
features = raw_datasets["train"].features | |
else: | |
column_names = raw_datasets["validation"].column_names | |
features = raw_datasets["validation"].features | |
self.label_column_name = f"tags" | |
self.label_list = features[self.label_column_name].feature.names | |
self.label_to_id = {l: i for i, l in enumerate(self.label_list)} | |
self.num_labels = len(self.label_list) | |
if training_args.do_train: | |
train_dataset = raw_datasets['train'] | |
if data_args.max_train_samples is not None: | |
train_dataset = train_dataset.select(range(data_args.max_train_samples)) | |
self.train_dataset = train_dataset.map( | |
self.tokenize_and_align_labels, | |
batched=True, | |
load_from_cache_file=True, | |
desc="Running tokenizer on train dataset", | |
) | |
if training_args.do_eval: | |
eval_dataset = raw_datasets['validation'] | |
if data_args.max_eval_samples is not None: | |
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) | |
self.eval_dataset = eval_dataset.map( | |
self.tokenize_and_align_labels, | |
batched=True, | |
load_from_cache_file=True, | |
desc="Running tokenizer on validation dataset", | |
) | |
if training_args.do_predict: | |
if data_args.dataset_name == "conll2005": | |
self.predict_dataset = OrderedDict() | |
self.predict_dataset['wsj'] = raw_datasets['test_wsj'].map( | |
self.tokenize_and_align_labels, | |
batched=True, | |
load_from_cache_file=True, | |
desc="Running tokenizer on WSJ test dataset", | |
) | |
self.predict_dataset['brown'] = raw_datasets['test_brown'].map( | |
self.tokenize_and_align_labels, | |
batched=True, | |
load_from_cache_file=True, | |
desc="Running tokenizer on Brown test dataset", | |
) | |
else: | |
self.predict_dataset = raw_datasets['test_wsj'].map( | |
self.tokenize_and_align_labels, | |
batched=True, | |
load_from_cache_file=True, | |
desc="Running tokenizer on WSJ test dataset", | |
) | |
self.data_collator = DataCollatorForTokenClassification(self.tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None) | |
self.metric = load_metric("seqeval") | |
def compute_metrics(self, p): | |
predictions, labels = p | |
predictions = np.argmax(predictions, axis=2) | |
# Remove ignored index (special tokens) | |
true_predictions = [ | |
[self.label_list[p] for (p, l) in zip(prediction, label) if l != -100] | |
for prediction, label in zip(predictions, labels) | |
] | |
true_labels = [ | |
[self.label_list[l] for (p, l) in zip(prediction, label) if l != -100] | |
for prediction, label in zip(predictions, labels) | |
] | |
results = self.metric.compute(predictions=true_predictions, references=true_labels) | |
return { | |
"precision": results["overall_precision"], | |
"recall": results["overall_recall"], | |
"f1": results["overall_f1"], | |
"accuracy": results["overall_accuracy"], | |
} | |
def tokenize_and_align_labels(self, examples): | |
for i, tokens in enumerate(examples['tokens']): | |
examples['tokens'][i] = tokens + ["[SEP]"] + [tokens[int(examples['index'][i])]] | |
tokenized_inputs = self.tokenizer( | |
examples['tokens'], | |
padding=False, | |
truncation=True, | |
# We use this argument because the texts in our dataset are lists of words (with a label for each word). | |
is_split_into_words=True, | |
) | |
# print(tokenized_inputs['input_ids'][0]) | |
labels = [] | |
for i, label in enumerate(examples['tags']): | |
word_ids = [None] | |
for j, word in enumerate(examples['tokens'][i][:-2]): | |
token = self.tokenizer.encode(word, add_special_tokens=False) | |
word_ids += [j] * len(token) | |
word_ids += [None] | |
verb = examples['tokens'][i][int(examples['index'][i])] | |
word_ids += [None] * len(self.tokenizer.encode(verb, add_special_tokens=False)) | |
word_ids += [None] | |
# word_ids = tokenized_inputs.word_ids(batch_index=i) | |
previous_word_idx = None | |
label_ids = [] | |
for word_idx in word_ids: | |
# Special tokens have a word id that is None. We set the label to -100 so they are automatically | |
# ignored in the loss function. | |
if word_idx is None: | |
label_ids.append(-100) | |
# We set the label for the first token of each word. | |
elif word_idx != previous_word_idx: | |
label_ids.append(label[word_idx]) | |
# For the other tokens in a word, we set the label to either the current label or -100, depending on | |
# the label_all_tokens flag. | |
else: | |
label_ids.append(-100) | |
previous_word_idx = word_idx | |
labels.append(label_ids) | |
tokenized_inputs["labels"] = labels | |
return tokenized_inputs |