Mark-Lasfar commited on
Commit
a20530e
·
1 Parent(s): dcad397

endpoints.py generation.py

Browse files
requirements.txt CHANGED
@@ -40,15 +40,15 @@ pymongo==4.10.1
40
  parler-tts @ git+https://github.com/huggingface/parler-tts.git@5d0aca9753ab74ded179732f5bd797f7a8c6f8ee
41
  soupsieve>=2.5
42
  tqdm>=4.66.0
43
- git+https://github.com/Dao-AILab/flash-attention.git
44
  argon2-cffi>=23.1.0
45
  wsproto>=1.2.0
46
  descript-audiotools>=0.7.2
47
  scipy>=1.15.0
48
  librosa>=0.10.0
49
  matplotlib>=3.10.0
50
- vllm
51
- accelerate
52
- flash-attn
53
- diffusers
54
- psutil
 
 
40
  parler-tts @ git+https://github.com/huggingface/parler-tts.git@5d0aca9753ab74ded179732f5bd797f7a8c6f8ee
41
  soupsieve>=2.5
42
  tqdm>=4.66.0
 
43
  argon2-cffi>=23.1.0
44
  wsproto>=1.2.0
45
  descript-audiotools>=0.7.2
46
  scipy>=1.15.0
47
  librosa>=0.10.0
48
  matplotlib>=3.10.0
49
+ vllm==0.5.5
50
+ accelerate>=0.26.0
51
+ diffusers>=0.30.0
52
+ psutil>=5.9.0
53
+ xformers>=0.0.27
54
+ anyio==4.6.0
utils/generation.py CHANGED
@@ -21,6 +21,7 @@ from huggingface_hub import snapshot_download
21
  import torch
22
  from qwenimage.pipeline_qwen_image_edit import QwenImageEditPipeline
23
  from qwenimage.pipeline_qwen_image import QwenImagePipeline
 
24
  from utils.constants import MODEL_ALIASES, MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME, CLIP_BASE_MODEL, CLIP_LARGE_MODEL, ASR_MODEL, TTS_MODEL, IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL
25
  logger = logging.getLogger(__name__)
26
 
@@ -44,6 +45,7 @@ API_ENDPOINT = os.getenv("API_ENDPOINT", "https://router.huggingface.co/v1")
44
  FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co/v1")
45
 
46
  # تحميل نموذج FLUX.1-dev مسبقًا إذا لزم الأمر
 
47
  try:
48
  model_path = snapshot_download(
49
  repo_id="black-forest-labs/FLUX.1-dev",
@@ -53,28 +55,28 @@ try:
53
  )
54
  except Exception as e:
55
  logger.error(f"Failed to download FLUX.1-dev: {e}")
56
-
57
 
58
 
59
 
60
 
61
 
62
  # دعم FlashAttention-3
63
- _flash_attn_func = None
64
- _kernels_err = None
65
- try:
66
- _k = get_kernel("kernels-community/vllm-flash-attn3")
67
- _flash_attn_func = _k.flash_attn_func
68
- except Exception as e:
69
- _flash_attn_func = None
70
- _kernels_err = e
71
-
72
- def _ensure_fa3_available():
73
- if _flash_attn_func is None:
74
- raise ImportError(
75
- "FlashAttention-3 via Hugging Face `kernels` is required. "
76
- f"Tried `get_kernel('kernels-community/vllm-flash-attn3')` and failed with:\n{_kernels_err}"
77
- )
78
  # تعطيل PROVIDER_ENDPOINTS لأننا بنستخدم Hugging Face فقط
