Spaces:
No application file
No application file
from datasets import load_dataset, load_metric | |
import numpy as np | |
from transformers import AutoTokenizer | |
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer | |
from transformers import DataCollatorForTokenClassification | |
label_list = ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'] | |
labels_vocab = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8} | |
labels_vocab_reverse = {v:k for k,v in labels_vocab.items()} | |
metric = load_metric("seqeval") | |
def load_datasets(tokenizer): | |
def tokenize_and_align_labels(examples): | |
label_all_tokens = False | |
tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True) | |
labels = [] | |
for i, label in enumerate(examples["ner_tags"]): | |
word_ids = tokenized_inputs.word_ids(batch_index=i) | |
previous_word_idx = None | |
label_ids = [] | |
for word_idx in word_ids: | |
# Special tokens have a word id that is None. We set the label to -100 so they are automatically | |
# ignored in the loss function. | |
if word_idx is None: | |
label_ids.append(-100) | |
# We set the label for the first token of each word. | |
elif word_idx != previous_word_idx: | |
label_ids.append(label[word_idx]) | |
# For the other tokens in a word, we set the label to either the current label or -100, depending on | |
# the label_all_tokens flag. | |
else: | |
label_ids.append(label[word_idx] if label_all_tokens else -100) | |
previous_word_idx = word_idx | |
labels.append(label_ids) | |
tokenized_inputs["labels"] = labels | |
return tokenized_inputs | |
datasets = load_dataset("Babelscape/wikineural") | |
train_en_dataset = datasets['train_en'] | |
val_en_dataset = datasets['val_en'] | |
test_en_dataset = datasets['test_en'] | |
train_tokenized = train_en_dataset.map(tokenize_and_align_labels, batched=True) | |
val_tokenized = val_en_dataset.map(tokenize_and_align_labels, batched=True) | |
test_tokenized = test_en_dataset.map(tokenize_and_align_labels, batched=True) | |
return train_tokenized, val_tokenized, test_tokenized | |
def compute_metrics(p): | |
predictions, labels = p | |
predictions = np.argmax(predictions, axis=2) | |
# Remove ignored index (special tokens) | |
true_predictions = [ | |
[label_list[p] for (p, l) in zip(prediction, label) if l != -100] | |
for prediction, label in zip(predictions, labels) | |
] | |
true_labels = [ | |
[label_list[l] for (p, l) in zip(prediction, label) if l != -100] | |
for prediction, label in zip(predictions, labels) | |
] | |
results = metric.compute(predictions=true_predictions, references=true_labels) | |
return { | |
"precision": results["overall_precision"], | |
"recall": results["overall_recall"], | |
"f1": results["overall_f1"], | |
"accuracy": results["overall_accuracy"], | |
} | |
def main(): | |
MODEL_NAME = "bert-base-cased" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
train_tokenized, val_tokenized, test_tokenized = load_dataset(tokenizer) | |
model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME, num_labels=len(label_list), | |
label2id=labels_vocab, id2label=labels_vocab_reverse) | |
data_collator = DataCollatorForTokenClassification(tokenizer) | |
args = TrainingArguments( | |
"wikineural-multilingual-ner", | |
evaluation_strategy = "steps", | |
learning_rate=2e-5, | |
per_device_train_batch_size=32, | |
per_device_eval_batch_size=32, | |
num_train_epochs=1, | |
do_train=True, | |
do_eval=True, | |
weight_decay=0.01, | |
eval_steps=10000, | |
save_steps=10000 | |
) | |
trainer = Trainer( | |
model, | |
args, | |
train_dataset=train_tokenized, | |
eval_dataset=test_tokenized, | |
data_collator=data_collator, | |
tokenizer=tokenizer, | |
compute_metrics=compute_metrics | |
) | |
trainer.train() | |
trainer.evaluate() | |
predictions, labels, _ = trainer.predict(test_tokenized) | |
predictions = np.argmax(predictions, axis=2) | |
# Remove ignored index (special tokens) | |
true_predictions = [ | |
[label_list[p] for (p, l) in zip(prediction, label) if l != -100] | |
for prediction, label in zip(predictions, labels) | |
] | |
true_labels = [ | |
[label_list[l] for (p, l) in zip(prediction, label) if l != -100] | |
for prediction, label in zip(predictions, labels) | |
] | |
results = metric.compute(predictions=true_predictions, references=true_labels) | |
results | |
return 0 | |
if __name__ == "__main__": | |
main() |