address-extraction / train.py
duoquote
Refactor labels and update model configuration
696ac96
raw
history blame
5.68 kB
import io
import requests
import json
import time
import torch
import orjson
import zipfile
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertForTokenClassification, Trainer, TrainingArguments, BertConfig
from transformers import AutoTokenizer, AutoModelForTokenClassification
API_URL = "http://dockerbase.duo:8000"
PROJECT_ID = 1
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def load_data():
res = requests.post(
API_URL + "/v1/auth/login/",
json={"username": "admin", "password": "123"}
)
token = res.json()["key"]
res = requests.post(API_URL + "/v1/projects/1/download",
json={"format":"JSONL","exportApproved": True},
headers={"Authorization": "Token " + token}
)
task_id = res.json()["task_id"]
ready = False
print("Waiting for export task to be ready.", end="")
while not ready:
res = requests.get(
API_URL + "/v1/tasks/status/" + str(task_id),
headers={"Authorization": "Token " + token}
)
ready = res.json()["ready"]
if not ready:
time.sleep(1)
print(".", end="")
print("")
res = requests.get(
API_URL + f"/v1/projects/{PROJECT_ID}/download",
params={"taskId": task_id},
headers={"Authorization": "Token " + token}
)
zip_file = io.BytesIO(res.content)
with zipfile.ZipFile(zip_file, "r") as zip_ref:
data = zip_ref.read("admin.jsonl").decode("utf-8")
res = requests.get(
API_URL + f"/v1/projects/{PROJECT_ID}/span-types",
headers={"Authorization": "Token " + token}
)
labels = res.json()
return labels, [orjson.loads(line) for line in data.split("\n") if line]
labels, data = load_data()
# label_to_id = {}
# for i, label in enumerate(labels):
# label_to_id["B-" + label["text"]] = i * 2 + 1
# label_to_id["I-" + label["text"]] = i * 2 + 2
# label_to_id["O"] = 0
label_to_id = {label["text"]: i + 1 for i, label in enumerate(labels)}
label_to_id["[PAD]"] = 0
label_to_id["[UNK]"] = len(label_to_id)
id_to_label = {v: k for k, v in label_to_id.items()}
tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-turkish-cased")
model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-base-turkish-cased", num_labels=len(label_to_id)).to(device)
model.config.id2label = id_to_label
model.config.label2id = label_to_id
from datasets import DatasetDict, Dataset
def preprocess_data(item, tokenizer, label_to_id):
text = item['text']
inputs = tokenizer(
text,
return_offsets_mapping=True,
return_tensors="pt",
truncation=True,
padding='max_length',
max_length=128,
)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
offset_mapping = inputs["offset_mapping"]
labels = ["[PAD]"] * 128
for token_idx, [off_start, off_end] in enumerate(offset_mapping[0]):
if off_start == off_end:
continue
for start, end, label in item['label']:
if start <= off_start and off_end <= end:
labels[token_idx] = label
break
if labels[token_idx] == "[PAD]":
labels[token_idx] = "[UNK]"
# Convert labels to ids
labels = [label_to_id[label] for label in labels]
return {
"input_ids": input_ids.flatten(),
"attention_mask": attention_mask.flatten(),
"labels": labels,
}
class AddressDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
item = self.dataset[index]
return {key: torch.tensor(val) for key, val in item.items()}
dataset = Dataset.from_generator(
lambda: (preprocess_data(item, tokenizer, label_to_id) for item in data),
)
dataset = dataset.train_test_split(test_size=0.2)
dataset = DatasetDict({
"train": dataset["train"],
"test": dataset["test"]
})
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=35,
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
# logging_dir="./logs",
# logging_first_step=True,
# evaluation_strategy="epoch",
# save_strategy="epoch",
logging_strategy="epoch",
# load_best_model_at_end=True,
)
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
def compute_metrics(pred, id_to_label):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
labels = [[id_to_label[label_id] for label_id in label_ids] for label_ids in labels]
preds = [[id_to_label[pred] for pred in preds] for preds in preds]
labels = [set(label) for label in labels]
preds = [set(pred) for pred in preds]
mlb = MultiLabelBinarizer()
mlb.fit([id_to_label.values()])
labels = mlb.transform(labels)
preds = mlb.transform(preds)
return {
"accuracy": accuracy_score(labels, preds),
"precision": precision_score(labels, preds, average="micro"),
"recall": recall_score(labels, preds, average="micro"),
"f1": f1_score(labels, preds, average="micro"),
}
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=tokenizer,
compute_metrics=lambda p: compute_metrics(p, id_to_label),
)
trainer.train()
trainer.evaluate()
trainer.save_model("./model")