Update app.py
Browse files
app.py
CHANGED
@@ -1,90 +1,250 @@
|
|
1 |
# app.py
|
2 |
|
|
|
|
|
|
|
|
|
|
|
3 |
import transformers
|
4 |
import torch
|
5 |
import gradio as gr
|
6 |
from transformers import (
|
7 |
AutoTokenizer,
|
8 |
AutoModelForCausalLM,
|
9 |
-
GenerationConfig
|
|
|
10 |
)
|
11 |
|
12 |
###############################################################################
|
13 |
-
#
|
14 |
###############################################################################
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
trust_remote_code=True
|
24 |
-
)
|
25 |
-
print("tokenizer_test =", tokenizer_test)
|
26 |
-
print("type(tokenizer_test) =", type(tokenizer_test))
|
27 |
-
except Exception as e:
|
28 |
-
print("AutoTokenizer failed with exception:", e)
|
29 |
-
raise e
|
30 |
-
|
31 |
-
# If it's returning False, bail out early so we don't crash below
|
32 |
-
if tokenizer_test is False:
|
33 |
-
raise ValueError("AutoTokenizer returned False, meaning it failed to load properly.")
|
34 |
|
35 |
###############################################################################
|
36 |
-
#
|
37 |
###############################################################################
|
38 |
-
|
39 |
-
|
40 |
-
MODEL_ID,
|
41 |
-
use_fast=False,
|
42 |
-
trust_remote_code=True
|
43 |
-
)
|
44 |
|
45 |
-
|
46 |
-
if getattr(tokenizer, "pad_token_id", None) is None:
|
47 |
-
tokenizer.pad_token_id = getattr(tokenizer, "eos_token_id", None)
|
48 |
|
49 |
###############################################################################
|
50 |
-
#
|
51 |
###############################################################################
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
)
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
###############################################################################
|
61 |
-
#
|
62 |
###############################################################################
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
###############################################################################
|
72 |
-
#
|
73 |
###############################################################################
|
74 |
-
def
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
for msg in messages:
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
88 |
gen_config = GenerationConfig(
|
89 |
temperature=temperature,
|
90 |
top_p=top_p,
|
@@ -92,52 +252,263 @@ def predict(messages, temperature, top_p, max_new_tokens):
|
|
92 |
repetition_penalty=1.1,
|
93 |
max_new_tokens=max_new_tokens,
|
94 |
)
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
###############################################################################
|
104 |
-
#
|
105 |
###############################################################################
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
-
|
|
|
|
1 |
# app.py
|
2 |
|
3 |
+
import os
|
4 |
+
import gc
|
5 |
+
import logging
|
6 |
+
import traceback
|
7 |
+
import time
|
8 |
import transformers
|
9 |
import torch
|
10 |
import gradio as gr
|
11 |
from transformers import (
|
12 |
AutoTokenizer,
|
13 |
AutoModelForCausalLM,
|
14 |
+
GenerationConfig,
|
15 |
+
BitsAndBytesConfig
|
16 |
)
|
17 |
|
18 |
###############################################################################
|
19 |
+
# Configure Logging
|
20 |
###############################################################################
|
21 |
+
logging.basicConfig(
|
22 |
+
level=logging.INFO,
|
23 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
24 |
+
handlers=[
|
25 |
+
logging.StreamHandler()
|
26 |
+
]
|
27 |
+
)
|
28 |
+
logger = logging.getLogger("DamageScan-App")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
###############################################################################
|
31 |
+
# Model Configuration
|
32 |
###############################################################################
|
33 |
+
MODEL_ID = "FrameRateTech/DamageScan-llama-8b-instruct-merged"
|
34 |
+
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
|
|
|
|
|
|
|
|
35 |
|
36 |
+
If a question is not clear or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
|
|
|
|
|
37 |
|
38 |
###############################################################################
|
39 |
+
# Device Configuration and Memory Management
|
40 |
###############################################################################
|
41 |
+
def get_device_info():
|
42 |
+
"""Log information about available devices and memory"""
|
43 |
+
device_info = {
|
44 |
+
"cuda_available": torch.cuda.is_available(),
|
45 |
+
"device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
|
46 |
+
"mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
47 |
+
}
|
48 |
+
|
49 |
+
if device_info["cuda_available"] and device_info["device_count"] > 0:
|
50 |
+
device_info["cuda_device_name"] = torch.cuda.get_device_name(0)
|
51 |
+
device_info["cuda_device_mem_total"] = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
52 |
+
device_info["cuda_device_mem_reserved"] = torch.cuda.memory_reserved(0) / (1024**3)
|
53 |
+
device_info["cuda_device_mem_allocated"] = torch.cuda.memory_allocated(0) / (1024**3)
|
54 |
+
|
55 |
+
logger.info(f"Device information: {device_info}")
|
56 |
+
return device_info
|
57 |
+
|
58 |
+
def optimize_memory():
|
59 |
+
"""Optimize memory usage by clearing caches and forcing garbage collection"""
|
60 |
+
if torch.cuda.is_available():
|
61 |
+
torch.cuda.empty_cache()
|
62 |
+
gc.collect()
|
63 |
+
logger.info("Memory optimized: caches cleared and garbage collected")
|
64 |
|
65 |
###############################################################################
|
66 |
+
# Model Loading with Error Handling
|
67 |
###############################################################################
|
68 |
+
def load_model_and_tokenizer():
|
69 |
+
"""Load the model and tokenizer with comprehensive error handling and logging"""
|
70 |
+
logger.info(f"Loading model: {MODEL_ID}")
|
71 |
+
logger.info(f"Transformers version: {transformers.__version__}")
|
72 |
+
logger.info(f"PyTorch version: {torch.__version__}")
|
73 |
+
|
74 |
+
device_info = get_device_info()
|
75 |
+
|
76 |
+
# Determine quantization settings based on available hardware
|
77 |
+
load_in_4bit = False
|
78 |
+
load_in_8bit = False
|
79 |
+
|
80 |
+
if device_info["cuda_available"]:
|
81 |
+
# On ZEROGPU environments, 4-bit quantization helps fit the model in memory
|
82 |
+
load_in_4bit = True
|
83 |
+
logger.info("Using 4-bit quantization for CUDA device")
|
84 |
+
|
85 |
+
# Configure quantization if needed
|
86 |
+
if load_in_4bit:
|
87 |
+
quantization_config = BitsAndBytesConfig(
|
88 |
+
load_in_4bit=True,
|
89 |
+
bnb_4bit_compute_dtype=torch.float16,
|
90 |
+
bnb_4bit_quant_type="nf4",
|
91 |
+
bnb_4bit_use_double_quant=True
|
92 |
+
)
|
93 |
+
logger.info("Configured 4-bit quantization with NF4 type")
|
94 |
+
elif load_in_8bit:
|
95 |
+
quantization_config = BitsAndBytesConfig(
|
96 |
+
load_in_8bit=True
|
97 |
+
)
|
98 |
+
logger.info("Configured 8-bit quantization")
|
99 |
+
else:
|
100 |
+
quantization_config = None
|
101 |
+
logger.info("No quantization configured, using default precision")
|
102 |
+
|
103 |
+
# Step 1: Load tokenizer with detailed error logging
|
104 |
+
try:
|
105 |
+
logger.info("Loading tokenizer...")
|
106 |
+
tokenizer_start = time.time()
|
107 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
108 |
+
MODEL_ID,
|
109 |
+
use_fast=False,
|
110 |
+
trust_remote_code=True
|
111 |
+
)
|
112 |
+
tokenizer_load_time = time.time() - tokenizer_start
|
113 |
+
logger.info(f"Tokenizer loaded successfully in {tokenizer_load_time:.2f} seconds")
|
114 |
+
logger.info(f"Tokenizer type: {type(tokenizer).__name__}")
|
115 |
+
|
116 |
+
# Log important tokenizer properties
|
117 |
+
tokenizer_info = {
|
118 |
+
"vocab_size": len(tokenizer),
|
119 |
+
"model_max_length": tokenizer.model_max_length,
|
120 |
+
"bos_token": tokenizer.bos_token,
|
121 |
+
"eos_token": tokenizer.eos_token,
|
122 |
+
"has_chat_template": hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None
|
123 |
+
}
|
124 |
+
logger.info(f"Tokenizer properties: {tokenizer_info}")
|
125 |
+
|
126 |
+
# Set pad token if needed
|
127 |
+
if getattr(tokenizer, "pad_token_id", None) is None:
|
128 |
+
logger.info("Pad token not found, setting pad_token_id to eos_token_id")
|
129 |
+
tokenizer.pad_token_id = getattr(tokenizer, "eos_token_id", None)
|
130 |
+
except Exception as e:
|
131 |
+
logger.error(f"Failed to load tokenizer: {str(e)}")
|
132 |
+
logger.error(traceback.format_exc())
|
133 |
+
raise RuntimeError(f"Failed to load tokenizer: {str(e)}")
|
134 |
+
|
135 |
+
# Step 2: Load model with detailed error logging
|
136 |
+
try:
|
137 |
+
logger.info("Loading model...")
|
138 |
+
model_start = time.time()
|
139 |
+
|
140 |
+
# Determine device map strategy
|
141 |
+
if device_info["cuda_available"]:
|
142 |
+
device_map = "auto"
|
143 |
+
torch_dtype = torch.float16
|
144 |
+
logger.info("Using 'auto' device map for CUDA with float16 precision")
|
145 |
+
elif device_info["mps_available"]:
|
146 |
+
device_map = {"": "mps"}
|
147 |
+
torch_dtype = torch.float16
|
148 |
+
logger.info("Using MPS device with float16 precision")
|
149 |
+
else:
|
150 |
+
device_map = {"": "cpu"}
|
151 |
+
torch_dtype = torch.float32
|
152 |
+
logger.info("Using CPU with float32 precision")
|
153 |
+
|
154 |
+
model = AutoModelForCausalLM.from_pretrained(
|
155 |
+
MODEL_ID,
|
156 |
+
torch_dtype=torch_dtype,
|
157 |
+
device_map=device_map,
|
158 |
+
trust_remote_code=True,
|
159 |
+
quantization_config=quantization_config
|
160 |
+
)
|
161 |
+
model.eval()
|
162 |
+
model_load_time = time.time() - model_start
|
163 |
+
logger.info(f"Model loaded successfully in {model_load_time:.2f} seconds")
|
164 |
+
|
165 |
+
# Log model info
|
166 |
+
model_info = {
|
167 |
+
"model_type": model.config.model_type,
|
168 |
+
"hidden_size": model.config.hidden_size,
|
169 |
+
"vocab_size": model.config.vocab_size,
|
170 |
+
"num_hidden_layers": model.config.num_hidden_layers
|
171 |
+
}
|
172 |
+
logger.info(f"Model properties: {model_info}")
|
173 |
+
|
174 |
+
except Exception as e:
|
175 |
+
logger.error(f"Failed to load model: {str(e)}")
|
176 |
+
logger.error(traceback.format_exc())
|
177 |
+
raise RuntimeError(f"Failed to load model: {str(e)}")
|
178 |
+
|
179 |
+
return model, tokenizer
|
180 |
|
181 |
###############################################################################
|
182 |
+
# Chat Formatting and Generation Functions
|
183 |
###############################################################################
|
184 |
+
def format_chat_for_model(messages, tokenizer, system_prompt=DEFAULT_SYSTEM_PROMPT):
|
185 |
+
"""
|
186 |
+
Format chat messages for the model using the tokenizer's chat template if available,
|
187 |
+
or fall back to a manual format for Llama models.
|
188 |
+
"""
|
189 |
+
logger.info(f"Formatting chat with {len(messages)} messages")
|
190 |
+
|
191 |
+
# Prepare messages in the correct format
|
192 |
+
formatted_messages = []
|
193 |
+
|
194 |
+
# Add system message if not already present
|
195 |
+
if messages and messages[0].get("role") != "system":
|
196 |
+
formatted_messages.append({"role": "system", "content": system_prompt})
|
197 |
+
|
198 |
+
# Add user and assistant messages
|
199 |
for msg in messages:
|
200 |
+
role = msg["role"]
|
201 |
+
# Skip system messages if we already added one
|
202 |
+
if role == "system" and formatted_messages and formatted_messages[0]["role"] == "system":
|
203 |
+
continue
|
204 |
+
formatted_messages.append({"role": role, "content": msg["content"]})
|
205 |
+
|
206 |
+
# Use the tokenizer's built-in chat template if available
|
207 |
+
if hasattr(tokenizer, "apply_chat_template") and callable(tokenizer.apply_chat_template):
|
208 |
+
logger.info("Using tokenizer's built-in chat template")
|
209 |
+
try:
|
210 |
+
chat_text = tokenizer.apply_chat_template(
|
211 |
+
formatted_messages,
|
212 |
+
tokenize=False,
|
213 |
+
add_generation_prompt=True
|
214 |
+
)
|
215 |
+
logger.debug(f"Formatted chat using built-in template: {chat_text[:100]}...")
|
216 |
+
return chat_text
|
217 |
+
except Exception as e:
|
218 |
+
logger.warning(f"Failed to apply chat template: {str(e)}")
|
219 |
+
logger.warning("Falling back to manual formatting")
|
220 |
+
|
221 |
+
# Manual fallback format for Llama models
|
222 |
+
logger.info("Using manual chat formatting for Llama model")
|
223 |
+
chat_text = ""
|
224 |
+
for msg in formatted_messages:
|
225 |
+
role = msg["role"]
|
226 |
+
content = msg["content"]
|
227 |
+
if role == "system":
|
228 |
+
chat_text += f"<|start_header_id|>system<|end_header_id|>\n\n{content}<|eot_id|>"
|
229 |
+
elif role == "user":
|
230 |
+
chat_text += f"<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>"
|
231 |
+
elif role == "assistant":
|
232 |
+
chat_text += f"<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>"
|
233 |
+
|
234 |
+
# Add the final assistant header for generation
|
235 |
+
chat_text += "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
236 |
+
|
237 |
+
logger.debug(f"Manually formatted chat: {chat_text[:100]}...")
|
238 |
+
return chat_text
|
239 |
|
240 |
+
def generate_response(model, tokenizer, messages, temperature=0.7, top_p=0.9, max_new_tokens=256, system_prompt=DEFAULT_SYSTEM_PROMPT):
|
241 |
+
"""Generate a response from the model with retry logic and error handling"""
|
242 |
+
logger.info(f"Generating response with temp={temperature}, top_p={top_p}, max_tokens={max_new_tokens}")
|
243 |
+
|
244 |
+
# Format the messages for the model
|
245 |
+
prompt = format_chat_for_model(messages, tokenizer, system_prompt)
|
246 |
+
|
247 |
+
# Configure generation parameters
|
248 |
gen_config = GenerationConfig(
|
249 |
temperature=temperature,
|
250 |
top_p=top_p,
|
|
|
252 |
repetition_penalty=1.1,
|
253 |
max_new_tokens=max_new_tokens,
|
254 |
)
|
255 |
+
|
256 |
+
# Tokenize the input
|
257 |
+
try:
|
258 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
259 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
260 |
+
logger.info(f"Input tokenized to {inputs['input_ids'].shape[1]} tokens")
|
261 |
+
except Exception as e:
|
262 |
+
logger.error(f"Error during tokenization: {str(e)}")
|
263 |
+
return "I encountered an error while processing your message. Please try again."
|
264 |
+
|
265 |
+
# Generate with retry logic
|
266 |
+
max_retries = 3
|
267 |
+
retry_count = 0
|
268 |
+
|
269 |
+
while retry_count < max_retries:
|
270 |
+
try:
|
271 |
+
# Run the generation
|
272 |
+
generation_start = time.time()
|
273 |
+
with torch.no_grad():
|
274 |
+
output_ids = model.generate(
|
275 |
+
**inputs,
|
276 |
+
generation_config=gen_config,
|
277 |
+
)
|
278 |
+
generation_time = time.time() - generation_start
|
279 |
+
logger.info(f"Generation completed in {generation_time:.2f} seconds")
|
280 |
+
|
281 |
+
# Decode the output
|
282 |
+
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
283 |
+
|
284 |
+
# Extract just the assistant's response
|
285 |
+
assistant_response = ""
|
286 |
+
if hasattr(tokenizer, "apply_chat_template") and callable(tokenizer.apply_chat_template):
|
287 |
+
# Extract assistant's response from the full output
|
288 |
+
if "<|start_header_id|>assistant<|end_header_id|>" in generated_text:
|
289 |
+
parts = generated_text.split("<|start_header_id|>assistant<|end_header_id|>")
|
290 |
+
if len(parts) > 1:
|
291 |
+
assistant_part = parts[-1]
|
292 |
+
if "<|eot_id|>" in assistant_part:
|
293 |
+
assistant_response = assistant_part.split("<|eot_id|>")[0].strip()
|
294 |
+
else:
|
295 |
+
assistant_response = assistant_part.strip()
|
296 |
+
else:
|
297 |
+
# Fall back to removing the prompt
|
298 |
+
assistant_response = generated_text[len(prompt):].strip()
|
299 |
+
else:
|
300 |
+
# Simple extraction method
|
301 |
+
assistant_response = generated_text[len(prompt):].strip()
|
302 |
+
|
303 |
+
logger.info(f"Response extracted, length: {len(assistant_response)} chars")
|
304 |
+
|
305 |
+
# Free up memory
|
306 |
+
del inputs, output_ids
|
307 |
+
optimize_memory()
|
308 |
+
|
309 |
+
return assistant_response
|
310 |
+
|
311 |
+
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
|
312 |
+
retry_count += 1
|
313 |
+
logger.warning(f"Generation attempt {retry_count} failed: {str(e)}")
|
314 |
+
|
315 |
+
if retry_count < max_retries:
|
316 |
+
logger.info(f"Retrying with reduced parameters...")
|
317 |
+
# Reduce parameters to try to fit in memory
|
318 |
+
max_new_tokens = max(64, max_new_tokens // 2)
|
319 |
+
optimize_memory()
|
320 |
+
else:
|
321 |
+
logger.error(f"Failed to generate after {max_retries} attempts")
|
322 |
+
return "I'm sorry, I encountered a resource limitation while generating a response. Please try a shorter message or adjust the generation parameters."
|
323 |
+
|
324 |
+
except Exception as e:
|
325 |
+
logger.error(f"Unexpected error during generation: {str(e)}")
|
326 |
+
logger.error(traceback.format_exc())
|
327 |
+
return "I encountered an unexpected error. Please try again with different parameters."
|
328 |
|
329 |
###############################################################################
|
330 |
+
# Gradio Interface
|
331 |
###############################################################################
|
332 |
+
def build_gradio_interface(model, tokenizer):
|
333 |
+
"""Build and launch the Gradio interface"""
|
334 |
+
logger.info("Building Gradio interface")
|
335 |
+
|
336 |
+
def user_submit(message_history, user_text, temp, top_p, max_tokens, system_message):
|
337 |
+
"""Handle user message submission"""
|
338 |
+
logger.info(f"Received user message: '{user_text[:50]}...' (length: {len(user_text)})")
|
339 |
+
|
340 |
+
if not user_text.strip():
|
341 |
+
logger.warning("Empty user message, skipping processing")
|
342 |
+
return message_history, ""
|
343 |
+
|
344 |
+
try:
|
345 |
+
# Add user message to history
|
346 |
+
if not message_history:
|
347 |
+
# Start with system message if this is the first message
|
348 |
+
message_history = [{"role": "system", "content": system_message}]
|
349 |
+
|
350 |
+
message_history.append({"role": "user", "content": user_text})
|
351 |
+
|
352 |
+
# Generate response
|
353 |
+
assistant_response = generate_response(
|
354 |
+
model,
|
355 |
+
tokenizer,
|
356 |
+
message_history,
|
357 |
+
temperature=temp,
|
358 |
+
top_p=top_p,
|
359 |
+
max_new_tokens=max_tokens,
|
360 |
+
system_prompt=system_message
|
361 |
)
|
362 |
+
|
363 |
+
# Add assistant response to history
|
364 |
+
message_history.append({"role": "assistant", "content": assistant_response})
|
365 |
+
logger.info(f"Added assistant response (length: {len(assistant_response)})")
|
366 |
+
|
367 |
+
# Optimize memory after generation
|
368 |
+
optimize_memory()
|
369 |
+
|
370 |
+
return message_history, ""
|
371 |
+
|
372 |
+
except Exception as e:
|
373 |
+
logger.error(f"Error in user_submit: {str(e)}")
|
374 |
+
logger.error(traceback.format_exc())
|
375 |
+
|
376 |
+
# Return original message history plus error message
|
377 |
+
error_msg = "I encountered an error processing your request. Please try again."
|
378 |
+
if not message_history:
|
379 |
+
message_history = []
|
380 |
+
message_history.append({"role": "user", "content": user_text})
|
381 |
+
message_history.append({"role": "assistant", "content": error_msg})
|
382 |
+
|
383 |
+
return message_history, ""
|
384 |
+
|
385 |
+
def clear_chat():
|
386 |
+
"""Clear the chat history"""
|
387 |
+
logger.info("Clearing chat history")
|
388 |
+
optimize_memory()
|
389 |
+
return [], ""
|
390 |
+
|
391 |
+
# Define the Gradio interface
|
392 |
+
with gr.Blocks(css="footer {visibility: hidden}") as demo:
|
393 |
+
gr.Markdown("<h1 align='center'>DamageScan 8B Instruct Chatbot</h1>")
|
394 |
+
gr.Markdown("<p align='center'>Powered by FrameRateTech/DamageScan-llama-8b-instruct-merged</p>")
|
395 |
+
|
396 |
+
with gr.Row():
|
397 |
+
with gr.Column(scale=3):
|
398 |
+
chatbot = gr.Chatbot(
|
399 |
+
label="Chat History",
|
400 |
+
height=600,
|
401 |
+
avatar_images=(None, "https://huggingface.co/spaces/FrameRateTech/DamageScan-8b-instruct-chat/resolve/main/avatar.png"),
|
402 |
+
)
|
403 |
+
|
404 |
+
with gr.Row():
|
405 |
+
with gr.Column(scale=8):
|
406 |
+
user_input = gr.Textbox(
|
407 |
+
lines=3,
|
408 |
+
label="Your Message",
|
409 |
+
placeholder="Type your message here...",
|
410 |
+
show_copy_button=True
|
411 |
+
)
|
412 |
+
with gr.Column(scale=1, min_width=50):
|
413 |
+
submit_btn = gr.Button("Send", variant="primary")
|
414 |
+
clear_btn = gr.Button("Clear Chat")
|
415 |
+
|
416 |
+
with gr.Column(scale=1):
|
417 |
+
gr.Markdown("### System Prompt")
|
418 |
+
system_prompt_input = gr.Textbox(
|
419 |
+
lines=5,
|
420 |
+
label="System Instructions",
|
421 |
+
value=DEFAULT_SYSTEM_PROMPT,
|
422 |
+
show_copy_button=True
|
423 |
+
)
|
424 |
+
|
425 |
+
gr.Markdown("### Generation Settings")
|
426 |
+
temperature_slider = gr.Slider(
|
427 |
+
minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature",
|
428 |
+
info="Higher values make output more random, lower values more deterministic"
|
429 |
+
)
|
430 |
+
top_p_slider = gr.Slider(
|
431 |
+
minimum=0.5, maximum=1.0, value=0.9, step=0.05, label="Top-p",
|
432 |
+
info="Controls diversity via nucleus sampling"
|
433 |
+
)
|
434 |
+
max_tokens_slider = gr.Slider(
|
435 |
+
minimum=64, maximum=1024, value=256, step=64, label="Max New Tokens",
|
436 |
+
info="Maximum length of generated response"
|
437 |
+
)
|
438 |
+
|
439 |
+
gr.Markdown("### Tips")
|
440 |
+
gr.Markdown("""
|
441 |
+
* Lower temperature (0.1-0.3) for factual responses
|
442 |
+
* Higher temperature (0.7-1.0) for creative tasks
|
443 |
+
* Reduce max tokens if responses are too long
|
444 |
+
* Clear chat if the model gets confused
|
445 |
+
""")
|
446 |
|
447 |
+
# Set up event handlers
|
448 |
+
submit_btn.click(
|
449 |
+
user_submit,
|
450 |
+
inputs=[chatbot, user_input, temperature_slider, top_p_slider, max_tokens_slider, system_prompt_input],
|
451 |
+
outputs=[chatbot, user_input],
|
452 |
+
)
|
453 |
+
user_input.submit(
|
454 |
+
user_submit,
|
455 |
+
inputs=[chatbot, user_input, temperature_slider, top_p_slider, max_tokens_slider, system_prompt_input],
|
456 |
+
outputs=[chatbot, user_input],
|
457 |
+
)
|
458 |
+
clear_btn.click(
|
459 |
+
clear_chat,
|
460 |
+
outputs=[chatbot, user_input]
|
461 |
+
)
|
462 |
+
|
463 |
+
# Add example prompts
|
464 |
+
gr.Examples(
|
465 |
+
examples=[
|
466 |
+
["Can you explain how the Large Hadron Collider works?"],
|
467 |
+
["Write a short story about a robot who learns to paint"],
|
468 |
+
["What are three ways to improve productivity when working from home?"],
|
469 |
+
["Explain quantum computing to me like I'm 10 years old"],
|
470 |
+
],
|
471 |
+
inputs=user_input,
|
472 |
+
label="Example Prompts"
|
473 |
+
)
|
474 |
+
|
475 |
+
return demo
|
476 |
|
477 |
+
###############################################################################
|
478 |
+
# Main Application Logic
|
479 |
+
###############################################################################
|
480 |
+
def main():
|
481 |
+
"""Main application entry point"""
|
482 |
+
try:
|
483 |
+
logger.info("Starting DamageScan 8B Instruct application")
|
484 |
+
logger.info(f"Environment: CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
|
485 |
+
|
486 |
+
# Load model and tokenizer
|
487 |
+
model, tokenizer = load_model_and_tokenizer()
|
488 |
+
|
489 |
+
# Build and launch Gradio interface
|
490 |
+
demo = build_gradio_interface(model, tokenizer)
|
491 |
+
|
492 |
+
# Launch the app
|
493 |
+
logger.info("Launching Gradio interface")
|
494 |
+
demo.queue().launch(
|
495 |
+
share=False,
|
496 |
+
debug=False,
|
497 |
+
show_error=True,
|
498 |
+
favicon_path="https://huggingface.co/spaces/FrameRateTech/DamageScan-8b-instruct-chat/resolve/main/favicon.ico"
|
499 |
+
)
|
500 |
+
|
501 |
+
except Exception as e:
|
502 |
+
logger.error(f"Application startup failed: {str(e)}")
|
503 |
+
logger.error(traceback.format_exc())
|
504 |
+
|
505 |
+
# Create a minimal fallback UI to show the error
|
506 |
+
with gr.Blocks() as fallback_demo:
|
507 |
+
gr.Markdown("# ⚠️ DamageScan 8B Application Error")
|
508 |
+
gr.Markdown(f"The application encountered an error during startup:\n\n```\n{str(e)}\n```")
|
509 |
+
gr.Markdown("Please check the logs for more details or try again later.")
|
510 |
+
|
511 |
+
fallback_demo.launch()
|
512 |
|
513 |
+
if __name__ == "__main__":
|
514 |
+
main()
|