hiyouga commited on
Commit
ea4be8b
1 Parent(s): db48480

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -8
app.py CHANGED
@@ -1,16 +1,96 @@
1
  import gradio as gr
2
  import spaces
3
  import torch
 
4
 
5
- zero = torch.Tensor([0]).cuda()
6
- print(zero.device) # <-- 'cpu' 🤔
7
 
8
 
9
- @spaces.GPU
10
- def greet(n):
11
- print(zero.device) # <-- 'cuda:0' 🤗
12
- return f"Hello {zero + n} Tensor"
13
 
 
14
 
15
- demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text())
16
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
+ from threading import Thread
5
 
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
7
 
8
 
9
+ TITLE = "<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>"
 
 
 
10
 
11
+ DESCRIPTION = "<h3><center>Visit <a href='' target='_blank'>LLaMA Factory</a> for details.</center></h3>"
12
 
13
+ CSS = r"""
14
+ .duplicate-button {
15
+ margin: auto !important;
16
+ color: white !important;
17
+ background: black !important;
18
+ border-radius: 100vh !important;
19
+ }
20
+ """
21
+
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained("shenzhi-wang/Llama3-8B-Chinese-Chat")
24
+ model = AutoModelForCausalLM.from_pretrained("shenzhi-wang/Llama3-8B-Chinese-Chat", device_map="auto")
25
+
26
+
27
+ @spaces.GPU(duration=120)
28
+ def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):
29
+ conversation = []
30
+ for prompt, answer in history:
31
+ conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
32
+
33
+ conversation.append({"role": "user", "content": message})
34
+
35
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
36
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
37
+
38
+ generate_kwargs = dict(
39
+ input_ids=input_ids,
40
+ streamer=streamer,
41
+ max_new_tokens=max_new_tokens,
42
+ temperature=temperature,
43
+ do_sample=True,
44
+ )
45
+ if temperature == 0:
46
+ generate_kwargs["do_sample"] = False
47
+
48
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
49
+ t.start()
50
+
51
+ output = ""
52
+ for new_token in streamer:
53
+ outputs += new_token
54
+ yield output
55
+
56
+
57
+ with gr.Blocks(fill_height=True, css=CSS) as demo:
58
+ gr.HTML(TITLE)
59
+ gr.HTML(DESCRIPTION)
60
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
61
+
62
+ gr.ChatInterface(
63
+ fn=stream_chat,
64
+ fill_height=True,
65
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
66
+ additional_inputs=[
67
+ gr.Slider(
68
+ minimum=0,
69
+ maximum=1,
70
+ step=0.1,
71
+ value=0.95,
72
+ label="Temperature",
73
+ render=False,
74
+ ),
75
+ gr.Slider(
76
+ minimum=128,
77
+ maximum=4096,
78
+ step=1,
79
+ value=512,
80
+ label="Max new tokens",
81
+ render=False,
82
+ ),
83
+ ],
84
+ examples=[
85
+ ['How to setup a human base on Mars? Give short answer.'],
86
+ ['Explain theory of relativity to me like I’m 8 years old.'],
87
+ ['What is 9,000 * 9,000?'],
88
+ ['Write a pun-filled happy birthday message to my friend Alex.'],
89
+ ['Justify why a penguin might make a good king of the jungle.']
90
+ ],
91
+ cache_examples=False,
92
+ )
93
+
94
+
95
+ if __name__ == "__main__":
96
+ demo.launch()