Spaces:
Build error
Build error
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftModel | |
BASE_MODEL = "nvidia/Llama-3.1-Nemotron-8B-UltraLong-1M-Instruct" | |
ADAPTERS_REPO = "nishantmulchandani/PAMAv1-Nemotron-8B-1M-LoRA" | |
def load_model(): | |
dtype = torch.float16 | |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) | |
base = AutoModelForCausalLM.from_pretrained( | |
BASE_MODEL, | |
device_map="auto", | |
trust_remote_code=True, | |
torch_dtype=dtype, | |
load_in_4bit=True, | |
) | |
model = PeftModel.from_pretrained(base, ADAPTERS_REPO) | |
return tokenizer, model | |
tokenizer, model = load_model() | |
def generate(prompt: str, max_new_tokens: int = 256, temperature: float = 0.7, top_p: float = 0.9): | |
if not prompt.strip(): | |
return "" | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
with torch.inference_mode(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
DESCRIPTION = """ | |
# PAMAv1 — LoRA demo | |
This Space loads the base model `nvidia/Llama-3.1-Nemotron-8B-UltraLong-1M-Instruct` with the LoRA adapters from `nishantmulchandani/PAMAv1-Nemotron-8B-1M-LoRA`. | |
Notes: | |
- Requires GPU hardware (A10G/T4 or better). Uses 4-bit loading to reduce VRAM. | |
- For long contexts, craft prompts carefully and raise `max_new_tokens`. | |
""" | |
examples = [ | |
[ | |
"You are a helpful assistant.\nUser: Summarize the Lottie JSON schema in two concise bullets.\nAssistant:", | |
128, | |
0.7, | |
0.9, | |
], | |
] | |
with gr.Blocks() as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt", lines=8) | |
max_new_tokens = gr.Slider(16, 1024, value=256, step=8, label="max_new_tokens") | |
temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature") | |
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p") | |
btn = gr.Button("Generate") | |
with gr.Column(): | |
out = gr.Textbox(label="Output", lines=12) | |
btn.click(generate, [prompt, max_new_tokens, temperature, top_p], out) | |
gr.Examples(examples=examples, fn=generate, inputs=[prompt, max_new_tokens, temperature, top_p], outputs=[out]) | |
if __name__ == "__main__": | |
demo.launch() | |