VRM-Emotions / app.py
MeowSky49887's picture
Update app.py
c2e19e0 verified
import pandas as pd
from sklearn.utils import shuffle
from googletrans import Translator
from tqdm.asyncio import tqdm_asyncio
from pathlib import Path
import asyncio
import os
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import TrainingArguments, Trainer, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification, AutoModelForCausalLM
import torch
import gradio as gr
data = pd.read_parquet("hf://datasets/boltuix/emotions-dataset/emotions_dataset.parquet")
groups = {
"neutral": "neutral",
"anger": "angry",
"love": "joy",
"happiness": "fun",
"sadness": "sorrow",
"surprise": "surprised"
}
data = data[data['Label'].isin(groups.keys())].copy()
data['Label'] = data['Label'].map(groups)
seeds = [1, 2, 3, 4]
async def translate_all(seed, texts):
semaphore = asyncio.Semaphore(12)
async def sem_translate_task(text, idx):
async with semaphore:
async with Translator() as translator:
result = await translator.translate(text, src='en', dest='ja')
return result.text, idx
tasks = [asyncio.create_task(sem_translate_task(text, idx)) for idx, text in enumerate(texts)]
translated = [None] * len(texts)
for coro in tqdm_asyncio.as_completed(tasks, total=len(tasks)):
result, index = await coro
translated[index] = result
return translated
async def sample_all(seed: int, progress=gr.Progress(track_tqdm=True)):
files = []
for seed in seeds:
try:
filename = f"./data/SampledData_{seed}.csv"
if not os.path.exists(filename):
Path("./data").mkdir(parents=True, exist_ok=True)
sampled = (
data.groupby('Label', group_keys=False)
.apply(lambda x: x.sample(n=1000, random_state=int(seed)))
)
sampled = shuffle(sampled).reset_index(drop=True)
texts = sampled["Sentence"].tolist()
translated = await translate_all(seed, texts)
sampled["Sentence"] = translated
sampled.to_csv(filename, index=False)
files.append(filename)
else:
files.append(filename)
except Exception as e:
raise gr.Error(e)
return files
class Dataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels=None):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
if self.labels:
item["labels"] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.encodings["input_ids"])
def prepare_dataset(df, tokenizer):
X = list(df["Sentence"])
y = list(df["Label"])
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42)
X_train_tokenized = tokenizer(X_train, padding=True, truncation=True, max_length=512)
X_val_tokenized = tokenizer(X_val, padding=True, truncation=True, max_length=512)
train_dataset = Dataset(X_train_tokenized, y_train)
val_dataset = Dataset(X_val_tokenized, y_val)
return train_dataset, val_dataset
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = logits[0].argmax(axis=-1)
accuracy = accuracy_score(labels, preds)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
}
def train_model():
id2label = {
0: "neutral",
1: "angry",
2: "fun",
3: "joy",
4: "sorrow",
5: "surprised"
}
label2id = {
"neutral": 0,
"angry": 1,
"fun": 2,
"joy": 3,
"sorrow": 4,
"surprised": 5
}
tokenizer = AutoTokenizer.from_pretrained("line-corporation/line-distilbert-base-japanese", trust_remote_code=True)
model_paths = []
metric_str_all = []
for seed in seeds:
csv_path = f"./data/SampledData_{seed}.csv"
if not os.path.exists(csv_path):
return f"File {csv_path} not found! กรุณาสร้างไฟล์ก่อน.", None
df = pd.read_csv(csv_path)
train_dataset, val_dataset = prepare_dataset(df, tokenizer)
if not os.path.exists(f"./output/{seed}/final_model/model.safetensors"):
model = AutoModelForSequenceClassification.from_pretrained(
"line-corporation/line-distilbert-base-japanese",
use_safetensors=True,
num_labels=6,
label2id=label2id,
id2label=id2label
)
training_args = TrainingArguments(
output_dir=f"./output/{seed}",
seed=seed,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
eval_strategy="epoch",
save_strategy="epoch",
num_train_epochs=5,
fp16=True,
logging_dir=f"./logs/{seed}",
logging_steps=100,
load_best_model_at_end=True
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
trainer.train()
trainer.save_model(f"./output/{seed}/final_model")
tokenizer.save_pretrained(f"./output/{seed}/final_model")
model_paths.append(f"./output/{seed}/final_model")
metrics = trainer.evaluate()
metric_str = f"Seed {seed}:\n" + "\n".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
metric_str_all.append(metric_str)
else:
model_paths.append(f"./output/{seed}/final_model")
if not os.path.exists(f"./VRM-Emotions/model.safetensors"):
models = [AutoModelForSequenceClassification.from_pretrained(p) for p in model_paths]
state_dicts = [m.state_dict() for m in models]
avg_state_dict = {}
for key in state_dicts[0].keys():
avg_param = torch.stack([sd[key].float() for sd in state_dicts], dim=0).mean(dim=0)
avg_state_dict[key] = avg_param
avg_model = AutoModelForSequenceClassification.from_pretrained(model_paths[0])
avg_model.load_state_dict(avg_state_dict)
avg_model.save_pretrained("./VRM-Emotions")
tokenizer.save_pretrained("./VRM-Emotions")
return "\n\n".join(metric_str_all), [os.path.join("./VRM-Emotions", file) for file in os.listdir("./VRM-Emotions")]
async def train_model_async(progress=gr.Progress(track_tqdm=True)):
try:
return await asyncio.to_thread(train_model)
except Exception as e:
raise gr.Error(e)
with gr.Blocks() as demo:
with gr.Tab("Prepare Dataset"):
dataset_files = gr.Files(label="CSV Files")
sample_btn = gr.Button("Get Datasets")
sample_btn.click(sample_all, inputs=None, outputs=dataset_files)
with gr.Tab("Train Model"):
train_results = gr.TextArea(label="Metrics", interactive=False)
models_files = gr.Files(label="Trained Model")
train_btn = gr.Button("Train All")
train_btn.click(train_model_async, inputs=None, outputs=[train_results, models_files])
demo.launch()