ibrahimlasfar commited on
Commit
67fb0f6
·
1 Parent(s): 41c40b7

Update chatbot with audio/image support and fixed models

Browse files
Files changed (1) hide show
  1. utils/generation.py +31 -16
utils/generation.py CHANGED
@@ -13,15 +13,13 @@ import pydub
13
  import io
14
  import torchaudio
15
  from PIL import Image
16
- import numpy as np
17
  from transformers import CLIPModel, CLIPProcessor, AutoProcessor
18
  from parler_tts import ParlerTTSForConditionalGeneration
19
- from utils.web_search import web_search # استيراد مباشر
20
 
21
  logger = logging.getLogger(__name__)
22
 
23
  # إعداد Cache
24
- cache = TTLCache(maxsize=100, ttl=600)
25
 
26
  # تعريف LATEX_DELIMS
27
  LATEX_DELIMS = [
@@ -33,18 +31,19 @@ LATEX_DELIMS = [
33
 
34
  # إعداد العميل لـ Hugging Face Inference API
35
  HF_TOKEN = os.getenv("HF_TOKEN")
36
- BACKUP_HF_TOKEN = os.getenv("BACKUP_HF_TOKEN")
37
- API_ENDPOINT = os.getenv("API_ENDPOINT", "https://api-inference.huggingface.co")
38
- FALLBACK_API_ENDPOINT = "https://api-inference.huggingface.co"
39
- MODEL_NAME = os.getenv("MODEL_NAME", "mistralai/Mixtral-8x7B-Instruct-v0.1")
40
- SECONDARY_MODEL_NAME = os.getenv("SECONDARY_MODEL_NAME", "meta-llama/Meta-Llama-3-8B-Instruct")
41
- TERTIARY_MODEL_NAME = os.getenv("TERTIARY_MODEL_NAME", "mistralai/Mixtral-8x22B-Instruct-v0.1")
42
  CLIP_BASE_MODEL = os.getenv("CLIP_BASE_MODEL", "openai/clip-vit-base-patch32")
43
  CLIP_LARGE_MODEL = os.getenv("CLIP_LARGE_MODEL", "openai/clip-vit-large-patch14")
44
- ASR_MODEL = os.getenv("ASR_MODEL", "openai/whisper-large-v3")
45
  TTS_MODEL = os.getenv("TTS_MODEL", "parler-tts/parler-tts-mini-v1")
46
 
47
  def check_model_availability(model_name: str, api_base: str, api_key: str) -> tuple[bool, str]:
 
48
  try:
49
  response = requests.get(
50
  f"{api_base}/models/{model_name}",
@@ -67,12 +66,15 @@ def check_model_availability(model_name: str, api_base: str, api_key: str) -> tu
67
 
68
  def select_model(query: str, input_type: str = "text") -> tuple[str, str]:
69
  query_lower = query.lower()
 
70
  if input_type == "audio" or any(keyword in query_lower for keyword in ["voice", "audio", "speech", "صوت", "تحويل صوت"]):
71
  logger.info(f"Selected {ASR_MODEL} with endpoint {FALLBACK_API_ENDPOINT} for audio input")
72
  return ASR_MODEL, FALLBACK_API_ENDPOINT
 
73
  if any(keyword in query_lower for keyword in ["text-to-speech", "tts", "تحويل نص إلى صوت"]):
74
  logger.info(f"Selected {TTS_MODEL} with endpoint {FALLBACK_API_ENDPOINT} for text-to-speech")
75
  return TTS_MODEL, FALLBACK_API_ENDPOINT
 
76
  image_patterns = [
77
  r"\bimage\b", r"\bpicture\b", r"\bphoto\b", r"\bvisual\b", r"\bصورة\b", r"\bتحليل\s+صورة\b",
78
  r"\bimage\s+analysis\b", r"\bimage\s+classification\b", r"\bimage\s+description\b"
@@ -81,6 +83,16 @@ def select_model(query: str, input_type: str = "text") -> tuple[str, str]:
81
  if re.search(pattern, query_lower, re.IGNORECASE):
82
  logger.info(f"Selected {CLIP_BASE_MODEL} with endpoint {FALLBACK_API_ENDPOINT} for image-related query: {query}")
83
  return CLIP_BASE_MODEL, FALLBACK_API_ENDPOINT
 
 
 
 
 
 
 
 
 
 
84
  logger.info(f"Selected {MODEL_NAME} with endpoint {API_ENDPOINT} for general query: {query}")
85
  return MODEL_NAME, API_ENDPOINT
86
 
@@ -102,7 +114,9 @@ def request_generation(
102
  audio_data: Optional[bytes] = None,
103
  image_data: Optional[bytes] = None,
104
  ) -> Generator[bytes | str, None, None]:
105
- # التحقق من توفر النموذج
 
 
106
  is_available, selected_api_key = check_model_availability(model_name, api_base, api_key)
107
  if not is_available:
108
  yield f"Error: Model {model_name} is not available. Please check the model endpoint or token."
@@ -129,10 +143,10 @@ def request_generation(
129
  enhanced_system_prompt = system_prompt
130
 
131
  # معالجة الصوت (ASR)
132
- if model_name == ASR_MODEL and audio_data is not None:
133
  task_type = "audio_transcription"
134
  try:
135
- audio_file = io.BytesIO(audio_data if isinstance(audio_data, bytes) else audio_data.tobytes())
136
  audio = pydub.AudioSegment.from_file(audio_file)
137
  audio = audio.set_channels(1).set_frame_rate(16000)
138
  audio_file = io.BytesIO()
@@ -171,12 +185,12 @@ def request_generation(
171
  return
172
 
173
  # معالجة الصور
174
- if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL] and image_data is not None:
175
  task_type = "image_analysis"
176
  try:
177
  model = CLIPModel.from_pretrained(model_name)
178
  processor = CLIPProcessor.from_pretrained(model_name)
179
- image = Image.fromarray(np.uint8(image_data)) if isinstance(image_data, np.ndarray) else Image.open(io.BytesIO(image_data)).convert("RGB")
180
  inputs = processor(text=message, images=image, return_tensors="pt", padding=True)
181
  outputs = model(**inputs)
182
  logits_per_image = outputs.logits_per_image
@@ -208,6 +222,7 @@ def request_generation(
208
  else:
209
  enhanced_system_prompt = f"{system_prompt}\nFor general queries, provide comprehensive, detailed responses with examples and explanations where applicable. Continue generating content until the query is fully answered, leveraging the full capacity of the model."
210
 
 
211
  if len(message.split()) < 5:
212
  enhanced_system_prompt += "\nEven for short or general queries, provide a detailed, in-depth response with examples, explanations, and additional context to ensure completeness."
213
 
@@ -487,7 +502,7 @@ def format_final(analysis_text: str, visible_text: str) -> str:
487
 
488
  def generate(message, history, system_prompt, temperature, reasoning_effort, enable_browsing, max_new_tokens, input_type="text", audio_data=None, image_data=None):
489
  if not message.strip() and not audio_data and not image_data:
490
- yield "Please enter a prompt, record audio, or capture an image."
491
  return
492
 
493
  model_name, api_endpoint = select_model(message, input_type=input_type)
 
13
  import io
14
  import torchaudio
15
  from PIL import Image
 
16
  from transformers import CLIPModel, CLIPProcessor, AutoProcessor
17
  from parler_tts import ParlerTTSForConditionalGeneration
 
18
 
19
  logger = logging.getLogger(__name__)
20
 
21
  # إعداد Cache
22
+ cache = TTLCache(maxsize=100, ttl=600) # Cache بحجم 100 ومدة 10 دقايق
23
 
24
  # تعريف LATEX_DELIMS
25
  LATEX_DELIMS = [
 
31
 
32
  # إعداد العميل لـ Hugging Face Inference API
33
  HF_TOKEN = os.getenv("HF_TOKEN")
34
+ BACKUP_HF_TOKEN = os.getenv("BACKUP_HF_TOKEN") # توكن احتياطي
35
+ API_ENDPOINT = os.getenv("API_ENDPOINT", "https://router.huggingface.co/v1")
36
+ FALLBACK_API_ENDPOINT = "https://api-inference.huggingface.co/v1"
37
+ MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-20b:fireworks-ai")
38
+ SECONDARY_MODEL_NAME = os.getenv("SECONDARY_MODEL_NAME", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
39
+ TERTIARY_MODEL_NAME = os.getenv("TERTIARY_MODEL_NAME", "mistralai/Mixtral-8x7B-Instruct-v0.1")
40
  CLIP_BASE_MODEL = os.getenv("CLIP_BASE_MODEL", "openai/clip-vit-base-patch32")
41
  CLIP_LARGE_MODEL = os.getenv("CLIP_LARGE_MODEL", "openai/clip-vit-large-patch14")
42
+ ASR_MODEL = os.getenv("ASR_MODEL", "openai/whisper-large-v3-turbo")
43
  TTS_MODEL = os.getenv("TTS_MODEL", "parler-tts/parler-tts-mini-v1")
44
 
45
  def check_model_availability(model_name: str, api_base: str, api_key: str) -> tuple[bool, str]:
46
+ """التحقق من توفر النموذج عبر API مع دعم التوكن الاحتياطي"""
47
  try:
48
  response = requests.get(
49
  f"{api_base}/models/{model_name}",
 
66
 
67
  def select_model(query: str, input_type: str = "text") -> tuple[str, str]:
68
  query_lower = query.lower()
69
+ # دعم الصوت
70
  if input_type == "audio" or any(keyword in query_lower for keyword in ["voice", "audio", "speech", "صوت", "تحويل صوت"]):
71
  logger.info(f"Selected {ASR_MODEL} with endpoint {FALLBACK_API_ENDPOINT} for audio input")
72
  return ASR_MODEL, FALLBACK_API_ENDPOINT
73
+ # دعم تحويل النص إلى صوت
74
  if any(keyword in query_lower for keyword in ["text-to-speech", "tts", "تحويل نص إلى صوت"]):
75
  logger.info(f"Selected {TTS_MODEL} with endpoint {FALLBACK_API_ENDPOINT} for text-to-speech")
76
  return TTS_MODEL, FALLBACK_API_ENDPOINT
77
+ # نماذج CLIP للاستعلامات المتعلقة بالصور
78
  image_patterns = [
79
  r"\bimage\b", r"\bpicture\b", r"\bphoto\b", r"\bvisual\b", r"\bصورة\b", r"\bتحليل\s+صورة\b",
80
  r"\bimage\s+analysis\b", r"\bimage\s+classification\b", r"\bimage\s+description\b"
 
83
  if re.search(pattern, query_lower, re.IGNORECASE):
84
  logger.info(f"Selected {CLIP_BASE_MODEL} with endpoint {FALLBACK_API_ENDPOINT} for image-related query: {query}")
85
  return CLIP_BASE_MODEL, FALLBACK_API_ENDPOINT
86
+ # نموذج DeepSeek للاستعلامات المتعلقة بـ MGZon
87
+ mgzon_patterns = [
88
+ r"\bmgzon\b", r"\bmgzon\s+(products|services|platform|features|mission|technology|solutions|oauth)\b",
89
+ r"\bميزات\s+mgzon\b", r"\bخدمات\s+mgzon\b", r"\boauth\b"
90
+ ]
91
+ for pattern in mgzon_patterns:
92
+ if re.search(pattern, query_lower, re.IGNORECASE):
93
+ logger.info(f"Selected {SECONDARY_MODEL_NAME} with endpoint {FALLBACK_API_ENDPOINT} for MGZon-related query: {query}")
94
+ return SECONDARY_MODEL_NAME, FALLBACK_API_ENDPOINT
95
+ # النموذج الافتراضي للاستعلامات العامة
96
  logger.info(f"Selected {MODEL_NAME} with endpoint {API_ENDPOINT} for general query: {query}")
97
  return MODEL_NAME, API_ENDPOINT
98
 
 
114
  audio_data: Optional[bytes] = None,
115
  image_data: Optional[bytes] = None,
116
  ) -> Generator[bytes | str, None, None]:
117
+ from utils.web_search import web_search # تأخير الاستيراد
118
+
119
+ # التحقق من توفر النموذج مع دعم التوكن الاحتياطي
120
  is_available, selected_api_key = check_model_availability(model_name, api_base, api_key)
121
  if not is_available:
122
  yield f"Error: Model {model_name} is not available. Please check the model endpoint or token."
 
143
  enhanced_system_prompt = system_prompt
144
 
145
  # معالجة الصوت (ASR)
146
+ if model_name == ASR_MODEL and audio_data:
147
  task_type = "audio_transcription"
148
  try:
149
+ audio_file = io.BytesIO(audio_data)
150
  audio = pydub.AudioSegment.from_file(audio_file)
151
  audio = audio.set_channels(1).set_frame_rate(16000)
152
  audio_file = io.BytesIO()
 
185
  return
186
 
187
  # معالجة الصور
188
+ if model_name in [CLIP_BASE_MODEL, CLIP_LARGE_MODEL] and image_data:
189
  task_type = "image_analysis"
190
  try:
191
  model = CLIPModel.from_pretrained(model_name)
192
  processor = CLIPProcessor.from_pretrained(model_name)
193
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
194
  inputs = processor(text=message, images=image, return_tensors="pt", padding=True)
195
  outputs = model(**inputs)
196
  logits_per_image = outputs.logits_per_image
 
222
  else:
223
  enhanced_system_prompt = f"{system_prompt}\nFor general queries, provide comprehensive, detailed responses with examples and explanations where applicable. Continue generating content until the query is fully answered, leveraging the full capacity of the model."
224
 
225
+ # إذا كان الاستعلام قصيرًا، شجع على التفصيل
226
  if len(message.split()) < 5:
227
  enhanced_system_prompt += "\nEven for short or general queries, provide a detailed, in-depth response with examples, explanations, and additional context to ensure completeness."
228
 
 
502
 
503
  def generate(message, history, system_prompt, temperature, reasoning_effort, enable_browsing, max_new_tokens, input_type="text", audio_data=None, image_data=None):
504
  if not message.strip() and not audio_data and not image_data:
505
+ yield "Please enter a prompt or upload a file."
506
  return
507
 
508
  model_name, api_endpoint = select_model(message, input_type=input_type)