# app.py import time import os import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteriaList, StoppingCriteria MODEL_NAME = "Rapnss/VIA-01" # your HF repo # Configs you can tune DEFAULT_MAX_NEW_TOKENS = 64 # keep low to meet latency targets MAX_PROMPT_TOKENS = 512 # truncate long prompts TEMPERATURE = 0.3 TOP_P = 0.9 DO_SAMPLE = False # deterministic and usually faster than sampling NUM_BEAMS = 1 # beam=1 is fastest WARMUP_PROMPT = "Hello." # used to warm model after loading # Try to load tokenizer / model in quantized mode (4-bit) if bitsandbytes available print("Loading tokenizer & model...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) model = None device = "cpu" try: # If CUDA is available and bitsandbytes exists, load 4-bit if torch.cuda.is_available(): device = "cuda" print("CUDA available — attempting 4-bit load with bitsandbytes...") model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, load_in_4bit=True, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True, # some user repos need it bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, ) else: raise RuntimeError("CUDA not available; load fallback") except Exception as e: print("4-bit load failed or not available:", e) print("Falling back to fp16/cpu (best-effort).") # fallback: try fp16 on GPU or float32 on CPU if torch.cuda.is_available(): device = "cuda" model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True, ) else: device = "cpu" model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float32, device_map={"": "cpu"}, trust_remote_code=True, ) # Put model to eval & optionally compile model.eval() # Optional: try torch.compile for small speedups (PyTorch 2.x only, may increase startup) try: if torch.__version__.startswith("2"): print("Attempting torch.compile(model) for runtime speedups...") model = torch.compile(model) except Exception as e: print("torch.compile not used:", e) print(f"Model loaded on {device}") # Utility: fast tokenize + move to proper device def prepare_inputs(prompt_text): # Truncate long prompts to limit total tokens on generation inputs = tokenizer( prompt_text, return_tensors="pt", truncation=True, max_length=MAX_PROMPT_TOKENS, padding=False, ) if device == "cuda": inputs = {k: v.cuda() for k, v in inputs.items()} return inputs # Optional: short stopping criteria example (stop on newline double) class StopOnDoubleNewline(StoppingCriteria): def __call__(self, input_ids, scores, **kwargs): # stop when last two tokens are newline + newline (customize as needed) if input_ids.shape[-1] >= 2: if input_ids[0, -2].item() == tokenizer.eos_token_id or input_ids[0, -1].item() == tokenizer.eos_token_id: return True return False stop_criteria = StoppingCriteriaList([StopOnDoubleNewline()]) # Warm-up function (to run a single tiny generation so the model caches kernels) def warm_up_model(): try: prompt = WARMUP_PROMPT inputs = prepare_inputs(prompt) with torch.inference_mode(): model.generate( **inputs, max_new_tokens=8, do_sample=False, use_cache=True, ) print("Warmup complete.") except Exception as e: print("Warmup failed:", e) # Warm up once at startup to reduce first-request latency warm_up_model() # The actual chat function used by Gradio def chat_fn(prompt: str, max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS, temperature: float = TEMPERATURE): t0 = time.time() prompt = prompt.strip() if not prompt: return "Please enter a prompt." # safety: clamp max_new_tokens to avoid huge generations max_new_tokens = int(max(1, min(max_new_tokens, 256))) inputs = prepare_inputs(prompt) # Generation arguments tuned for speed gen_kwargs = dict( **inputs, max_new_tokens=max_new_tokens, temperature=float(temperature), top_p=float(TOP_P), do_sample=DO_SAMPLE, num_beams=NUM_BEAMS, eos_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id, pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, use_cache=True, early_stopping=True, # stopping_criteria=stop_criteria, # enable if you want custom stopping ) # Inference context to reduce overhead with torch.inference_mode(): outputs = model.generate(**gen_kwargs) response = tokenizer.decode(outputs[0], skip_special_tokens=True) latency = time.time() - t0 # Return response and latency for debugging return f"{response}\n\n---\nLatency: {latency:.2f}s (max_new_tokens={max_new_tokens}, device={device})" # Gradio UI with gr.Blocks() as demo: gr.Markdown("# Rapnss VIA-01") with gr.Row(): txt = gr.Textbox(lines=3, placeholder="Ask VIA-01 something...", label="Prompt") with gr.Row(): max_tokens = gr.Slider(16, 256, value=DEFAULT_MAX_NEW_TOKENS, step=16, label="Max new tokens") temp = gr.Slider(0.0, 1.0, value=TEMPERATURE, step=0.05, label="Temperature") out = gr.Textbox(label="VIA-01 Response", lines=12) submit = gr.Button("Generate") submit.click(fn=chat_fn, inputs=[txt, max_tokens, temp], outputs=out) if __name__ == "__main__": demo.launch(share=False, server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))