ThanhNguyen1811 commited on
Commit
b7f8cd0
·
verified ·
1 Parent(s): dc302ab

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -39
app.py CHANGED
@@ -4,10 +4,7 @@ import torchaudio
4
  import pandas as pd
5
  import os
6
  import torch.nn as nn
7
- from transformers import (
8
- Wav2Vec2ForCTC, Wav2Vec2Processor, AutoModel, AutoTokenizer,
9
- WhisperProcessor, WhisperForConditionalGeneration
10
- )
11
 
12
  # Import các class mô hình từ file models.py
13
  from models import MultimodalClassifier, TextClassifier
@@ -24,18 +21,12 @@ LABELS_B = {0: "Đe dọa", 1: "Tức giận", 2: "Tiêu cực thông thường"
24
  # Đường dẫn (Tương đối với thư mục gốc của Space)
25
  MODEL_A_PATH = "saved_models/best_model_A.pth"
26
  MODEL_B_PATH = "saved_models/best_model_B.pth"
27
- FUZZY_RULES_PATH = "data/datafuzzy29d.csv"
28
-
29
- # --- Tải hình STT (ĐÃ THAY ĐỔI SANG WHISPER) ---
30
- # === SỬA LỖI LẦN 2: Đã cập nhật tên model chính xác ===
31
- STT_MODEL_ID = "vinai/vinai-whisper-base"
32
- print(f"Đang tải mô hình STT Whisper: {STT_MODEL_ID}...")
33
- # Cần 'language' và 'task' để bộ xử lý biết cách hoạt động
34
- audio_processor = WhisperProcessor.from_pretrained(STT_MODEL_ID, language="vi", task="transcribe")
35
- stt_model = WhisperForConditionalGeneration.from_pretrained(STT_MODEL_ID).to(device)
36
-
37
- # --- Tải các mô hình nền khác ---
38
- print("Đang tải mô hình PhoBERT...")
39
  text_tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base")
40
  text_feature_extractor = AutoModel.from_pretrained("vinai/phobert-base").to(device)
41
 
@@ -59,6 +50,7 @@ try:
59
  fuzzy_rules_df = pd.read_csv(FUZZY_RULES_PATH, sep=';')
60
  fuzzy_rules = {}
61
  for _, row in fuzzy_rules_df.iterrows():
 
62
  fuzzy_rules[(row['model_a_label'], row['model_b_label'])] = row['final_label']
63
  print(f"Đã tải {len(fuzzy_rules)} luật fuzzy.")
64
  except Exception as e:
@@ -67,47 +59,48 @@ except Exception as e:
67
 
68
  print("Tất cả mô hình đã sẵn sàng.")
69
 
70
- # --- 2. Định nghĩa Hàm Dự đoán (ĐÃ CẬP NHẬT) ---
 
71
  def predict_sentiment(audio_input):
72
  if audio_input is None:
73
  return "[Chưa có âm thanh]", "N/A", "N/A", "N/A"
74
 
75
  sample_rate, waveform_numpy = audio_input
 
 
76
  waveform = torch.from_numpy(waveform_numpy).float()
77
 
 
78
  if waveform.ndim > 1:
79
  waveform = waveform[0]
80
 
81
- # --- Bước 1 & 2 (Gộp): STT và Đặc trưng Audio (Logic của Whisper) ---
 
 
 
82
  try:
83
- # 1a. Resample (Whisper yêu cầu 16000 Hz)
84
  if sample_rate != 16000:
85
  resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
86
  waveform = resampler(waveform)
87
 
88
- # 1b. Chuẩn bị input audio cho Whisper
89
- # Không cần unsqueeze(0) processor tự xử lý
90
- inputs = audio_processor(waveform, sampling_rate=16000, return_tensors="pt")
91
- input_features = inputs.input_features.to(device)
92
 
93
- # 2a. Trích xuất Đặc trưng Audio (cho Model A)
94
- # Chúng ta cần chạy encoder của Whisper để lấy hidden states
95
  with torch.no_grad():
96
- encoder_outputs = stt_model.model.encoder(input_features)
97
- # Lấy hidden state cuối cùng và tính trung bình
98
- audio_feat_A = torch.mean(encoder_outputs.last_hidden_state, dim=1)
99
 
100
- # 2b. Trích xuất Văn bản (STT)
101
- # Chạy hàm generate() để tạo token ID
102
- with torch.no_grad():
103
- predicted_ids = stt_model.generate(input_features, language="vi")
104
-
105
- # Giải mã token ID thành văn bản
106
- transcribed_text = audio_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].lower()
107
 
