domain-labeler / train.py
davidmezzetti's picture
Add model
6f6a301
import warnings
import numpy as np
from datasets import load_dataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, average_precision_score
from scipy.special import softmax
from transformers import AutoModelForSequenceClassification, AutoConfig, AutoTokenizer
from txtai import Embeddings
from txtai.pipeline import HFTrainer
def batchlabel(rows):
return {"label": [config.label2id[label] for label in rows["label"]]}
def batchtext(rows):
texts = []
for x in rows["id"]:
results = embeddings.search("SELECT text FROM txtai WHERE id=:id", 1, parameters={"id": x})
texts.append(results[0]["text"])
return {"text": texts}
def metrics(pred):
logits, labelids = pred
preds = logits.argmax(-1)
# Calculate accuracy, precision, recall, and F1-score
accuracy = accuracy_score(labelids, preds)
precision = precision_score(labelids, preds, average="weighted", zero_division=0)
recall = recall_score(labelids, preds, average="weighted", zero_division=0)
f1 = f1_score(labelids, preds, average="weighted", zero_division=0)
# Calculate PR AUC
probs = softmax(logits, axis=-1)
nclasses = logits.shape[1]
onehot = np.eye(nclasses)[labelids]
# average_precision_score doesn't support zero_division parameter
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="No positive class found in y_true")
prauc = average_precision_score(onehot, probs, average="weighted")
return {
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"f1": f1,
"prauc": prauc
}
# Embeddings database
embeddings = Embeddings()
embeddings.load(provider="huggingface-hub", container="neuml/txtai-wikipedia-slim")
# Training dataset
ds = load_dataset("csv", data_files="labels.csv", split="train", keep_default_na=False)
labels = dict(enumerate(sorted(ds.unique("label"))))
print(labels)
# Base model
path = "jhu-clsp/ettin-encoder-32m"
# Model configuration
config = AutoConfig.from_pretrained(path)
config.num_labels = len(labels)
config.id2label = labels
config.label2id = {label: uid for uid, label in labels.items()}
# Map label ids
ds = ds.map(batchlabel, batched=True)
# Map text
ds = ds.map(batchtext, batched=True)
# Split into train and test
ds = ds.train_test_split(test_size=0.05, seed=42)
training, test = ds["train"], ds["test"]
# Model to train
model = AutoModelForSequenceClassification.from_pretrained(path, config=config)
tokenizer = AutoTokenizer.from_pretrained(path)
train = HFTrainer()
train(
(model, tokenizer), training, test, metrics=metrics, maxlength=512, bf16=True,
learning_rate=5e-5, per_device_train_batch_size=64, num_train_epochs=3,
warmup_ratio=0.1, lr_scheduler_type="cosine",
eval_strategy="steps", eval_steps=500, logging_steps=500,
tokenizers=True, dataloader_num_workers=20,
output_dir="domain-labeler"
)