import sys from datasets import load_dataset from transformers import TrainingArguments from span_marker import SpanMarkerModel, Trainer # Load the dataset, ensure "tokens" and "ner_tags" columns, and get a list of labels dataset = load_dataset("gwlms/germeval2014") labels = dataset["train"].features["ner_tags"].feature.names # Initialize a SpanMarker model using a pretrained BERT-style encoder model_name = sys.argv[1] model = SpanMarkerModel.from_pretrained( model_name, labels=labels, # SpanMarker hyperparameters: model_max_length=256, marker_max_length=128, entity_max_length=8, ) args = TrainingArguments( output_dir="/tmp", per_device_eval_batch_size=64, ) # Initialize the trainer using our model, training args & dataset, and train trainer = Trainer( model=model, args=args, train_dataset=dataset["train"], eval_dataset=dataset["validation"], ) print("Evaluating on development set...") dev_metrics = trainer.evaluate(dataset["validation"], metric_key_prefix="eval") print(dev_metrics) print("Evaluating on test set...") test_metrics = trainer.evaluate(dataset["test"], metric_key_prefix="test") print(test_metrics)