79
  PROVIDER_ENDPOINTS = {
80
  "huggingface": API_ENDPOINT
@@ -271,48 +273,41 @@ def request_generation(
271
  return
272
 
273
  # معالجة توليد الصور أو تحريرها
274
- if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "image_gen":
275
- task_type = "image_generation"
276
- try:
277
- dtype = torch.float16 # يمكن تعديل هذا حسب الأجهزة
278
- device = "cuda" if torch.cuda.is_available() else "cpu"
279
- _ensure_fa3_available() # التأكد من توفر FlashAttention-3
280
- if model_name == IMAGE_GEN_MODEL:
281
- pipe = QwenImagePipeline.from_pretrained(model_name, torch_dtype=dtype).to(device)
282
- pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
283
- else:
284
- pipe = QwenImageEditPipeline.from_pretrained(model_name, torch_dtype=dtype).to(device)
285
- pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
286
-
287
- # إعداد المعلمات لتوليد الصور
288
- polished_prompt = polish_prompt(message)
289
- image_params = {
290
- "prompt": polished_prompt,
291
- "seed": 0,
292
- "randomize_seed": True,
293
- "aspect_ratio": "16:9",
294
- "guidance_scale": 4,
295
- "num_inference_steps": 50,
296
- "prompt_enhance": True
297
- }
298
- if input_type == "image_gen" and image_data:
299
- image = Image.open(io.BytesIO(image_data)).convert("RGB")
300
- image_params["image"] = image
301
-
302
- # توليد الصورة
303
- output = pipe(**image_params)
304
- image_file = io.BytesIO()
305
- output.images[0].save(image_file, format="PNG")
306
- image_file.seek(0)
307
- image_data = image_file.read()
308
- logger.debug(f"Generated image data of length: {len(image_data)} bytes")
309
- yield image_data
310
- cache[cache_key] = [image_data]
311
- return
312
- except Exception as e:
313
- logger.error(f"Image generation failed: {e}")
314
- yield f"Error: Image generation failed: {e}"
315
- return
316
 
317
  # معالجة النصوص (كما هو موجود في الكود الأصلي)
318
  if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL]:
 
21
  import torch
22
  from qwenimage.pipeline_qwen_image_edit import QwenImageEditPipeline
23
  from qwenimage.pipeline_qwen_image import QwenImagePipeline
24
+ from diffusers import DiffusionPipeline
25
  from utils.constants import MODEL_ALIASES, MODEL_NAME, SECONDARY_MODEL_NAME, TERTIARY_MODEL_NAME, CLIP_BASE_MODEL, CLIP_LARGE_MODEL, ASR_MODEL, TTS_MODEL, IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL
26
  logger = logging.getLogger(__name__)
27
 
 
45
  FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co/v1")
46
 
47
  # تحميل نموذج FLUX.1-dev مسبقًا إذا لزم الأمر
48
+ model_path = None
49
  try:
50
  model_path = snapshot_download(
51
  repo_id="black-forest-labs/FLUX.1-dev",
 
55
  )
56
  except Exception as e:
57
  logger.error(f"Failed to download FLUX.1-dev: {e}")
58
+ model_path = None
59
 
60
 
61
 
62
 
63
 
64
  # دعم FlashAttention-3
65
+ # _flash_attn_func = None
66
+ # _kernels_err = None
67
+ # try:
68
+ # _k = get_kernel("kernels-community/vllm-flash-attn3")
69
+ # _flash_attn_func = _k.flash_attn_func
70
+ # except Exception as e:
71
+ # _flash_attn_func = None
72
+ # _kernels_err = e
73
+
74
+ # def _ensure_fa3_available():
75
+ # if _flash_attn_func is None:
76
+ # raise ImportError(
77
+ # "FlashAttention-3 via Hugging Face `kernels` is required. "
78
+ # f"Tried `get_kernel('kernels-community/vllm-flash-attn3')` and failed with:\n{_kernels_err}"
79
+ # )
80
  # تعطيل PROVIDER_ENDPOINTS لأننا بنستخدم Hugging Face فقط
81
  PROVIDER_ENDPOINTS = {
82
  "huggingface": API_ENDPOINT
 
273
  return
274
 
275
  # معالجة توليد الصور أو تحريرها
276
+
277
+ if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "image_gen":
278
+ task_type = "image_generation"
279
+ try:
280
+ dtype = torch.float16
281
+ device = "cuda" if torch.cuda.is_available() else "cpu"
282
+ if model_name == IMAGE_GEN_MODEL:
283
+ pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=dtype).to(device)
284
+ else:
285
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype).to(device)
286
+
287
+ polished_prompt = polish_prompt(message)
288
+ image_params = {
289
+ "prompt": polished_prompt,
290
+ "num_inference_steps": 50,
291
+ "guidance_scale": 7.5,
292
+ }
293
+ if input_type == "image_gen" and image_data:
294
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
295
+ image_params["image"] = image
296
+
297
+ output = pipe(**image_params)
298
+ image_file = io.BytesIO()
299
+ output.images[0].save(image_file, format="PNG")
300
+ image_file.seek(0)
301
+ image_data = image_file.read()
302
+ logger.debug(f"Generated image data of length: {len(image_data)} bytes")
303
+ yield image_data
304
+ cache[cache_key] = [image_data]
305
+ return
306
+ except Exception as e:
307
+ logger.error(f"Image generation failed: {e}")
308
+ yield f"Error: Image generation failed: {e}"
309
+ return
310
+
 
 
 
 
 
 
 
