sam-x-api / app.py
Bc-AI's picture
Update app.py (#2)
728df68 verified
import os
os.environ['KERAS_BACKEND'] = 'tensorflow'
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
import keras
import numpy as np
from tokenizers import Tokenizer
from huggingface_hub import hf_hub_download
import json
from abc import ABC, abstractmethod
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator
import asyncio
import gradio as gr
from gradio import HTML
# ==============================================================================
# Model Architecture (Same as before)
# ==============================================================================
@keras.saving.register_keras_serializable()
class RotaryEmbedding(keras.layers.Layer):
def __init__(self, dim, max_len=2048, theta=10000, **kwargs):
super().__init__(**kwargs)
self.dim = dim
self.max_len = max_len
self.theta = theta
self.built_cache = False
def build(self, input_shape):
if not self.built_cache:
inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim))
t = tf.range(self.max_len, dtype=tf.float32)
freqs = tf.einsum("i,j->ij", t, inv_freq)
emb = tf.concat([freqs, freqs], axis=-1)
self.cos_cached = tf.constant(tf.cos(emb), dtype=tf.float32)
self.sin_cached = tf.constant(tf.sin(emb), dtype=tf.float32)
self.built_cache = True
super().build(input_shape)
def rotate_half(self, x):
x1, x2 = tf.split(x, 2, axis=-1)
return tf.concat([-x2, x1], axis=-1)
def call(self, q, k):
seq_len = tf.shape(q)[2]
dtype = q.dtype
cos = tf.cast(self.cos_cached[:seq_len, :], dtype)[None, None, :, :]
sin = tf.cast(self.sin_cached[:seq_len, :], dtype)[None, None, :, :]
q_rotated = (q * cos) + (self.rotate_half(q) * sin)
k_rotated = (k * cos) + (self.rotate_half(k) * sin)
return q_rotated, k_rotated
def get_config(self):
config = super().get_config()
config.update({"dim": self.dim, "max_len": self.max_len, "theta": self.theta})
return config
@keras.saving.register_keras_serializable()
class RMSNorm(keras.layers.Layer):
def __init__(self, epsilon=1e-5, **kwargs):
super().__init__(**kwargs)
self.epsilon = epsilon
def build(self, input_shape):
self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones")
def call(self, x):
variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True)
return x * tf.math.rsqrt(variance + self.epsilon) * self.scale
def get_config(self):
config = super().get_config()
config.update({"epsilon": self.epsilon})
return config
@keras.saving.register_keras_serializable()
class TransformerBlock(keras.layers.Layer):
def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs):
super().__init__(**kwargs)
self.d_model = d_model
self.n_heads = n_heads
self.ff_dim = ff_dim
self.dropout_rate = dropout
self.max_len = max_len
self.rope_theta = rope_theta
self.head_dim = d_model // n_heads
self.layer_idx = layer_idx
self.pre_attn_norm = RMSNorm()
self.pre_ffn_norm = RMSNorm()
self.q_proj = keras.layers.Dense(d_model, use_bias=False, name="q_proj")
self.k_proj = keras.layers.Dense(d_model, use_bias=False, name="k_proj")
self.v_proj = keras.layers.Dense(d_model, use_bias=False, name="v_proj")
self.out_proj = keras.layers.Dense(d_model, use_bias=False, name="o_proj")
self.rope = RotaryEmbedding(self.head_dim, max_len=max_len, theta=rope_theta)
self.gate_proj = keras.layers.Dense(ff_dim, use_bias=False, name="gate_proj")
self.up_proj = keras.layers.Dense(ff_dim, use_bias=False, name="up_proj")
self.down_proj = keras.layers.Dense(d_model, use_bias=False, name="down_proj")
self.dropout = keras.layers.Dropout(dropout)
def call(self, x, training=None):
B, T, D = tf.shape(x)[0], tf.shape(x)[1], self.d_model
dtype = x.dtype
res = x
y = self.pre_attn_norm(x)
q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3])
q, k = self.rope(q, k)
scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype))
mask = tf.where(
tf.linalg.band_part(tf.ones([T, T], dtype=dtype), -1, 0) == 0,
tf.constant(-1e9, dtype=dtype),
tf.constant(0.0, dtype=dtype)
)
scores += mask
attn = tf.matmul(tf.nn.softmax(scores, axis=-1), v)
attn = tf.reshape(tf.transpose(attn, [0, 2, 1, 3]), [B, T, D])
x = res + self.dropout(self.out_proj(attn), training=training)
res = x
y = self.pre_ffn_norm(x)
ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y))
return res + self.dropout(ffn, training=training)
def get_config(self):
config = super().get_config()
config.update({
"d_model": self.d_model,
"n_heads": self.n_heads,
"ff_dim": self.ff_dim,
"dropout": self.dropout_rate,
"max_len": self.max_len,
"rope_theta": self.rope_theta,
"layer_idx": self.layer_idx
})
return config
@keras.saving.register_keras_serializable()
class SAM1Model(keras.Model):
def __init__(self, **kwargs):
super().__init__()
if 'config' in kwargs and isinstance(kwargs['config'], dict):
self.cfg = kwargs['config']
elif 'vocab_size' in kwargs:
self.cfg = kwargs
else:
self.cfg = kwargs.get('cfg', kwargs)
self.embed = keras.layers.Embedding(self.cfg['vocab_size'], self.cfg['d_model'], name="embed_tokens")
# βœ… FIXED: Was using 'ff_num' β€” now correctly uses 'ff_dim'
ff_dim = int(self.cfg['d_model'] * self.cfg['ff_mult'])
block_args = {
'd_model': self.cfg['d_model'],
'n_heads': self.cfg['n_heads'],
'ff_dim': ff_dim, # βœ… Correct variable name
'dropout': self.cfg['dropout'],
'max_len': self.cfg['max_len'],
'rope_theta': self.cfg['rope_theta']
}
self.blocks = []
for i in range(self.cfg['n_layers']):
block = TransformerBlock(name=f"block_{i}", layer_idx=i, **block_args)
self.blocks.append(block)
self.norm = RMSNorm(name="final_norm")
self.lm_head = keras.layers.Dense(self.cfg['vocab_size'], use_bias=False, name="lm_head")
def call(self, input_ids, training=None):
x = self.embed(input_ids)
for block in self.blocks:
x = block(x, training=training)
return self.lm_head(self.norm(x))
def get_config(self):
base_config = super().get_config()
base_config['config'] = self.cfg
return base_config
# ==============================================================================
# Helper Functions
# ==============================================================================
def count_parameters(model):
total_params = 0
non_zero_params = 0
for weight in model.weights:
w = weight.numpy()
total_params += w.size
non_zero_params += np.count_nonzero(w)
return total_params, non_zero_params
def format_param_count(count):
if count >= 1e9:
return f"{count/1e9:.2f}B"
elif count >= 1e6:
return f"{count/1e6:.2f}M"
elif count >= 1e3:
return f"{count/1e3:.2f}K"
else:
return str(count)
# ==============================================================================
# Backend Interface
# ==============================================================================
class ModelBackend(ABC):
@abstractmethod
def predict(self, input_ids): pass
@abstractmethod
def get_name(self): pass
@abstractmethod
def get_info(self): pass
class KerasBackend(ModelBackend):
def __init__(self, model, name, display_name):
self.model = model
self.name = name
self.display_name = display_name
total, non_zero = count_parameters(model)
self.total_params = total
self.non_zero_params = non_zero
self.sparsity = (1 - non_zero / total) * 100 if total > 0 else 0
self.n_heads = model.cfg.get('n_heads', 0)
self.ff_dim = int(model.cfg.get('d_model', 0) * model.cfg.get('ff_mult', 0))
def predict(self, input_ids):
inputs = np.array([input_ids], dtype=np.int32)
logits = self.model(inputs, training=False)
return logits[0, -1, :].numpy()
def get_name(self):
return self.display_name
def get_info(self):
info = f"{self.display_name}\n"
info += f" Total params: {format_param_count(self.total_params)}\n"
info += f" Attention heads: {self.n_heads}\n"
info += f" FFN dimension: {self.ff_dim}\n"
if self.sparsity > 1:
info += f" Sparsity: {self.sparsity:.1f}%\n"
return info
# ==============================================================================
# Load Models & Tokenizer
# ==============================================================================
CONFIG_TOKENIZER_REPO_ID = "Smilyai-labs/Sam-1-large-it-0002"
print("="*60)
print("πŸš€ SAM-X-1 API Server Loading...".center(60))
print("="*60)
# Download config/tokenizer
print(f"πŸ“¦ Fetching config & tokenizer from {CONFIG_TOKENIZER_REPO_ID}")
config_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="config.json")
tokenizer_path = hf_hub_download(repo_id=CONFIG_TOKENIZER_REPO_ID, filename="tokenizer.json")
with open(config_path, 'r') as f:
base_config = json.load(f)
base_model_config = {
'vocab_size': base_config['vocab_size'],
'd_model': base_config['hidden_size'],
'n_heads': base_config['num_attention_heads'],
'ff_mult': base_config['intermediate_size'] / base_config['hidden_size'],
'dropout': base_config.get('dropout', 0.0),
'max_len': base_config['max_position_embeddings'],
'rope_theta': base_config['rope_theta'],
'n_layers': base_config['num_hidden_layers']
}
print("πŸ”€ Building tokenizer...")
tokenizer = Tokenizer.from_pretrained("gpt2")
eos_token = ""
eos_token_id = tokenizer.token_to_id(eos_token)
if eos_token_id is None:
tokenizer.add_special_tokens([eos_token])
eos_token_id = tokenizer.token_to_id(eos_token)
custom_tokens = ["<think>", "<think/>"]
for token in custom_tokens:
if tokenizer.token_to_id(token) is None:
tokenizer.add_special_tokens([token])
tokenizer.no_padding()
tokenizer.enable_truncation(max_length=base_config['max_position_embeddings'])
print("βœ… Tokenizer ready")
# Model Registry
MODEL_REGISTRY = [
("SAM-X-1-Large", "Smilyai-labs/Sam-1x-instruct", "ckpt.weights.h5", None),
("SAM-X-1-Fast ⚑ (BETA)", "Smilyai-labs/Sam-X-1-fast", "sam1_fast_finetuned.weights.h5", "sam1_fast_finetuned_config.json"),
("SAM-X-1-Mini πŸš€ (BETA)", "Smilyai-labs/Sam-X-1-Mini", "sam1_mini.weights_finetuned.h5", "sam1_mini_finetuned_config.json"),
("SAM-X-1-Nano ⚑⚑ (BETA)", "Smilyai-labs/Sam-X-1-Nano", "sam1_nano_finetuned.weights.h5", "sam1_nano_finetuned_config.json"),
]
available_models = {}
dummy_input = tf.zeros((1, 1), dtype=tf.int32)
for display_name, repo_id, weights_filename, config_filename in MODEL_REGISTRY:
try:
print(f"\nπŸ“₯ Loading {display_name}...")
weights_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
model_config = base_model_config.copy()
if config_filename:
print(f" Custom config: {config_filename}")
custom_config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
with open(custom_config_path, 'r') as f:
model_config.update(json.load(f))
model = SAM1Model(**model_config)
model(dummy_input)
model.load_weights(weights_path)
model.trainable = False
backend = KerasBackend(model, display_name, display_name)
available_models[display_name] = backend
print(f"βœ… Loaded: {display_name}")
print(f" β†’ Params: {format_param_count(backend.total_params)} | Heads: {backend.n_heads}")
except Exception as e:
print(f"❌ Failed to load {display_name}: {e}")
if not available_models:
raise RuntimeError("No models loaded!")
current_backend = list(available_models.values())[0]
print(f"\nπŸŽ‰ Ready! Default model: {current_backend.get_name()}")
# ==============================================================================
# Streaming Generator
# ==============================================================================
async def generate_stream(prompt: str, backend, temperature: float) -> AsyncGenerator[str, None]: # βœ… Fixed type hint
encoded_prompt = tokenizer.encode(prompt)
input_ids = [i for i in encoded_prompt.ids if i != eos_token_id]
generated = input_ids.copy()
max_len = backend.model.cfg['max_len']
buffer = ""
for _ in range(512):
await asyncio.sleep(0)
current_input = generated[-max_len:]
next_token_logits = backend.predict(current_input)
if temperature > 0:
next_token_logits /= temperature
top_k_indices = np.argpartition(next_token_logits, -50)[-50:]
top_k_logits = next_token_logits[top_k_indices]
top_k_probs = np.exp(top_k_logits - np.max(top_k_logits))
top_k_probs /= top_k_probs.sum()
next_token = np.random.choice(top_k_indices, p=top_k_probs)
else:
next_token = int(np.argmax(next_token_logits))
if next_token == eos_token_id:
break
generated.append(int(next_token))
new_text = tokenizer.decode(generated[len(input_ids):])
if len(new_text) > len(buffer):
new_chunk = new_text[len(buffer):]
buffer = new_text
yield new_chunk
# ==============================================================================
# FastAPI Endpoints (OpenAI-style)
# ==============================================================================
class Message(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str = list(available_models.keys())[0]
messages: List[Message]
temperature: float = 0.7
stream: bool = False
max_tokens: int = 512
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
if request.model not in available_models:
raise HTTPException(404, f"Model '{request.model}' not found.")
backend = available_models[request.model]
prompt_parts = []
for msg in request.messages:
prefix = "User" if msg.role.lower() == "user" else "Sam"
prompt_parts.append(f"{prefix}: {msg.content}")
prompt_parts.append("Sam: <think>")
prompt = "\n".join(prompt_parts)
async def event_stream():
async for token in generate_stream(prompt, backend, request.temperature):
chunk = {
"id": "chatcmpl-123",
"object": "chat.completion.chunk",
"created": 1677858242,
"model": request.model,
"choices": [{
"index": 0,
"delta": {"content": token},
"finish_reason": None
}]
}
yield f" {json.dumps(chunk)}\n\n"
yield " [DONE]\n\n"
if request.stream:
return StreamingResponse(event_stream(), media_type="text/event-stream")
else:
full = ""
async for token in event_stream():
if "[DONE]" not in token:
data = json.loads(token.replace(" ", "").strip())
full += data["choices"][0]["delta"]["content"]
return {"choices": [{"message": {"content": full}}]}
@app.get("/v1/models")
async def list_models():
return {
"data": [
{"id": name, "object": "model", "owned_by": "SmilyAI"}
for name in available_models.keys()
]
}
# ==============================================================================
# Gradio App (API Info Page)
# ==============================================================================
def get_api_info():
model_info = "\n".join([f"- {name}" for name in available_models.keys()])
return f"""
# πŸ€– SAM-X-1 AI API Server
This is a production-grade API server for the SAM-X-1 family of models.
## πŸš€ Available Models:
{model_info}
## πŸ”Œ API Endpoints:
- `POST /v1/chat/completions` - Chat completions (OpenAI-style)
- `GET /v1/models` - List available models
## 🌊 Streaming:
Set `"stream": true` in your request to receive real-time token-by-token responses.
## πŸ§ͺ Example Request:
```json
{{
"model": "SAM-X-1-Large",
"messages": [
{{"role": "user", "content": "Hello!"}}
],
"stream": true,
"temperature": 0.7
}}
```
"""
# Create the Gradio app
with gr.Blocks(title="SAM-X-1 API") as demo:
gr.Markdown(get_api_info())
# Launch Gradio app with FastAPI mounted
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)