Files changed (1) hide show
  1. app.py +169 -13
app.py CHANGED
@@ -5,10 +5,60 @@ from transformers import StoppingCriteria, StoppingCriteriaList
5
  import torch
6
  import spaces
7
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  model_name = "microsoft/Phi-3-medium-128k-instruct"
10
  from transformers import AutoModelForCausalLM, AutoTokenizer
11
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map='cuda', torch_dtype=torch.float16, trust_remote_code=True)
12
  tokenizer = AutoTokenizer.from_pretrained(model_name)
13
 
14
  class StopOnTokens(StoppingCriteria):
@@ -18,9 +68,9 @@ class StopOnTokens(StoppingCriteria):
18
  if input_ids[0][-1] == stop_id:
19
  return True
20
  return False
21
- model.to('cuda')
22
- @spaces.GPU()
23
- def predict(message, history, temperature, max_tokens, top_p, top_k):
24
  history_transformer_format = history + [[message, ""]]
25
  stop = StopOnTokens()
26
  messages = "".join(["".join(["\n<|end|>\n<|user|>\n"+item[0], "\n<|end|>\n<|assistant|>\n"+item[1]]) for item in history_transformer_format])
@@ -32,7 +82,7 @@ def predict(message, history, temperature, max_tokens, top_p, top_k):
32
  max_new_tokens=max_tokens,
33
  do_sample=True,
34
  top_p=top_p,
35
- top_k=top_k,
36
  temperature=temperature,
37
  stopping_criteria=StoppingCriteriaList([stop])
38
  )
@@ -44,14 +94,120 @@ def predict(message, history, temperature, max_tokens, top_p, top_k):
44
  partial_message += new_token
45
  yield partial_message
46
 
47
- demo = gr.ChatInterface(
48
- fn=predict,
49
- title="Phi-3-medium-128k-instruct",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  additional_inputs=[
51
- gr.Slider(0.1, 0.9, value=0.7, label="Temperature"),
52
- gr.Slider(512, 8192, value=4096, label="Max Tokens"),
53
- gr.Slider(0.1, 0.9, value=0.7, label="top_p"),
54
- gr.Slider(10, 90, value=40, label="top_k"),
55
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  demo.launch(share=True)
 
5
  import torch
6
  import spaces
7
  import os
8
+ import subprocess
9
+
10
+ # Install flash attention
11
+ subprocess.run(
12
+ "pip install flash-attn --no-build-isolation",
13
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
14
+ shell=True,
15
+ )
16
+
17
+ theme = gr.themes.Base(
18
+ font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
19
+ )
20
+
21
+ model_name1 = "microsoft/Phi-3-medium-4k-instruct"
22
+ from transformers import AutoModelForCausalLM, AutoTokenizer
23
+ model = AutoModelForCausalLM.from_pretrained(model_name1, device_map='cuda', torch_dtype=torch.float16, _attn_implementation="flash_attention_2", trust_remote_code=True)
24
+ tokenizer = AutoTokenizer.from_pretrained(model_name1)
25
+
26
+ class StopOnTokens(StoppingCriteria):
27
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
28
+ stop_ids = [29, 0]
29
+ for stop_id in stop_ids:
30
+ if input_ids[0][-1] == stop_id:
31
+ return True
32
+ return False
33
+
34
+ @spaces.GPU(duration=40, queue=False)
35
+ def predict1(message, history, temperature, max_tokens, repetition_penalty, top_p):
36
+ history_transformer_format = history + [[message, ""]]
37
+ stop = StopOnTokens()
38
+ messages = "".join(["".join(["\n<|end|>\n<|user|>\n"+item[0], "\n<|end|>\n<|assistant|>\n"+item[1]]) for item in history_transformer_format])
39
+ model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
40
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
41
+ generate_kwargs = dict(
42
+ model_inputs,
43
+ streamer=streamer,
44
+ max_new_tokens=max_tokens1,
45
+ do_sample=True,
46
+ top_p=top_p1,
47
+ repetition_penalty=repetition_penalty1,
48
+ temperature=temperature1,
49
+ stopping_criteria=StoppingCriteriaList([stop])
50
+ )
51
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
52
+ t.start()
53
+ partial_message = ""
54
+ for new_token in streamer:
55
+ if new_token != '<':
56
+ partial_message += new_token
57
+ yield partial_message
58
 
59
  model_name = "microsoft/Phi-3-medium-128k-instruct"
60
  from transformers import AutoModelForCausalLM, AutoTokenizer
61
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map='cuda', torch_dtype=torch.float16, _attn_implementation="flash_attention_2", trust_remote_code=True)
62
  tokenizer = AutoTokenizer.from_pretrained(model_name)
63
 
64
  class StopOnTokens(StoppingCriteria):
 
68
  if input_ids[0][-1] == stop_id:
69
  return True
70
  return False
71
+
72
+ @spaces.GPU(duration=40, queue=False)
73
+ def predict(message, history, temperature, max_tokens, repetition_penalty, top_p):
74
  history_transformer_format = history + [[message, ""]]
75
  stop = StopOnTokens()
76
  messages = "".join(["".join(["\n<|end|>\n<|user|>\n"+item[0], "\n<|end|>\n<|assistant|>\n"+item[1]]) for item in history_transformer_format])
 
82
  max_new_tokens=max_tokens,
83
  do_sample=True,
84
  top_p=top_p,
85
+ repetition_penalty=repetition_penalty,
86
  temperature=temperature,
87
  stopping_criteria=StoppingCriteriaList([stop])
88
  )
 
