nuojohnchen commited on
Commit
e77dcbd
·
verified ·
1 Parent(s): d46d2bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +282 -50
app.py CHANGED
@@ -1,64 +1,296 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
27
 
28
- response = ""
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
  demo.launch()
 
1
  import gradio as gr
2
+ import os
3
+ import spaces
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from threading import Thread
6
+ import torch
7
+ import time
8
 
9
+ # Set environment variables
10
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
 
12
+ # Apollo system prompt
13
+ SYSTEM_PROMPT = "You are Apollo, a multilingual medical model. You communicate with people and assist them."
 
 
 
 
 
 
 
14
 
15
+ # Apollo model options
16
+ APOLLO_MODELS = {
17
+ "Apollo": [
18
+ "FreedomIntelligence/Apollo-7B",
19
+ "FreedomIntelligence/Apollo-6B",
20
+ "FreedomIntelligence/Apollo-2B",
21
+ "FreedomIntelligence/Apollo-0.5B",
22
+
23
+ ],
24
+ "Apollo2": [
25
+ "FreedomIntelligence/Apollo2-7B",
26
+ "FreedomIntelligence/Apollo2-3.8B",
27
+ "FreedomIntelligence/Apollo2-2B",
28
+ ],
29
+ "Apollo-MoE": [
30
+ "FreedomIntelligence/Apollo-MoE-7B",
31
+ "FreedomIntelligence/Apollo-MoE-1.5B",
32
+ "FreedomIntelligence/Apollo-MoE-0.5B",
33
+
34
+ ]
35
+ }
36
 
37
+ # CSS styles
38
+ css = """
39
+ h1 {
40
+ text-align: center;
41
+ display: block;
42
+ }
43
+ .gradio-container {
44
+ max-width: 1200px;
45
+ margin: auto;
46
+ }
47
+ """
48
 
49
+ # Global variables to store currently loaded model and tokenizer
50
+ current_model = None
51
+ current_tokenizer = None
52
+ current_model_path = None
53
 
54
+ @spaces.GPU(duration=120)
55
+ def load_model(model_path, progress=gr.Progress()):
56
+ """Load the selected model and tokenizer"""
57
+ global current_model, current_tokenizer, current_model_path
58
+
59
+ # If the same model is already loaded, don't reload it
60
+ if current_model_path == model_path and current_model is not None:
61
+ return "Model already loaded, no need to reload."
62
+
63
+ # Clean up previously loaded model (if any)
64
+ if current_model is not None:
65
+ del current_model
66
+ del current_tokenizer
67
+ torch.cuda.empty_cache()
68
+
69
+ progress(0.1, desc=f"Starting to load model {model_path}...")
70
+
71
+ try:
72
+ progress(0.3, desc="Loading tokenizer...")
73
+ current_tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
74
+
75
+ progress(0.5, desc="Loading model...")
76
+ current_model = AutoModelForCausalLM.from_pretrained(
77
+ model_path,
78
+ device_map="auto",
79
+ torch_dtype=torch.float16
80
+ )
81
+
82
+ current_model_path = model_path
83
+ progress(1.0, desc="Model loading complete!")
84
+ return f"Model {model_path} successfully loaded."
85
+ except Exception as e:
86
+ progress(1.0, desc="Model loading failed!")
87
+ return f"Model loading failed: {str(e)}"
88
 
