Files changed (1) hide show
  1. app.py +163 -12
app.py CHANGED
@@ -13,9 +13,54 @@ subprocess.run(
13
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
14
  shell=True,
15
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  model_name = "microsoft/Phi-3-medium-128k-instruct"
17
  from transformers import AutoModelForCausalLM, AutoTokenizer
18
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map='cuda', _attn_implementation="flash_attention_2", torch_dtype=torch.float16, trust_remote_code=True)
19
  tokenizer = AutoTokenizer.from_pretrained(model_name)
20
 
21
  class StopOnTokens(StoppingCriteria):
@@ -25,9 +70,9 @@ class StopOnTokens(StoppingCriteria):
25
  if input_ids[0][-1] == stop_id:
26
  return True
27
  return False
28
- model.to('cuda')
29
  @spaces.GPU(queue=False)
30
- def predict(message, history, temperature, max_tokens, top_p, top_k):
31
  history_transformer_format = history + [[message, ""]]
32
  stop = StopOnTokens()
33
  messages = "".join(["".join(["\n<|end|>\n<|user|>\n"+item[0], "\n<|end|>\n<|assistant|>\n"+item[1]]) for item in history_transformer_format])
@@ -39,7 +84,7 @@ def predict(message, history, temperature, max_tokens, top_p, top_k):
39
  max_new_tokens=max_tokens,
40
  do_sample=True,
41
  top_p=top_p,
42
- top_k=top_k,
43
  temperature=temperature,
44
  stopping_criteria=StoppingCriteriaList([stop])
45
  )
@@ -51,14 +96,120 @@ def predict(message, history, temperature, max_tokens, top_p, top_k):
51
  partial_message += new_token
52
  yield partial_message
53
 
54
- demo = gr.ChatInterface(
55
- fn=predict,
56
- title="Phi-3-medium-128k-instruct",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  additional_inputs=[
58
- gr.Slider(0.1, 0.9, step=0.1, value=0.7, label="Temperature"),
59
- gr.Slider(512, 8192, value=4096, label="Max Tokens"),
60
- gr.Slider(0.1, 0.9, step=0.1, value=0.7, label="top_p"),
61
- gr.Slider(10, 90, step=10, value=40, label="top_k"),
62
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  demo.launch(share=True)
 
13
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
14
  shell=True,
15
  )
16
+
17
+ theme = gr.themes.Base(
18
+ font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
19
+ )
20
+
21
+ model_name = "microsoft/Phi-3-medium-4k-instruct"
22
+ from transformers import AutoModelForCausalLM, AutoTokenizer
23
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map='cuda', torch_dtype=torch.float16, _attn_implementation="flash_attention_2", trust_remote_code=True)
24
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
25
+
26
+ class StopOnTokens(StoppingCriteria):
27
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
28
+ stop_ids = [29, 0]
29
+ for stop_id in stop_ids:
30
+ if input_ids[0][-1] == stop_id:
31
+ return True
32
+ return False
33
+
34
+ @spaces.GPU(queue=False)
35
+ def predict1(message, history, temperature1, max_tokens1, repetition_penalty1, top_p1):
36
+ history_transformer_format = history + [[message, ""]]
37
+ stop = StopOnTokens()
38
+ messages = "".join(["".join(["\n<|end|>\n<|user|>\n"+item[0], "\n<|end|>\n<|assistant|>\n"+item[1]]) for item in history_transformer_format])
39
+ model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
40
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
41
+ generate_kwargs = dict(
42
+ model_inputs,
43
+ streamer=streamer,
44
+ max_new_tokens=max_tokens1,
45
+ do_sample=True,
46
+ top_p=top_p1,
47
+ repetition_penalty=repetition_penalty1,
48
+ temperature=temperature1,
49
+ stopping_criteria=StoppingCriteriaList([stop])
50
+ )
51
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
52
+ t.start()
53
+ partial_message = ""
54
+ for new_token in streamer:
55
+ if new_token != '<':
56
+ partial_message += new_token
57
+ yield partial_message
58
+
59
+
60
+
61
  model_name = "microsoft/Phi-3-medium-128k-instruct"
62
  from transformers import AutoModelForCausalLM, AutoTokenizer
63
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map='cuda', torch_dtype=torch.float16, _attn_implementation="flash_attention_2", trust_remote_code=True)
64
  tokenizer = AutoTokenizer.from_pretrained(model_name)
65
 
66
  class StopOnTokens(StoppingCriteria):
 
