Mark-Lasfar commited on
Commit
ee74cc6
·
1 Parent(s): 61c3a4c

endpoints.py generation.py

Browse files
Files changed (1) hide show
  1. utils/generation.py +42 -39
utils/generation.py CHANGED
@@ -15,6 +15,8 @@ import torchaudio
15
  from PIL import Image
16
  from transformers import CLIPModel, CLIPProcessor, AutoProcessor
17
  from parler_tts import ParlerTTSForConditionalGeneration
 
 
18
  import torch
19
  from diffusers import DiffusionPipeline
20
  from utils.constants import MODEL_ALIASES, MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME, CLIP_BASE_MODEL, CLIP_LARGE_MODEL, ASR_MODEL, TTS_MODEL, IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL
@@ -39,6 +41,19 @@ ROUTER_API_URL = os.getenv("ROUTER_API_URL", "https://router.huggingface.co")
39
  API_ENDPOINT = os.getenv("API_ENDPOINT", "https://router.huggingface.co/v1")
40
  FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co/v1")
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  # تعطيل PROVIDER_ENDPOINTS لأننا بنستخدم Hugging Face فقط
43
  PROVIDER_ENDPOINTS = {
44
  "huggingface": API_ENDPOINT
@@ -111,29 +126,6 @@ def select_model(query: str, input_type: str = "text", preferred_model: Optional
111
  logger.error("No models available. Falling back to default.")
112
  return MODEL_NAME, API_ENDPOINT
113
 
114
- def polish_prompt(original_prompt: str, image: Optional[Image.Image] = None) -> str:
115
- original_prompt = original_prompt.strip()
116
- system_prompt = "You are an expert in generating high-quality prompts for image generation. Rewrite the user input to be clear, descriptive, and optimized for creating visually appealing images."
117
- if any(0x0600 <= ord(char) <= 0x06FF for char in original_prompt):
118
- system_prompt += "\nRespond in Arabic with a polished prompt suitable for image generation."
119
- prompt = f"{system_prompt}\n\nUser Input: {original_prompt}\n\nRewritten Prompt:"
120
- magic_prompt = "Ultra HD, 4K, cinematic composition"
121
-
122
- try:
123
- client = OpenAI(api_key=HF_TOKEN, base_url=FALLBACK_API_ENDPOINT, timeout=120.0)
124
- polished_prompt = client.chat.completions.create(
125
- model=SECONDARY_MODEL_NAME, # استخدام نموذج متوافق
126
- messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
127
- temperature=0.7,
128
- max_tokens=200
129
- ).choices[0].message.content.strip()
130
- polished_prompt = polished_prompt.replace("\n", " ")
131
- except Exception as e:
132
- logger.error(f"Error during prompt polishing: {e}")
133
- polished_prompt = original_prompt
134
-
135
- return polished_prompt + " " + magic_prompt
136
-
137
  @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=2, min=4, max=60))
