KingNish commited on
Commit
669951f
1 Parent(s): 87a7979

Complete Overhaul

Browse files

- increased context length
- changed theme
- added 4k model
- Made it one place for all phi 3 medium models.

Files changed (1) hide show
  1. app.py +74 -11
app.py CHANGED
@@ -6,6 +6,48 @@ import torch
6
  import spaces
7
  import os
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  model_name = "microsoft/Phi-3-medium-128k-instruct"
10
  from transformers import AutoModelForCausalLM, AutoTokenizer
11
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map='cuda', torch_dtype=torch.float16, trust_remote_code=True)
@@ -19,7 +61,7 @@ class StopOnTokens(StoppingCriteria):
19
  return True
20
  return False
21
 
22
- @spaces.GPU(duration=120)
23
  def predict(message, history, temperature, max_tokens, top_p, top_k):
24
  history_transformer_format = history + [[message, ""]]
25
  stop = StopOnTokens()
@@ -44,14 +86,35 @@ def predict(message, history, temperature, max_tokens, top_p, top_k):
44
  partial_message += new_token
45
  yield partial_message
46
 
47
- demo = gr.ChatInterface(
48
- fn=predict,
49
- title="Phi-3-medium-128k-instruct",
50
- additional_inputs=[
51
- gr.Slider(0.1, 0.9, value=0.7, label="Temperature"),
52
- gr.Slider(512, 8192, value=4096, label="Max Tokens"),
53
- gr.Slider(0.1, 0.9, value=0.7, label="top_p"),
54
- gr.Slider(10, 90, value=40, label="top_k"),
55
- ]
56
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  demo.launch(share=True)
 
6
  import spaces
7
  import os
8
 
9
+ theme = gr.themes.Base(
10
+ font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
11
+ )
12
+
13
+ model_name1 = "microsoft/Phi-3-medium-4k-instruct"
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer
15
+ model1 = AutoModelForCausalLM.from_pretrained(model_name1, device_map='cuda', torch_dtype=torch.float16, trust_remote_code=True)
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name1)
17
+
18
+ class StopOnTokens(StoppingCriteria):
19
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
20
+ stop_ids = [29, 0]
21
+ for stop_id in stop_ids:
22
+ if input_ids[0][-1] == stop_id:
23
+ return True
24
+ return False
25
+
26
+ @spaces.GPU(duration=40)
27
+ def predict1(message, history, temperature, max_tokens, top_p, top_k):
28
+ history_transformer_format = history + [[message, ""]]
29
+ stop = StopOnTokens()
30
+ messages = "".join(["".join(["\n<|end|>\n<|user|>\n"+item[0], "\n<|end|>\n<|assistant|>\n"+item[1]]) for item in history_transformer_format])
31
+ model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
32
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
33
+ generate_kwargs = dict(
34
+ model_inputs,
35
+ streamer=streamer,
36
+ max_new_tokens=max_tokens,
37
+ do_sample=True,
38
+ top_p=top_p,
39
+ top_k=top_k,
40
+ temperature=temperature,
41
+ stopping_criteria=StoppingCriteriaList([stop])
42
+ )
43
+ t = Thread(target=model1.generate, kwargs=generate_kwargs)
44
+ t.start()
45
+ partial_message = ""
46
+ for new_token in streamer:
47
+ if new_token != '<':
48
+ partial_message += new_token
49
+ yield partial_message
50
+
51
  model_name = "microsoft/Phi-3-medium-128k-instruct"
52
  from transformers import AutoModelForCausalLM, AutoTokenizer
53
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map='cuda', torch_dtype=torch.float16, trust_remote_code=True)
 
61
  return True
62
  return False
63
 
64
+ @spaces.GPU(duration=60)
65
  def predict(message, history, temperature, max_tokens, top_p, top_k):
66
  history_transformer_format = history + [[message, ""]]
67
  stop = StopOnTokens()
 
86
  partial_message += new_token
87
  yield partial_message
88
 
89
+ with gr.Blocks() as min:
90
+ gr.ChatInterface(
91
+ fn=predict1,
92
+ title="Phi-3-medium-4k-instruct",
93
+ additional_inputs=[
94
+ gr.Slider(0.1, 0.9, value=0.7, label="Temperature"),
95
+ gr.Slider(512, 4096, value=4096, label="Max Tokens"),
96
+ gr.Slider(0.1, 0.9, value=0.7, label="top_p"),
97
+ gr.Slider(10, 90, value=40, label="top_k"),
98
+ ]
99
+ )
100
+
101
+
102
+ with gr.Blocks() as max:
103
+ gr.ChatInterface(
104
+ fn=predict,
105
+ title="Phi-3-medium-128k-instruct",
106
+ additional_inputs=[
107
+ gr.Slider(0.1, 0.9, value=0.7, label="Temperature"),
108
+ gr.Slider(64000, 128000, value=100000, label="Max Tokens"),
109
+ gr.Slider(0.1, 0.9, value=0.7, label="top_p"),
110
+ gr.Slider(10, 90, value=40, label="top_k"),
111
+ ]
112
+ )
113
+
114
+
115
+
116
+ with gr.Blocks(theme=theme, title="Phi 3 Medium DEMO") as demo:
117
+ gr.Markdown("# Phi3 Medium all in one")
118
+ gr.TabbedInterface([max, min], ['Phi3 medium 128k','Phi3 medium 4k'])
119
+
120
  demo.launch(share=True)