89
+ @spaces.GPU(duration=120)
90
+ def generate_response_non_streaming(instruction, model_name, temperature=0.7, max_tokens=1024):
91
+ """Generate a response from the Apollo model (non-streaming)"""
92
+ global current_model, current_tokenizer, current_model_path
93
+
94
+ # If model is not yet loaded, load it first
95
+ if current_model_path != model_name or current_model is None:
96
+ load_message = load_model(model_name)
97
+ if "failed" in load_message.lower():
98
+ return load_message
99
+
100
+ try:
101
+ # 检查模型是否有聊天模板
102
+ if hasattr(current_tokenizer, 'chat_template') and current_tokenizer.chat_template:
103
+ # 使用模型的聊天模板
104
+ messages = [
105
+ {"role": "system", "content": SYSTEM_PROMPT},
106
+ {"role": "user", "content": instruction}
107
+ ]
108
+
109
+ # 使用模型的聊天模板格式化输入
110
+ chat_input = current_tokenizer.apply_chat_template(
111
+ messages,
112
+ tokenize=True,
113
+ return_tensors="pt"
114
+ ).to(current_model.device)
115
+ else:
116
+ # 使用指定的提示格式
117
+ prompt = f"User:{instruction}\nAssistant:"
118
+ chat_input = current_tokenizer.encode(prompt, return_tensors="pt").to(current_model.device)
119
+
120
+ # 获取<|endoftext|>的token id,用于停止生成
121
+ eos_token_id = current_tokenizer.eos_token_id
122
+
123
+ # 生成响应
124
+ output = current_model.generate(
125
+ input_ids=chat_input,
126
+ max_new_tokens=max_tokens,
127
+ temperature=temperature,
128
+ do_sample=(temperature > 0),
129
+ eos_token_id=current_tokenizer.eos_token_id # 使用<|endoftext|>作为停止标记
130
+ )
131
+
132
+ # 解码并返回生成的文本
133
+ generated_text = current_tokenizer.decode(output[0][len(chat_input[0]):], skip_special_tokens=True)
134
+ return generated_text
135
+ except Exception as e:
136
+ return f"生成响应时出错: {str(e)}"
137
 
138
+ def update_chat_with_response(chatbot, instruction, model_name, temperature, max_tokens):
139
+ """Updates the chatbot with non-streaming response"""
140
+ global current_model, current_tokenizer, current_model_path
141
+
142
+ # If model is not yet loaded, load it first
143
+ if current_model_path != model_name or current_model is None:
144
+ load_result = load_model(model_name)
145
+ if "failed" in load_result.lower():
146
+ new_chat = list(chatbot)
147
+ new_chat[-1] = (instruction, load_result)
148
+ return new_chat
149
+
150
+ # Generate response using the non-streaming function
151
+ response = generate_response_non_streaming(instruction, model_name, temperature, max_tokens)
152
+
153
+ # Create a copy of the current chatbot and add the response
154
+ new_chat = list(chatbot)
155
+ new_chat[-1] = (instruction, response)
156
+
157
+ return new_chat
158
 