138
  def request_generation(
139
  api_key: str,
@@ -268,25 +260,16 @@ def request_generation(
268
  del model
269
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
270
 
271
- # معالجة توليد الصور
272
  if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "image_gen":
273
  task_type = "image_generation"
274
  try:
275
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
276
  device = "cuda" if torch.cuda.is_available() else "cpu"
277
- logger.info(f"Using device: {device}, dtype: {dtype}")
278
  if model_name == IMAGE_GEN_MODEL:
279
- pipe = DiffusionPipeline.from_pretrained(
280
- "runwayml/stable-diffusion-v1-5",
281
- torch_dtype=dtype,
282
- use_auth_token=HF_TOKEN if HF_TOKEN else None
283
- ).to(device)
284
  else:
285
- pipe = DiffusionPipeline.from_pretrained(
286
- "black-forest-labs/FLUX.1-dev",
287
- torch_dtype=dtype,
288
- use_auth_token=HF_TOKEN if HF_TOKEN else None
289
- ).to(device)
290
 
291
  polished_prompt = polish_prompt(message)
292
  image_params = {
@@ -348,17 +331,16 @@ def request_generation(
348
 
349
  if deep_search:
350
  try:
351
- from utils.web_search import web_search
352
  search_result = web_search(message)
353
  input_messages.append({"role": "user", "content": f"User query: {message}\nWeb search context: {search_result}"})
354
- except (ImportError, Exception) as e:
355
- logger.error(f"Web search failed or not available: {e}")
356
  input_messages.append({"role": "user", "content": message})
357
  else:
358
  input_messages.append({"role": "user", "content": message})
359
 
360
  tools = tools if tools and model_name in [MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME] else []
361
- tool_choice = tool_choice if tool_choice in ["auto", "none", "any", "required"] else "none"
362
 
363
  cached_chunks = []
364
  try:
@@ -682,6 +664,27 @@ def format_final(analysis_text: str, visible_text: str) -> str:
682
  f"{response}" if response else "No final response available."
683
  )
684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685
  def generate(message, history, system_prompt, temperature, reasoning_effort, enable_browsing, max_new_tokens, input_type="text", audio_data=None, image_data=None, output_format="text"):
686
  if not message.strip() and not audio_data and not image_data:
687
  yield "Please enter a prompt or upload a file."
 
15
  from PIL import Image
16
  from transformers import CLIPModel, CLIPProcessor, AutoProcessor
17
  from parler_tts import ParlerTTSForConditionalGeneration
18
+ from utils.web_search import web_search
19
+ from huggingface_hub import snapshot_download
20
  import torch
21
  from diffusers import DiffusionPipeline
22
  from utils.constants import MODEL_ALIASES, MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME, CLIP_BASE_MODEL, CLIP_LARGE_MODEL, ASR_MODEL, TTS_MODEL, IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL
 
41
  API_ENDPOINT = os.getenv("API_ENDPOINT", "https://router.huggingface.co/v1")
42
  FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co/v1")
43
 
44
+ # تحميل نموذج FLUX.1-dev مسبقًا إذا لزم الأمر
45
+ model_path = None
46
+ try:
47
+ model_path = snapshot_download(
48
+ repo_id="black-forest-labs/FLUX.1-dev",
49
+ repo_type="model",
50
+ ignore_patterns=["*.md", "*..gitattributes"],
51
+ local_dir="FLUX.1-dev",
52
+ )
53
+ except Exception as e:
54
+ logger.error(f"Failed to download FLUX.1-dev: {e}")
55
+ model_path = None
56
+
57
  # تعطيل PROVIDER_ENDPOINTS لأننا بنستخدم Hugging Face فقط
58
  PROVIDER_ENDPOINTS = {
59
  "huggingface": API_ENDPOINT
 
126
  logger.error("No models available. Falling back to default.")
127
  return MODEL_NAME, API_ENDPOINT
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=2, min=4, max=60))
130
  def request_generation(
131
  api_key: str,
 
260
  del model
261
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
262
 
263
+ # معالجة توليد الصور أو تحريرها
264
  if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "image_gen":
265
  task_type = "image_generation"
266
  try:
267
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
268
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
269
  if model_name == IMAGE_GEN_MODEL:
270
+ pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=dtype).to(device)
 
 
 
 
271
  else:
272
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype).to(device)
 
 
 
 
273
 
274
  polished_prompt = polish_prompt(message)
275
  image_params = {
 
331
 
332
  if deep_search:
333
  try:
 
334
  search_result = web_search(message)
335
  input_messages.append({"role": "user", "content": f"User query: {message}\nWeb search context: {search_result}"})
336
+ except Exception as e:
337
+ logger.error(f"Web search failed: {e}")
338
  input_messages.append({"role": "user", "content": message})
339
  else:
340
  input_messages.append({"role": "user", "content": message})
341
 
342
  tools = tools if tools and model_name in [MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME] else []
343
+ tool_choice = tool_choice if tool_choice in ["auto", "none", "any", "required"] and model_name in [MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME] else "none"
344
 
345
  cached_chunks = []
346
  try:
 
664
  f"{response}" if response else "No final response available."
665
  )
666
 
667
+ def polish_prompt(original_prompt: str, image: Optional[Image.Image] = None) -> str:
668
+ original_prompt = original_prompt.strip()
669
+ system_prompt = "You are an expert in generating high-quality prompts for image generation. Rewrite the user input to be clear, descriptive, and optimized for creating visually appealing images."
670
+ if any(0x0600 <= ord(char) <= 0x06FF for char in original_prompt):
671
+ system_prompt += "\nRespond in Arabic with a polished prompt suitable for image generation."
672
+ prompt = f"{system_prompt}\n\nUser Input: {original_prompt}\n\nRewritten Prompt:"
673
+ magic_prompt = "Ultra HD, 4K, cinematic composition"
674
+ try:
675
+ client = OpenAI(api_key=HF_TOKEN, base_url=FALLBACK_API_ENDPOINT, timeout=120.0)
676
+ polished_prompt = client.chat.completions.create(
677
+ model=SECONDARY_MODEL_NAME,
678
+ messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
679
+ temperature=0.7,
680
+ max_tokens=200
681
+ ).choices[0].message.content.strip()
682
+ polished_prompt = polished_prompt.replace("\n", " ")
683
+ except Exception as e:
684
+ logger.error(f"Error during prompt polishing: {e}")
685
+ polished_prompt = original_prompt
686
+ return polished_prompt + " " + magic_prompt
687
+
688
  def generate(message, history, system_prompt, temperature, reasoning_effort, enable_browsing, max_new_tokens, input_type="text", audio_data=None, image_data=None, output_format="text"):
689
  if not message.strip() and not audio_data and not image_data:
690
  yield "Please enter a prompt or upload a file."