anaspro commited on
Commit
64854b8
·
1 Parent(s): a0e315d

✨ Major chatbox optimization and enhancement

Browse files

- 🚀 Improved model loading with error handling and GPU optimization
- 💬 Fixed conversation history handling for better context management
- ⚙️ Enhanced text generation with optimized parameters
- 🛡️ Added comprehensive error handling and recovery
- 🎨 Enhanced UI with better styling and user experience
- 📋 Added configuration file for easy parameter tuning
- 🙈 Added .gitignore to prevent cache files

Files changed (4) hide show
  1. .gitignore +44 -0
  2. app.py +290 -93
  3. config.json +31 -0
  4. requirements.txt +4 -1
.gitignore ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+ MANIFEST
23
+
24
+ # Virtual environments
25
+ venv/
26
+ env/
27
+ ENV/
28
+
29
+ # IDE
30
+ .vscode/
31
+ .idea/
32
+ *.swp
33
+ *.swo
34
+
35
+ # OS
36
+ .DS_Store
37
+ Thumbs.db
38
+
39
+ # Logs
40
+ *.log
41
+
42
+ # Model cache
43
+ models/
44
+ checkpoints/
app.py CHANGED
@@ -2,132 +2,329 @@ import os
2
  import torch
3
  import gradio as gr
4
  import spaces
 
 
5
  from threading import Thread
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
7
  from huggingface_hub import login
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  # ======================================================
10
  # Settings
11
  # ======================================================
12
- MODEL_ID = "anaspro/Lahja-iraqi-4B"
13
 
14
  # Load system prompt from external file
15
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
16
- SYSTEM_PROMPT = f.read()
 
 
 
 
17
 
18
  # Login to Hugging Face
19
  if os.getenv("HF_TOKEN"):
20
  login(token=os.getenv("HF_TOKEN"))
21
- print("🔐 Logged in to Hugging Face")
22
 
23
  # Global model variables
24
  model = None
25
  tokenizer = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # ======================================================
28
  # Chat function (ZeroGPU)
29
  # ======================================================
30
  @spaces.GPU(duration=120)
31
  def chat(message, history):
 
32
  global model, tokenizer
33
 
34
- # Load model once
35
- if model is None:
36
- print("🔄 Loading model...")
37
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
38
- model = AutoModelForCausalLM.from_pretrained(
39
- MODEL_ID,
40
- dtype=torch.bfloat16,
41
- device_map="auto",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  )
43
- model.eval()
44
- print("✅ Model loaded!")
45
- else:
46
- print("♻️ Reusing already loaded model in memory.")
47
-
48
- # ======================================================
49
- # Build conversation
50
- # ======================================================
51
- messages = [{"role": "system", "content": SYSTEM_PROMPT}]
52
-
53
- # Add conversation history
54
- for turn in history:
55
- if isinstance(turn, dict):
56
- role = turn.get("role")
57
- content = turn.get("content")
58
- if role and content:
59
- messages.append({"role": role, "content": content})
60
- elif isinstance(turn, (list, tuple)) and len(turn) == 2:
61
- messages.append({"role": "user", "content": turn[0]})
62
- messages.append({"role": "assistant", "content": turn[1]})
63
-
64
- # Add current user message
65
- messages.append({"role": "user", "content": message})
66
-
67
- # ======================================================
68
- # Tokenize input
69
- # ======================================================
70
- input_ids = tokenizer.apply_chat_template(
71
- messages,
72
- return_tensors="pt",
73
- add_generation_prompt=True
74
- ).to(model.device)
75
-
76
- # ======================================================
77
- # Setup text streamer
78
- # ======================================================
79
- streamer = TextIteratorStreamer(
80
- tokenizer,
81
- skip_prompt=True,
82
- skip_special_tokens=True
83
- )
84
-
85
- generation_kwargs = {
86
- "input_ids": input_ids,
87
- "streamer": streamer,
88
- "max_new_tokens": 1024,
89
- "temperature": 0.85,
90
- "top_p": 0.9,
91
- "top_k": 50,
92
- "do_sample": True,
93
- "repetition_penalty": 1.1,
94
- "eos_token_id": None, # ⬅️ مهم حتى لا يتوقف مبكراً
95
- }
96
 
97
- # ======================================================
98
- # Generate output in a separate thread
99
- # ======================================================
100
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
101
- thread.start()
 
 
 
 
 
 
 
 
 
102
 
103
- partial_text = ""
104
- for new_text in streamer:
105
- partial_text += new_text
106
- print(new_text, end="", flush=True)
107
- yield partial_text
 
108
 
109
- thread.join()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
 
112
  # ======================================================
113
- # Gradio Interface
114
  # ======================================================