94
  partial_message += new_token
95
  yield partial_message
96
 
97
+ max_tokens1 = gr.Slider(
98
+ minimum=512,
99
+ maximum=4096,
100
+ value=4096,
101
+ step=32,
102
+ interactive=True,
103
+ label="Maximum number of new tokens to generate",
104
+ )
105
+ repetition_penalty1 = gr.Slider(
106
+ minimum=0.01,
107
+ maximum=5.0,
108
+ value=1,
109
+ step=0.01,
110
+ interactive=True,
111
+ label="Repetition penalty",
112
+ )
113
+ temperature1 = gr.Slider(
114
+ minimum=0.0,
115
+ maximum=1.0,
116
+ value=0.7,
117
+ step=0.05,
118
+ visible=True,
119
+ interactive=True,
120
+ label="Temperature",
121
+ )
122
+ top_p1 = gr.Slider(
123
+ minimum=0.01,
124
+ maximum=0.99,
125
+ value=0.9,
126
+ step=0.01,
127
+ visible=True,
128
+ interactive=True,
129
+ label="Top P",
130
+ )
131
+
132
+ chatbot1 = gr.Chatbot(
133
+ label="Phi3-medium-4k",
134
+ show_copy_button=True,
135
+ likeable=True,
136
+ layout="panel"
137
+ )
138
+
139
+ output=gr.Textbox(label="Prompt")
140
+
141
+ with gr.Blocks() as min:
142
+ gr.ChatInterface(
143
+ fn=predict1,
144
+ chatbot=chatbot1,
145
  additional_inputs=[
146
+ temperature1,
147
+ max_tokens1,
148
+ repetition_penalty1,
149
+ top_p1,
150
+ ],
151
+ )
152
+
153
+ max_tokens = gr.Slider(
154
+ minimum=64000,
155
+ maximum=128000,
156
+ value=100000,
157
+ step=1000,
158
+ interactive=True,
159
+ label="Maximum number of new tokens to generate",
160
+ )
161
+ repetition_penalty = gr.Slider(
162
+ minimum=0.01,
163
+ maximum=5.0,
164
+ value=1,
165
+ step=0.01,
166
+ interactive=True,
167
+ label="Repetition penalty",
168
  )
169
+ temperature = gr.Slider(
170
+ minimum=0.0,
171
+ maximum=1.0,
172
+ value=0.7,
173
+ step=0.05,
174
+ visible=True,
175
+ interactive=True,
176
+ label="Temperature",
177
+ )
178
+ top_p = gr.Slider(
179
+ minimum=0.01,
180
+ maximum=0.99,
181
+ value=0.9,
182
+ step=0.01,
183
+ visible=True,
184
+ interactive=True,
185
+ label="Top P",
186
+ )
187
+
188
+ chatbot = gr.Chatbot(
189
+ label="Phi3-medium-128k",
190
+ show_copy_button=True,
191
+ likeable=True,
192
+ layout="panel"
193
+ )
194
+
195
+ output=gr.Textbox(label="Prompt")
196
+
197
+ with gr.Blocks() as max:
198
+ gr.ChatInterface(
199
+ fn=predict,
200
+ chatbot=chatbot,
201
+ additional_inputs=[
202
+ temperature,
203
+ max_tokens,
204
+ repetition_penalty,
205
+ top_p,
206
+ ],
207
+ )
208
+
209
+ with gr.Blocks(title="Phi 3 Medium DEMO", theme=theme) as demo:
210
+ gr.Markdown("# Phi3 Medium all in one")
211
+ gr.TabbedInterface([max, min], ['Phi3 medium 128k','Phi3 medium 4k'])
212
+
213
  demo.launch(share=True)