108
  if not transcribed_text:
109
  transcribed_text = "[Không nhận diện được giọng nói]"
110
 
 
 
 
111
  except Exception as e:
112
  return f"[Lỗi xử lý audio: {e}]", "Lỗi Audio", "Lỗi Audio", "Lỗi Audio"
113
 
@@ -158,8 +151,10 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
158
 
159
  with gr.Row():
160
  with gr.Column(scale=2):
 
 
161
  audio_in = gr.Audio(
162
- sources=["upload", "microphone"],
163
  type="numpy",
164
  label="Tải lên tệp âm thanh hoặc Ghi âm"
165
  )
@@ -167,6 +162,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
167
 
168
  with gr.Column(scale=3):
169
  gr.Markdown("### Kết quả Phân tích")
 
170
  text_out = gr.Textbox(label="Văn bản được nhận diện (STT)")
171
  final_pred_out = gr.Label(label="Kết quả cuối cùng (Nguy cơ)")
172
 
@@ -174,14 +170,14 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
174
  pred_A_out = gr.Textbox(label="Dự đoán Model A (Đa phương tiện)")
175
  pred_B_out = gr.Textbox(label="Dự đoán Model B (Chỉ văn bản)")
176
 
 
177
  submit_btn.click(
178
  fn=predict_sentiment,
179
  inputs=audio_in,
180
  outputs=[text_out, pred_A_out, pred_B_out, final_pred_out]
181
  )
182
 
183
- gr.Markdown("Lưu ý: Mô hình STT hiện đang sử dụng `vinai/vinai-whisper-base`.")
184
 
185
  print("Đang khởi chạy demo...")
186
- demo.launch()
187
-
 
4
  import pandas as pd
5
  import os
6
  import torch.nn as nn
7
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, AutoModel, AutoTokenizer
 
 
 
8
 
9
  # Import các class mô hình từ file models.py
10
  from models import MultimodalClassifier, TextClassifier
 
21
  # Đường dẫn (Tương đối với thư mục gốc của Space)
22
  MODEL_A_PATH = "saved_models/best_model_A.pth"
23
  MODEL_B_PATH = "saved_models/best_model_B.pth"
24
+ FUZZY_RULES_PATH = "data/datafuzzy29d.csv" # Đảm bảo tên file này chính xác
25
+
26
+ # Tải các hình nền (từ Hugging Face Hub)
27
+ print("Đang tải các hình nền (STT, PhoBERT)...")
28
+ audio_processor = Wav2Vec2Processor.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h")
29
+ stt_model = Wav2Vec2ForCTC.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h").to(device)
 
 
 
 
 
 
30
  text_tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base")
31
  text_feature_extractor = AutoModel.from_pretrained("vinai/phobert-base").to(device)
32
 
 
50
  fuzzy_rules_df = pd.read_csv(FUZZY_RULES_PATH, sep=';')
51
  fuzzy_rules = {}
52
  for _, row in fuzzy_rules_df.iterrows():
53
+ # Đảm bảo tên cột khớp với file CSV của bạn
54
  fuzzy_rules[(row['model_a_label'], row['model_b_label'])] = row['final_label']
55
  print(f"Đã tải {len(fuzzy_rules)} luật fuzzy.")
