File size: 4,150 Bytes
ea4be8b
bed1c6d
2b98c70
 
ea4be8b
bed1c6d
 
7802d6c
3a663e8
 
 
237e017
6be3233
3a663e8
 
2b98c70
3a663e8
2b98c70
 
3a663e8
2b98c70
 
3a663e8
 
 
 
 
 
 
 
 
4a95791
3a663e8
 
 
 
 
ea4be8b
2b98c70
 
 
 
 
 
 
 
 
ea4be8b
 
d106232
ea4be8b
 
e1e6dcd
 
6be3233
ea4be8b
 
 
 
 
2b98c70
 
 
ea4be8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a663e8
ea4be8b
 
 
c7ec932
 
2b98c70
 
 
 
ea4be8b
 
c7ec932
ea4be8b
 
 
e1e6dcd
4a95791
e1e6dcd
 
 
ea4be8b
 
 
 
7d8694d
ea4be8b
 
 
 
 
 
 
179f53f
ea4be8b
 
 
 
 
6be3233
af7b600
2b98c70
ed2abfb
4a95791
af7b600
ea4be8b
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from threading import Thread

import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer


TITLE = "<h1><center>Chat with Llama3-8B-Chinese-Chat-v2.1</center></h1>"

DESCRIPTION = "<h3><center>Visit <a href='https://huggingface.co/shenzhi-wang/Llama3-8B-Chinese-Chat' target='_blank'>our model page</a> for details.</center></h3>"

DEFAULT_SYSTEM = "You are Llama-3, developed by an independent team. You are a helpful assistant."

TOOL_EXAMPLE = '''You have access to the following tools:
```python
def generate_password(length: int, include_symbols: Optional[bool]):
    """
    Generate a random password.

    Args:
        length (int): The length of the password
        include_symbols (Optional[bool]): Include symbols in the password
    """
    pass
```

Write "Action:" followed by a list of actions in JSON that you want to call, e.g.
Action:
```json
[
    {
        "name": "tool name (one of [generate_password])",
        "arguments": "the input to the tool"
    }
]
```
'''

CSS = """
.duplicate-button {
  margin: auto !important;
  color: white !important;
  background: black !important;
  border-radius: 100vh !important;
}
"""


tokenizer = AutoTokenizer.from_pretrained("shenzhi-wang/Llama3-8B-Chinese-Chat")
model = AutoModelForCausalLM.from_pretrained("shenzhi-wang/Llama3-8B-Chinese-Chat", torch_dtype="auto", device_map="auto")


@spaces.GPU
def stream_chat(message: str, history: list, system: str, temperature: float, max_new_tokens: int):
    conversation = [{"role": "system", "content": system or DEFAULT_SYSTEM}]
    for prompt, answer in history:
        conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])

    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(
        model.device
    )
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        do_sample=True,
    )
    if temperature == 0:
        generate_kwargs["do_sample"] = False

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    output = ""
    for new_token in streamer:
        output += new_token
        yield output


chatbot = gr.Chatbot(height=450)

with gr.Blocks(css=CSS) as demo:
    gr.HTML(TITLE)
    gr.HTML(DESCRIPTION)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
    gr.ChatInterface(
        fn=stream_chat,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Text(
                value="",
                label="System",
                render=False,
            ),
            gr.Slider(
                minimum=0,
                maximum=1,
                step=0.1,
                value=0.8,
                label="Temperature",
                render=False,
            ),
            gr.Slider(
                minimum=128,
                maximum=4096,
                step=1,
                value=1024,
                label="Max new tokens",
                render=False,
            ),
        ],
        examples=[
            ["我的蓝牙耳机坏了,我该去看牙科还是耳鼻喉科?", ""],
            ["7年前,妈妈年龄是儿子的6倍,儿子今年12岁,妈妈今年多少岁?", ""],
            ["我的笔记本找不到了。", "扮演诸葛亮和我对话。"],
            ["我想要一个新的密码,长度为8位,包含特殊符号。", TOOL_EXAMPLE],
            ["How are you today?", "You are Taylor Swift, use beautiful lyrics to answer questions."],
            ["用C++实现KMP算法,并加上中文注释", ""],
        ],
        cache_examples=False,
    )


if __name__ == "__main__":
    demo.launch()