|
from pathlib import Path |
|
import random |
|
import shutil |
|
from datasets import load_dataset, concatenate_datasets, Features, Sequence, ClassLabel, Value, DatasetDict |
|
from transformers import TrainingArguments |
|
from span_marker import SpanMarkerModel, Trainer |
|
from span_marker.model_card import SpanMarkerModelCardData |
|
from huggingface_hub import upload_folder, upload_file |
|
|
|
|
|
""" |
|
FEATURES = Features({"tokens": Sequence(feature=Value(dtype='string')), "ner_tags": Sequence(feature=ClassLabel(names=['O', 'B-ORG', 'I-ORG']))}) |
|
|
|
|
|
def load_fewnerd(): |
|
def mapper(sample): |
|
sample["ner_tags"] = [int(tag == 5) for tag in sample["ner_tags"]] |
|
sample["ner_tags"] = [2 if tag == 1 and idx > 0 and sample["ner_tags"][idx - 1] == 1 else tag for idx, tag in enumerate(sample["ner_tags"])] |
|
return sample |
|
|
|
dataset = load_dataset("DFKI-SLT/few-nerd", "supervised") |
|
dataset = dataset.map(mapper, remove_columns=["id", "fine_ner_tags"]) |
|
dataset = dataset.cast(FEATURES) |
|
return dataset |
|
|
|
|
|
def load_conll(): |
|
label_mapping = {3: 1, 4: 2} |
|
def mapper(sample): |
|
sample["ner_tags"] = [label_mapping.get(tag, 0) for tag in sample["ner_tags"]] |
|
return sample |
|
|
|
dataset = load_dataset("conll2003") |
|
dataset = dataset.map(mapper, remove_columns=["id", "pos_tags", "chunk_tags"]) |
|
dataset = dataset.cast(FEATURES) |
|
return dataset |
|
|
|
|
|
def load_ontonotes(): |
|
label_mapping = {11: 1, 12: 2} |
|
def mapper(sample): |
|
sample["ner_tags"] = [label_mapping.get(tag, 0) for tag in sample["ner_tags"]] |
|
return sample |
|
|
|
dataset = load_dataset("tner/ontonotes5") |
|
dataset = dataset.rename_column("tags", "ner_tags") |
|
dataset = dataset.map(mapper) |
|
dataset = dataset.cast(FEATURES) |
|
return dataset |
|
|
|
|
|
def load_multinerd(): |
|
label_mapping = {5: 1, 6: 2} |
|
def mapper(sample): |
|
sample["ner_tags"] = [label_mapping.get(tag, 0) for tag in sample["ner_tags"]] |
|
return sample |
|
|
|
def lang_filter(sample): |
|
return sample["lang"] == "en" |
|
|
|
dataset = load_dataset("Babelscape/multinerd") |
|
dataset = dataset.filter(lang_filter) |
|
dataset = dataset.map(mapper, remove_columns="lang") |
|
dataset = dataset.cast(FEATURES) |
|
return dataset |
|
|
|
|
|
def preprocess_raw_dataset(raw_dataset): |
|
# Set the number of sentences without an org equal to the number of sentences with an org |
|
def has_org(sample): |
|
return bool(sum(sample["ner_tags"])) |
|
|
|
def has_no_org(sample): |
|
return not has_org(sample) |
|
|
|
dataset_org = raw_dataset.filter(has_org) |
|
dataset_no_org = raw_dataset.filter(has_no_org) |
|
dataset_no_org = dataset_no_org.select(random.sample(range(len(dataset_no_org)), k=len(dataset_org))) |
|
dataset = concatenate_datasets([dataset_org, dataset_no_org]) |
|
return dataset |
|
""" |
|
|
|
|
|
def main() -> None: |
|
|
|
labels = ["O", "B-ORG", "I-ORG"] |
|
""" |
|
fewnerd_dataset = load_fewnerd() |
|
conll_dataset = load_conll() |
|
ontonotes_dataset = load_ontonotes() |
|
multinerd_dataset = load_multinerd() |
|
|
|
raw_train_dataset = concatenate_datasets([fewnerd_dataset["train"], conll_dataset["train"], ontonotes_dataset["train"], multinerd_dataset["train"]]) |
|
raw_eval_dataset = concatenate_datasets([fewnerd_dataset["validation"], conll_dataset["validation"], ontonotes_dataset["validation"], multinerd_dataset["validation"]]) |
|
raw_test_dataset = concatenate_datasets([fewnerd_dataset["test"], conll_dataset["test"], ontonotes_dataset["test"], multinerd_dataset["test"]]) |
|
|
|
train_dataset = preprocess_raw_dataset(raw_train_dataset) |
|
eval_dataset = preprocess_raw_dataset(raw_eval_dataset) |
|
test_dataset = preprocess_raw_dataset(raw_test_dataset) |
|
|
|
dataset_dict = DatasetDict({ |
|
"train": train_dataset, |
|
"validation": eval_dataset, |
|
"test": test_dataset, |
|
}) |
|
dataset_dict.push_to_hub("ner-orgs", private=True) |
|
""" |
|
|
|
dataset = load_dataset("tomaarsen/ner-orgs") |
|
|
|
train_dataset = dataset["train"] |
|
eval_dataset = dataset["validation"] |
|
eval_dataset = eval_dataset.select(random.sample(range(len(eval_dataset)), k=3000)) |
|
test_dataset = dataset["test"] |
|
|
|
|
|
encoder_id = "bert-base-cased" |
|
model_id = f"tomaarsen/span-marker-bert-base-orgs" |
|
model = SpanMarkerModel.from_pretrained( |
|
encoder_id, |
|
labels=labels, |
|
|
|
model_max_length=256, |
|
marker_max_length=128, |
|
entity_max_length=8, |
|
|
|
model_card_data=SpanMarkerModelCardData( |
|
model_id=model_id, |
|
encoder_id=encoder_id, |
|
dataset_name="FewNERD, CoNLL2003, OntoNotes v5, and MultiNERD", |
|
language=["en"], |
|
), |
|
) |
|
|
|
|
|
output_dir = Path("models") / model_id |
|
args = TrainingArguments( |
|
output_dir=output_dir, |
|
run_name=model_id, |
|
|
|
learning_rate=5e-5, |
|
per_device_train_batch_size=32, |
|
per_device_eval_batch_size=32, |
|
num_train_epochs=3, |
|
weight_decay=0.01, |
|
warmup_ratio=0.1, |
|
bf16=True, |
|
|
|
logging_first_step=True, |
|
logging_steps=100, |
|
evaluation_strategy="steps", |
|
save_strategy="steps", |
|
eval_steps=3000, |
|
save_total_limit=1, |
|
dataloader_num_workers=4, |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=args, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
) |
|
trainer.train() |
|
|
|
|
|
metrics = trainer.evaluate(test_dataset, metric_key_prefix="test") |
|
trainer.save_metrics("test", metrics) |
|
|
|
|
|
trainer.save_model(output_dir / "checkpoint-final") |
|
shutil.copy2(__file__, output_dir / "checkpoint-final" / "train.py") |
|
|
|
|
|
breakpoint() |
|
model.push_to_hub(model_id, private=True) |
|
upload_folder(folder_path=output_dir / "runs", path_in_repo="runs", repo_id=model_id) |
|
upload_file(path_or_fileobj=__file__, path_in_repo="train.py", repo_id=model_id) |
|
upload_file(path_or_fileobj=output_dir / "all_results.json", path_in_repo="all_results.json", repo_id=model_id) |
|
upload_file(path_or_fileobj=output_dir / "emissions.csv", path_in_repo="emissions.csv", repo_id=model_id) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |