Mark-Lasfar commited on
Commit
61c3a4c
·
1 Parent(s): 40b8409

endpoints.py generation.py

Browse files
Files changed (2) hide show
  1. requirements.txt +2 -1
  2. utils/generation.py +120 -121
requirements.txt CHANGED
@@ -50,4 +50,5 @@ accelerate>=0.26.0
50
  diffusers>=0.30.0
51
  psutil>=5.9.0
52
  xformers>=0.0.27
53
- anyio==4.6.0
 
 
50
  diffusers>=0.30.0
51
  psutil>=5.9.0
52
  xformers>=0.0.27
53
+ anyio==4.6.0
54
+ duckduckgo-search
utils/generation.py CHANGED
@@ -1,4 +1,3 @@
1
- # utils/generation.py
2
  import os
3
  import re
4
  import json
@@ -16,13 +15,10 @@ import torchaudio
16
  from PIL import Image
17
  from transformers import CLIPModel, CLIPProcessor, AutoProcessor
18
  from parler_tts import ParlerTTSForConditionalGeneration
19
- from utils.web_search import web_search
20
- from huggingface_hub import snapshot_download
21
  import torch
22
- from qwenimage.pipeline_qwen_image_edit import QwenImageEditPipeline
23
- from qwenimage.pipeline_qwen_image import QwenImagePipeline
24
  from diffusers import DiffusionPipeline
25
  from utils.constants import MODEL_ALIASES, MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME, CLIP_BASE_MODEL, CLIP_LARGE_MODEL, ASR_MODEL, TTS_MODEL, IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL
 
26
  logger = logging.getLogger(__name__)
27
 
28
  # إعداد Cache
@@ -36,7 +32,6 @@ LATEX_DELIMS = [
36
  {"left": "\\(", "right": "\\)", "display": False},
37
  ]
38
 
39
-
40
  # إعداد العميل لـ Hugging Face API
41
  HF_TOKEN = os.getenv("HF_TOKEN")
42
  BACKUP_HF_TOKEN = os.getenv("BACKUP_HF_TOKEN")
@@ -44,39 +39,6 @@ ROUTER_API_URL = os.getenv("ROUTER_API_URL", "https://router.huggingface.co")
44
  API_ENDPOINT = os.getenv("API_ENDPOINT", "https://router.huggingface.co/v1")
45
  FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co/v1")
46
 
47
- # تحميل نموذج FLUX.1-dev مسبقًا إذا لزم الأمر
48
- model_path = None
49
- try:
50
- model_path = snapshot_download(
51
- repo_id="black-forest-labs/FLUX.1-dev",
52
- repo_type="model",
53
- ignore_patterns=["*.md", "*..gitattributes"],
54
- local_dir="FLUX.1-dev",
55
- )
56
- except Exception as e:
57
- logger.error(f"Failed to download FLUX.1-dev: {e}")
58
- model_path = None
59
-
60
-
61
-
62
-
63
-
64
- # دعم FlashAttention-3
65
- # _flash_attn_func = None
66
- # _kernels_err = None
67
- # try:
68
- # _k = get_kernel("kernels-community/vllm-flash-attn3")
69
- # _flash_attn_func = _k.flash_attn_func
70
- # except Exception as e:
71
- # _flash_attn_func = None
72
- # _kernels_err = e
73
-
74
- # def _ensure_fa3_available():
75
- # if _flash_attn_func is None:
76
- # raise ImportError(
77
- # "FlashAttention-3 via Hugging Face `kernels` is required. "
78
- # f"Tried `get_kernel('kernels-community/vllm-flash-attn3')` and failed with:\n{_kernels_err}"
79
- # )
80
  # تعطيل PROVIDER_ENDPOINTS لأننا بنستخدم Hugging Face فقط
81
  PROVIDER_ENDPOINTS = {
82
  "huggingface": API_ENDPOINT
@@ -149,7 +111,29 @@ def select_model(query: str, input_type: str = "text", preferred_model: Optional
149
  logger.error("No models available. Falling back to default.")
150
  return MODEL_NAME, API_ENDPOINT
151
 
152
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=2, min=4, max=60))
154
  def request_generation(
155
  api_key: str,
@@ -223,9 +207,11 @@ def request_generation(
223
  if model_name == TTS_MODEL or output_format == "audio":
224
  task_type = "text_to_speech"
225
  try:
226
- model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL)
 
 
227
  processor = AutoProcessor.from_pretrained(TTS_MODEL)
228
- inputs = processor(text=message, return_tensors="pt")
229
  audio = model.generate(**inputs)
230
  audio_file = io.BytesIO()
231
  torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
@@ -239,24 +225,30 @@ def request_generation(
239
  logger.error(f"Text-to-speech failed: {e}")
240
  yield f"Error: Text-to-speech failed: {e}"
241
  return
 
 
 
 
242
 
243
  # معالجة تحليل الصور
244
  if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL] and image_data:
245
  task_type = "image_analysis"
246
  try:
247
- model = CLIPModel.from_pretrained(model_name)
 
 
248
  processor = CLIPProcessor.from_pretrained(model_name)
249
  image = Image.open(io.BytesIO(image_data)).convert("RGB")
250
- inputs = processor(text=message, images=image, return_tensors="pt", padding=True)
251
  outputs = model(**inputs)
252
  logits_per_image = outputs.logits_per_image
253
  probs = logits_per_image.softmax(dim=1)
254
  result = f"Image analysis result: {probs.tolist()}"
255
  logger.debug(f"Image analysis result: {result}")
256
  if output_format == "audio":
257
- model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL)
258
  processor = AutoProcessor.from_pretrained(TTS_MODEL)
259
- inputs = processor(text=result, return_tensors="pt")
260
  audio = model.generate(**inputs)
261
  audio_file = io.BytesIO()
262
  torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
@@ -271,45 +263,60 @@ def request_generation(
271
  logger.error(f"Image analysis failed: {e}")
272
  yield f"Error: Image analysis failed: {e}"
273
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
- # معالجة توليد الصور أو تحريرها
276
-
277
- if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "image_gen":
278
- task_type = "image_generation"
279
- try:
280
- dtype = torch.float16
281
- device = "cuda" if torch.cuda.is_available() else "cpu"
282
- if model_name == IMAGE_GEN_MODEL:
283
- pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=dtype).to(device)
284
- else:
285
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype).to(device)
286
-
287
- polished_prompt = polish_prompt(message)
288
- image_params = {
289
- "prompt": polished_prompt,
290
- "num_inference_steps": 50,
291
- "guidance_scale": 7.5,
292
- }
293
- if input_type == "image_gen" and image_data:
294
- image = Image.open(io.BytesIO(image_data)).convert("RGB")
295
- image_params["image"] = image
296
-
297
- output = pipe(**image_params)
298
- image_file = io.BytesIO()
299
- output.images[0].save(image_file, format="PNG")
300
- image_file.seek(0)
301
- image_data = image_file.read()
302
- logger.debug(f"Generated image data of length: {len(image_data)} bytes")
303
- yield image_data
304
- cache[cache_key] = [image_data]
305
- return
306
- except Exception as e:
307
- logger.error(f"Image generation failed: {e}")
308
- yield f"Error: Image generation failed: {e}"
309
- return
310
-
311
-
312
- # معالجة النصوص (كما هو موجود في الكود الأصلي)
313
  if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL]:
314
  task_type = "image"
315
  enhanced_system_prompt = f"{system_prompt}\nYou are an expert in image analysis and description. Provide detailed descriptions, classifications, or analysis of images based on the query."
@@ -341,16 +348,17 @@ if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "
341
 
342
  if deep_search:
343
  try:
 
344
  search_result = web_search(message)
345
  input_messages.append({"role": "user", "content": f"User query: {message}\nWeb search context: {search_result}"})
346
- except Exception as e:
347
- logger.error(f"Web search failed: {e}")
348
  input_messages.append({"role": "user", "content": message})
349
  else:
350
  input_messages.append({"role": "user", "content": message})
351
 
352
  tools = tools if tools and model_name in [MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME] else []
353
- tool_choice = tool_choice if tool_choice in ["auto", "none", "any", "required"] and model_name in [MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME] else "none"
354
 
355
  cached_chunks = []
356
  try:
@@ -444,9 +452,11 @@ if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "
444
 
445
  if output_format == "audio":
446
  try:
447
- model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL)
 
 
448
  processor = AutoProcessor.from_pretrained(TTS_MODEL)
449
- inputs = processor(text=buffer, return_tensors="pt")
450
  audio = model.generate(**inputs)
451
  audio_file = io.BytesIO()
452
  torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
@@ -457,6 +467,10 @@ if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "
457
  except Exception as e:
458
  logger.error(f"Text-to-speech conversion failed: {e}")
459
  yield f"Error: Text-to-speech conversion failed: {e}"
 
 
 
 
460
 
461
  cache[cache_key] = cached_chunks
462
 
@@ -556,9 +570,11 @@ if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "
556
 
557
  if buffer and output_format == "audio":
558
  try:
559
- model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL)
 
 
560
  processor = AutoProcessor.from_pretrained(TTS_MODEL)
561
- inputs = processor(text=buffer, return_tensors="pt")
562
  audio = model.generate(**inputs)
563
  audio_file = io.BytesIO()
564
  torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
@@ -569,6 +585,10 @@ if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "
569
  except Exception as e:
570
  logger.error(f"Text-to-speech conversion failed: {e}")
571
  yield f"Error: Text-to-speech conversion failed: {e}"
 
 
 
 
572
 
573
  cache[cache_key] = cached_chunks
574
 
@@ -620,9 +640,11 @@ if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "
620
  break
621
  if buffer and output_format == "audio":
622
  try:
623
- model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL)
 
 
624
  processor = AutoProcessor.from_pretrained(TTS_MODEL)
625
- inputs = processor(text=buffer, return_tensors="pt")
626
  audio = model.generate(**inputs)
627
  audio_file = io.BytesIO()
628
  torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
@@ -633,6 +655,10 @@ if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "
633
  except Exception as e:
634
  logger.error(f"Text-to-speech conversion failed: {e}")
635
  yield f"Error: Text-to-speech conversion failed: {e}"
 
 
 
 
636
  cache[cache_key] = cached_chunks
637
  except Exception as e3:
638
  logger.error(f"[Gateway] Streaming failed for tertiary model {TERTIARY_MODEL_NAME}: {e3}")
@@ -642,7 +668,6 @@ if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "
642
  yield f"Error: Failed to load model {model_name}: {e}"
643
  return
644
 
645
-
646
  def format_final(analysis_text: str, visible_text: str) -> str:
647
  reasoning_safe = html.escape((analysis_text or "").strip())
648
  response = (visible_text or "").strip()
@@ -657,32 +682,6 @@ def format_final(analysis_text: str, visible_text: str) -> str:
657
  f"{response}" if response else "No final response available."
658
  )
659
 
660
-
661
- def polish_prompt(original_prompt: str, image: Optional[Image.Image] = None) -> str:
662
- original_prompt = original_prompt.strip()
663
- system_prompt = "You are an expert in generating high-quality prompts for image generation. Rewrite the user input to be clear, descriptive, and optimized for creating visually appealing images."
664
- if any(0x0600 <= ord(char) <= 0x06FF for char in original_prompt):
665
- system_prompt += "\nRespond in Arabic with a polished prompt suitable for image generation."
666
- prompt = f"{system_prompt}\n\nUser Input: {original_prompt}\n\nRewritten Prompt:"
667
- magic_prompt = "Ultra HD, 4K, cinematic composition"
668
- success = False
669
- while not success:
670
- try:
671
- polished_prompt = client.chat.completions.create(
672
- model=MODEL_NAME,
673
- messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
674
- temperature=0.7,
675
- max_tokens=200
676
- ).choices[0].message.content.strip()
677
- polished_prompt = polished_prompt.replace("\n", " ")
678
- success = True
679
- except Exception as e:
680
- logger.error(f"Error during prompt polishing: {e}")
681
- polished_prompt = original_prompt
682
- break
683
- return polished_prompt + " " + magic_prompt
684
-
685
-
686
  def generate(message, history, system_prompt, temperature, reasoning_effort, enable_browsing, max_new_tokens, input_type="text", audio_data=None, image_data=None, output_format="text"):
