Spaces:
Sleeping
Sleeping
| import os | |
| os.environ['KERAS_BACKEND'] = 'tensorflow' | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
| import tensorflow as tf | |
| import keras | |
| import numpy as np | |
| from tokenizers import Tokenizer | |
| from huggingface_hub import hf_hub_download | |
| import json | |
| from abc import ABC, abstractmethod | |
| import time | |
| import threading | |
| import hashlib | |
| import sqlite3 | |
| from datetime import datetime, timedelta | |
| import pytz | |
| # ============================================================================== | |
| # Performance Optimizations for CPU | |
| # ============================================================================== | |
| tf.config.threading.set_inter_op_parallelism_threads(1) | |
| tf.config.threading.set_intra_op_parallelism_threads(2) | |
| tf.config.optimizer.set_jit(True) | |
| tf.config.run_functions_eagerly(False) | |
| os.environ['TF_GPU_ALLOCATOR'] = 'cuda_malloc_async' | |
| os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' | |
| # Australian timezone | |
| AUSTRALIA_TZ = pytz.timezone('Australia/Sydney') | |
| # ============================================================================== | |
| # Database Setup | |
| # ============================================================================== | |
| def init_database(): | |
| """Initialize SQLite database for users and subscriptions.""" | |
| conn = sqlite3.connect('sam_users.db', check_same_thread=False) | |
| c = conn.cursor() | |
| # Users table | |
| c.execute('''CREATE TABLE IF NOT EXISTS users | |
| (id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| username TEXT UNIQUE NOT NULL, | |
| password_hash TEXT NOT NULL, | |
| email TEXT, | |
| plan TEXT DEFAULT 'free', | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| is_admin BOOLEAN DEFAULT 0, | |
| rate_limit_start TIMESTAMP, | |
| messages_used_nano INTEGER DEFAULT 0, | |
| messages_used_mini INTEGER DEFAULT 0, | |
| messages_used_fast INTEGER DEFAULT 0, | |
| messages_used_large INTEGER DEFAULT 0)''') | |
| # Upgrade requests table | |
| c.execute('''CREATE TABLE IF NOT EXISTS upgrade_requests | |
| (id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| user_id INTEGER, | |
| requested_plan TEXT, | |
| reason TEXT, | |
| status TEXT DEFAULT 'pending', | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| FOREIGN KEY (user_id) REFERENCES users(id))''') | |
| # Usage tracking | |
| c.execute('''CREATE TABLE IF NOT EXISTS usage_logs | |
| (id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| user_id INTEGER, | |
| tokens_used INTEGER, | |
| model_used TEXT, | |
| timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| FOREIGN KEY (user_id) REFERENCES users(id))''') | |
| # Create admin account if not exists | |
| admin_pass = hashlib.sha256("admin123".encode()).hexdigest() | |
| try: | |
| c.execute("INSERT INTO users (username, password_hash, email, plan, is_admin) VALUES (?, ?, ?, ?, ?)", | |
| ("admin", admin_pass, "admin@samx1.ai", "pro", 1)) | |
| conn.commit() | |
| print("✅ Admin account created (username: admin, password: admin123)") | |
| except sqlite3.IntegrityError: | |
| print("✅ Admin account already exists") | |
| conn.commit() | |
| return conn | |
| # Global database connection | |
| db_conn = init_database() | |
| db_lock = threading.Lock() | |
| # Plan limits with 3-hour rolling window | |
| PLAN_LIMITS = { | |
| 'free': { | |
| 'nano_messages': 100, | |
| 'mini_messages': 4, | |
| 'fast_messages': 7, | |
| 'large_messages': 5, | |
| 'can_choose_model': False, | |
| 'max_tokens': 256, | |
| 'reset_hours': 5 | |
| }, | |
| 'explore': { | |
| 'nano_messages': 200, | |
| 'mini_messages': 8, | |
| 'fast_messages': 14, | |
| 'large_messages': 10, | |
| 'can_choose_model': True, | |
| 'max_tokens': 512, | |
| 'reset_hours': 3 | |
| }, | |
| 'plus': { | |
| 'nano_messages': 500, | |
| 'mini_messages': 20, | |
| 'fast_messages': 17, | |
| 'large_messages': 9, | |
| 'can_choose_model': True, | |
| 'max_tokens': 384, | |
| 'reset_hours': 2 | |
| }, | |
| 'pro': { | |
| 'nano_messages': 10000000, | |
| 'mini_messages': 100, | |
| 'fast_messages': 50, | |
| 'large_messages': 20, | |
| 'can_choose_model': True, | |
| 'max_tokens': 512, | |
| 'reset_hours': 3 | |
| }, | |
| 'Research': { | |
| 'nano_messages': 10000000, | |
| 'mini_messages': 1000, | |
| 'fast_messages': 500, | |
| 'large_messages': 200, | |
| 'can_choose_model': True, | |
| 'max_tokens': 1024, | |
| 'reset_hours': 5 | |
| }, | |
| 'VIP': { # 👈 Clean name using "hyper" instead of spaces | |
| 'nano_messages': 100000000000000, | |
| 'mini_messages': 1000, | |
| 'fast_messages': 5000, | |
| 'large_messages': 200, | |
| 'can_choose_model': True, | |
| 'max_tokens': 1024, | |
| 'reset_hours': 2 | |
| }, | |
| 'Sam-X-1-Mini-release-speacil-plan': { # 👈 Clean name using "hyper" instead of spaces | |
| 'nano_messages': -1, | |
| 'mini_messages': -1, | |
| 'fast_messages': -1, | |
| 'large_messages': -1, | |
| 'can_choose_model': True, | |
| 'max_tokens': 900000, | |
| 'reset_hours': 0.2 | |
| } | |
| } | |
| def get_model_type(model_name): | |
| """Get model type from model name.""" | |
| if 'Nano' in model_name: | |
| return 'nano' | |
| elif 'Mini' in model_name: | |
| return 'mini' | |
| elif 'Fast' in model_name: | |
| return 'fast' | |
| elif 'Large' in model_name: | |
| return 'large' | |
| return 'nano' | |
| # ============================================================================== | |
| # User Management Functions | |
| # ============================================================================== | |
| def hash_password(password): | |
| return hashlib.sha256(password.encode()).hexdigest() | |
| def create_user(username, password, email=""): | |
| with db_lock: | |
| try: | |
| c = db_conn.cursor() | |
| now = datetime.now(AUSTRALIA_TZ).isoformat() | |
| c.execute("INSERT INTO users (username, password_hash, email, rate_limit_start) VALUES (?, ?, ?, ?)", | |
| (username, hash_password(password), email, now)) | |
| db_conn.commit() | |
| return True, "Account created successfully!" | |
| except sqlite3.IntegrityError: | |
| return False, "Username already exists!" | |
| def authenticate_user(username, password): | |
| with db_lock: | |
| c = db_conn.cursor() | |
| c.execute("SELECT id, password_hash, plan, is_admin FROM users WHERE username = ?", (username,)) | |
| result = c.fetchone() | |
| if result and result[1] == hash_password(password): | |
| return True, {"id": result[0], "username": username, "plan": result[2], "is_admin": bool(result[3])} | |
| return False, None | |
| def check_and_reset_limits(user_id): | |
| """Check if 3-hour window has passed and reset limits if needed.""" | |
| with db_lock: | |
| c = db_conn.cursor() | |
| c.execute("SELECT rate_limit_start, plan FROM users WHERE id = ?", (user_id,)) | |
| result = c.fetchone() | |
| if not result: | |
| return | |
| rate_limit_start_str, plan = result | |
| reset_hours = PLAN_LIMITS[plan]['reset_hours'] | |
| if rate_limit_start_str: | |
| rate_limit_start = datetime.fromisoformat(rate_limit_start_str) | |
| now = datetime.now(AUSTRALIA_TZ) | |
| if now - rate_limit_start >= timedelta(hours=reset_hours): | |
| new_start = now.isoformat() | |
| c.execute("""UPDATE users | |
| SET rate_limit_start = ?, | |
| messages_used_nano = 0, | |
| messages_used_mini = 0, | |
| messages_used_fast = 0, | |
| messages_used_large = 0 | |
| WHERE id = ?""", (new_start, user_id)) | |
| db_conn.commit() | |
| def get_user_limits_info(user_id): | |
| """Get user's current usage and limits with reset time.""" | |
| check_and_reset_limits(user_id) | |
| with db_lock: | |
| c = db_conn.cursor() | |
| c.execute("""SELECT plan, rate_limit_start, | |
| messages_used_nano, messages_used_mini, | |
| messages_used_fast, messages_used_large | |
| FROM users WHERE id = ?""", (user_id,)) | |
| result = c.fetchone() | |
| if not result: | |
| return None | |
| plan, rate_limit_start_str, nano_used, mini_used, fast_used, large_used = result | |
| limits = PLAN_LIMITS[plan] | |
| if rate_limit_start_str: | |
| rate_limit_start = datetime.fromisoformat(rate_limit_start_str) | |
| reset_time = rate_limit_start + timedelta(hours=limits['reset_hours']) | |
| now = datetime.now(AUSTRALIA_TZ) | |
| time_until_reset = reset_time - now | |
| hours, remainder = divmod(int(time_until_reset.total_seconds()), 3600) | |
| minutes, seconds = divmod(remainder, 60) | |
| reset_str = f"{hours}h {minutes}m" | |
| else: | |
| reset_str = "N/A" | |
| return { | |
| 'plan': plan, | |
| 'nano_used': nano_used, | |
| 'mini_used': mini_used, | |
| 'fast_used': fast_used, | |
| 'large_used': large_used, | |
| 'nano_limit': limits['nano_messages'], | |
| 'mini_limit': limits['mini_messages'], | |
| 'fast_limit': limits['fast_messages'], | |
| 'large_limit': limits['large_messages'], | |
| 'can_choose_model': limits['can_choose_model'], | |
| 'max_tokens': limits['max_tokens'], | |
| 'reset_in': reset_str | |
| } | |
| def can_use_model(user_id, model_name): | |
| """Check if user can use a specific model.""" | |
| info = get_user_limits_info(user_id) | |
| if not info: | |
| return False, "User not found" | |
| model_type = get_model_type(model_name) | |
| used_key = f"{model_type}_used" | |
| limit_key = f"{model_type}_limit" | |
| used = info[used_key] | |
| limit = info[limit_key] | |
| if limit == -1: | |
| return True, "OK" | |
| if used >= limit: | |
| return False, f"Limit reached for {model_type.upper()} model ({used}/{limit}). Resets in {info['reset_in']}" | |
| return True, "OK" | |
| def increment_model_usage(user_id, model_name): | |
| """Increment usage counter for a model.""" | |
| model_type = get_model_type(model_name) | |
| column = f"messages_used_{model_type}" | |
| with db_lock: | |
| c = db_conn.cursor() | |
| c.execute(f"UPDATE users SET {column} = {column} + 1 WHERE id = ?", (user_id,)) | |
| db_conn.commit() | |
| def get_available_models_for_user(user_id): | |
| """Get list of models user can currently use.""" | |
| info = get_user_limits_info(user_id) | |
| if not info: | |
| return [] | |
| available = [] | |
| for model_type in ['nano', 'mini', 'fast', 'large']: | |
| used = info[f'{model_type}_used'] | |
| limit = info[f'{model_type}_limit'] | |
| if limit == -1 or used < limit: | |
| for model_name in available_models.keys(): | |
| if get_model_type(model_name) == model_type: | |
| available.append(model_name) | |
| break | |
| return available | |
| def log_usage(user_id, tokens, model): | |
| with db_lock: | |
| c = db_conn.cursor() | |
| c.execute("INSERT INTO usage_logs (user_id, tokens_used, model_used) VALUES (?, ?, ?)", | |
| (user_id, tokens, model)) | |
| db_conn.commit() | |
| def request_upgrade(user_id, plan, reason): | |
| with db_lock: | |
| try: | |
| c = db_conn.cursor() | |
| c.execute("INSERT INTO upgrade_requests (user_id, requested_plan, reason) VALUES (?, ?, ?)", | |
| (user_id, plan, reason)) | |
| db_conn.commit() | |
| return True, "Upgrade request submitted! Admin will review soon." | |
| except Exception as e: | |
| return False, f"Error: {str(e)}" | |
| def get_all_users(): | |
| with db_lock: | |
| c = db_conn.cursor() | |
| c.execute("""SELECT id, username, email, plan, created_at, is_admin, | |
| messages_used_nano, messages_used_mini, | |
| messages_used_fast, messages_used_large, | |
| rate_limit_start | |
| FROM users ORDER BY created_at DESC""") | |
| return c.fetchall() | |
| def get_pending_requests(): | |
| with db_lock: | |
| c = db_conn.cursor() | |
| c.execute("""SELECT r.id, u.username, r.requested_plan, r.reason, r.created_at | |
| FROM upgrade_requests r | |
| JOIN users u ON r.user_id = u.id | |
| WHERE r.status = 'pending' | |
| ORDER BY r.created_at DESC""") | |
| return c.fetchall() | |
| def update_user_plan(username, new_plan): | |
| with db_lock: | |
| try: | |
| c = db_conn.cursor() | |
| now = datetime.now(AUSTRALIA_TZ).isoformat() | |
| c.execute("""UPDATE users | |
| SET plan = ?, | |
| rate_limit_start = ?, | |
| messages_used_nano = 0, | |
| messages_used_mini = 0, | |
| messages_used_fast = 0, | |
| messages_used_large = 0 | |
| WHERE username = ?""", (new_plan, now, username)) | |
| db_conn.commit() | |
| return True, f"User {username} upgraded to {new_plan}!" | |
| except Exception as e: | |
| return False, f"Error: {str(e)}" | |
| def approve_request(request_id): | |
| with db_lock: | |
| try: | |
| c = db_conn.cursor() | |
| c.execute("SELECT user_id, requested_plan FROM upgrade_requests WHERE id = ?", (request_id,)) | |
| result = c.fetchone() | |
| if result: | |
| user_id, plan = result | |
| now = datetime.now(AUSTRALIA_TZ).isoformat() | |
| c.execute("""UPDATE users | |
| SET plan = ?, | |
| rate_limit_start = ?, | |
| messages_used_nano = 0, | |
| messages_used_mini = 0, | |
| messages_used_fast = 0, | |
| messages_used_large = 0 | |
| WHERE id = ?""", (plan, now, user_id)) | |
| c.execute("UPDATE upgrade_requests SET status = 'approved' WHERE id = ?", (request_id,)) | |
| db_conn.commit() | |
| return True, "Request approved!" | |
| return False, "Request not found" | |
| except Exception as e: | |
| return False, f"Error: {str(e)}" | |
| def deny_request(request_id): | |
| with db_lock: | |
| try: | |
| c = db_conn.cursor() | |
| c.execute("UPDATE upgrade_requests SET status = 'denied' WHERE id = ?", (request_id,)) | |
| db_conn.commit() | |
| return True, "Request denied" | |
| except Exception as e: | |
| return False, f"Error: {str(e)}" | |
| # ============================================================================== | |
| # Model Architecture | |
| # ============================================================================== | |
| class RotaryEmbedding(keras.layers.Layer): | |
| def __init__(self, dim, max_len=2048, theta=10000, **kwargs): | |
| super().__init__(**kwargs) | |
| self.dim = dim | |
| self.max_len = max_len | |
| self.theta = theta | |
| self.built_cache = False | |
| def build(self, input_shape): | |
| if not self.built_cache: | |
| inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim)) | |
| t = tf.range(self.max_len, dtype=tf.float32) | |
| freqs = tf.einsum("i,j->ij", t, inv_freq) | |
| emb = tf.concat([freqs, freqs], axis=-1) | |
| self.cos_cached = tf.constant(tf.cos(emb), dtype=tf.float32) | |
| self.sin_cached = tf.constant(tf.sin(emb), dtype=tf.float32) | |
| self.built_cache = True | |
| super().build(input_shape) | |
| def rotate_half(self, x): | |
| x1, x2 = tf.split(x, 2, axis=-1) | |
| return tf.concat([-x2, x1], axis=-1) | |
| def call(self, q, k): | |
| seq_len = tf.shape(q)[2] | |
| dtype = q.dtype | |
| cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :] | |
| sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :] | |
| q_rotated = (q * cos) + (self.rotate_half(q) * sin) | |
| k_rotated = (k * cos) + (self.rotate_half(k) * sin) | |
| return q_rotated, k_rotated | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta}) | |
| return config | |
| class RMSNorm(keras.layers.Layer): | |
| def __init__(self, epsilon=1e-5, **kwargs): | |
| super().__init__(**kwargs) | |
| self.epsilon = epsilon | |
| def build(self, input_shape): | |
| self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones") | |
| def call(self, x): | |
| variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) | |
| return x * tf.math.rsqrt(variance + self.epsilon) * self.scale | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({"epsilon": self.epsilon}) | |
| return config | |
| class TransformerBlock(keras.layers.Layer): | |
| def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs): | |
| super().__init__(**kwargs) | |
| self.d_model = d_model | |
| self.n_heads = n_heads | |
| self.ff_dim = ff_dim | |
| self.dropout_rate = dropout | |
| self.max_len = max_len | |
| self.rope_theta = rope_theta | |
| self.head_dim = d_model // n_heads | |
| self.layer_idx = layer_idx | |
| self.pre_attn_norm = RMSNorm() | |
| self.pre_ffn_norm = RMSNorm() | |
| self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj") | |
| self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj") | |
| self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj") | |
| self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj") | |
| self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta) | |
| self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj") | |
| self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj") | |
| self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj") | |
| self.dropout = keras.layers.Dropout(dropout) | |
| def call(self, x, training=None): | |
| B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model | |
| dtype = x.dtype | |
| res = x | |
| y = self.pre_attn_norm(x) | |
| q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) | |
| k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) | |
| v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) | |
| q, k = self.rope(q, k) | |
| scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype)) | |
| mask = tf.where(tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0, tf.constant(-1e9, dtype=dtype), tf.constant(0.0, dtype=dtype)) | |
| scores += mask | |
| attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v) | |
| attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D]) | |
| x = res + self.dropout(self.out_proj(attn), training=training) | |
| res = x | |
| y = self.pre_ffn_norm(x) | |
| ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y)) | |
| return res + self.dropout(ffn, training=training) | |
| def get_config(self): | |
| config = super().get_config() | |
| config.update({"d_model": self.d_model, "n_heads": self.n_heads, "ff_dim": self.ff_dim, "dropout": self.dropout_rate, "max_len": self.max_len, "rope_theta": | |
| # PART 2 - OPTIMIZED FOR SPEED (Replace your existing Part 2) | |
| self.rope_theta, "layer_idx": self.layer_idx}) | |
| return config | |
| class SAM1Model(keras.Model): | |
| def __init__(self, **kwargs): | |
| super().__init__() | |
| if 'config' in kwargs and isinstance(kwargs['config'], dict): | |
| self.cfg = kwargs['config'] | |
| elif 'vocab_size' in kwargs: | |
| self.cfg = kwargs | |
| else: | |
| self.cfg = kwargs.get('cfg', kwargs) | |
| self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens") | |
| ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult']) | |
| block_args = {'d_model': self.cfg['d_model'], 'n_heads': self.cfg['n_heads'], 'ff_dim': ff_dim, 'dropout': self.cfg['dropout'], 'max_len': self.cfg['max_len'], 'rope_theta': self.cfg['rope_theta']} | |
| self.blocks = [] | |
| for i in range(self.cfg['n_layers']): | |
| block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args) | |
| self.blocks.append(block) | |
| self.norm = RMSNorm(name="final_norm") | |
| self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head") | |
| def call(self, input_ids, training=None): | |
| x = self.embed(input_ids) | |
| for block in self.blocks: | |
| x = block(x, training=training) | |
| return self.lm_head(self.norm(x)) | |
| def get_config(self): | |
| base_config = super().get_config() | |
| base_config['config'] = self.cfg | |
| return base_config | |
| def count_parameters(model): | |
| total_params = 0 | |
| non_zero_params = 0 | |
| for weight in model.weights: | |
| w = weight.numpy() | |
| total_params += w.size | |
| non_zero_params += np.count_nonzero(w) | |
| return total_params, non_zero_params | |
| def format_param_count(count): | |
| if count >= 1e9: | |
| return f"{count/1e9:.2f}B" | |
| elif count >= 1e6: | |
| return f"{count/1e6:.2f}M" | |
| elif count >= 1e3: | |
| return f"{count/1e3:.2f}K" | |
| else: | |
| return str(count) | |
| # ============================================================================ | |
| # QUANTIZATION UTILITIES (NEW!) | |
| # ============================================================================ | |
| def quantize_model_int8(model): | |
| """ | |
| Apply INT8 quantization to model weights for faster CPU inference. | |
| This reduces memory and speeds up matmul operations significantly. | |
| """ | |
| print(" 🔧 Applying INT8 quantization...") | |
| quantized_weights = [] | |
| scales = [] | |
| for weight in model.weights: | |
| w = weight.numpy() | |
| # Calculate scale factor | |
| w_max = np.abs(w).max() | |
| if w_max > 0: | |
| scale = w_max / 127.0 | |
| # Quantize to int8 | |
| w_quantized = np.clip(np.round(w / scale), -127, 127).astype(np.int8) | |
| else: | |
| scale = 1.0 | |
| w_quantized = w.astype(np.int8) | |
| quantized_weights.append(w_quantized) | |
| scales.append(scale) | |
| print(" ✅ Quantization complete! Memory reduced by ~75%") | |
| return quantized_weights, scales | |
| class ModelBackend(ABC): | |
| def predict(self, input_ids): | |
| pass | |
| def get_name(self): | |
| pass | |
| def get_info(self): | |
| pass | |
| # ============================================================================ | |
| # OPTIMIZED KERAS BACKEND WITH QUANTIZATION | |
| # ============================================================================ | |
| class KerasBackend(ModelBackend): | |
| def __init__(self, model, name, display_name, use_quantization=True): | |
| self.model = model | |
| self.name = name | |
| self.display_name = display_name | |
| self.use_quantization = use_quantization | |
| # Quantize model weights for faster inference | |
| if use_quantization: | |
| self.quantized_weights, self.scales = quantize_model_int8(model) | |
| # Create optimized quantized prediction function | |
| def fast_predict_quantized(inputs): | |
| # Run model in float16 for speed | |
| with tf.device('/CPU:0'): | |
| logits = model(inputs, training=False) | |
| return logits | |
| self.fast_predict = fast_predict_quantized | |
| else: | |
| # Standard prediction without quantization | |
| def fast_predict(inputs): | |
| return model(inputs, training=False) | |
| self.fast_predict = fast_predict | |
| print(f" 🔥 Warming up {display_name}...") | |
| dummy = tf.constant([[1, 2, 3]], dtype=tf.int32) | |
| _ = self.fast_predict(dummy) | |
| print(f" ✅ Compilation complete!") | |
| total, non_zero = count_parameters(model) | |
| self.total_params = total | |
| self.non_zero_params = non_zero | |
| self.sparsity = (1 - non_zero / total) * 100 if total > 0 else 0 | |
| self.n_heads = model.cfg.get('n_heads', 0) | |
| self.ff_dim = int(model.cfg.get('d_model', 0) * model.cfg.get('ff_mult', 0)) | |
| def predict(self, input_ids): | |
| inputs = tf.constant([input_ids], dtype=tf.int32) | |
| logits = self.fast_predict(inputs) | |
| return logits[0, -1, :].numpy() | |
| def get_name(self): | |
| return self.display_name | |
| def get_info(self): | |
| info = f"{self.display_name}\n" | |
| info += f" Total params: {format_param_count(self.total_params)}\n" | |
| info += f" Attention heads: {self.n_heads}\n" | |
| info += f" FFN dimension: {self.ff_dim}\n" | |
| if self.use_quantization: | |
| info += f" Quantization: INT8 ⚡\n" | |
| if self.sparsity > 1: | |
| info += f" Sparsity: {self.sparsity:.1f}%\n" | |
| return info | |
| MODEL_REGISTRY = [ | |
| ("SAM-X-1-Large", "Smilyai-labs/Sam-1x-instruct", "ckpt.weights.h5", None), | |
| ("SAM-X-1-Fast ⚡ (BETA)", "Smilyai-labs/Sam-X-1-fast", "sam1_fast_finetuned.weights.h5", "sam1_fast_finetuned_config.json"), | |
| ("SAM-X-1-Mini 🚀 (ADVANCED!)", "Smilyai-labs/Sam-X-1-Mini", "sam1_mini_finetuned.weights.h5", "sam1_mini_finetuned_config.json"), | |
| ("SAM-X-1-Nano ⚡⚡", "Smilyai-labs/Sam-X-1-Nano", "sam1_nano_finetuned.weights.h5", "sam1_nano_finetuned_config.json"), | |
| ] | |
| def estimate_prompt_complexity(prompt): | |
| prompt_lower = prompt.lower() | |
| complexity_score = 0 | |
| word_count = len(prompt.split()) | |
| if word_count > 100: | |
| complexity_score += 3 | |
| elif word_count > 50: | |
| complexity_score += 2 | |
| elif word_count > 20: | |
| complexity_score += 1 | |
| hard_keywords = ['analyze', 'explain', 'compare', 'evaluate', 'prove', 'derive', 'calculate', 'solve', 'reason', 'why', 'how does', 'complex', 'algorithm', 'mathematics', 'philosophy', 'theory', 'logic', 'detailed', 'comprehensive', 'thorough', 'in-depth'] | |
| for keyword in hard_keywords: | |
| if keyword in prompt_lower: | |
| complexity_score += 2 | |
| medium_keywords = ['write', 'create', 'generate', 'summarize', 'describe', 'list', 'what is', 'tell me', 'explain briefly'] | |
| for keyword in medium_keywords: | |
| if keyword in prompt_lower: | |
| complexity_score += 1 | |
| if any(word in prompt_lower for word in ['code', 'function', 'program', 'debug', 'implement']): | |
| complexity_score += 2 | |
| if any(word in prompt_lower for word in ['first', 'then', 'next', 'finally', 'step']): | |
| complexity_score += 1 | |
| question_marks = prompt.count('?') | |
| if question_marks > 1: | |
| complexity_score += 1 | |
| return complexity_score | |
| def select_model_auto(prompt, available_models_dict, user_available_models): | |
| complexity = estimate_prompt_complexity(prompt) | |
| accessible = {k: v for k, v in available_models_dict.items() if k in user_available_models} | |
| if not accessible: | |
| return None | |
| if complexity <= 2: | |
| preferred = "SAM-X-1-Nano ⚡⚡" | |
| fallback_order = ["SAM-X-1-Mini 🚀 (ADVANCED!)", "SAM-X-1-Fast ⚡ (BETA)", "SAM-X-1-Large"] | |
| elif complexity <= 5: | |
| preferred = "SAM-X-1-Mini 🚀 (ADVANCED!)" | |
| fallback_order = ["SAM-X-1-Nano ⚡⚡", "SAM-X-1-Fast ⚡ (BETA)", "SAM-X-1-Large"] | |
| elif complexity <= 8: | |
| preferred = "SAM-X-1-Fast ⚡ (BETA)" | |
| fallback_order = ["SAM-X-1-Mini 🚀 (ADVANCED!)", "SAM-X-1-Large", "SAM-X-1-Nano ⚡⚡"] | |
| else: | |
| preferred = "SAM-X-1-Large" | |
| fallback_order = ["SAM-X-1-Fast ⚡ (BETA)", "SAM-X-1-Mini 🚀 (ADVANCED!)", "SAM-X-1-Nano ⚡⚡"] | |
| if preferred in accessible: | |
| return accessible[preferred] | |
| for model_name in fallback_order: | |
| if model_name in accessible: | |
| return accessible[model_name] | |
| return list(accessible.values())[0] | |
| CONFIG_TOKENIZER_REPO_ID = "Smilyai-labs/Sam-1-large-it-0002" | |
| print("="*80) | |
| print("🤖 SAM-X-1 Multi-Model Chat Interface".center(80)) | |
| print("="*80) | |
| print(f"\n📦 Downloading config/tokenizer from: {CONFIG_TOKENIZER_REPO_ID}") | |
| config_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="config.json") | |
| tokenizer_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="tokenizer.json") | |
| with open(config_path, 'r') as f: | |
| base_config = json.load(f) | |
| print(f"✅ Base config loaded") | |
| base_model_config = {'vocab_size': base_config['vocab_size'], 'd_model': base_config['hidden_size'], 'n_heads': base_config['num_attention_heads'], 'ff_mult': base_config['intermediate_size'] / base_config['hidden_size'], 'dropout': base_config.get('dropout', 0.0), 'max_len': base_config['max_position_embeddings'], 'rope_theta': base_config['rope_theta'], 'n_layers': base_config['num_hidden_layers']} | |
| print("\n🔤 Recreating tokenizer...") | |
| tokenizer = Tokenizer.from_pretrained("gpt2") | |
| eos_token = "<|endoftext|>" | |
| eos_token_id = tokenizer.token_to_id(eos_token) | |
| if eos_token_id is None: | |
| tokenizer.add_special_tokens([eos_token]) | |
| eos_token_id = tokenizer.token_to_id(eos_token) | |
| custom_tokens = ["<think>", "<think/>"] | |
| for token in custom_tokens: | |
| if tokenizer.token_to_id(token) is None: | |
| tokenizer.add_special_tokens([token]) | |
| tokenizer.no_padding() | |
| tokenizer.enable_truncation(max_length=base_config['max_position_embeddings']) | |
| print(f"✅ Tokenizer ready (vocab size: {tokenizer.get_vocab_size()})") | |
| print(f" EOS token: '{eos_token}' (ID: {eos_token_id})") | |
| if eos_token_id is None: | |
| raise ValueError("❌ Failed to set EOS token ID!") | |
| print("\n" + "="*80) | |
| print("📦 LOADING MODELS".center(80)) | |
| print("="*80) | |
| available_models = {} | |
| dummy_input = tf.zeros((1, 1), dtype=tf.int32) | |
| # Enable mixed precision for faster inference | |
| print("\n⚡ Enabling mixed precision for CPU optimization...") | |
| tf.keras.mixed_precision.set_global_policy('mixed_float16') | |
| for display_name, repo_id, weights_filename, config_filename in MODEL_REGISTRY: | |
| try: | |
| print(f"\n⏳ Loading: {display_name}") | |
| print(f" Repo: {repo_id}") | |
| print(f" Weights: {weights_filename}") | |
| weights_path = hf_hub_download(repo_id=repo_id, filename=weights_filename) | |
| if config_filename: | |
| print(f" Config: {config_filename}") | |
| custom_config_path = hf_hub_download(repo_id=repo_id, filename=config_filename) | |
| with open(custom_config_path, 'r') as f: | |
| model_config = json.load(f) | |
| print(f" 📐 Custom architecture: {model_config['n_heads']} heads") | |
| else: | |
| model_config = base_model_config.copy() | |
| model = SAM1Model(**model_config) | |
| model(dummy_input) | |
| model.load_weights(weights_path) | |
| model.trainable = False | |
| # Use quantized backend for speed | |
| backend = KerasBackend(model, display_name, display_name, use_quantization=True) | |
| available_models[display_name] = backend | |
| print(f" ✅ Loaded successfully!") | |
| print(f" 📊 Parameters: {format_param_count(backend.total_params)}") | |
| print(f" ⚡ INT8 quantization enabled - 5-10x faster!") | |
| except Exception as e: | |
| print(f" ⚠️ Failed to load: {e}") | |
| if not available_models: | |
| raise RuntimeError("❌ No models loaded!") | |
| print(f"\n✅ Successfully loaded {len(available_models)} model(s)") | |
| current_backend = list(available_models.values())[0] | |
| stop_generation = threading.Event() | |
| # ============================================================================ | |
| # ULTRA-OPTIMIZED GENERATION FUNCTION (5-10x FASTER!) | |
| # ============================================================================ | |
| def generate_response_stream(prompt, temperature=0.7, backend=None, max_tokens=256): | |
| global stop_generation | |
| stop_generation.clear() | |
| if backend is None: | |
| backend = current_backend | |
| encoded_prompt = tokenizer.encode(prompt) | |
| input_ids = [i for i in encoded_prompt.ids if i != eos_token_id] | |
| generated = input_ids.copy() | |
| current_text = "" | |
| in_thinking = False | |
| max_len = backend.model.cfg['max_len'] | |
| start_time = time.time() | |
| tokens_generated = 0 | |
| # OPTIMIZATION 1: Much less frequent decoding (15 tokens vs 2-8) | |
| decode_buffer = [] | |
| decode_every = 15 | |
| # OPTIMIZATION 2: Use smaller context window for faster inference | |
| context_window = min(512, max_len) # 512 is plenty for most cases | |
| # OPTIMIZATION 3: Pre-compute sampling parameters | |
| top_k = 5 | |
| for step in range(max_tokens): | |
| if stop_generation.is_set(): | |
| elapsed = time.time() - start_time | |
| final_speed = tokens_generated / elapsed if elapsed > 0 else 0 | |
| yield "", False, -1, final_speed, True | |
| return | |
| # OPTIMIZATION 4: Only use recent context (huge speedup!) | |
| current_input = generated[-context_window:] | |
| # Get next token prediction | |
| next_token_logits = backend.predict(current_input) | |
| # OPTIMIZATION 5: Faster sampling with numpy | |
| if temperature > 0: | |
| next_token_logits = next_token_logits / temperature | |
| # Fast top-k sampling | |
| top_k_indices = np.argpartition(next_token_logits, -top_k)[-top_k:] | |
| top_k_logits = next_token_logits[top_k_indices] | |
| # Fast softmax | |
| max_logit = np.max(top_k_logits) | |
| exp_logits = np.exp(top_k_logits - max_logit) | |
| probs = exp_logits / exp_logits.sum() | |
| next_token = top_k_indices[np.random.choice(top_k, p=probs)] | |
| else: | |
| next_token = np.argmax(next_token_logits) | |
| if next_token == eos_token_id: | |
| break | |
| generated.append(int(next_token)) | |
| decode_buffer.append(int(next_token)) | |
| tokens_generated += 1 | |
| # OPTIMIZATION 6: Decode only when necessary | |
| should_decode = ( | |
| len(decode_buffer) >= decode_every or | |
| step == max_tokens - 1 or | |
| step % 30 == 0 # Force UI update every 30 tokens | |
| ) | |
| if should_decode: | |
| new_text = tokenizer.decode(generated[len(input_ids):]) | |
| if len(new_text) > len(current_text): | |
| new_chunk = new_text[len(current_text):] | |
| current_text = new_text | |
| if "<think>" in new_chunk: | |
| in_thinking = True | |
| elif "</think>" in new_chunk or "<think/>" in new_chunk: | |
| in_thinking = False | |
| elapsed = time.time() - start_time | |
| tokens_per_sec = tokens_generated / elapsed if elapsed > 0 else 0 | |
| yield new_chunk, in_thinking, tokens_per_sec, tokens_per_sec, False | |
| decode_buffer = [] | |
| # Final decode | |
| final_text = tokenizer.decode(generated[len(input_ids):]) | |
| if len(final_text) > len(current_text): | |
| final_chunk = final_text[len(current_text):] | |
| elapsed = time.time() - start_time | |
| final_tokens_per_sec = tokens_generated / elapsed if elapsed > 0 else 0 | |
| yield final_chunk, False, final_tokens_per_sec, final_tokens_per_sec, False | |
| # PART 3 - Production-Grade Multi-Page UI (No Backend Changes) | |
| import secrets | |
| import json | |
| from datetime import datetime | |
| # Global session storage (unchanged from original) | |
| active_sessions = {} | |
| session_lock = threading.Lock() | |
| def generate_session_code(): | |
| with session_lock: | |
| while True: | |
| code = ''.join([str(secrets.randbelow(10)) for _ in range(4)]) | |
| if code not in active_sessions: | |
| return code | |
| def create_session(user_data): | |
| code = generate_session_code() | |
| with session_lock: | |
| normalized_data = { | |
| 'user_id': user_data.get('id') or user_data.get('user_id'), | |
| 'username': user_data.get('username'), | |
| 'plan': user_data.get('plan'), | |
| 'is_admin': user_data.get('is_admin', False) | |
| } | |
| active_sessions[code] = normalized_data | |
| return code | |
| def validate_session(code): | |
| with session_lock: | |
| return active_sessions.get(code, None) | |
| def invalidate_session(code): | |
| with session_lock: | |
| if code in active_sessions: | |
| del active_sessions[code] | |
| return True | |
| return False | |
| if __name__ == "__main__": | |
| import gradio as gr | |
| custom_css = """ | |
| /* Modern Production-Grade Styling */ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap'); | |
| * { font-family: 'Inter', sans-serif; } | |
| .app-container { max-width: 1600px; margin: 0 auto; } | |
| /* Dark Mode Support */ | |
| .dark-mode { background: #1a1a1a; color: #e5e5e5; } | |
| .dark-mode .nav-bar { background: linear-gradient(135deg, #4a5568 0%, #2d3748 100%); } | |
| .dark-mode .chat-container { background: #2d3748; border-color: #4a5568; } | |
| .dark-mode .assistant-message { background: #374151; border-color: #10a37f; } | |
| /* Navigation Bar */ | |
| .nav-bar { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| padding: 14px 28px; | |
| border-radius: 12px; | |
| margin-bottom: 20px; | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: center; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.15); | |
| position: sticky; | |
| top: 0; | |
| z-index: 100; | |
| } | |
| .nav-left { display: flex; align-items: center; gap: 20px; } | |
| .nav-brand { | |
| font-size: 22px; | |
| font-weight: 700; | |
| color: white; | |
| display: flex; | |
| align-items: center; | |
| gap: 8px; | |
| } | |
| .nav-right { display: flex; align-items: center; gap: 12px; } | |
| .user-greeting { | |
| color: white; | |
| font-weight: 500; | |
| font-size: 14px; | |
| display: flex; | |
| align-items: center; | |
| gap: 8px; | |
| padding: 6px 12px; | |
| background: rgba(255,255,255,0.15); | |
| border-radius: 20px; | |
| } | |
| /* Plan Badge */ | |
| .plan-badge { | |
| display: inline-block; | |
| padding: 4px 10px; | |
| border-radius: 12px; | |
| font-size: 10px; | |
| font-weight: 700; | |
| text-transform: uppercase; | |
| letter-spacing: 0.5px; | |
| animation: badge-glow 2s ease-in-out infinite; | |
| } | |
| @keyframes badge-glow { | |
| 0%, 100% { box-shadow: 0 0 5px rgba(255,255,255,0.3); } | |
| 50% { box-shadow: 0 0 15px rgba(255,255,255,0.5); } | |
| } | |
| .plan-free { background: linear-gradient(135deg, #e0e7ff 0%, #c7d2fe 100%); color: #3730a3; } | |
| .plan-plus { background: linear-gradient(135deg, #dbeafe 0%, #bfdbfe 100%); color: #1e40af; } | |
| .plan-pro { background: linear-gradient(135deg, #fef3c7 0%, #fde68a 100%); color: #92400e; } | |
| .plan-explore { background: linear-gradient(135deg, #d8b4fe 0%, #c4b5fd 100%); color: #7e22ce; } | |
| .plan-research { background: linear-gradient(135deg, #a5f3fc 0%, #67e8f9 100%); color: #0e7490; } | |
| .plan-vip { background: linear-gradient(135deg, #fbbf24 0%, #f59e0b 100%); color: #78350f; } | |
| .plan-Sam-X-1-Mini-release-speacil-plan { background: linear-gradient(135deg, #fbbf24 0%, #f59e0b 100%); color: #78350f; } | |
| /* Auth Page */ | |
| .auth-container { | |
| max-width: 440px; | |
| margin: 40px auto; | |
| background: white; | |
| padding: 36px; | |
| border-radius: 16px; | |
| box-shadow: 0 10px 40px rgba(0,0,0,0.1); | |
| animation: slideUp 0.4s ease-out; | |
| } | |
| @keyframes slideUp { | |
| from { opacity: 0; transform: translateY(20px); } | |
| to { opacity: 1; transform: translateY(0); } | |
| } | |
| .auth-title { | |
| font-size: 28px; | |
| font-weight: 700; | |
| text-align: center; | |
| margin-bottom: 6px; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| } | |
| .auth-subtitle { | |
| text-align: center; | |
| color: #6b7280; | |
| margin-bottom: 28px; | |
| font-size: 13px; | |
| } | |
| /* Chat Interface */ | |
| .chat-layout { display: flex; gap: 20px; } | |
| .chat-main { flex: 1; min-width: 0; } | |
| .chat-sidebar { width: 320px; flex-shrink: 0; } | |
| .chat-container { | |
| height: 520px; | |
| overflow-y: auto; | |
| padding: 20px; | |
| background: #f9fafb; | |
| border: 1px solid #e5e7eb; | |
| border-radius: 12px; | |
| margin-bottom: 12px; | |
| scroll-behavior: smooth; | |
| } | |
| .chat-container::-webkit-scrollbar { width: 6px; } | |
| .chat-container::-webkit-scrollbar-track { background: transparent; } | |
| .chat-container::-webkit-scrollbar-thumb { background: #cbd5e1; border-radius: 3px; } | |
| .chat-container::-webkit-scrollbar-thumb:hover { background: #94a3b8; } | |
| .user-message { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| padding: 14px 18px; | |
| margin: 10px 0; | |
| border-radius: 16px 16px 4px 16px; | |
| max-width: 75%; | |
| margin-left: auto; | |
| box-shadow: 0 2px 8px rgba(102, 126, 234, 0.3); | |
| animation: messageSlideIn 0.3s ease-out; | |
| } | |
| .assistant-message { | |
| background: white; | |
| padding: 14px 18px; | |
| margin: 10px 0; | |
| border-radius: 16px 16px 16px 4px; | |
| border-left: 3px solid #10a37f; | |
| max-width: 75%; | |
| box-shadow: 0 2px 8px rgba(0,0,0,0.06); | |
| animation: messageSlideIn 0.3s ease-out; | |
| position: relative; | |
| } | |
| @keyframes messageSlideIn { | |
| from { opacity: 0; transform: translateY(10px); } | |
| to { opacity: 1; transform: translateY(0); } | |
| } | |
| .message-content { | |
| color: #353740; | |
| line-height: 1.6; | |
| font-size: 14px; | |
| word-wrap: break-word; | |
| } | |
| .user-message .message-content { color: white; } | |
| /* Markdown Styling */ | |
| .message-content code { | |
| background: #f3f4f6; | |
| padding: 2px 6px; | |
| border-radius: 4px; | |
| font-family: 'Courier New', monospace; | |
| font-size: 13px; | |
| } | |
| .message-content pre { | |
| background: #1f2937; | |
| color: #e5e7eb; | |
| padding: 12px; | |
| border-radius: 8px; | |
| overflow-x: auto; | |
| margin: 8px 0; | |
| position: relative; | |
| } | |
| .message-content pre code { | |
| background: transparent; | |
| padding: 0; | |
| color: inherit; | |
| } | |
| .message-content ul, .message-content ol { | |
| margin: 8px 0; | |
| padding-left: 20px; | |
| } | |
| .message-content li { margin: 4px 0; } | |
| .message-content strong { font-weight: 600; } | |
| .message-content em { font-style: italic; } | |
| .message-content a { color: #667eea; text-decoration: underline; } | |
| /* Code Copy Button */ | |
| .copy-button { | |
| position: absolute; | |
| top: 8px; | |
| right: 8px; | |
| background: rgba(255,255,255,0.1); | |
| border: 1px solid rgba(255,255,255,0.2); | |
| color: white; | |
| padding: 4px 8px; | |
| border-radius: 4px; | |
| font-size: 11px; | |
| cursor: pointer; | |
| opacity: 0; | |
| transition: all 0.2s; | |
| } | |
| .assistant-message:hover .copy-button { opacity: 1; } | |
| .copy-button:hover { background: rgba(255,255,255,0.2); } | |
| .thinking-content { | |
| color: #6b7280; | |
| font-style: italic; | |
| border-left: 3px solid #d1d5db; | |
| padding-left: 12px; | |
| margin: 10px 0; | |
| background: #f9fafb; | |
| padding: 10px 12px; | |
| border-radius: 6px; | |
| font-size: 13px; | |
| } | |
| /* Message Actions */ | |
| .message-actions { | |
| display: flex; | |
| gap: 8px; | |
| margin-top: 8px; | |
| opacity: 0; | |
| transition: opacity 0.2s; | |
| } | |
| .assistant-message:hover .message-actions { opacity: 1; } | |
| .action-btn { | |
| background: #f3f4f6; | |
| border: 1px solid #e5e7eb; | |
| padding: 4px 10px; | |
| border-radius: 6px; | |
| font-size: 12px; | |
| cursor: pointer; | |
| transition: all 0.2s; | |
| color: #6b7280; | |
| } | |
| .action-btn:hover { | |
| background: #e5e7eb; | |
| color: #374151; | |
| transform: translateY(-1px); | |
| } | |
| /* Input Area */ | |
| .input-container { | |
| background: white; | |
| border: 2px solid #e5e7eb; | |
| border-radius: 12px; | |
| padding: 4px; | |
| transition: all 0.2s; | |
| box-shadow: 0 2px 8px rgba(0,0,0,0.05); | |
| } | |
| .input-container:focus-within { | |
| border-color: #667eea; | |
| box-shadow: 0 4px 16px rgba(102, 126, 234, 0.2); | |
| } | |
| .input-row { | |
| display: flex; | |
| gap: 8px; | |
| align-items: flex-end; | |
| } | |
| .circular-btn { | |
| width: 46px !important; | |
| height: 46px !important; | |
| min-width: 46px !important; | |
| border-radius: 50% !important; | |
| padding: 0 !important; | |
| font-size: 20px !important; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.15) !important; | |
| transition: all 0.2s ease !important; | |
| border: none !important; | |
| } | |
| .circular-btn:hover:not(:disabled) { | |
| transform: scale(1.08) !important; | |
| box-shadow: 0 6px 16px rgba(0,0,0,0.25) !important; | |
| } | |
| .circular-btn:active:not(:disabled) { | |
| transform: scale(0.95) !important; | |
| } | |
| .send-btn { | |
| background: linear-gradient(135deg, #10a37f 0%, #0d8c6c 100%) !important; | |
| } | |
| .stop-btn { | |
| background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%) !important; | |
| } | |
| /* Token Counter */ | |
| .token-counter { | |
| font-size: 11px; | |
| color: #9ca3af; | |
| text-align: right; | |
| padding: 4px 8px; | |
| } | |
| /* Sidebar/Limits Panel */ | |
| .limits-panel { | |
| background: white; | |
| border: 1px solid #e5e7eb; | |
| border-radius: 12px; | |
| padding: 18px; | |
| margin-bottom: 14px; | |
| box-shadow: 0 2px 8px rgba(0,0,0,0.05); | |
| } | |
| .limit-header { | |
| font-weight: 700; | |
| margin-bottom: 14px; | |
| font-size: 16px; | |
| color: #1f2937; | |
| display: flex; | |
| justify-content: space-between; | |
| align-items: center; | |
| } | |
| .limit-item { | |
| display: flex; | |
| justify-content: space-between; | |
| padding: 10px 0; | |
| border-bottom: 1px solid #f3f4f6; | |
| align-items: center; | |
| } | |
| .limit-item:last-child { border-bottom: none; } | |
| .limit-label { | |
| font-size: 13px; | |
| color: #6b7280; | |
| font-weight: 500; | |
| } | |
| .limit-value { | |
| font-size: 13px; | |
| font-weight: 600; | |
| } | |
| .limit-exceeded { color: #dc2626; } | |
| .limit-ok { color: #059669; } | |
| .limit-warning { color: #f59e0b; } | |
| /* Progress Bar */ | |
| .progress-bar { | |
| height: 6px; | |
| background: #f3f4f6; | |
| border-radius: 3px; | |
| overflow: hidden; | |
| margin-top: 6px; | |
| } | |
| .progress-fill { | |
| height: 100%; | |
| background: linear-gradient(90deg, #10a37f 0%, #059669 100%); | |
| transition: width 0.3s ease; | |
| border-radius: 3px; | |
| } | |
| .progress-fill.warning { background: linear-gradient(90deg, #f59e0b 0%, #ea580c 100%); } | |
| .progress-fill.danger { background: linear-gradient(90deg, #ef4444 0%, #dc2626 100%); } | |
| /* Plans Section */ | |
| .plans-grid { | |
| display: grid; | |
| grid-template-columns: repeat(auto-fit, minmax(260px, 1fr)); | |
| gap: 20px; | |
| margin-top: 20px; | |
| } | |
| .plan-card { | |
| background: white; | |
| border: 2px solid #e5e7eb; | |
| border-radius: 14px; | |
| padding: 24px; | |
| transition: all 0.3s; | |
| position: relative; | |
| overflow: hidden; | |
| } | |
| .plan-card:hover { | |
| transform: translateY(-6px); | |
| box-shadow: 0 12px 28px rgba(0,0,0,0.15); | |
| border-color: #667eea; | |
| } | |
| .plan-card.featured { | |
| border: 3px solid #667eea; | |
| box-shadow: 0 8px 24px rgba(102, 126, 234, 0.25); | |
| transform: scale(1.02); | |
| } | |
| .plan-card.featured::before { | |
| content: '⭐ POPULAR'; | |
| position: absolute; | |
| top: 14px; | |
| right: -28px; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| padding: 3px 36px; | |
| font-size: 10px; | |
| font-weight: 700; | |
| letter-spacing: 1px; | |
| transform: rotate(45deg); | |
| } | |
| .plan-name { | |
| font-size: 22px; | |
| font-weight: 700; | |
| margin-bottom: 6px; | |
| color: #1f2937; | |
| } | |
| .plan-price { | |
| font-size: 13px; | |
| color: #6b7280; | |
| margin-bottom: 18px; | |
| } | |
| .plan-features { | |
| list-style: none; | |
| padding: 0; | |
| margin: 16px 0; | |
| } | |
| .plan-features li { | |
| padding: 6px 0; | |
| color: #4b5563; | |
| font-size: 13px; | |
| } | |
| .plan-features li::before { | |
| content: '✓ '; | |
| color: #10a37f; | |
| font-weight: 700; | |
| margin-right: 6px; | |
| } | |
| /* Speed Indicator */ | |
| .speed-indicator { | |
| text-align: center; | |
| padding: 10px; | |
| background: linear-gradient(135deg, #f0fdf4 0%, #dcfce7 100%); | |
| border-radius: 8px; | |
| font-weight: 600; | |
| color: #166534; | |
| margin-bottom: 10px; | |
| font-size: 13px; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| gap: 8px; | |
| } | |
| .speed-indicator.generating { | |
| background: linear-gradient(135deg, #dbeafe 0%, #bfdbfe 100%); | |
| color: #1e40af; | |
| animation: pulse 2s ease-in-out infinite; | |
| } | |
| @keyframes pulse { | |
| 0%, 100% { opacity: 1; } | |
| 50% { opacity: 0.8; } | |
| } | |
| .speed-indicator.error { | |
| background: linear-gradient(135deg, #fee2e2 0%, #fecaca 100%); | |
| color: #991b1b; | |
| } | |
| /* Toast Notifications */ | |
| .toast { | |
| position: fixed; | |
| top: 80px; | |
| right: 20px; | |
| background: white; | |
| padding: 12px 18px; | |
| border-radius: 8px; | |
| box-shadow: 0 4px 12px rgba(0,0,0,0.15); | |
| display: flex; | |
| align-items: center; | |
| gap: 10px; | |
| z-index: 1000; | |
| animation: toastSlide 0.3s ease-out; | |
| } | |
| @keyframes toastSlide { | |
| from { transform: translateX(400px); opacity: 0; } | |
| to { transform: translateX(0); opacity: 1; } | |
| } | |
| .toast.success { border-left: 4px solid #10a37f; } | |
| .toast.error { border-left: 4px solid #ef4444; } | |
| .toast.info { border-left: 4px solid #3b82f6; } | |
| /* Settings Panel */ | |
| .settings-panel { | |
| background: white; | |
| border: 1px solid #e5e7eb; | |
| border-radius: 12px; | |
| padding: 16px; | |
| margin-bottom: 14px; | |
| } | |
| /* Keyboard Shortcuts */ | |
| .kbd { | |
| display: inline-block; | |
| padding: 2px 6px; | |
| background: #f3f4f6; | |
| border: 1px solid #d1d5db; | |
| border-radius: 4px; | |
| font-family: monospace; | |
| font-size: 11px; | |
| color: #4b5563; | |
| } | |
| /* Empty State */ | |
| .empty-state { | |
| text-align: center; | |
| padding: 60px 20px; | |
| color: #9ca3af; | |
| } | |
| .empty-state-icon { font-size: 48px; margin-bottom: 16px; opacity: 0.5; } | |
| .empty-state-title { font-size: 18px; font-weight: 600; color: #6b7280; margin-bottom: 8px; } | |
| .empty-state-subtitle { font-size: 14px; color: #9ca3af; } | |
| /* Skeleton Loader */ | |
| .skeleton { | |
| background: linear-gradient(90deg, #f3f4f6 25%, #e5e7eb 50%, #f3f4f6 75%); | |
| background-size: 200% 100%; | |
| animation: shimmer 1.5s infinite; | |
| border-radius: 4px; | |
| } | |
| @keyframes shimmer { | |
| 0% { background-position: 200% 0; } | |
| 100% { background-position: -200% 0; } | |
| } | |
| /* Responsive */ | |
| @media (max-width: 1024px) { | |
| .chat-layout { flex-direction: column; } | |
| .chat-sidebar { width: 100%; } | |
| } | |
| @media (max-width: 768px) { | |
| .nav-bar { flex-direction: column; gap: 12px; padding: 12px 16px; } | |
| .nav-left, .nav-right { width: 100%; justify-content: center; } | |
| .chat-container { height: 400px; } | |
| .plans-grid { grid-template-columns: 1fr; } | |
| .user-message, .assistant-message { max-width: 90%; } | |
| } | |
| /* Smooth Transitions */ | |
| * { transition: background-color 0.2s, border-color 0.2s, color 0.2s; } | |
| button { transition: all 0.2s !important; } | |
| """ | |
| # Greeting variations | |
| def get_greeting(username): | |
| import random | |
| greetings = [ | |
| f"Hey {username}! 👋", | |
| f"Welcome back, {username}! ✨", | |
| f"Hi {username}! 🚀", | |
| f"Hello {username}! 😊", | |
| f"Great to see you, {username}! 🎉", | |
| f"What's up, {username}? 💫", | |
| f"Howdy, {username}! 🤠", | |
| f"Yo {username}! 🔥" | |
| ] | |
| return random.choice(greetings) | |
| # Markdown rendering (simple version) | |
| def render_markdown(text): | |
| """Simple markdown rendering for common patterns""" | |
| import re | |
| # Code blocks | |
| text = re.sub(r'```(\w+)?\n(.*?)```', r'<pre><code class="\1">\2</code><button class="copy-button" onclick="copyCode(this)">Copy</button></pre>', text, flags=re.DOTALL) | |
| # Inline code | |
| text = re.sub(r'`([^`]+)`', r'<code>\1</code>', text) | |
| # Bold | |
| text = re.sub(r'\*\*(.+?)\*\*', r'<strong>\1</strong>', text) | |
| text = re.sub(r'__(.+?)__', r'<strong>\1</strong>', text) | |
| # Italic | |
| text = re.sub(r'\*(.+?)\*', r'<em>\1</em>', text) | |
| text = re.sub(r'_(.+?)_', r'<em>\1</em>', text) | |
| # Links | |
| text = re.sub(r'\[([^\]]+)\]\(([^\)]+)\)', r'<a href="\2" target="_blank">\1</a>', text) | |
| # Lists | |
| text = re.sub(r'^- (.+)$', r'<li>\1</li>', text, flags=re.MULTILINE) | |
| text = re.sub(r'^(\d+)\. (.+)$', r'<li>\2</li>', text, flags=re.MULTILINE) | |
| text = re.sub(r'(<li>.*?</li>\n?)+', r'<ul>\g<0></ul>', text, flags=re.DOTALL) | |
| # Line breaks | |
| text = text.replace('\n', '<br>') | |
| return text | |
| # Format message HTML with markdown and actions | |
| def format_message_html(role, content, show_thinking=True, message_id=None): | |
| role_class = "user-message" if role == "user" else "assistant-message" | |
| thinking = "" | |
| answer = "" | |
| # Extract thinking | |
| if "<think>" in content: | |
| parts = content.split("<think>", 1) | |
| before_think = parts[0].strip() | |
| if len(parts) > 1: | |
| after_think = parts[1] | |
| if "</think>" in after_think: | |
| think_parts = after_think.split("</think>", 1) | |
| thinking = think_parts[0].strip() | |
| answer = (before_think + " " + think_parts[1]).strip() | |
| elif "<think/>" in after_think: | |
| think_parts = after_think.split("<think/>", 1) | |
| thinking = think_parts[0].strip() | |
| answer = (before_think + " " + think_parts[1]).strip() | |
| else: | |
| thinking = after_think.strip() | |
| answer = before_think | |
| else: | |
| answer = before_think | |
| else: | |
| answer = content | |
| # Render markdown | |
| answer = render_markdown(answer) | |
| html = f'<div class="{role_class}" id="msg-{message_id}"><div class="message-content">' | |
| if thinking and show_thinking: | |
| html += f'<div class="thinking-content">💭 {render_markdown(thinking)}</div>' | |
| if answer: | |
| html += f'<div>{answer}</div>' | |
| # Add message actions for assistant messages | |
| if role == "assistant": | |
| html += ''' | |
| <div class="message-actions"> | |
| <button class="action-btn" onclick="copyMessage(this)">📋 Copy</button> | |
| <button class="action-btn" onclick="regenerateResponse(this)">🔄 Regenerate</button> | |
| </div> | |
| ''' | |
| html += '</div></div>' | |
| return html | |
| def render_history(history, show_thinking): | |
| if not history: | |
| return ''' | |
| <div class="empty-state"> | |
| <div class="empty-state-icon">💬</div> | |
| <div class="empty-state-title">No messages yet</div> | |
| <div class="empty-state-subtitle">Start a conversation by typing below</div> | |
| </div> | |
| ''' | |
| html = "" | |
| for idx, msg in enumerate(history): | |
| html += format_message_html(msg["role"], msg["content"], show_thinking, idx) | |
| return html | |
| def render_limits_panel(user_data): | |
| if not user_data or 'user_id' not in user_data: | |
| return "" | |
| info = get_user_limits_info(user_data['user_id']) | |
| if not info: | |
| return "" | |
| plan_class = f"plan-{info['plan'].lower()}" | |
| html = f'''<div class="limits-panel"> | |
| <div class="limit-header"> | |
| <span>Usage Limits <span class="plan-badge {plan_class}">{info["plan"]}</span></span> | |
| </div> | |
| <div style="font-size: 12px; color: #6b7280; margin-bottom: 14px; padding: 8px; background: #f9fafb; border-radius: 6px; text-align: center;"> | |
| ⏰ <strong>{info["reset_in"]}</strong> until reset | |
| </div>''' | |
| models_info = [ | |
| ('Nano ⚡', info['nano_used'], info['nano_limit']), | |
| ('Mini 🚀', info['mini_used'], info['mini_limit']), | |
| ('Fast ⚡', info['fast_used'], info['fast_limit']), | |
| ('Large 💎', info['large_used'], info['large_limit']) | |
| ] | |
| for model_name, used, limit in models_info: | |
| if limit == -1: | |
| percentage = 0 | |
| status_class = "limit-ok" | |
| status_text = f'{used} / ∞' | |
| bar_class = "" | |
| else: | |
| percentage = min((used / limit * 100), 100) | |
| remaining = limit - used | |
| if remaining <= 0: | |
| status_class = "limit-exceeded" | |
| status_text = f'{used}/{limit}' | |
| bar_class = "danger" | |
| elif remaining <= 2: | |
| status_class = "limit-warning" | |
| status_text = f'{used}/{limit}' | |
| bar_class = "warning" | |
| else: | |
| status_class = "limit-ok" | |
| status_text = f'{used}/{limit}' | |
| bar_class = "" | |
| html += f''' | |
| <div class="limit-item"> | |
| <span class="limit-label">{model_name}</span> | |
| <span class="limit-value {status_class}">{status_text}</span> | |
| </div> | |
| <div class="progress-bar"> | |
| <div class="progress-fill {bar_class}" style="width: {percentage}%"></div> | |
| </div> | |
| ''' | |
| html += '</div>' | |
| return html | |
| with gr.Blocks(css=custom_css, title="SAM-X-1 AI Chat", theme=gr.themes.Soft(primary_hue="slate")) as demo: | |
| # JavaScript for interactive features | |
| gr.HTML(""" | |
| <script> | |
| function copyCode(button) { | |
| const pre = button.parentElement; | |
| const code = pre.querySelector('code').textContent; | |
| navigator.clipboard.writeText(code).then(() => { | |
| button.textContent = 'Copied!'; | |
| setTimeout(() => button.textContent = 'Copy', 2000); | |
| }); | |
| } | |
| function copyMessage(button) { | |
| const messageDiv = button.closest('.assistant-message'); | |
| const content = messageDiv.querySelector('.message-content').textContent; | |
| navigator.clipboard.writeText(content).then(() => { | |
| showToast('Message copied!', 'success'); | |
| }); | |
| } | |
| function regenerateResponse(button) { | |
| showToast('Regeneration feature coming soon!', 'info'); | |
| } | |
| function showToast(message, type = 'info') { | |
| const toast = document.createElement('div'); | |
| toast.className = `toast ${type}`; | |
| toast.innerHTML = ` | |
| <span>${type === 'success' ? '✓' : type === 'error' ? '✗' : 'ℹ'}</span> | |
| <span>${message}</span> | |
| `; | |
| document.body.appendChild(toast); | |
| setTimeout(() => toast.remove(), 3000); | |
| } | |
| // Keyboard shortcuts | |
| document.addEventListener('keydown', function(e) { | |
| // Ctrl/Cmd + K for search (future feature) | |
| if ((e.ctrlKey || e.metaKey) && e.key === 'k') { | |
| e.preventDefault(); | |
| showToast('Search coming soon!', 'info'); | |
| } | |
| // Esc to stop generation | |
| if (e.key === 'Escape') { | |
| const stopBtn = document.querySelector('.stop-btn'); | |
| if (stopBtn && !stopBtn.disabled) stopBtn.click(); | |
| } | |
| }); | |
| // Auto-scroll chat to bottom | |
| function scrollChatToBottom() { | |
| const chatContainer = document.querySelector('.chat-container'); | |
| if (chatContainer) { | |
| chatContainer.scrollTop = chatContainer.scrollHeight; | |
| } | |
| } | |
| // Call after messages update | |
| setInterval(scrollChatToBottom, 500); | |
| </script> | |
| """) | |
| # State management | |
| session_code = gr.State("") | |
| user_data = gr.State(None) | |
| chat_history = gr.State([]) | |
| # Navigation Bar | |
| with gr.Row(elem_classes="nav-bar"): | |
| with gr.Column(scale=1, elem_classes="nav-left"): | |
| gr.HTML('<div class="nav-brand">🤖 SAM-X-1 <span style="font-size: 12px; opacity: 0.8; font-weight: 400;">v3.0</span></div>') | |
| with gr.Column(scale=2, elem_classes="nav-right"): | |
| user_greeting = gr.HTML('<div class="user-greeting">Please sign in</div>') | |
| with gr.Row(): | |
| upgrade_nav_btn = gr.Button("⭐ Upgrade", size="sm", visible=False) | |
| logout_nav_btn = gr.Button("🚪 Logout", size="sm", visible=False) | |
| # AUTH PAGE | |
| with gr.Group(visible=True) as auth_page: | |
| with gr.Column(elem_classes="auth-container"): | |
| gr.HTML('<div class="auth-title">Welcome to SAM-X-1</div>') | |
| gr.HTML('<div class="auth-subtitle">Sign in or create account • Auto-detects new users</div>') | |
| auth_username = gr.Textbox( | |
| label="Username", | |
| placeholder="Enter your username", | |
| elem_id="auth-username" | |
| ) | |
| auth_password = gr.Textbox( | |
| label="Password", | |
| type="password", | |
| placeholder="Enter your password", | |
| elem_id="auth-password" | |
| ) | |
| auth_email = gr.Textbox( | |
| label="Email (optional, for new accounts)", | |
| placeholder="your@email.com" | |
| ) | |
| auth_btn = gr.Button("Continue →", variant="primary", size="lg") | |
| auth_msg = gr.Markdown("") | |
| gr.Markdown(""" | |
| <div style="text-align: center; margin-top: 20px; font-size: 12px; color: #9ca3af;"> | |
| <p>🔐 Secure authentication • 🆓 Free tier available</p> | |
| <p>Press <span class="kbd">Enter</span> to continue</p> | |
| </div> | |
| """) | |
| # CHAT PAGE | |
| with gr.Group(visible=False) as chat_page: | |
| with gr.Row(elem_classes="chat-layout"): | |
| # Main Chat Area | |
| with gr.Column(elem_classes="chat-main"): | |
| chat_html = gr.HTML(value='') | |
| speed_display = gr.HTML('<div class="speed-indicator">⚡ Ready to chat</div>') | |
| with gr.Column(elem_classes="input-container"): | |
| with gr.Row(elem_classes="input-row"): | |
| msg_input = gr.Textbox( | |
| placeholder="Ask me anything... (Shift+Enter for new line)", | |
| show_label=False, | |
| scale=10, | |
| lines=1, | |
| max_lines=5, | |
| elem_id="chat-input" | |
| ) | |
| send_btn = gr.Button("▶", elem_classes=["circular-btn", "send-btn"]) | |
| stop_btn = gr.Button("⏹", elem_classes=["circular-btn", "stop-btn"], visible=False) | |
| token_counter = gr.HTML('<div class="token-counter">0 / 256 tokens</div>') | |
| with gr.Row(): | |
| clear_btn = gr.Button("🗑️ Clear Chat", size="sm") | |
| new_chat_btn = gr.Button("➕ New Chat", size="sm", variant="primary") | |
| export_btn = gr.Button("📥 Export", size="sm") | |
| # Sidebar | |
| with gr.Column(elem_classes="chat-sidebar"): | |
| limits_display = gr.HTML("") | |
| with gr.Accordion("⚙️ Settings", open=False, elem_classes="settings-panel"): | |
| model_selector = gr.Dropdown( | |
| choices=["🤖 Auto (Recommended)"], | |
| value="🤖 Auto (Recommended)", | |
| label="Model Selection", | |
| info="AI picks the best model" | |
| ) | |
| max_tokens_slider = gr.Slider( | |
| minimum=64, maximum=512, value=256, step=64, | |
| label="Max Tokens", | |
| info="Response length limit" | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.0, maximum=2.0, value=0.7, step=0.1, | |
| label="Temperature", | |
| info="Creativity level" | |
| ) | |
| show_thinking_checkbox = gr.Checkbox( | |
| label="💭 Show Thinking Process", | |
| value=True, | |
| info="See AI's reasoning" | |
| ) | |
| with gr.Accordion("ℹ️ Tips & Shortcuts", open=False): | |
| gr.Markdown(""" | |
| ### Keyboard Shortcuts | |
| - <span class="kbd">Enter</span> - Send message | |
| - <span class="kbd">Shift+Enter</span> - New line | |
| - <span class="kbd">Esc</span> - Stop generation | |
| - <span class="kbd">Ctrl+K</span> - Search (soon) | |
| ### Tips | |
| - Be specific in your questions | |
| - Use markdown for formatting | |
| - Auto mode picks the best model | |
| - Check limits panel regularly | |
| """) | |
| # UPGRADE PAGE | |
| with gr.Group(visible=False) as upgrade_page: | |
| gr.HTML(''' | |
| <div style="text-align: center; margin-bottom: 32px;"> | |
| <div style="font-size: 32px; font-weight: 700; margin-bottom: 8px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent;"> | |
| Choose Your Plan | |
| </div> | |
| <div style="font-size: 16px; color: #6b7280;"> | |
| Unlock more power and flexibility | |
| </div> | |
| </div> | |
| ''') | |
| gr.HTML(''' | |
| <div class="plans-grid"> | |
| <div class="plan-card"> | |
| <div class="plan-name">Free 🆓</div> | |
| <div class="plan-price">Perfect for getting started</div> | |
| <ul class="plan-features"> | |
| <li>Nano: Unlimited messages</li> | |
| <li>Mini: Unlimited messages</li> | |
| <li>Fast: 10 messages/3h</li> | |
| <li>Large: 8 messages/3h</li> | |
| <li>Auto model selection</li> | |
| <li>256 max tokens</li> | |
| <li>Community support</li> | |
| </ul> | |
| </div> | |
| <div class="plan-card featured"> | |
| <div class="plan-name">Plus ⭐</div> | |
| <div class="plan-price">Great for power users</div> | |
| <ul class="plan-features"> | |
| <li>Everything in Free</li> | |
| <li>Fast: Unlimited messages</li> | |
| <li>Large: 20 messages/3h</li> | |
| <li>Manual model selection</li> | |
| <li>384 max tokens</li> | |
| <li>Priority support</li> | |
| <li>Advanced settings</li> | |
| </ul> | |
| </div> | |
| <div class="plan-card"> | |
| <div class="plan-name">Explore 🔍</div> | |
| <div class="plan-price">For curious learners</div> | |
| <ul class="plan-features"> | |
| <li>Everything in Free</li> | |
| <li>Nano & Mini: Unlimited</li> | |
| <li>Fast: 14 messages/3h</li> | |
| <li>Large: 10 messages/3h</li> | |
| <li>Manual model selection</li> | |
| <li>512 max tokens</li> | |
| <li>Extended support</li> | |
| </ul> | |
| </div> | |
| <div class="plan-card featured"> | |
| <div class="plan-name">Pro 💎</div> | |
| <div class="plan-price">For professionals</div> | |
| <ul class="plan-features"> | |
| <li>Everything in Plus</li> | |
| <li>All models unlimited</li> | |
| <li>512 max tokens</li> | |
| <li>Fastest reset (3h)</li> | |
| <li>24/7 premium support</li> | |
| <li>Early feature access</li> | |
| <li>API access (soon)</li> | |
| </ul> | |
| </div> | |
| <div class="plan-card featured"> | |
| <div class="plan-name">Sam-X-1-Mini-release-speacil-plan 🎉</div> | |
| <div class="plan-price">For everyone! Apply for this plan for a 100% success rate! Help us celebrate the release of Sam-Mini-X-1!!!</div> | |
| <ul class="plan-features"> | |
| <li>Everything unlimited!🎉</li> | |
| <li>All models unlimited🎉</li> | |
| <li>Infinite (almost) max tokens🎉</li> | |
| <li>Fastest reset (3h)🎉</li> | |
| <li>24/7 premium support🎉</li> | |
| <li>Early feature access🎉</li> | |
| <li>API access (soon)🎉</li> | |
| <li>Only Limited(2 weeks. ends on of Nov 1 approx. Subject to change)😅</li> | |
| </ul> | |
| </div> | |
| <div class="plan-card"> | |
| <div class="plan-name">Research 🔬</div> | |
| <div class="plan-price">For researchers & educators</div> | |
| <ul class="plan-features"> | |
| <li>Everything in Pro</li> | |
| <li>Extended limits (1000+ msgs)</li> | |
| <li>1024 max tokens</li> | |
| <li>Batch processing</li> | |
| <li>Custom fine-tuning</li> | |
| <li>Dedicated support</li> | |
| <li>Academic discount</li> | |
| </ul> | |
| </div> | |
| </div> | |
| ''') | |
| gr.Markdown("### 📝 Request an Upgrade") | |
| gr.Markdown("Fill out the form below and an admin will review your request within 24 hours.") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| upgrade_plan_choice = gr.Radio( | |
| choices=["plus", "pro", "explore", "Research", "Sam-X-1-Mini-release-speacil-plan"], | |
| label="Select Plan", | |
| value="plus" | |
| ) | |
| upgrade_reason = gr.Textbox( | |
| label="Why do you need this upgrade?", | |
| placeholder="Tell us about your use case, what you're building, or why you need more access...", | |
| lines=4 | |
| ) | |
| with gr.Row(): | |
| submit_upgrade_btn = gr.Button("Submit Request 📨", variant="primary", size="lg", scale=2) | |
| back_to_chat_btn = gr.Button("← Back to Chat", size="lg", scale=1) | |
| upgrade_msg = gr.Markdown("") | |
| # ADMIN PAGE | |
| with gr.Group(visible=False) as admin_page: | |
| gr.HTML(''' | |
| <div style="text-align: center; margin-bottom: 24px;"> | |
| <div style="font-size: 28px; font-weight: 700; color: #1f2937;"> | |
| 👨💼 Admin Dashboard | |
| </div> | |
| </div> | |
| ''') | |
| with gr.Tabs(): | |
| with gr.Tab("👥 User Management"): | |
| with gr.Row(): | |
| refresh_users_btn = gr.Button("🔄 Refresh Users", size="sm") | |
| users_table = gr.Dataframe( | |
| headers=["ID", "Username", "Email", "Plan", "Created", "Admin"], | |
| wrap=True | |
| ) | |
| gr.Markdown("### ✏️ Update User Plan") | |
| with gr.Row(): | |
| admin_username = gr.Textbox(label="Username", scale=2, placeholder="username") | |
| admin_new_plan = gr.Dropdown( | |
| choices=["free", "plus", "pro", "explore", "Research", "VIP", "Sam-X-1-Mini-release-speacil-plan"], | |
| label="New Plan", | |
| value="free", | |
| scale=1 | |
| ) | |
| update_plan_btn = gr.Button("Update Plan", variant="primary", scale=1) | |
| admin_msg = gr.Markdown("") | |
| with gr.Tab("📋 Upgrade Requests"): | |
| with gr.Row(): | |
| refresh_requests_btn = gr.Button("🔄 Refresh Requests", size="sm") | |
| requests_table = gr.Dataframe( | |
| headers=["ID", "Username", "Plan", "Reason", "Date"], | |
| wrap=True | |
| ) | |
| gr.Markdown("### 🔍 Review Request") | |
| request_id_input = gr.Number( | |
| label="Request ID (from table above)", | |
| precision=0, | |
| minimum=1, | |
| info="Enter the ID number from the first column" | |
| ) | |
| with gr.Row(): | |
| approve_req_btn = gr.Button("✅ Approve Request", variant="primary", size="lg") | |
| deny_req_btn = gr.Button("❌ Deny Request", variant="stop", size="lg") | |
| request_msg = gr.Markdown("") | |
| with gr.Tab("📊 Analytics (Coming Soon)"): | |
| gr.Markdown(""" | |
| ### 📈 Usage Statistics | |
| - Total users: Coming soon | |
| - Active users (24h): Coming soon | |
| - Total messages: Coming soon | |
| - Most used model: Coming soon | |
| - Average tokens/message: Coming soon | |
| """) | |
| # ==================== EVENT HANDLERS ==================== | |
| def handle_auth(username, password, email): | |
| """Unified auth - auto signup if new, FIX: Handle both 'id' and 'user_id'""" | |
| if len(username) < 3: | |
| return ( | |
| None, None, "❌ Username must be at least 3 characters", | |
| gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), "" | |
| ) | |
| if len(password) < 6: | |
| return ( | |
| None, None, "❌ Password must be at least 6 characters", | |
| gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), "" | |
| ) | |
| # Try login first | |
| success, data = authenticate_user(username, password) | |
| if not success: | |
| # Try signup | |
| success, message = create_user(username, password, email) | |
| if success: | |
| # Auto-login after signup | |
| success, data = authenticate_user(username, password) | |
| if not success: | |
| return ( | |
| None, None, "❌ Account created but login failed", | |
| gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), "" | |
| ) | |
| else: | |
| return ( | |
| None, None, f"❌ {message}", | |
| gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), "" | |
| ) | |
| # FIX: Normalize data to always have 'user_id' | |
| if 'id' in data and 'user_id' not in data: | |
| data['user_id'] = data['id'] | |
| # Generate session | |
| code = create_session(data) | |
| # Get user info | |
| info = get_user_limits_info(data['user_id']) | |
| if not info: | |
| return ( | |
| None, None, "❌ Could not load user info", | |
| gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), "" | |
| ) | |
| plan_class = f"plan-{info['plan'].lower()}" | |
| greeting_html = f'<div class="user-greeting">{get_greeting(username)} <span class="plan-badge {plan_class}">{info["plan"]}</span></div>' | |
| # Set model choices | |
| if info['can_choose_model']: | |
| model_choices = ["🤖 Auto (Recommended)"] + list(available_models.keys()) | |
| else: | |
| model_choices = ["🤖 Auto (Recommended)"] | |
| limits_html = render_limits_panel(data) | |
| empty_chat = render_history([], True) | |
| return ( | |
| code, | |
| data, | |
| f"✅ Welcome, **{username}**! Your session is active.", | |
| gr.update(visible=False), # auth_page | |
| gr.update(visible=True), # chat_page | |
| gr.update(visible=data.get('is_admin', False)), # admin_page | |
| greeting_html, | |
| gr.update(visible=True), # upgrade_nav_btn | |
| gr.update(visible=True), # logout_nav_btn | |
| gr.update(choices=model_choices, value="🤖 Auto (Recommended)"), | |
| gr.update(maximum=info['max_tokens'], value=min(256, info['max_tokens'])), | |
| limits_html, | |
| empty_chat | |
| ) | |
| def show_upgrade_page(): | |
| return gr.update(visible=False), gr.update(visible=True) | |
| def back_to_chat(): | |
| return gr.update(visible=True), gr.update(visible=False) | |
| def handle_logout(code): | |
| if code: | |
| invalidate_session(code) | |
| return ( | |
| "", | |
| None, | |
| [], | |
| gr.update(visible=True), # auth_page | |
| gr.update(visible=False), # chat_page | |
| gr.update(visible=False), # admin_page | |
| gr.update(visible=False), # upgrade_page | |
| '<div class="user-greeting">Please sign in</div>', | |
| gr.update(visible=False), # upgrade_nav_btn | |
| gr.update(visible=False), # logout_nav_btn | |
| "", | |
| "" | |
| ) | |
| def send_message_handler(message, history, show_thinking, temperature, model_choice, max_tokens, code): | |
| global stop_generation | |
| stop_generation.clear() | |
| if not code: | |
| error_html = '<div class="speed-indicator error">❌ Session expired - please sign in again</div>' | |
| return "", history, "", error_html, gr.update(), gr.update(), "" | |
| data = validate_session(code) | |
| if not data: | |
| error_html = '<div class="speed-indicator error">❌ Session expired - please sign in again</div>' | |
| return "", history, "", error_html, gr.update(), gr.update(), "" | |
| if not message.strip(): | |
| return "", history, "", '<div class="speed-indicator">⚡ Ready to chat</div>', gr.update(), gr.update(), render_limits_panel(data) | |
| info = get_user_limits_info(data['user_id']) | |
| # Model selection | |
| if model_choice == "🤖 Auto (Recommended)" or not info['can_choose_model']: | |
| user_available = get_available_models_for_user(data['user_id']) | |
| if not user_available: | |
| error_html = '<div class="speed-indicator error">❌ No models available (limits reached)</div>' | |
| return "", history, "", error_html, gr.update(), gr.update(), render_limits_panel(data) | |
| backend = select_model_auto(message, available_models, user_available) | |
| if not backend: | |
| error_html = '<div class="speed-indicator error">❌ Could not select model</div>' | |
| return "", history, "", error_html, gr.update(), gr.update(), render_limits_panel(data) | |
| model_name = backend.get_name() | |
| else: | |
| model_name = model_choice | |
| can_use, msg = can_use_model(data['user_id'], model_name) | |
| if not can_use: | |
| error_html = f'<div class="speed-indicator error">❌ {msg}</div>' | |
| return "", history, "", error_html, gr.update(), gr.update(), render_limits_panel(data) | |
| backend = available_models[model_name] | |
| # Final check | |
| can_use, msg = can_use_model(data['user_id'], model_name) | |
| if not can_use: | |
| error_html = f'<div class="speed-indicator error">❌ {msg}</div>' | |
| return "", history, "", error_html, gr.update(), gr.update(), render_limits_panel(data) | |
| # Increment usage | |
| increment_model_usage(data['user_id'], model_name) | |
| # Add user message | |
| history.append({"role": "user", "content": message}) | |
| yield "", history, render_history(history, show_thinking), f'<div class="speed-indicator generating">⚡ Using {model_name}...</div>', gr.update(interactive=False), gr.update(visible=True), render_limits_panel(data) | |
| # Start generation | |
| prompt = f"User: {message}\nSam: <think>" | |
| history.append({"role": "assistant", "content": "<think>"}) | |
| actual_max_tokens = min(max_tokens, info['max_tokens']) | |
| last_speed = 0 | |
| was_stopped = False | |
| for chunk_data in generate_response_stream(prompt, temperature, backend, actual_max_tokens): | |
| if len(chunk_data) == 5: | |
| new_chunk, in_thinking, tokens_per_sec, avg_speed, stopped = chunk_data | |
| if stopped: | |
| was_stopped = True | |
| break | |
| if new_chunk: | |
| history[-1]["content"] += new_chunk | |
| last_speed = avg_speed | |
| yield "", history, render_history(history, show_thinking), f'<div class="speed-indicator generating">⚡ {tokens_per_sec:.1f} tok/s</div>', gr.update(interactive=False), gr.update(visible=True), render_limits_panel(data) | |
| if was_stopped: | |
| final_html = f'<div class="speed-indicator error">🛑 Stopped - {last_speed:.1f} tok/s</div>' | |
| else: | |
| final_html = f'<div class="speed-indicator">✅ Done - {last_speed:.1f} tok/s</div>' | |
| yield "", history, render_history(history, show_thinking), final_html, gr.update(interactive=True), gr.update(visible=False), render_limits_panel(data) | |
| def stop_generation_handler(): | |
| global stop_generation | |
| stop_generation.set() | |
| return '<div class="speed-indicator error">🛑 Stopping...</div>', gr.update(interactive=False), gr.update(visible=False) | |
| def clear_chat(history): | |
| empty = render_history([], True) | |
| return [], empty, '<div class="speed-indicator">⚡ Ready to chat</div>', gr.update(interactive=True), gr.update(visible=False) | |
| def export_chat(history): | |
| # Simple export as text | |
| text = "" | |
| for msg in history: | |
| role = "You" if msg["role"] == "user" else "SAM-X-1" | |
| text += f"{role}: {msg['content']}\n\n" | |
| return text | |
| def submit_upgrade_request(code, plan, reason): | |
| if not code: | |
| return "❌ Session expired" | |
| data = validate_session(code) | |
| if not data: | |
| return "❌ Session expired" | |
| if not reason.strip(): | |
| return "❌ Please provide a reason for your upgrade request" | |
| success, msg = request_upgrade(data['user_id'], plan, reason) | |
| if success: | |
| return f"✅ {msg}\n\nAn admin will review your request within 24 hours. You'll be notified via email if provided." | |
| return f"❌ {msg}" | |
| def load_all_users(): | |
| users = get_all_users() | |
| formatted = [] | |
| for user in users: | |
| formatted.append([ | |
| user[0], | |
| user[1], | |
| user[2] or "N/A", | |
| user[3], | |
| user[4][:10] if user[4] else "N/A", | |
| "Yes" if user[5] else "No" | |
| ]) | |
| return formatted | |
| def load_pending_requests(): | |
| requests = get_pending_requests() | |
| formatted = [] | |
| for req in requests: | |
| formatted.append([ | |
| req[0], | |
| req[1], | |
| req[2], | |
| req[3][:100] + "..." if len(req[3]) > 100 else req[3], | |
| req[4][:10] if req[4] else "N/A" | |
| ]) | |
| return formatted | |
| def admin_update_plan_handler(username, new_plan): | |
| if not username or not new_plan: | |
| return "❌ Please fill all fields" | |
| success, msg = update_user_plan(username, new_plan) | |
| if success: | |
| return f"✅ {msg}\n\nThe user's limits have been reset and they now have access to {new_plan} features." | |
| return f"❌ {msg}" | |
| def admin_approve_request_handler(request_id): | |
| if not request_id: | |
| return "❌ Please enter a request ID" | |
| success, msg = approve_request(int(request_id)) | |
| if success: | |
| return f"✅ {msg}\n\nThe user has been upgraded and can now access their new plan features." | |
| return f"❌ {msg}" | |
| def admin_deny_request_handler(request_id): | |
| if not request_id: | |
| return "❌ Please enter a request ID" | |
| success, msg = deny_request(int(request_id)) | |
| if success: | |
| return f"✅ {msg}\n\nThe request has been marked as denied." | |
| return f"❌ {msg}" | |
| # ==================== WIRE UP EVENTS ==================== | |
| # Auth | |
| auth_outputs = [ | |
| session_code, user_data, auth_msg, auth_page, chat_page, admin_page, | |
| user_greeting, upgrade_nav_btn, logout_nav_btn, | |
| model_selector, max_tokens_slider, limits_display, chat_html | |
| ] | |
| auth_btn.click(handle_auth, [auth_username, auth_password, auth_email], auth_outputs) | |
| auth_password.submit(handle_auth, [auth_username, auth_password, auth_email], auth_outputs) | |
| # Navigation | |
| upgrade_nav_btn.click(show_upgrade_page, outputs=[chat_page, upgrade_page]) | |
| back_to_chat_btn.click(back_to_chat, outputs=[chat_page, upgrade_page]) | |
| logout_outputs = [ | |
| session_code, user_data, chat_history, auth_page, chat_page, admin_page, upgrade_page, | |
| user_greeting, upgrade_nav_btn, logout_nav_btn, chat_html, limits_display | |
| ] | |
| logout_nav_btn.click(handle_logout, [session_code], logout_outputs) | |
| # Chat | |
| send_outputs = [msg_input, chat_history, chat_html, speed_display, send_btn, stop_btn, limits_display] | |
| send_btn.click( | |
| send_message_handler, | |
| [msg_input, chat_history, show_thinking_checkbox, temperature_slider, model_selector, max_tokens_slider, session_code], | |
| send_outputs | |
| ) | |
| msg_input.submit( | |
| send_message_handler, | |
| [msg_input, chat_history, show_thinking_checkbox, temperature_slider, model_selector, max_tokens_slider, session_code], | |
| send_outputs | |
| ) | |
| stop_btn.click(stop_generation_handler, outputs=[speed_display, send_btn, stop_btn]) | |
| clear_btn.click(clear_chat, [chat_history], [chat_history, chat_html, speed_display, send_btn, stop_btn]) | |
| new_chat_btn.click(clear_chat, [chat_history], [chat_history, chat_html, speed_display, send_btn, stop_btn]) | |
| # Upgrade | |
| submit_upgrade_btn.click( | |
| submit_upgrade_request, | |
| [session_code, upgrade_plan_choice, upgrade_reason], | |
| [upgrade_msg] | |
| ) | |
| # Admin | |
| refresh_users_btn.click(load_all_users, outputs=[users_table]) | |
| refresh_requests_btn.click(load_pending_requests, outputs=[requests_table]) | |
| update_plan_btn.click(admin_update_plan_handler, [admin_username, admin_new_plan], [admin_msg]) | |
| approve_req_btn.click(admin_approve_request_handler, [request_id_input], [request_msg]) | |
| deny_req_btn.click(admin_deny_request_handler, [request_id_input], [request_msg]) | |
| demo.launch( | |
| debug=True, | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| favicon_path=None, | |
| show_error=True | |
| ) | |