Spaces:
Paused
Paused
File size: 2,197 Bytes
17c19fa 4916696 f1a856e 17c19fa f1a856e be09389 4916696 011f25d 4916696 7caa665 4916696 afa30ff f1a856e afa30ff 17c19fa 4916696 3baf600 4916696 17c19fa 7caa665 16ac456 7caa665 4916696 7caa665 be09389 7caa665 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import torch
import gradio as gr
import pandas as pd
from datasets import Dataset
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer
)
import os
token = os.getenv('HF_TOKEN')
# load dataset
df = pd.read_csv("dataset.csv")
dataset = Dataset.from_pandas(df)
# load tokenizer & model
model_name = "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2, ignore_mismatched_sizes=True)
# tokenize data
def preprocess(examples):
return tokenizer(examples["text"], truncation=True, padding=True)
tokenized_dataset = dataset.map(preprocess, batched=True)
# training arguments
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=4,
num_train_epochs=3,
logging_steps=10,
save_strategy="no",
learning_rate=2e-5,
)
# train
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
tokenizer=tokenizer,
)
trainer.train()
from huggingface_hub import HfApi
model.push_to_hub("kwanpon/mdeberta-classify-thai", use_auth_token=token)
tokenizer.push_to_hub("kwanpon/mdeberta-classify-thai", use_auth_token=token)
# inference function for gradio
def classify(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1).numpy()[0]
return {
"ไม่เกี่ยวข้อง": float(probs[0]),
"จ้างงานรถขนส่ง": float(probs[1]),
}
# gradio interface
demo = gr.Interface(
fn=classify,
inputs=gr.Textbox(lines=3, label="ข้อความ"),
outputs=gr.Label(label="ผลการจำแนก"),
title="Text Classifier: Zero-Shot NLI",
description="กรุณาพิมพ์ข้อความเพื่อตรวจสอบว่าเป็นการว่าจ้างงานรถขนส่งหรือไม่"
)
if __name__ == "__main__":
demo.launch() |