|
|
""" |
|
|
Sam-large-2 Distributed Inference - HEAD NODE |
|
|
Edit the CONFIG below, then deploy. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CONFIG = { |
|
|
"node_id": "head-main", |
|
|
"layer_start": 0, |
|
|
"layer_end": 6, |
|
|
"worker_urls": [], |
|
|
"secret_token": "sam2-distributed-secret-change-me", |
|
|
"model_repo": "Smilyai-labs/Sam-large-2", |
|
|
"cache_dir": "./model_cache", |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
NUM_CORES = os.cpu_count() or 4 |
|
|
|
|
|
os.environ['TF_NUM_INTEROP_THREADS'] = str(NUM_CORES) |
|
|
os.environ['TF_NUM_INTRAOP_THREADS'] = str(NUM_CORES) |
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' |
|
|
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '1' |
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
|
|
|
|
import json |
|
|
import time |
|
|
import io |
|
|
import base64 |
|
|
from typing import Dict, List, Optional, Tuple, Any |
|
|
|
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import requests |
|
|
import tensorflow as tf |
|
|
import keras |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
tf.config.threading.set_inter_op_parallelism_threads(NUM_CORES) |
|
|
tf.config.threading.set_intra_op_parallelism_threads(NUM_CORES) |
|
|
|
|
|
print(f"β
CPU optimized: {NUM_CORES} threads") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@keras.saving.register_keras_serializable() |
|
|
class RotaryEmbedding(keras.layers.Layer): |
|
|
def __init__(self, dim, max_len=2048, theta=10000, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.dim = dim |
|
|
self.max_len = max_len |
|
|
self.theta = theta |
|
|
self.built_cache = False |
|
|
self.cos_cached = None |
|
|
self.sin_cached = None |
|
|
|
|
|
def build(self, input_shape): |
|
|
super().build(input_shape) |
|
|
|
|
|
def _build_cache(self): |
|
|
if not self.built_cache: |
|
|
inv_freq = 1.0 / (self.theta ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim)) |
|
|
t = tf.range(self.max_len, dtype=tf.float32) |
|
|
freqs = tf.einsum("i,j->ij", t, inv_freq) |
|
|
emb = tf.concat([freqs, freqs], axis=-1) |
|
|
self.cos_cached = tf.constant(np.cos(emb.numpy()), dtype=tf.float32) |
|
|
self.sin_cached = tf.constant(np.sin(emb.numpy()), dtype=tf.float32) |
|
|
self.built_cache = True |
|
|
|
|
|
def rotate_half(self, x): |
|
|
x1, x2 = tf.split(x, 2, axis=-1) |
|
|
return tf.concat([-x2, x1], axis=-1) |
|
|
|
|
|
def call(self, q, k, offset=0): |
|
|
self._build_cache() |
|
|
seq_len = tf.shape(q)[2] |
|
|
dtype = q.dtype |
|
|
cos = tf.cast(self.cos_cached[offset:offset + seq_len, :], dtype)[None, None, :, :] |
|
|
sin = tf.cast(self.sin_cached[offset:offset + seq_len, :], dtype)[None, None, :, :] |
|
|
q_embed = (q * cos) + (self.rotate_half(q) * sin) |
|
|
k_embed = (k * cos) + (self.rotate_half(k) * sin) |
|
|
return q_embed, k_embed |
|
|
|
|
|
def get_config(self): |
|
|
return {**super().get_config(), "dim": self.dim, "max_len": self.max_len, "theta": self.theta} |
|
|
|
|
|
|
|
|
@keras.saving.register_keras_serializable() |
|
|
class RMSNorm(keras.layers.Layer): |
|
|
def __init__(self, epsilon=1e-5, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.epsilon = epsilon |
|
|
|
|
|
def build(self, input_shape): |
|
|
self.scale = self.add_weight(name="scale", shape=(input_shape[-1],), initializer="ones") |
|
|
super().build(input_shape) |
|
|
|
|
|
def call(self, x): |
|
|
variance = tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) |
|
|
return x * tf.math.rsqrt(variance + self.epsilon) * self.scale |
|
|
|
|
|
def get_config(self): |
|
|
return {**super().get_config(), "epsilon": self.epsilon} |
|
|
|
|
|
|
|
|
@keras.saving.register_keras_serializable() |
|
|
class TransformerBlock(keras.layers.Layer): |
|
|
def __init__(self, d_model, n_heads, ff_dim, dropout, max_len, rope_theta, layer_idx=0, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.d_model = d_model |
|
|
self.n_heads = n_heads |
|
|
self.ff_dim = ff_dim |
|
|
self.dropout_rate = dropout |
|
|
self.max_len = max_len |
|
|
self.rope_theta = rope_theta |
|
|
self.head_dim = d_model // n_heads |
|
|
self.layer_idx = layer_idx |
|
|
|
|
|
def build(self, input_shape): |
|
|
self.pre_attn_norm = RMSNorm(name="pre_attn_norm") |
|
|
self.pre_ffn_norm = RMSNorm(name="pre_ffn_norm") |
|
|
self.q_proj = keras.layers.Dense(self.d_model, use_bias=False, name="q_proj") |
|
|
self.k_proj = keras.layers.Dense(self.d_model, use_bias=False, name="k_proj") |
|
|
self.v_proj = keras.layers.Dense(self.d_model, use_bias=False, name="v_proj") |
|
|
self.out_proj = keras.layers.Dense(self.d_model, use_bias=False, name="o_proj") |
|
|
self.rope = RotaryEmbedding(self.head_dim, max_len=self.max_len, theta=self.rope_theta) |
|
|
self.gate_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="gate_proj") |
|
|
self.up_proj = keras.layers.Dense(self.ff_dim, use_bias=False, name="up_proj") |
|
|
self.down_proj = keras.layers.Dense(self.d_model, use_bias=False, name="down_proj") |
|
|
self.dropout = keras.layers.Dropout(self.dropout_rate) |
|
|
super().build(input_shape) |
|
|
|
|
|
def call(self, x, training=None, past_kv=None, use_cache=False): |
|
|
B, T = tf.shape(x)[0], tf.shape(x)[1] |
|
|
dtype = x.dtype |
|
|
|
|
|
res = x |
|
|
y = self.pre_attn_norm(x) |
|
|
|
|
|
q = tf.transpose(tf.reshape(self.q_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) |
|
|
k = tf.transpose(tf.reshape(self.k_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) |
|
|
v = tf.transpose(tf.reshape(self.v_proj(y), [B, T, self.n_heads, self.head_dim]), [0, 2, 1, 3]) |
|
|
|
|
|
past_len = tf.shape(past_kv[0])[2] if past_kv is not None else 0 |
|
|
q, k = self.rope(q, k, offset=past_len) |
|
|
|
|
|
if past_kv is not None: |
|
|
k = tf.concat([past_kv[0], k], axis=2) |
|
|
v = tf.concat([past_kv[1], v], axis=2) |
|
|
|
|
|
new_kv = (k, v) if use_cache else None |
|
|
|
|
|
scores = tf.matmul(q, k, transpose_b=True) / tf.sqrt(tf.cast(self.head_dim, dtype)) |
|
|
full_len = tf.shape(k)[2] |
|
|
q_pos = tf.range(past_len, past_len + T) |
|
|
k_pos = tf.range(full_len) |
|
|
mask = tf.where(q_pos[:, None] >= k_pos[None, :], 0.0, -1e9) |
|
|
scores = scores + tf.cast(mask[None, None, :, :], dtype) |
|
|
|
|
|
attn = tf.nn.softmax(scores, axis=-1) |
|
|
attn_out = tf.reshape(tf.transpose(tf.matmul(attn, v), [0, 2, 1, 3]), [B, T, self.d_model]) |
|
|
x = res + self.dropout(self.out_proj(attn_out), training=training) |
|
|
|
|
|
res = x |
|
|
y = self.pre_ffn_norm(x) |
|
|
ffn = self.down_proj(keras.activations.silu(self.gate_proj(y)) * self.up_proj(y)) |
|
|
return res + self.dropout(ffn, training=training), new_kv |
|
|
|
|
|
def get_config(self): |
|
|
return {**super().get_config(), "d_model": self.d_model, "n_heads": self.n_heads, |
|
|
"ff_dim": self.ff_dim, "dropout": self.dropout_rate, "max_len": self.max_len, |
|
|
"rope_theta": self.rope_theta, "layer_idx": self.layer_idx} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelState: |
|
|
def __init__(self): |
|
|
self.config = None |
|
|
self.tokenizer = None |
|
|
self.eos_token_id = 50256 |
|
|
self.embedding = None |
|
|
self.blocks: List = [] |
|
|
self.final_norm = None |
|
|
self.lm_head = None |
|
|
self.my_block_start = 0 |
|
|
self.my_block_end = 0 |
|
|
|
|
|
STATE = ModelState() |
|
|
stop_generation = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def serialize_tensor(tensor: tf.Tensor) -> str: |
|
|
buffer = io.BytesIO() |
|
|
np.save(buffer, tensor.numpy(), allow_pickle=False) |
|
|
return base64.b64encode(buffer.getvalue()).decode('utf-8') |
|
|
|
|
|
def deserialize_tensor(data: str) -> tf.Tensor: |
|
|
buffer = io.BytesIO(base64.b64decode(data)) |
|
|
return tf.constant(np.load(buffer, allow_pickle=False)) |
|
|
|
|
|
def serialize_kv_cache(past_kv): |
|
|
if past_kv is None: |
|
|
return None |
|
|
return [{"k": serialize_tensor(k), "v": serialize_tensor(v)} if k is not None else None for k, v in past_kv] |
|
|
|
|
|
def deserialize_kv_cache(data): |
|
|
if data is None: |
|
|
return None |
|
|
return [(deserialize_tensor(item["k"]), deserialize_tensor(item["v"])) if item else None for item in data] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def call_worker(url: str, hidden_states: tf.Tensor, past_kv=None, use_cache=False) -> Tuple[tf.Tensor, Any]: |
|
|
try: |
|
|
response = requests.post( |
|
|
f"{url.rstrip('/')}/api/forward", |
|
|
json={ |
|
|
"hidden_states": serialize_tensor(hidden_states), |
|
|
"past_kv": serialize_kv_cache(past_kv), |
|
|
"use_cache": use_cache, |
|
|
}, |
|
|
headers={"Authorization": f"Bearer {CONFIG['secret_token']}"}, |
|
|
timeout=120 |
|
|
) |
|
|
|
|
|
if response.status_code == 200: |
|
|
result = response.json() |
|
|
output = deserialize_tensor(result["hidden_states"]) |
|
|
new_kv = deserialize_kv_cache(result.get("past_kv")) |
|
|
return output, new_kv |
|
|
else: |
|
|
raise RuntimeError(f"Worker returned {response.status_code}") |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Worker call failed: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(): |
|
|
print("π Loading model...") |
|
|
|
|
|
config_path = hf_hub_download(CONFIG["model_repo"], "config.json", cache_dir=CONFIG["cache_dir"]) |
|
|
with open(config_path, 'r') as f: |
|
|
model_config = json.load(f) |
|
|
STATE.config = model_config |
|
|
|
|
|
from transformers import AutoTokenizer |
|
|
from tokenizers import Tokenizer |
|
|
|
|
|
hf_tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
hf_tokenizer.add_special_tokens({"additional_special_tokens": |
|
|
["<|im_start|>", "<|im_end|>", "<think>", "</think>", "<CONTINUE>", "<im end for model tun>"]}) |
|
|
os.makedirs("./temp_tokenizer", exist_ok=True) |
|
|
hf_tokenizer.save_pretrained("./temp_tokenizer") |
|
|
STATE.tokenizer = Tokenizer.from_file("./temp_tokenizer/tokenizer.json") |
|
|
STATE.eos_token_id = model_config.get('eos_token_id', 50256) |
|
|
|
|
|
weights_path = hf_hub_download(CONFIG["model_repo"], "ckpt.weights.h5", cache_dir=CONFIG["cache_dir"]) |
|
|
|
|
|
n_layers = model_config['num_hidden_layers'] |
|
|
d_model = model_config['hidden_size'] |
|
|
n_heads = model_config['num_attention_heads'] |
|
|
ff_dim = model_config['intermediate_size'] |
|
|
max_len = model_config['max_position_embeddings'] |
|
|
rope_theta = model_config['rope_theta'] |
|
|
vocab_size = model_config['vocab_size'] |
|
|
|
|
|
embedding = keras.layers.Embedding(vocab_size, d_model, name="embed_tokens") |
|
|
blocks = [TransformerBlock(d_model, n_heads, ff_dim, 0.0, max_len, rope_theta, i, name=f"block_{i}") |
|
|
for i in range(n_layers)] |
|
|
final_norm = RMSNorm(name="final_norm") |
|
|
lm_head = keras.layers.Dense(vocab_size, use_bias=False, name="lm_head") |
|
|
|
|
|
dummy = tf.zeros((1, 16), dtype=tf.int32) |
|
|
x = embedding(dummy) |
|
|
for block in blocks: |
|
|
x, _ = block(x) |
|
|
x = final_norm(x) |
|
|
_ = lm_head(x) |
|
|
|
|
|
class TempModel(keras.Model): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.embed = embedding |
|
|
self.blocks = blocks |
|
|
self.norm = final_norm |
|
|
self.lm_head = lm_head |
|
|
def call(self, x): |
|
|
x = self.embed(x) |
|
|
for b in self.blocks: |
|
|
x, _ = b(x) |
|
|
return self.lm_head(self.norm(x)) |
|
|
|
|
|
temp_model = TempModel() |
|
|
temp_model(dummy) |
|
|
temp_model.load_weights(weights_path) |
|
|
print("β
Weights loaded") |
|
|
|
|
|
STATE.my_block_start = CONFIG["layer_start"] |
|
|
STATE.my_block_end = CONFIG["layer_end"] if CONFIG["layer_end"] > 0 else n_layers |
|
|
|
|
|
STATE.embedding = embedding |
|
|
STATE.blocks = blocks[STATE.my_block_start:STATE.my_block_end] |
|
|
print(f"β
Loaded blocks {STATE.my_block_start} to {STATE.my_block_end - 1}") |
|
|
|
|
|
has_workers = len(CONFIG["worker_urls"]) > 0 |
|
|
if not has_workers: |
|
|
STATE.final_norm = final_norm |
|
|
STATE.lm_head = lm_head |
|
|
print("β
Loaded final norm and LM head (standalone mode)") |
|
|
|
|
|
print("π₯ Warming up...") |
|
|
dummy = tf.constant([[1, 2, 3]], dtype=tf.int32) |
|
|
x = STATE.embedding(dummy) |
|
|
for block in STATE.blocks: |
|
|
x, _ = block(x, use_cache=False) |
|
|
if STATE.lm_head: |
|
|
_ = STATE.lm_head(STATE.final_norm(x)) |
|
|
|
|
|
print("β
Model ready!") |
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_pass(input_ids: tf.Tensor, past_kv_local=None, past_kv_workers=None, use_cache=False): |
|
|
x = STATE.embedding(input_ids) |
|
|
|
|
|
new_local_kv = [] if use_cache else None |
|
|
for i, block in enumerate(STATE.blocks): |
|
|
block_past = past_kv_local[i] if past_kv_local else None |
|
|
x, kv = block(x, past_kv=block_past, use_cache=use_cache) |
|
|
if use_cache: |
|
|
new_local_kv.append(kv) |
|
|
|
|
|
new_worker_kv = {} if use_cache else None |
|
|
for worker_url in CONFIG["worker_urls"]: |
|
|
worker_past = past_kv_workers.get(worker_url) if past_kv_workers else None |
|
|
x, worker_kv = call_worker(worker_url, x, worker_past, use_cache) |
|
|
if use_cache: |
|
|
new_worker_kv[worker_url] = worker_kv |
|
|
|
|
|
if STATE.lm_head: |
|
|
logits = STATE.lm_head(STATE.final_norm(x)) |
|
|
else: |
|
|
logits = x |
|
|
|
|
|
return logits, new_local_kv, new_worker_kv |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample_token(logits, temperature, top_k, top_p, token_freq, rep_penalty): |
|
|
logits = np.array(logits) / temperature |
|
|
|
|
|
for tid, freq in token_freq.items(): |
|
|
if tid < len(logits): |
|
|
logits[tid] /= (rep_penalty ** freq) |
|
|
|
|
|
if 0 < top_k < len(logits): |
|
|
top_k_idx = np.argpartition(logits, -top_k)[-top_k:] |
|
|
top_k_logits = logits[top_k_idx] |
|
|
else: |
|
|
top_k_idx = np.arange(len(logits)) |
|
|
top_k_logits = logits |
|
|
|
|
|
top_k_logits = top_k_logits - np.max(top_k_logits) |
|
|
probs = np.exp(top_k_logits) |
|
|
probs /= probs.sum() |
|
|
|
|
|
if top_p < 1.0: |
|
|
sorted_idx = np.argsort(probs)[::-1] |
|
|
cumsum = np.cumsum(probs[sorted_idx]) |
|
|
cutoff = np.searchsorted(cumsum, top_p) + 1 |
|
|
nucleus_idx = sorted_idx[:cutoff] |
|
|
nucleus_probs = probs[nucleus_idx] |
|
|
nucleus_probs /= nucleus_probs.sum() |
|
|
sampled = np.random.choice(len(nucleus_probs), p=nucleus_probs) |
|
|
return int(top_k_idx[nucleus_idx[sampled]]) |
|
|
|
|
|
return int(top_k_idx[np.random.choice(len(probs), p=probs)]) |
|
|
|
|
|
|
|
|
def generate_stream(prompt: str, max_tokens=512, temperature=0.8, top_k=40, top_p=0.9, rep_penalty=1.1): |
|
|
global stop_generation |
|
|
stop_generation = False |
|
|
|
|
|
input_ids = [i for i in STATE.tokenizer.encode(prompt).ids if i != STATE.eos_token_id] |
|
|
if not input_ids: |
|
|
yield "Error: Empty prompt" |
|
|
return |
|
|
|
|
|
generated = "" |
|
|
token_freq = {} |
|
|
|
|
|
stop_ids = {STATE.eos_token_id, STATE.tokenizer.token_to_id("<|im_end|>"), |
|
|
STATE.tokenizer.token_to_id("<im end for model tun>")} |
|
|
stop_ids.discard(None) |
|
|
|
|
|
max_ctx = STATE.config['max_position_embeddings'] |
|
|
if len(input_ids) > max_ctx - max_tokens: |
|
|
input_ids = input_ids[-(max_ctx - max_tokens):] |
|
|
|
|
|
start = time.time() |
|
|
|
|
|
input_tensor = tf.constant([input_ids], dtype=tf.int32) |
|
|
try: |
|
|
logits, local_kv, worker_kv = forward_pass(input_tensor, None, None, use_cache=True) |
|
|
except Exception as e: |
|
|
yield f"Error: {e}" |
|
|
return |
|
|
|
|
|
next_logits = logits[0, -1, :].numpy() |
|
|
prefill_time = time.time() - start |
|
|
print(f"β‘ Prefill: {len(input_ids)} tokens in {prefill_time:.2f}s") |
|
|
|
|
|
decode_start = time.time() |
|
|
tokens_generated = 0 |
|
|
|
|
|
for _ in range(max_tokens): |
|
|
if stop_generation: |
|
|
yield generated + "\n\n*[Stopped]*" |
|
|
return |
|
|
|
|
|
next_id = sample_token(next_logits, temperature, top_k, top_p, token_freq, rep_penalty) |
|
|
|
|
|
if next_id in stop_ids: |
|
|
break |
|
|
|
|
|
token_freq[next_id] = token_freq.get(next_id, 0) + 1 |
|
|
generated += STATE.tokenizer.decode([next_id]) |
|
|
tokens_generated += 1 |
|
|
yield generated |
|
|
|
|
|
next_input = tf.constant([[next_id]], dtype=tf.int32) |
|
|
try: |
|
|
logits, local_kv, worker_kv = forward_pass(next_input, local_kv, worker_kv, use_cache=True) |
|
|
except Exception as e: |
|
|
yield generated + f"\n\n*[Error: {e}]*" |
|
|
return |
|
|
|
|
|
next_logits = logits[0, -1, :].numpy() |
|
|
|
|
|
if tokens_generated > 0: |
|
|
total = time.time() - start |
|
|
tps = tokens_generated / (time.time() - decode_start) |
|
|
workers = len(CONFIG["worker_urls"]) |
|
|
mode = f", {workers} workers" if workers else " standalone" |
|
|
generated += f"\n\n*[{tokens_generated} tokens in {total:.1f}s ({tps:.1f} tok/s){mode}]*" |
|
|
|
|
|
yield generated |
|
|
|
|
|
|
|
|
def format_prompt(message: str, history: list, reasoning: bool) -> str: |
|
|
prompt = "" |
|
|
for msg in history: |
|
|
if msg["role"] == "user": |
|
|
prompt += f"<|im_start|>user\n{msg['content']}<|im_end|>\n" |
|
|
elif msg["role"] == "assistant": |
|
|
content = msg['content'].split('*[')[0].strip() |
|
|
prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n" |
|
|
prompt += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n" |
|
|
if reasoning: |
|
|
prompt += "<think>" |
|
|
return prompt |
|
|
|
|
|
|
|
|
def chat_respond(message, history, max_tokens, temp, top_k, top_p, rep_pen, reasoning): |
|
|
if not message.strip(): |
|
|
yield history |
|
|
return |
|
|
|
|
|
prompt = format_prompt(message, history, reasoning) |
|
|
|
|
|
|
|
|
history = history + [{"role": "user", "content": message}] |
|
|
|
|
|
for text in generate_stream(prompt, max_tokens, temp, top_k, top_p, rep_pen): |
|
|
display = text |
|
|
|
|
|
|
|
|
for tag in ["<|im_end|>", "<im end for model tun>"]: |
|
|
if tag in display: |
|
|
idx = display.find(tag) |
|
|
stats = display.find("\n\n*[") |
|
|
display = display[:idx] + (display[stats:] if stats > idx else "") |
|
|
|
|
|
|
|
|
if reasoning and '<think>' in display and '</think>' in display: |
|
|
s, e = display.find('<think>'), display.find('</think>') |
|
|
if s < e: |
|
|
thought = display[s+7:e].strip() |
|
|
display = display[:s] + f'<details><summary>π§ Reasoning</summary><p>{thought}</p></details>' + display[e+8:] |
|
|
|
|
|
yield history + [{"role": "assistant", "content": display.strip()}] |
|
|
|
|
|
|
|
|
def stop(): |
|
|
global stop_generation |
|
|
stop_generation = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_ui(): |
|
|
workers = CONFIG["worker_urls"] |
|
|
mode = f"Distributed ({len(workers)} workers)" if workers else "Standalone" |
|
|
|
|
|
with gr.Blocks(title="Sam-large-2 HEAD") as app: |
|
|
gr.Markdown(f""" |
|
|
# π Sam-large-2 - HEAD NODE |
|
|
**Mode:** {mode} | **Blocks:** {CONFIG['layer_start']}-{CONFIG['layer_end']-1} | **ID:** {CONFIG['node_id']} |
|
|
""") |
|
|
|
|
|
if workers: |
|
|
gr.Markdown("**Workers:** " + ", ".join(f"`{w}`" for w in workers)) |
|
|
|
|
|
reasoning = gr.State(False) |
|
|
|
|
|
chatbot = gr.Chatbot( |
|
|
height=500, |
|
|
type="messages" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
reason_btn = gr.Button("π‘", size="sm", scale=0) |
|
|
msg = gr.Textbox(placeholder="Type message...", show_label=False, scale=8) |
|
|
send = gr.Button("Send", variant="primary", scale=1) |
|
|
stop_btn = gr.Button("βΉοΈ", scale=0) |
|
|
|
|
|
with gr.Accordion("βοΈ Settings", open=False): |
|
|
max_tok = gr.Slider(50, 1024, 512, label="Max Tokens") |
|
|
temp = gr.Slider(0.1, 2.0, 0.8, label="Temperature") |
|
|
topk = gr.Slider(1, 100, 40, label="Top-K") |
|
|
topp = gr.Slider(0.1, 1.0, 0.9, label="Top-P") |
|
|
rep = gr.Slider(1.0, 2.0, 1.1, label="Repetition Penalty") |
|
|
|
|
|
def toggle(r): |
|
|
return not r, gr.update(variant="primary" if not r else "secondary") |
|
|
|
|
|
reason_btn.click(toggle, [reasoning], [reasoning, reason_btn]) |
|
|
|
|
|
inputs = [msg, chatbot, max_tok, temp, topk, topp, rep, reasoning] |
|
|
submit = msg.submit(chat_respond, inputs, chatbot).then(lambda: "", outputs=msg) |
|
|
click = send.click(chat_respond, inputs, chatbot).then(lambda: "", outputs=msg) |
|
|
stop_btn.click(stop, cancels=[submit, click]) |
|
|
|
|
|
gr.Button("ποΈ Clear").click(lambda: [], outputs=[chatbot]) |
|
|
|
|
|
return app |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("=" * 60) |
|
|
print("π Sam-large-2 HEAD Node Starting") |
|
|
print(f" Blocks: {CONFIG['layer_start']} to {CONFIG['layer_end']}") |
|
|
print(f" Workers: {CONFIG['worker_urls'] or 'None (standalone)'}") |
|
|
print("=" * 60) |
|
|
|
|
|
load_model() |
|
|
app = create_ui() |
|
|
app.queue() |
|
|
app.launch(server_name="0.0.0.0", server_port=7860) |