687
  if not message.strip() and not audio_data and not image_data:
688
  yield "Please enter a prompt or upload a file."
@@ -835,4 +834,4 @@ Response (draft):
835
 
836
  except Exception as e:
837
  logger.exception("Stream failed")
838
- yield f"❌ Error: {e}"
 
 
1
  import os
2
  import re
3
  import json
 
15
  from PIL import Image
16
  from transformers import CLIPModel, CLIPProcessor, AutoProcessor
17
  from parler_tts import ParlerTTSForConditionalGeneration
 
 
18
  import torch
 
 
19
  from diffusers import DiffusionPipeline
20
  from utils.constants import MODEL_ALIASES, MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME, CLIP_BASE_MODEL, CLIP_LARGE_MODEL, ASR_MODEL, TTS_MODEL, IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL
21
+
22
  logger = logging.getLogger(__name__)
23
 
24
  # إعداد Cache
 
32
  {"left": "\\(", "right": "\\)", "display": False},
33
  ]
34
 
 
35
  # إعداد العميل لـ Hugging Face API
36
  HF_TOKEN = os.getenv("HF_TOKEN")
37
  BACKUP_HF_TOKEN = os.getenv("BACKUP_HF_TOKEN")
 
39
  API_ENDPOINT = os.getenv("API_ENDPOINT", "https://router.huggingface.co/v1")
40
  FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co/v1")
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  # تعطيل PROVIDER_ENDPOINTS لأننا بنستخدم Hugging Face فقط
43
  PROVIDER_ENDPOINTS = {
44
  "huggingface": API_ENDPOINT
 
111
  logger.error("No models available. Falling back to default.")
112
  return MODEL_NAME, API_ENDPOINT
113
 
114
+ def polish_prompt(original_prompt: str, image: Optional[Image.Image] = None) -> str:
115
+ original_prompt = original_prompt.strip()
116
+ system_prompt = "You are an expert in generating high-quality prompts for image generation. Rewrite the user input to be clear, descriptive, and optimized for creating visually appealing images."
117
+ if any(0x0600 <= ord(char) <= 0x06FF for char in original_prompt):
118
+ system_prompt += "\nRespond in Arabic with a polished prompt suitable for image generation."
119
+ prompt = f"{system_prompt}\n\nUser Input: {original_prompt}\n\nRewritten Prompt:"
120
+ magic_prompt = "Ultra HD, 4K, cinematic composition"
121
+
122
+ try:
123
+ client = OpenAI(api_key=HF_TOKEN, base_url=FALLBACK_API_ENDPOINT, timeout=120.0)
124
+ polished_prompt = client.chat.completions.create(
125
+ model=SECONDARY_MODEL_NAME, # استخدام نموذج متوافق
126
+ messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
127
+ temperature=0.7,
128
+ max_tokens=200
129
+ ).choices[0].message.content.strip()
130
+ polished_prompt = polished_prompt.replace("\n", " ")
131
+ except Exception as e:
132
+ logger.error(f"Error during prompt polishing: {e}")
133
+ polished_prompt = original_prompt
134
+
135
+ return polished_prompt + " " + magic_prompt
136
+
137
  @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=2, min=4, max=60))
138
  def request_generation(
139
  api_key: str,
 
207
  if model_name == TTS_MODEL or output_format == "audio":
208
  task_type = "text_to_speech"
209
  try:
210
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
211
+ device = "cuda" if torch.cuda.is_available() else "cpu"
212
+ model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL, torch_dtype=dtype).to(device)
213
  processor = AutoProcessor.from_pretrained(TTS_MODEL)