159
+ def on_model_series_change(model_series):
160
+ """Update available model list based on selected model series"""
161
+ if model_series in APOLLO_MODELS:
162
+ return gr.update(choices=APOLLO_MODELS[model_series], value=APOLLO_MODELS[model_series][0])
163
+ return gr.update(choices=[], value=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
+ # Create Gradio interface
166
+ with gr.Blocks(css=css) as demo:
167
+ # Title and description
168
+ favicon = "🩺"
169
+ gr.Markdown(
170
+ f"""# {favicon} Apollo Playground
171
+ This is a demo of the multilingual medical model series **[Apollo](https://huggingface.co/FreedomIntelligence/Apollo-7B-GGUF)**.
172
+ [Apollo1](https://arxiv.org/abs/2403.03640) supports 6 languages. [Apollo2](https://arxiv.org/abs/2410.10626) supports 50 languages.
173
+ """
174
+ )
175
+
176
+ with gr.Row():
177
+ with gr.Column(scale=1):
178
+ # Model selection controls
179
+ model_series = gr.Dropdown(
180
+ choices=list(APOLLO_MODELS.keys()),
181
+ value="Apollo",
182
+ label="Select Model Series",
183
+ info="First choose Apollo, Apollo2 or Apollo-MoE"
184
+ )
185
+
186
+ model_name = gr.Dropdown(
187
+ choices=APOLLO_MODELS["Apollo"],
188
+ value=APOLLO_MODELS["Apollo"][0],
189
+ label="Select Model Size",
190
+ info="Select the specific model size based on the chosen model series"
191
+ )
192
+
193
+ # Parameter settings
194
+ with gr.Accordion("Generation Parameters", open=False):
195
+ temperature = gr.Slider(
196
+ minimum=0.0,
197
+ maximum=1.0,
198
+ value=0.7,
199
+ step=0.05,
200
+ label="Temperature"
201
+ )
202
+ max_tokens = gr.Slider(
203
+ minimum=128,
204
+ maximum=2048,
205
+ value=1024,
206
+ step=32,
207
+ label="Maximum Tokens"
208
+ )
209
+
210
+ # Load model button
211
+ load_button = gr.Button("Load Model")
212
+ model_status = gr.Textbox(label="Model Status", value="No model loaded yet")
213
+
214
+ with gr.Column(scale=2):
215
+ # Chat interface
216
+ chatbot = gr.Chatbot(label="Conversation", height=500, value=[]) # Initialize with empty list
217
+ user_input = gr.Textbox(
218
+ label="Input Medical Question",
219
+ placeholder="Example: What are the symptoms of hypertension? 高血压有哪些症状?",
220
+ lines=3
221
+ )
222
+ submit_button = gr.Button("Submit")
223
+ clear_button = gr.Button("Clear Chat")
224
+
225
+ # Event handling
226
+ # Update model selection when model series changes
227
+ model_series.change(
228
+ fn=on_model_series_change,
229
+ inputs=model_series,
230
+ outputs=model_name
231
+ )
232
+
233
+ # Load model
234
+ load_button.click(
235
+ fn=load_model,
236
+ inputs=model_name,
237
+ outputs=model_status
238
+ )
239
+
240
+ # Handle message submission
241
+ def user_message_submitted(message, chat_history):
242
+ """Handle user submitted message"""
243
+ # Ensure chat_history is a list
244
+ if chat_history is None:
245
+ chat_history = []
246
+
247
+ if message.strip() == "":
248
+ return "", chat_history
249
+
250
+ # Add user message to chat history
251
+ chat_history = list(chat_history)
252
+ chat_history.append((message, None))
253
+ return "", chat_history
254
+
255
+ # Bind message submission
256
+ submit_event = user_input.submit(
257
+ fn=user_message_submitted,
258
+ inputs=[user_input, chatbot],
259
+ outputs=[user_input, chatbot]
260
+ ).then(
261
+ fn=update_chat_with_response,
262
+ inputs=[chatbot, user_input, model_name, temperature, max_tokens],
263
+ outputs=chatbot
264
+ )
265
+
266
+ submit_button.click(
267
+ fn=user_message_submitted,
268
+ inputs=[user_input, chatbot],
269
+ outputs=[user_input, chatbot]
270
+ ).then(
271
+ fn=update_chat_with_response,
272
+ inputs=[chatbot, user_input, model_name, temperature, max_tokens],
273
+ outputs=chatbot
274
+ )
275
+
276
+ # Clear chat
277
+ clear_button.click(
278
+ fn=lambda: [],
279
+ outputs=chatbot
280
+ )
281
+
282
+ examples = [
283
+ ["Últimamente tengo la tensión un poco alta, ¿cómo debo adaptar mis hábitos?"],
284
+ ["What are the common side effects of metformin?"],
285
+ ["中医和西医在治疗高血压方面有什么不同的观点?"],
286
+ ["मेरा सिर दर्द कर रहा है, मुझे क्या करना चाहिए? "],
287
+ ["Comment savoir si je suis diabétique ?"],
288
+ ["ما الدواء الذي يمكنني تناوله إذا لم أستطع النوم ليلاً؟"]
289
+ ]
290
+ gr.Examples(
291
+ examples=examples,
292
+ inputs=user_input
293
+ )
294
 
295
  if __name__ == "__main__":
296
  demo.launch()