homeway's picture
Add application file
7713b1f
from typing import OrderedDict
import torch
from torch.utils import data
from torch.utils.data import Dataset
from datasets.arrow_dataset import Dataset as HFDataset
from datasets.load import load_dataset, load_metric
from transformers import AutoTokenizer, DataCollatorForTokenClassification, AutoConfig
import numpy as np
class SRLDataset(Dataset):
def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None:
super().__init__()
raw_datasets = load_dataset(f'tasks/srl/datasets/{data_args.dataset_name}.py')
self.tokenizer = tokenizer
if training_args.do_train:
column_names = raw_datasets["train"].column_names
features = raw_datasets["train"].features
else:
column_names = raw_datasets["validation"].column_names
features = raw_datasets["validation"].features
self.label_column_name = f"tags"
self.label_list = features[self.label_column_name].feature.names
self.label_to_id = {l: i for i, l in enumerate(self.label_list)}
self.num_labels = len(self.label_list)
if training_args.do_train:
train_dataset = raw_datasets['train']
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
self.train_dataset = train_dataset.map(
self.tokenize_and_align_labels,
batched=True,
load_from_cache_file=True,
desc="Running tokenizer on train dataset",
)
if training_args.do_eval:
eval_dataset = raw_datasets['validation']
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
self.eval_dataset = eval_dataset.map(
self.tokenize_and_align_labels,
batched=True,
load_from_cache_file=True,
desc="Running tokenizer on validation dataset",
)
if training_args.do_predict:
if data_args.dataset_name == "conll2005":
self.predict_dataset = OrderedDict()
self.predict_dataset['wsj'] = raw_datasets['test_wsj'].map(
self.tokenize_and_align_labels,
batched=True,
load_from_cache_file=True,
desc="Running tokenizer on WSJ test dataset",
)
self.predict_dataset['brown'] = raw_datasets['test_brown'].map(
self.tokenize_and_align_labels,
batched=True,
load_from_cache_file=True,
desc="Running tokenizer on Brown test dataset",
)
else:
self.predict_dataset = raw_datasets['test_wsj'].map(
self.tokenize_and_align_labels,
batched=True,
load_from_cache_file=True,
desc="Running tokenizer on WSJ test dataset",
)
self.data_collator = DataCollatorForTokenClassification(self.tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)
self.metric = load_metric("seqeval")
def compute_metrics(self, p):
predictions, labels = p
predictions = np.argmax(predictions, axis=2)
# Remove ignored index (special tokens)
true_predictions = [
[self.label_list[p] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
true_labels = [
[self.label_list[l] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
results = self.metric.compute(predictions=true_predictions, references=true_labels)
return {
"precision": results["overall_precision"],
"recall": results["overall_recall"],
"f1": results["overall_f1"],
"accuracy": results["overall_accuracy"],
}
def tokenize_and_align_labels(self, examples):
for i, tokens in enumerate(examples['tokens']):
examples['tokens'][i] = tokens + ["[SEP]"] + [tokens[int(examples['index'][i])]]
tokenized_inputs = self.tokenizer(
examples['tokens'],
padding=False,
truncation=True,
# We use this argument because the texts in our dataset are lists of words (with a label for each word).
is_split_into_words=True,
)
# print(tokenized_inputs['input_ids'][0])
labels = []
for i, label in enumerate(examples['tags']):
word_ids = [None]
for j, word in enumerate(examples['tokens'][i][:-2]):
token = self.tokenizer.encode(word, add_special_tokens=False)
word_ids += [j] * len(token)
word_ids += [None]
verb = examples['tokens'][i][int(examples['index'][i])]
word_ids += [None] * len(self.tokenizer.encode(verb, add_special_tokens=False))
word_ids += [None]
# word_ids = tokenized_inputs.word_ids(batch_index=i)
previous_word_idx = None
label_ids = []
for word_idx in word_ids:
# Special tokens have a word id that is None. We set the label to -100 so they are automatically
# ignored in the loss function.
if word_idx is None:
label_ids.append(-100)
# We set the label for the first token of each word.
elif word_idx != previous_word_idx:
label_ids.append(label[word_idx])
# For the other tokens in a word, we set the label to either the current label or -100, depending on
# the label_all_tokens flag.
else:
label_ids.append(-100)
previous_word_idx = word_idx
labels.append(label_ids)
tokenized_inputs["labels"] = labels
return tokenized_inputs