vnanhtuan commited on
Commit
718d45a
·
verified ·
1 Parent(s): 67083e9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -0
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
5
+ from sklearn.preprocessing import LabelEncoder
6
+ from sklearn.model_selection import train_test_split
7
+ from datasets import Dataset
8
+ from underthesea import word_tokenize
9
+ import os
10
+ import pickle
11
+
12
+ # ---- Load PhoBERT ----
13
+ MODEL_NAME = "vinai/phobert-base"
14
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
15
+
16
+ # ---- Label Encoder (sẽ lưu sau khi train) ----
17
+ label_encoder = None
18
+
19
+ # ---- Tokenization ----
20
+ def preprocess_function(example):
21
+ # word_tokenize tiếng Việt đúng cách cho PhoBERT
22
+ tokens = word_tokenize(example["comment"], format="text")
23
+ return tokenizer(tokens, truncation=True)
24
+
25
+ # ---- Train function ----
26
+ def train_model(file):
27
+ df = pd.read_csv(file.name)
28
+
29
+ global label_encoder
30
+ label_encoder = LabelEncoder()
31
+ df["label"] = label_encoder.fit_transform(df["label"])
32
+
33
+ # Save encoder
34
+ with open("label_encoder.pkl", "wb") as f:
35
+ pickle.dump(label_encoder, f)
36
+
37
+ # Chuyển Hugging Face Dataset
38
+ dataset = Dataset.from_pandas(df[["comment", "label"]])
39
+ tokenized_dataset = dataset.map(preprocess_function)
40
+
41
+ # Split
42
+ tokenized_dataset = tokenized_dataset.train_test_split(test_size=0.2)
43
+
44
+ # Load model
45
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=len(set(df["label"])))
46
+
47
+ # Training
48
+ args = TrainingArguments(
49
+ output_dir="./results",
50
+ evaluation_strategy="epoch",
51
+ per_device_train_batch_size=8,
52
+ per_device_eval_batch_size=8,
53
+ num_train_epochs=3,
54
+ save_strategy="no",
55
+ logging_steps=10
56
+ )
57
+
58
+ trainer = Trainer(
59
+ model=model,
60
+ args=args,
61
+ train_dataset=tokenized_dataset["train"],
62
+ eval_dataset=tokenized_dataset["test"],
63
+ tokenizer=tokenizer,
64
+ )
65
+
66
+ trainer.train()
67
+
68
+ # Save model
69
+ model.save_pretrained("finetuned_phobert")
70
+ tokenizer.save_pretrained("finetuned_phobert")
71
+
72
+ return "✅ Huấn luyện thành công!"
73
+
74
+ # ---- Dự đoán ----
75
+ def predict_sentiment(text):
76
+ if not os.path.exists("finetuned_phobert"):
77
+ return "❌ Chưa có mô hình được huấn luyện."
78
+
79
+ model = AutoModelForSequenceClassification.from_pretrained("finetuned_phobert")
80
+ tokenizer = AutoTokenizer.from_pretrained("finetuned_phobert", use_fast=False)
81
+
82
+ global label_encoder
83
+ if label_encoder is None:
84
+ with open("label_encoder.pkl", "rb") as f:
85
+ label_encoder = pickle.load(f)
86
+
87
+ tokens = word_tokenize(text, format="text")
88
+ inputs = tokenizer(tokens, return_tensors="pt", truncation=True)
89
+
90
+ with torch.no_grad():
91
+ outputs = model(**inputs)
92
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
93
+ pred = torch.argmax(probs, dim=1).item()
94
+
95
+ return f"🔎 Dự đoán: {label_encoder.inverse_transform([pred])[0]} (xác suất: {probs[0][pred]:.2f})"
96
+
97
+ # ---- Giao diện Gradio ----
98
+ with gr.Blocks() as demo:
99
+ gr.Markdown("# 🔥 Fine-tune cảm xúc tiếng Việt với PhoBERT")
100
+
101
+ with gr.Tab("1️⃣ Huấn luyện"):
102
+ file_input = gr.File(label="Tải lên file CSV")
103
+ train_button = gr.Button("Huấn luyện mô hình")
104
+ train_output = gr.Textbox(label="Kết quả")
105
+
106
+ train_button.click(fn=train_model, inputs=file_input, outputs=train_output)
107
+
108
+ with gr.Tab("2️⃣ Dự đoán"):
109
+ text_input = gr.Textbox(label="Nhập câu đánh giá")
110
+ predict_button = gr.Button("Dự đoán")
111
+ predict_output = gr.Textbox(label="Kết quả")
112
+
113
+ predict_button.click(fn=predict_sentiment, inputs=text_input, outputs=predict_output)
114
+
115
+ demo.launch()