Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| from sklearn.utils import shuffle | |
| from googletrans import Translator | |
| from tqdm.asyncio import tqdm_asyncio | |
| from pathlib import Path | |
| import asyncio | |
| import os | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import accuracy_score, precision_recall_fscore_support | |
| from transformers import TrainingArguments, Trainer, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification, AutoModelForCausalLM | |
| import torch | |
| import gradio as gr | |
| data = pd.read_parquet("hf://datasets/boltuix/emotions-dataset/emotions_dataset.parquet") | |
| groups = { | |
| "neutral": "neutral", | |
| "anger": "angry", | |
| "love": "joy", | |
| "happiness": "fun", | |
| "sadness": "sorrow", | |
| "surprise": "surprised" | |
| } | |
| data = data[data['Label'].isin(groups.keys())].copy() | |
| data['Label'] = data['Label'].map(groups) | |
| seeds = [1, 2, 3, 4] | |
| async def translate_all(seed, texts): | |
| semaphore = asyncio.Semaphore(12) | |
| async def sem_translate_task(text, idx): | |
| async with semaphore: | |
| async with Translator() as translator: | |
| result = await translator.translate(text, src='en', dest='ja') | |
| return result.text, idx | |
| tasks = [asyncio.create_task(sem_translate_task(text, idx)) for idx, text in enumerate(texts)] | |
| translated = [None] * len(texts) | |
| for coro in tqdm_asyncio.as_completed(tasks, total=len(tasks)): | |
| result, index = await coro | |
| translated[index] = result | |
| return translated | |
| async def sample_all(seed: int, progress=gr.Progress(track_tqdm=True)): | |
| files = [] | |
| for seed in seeds: | |
| try: | |
| filename = f"./data/SampledData_{seed}.csv" | |
| if not os.path.exists(filename): | |
| Path("./data").mkdir(parents=True, exist_ok=True) | |
| sampled = ( | |
| data.groupby('Label', group_keys=False) | |
| .apply(lambda x: x.sample(n=1000, random_state=int(seed))) | |
| ) | |
| sampled = shuffle(sampled).reset_index(drop=True) | |
| texts = sampled["Sentence"].tolist() | |
| translated = await translate_all(seed, texts) | |
| sampled["Sentence"] = translated | |
| sampled.to_csv(filename, index=False) | |
| files.append(filename) | |
| else: | |
| files.append(filename) | |
| except Exception as e: | |
| raise gr.Error(e) | |
| return files | |
| class Dataset(torch.utils.data.Dataset): | |
| def __init__(self, encodings, labels=None): | |
| self.encodings = encodings | |
| self.labels = labels | |
| def __getitem__(self, idx): | |
| item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} | |
| if self.labels: | |
| item["labels"] = torch.tensor(self.labels[idx]) | |
| return item | |
| def __len__(self): | |
| return len(self.encodings["input_ids"]) | |
| def prepare_dataset(df, tokenizer): | |
| X = list(df["Sentence"]) | |
| y = list(df["Label"]) | |
| X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42) | |
| X_train_tokenized = tokenizer(X_train, padding=True, truncation=True, max_length=512) | |
| X_val_tokenized = tokenizer(X_val, padding=True, truncation=True, max_length=512) | |
| train_dataset = Dataset(X_train_tokenized, y_train) | |
| val_dataset = Dataset(X_val_tokenized, y_val) | |
| return train_dataset, val_dataset | |
| def compute_metrics(eval_pred): | |
| logits, labels = eval_pred | |
| preds = logits[0].argmax(axis=-1) | |
| accuracy = accuracy_score(labels, preds) | |
| precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted') | |
| return { | |
| 'accuracy': accuracy, | |
| 'precision': precision, | |
| 'recall': recall, | |
| 'f1': f1, | |
| } | |
| def train_model(): | |
| id2label = { | |
| 0: "neutral", | |
| 1: "angry", | |
| 2: "fun", | |
| 3: "joy", | |
| 4: "sorrow", | |
| 5: "surprised" | |
| } | |
| label2id = { | |
| "neutral": 0, | |
| "angry": 1, | |
| "fun": 2, | |
| "joy": 3, | |
| "sorrow": 4, | |
| "surprised": 5 | |
| } | |
| tokenizer = AutoTokenizer.from_pretrained("line-corporation/line-distilbert-base-japanese", trust_remote_code=True) | |
| model_paths = [] | |
| metric_str_all = [] | |
| for seed in seeds: | |
| csv_path = f"./data/SampledData_{seed}.csv" | |
| if not os.path.exists(csv_path): | |
| return f"File {csv_path} not found! กรุณาสร้างไฟล์ก่อน.", None | |
| df = pd.read_csv(csv_path) | |
| train_dataset, val_dataset = prepare_dataset(df, tokenizer) | |
| if not os.path.exists(f"./output/{seed}/final_model/model.safetensors"): | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| "line-corporation/line-distilbert-base-japanese", | |
| use_safetensors=True, | |
| num_labels=6, | |
| label2id=label2id, | |
| id2label=id2label | |
| ) | |
| training_args = TrainingArguments( | |
| output_dir=f"./output/{seed}", | |
| seed=seed, | |
| per_device_train_batch_size=8, | |
| per_device_eval_batch_size=8, | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| num_train_epochs=5, | |
| fp16=True, | |
| logging_dir=f"./logs/{seed}", | |
| logging_steps=100, | |
| load_best_model_at_end=True | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=val_dataset, | |
| tokenizer=tokenizer, | |
| compute_metrics=compute_metrics | |
| ) | |
| trainer.train() | |
| trainer.save_model(f"./output/{seed}/final_model") | |
| tokenizer.save_pretrained(f"./output/{seed}/final_model") | |
| model_paths.append(f"./output/{seed}/final_model") | |
| metrics = trainer.evaluate() | |
| metric_str = f"Seed {seed}:\n" + "\n".join([f"{k}: {v:.4f}" for k, v in metrics.items()]) | |
| metric_str_all.append(metric_str) | |
| else: | |
| model_paths.append(f"./output/{seed}/final_model") | |
| if not os.path.exists(f"./VRM-Emotions/model.safetensors"): | |
| models = [AutoModelForSequenceClassification.from_pretrained(p) for p in model_paths] | |
| state_dicts = [m.state_dict() for m in models] | |
| avg_state_dict = {} | |
| for key in state_dicts[0].keys(): | |
| avg_param = torch.stack([sd[key].float() for sd in state_dicts], dim=0).mean(dim=0) | |
| avg_state_dict[key] = avg_param | |
| avg_model = AutoModelForSequenceClassification.from_pretrained(model_paths[0]) | |
| avg_model.load_state_dict(avg_state_dict) | |
| avg_model.save_pretrained("./VRM-Emotions") | |
| tokenizer.save_pretrained("./VRM-Emotions") | |
| return "\n\n".join(metric_str_all), [os.path.join("./VRM-Emotions", file) for file in os.listdir("./VRM-Emotions")] | |
| async def train_model_async(progress=gr.Progress(track_tqdm=True)): | |
| try: | |
| return await asyncio.to_thread(train_model) | |
| except Exception as e: | |
| raise gr.Error(e) | |
| with gr.Blocks() as demo: | |
| with gr.Tab("Prepare Dataset"): | |
| dataset_files = gr.Files(label="CSV Files") | |
| sample_btn = gr.Button("Get Datasets") | |
| sample_btn.click(sample_all, inputs=None, outputs=dataset_files) | |
| with gr.Tab("Train Model"): | |
| train_results = gr.TextArea(label="Metrics", interactive=False) | |
| models_files = gr.Files(label="Trained Model") | |
| train_btn = gr.Button("Train All") | |
| train_btn.click(train_model_async, inputs=None, outputs=[train_results, models_files]) | |
| demo.launch() |