311
 
312
  # معالجة النصوص (كما هو موجود في الكود الأصلي)
313
  if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL]:
utils/utils/constants.py CHANGED
@@ -1,11 +1,11 @@
1
  MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b:cerebras")
2
  SECONDARY_MODEL_NAME = os.getenv("SECONDARY_MODEL_NAME", "mistralai/Mixtral-8x7B-Instruct-v0.1")
3
- TERTIARY_MODEL_NAME = os.getenv("TERTIARY_MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct:featherless-ai")
4
  CLIP_BASE_MODEL = os.getenv("CLIP_BASE_MODEL", "Salesforce/blip-image-captioning-large")
5
  CLIP_LARGE_MODEL = os.getenv("CLIP_LARGE_MODEL", "openai/clip-vit-large-patch14")
6
  ASR_MODEL = os.getenv("ASR_MODEL", "openai/whisper-large-v3")
7
  TTS_MODEL = os.getenv("TTS_MODEL", "facebook/mms-tts-ara")
8
- IMAGE_GEN_MODEL = os.getenv("IMAGE_GEN_MODEL", "Qwen/Qwen2.5-VL-7B-Instruct")
9
  SECONDARY_IMAGE_GEN_MODEL = os.getenv("SECONDARY_IMAGE_GEN_MODEL", "black-forest-labs/FLUX.1-dev")
10
 
11
  MODEL_ALIASES = {
 
1
  MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b:cerebras")
2
  SECONDARY_MODEL_NAME = os.getenv("SECONDARY_MODEL_NAME", "mistralai/Mixtral-8x7B-Instruct-v0.1")
3
+ TERTIARY_MODEL_NAME = os.getenv("TERTIARY_MODEL_NAME", "llama/Llama-3.1-8B-Instruct:featherless-ai")
4
  CLIP_BASE_MODEL = os.getenv("CLIP_BASE_MODEL", "Salesforce/blip-image-captioning-large")
5
  CLIP_LARGE_MODEL = os.getenv("CLIP_LARGE_MODEL", "openai/clip-vit-large-patch14")
6
  ASR_MODEL = os.getenv("ASR_MODEL", "openai/whisper-large-v3")
7
  TTS_MODEL = os.getenv("TTS_MODEL", "facebook/mms-tts-ara")
8
+ IMAGE_GEN_MODEL = os.getenv("IMAGE_GEN_MODEL", "Qwen/Qwen2.5-VL-7B-Instruct:novita")
9
  SECONDARY_IMAGE_GEN_MODEL = os.getenv("SECONDARY_IMAGE_GEN_MODEL", "black-forest-labs/FLUX.1-dev")
10
 
11
  MODEL_ALIASES = {