214
+ inputs = processor(text=message, return_tensors="pt").to(device)
215
  audio = model.generate(**inputs)
216
  audio_file = io.BytesIO()
217
  torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
 
225
  logger.error(f"Text-to-speech failed: {e}")
226
  yield f"Error: Text-to-speech failed: {e}"
227
  return
228
+ finally:
229
+ if 'model' in locals():
230
+ del model
231
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
232
 
233
  # معالجة تحليل الصور
234
  if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL] and image_data:
235
  task_type = "image_analysis"
236
  try:
237
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
238
+ device = "cuda" if torch.cuda.is_available() else "cpu"
239
+ model = CLIPModel.from_pretrained(model_name, torch_dtype=dtype).to(device)
240
  processor = CLIPProcessor.from_pretrained(model_name)
241
  image = Image.open(io.BytesIO(image_data)).convert("RGB")
242
+ inputs = processor(text=message, images=image, return_tensors="pt", padding=True).to(device)
243
  outputs = model(**inputs)
244
  logits_per_image = outputs.logits_per_image
245
  probs = logits_per_image.softmax(dim=1)
246
  result = f"Image analysis result: {probs.tolist()}"
247
  logger.debug(f"Image analysis result: {result}")
248
  if output_format == "audio":
249
+ model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL, torch_dtype=dtype).to(device)
250
  processor = AutoProcessor.from_pretrained(TTS_MODEL)
251
+ inputs = processor(text=result, return_tensors="pt").to(device)
252
  audio = model.generate(**inputs)
253
  audio_file = io.BytesIO()
254
  torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
 
263
  logger.error(f"Image analysis failed: {e}")
264
  yield f"Error: Image analysis failed: {e}"
265
  return
266
+ finally:
267
+ if 'model' in locals():
268
+ del model
269
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
270
+
271
+ # معالجة توليد الصور
272
+ if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "image_gen":
273
+ task_type = "image_generation"
274
+ try:
275
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
276
+ device = "cuda" if torch.cuda.is_available() else "cpu"
277
+ logger.info(f"Using device: {device}, dtype: {dtype}")
278
+ if model_name == IMAGE_GEN_MODEL:
279
+ pipe = DiffusionPipeline.from_pretrained(
280
+ "runwayml/stable-diffusion-v1-5",
281
+ torch_dtype=dtype,
282
+ use_auth_token=HF_TOKEN if HF_TOKEN else None
283
+ ).to(device)
284
+ else:
285
+ pipe = DiffusionPipeline.from_pretrained(
286
+ "black-forest-labs/FLUX.1-dev",
287
+ torch_dtype=dtype,
288
+ use_auth_token=HF_TOKEN if HF_TOKEN else None
289
+ ).to(device)
290
+
291
+ polished_prompt = polish_prompt(message)
292
+ image_params = {
293
+ "prompt": polished_prompt,
294
+ "num_inference_steps": 50,
295
+ "guidance_scale": 7.5,
296
+ }
297
+ if input_type == "image_gen" and image_data:
298
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
299
+ image_params["image"] = image
300
+
301
+ output = pipe(**image_params)
302
+ image_file = io.BytesIO()
303
+ output.images[0].save(image_file, format="PNG")
304
+ image_file.seek(0)
305
+ image_data = image_file.read()
306
+ logger.debug(f"Generated image data of length: {len(image_data)} bytes")
307
+ yield image_data
308
+ cache[cache_key] = [image_data]
309
+ return
310
+ except Exception as e:
311
+ logger.error(f"Image generation failed: {e}")
312
+ yield f"Error: Image generation failed: {e}"
313
+ return
314
+ finally:
315
+ if 'pipe' in locals():
316
+ del pipe
317
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
318
 
319
+ # معالجة النصوص
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL]:
321
  task_type = "image"
322
  enhanced_system_prompt = f"{system_prompt}\nYou are an expert in image analysis and description. Provide detailed descriptions, classifications, or analysis of images based on the query."
 
