FrameRateTech commited on
Commit
89d86b2
·
verified ·
1 Parent(s): 1a42bbb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +473 -102
app.py CHANGED
@@ -1,90 +1,250 @@
1
  # app.py
2
 
 
 
 
 
 
3
  import transformers
4
  import torch
5
  import gradio as gr
6
  from transformers import (
7
  AutoTokenizer,
8
  AutoModelForCausalLM,
9
- GenerationConfig
 
10
  )
11
 
12
  ###############################################################################
13
- # Debug Print Section
14
  ###############################################################################
15
- MODEL_ID = "FrameRateTech/DamageScan-llama-8b-instruct-merged"
16
- print("Transformers version:", transformers.__version__)
17
-
18
- # Attempt to load the tokenizer once just to see what happens
19
- try:
20
- tokenizer_test = AutoTokenizer.from_pretrained(
21
- MODEL_ID,
22
- use_fast=False,
23
- trust_remote_code=True
24
- )
25
- print("tokenizer_test =", tokenizer_test)
26
- print("type(tokenizer_test) =", type(tokenizer_test))
27
- except Exception as e:
28
- print("AutoTokenizer failed with exception:", e)
29
- raise e
30
-
31
- # If it's returning False, bail out early so we don't crash below
32
- if tokenizer_test is False:
33
- raise ValueError("AutoTokenizer returned False, meaning it failed to load properly.")
34
 
35
  ###############################################################################
36
- # 1. Load Tokenizer
37
  ###############################################################################
38
- # Now load the real tokenizer for your app
39
- tokenizer = AutoTokenizer.from_pretrained(
40
- MODEL_ID,
41
- use_fast=False,
42
- trust_remote_code=True
43
- )
44
 
45
- # If `tokenizer` is not False, set pad_token_id if needed
46
- if getattr(tokenizer, "pad_token_id", None) is None:
47
- tokenizer.pad_token_id = getattr(tokenizer, "eos_token_id", None)
48
 
49
  ###############################################################################
50
- # 2. Load Model
51
  ###############################################################################
52
- model = AutoModelForCausalLM.from_pretrained(
53
- MODEL_ID,
54
- torch_dtype=torch.float16,
55
- device_map="auto",
56
- trust_remote_code=True
57
- )
58
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  ###############################################################################
61
- # 3. Default Generation Settings
62
  ###############################################################################
