Tonic commited on
Commit
4df27ca
1 Parent(s): 11e1b63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -195
app.py CHANGED
@@ -9,6 +9,10 @@ import gradio as gr
9
  import sentencepiece
10
 
11
 
 
 
 
 
12
  DESCRIPTION = """
13
  # Welcome to Tonic'sYI-6B-200K
14
  You can use this Space to test out the current model [01-ai/Yi-6B-200K](https://huggingface.co/01-ai/Yi-6B-200K)
@@ -16,231 +20,63 @@ You can also use YI-200 by cloning this space. Simply click here: <a style="disp
16
  Join us : TeamTonic is always making cool demos! Join our active builder's community on Discord: [Discord](https://discord.gg/nXx5wbX9) On Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On Github: [Polytonic](https://github.com/tonic-ai) & contribute to [PolyGPT](https://github.com/tonic-ai/polygpt-alpha)
17
  """
18
 
19
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:126'
20
-
21
- MAX_MAX_NEW_TOKENS = 160000
22
- DEFAULT_MAX_NEW_TOKENS = 20000
23
- MAX_INPUT_TOKEN_LENGTH = 160000
24
- device = "cuda" if torch.cuda.is_available() else "cpu"
25
 
 
26
  model_name = "01-ai/Yi-6B-200K"
27
-
28
- tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, device_map="cuda", trust_remote_code=True)
29
- # tokenizer = YiTokenizer(vocab_file=model_name)
30
- model = transformers.AutoModelForCausalLM.from_pretrained(model_name,
31
  device_map="auto",
32
  torch_dtype=torch.bfloat16,
33
  load_in_4bit=True,
34
  trust_remote_code=True
35
  )
 
36
 
37
- def run(message, chat_history, max_new_tokens=20000, temperature=1.5, top_p=0.9, top_k=900):
38
- prompt = get_prompt(message, chat_history)
39
-
40
- # Encode the prompt to tensor
41
- input_ids = tokenizer.encode(prompt, return_tensors='pt')
42
-
43
- # Move input_ids to the same device as the model
44
- input_ids = input_ids.to(model.device)
45
-
46
- # Generate a response using the model with adjusted parameters
47
  response_ids = model.generate(
48
  input_ids,
49
  max_length=max_new_tokens + input_ids.shape[1],
50
- temperature=temperature, # Controls randomness. Lower values make text more deterministic.
51
- top_p=top_p, # Nucleus sampling: higher values allow more diversity.
52
- top_k=top_k, # Top-k sampling: limits the number of top tokens considered.
53
  pad_token_id=tokenizer.eos_token_id,
54
- do_sample=True # Enable sampling-based generation
55
-
56
  )
