File size: 2,583 Bytes
297d755
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel


BASE_MODEL = "nvidia/Llama-3.1-Nemotron-8B-UltraLong-1M-Instruct"
ADAPTERS_REPO = "nishantmulchandani/PAMAv1-Nemotron-8B-1M-LoRA"


def load_model():
    dtype = torch.float16
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
    base = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        device_map="auto",
        trust_remote_code=True,
        torch_dtype=dtype,
        load_in_4bit=True,
    )
    model = PeftModel.from_pretrained(base, ADAPTERS_REPO)
    return tokenizer, model


tokenizer, model = load_model()


def generate(prompt: str, max_new_tokens: int = 256, temperature: float = 0.7, top_p: float = 0.9):
    if not prompt.strip():
        return ""
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.inference_mode():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)


DESCRIPTION = """
# PAMAv1 — LoRA demo

This Space loads the base model `nvidia/Llama-3.1-Nemotron-8B-UltraLong-1M-Instruct` with the LoRA adapters from `nishantmulchandani/PAMAv1-Nemotron-8B-1M-LoRA`.

Notes:
- Requires GPU hardware (A10G/T4 or better). Uses 4-bit loading to reduce VRAM.
- For long contexts, craft prompts carefully and raise `max_new_tokens`.
"""

examples = [
    [
        "You are a helpful assistant.\nUser: Summarize the Lottie JSON schema in two concise bullets.\nAssistant:",
        128,
        0.7,
        0.9,
    ],
]

with gr.Blocks() as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt", lines=8)
            max_new_tokens = gr.Slider(16, 1024, value=256, step=8, label="max_new_tokens")
            temperature = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature")
            top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
            btn = gr.Button("Generate")
        with gr.Column():
            out = gr.Textbox(label="Output", lines=12)
    btn.click(generate, [prompt, max_new_tokens, temperature, top_p], out)
    gr.Examples(examples=examples, fn=generate, inputs=[prompt, max_new_tokens, temperature, top_p], outputs=[out])

if __name__ == "__main__":
    demo.launch()