348
 
349
  if deep_search:
350
  try:
351
+ from utils.web_search import web_search
352
  search_result = web_search(message)
353
  input_messages.append({"role": "user", "content": f"User query: {message}\nWeb search context: {search_result}"})
354
+ except (ImportError, Exception) as e:
355
+ logger.error(f"Web search failed or not available: {e}")
356
  input_messages.append({"role": "user", "content": message})
357
  else:
358
  input_messages.append({"role": "user", "content": message})
359
 
360
  tools = tools if tools and model_name in [MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME] else []
361
+ tool_choice = tool_choice if tool_choice in ["auto", "none", "any", "required"] else "none"
362
 
363
  cached_chunks = []
364
  try:
 
452
 
453
  if output_format == "audio":
454
  try:
455
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
456
+ device = "cuda" if torch.cuda.is_available() else "cpu"
457
+ model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL, torch_dtype=dtype).to(device)
458
  processor = AutoProcessor.from_pretrained(TTS_MODEL)
459
+ inputs = processor(text=buffer, return_tensors="pt").to(device)
460
  audio = model.generate(**inputs)
461
  audio_file = io.BytesIO()
462
  torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
 
467
  except Exception as e:
468
  logger.error(f"Text-to-speech conversion failed: {e}")
469
  yield f"Error: Text-to-speech conversion failed: {e}"
470
+ finally:
471
+ if 'model' in locals():
472
+ del model
473
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
474
 
475
  cache[cache_key] = cached_chunks
476
 
 
570
 
571
  if buffer and output_format == "audio":
572
  try:
573
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
574
+ device = "cuda" if torch.cuda.is_available() else "cpu"
575
+ model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL, torch_dtype=dtype).to(device)
576
  processor = AutoProcessor.from_pretrained(TTS_MODEL)
577
+ inputs = processor(text=buffer, return_tensors="pt").to(device)
578
  audio = model.generate(**inputs)
579
  audio_file = io.BytesIO()
580
  torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
 
585
  except Exception as e:
586
  logger.error(f"Text-to-speech conversion failed: {e}")
587
  yield f"Error: Text-to-speech conversion failed: {e}"
588
+ finally:
589
+ if 'model' in locals():
590
+ del model
591
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
592
 
593
  cache[cache_key] = cached_chunks
594
 
 
640
  break
641
  if buffer and output_format == "audio":
642
  try:
643
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
644
+ device = "cuda" if torch.cuda.is_available() else "cpu"
645
+ model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL, torch_dtype=dtype).to(device)
646
  processor = AutoProcessor.from_pretrained(TTS_MODEL)
647
+ inputs = processor(text=buffer, return_tensors="pt").to(device)
648
  audio = model.generate(**inputs)
649
  audio_file = io.BytesIO()
650
  torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
 
655
  except Exception as e:
656
  logger.error(f"Text-to-speech conversion failed: {e}")
657
  yield f"Error: Text-to-speech conversion failed: {e}"
658
+ finally:
659
+ if 'model' in locals():
660
+ del model
661
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
662
  cache[cache_key] = cached_chunks
663
  except Exception as e3:
664
  logger.error(f"[Gateway] Streaming failed for tertiary model {TERTIARY_MODEL_NAME}: {e3}")
 
668
  yield f"Error: Failed to load model {model_name}: {e}"
669
  return
670
 
 
671
  def format_final(analysis_text: str, visible_text: str) -> str:
672
  reasoning_safe = html.escape((analysis_text or "").strip())
673
  response = (visible_text or "").strip()
 
682
  f"{response}" if response else "No final response available."
683
  )
684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685
  def generate(message, history, system_prompt, temperature, reasoning_effort, enable_browsing, max_new_tokens, input_type="text", audio_data=None, image_data=None, output_format="text"):
686
  if not message.strip() and not audio_data and not image_data:
687
  yield "Please enter a prompt or upload a file."
 
834
 
835
  except Exception as e:
836
  logger.exception("Stream failed")
837
+ yield f"❌ Error: {e}"