FrameRateTech commited on
Commit
e884311
·
verified ·
1 Parent(s): 5e75927

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -232
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py
2
 
3
  import os
4
  import gc
@@ -9,11 +9,9 @@ import transformers
9
  import torch
10
  import gradio as gr
11
  from transformers import (
12
- AutoTokenizer,
13
  AutoModelForCausalLM,
14
- GenerationConfig,
15
- BitsAndBytesConfig,
16
- LlamaTokenizer # Added direct import for LlamaTokenizer
17
  )
18
 
19
  ###############################################################################
@@ -36,26 +34,16 @@ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant.
36
 
37
  If a question is not clear or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
38
 
 
 
 
 
 
 
 
39
  ###############################################################################
40
- # Device Configuration and Memory Management
41
  ###############################################################################
42
- def get_device_info():
43
- """Log information about available devices and memory"""
44
- device_info = {
45
- "cuda_available": torch.cuda.is_available(),
46
- "device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
47
- "mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
48
- }
49
-
50
- if device_info["cuda_available"] and device_info["device_count"] > 0:
51
- device_info["cuda_device_name"] = torch.cuda.get_device_name(0)
52
- device_info["cuda_device_mem_total"] = torch.cuda.get_device_properties(0).total_memory / (1024**3)
53
- device_info["cuda_device_mem_reserved"] = torch.cuda.memory_reserved(0) / (1024**3)
54
- device_info["cuda_device_mem_allocated"] = torch.cuda.memory_allocated(0) / (1024**3)
55
-
56
- logger.info(f"Device information: {device_info}")
57
- return device_info
58
-
59
  def optimize_memory():
60
  """Optimize memory usage by clearing caches and forcing garbage collection"""
61
  if torch.cuda.is_available():
@@ -64,106 +52,78 @@ def optimize_memory():
64
  logger.info("Memory optimized: caches cleared and garbage collected")
65
 
66
  ###############################################################################
67
- # Model Loading with Error Handling
68
  ###############################################################################
69
- def load_model_and_tokenizer():
70
- """Load the model and tokenizer with comprehensive error handling and logging"""
71
- logger.info(f"Loading model: {MODEL_ID}")
72
- logger.info(f"Transformers version: {transformers.__version__}")
73
- logger.info(f"PyTorch version: {torch.__version__}")
74
-
75
- device_info = get_device_info()
76
-
77
- # Determine quantization settings based on available hardware
78
- load_in_4bit = False
79
- load_in_8bit = False
80
-
81
- if device_info["cuda_available"]:
82
- # On ZEROGPU environments, 4-bit quantization helps fit the model in memory
83
- load_in_4bit = True
84
- logger.info("Using 4-bit quantization for CUDA device")
85
-
86
- # Configure quantization if needed
87
- if load_in_4bit:
88
- quantization_config = BitsAndBytesConfig(
89
- load_in_4bit=True,
90
- bnb_4bit_compute_dtype=torch.float16,
91
- bnb_4bit_quant_type="nf4",
92
- bnb_4bit_use_double_quant=True
93
- )
94
- logger.info("Configured 4-bit quantization with NF4 type")
95
- elif load_in_8bit:
96
- quantization_config = BitsAndBytesConfig(
97
- load_in_8bit=True
98
- )
99
- logger.info("Configured 8-bit quantization")
100
- else:
101
- quantization_config = None
102
- logger.info("No quantization configured, using default precision")
103
 
104
- # Step 1: Load tokenizer with direct class instantiation for Llama models
105
- try:
106
- logger.info("Loading tokenizer...")
107
- tokenizer_start = time.time()
 
 
108
 
109
- # First, try loading directly as a LlamaTokenizer instead of using AutoTokenizer
110
- try:
111
- logger.info("Attempting to load as LlamaTokenizer...")
112
- tokenizer = LlamaTokenizer.from_pretrained(
113
- MODEL_ID,
114
- use_fast=False,
115
- trust_remote_code=True
116
- )
117
- logger.info("Successfully loaded tokenizer as LlamaTokenizer")
118
- except Exception as e:
119
- logger.warning(f"Failed to load as LlamaTokenizer: {str(e)}")
120
- logger.info("Falling back to AutoTokenizer...")
121
-
122
- # Try with AutoTokenizer but with strict error checking
123
- tokenizer = AutoTokenizer.from_pretrained(
124
- MODEL_ID,
125
- use_fast=False,
126
- trust_remote_code=True
127
- )
128
 
