Mark-Lasfar commited on
Commit
509531f
·
1 Parent(s): cb937e4

endpoints.py generation.py

Browse files
Files changed (3) hide show
  1. api/endpoints.py +3 -1
  2. utils/constants.py +5 -1
  3. utils/generation.py +32 -29
api/endpoints.py CHANGED
@@ -20,7 +20,9 @@ from motor.motor_asyncio import AsyncIOMotorClient
20
  from datetime import datetime
21
  import logging
22
  from typing import List, Optional
23
- from utils.constants import MODEL_ALIASES, MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME, CLIP_BASE_MODEL, CLIP_LARGE_MODEL, ASR_MODEL, TTS_MODEL, IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL
 
 
24
  import psutil
25
  import time
26
  router = APIRouter()
 
20
  from datetime import datetime
21
  import logging
22
  from typing import List, Optional
23
+ # from utils.constants import MODEL_ALIASES, MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME, CLIP_BASE_MODEL, CLIP_LARGE_MODEL, ASR_MODEL, TTS_MODEL, IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL
24
+ from utils.constants import MODEL_ALIASES, MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME, CLIP_BASE_MODEL, CLIP_LARGE_MODEL, ASR_MODEL, TTS_MODEL, IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL, IMAGE_INFERENCE_API
25
+
26
  import psutil
27
  import time
28
  router = APIRouter()
utils/constants.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
 
3
-
4
  MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b:cerebras")
5
  SECONDARY_MODEL_NAME = os.getenv("SECONDARY_MODEL_NAME", "mistralai/Mixtral-8x7B-Instruct-v0.1")
6
  TERTIARY_MODEL_NAME = os.getenv("TERTIARY_MODEL_NAME", "llama/Llama-3.1-8B-Instruct:featherless-ai")
@@ -11,6 +10,11 @@ TTS_MODEL = os.getenv("TTS_MODEL", "facebook/mms-tts-ara")
11
  IMAGE_GEN_MODEL = os.getenv("IMAGE_GEN_MODEL", "Qwen/Qwen2.5-VL-7B-Instruct:novita")
12
  SECONDARY_IMAGE_GEN_MODEL = os.getenv("SECONDARY_IMAGE_GEN_MODEL", "black-forest-labs/FLUX.1-dev")
