MakiAi commited on
Commit
b2c8b1e
1 Parent(s): 8897312

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -29
app.py CHANGED
@@ -6,18 +6,22 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStream
6
  import gradio as gr
7
  from threading import Thread
8
 
9
- MODEL = "microsoft/Phi-3.5-mini-instruct"
 
 
 
 
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
 
12
- TITLE = "<h1><center>Phi 3.5 Mini</center></h1>"
13
 
14
  PLACEHOLDER = """
15
  <center>
16
- <p>Hi, I'm Phi. Ask me anything.</p>
17
  </center>
18
  """
19
 
20
-
21
  CSS = """
22
  .duplicate-button {
23
  margin: auto !important;
@@ -30,20 +34,26 @@ h3 {
30
  }
31
  """
32
 
33
- device = "cuda" # for GPU usage or "cpu" for CPU usage
34
 
35
  quantization_config = BitsAndBytesConfig(
36
  load_in_4bit=True,
37
  bnb_4bit_compute_dtype=torch.bfloat16,
38
  bnb_4bit_use_double_quant=True,
39
- bnb_4bit_quant_type= "nf4")
40
 
41
- tokenizer = AutoTokenizer.from_pretrained(MODEL)
42
- model = AutoModelForCausalLM.from_pretrained(
43
- MODEL,
44
- torch_dtype=torch.bfloat16,
45
- device_map="auto",
46
- quantization_config=quantization_config)
 
 
 
 
 
 
47
 
48
  @spaces.GPU()