129
- # Check if tokenizer is a valid object
130
- if tokenizer is None or isinstance(tokenizer, bool):
131
- logger.error(f"Tokenizer loaded as {type(tokenizer).__name__} (value: {tokenizer})")
132
- logger.info("Attempting to create a basic LlamaTokenizer...")
133
-
134
- # Last resort: Create a basic LlamaTokenizer with default config
135
- tokenizer = LlamaTokenizer.from_pretrained(
136
- "meta-llama/Llama-3.1-8B-Instruct", # Use base model as fallback
137
- use_fast=False
138
- )
139
- logger.info("Created fallback tokenizer from base model")
140
 
141
- tokenizer_load_time = time.time() - tokenizer_start
142
- logger.info(f"Tokenizer loaded successfully in {tokenizer_load_time:.2f} seconds")
143
- logger.info(f"Tokenizer type: {type(tokenizer).__name__}")
 
144
 
145
- # Set pad token if needed
146
- if getattr(tokenizer, "pad_token_id", None) is None:
147
- logger.info("Pad token not found, setting pad_token_id to eos_token_id")
148
- tokenizer.pad_token_id = getattr(tokenizer, "eos_token_id", None)
149
-
150
- # Log important tokenizer properties if possible
151
- try:
152
- tokenizer_info = {
153
- "vocab_size": len(tokenizer.get_vocab()) if hasattr(tokenizer, "get_vocab") else "unknown",
154
- "model_max_length": tokenizer.model_max_length if hasattr(tokenizer, "model_max_length") else "unknown",
155
- "bos_token": tokenizer.bos_token if hasattr(tokenizer, "bos_token") else "unknown",
156
- "eos_token": tokenizer.eos_token if hasattr(tokenizer, "eos_token") else "unknown",
157
- "has_chat_template": hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None
158
  }
159
- logger.info(f"Tokenizer properties: {tokenizer_info}")
160
- except Exception as e:
161
- logger.warning(f"Could not log all tokenizer properties: {str(e)}")
162
 
