prithivMLmods commited on
Commit
32d8e74
·
verified ·
1 Parent(s): 37aaee6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -55
app.py CHANGED
@@ -1,31 +1,17 @@
1
  import os
2
- from collections.abc import Iterator
3
- from threading import Thread
4
  import gradio as gr
5
- import spaces
6
  import torch
7
- import edge_tts
8
  import asyncio
 
 
 
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
 
11
  DESCRIPTION = """
12
- # QwQ Tiny
13
  """
14
 
15
- css ='''
16
- h1 {
17
- text-align: center;
18
- display: block;
19
- }
20
-
21
- #duplicate-button {
22
- margin: auto;
23
- color: #fff;
24
- background: #1565c0;
25
- border-radius: 100vh;
26
- }
27
- '''
28
-
29
  MAX_MAX_NEW_TOKENS = 2048
30
  DEFAULT_MAX_NEW_TOKENS = 1024
31
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
@@ -41,16 +27,14 @@ model = AutoModelForCausalLM.from_pretrained(
41
  )
42
  model.eval()
43
 
 
 
 
 
 
 
 
44
 
45
- async def text_to_speech(text: str, output_file="output.mp3"):
46
- """Convert text to speech using Edge TTS and save as MP3"""
47
- voice = "en-US-JennyNeural" # Change this to your preferred voice
48
- communicate = edge_tts.Communicate(text, voice)
49
- await communicate.save(output_file)
50
- return output_file
51
-
52
-
53
- @spaces.GPU
54
  def generate(
55
  message: str,
56
  chat_history: list[dict],
@@ -59,47 +43,55 @@ def generate(
59
  top_p: float = 0.9,
60
  top_k: int = 50,
61
  repetition_penalty: float = 1.2,
62
- ):
63
- """Generates chatbot response and handles TTS requests"""
64
- is_tts = message.strip().lower().startswith("@tts")
65
- message = message.replace("@tts", "").strip()
66
 
67
- conversation = [*chat_history, {"role": "user", "content": message}]
 
 
 
 
68
 
 
69
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
 
70
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
71
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
72
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
73
  input_ids = input_ids.to(model.device)
74
 
75
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
76
- generate_kwargs = dict(
77
- {"input_ids": input_ids},
78
- streamer=streamer,
79
- max_new_tokens=max_new_tokens,
80
- do_sample=True,
81
- top_p=top_p,
82
- top_k=top_k,
83
- temperature=temperature,
84
- num_beams=1,
85
- repetition_penalty=repetition_penalty,
86
- )
87
  t = Thread(target=model.generate, kwargs=generate_kwargs)
88
  t.start()
89
 
90
  outputs = []
91
  for text in streamer:
92
  outputs.append(text)
93
- yield "".join(outputs)
94
 
95
- final_response = "".join(outputs)
96
 
 
97
  if is_tts:
98
- output_file = asyncio.run(text_to_speech(final_response))
99
- return output_file # Return MP3 file
100
-
101
- return final_response # Return text response
102
 
 
103
 
104
  demo = gr.ChatInterface(
105
  fn=generate,
@@ -113,15 +105,13 @@ demo = gr.ChatInterface(
113
  stop_btn=None,
114
  examples=[
115
  ["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
116
- ["Write a Python function to check if a number is prime."],
117
- ["What causes rainbows to form?"],
118
- ["Rewrite the following sentence in passive voice: 'The dog chased the cat.'"],
119
- ["@tts What is the capital of France?"],
120
  ],
121
  cache_examples=False,
122
  type="messages",
123
  description=DESCRIPTION,
124
- css=css,
125
  fill_height=True,
126
  )
127
 
 
1
  import os
 
 
2
  import gradio as gr
 
3
  import torch
4
+ import tempfile
5
  import asyncio
6
+ import edge_tts
7
+ from threading import Thread
8
+ from collections.abc import Iterator
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
 
11
  DESCRIPTION = """
12
+ # QwQ Tiny with Edge TTS
13
  """
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  MAX_MAX_NEW_TOKENS = 2048
16
  DEFAULT_MAX_NEW_TOKENS = 1024
17
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
27
  )
28
  model.eval()
29
 
30
+ async def text_to_speech(text: str) -> str:
31
+ """Converts text to speech using Edge TTS and returns the generated audio file path."""
32
+ communicate = edge_tts.Communicate(text)
33
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
34
+ tmp_path = tmp_file.name
35
+ await communicate.save(tmp_path)
36
+ return tmp_path
37
 
 
 
 
 
 
 
 
 
 
38
  def generate(
39
  message: str,
40
  chat_history: list[dict],
 
43
  top_p: float = 0.9,
44
  top_k: int = 50,
45
  repetition_penalty: float = 1.2,
46
+ ) -> Iterator[str] | str:
47
+
48
+ is_tts = message.strip().startswith("@tts")
49
+ is_text_only = message.strip().startswith("@text")
50
 
51
+ # Remove special tags
52
+ if is_tts:
53
+ message = message.replace("@tts", "").strip()
54
+ elif is_text_only:
55
+ message = message.replace("@text", "").strip()
56
 
57
+ conversation = [*chat_history, {"role": "user", "content": message}]
58
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
59
+
60
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
61
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
62
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
63
+
64
  input_ids = input_ids.to(model.device)
65
 
66
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
67
+ generate_kwargs = {
68
+ "input_ids": input_ids,
69
+ "streamer": streamer,
70
+ "max_new_tokens": max_new_tokens,
71
+ "do_sample": True,
72
+ "top_p": top_p,
73
+ "top_k": top_k,
74
+ "temperature": temperature,
75
+ "num_beams": 1,
76
+ "repetition_penalty": repetition_penalty,
77
+ }
78
  t = Thread(target=model.generate, kwargs=generate_kwargs)
79
  t.start()
80
 
81
  outputs = []
82
  for text in streamer:
83
  outputs.append(text)
 
84
 
85
+ final_output = "".join(outputs)
86
 
87
+ # If TTS requested, generate speech and return audio file
88
  if is_tts:
89
+ loop = asyncio.new_event_loop()
90
+ asyncio.set_event_loop(loop)
91
+ audio_path = loop.run_until_complete(text_to_speech(final_output))
92
+ return audio_path # Returning audio file path
93
 
94
+ return final_output # Returning text output
95
 
96
  demo = gr.ChatInterface(
97
  fn=generate,
 
105
  stop_btn=None,
106
  examples=[
107
  ["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
108
+ ["@text What causes rainbows to form?"],
109
+ ["edgetts@tts Explain Newton's third law of motion."],
110
+ ["@text Rewrite the following sentence in passive voice: 'The dog chased the cat.'"],
 
111
  ],
112
  cache_examples=False,
113
  type="messages",
114
  description=DESCRIPTION,
 
115
  fill_height=True,
116
  )
117