File size: 8,193 Bytes
52b95f1
2d9c975
 
883bf93
2d9c975
 
 
 
25821ee
2d9c975
 
8740a9b
f23d10a
2d9c975
 
4804dfb
2d9c975
 
c0c6c53
2d9c975
883bf93
 
4991207
 
 
 
2d9c975
4991207
 
2d9c975
 
883bf93
960ccaf
 
 
2d9c975
 
 
 
 
 
 
4991207
 
2d9c975
 
4991207
 
2d9c975
883bf93
 
9ac8eab
2d9c975
 
 
9ac8eab
2d9c975
 
883bf93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ec8dc9
0526ee9
2ec8dc9
 
883bf93
 
0526ee9
2d9c975
0526ee9
4991207
2d9c975
 
4991207
 
 
 
 
0526ee9
4991207
2ec8dc9
4991207
2d9c975
39c1d5c
4804dfb
39c1d5c
4804dfb
39c1d5c
 
4804dfb
2d9c975
883bf93
 
4804dfb
883bf93
4804dfb
 
 
294e733
2d9c975
 
 
 
b487799
2d9c975
b36759a
 
 
 
d236f33
b36759a
d236f33
 
 
d7367c4
b36759a
2d9c975
 
294e733
1bdb3dd
b824b83
1bdb3dd
d236f33
b824b83
2d9c975
 
1bdb3dd
b824b83
2d9c975
b824b83
2d9c975
 
 
 
4804dfb
2d9c975
294e733
2d9c975
 
b824b83
2d9c975
 
 
 
1bdb3dd
2d9c975
 
1bdb3dd
2d9c975
 
d236f33
 
2d9c975
b36759a
2d9c975
21861a0
 
 
 
d236f33
2d9c975
 
1bdb3dd
 
 
 
d236f33
1bdb3dd
2d9c975
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294e733
2d9c975
1bdb3dd
b824b83
2d9c975
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import spaces  # Import the ZeroGPU helper
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
import torch
from torch.nn.functional import softmax
import numpy as np
import soundfile as sf
import io
import tempfile
import outlines  # For Qwen integration via outlines
import kokoro     # For TTS synthesis
import re
from pathlib import Path
from functools import lru_cache
import warnings

# Suppress FutureWarnings (e.g. about using `inputs` vs. `input_features`)
warnings.filterwarnings("ignore", category=FutureWarning)

# ------------------- Model Identifiers -------------------
whisper_model_id = "Jingmiao/whisper-small-zh_tw"
qwen_model_id = "Qwen/Qwen2.5-0.5B-Instruct"

available_models = {
    "ALBERT-tiny (Chinese)": "Luigi/albert-tiny-chinese-dinercall-intent",
    "ALBERT-base (Chinese)": "Luigi/albert-base-chinese-dinercall-intent",
    "Qwen (via Transformers - outlines)": "qwen"
}

# ------------------- Caching and Loading Functions -------------------
@lru_cache(maxsize=1)
def load_whisper_pipeline():
    pipe = pipeline("automatic-speech-recognition", 
                    model=whisper_model_id,
                    chunk_length_s=30)
    # Move model to GPU if available for faster inference
    if torch.cuda.is_available():
        pipe.model.to("cuda")
    return pipe

@lru_cache(maxsize=2)
def load_transformers_model(model_id: str):
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
    model = AutoModelForSequenceClassification.from_pretrained(model_id)
    if torch.cuda.is_available():
        model.to("cuda")
    return tokenizer, model

@lru_cache(maxsize=1)
def load_qwen_model():
    return outlines.models.transformers(qwen_model_id)

@lru_cache(maxsize=1)
def get_tts_pipeline():
    return kokoro.KPipeline(lang_code="z")

# ------------------- Inference Functions -------------------
def predict_with_qwen(text: str):
    model = load_qwen_model()
    prompt = f"""
<|im_start|>system
You are an expert in classification of restaurant customers' messages.
You must decide between the following two intents:
RESERVATION: Inquiries and requests highly related to table reservations and seating.
NOT_RESERVATION: All other messages.
Respond with *only* the intent label in a JSON object, like: {{"result": "RESERVATION"}}.
<|im_end|>

<|im_start|>user
Classify the following message: "{text}"
<|im_end|>

<|im_start|>assistant
"""
    generator = outlines.generate.choice(model, ["RESERVATION", "NOT_RESERVATION"])
    prediction = generator(prompt)
    if prediction == "RESERVATION":
        return "📞 訂位意圖 (Reservation intent)"
    elif prediction == "NOT_RESERVATION":
        return "❌ 無訂位意圖 (Not Reservation intent)"
    else:
        return f"未知回應: {prediction}"

