nernerner / trainer.py
james-joobs
add trainer with ner example
9974a90
from datasets import load_dataset, load_metric
import numpy as np
from transformers import AutoTokenizer
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
from transformers import DataCollatorForTokenClassification
label_list = ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']
labels_vocab = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}
labels_vocab_reverse = {v:k for k,v in labels_vocab.items()}
metric = load_metric("seqeval")
def load_datasets(tokenizer):
def tokenize_and_align_labels(examples):
label_all_tokens = False
tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)
labels = []
for i, label in enumerate(examples["ner_tags"]):
word_ids = tokenized_inputs.word_ids(batch_index=i)
previous_word_idx = None
label_ids = []
for word_idx in word_ids:
# Special tokens have a word id that is None. We set the label to -100 so they are automatically
# ignored in the loss function.
if word_idx is None:
label_ids.append(-100)
# We set the label for the first token of each word.
elif word_idx != previous_word_idx:
label_ids.append(label[word_idx])
# For the other tokens in a word, we set the label to either the current label or -100, depending on
# the label_all_tokens flag.
else:
label_ids.append(label[word_idx] if label_all_tokens else -100)
previous_word_idx = word_idx
labels.append(label_ids)
tokenized_inputs["labels"] = labels
return tokenized_inputs
datasets = load_dataset("Babelscape/wikineural")
train_en_dataset = datasets['train_en']
val_en_dataset = datasets['val_en']
test_en_dataset = datasets['test_en']
train_tokenized = train_en_dataset.map(tokenize_and_align_labels, batched=True)
val_tokenized = val_en_dataset.map(tokenize_and_align_labels, batched=True)
test_tokenized = test_en_dataset.map(tokenize_and_align_labels, batched=True)
return train_tokenized, val_tokenized, test_tokenized
def compute_metrics(p):
predictions, labels = p
predictions = np.argmax(predictions, axis=2)
# Remove ignored index (special tokens)
true_predictions = [
[label_list[p] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
true_labels = [
[label_list[l] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
results = metric.compute(predictions=true_predictions, references=true_labels)
return {
"precision": results["overall_precision"],
"recall": results["overall_recall"],
"f1": results["overall_f1"],
"accuracy": results["overall_accuracy"],
}
def main():
MODEL_NAME = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
train_tokenized, val_tokenized, test_tokenized = load_dataset(tokenizer)
model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME, num_labels=len(label_list),
label2id=labels_vocab, id2label=labels_vocab_reverse)
data_collator = DataCollatorForTokenClassification(tokenizer)
args = TrainingArguments(
"wikineural-multilingual-ner",
evaluation_strategy = "steps",
learning_rate=2e-5,
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
num_train_epochs=1,
do_train=True,
do_eval=True,
weight_decay=0.01,
eval_steps=10000,
save_steps=10000
)
trainer = Trainer(
model,
args,
train_dataset=train_tokenized,
eval_dataset=test_tokenized,
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
trainer.train()
trainer.evaluate()
predictions, labels, _ = trainer.predict(test_tokenized)
predictions = np.argmax(predictions, axis=2)
# Remove ignored index (special tokens)
true_predictions = [
[label_list[p] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
true_labels = [
[label_list[l] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
results = metric.compute(predictions=true_predictions, references=true_labels)
results
return 0
if __name__ == "__main__":
main()