File size: 5,175 Bytes
0fafb5e
6ac164c
650e6f8
37ca5d0
91b03f9
37ca5d0
 
 
 
b17ecc2
a445827
00a2ac7
 
a445827
37ca5d0
 
 
 
 
 
e3f498d
37ca5d0
 
 
 
a9049fb
 
 
 
b17ecc2
dbb94cf
a445827
37ca5d0
77b3a6a
72c7c74
cb1a144
fcb9074
37ca5d0
 
00a2ac7
b440713
906992b
 
77b3a6a
37ca5d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00a2ac7
 
 
 
 
 
37ca5d0
 
 
00a2ac7
 
 
37ca5d0
00a2ac7
 
 
37ca5d0
a9049fb
37ca5d0
00a2ac7
37ca5d0
 
 
 
 
 
 
00a2ac7
6ac164c
37ca5d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ac164c
37ca5d0
 
 
 
e3f498d
37ca5d0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
from threading import Thread
import bitsandbytes
from typing import Iterator

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer


# Set the environment variable
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

DESCRIPTION = """\
# Llama 3.2 3B Instruct
Llama 3.2 3B is Meta's latest iteration of open LLMs.
This is a demo of [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), fine-tuned for instruction following.
For more details, please check [our post](https://huggingface.co/blog/llama32).
"""

# Access token for the model (if required)
access_token = os.getenv('HF_TOKEN')
# Download the Base model
#model_id = "./models/Llama-32-3B-Instruct"
model_id = "nvidia/Llama-3_1-Nemotron-51B-Instruct"
MAX_MAX_NEW_TOKENS = 6144
DEFAULT_MAX_NEW_TOKENS = 6144
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "6144"))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#model_id = "nltpt/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id,token=access_token)
#tokenizer.padding_side = 'right'
#tokenizer.eos_token_id = 107
#tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map=device,
    trust_remote_code=True,
    #torch_dtype=torch.float8,
    load_in_8bit=True,
    token=access_token
)
model.eval()


@spaces.GPU(duration=90)
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = [{"role": "system", "content": system_prompt}]
    for user, assistant in chat_history:
        conversation.extend(
            [
                {"role": "user", "content": user},
                {"role": "assistant", "content": assistant},
            ]
        )
    conversation.append({"role": "user", "content": message})

    # Set pad_token_id if it's not already set
    if tokenizer.pad_token_id is None:
        tokenizer.padding_side = 'right'
        tokenizer.pad_token = tokenizer.eos_token

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True,add_special_tokens=True, return_tensors="pt",padding=True ,return_attention_mask=True)
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    # Ensure attention mask is set
    #attention_mask = input_ids['attention_mask']

    input_ids = input_ids.to(model.device)
    #attention_mask = attention_mask.to(model.device)



    streamer = TextIteratorStreamer(tokenizer, timeout=2000.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Textbox(
            label="System Prompt",
            placeholder="Enter system prompt here...",
            lines=2,
        ),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    stop_btn=None,
    examples=[
        ["Hello there! How are you doing?"],
        ["Can you explain briefly to me what is the Python programming language?"],
        ["Explain the plot of Cinderella in a sentence."],
        ["How many hours does it take a man to eat a Helicopter?"],
        ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
    ],
    cache_examples=False,
)

with gr.Blocks(css="style.css", fill_height=True) as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
    chat_interface.render()

if __name__ == "__main__":
    demo.queue(max_size=20).launch()