115
- demo = gr.ChatInterface(
116
- fn=chat,
117
- type="messages",
118
- title="📞 دعم فني - NB TEL Internet Assistant",
119
- description=(
120
- "**مساعد ذكي لخدمة الدعم الفني في شبكة النور - NB TEL**\n\n"
121
- "تحدث معه كأنك زبون: اشرح مشكلتك، اسأل عن الباقات، أو اطلب تذكرة دعم."
122
- ),
123
- examples=[
124
- ["الإنترنت عندي مقطوع من الصبح، شنو السبب؟"],
125
- ["أريد أرقّي الباقة إلى 50 ميج."],
126
- ["ضوء الـ LOS في جهاز الفايبر أحمر، شنو معناها؟"],
127
- ],
128
- theme=gr.themes.Soft(),
129
- cache_examples=False,
130
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  if __name__ == "__main__":
133
  demo.launch()
 
2
  import torch
3
  import gradio as gr
4
  import spaces
5
+ import json
6
+ import time
7
  from threading import Thread
8
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
9
  from huggingface_hub import login
10
+ import logging
11
+
12
+ # Setup logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # ======================================================
17
+ # Load Configuration
18
+ # ======================================================
19
+ def load_config():
20
+ """Load configuration from config.json"""
21
+ try:
22
+ with open("config.json", "r", encoding="utf-8") as f:
23
+ return json.load(f)
24
+ except FileNotFoundError:
25
+ logger.warning("config.json not found, using default settings")
26
+ return {
27
+ "model": {"model_id": "anaspro/Lahja-iraqi-4B"},
28
+ "generation": {
29
+ "max_new_tokens": 1024,
30
+ "temperature": 0.7,
31
+ "top_p": 0.9,
32
+ "top_k": 50,
33
+ "do_sample": True,
34
+ "repetition_penalty": 1.1,
35
+ "timeout_seconds": 60
36
+ },
37
+ "interface": {"max_context_length": 4096}
38
+ }
39
+
40
+ config = load_config()
41
 
42
  # ======================================================
43
  # Settings
44
  # ======================================================
45
+ MODEL_ID = config["model"].get("model_id", "anaspro/Lahja-iraqi-4B")
46
 
47
  # Load system prompt from external file
48
+ try:
49
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
50
+ SYSTEM_PROMPT = f.read()
51
+ except FileNotFoundError:
52
+ logger.warning("system_prompt.txt not found, using default prompt")
53
+ SYSTEM_PROMPT = "أنت مساعد ذكي مفيد. تحدث بالعربية وساعد المستخدم في استفساراته."
54
 
55
  # Login to Hugging Face
56
  if os.getenv("HF_TOKEN"):
57
  login(token=os.getenv("HF_TOKEN"))
58
+ logger.info("🔐 Logged in to Hugging Face")
59
 
60
  # Global model variables
61
  model = None
62
  tokenizer = None
63
+ model_lock = False
64
+
65
+ # ======================================================
66
+ # Model loading function
67
+ # ======================================================
68
+ def load_model():
69
+ """Load the model and tokenizer with proper error handling"""
70
+ global model, tokenizer, model_lock
71
+
72
+ if model_lock:
73
+ logger.info("Model loading already in progress...")
74
+ return False
75
+
76
+ model_lock = True
77
+ try:
78
+ logger.info("🔄 Loading model...")
79
+
80
+ # Load tokenizer first
81
+ tokenizer = AutoTokenizer.from_pretrained(
82
+ MODEL_ID,
83
+ trust_remote_code=True,
84
+ use_fast=True
85
+ )
86
+
87
+ # Add padding token if missing
88
+ if tokenizer.pad_token is None:
89
+ tokenizer.pad_token = tokenizer.eos_token
90
+
91
+ # Load model with optimized settings
92
+ model = AutoModelForCausalLM.from_pretrained(
93
+ MODEL_ID,
94
+ torch_dtype=torch.bfloat16,
95
+ device_map="auto",
96
+ trust_remote_code=True,
97
+ attn_implementation="flash_attention_2" if torch.cuda.is_available() else None,
98
+ low_cpu_mem_usage=True
99
+ )
100
+
101
+ model.eval()
102
+
103
+ # Clear cache to free memory
104
+ if torch.cuda.is_available():
105
+ torch.cuda.empty_cache()
106
+
107
+ logger.info("✅ Model loaded successfully!")
108
+ return True
109
+
110
+ except Exception as e:
111
+ logger.error(f"❌ Error loading model: {str(e)}")
112
+ return False
113
+ finally:
114
+ model_lock = False
115
 
116
  # ======================================================
117
  # Chat function (ZeroGPU)
118
  # ======================================================
119
  @spaces.GPU(duration=120)
120
  def chat(message, history):
121
+ """Main chat function with improved error handling and conversation management"""
122
  global model, tokenizer
123
 
124
+ # Load model if not already loaded
125
+ if model is None or tokenizer is None:
126
+ if not load_model():
127
+ return "❌ عذراً، حدث خطأ في تحميل النموذج. يرجى المحاولة مرة أخرى."
128
+
129
+ try:
130
+ # ======================================================
131
+ # Build conversation properly
132
+ # ======================================================
133
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
134
+
135
+ # Process conversation history correctly
136
+ if history:
137
+ for exchange in history:
138
+ if isinstance(exchange, dict):
139
+ # Handle message format from Gradio
140
+ if exchange.get("role") == "user":
141
+ messages.append({"role": "user", "content": exchange.get("content", "")})
142
+ elif exchange.get("role") == "assistant":
143
+ messages.append({"role": "assistant", "content": exchange.get("content", "")})
144
+ elif isinstance(exchange, (list, tuple)) and len(exchange) >= 2:
145
+ # Handle [user_msg, assistant_msg] format
146
+ if exchange[0]: # User message
147
+ messages.append({"role": "user", "content": str(exchange[0])})
148
+ if exchange[1]: # Assistant message
149
+ messages.append({"role": "assistant", "content": str(exchange[1])})
150
+
151
+ # Add current user message
152
+ if message and message.strip():
153
+ messages.append({"role": "user", "content": message.strip()})
154
+ else:
155
+ return "يرجى كتابة رسالة صحيحة."
156
+
157
+ # ======================================================
158
+ # Tokenize input with error handling
159
+ # ======================================================
160
+ try:
161
+ max_length = config.get("interface", {}).get("max_context_length", 4096)
162
+ input_ids = tokenizer.apply_chat_template(
163
+ messages,
164
+ return_tensors="pt",
165
+ add_generation_prompt=True,
166
+ truncation=True,
167
+ max_length=max_length
168
+ ).to(model.device)
169
+ except Exception as e:
170
+ logger.error(f"Tokenization error: {e}")
171
+ return "❌ خطأ في معالجة الرسالة. يرجى المحاولة مرة أخرى."
172
+
173
+ # ======================================================
174
+ # Setup text streamer
175
+ # ======================================================
176
+ streamer = TextIteratorStreamer(
177
+ tokenizer,
178
+ skip_prompt=True,
179
+ skip_special_tokens=True,
180
+ clean_up_tokenization_spaces=True
181
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
+ generation_config = config.get("generation", {})
184
+ generation_kwargs = {
185
+ "input_ids": input_ids,
186
+ "streamer": streamer,
187
+ "max_new_tokens": generation_config.get("max_new_tokens", 1024),
188
+ "temperature": generation_config.get("temperature", 0.7),
189
+ "top_p": generation_config.get("top_p", 0.9),
190
+ "top_k": generation_config.get("top_k", 50),
191
+ "do_sample": generation_config.get("do_sample", True),
192
+ "repetition_penalty": generation_config.get("repetition_penalty", 1.1),
193
+ "pad_token_id": tokenizer.pad_token_id,
194
+ "eos_token_id": tokenizer.eos_token_id,
195
+ "use_cache": True
196
+ }
197
 
198
+ # ======================================================
199
+ # Generate output in a separate thread with timeout
200
+ # ======================================================
201
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
202
+ thread.daemon = True
203
+ thread.start()
204
 
205
+ partial_text = ""
206
+ start_time = time.time()
207
+ timeout = config.get("generation", {}).get("timeout_seconds", 60)
208
+
209
+ try:
210
+ for new_text in streamer:
211
+ if time.time() - start_time > timeout:
212
+ logger.warning("Generation timeout reached")
213
+ break
214
+
215
+ partial_text += new_text
216
+ yield partial_text
217
+ except Exception as e:
218
+ logger.error(f"Generation error: {e}")
219
+ yield "❌ حدث خطأ أثناء توليد الإجابة. يرجى المحاولة مرة أخرى."
220
+
221
+ thread.join(timeout=5) # Give thread 5 seconds to finish
222
+
223
+ # Clear GPU cache after generation
224
+ if torch.cuda.is_available():
225
+ torch.cuda.empty_cache()
226
+
227
+ except Exception as e:
228
+ logger.error(f"Chat function error: {e}")
229
+ return f"❌ حدث خطأ غير متوقع: {str(e)}"
230
 
231
 
232
  # ======================================================
233
+ # Gradio Interface with enhanced styling
234
  # ======================================================
235
+ def create_interface():
236
+ """Create the Gradio interface with enhanced UI"""
237
+
238
+ # Custom CSS for better styling
239
+ custom_css = """
240
+ .gradio-container {
241
+ max-width: 1000px !important;
242
+ margin: auto !important;
243
+ }
244
+ .chat-message {
245
+ padding: 10px !important;
246
+ margin: 5px 0 !important;
247
+ border-radius: 10px !important;
248
+ }
249
+ .message {
250
+ font-size: 16px !important;
251
+ line-height: 1.5 !important;
252
+ }
253
+ .title {
254
+ text-align: center !important;
255
+ color: #2563eb !important;
256
+ margin-bottom: 20px !important;
257
+ }
258
+ .description {
259
+ text-align: center !important;
260
+ margin-bottom: 30px !important;
261
+ color: #6b7280 !important;
262
+ }
263
+ """
264
+
265
+ with gr.Blocks(
266
+ css=custom_css,
267
+ theme=gr.themes.Soft(
268
+ primary_hue="blue",
269
+ secondary_hue="gray",
270
+ neutral_hue="slate"
271
+ ),
272
+ title="دعم فني - NB TEL"
273
+ ) as demo:
274
+
275
+ gr.Markdown(
276
+ """
277
+ # 📞 دعم فني - NB TEL Internet Assistant
278
+
279
+ **مساعد ذكي لخدمة الدعم الفني في شبكة النور - NB TEL**
280
+
281
+ تحدث معه كأنك زبون: اشرح مشكلتك، اسأل عن الباقات، أو اطلب تذكرة دعم.
282
+ """,
283
+ elem_classes=["title", "description"]
284
+ )
285
+
286
+ # Chat interface
287
+ chatbot = gr.ChatInterface(
288
+ fn=chat,
289
+ type="messages",
290
+ examples=[
291
+ ["الإنترنت عندي مقطوع من الصبح، شنو السبب؟"],
292
+ ["أريد أرقّي الباقة إلى 50 ميج."],
293
+ ["ضوء الـ LOS في جهاز الفايبر أحمر، شنو معناها؟"],
294
+ ["كم سعر باقة الإنترنت اللامحدود؟"],
295
+ ["المودم يفصل ويوصل باستمرار، شنو الحل؟"]
296
+ ],
297
+ cache_examples=False,
298
+ retry_btn="🔄 إعادة المحاولة",
299
+ undo_btn="↶ تراجع",
300
+ clear_btn="🗑️ مسح المحادثة",
301
+ submit_btn="إرسال 📤",
302
+ textbox=gr.Textbox(
303
+ placeholder="اكتب استفسارك هنا... 💬",
304
+ container=False,
305
+ scale=7
306
+ )
307
+ )
308
+
309
+ # Footer with information
310
+ gr.Markdown(
311
+ """
312
+ ---
313
+ **ملاحظة:** هذا مساعد ذكي للمحاكاة. البيانات المعروضة هي للتدريب فقط.
314
+
315
+ **الباقات المتاحة:**
316
+ - 🏠 HOME-10M: 10 Mbps - $9.99/شهر
317
+ - 🏠 HOME-50M: 50 Mbps - $19.99/شهر
318
+ - 🏢 BUS-200M: 200 Mbps - $69.99/شهر
319
+ - ⚡ UNL-1G: 1 Gbps غير محدود - $149.99/شهر
320
+ """,
321
+ elem_classes=["description"]
322
+ )
323
+
324
+ return demo
325
+
326
+ # Create the interface
327
+ demo = create_interface()
328
 
329
  if __name__ == "__main__":
330
  demo.launch()
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": {
3
+ "model_id": "anaspro/Lahja-iraqi-4B",
4
+ "torch_dtype": "bfloat16",
5
+ "device_map": "auto",
6
+ "trust_remote_code": true,
7
+ "use_flash_attention": true,
8
+ "low_cpu_mem_usage": true
9
+ },
10
+ "generation": {
11
+ "max_new_tokens": 1024,
12
+ "temperature": 0.7,
13
+ "top_p": 0.9,
14
+ "top_k": 50,
15
+ "do_sample": true,
16
+ "repetition_penalty": 1.1,
17
+ "timeout_seconds": 60
18
+ },
19
+ "interface": {
20
+ "title": "📞 دعم فني - NB TEL Internet Assistant",
21
+ "description": "مساعد ذكي لخدمة الدعم الفني في شبكة النور - NB TEL",
22
+ "max_context_length": 4096,
23
+ "share": false,
24
+ "server_name": "0.0.0.0",
25
+ "server_port": 7860
26
+ },
27
+ "logging": {
28
+ "level": "INFO",
29
+ "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
30
+ }
31
+ }
requirements.txt CHANGED
@@ -6,4 +6,7 @@ torch>=2.2.0
6
  bitsandbytes>=0.42.0
7
  huggingface_hub>=0.23.0
8
  xformers>=0.0.27
9
- triton>=2.1.0
 
 
 
 
6
  bitsandbytes>=0.42.0
7
  huggingface_hub>=0.23.0
8
  xformers>=0.0.27
9
+ triton>=2.1.0
10
+ flash-attn>=2.5.0
11
+ sentencepiece>=0.1.99
12
+ protobuf>=3.20.0