Mark-Lasfar
commited on
Commit
·
e1c2945
1
Parent(s):
07eb745
endpoints.py generation.py
Browse files- utils/generation.py +42 -43
utils/generation.py
CHANGED
|
@@ -172,7 +172,7 @@ def request_generation(
|
|
| 172 |
enhanced_system_prompt = system_prompt
|
| 173 |
buffer = ""
|
| 174 |
|
| 175 |
-
# معالجة الصوت
|
| 176 |
if model_name == ASR_MODEL and audio_data:
|
| 177 |
task_type = "audio_transcription"
|
| 178 |
try:
|
|
@@ -196,7 +196,7 @@ def request_generation(
|
|
| 196 |
yield f"Error: Audio transcription failed: {e}"
|
| 197 |
return
|
| 198 |
|
| 199 |
-
# معالجة تحويل النص إلى صوت
|
| 200 |
if model_name == TTS_MODEL or output_format == "audio":
|
| 201 |
task_type = "text_to_speech"
|
| 202 |
try:
|
|
@@ -223,47 +223,47 @@ def request_generation(
|
|
| 223 |
del model
|
| 224 |
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 225 |
|
| 226 |
-
# معالجة تحليل الصور
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
| 250 |
else:
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
logger.error(f"Image analysis failed
|
| 256 |
-
yield f"Error: Image analysis failed
|
| 257 |
return
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
del model
|
| 265 |
-
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 266 |
-
# معالجة توليد الصور أو تحريرها
|
| 267 |
if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "image_gen":
|
| 268 |
task_type = "image_generation"
|
| 269 |
try:
|
|
@@ -302,7 +302,7 @@ if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL] and image_data:
|
|
| 302 |
del pipe
|
| 303 |
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 304 |
|
| 305 |
-
# معالجة النصوص
|
| 306 |
if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL]:
|
| 307 |
task_type = "image"
|
| 308 |
enhanced_system_prompt = f"{system_prompt}\nYou are an expert in image analysis and description. Provide detailed descriptions, classifications, or analysis of images based on the query."
|
|
@@ -652,7 +652,6 @@ if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL] and image_data:
|
|
| 652 |
else:
|
| 653 |
yield f"Error: Failed to load model {model_name}: {e}"
|
| 654 |
return
|
| 655 |
-
|
| 656 |
def format_final(analysis_text: str, visible_text: str) -> str:
|
| 657 |
reasoning_safe = html.escape((analysis_text or "").strip())
|
| 658 |
response = (visible_text or "").strip()
|
|
|
|
| 172 |
enhanced_system_prompt = system_prompt
|
| 173 |
buffer = ""
|
| 174 |
|
| 175 |
+
# === معالجة الصوت ===
|
| 176 |
if model_name == ASR_MODEL and audio_data:
|
| 177 |
task_type = "audio_transcription"
|
| 178 |
try:
|
|
|
|
| 196 |
yield f"Error: Audio transcription failed: {e}"
|
| 197 |
return
|
| 198 |
|
| 199 |
+
# === معالجة تحويل النص إلى صوت ===
|
| 200 |
if model_name == TTS_MODEL or output_format == "audio":
|
| 201 |
task_type = "text_to_speech"
|
| 202 |
try:
|
|
|
|
| 223 |
del model
|
| 224 |
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 225 |
|
| 226 |
+
# === معالجة تحليل الصور ===
|
| 227 |
+
if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL] and image_data:
|
| 228 |
+
task_type = "image_analysis"
|
| 229 |
+
try:
|
| 230 |
+
url = f"{IMAGE_INFERENCE_API}/{model_name}"
|
| 231 |
+
headers = {"Authorization": f"Bearer {api_key}"}
|
| 232 |
+
response = requests.post(url, headers=headers, data=image_data)
|
| 233 |
+
if response.status_code == 200:
|
| 234 |
+
result = response.json()
|
| 235 |
+
caption = result[0]['generated_text'] if isinstance(result, list) else result.get('generated_text', 'No caption generated')
|
| 236 |
+
logger.debug(f"Image analysis result: {caption}")
|
| 237 |
+
if output_format == "audio":
|
| 238 |
+
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 239 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 240 |
+
model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL, torch_dtype=dtype).to(device)
|
| 241 |
+
processor = AutoProcessor.from_pretrained(TTS_MODEL)
|
| 242 |
+
inputs = processor(text=caption, return_tensors="pt").to(device)
|
| 243 |
+
audio = model.generate(**inputs)
|
| 244 |
+
audio_file = io.BytesIO()
|
| 245 |
+
torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
|
| 246 |
+
audio_file.seek(0)
|
| 247 |
+
audio_data = audio_file.read()
|
| 248 |
+
yield audio_data
|
| 249 |
+
else:
|
| 250 |
+
yield caption
|
| 251 |
+
cache[cache_key] = [caption]
|
| 252 |
+
return
|
| 253 |
else:
|
| 254 |
+
logger.error(f"Image analysis failed with status {response.status_code}: {response.text}")
|
| 255 |
+
yield f"Error: Image analysis failed with status {response.status_code}: {response.text}"
|
| 256 |
+
return
|
| 257 |
+
except Exception as e:
|
| 258 |
+
logger.error(f"Image analysis failed: {e}")
|
| 259 |
+
yield f"Error: Image analysis failed: {e}"
|
| 260 |
return
|
| 261 |
+
finally:
|
| 262 |
+
if 'model' in locals():
|
| 263 |
+
del model
|
| 264 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 265 |
+
|
| 266 |
+
# === معالجة توليد الصور أو تحريرها ===
|
|
|
|
|
|
|
|
|
|
| 267 |
if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "image_gen":
|
| 268 |
task_type = "image_generation"
|
| 269 |
try:
|
|
|
|
| 302 |
del pipe
|
| 303 |
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 304 |
|
| 305 |
+
# === معالجة النصوص (الدردشة) ===
|
| 306 |
if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL]:
|
| 307 |
task_type = "image"
|
| 308 |
enhanced_system_prompt = f"{system_prompt}\nYou are an expert in image analysis and description. Provide detailed descriptions, classifications, or analysis of images based on the query."
|
|
|
|
| 652 |
else:
|
| 653 |
yield f"Error: Failed to load model {model_name}: {e}"
|
| 654 |
return
|
|
|
|
| 655 |
def format_final(analysis_text: str, visible_text: str) -> str:
|
| 656 |
reasoning_safe = html.escape((analysis_text or "").strip())
|
| 657 |
response = (visible_text or "").strip()
|