63
- default_gen_config = GenerationConfig(
64
- temperature=0.7,
65
- top_p=0.9,
66
- do_sample=True,
67
- repetition_penalty=1.1,
68
- max_new_tokens=256,
69
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  ###############################################################################
72
- # 4. Helper: Convert Chatbot Messages to Prompt
73
  ###############################################################################
74
- def messages_to_prompt(messages):
75
- conversation = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  for msg in messages:
77
- if msg["role"] == "user":
78
- conversation += f"User: {msg['content']}\n"
79
- elif msg["role"] == "assistant":
80
- conversation += f"Assistant: {msg['content']}\n"
81
- return conversation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- ###############################################################################
84
- # 5. Generation Function
85
- ###############################################################################
86
- def predict(messages, temperature, top_p, max_new_tokens):
87
- prompt_text = messages_to_prompt(messages) + "Assistant:"
 
 
 
88
  gen_config = GenerationConfig(
89
  temperature=temperature,
90
  top_p=top_p,
@@ -92,52 +252,263 @@ def predict(messages, temperature, top_p, max_new_tokens):
92
  repetition_penalty=1.1,
93
  max_new_tokens=max_new_tokens,
94
  )
95
- with torch.no_grad():
96
- inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
97
- outputs = model.generate(**inputs, generation_config=gen_config)
98
- full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
99
- generated_reply = full_text[len(prompt_text):].strip()
100
- messages.append({"role": "assistant", "content": generated_reply})
101
- return messages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  ###############################################################################
104
- # 6. Build the Gradio Interface
105
  ###############################################################################
106
- with gr.Blocks() as demo:
107
- gr.Markdown("<h1 align='center'>DamageScan 8B Instruct Chatbot</h1>")
108
-
109
- with gr.Row():
110
- with gr.Column():
111
- chatbot = gr.Chatbot(label="Chat History", type="messages")
112
- with gr.Column():
113
- gr.Markdown("### Generation Settings")
114
- temperature_slider = gr.Slider(
115
- minimum=0.0, maximum=1.5, value=0.7, step=0.1, label="Temperature"
116
- )
117
- top_p_slider = gr.Slider(
118
- minimum=0.5, maximum=1.0, value=0.9, step=0.05, label="Top-p"
119
- )
120
- max_tokens_slider = gr.Slider(
121
- minimum=64, maximum=2048, value=256, step=64, label="Max New Tokens"
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- user_input = gr.Textbox(lines=1, label="Your Message", placeholder="Type here...")
125
- send_btn = gr.Button("Send")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- def user_submit(message_history, user_text, temp, top_p, max_tokens):
128
- message_history.append({"role": "user", "content": user_text})
129
- updated_messages = predict(message_history, temp, top_p, max_tokens)
130
- return updated_messages, ""
131
-
132
- send_btn.click(
133
- user_submit,
134
- inputs=[chatbot, user_input, temperature_slider, top_p_slider, max_tokens_slider],
135
- outputs=[chatbot, user_input],
136
- )
137
- user_input.submit(
138
- user_submit,
139
- inputs=[chatbot, user_input, temperature_slider, top_p_slider, max_tokens_slider],
140
- outputs=[chatbot, user_input],
141
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
- demo.queue().launch()
 
 
1
  # app.py
2
 
3
+ import os
4
+ import gc
5
+ import logging
6
+ import traceback
7
+ import time
8
  import transformers
9
  import torch
10
  import gradio as gr
11
  from transformers import (
12
  AutoTokenizer,
13
  AutoModelForCausalLM,
14
+ GenerationConfig,
15
+ BitsAndBytesConfig
16
  )
17
 
18
  ###############################################################################
19
+ # Configure Logging
20
  ###############################################################################
21
+ logging.basicConfig(
22
+ level=logging.INFO,
23
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
24
+ handlers=[
25
+ logging.StreamHandler()
26
+ ]
27
+ )
28
+ logger = logging.getLogger("DamageScan-App")
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  ###############################################################################
31
+ # Model Configuration
32
  ###############################################################################
33
+ MODEL_ID = "FrameRateTech/DamageScan-llama-8b-instruct-merged"
34
+ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
 
 
 
 
35
 
36
+ If a question is not clear or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
 
 
37
 
38
  ###############################################################################
39
+ # Device Configuration and Memory Management
40
  ###############################################################################
41
+ def get_device_info():
42
+ """Log information about available devices and memory"""
43
+ device_info = {
44
+ "cuda_available": torch.cuda.is_available(),
45
+ "device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
46
+ "mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
47
+ }
48
+
49
+ if device_info["cuda_available"] and device_info["device_count"] > 0:
50
+ device_info["cuda_device_name"] = torch.cuda.get_device_name(0)
51
+ device_info["cuda_device_mem_total"] = torch.cuda.get_device_properties(0).total_memory / (1024**3)
52
+ device_info["cuda_device_mem_reserved"] = torch.cuda.memory_reserved(0) / (1024**3)
53
+ device_info["cuda_device_mem_allocated"] = torch.cuda.memory_allocated(0) / (1024**3)
54
+
55
+ logger.info(f"Device information: {device_info}")
56
+ return device_info
57
+
58
+ def optimize_memory():
59
+ """Optimize memory usage by clearing caches and forcing garbage collection"""
60
+ if torch.cuda.is_available():
61
+ torch.cuda.empty_cache()
62
+ gc.collect()
63
+ logger.info("Memory optimized: caches cleared and garbage collected")
64
 
65
  ###############################################################################
66
+ # Model Loading with Error Handling
67
  ###############################################################################
68
+ def load_model_and_tokenizer():
69
+ """Load the model and tokenizer with comprehensive error handling and logging"""
70
+ logger.info(f"Loading model: {MODEL_ID}")
71
+ logger.info(f"Transformers version: {transformers.__version__}")
72
+ logger.info(f"PyTorch version: {torch.__version__}")
73
+
74
+ device_info = get_device_info()
75
+
76
+ # Determine quantization settings based on available hardware
77
+ load_in_4bit = False
78
+ load_in_8bit = False
79
+
80
+ if device_info["cuda_available"]:
81
+ # On ZEROGPU environments, 4-bit quantization helps fit the model in memory
82
+ load_in_4bit = True
83
+ logger.info("Using 4-bit quantization for CUDA device")
84
+
85
+ # Configure quantization if needed
86
+ if load_in_4bit:
87
+ quantization_config = BitsAndBytesConfig(
88
+ load_in_4bit=True,
89
+ bnb_4bit_compute_dtype=torch.float16,
90
+ bnb_4bit_quant_type="nf4",
91
+ bnb_4bit_use_double_quant=True
92
+ )
93
+ logger.info("Configured 4-bit quantization with NF4 type")
94
+ elif load_in_8bit:
95
+ quantization_config = BitsAndBytesConfig(
96
+ load_in_8bit=True
97
+ )
98
+ logger.info("Configured 8-bit quantization")
99
+ else:
100
+ quantization_config = None
101
+ logger.info("No quantization configured, using default precision")
102
+
103
+ # Step 1: Load tokenizer with detailed error logging
104
+ try:
105
+ logger.info("Loading tokenizer...")
106
+ tokenizer_start = time.time()
107
+ tokenizer = AutoTokenizer.from_pretrained(
108
+ MODEL_ID,
109
+ use_fast=False,
110
+ trust_remote_code=True
111
+ )
112
+ tokenizer_load_time = time.time() - tokenizer_start
113
+ logger.info(f"Tokenizer loaded successfully in {tokenizer_load_time:.2f} seconds")
114
+ logger.info(f"Tokenizer type: {type(tokenizer).__name__}")
115
+
116
+ # Log important tokenizer properties
117
+ tokenizer_info = {
118
+ "vocab_size": len(tokenizer),
119
+ "model_max_length": tokenizer.model_max_length,
120
+ "bos_token": tokenizer.bos_token,
121
+ "eos_token": tokenizer.eos_token,
122
+ "has_chat_template": hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None
123
+ }
124
+ logger.info(f"Tokenizer properties: {tokenizer_info}")
125
+
126
+ # Set pad token if needed
127
+ if getattr(tokenizer, "pad_token_id", None) is None:
128
+ logger.info("Pad token not found, setting pad_token_id to eos_token_id")
129
+ tokenizer.pad_token_id = getattr(tokenizer, "eos_token_id", None)
130
+ except Exception as e:
131
+ logger.error(f"Failed to load tokenizer: {str(e)}")
132
+ logger.error(traceback.format_exc())
133
+ raise RuntimeError(f"Failed to load tokenizer: {str(e)}")
134
+
135
+ # Step 2: Load model with detailed error logging
136
+ try:
137
+ logger.info("Loading model...")
138
+ model_start = time.time()
139
+
140
+ # Determine device map strategy
141
+ if device_info["cuda_available"]:
142
+ device_map = "auto"
143
+ torch_dtype = torch.float16
144
+ logger.info("Using 'auto' device map for CUDA with float16 precision")
145
+ elif device_info["mps_available"]:
146
+ device_map = {"": "mps"}
147
+ torch_dtype = torch.float16
148
+ logger.info("Using MPS device with float16 precision")
149
+ else:
150
+ device_map = {"": "cpu"}
151
+ torch_dtype = torch.float32
152
+ logger.info("Using CPU with float32 precision")
153
+
154
+ model = AutoModelForCausalLM.from_pretrained(
155
+ MODEL_ID,
156
+ torch_dtype=torch_dtype,
157
+ device_map=device_map,
158
+ trust_remote_code=True,
159
+ quantization_config=quantization_config
160
+ )
161
+ model.eval()
162
+ model_load_time = time.time() - model_start
163
+ logger.info(f"Model loaded successfully in {model_load_time:.2f} seconds")
164
+
165
+ # Log model info
166
+ model_info = {
167
+ "model_type": model.config.model_type,
168
+ "hidden_size": model.config.hidden_size,
169
+ "vocab_size": model.config.vocab_size,
170
+ "num_hidden_layers": model.config.num_hidden_layers
171
+ }
172
+ logger.info(f"Model properties: {model_info}")
173
+
174
+ except Exception as e:
175
+ logger.error(f"Failed to load model: {str(e)}")
176
+ logger.error(traceback.format_exc())
177
+ raise RuntimeError(f"Failed to load model: {str(e)}")
178
+
179
+ return model, tokenizer
180
 
181
  ###############################################################################
182
+ # Chat Formatting and Generation Functions
183
  ###############################################################################
184
+ def format_chat_for_model(messages, tokenizer, system_prompt=DEFAULT_SYSTEM_PROMPT):
185
+ """
186
+ Format chat messages for the model using the tokenizer's chat template if available,
187
+ or fall back to a manual format for Llama models.
188
+ """
189
+ logger.info(f"Formatting chat with {len(messages)} messages")
190
+
191
+ # Prepare messages in the correct format
192
+ formatted_messages = []
193
+
194
+ # Add system message if not already present
195
+ if messages and messages[0].get("role") != "system":
196
+ formatted_messages.append({"role": "system", "content": system_prompt})
197
+
198
+ # Add user and assistant messages
199
  for msg in messages:
200
+ role = msg["role"]
201
+ # Skip system messages if we already added one
202
+ if role == "system" and formatted_messages and formatted_messages[0]["role"] == "system":
203
+ continue
204
+ formatted_messages.append({"role": role, "content": msg["content"]})
205
+
206
+ # Use the tokenizer's built-in chat template if available
207
+ if hasattr(tokenizer, "apply_chat_template") and callable(tokenizer.apply_chat_template):
208
+ logger.info("Using tokenizer's built-in chat template")
209
+ try:
210
+ chat_text = tokenizer.apply_chat_template(
211
+ formatted_messages,
212
+ tokenize=False,
213
+ add_generation_prompt=True
214
+ )
215
+ logger.debug(f"Formatted chat using built-in template: {chat_text[:100]}...")
216
+ return chat_text
217
+ except Exception as e:
218
+ logger.warning(f"Failed to apply chat template: {str(e)}")
219
+ logger.warning("Falling back to manual formatting")
220
+
221
+ # Manual fallback format for Llama models
222
+ logger.info("Using manual chat formatting for Llama model")
223
+ chat_text = ""
224
+ for msg in formatted_messages:
225
+ role = msg["role"]
226
+ content = msg["content"]
227
+ if role == "system":
228
+ chat_text += f"<|start_header_id|>system<|end_header_id|>\n\n{content}<|eot_id|>"
229
+ elif role == "user":
230
+ chat_text += f"<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>"
231
+ elif role == "assistant":
232
+ chat_text += f"<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>"
233
+
234
+ # Add the final assistant header for generation
235
+ chat_text += "<|start_header_id|>assistant<|end_header_id|>\n\n"
236
+
237
+ logger.debug(f"Manually formatted chat: {chat_text[:100]}...")
238
+ return chat_text
239
 
240
+ def generate_response(model, tokenizer, messages, temperature=0.7, top_p=0.9, max_new_tokens=256, system_prompt=DEFAULT_SYSTEM_PROMPT):
241
+ """Generate a response from the model with retry logic and error handling"""
242
+ logger.info(f"Generating response with temp={temperature}, top_p={top_p}, max_tokens={max_new_tokens}")
243
+
244
+ # Format the messages for the model
245
+ prompt = format_chat_for_model(messages, tokenizer, system_prompt)
246
+
247
+ # Configure generation parameters
248
  gen_config = GenerationConfig(
249
  temperature=temperature,
250
  top_p=top_p,
 
252
  repetition_penalty=1.1,
253
  max_new_tokens=max_new_tokens,
254
  )
255
+
256
+ # Tokenize the input
257
+ try:
258
+ inputs = tokenizer(prompt, return_tensors="pt")
259
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
260
+ logger.info(f"Input tokenized to {inputs['input_ids'].shape[1]} tokens")
261
+ except Exception as e:
262
+ logger.error(f"Error during tokenization: {str(e)}")
263
+ return "I encountered an error while processing your message. Please try again."
264
+
265
+ # Generate with retry logic
266
+ max_retries = 3
267
+ retry_count = 0
268
+
269
+ while retry_count < max_retries:
270
+ try:
271
+ # Run the generation
272
+ generation_start = time.time()
273
+ with torch.no_grad():
274
+ output_ids = model.generate(
275
+ **inputs,
276
+ generation_config=gen_config,
277
+ )
278
+ generation_time = time.time() - generation_start
279
+ logger.info(f"Generation completed in {generation_time:.2f} seconds")
280
+
281
+ # Decode the output
282
+ generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
283
+
284
+ # Extract just the assistant's response
285
+ assistant_response = ""
286
+ if hasattr(tokenizer, "apply_chat_template") and callable(tokenizer.apply_chat_template):
287
+ # Extract assistant's response from the full output
288
+ if "<|start_header_id|>assistant<|end_header_id|>" in generated_text:
289
+ parts = generated_text.split("<|start_header_id|>assistant<|end_header_id|>")
290
+ if len(parts) > 1:
291
+ assistant_part = parts[-1]
292
+ if "<|eot_id|>" in assistant_part:
293
+ assistant_response = assistant_part.split("<|eot_id|>")[0].strip()
294
+ else:
295
+ assistant_response = assistant_part.strip()
296
+ else:
297
+ # Fall back to removing the prompt
298
+ assistant_response = generated_text[len(prompt):].strip()
299
+ else:
300
+ # Simple extraction method
301
+ assistant_response = generated_text[len(prompt):].strip()
302
+
303
+ logger.info(f"Response extracted, length: {len(assistant_response)} chars")
304
+
305
+ # Free up memory
306
+ del inputs, output_ids
307
+ optimize_memory()
308
+
309
+ return assistant_response
310
+
311
+ except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
312
+ retry_count += 1
313
+ logger.warning(f"Generation attempt {retry_count} failed: {str(e)}")
314
+
315
+ if retry_count < max_retries:
316
+ logger.info(f"Retrying with reduced parameters...")
317
+ # Reduce parameters to try to fit in memory
318
+ max_new_tokens = max(64, max_new_tokens // 2)
319
+ optimize_memory()
320
+ else:
321
+ logger.error(f"Failed to generate after {max_retries} attempts")
322
+ return "I'm sorry, I encountered a resource limitation while generating a response. Please try a shorter message or adjust the generation parameters."
323
+
324
+ except Exception as e:
325
+ logger.error(f"Unexpected error during generation: {str(e)}")
326
+ logger.error(traceback.format_exc())
327
+ return "I encountered an unexpected error. Please try again with different parameters."
328
 
329
  ###############################################################################
330
+ # Gradio Interface
331
  ###############################################################################
332
+ def build_gradio_interface(model, tokenizer):
333
+ """Build and launch the Gradio interface"""
334
+ logger.info("Building Gradio interface")
335
+
336
+ def user_submit(message_history, user_text, temp, top_p, max_tokens, system_message):
337
+ """Handle user message submission"""
338
+ logger.info(f"Received user message: '{user_text[:50]}...' (length: {len(user_text)})")
339
+
340
+ if not user_text.strip():
341
+ logger.warning("Empty user message, skipping processing")
342
+ return message_history, ""
343
+
344
+ try:
345
+ # Add user message to history
346
+ if not message_history:
347
+ # Start with system message if this is the first message
348
+ message_history = [{"role": "system", "content": system_message}]
349
+
350
+ message_history.append({"role": "user", "content": user_text})
351
+
352
+ # Generate response
353
+ assistant_response = generate_response(
354
+ model,
355
+ tokenizer,
356
+ message_history,
357
+ temperature=temp,
358
+ top_p=top_p,
359
+ max_new_tokens=max_tokens,
360
+ system_prompt=system_message
361
  )
362
+
363
+ # Add assistant response to history
364
+ message_history.append({"role": "assistant", "content": assistant_response})
365
+ logger.info(f"Added assistant response (length: {len(assistant_response)})")
366
+
367
+ # Optimize memory after generation
368
+ optimize_memory()
369
+
370
+ return message_history, ""
371
+
372
+ except Exception as e:
373
+ logger.error(f"Error in user_submit: {str(e)}")
374
+ logger.error(traceback.format_exc())
375
+
376
+ # Return original message history plus error message
377
+ error_msg = "I encountered an error processing your request. Please try again."
378
+ if not message_history:
379
+ message_history = []
380
+ message_history.append({"role": "user", "content": user_text})
381
+ message_history.append({"role": "assistant", "content": error_msg})
382
+
383
+ return message_history, ""
384
+
385
+ def clear_chat():
386
+ """Clear the chat history"""
387
+ logger.info("Clearing chat history")
388
+ optimize_memory()
389
+ return [], ""
390
+
391
+ # Define the Gradio interface
392
+ with gr.Blocks(css="footer {visibility: hidden}") as demo:
393
+ gr.Markdown("<h1 align='center'>DamageScan 8B Instruct Chatbot</h1>")
394
+ gr.Markdown("<p align='center'>Powered by FrameRateTech/DamageScan-llama-8b-instruct-merged</p>")
395
+
396
+ with gr.Row():
397
+ with gr.Column(scale=3):
398
+ chatbot = gr.Chatbot(
399
+ label="Chat History",
400
+ height=600,
401
+ avatar_images=(None, "https://huggingface.co/spaces/FrameRateTech/DamageScan-8b-instruct-chat/resolve/main/avatar.png"),
402
+ )
403
+
404
+ with gr.Row():
405
+ with gr.Column(scale=8):
406
+ user_input = gr.Textbox(
407
+ lines=3,
408
+ label="Your Message",
409
+ placeholder="Type your message here...",
410
+ show_copy_button=True
411
+ )
412
+ with gr.Column(scale=1, min_width=50):
413
+ submit_btn = gr.Button("Send", variant="primary")
414
+ clear_btn = gr.Button("Clear Chat")
415
+
416
+ with gr.Column(scale=1):
417
+ gr.Markdown("### System Prompt")
418
+ system_prompt_input = gr.Textbox(
419
+ lines=5,
420
+ label="System Instructions",
421
+ value=DEFAULT_SYSTEM_PROMPT,
422
+ show_copy_button=True
423
+ )
424
+
425
+ gr.Markdown("### Generation Settings")
426
+ temperature_slider = gr.Slider(
427
+ minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature",
428
+ info="Higher values make output more random, lower values more deterministic"
429
+ )
430
+ top_p_slider = gr.Slider(
431
+ minimum=0.5, maximum=1.0, value=0.9, step=0.05, label="Top-p",
432
+ info="Controls diversity via nucleus sampling"
433
+ )
434
+ max_tokens_slider = gr.Slider(
435
+ minimum=64, maximum=1024, value=256, step=64, label="Max New Tokens",
436
+ info="Maximum length of generated response"
437
+ )
438
+
439
+ gr.Markdown("### Tips")
440
+ gr.Markdown("""
441
+ * Lower temperature (0.1-0.3) for factual responses
442
+ * Higher temperature (0.7-1.0) for creative tasks
443
+ * Reduce max tokens if responses are too long
444
+ * Clear chat if the model gets confused
445
+ """)
446
 
447
+ # Set up event handlers
448
+ submit_btn.click(
449
+ user_submit,
450
+ inputs=[chatbot, user_input, temperature_slider, top_p_slider, max_tokens_slider, system_prompt_input],
451
+ outputs=[chatbot, user_input],
452
+ )
453
+ user_input.submit(
454
+ user_submit,
455
+ inputs=[chatbot, user_input, temperature_slider, top_p_slider, max_tokens_slider, system_prompt_input],
456
+ outputs=[chatbot, user_input],
457
+ )
458
+ clear_btn.click(
459
+ clear_chat,
460
+ outputs=[chatbot, user_input]
461
+ )
462
+
463
+ # Add example prompts
464
+ gr.Examples(
465
+ examples=[
466
+ ["Can you explain how the Large Hadron Collider works?"],
467
+ ["Write a short story about a robot who learns to paint"],
468
+ ["What are three ways to improve productivity when working from home?"],
469
+ ["Explain quantum computing to me like I'm 10 years old"],
470
+ ],
471
+ inputs=user_input,
472
+ label="Example Prompts"
473
+ )
474
+
475
+ return demo
476
 
477
+ ###############################################################################
478
+ # Main Application Logic
479
+ ###############################################################################
480
+ def main():
481
+ """Main application entry point"""
482
+ try:
483
+ logger.info("Starting DamageScan 8B Instruct application")
484
+ logger.info(f"Environment: CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
485
+
486
+ # Load model and tokenizer
487
+ model, tokenizer = load_model_and_tokenizer()
488
+
489
+ # Build and launch Gradio interface
490
+ demo = build_gradio_interface(model, tokenizer)
491
+
492
+ # Launch the app
493
+ logger.info("Launching Gradio interface")
494
+ demo.queue().launch(
495
+ share=False,
496
+ debug=False,
497
+ show_error=True,
498
+ favicon_path="https://huggingface.co/spaces/FrameRateTech/DamageScan-8b-instruct-chat/resolve/main/favicon.ico"
499
+ )
500
+
501
+ except Exception as e:
502
+ logger.error(f"Application startup failed: {str(e)}")
503
+ logger.error(traceback.format_exc())
504
+
505
+ # Create a minimal fallback UI to show the error
506
+ with gr.Blocks() as fallback_demo:
507
+ gr.Markdown("# ⚠️ DamageScan 8B Application Error")
508
+ gr.Markdown(f"The application encountered an error during startup:\n\n```\n{str(e)}\n```")
509
+ gr.Markdown("Please check the logs for more details or try again later.")
510
+
511
+ fallback_demo.launch()
512
 
513
+ if __name__ == "__main__":
514
+ main()