56
  except Exception as e:
 
59
 
60
  print("Tất cả mô hình đã sẵn sàng.")
61
 
62
+ # --- 2. Định nghĩa Hàm Dự đoán ---
63
+ # Hàm này sẽ được Gradio gọi mỗi khi người dùng nhấn "Submit"
64
  def predict_sentiment(audio_input):
65
  if audio_input is None:
66
  return "[Chưa có âm thanh]", "N/A", "N/A", "N/A"
67
 
68
  sample_rate, waveform_numpy = audio_input
69
+
70
+ # Đảm bảo waveform là tensor float
71
  waveform = torch.from_numpy(waveform_numpy).float()
72
 
73
+ # Đảm bảo là 1D (mono) hoặc lấy kênh đầu tiên nếu là stereo
74
  if waveform.ndim > 1:
75
  waveform = waveform[0]
76
 
77
+ # Thêm chiều batch (1,)
78
+ waveform = waveform.unsqueeze(0)
79
+
80
+ # --- Bước 1 & 2 (Gộp): STT và Đặc trưng Audio ---
81
  try:
82
+ # 1a. Resample
83
  if sample_rate != 16000:
84
  resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
85
  waveform = resampler(waveform)
86
 
87
+ # 1b. Chuẩn bị input audio
88
+ input_values = audio_processor(waveform.squeeze(), return_tensors="pt", sampling_rate=16000).input_values.to(device)
 
 
89
 
 
 
90
  with torch.no_grad():
91
+ audio_outputs = stt_model(input_values, output_hidden_states=True)
 
 
92
 
93
+ # 2a. Trích xuất Văn bản (STT)
94
+ logits = audio_outputs.logits
95
+ predicted_ids = torch.argmax(logits, dim=-1)
96
+ transcribed_text = audio_processor.batch_decode(predicted_ids)[0].lower()
 
 
 
97
 
98
  if not transcribed_text:
99
  transcribed_text = "[Không nhận diện được giọng nói]"
100
 
101
+ # 2b. Trích xuất Đặc trưng Audio (cho Model A)
102
+ audio_feat_A = torch.mean(audio_outputs.hidden_states[-1], dim=1)
103
+
104
  except Exception as e:
105
  return f"[Lỗi xử lý audio: {e}]", "Lỗi Audio", "Lỗi Audio", "Lỗi Audio"
106
 
 
151
 
152
  with gr.Row():
153
  with gr.Column(scale=2):
154
+ # === BỔ SUNG TÍNH NĂNG ===
155
+ # Thêm "microphone" vào sources để cho phép ghi âm
156
  audio_in = gr.Audio(
157
+ sources=["upload", "microphone"], # Cho phép cả tải lên và ghi âm
158
  type="numpy",
159
  label="Tải lên tệp âm thanh hoặc Ghi âm"
160
  )
 
162
 
163
  with gr.Column(scale=3):
164
  gr.Markdown("### Kết quả Phân tích")
165
+ # Các ô output
166
  text_out = gr.Textbox(label="Văn bản được nhận diện (STT)")
167
  final_pred_out = gr.Label(label="Kết quả cuối cùng (Nguy cơ)")
168
 
 
170
  pred_A_out = gr.Textbox(label="Dự đoán Model A (Đa phương tiện)")
171
  pred_B_out = gr.Textbox(label="Dự đoán Model B (Chỉ văn bản)")
172
 
173
+ # Liên kết nút bấm với hàm dự đoán
174
  submit_btn.click(
175
  fn=predict_sentiment,
176
  inputs=audio_in,
177
  outputs=[text_out, pred_A_out, pred_B_out, final_pred_out]
178
  )
179
 
180
+ gr.Markdown("Lưu ý: Mô hình STT được tối ưu cho tiếng Việt.")
181
 
182
  print("Đang khởi chạy demo...")
183
+ demo.launch() # Không cần (share=True) khi chạy trên Spaces