Mark-Lasfar commited on
Commit
e1c2945
·
1 Parent(s): 07eb745

endpoints.py generation.py

Browse files
Files changed (1) hide show
  1. utils/generation.py +42 -43
utils/generation.py CHANGED
@@ -172,7 +172,7 @@ def request_generation(
172
  enhanced_system_prompt = system_prompt
173
  buffer = ""
174
 
175
- # معالجة الصوت
176
  if model_name == ASR_MODEL and audio_data:
177
  task_type = "audio_transcription"
178
  try:
@@ -196,7 +196,7 @@ def request_generation(
196
  yield f"Error: Audio transcription failed: {e}"
197
  return
198
 
199
- # معالجة تحويل النص إلى صوت
200
  if model_name == TTS_MODEL or output_format == "audio":
201
  task_type = "text_to_speech"
202
  try:
@@ -223,47 +223,47 @@ def request_generation(
223
  del model
224
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
225
 
226
- # معالجة تحليل الصور
227
- # معالجة تحليل الصور
228
- if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL] and image_data:
229
- task_type = "image_analysis"
230
- try:
231
- url = f"{IMAGE_INFERENCE_API}/{model_name}"
232
- headers = {"Authorization": f"Bearer {api_key}"}
233
- response = requests.post(url, headers=headers, data=image_data)
234
- if response.status_code == 200:
235
- result = response.json()
236
- caption = result[0]['generated_text'] if isinstance(result, list) else result.get('generated_text', 'No caption generated')
237
- logger.debug(f"Image analysis result: {caption}")
238
- if output_format == "audio":
239
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
240
- device = "cuda" if torch.cuda.is_available() else "cpu"
241
- model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL, torch_dtype=dtype).to(device)
242
- processor = AutoProcessor.from_pretrained(TTS_MODEL)
243
- inputs = processor(text=caption, return_tensors="pt").to(device)
244
- audio = model.generate(**inputs)
245
- audio_file = io.BytesIO()
246
- torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
247
- audio_file.seek(0)
248
- audio_data = audio_file.read()
249
- yield audio_data
 
 
 
250
  else:
251
- yield caption
252
- cache[cache_key] = [caption]
253
- return
254
- else:
255
- logger.error(f"Image analysis failed with status {response.status_code}: {response.text}")
256
- yield f"Error: Image analysis failed with status {response.status_code}: {response.text}"
257
  return
258
- except Exception as e:
259
- logger.error(f"Image analysis failed: {e}")
260
- yield f"Error: Image analysis failed: {e}"
261
- return
262
- finally:
263
- if 'model' in locals():
264
- del model
265
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
266
- # معالجة توليد الصور أو تحريرها
267
  if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "image_gen":
268
  task_type = "image_generation"
269
  try:
@@ -302,7 +302,7 @@ if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL] and image_data:
302
  del pipe
303
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
304
 
305
- # معالجة النصوص
306
  if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL]:
307
  task_type = "image"
308
  enhanced_system_prompt = f"{system_prompt}\nYou are an expert in image analysis and description. Provide detailed descriptions, classifications, or analysis of images based on the query."
@@ -652,7 +652,6 @@ if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL] and image_data:
652
  else:
653
  yield f"Error: Failed to load model {model_name}: {e}"
654
  return
655
-
656
  def format_final(analysis_text: str, visible_text: str) -> str:
657
  reasoning_safe = html.escape((analysis_text or "").strip())
658
  response = (visible_text or "").strip()
 
172
  enhanced_system_prompt = system_prompt
173
  buffer = ""
174
 
175
+ # === معالجة الصوت ===
176
  if model_name == ASR_MODEL and audio_data:
177
  task_type = "audio_transcription"
178
  try:
 
196
  yield f"Error: Audio transcription failed: {e}"
197
  return
198
 
199
+ # === معالجة تحويل النص إلى صوت ===
200
  if model_name == TTS_MODEL or output_format == "audio":
201
  task_type = "text_to_speech"
202
  try:
 
223
  del model
224
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
225
 
226
+ # === معالجة تحليل الصور ===
227
+ if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL] and image_data:
228
+ task_type = "image_analysis"
229
+ try:
230
+ url = f"{IMAGE_INFERENCE_API}/{model_name}"
231
+ headers = {"Authorization": f"Bearer {api_key}"}
232
+ response = requests.post(url, headers=headers, data=image_data)
233
+ if response.status_code == 200:
234
+ result = response.json()
235
+ caption = result[0]['generated_text'] if isinstance(result, list) else result.get('generated_text', 'No caption generated')
236
+ logger.debug(f"Image analysis result: {caption}")
237
+ if output_format == "audio":
238
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
239
+ device = "cuda" if torch.cuda.is_available() else "cpu"
240
+ model = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL, torch_dtype=dtype).to(device)
241
+ processor = AutoProcessor.from_pretrained(TTS_MODEL)
242
+ inputs = processor(text=caption, return_tensors="pt").to(device)
243
+ audio = model.generate(**inputs)
244
+ audio_file = io.BytesIO()
245
+ torchaudio.save(audio_file, audio[0], sample_rate=22050, format="wav")
246
+ audio_file.seek(0)
247
+ audio_data = audio_file.read()
248
+ yield audio_data
249
+ else:
250
+ yield caption
251
+ cache[cache_key] = [caption]
252
+ return
253
  else:
254
+ logger.error(f"Image analysis failed with status {response.status_code}: {response.text}")
255
+ yield f"Error: Image analysis failed with status {response.status_code}: {response.text}"
256
+ return
257
+ except Exception as e:
258
+ logger.error(f"Image analysis failed: {e}")
259
+ yield f"Error: Image analysis failed: {e}"
260
  return
261
+ finally:
262
+ if 'model' in locals():
263
+ del model
264
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
265
+
266
+ # === معالجة توليد الصور أو تحريرها ===
 
 
 
267
  if model_name in [IMAGE_GEN_MODEL, SECONDARY_IMAGE_GEN_MODEL] or input_type == "image_gen":
268
  task_type = "image_generation"
269
  try:
 
302
  del pipe
303
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
304
 
305
+ # === معالجة النصوص (الدردشة) ===
306
  if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL]:
307
  task_type = "image"
308
  enhanced_system_prompt = f"{system_prompt}\nYou are an expert in image analysis and description. Provide detailed descriptions, classifications, or analysis of images based on the query."
 
652
  else:
653
  yield f"Error: Failed to load model {model_name}: {e}"
654
  return
 
655
  def format_final(analysis_text: str, visible_text: str) -> str:
656
  reasoning_safe = html.escape((analysis_text or "").strip())
657
  response = (visible_text or "").strip()