70
  if input_ids[0][-1] == stop_id:
71
  return True
72
  return False
73
+
74
  @spaces.GPU(queue=False)
75
+ def predict(message, history, temperature, max_tokens, repetition_penalty, top_p):
76
  history_transformer_format = history + [[message, ""]]
77
  stop = StopOnTokens()
78
  messages = "".join(["".join(["\n<|end|>\n<|user|>\n"+item[0], "\n<|end|>\n<|assistant|>\n"+item[1]]) for item in history_transformer_format])
 
84
  max_new_tokens=max_tokens,
85
  do_sample=True,
86
  top_p=top_p,
87
+ repetition_penalty=repetition_penalty,
88
  temperature=temperature,
89
  stopping_criteria=StoppingCriteriaList([stop])
90
  )
 
96
  partial_message += new_token
97
  yield partial_message
98
 
99
+ max_tokens1 = gr.Slider(
100
+ minimum=512,
101
+ maximum=4096,
102
+ value=4000,
103
+ step=32,
104
+ interactive=True,
105
+ label="Maximum number of new tokens to generate",
106
+ )
107
+ repetition_penalty1 = gr.Slider(
108
+ minimum=0.01,
109
+ maximum=5.0,
110
+ value=1,
111
+ step=0.01,
112
+ interactive=True,
113
+ label="Repetition penalty",
114
+ )
115
+ temperature1 = gr.Slider(
116
+ minimum=0.0,
117
+ maximum=1.0,
118
+ value=0.7,
119
+ step=0.05,
120
+ visible=True,
121
+ interactive=True,
122
+ label="Temperature",
123
+ )
124
+ top_p1 = gr.Slider(
125
+ minimum=0.01,
126
+ maximum=0.99,
127
+ value=0.9,
128
+ step=0.01,
129
+ visible=True,
130
+ interactive=True,
131
+ label="Top P",
132
+ )
133
+
134
+ chatbot1 = gr.Chatbot(
135
+ label="Phi3-medium-4k",
136
+ show_copy_button=True,
137
+ likeable=True,
138
+ layout="panel"
139
+ )
140
+
141
+ output=gr.Textbox(label="Prompt")
142
+
143
+ with gr.Blocks() as min:
144
+ gr.ChatInterface(
145
+ fn=predict1,
146
+ chatbot=chatbot1,
147
  additional_inputs=[
148
+ temperature1,
149
+ max_tokens1,
150
+ repetition_penalty1,
151
+ top_p1,
152
+ ],
153
+ )
154
+
155
+ max_tokens = gr.Slider(
156
+ minimum=64000,
157
+ maximum=128000,
158
+ value=100000,
159
+ step=1000,
160
+ interactive=True,
161
+ label="Maximum number of new tokens to generate",
162
+ )
163
+ repetition_penalty = gr.Slider(
164
+ minimum=0.01,
165
+ maximum=5.0,
166
+ value=1,
167
+ step=0.01,
168
+ interactive=True,
169
+ label="Repetition penalty",
170
+ )
171
+ temperature = gr.Slider(
172
+ minimum=0.0,
173
+ maximum=1.0,
174
+ value=0.7,
175
+ step=0.05,
176
+ visible=True,
177
+ interactive=True,
178
+ label="Temperature",
179
+ )
180
+ top_p = gr.Slider(
181
+ minimum=0.01,
182
+ maximum=0.99,
183
+ value=0.9,
184
+ step=0.01,
185
+ visible=True,
186
+ interactive=True,
187
+ label="Top P",
188
  )
189
+
190
+ chatbot = gr.Chatbot(
191
+ label="Phi3-medium-128k",
192
+ show_copy_button=True,
193
+ likeable=True,
194
+ layout="panel"
195
+ )
196
+
197
+ output=gr.Textbox(label="Prompt")
198
+
199
+ with gr.Blocks() as max:
200
+ gr.ChatInterface(
201
+ fn=predict,
202
+ chatbot=chatbot,
203
+ additional_inputs=[
204
+ temperature,
205
+ max_tokens,
206
+ repetition_penalty,
207
+ top_p,
208
+ ],
209
+ )
210
+
211
+ with gr.Blocks(title="Phi 3 Medium DEMO", theme=theme) as demo:
212
+ gr.Markdown("# Phi3 Medium all in one")
213
+ gr.TabbedInterface([max, min], ['Phi3 medium 128k','Phi3 medium 4k'])
214
+
215
  demo.launch(share=True)