vuvanhung commited on
Commit
67ca15b
·
verified ·
1 Parent(s): a616cb0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -0
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import uuid
4
+ import time
5
+ import json
6
+ import logging
7
+ import tempfile
8
+ import threading
9
+
10
+ from flask import Flask, request, jsonify, send_file
11
+ from transformers import pipeline
12
+ from gtts import gTTS
13
+ from pydub import AudioSegment
14
+
15
+ # ================= CONFIG =================
16
+
17
+ TEMP_AUDIO_DIR = "/tmp/audio"
18
+ os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
19
+
20
+ STT_MODEL = "openai/whisper-tiny"
21
+ LLM_MODEL = "google/flan-t5-base"
22
+
23
+ MAX_AUDIO_SECONDS = 10
24
+ MAX_TEXT_LEN = 200
25
+
26
+ CLEANUP_INTERVAL = 300 # seconds
27
+ FILE_EXPIRE_TIME = 600 # seconds
28
+
29
+ # ================= LOG =================
30
+
31
+ logging.basicConfig(
32
+ level=logging.INFO,
33
+ format="%(asctime)s | %(levelname)s | %(message)s"
34
+ )
35
+ logger = logging.getLogger(__name__)
36
+
37
+ # ================= APP =================
38
+
39
+ app = Flask(__name__)
40
+ app.config["TEMP_AUDIO_DIR"] = TEMP_AUDIO_DIR
41
+
42
+ # ================= LOAD MODELS =================
43
+
44
+ logger.info("Loading STT model...")
45
+ stt_pipeline = pipeline(
46
+ "automatic-speech-recognition",
47
+ model=STT_MODEL,
48
+ device="cpu"
49
+ )
50
+
51
+ logger.info("Loading LLM model...")
52
+ llm_pipeline = pipeline(
53
+ "text2text-generation",
54
+ model=LLM_MODEL,
55
+ device="cpu"
56
+ )
57
+
58
+ logger.info("Models loaded successfully")
59
+
60
+ # ================= UTILS =================
61
+
62
+ def generate_tts_audio(text: str) -> bytes:
63
+ """
64
+ Generate WAV 16kHz mono audio from text
65
+ """
66
+ try:
67
+ text = text.replace("\n", " ").strip()
68
+ if not text:
69
+ text = "I understand."
70
+
71
+ text = text[:MAX_TEXT_LEN]
72
+ logger.info(f"TTS: {text}")
73
+
74
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as wav_file:
75
+ mp3_path = wav_file.name.replace(".wav", ".mp3")
76
+
77
+ tts = gTTS(text=text, lang="en")
78
+ tts.save(mp3_path)
79
+
80
+ audio = AudioSegment.from_file(mp3_path)
81
+ audio = audio.set_frame_rate(16000).set_channels(1)
82
+ audio.export(wav_file.name, format="wav")
83
+
84
+ with open(wav_file.name, "rb") as f:
85
+ wav_data = f.read()
86
+
87
+ os.remove(mp3_path)
88
+ os.remove(wav_file.name)
89
+
90
+ return wav_data
91
+
92
+ except Exception as e:
93
+ logger.error(f"TTS error: {e}", exc_info=True)
94
+ return b""
95
+
96
+
97
+ def cleanup_temp_files():
98
+ while True:
99
+ try:
100
+ now = time.time()
101
+ for filename in os.listdir(TEMP_AUDIO_DIR):
102
+ path = os.path.join(TEMP_AUDIO_DIR, filename)
103
+ if os.path.isfile(path):
104
+ if now - os.path.getmtime(path) > FILE_EXPIRE_TIME:
105
+ os.remove(path)
106
+ except Exception as e:
107
+ logger.warning(f"Cleanup error: {e}")
108
+
109
+ time.sleep(CLEANUP_INTERVAL)
110
+
111
+
112
+ # ================= ROUTES =================
113
+
114
+ @app.route("/health", methods=["GET"])
115
+ def health():
116
+ return jsonify({
117
+ "status": "ok",
118
+ "stt": STT_MODEL,
119
+ "llm": LLM_MODEL
120
+ })
121
+
122
+
123
+ @app.route("/process_audio", methods=["POST"])
124
+ def process_audio():
125
+ try:
126
+ if "audio" not in request.files:
127
+ return jsonify({"error": "No audio file"}), 400
128
+
129
+ audio_file = request.files["audio"]
130
+ raw_audio = audio_file.read()
131
+
132
+ if len(raw_audio) < 1000:
133
+ return jsonify({"error": "Audio too short"}), 400
134
+
135
+ # ================= STT =================
136
+ logger.info("Running STT...")
137
+ stt_result = stt_pipeline(
138
+ raw_audio,
139
+ sampling_rate=16000
140
+ )
141
+
142
+ user_text = stt_result.get("text", "").strip()
143
+ logger.info(f"User said: {user_text}")
144
+
145
+ if not user_text:
146
+ user_text = "Hello"
147
+
148
+ # ================= LLM =================
149
+ logger.info("Running LLM...")
150
+ llm_result = llm_pipeline(
151
+ user_text,
152
+ max_new_tokens=64,
153
+ do_sample=False
154
+ )
155
+
156
+ answer = llm_result[0]["generated_text"]
157
+ logger.info(f"Answer: {answer}")
158
+
159
+ # ================= TTS =================
160
+ audio_response = generate_tts_audio(answer)
161
+
162
+ if not audio_response:
163
+ return jsonify({"error": "TTS failed"}), 500
164
+
165
+ file_id = str(uuid.uuid4())
166
+ filepath = os.path.join(TEMP_AUDIO_DIR, f"{file_id}.wav")
167
+
168
+ with open(filepath, "wb") as f:
169
+ f.write(audio_response)
170
+
171
+ return send_file(
172
+ filepath,
173
+ mimetype="audio/wav",
174
+ as_attachment=False,
175
+ download_name="response.wav"
176
+ )
177
+
178
+ except Exception as e:
179
+ logger.error(f"Processing error: {e}", exc_info=True)
180
+ return jsonify({"error": "Internal error"}), 500
181
+
182
+
183
+ # ================= STARTUP =================
184
+
185
+ if __name__ == "__main__":
186
+ threading.Thread(target=cleanup_temp_files, daemon=True).start()
187
+
188
+ app.run(
189
+ host="0.0.0.0",
190
+ port=7860,
191
+ threaded=True
192
+ )