Mark-Lasfar
commited on
Commit
·
61c3a4c
1
Parent(s):
40b8409
endpoints.py generation.py
Browse files- requirements.txt +2 -1
- utils/generation.py +120 -121
requirements.txt
CHANGED
|
@@ -50,4 +50,5 @@ accelerate>=0.26.0
|
|
| 50 |
diffusers>=0.30.0
|
| 51 |
psutil>=5.9.0
|
| 52 |
xformers>=0.0.27
|
| 53 |
-
anyio==4.6.0
|
|
|
|
|
|
| 50 |
diffusers>=0.30.0
|
| 51 |
psutil>=5.9.0
|
| 52 |
xformers>=0.0.27
|
| 53 |
+
anyio==4.6.0
|
| 54 |
+
duckduckgo-search
|
utils/generation.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
# utils/generation.py
|
| 2 |
import os
|
| 3 |
import re
|
| 4 |
import json
|
|
@@ -16,13 +15,10 @@ import torchaudio
|
|
| 16 |
from PIL import Image
|
| 17 |
from transformers import CLIPModel, CLIPProcessor, AutoProcessor
|
| 18 |
from parler_tts import ParlerTTSForConditionalGeneration
|
| 19 |
-
from utils.web_search import web_search
|
| 20 |
-
from huggingface_hub import snapshot_download
|
| 21 |
import torch
|
| 22 |
-
from qwenimage.pipeline_qwen_image_edit import QwenImageEditPipeline
|
| 23 |
-
from qwenimage.pipeline_qwen_image import QwenImagePipeline
|
| 24 |
from diffusers import DiffusionPipeline
|
| 25 |
from utils.constants import MODEL_ALIASES, MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME, CLIP_BASE_MODEL, CLIP_LARGE_MODEL, ASR_MODEL, TTS_MODEL, IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL
|
|
|
|
| 26 |
logger = logging.getLogger(__name__)
|
| 27 |
|
| 28 |
# إعداد Cache
|
|
@@ -36,7 +32,6 @@ LATEX_DELIMS = [
|
|
| 36 |
{"left": "\\(", "right": "\\)", "display": False},
|
| 37 |
]
|
| 38 |
|
| 39 |
-
|
| 40 |
# إعداد العميل لـ Hugging Face API
|
| 41 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 42 |
BACKUP_HF_TOKEN = os.getenv("BACKUP_HF_TOKEN")
|
|
@@ -44,39 +39,6 @@ ROUTER_API_URL = os.getenv("ROUTER_API_URL", "https://router.huggingface.co")
|
|
| 44 |
API_ENDPOINT = os.getenv("API_ENDPOINT", "https://router.huggingface.co/v1")
|
| 45 |
FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co/v1")
|
| 46 |
|
| 47 |
-
# تحميل نموذج FLUX.1-dev مسبقًا إذا لزم الأمر
|
| 48 |
-
model_path = None
|
| 49 |
-
try:
|
| 50 |
-
model_path = snapshot_download(
|
| 51 |
-
repo_id="black-forest-labs/FLUX.1-dev",
|
| 52 |
-
repo_type="model",
|
| 53 |
-
ignore_patterns=["*.md", "*..gitattributes"],
|
| 54 |
-
local_dir="FLUX.1-dev",
|
| 55 |
-
)
|
| 56 |
-
except Exception as e:
|
| 57 |
-
logger.error(f"Failed to download FLUX.1-dev: {e}")
|
| 58 |
-
model_path = None
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
# دعم FlashAttention-3
|
| 65 |
-
# _flash_attn_func = None
|
| 66 |
-
# _kernels_err = None
|
| 67 |
-
# try:
|
| 68 |
-
# _k = get_kernel("kernels-community/vllm-flash-attn3")
|
| 69 |
-
# _flash_attn_func = _k.flash_attn_func
|
| 70 |
-
# except Exception as e:
|
| 71 |
-
# _flash_attn_func = None
|
| 72 |
-
# _kernels_err = e
|
| 73 |
-
|
| 74 |
-
# def _ensure_fa3_available():
|
| 75 |
-
# if _flash_attn_func is None:
|
| 76 |
-
# raise ImportError(
|
| 77 |
-
# "FlashAttention-3 via Hugging Face `kernels` is required. "
|
| 78 |
-
# f"Tried `get_kernel('kernels-community/vllm-flash-attn3')` and failed with:\n{_kernels_err}"
|
| 79 |
-
# )
|
| 80 |
# تعطيل PROVIDER_ENDPOINTS لأننا بنستخدم Hugging Face فقط
|
| 81 |
PROVIDER_ENDPOINTS = {
|
| 82 |
"huggingface": API_ENDPOINT
|
|
@@ -149,7 +111,29 @@ def select_model(query: str, input_type: str = "text", preferred_model: Optional
|
|
| 149 |
logger.error("No models available. Falling back to default.")
|
| 150 |
return MODEL_NAME, API_ENDPOINT
|
| 151 |
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=2, min=4, max=60))
|
| 154 |
def request_generation(
|
| 155 |
api_key: str,
|
|
@@ -223,9 +207,11 @@ def request_generation(
|
|
| 223 |
if model_name == TTS_MODEL or output_format == "audio":
|
| 224 |
task_type = "text_to_speech"
|
| 225 |
try:
|
| 226 |
-
|
|
|
|
|
|
|
| 227 |
processor = AutoProcessor.from_pretrained(TTS_MODEL)
|
| 228 |
-
inputs = processor(text=message, return_tensors="pt")
|
| 229 |
audio = model.generate(**inputs)
|
| 230 |
audio_file = io.BytesIO()
|
| 231 |
torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
|
|
@@ -239,24 +225,30 @@ def request_generation(
|
|
| 239 |
logger.error(f"Text-to-speech failed: {e}")
|
| 240 |
yield f"Error: Text-to-speech failed: {e}"
|
| 241 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
# معالجة تحليل الصور
|
| 244 |
if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL] and image_data:
|
| 245 |
task_type = "image_analysis"
|
| 246 |
try:
|
| 247 |
-
|
|
|
|
|
|
|
| 248 |
processor = CLIPProcessor.from_pretrained(model_name)
|
| 249 |
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
| 250 |
-
inputs = processor(text=message, images=image, return_tensors="pt", padding=True)
|
| 251 |
outputs = model(**inputs)
|
| 252 |
logits_per_image = outputs.logits_per_image
|
| 253 |
probs = logits_per_image.softmax(dim=1)
|
| 254 |
result = f"Image analysis result: {probs.tolist()}"
|
| 255 |
logger.debug(f"Image analysis result: {result}")
|
| 256 |
if output_format == "audio":
|
| 257 |
-
model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL)
|
| 258 |
processor = AutoProcessor.from_pretrained(TTS_MODEL)
|
| 259 |
-
inputs = processor(text=result, return_tensors="pt")
|
| 260 |
audio = model.generate(**inputs)
|
| 261 |
audio_file = io.BytesIO()
|
| 262 |
torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
|
|
@@ -271,45 +263,60 @@ def request_generation(
|
|
| 271 |
logger.error(f"Image analysis failed: {e}")
|
| 272 |
yield f"Error: Image analysis failed: {e}"
|
| 273 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
-
# معالجة
|
| 276 |
-
|
| 277 |
-
if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "image_gen":
|
| 278 |
-
task_type = "image_generation"
|
| 279 |
-
try:
|
| 280 |
-
dtype = torch.float16
|
| 281 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 282 |
-
if model_name == IMAGE_GEN_MODEL:
|
| 283 |
-
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=dtype).to(device)
|
| 284 |
-
else:
|
| 285 |
-
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype).to(device)
|
| 286 |
-
|
| 287 |
-
polished_prompt = polish_prompt(message)
|
| 288 |
-
image_params = {
|
| 289 |
-
"prompt": polished_prompt,
|
| 290 |
-
"num_inference_steps": 50,
|
| 291 |
-
"guidance_scale": 7.5,
|
| 292 |
-
}
|
| 293 |
-
if input_type == "image_gen" and image_data:
|
| 294 |
-
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
| 295 |
-
image_params["image"] = image
|
| 296 |
-
|
| 297 |
-
output = pipe(**image_params)
|
| 298 |
-
image_file = io.BytesIO()
|
| 299 |
-
output.images[0].save(image_file, format="PNG")
|
| 300 |
-
image_file.seek(0)
|
| 301 |
-
image_data = image_file.read()
|
| 302 |
-
logger.debug(f"Generated image data of length: {len(image_data)} bytes")
|
| 303 |
-
yield image_data
|
| 304 |
-
cache[cache_key] = [image_data]
|
| 305 |
-
return
|
| 306 |
-
except Exception as e:
|
| 307 |
-
logger.error(f"Image generation failed: {e}")
|
| 308 |
-
yield f"Error: Image generation failed: {e}"
|
| 309 |
-
return
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
# معالجة النصوص (كما هو موجود في الكود الأصلي)
|
| 313 |
if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL]:
|
| 314 |
task_type = "image"
|
| 315 |
enhanced_system_prompt = f"{system_prompt}\nYou are an expert in image analysis and description. Provide detailed descriptions, classifications, or analysis of images based on the query."
|
|
@@ -341,16 +348,17 @@ if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "
|
|
| 341 |
|
| 342 |
if deep_search:
|
| 343 |
try:
|
|
|
|
| 344 |
search_result = web_search(message)
|
| 345 |
input_messages.append({"role": "user", "content": f"User query: {message}\nWeb search context: {search_result}"})
|
| 346 |
-
except Exception as e:
|
| 347 |
-
logger.error(f"Web search failed: {e}")
|
| 348 |
input_messages.append({"role": "user", "content": message})
|
| 349 |
else:
|
| 350 |
input_messages.append({"role": "user", "content": message})
|
| 351 |
|
| 352 |
tools = tools if tools and model_name in [MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME] else []
|
| 353 |
-
tool_choice = tool_choice if tool_choice in ["auto", "none", "any", "required"]
|
| 354 |
|
| 355 |
cached_chunks = []
|
| 356 |
try:
|
|
@@ -444,9 +452,11 @@ if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "
|
|
| 444 |
|
| 445 |
if output_format == "audio":
|
| 446 |
try:
|
| 447 |
-
|
|
|
|
|
|
|
| 448 |
processor = AutoProcessor.from_pretrained(TTS_MODEL)
|
| 449 |
-
inputs = processor(text=buffer, return_tensors="pt")
|
| 450 |
audio = model.generate(**inputs)
|
| 451 |
audio_file = io.BytesIO()
|
| 452 |
torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
|
|
@@ -457,6 +467,10 @@ if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "
|
|
| 457 |
except Exception as e:
|
| 458 |
logger.error(f"Text-to-speech conversion failed: {e}")
|
| 459 |
yield f"Error: Text-to-speech conversion failed: {e}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
|
| 461 |
cache[cache_key] = cached_chunks
|
| 462 |
|
|
@@ -556,9 +570,11 @@ if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "
|
|
| 556 |
|
| 557 |
if buffer and output_format == "audio":
|
| 558 |
try:
|
| 559 |
-
|
|
|
|
|
|
|
| 560 |
processor = AutoProcessor.from_pretrained(TTS_MODEL)
|
| 561 |
-
inputs = processor(text=buffer, return_tensors="pt")
|
| 562 |
audio = model.generate(**inputs)
|
| 563 |
audio_file = io.BytesIO()
|
| 564 |
torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
|
|
@@ -569,6 +585,10 @@ if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "
|
|
| 569 |
except Exception as e:
|
| 570 |
logger.error(f"Text-to-speech conversion failed: {e}")
|
| 571 |
yield f"Error: Text-to-speech conversion failed: {e}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
|
| 573 |
cache[cache_key] = cached_chunks
|
| 574 |
|
|
@@ -620,9 +640,11 @@ if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "
|
|
| 620 |
break
|
| 621 |
if buffer and output_format == "audio":
|
| 622 |
try:
|
| 623 |
-
|
|
|
|
|
|
|
| 624 |
processor = AutoProcessor.from_pretrained(TTS_MODEL)
|
| 625 |
-
inputs = processor(text=buffer, return_tensors="pt")
|
| 626 |
audio = model.generate(**inputs)
|
| 627 |
audio_file = io.BytesIO()
|
| 628 |
torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
|
|
@@ -633,6 +655,10 @@ if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "
|
|
| 633 |
except Exception as e:
|
| 634 |
logger.error(f"Text-to-speech conversion failed: {e}")
|
| 635 |
yield f"Error: Text-to-speech conversion failed: {e}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 636 |
cache[cache_key] = cached_chunks
|
| 637 |
except Exception as e3:
|
| 638 |
logger.error(f"[Gateway] Streaming failed for tertiary model {TERTIARY_MODEL_NAME}: {e3}")
|
|
@@ -642,7 +668,6 @@ if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "
|
|
| 642 |
yield f"Error: Failed to load model {model_name}: {e}"
|
| 643 |
return
|
| 644 |
|
| 645 |
-
|
| 646 |
def format_final(analysis_text: str, visible_text: str) -> str:
|
| 647 |
reasoning_safe = html.escape((analysis_text or "").strip())
|
| 648 |
response = (visible_text or "").strip()
|
|
@@ -657,32 +682,6 @@ def format_final(analysis_text: str, visible_text: str) -> str:
|
|
| 657 |
f"{response}" if response else "No final response available."
|
| 658 |
)
|
| 659 |
|
| 660 |
-
|
| 661 |
-
def polish_prompt(original_prompt: str, image: Optional[Image.Image] = None) -> str:
|
| 662 |
-
original_prompt = original_prompt.strip()
|
| 663 |
-
system_prompt = "You are an expert in generating high-quality prompts for image generation. Rewrite the user input to be clear, descriptive, and optimized for creating visually appealing images."
|
| 664 |
-
if any(0x0600 <= ord(char) <= 0x06FF for char in original_prompt):
|
| 665 |
-
system_prompt += "\nRespond in Arabic with a polished prompt suitable for image generation."
|
| 666 |
-
prompt = f"{system_prompt}\n\nUser Input: {original_prompt}\n\nRewritten Prompt:"
|
| 667 |
-
magic_prompt = "Ultra HD, 4K, cinematic composition"
|
| 668 |
-
success = False
|
| 669 |
-
while not success:
|
| 670 |
-
try:
|
| 671 |
-
polished_prompt = client.chat.completions.create(
|
| 672 |
-
model=MODEL_NAME,
|
| 673 |
-
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
|
| 674 |
-
temperature=0.7,
|
| 675 |
-
max_tokens=200
|
| 676 |
-
).choices[0].message.content.strip()
|
| 677 |
-
polished_prompt = polished_prompt.replace("\n", " ")
|
| 678 |
-
success = True
|
| 679 |
-
except Exception as e:
|
| 680 |
-
logger.error(f"Error during prompt polishing: {e}")
|
| 681 |
-
polished_prompt = original_prompt
|
| 682 |
-
break
|
| 683 |
-
return polished_prompt + " " + magic_prompt
|
| 684 |
-
|
| 685 |
-
|
| 686 |
def generate(message, history, system_prompt, temperature, reasoning_effort, enable_browsing, max_new_tokens, input_type="text", audio_data=None, image_data=None, output_format="text"):
|
| 687 |
if not message.strip() and not audio_data and not image_data:
|
| 688 |
yield "Please enter a prompt or upload a file."
|
|
@@ -835,4 +834,4 @@ Response (draft):
|
|
| 835 |
|
| 836 |
except Exception as e:
|
| 837 |
logger.exception("Stream failed")
|
| 838 |
-
yield f"❌ Error: {e}"
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
import json
|
|
|
|
| 15 |
from PIL import Image
|
| 16 |
from transformers import CLIPModel, CLIPProcessor, AutoProcessor
|
| 17 |
from parler_tts import ParlerTTSForConditionalGeneration
|
|
|
|
|
|
|
| 18 |
import torch
|
|
|
|
|
|
|
| 19 |
from diffusers import DiffusionPipeline
|
| 20 |
from utils.constants import MODEL_ALIASES, MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME, CLIP_BASE_MODEL, CLIP_LARGE_MODEL, ASR_MODEL, TTS_MODEL, IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL
|
| 21 |
+
|
| 22 |
logger = logging.getLogger(__name__)
|
| 23 |
|
| 24 |
# إعداد Cache
|
|
|
|
| 32 |
{"left": "\\(", "right": "\\)", "display": False},
|
| 33 |
]
|
| 34 |
|
|
|
|
| 35 |
# إعداد العميل لـ Hugging Face API
|
| 36 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 37 |
BACKUP_HF_TOKEN = os.getenv("BACKUP_HF_TOKEN")
|
|
|
|
| 39 |
API_ENDPOINT = os.getenv("API_ENDPOINT", "https://router.huggingface.co/v1")
|
| 40 |
FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co/v1")
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
# تعطيل PROVIDER_ENDPOINTS لأننا بنستخدم Hugging Face فقط
|
| 43 |
PROVIDER_ENDPOINTS = {
|
| 44 |
"huggingface": API_ENDPOINT
|
|
|
|
| 111 |
logger.error("No models available. Falling back to default.")
|
| 112 |
return MODEL_NAME, API_ENDPOINT
|
| 113 |
|
| 114 |
+
def polish_prompt(original_prompt: str, image: Optional[Image.Image] = None) -> str:
|
| 115 |
+
original_prompt = original_prompt.strip()
|
| 116 |
+
system_prompt = "You are an expert in generating high-quality prompts for image generation. Rewrite the user input to be clear, descriptive, and optimized for creating visually appealing images."
|
| 117 |
+
if any(0x0600 <= ord(char) <= 0x06FF for char in original_prompt):
|
| 118 |
+
system_prompt += "\nRespond in Arabic with a polished prompt suitable for image generation."
|
| 119 |
+
prompt = f"{system_prompt}\n\nUser Input: {original_prompt}\n\nRewritten Prompt:"
|
| 120 |
+
magic_prompt = "Ultra HD, 4K, cinematic composition"
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
client = OpenAI(api_key=HF_TOKEN, base_url=FALLBACK_API_ENDPOINT, timeout=120.0)
|
| 124 |
+
polished_prompt = client.chat.completions.create(
|
| 125 |
+
model=SECONDARY_MODEL_NAME, # استخدام نموذج متوافق
|
| 126 |
+
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
|
| 127 |
+
temperature=0.7,
|
| 128 |
+
max_tokens=200
|
| 129 |
+
).choices[0].message.content.strip()
|
| 130 |
+
polished_prompt = polished_prompt.replace("\n", " ")
|
| 131 |
+
except Exception as e:
|
| 132 |
+
logger.error(f"Error during prompt polishing: {e}")
|
| 133 |
+
polished_prompt = original_prompt
|
| 134 |
+
|
| 135 |
+
return polished_prompt + " " + magic_prompt
|
| 136 |
+
|
| 137 |
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=2, min=4, max=60))
|
| 138 |
def request_generation(
|
| 139 |
api_key: str,
|
|
|
|
| 207 |
if model_name == TTS_MODEL or output_format == "audio":
|
| 208 |
task_type = "text_to_speech"
|
| 209 |
try:
|
| 210 |
+
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 211 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 212 |
+
model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL, torch_dtype=dtype).to(device)
|
| 213 |
processor = AutoProcessor.from_pretrained(TTS_MODEL)
|
| 214 |
+
inputs = processor(text=message, return_tensors="pt").to(device)
|
| 215 |
audio = model.generate(**inputs)
|
| 216 |
audio_file = io.BytesIO()
|
| 217 |
torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
|
|
|
|
| 225 |
logger.error(f"Text-to-speech failed: {e}")
|
| 226 |
yield f"Error: Text-to-speech failed: {e}"
|
| 227 |
return
|
| 228 |
+
finally:
|
| 229 |
+
if 'model' in locals():
|
| 230 |
+
del model
|
| 231 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 232 |
|
| 233 |
# معالجة تحليل الصور
|
| 234 |
if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL] and image_data:
|
| 235 |
task_type = "image_analysis"
|
| 236 |
try:
|
| 237 |
+
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 238 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 239 |
+
model = CLIPModel.from_pretrained(model_name, torch_dtype=dtype).to(device)
|
| 240 |
processor = CLIPProcessor.from_pretrained(model_name)
|
| 241 |
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
| 242 |
+
inputs = processor(text=message, images=image, return_tensors="pt", padding=True).to(device)
|
| 243 |
outputs = model(**inputs)
|
| 244 |
logits_per_image = outputs.logits_per_image
|
| 245 |
probs = logits_per_image.softmax(dim=1)
|
| 246 |
result = f"Image analysis result: {probs.tolist()}"
|
| 247 |
logger.debug(f"Image analysis result: {result}")
|
| 248 |
if output_format == "audio":
|
| 249 |
+
model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL, torch_dtype=dtype).to(device)
|
| 250 |
processor = AutoProcessor.from_pretrained(TTS_MODEL)
|
| 251 |
+
inputs = processor(text=result, return_tensors="pt").to(device)
|
| 252 |
audio = model.generate(**inputs)
|
| 253 |
audio_file = io.BytesIO()
|
| 254 |
torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
|
|
|
|
| 263 |
logger.error(f"Image analysis failed: {e}")
|
| 264 |
yield f"Error: Image analysis failed: {e}"
|
| 265 |
return
|
| 266 |
+
finally:
|
| 267 |
+
if 'model' in locals():
|
| 268 |
+
del model
|
| 269 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 270 |
+
|
| 271 |
+
# معالجة توليد الصور
|
| 272 |
+
if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "image_gen":
|
| 273 |
+
task_type = "image_generation"
|
| 274 |
+
try:
|
| 275 |
+
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 276 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 277 |
+
logger.info(f"Using device: {device}, dtype: {dtype}")
|
| 278 |
+
if model_name == IMAGE_GEN_MODEL:
|
| 279 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 280 |
+
"runwayml/stable-diffusion-v1-5",
|
| 281 |
+
torch_dtype=dtype,
|
| 282 |
+
use_auth_token=HF_TOKEN if HF_TOKEN else None
|
| 283 |
+
).to(device)
|
| 284 |
+
else:
|
| 285 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 286 |
+
"black-forest-labs/FLUX.1-dev",
|
| 287 |
+
torch_dtype=dtype,
|
| 288 |
+
use_auth_token=HF_TOKEN if HF_TOKEN else None
|
| 289 |
+
).to(device)
|
| 290 |
+
|
| 291 |
+
polished_prompt = polish_prompt(message)
|
| 292 |
+
image_params = {
|
| 293 |
+
"prompt": polished_prompt,
|
| 294 |
+
"num_inference_steps": 50,
|
| 295 |
+
"guidance_scale": 7.5,
|
| 296 |
+
}
|
| 297 |
+
if input_type == "image_gen" and image_data:
|
| 298 |
+
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
| 299 |
+
image_params["image"] = image
|
| 300 |
+
|
| 301 |
+
output = pipe(**image_params)
|
| 302 |
+
image_file = io.BytesIO()
|
| 303 |
+
output.images[0].save(image_file, format="PNG")
|
| 304 |
+
image_file.seek(0)
|
| 305 |
+
image_data = image_file.read()
|
| 306 |
+
logger.debug(f"Generated image data of length: {len(image_data)} bytes")
|
| 307 |
+
yield image_data
|
| 308 |
+
cache[cache_key] = [image_data]
|
| 309 |
+
return
|
| 310 |
+
except Exception as e:
|
| 311 |
+
logger.error(f"Image generation failed: {e}")
|
| 312 |
+
yield f"Error: Image generation failed: {e}"
|
| 313 |
+
return
|
| 314 |
+
finally:
|
| 315 |
+
if 'pipe' in locals():
|
| 316 |
+
del pipe
|
| 317 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 318 |
|
| 319 |
+
# معالجة النصوص
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL]:
|
| 321 |
task_type = "image"
|
| 322 |
enhanced_system_prompt = f"{system_prompt}\nYou are an expert in image analysis and description. Provide detailed descriptions, classifications, or analysis of images based on the query."
|
|
|
|
| 348 |
|
| 349 |
if deep_search:
|
| 350 |
try:
|
| 351 |
+
from utils.web_search import web_search
|
| 352 |
search_result = web_search(message)
|
| 353 |
input_messages.append({"role": "user", "content": f"User query: {message}\nWeb search context: {search_result}"})
|
| 354 |
+
except (ImportError, Exception) as e:
|
| 355 |
+
logger.error(f"Web search failed or not available: {e}")
|
| 356 |
input_messages.append({"role": "user", "content": message})
|
| 357 |
else:
|
| 358 |
input_messages.append({"role": "user", "content": message})
|
| 359 |
|
| 360 |
tools = tools if tools and model_name in [MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME] else []
|
| 361 |
+
tool_choice = tool_choice if tool_choice in ["auto", "none", "any", "required"] else "none"
|
| 362 |
|
| 363 |
cached_chunks = []
|
| 364 |
try:
|
|
|
|
| 452 |
|
| 453 |
if output_format == "audio":
|
| 454 |
try:
|
| 455 |
+
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 456 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 457 |
+
model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL, torch_dtype=dtype).to(device)
|
| 458 |
processor = AutoProcessor.from_pretrained(TTS_MODEL)
|
| 459 |
+
inputs = processor(text=buffer, return_tensors="pt").to(device)
|
| 460 |
audio = model.generate(**inputs)
|
| 461 |
audio_file = io.BytesIO()
|
| 462 |
torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
|
|
|
|
| 467 |
except Exception as e:
|
| 468 |
logger.error(f"Text-to-speech conversion failed: {e}")
|
| 469 |
yield f"Error: Text-to-speech conversion failed: {e}"
|
| 470 |
+
finally:
|
| 471 |
+
if 'model' in locals():
|
| 472 |
+
del model
|
| 473 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 474 |
|
| 475 |
cache[cache_key] = cached_chunks
|
| 476 |
|
|
|
|
| 570 |
|
| 571 |
if buffer and output_format == "audio":
|
| 572 |
try:
|
| 573 |
+
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 574 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 575 |
+
model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL, torch_dtype=dtype).to(device)
|
| 576 |
processor = AutoProcessor.from_pretrained(TTS_MODEL)
|
| 577 |
+
inputs = processor(text=buffer, return_tensors="pt").to(device)
|
| 578 |
audio = model.generate(**inputs)
|
| 579 |
audio_file = io.BytesIO()
|
| 580 |
torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
|
|
|
|
| 585 |
except Exception as e:
|
| 586 |
logger.error(f"Text-to-speech conversion failed: {e}")
|
| 587 |
yield f"Error: Text-to-speech conversion failed: {e}"
|
| 588 |
+
finally:
|
| 589 |
+
if 'model' in locals():
|
| 590 |
+
del model
|
| 591 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 592 |
|
| 593 |
cache[cache_key] = cached_chunks
|
| 594 |
|
|
|
|
| 640 |
break
|
| 641 |
if buffer and output_format == "audio":
|
| 642 |
try:
|
| 643 |
+
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 644 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 645 |
+
model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL, torch_dtype=dtype).to(device)
|
| 646 |
processor = AutoProcessor.from_pretrained(TTS_MODEL)
|
| 647 |
+
inputs = processor(text=buffer, return_tensors="pt").to(device)
|
| 648 |
audio = model.generate(**inputs)
|
| 649 |
audio_file = io.BytesIO()
|
| 650 |
torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
|
|
|
|
| 655 |
except Exception as e:
|
| 656 |
logger.error(f"Text-to-speech conversion failed: {e}")
|
| 657 |
yield f"Error: Text-to-speech conversion failed: {e}"
|
| 658 |
+
finally:
|
| 659 |
+
if 'model' in locals():
|
| 660 |
+
del model
|
| 661 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 662 |
cache[cache_key] = cached_chunks
|
| 663 |
except Exception as e3:
|
| 664 |
logger.error(f"[Gateway] Streaming failed for tertiary model {TERTIARY_MODEL_NAME}: {e3}")
|
|
|
|
| 668 |
yield f"Error: Failed to load model {model_name}: {e}"
|
| 669 |
return
|
| 670 |
|
|
|
|
| 671 |
def format_final(analysis_text: str, visible_text: str) -> str:
|
| 672 |
reasoning_safe = html.escape((analysis_text or "").strip())
|
| 673 |
response = (visible_text or "").strip()
|
|
|
|
| 682 |
f"{response}" if response else "No final response available."
|
| 683 |
)
|
| 684 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 685 |
def generate(message, history, system_prompt, temperature, reasoning_effort, enable_browsing, max_new_tokens, input_type="text", audio_data=None, image_data=None, output_format="text"):
|
| 686 |
if not message.strip() and not audio_data and not image_data:
|
| 687 |
yield "Please enter a prompt or upload a file."
|
|
|
|
| 834 |
|
| 835 |
except Exception as e:
|
| 836 |
logger.exception("Stream failed")
|
| 837 |
+
yield f"❌ Error: {e}"
|