49
  def stream_chat(
@@ -55,7 +65,13 @@ def stream_chat(
55
  top_p: float = 1.0,
56
  top_k: int = 20,
57
  penalty: float = 1.2,
 
58
  ):
 
 
 
 
 
59
  print(f'message: {message}')
60
  print(f'history: {history}')
61
 
@@ -76,12 +92,13 @@ def stream_chat(
76
 
77
  generate_kwargs = dict(
78
  input_ids=input_ids,
79
- max_new_tokens = max_new_tokens,
80
- do_sample = False if temperature == 0 else True,
81
- top_p = top_p,
82
- top_k = top_k,
83
- temperature = temperature,
84
- eos_token_id=[128001,128008,128009],
 
85
  streamer=streamer,
86
  )
87
 
@@ -94,7 +111,6 @@ def stream_chat(
94
  buffer += new_text
95
  yield buffer
96
 
97
-
98
  chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
99
 
100
  with gr.Blocks(css=CSS, theme="Nymbo/Nymbo_Theme") as demo:
@@ -103,12 +119,15 @@ with gr.Blocks(css=CSS, theme="Nymbo/Nymbo_Theme") as demo:
103
  fn=stream_chat,
104
  chatbot=chatbot,
105
  fill_height=True,
106
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
107
  additional_inputs=[
 
 
 
 
 
108
  gr.Textbox(
109
  value="You are a helpful assistant",
110
  label="System Prompt",
111
- render=False,
112
  ),
113
  gr.Slider(
114
  minimum=0,
@@ -116,7 +135,6 @@ with gr.Blocks(css=CSS, theme="Nymbo/Nymbo_Theme") as demo:
116
  step=0.1,
117
  value=0.8,
118
  label="Temperature",
119
- render=False,
120
  ),
121
  gr.Slider(
122
  minimum=128,
@@ -124,7 +142,6 @@ with gr.Blocks(css=CSS, theme="Nymbo/Nymbo_Theme") as demo:
124
  step=1,
125
  value=1024,
126
  label="Max new tokens",
127
- render=False,
128
  ),
129
  gr.Slider(
130
  minimum=0.0,
@@ -132,7 +149,6 @@ with gr.Blocks(css=CSS, theme="Nymbo/Nymbo_Theme") as demo:
132
  step=0.1,
133
  value=1.0,
134
  label="top_p",
135
- render=False,
136
  ),
137
  gr.Slider(
138
  minimum=1,
@@ -140,15 +156,13 @@ with gr.Blocks(css=CSS, theme="Nymbo/Nymbo_Theme") as demo:
140
  step=1,
141
  value=20,
142
  label="top_k",
143
- render=False,
144
  ),
145
  gr.Slider(
146
- minimum=0.0,
147
  maximum=2.0,
148
  step=0.1,
149
  value=1.2,
150
  label="Repetition penalty",
151
- render=False,
152
  ),
153
  ],
154
  examples=[
@@ -160,6 +174,5 @@ with gr.Blocks(css=CSS, theme="Nymbo/Nymbo_Theme") as demo:
160
  cache_examples=False,
161
  )
162
 
163
-
164
  if __name__ == "__main__":
165
- demo.launch()
 
6
  import gradio as gr
7
  from threading import Thread
8
 
9
+ MODELS = {
10
+ "Phi-3.5-mini": "microsoft/Phi-3.5-mini-instruct",
11
+ "Borea-Phi-3.5-mini-Jp": "AXCXEPT/Borea-Phi-3.5-mini-Instruct-Jp",
12
+ "EZO-Common-9B": "HODACHI/EZO-Common-9B-gemma-2-it"
13
+ }
14
+
15
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
16
 
17
+ TITLE = "<h1><center>Multi-Model Chat Interface</center></h1>"
18
 
19
  PLACEHOLDER = """
20
  <center>
21
+ <p>Hi, I'm an AI assistant. Ask me anything.</p>
22
  </center>
23
  """
24
 
 
25
  CSS = """
26
  .duplicate-button {
27
  margin: auto !important;
 
34
  }
35
  """
36
 
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
 
39
  quantization_config = BitsAndBytesConfig(
40
  load_in_4bit=True,
41
  bnb_4bit_compute_dtype=torch.bfloat16,
42
  bnb_4bit_use_double_quant=True,
43
+ bnb_4bit_quant_type="nf4")
44
 
45
+ model = None
46
+ tokenizer = None
47
+
48
+ def load_model(model_name):
49
+ global model, tokenizer
50
+ model_path = MODELS[model_name]
51
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
52
+ model = AutoModelForCausalLM.from_pretrained(
53
+ model_path,
54
+ torch_dtype=torch.bfloat16,
55
+ device_map="auto",
56
+ quantization_config=quantization_config)
57
 
58
  @spaces.GPU()
59
  def stream_chat(
 
65
  top_p: float = 1.0,
66
  top_k: int = 20,
67
  penalty: float = 1.2,
68
+ model_name: str = "Phi-3.5-mini"
69
  ):
70
+ global model, tokenizer
71
+
72
+ if model is None or tokenizer is None or model.name_or_path != MODELS[model_name]:
73
+ load_model(model_name)
74
+
75
  print(f'message: {message}')
76
  print(f'history: {history}')
77
 
 
92
 
93
  generate_kwargs = dict(
94
  input_ids=input_ids,
95
+ max_new_tokens=max_new_tokens,
96
+ do_sample=False if temperature == 0 else True,
97
+ top_p=top_p,
98
+ top_k=top_k,
99
+ temperature=temperature,
100
+ repetition_penalty=penalty,
101
+ eos_token_id=tokenizer.eos_token_id,
102
  streamer=streamer,
103
  )
104
 
 
111
  buffer += new_text
112
  yield buffer
113
 
 
114
  chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
115
 
116
  with gr.Blocks(css=CSS, theme="Nymbo/Nymbo_Theme") as demo:
 
119
  fn=stream_chat,
120
  chatbot=chatbot,
121
  fill_height=True,
 
122
  additional_inputs=[
123
+ gr.Dropdown(
124
+ choices=list(MODELS.keys()),
125
+ value="Phi-3.5-mini",
126
+ label="Model",
127
+ ),
128
  gr.Textbox(
129
  value="You are a helpful assistant",
130
  label="System Prompt",
 
131
  ),
132
  gr.Slider(
133
  minimum=0,
 
135
  step=0.1,
136
  value=0.8,
137
  label="Temperature",
 
138
  ),
139
  gr.Slider(
140
  minimum=128,
 
142
  step=1,
143
  value=1024,
144
  label="Max new tokens",
 
145
  ),
146
  gr.Slider(
147
  minimum=0.0,
 
149
  step=0.1,
150
  value=1.0,
151
  label="top_p",
 
152
  ),
153
  gr.Slider(
154
  minimum=1,
 
156
  step=1,
157
  value=20,
158
  label="top_k",
 
159
  ),
160
  gr.Slider(
161
+ minimum=1.0,
162
  maximum=2.0,
163
  step=0.1,
164
  value=1.2,
165
  label="Repetition penalty",
 
166
  ),
167
  ],
168
  examples=[
 
174
  cache_examples=False,
175
  )
176
 
 
177
  if __name__ == "__main__":
178
+ demo.launch()