Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# app.py
|
2 |
|
3 |
import os
|
4 |
import gc
|
@@ -9,11 +9,9 @@ import transformers
|
|
9 |
import torch
|
10 |
import gradio as gr
|
11 |
from transformers import (
|
12 |
-
|
13 |
AutoModelForCausalLM,
|
14 |
-
GenerationConfig
|
15 |
-
BitsAndBytesConfig,
|
16 |
-
LlamaTokenizer # Added direct import for LlamaTokenizer
|
17 |
)
|
18 |
|
19 |
###############################################################################
|
@@ -36,26 +34,16 @@ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant.
|
|
36 |
|
37 |
If a question is not clear or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
###############################################################################
|
40 |
-
#
|
41 |
###############################################################################
|
42 |
-
def get_device_info():
|
43 |
-
"""Log information about available devices and memory"""
|
44 |
-
device_info = {
|
45 |
-
"cuda_available": torch.cuda.is_available(),
|
46 |
-
"device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
|
47 |
-
"mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
48 |
-
}
|
49 |
-
|
50 |
-
if device_info["cuda_available"] and device_info["device_count"] > 0:
|
51 |
-
device_info["cuda_device_name"] = torch.cuda.get_device_name(0)
|
52 |
-
device_info["cuda_device_mem_total"] = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
53 |
-
device_info["cuda_device_mem_reserved"] = torch.cuda.memory_reserved(0) / (1024**3)
|
54 |
-
device_info["cuda_device_mem_allocated"] = torch.cuda.memory_allocated(0) / (1024**3)
|
55 |
-
|
56 |
-
logger.info(f"Device information: {device_info}")
|
57 |
-
return device_info
|
58 |
-
|
59 |
def optimize_memory():
|
60 |
"""Optimize memory usage by clearing caches and forcing garbage collection"""
|
61 |
if torch.cuda.is_available():
|
@@ -64,106 +52,78 @@ def optimize_memory():
|
|
64 |
logger.info("Memory optimized: caches cleared and garbage collected")
|
65 |
|
66 |
###############################################################################
|
67 |
-
#
|
68 |
###############################################################################
|
69 |
-
|
70 |
-
"""
|
71 |
-
logger.info(f"Loading model: {MODEL_ID}")
|
72 |
-
logger.info(f"Transformers version: {transformers.__version__}")
|
73 |
-
logger.info(f"PyTorch version: {torch.__version__}")
|
74 |
-
|
75 |
-
device_info = get_device_info()
|
76 |
-
|
77 |
-
# Determine quantization settings based on available hardware
|
78 |
-
load_in_4bit = False
|
79 |
-
load_in_8bit = False
|
80 |
-
|
81 |
-
if device_info["cuda_available"]:
|
82 |
-
# On ZEROGPU environments, 4-bit quantization helps fit the model in memory
|
83 |
-
load_in_4bit = True
|
84 |
-
logger.info("Using 4-bit quantization for CUDA device")
|
85 |
-
|
86 |
-
# Configure quantization if needed
|
87 |
-
if load_in_4bit:
|
88 |
-
quantization_config = BitsAndBytesConfig(
|
89 |
-
load_in_4bit=True,
|
90 |
-
bnb_4bit_compute_dtype=torch.float16,
|
91 |
-
bnb_4bit_quant_type="nf4",
|
92 |
-
bnb_4bit_use_double_quant=True
|
93 |
-
)
|
94 |
-
logger.info("Configured 4-bit quantization with NF4 type")
|
95 |
-
elif load_in_8bit:
|
96 |
-
quantization_config = BitsAndBytesConfig(
|
97 |
-
load_in_8bit=True
|
98 |
-
)
|
99 |
-
logger.info("Configured 8-bit quantization")
|
100 |
-
else:
|
101 |
-
quantization_config = None
|
102 |
-
logger.info("No quantization configured, using default precision")
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
108 |
|
109 |
-
#
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
use_fast=False,
|
115 |
-
trust_remote_code=True
|
116 |
-
)
|
117 |
-
logger.info("Successfully loaded tokenizer as LlamaTokenizer")
|
118 |
-
except Exception as e:
|
119 |
-
logger.warning(f"Failed to load as LlamaTokenizer: {str(e)}")
|
120 |
-
logger.info("Falling back to AutoTokenizer...")
|
121 |
-
|
122 |
-
# Try with AutoTokenizer but with strict error checking
|
123 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
124 |
-
MODEL_ID,
|
125 |
-
use_fast=False,
|
126 |
-
trust_remote_code=True
|
127 |
-
)
|
128 |
|
129 |
-
#
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
tokenizer = LlamaTokenizer.from_pretrained(
|
136 |
-
"meta-llama/Llama-3.1-8B-Instruct", # Use base model as fallback
|
137 |
-
use_fast=False
|
138 |
-
)
|
139 |
-
logger.info("Created fallback tokenizer from base model")
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
144 |
|
145 |
-
#
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
"
|
154 |
-
"model_max_length": tokenizer.model_max_length if hasattr(tokenizer, "model_max_length") else "unknown",
|
155 |
-
"bos_token": tokenizer.bos_token if hasattr(tokenizer, "bos_token") else "unknown",
|
156 |
-
"eos_token": tokenizer.eos_token if hasattr(tokenizer, "eos_token") else "unknown",
|
157 |
-
"has_chat_template": hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None
|
158 |
}
|
159 |
-
logger.info(f"Tokenizer properties: {tokenizer_info}")
|
160 |
-
except Exception as e:
|
161 |
-
logger.warning(f"Could not log all tokenizer properties: {str(e)}")
|
162 |
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
# Step 2: Load model with detailed error logging
|
169 |
try:
|
@@ -184,25 +144,28 @@ def load_model_and_tokenizer():
|
|
184 |
torch_dtype = torch.float32
|
185 |
logger.info("Using CPU with float32 precision")
|
186 |
|
|
|
187 |
model = AutoModelForCausalLM.from_pretrained(
|
188 |
MODEL_ID,
|
189 |
torch_dtype=torch_dtype,
|
190 |
device_map=device_map,
|
191 |
trust_remote_code=True,
|
192 |
-
quantization_config=quantization_config
|
193 |
)
|
194 |
model.eval()
|
195 |
model_load_time = time.time() - model_start
|
196 |
logger.info(f"Model loaded successfully in {model_load_time:.2f} seconds")
|
197 |
|
198 |
# Log model info
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
|
|
|
|
|
|
206 |
|
207 |
except Exception as e:
|
208 |
logger.error(f"Failed to load model: {str(e)}")
|
@@ -214,66 +177,33 @@ def load_model_and_tokenizer():
|
|
214 |
###############################################################################
|
215 |
# Chat Formatting and Generation Functions
|
216 |
###############################################################################
|
217 |
-
def format_chat_for_model(messages,
|
218 |
-
"""
|
219 |
-
Format chat messages for the model using the tokenizer's chat template if available,
|
220 |
-
or fall back to a manual format for Llama models.
|
221 |
-
"""
|
222 |
logger.info(f"Formatting chat with {len(messages)} messages")
|
223 |
|
224 |
-
#
|
225 |
-
|
226 |
|
227 |
# Add system message if not already present
|
228 |
-
if messages
|
229 |
-
|
230 |
|
231 |
-
# Add
|
232 |
for msg in messages:
|
233 |
-
role = msg["role"]
|
234 |
-
# Skip system messages if we already added one
|
235 |
-
if role == "system" and formatted_messages and formatted_messages[0]["role"] == "system":
|
236 |
-
continue
|
237 |
-
formatted_messages.append({"role": role, "content": msg["content"]})
|
238 |
-
|
239 |
-
# Try different approaches to format the chat
|
240 |
-
|
241 |
-
# Approach 1: Use the tokenizer's built-in chat template if available
|
242 |
-
if hasattr(tokenizer, "apply_chat_template") and callable(getattr(tokenizer, "apply_chat_template")):
|
243 |
-
logger.info("Using tokenizer's built-in chat template")
|
244 |
-
try:
|
245 |
-
chat_text = tokenizer.apply_chat_template(
|
246 |
-
formatted_messages,
|
247 |
-
tokenize=False,
|
248 |
-
add_generation_prompt=True
|
249 |
-
)
|
250 |
-
logger.debug(f"Formatted chat using built-in template: {chat_text[:100]}...")
|
251 |
-
return chat_text
|
252 |
-
except Exception as e:
|
253 |
-
logger.warning(f"Failed to apply chat template: {str(e)}")
|
254 |
-
logger.warning("Falling back to manual formatting")
|
255 |
-
|
256 |
-
# Approach 2: Use a Llama 3.1 specific prompt format based on the config files we've seen
|
257 |
-
# This is based on the special tokens in the model's configuration
|
258 |
-
logger.info("Using manual chat formatting for Llama model")
|
259 |
-
|
260 |
-
chat_text = "<|begin_of_text|>"
|
261 |
-
|
262 |
-
for msg in formatted_messages:
|
263 |
role = msg["role"]
|
264 |
content = msg["content"]
|
265 |
|
266 |
if role == "system":
|
267 |
-
chat_text +=
|
268 |
elif role == "user":
|
269 |
-
chat_text +=
|
270 |
elif role == "assistant":
|
271 |
-
chat_text +=
|
272 |
|
273 |
-
# Add
|
274 |
-
chat_text +=
|
275 |
|
276 |
-
logger.
|
277 |
return chat_text
|
278 |
|
279 |
def generate_response(model, tokenizer, messages, temperature=0.7, top_p=0.9, max_new_tokens=256, system_prompt=DEFAULT_SYSTEM_PROMPT):
|
@@ -281,7 +211,7 @@ def generate_response(model, tokenizer, messages, temperature=0.7, top_p=0.9, ma
|
|
281 |
logger.info(f"Generating response with temp={temperature}, top_p={top_p}, max_tokens={max_new_tokens}")
|
282 |
|
283 |
# Format the messages for the model
|
284 |
-
prompt = format_chat_for_model(messages,
|
285 |
|
286 |
# Configure generation parameters
|
287 |
gen_config = GenerationConfig(
|
@@ -290,71 +220,75 @@ def generate_response(model, tokenizer, messages, temperature=0.7, top_p=0.9, ma
|
|
290 |
do_sample=True,
|
291 |
repetition_penalty=1.1,
|
292 |
max_new_tokens=max_new_tokens,
|
|
|
|
|
|
|
293 |
)
|
294 |
|
295 |
-
# Tokenize the input
|
296 |
-
try:
|
297 |
-
inputs = tokenizer(prompt, return_tensors="pt")
|
298 |
-
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
299 |
-
logger.info(f"Input tokenized to {inputs['input_ids'].shape[1]} tokens")
|
300 |
-
except Exception as e:
|
301 |
-
logger.error(f"Error during tokenization: {str(e)}")
|
302 |
-
return "I encountered an error while processing your message. Please try again."
|
303 |
-
|
304 |
# Generate with retry logic
|
305 |
max_retries = 3
|
306 |
retry_count = 0
|
307 |
|
308 |
while retry_count < max_retries:
|
309 |
try:
|
|
|
|
|
|
|
|
|
|
|
310 |
# Run the generation
|
311 |
generation_start = time.time()
|
312 |
with torch.no_grad():
|
313 |
-
|
314 |
**inputs,
|
315 |
generation_config=gen_config,
|
316 |
)
|
317 |
generation_time = time.time() - generation_start
|
318 |
logger.info(f"Generation completed in {generation_time:.2f} seconds")
|
319 |
|
320 |
-
#
|
321 |
-
|
|
|
|
|
|
|
322 |
|
323 |
-
# Extract
|
324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
|
326 |
-
#
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
parts = generated_text.split("<|start_header_id|>assistant<|end_header_id|>")
|
331 |
-
if len(parts) > 1:
|
332 |
-
assistant_part = parts[-1]
|
333 |
-
if "<|eot_id|>" in assistant_part:
|
334 |
-
assistant_response = assistant_part.split("<|eot_id|>")[0].strip()
|
335 |
-
else:
|
336 |
-
assistant_response = assistant_part.strip()
|
337 |
-
# Method 2: Simple extraction based on prompt length
|
338 |
-
else:
|
339 |
-
# This is a fallback - not as accurate but should work in most cases
|
340 |
-
assistant_response = generated_text[len(prompt):].strip()
|
341 |
|
342 |
-
#
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
|
|
|
|
348 |
|
349 |
-
|
350 |
-
if not assistant_response.strip():
|
351 |
-
logger.warning("Empty response detected, using fallback message")
|
352 |
-
assistant_response = "I'm sorry, I couldn't generate a proper response. Please try again with a different question or adjust the generation parameters."
|
353 |
|
354 |
# Free up memory
|
355 |
-
del inputs,
|
356 |
optimize_memory()
|
357 |
|
|
|
|
|
|
|
|
|
358 |
return assistant_response
|
359 |
|
360 |
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
|
@@ -523,23 +457,6 @@ def build_gradio_interface(model, tokenizer):
|
|
523 |
|
524 |
return demo
|
525 |
|
526 |
-
###############################################################################
|
527 |
-
# Simple messaging for testing tokenizer
|
528 |
-
###############################################################################
|
529 |
-
def test_tokenize_function(tokenizer):
|
530 |
-
"""Test function to ensure tokenizer works with a simple input"""
|
531 |
-
try:
|
532 |
-
logger.info("Testing tokenizer with a simple input")
|
533 |
-
test_input = "Hello, how are you today?"
|
534 |
-
encoded = tokenizer(test_input, return_tensors="pt")
|
535 |
-
logger.info(f"Tokenizer test successful: encoded to {encoded['input_ids'].shape[1]} tokens")
|
536 |
-
decoded = tokenizer.decode(encoded["input_ids"][0])
|
537 |
-
logger.info(f"Decoded test: '{decoded}'")
|
538 |
-
return True
|
539 |
-
except Exception as e:
|
540 |
-
logger.error(f"Tokenizer test failed: {str(e)}")
|
541 |
-
return False
|
542 |
-
|
543 |
###############################################################################
|
544 |
# Main Application Logic
|
545 |
###############################################################################
|
@@ -552,11 +469,6 @@ def main():
|
|
552 |
# Load model and tokenizer
|
553 |
model, tokenizer = load_model_and_tokenizer()
|
554 |
|
555 |
-
# Test tokenizer functionality
|
556 |
-
test_result = test_tokenize_function(tokenizer)
|
557 |
-
if not test_result:
|
558 |
-
logger.warning("Tokenizer test failed, but continuing with caution")
|
559 |
-
|
560 |
# Build and launch Gradio interface
|
561 |
demo = build_gradio_interface(model, tokenizer)
|
562 |
|
|
|
1 |
+
# app.py - Minimal Version
|
2 |
|
3 |
import os
|
4 |
import gc
|
|
|
9 |
import torch
|
10 |
import gradio as gr
|
11 |
from transformers import (
|
12 |
+
PreTrainedTokenizerFast,
|
13 |
AutoModelForCausalLM,
|
14 |
+
GenerationConfig
|
|
|
|
|
15 |
)
|
16 |
|
17 |
###############################################################################
|
|
|
34 |
|
35 |
If a question is not clear or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
|
36 |
|
37 |
+
# The special tokens we observed in the model's configuration
|
38 |
+
BOS_TOKEN = "<|begin_of_text|>"
|
39 |
+
EOS_TOKEN = "<|eot_id|>"
|
40 |
+
SYSTEM_START = "<|start_header_id|>system<|end_header_id|>\n\n"
|
41 |
+
USER_START = "<|start_header_id|>user<|end_header_id|>\n\n"
|
42 |
+
ASSISTANT_START = "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
43 |
+
|
44 |
###############################################################################
|
45 |
+
# Memory Management
|
46 |
###############################################################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
def optimize_memory():
|
48 |
"""Optimize memory usage by clearing caches and forcing garbage collection"""
|
49 |
if torch.cuda.is_available():
|
|
|
52 |
logger.info("Memory optimized: caches cleared and garbage collected")
|
53 |
|
54 |
###############################################################################
|
55 |
+
# Custom Tokenizer Class
|
56 |
###############################################################################
|
57 |
+
class MinimalTokenizer:
|
58 |
+
"""A minimal tokenizer implementation that works with basic model I/O"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
+
def __init__(self):
|
61 |
+
logger.info("Initializing MinimalTokenizer")
|
62 |
+
# Use a basic set of special tokens based on the model config
|
63 |
+
self.bos_token = BOS_TOKEN
|
64 |
+
self.eos_token = EOS_TOKEN
|
65 |
+
self.pad_token = EOS_TOKEN
|
66 |
|
67 |
+
# Map tokens to ids (using values from the model config)
|
68 |
+
self.token_to_id = {
|
69 |
+
BOS_TOKEN: 128000, # Based on config.json
|
70 |
+
EOS_TOKEN: 128009, # Based on config.json
|
71 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
+
# For logging
|
74 |
+
logger.info(f"MinimalTokenizer initialized with special tokens: {self.token_to_id}")
|
75 |
+
|
76 |
+
def __call__(self, text, return_tensors=None):
|
77 |
+
"""Tokenize text using the model directly"""
|
78 |
+
logger.info(f"Tokenizing text (length: {len(text)})")
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
+
# Create inputs for the model - we'll let the model tokenize internally
|
81 |
+
inputs = {
|
82 |
+
"text": text,
|
83 |
+
}
|
84 |
|
85 |
+
# If return_tensors is specified, create a dummy tensor
|
86 |
+
# The model will handle tokenization internally
|
87 |
+
if return_tensors == "pt":
|
88 |
+
# Create a dummy input_ids tensor with the BOS token
|
89 |
+
# The actual tokenization will happen inside the model
|
90 |
+
dummy_input_ids = torch.tensor([[self.token_to_id[self.bos_token]]])
|
91 |
+
inputs = {
|
92 |
+
"input_ids": dummy_input_ids,
|
93 |
+
"_text": text, # Store the text for the model to use
|
|
|
|
|
|
|
|
|
94 |
}
|
|
|
|
|
|
|
95 |
|
96 |
+
return inputs
|
97 |
+
|
98 |
+
def decode(self, token_ids, skip_special_tokens=True):
|
99 |
+
"""Dummy decode function - the model will handle decoding"""
|
100 |
+
# This is just a placeholder - the model will decode internally
|
101 |
+
# For logging purposes
|
102 |
+
logger.info(f"Decoding token_ids (shape: {token_ids.shape if hasattr(token_ids, 'shape') else 'N/A'})")
|
103 |
+
|
104 |
+
# We'll get the raw output from the model and handle it specially
|
105 |
+
# in the generation function
|
106 |
+
return ""
|
107 |
+
|
108 |
+
###############################################################################
|
109 |
+
# Model Loading with Error Handling
|
110 |
+
###############################################################################
|
111 |
+
def load_model_and_tokenizer():
|
112 |
+
"""Load the model with comprehensive error handling and logging"""
|
113 |
+
logger.info(f"Loading model: {MODEL_ID}")
|
114 |
+
logger.info(f"Transformers version: {transformers.__version__}")
|
115 |
+
logger.info(f"PyTorch version: {torch.__version__}")
|
116 |
+
|
117 |
+
# Check available devices
|
118 |
+
device_info = {
|
119 |
+
"cuda_available": torch.cuda.is_available(),
|
120 |
+
"device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
|
121 |
+
"mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
122 |
+
}
|
123 |
+
logger.info(f"Device information: {device_info}")
|
124 |
+
|
125 |
+
# Create minimal tokenizer
|
126 |
+
tokenizer = MinimalTokenizer()
|
127 |
|
128 |
# Step 2: Load model with detailed error logging
|
129 |
try:
|
|
|
144 |
torch_dtype = torch.float32
|
145 |
logger.info("Using CPU with float32 precision")
|
146 |
|
147 |
+
# Load the model
|
148 |
model = AutoModelForCausalLM.from_pretrained(
|
149 |
MODEL_ID,
|
150 |
torch_dtype=torch_dtype,
|
151 |
device_map=device_map,
|
152 |
trust_remote_code=True,
|
|
|
153 |
)
|
154 |
model.eval()
|
155 |
model_load_time = time.time() - model_start
|
156 |
logger.info(f"Model loaded successfully in {model_load_time:.2f} seconds")
|
157 |
|
158 |
# Log model info
|
159 |
+
try:
|
160 |
+
model_info = {
|
161 |
+
"model_type": model.config.model_type,
|
162 |
+
"hidden_size": model.config.hidden_size,
|
163 |
+
"vocab_size": model.config.vocab_size,
|
164 |
+
"num_hidden_layers": model.config.num_hidden_layers
|
165 |
+
}
|
166 |
+
logger.info(f"Model properties: {model_info}")
|
167 |
+
except Exception as e:
|
168 |
+
logger.warning(f"Could not log all model properties: {str(e)}")
|
169 |
|
170 |
except Exception as e:
|
171 |
logger.error(f"Failed to load model: {str(e)}")
|
|
|
177 |
###############################################################################
|
178 |
# Chat Formatting and Generation Functions
|
179 |
###############################################################################
|
180 |
+
def format_chat_for_model(messages, system_prompt=DEFAULT_SYSTEM_PROMPT):
|
181 |
+
"""Format chat messages using the special tokens from model configuration"""
|
|
|
|
|
|
|
182 |
logger.info(f"Formatting chat with {len(messages)} messages")
|
183 |
|
184 |
+
# Start with BOS token
|
185 |
+
chat_text = BOS_TOKEN
|
186 |
|
187 |
# Add system message if not already present
|
188 |
+
if not messages or messages[0].get("role") != "system":
|
189 |
+
chat_text += SYSTEM_START + system_prompt + EOS_TOKEN
|
190 |
|
191 |
+
# Add all messages in the correct format
|
192 |
for msg in messages:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
role = msg["role"]
|
194 |
content = msg["content"]
|
195 |
|
196 |
if role == "system":
|
197 |
+
chat_text += SYSTEM_START + content + EOS_TOKEN
|
198 |
elif role == "user":
|
199 |
+
chat_text += USER_START + content + EOS_TOKEN
|
200 |
elif role == "assistant":
|
201 |
+
chat_text += ASSISTANT_START + content + EOS_TOKEN
|
202 |
|
203 |
+
# Add final assistant header for the model to continue
|
204 |
+
chat_text += ASSISTANT_START
|
205 |
|
206 |
+
logger.info(f"Formatted chat text (length: {len(chat_text)})")
|
207 |
return chat_text
|
208 |
|
209 |
def generate_response(model, tokenizer, messages, temperature=0.7, top_p=0.9, max_new_tokens=256, system_prompt=DEFAULT_SYSTEM_PROMPT):
|
|
|
211 |
logger.info(f"Generating response with temp={temperature}, top_p={top_p}, max_tokens={max_new_tokens}")
|
212 |
|
213 |
# Format the messages for the model
|
214 |
+
prompt = format_chat_for_model(messages, system_prompt)
|
215 |
|
216 |
# Configure generation parameters
|
217 |
gen_config = GenerationConfig(
|
|
|
220 |
do_sample=True,
|
221 |
repetition_penalty=1.1,
|
222 |
max_new_tokens=max_new_tokens,
|
223 |
+
pad_token_id=tokenizer.token_to_id[tokenizer.pad_token],
|
224 |
+
bos_token_id=tokenizer.token_to_id[tokenizer.bos_token],
|
225 |
+
eos_token_id=tokenizer.token_to_id[tokenizer.eos_token],
|
226 |
)
|
227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
# Generate with retry logic
|
229 |
max_retries = 3
|
230 |
retry_count = 0
|
231 |
|
232 |
while retry_count < max_retries:
|
233 |
try:
|
234 |
+
# Tokenize with dummy tensors - the model will handle the actual text
|
235 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
236 |
+
inputs["text"] = prompt # Store the actual text
|
237 |
+
inputs = {k: v.to(model.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
|
238 |
+
|
239 |
# Run the generation
|
240 |
generation_start = time.time()
|
241 |
with torch.no_grad():
|
242 |
+
outputs = model.generate(
|
243 |
**inputs,
|
244 |
generation_config=gen_config,
|
245 |
)
|
246 |
generation_time = time.time() - generation_start
|
247 |
logger.info(f"Generation completed in {generation_time:.2f} seconds")
|
248 |
|
249 |
+
# Extract just the assistant's response using string operations
|
250 |
+
# This is the key part - the model's output is processed as a string, not tokens
|
251 |
+
# Split on the last occurrence of our custom beginning of assistant text
|
252 |
+
# We trust the model to format the output correctly
|
253 |
+
full_text = prompt # Start with our prompt
|
254 |
|
255 |
+
# Extract actual new text from model's output
|
256 |
+
# The output might be unpredictable, so we need to be careful here
|
257 |
+
try:
|
258 |
+
# Try to get string representation of the output
|
259 |
+
output_text = "".join([chr(id) for id in outputs[0].tolist()])
|
260 |
+
# Remove initial prompt text to get just the model's generation
|
261 |
+
# Add this to the full text
|
262 |
+
full_text += output_text
|
263 |
+
except Exception as e:
|
264 |
+
logger.warning(f"Could not process model output as expected: {str(e)}")
|
265 |
+
# In case of failure, produce a simple response
|
266 |
+
full_text += "I apologize, but I'm having trouble generating a response."
|
267 |
|
268 |
+
# Extract just the final assistant's response
|
269 |
+
try:
|
270 |
+
parts = full_text.split(ASSISTANT_START)
|
271 |
+
assistant_part = parts[-1] # Get the last assistant part
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
|
273 |
+
# Remove any trailing EOS token
|
274 |
+
if EOS_TOKEN in assistant_part:
|
275 |
+
assistant_response = assistant_part.split(EOS_TOKEN)[0].strip()
|
276 |
+
else:
|
277 |
+
assistant_response = assistant_part.strip()
|
278 |
+
except Exception as e:
|
279 |
+
logger.warning(f"Error extracting assistant response: {str(e)}")
|
280 |
+
assistant_response = "I apologize, but I'm having trouble generating a proper response."
|
281 |
|
282 |
+
logger.info(f"Extracted assistant response (length: {len(assistant_response)})")
|
|
|
|
|
|
|
283 |
|
284 |
# Free up memory
|
285 |
+
del inputs, outputs
|
286 |
optimize_memory()
|
287 |
|
288 |
+
# Fallback if we get an empty response
|
289 |
+
if not assistant_response:
|
290 |
+
assistant_response = "I apologize, but I couldn't generate a response. Please try again."
|
291 |
+
|
292 |
return assistant_response
|
293 |
|
294 |
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
|
|
|
457 |
|
458 |
return demo
|
459 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
460 |
###############################################################################
|
461 |
# Main Application Logic
|
462 |
###############################################################################
|
|
|
469 |
# Load model and tokenizer
|
470 |
model, tokenizer = load_model_and_tokenizer()
|
471 |
|
|
|
|
|
|
|
|
|
|
|
472 |
# Build and launch Gradio interface
|
473 |
demo = build_gradio_interface(model, tokenizer)
|
474 |
|