57
-
58
- # Decode the response
59
  response = tokenizer.decode(response_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
60
  return response
61
 
62
- def get_prompt(message, chat_history):
63
- texts = []
64
-
65
- do_strip = False
66
- for user_input, response in chat_history:
67
- user_input = user_input.strip() if do_strip else user_input
68
- do_strip = True
69
- texts.append(f" {response.strip()} {user_input} ")
70
- message = message.strip() if do_strip else message
71
- texts.append(f"{message}")
72
- return ''.join(texts)
73
-
74
- def clear_and_save_textbox(message): return '', message
75
-
76
- def display_input(message, history=[]):
77
- history.append((message, ''))
78
- return history
79
-
80
- def delete_prev_fn(history=[]):
81
- try:
82
- message, _ = history.pop()
83
- except IndexError:
84
- message = ''
85
- return history, message or ''
86
-
87
- def generate(message, history_with_input, max_new_tokens, temperature, top_p, top_k):
88
- if int(max_new_tokens) > MAX_MAX_NEW_TOKENS:
89
- raise ValueError
90
-
91
- history = history_with_input[:-1]
92
- response = run(message, history, max_new_tokens, temperature, top_p, top_k)
93
- yield history + [(message, response)]
94
-
95
-
96
- def process_example(message):
97
- generator = generate(message, [], 4056, 1.9, 0.95, 900)
98
- for x in generator:
99
- pass
100
- return '', x
101
-
102
- def check_input_token_length(message, chat_history):
103
- input_token_length = len(message) + len(chat_history)
104
- if input_token_length > MAX_INPUT_TOKEN_LENGTH:
105
- raise gr.Error(f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.")
106
 
 
107
  with gr.Blocks(theme='ParityError/Anime') as demo:
108
  gr.Markdown(DESCRIPTION)
109
-
110
-
111
 
112
  with gr.Group():
113
- chatbot = gr.Chatbot(label='TonicYi-30B-200K')
114
  with gr.Row():
115
- textbox = gr.Textbox(
116
- container=False,
117
- show_label=False,
118
- placeholder='As the dawn approached, they leant in and said',
119
- scale=10
120
  )
121
- submit_button = gr.Button('Submit', variant='primary', scale=1, min_width=0)
122
-
123
- with gr.Row():
124
- retry_button = gr.Button('Retry', variant='secondary')
125
- undo_button = gr.Button('Undo', variant='secondary')
126
- clear_button = gr.Button('Clear', variant='secondary')
127
-
128
- saved_input = gr.State()
129
 
130
  with gr.Accordion(label='Advanced options', open=False):
131
- # system_prompt = gr.Textbox(label='System prompt', value=DEFAULT_SYSTEM_PROMPT, lines=5, interactive=False)
132
  max_new_tokens = gr.Slider(label='Max New Tokens', minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
133
- temperature = gr.Slider(label='Temperature', minimum=0.1, maximum=2.0, step=0.1, value=0.1)
134
  top_p = gr.Slider(label='Top-P (nucleus sampling)', minimum=0.05, maximum=1.0, step=0.05, value=0.9)
135
- top_k = gr.Slider(label='Top-K', minimum=1, maximum=1000, step=1, value=10)
136
 
137
- textbox.submit(
138
- fn=clear_and_save_textbox,
139
- inputs=textbox,
140
- outputs=[textbox, saved_input],
141
- api_name=False,
142
- queue=False,
143
- ).then(
144
- fn=display_input,
145
- inputs=[saved_input, chatbot],
146
- outputs=chatbot,
147
- api_name=False,
148
- queue=False,
149
- ).then(
150
- fn=check_input_token_length,
151
- inputs=[saved_input, chatbot],
152
- api_name=False,
153
- queue=False,
154
- ).success(
155
- fn=generate,
156
- inputs=[
157
- saved_input,
158
- chatbot,
159
- max_new_tokens,
160
- temperature,
161
- top_p,
162
- top_k,
163
- ],
164
- outputs=chatbot,
165
- api_name="Generate",
166
- )
167
-
168
- button_event_preprocess = submit_button.click(
169
- fn=clear_and_save_textbox,
170
- inputs=textbox,
171
- outputs=[textbox, saved_input],
172
- api_name=False,
173
- queue=False,
174
- ).then(
175
- fn=display_input,
176
- inputs=[saved_input, chatbot],
177
- outputs=chatbot,
178
- api_name=False,
179
- queue=False,
180
- ).then(
181
- fn=check_input_token_length,
182
- inputs=[saved_input, chatbot],
183
- api_name=False,
184
- queue=False,
185
- ).success(
186
- fn=generate,
187
- inputs=[
188
- saved_input,
189
- chatbot,
190
- max_new_tokens,
191
- temperature,
192
- top_p,
193
- top_k,
194
- ],
195
- outputs=chatbot,
196
- api_name="Cgenerate",
197
- )
198
 
199
- retry_button.click(
200
- fn=delete_prev_fn,
201
- inputs=chatbot,
202
- outputs=[chatbot, saved_input],
203
- api_name=False,
204
- queue=False,
205
- ).then(
206
- fn=display_input,
207
- inputs=[saved_input, chatbot],
208
- outputs=chatbot,
209
- api_name=False,
210
- queue=False,
211
- ).then(
212
  fn=generate,
213
- inputs=[
214
- saved_input,
215
- chatbot,
216
- max_new_tokens,
217
- temperature,
218
- top_p,
219
- top_k,
220
- ],
221
- outputs=chatbot,
222
- api_name=False,
223
- )
224
-
225
- undo_button.click(
226
- fn=delete_prev_fn,
227
- inputs=chatbot,
228
- outputs=[chatbot, saved_input],
229
- api_name=False,
230
- queue=False,
231
- ).then(
232
- fn=lambda x: x,
233
- inputs=[saved_input],
234
- outputs=textbox,
235
- api_name=False,
236
- queue=False,
237
- )
238
-
239
- clear_button.click(
240
- fn=lambda: ([], ''),
241
- outputs=[chatbot, saved_input],
242
- queue=False,
243
- api_name=False,
244
  )
245
 
246
  demo.queue(max_size=5).launch(show_api=True)
 
9
  import sentencepiece
10
 
11
 
12
+ # os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:126'
13
+ MAX_MAX_NEW_TOKENS = 160000
14
+ DEFAULT_MAX_NEW_TOKENS = 20000
15
+ MAX_INPUT_TOKEN_LENGTH = 160000
16
  DESCRIPTION = """
17
  # Welcome to Tonic'sYI-6B-200K
18
  You can use this Space to test out the current model [01-ai/Yi-6B-200K](https://huggingface.co/01-ai/Yi-6B-200K)
 
20
  Join us : TeamTonic is always making cool demos! Join our active builder's community on Discord: [Discord](https://discord.gg/nXx5wbX9) On Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On Github: [Polytonic](https://github.com/tonic-ai) & contribute to [PolyGPT](https://github.com/tonic-ai/polygpt-alpha)
21
  """
22
 
 
 
 
 
 
 
23
 
24
+ # Set up the model and tokenizer
25
  model_name = "01-ai/Yi-6B-200K"
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="cuda", trust_remote_code=True)
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ model_name,
30
  device_map="auto",
31
  torch_dtype=torch.bfloat16,
32
  load_in_4bit=True,
33
  trust_remote_code=True
34
  )
35
+ model.to(device)
36
 
37
+ def run(prompt, max_new_tokens, temperature, top_p, top_k):
38
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
 
 
 
 
 
 
 
 
39
  response_ids = model.generate(
40
  input_ids,
41
  max_length=max_new_tokens + input_ids.shape[1],
42
+ temperature=temperature,
43
+ top_p=top_p,
44
+ top_k=top_k,
45
  pad_token_id=tokenizer.eos_token_id,
46
+ do_sample=True
 
47
  )
 
 
48
  response = tokenizer.decode(response_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
49
  return response
50
 
51
+ def generate(prompt, max_new_tokens, temperature, top_p, top_k):
52
+ response = run(prompt, max_new_tokens, temperature, top_p, top_k)
53
+ return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ # Gradio Interface
56
  with gr.Blocks(theme='ParityError/Anime') as demo:
57
  gr.Markdown(DESCRIPTION)
 
 
58
 
59
  with gr.Group():
 
60
  with gr.Row():
61
+ prompt = gr.Textbox(
62
+ label='Enter your prompt',
63
+ placeholder='Type something...',
64
+ lines=5
 
65
  )
66
+ submit_button = gr.Button('Generate')
 
 
 
 
 
 
 
67
 
68
  with gr.Accordion(label='Advanced options', open=False):
 
69
  max_new_tokens = gr.Slider(label='Max New Tokens', minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
70
+ temperature = gr.Slider(label='Temperature', minimum=0.1, maximum=2.0, step=0.1, value=1.2)
71
  top_p = gr.Slider(label='Top-P (nucleus sampling)', minimum=0.05, maximum=1.0, step=0.05, value=0.9)
72
+ top_k = gr.Slider(label='Top-K', minimum=1, maximum=1000, step=1, value=900)
73
 
74
+ output = gr.Textbox(label='Generated Text', lines=10, readonly=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ submit_button.click(
 
 
 
 
 
 
 
 
 
 
 
 
77
  fn=generate,
78
+ inputs=[prompt, max_new_tokens, temperature, top_p, top_k],
79
+ outputs=output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  )
81
 
82
  demo.queue(max_size=5).launch(show_api=True)