Mark-Lasfar
commited on
Commit
·
ee74cc6
1
Parent(s):
61c3a4c
endpoints.py generation.py
Browse files- utils/generation.py +42 -39
utils/generation.py
CHANGED
|
@@ -15,6 +15,8 @@ import torchaudio
|
|
| 15 |
from PIL import Image
|
| 16 |
from transformers import CLIPModel, CLIPProcessor, AutoProcessor
|
| 17 |
from parler_tts import ParlerTTSForConditionalGeneration
|
|
|
|
|
|
|
| 18 |
import torch
|
| 19 |
from diffusers import DiffusionPipeline
|
| 20 |
from utils.constants import MODEL_ALIASES, MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME, CLIP_BASE_MODEL, CLIP_LARGE_MODEL, ASR_MODEL, TTS_MODEL, IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL
|
|
@@ -39,6 +41,19 @@ ROUTER_API_URL = os.getenv("ROUTER_API_URL", "https://router.huggingface.co")
|
|
| 39 |
API_ENDPOINT = os.getenv("API_ENDPOINT", "https://router.huggingface.co/v1")
|
| 40 |
FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co/v1")
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
# تعطيل PROVIDER_ENDPOINTS لأننا بنستخدم Hugging Face فقط
|
| 43 |
PROVIDER_ENDPOINTS = {
|
| 44 |
"huggingface": API_ENDPOINT
|
|
@@ -111,29 +126,6 @@ def select_model(query: str, input_type: str = "text", preferred_model: Optional
|
|
| 111 |
logger.error("No models available. Falling back to default.")
|
| 112 |
return MODEL_NAME, API_ENDPOINT
|
| 113 |
|
| 114 |
-
def polish_prompt(original_prompt: str, image: Optional[Image.Image] = None) -> str:
|
| 115 |
-
original_prompt = original_prompt.strip()
|
| 116 |
-
system_prompt = "You are an expert in generating high-quality prompts for image generation. Rewrite the user input to be clear, descriptive, and optimized for creating visually appealing images."
|
| 117 |
-
if any(0x0600 <= ord(char) <= 0x06FF for char in original_prompt):
|
| 118 |
-
system_prompt += "\nRespond in Arabic with a polished prompt suitable for image generation."
|
| 119 |
-
prompt = f"{system_prompt}\n\nUser Input: {original_prompt}\n\nRewritten Prompt:"
|
| 120 |
-
magic_prompt = "Ultra HD, 4K, cinematic composition"
|
| 121 |
-
|
| 122 |
-
try:
|
| 123 |
-
client = OpenAI(api_key=HF_TOKEN, base_url=FALLBACK_API_ENDPOINT, timeout=120.0)
|
| 124 |
-
polished_prompt = client.chat.completions.create(
|
| 125 |
-
model=SECONDARY_MODEL_NAME, # استخدام نموذج متوافق
|
| 126 |
-
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
|
| 127 |
-
temperature=0.7,
|
| 128 |
-
max_tokens=200
|
| 129 |
-
).choices[0].message.content.strip()
|
| 130 |
-
polished_prompt = polished_prompt.replace("\n", " ")
|
| 131 |
-
except Exception as e:
|
| 132 |
-
logger.error(f"Error during prompt polishing: {e}")
|
| 133 |
-
polished_prompt = original_prompt
|
| 134 |
-
|
| 135 |
-
return polished_prompt + " " + magic_prompt
|
| 136 |
-
|
| 137 |
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=2, min=4, max=60))
|
| 138 |
def request_generation(
|
| 139 |
api_key: str,
|
|
@@ -268,25 +260,16 @@ def request_generation(
|
|
| 268 |
del model
|
| 269 |
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 270 |
|
| 271 |
-
# معالجة توليد الصور
|
| 272 |
if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "image_gen":
|
| 273 |
task_type = "image_generation"
|
| 274 |
try:
|
| 275 |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 276 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 277 |
-
logger.info(f"Using device: {device}, dtype: {dtype}")
|
| 278 |
if model_name == IMAGE_GEN_MODEL:
|
| 279 |
-
pipe = DiffusionPipeline.from_pretrained(
|
| 280 |
-
"runwayml/stable-diffusion-v1-5",
|
| 281 |
-
torch_dtype=dtype,
|
| 282 |
-
use_auth_token=HF_TOKEN if HF_TOKEN else None
|
| 283 |
-
).to(device)
|
| 284 |
else:
|
| 285 |
-
pipe = DiffusionPipeline.from_pretrained(
|
| 286 |
-
"black-forest-labs/FLUX.1-dev",
|
| 287 |
-
torch_dtype=dtype,
|
| 288 |
-
use_auth_token=HF_TOKEN if HF_TOKEN else None
|
| 289 |
-
).to(device)
|
| 290 |
|
| 291 |
polished_prompt = polish_prompt(message)
|
| 292 |
image_params = {
|
|
@@ -348,17 +331,16 @@ def request_generation(
|
|
| 348 |
|
| 349 |
if deep_search:
|
| 350 |
try:
|
| 351 |
-
from utils.web_search import web_search
|
| 352 |
search_result = web_search(message)
|
| 353 |
input_messages.append({"role": "user", "content": f"User query: {message}\nWeb search context: {search_result}"})
|
| 354 |
-
except
|
| 355 |
-
logger.error(f"Web search failed
|
| 356 |
input_messages.append({"role": "user", "content": message})
|
| 357 |
else:
|
| 358 |
input_messages.append({"role": "user", "content": message})
|
| 359 |
|
| 360 |
tools = tools if tools and model_name in [MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME] else []
|
| 361 |
-
tool_choice = tool_choice if tool_choice in ["auto", "none", "any", "required"] else "none"
|
| 362 |
|
| 363 |
cached_chunks = []
|
| 364 |
try:
|
|
@@ -682,6 +664,27 @@ def format_final(analysis_text: str, visible_text: str) -> str:
|
|
| 682 |
f"{response}" if response else "No final response available."
|
| 683 |
)
|
| 684 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 685 |
def generate(message, history, system_prompt, temperature, reasoning_effort, enable_browsing, max_new_tokens, input_type="text", audio_data=None, image_data=None, output_format="text"):
|
| 686 |
if not message.strip() and not audio_data and not image_data:
|
| 687 |
yield "Please enter a prompt or upload a file."
|
|
|
|
| 15 |
from PIL import Image
|
| 16 |
from transformers import CLIPModel, CLIPProcessor, AutoProcessor
|
| 17 |
from parler_tts import ParlerTTSForConditionalGeneration
|
| 18 |
+
from utils.web_search import web_search
|
| 19 |
+
from huggingface_hub import snapshot_download
|
| 20 |
import torch
|
| 21 |
from diffusers import DiffusionPipeline
|
| 22 |
from utils.constants import MODEL_ALIASES, MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME, CLIP_BASE_MODEL, CLIP_LARGE_MODEL, ASR_MODEL, TTS_MODEL, IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL
|
|
|
|
| 41 |
API_ENDPOINT = os.getenv("API_ENDPOINT", "https://router.huggingface.co/v1")
|
| 42 |
FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co/v1")
|
| 43 |
|
| 44 |
+
# تحميل نموذج FLUX.1-dev مسبقًا إذا لزم الأمر
|
| 45 |
+
model_path = None
|
| 46 |
+
try:
|
| 47 |
+
model_path = snapshot_download(
|
| 48 |
+
repo_id="black-forest-labs/FLUX.1-dev",
|
| 49 |
+
repo_type="model",
|
| 50 |
+
ignore_patterns=["*.md", "*..gitattributes"],
|
| 51 |
+
local_dir="FLUX.1-dev",
|
| 52 |
+
)
|
| 53 |
+
except Exception as e:
|
| 54 |
+
logger.error(f"Failed to download FLUX.1-dev: {e}")
|
| 55 |
+
model_path = None
|
| 56 |
+
|
| 57 |
# تعطيل PROVIDER_ENDPOINTS لأننا بنستخدم Hugging Face فقط
|
| 58 |
PROVIDER_ENDPOINTS = {
|
| 59 |
"huggingface": API_ENDPOINT
|
|
|
|
| 126 |
logger.error("No models available. Falling back to default.")
|
| 127 |
return MODEL_NAME, API_ENDPOINT
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=2, min=4, max=60))
|
| 130 |
def request_generation(
|
| 131 |
api_key: str,
|
|
|
|
| 260 |
del model
|
| 261 |
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 262 |
|
| 263 |
+
# معالجة توليد الصور أو تحريرها
|
| 264 |
if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "image_gen":
|
| 265 |
task_type = "image_generation"
|
| 266 |
try:
|
| 267 |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 268 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 269 |
if model_name == IMAGE_GEN_MODEL:
|
| 270 |
+
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=dtype).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
else:
|
| 272 |
+
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
polished_prompt = polish_prompt(message)
|
| 275 |
image_params = {
|
|
|
|
| 331 |
|
| 332 |
if deep_search:
|
| 333 |
try:
|
|
|
|
| 334 |
search_result = web_search(message)
|
| 335 |
input_messages.append({"role": "user", "content": f"User query: {message}\nWeb search context: {search_result}"})
|
| 336 |
+
except Exception as e:
|
| 337 |
+
logger.error(f"Web search failed: {e}")
|
| 338 |
input_messages.append({"role": "user", "content": message})
|
| 339 |
else:
|
| 340 |
input_messages.append({"role": "user", "content": message})
|
| 341 |
|
| 342 |
tools = tools if tools and model_name in [MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME] else []
|
| 343 |
+
tool_choice = tool_choice if tool_choice in ["auto", "none", "any", "required"] and model_name in [MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME] else "none"
|
| 344 |
|
| 345 |
cached_chunks = []
|
| 346 |
try:
|
|
|
|
| 664 |
f"{response}" if response else "No final response available."
|
| 665 |
)
|
| 666 |
|
| 667 |
+
def polish_prompt(original_prompt: str, image: Optional[Image.Image] = None) -> str:
|
| 668 |
+
original_prompt = original_prompt.strip()
|
| 669 |
+
system_prompt = "You are an expert in generating high-quality prompts for image generation. Rewrite the user input to be clear, descriptive, and optimized for creating visually appealing images."
|
| 670 |
+
if any(0x0600 <= ord(char) <= 0x06FF for char in original_prompt):
|
| 671 |
+
system_prompt += "\nRespond in Arabic with a polished prompt suitable for image generation."
|
| 672 |
+
prompt = f"{system_prompt}\n\nUser Input: {original_prompt}\n\nRewritten Prompt:"
|
| 673 |
+
magic_prompt = "Ultra HD, 4K, cinematic composition"
|
| 674 |
+
try:
|
| 675 |
+
client = OpenAI(api_key=HF_TOKEN, base_url=FALLBACK_API_ENDPOINT, timeout=120.0)
|
| 676 |
+
polished_prompt = client.chat.completions.create(
|
| 677 |
+
model=SECONDARY_MODEL_NAME,
|
| 678 |
+
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
|
| 679 |
+
temperature=0.7,
|
| 680 |
+
max_tokens=200
|
| 681 |
+
).choices[0].message.content.strip()
|
| 682 |
+
polished_prompt = polished_prompt.replace("\n", " ")
|
| 683 |
+
except Exception as e:
|
| 684 |
+
logger.error(f"Error during prompt polishing: {e}")
|
| 685 |
+
polished_prompt = original_prompt
|
| 686 |
+
return polished_prompt + " " + magic_prompt
|
| 687 |
+
|
| 688 |
def generate(message, history, system_prompt, temperature, reasoning_effort, enable_browsing, max_new_tokens, input_type="text", audio_data=None, image_data=None, output_format="text"):
|
| 689 |
if not message.strip() and not audio_data and not image_data:
|
| 690 |
yield "Please enter a prompt or upload a file."
|