def predict_intent(text: str, model_id: str):
    tokenizer, model = load_transformers_model(model_id)
    inputs = tokenizer(text, return_tensors="pt")
    if torch.cuda.is_available():
        inputs = {k: v.to("cuda") for k, v in inputs.items()}
    with torch.no_grad():
        logits = model(**inputs).logits
        probs = softmax(logits, dim=-1)
        confidence = probs[0, 1].item()
    if confidence >= 0.7:
        return f"📞 訂位意圖 (Reservation intent)(訂位信心度 Confidence: {confidence:.2%})"
    else:
        return f"❌ 無訂位意圖 (Not Reservation intent)(訂位信心度 Confidence: {confidence:.2%})"

def get_tts_message(intent_result: str):
    if intent_result and "訂位意圖" in intent_result and "無" not in intent_result:
        return "稍後您將會從簡訊收到訂位連結"
    elif intent_result:
        return "我們將會將您的回饋傳達給負責人,謝謝您"
    else:
        return "未能判斷意圖"

def tts_audio_output(message: str, voice: str = 'af_heart'):
    pipeline_tts = get_tts_pipeline()
    generator = pipeline_tts(message, voice=voice)
    audio_chunks = []
    for _, _, audio in generator:
        audio_chunks.append(audio)
    if audio_chunks:
        audio_concat = np.concatenate(audio_chunks)
        # Return as tuple (sample_rate, numpy_array) for gr.Audio (using 24000 Hz)
        return (24000, audio_concat)
    else:
        return None

def transcribe_audio(audio_input):
    whisper_pipe = load_whisper_pipeline()
    # For file input, audio_input is a filepath string.
    if isinstance(audio_input, str):
        result = whisper_pipe(audio_input)
        return result["text"]
    # For microphone input, we now also use file_path.
    elif isinstance(audio_input, tuple):
        # In our updated configuration, microphone input should be provided as a file path,
        # so this branch may not be reached.
        return ""
    else:
        return ""

# ------------------- Main Processing Function -------------------
@spaces.GPU  # Decorate to run on GPU when processing
def classify_intent(mode, mic_audio, text_input, file_audio, model_choice):
    # Determine input based on selected mode.
    if mode == "Microphone" and mic_audio is not None:
        # mic_audio is a file path.
        transcription = transcribe_audio(mic_audio)
    elif mode == "Text" and text_input:
        transcription = text_input
    elif mode == "File" and file_audio is not None:
        transcription = transcribe_audio(file_audio)
    else:
        return "請提供語音或文字輸入", "", None

    # Classify the transcribed or provided text.
    if available_models[model_choice] == "qwen":
        classification = predict_with_qwen(transcription)
    else:
        classification = predict_intent(transcription, available_models[model_choice])
    # Generate TTS message and corresponding audio.
    tts_msg = get_tts_message(classification)
    tts_audio = tts_audio_output(tts_msg)
    return transcription, classification, tts_audio

# ------------------- Gradio Blocks Interface Setup -------------------
with gr.Blocks() as demo:
    gr.Markdown("## 🍽️ 餐廳訂位意圖識別")
    gr.Markdown("錄音、上傳語音檔案或輸入文字,自動判斷是否具有訂位意圖。")
    
    with gr.Row():
        mode = gr.Radio(choices=["Microphone", "Text", "File"], label="選擇輸入模式", value="Microphone")
    
    with gr.Row():
        # For microphone input, set type="filepath" so that we always get a file path.
        mic_audio = gr.Audio(sources=["microphone"], type="filepath", label="語音輸入 (點擊錄音)")
        text_input = gr.Textbox(lines=2, placeholder="請輸入文字", label="文字輸入")
        file_audio = gr.Audio(sources=["upload"], type="filepath", label="上傳語音檔案")
    
    # Initially, only the microphone input is visible.
    text_input.visible = False
    file_audio.visible = False

    # Set visibility based on selected mode.
    def update_visibility(selected_mode):
        if selected_mode == "Microphone":
            return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
        elif selected_mode == "Text":
            return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
        else:  # File
            return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
    mode.change(fn=update_visibility, inputs=mode, outputs=[mic_audio, text_input, file_audio])
    
    with gr.Row():
        model_dropdown = gr.Dropdown(choices=list(available_models.keys()),
                                     value="ALBERT-tiny (Chinese)", label="選擇模型")
    
    with gr.Row():
        classify_btn = gr.Button("執行辨識")
    
    with gr.Row():
        transcription_output = gr.Textbox(label="轉換文字")
    with gr.Row():
        classification_output = gr.Textbox(label="意圖判斷結果")
    with gr.Row():
        tts_output = gr.Audio(type="numpy", label="TTS 語音輸出")
    
    # Button event triggers the classification.
    classify_btn.click(fn=classify_intent, 
                       inputs=[mode, mic_audio, text_input, file_audio, model_dropdown],
                       outputs=[transcription_output, classification_output, tts_output])

demo.launch()