Mark-Lasfar commited on
Commit
dcad397
·
1 Parent(s): a1a7a58

endpoints.py generation.py

Browse files
api/endpoints.py CHANGED
@@ -20,7 +20,9 @@ from motor.motor_asyncio import AsyncIOMotorClient
20
  from datetime import datetime
21
  import logging
22
  from typing import List, Optional
23
-
 
 
24
  router = APIRouter()
25
  logger = logging.getLogger(__name__)
26
 
@@ -37,24 +39,7 @@ if not BACKUP_HF_TOKEN:
37
  ROUTER_API_URL = os.getenv("ROUTER_API_URL", "https://router.huggingface.co")
38
  API_ENDPOINT = os.getenv("API_ENDPOINT", "https://router.huggingface.co/v1")
39
  FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co/v1")
40
- MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b:cerebras")
41
- SECONDARY_MODEL_NAME = os.getenv("SECONDARY_MODEL_NAME", "mistralai/Mixtral-8x7B-Instruct-v0.1")
42
- TERTIARY_MODEL_NAME = os.getenv("TERTIARY_MODEL_NAME", "meta-llama/Llama-3-8b-chat-hf")
43
- CLIP_BASE_MODEL = os.getenv("CLIP_BASE_MODEL", "Salesforce/blip-image-captioning-large")
44
- CLIP_LARGE_MODEL = os.getenv("CLIP_LARGE_MODEL", "openai/clip-vit-large-patch14")
45
- ASR_MODEL = os.getenv("ASR_MODEL", "openai/whisper-large-v3")
46
- TTS_MODEL = os.getenv("TTS_MODEL", "facebook/mms-tts-ara")
47
-
48
- # Model alias mapping for user-friendly names
49
- MODEL_ALIASES = {
50
- "advanced": MODEL_NAME,
51
- "standard": SECONDARY_MODEL_NAME,
52
- "light": TERTIARY_MODEL_NAME,
53
- "image_base": CLIP_BASE_MODEL,
54
- "image_advanced": CLIP_LARGE_MODEL,
55
- "audio": ASR_MODEL,
56
- "tts": TTS_MODEL
57
- }
58
 
59
  # MongoDB setup
60
  MONGO_URI = os.getenv("MONGODB_URI")
@@ -62,6 +47,10 @@ client = AsyncIOMotorClient(MONGO_URI)
62
  db = client["hager"]
63
  session_message_counts = db["session_message_counts"]
64
 
 
 
 
 
65
  # Helper function to handle sessions for non-logged-in users
66
  async def handle_session(request: Request):
67
  if not hasattr(request, "session"):
@@ -142,7 +131,7 @@ async def performance_stats():
142
  return {
143
  "queue_size": int(os.getenv("QUEUE_SIZE", 80)),
144
  "concurrency_limit": int(os.getenv("CONCURRENCY_LIMIT", 20)),
145
- "uptime": os.popen("uptime").read().strip()
146
  }
147
 
148
  @router.post("/api/chat")
