Keeby-smilyai commited on
Commit
a459d20
Β·
verified Β·
1 Parent(s): 07e759b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -39
app.py CHANGED
@@ -1,8 +1,8 @@
1
  # -------------------------------
2
  # app.py
3
  #
4
- # This file contains the backend logic and Gradio UI for the chatbot.
5
- # Now using Sam-3.0-3 from Smilyai-labs/Sam-3.0-3 β€” a model that thinks, reasons, and responds with clarity.
6
  # -------------------------------
7
 
8
  import math
@@ -18,7 +18,7 @@ import os
18
  from huggingface_hub import hf_hub_download
19
 
20
  # -------------------------------
21
- # 1) Sam-3.0-3 Architecture (from your second code)
22
  # -------------------------------
23
  @dataclass
24
  class Sam3Config:
@@ -116,7 +116,7 @@ class Sam3(nn.Module):
116
  return self.lm_head(x)
117
 
118
  # -------------------------------
119
- # 2) Load tokenizer & special tokens (Sam-3.0-3 style)
120
  # -------------------------------
121
  SPECIAL_TOKENS = {
122
  "bos": "<|bos|>",
@@ -127,19 +127,18 @@ SPECIAL_TOKENS = {
127
  "think": "<|think|>",
128
  }
129
 
130
- # Use GPT-2 tokenizer and add special tokens
131
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
132
  if tokenizer.pad_token is None:
133
  tokenizer.pad_token = tokenizer.eos_token
134
  tokenizer.add_special_tokens({"additional_special_tokens": list(SPECIAL_TOKENS.values())})
135
 
136
- EOT_ID = SPECIAL_TOKENS["eot"]
137
- EOT_ID = tokenizer.convert_tokens_to_ids(EOT_ID) or tokenizer.eos_token_id
138
 
139
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
140
 
141
  # -------------------------------
142
- # 3) Download model weights from Hugging Face Hub
143
  # -------------------------------
144
  hf_repo = "Smilyai-labs/Sam-3.0-3"
145
  weights_filename = "model.safetensors"
@@ -147,11 +146,9 @@ weights_filename = "model.safetensors"
147
  print(f"Loading model '{hf_repo}' from Hugging Face Hub...")
148
 
149
  try:
150
- # Download weights
151
  weights_path = hf_hub_download(repo_id=hf_repo, filename=weights_filename)
152
  print(f"βœ… Downloaded weights to: {weights_path}")
153
 
154
- # Verify file size
155
  if not os.path.exists(weights_path):
156
  raise FileNotFoundError(f"Downloaded file not found at {weights_path}")
157
  file_size = os.path.getsize(weights_path)
@@ -160,20 +157,18 @@ try:
160
  except Exception as e:
161
  raise RuntimeError(f"❌ Failed to download model weights: {e}")
162
 
163
- # Initialize model with correct vocab size
164
  cfg = Sam3Config(vocab_size=len(tokenizer))
165
  model = Sam3(cfg).to(device)
166
 
167
  # Load state dict safely
168
  print("Loading state dict...")
169
  try:
170
- # Try safe_open first (preferred)
171
  state_dict = {}
172
  with safe_open(weights_path, framework="pt", device="cpu") as f:
173
  for key in f.keys():
174
  state_dict[key] = f.get_tensor(key)
175
  print("βœ… Loaded via safe_open")
176
-
177
  except Exception as e:
178
  print(f"⚠️ safe_open failed: {e}. Falling back to torch.load...")
179
  try:
@@ -182,24 +177,22 @@ except Exception as e:
182
  except Exception as torch_e:
183
  raise RuntimeError(f"❌ Could not load model weights: {torch_e}")
184
 
185
- # Filter state_dict to match model keys
186
  model_state_dict = model.state_dict()
187
  filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
188
-
189
- # Warn about missing/extra keys
190
  missing_keys = set(model_state_dict.keys()) - set(filtered_state_dict.keys())
191
  extra_keys = set(state_dict.keys()) - set(model_state_dict.keys())
192
  if missing_keys:
193
- print(f"⚠️ Missing keys in loaded state dict: {missing_keys}")
194
  if extra_keys:
195
- print(f"⚠️ Extra keys in loaded state dict: {extra_keys}")
196
 
197
  model.load_state_dict(filtered_state_dict, strict=False)
198
  model.eval()
199
  print("βœ… Model loaded successfully!")
200
 
201
  # -------------------------------
202
- # 4) Sampling function (unchanged from Sam-3.0-3 code)
203
  # -------------------------------
204
  def sample_next_token(
205
  logits,
@@ -277,12 +270,12 @@ def sample_next_token(
277
  return next_token.to(device)
278
 
279
  # -------------------------------
280
- # 5) Gradio Chat UI and API Logic (Updated with truthful, compelling UI)
281
  # -------------------------------
282
- SPECIAL_TOKENS_CHAT = {"bos": "<|bos|>", "eot": "<|eot|>", "user": "<|user|>", "assistant": "<|assistant|>", "system": "<|system|>"}
283
 
284
  def predict(message, history):
285
- # Construct the chat history with special tokens
286
  chat_history = []
287
  for human, assistant in history:
288
  chat_history.append(f"{SPECIAL_TOKENS_CHAT['user']} {human} {SPECIAL_TOKENS_CHAT['eot']}")
@@ -291,44 +284,85 @@ def predict(message, history):
291
 
292
  chat_history.append(f"{SPECIAL_TOKENS_CHAT['user']} {message} {SPECIAL_TOKENS_CHAT['eot']}")
293
 
294
- system_prompt = "You are Sam-3, an advanced reasoning AI. You think step by step, analyze deeply, and answer with precision. You do not guess β€” you deduce. Avoid medical or legal advice."
295
- prompt = f"{SPECIAL_TOKENS_CHAT['system']} {system_prompt} {SPECIAL_TOKENS_CHAT['eot']}\n" + "\n".join(chat_history) + f"\n{SPECIAL_TOKENS_CHAT['assistant']}"
296
 
297
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
298
  input_ids = inputs["input_ids"]
299
  attention_mask = inputs["attention_mask"]
300
-
301
  generated_text = ""
 
 
 
302
  for _ in range(256):
303
  with torch.no_grad():
304
  logits = model(input_ids, attention_mask=attention_mask)
305
  next_token = sample_next_token(logits, input_ids[0], temperature=0.4, top_k=50, top_p=0.9, repetition_penalty=1.1)
306
 
307
  token_id = int(next_token.squeeze().item())
308
- token_str = tokenizer.decode([token_id], skip_special_tokens=True)
309
-
310
  input_ids = torch.cat([input_ids, next_token], dim=1)
311
  attention_mask = torch.cat([attention_mask, torch.ones((attention_mask.size(0), 1), device=device, dtype=attention_mask.dtype)], dim=1)
312
-
313
- generated_text += token_str
314
- yield generated_text
315
-
316
- if token_id == EOT_ID:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  break
318
 
319
- # Gradio Interface β€” Now Truthfully Representing the Model’s Capabilities
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  demo = gr.ChatInterface(
321
  fn=predict,
322
  title="🌟 Sam-3: The Reasoning AI",
323
  description="""
324
- Sam-3 is not just a language model β€” it **thinks before it speaks**.
325
- Built with deep architectural integrity, it analyzes problems step-by-step, uncovers hidden patterns, and delivers precise, logical answers.
326
- No fluff. No guessing. Just reasoning.
327
 
328
- Try asking it:
329
- β†’ β€œIf I have 3 apples and give away half of them, then buy 5 more, how many do I have?”
 
330
  β†’ β€œExplain quantum entanglement like I’m 10.”
331
- β†’ β€œWhat’s the flaw in this argument: β€˜All birds fly; penguins are birds; therefore penguins can fly’?”
332
  """,
333
  theme=gr.themes.Soft(
334
  primary_hue="indigo",
@@ -338,6 +372,10 @@ demo = gr.ChatInterface(
338
  label="Sam-3 πŸ€”",
339
  bubble_full_width=False,
340
  height=600,
 
 
 
 
341
  ),
342
  examples=[
343
  "What is the capital of France?",
@@ -345,6 +383,7 @@ demo = gr.ChatInterface(
345
  "If a train leaves at 2 PM going 60 mph, and another leaves 30 minutes later at 80 mph, when does the second catch up?",
346
  "What are the ethical implications of AI making medical diagnoses?"
347
  ],
 
348
  cache_examples=False
349
  ).launch(
350
  show_api=True
 
1
  # -------------------------------
2
  # app.py
3
  #
4
+ # Sam-3: The Reasoning AI β€” Now Showing Its Thought Process!
5
+ # Powered by Smilyai-labs/Sam-3.0-3. Trained to think before speaking.
6
  # -------------------------------
7
 
8
  import math
 
18
  from huggingface_hub import hf_hub_download
19
 
20
  # -------------------------------
21
+ # 1) Sam-3.0-3 Architecture
22
  # -------------------------------
23
  @dataclass
24
  class Sam3Config:
 
116
  return self.lm_head(x)
117
 
118
  # -------------------------------
119
+ # 2) Load Tokenizer & Special Tokens
120
  # -------------------------------
121
  SPECIAL_TOKENS = {
122
  "bos": "<|bos|>",
 
127
  "think": "<|think|>",
128
  }
129
 
 
130
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
131
  if tokenizer.pad_token is None:
132
  tokenizer.pad_token = tokenizer.eos_token
133
  tokenizer.add_special_tokens({"additional_special_tokens": list(SPECIAL_TOKENS.values())})
134
 
135
+ EOT_ID = tokenizer.convert_tokens_to_ids("<|eot|>") or tokenizer.eos_token_id
136
+ THINK_ID = tokenizer.convert_tokens_to_ids("<|think|>")
137
 
138
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
139
 
140
  # -------------------------------
141
+ # 3) Download Model Weights from Hugging Face Hub
142
  # -------------------------------
143
  hf_repo = "Smilyai-labs/Sam-3.0-3"
144
  weights_filename = "model.safetensors"
 
146
  print(f"Loading model '{hf_repo}' from Hugging Face Hub...")
147
 
148
  try:
 
149
  weights_path = hf_hub_download(repo_id=hf_repo, filename=weights_filename)
150
  print(f"βœ… Downloaded weights to: {weights_path}")
151
 
 
152
  if not os.path.exists(weights_path):
153
  raise FileNotFoundError(f"Downloaded file not found at {weights_path}")
154
  file_size = os.path.getsize(weights_path)
 
157
  except Exception as e:
158
  raise RuntimeError(f"❌ Failed to download model weights: {e}")
159
 
160
+ # Initialize model
161
  cfg = Sam3Config(vocab_size=len(tokenizer))
162
  model = Sam3(cfg).to(device)
163
 
164
  # Load state dict safely
165
  print("Loading state dict...")
166
  try:
 
167
  state_dict = {}
168
  with safe_open(weights_path, framework="pt", device="cpu") as f:
169
  for key in f.keys():
170
  state_dict[key] = f.get_tensor(key)
171
  print("βœ… Loaded via safe_open")
 
172
  except Exception as e:
173
  print(f"⚠️ safe_open failed: {e}. Falling back to torch.load...")
174
  try:
 
177
  except Exception as torch_e:
178
  raise RuntimeError(f"❌ Could not load model weights: {torch_e}")
179
 
180
+ # Filter and load
181
  model_state_dict = model.state_dict()
182
  filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
 
 
183
  missing_keys = set(model_state_dict.keys()) - set(filtered_state_dict.keys())
184
  extra_keys = set(state_dict.keys()) - set(model_state_dict.keys())
185
  if missing_keys:
186
+ print(f"⚠️ Missing keys: {missing_keys}")
187
  if extra_keys:
188
+ print(f"⚠️ Extra keys: {extra_keys}")
189
 
190
  model.load_state_dict(filtered_state_dict, strict=False)
191
  model.eval()
192
  print("βœ… Model loaded successfully!")
193
 
194
  # -------------------------------
195
+ # 4) Sampling Function (Unchanged)
196
  # -------------------------------
197
  def sample_next_token(
198
  logits,
 
270
  return next_token.to(device)
271
 
272
  # -------------------------------
273
+ # 5) Gradio Chat Interface β€” WITH STYLED THINKING STEPS
274
  # -------------------------------
275
+ SPECIAL_TOKENS_CHAT = {"bos": "<|bos|>", "eot": "<|eot|>", "user": "<|user|>", "assistant": "<|assistant|>", "system": "<|system|>", "think": "<|think|>"}
276
 
277
  def predict(message, history):
278
+ # Build prompt with <|think|> to trigger internal reasoning
279
  chat_history = []
280
  for human, assistant in history:
281
  chat_history.append(f"{SPECIAL_TOKENS_CHAT['user']} {human} {SPECIAL_TOKENS_CHAT['eot']}")
 
284
 
285
  chat_history.append(f"{SPECIAL_TOKENS_CHAT['user']} {message} {SPECIAL_TOKENS_CHAT['eot']}")
286
 
287
+ system_prompt = "You are Sam-3, an advanced reasoning AI. You think step-by-step, analyze deeply, and respond with precision. You do not guess β€” you deduce. Avoid medical or legal advice."
288
+ prompt = f"{SPECIAL_TOKENS_CHAT['system']} {system_prompt} {SPECIAL_TOKENS_CHAT['eot']}\n" + "\n".join(chat_history) + f"\n{SPECIAL_TOKENS_CHAT['assistant']} {SPECIAL_TOKENS_CHAT['think']}"
289
 
290
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
291
  input_ids = inputs["input_ids"]
292
  attention_mask = inputs["attention_mask"]
293
+
294
  generated_text = ""
295
+ thinking_mode = False
296
+ thinking_buffer = ""
297
+
298
  for _ in range(256):
299
  with torch.no_grad():
300
  logits = model(input_ids, attention_mask=attention_mask)
301
  next_token = sample_next_token(logits, input_ids[0], temperature=0.4, top_k=50, top_p=0.9, repetition_penalty=1.1)
302
 
303
  token_id = int(next_token.squeeze().item())
304
+ token_str = tokenizer.decode([token_id], skip_special_tokens=False) # Keep special tokens!
305
+
306
  input_ids = torch.cat([input_ids, next_token], dim=1)
307
  attention_mask = torch.cat([attention_mask, torch.ones((attention_mask.size(0), 1), device=device, dtype=attention_mask.dtype)], dim=1)
308
+
309
+ # Detect if we're entering/exiting thinking mode
310
+ if not thinking_mode and token_str == "<|think|>":
311
+ thinking_mode = True
312
+ thinking_buffer = "" # Start capturing thoughts
313
+ continue # Don't yield <|think|> itself
314
+
315
+ if thinking_mode:
316
+ if token_str == "<|eot|>":
317
+ # End of thought β€” now yield the full thinking block
318
+ thinking_buffer = thinking_buffer.strip()
319
+ if thinking_buffer:
320
+ # Yield as styled markdown block
321
+ yield f"<div style='background-color:#f8f9fa; padding:12px; border-left:4px solid #ccc; border-radius:0 8px 8px 0; margin:10px 0; font-style:italic; color:#555;'>πŸ’‘ Thinking: {thinking_buffer}</div>"
322
+ thinking_mode = False
323
+ continue
324
+ else:
325
+ thinking_buffer += token_str
326
+ continue # Don't yield yet β€” buffer until <|eot|>
327
+
328
+ # Normal response output
329
+ if not thinking_mode:
330
+ generated_text += token_str
331
+ yield generated_text
332
+
333
+ # Stop on final EOT
334
+ if token_id == EOT_ID and not thinking_mode:
335
  break
336
 
337
+ # Custom CSS for styling thinking blocks
338
+ CSS = """
339
+ .gradio-container .message-bubble {
340
+ border-radius: 12px !important;
341
+ }
342
+ .gradio-container .message-bubble.user {
343
+ background-color: #1f7bff !important;
344
+ color: white !important;
345
+ }
346
+ .gradio-container .message-bubble.assistant {
347
+ background-color: #e9ecef !important;
348
+ color: #212529 !important;
349
+ }
350
+ """
351
+
352
+ # Gradio Interface
353
  demo = gr.ChatInterface(
354
  fn=predict,
355
  title="🌟 Sam-3: The Reasoning AI",
356
  description="""
357
+ Sam-3 doesn’t just answer β€” it **thinks first**.
358
+ Watch its internal reasoning unfold in real time β€” step by step, clearly shown.
359
+ No guessing. No fluff. Just pure deduction.
360
 
361
+ Try asking:
362
+ β†’ β€œWhy does a mirror reverse left and right but not up and down?”
363
+ β†’ β€œIf I have 3 apples and give away half, then buy 5 more, how many do I have?”
364
  β†’ β€œExplain quantum entanglement like I’m 10.”
365
+ β†’ β€œWhat’s wrong with this argument: β€˜All birds fly; penguins are birds; therefore penguins can fly’?”
366
  """,
367
  theme=gr.themes.Soft(
368
  primary_hue="indigo",
 
372
  label="Sam-3 πŸ€”",
373
  bubble_full_width=False,
374
  height=600,
375
+ avatar_images=(
376
+ "https://huggingface.co/datasets/huggingface/branding/resolve/main/avatar-bot.jpg",
377
+ "https://huggingface.co/datasets/huggingface/branding/resolve/main/avatar-user.jpg"
378
+ )
379
  ),
380
  examples=[
381
  "What is the capital of France?",
 
383
  "If a train leaves at 2 PM going 60 mph, and another leaves 30 minutes later at 80 mph, when does the second catch up?",
384
  "What are the ethical implications of AI making medical diagnoses?"
385
  ],
386
+ css=CSS,
387
  cache_examples=False
388
  ).launch(
389
  show_api=True