13
 
 
 
 
 
 
14
  MODEL_ALIASES = {
15
  "advanced": MODEL_NAME,
16
  "standard": SECONDARY_MODEL_NAME,
 
1
  import os
2
 
 
3
  MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b:cerebras")
4
  SECONDARY_MODEL_NAME = os.getenv("SECONDARY_MODEL_NAME", "mistralai/Mixtral-8x7B-Instruct-v0.1")
5
  TERTIARY_MODEL_NAME = os.getenv("TERTIARY_MODEL_NAME", "llama/Llama-3.1-8B-Instruct:featherless-ai")
 
10
  IMAGE_GEN_MODEL = os.getenv("IMAGE_GEN_MODEL", "Qwen/Qwen2.5-VL-7B-Instruct:novita")
11
  SECONDARY_IMAGE_GEN_MODEL = os.getenv("SECONDARY_IMAGE_GEN_MODEL", "black-forest-labs/FLUX.1-dev")
12
 
13
+ ROUTER_API_URL = os.getenv("ROUTER_API_URL", "https://router.huggingface.co")
14
+ API_ENDPOINT = os.getenv("API_ENDPOINT", "https://router.huggingface.co/v1")
15
+ FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co/v1")
16
+ IMAGE_INFERENCE_API = os.getenv("IMAGE_INFERENCE_API", "https://api-inference.huggingface.co/models") # 👈 إضافة جديدة
17
+
18
  MODEL_ALIASES = {
19
  "advanced": MODEL_NAME,
20
  "standard": SECONDARY_MODEL_NAME,
utils/generation.py CHANGED
@@ -19,8 +19,8 @@ from utils.web_search import web_search
19
  from huggingface_hub import snapshot_download
20
  import torch
21
  from diffusers import DiffusionPipeline
22
- from utils.constants import MODEL_ALIASES, MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME, CLIP_BASE_MODEL, CLIP_LARGE_MODEL, ASR_MODEL, TTS_MODEL, IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL
23
-
24
  logger = logging.getLogger(__name__)
25
 
26
  # إعداد Cache
@@ -107,8 +107,9 @@ def select_model(query: str, input_type: str = "text", preferred_model: Optional
107
  ]
108
  for pattern in image_patterns:
109
  if re.search(pattern, query_lower, re.IGNORECASE):
110
- logger.info(f"Selected {CLIP_BASE_MODEL} with endpoint {FALLBACK_API_ENDPOINT} for image-related query: {query[:50]}...")
111
- return CLIP_BASE_MODEL, FALLBACK_API_ENDPOINT
 
112
  for pattern in image_gen_patterns:
113
  if re.search(pattern, query_lower, re.IGNORECASE) or input_type == "image_gen":
114
  logger.info(f"Selected {IMAGE_GEN_MODEL} with endpoint {FALLBACK_API_ENDPOINT} for image generation query: {query[:50]}...")
@@ -223,24 +224,23 @@ def request_generation(
223
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
224
 
225
  # معالجة تحليل الصور
226
- if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL] and image_data:
227
- task_type = "image_analysis"
228
- try:
229
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
230
- device = "cuda" if torch.cuda.is_available() else "cpu"
231
- model = CLIPModel.from_pretrained(model_name, torch_dtype=dtype).to(device)
232
- processor = CLIPProcessor.from_pretrained(model_name)
233
- image = Image.open(io.BytesIO(image_data)).convert("RGB")
234
- inputs = processor(text=message, images=image, return_tensors="pt", padding=True).to(device)
235
- outputs = model(**inputs)
236
- logits_per_image = outputs.logits_per_image
237
- probs = logits_per_image.softmax(dim=1)
238
- result = f"Image analysis result: {probs.tolist()}"
239
- logger.debug(f"Image analysis result: {result}")
240
  if output_format == "audio":
 
 
241
  model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL, torch_dtype=dtype).to(device)
242
  processor = AutoProcessor.from_pretrained(TTS_MODEL)
243
- inputs = processor(text=result, return_tensors="pt").to(device)
244
  audio = model.generate(**inputs)
245
  audio_file = io.BytesIO()
246
  torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
@@ -248,18 +248,21 @@ def request_generation(
248
  audio_data = audio_file.read()
249
  yield audio_data
250
  else:
251
- yield result
252
- cache[cache_key] = [result]
253
  return
254
- except Exception as e:
255
- logger.error(f"Image analysis failed: {e}")
256
- yield f"Error: Image analysis failed: {e}"
257
  return
258
- finally:
259
- if 'model' in locals():
260
- del model
261
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
262
-
 
 
 
263
  # معالجة توليد الصور أو تحريرها
264
  if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "image_gen":
265
  task_type = "image_generation"
 
19
  from huggingface_hub import snapshot_download
20
  import torch
21
  from diffusers import DiffusionPipeline
22
+ # from utils.constants import MODEL_ALIASES, MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME, CLIP_BASE_MODEL, CLIP_LARGE_MODEL, ASR_MODEL, TTS_MODEL, IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL
23
+ from utils.constants import MODEL_ALIASES, MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME, CLIP_BASE_MODEL, CLIP_LARGE_MODEL, ASR_MODEL, TTS_MODEL, IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL, IMAGE_INFERENCE_API
24
  logger = logging.getLogger(__name__)
25
 
26
  # إعداد Cache
 
107
  ]
108
  for pattern in image_patterns:
109
  if re.search(pattern, query_lower, re.IGNORECASE):
110
+ model = CLIP_LARGE_MODEL if preferred_model == "image_advanced" else CLIP_BASE_MODEL
111
+ logger.info(f"Selected {model} with endpoint {IMAGE_INFERENCE_API} for image-related query: {query[:50]}...")
112
+ return model, f"{IMAGE_INFERENCE_API}/{model}"
113
  for pattern in image_gen_patterns:
114
  if re.search(pattern, query_lower, re.IGNORECASE) or input_type == "image_gen":
115
  logger.info(f"Selected {IMAGE_GEN_MODEL} with endpoint {FALLBACK_API_ENDPOINT} for image generation query: {query[:50]}...")
 
224
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
225
 
226
  # معالجة تحليل الصور
227
+ # معالجة تحليل الصور
228
+ if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL] and image_data:
229
+ task_type = "image_analysis"
230
+ try:
231
+ url = f"{IMAGE_INFERENCE_API}/{model_name}"
232
+ headers = {"Authorization": f"Bearer {api_key}"}
233
+ response = requests.post(url, headers=headers, data=image_data)
234
+ if response.status_code == 200:
235
+ result = response.json()
236
+ caption = result[0]['generated_text'] if isinstance(result, list) else result.get('generated_text', 'No caption generated')
237
+ logger.debug(f"Image analysis result: {caption}")
 
 
 
238
  if output_format == "audio":
239
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
240
+ device = "cuda" if torch.cuda.is_available() else "cpu"
241
  model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL, torch_dtype=dtype).to(device)
242
  processor = AutoProcessor.from_pretrained(TTS_MODEL)
243
+ inputs = processor(text=caption, return_tensors="pt").to(device)
244
  audio = model.generate(**inputs)
245
  audio_file = io.BytesIO()
246
  torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
 
248
  audio_data = audio_file.read()
249
  yield audio_data
250
  else:
251
+ yield caption
252
+ cache[cache_key] = [caption]
253
  return
254
+ else:
255
+ logger.error(f"Image analysis failed with status {response.status_code}: {response.text}")
256
+ yield f"Error: Image analysis failed with status {response.status_code}: {response.text}"
257
  return
258
+ except Exception as e:
259
+ logger.error(f"Image analysis failed: {e}")
260
+ yield f"Error: Image analysis failed: {e}"
261
+ return
262
+ finally:
263
+ if 'model' in locals():
264
+ del model
265
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
266
  # معالجة توليد الصور أو تحريرها
267
  if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "image_gen":
268
  task_type = "image_generation"