Keeby-smilyai commited on
Commit
590f0df
·
verified ·
1 Parent(s): 08810a8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +351 -0
app.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -------------------------------
2
+ # app.py
3
+ #
4
+ # This file contains the backend logic and Gradio UI for the chatbot.
5
+ # Now using Sam-3.0-3 from Smilyai-labs/Sam-3.0-3 — a model that thinks, reasons, and responds with clarity.
6
+ # -------------------------------
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from pathlib import Path
13
+ from safetensors.torch import load_file, safe_open
14
+ from transformers import AutoTokenizer
15
+ from dataclasses import dataclass
16
+ import gradio as gr
17
+ import os
18
+ from huggingface_hub import hf_hub_download
19
+
20
+ # -------------------------------
21
+ # 1) Sam-3.0-3 Architecture (from your second code)
22
+ # -------------------------------
23
+ @dataclass
24
+ class Sam3Config:
25
+ vocab_size: int = 50257
26
+ d_model: int = 384
27
+ n_layers: int = 10
28
+ n_heads: int = 6
29
+ ff_mult: float = 4.0
30
+ dropout: float = 0.1
31
+ input_modality: str = "text"
32
+ head_type: str = "causal_lm"
33
+ version: str = "0.1"
34
+
35
+ def __init__(self, vocab_size=50257, d_model=384, n_layers=10, n_heads=6, ff_mult=4.0, dropout=0.1, input_modality="text", head_type="causal_lm", version="0.1", **kwargs):
36
+ self.vocab_size = vocab_size
37
+ self.d_model = d_model
38
+ self.n_layers = n_layers
39
+ self.n_heads = n_heads
40
+ self.ff_mult = ff_mult
41
+ self.dropout = dropout
42
+ self.input_modality = input_modality
43
+ self.head_type = head_type
44
+ self.version = version
45
+
46
+ class RMSNorm(nn.Module):
47
+ def __init__(self, d, eps=1e-6):
48
+ super().__init__()
49
+ self.eps = eps
50
+ self.weight = nn.Parameter(torch.ones(d))
51
+ def forward(self, x):
52
+ return self.weight * x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt()
53
+
54
+ class MHA(nn.Module):
55
+ def __init__(self, d_model, n_heads, dropout=0.0):
56
+ super().__init__()
57
+ assert d_model % n_heads == 0
58
+ self.n_heads = n_heads
59
+ self.head_dim = d_model // n_heads
60
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
61
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
62
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
63
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
64
+ self.dropout = nn.Dropout(dropout)
65
+ def forward(self, x, attn_mask=None):
66
+ B, T, C = x.shape
67
+ q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
68
+ k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
69
+ v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
70
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
71
+ causal = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
72
+ scores = scores.masked_fill(causal, float("-inf"))
73
+ if attn_mask is not None:
74
+ scores = scores.masked_fill(~attn_mask.unsqueeze(1).unsqueeze(2).bool(), float("-inf"))
75
+ attn = torch.softmax(scores, dim=-1)
76
+ out = torch.matmul(self.dropout(attn), v).transpose(1, 2).contiguous().view(B, T, C)
77
+ return self.out_proj(out)
78
+
79
+ class SwiGLU(nn.Module):
80
+ def __init__(self, d_model, d_ff, dropout=0.0):
81
+ super().__init__()
82
+ self.w1 = nn.Linear(d_model, d_ff, bias=False)
83
+ self.w2 = nn.Linear(d_model, d_ff, bias=False)
84
+ self.w3 = nn.Linear(d_ff, d_model, bias=False)
85
+ self.dropout = nn.Dropout(dropout)
86
+ def forward(self, x):
87
+ return self.w3(self.dropout(torch.nn.functional.silu(self.w1(x)) * self.w2(x)))
88
+
89
+ class Block(nn.Module):
90
+ def __init__(self, d_model, n_heads, ff_mult, dropout=0.0):
91
+ super().__init__()
92
+ self.norm1 = RMSNorm(d_model)
93
+ self.attn = MHA(d_model, n_heads, dropout=dropout)
94
+ self.norm2 = RMSNorm(d_model)
95
+ self.ff = SwiGLU(d_model, int(ff_mult * d_model), dropout=dropout)
96
+ self.drop = nn.Dropout(dropout)
97
+ def forward(self, x, attn_mask=None):
98
+ x = x + self.drop(self.attn(self.norm1(x), attn_mask=attn_mask))
99
+ x = x + self.drop(self.ff(self.norm2(x)))
100
+ return x
101
+
102
+ class Sam3(nn.Module):
103
+ def __init__(self, config: Sam3Config):
104
+ super().__init__()
105
+ self.config = config
106
+ self.embed = nn.Embedding(config.vocab_size, config.d_model)
107
+ self.blocks = nn.ModuleList([Block(config.d_model, config.n_heads, config.ff_mult, dropout=config.dropout) for _ in range(config.n_layers)])
108
+ self.norm = RMSNorm(config.d_model)
109
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
110
+ self.lm_head.weight = self.embed.weight
111
+ def forward(self, input_ids, attention_mask=None):
112
+ x = self.embed(input_ids)
113
+ for blk in self.blocks:
114
+ x = blk(x, attn_mask=attention_mask)
115
+ x = self.norm(x)
116
+ return self.lm_head(x)
117
+
118
+ # -------------------------------
119
+ # 2) Load tokenizer & special tokens (Sam-3.0-3 style)
120
+ # -------------------------------
121
+ SPECIAL_TOKENS = {
122
+ "bos": "<|bos|>",
123
+ "eot": "<|eot|>",
124
+ "user": "<|user|>",
125
+ "assistant": "<|assistant|>",
126
+ "system": "<|system|>",
127
+ "think": "<|think|>",
128
+ }
129
+
130
+ # Use GPT-2 tokenizer and add special tokens
131
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
132
+ if tokenizer.pad_token is None:
133
+ tokenizer.pad_token = tokenizer.eos_token
134
+ tokenizer.add_special_tokens({"additional_special_tokens": list(SPECIAL_TOKENS.values())})
135
+
136
+ EOT_ID = SPECIAL_TOKENS["eot"]
137
+ EOT_ID = tokenizer.convert_tokens_to_ids(EOT_ID) or tokenizer.eos_token_id
138
+
139
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
140
+
141
+ # -------------------------------
142
+ # 3) Download model weights from Hugging Face Hub
143
+ # -------------------------------
144
+ hf_repo = "Smilyai-labs/Sam-3.0-3"
145
+ weights_filename = "model.safetensors"
146
+
147
+ print(f"Loading model '{hf_repo}' from Hugging Face Hub...")
148
+
149
+ try:
150
+ # Download weights
151
+ weights_path = hf_hub_download(repo_id=hf_repo, filename=weights_filename)
152
+ print(f"✅ Downloaded weights to: {weights_path}")
153
+
154
+ # Verify file size
155
+ if not os.path.exists(weights_path):
156
+ raise FileNotFoundError(f"Downloaded file not found at {weights_path}")
157
+ file_size = os.path.getsize(weights_path)
158
+ print(f"📄 File size: {file_size} bytes")
159
+
160
+ except Exception as e:
161
+ raise RuntimeError(f"❌ Failed to download model weights: {e}")
162
+
163
+ # Initialize model with correct vocab size
164
+ cfg = Sam3Config(vocab_size=len(tokenizer))
165
+ model = Sam3(cfg).to(device)
166
+
167
+ # Load state dict safely
168
+ print("Loading state dict...")
169
+ try:
170
+ # Try safe_open first (preferred)
171
+ state_dict = {}
172
+ with safe_open(weights_path, framework="pt", device="cpu") as f:
173
+ for key in f.keys():
174
+ state_dict[key] = f.get_tensor(key)
175
+ print("✅ Loaded via safe_open")
176
+
177
+ except Exception as e:
178
+ print(f"⚠️ safe_open failed: {e}. Falling back to torch.load...")
179
+ try:
180
+ state_dict = torch.load(weights_path, map_location="cpu")
181
+ print("✅ Loaded via torch.load")
182
+ except Exception as torch_e:
183
+ raise RuntimeError(f"❌ Could not load model weights: {torch_e}")
184
+
185
+ # Filter state_dict to match model keys
186
+ model_state_dict = model.state_dict()
187
+ filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
188
+
189
+ # Warn about missing/extra keys
190
+ missing_keys = set(model_state_dict.keys()) - set(filtered_state_dict.keys())
191
+ extra_keys = set(state_dict.keys()) - set(model_state_dict.keys())
192
+ if missing_keys:
193
+ print(f"⚠️ Missing keys in loaded state dict: {missing_keys}")
194
+ if extra_keys:
195
+ print(f"⚠️ Extra keys in loaded state dict: {extra_keys}")
196
+
197
+ model.load_state_dict(filtered_state_dict, strict=False)
198
+ model.eval()
199
+ print("✅ Model loaded successfully!")
200
+
201
+ # -------------------------------
202
+ # 4) Sampling function (unchanged from Sam-3.0-3 code)
203
+ # -------------------------------
204
+ def sample_next_token(
205
+ logits,
206
+ past_tokens,
207
+ temperature=0.8,
208
+ top_k=60,
209
+ top_p=0.9,
210
+ repetition_penalty=1.1,
211
+ max_repeat=5,
212
+ no_repeat_ngram_size=3
213
+ ):
214
+ if logits.dim() == 3:
215
+ logits = logits[:, -1, :].clone()
216
+ else:
217
+ logits = logits.clone()
218
+ batch_size, vocab_size = logits.size(0), logits.size(1)
219
+ orig_logits = logits.clone()
220
+
221
+ if temperature != 1.0:
222
+ logits = logits / float(temperature)
223
+
224
+ past_list = past_tokens.tolist() if isinstance(past_tokens, torch.Tensor) else list(past_tokens)
225
+
226
+ for token_id in set(past_list):
227
+ if 0 <= token_id < vocab_size:
228
+ logits[:, token_id] /= repetition_penalty
229
+
230
+ if len(past_list) >= max_repeat:
231
+ last_token = past_list[-1]
232
+ count = 1
233
+ for i in reversed(past_list[:-1]):
234
+ if i == last_token:
235
+ count += 1
236
+ else:
237
+ break
238
+ if count >= max_repeat:
239
+ if 0 <= last_token < vocab_size:
240
+ logits[:, last_token] = -float("inf")
241
+
242
+ if no_repeat_ngram_size > 0 and len(past_list) >= no_repeat_ngram_size:
243
+ for i in range(len(past_list) - no_repeat_ngram_size + 1):
244
+ ngram = tuple(past_list[i : i + no_repeat_ngram_size])
245
+ if len(past_list) >= no_repeat_ngram_size - 1:
246
+ prefix = tuple(past_list[-(no_repeat_ngram_size - 1):])
247
+ for token_id in range(vocab_size):
248
+ if tuple(list(prefix) + [token_id]) == ngram and 0 <= token_id < vocab_size:
249
+ logits[:, token_id] = -float("inf")
250
+
251
+ if top_k is not None and top_k > 0:
252
+ tk = min(max(1, int(top_k)), vocab_size)
253
+ topk_vals, topk_indices = torch.topk(logits, tk, dim=-1)
254
+ min_topk = topk_vals[:, -1].unsqueeze(-1)
255
+ logits[logits < min_topk] = -float("inf")
256
+
257
+ if top_p is not None and 0.0 < top_p < 1.0:
258
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
259
+ sorted_probs = F.softmax(sorted_logits, dim=-1)
260
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
261
+ for b in range(batch_size):
262
+ sorted_mask = cumulative_probs[b] > top_p
263
+ if sorted_mask.numel() > 0:
264
+ sorted_mask[0] = False
265
+ tokens_to_remove = sorted_indices[b][sorted_mask]
266
+ logits[b, tokens_to_remove] = -float("inf")
267
+
268
+ for b in range(batch_size):
269
+ if torch.isneginf(logits[b]).all():
270
+ logits[b] = orig_logits[b]
271
+
272
+ probs = F.softmax(logits, dim=-1)
273
+ if torch.isnan(probs).any():
274
+ probs = torch.ones_like(logits) / logits.size(1)
275
+
276
+ next_token = torch.multinomial(probs, num_samples=1)
277
+ return next_token.to(device)
278
+
279
+ # -------------------------------
280
+ # 5) Gradio Chat UI and API Logic (Updated with truthful, compelling UI)
281
+ # -------------------------------
282
+ SPECIAL_TOKENS_CHAT = {"bos": "<|bos|>", "eot": "<|eot|>", "user": "<|user|>", "assistant": "<|assistant|>", "system": "<|system|>"}
283
+
284
+ def predict(message, history):
285
+ # Construct the chat history with special tokens
286
+ chat_history = []
287
+ for human, assistant in history:
288
+ chat_history.append(f"{SPECIAL_TOKENS_CHAT['user']} {human} {SPECIAL_TOKENS_CHAT['eot']}")
289
+ if assistant:
290
+ chat_history.append(f"{SPECIAL_TOKENS_CHAT['assistant']} {assistant} {SPECIAL_TOKENS_CHAT['eot']}")
291
+
292
+ chat_history.append(f"{SPECIAL_TOKENS_CHAT['user']} {message} {SPECIAL_TOKENS_CHAT['eot']}")
293
+
294
+ system_prompt = "You are Sam-3, an advanced reasoning AI. You think step by step, analyze deeply, and answer with precision. You do not guess — you deduce. Avoid medical or legal advice."
295
+ prompt = f"{SPECIAL_TOKENS_CHAT['system']} {system_prompt} {SPECIAL_TOKENS_CHAT['eot']}\n" + "\n".join(chat_history) + f"\n{SPECIAL_TOKENS_CHAT['assistant']}"
296
+
297
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
298
+ input_ids = inputs["input_ids"]
299
+ attention_mask = inputs["attention_mask"]
300
+
301
+ generated_text = ""
302
+ for _ in range(256):
303
+ with torch.no_grad():
304
+ logits = model(input_ids, attention_mask=attention_mask)
305
+ next_token = sample_next_token(logits, input_ids[0], temperature=0.4, top_k=50, top_p=0.9, repetition_penalty=1.1)
306
+
307
+ token_id = int(next_token.squeeze().item())
308
+ token_str = tokenizer.decode([token_id], skip_special_tokens=True)
309
+
310
+ input_ids = torch.cat([input_ids, next_token], dim=1)
311
+ attention_mask = torch.cat([attention_mask, torch.ones((attention_mask.size(0), 1), device=device, dtype=attention_mask.dtype)], dim=1)
312
+
313
+ generated_text += token_str
314
+ yield generated_text
315
+
316
+ if token_id == EOT_ID:
317
+ break
318
+
319
+ # Gradio Interface — Now Truthfully Representing the Model’s Capabilities
320
+ demo = gr.ChatInterface(
321
+ fn=predict,
322
+ title="🌟 Sam-3: The Reasoning AI",
323
+ description="""
324
+ Sam-3 is not just a language model — it **thinks before it speaks**.
325
+ Built with deep architectural integrity, it analyzes problems step-by-step, uncovers hidden patterns, and delivers precise, logical answers.
326
+ No fluff. No guessing. Just reasoning.
327
+
328
+ Try asking it:
329
+ → “If I have 3 apples and give away half of them, then buy 5 more, how many do I have?”
330
+ → “Explain quantum entanglement like I’m 10.”
331
+ → “What’s the flaw in this argument: ‘All birds fly; penguins are birds; therefore penguins can fly’?”
332
+ """,
333
+ theme=gr.themes.Soft(
334
+ primary_hue="indigo",
335
+ secondary_hue="blue"
336
+ ),
337
+ chatbot=gr.Chatbot(
338
+ label="Sam-3 🤔",
339
+ bubble_full_width=False,
340
+ height=600,
341
+ ),
342
+ examples=[
343
+ "What is the capital of France?",
344
+ "Explain why the sky is blue.",
345
+ "If a train leaves at 2 PM going 60 mph, and another leaves 30 minutes later at 80 mph, when does the second catch up?",
346
+ "What are the ethical implications of AI making medical diagnoses?"
347
+ ],
348
+ cache_examples=False
349
+ ).launch(
350
+ show_api=True
351
+ )