Zai / app.py
huynhkimthien's picture
Update app.py
a95aa9c verified
from fastapi import FastAPI, File, UploadFile, Request
from fastapi.responses import FileResponse
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import whisper
import torch
from gtts import gTTS
import os
import yt_dlp
#import re
import io
import numpy as np
import scipy.io.wavfile as wav
from fastapi.responses import JSONResponse
hf_token = os.getenv("HF_TOKEN")
app = FastAPI()
# Load Qwen model
model_name = "Qwen/Qwen3-4B-Instruct-2507"
tokenizer = AutoTokenizer.from_pretrained(model_name,token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
model_name,
token=hf_token,
device_map={"": "cpu"},
dtype=torch.float32
)
# Load Whisper model
whisper_model = whisper.load_model("base")
# Lưu hội thoại
conversation = [{"role": "system", "content": "Bạn là một trợ lý AI. Hãy trả lời ngắn gọn, súc tích, tối đa 2 câu."}]
# Hàm trích xuất tên bài hát từ văn bản
def extract_song_name(text):
import re
match = re.search(r"(bài|bài hát|nghe nhạc|mở nhạc)\s+(.*)", text.lower())
if match:
return match.group(2).strip()
return None
def download_youtube_as_wav(song_name, output_path="song.wav"):
search_query = f"ytsearch1:{song_name}"
ydl_opts = {
'format': 'bestaudio/best',
'outtmpl': 'temp_audio.%(ext)s',
'postprocessors': [{
'key': 'FFmpegExtractAudio',
'preferredcodec': 'wav',
'preferredquality': '192',
}],
'quiet': True,
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
ydl.download([search_query])
if os.path.exists("temp_audio.wav"):
os.rename("temp_audio.wav", output_path)
return output_path
return None
class ChatRequest(BaseModel):
message: str
@app.get("/")
def read_root():
return {"message": "Ứng dụng đang chạy!"}
# Endpoint chat text
@app.post("/chat")
async def chat(request: ChatRequest):
conversation.append({"role": "user", "content": request.message})
text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
response_text = generate_full_response(model_inputs)
conversation.append({"role": "assistant", "content": response_text})
return {"response": response_text}
# Endpoint voice chat + TTS
@app.post("/voice_chat")
async def voice_chat(request: Request):
try:
raw_audio = await request.body()
sample_rate = 16000
# Chuyển từ 3 bytes → int32
audio_np = np.frombuffer(raw_audio, dtype=np.uint8).reshape(-1, 3)
audio_int = (audio_np[:, 0].astype(np.int32) << 16) | \
(audio_np[:, 1].astype(np.int32) << 8) | \
audio_np[:, 2].astype(np.int32)
# Scale về int16 để ghi WAV
audio_int16 = (audio_int >> 8).astype(np.int16)
# Chuyển thành WAV
wav_io = io.BytesIO()
wav.write(wav_io, sample_rate, audio_int16)
wav_io.seek(0)
with open("temp_audio.wav", "wb") as f:
f.write(wav_io.read())
# Whisper nhận dạng
result = whisper_model.transcribe("temp_audio.wav", language="vi")
user_text = result["text"]
# Kiểm tra yêu cầu mở nhạc
if any(kw in user_text.lower() for kw in ["nghe nhạc", "mở bài hát", "bài hát", "bài"]):
song_name = extract_song_name(user_text)
if song_name:
wav_path = download_youtube_as_wav(song_name)
if wav_path:
return FileResponse(wav_path, media_type="audio/wav", filename="song.wav")
else:
return JSONResponse({"error": "Không tìm thấy hoặc tải được bài hát."}, status_code=404)
# Xử lý hội thoại
conversation.append({"role": "user", "content": user_text})
text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
response_text = generate_full_response(model_inputs)
conversation.append({"role": "assistant", "content": response_text})
# TTS
tts = gTTS(response_text, lang="vi")
audio_file = "response.mp3"
tts.save(audio_file)
return {
"user_text": user_text,
"response": response_text,
"audio_url": "/get_audio"
}
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
# Endpoint trả về file âm thanh
@app.get("/get_audio")
async def get_audio():
return FileResponse("response.mp3", media_type="audio/mpeg")
# Hàm sinh phản hồi
def generate_full_response(model_inputs, max_new_tokens=64):
with torch.inference_mode():
generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
response_text = tokenizer.decode(output_ids, skip_special_tokens=True)
return response_text.strip()