hiyouga's picture
Update app.py
b5ab4e2 verified
raw
history blame contribute delete
No virus
4.1 kB
from threading import Thread
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
TITLE = "<h1><center>Chat with Gemma-2-9B-Chinese-Chat</center></h1>"
DESCRIPTION = "<h3><center>Visit <a href='https://huggingface.co/shenzhi-wang/Gemma-2-9B-Chinese-Chat' target='_blank'>our model page</a> for details.</center></h3>"
DEFAULT_SYSTEM = "You are a helpful assistant."
TOOL_EXAMPLE = '''You have access to the following tools:
```python
def generate_password(length: int, include_symbols: Optional[bool]):
"""
Generate a random password.
Args:
length (int): The length of the password
include_symbols (Optional[bool]): Include symbols in the password
"""
pass
```
Write "Action:" followed by a list of actions in JSON that you want to call, e.g.
Action:
```json
[
{
"name": "tool name (one of [generate_password])",
"arguments": "the input to the tool"
}
]
```
'''
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
"""
tokenizer = AutoTokenizer.from_pretrained("shenzhi-wang/Gemma-2-9B-Chinese-Chat")
model = AutoModelForCausalLM.from_pretrained("shenzhi-wang/Gemma-2-9B-Chinese-Chat", torch_dtype="auto", device_map="auto")
@spaces.GPU
def stream_chat(message: str, history: list, system: str, temperature: float, max_new_tokens: int):
conversation = [{"role": "system", "content": system or DEFAULT_SYSTEM}]
for prompt, answer in history:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(
model.device
)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
)
if temperature == 0:
generate_kwargs["do_sample"] = False
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
output = ""
for new_token in streamer:
output += new_token
yield output
chatbot = gr.Chatbot(height=450)
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Text(
value="",
label="System",
render=False,
),
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=1024,
label="Max new tokens",
render=False,
),
],
examples=[
["我的蓝牙耳机坏了,我该去看牙科还是耳鼻喉科?", ""],
["7年前,妈妈年龄是儿子的6倍,儿子今年12岁,妈妈今年多少岁?", ""],
["我的笔记本找不到了。", "扮演诸葛亮和我对话。"],
["我想要一个新的密码,长度为8位,包含特殊符号。", TOOL_EXAMPLE],
["How are you today?", "You are Taylor Swift, use beautiful lyrics to answer questions."],
["用C++实现KMP算法,并加上中文注释", ""],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()