File size: 4,098 Bytes
f75fe7f
 
 
 
d6a8ce7
f75fe7f
 
d6a8ce7
f75fe7f
d6a8ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f75fe7f
 
 
 
 
 
 
 
 
 
 
d6a8ce7
b5ab4e2
f75fe7f
 
 
d6a8ce7
 
f75fe7f
 
 
d6a8ce7
f75fe7f
d6a8ce7
 
 
f75fe7f
 
 
 
 
d6a8ce7
 
892484f
f75fe7f
d6a8ce7
 
f75fe7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6a8ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f75fe7f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from threading import Thread

import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer


TITLE = "<h1><center>Chat with Gemma-2-9B-Chinese-Chat</center></h1>"

DESCRIPTION = "<h3><center>Visit <a href='https://huggingface.co/shenzhi-wang/Gemma-2-9B-Chinese-Chat' target='_blank'>our model page</a> for details.</center></h3>"

DEFAULT_SYSTEM = "You are a helpful assistant."

TOOL_EXAMPLE = '''You have access to the following tools:
```python
def generate_password(length: int, include_symbols: Optional[bool]):
    """
    Generate a random password.

    Args:
        length (int): The length of the password
        include_symbols (Optional[bool]): Include symbols in the password
    """
    pass
```

Write "Action:" followed by a list of actions in JSON that you want to call, e.g.
Action:
```json
[
    {
        "name": "tool name (one of [generate_password])",
        "arguments": "the input to the tool"
    }
]
```
'''

CSS = """
.duplicate-button {
  margin: auto !important;
  color: white !important;
  background: black !important;
  border-radius: 100vh !important;
}
"""


tokenizer = AutoTokenizer.from_pretrained("shenzhi-wang/Gemma-2-9B-Chinese-Chat")
model = AutoModelForCausalLM.from_pretrained("shenzhi-wang/Gemma-2-9B-Chinese-Chat", torch_dtype="auto", device_map="auto")


@spaces.GPU
def stream_chat(message: str, history: list, system: str, temperature: float, max_new_tokens: int):
    conversation = [{"role": "system", "content": system or DEFAULT_SYSTEM}]
    for prompt, answer in history:
        conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])

    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(
        model.device
    )
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        do_sample=True,
    )
    if temperature == 0:
        generate_kwargs["do_sample"] = False

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    output = ""
    for new_token in streamer:
        output += new_token
        yield output


chatbot = gr.Chatbot(height=450)

with gr.Blocks(css=CSS) as demo:
    gr.HTML(TITLE)
    gr.HTML(DESCRIPTION)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
    gr.ChatInterface(
        fn=stream_chat,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Text(
                value="",
                label="System",
                render=False,
            ),
            gr.Slider(
                minimum=0,
                maximum=1,
                step=0.1,
                value=0.8,
                label="Temperature",
                render=False,
            ),
            gr.Slider(
                minimum=128,
                maximum=4096,
                step=1,
                value=1024,
                label="Max new tokens",
                render=False,
            ),
        ],
        examples=[
            ["我的蓝牙耳机坏了,我该去看牙科还是耳鼻喉科?", ""],
            ["7年前,妈妈年龄是儿子的6倍,儿子今年12岁,妈妈今年多少岁?", ""],
            ["我的笔记本找不到了。", "扮演诸葛亮和我对话。"],
            ["我想要一个新的密码,长度为8位,包含特殊符号。", TOOL_EXAMPLE],
            ["How are you today?", "You are Taylor Swift, use beautiful lyrics to answer questions."],
            ["用C++实现KMP算法,并加上中文注释", ""],
        ],
        cache_examples=False,
    )


if __name__ == "__main__":
    demo.launch()