File size: 6,312 Bytes
9ecbf84
 
f0c68fa
 
 
 
 
 
 
 
 
82b1ec4
f0c68fa
 
 
9ecbf84
82b1ec4
ebc86be
9065e38
 
4a12906
9310ba1
43a7079
 
24083d5
43a7079
 
9310ba1
43a7079
 
9310ba1
43a7079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c67fd6f
43a7079
7b75ee1
 
 
 
43a7079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b75ee1
43a7079
 
 
 
 
 
ad9d4f6
 
 
 
 
 
 
 
 
43a7079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import subprocess
# Install flash attention, skipping CUDA build if necessary
subprocess.run(
    "export CPATH=$CPATH:/usr/local/cuda/include",
    shell=True,
)
subprocess.run(
    "export LIBRARY_PATH=$LIBRARY_PATH:/usr/local/cuda/lib64",
    shell=True,
)
subprocess.run(
    "/home/user/.pyenv/shims/pip install flash-attn --no-build-isolation",
    env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
    shell=True,
)
subprocess.run(
    "/home/user/.pyenv/shims/pip install pycuda==2023.1",
    env={"CPATH": "$CPATH:/usr/local/cuda/include", "LIBRARY_PATH": "$LIBRARY_PATH:/usr/local/cuda/lib64"},
    shell=True,
)

import gradio as gr
import os
import spaces
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)


DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Meta Llama3 8B</h1>
<p>This Space demonstrates the instruction-tuned model <a href="https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct"><b>Meta Llama3 8b Chat</b></a>. Meta Llama3 is the new open LLM and comes in two sizes: 8b and 70b. Feel free to play with it, or duplicate to run privately!</p>
<p>🔎 For more details about the Llama3 release and how to use the model with <code>transformers</code>, take a look <a href="https://huggingface.co/blog/llama3">at our blog post</a>.</p>
<p>🦕 Looking for an even more powerful model? Check out the <a href="https://huggingface.co/chat/"><b>Hugging Chat</b></a> integration for Meta Llama 3 70b</p>
</div>
'''

LICENSE = """
<p/>
---
Built with Meta Llama 3
"""

PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
   <img src="https://ysharma-dummy-chat-app.hf.space/file=/tmp/gradio/8e75e61cc9bab22b7ce3dec85ab0e6db1da5d107/Meta_lockup_positive%20primary_RGB.jpg" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55;  "> 
   <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">Meta llama3</h1>
   <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
</div>
"""


css = """
h1 {
  text-align: center;
  display: block;
}
#duplicate-button {
  margin: auto;
  color: white;
  background: #1565c0;
  border-radius: 100vh;
}
"""

# Load the tokenizer and model
model_name = "gradientai/Llama-3-8B-Instruct-262k"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")  # to("cuda:0") 

from minference import MInference
minference_patch = MInference("minference", model_name)
model = minference_patch(model)

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

@spaces.GPU(duration=120)
def chat_llama3_8b(message: str, 
              history: list, 
              temperature: float, 
              max_new_tokens: int
             ) -> str:
    """
    Generate a streaming response using the llama3-8b model.
    Args:
        message (str): The input message.
        history (list): The conversation history used by ChatInterface.
        temperature (float): The temperature for generating the response.
        max_new_tokens (int): The maximum number of new tokens to generate.
    Returns:
        str: The generated response.
    """
    # global model
    conversation = []
    for user, assistant in history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
    print(model.device)
    # subprocess.run(
    #     "pip install pycuda==2023.1",
    #     shell=True,
    # )
    # if "has_patch" not in model.__dict__:
    #     from minference import MInference
    #     minference_patch = MInference("minference", model_name)
    #     model = minference_patch(model)
    
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

    generate_kwargs = dict(
        input_ids= input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        eos_token_id=terminators,
    )
    # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.             
    if temperature == 0:
        generate_kwargs['do_sample'] = False
        
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        #print(outputs)
        yield "".join(outputs)
        

# Gradio block
chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')

with gr.Blocks(fill_height=True, css=css) as demo:
    
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
    gr.ChatInterface(
        fn=chat_llama3_8b,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Slider(minimum=0,
                      maximum=1, 
                      step=0.1,
                      value=0.95, 
                      label="Temperature", 
                      render=False),
            gr.Slider(minimum=128, 
                      maximum=4096,
                      step=1,
                      value=512, 
                      label="Max new tokens", 
                      render=False ),
            ],
        examples=[
            ['How to setup a human base on Mars? Give short answer.'],
            ['Explain theory of relativity to me like I’m 8 years old.'],
            ['What is 9,000 * 9,000?'],
            ['Write a pun-filled happy birthday message to my friend Alex.'],
            ['Justify why a penguin might make a good king of the jungle.']
            ],
        cache_examples=False,
                     )
    
    gr.Markdown(LICENSE)
    
if __name__ == "__main__":
    demo.launch()