unassigned's picture
Initial Space: PAMAv1 LoRA demo
297d755
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
BASE_MODEL = "nvidia/Llama-3.1-Nemotron-8B-UltraLong-1M-Instruct"
ADAPTERS_REPO = "nishantmulchandani/PAMAv1-Nemotron-8B-1M-LoRA"
def load_model():
dtype = torch.float16
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
trust_remote_code=True,
torch_dtype=dtype,
load_in_4bit=True,
)
model = PeftModel.from_pretrained(base, ADAPTERS_REPO)
return tokenizer, model
tokenizer, model = load_model()
def generate(prompt: str, max_new_tokens: int = 256, temperature: float = 0.7, top_p: float = 0.9):
if not prompt.strip():
return ""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
DESCRIPTION = """
# PAMAv1 — LoRA demo
This Space loads the base model `nvidia/Llama-3.1-Nemotron-8B-UltraLong-1M-Instruct` with the LoRA adapters from `nishantmulchandani/PAMAv1-Nemotron-8B-1M-LoRA`.
Notes:
- Requires GPU hardware (A10G/T4 or better). Uses 4-bit loading to reduce VRAM.
- For long contexts, craft prompts carefully and raise `max_new_tokens`.
"""
examples = [
[
"You are a helpful assistant.\nUser: Summarize the Lottie JSON schema in two concise bullets.\nAssistant:",
128,
0.7,
0.9,
],
]
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", lines=8)
max_new_tokens = gr.Slider(16, 1024, value=256, step=8, label="max_new_tokens")
temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
btn = gr.Button("Generate")
with gr.Column():
out = gr.Textbox(label="Output", lines=12)
btn.click(generate, [prompt, max_new_tokens, temperature, top_p], out)
gr.Examples(examples=examples, fn=generate, inputs=[prompt, max_new_tokens, temperature, top_p], outputs=[out])
if __name__ == "__main__":
demo.launch()