VIA-01 / app.py
Invescoz's picture
Update app.py
502b633 verified
# app.py
import time
import os
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteriaList, StoppingCriteria
MODEL_NAME = "Rapnss/VIA-01" # your HF repo
# Configs you can tune
DEFAULT_MAX_NEW_TOKENS = 64 # keep low to meet latency targets
MAX_PROMPT_TOKENS = 512 # truncate long prompts
TEMPERATURE = 0.3
TOP_P = 0.9
DO_SAMPLE = False # deterministic and usually faster than sampling
NUM_BEAMS = 1 # beam=1 is fastest
WARMUP_PROMPT = "Hello." # used to warm model after loading
# Try to load tokenizer / model in quantized mode (4-bit) if bitsandbytes available
print("Loading tokenizer & model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
model = None
device = "cpu"
try:
# If CUDA is available and bitsandbytes exists, load 4-bit
if torch.cuda.is_available():
device = "cuda"
print("CUDA available — attempting 4-bit load with bitsandbytes...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
load_in_4bit=True,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True, # some user repos need it
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
else:
raise RuntimeError("CUDA not available; load fallback")
except Exception as e:
print("4-bit load failed or not available:", e)
print("Falling back to fp16/cpu (best-effort).")
# fallback: try fp16 on GPU or float32 on CPU
if torch.cuda.is_available():
device = "cuda"
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
else:
device = "cpu"
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
device_map={"": "cpu"},
trust_remote_code=True,
)
# Put model to eval & optionally compile
model.eval()
# Optional: try torch.compile for small speedups (PyTorch 2.x only, may increase startup)
try:
if torch.__version__.startswith("2"):
print("Attempting torch.compile(model) for runtime speedups...")
model = torch.compile(model)
except Exception as e:
print("torch.compile not used:", e)
print(f"Model loaded on {device}")
# Utility: fast tokenize + move to proper device
def prepare_inputs(prompt_text):
# Truncate long prompts to limit total tokens on generation
inputs = tokenizer(
prompt_text,
return_tensors="pt",
truncation=True,
max_length=MAX_PROMPT_TOKENS,
padding=False,
)
if device == "cuda":
inputs = {k: v.cuda() for k, v in inputs.items()}
return inputs
# Optional: short stopping criteria example (stop on newline double)
class StopOnDoubleNewline(StoppingCriteria):
def __call__(self, input_ids, scores, **kwargs):
# stop when last two tokens are newline + newline (customize as needed)
if input_ids.shape[-1] >= 2:
if input_ids[0, -2].item() == tokenizer.eos_token_id or input_ids[0, -1].item() == tokenizer.eos_token_id:
return True
return False
stop_criteria = StoppingCriteriaList([StopOnDoubleNewline()])
# Warm-up function (to run a single tiny generation so the model caches kernels)
def warm_up_model():
try:
prompt = WARMUP_PROMPT
inputs = prepare_inputs(prompt)
with torch.inference_mode():
model.generate(
**inputs,
max_new_tokens=8,
do_sample=False,
use_cache=True,
)
print("Warmup complete.")
except Exception as e:
print("Warmup failed:", e)
# Warm up once at startup to reduce first-request latency
warm_up_model()
# The actual chat function used by Gradio
def chat_fn(prompt: str, max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS, temperature: float = TEMPERATURE):
t0 = time.time()
prompt = prompt.strip()
if not prompt:
return "Please enter a prompt."
# safety: clamp max_new_tokens to avoid huge generations
max_new_tokens = int(max(1, min(max_new_tokens, 256)))
inputs = prepare_inputs(prompt)
# Generation arguments tuned for speed
gen_kwargs = dict(
**inputs,
max_new_tokens=max_new_tokens,
temperature=float(temperature),
top_p=float(TOP_P),
do_sample=DO_SAMPLE,
num_beams=NUM_BEAMS,
eos_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id,
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
use_cache=True,
early_stopping=True,
# stopping_criteria=stop_criteria, # enable if you want custom stopping
)
# Inference context to reduce overhead
with torch.inference_mode():
outputs = model.generate(**gen_kwargs)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
latency = time.time() - t0
# Return response and latency for debugging
return f"{response}\n\n---\nLatency: {latency:.2f}s (max_new_tokens={max_new_tokens}, device={device})"
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# Rapnss VIA-01")
with gr.Row():
txt = gr.Textbox(lines=3, placeholder="Ask VIA-01 something...", label="Prompt")
with gr.Row():
max_tokens = gr.Slider(16, 256, value=DEFAULT_MAX_NEW_TOKENS, step=16, label="Max new tokens")
temp = gr.Slider(0.0, 1.0, value=TEMPERATURE, step=0.05, label="Temperature")
out = gr.Textbox(label="VIA-01 Response", lines=12)
submit = gr.Button("Generate")
submit.click(fn=chat_fn, inputs=[txt, max_tokens, temp], outputs=out)
if __name__ == "__main__":
demo.launch(share=False, server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))