163
- except Exception as e:
164
- logger.error(f"Failed to load tokenizer: {str(e)}")
165
- logger.error(traceback.format_exc())
166
- raise RuntimeError(f"Failed to load tokenizer: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  # Step 2: Load model with detailed error logging
169
  try:
@@ -184,25 +144,28 @@ def load_model_and_tokenizer():
184
  torch_dtype = torch.float32
185
  logger.info("Using CPU with float32 precision")
186
 
 
187
  model = AutoModelForCausalLM.from_pretrained(
188
  MODEL_ID,
189
  torch_dtype=torch_dtype,
190
  device_map=device_map,
191
  trust_remote_code=True,
192
- quantization_config=quantization_config
193
  )
194
  model.eval()
195
  model_load_time = time.time() - model_start
196
  logger.info(f"Model loaded successfully in {model_load_time:.2f} seconds")
197
 
198
  # Log model info
199
- model_info = {
200
- "model_type": model.config.model_type,
201
- "hidden_size": model.config.hidden_size,
202
- "vocab_size": model.config.vocab_size,
203
- "num_hidden_layers": model.config.num_hidden_layers
204
- }
205
- logger.info(f"Model properties: {model_info}")
 
 
 
206
 
207
  except Exception as e:
208
  logger.error(f"Failed to load model: {str(e)}")
@@ -214,66 +177,33 @@ def load_model_and_tokenizer():
214
  ###############################################################################
215
  # Chat Formatting and Generation Functions
216
  ###############################################################################
217
- def format_chat_for_model(messages, tokenizer, system_prompt=DEFAULT_SYSTEM_PROMPT):
218
- """
219
- Format chat messages for the model using the tokenizer's chat template if available,
220
- or fall back to a manual format for Llama models.
221
- """
222
  logger.info(f"Formatting chat with {len(messages)} messages")
223
 
224
- # Prepare messages in the correct format
225
- formatted_messages = []
226
 
227
  # Add system message if not already present
228
- if messages and messages[0].get("role") != "system":
229
- formatted_messages.append({"role": "system", "content": system_prompt})
230
 
231
- # Add user and assistant messages
232
  for msg in messages:
233
- role = msg["role"]
234
- # Skip system messages if we already added one
235
- if role == "system" and formatted_messages and formatted_messages[0]["role"] == "system":
236
- continue
237
- formatted_messages.append({"role": role, "content": msg["content"]})
238
-
239
- # Try different approaches to format the chat
240
-
241
- # Approach 1: Use the tokenizer's built-in chat template if available
242
- if hasattr(tokenizer, "apply_chat_template") and callable(getattr(tokenizer, "apply_chat_template")):
243
- logger.info("Using tokenizer's built-in chat template")
244
- try:
245
- chat_text = tokenizer.apply_chat_template(
246
- formatted_messages,
247
- tokenize=False,
248
- add_generation_prompt=True
249
- )
250
- logger.debug(f"Formatted chat using built-in template: {chat_text[:100]}...")
251
- return chat_text
252
- except Exception as e:
253
- logger.warning(f"Failed to apply chat template: {str(e)}")
254
- logger.warning("Falling back to manual formatting")
255
-
256
- # Approach 2: Use a Llama 3.1 specific prompt format based on the config files we've seen
257
- # This is based on the special tokens in the model's configuration
258
- logger.info("Using manual chat formatting for Llama model")
259
-
260
- chat_text = "<|begin_of_text|>"
261
-
262
- for msg in formatted_messages:
263
  role = msg["role"]
264
  content = msg["content"]
265
 
266
  if role == "system":
267
- chat_text += f"<|start_header_id|>system<|end_header_id|>\n\n{content}<|eot_id|>"
268
  elif role == "user":
269
- chat_text += f"<|start_header_id|>user<|end_header_id|>\n\n{content}<|eot_id|>"
270
  elif role == "assistant":
271
- chat_text += f"<|start_header_id|>assistant<|end_header_id|>\n\n{content}<|eot_id|>"
272
 
273
- # Add the final assistant header for generation
274
- chat_text += "<|start_header_id|>assistant<|end_header_id|>\n\n"
275
 
276
- logger.debug(f"Manually formatted chat: {chat_text[:100]}...")
277
  return chat_text
278
 
279
  def generate_response(model, tokenizer, messages, temperature=0.7, top_p=0.9, max_new_tokens=256, system_prompt=DEFAULT_SYSTEM_PROMPT):
@@ -281,7 +211,7 @@ def generate_response(model, tokenizer, messages, temperature=0.7, top_p=0.9, ma
281
  logger.info(f"Generating response with temp={temperature}, top_p={top_p}, max_tokens={max_new_tokens}")
282
 
283
  # Format the messages for the model
284
- prompt = format_chat_for_model(messages, tokenizer, system_prompt)
285
 
286
  # Configure generation parameters
287
  gen_config = GenerationConfig(
@@ -290,71 +220,75 @@ def generate_response(model, tokenizer, messages, temperature=0.7, top_p=0.9, ma
290
  do_sample=True,
291
  repetition_penalty=1.1,
292
  max_new_tokens=max_new_tokens,
 
 
 
293
  )
294
 
295
- # Tokenize the input
296
- try:
297
- inputs = tokenizer(prompt, return_tensors="pt")
298
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
299
- logger.info(f"Input tokenized to {inputs['input_ids'].shape[1]} tokens")
300
- except Exception as e:
301
- logger.error(f"Error during tokenization: {str(e)}")
302
- return "I encountered an error while processing your message. Please try again."
303
-
304
  # Generate with retry logic
305
  max_retries = 3
306
  retry_count = 0
307
 
308
  while retry_count < max_retries:
309
  try:
 
 
 
 
 
310
  # Run the generation
311
  generation_start = time.time()
312
  with torch.no_grad():
313
- output_ids = model.generate(
314
  **inputs,
315
  generation_config=gen_config,
316
  )
317
  generation_time = time.time() - generation_start
318
  logger.info(f"Generation completed in {generation_time:.2f} seconds")
319
 
320
- # Decode the output
321
- generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
 
 
322
 
323
- # Extract just the assistant's response
324
- assistant_response = ""
 
 
 
 
 
 
 
 
 
 
325
 
326
- # Try different extraction methods based on the model format
327
-
328
- # Method 1: Standard extraction for template-based output
329
- if "<|start_header_id|>assistant<|end_header_id|>" in generated_text:
330
- parts = generated_text.split("<|start_header_id|>assistant<|end_header_id|>")
331
- if len(parts) > 1:
332
- assistant_part = parts[-1]
333
- if "<|eot_id|>" in assistant_part:
334
- assistant_response = assistant_part.split("<|eot_id|>")[0].strip()
335
- else:
336
- assistant_response = assistant_part.strip()
337
- # Method 2: Simple extraction based on prompt length
338
- else:
339
- # This is a fallback - not as accurate but should work in most cases
340
- assistant_response = generated_text[len(prompt):].strip()
341
 
342
- # If the assistant response seems to have formatting tokens, clean them up
343
- for token in ["<|eot_id|>", "<|eom_id|>", "<|end_of_text|>"]:
344
- if token in assistant_response:
345
- assistant_response = assistant_response.split(token)[0].strip()
346
-
347
- logger.info(f"Response extracted, length: {len(assistant_response)} chars")
 
 
348
 
349
- # If we got an empty response, return a fallback message
350
- if not assistant_response.strip():
351
- logger.warning("Empty response detected, using fallback message")
352
- assistant_response = "I'm sorry, I couldn't generate a proper response. Please try again with a different question or adjust the generation parameters."
353
 
354
  # Free up memory
355
- del inputs, output_ids
356
  optimize_memory()
357
 
 
 
 
 
358
  return assistant_response
359
 
360
  except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
@@ -523,23 +457,6 @@ def build_gradio_interface(model, tokenizer):
523
 
524
  return demo
525
 
526
- ###############################################################################
527
- # Simple messaging for testing tokenizer
528
- ###############################################################################
529
- def test_tokenize_function(tokenizer):
530
- """Test function to ensure tokenizer works with a simple input"""
531
- try:
532
- logger.info("Testing tokenizer with a simple input")
533
- test_input = "Hello, how are you today?"
534
- encoded = tokenizer(test_input, return_tensors="pt")
535
- logger.info(f"Tokenizer test successful: encoded to {encoded['input_ids'].shape[1]} tokens")
536
- decoded = tokenizer.decode(encoded["input_ids"][0])
537
- logger.info(f"Decoded test: '{decoded}'")
538
- return True
539
- except Exception as e:
540
- logger.error(f"Tokenizer test failed: {str(e)}")
541
- return False
542
-
543
  ###############################################################################
544
  # Main Application Logic
545
  ###############################################################################
@@ -552,11 +469,6 @@ def main():
552
  # Load model and tokenizer
553
  model, tokenizer = load_model_and_tokenizer()
554
 
555
- # Test tokenizer functionality
556
- test_result = test_tokenize_function(tokenizer)
557
- if not test_result:
558
- logger.warning("Tokenizer test failed, but continuing with caution")
559
-
560
  # Build and launch Gradio interface
561
  demo = build_gradio_interface(model, tokenizer)
562
 
 
1
+ # app.py - Minimal Version
2
 
3
  import os
4
  import gc
 
9
  import torch
10
  import gradio as gr
11
  from transformers import (
12
+ PreTrainedTokenizerFast,
13
  AutoModelForCausalLM,
14
+ GenerationConfig
 
 
15
  )
16
 
17
  ###############################################################################
 
34
 
35
  If a question is not clear or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
36
 
37
+ # The special tokens we observed in the model's configuration
38
+ BOS_TOKEN = "<|begin_of_text|>"
39
+ EOS_TOKEN = "<|eot_id|>"
40
+ SYSTEM_START = "<|start_header_id|>system<|end_header_id|>\n\n"
41
+ USER_START = "<|start_header_id|>user<|end_header_id|>\n\n"
42
+ ASSISTANT_START = "<|start_header_id|>assistant<|end_header_id|>\n\n"
43
+
44
  ###############################################################################
45
+ # Memory Management
46
  ###############################################################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def optimize_memory():
48
  """Optimize memory usage by clearing caches and forcing garbage collection"""
49
  if torch.cuda.is_available():
 
52
  logger.info("Memory optimized: caches cleared and garbage collected")
53
 
54
  ###############################################################################
55
+ # Custom Tokenizer Class
56
  ###############################################################################
57
+ class MinimalTokenizer:
58
+ """A minimal tokenizer implementation that works with basic model I/O"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ def __init__(self):
61
+ logger.info("Initializing MinimalTokenizer")
62
+ # Use a basic set of special tokens based on the model config
63
+ self.bos_token = BOS_TOKEN
64
+ self.eos_token = EOS_TOKEN
65
+ self.pad_token = EOS_TOKEN
66
 
67
+ # Map tokens to ids (using values from the model config)
68
+ self.token_to_id = {
69
+ BOS_TOKEN: 128000, # Based on config.json
70
+ EOS_TOKEN: 128009, # Based on config.json
71
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ # For logging
74
+ logger.info(f"MinimalTokenizer initialized with special tokens: {self.token_to_id}")
75
+
76
+ def __call__(self, text, return_tensors=None):
77
+ """Tokenize text using the model directly"""
78
+ logger.info(f"Tokenizing text (length: {len(text)})")
 
 
 
 
 
79
 
80
+ # Create inputs for the model - we'll let the model tokenize internally
81
+ inputs = {
82
+ "text": text,
83
+ }
84
 
85
+ # If return_tensors is specified, create a dummy tensor
86
+ # The model will handle tokenization internally
87
+ if return_tensors == "pt":
88
+ # Create a dummy input_ids tensor with the BOS token
89
+ # The actual tokenization will happen inside the model
90
+ dummy_input_ids = torch.tensor([[self.token_to_id[self.bos_token]]])
91
+ inputs = {
92
+ "input_ids": dummy_input_ids,
93
+ "_text": text, # Store the text for the model to use
 
 
 
 
94
  }
 
 
 
95
 
96
+ return inputs
97
+
98
+ def decode(self, token_ids, skip_special_tokens=True):
99
+ """Dummy decode function - the model will handle decoding"""
100
+ # This is just a placeholder - the model will decode internally
101
+ # For logging purposes
102
+ logger.info(f"Decoding token_ids (shape: {token_ids.shape if hasattr(token_ids, 'shape') else 'N/A'})")
103
+
104
+ # We'll get the raw output from the model and handle it specially
105
+ # in the generation function
106
+ return ""
107
+
108
+ ###############################################################################
109
+ # Model Loading with Error Handling
110
+ ###############################################################################
111
+ def load_model_and_tokenizer():
112
+ """Load the model with comprehensive error handling and logging"""
113
+ logger.info(f"Loading model: {MODEL_ID}")
114
+ logger.info(f"Transformers version: {transformers.__version__}")
115
+ logger.info(f"PyTorch version: {torch.__version__}")
116
+
117
+ # Check available devices
118
+ device_info = {
119
+ "cuda_available": torch.cuda.is_available(),
120
+ "device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
121
+ "mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
122
+ }
123
+ logger.info(f"Device information: {device_info}")
124
+
125
+ # Create minimal tokenizer
126
+ tokenizer = MinimalTokenizer()
127
 
128
  # Step 2: Load model with detailed error logging
129
  try:
 
144
  torch_dtype = torch.float32
145
  logger.info("Using CPU with float32 precision")
146
 
147
+ # Load the model
148
  model = AutoModelForCausalLM.from_pretrained(
149
  MODEL_ID,
150
  torch_dtype=torch_dtype,
151
  device_map=device_map,
152
  trust_remote_code=True,
 
153
  )
154
  model.eval()
155
  model_load_time = time.time() - model_start
156
  logger.info(f"Model loaded successfully in {model_load_time:.2f} seconds")
157
 
158
  # Log model info
159
+ try:
160
+ model_info = {
161
+ "model_type": model.config.model_type,
162
+ "hidden_size": model.config.hidden_size,
163
+ "vocab_size": model.config.vocab_size,
164
+ "num_hidden_layers": model.config.num_hidden_layers
165
+ }
166
+ logger.info(f"Model properties: {model_info}")
167
+ except Exception as e:
168
+ logger.warning(f"Could not log all model properties: {str(e)}")
169
 
170
  except Exception as e:
171
  logger.error(f"Failed to load model: {str(e)}")
 
177
  ###############################################################################
178
  # Chat Formatting and Generation Functions
179
  ###############################################################################
180
+ def format_chat_for_model(messages, system_prompt=DEFAULT_SYSTEM_PROMPT):
181
+ """Format chat messages using the special tokens from model configuration"""
 
 
 
182
  logger.info(f"Formatting chat with {len(messages)} messages")
183
 
184
+ # Start with BOS token
185
+ chat_text = BOS_TOKEN
186
 
187
  # Add system message if not already present
188
+ if not messages or messages[0].get("role") != "system":
189
+ chat_text += SYSTEM_START + system_prompt + EOS_TOKEN
190
 
191
+ # Add all messages in the correct format
192
  for msg in messages:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  role = msg["role"]
194
  content = msg["content"]
195
 
196
  if role == "system":
197
+ chat_text += SYSTEM_START + content + EOS_TOKEN
198
  elif role == "user":
199
+ chat_text += USER_START + content + EOS_TOKEN
200
  elif role == "assistant":
201
+ chat_text += ASSISTANT_START + content + EOS_TOKEN
202
 
203
+ # Add final assistant header for the model to continue
204
+ chat_text += ASSISTANT_START
205
 
206
+ logger.info(f"Formatted chat text (length: {len(chat_text)})")
207
  return chat_text
208
 
209
  def generate_response(model, tokenizer, messages, temperature=0.7, top_p=0.9, max_new_tokens=256, system_prompt=DEFAULT_SYSTEM_PROMPT):
 
211
  logger.info(f"Generating response with temp={temperature}, top_p={top_p}, max_tokens={max_new_tokens}")
212
 
213
  # Format the messages for the model
214
+ prompt = format_chat_for_model(messages, system_prompt)
215
 
216
  # Configure generation parameters
217
  gen_config = GenerationConfig(
 
220
  do_sample=True,
221
  repetition_penalty=1.1,
222
  max_new_tokens=max_new_tokens,
223
+ pad_token_id=tokenizer.token_to_id[tokenizer.pad_token],
224
+ bos_token_id=tokenizer.token_to_id[tokenizer.bos_token],
225
+ eos_token_id=tokenizer.token_to_id[tokenizer.eos_token],
226
  )
227
 
 
 
 
 
 
 
 
 
 
228
  # Generate with retry logic
229
  max_retries = 3
230
  retry_count = 0
231
 
232
  while retry_count < max_retries:
233
  try:
234
+ # Tokenize with dummy tensors - the model will handle the actual text
235
+ inputs = tokenizer(prompt, return_tensors="pt")
236
+ inputs["text"] = prompt # Store the actual text
237
+ inputs = {k: v.to(model.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
238
+
239
  # Run the generation
240
  generation_start = time.time()
241
  with torch.no_grad():
242
+ outputs = model.generate(
243
  **inputs,
244
  generation_config=gen_config,
245
  )
246
  generation_time = time.time() - generation_start
247
  logger.info(f"Generation completed in {generation_time:.2f} seconds")
248
 
249
+ # Extract just the assistant's response using string operations
250
+ # This is the key part - the model's output is processed as a string, not tokens
251
+ # Split on the last occurrence of our custom beginning of assistant text
252
+ # We trust the model to format the output correctly
253
+ full_text = prompt # Start with our prompt
254
 
255
+ # Extract actual new text from model's output
256
+ # The output might be unpredictable, so we need to be careful here
257
+ try:
258
+ # Try to get string representation of the output
259
+ output_text = "".join([chr(id) for id in outputs[0].tolist()])
260
+ # Remove initial prompt text to get just the model's generation
261
+ # Add this to the full text
262
+ full_text += output_text
263
+ except Exception as e:
264
+ logger.warning(f"Could not process model output as expected: {str(e)}")
265
+ # In case of failure, produce a simple response
266
+ full_text += "I apologize, but I'm having trouble generating a response."
267
 
268
+ # Extract just the final assistant's response
269
+ try:
270
+ parts = full_text.split(ASSISTANT_START)
271
+ assistant_part = parts[-1] # Get the last assistant part
 
 
 
 
 
 
 
 
 
 
 
272
 
273
+ # Remove any trailing EOS token
274
+ if EOS_TOKEN in assistant_part:
275
+ assistant_response = assistant_part.split(EOS_TOKEN)[0].strip()
276
+ else:
277
+ assistant_response = assistant_part.strip()
278
+ except Exception as e:
279
+ logger.warning(f"Error extracting assistant response: {str(e)}")
280
+ assistant_response = "I apologize, but I'm having trouble generating a proper response."
281
 
282
+ logger.info(f"Extracted assistant response (length: {len(assistant_response)})")
 
 
 
283
 
284
  # Free up memory
285
+ del inputs, outputs
286
  optimize_memory()
287
 
288
+ # Fallback if we get an empty response
289
+ if not assistant_response:
290
+ assistant_response = "I apologize, but I couldn't generate a response. Please try again."
291
+
292
  return assistant_response
293
 
294
  except (torch.cuda.OutOfMemoryError, RuntimeError) as e:
 
457
 
458
  return demo
459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  ###############################################################################
461
  # Main Application Logic
462
  ###############################################################################
 
469
  # Load model and tokenizer
470
  model, tokenizer = load_model_and_tokenizer()
471
 
 
 
 
 
 
472
  # Build and launch Gradio interface
473
  demo = build_gradio_interface(model, tokenizer)
474