Tonic commited on
Commit
2682883
1 Parent(s): 2ff4ae4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +234 -0
app.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ from tokenization_yi import YiTokenizer
3
+ import torch
4
+ import os
5
+ import gradio as gr
6
+ import sentencepiece
7
+
8
+ model_id = "01-ai/Yi-34B-200K"
9
+
10
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:50'
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ offload_directory = './model_offload'
13
+ if not os.path.exists(offload_directory):
14
+ os.makedirs(offload_directory)
15
+
16
+ tokenizer = YiTokenizer(vocab_file="./tokenizer.model")
17
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_8bit=True, load_in_8bit_fp32_cpu_offload=True, offload_folder=offload_directory, trust_remote_code=True)
18
+ # model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
19
+ # model = model.to(device)
20
+
21
+ def run(message, chat_history, max_new_tokens=4056, temperature=3.5, top_p=0.9, top_k=800):
22
+ prompt = get_prompt(message, chat_history)
23
+ input_ids = tokenizer.encode(prompt, return_tensors='pt')
24
+ input_ids = input_ids.to(model.device)
25
+ response_ids = model.generate(
26
+ input_ids,
27
+ max_length=max_new_tokens + input_ids.shape[1],
28
+ temperature=temperature,
29
+ top_p=top_p,
30
+ top_k=top_k,
31
+ pad_token_id=tokenizer.eos_token_id,
32
+ do_sample=True
33
+
34
+ )
35
+
36
+ response = tokenizer.decode(response_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
37
+ return response
38
+
39
+ def get_prompt(message, chat_history):
40
+ texts = []
41
+
42
+ do_strip = False
43
+ for user_input, response in chat_history:
44
+ user_input = user_input.strip() if do_strip else user_input
45
+ do_strip = True
46
+ texts.append(f" {response.strip()} {user_input} ")
47
+ message = message.strip() if do_strip else message
48
+ texts.append(f"{message}")
49
+ return ''.join(texts)
50
+
51
+ DESCRIPTION = """
52
+ # 👋🏻Welcome to 🙋🏻‍♂️Tonic's🧑🏻‍🚀YI-200K🚀"
53
+ You can use this Space to test out the current model [Tonic/YI](https://huggingface.co/01-ai/Yi-34B)
54
+ You can also use 🧑🏻‍🚀YI-200K🚀 by cloning this space. 🧬🔬🔍 Simply click here: <a style="display:inline-block" href="https://huggingface.co/spaces/Tonic1/YiTonic?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></h3>
55
+ Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community on 👻Discord: [Discord](https://discord.gg/nXx5wbX9) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Polytonic](https://github.com/tonic-ai) & contribute to 🌟 [PolyGPT](https://github.com/tonic-ai/polygpt-alpha)
56
+ """
57
+
58
+ MAX_MAX_NEW_TOKENS = 4056
59
+ DEFAULT_MAX_NEW_TOKENS = 1256
60
+ MAX_INPUT_TOKEN_LENGTH = 120000
61
+
62
+ def clear_and_save_textbox(message): return '', message
63
+
64
+ def display_input(message, history=[]):
65
+ history.append((message, ''))
66
+ return history
67
+
68
+ def delete_prev_fn(history=[]):
69
+ try:
70
+ message, _ = history.pop()
71
+ except IndexError:
72
+ message = ''
73
+ return history, message or ''
74
+
75
+ def generate(message, history_with_input, max_new_tokens, temperature, top_p, top_k):
76
+ if int(max_new_tokens) > MAX_MAX_NEW_TOKENS:
77
+ raise ValueError
78
+
79
+ history = history_with_input[:-1]
80
+ response = run(message, history, max_new_tokens, temperature, top_p, top_k)
81
+ yield history + [(message, response)]
82
+
83
+
84
+ def process_example(message):
85
+ generator = generate(message, [], 1024, 2.5, 0.95, 900)
86
+ for x in generator:
87
+ pass
88
+ return '', x
89
+
90
+ def check_input_token_length(message, chat_history):
91
+ input_token_length = len(message) + len(chat_history)
92
+ if input_token_length > MAX_INPUT_TOKEN_LENGTH:
93
+ raise gr.Error(f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.")
94
+
95
+ with gr.Blocks(theme='ParityError/Anime') as demo:
96
+ gr.Markdown(DESCRIPTION)
97
+
98
+
99
+
100
+ with gr.Group():
101
+ chatbot = gr.Chatbot(label='TonicYi-30B-200K')
102
+ with gr.Row():
103
+ textbox = gr.Textbox(
104
+ container=False,
105
+ show_label=False,
106
+ placeholder='As the dawn approached, they leant in and said',
107
+ scale=10
108
+ )
109
+ submit_button = gr.Button('Submit', variant='primary', scale=1, min_width=0)
110
+
111
+ with gr.Row():
112
+ retry_button = gr.Button('Retry', variant='secondary')
113
+ undo_button = gr.Button('Undo', variant='secondary')
114
+ clear_button = gr.Button('Clear', variant='secondary')
115
+
116
+ saved_input = gr.State()
117
+
118
+ with gr.Accordion(label='Advanced options', open=False):
119
+ # system_prompt = gr.Textbox(label='System prompt', value=DEFAULT_SYSTEM_PROMPT, lines=5, interactive=False)
120
+ max_new_tokens = gr.Slider(label='Max New Tokens', minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
121
+ temperature = gr.Slider(label='Temperature', minimum=0.1, maximum=4.0, step=0.1, value=0.1)
122
+ top_p = gr.Slider(label='Top-P (nucleus sampling)', minimum=0.05, maximum=1.0, step=0.05, value=0.9)
123
+ top_k = gr.Slider(label='Top-K', minimum=1, maximum=1000, step=1, value=10)
124
+
125
+ textbox.submit(
126
+ fn=clear_and_save_textbox,
127
+ inputs=textbox,
128
+ outputs=[textbox, saved_input],
129
+ api_name=False,
130
+ queue=False,
131
+ ).then(
132
+ fn=display_input,
133
+ inputs=[saved_input, chatbot],
134
+ outputs=chatbot,
135
+ api_name=False,
136
+ queue=False,
137
+ ).then(
138
+ fn=check_input_token_length,
139
+ inputs=[saved_input, chatbot],
140
+ api_name=False,
141
+ queue=False,
142
+ ).success(
143
+ fn=generate,
144
+ inputs=[
145
+ saved_input,
146
+ chatbot,
147
+ max_new_tokens,
148
+ temperature,
149
+ top_p,
150
+ top_k,
151
+ ],
152
+ outputs=chatbot,
153
+ api_name="Generate",
154
+ )
155
+
156
+ button_event_preprocess = submit_button.click(
157
+ fn=clear_and_save_textbox,
158
+ inputs=textbox,
159
+ outputs=[textbox, saved_input],
160
+ api_name=False,
161
+ queue=False,
162
+ ).then(
163
+ fn=display_input,
164
+ inputs=[saved_input, chatbot],
165
+ outputs=chatbot,
166
+ api_name=False,
167
+ queue=False,
168
+ ).then(
169
+ fn=check_input_token_length,
170
+ inputs=[saved_input, chatbot],
171
+ api_name=False,
172
+ queue=False,
173
+ ).success(
174
+ fn=generate,
175
+ inputs=[
176
+ saved_input,
177
+ chatbot,
178
+ max_new_tokens,
179
+ temperature,
180
+ top_p,
181
+ top_k,
182
+ ],
183
+ outputs=chatbot,
184
+ api_name="Cgenerate",
185
+ )
186
+
187
+ retry_button.click(
188
+ fn=delete_prev_fn,
189
+ inputs=chatbot,
190
+ outputs=[chatbot, saved_input],
191
+ api_name=False,
192
+ queue=False,
193
+ ).then(
194
+ fn=display_input,
195
+ inputs=[saved_input, chatbot],
196
+ outputs=chatbot,
197
+ api_name=False,
198
+ queue=False,
199
+ ).then(
200
+ fn=generate,
201
+ inputs=[
202
+ saved_input,
203
+ chatbot,
204
+ max_new_tokens,
205
+ temperature,
206
+ top_p,
207
+ top_k,
208
+ ],
209
+ outputs=chatbot,
210
+ api_name=False,
211
+ )
212
+
213
+ undo_button.click(
214
+ fn=delete_prev_fn,
215
+ inputs=chatbot,
216
+ outputs=[chatbot, saved_input],
217
+ api_name=False,
218
+ queue=False,
219
+ ).then(
220
+ fn=lambda x: x,
221
+ inputs=[saved_input],
222
+ outputs=textbox,
223
+ api_name=False,
224
+ queue=False,
225
+ )
226
+
227
+ clear_button.click(
228
+ fn=lambda: ([], ''),
229
+ outputs=[chatbot, saved_input],
230
+ queue=False,
231
+ api_name=False,
232
+ )
233
+
234
+ demo.queue().launch(show_api=True)