@@ -287,6 +276,88 @@ async def chat_endpoint(
287
 
288
  return {"response": response}
289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  @router.post("/api/audio-transcription")
291
  async def audio_transcription_endpoint(
292
  request: Request,
@@ -824,7 +895,7 @@ async def verify_token(user: User = Depends(current_active_user)):
824
  raise HTTPException(status_code=401, detail="Invalid or expired token")
825
  return {"status": "valid"}
826
 
827
-
828
  @router.put("/users/me")
829
  async def update_user_settings(
830
  settings: UserUpdate,
 
20
  from datetime import datetime
21
  import logging
22
  from typing import List, Optional
23
+ from utils.constants import MODEL_ALIASES, MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME, CLIP_BASE_MODEL, CLIP_LARGE_MODEL, ASR_MODEL, TTS_MODEL, IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL
24
+ import psutil
25
+ import time
26
  router = APIRouter()
27
  logger = logging.getLogger(__name__)
28
 
 
39
  ROUTER_API_URL = os.getenv("ROUTER_API_URL", "https://router.huggingface.co")
40
  API_ENDPOINT = os.getenv("API_ENDPOINT", "https://router.huggingface.co/v1")
41
  FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co/v1")
42
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  # MongoDB setup
45
  MONGO_URI = os.getenv("MONGODB_URI")
 
47
  db = client["hager"]
48
  session_message_counts = db["session_message_counts"]
49
 
50
+ class ImageGenRequest(BaseModel):
51
+ prompt: str
52
+ output_format: str = "image"
53
+
54
  # Helper function to handle sessions for non-logged-in users
55
  async def handle_session(request: Request):
56
  if not hasattr(request, "session"):
 
131
  return {
132
  "queue_size": int(os.getenv("QUEUE_SIZE", 80)),
133
  "concurrency_limit": int(os.getenv("CONCURRENCY_LIMIT", 20)),
134
+ "uptime": time.time() - psutil.boot_time() # مدة تشغيل النظام بالثواني
135
  }
136
 
137
  @router.post("/api/chat")
 
276
 
277
  return {"response": response}
278
 
279
+
280
+ @router.post("/api/image-generation")
281
+ async def image_generation_endpoint(
282
+ request: Request,
283
+ req: dict,
284
+ file: Optional[UploadFile] = File(None),
285
+ user: User = Depends(current_active_user),
286
+ db: AsyncSession = Depends(get_db)
287
+ ):
288
+ if not user:
289
+ await handle_session(request)
290
+
291
+ prompt = req.get("prompt", "")
292
+ output_format = req.get("output_format", "image")
293
+ if not prompt.strip():
294
+ raise HTTPException(status_code=400, detail="Prompt is required for image generation.")
295
+
296
+ model_name, api_endpoint = select_model(prompt, input_type="image_gen")
297
+
298
+ is_available, api_key, selected_endpoint = check_model_availability(model_name, HF_TOKEN)
299
+ if not is_available:
300
+ logger.error(f"Model {model_name} is not available at {api_endpoint}")
301
+ raise HTTPException(status_code=503, detail=f"Model {model_name} is not available. Please try another model.")
302
+
303
+ image_data = None
304
+ if file:
305
+ image_data = await file.read()
306
+
307
+ system_prompt = enhance_system_prompt(
308
+ "You are an expert in generating high-quality images based on detailed prompts. Ensure the output is visually appealing and matches the user's description.",
309
+ prompt, user
310
+ )
311
+
312
+ stream = request_generation(
313
+ api_key=api_key,
314
+ api_base=selected_endpoint,
315
+ message=prompt,
316
+ system_prompt=system_prompt,
317
+ model_name=model_name,
318
+ temperature=0.7,
319
+ max_new_tokens=2048,
320
+ input_type="image_gen",
321
+ image_data=image_data,
322
+ output_format=output_format
323
+ )
324
+
325
+ if output_format == "image":
326
+ image_chunks = []
327
+ try:
328
+ for chunk in stream:
329
+ logger.debug(f"Processing image chunk: {chunk[:100] if isinstance(chunk, str) else 'bytes'}")
330
+ if isinstance(chunk, bytes):
331
+ image_chunks.append(chunk)
332
+ else:
333
+ logger.warning(f"Unexpected non-bytes chunk in image stream: {chunk}")
334
+ if not image_chunks:
335
+ logger.error("No image data generated.")
336
+ raise HTTPException(status_code=500, detail="No image data generated for image generation.")
337
+ image_data = b"".join(image_chunks)
338
+ return StreamingResponse(io.BytesIO(image_data), media_type="image/png")
339
+ except Exception as e:
340
+ logger.error(f"Image generation failed: {e}")
341
+ raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}")
342
+
343
+ response_chunks = []
344
+ try:
345
+ for chunk in stream:
346
+ logger.debug(f"Processing text chunk: {chunk[:100]}...")
347
+ if isinstance(chunk, str) and chunk.strip() and chunk not in ["analysis", "assistantfinal"]:
348
+ response_chunks.append(chunk)
349
+ else:
350
+ logger.warning(f"Skipping chunk: {chunk}")
351
+ response = "".join(response_chunks)
352
+ if not response.strip():
353
+ logger.error("Empty response generated.")
354
+ raise HTTPException(status_code=500, detail="Empty response generated from model.")
355
+ return {"response": response}
356
+ except Exception as e:
357
+ logger.error(f"Image generation failed: {e}")
358
+ raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}")
359
+
360
+
361
  @router.post("/api/audio-transcription")
362
  async def audio_transcription_endpoint(
363
  request: Request,
 
895
  raise HTTPException(status_code=401, detail="Invalid or expired token")
896
  return {"status": "valid"}
897
 
898
+
899
  @router.put("/users/me")
900
  async def update_user_settings(
901
  settings: UserUpdate,
generated_image.png ADDED

Git LFS Details

  • SHA256: 0019dfc4b32d63c1392aa264aed2253c1e0c2fb09216f8e2cc269bbfb8bb49b5
  • Pointer size: 126 Bytes
  • Size of remote file: 9 Bytes
requirements.txt CHANGED
@@ -2,6 +2,7 @@ fastapi==0.95.2
2
  fastapi-users[sqlalchemy,oauth2]==10.4.2
3
  pydantic==1.10.13
4
  email-validator==1.3.1
 
5
  aiosqlite==0.21.0
6
  sqlalchemy==2.0.43
7
  python-jose[cryptography]==3.3.0
@@ -39,9 +40,15 @@ pymongo==4.10.1
39
  parler-tts @ git+https://github.com/huggingface/parler-tts.git@5d0aca9753ab74ded179732f5bd797f7a8c6f8ee
40
  soupsieve>=2.5
41
  tqdm>=4.66.0
 
42
  argon2-cffi>=23.1.0
43
  wsproto>=1.2.0
44
  descript-audiotools>=0.7.2
45
  scipy>=1.15.0
46
  librosa>=0.10.0
47
- matplotlib>=3.10.0
 
 
 
 
 
 
2
  fastapi-users[sqlalchemy,oauth2]==10.4.2
3
  pydantic==1.10.13
4
  email-validator==1.3.1
5
+ sqlalchemy[asyncio]
6
  aiosqlite==0.21.0
7
  sqlalchemy==2.0.43
8
  python-jose[cryptography]==3.3.0
 
40
  parler-tts @ git+https://github.com/huggingface/parler-tts.git@5d0aca9753ab74ded179732f5bd797f7a8c6f8ee
41
  soupsieve>=2.5
42
  tqdm>=4.66.0
43
+ git+https://github.com/Dao-AILab/flash-attention.git
44
  argon2-cffi>=23.1.0
45
  wsproto>=1.2.0
46
  descript-audiotools>=0.7.2
47
  scipy>=1.15.0
48
  librosa>=0.10.0
49
+ matplotlib>=3.10.0
50
+ vllm
51
+ accelerate
52
+ flash-attn
53
+ diffusers
54
+ psutil
templates/index.html CHANGED
@@ -217,7 +217,7 @@
217
  <div class="glass p-6">
218
  <h3 class="text-xl font-semibold mb-2">New AI Features</h3>
219
  <p>Explore our latest AI updates for smarter code and e-commerce tools.</p>
220
- <a href="https://hager-zon.vercel.app/blog" target="_blank" class="text-emerald-300 hover:underline">Read More →</a>
221
  </div>
222
  <div class="glass p-6">
223
  <h3 class="text-xl font-semibold mb-2">Global Expansion</h3>
 
217
  <div class="glass p-6">
218
  <h3 class="text-xl font-semibold mb-2">New AI Features</h3>
219
  <p>Explore our latest AI updates for smarter code and e-commerce tools.</p>
220
+ <a href="/blog" target="_blank" class="text-emerald-300 hover:underline">Read More →</a>
221
  </div>
222
  <div class="glass p-6">
223
  <h3 class="text-xl font-semibold mb-2">Global Expansion</h3>
utils/generation.py CHANGED
@@ -1,7 +1,4 @@
1
  # utils/generation.py
2
- # SPDX-FileCopyrightText: Hadad <hadad@linuxmail.org>
3
- # SPDX-License-Identifier: Apache-2.0
4
-
5
  import os
6
  import re
7
  import json
@@ -20,7 +17,11 @@ from PIL import Image
20
  from transformers import CLIPModel, CLIPProcessor, AutoProcessor
21
  from parler_tts import ParlerTTSForConditionalGeneration
22
  from utils.web_search import web_search
23
-
 
 
 
 
24
  logger = logging.getLogger(__name__)
25
 
26
  # إعداد Cache
@@ -34,20 +35,46 @@ LATEX_DELIMS = [
34
  {"left": "\\(", "right": "\\)", "display": False},
35
  ]
36
 
 
37
  # إعداد العميل لـ Hugging Face API
38
  HF_TOKEN = os.getenv("HF_TOKEN")
39
  BACKUP_HF_TOKEN = os.getenv("BACKUP_HF_TOKEN")
40
  ROUTER_API_URL = os.getenv("ROUTER_API_URL", "https://router.huggingface.co")
41
  API_ENDPOINT = os.getenv("API_ENDPOINT", "https://router.huggingface.co/v1")
42
  FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co/v1")
43
- MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b:cerebras")
44
- SECONDARY_MODEL_NAME = os.getenv("SECONDARY_MODEL_NAME", "mistralai/Mixtral-8x7B-Instruct-v0.1")
45
- TERTIARY_MODEL_NAME = os.getenv("TERTIARY_MODEL_NAME", "meta-llama/Llama-3-8b-chat-hf")
46
- CLIP_BASE_MODEL = os.getenv("CLIP_BASE_MODEL", "Salesforce/blip-image-captioning-large")
47
- CLIP_LARGE_MODEL = os.getenv("CLIP_LARGE_MODEL", "openai/clip-vit-large-patch14")
48
- ASR_MODEL = os.getenv("ASR_MODEL", "openai/whisper-large-v3")
49
- TTS_MODEL = os.getenv("TTS_MODEL", "facebook/mms-tts-ara")
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # تعطيل PROVIDER_ENDPOINTS لأننا بنستخدم Hugging Face فقط
52
  PROVIDER_ENDPOINTS = {
53
  "huggingface": API_ENDPOINT
@@ -95,10 +122,18 @@ def select_model(query: str, input_type: str = "text", preferred_model: Optional
95
  r"\bimage\b", r"\bpicture\b", r"\bphoto\b", r"\bvisual\b", r"\bصورة\b", r"\bتحليل\s+صورة\b",
96
  r"\bimage\s+analysis\b", r"\bimage\s+classification\b", r"\bimage\s+description\b"
97
  ]
 
 
 
 
98
  for pattern in image_patterns:
99
  if re.search(pattern, query_lower, re.IGNORECASE):
100
  logger.info(f"Selected {CLIP_BASE_MODEL} with endpoint {FALLBACK_API_ENDPOINT} for image-related query: {query[:50]}...")
101
  return CLIP_BASE_MODEL, FALLBACK_API_ENDPOINT
 
 
 
 
102
  available_models = [
103
  (MODEL_NAME, API_ENDPOINT),
104
  (SECONDARY_MODEL_NAME, FALLBACK_API_ENDPOINT),
@@ -112,6 +147,7 @@ def select_model(query: str, input_type: str = "text", preferred_model: Optional
112
  logger.error("No models available. Falling back to default.")
113
  return MODEL_NAME, API_ENDPOINT
114
 
 
115
  @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=2, min=4, max=60))
116
  def request_generation(
117
  api_key: str,
@@ -157,6 +193,7 @@ def request_generation(
157
  enhanced_system_prompt = system_prompt
158
  buffer = ""
159
 
 
160
  if model_name == ASR_MODEL and audio_data:
161
  task_type = "audio_transcription"
162
  try:
@@ -180,6 +217,7 @@ def request_generation(
180
  yield f"Error: Audio transcription failed: {e}"
181
  return
182
 
 
183
  if model_name == TTS_MODEL or output_format == "audio":
184
  task_type = "text_to_speech"
185
  try:
@@ -200,6 +238,7 @@ def request_generation(
200
  yield f"Error: Text-to-speech failed: {e}"
201
  return
202
 
 
203
  if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL] and image_data:
204
  task_type = "image_analysis"
205
  try:
@@ -231,6 +270,51 @@ def request_generation(
231
  yield f"Error: Image analysis failed: {e}"
232
  return
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL]:
235
  task_type = "image"
236
  enhanced_system_prompt = f"{system_prompt}\nYou are an expert in image analysis and description. Provide detailed descriptions, classifications, or analysis of images based on the query."
@@ -259,7 +343,7 @@ def request_generation(
259
  clean_msg = {"role": msg.get("role"), "content": msg.get("content")}
260
  if clean_msg["content"]:
261
  input_messages.append(clean_msg)
262
-
263
  if deep_search:
264
  try:
265
  search_result = web_search(message)
@@ -563,6 +647,7 @@ def request_generation(
563
  yield f"Error: Failed to load model {model_name}: {e}"
564
  return
565
 
 
566
  def format_final(analysis_text: str, visible_text: str) -> str:
567
  reasoning_safe = html.escape((analysis_text or "").strip())
568
  response = (visible_text or "").strip()
@@ -577,6 +662,32 @@ def format_final(analysis_text: str, visible_text: str) -> str:
577
  f"{response}" if response else "No final response available."
578
  )
579
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580
  def generate(message, history, system_prompt, temperature, reasoning_effort, enable_browsing, max_new_tokens, input_type="text", audio_data=None, image_data=None, output_format="text"):
581
  if not message.strip() and not audio_data and not image_data:
582
  yield "Please enter a prompt or upload a file."
 
1
  # utils/generation.py
 
 
 
2
  import os
3
  import re
4
  import json
 
17
  from transformers import CLIPModel, CLIPProcessor, AutoProcessor
18
  from parler_tts import ParlerTTSForConditionalGeneration
19
  from utils.web_search import web_search
20
+ from huggingface_hub import snapshot_download
21
+ import torch
22
+ from qwenimage.pipeline_qwen_image_edit import QwenImageEditPipeline
23
+ from qwenimage.pipeline_qwen_image import QwenImagePipeline
24
+ from utils.constants import MODEL_ALIASES, MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME, CLIP_BASE_MODEL, CLIP_LARGE_MODEL, ASR_MODEL, TTS_MODEL, IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL
25
  logger = logging.getLogger(__name__)
26
 
27
  # إعداد Cache
 
35
  {"left": "\\(", "right": "\\)", "display": False},
36
  ]
37
 
38
+
39
  # إعداد العميل لـ Hugging Face API
40
  HF_TOKEN = os.getenv("HF_TOKEN")
41
  BACKUP_HF_TOKEN = os.getenv("BACKUP_HF_TOKEN")
42
  ROUTER_API_URL = os.getenv("ROUTER_API_URL", "https://router.huggingface.co")
43
  API_ENDPOINT = os.getenv("API_ENDPOINT", "https://router.huggingface.co/v1")
44
  FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co/v1")
 
 
 
 
 
 
 
45
 
46
+ # تحميل نموذج FLUX.1-dev مسبقًا إذا لزم الأمر
47
+ try:
48
+ model_path = snapshot_download(
49
+ repo_id="black-forest-labs/FLUX.1-dev",
50
+ repo_type="model",
51
+ ignore_patterns=["*.md", "*..gitattributes"],
52
+ local_dir="FLUX.1-dev",
53
+ )
54
+ except Exception as e:
55
+ logger.error(f"Failed to download FLUX.1-dev: {e}")
56
+
57
+
58
+
59
+
60
+
61
+
62
+ # دعم FlashAttention-3
63
+ _flash_attn_func = None
64
+ _kernels_err = None
65
+ try:
66
+ _k = get_kernel("kernels-community/vllm-flash-attn3")
67
+ _flash_attn_func = _k.flash_attn_func
68
+ except Exception as e:
69
+ _flash_attn_func = None
70
+ _kernels_err = e
71
+
72
+ def _ensure_fa3_available():
73
+ if _flash_attn_func is None:
74
+ raise ImportError(
75
+ "FlashAttention-3 via Hugging Face `kernels` is required. "
76
+ f"Tried `get_kernel('kernels-community/vllm-flash-attn3')` and failed with:\n{_kernels_err}"
77
+ )
78
  # تعطيل PROVIDER_ENDPOINTS لأننا بنستخدم Hugging Face فقط
79
  PROVIDER_ENDPOINTS = {
80
  "huggingface": API_ENDPOINT
 
122
  r"\bimage\b", r"\bpicture\b", r"\bphoto\b", r"\bvisual\b", r"\bصورة\b", r"\bتحليل\s+صورة\b",
123
  r"\bimage\s+analysis\b", r"\bimage\s+classification\b", r"\bimage\s+description\b"
124
  ]
125
+ image_gen_patterns = [
126
+ r"\bgenerate\s+image\b", r"\bcreate\s+image\b", r"\bimage\s+generation\b", r"\bصورة\s+توليد\b",
127
+ r"\bimage\s+edit\b", r"\bتحرير\s+صورة\b"
128
+ ]
129
  for pattern in image_patterns:
130
  if re.search(pattern, query_lower, re.IGNORECASE):
131
  logger.info(f"Selected {CLIP_BASE_MODEL} with endpoint {FALLBACK_API_ENDPOINT} for image-related query: {query[:50]}...")
132
  return CLIP_BASE_MODEL, FALLBACK_API_ENDPOINT
133
+ for pattern in image_gen_patterns:
134
+ if re.search(pattern, query_lower, re.IGNORECASE) or input_type == "image_gen":
135
+ logger.info(f"Selected {IMAGE_GEN_MODEL} with endpoint {FALLBACK_API_ENDPOINT} for image generation query: {query[:50]}...")
136
+ return IMAGE_GEN_MODEL, FALLBACK_API_ENDPOINT
137
  available_models = [
138
  (MODEL_NAME, API_ENDPOINT),
139
  (SECONDARY_MODEL_NAME, FALLBACK_API_ENDPOINT),
 
147
  logger.error("No models available. Falling back to default.")
148
  return MODEL_NAME, API_ENDPOINT
149
 
150
+
151
  @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=2, min=4, max=60))
152
  def request_generation(
153
  api_key: str,
 
193
  enhanced_system_prompt = system_prompt
194
  buffer = ""
195
 
196
+ # معالجة الصوت
197
  if model_name == ASR_MODEL and audio_data:
198
  task_type = "audio_transcription"
199
  try:
 
217
  yield f"Error: Audio transcription failed: {e}"
218
  return
219
 
220
+ # معالجة تحويل النص إلى صوت
221
  if model_name == TTS_MODEL or output_format == "audio":
222
  task_type = "text_to_speech"
223
  try:
 
238
  yield f"Error: Text-to-speech failed: {e}"
239
  return
240
 
241
+ # معالجة تحليل الصور
242
  if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL] and image_data:
243
  task_type = "image_analysis"
244
  try:
 
270
  yield f"Error: Image analysis failed: {e}"
271
  return
272
 
273
+ # معالجة توليد الصور أو تحريرها
274
+ if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "image_gen":
275
+ task_type = "image_generation"
276
+ try:
277
+ dtype = torch.float16 # يمكن تعديل هذا حسب الأجهزة
278
+ device = "cuda" if torch.cuda.is_available() else "cpu"
279
+ _ensure_fa3_available() # التأكد من توفر FlashAttention-3
280
+ if model_name == IMAGE_GEN_MODEL:
281
+ pipe = QwenImagePipeline.from_pretrained(model_name, torch_dtype=dtype).to(device)
282
+ pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
283
+ else:
284
+ pipe = QwenImageEditPipeline.from_pretrained(model_name, torch_dtype=dtype).to(device)
285
+ pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
286
+
287
+ # إعداد المعلمات لتوليد الصور
288
+ polished_prompt = polish_prompt(message)
289
+ image_params = {
290
+ "prompt": polished_prompt,
291
+ "seed": 0,
292
+ "randomize_seed": True,
293
+ "aspect_ratio": "16:9",
294
+ "guidance_scale": 4,
295
+ "num_inference_steps": 50,
296
+ "prompt_enhance": True
297
+ }
298
+ if input_type == "image_gen" and image_data:
299
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
300
+ image_params["image"] = image
301
+
302
+ # توليد الصورة
303
+ output = pipe(**image_params)
304
+ image_file = io.BytesIO()
305
+ output.images[0].save(image_file, format="PNG")
306
+ image_file.seek(0)
307
+ image_data = image_file.read()
308
+ logger.debug(f"Generated image data of length: {len(image_data)} bytes")
309
+ yield image_data
310
+ cache[cache_key] = [image_data]
311
+ return
312
+ except Exception as e:
313
+ logger.error(f"Image generation failed: {e}")
314
+ yield f"Error: Image generation failed: {e}"
315
+ return
316
+
317
+ # معالجة النصوص (كما هو موجود في الكود الأصلي)
318
  if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL]:
319
  task_type = "image"
320
  enhanced_system_prompt = f"{system_prompt}\nYou are an expert in image analysis and description. Provide detailed descriptions, classifications, or analysis of images based on the query."
 
343
  clean_msg = {"role": msg.get("role"), "content": msg.get("content")}
344
  if clean_msg["content"]:
345
  input_messages.append(clean_msg)
346
+
347
  if deep_search:
348
  try:
349
  search_result = web_search(message)
 
647
  yield f"Error: Failed to load model {model_name}: {e}"
648
  return
649
 
650
+
651
  def format_final(analysis_text: str, visible_text: str) -> str:
652
  reasoning_safe = html.escape((analysis_text or "").strip())
653
  response = (visible_text or "").strip()
 
662
  f"{response}" if response else "No final response available."
663
  )
664
 
665
+
666
+ def polish_prompt(original_prompt: str, image: Optional[Image.Image] = None) -> str:
667
+ original_prompt = original_prompt.strip()
668
+ system_prompt = "You are an expert in generating high-quality prompts for image generation. Rewrite the user input to be clear, descriptive, and optimized for creating visually appealing images."
669
+ if any(0x0600 <= ord(char) <= 0x06FF for char in original_prompt):
670
+ system_prompt += "\nRespond in Arabic with a polished prompt suitable for image generation."
671
+ prompt = f"{system_prompt}\n\nUser Input: {original_prompt}\n\nRewritten Prompt:"
672
+ magic_prompt = "Ultra HD, 4K, cinematic composition"
673
+ success = False
674
+ while not success:
675
+ try:
676
+ polished_prompt = client.chat.completions.create(
677
+ model=MODEL_NAME,
678
+ messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
679
+ temperature=0.7,
680
+ max_tokens=200
681
+ ).choices[0].message.content.strip()
682
+ polished_prompt = polished_prompt.replace("\n", " ")
683
+ success = True
684
+ except Exception as e:
685
+ logger.error(f"Error during prompt polishing: {e}")
686
+ polished_prompt = original_prompt
687
+ break
688
+ return polished_prompt + " " + magic_prompt
689
+
690
+
691
  def generate(message, history, system_prompt, temperature, reasoning_effort, enable_browsing, max_new_tokens, input_type="text", audio_data=None, image_data=None, output_format="text"):
692
  if not message.strip() and not audio_data and not image_data:
693
  yield "Please enter a prompt or upload a file."
utils/utils/constants.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b:cerebras")
2
+ SECONDARY_MODEL_NAME = os.getenv("SECONDARY_MODEL_NAME", "mistralai/Mixtral-8x7B-Instruct-v0.1")
3
+ TERTIARY_MODEL_NAME = os.getenv("TERTIARY_MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct:featherless-ai")
4
+ CLIP_BASE_MODEL = os.getenv("CLIP_BASE_MODEL", "Salesforce/blip-image-captioning-large")
5
+ CLIP_LARGE_MODEL = os.getenv("CLIP_LARGE_MODEL", "openai/clip-vit-large-patch14")
6
+ ASR_MODEL = os.getenv("ASR_MODEL", "openai/whisper-large-v3")
7
+ TTS_MODEL = os.getenv("TTS_MODEL", "facebook/mms-tts-ara")
8
+ IMAGE_GEN_MODEL = os.getenv("IMAGE_GEN_MODEL", "Qwen/Qwen2.5-VL-7B-Instruct")
9
+ SECONDARY_IMAGE_GEN_MODEL = os.getenv("SECONDARY_IMAGE_GEN_MODEL", "black-forest-labs/FLUX.1-dev")
10
+
11
+ MODEL_ALIASES = {
12
+ "advanced": MODEL_NAME,
13
+ "standard": SECONDARY_MODEL_NAME,
14
+ "light": TERTIARY_MODEL_NAME,
15
+ "image_base": CLIP_BASE_MODEL,
16
+ "image_advanced": CLIP_LARGE_MODEL,
17
+ "audio": ASR_MODEL,
18
+ "tts": TTS_MODEL,
19
+ "image_gen": IMAGE_GEN_MODEL,
20
+ "secondary_image_gen": SECONDARY_IMAGE_GEN_MODEL
21
+ }