huynhkimthien commited on
Commit
e95508c
·
verified ·
1 Parent(s): 882ca13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -9
app.py CHANGED
@@ -1,13 +1,16 @@
1
- from fastapi import FastAPI
 
2
  from pydantic import BaseModel
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
  import torch
 
5
  import os
6
 
7
  app = FastAPI()
8
- model_name = "Qwen/Qwen3-4B-Instruct-2507"
9
 
10
- # Load tokenizer model (CPU cho Spaces Free)
 
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
  model = AutoModelForCausalLM.from_pretrained(
13
  model_name,
@@ -15,7 +18,11 @@ model = AutoModelForCausalLM.from_pretrained(
15
  torch_dtype=torch.float32
16
  )
17
 
18
- conversation = [{"role": "system", "content": "Bạn là một trợ lý AI. Hãy trả lời ngắn gọn, súc tích, tối đa 2 câu."}] # Lưu hội thoại
 
 
 
 
19
 
20
  class ChatRequest(BaseModel):
21
  message: str
@@ -24,25 +31,55 @@ class ChatRequest(BaseModel):
24
  def read_root():
25
  return {"message": "Ứng dụng đang chạy!"}
26
 
 
27
  @app.post("/chat")
28
  async def chat(request: ChatRequest):
29
  conversation.append({"role": "user", "content": request.message})
30
-
31
- # Áp dụng template hội thoại
32
  text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
33
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
 
 
 
 
34
 
 
 
 
 
 
 
 
 
35
  response_text = generate_full_response(model_inputs)
36
  conversation.append({"role": "assistant", "content": response_text})
37
 
38
- return {"response": response_text}
 
 
 
39
 
 
 
 
 
 
40
 
 
 
 
 
 
 
41
  def generate_full_response(model_inputs, max_new_tokens=64):
42
  with torch.inference_mode():
43
  generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)
44
  output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
45
  response_text = tokenizer.decode(output_ids, skip_special_tokens=True)
46
-
47
-
48
  return response_text.strip()
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import FileResponse
3
  from pydantic import BaseModel
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ import whisper
6
  import torch
7
+ from gtts import gTTS
8
  import os
9
 
10
  app = FastAPI()
 
11
 
12
+ # Load Qwen model
13
+ model_name = "Qwen/Qwen3-4B-Instruct-2507"
14
  tokenizer = AutoTokenizer.from_pretrained(model_name)
15
  model = AutoModelForCausalLM.from_pretrained(
16
  model_name,
 
18
  torch_dtype=torch.float32
19
  )
20
 
21
+ # Load Whisper model
22
+ whisper_model = whisper.load_model("base")
23
+
24
+ # Lưu hội thoại
25
+ conversation = [{"role": "system", "content": "Bạn là một trợ lý AI. Hãy trả lời ngắn gọn, súc tích, tối đa 2 câu."}]
26
 
27
  class ChatRequest(BaseModel):
28
  message: str
 
31
  def read_root():
32
  return {"message": "Ứng dụng đang chạy!"}
33
 
34
+ # Endpoint chat text
35
  @app.post("/chat")
36
  async def chat(request: ChatRequest):
37
  conversation.append({"role": "user", "content": request.message})
 
 
38
  text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
39
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
40
+ response_text = generate_full_response(model_inputs)
41
+ conversation.append({"role": "assistant", "content": response_text})
42
+ return {"response": response_text}
43
+
44
+ # Endpoint voice chat + TTS
45
+ @app.post("/voice_chat")
46
+ async def voice_chat(file: UploadFile = File(...)):
47
+ # Lưu file tạm
48
+ file_location = f"temp_{file.filename}"
49
+ with open(file_location, "wb") as f:
50
+ f.write(await file.read())
51
 
52
+ # Chuyển âm thanh thành text
53
+ result = whisper_model.transcribe(file_location, language="vi")
54
+ user_text = result["text"]
55
+
56
+ # Gọi mô hình Qwen để trả lời
57
+ conversation.append({"role": "user", "content": user_text})
58
+ text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
59
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
60
  response_text = generate_full_response(model_inputs)
61
  conversation.append({"role": "assistant", "content": response_text})
62
 
63
+ # Tạo file âm thanh từ phản hồi
64
+ tts = gTTS(response_text, lang="vi")
65
+ audio_file = "response.mp3"
66
+ tts.save(audio_file)
67
 
68
+ return {
69
+ "user_text": user_text,
70
+ "response": response_text,
71
+ "audio_url": f"/get_audio"
72
+ }
73
 
74
+ # Endpoint trả về file âm thanh
75
+ @app.get("/get_audio")
76
+ async def get_audio():
77
+ return FileResponse("response.mp3", media_type="audio/mpeg")
78
+
79
+ # Hàm sinh phản hồi
80
  def generate_full_response(model_inputs, max_new_tokens=64):
81
  with torch.inference_mode():
82
  generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)
83
  output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
84
  response_text = tokenizer.decode(output_ids, skip_special_tokens=True)
 
 
85
  return response_text.strip()