Spaces:
Runtime error
Runtime error
# app.py | |
import time | |
import os | |
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteriaList, StoppingCriteria | |
MODEL_NAME = "Rapnss/VIA-01" # your HF repo | |
# Configs you can tune | |
DEFAULT_MAX_NEW_TOKENS = 64 # keep low to meet latency targets | |
MAX_PROMPT_TOKENS = 512 # truncate long prompts | |
TEMPERATURE = 0.3 | |
TOP_P = 0.9 | |
DO_SAMPLE = False # deterministic and usually faster than sampling | |
NUM_BEAMS = 1 # beam=1 is fastest | |
WARMUP_PROMPT = "Hello." # used to warm model after loading | |
# Try to load tokenizer / model in quantized mode (4-bit) if bitsandbytes available | |
print("Loading tokenizer & model...") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) | |
model = None | |
device = "cpu" | |
try: | |
# If CUDA is available and bitsandbytes exists, load 4-bit | |
if torch.cuda.is_available(): | |
device = "cuda" | |
print("CUDA available — attempting 4-bit load with bitsandbytes...") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
load_in_4bit=True, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
trust_remote_code=True, # some user repos need it | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
) | |
else: | |
raise RuntimeError("CUDA not available; load fallback") | |
except Exception as e: | |
print("4-bit load failed or not available:", e) | |
print("Falling back to fp16/cpu (best-effort).") | |
# fallback: try fp16 on GPU or float32 on CPU | |
if torch.cuda.is_available(): | |
device = "cuda" | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True, | |
) | |
else: | |
device = "cpu" | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
torch_dtype=torch.float32, | |
device_map={"": "cpu"}, | |
trust_remote_code=True, | |
) | |
# Put model to eval & optionally compile | |
model.eval() | |
# Optional: try torch.compile for small speedups (PyTorch 2.x only, may increase startup) | |
try: | |
if torch.__version__.startswith("2"): | |
print("Attempting torch.compile(model) for runtime speedups...") | |
model = torch.compile(model) | |
except Exception as e: | |
print("torch.compile not used:", e) | |
print(f"Model loaded on {device}") | |
# Utility: fast tokenize + move to proper device | |
def prepare_inputs(prompt_text): | |
# Truncate long prompts to limit total tokens on generation | |
inputs = tokenizer( | |
prompt_text, | |
return_tensors="pt", | |
truncation=True, | |
max_length=MAX_PROMPT_TOKENS, | |
padding=False, | |
) | |
if device == "cuda": | |
inputs = {k: v.cuda() for k, v in inputs.items()} | |
return inputs | |
# Optional: short stopping criteria example (stop on newline double) | |
class StopOnDoubleNewline(StoppingCriteria): | |
def __call__(self, input_ids, scores, **kwargs): | |
# stop when last two tokens are newline + newline (customize as needed) | |
if input_ids.shape[-1] >= 2: | |
if input_ids[0, -2].item() == tokenizer.eos_token_id or input_ids[0, -1].item() == tokenizer.eos_token_id: | |
return True | |
return False | |
stop_criteria = StoppingCriteriaList([StopOnDoubleNewline()]) | |
# Warm-up function (to run a single tiny generation so the model caches kernels) | |
def warm_up_model(): | |
try: | |
prompt = WARMUP_PROMPT | |
inputs = prepare_inputs(prompt) | |
with torch.inference_mode(): | |
model.generate( | |
**inputs, | |
max_new_tokens=8, | |
do_sample=False, | |
use_cache=True, | |
) | |
print("Warmup complete.") | |
except Exception as e: | |
print("Warmup failed:", e) | |
# Warm up once at startup to reduce first-request latency | |
warm_up_model() | |
# The actual chat function used by Gradio | |
def chat_fn(prompt: str, max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS, temperature: float = TEMPERATURE): | |
t0 = time.time() | |
prompt = prompt.strip() | |
if not prompt: | |
return "Please enter a prompt." | |
# safety: clamp max_new_tokens to avoid huge generations | |
max_new_tokens = int(max(1, min(max_new_tokens, 256))) | |
inputs = prepare_inputs(prompt) | |
# Generation arguments tuned for speed | |
gen_kwargs = dict( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
temperature=float(temperature), | |
top_p=float(TOP_P), | |
do_sample=DO_SAMPLE, | |
num_beams=NUM_BEAMS, | |
eos_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id, | |
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, | |
use_cache=True, | |
early_stopping=True, | |
# stopping_criteria=stop_criteria, # enable if you want custom stopping | |
) | |
# Inference context to reduce overhead | |
with torch.inference_mode(): | |
outputs = model.generate(**gen_kwargs) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
latency = time.time() - t0 | |
# Return response and latency for debugging | |
return f"{response}\n\n---\nLatency: {latency:.2f}s (max_new_tokens={max_new_tokens}, device={device})" | |
# Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("# Rapnss VIA-01") | |
with gr.Row(): | |
txt = gr.Textbox(lines=3, placeholder="Ask VIA-01 something...", label="Prompt") | |
with gr.Row(): | |
max_tokens = gr.Slider(16, 256, value=DEFAULT_MAX_NEW_TOKENS, step=16, label="Max new tokens") | |
temp = gr.Slider(0.0, 1.0, value=TEMPERATURE, step=0.05, label="Temperature") | |
out = gr.Textbox(label="VIA-01 Response", lines=12) | |
submit = gr.Button("Generate") | |
submit.click(fn=chat_fn, inputs=[txt, max_tokens, temp], outputs=out) | |
if __name__ == "__main__": | |
demo.launch(share=False, server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) |