pseudotensor commited on
Commit
32c203b
1 Parent(s): 65121b5

Update with h2oGPT hash f2a7ba06a6f9e200e59d7e1388fa02b52bd20e8d

Browse files
Files changed (4) hide show
  1. app.py +1 -1
  2. gradio_runner.py +903 -0
  3. gradio_themes.py +142 -0
  4. utils.py +5 -4
app.py CHANGED
@@ -833,7 +833,7 @@ def evaluate(
833
  target_func = generate_with_exceptions
834
  if concurrency_count == 1:
835
  # otherwise can't do this
836
- KThread.kill_threads(target_func.__name__)
837
  target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
838
  raise_generate_gpu_exceptions, **gen_kwargs)
839
  thread = KThread(target=target)
 
833
  target_func = generate_with_exceptions
834
  if concurrency_count == 1:
835
  # otherwise can't do this
836
+ KThread.kill_threads(target_func.__name__, debug=debug)
837
  target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
838
  raise_generate_gpu_exceptions, **gen_kwargs)
839
  thread = KThread(target=target)
gradio_runner.py ADDED
@@ -0,0 +1,903 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import inspect
3
+ import os
4
+ import sys
5
+
6
+ from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js
7
+ from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print
8
+ from finetune import prompt_type_to_model_name, prompt_types_strings, generate_prompt, inv_prompt_type_to_model_lower
9
+ from generate import get_model, languages_covered, evaluate, eval_func_param_names, score_qa
10
+
11
+ import gradio as gr
12
+ from apscheduler.schedulers.background import BackgroundScheduler
13
+
14
+
15
+ def go_gradio(**kwargs):
16
+ allow_api = kwargs['allow_api']
17
+ is_public = kwargs['is_public']
18
+ is_hf = kwargs['is_hf']
19
+ is_low_mem = kwargs['is_low_mem']
20
+ n_gpus = kwargs['n_gpus']
21
+ admin_pass = kwargs['admin_pass']
22
+ model_state0 = kwargs['model_state0']
23
+ score_model_state0 = kwargs['score_model_state0']
24
+
25
+ # easy update of kwargs needed for evaluate() etc.
26
+ kwargs.update(locals())
27
+
28
+ if 'mbart-' in kwargs['model_lower']:
29
+ instruction_label_nochat = "Text to translate"
30
+ else:
31
+ instruction_label_nochat = "Instruction (Shift-Enter or push Submit to send message," \
32
+ " use Enter for multiple input lines)"
33
+ if kwargs['input_lines'] > 1:
34
+ instruction_label = "You (Shift-Enter or push Submit to send message, use Enter for multiple input lines)"
35
+ else:
36
+ instruction_label = "You (Enter or push Submit to send message, shift-enter for more lines)"
37
+
38
+ title = 'h2oGPT'
39
+ if 'h2ogpt-research' in kwargs['base_model']:
40
+ title += " [Research demonstration]"
41
+ if kwargs['verbose']:
42
+ description = f"""Model {kwargs['base_model']} Instruct dataset.
43
+ For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio).
44
+ Command: {str(' '.join(sys.argv))}
45
+ Hash: {get_githash()}
46
+ """
47
+ else:
48
+ description = "For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio).<br>"
49
+ if is_public:
50
+ description += "If this host is busy, try [gpt.h2o.ai 20B](https://gpt.h2o.ai) and [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) and [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
51
+ description += """<p><b> DISCLAIMERS: </b><ul><i><li>The model was trained on The Pile and other data, which may contain objectionable content. Use at own risk.</i></li>"""
52
+ if kwargs['load_8bit']:
53
+ description += """<i><li> Model is loaded in 8-bit and has other restrictions on this host. UX can be worse than non-hosted version.</i></li>"""
54
+ description += """<i><li>Conversations may be used to improve h2oGPT. Do not share sensitive information.</i></li>"""
55
+ if 'h2ogpt-research' in kwargs['base_model']:
56
+ description += """<i><li>Research demonstration only, not used for commercial purposes.</i></li>"""
57
+ description += """<i><li>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md).</i></li></ul></p>"""
58
+
59
+ if kwargs['verbose']:
60
+ task_info_md = f"""
61
+ ### Task: {kwargs['task_info']}"""
62
+ else:
63
+ task_info_md = ''
64
+
65
+ if kwargs['h2ocolors']:
66
+ css_code = """footer {visibility: hidden;}
67
+ body{background:linear-gradient(#f5f5f5,#e5e5e5);}
68
+ body.dark{background:linear-gradient(#000000,#0d0d0d);}
69
+ """
70
+ else:
71
+ css_code = """footer {visibility: hidden}"""
72
+
73
+ if kwargs['gradio_avoid_processing_markdown']:
74
+ from gradio_client import utils as client_utils
75
+ from gradio.components import Chatbot
76
+
77
+ # gradio has issue with taking too long to process input/output for markdown etc.
78
+ # Avoid for now, allow raw html to render, good enough for chatbot.
79
+ def _postprocess_chat_messages(self, chat_message: str):
80
+ if chat_message is None:
81
+ return None
82
+ elif isinstance(chat_message, (tuple, list)):
83
+ filepath = chat_message[0]
84
+ mime_type = client_utils.get_mimetype(filepath)
85
+ filepath = self.make_temp_copy_if_needed(filepath)
86
+ return {
87
+ "name": filepath,
88
+ "mime_type": mime_type,
89
+ "alt_text": chat_message[1] if len(chat_message) > 1 else None,
90
+ "data": None, # These last two fields are filled in by the frontend
91
+ "is_file": True,
92
+ }
93
+ elif isinstance(chat_message, str):
94
+ return chat_message
95
+ else:
96
+ raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
97
+
98
+ Chatbot._postprocess_chat_messages = _postprocess_chat_messages
99
+
100
+ theme = H2oTheme() if kwargs['h2ocolors'] else SoftTheme()
101
+ demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False)
102
+ callback = gr.CSVLogger()
103
+
104
+ model_options = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options']
105
+ if kwargs['base_model'].strip() not in model_options:
106
+ lora_options = [kwargs['base_model'].strip()] + model_options
107
+ lora_options = kwargs['extra_lora_options']
108
+ if kwargs['lora_weights'].strip() not in lora_options:
109
+ lora_options = [kwargs['lora_weights'].strip()] + lora_options
110
+ # always add in no lora case
111
+ # add fake space so doesn't go away in gradio dropdown
112
+ no_lora_str = no_model_str = '[None/Remove]'
113
+ lora_options = [no_lora_str] + kwargs['extra_lora_options'] # FIXME: why double?
114
+ # always add in no model case so can free memory
115
+ # add fake space so doesn't go away in gradio dropdown
116
+ model_options = [no_model_str] + model_options
117
+
118
+ # transcribe, will be detranscribed before use by evaluate()
119
+ if not kwargs['lora_weights'].strip():
120
+ kwargs['lora_weights'] = no_lora_str
121
+
122
+ if not kwargs['base_model'].strip():
123
+ kwargs['base_model'] = no_model_str
124
+
125
+ # transcribe for gradio
126
+ kwargs['gpu_id'] = str(kwargs['gpu_id'])
127
+
128
+ no_model_msg = 'h2oGPT [ !!! Please Load Model in Models Tab !!! ]'
129
+ output_label0 = f'h2oGPT [Model: {kwargs.get("base_model")}]' if kwargs.get(
130
+ 'base_model') else no_model_msg
131
+ output_label0_model2 = no_model_msg
132
+
133
+ with demo:
134
+ # avoid actual model/tokenizer here or anything that would be bad to deepcopy
135
+ # https://github.com/gradio-app/gradio/issues/3558
136
+ model_state = gr.State(['model', 'tokenizer', kwargs['device'], kwargs['base_model']])
137
+ model_state2 = gr.State([None, None, None, None])
138
+ model_options_state = gr.State([model_options])
139
+ lora_options_state = gr.State([lora_options])
140
+ gr.Markdown(f"""
141
+ {get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)}
142
+
143
+ {description}
144
+ {task_info_md}
145
+ """)
146
+ if is_hf:
147
+ gr.HTML(
148
+ '''<center><a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate this Space to skip the queue and run in a private space</center>''')
149
+
150
+ # go button visible if
151
+ base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
152
+ go_btn = gr.Button(value="ENTER", visible=base_wanted, variant="primary")
153
+ normal_block = gr.Row(visible=not base_wanted)
154
+ with normal_block:
155
+ with gr.Tabs():
156
+ with gr.Row():
157
+ col_nochat = gr.Column(visible=not kwargs['chat'])
158
+ with col_nochat: # FIXME: for model comparison, and check rest
159
+ text_output_nochat = gr.Textbox(lines=5, label=output_label0)
160
+ instruction_nochat = gr.Textbox(
161
+ lines=kwargs['input_lines'],
162
+ label=instruction_label_nochat,
163
+ placeholder=kwargs['placeholder_instruction'],
164
+ )
165
+ iinput_nochat = gr.Textbox(lines=4, label="Input context for Instruction",
166
+ placeholder=kwargs['placeholder_input'])
167
+ submit_nochat = gr.Button("Submit")
168
+ flag_btn_nochat = gr.Button("Flag")
169
+ if not kwargs['auto_score']:
170
+ with gr.Column(visible=kwargs['score_model']):
171
+ score_btn_nochat = gr.Button("Score last prompt & response")
172
+ score_text_nochat = gr.Textbox("Response Score: NA", show_label=False)
173
+ else:
174
+ with gr.Column(visible=kwargs['score_model']):
175
+ score_text_nochat = gr.Textbox("Response Score: NA", show_label=False)
176
+ col_chat = gr.Column(visible=kwargs['chat'])
177
+ with col_chat:
178
+ with gr.Row():
179
+ text_output = gr.Chatbot(label=output_label0).style(height=kwargs['height'] or 400)
180
+ text_output2 = gr.Chatbot(label=output_label0_model2, visible=False).style(
181
+ height=kwargs['height'] or 400)
182
+ with gr.Row():
183
+ with gr.Column(scale=50):
184
+ instruction = gr.Textbox(
185
+ lines=kwargs['input_lines'],
186
+ label=instruction_label,
187
+ placeholder=kwargs['placeholder_instruction'],
188
+ )
189
+ with gr.Row():
190
+ submit = gr.Button(value='Submit').style(full_width=False, size='sm')
191
+ stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
192
+ with gr.Row():
193
+ clear = gr.Button("New Conversation")
194
+ flag_btn = gr.Button("Flag")
195
+ if not kwargs['auto_score']: # FIXME: For checkbox model2
196
+ with gr.Column(visible=kwargs['score_model']):
197
+ with gr.Row():
198
+ score_btn = gr.Button("Score last prompt & response").style(
199
+ full_width=False, size='sm')
200
+ score_text = gr.Textbox("Response Score: NA", show_label=False)
201
+ score_res2 = gr.Row(visible=False)
202
+ with score_res2:
203
+ score_btn2 = gr.Button("Score last prompt & response 2").style(
204
+ full_width=False, size='sm')
205
+ score_text2 = gr.Textbox("Response Score2: NA", show_label=False)
206
+ else:
207
+ with gr.Column(visible=kwargs['score_model']):
208
+ score_text = gr.Textbox("Response Score: NA", show_label=False)
209
+ score_text2 = gr.Textbox("Response Score2: NA", show_label=False, visible=False)
210
+ retry = gr.Button("Regenerate")
211
+ undo = gr.Button("Undo")
212
+ with gr.TabItem("Input/Output"):
213
+ with gr.Row():
214
+ if 'mbart-' in kwargs['model_lower']:
215
+ src_lang = gr.Dropdown(list(languages_covered().keys()),
216
+ value=kwargs['src_lang'],
217
+ label="Input Language")
218
+ tgt_lang = gr.Dropdown(list(languages_covered().keys()),
219
+ value=kwargs['tgt_lang'],
220
+ label="Output Language")
221
+ with gr.TabItem("Expert"):
222
+ with gr.Row():
223
+ with gr.Column():
224
+ stream_output = gr.components.Checkbox(label="Stream output",
225
+ value=kwargs['stream_output'])
226
+ prompt_type = gr.Dropdown(prompt_types_strings,
227
+ value=kwargs['prompt_type'], label="Prompt Type",
228
+ visible=not is_public)
229
+ prompt_type2 = gr.Dropdown(prompt_types_strings,
230
+ value=kwargs['prompt_type'], label="Prompt Type Model 2",
231
+ visible=not is_public and False)
232
+ do_sample = gr.Checkbox(label="Sample",
233
+ info="Enable sampler, required for use of temperature, top_p, top_k",
234
+ value=kwargs['do_sample'])
235
+ temperature = gr.Slider(minimum=0.01, maximum=3,
236
+ value=kwargs['temperature'],
237
+ label="Temperature",
238
+ info="Lower is deterministic (but may lead to repeats), Higher more creative (but may lead to hallucinations)")
239
+ top_p = gr.Slider(minimum=0, maximum=1,
240
+ value=kwargs['top_p'], label="Top p",
241
+ info="Cumulative probability of tokens to sample from")
242
+ top_k = gr.Slider(
243
+ minimum=0, maximum=100, step=1,
244
+ value=kwargs['top_k'], label="Top k",
245
+ info='Num. tokens to sample from'
246
+ )
247
+ max_beams = 8 if not is_low_mem else 2
248
+ num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
249
+ value=min(max_beams, kwargs['num_beams']), label="Beams",
250
+ info="Number of searches for optimal overall probability. "
251
+ "Uses more GPU memory/compute")
252
+ max_max_new_tokens = 2048 if not is_low_mem else kwargs['max_new_tokens']
253
+ max_new_tokens = gr.Slider(
254
+ minimum=1, maximum=max_max_new_tokens, step=1,
255
+ value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length",
256
+ )
257
+ min_new_tokens = gr.Slider(
258
+ minimum=0, maximum=max_max_new_tokens, step=1,
259
+ value=min(max_max_new_tokens, kwargs['min_new_tokens']), label="Min output length",
260
+ )
261
+ early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
262
+ value=kwargs['early_stopping'])
263
+ max_max_time = 60 * 5 if not is_low_mem else 60
264
+ max_time = gr.Slider(minimum=0, maximum=max_max_time, step=1,
265
+ value=min(max_max_time, kwargs['max_time']), label="Max. time",
266
+ info="Max. time to search optimal output.")
267
+ repetition_penalty = gr.Slider(minimum=0.01, maximum=3.0,
268
+ value=kwargs['repetition_penalty'],
269
+ label="Repetition Penalty")
270
+ num_return_sequences = gr.Slider(minimum=1, maximum=10, step=1,
271
+ value=kwargs['num_return_sequences'],
272
+ label="Number Returns", info="Must be <= num_beams",
273
+ visible=not is_public)
274
+ iinput = gr.Textbox(lines=4, label="Input",
275
+ placeholder=kwargs['placeholder_input'],
276
+ visible=not is_public)
277
+ context = gr.Textbox(lines=3, label="System Pre-Context",
278
+ info="Directly pre-appended without prompt processing",
279
+ visible=not is_public)
280
+ chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
281
+ visible=not is_public)
282
+
283
+ with gr.TabItem("Models"):
284
+ load_msg = "Load-Unload Model/LORA" if not is_public \
285
+ else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO"
286
+ load_msg2 = "Load-Unload Model/LORA 2" if not is_public \
287
+ else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO 2"
288
+ compare_checkbox = gr.components.Checkbox(label="Compare Mode",
289
+ value=False, visible=not is_public)
290
+ with gr.Row():
291
+ n_gpus_list = [str(x) for x in list(range(-1, n_gpus))]
292
+ with gr.Column():
293
+ with gr.Row():
294
+ with gr.Column(scale=50):
295
+ model_choice = gr.Dropdown(model_options_state.value[0], label="Choose Model",
296
+ value=kwargs['base_model'])
297
+ lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA",
298
+ value=kwargs['lora_weights'], visible=kwargs['show_lora'])
299
+ with gr.Column(scale=1):
300
+ load_model_button = gr.Button(load_msg)
301
+ model_load8bit_checkbox = gr.components.Checkbox(
302
+ label="Load 8-bit [requires support]",
303
+ value=kwargs['load_8bit'])
304
+ model_infer_devices_checkbox = gr.components.Checkbox(
305
+ label="Choose Devices [If not Checked, use all GPUs]",
306
+ value=kwargs['infer_devices'])
307
+ model_gpu = gr.Dropdown(n_gpus_list,
308
+ label="GPU ID 2 [-1 = all GPUs, if Choose is enabled]",
309
+ value=kwargs['gpu_id'])
310
+ model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'])
311
+ lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'],
312
+ visible=kwargs['show_lora'])
313
+ with gr.Row():
314
+ with gr.Column(scale=50):
315
+ new_model = gr.Textbox(label="New Model HF name/path")
316
+ new_lora = gr.Textbox(label="New LORA HF name/path", visible=kwargs['show_lora'])
317
+ with gr.Column(scale=1):
318
+ add_model_button = gr.Button("Add new model name")
319
+ add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora'])
320
+ col_model2 = gr.Column(visible=False)
321
+ with col_model2:
322
+ with gr.Row():
323
+ with gr.Column(scale=50):
324
+ model_choice2 = gr.Dropdown(model_options_state.value[0], label="Choose Model 2",
325
+ value=no_model_str)
326
+ lora_choice2 = gr.Dropdown(lora_options_state.value[0], label="Choose LORA 2",
327
+ value=no_lora_str,
328
+ visible=kwargs['show_lora'])
329
+ with gr.Column(scale=1):
330
+ load_model_button2 = gr.Button(load_msg2)
331
+ model_load8bit_checkbox2 = gr.components.Checkbox(
332
+ label="Load 8-bit 2 [requires support]",
333
+ value=kwargs['load_8bit'])
334
+ model_infer_devices_checkbox2 = gr.components.Checkbox(
335
+ label="Choose Devices 2 [If not Checked, use all GPUs]",
336
+ value=kwargs[
337
+ 'infer_devices'])
338
+ model_gpu2 = gr.Dropdown(n_gpus_list,
339
+ label="GPU ID [-1 = all GPUs, if choose is enabled]",
340
+ value=kwargs['gpu_id'])
341
+ # no model/lora loaded ever in model2 by default
342
+ model_used2 = gr.Textbox(label="Current Model 2", value=no_model_str)
343
+ lora_used2 = gr.Textbox(label="Current LORA 2", value=no_lora_str,
344
+ visible=kwargs['show_lora'])
345
+ with gr.TabItem("System"):
346
+ admin_row = gr.Row()
347
+ with admin_row:
348
+ admin_pass_textbox = gr.Textbox(label="Admin Password", type='password', visible=is_public)
349
+ admin_btn = gr.Button(value="Admin Access", visible=is_public)
350
+ system_row = gr.Row(visible=not is_public)
351
+ with system_row:
352
+ with gr.Column():
353
+ with gr.Row():
354
+ system_btn = gr.Button(value='Get System Info')
355
+ system_text = gr.Textbox(label='System Info')
356
+
357
+ with gr.Row():
358
+ zip_btn = gr.Button("Zip")
359
+ zip_text = gr.Textbox(label="Zip file name")
360
+ file_output = gr.File()
361
+ with gr.Row():
362
+ s3up_btn = gr.Button("S3UP")
363
+ s3up_text = gr.Textbox(label='S3UP result')
364
+
365
+ # Get flagged data
366
+ zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
367
+ zip_btn.click(zip_data1, inputs=None, outputs=[file_output, zip_text])
368
+ s3up_btn.click(s3up, inputs=zip_text, outputs=s3up_text)
369
+
370
+ def check_admin_pass(x):
371
+ return gr.update(visible=x == admin_pass)
372
+
373
+ def close_admin(x):
374
+ return gr.update(visible=not (x == admin_pass))
375
+
376
+ admin_btn.click(check_admin_pass, inputs=admin_pass_textbox, outputs=system_row) \
377
+ .then(close_admin, inputs=admin_pass_textbox, outputs=admin_row)
378
+
379
+ # Get inputs to evaluate()
380
+ all_kwargs = kwargs.copy()
381
+ all_kwargs.update(locals())
382
+ inputs_list = get_inputs_list(all_kwargs, kwargs['model_lower'])
383
+ from functools import partial
384
+ kwargs_evaluate = {k: v for k, v in all_kwargs.items() if k in inputs_kwargs_list}
385
+ # ensure present
386
+ for k in inputs_kwargs_list:
387
+ assert k in kwargs_evaluate, "Missing %s" % k
388
+ fun = partial(evaluate,
389
+ **kwargs_evaluate)
390
+ fun2 = partial(evaluate,
391
+ **kwargs_evaluate)
392
+
393
+ dark_mode_btn = gr.Button("Dark Mode", variant="primary").style(
394
+ size="sm",
395
+ )
396
+ dark_mode_btn.click(
397
+ None,
398
+ None,
399
+ None,
400
+ _js=get_dark_js(),
401
+ api_name="dark" if allow_api else None,
402
+ )
403
+
404
+ # Control chat and non-chat blocks, which can be independently used by chat checkbox swap
405
+ def col_nochat_fun(x):
406
+ return gr.Column.update(visible=not x)
407
+
408
+ def col_chat_fun(x):
409
+ return gr.Column.update(visible=x)
410
+
411
+ def context_fun(x):
412
+ return gr.Textbox.update(visible=not x)
413
+
414
+ chat.select(col_nochat_fun, chat, col_nochat, api_name="chat_checkbox" if allow_api else None) \
415
+ .then(col_chat_fun, chat, col_chat) \
416
+ .then(context_fun, chat, context)
417
+
418
+ # examples after submit or any other buttons for chat or no chat
419
+ if kwargs['examples'] is not None and kwargs['show_examples']:
420
+ gr.Examples(examples=kwargs['examples'], inputs=inputs_list)
421
+
422
+ # Score
423
+ def score_last_response(*args, nochat=False, model2=False):
424
+ """ Similar to user() """
425
+ args_list = list(args)
426
+
427
+ max_length_tokenize = 512 if is_low_mem else 2048
428
+ cutoff_len = max_length_tokenize * 4 # restrict deberta related to max for LLM
429
+ smodel = score_model_state0[0]
430
+ stokenizer = score_model_state0[1]
431
+ sdevice = score_model_state0[2]
432
+ if not nochat:
433
+ history = args_list[-1]
434
+ if history is None:
435
+ if not model2:
436
+ # maybe only doing first model, no need to complain
437
+ print("Bad history in scoring last response, fix for now", flush=True)
438
+ history = []
439
+ if smodel is not None and \
440
+ stokenizer is not None and \
441
+ sdevice is not None and \
442
+ history is not None and len(history) > 0 and \
443
+ history[-1] is not None and \
444
+ len(history[-1]) >= 2:
445
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
446
+
447
+ question = history[-1][0]
448
+
449
+ answer = history[-1][1]
450
+ else:
451
+ return 'Response Score: NA'
452
+ else:
453
+ answer = args_list[-1]
454
+ instruction_nochat_arg_id = eval_func_param_names.index('instruction_nochat')
455
+ question = args_list[instruction_nochat_arg_id]
456
+
457
+ if question is None:
458
+ return 'Response Score: Bad Question'
459
+ if answer is None:
460
+ return 'Response Score: Bad Answer'
461
+ score = score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_len)
462
+ if isinstance(score, str):
463
+ return 'Response Score: NA'
464
+ return 'Response Score: {:.1%}'.format(score)
465
+
466
+ def noop_score_last_response(*args, **kwargs):
467
+ return "Response Score: Disabled"
468
+
469
+ if kwargs['score_model']:
470
+ score_fun = score_last_response
471
+ else:
472
+ score_fun = noop_score_last_response
473
+
474
+ score_args = dict(fn=score_fun,
475
+ inputs=inputs_list + [text_output],
476
+ outputs=[score_text],
477
+ )
478
+ score_args2 = dict(fn=partial(score_fun, model2=True),
479
+ inputs=inputs_list + [text_output2],
480
+ outputs=[score_text2],
481
+ )
482
+
483
+ score_args_nochat = dict(fn=partial(score_fun, nochat=True),
484
+ inputs=inputs_list + [text_output_nochat],
485
+ outputs=[score_text_nochat],
486
+ )
487
+ if not kwargs['auto_score']:
488
+ score_event = score_btn.click(**score_args, queue=stream_output, api_name='score' if allow_api else None) \
489
+ .then(**score_args2, queue=stream_output, api_name='score2' if allow_api else None)
490
+ score_event_nochat = score_btn_nochat.click(**score_args_nochat, queue=stream_output,
491
+ api_name='score_nochat' if allow_api else None)
492
+
493
+ def user(*args, undo=False, sanitize_user_prompt=True, model2=False):
494
+ """
495
+ User that fills history for bot
496
+ :param args:
497
+ :param undo:
498
+ :param sanitize_user_prompt:
499
+ :param model2:
500
+ :return:
501
+ """
502
+ args_list = list(args)
503
+ user_message = args_list[0]
504
+ input1 = args_list[1]
505
+ context1 = args_list[2]
506
+ if input1 and not user_message.endswith(':'):
507
+ user_message1 = user_message + ":" + input1
508
+ elif input1:
509
+ user_message1 = user_message + input1
510
+ else:
511
+ user_message1 = user_message
512
+ if sanitize_user_prompt:
513
+ from better_profanity import profanity
514
+ user_message1 = profanity.censor(user_message1)
515
+
516
+ history = args_list[-1]
517
+ if undo and history:
518
+ history.pop()
519
+ args_list = args_list[:-1] # FYI, even if unused currently
520
+ if history is None:
521
+ if not model2:
522
+ # no need to complain so often unless model1
523
+ print("Bad history, fix for now", flush=True)
524
+ history = []
525
+ # ensure elements not mixed across models as output,
526
+ # even if input is currently same source
527
+ history = history.copy()
528
+ if undo:
529
+ return history
530
+ else:
531
+ # FIXME: compare, same history for now
532
+ return history + [[user_message1, None]]
533
+
534
+ def bot(*args, retry=False):
535
+ """
536
+ bot that consumes history for user input
537
+ instruction (from input_list) itself is not consumed by bot
538
+ :param args:
539
+ :param retry:
540
+ :return:
541
+ """
542
+ args_list = list(args).copy()
543
+ history = args_list[-1] # model_state is -2
544
+ if retry and history:
545
+ history.pop()
546
+ if not history:
547
+ print("No history", flush=True)
548
+ return
549
+ # ensure output will be unique to models
550
+ history = history.copy()
551
+ instruction1 = history[-1][0]
552
+ context1 = ''
553
+ if kwargs['chat_history'] > 0:
554
+ prompt_type_arg_id = eval_func_param_names.index('prompt_type')
555
+ prompt_type1 = args_list[prompt_type_arg_id]
556
+ chat_arg_id = eval_func_param_names.index('chat')
557
+ chat1 = args_list[chat_arg_id]
558
+ context1 = ''
559
+ for histi in range(len(history) - 1):
560
+ data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
561
+ context1 += generate_prompt(data_point, prompt_type1, chat1, reduced=True)[0].replace(
562
+ '<br>', '\n')
563
+ if not context1.endswith('\n'):
564
+ context1 += '\n'
565
+ if context1 and not context1.endswith('\n'):
566
+ context1 += '\n' # ensure if terminates abruptly, then human continues on next line
567
+ args_list[0] = instruction1 # override original instruction with history from user
568
+ # only include desired chat history
569
+ args_list[2] = context1[-kwargs['chat_history']:]
570
+ model_state1 = args_list[-2]
571
+ if model_state1[0] is None or model_state1[0] == no_model_str:
572
+ return
573
+ args_list = args_list[:-2]
574
+ fun1 = partial(evaluate,
575
+ model_state1,
576
+ **kwargs_evaluate)
577
+ try:
578
+ for output in fun1(*tuple(args_list)):
579
+ bot_message = output
580
+ history[-1][1] = bot_message
581
+ yield history
582
+ except StopIteration:
583
+ yield history
584
+ except RuntimeError as e:
585
+ if "generator raised StopIteration" in str(e):
586
+ # assume last entry was bad, undo
587
+ history.pop()
588
+ yield history
589
+ raise
590
+ except Exception as e:
591
+ # put error into user input
592
+ history[-1][0] = "Exception: %s" % str(e)
593
+ yield history
594
+ raise
595
+ return
596
+
597
+ # NORMAL MODEL
598
+ user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
599
+ inputs=inputs_list + [text_output],
600
+ outputs=text_output,
601
+ )
602
+ bot_args = dict(fn=bot,
603
+ inputs=inputs_list + [model_state] + [text_output],
604
+ outputs=text_output,
605
+ )
606
+ retry_bot_args = dict(fn=functools.partial(bot, retry=True),
607
+ inputs=inputs_list + [model_state] + [text_output],
608
+ outputs=text_output,
609
+ )
610
+ undo_user_args = dict(fn=functools.partial(user, undo=True),
611
+ inputs=inputs_list + [text_output],
612
+ outputs=text_output,
613
+ )
614
+
615
+ # MODEL2
616
+ user_args2 = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt'], model2=True),
617
+ inputs=inputs_list + [text_output2],
618
+ outputs=text_output2,
619
+ )
620
+ bot_args2 = dict(fn=bot,
621
+ inputs=inputs_list + [model_state2] + [text_output2],
622
+ outputs=text_output2,
623
+ )
624
+ retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
625
+ inputs=inputs_list + [model_state2] + [text_output2],
626
+ outputs=text_output2,
627
+ )
628
+ undo_user_args2 = dict(fn=functools.partial(user, undo=True),
629
+ inputs=inputs_list + [text_output2],
630
+ outputs=text_output2,
631
+ )
632
+
633
+ def clear_instruct():
634
+ return gr.Textbox.update(value='')
635
+
636
+ if kwargs['auto_score']:
637
+ # in case 2nd model, consume instruction first, so can clear quickly
638
+ # bot doesn't consume instruction itself, just history from user, so why works
639
+ submit_event = instruction.submit(**user_args, queue=stream_output,
640
+ api_name='instruction' if allow_api else None) \
641
+ .then(**user_args2, queue=stream_output, api_name='instruction2' if allow_api else None) \
642
+ .then(clear_instruct, None, instruction) \
643
+ .then(clear_instruct, None, iinput) \
644
+ .then(**bot_args, api_name='instruction_bot' if allow_api else None) \
645
+ .then(**score_args, api_name='instruction_bot_score' if allow_api else None) \
646
+ .then(**bot_args2, api_name='instruction_bot2' if allow_api else None) \
647
+ .then(**score_args2, api_name='instruction_bot_score2' if allow_api else None) \
648
+ .then(clear_torch_cache)
649
+ submit_event2 = submit.click(**user_args, queue=stream_output, api_name='submit' if allow_api else None) \
650
+ .then(**user_args2, queue=stream_output, api_name='submit2' if allow_api else None) \
651
+ .then(clear_instruct, None, instruction) \
652
+ .then(clear_instruct, None, iinput) \
653
+ .then(**bot_args, api_name='submit_bot' if allow_api else None) \
654
+ .then(**score_args, api_name='submit_bot_score' if allow_api else None) \
655
+ .then(**bot_args2, api_name='submit_bot2' if allow_api else None) \
656
+ .then(**score_args2, api_name='submit_bot_score2' if allow_api else None) \
657
+ .then(clear_torch_cache)
658
+ submit_event3 = retry.click(**user_args, queue=stream_output, api_name='retry' if allow_api else None) \
659
+ .then(**user_args2, queue=stream_output, api_name='retry2' if allow_api else None) \
660
+ .then(clear_instruct, None, instruction) \
661
+ .then(clear_instruct, None, iinput) \
662
+ .then(**retry_bot_args, api_name='retry_bot' if allow_api else None) \
663
+ .then(**score_args, api_name='retry_bot_score' if allow_api else None) \
664
+ .then(**retry_bot_args2, api_name='retry_bot2' if allow_api else None) \
665
+ .then(**score_args2, api_name='retry_bot_score2' if allow_api else None) \
666
+ .then(clear_torch_cache)
667
+ submit_event4 = undo.click(**undo_user_args, queue=stream_output, api_name='undo' if allow_api else None) \
668
+ .then(**undo_user_args2, queue=stream_output, api_name='undo2' if allow_api else None) \
669
+ .then(clear_instruct, None, instruction) \
670
+ .then(clear_instruct, None, iinput) \
671
+ .then(**score_args, api_name='undo_score' if allow_api else None) \
672
+ .then(**score_args2, api_name='undo_score2' if allow_api else None)
673
+ else:
674
+ submit_event = instruction.submit(**user_args, queue=stream_output,
675
+ api_name='instruction' if allow_api else None) \
676
+ .then(**user_args2, queue=stream_output, api_name='instruction2' if allow_api else None) \
677
+ .then(clear_instruct, None, instruction) \
678
+ .then(clear_instruct, None, iinput) \
679
+ .then(**bot_args, api_name='instruction_bot' if allow_api else None) \
680
+ .then(**bot_args2, api_name='instruction_bot2' if allow_api else None) \
681
+ .then(clear_torch_cache)
682
+ submit_event2 = submit.click(**user_args, queue=stream_output, api_name='submit' if allow_api else None) \
683
+ .then(**user_args2, queue=stream_output, api_name='submit2' if allow_api else None) \
684
+ .then(clear_instruct, None, instruction) \
685
+ .then(clear_instruct, None, iinput) \
686
+ .then(**bot_args, api_name='submit_bot' if allow_api else None) \
687
+ .then(**bot_args2, api_name='submit_bot2' if allow_api else None) \
688
+ .then(clear_torch_cache)
689
+ submit_event3 = retry.click(**user_args, queue=stream_output, api_name='retry' if allow_api else None) \
690
+ .then(**user_args2, queue=stream_output, api_name='retry2' if allow_api else None) \
691
+ .then(clear_instruct, None, instruction) \
692
+ .then(clear_instruct, None, iinput) \
693
+ .then(**retry_bot_args, api_name='retry_bot' if allow_api else None) \
694
+ .then(**retry_bot_args2, api_name='retry_bot2' if allow_api else None) \
695
+ .then(clear_torch_cache)
696
+ submit_event4 = undo.click(**undo_user_args, queue=stream_output, api_name='undo' if allow_api else None) \
697
+ .then(**undo_user_args2, queue=stream_output, api_name='undo2' if allow_api else None)
698
+
699
+ # does both models
700
+ clear.click(lambda: None, None, text_output, queue=False, api_name='clear' if allow_api else None) \
701
+ .then(lambda: None, None, text_output2, queue=False, api_name='clear2' if allow_api else None)
702
+ # NOTE: clear of instruction/iinput for nochat has to come after score,
703
+ # because score for nochat consumes actual textbox, while chat consumes chat history filled by user()
704
+ submit_event_nochat = submit_nochat.click(fun, inputs=[model_state] + inputs_list,
705
+ outputs=text_output_nochat,
706
+ api_name='submit_nochat' if allow_api else None) \
707
+ .then(**score_args_nochat, api_name='instruction_bot_score_nochat' if allow_api else None) \
708
+ .then(clear_instruct, None, instruction_nochat) \
709
+ .then(clear_instruct, None, iinput_nochat) \
710
+ .then(clear_torch_cache)
711
+
712
+ def load_model(model_name, lora_weights, model_state_old, prompt_type_old, load_8bit, infer_devices, gpu_id):
713
+ # ensure old model removed from GPU memory
714
+ if kwargs['debug']:
715
+ print("Pre-switch pre-del GPU memory: %s" % get_torch_allocated(), flush=True)
716
+
717
+ model0 = model_state0[0]
718
+ if isinstance(model_state_old[0], str) and model0 is not None:
719
+ # best can do, move model loaded at first to CPU
720
+ model0.cpu()
721
+
722
+ if model_state_old[0] is not None and not isinstance(model_state_old[0], str):
723
+ try:
724
+ model_state_old[0].cpu()
725
+ except Exception as e:
726
+ # sometimes hit NotImplementedError: Cannot copy out of meta tensor; no data!
727
+ print("Unable to put model on CPU: %s" % str(e), flush=True)
728
+ del model_state_old[0]
729
+ model_state_old[0] = None
730
+
731
+ if model_state_old[1] is not None and not isinstance(model_state_old[1], str):
732
+ del model_state_old[1]
733
+ model_state_old[1] = None
734
+
735
+ clear_torch_cache()
736
+ if kwargs['debug']:
737
+ print("Pre-switch post-del GPU memory: %s" % get_torch_allocated(), flush=True)
738
+
739
+ if model_name is None or model_name == no_model_str:
740
+ # no-op if no model, just free memory
741
+ # no detranscribe needed for model, never go into evaluate
742
+ lora_weights = no_lora_str
743
+ return [None, None, None, model_name], model_name, lora_weights, prompt_type_old
744
+
745
+ all_kwargs1 = all_kwargs.copy()
746
+ all_kwargs1['base_model'] = model_name.strip()
747
+ all_kwargs1['load_8bit'] = load_8bit
748
+ all_kwargs1['infer_devices'] = infer_devices
749
+ all_kwargs1['gpu_id'] = int(gpu_id) # detranscribe
750
+ model_lower = model_name.strip().lower()
751
+ if model_lower in inv_prompt_type_to_model_lower:
752
+ prompt_type1 = inv_prompt_type_to_model_lower[model_lower]
753
+ else:
754
+ prompt_type1 = prompt_type_old
755
+
756
+ # detranscribe
757
+ if lora_weights == no_lora_str:
758
+ lora_weights = ''
759
+
760
+ all_kwargs1['lora_weights'] = lora_weights.strip()
761
+ model1, tokenizer1, device1 = get_model(**all_kwargs1)
762
+ clear_torch_cache()
763
+
764
+ if kwargs['debug']:
765
+ print("Post-switch GPU memory: %s" % get_torch_allocated(), flush=True)
766
+ return [model1, tokenizer1, device1, model_name], model_name, lora_weights, prompt_type1
767
+
768
+ def dropdown_prompt_type_list(x):
769
+ return gr.Dropdown.update(value=x)
770
+
771
+ def chatbot_list(x, model_used_in):
772
+ return gr.Textbox.update(label=f'h2oGPT [Model: {model_used_in}]')
773
+
774
+ load_model_args = dict(fn=load_model,
775
+ inputs=[model_choice, lora_choice, model_state, prompt_type,
776
+ model_load8bit_checkbox, model_infer_devices_checkbox, model_gpu],
777
+ outputs=[model_state, model_used, lora_used, prompt_type])
778
+ prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type)
779
+ chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
780
+ nochat_update_args = dict(fn=chatbot_list, inputs=[text_output_nochat, model_used], outputs=text_output_nochat)
781
+ if not is_public:
782
+ load_model_event = load_model_button.click(**load_model_args) \
783
+ .then(**prompt_update_args) \
784
+ .then(**chatbot_update_args) \
785
+ .then(**nochat_update_args) \
786
+ .then(clear_torch_cache)
787
+
788
+ load_model_args2 = dict(fn=load_model,
789
+ inputs=[model_choice2, lora_choice2, model_state2, prompt_type2,
790
+ model_load8bit_checkbox2, model_infer_devices_checkbox2, model_gpu2],
791
+ outputs=[model_state2, model_used2, lora_used2, prompt_type2])
792
+ prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2)
793
+ chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2)
794
+ if not is_public:
795
+ load_model_event2 = load_model_button2.click(**load_model_args2) \
796
+ .then(**prompt_update_args2) \
797
+ .then(**chatbot_update_args2) \
798
+ .then(clear_torch_cache)
799
+
800
+ def dropdown_model_list(list0, x):
801
+ new_state = [list0[0] + [x]]
802
+ new_options = [*new_state[0]]
803
+ return gr.Dropdown.update(value=x, choices=new_options), \
804
+ gr.Dropdown.update(value=x, choices=new_options), \
805
+ '', new_state
806
+
807
+ add_model_event = add_model_button.click(fn=dropdown_model_list,
808
+ inputs=[model_options_state, new_model],
809
+ outputs=[model_choice, model_choice2, new_model, model_options_state])
810
+
811
+ def dropdown_lora_list(list0, x, model_used1, lora_used1, model_used2, lora_used2):
812
+ new_state = [list0[0] + [x]]
813
+ new_options = [*new_state[0]]
814
+ # don't switch drop-down to added lora if already have model loaded
815
+ x1 = x if model_used1 == no_model_str else lora_used1
816
+ x2 = x if model_used2 == no_model_str else lora_used2
817
+ return gr.Dropdown.update(value=x1, choices=new_options), \
818
+ gr.Dropdown.update(value=x2, choices=new_options), \
819
+ '', new_state
820
+
821
+ add_lora_event = add_lora_button.click(fn=dropdown_lora_list,
822
+ inputs=[lora_options_state, new_lora, model_used, lora_used, model_used2,
823
+ lora_used2],
824
+ outputs=[lora_choice, lora_choice2, new_lora, lora_options_state])
825
+
826
+ go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go" if allow_api else None) \
827
+ .then(lambda: gr.update(visible=True), None, normal_block) \
828
+ .then(**load_model_args).then(**prompt_update_args)
829
+
830
+ def compare_textbox_fun(x):
831
+ return gr.Textbox.update(visible=x)
832
+
833
+ def compare_column_fun(x):
834
+ return gr.Column.update(visible=x)
835
+
836
+ def compare_prompt_fun(x):
837
+ return gr.Dropdown.update(visible=x)
838
+
839
+ compare_checkbox.select(compare_textbox_fun, compare_checkbox, text_output2,
840
+ api_name="compare_checkbox" if allow_api else None) \
841
+ .then(compare_column_fun, compare_checkbox, col_model2) \
842
+ .then(compare_prompt_fun, compare_checkbox, prompt_type2) \
843
+ .then(compare_textbox_fun, compare_checkbox, score_text2)
844
+ # FIXME: add score_res2 in condition, but do better
845
+
846
+ # callback for logging flagged input/output
847
+ callback.setup(inputs_list + [text_output, text_output2], "flagged_data_points")
848
+ flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output, text_output2], None, preprocess=False,
849
+ api_name='flag' if allow_api else None)
850
+ flag_btn_nochat.click(lambda *args: callback.flag(args), inputs_list + [text_output_nochat], None, preprocess=False,
851
+ api_name='flag_nochat' if allow_api else None)
852
+
853
+ def get_system_info():
854
+ return gr.Textbox.update(value=system_info_print())
855
+
856
+ system_event = system_btn.click(get_system_info, outputs=system_text,
857
+ api_name='system_info' if allow_api else None)
858
+
859
+ # don't pass text_output, don't want to clear output, just stop it
860
+ # FIXME: have to click once to stop output and second time to stop GPUs going
861
+ stop_btn.click(lambda: None, None, None,
862
+ cancels=[submit_event_nochat, submit_event, submit_event2, submit_event3],
863
+ queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache)
864
+ demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] else None)
865
+
866
+ demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open'])
867
+ favicon_path = "h2o-logo.svg"
868
+
869
+ scheduler = BackgroundScheduler()
870
+ scheduler.add_job(func=clear_torch_cache, trigger="interval", seconds=20)
871
+ scheduler.start()
872
+
873
+ demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
874
+ favicon_path=favicon_path, prevent_thread_lock=True) # , enable_queue=True)
875
+ print("Started GUI", flush=True)
876
+ if kwargs['block_gradio_exit']:
877
+ demo.block_thread()
878
+
879
+
880
+ input_args_list = ['model_state']
881
+ inputs_kwargs_list = ['debug', 'save_dir', 'hard_stop_list', 'sanitize_bot_response', 'model_state0', 'is_low_mem',
882
+ 'raise_generate_gpu_exceptions', 'chat_context', 'concurrency_count']
883
+
884
+
885
+ def get_inputs_list(inputs_dict, model_lower):
886
+ """
887
+ map gradio objects in locals() to inputs for evaluate().
888
+ :param inputs_dict:
889
+ :param model_lower:
890
+ :return:
891
+ """
892
+ inputs_list_names = list(inspect.signature(evaluate).parameters)
893
+ inputs_list = []
894
+ for k in inputs_list_names:
895
+ if k == 'kwargs':
896
+ continue
897
+ if k in input_args_list + inputs_kwargs_list:
898
+ # these are added via partial, not taken as input
899
+ continue
900
+ if 'mbart-' not in model_lower and k in ['src_lang', 'tgt_lang']:
901
+ continue
902
+ inputs_list.append(inputs_dict[k])
903
+ return inputs_list
gradio_themes.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from gradio.themes.soft import Soft
3
+ from gradio.themes.utils import Color, colors, sizes
4
+
5
+ h2o_yellow = Color(
6
+ name="yellow",
7
+ c50="#fffef2",
8
+ c100="#fff9e6",
9
+ c200="#ffecb3",
10
+ c300="#ffe28c",
11
+ c400="#ffd659",
12
+ c500="#fec925",
13
+ c600="#e6ac00",
14
+ c700="#bf8f00",
15
+ c800="#a67c00",
16
+ c900="#664d00",
17
+ c950="#403000",
18
+ )
19
+ h2o_gray = Color(
20
+ name="gray",
21
+ c50="#f8f8f8",
22
+ c100="#e5e5e5",
23
+ c200="#cccccc",
24
+ c300="#b2b2b2",
25
+ c400="#999999",
26
+ c500="#7f7f7f",
27
+ c600="#666666",
28
+ c700="#4c4c4c",
29
+ c800="#333333",
30
+ c900="#191919",
31
+ c950="#0d0d0d",
32
+ )
33
+
34
+
35
+ class H2oTheme(Soft):
36
+ def __init__(
37
+ self,
38
+ *,
39
+ primary_hue: colors.Color | str = h2o_yellow,
40
+ secondary_hue: colors.Color | str = h2o_yellow,
41
+ neutral_hue: colors.Color | str = h2o_gray,
42
+ spacing_size: sizes.Size | str = sizes.spacing_md,
43
+ radius_size: sizes.Size | str = sizes.radius_md,
44
+ text_size: sizes.Size | str = sizes.text_lg,
45
+ ):
46
+ super().__init__(
47
+ primary_hue=primary_hue,
48
+ secondary_hue=secondary_hue,
49
+ neutral_hue=neutral_hue,
50
+ spacing_size=spacing_size,
51
+ radius_size=radius_size,
52
+ text_size=text_size,
53
+ )
54
+ super().set(
55
+ link_text_color="#3344DD",
56
+ link_text_color_hover="#3344DD",
57
+ link_text_color_visited="#3344DD",
58
+ link_text_color_dark="#74abff",
59
+ link_text_color_hover_dark="#a3c8ff",
60
+ link_text_color_active_dark="#a3c8ff",
61
+ link_text_color_visited_dark="#74abff",
62
+ button_primary_text_color="*neutral_950",
63
+ button_primary_text_color_dark="*neutral_950",
64
+ button_primary_background_fill="*primary_500",
65
+ button_primary_background_fill_dark="*primary_500",
66
+ block_label_background_fill="*primary_500",
67
+ block_label_background_fill_dark="*primary_500",
68
+ block_label_text_color="*neutral_950",
69
+ block_label_text_color_dark="*neutral_950",
70
+ block_title_text_color="*neutral_950",
71
+ block_title_text_color_dark="*neutral_950",
72
+ block_background_fill_dark="*neutral_950",
73
+ body_background_fill="*neutral_50",
74
+ body_background_fill_dark="*neutral_900",
75
+ background_fill_primary_dark="*block_background_fill",
76
+ block_radius="0 0 8px 8px",
77
+ )
78
+
79
+
80
+ class SoftTheme(Soft):
81
+ def __init__(
82
+ self,
83
+ *,
84
+ primary_hue: colors.Color | str = colors.indigo,
85
+ secondary_hue: colors.Color | str = colors.indigo,
86
+ neutral_hue: colors.Color | str = colors.gray,
87
+ spacing_size: sizes.Size | str = sizes.spacing_md,
88
+ radius_size: sizes.Size | str = sizes.radius_md,
89
+ text_size: sizes.Size | str = sizes.text_md,
90
+ ):
91
+ super().__init__(
92
+ primary_hue=primary_hue,
93
+ secondary_hue=secondary_hue,
94
+ neutral_hue=neutral_hue,
95
+ spacing_size=spacing_size,
96
+ radius_size=radius_size,
97
+ text_size=text_size,
98
+ )
99
+
100
+
101
+ h2o_logo = '<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" width="100%" height="100%"' \
102
+ ' viewBox="0 0 600.28 600.28"><defs><style>.cls-1{fill:#fec925;}.cls-2{fill:#161616;}.cls-3{fill:' \
103
+ '#54585a;}</style></defs><g id="Fill-1"><rect class="cls-1" width="600.28" height="600.28" ' \
104
+ 'rx="23.24"/></g><path class="cls-2" d="M174.33,246.06v92.78H152.86v-38H110.71v38H89.24V246.06h21.' \
105
+ '47v36.58h42.15V246.06Z"/><path class="cls-2" d="M259.81,321.34v17.5H189.7V324.92l35.78-33.8c8.22-7.' \
106
+ '82,9.68-12.59,9.68-17.09,0-7.29-5-11.53-14.85-11.53-7.95,0-14.71,3-19.21,9.27L185.46,261.7c7.15-10' \
107
+ '.47,20.14-17.23,36.84-17.23,20.68,0,34.46,10.6,34.46,27.44,0,9-2.52,17.22-15.51,29.29l-21.33,20.14Z"' \
108
+ '/><path class="cls-2" d="M268.69,292.45c0-27.57,21.47-48,50.76-48s50.76,20.28,50.76,48-21.6,48-50.' \
109
+ '76,48S268.69,320,268.69,292.45Zm79.78,0c0-17.63-12.46-29.69-29-29.69s-29,12.06-29,29.69,12.46,29.69' \
110
+ ',29,29.69S348.47,310.08,348.47,292.45Z"/><path class="cls-3" d="M377.23,326.91c0-7.69,5.7-12.73,12.' \
111
+ '85-12.73s12.86,5,12.86,12.73a12.86,12.86,0,1,1-25.71,0Z"/><path class="cls-3" d="M481.4,298.15v40.' \
112
+ '69H462.05V330c-3.84,6.49-11.27,9.94-21.74,9.94-16.7,0-26.64-9.28-26.64-21.61,0-12.59,8.88-21.34,30.' \
113
+ '62-21.34h16.43c0-8.87-5.3-14-16.43-14-7.55,0-15.37,2.51-20.54,6.62l-7.43-14.44c7.82-5.57,19.35-8.' \
114
+ '62,30.75-8.62C468.81,266.47,481.4,276.54,481.4,298.15Zm-20.68,18.16V309H446.54c-9.67,0-12.72,3.57-' \
115
+ '12.72,8.35,0,5.16,4.37,8.61,11.66,8.61C452.37,326,458.34,322.8,460.72,316.31Z"/><path class="cls-3"' \
116
+ ' d="M497.56,246.06c0-6.49,5.17-11.53,12.86-11.53s12.86,4.77,12.86,11.13c0,6.89-5.17,11.93-12.86,' \
117
+ '11.93S497.56,252.55,497.56,246.06Zm2.52,21.47h20.68v71.31H500.08Z"/></svg>'
118
+
119
+
120
+ def get_h2o_title(title):
121
+ return f"""<div style="display:flex; justify-content:center; margin-bottom:30px;">
122
+ <div style="height: 60px; width: 60px; margin-right:20px;">{h2o_logo}</div>
123
+ <h1 style="line-height:60px">{title}</h1>
124
+ </div>
125
+ <div style="float:right; height: 80px; width: 80px; margin-top:-100px">
126
+ <img src=https://raw.githubusercontent.com/h2oai/h2ogpt/main/h2o-qr.png></img>
127
+ </div>
128
+ """
129
+
130
+
131
+ def get_simple_title(title):
132
+ return f"""<h1 align="center"> {title}</h1>"""
133
+
134
+
135
+ def get_dark_js():
136
+ return """() => {
137
+ if (document.querySelectorAll('.dark').length) {
138
+ document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
139
+ } else {
140
+ document.querySelector('body').classList.add('dark');
141
+ }
142
+ }"""
utils.py CHANGED
@@ -284,13 +284,14 @@ class KThread(threading.Thread):
284
  print(thread.name, flush=True)
285
 
286
  @staticmethod
287
- def kill_threads(name):
288
  for thread in threading.enumerate():
289
  if name in thread.name:
290
- print(thread)
291
- print("Trying to kill %s" % thread.ident)
292
  thread.kill()
293
- print(thread)
 
294
 
295
 
296
  def wrapped_partial(func, *args, **kwargs):
 
284
  print(thread.name, flush=True)
285
 
286
  @staticmethod
287
+ def kill_threads(name, debug=False):
288
  for thread in threading.enumerate():
289
  if name in thread.name:
290
+ if debug:
291
+ print("Trying to kill %s %s" % (thread.ident, thread), flush=True)
292
  thread.kill()
293
+ if debug:
294
+ print(thread, flush=True)
295
 
296
 
297
  def wrapped_partial(func, *args, **kwargs):