Spaces:
Runtime error
Runtime error
pseudotensor
commited on
Commit
·
a0e2e84
1
Parent(s):
dc1d7fe
Update with h2oGPT hash f121dcf534b7c7da96e22fdfb00a7436503f167e
Browse files- app.py +72 -16
- finetune.py +1 -0
- utils.py +64 -0
app.py
CHANGED
@@ -2,10 +2,11 @@ import functools
|
|
2 |
import inspect
|
3 |
import sys
|
4 |
import os
|
|
|
5 |
import traceback
|
6 |
import typing
|
7 |
-
|
8 |
-
from utils import set_seed, flatten_list, clear_torch_cache, system_info_print
|
9 |
|
10 |
SEED = 1236
|
11 |
set_seed(SEED)
|
@@ -27,10 +28,11 @@ from finetune import get_loaders, example_data_points, generate_prompt, get_gith
|
|
27 |
human, bot, prompt_type_to_model_name, inv_prompt_type_to_model_lower
|
28 |
from stopping import CallbackToGenerator, Stream, StoppingCriteriaSub
|
29 |
|
30 |
-
is_hf = os.getenv("HUGGINGFACE_SPACES")
|
31 |
-
is_gpth2oai = os.getenv("GPT_H2O_AI")
|
32 |
is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer
|
33 |
is_low_mem = is_hf # assumes run on 24GB consumer GPU
|
|
|
34 |
|
35 |
|
36 |
def main(
|
@@ -58,6 +60,7 @@ def main(
|
|
58 |
|
59 |
llama_type: bool = None,
|
60 |
debug: bool = False,
|
|
|
61 |
share: bool = True,
|
62 |
local_files_only: bool = False,
|
63 |
resume_download: bool = True,
|
@@ -111,6 +114,7 @@ def main(
|
|
111 |
if is_hf:
|
112 |
# must override share if in spaces
|
113 |
share = False
|
|
|
114 |
|
115 |
# get defaults
|
116 |
model_lower = base_model.lower()
|
@@ -178,7 +182,7 @@ def main(
|
|
178 |
if not eval_sharegpt_as_output:
|
179 |
model, tokenizer, device = get_model(**locals())
|
180 |
model_state = [model, tokenizer, device, base_model]
|
181 |
-
fun = partial(evaluate, model_state, debug=debug, chat=chat)
|
182 |
else:
|
183 |
assert eval_sharegpt_prompts_only > 0
|
184 |
|
@@ -542,8 +546,9 @@ def go_gradio(**kwargs):
|
|
542 |
if is_public:
|
543 |
description += """<p><b> DISCLAIMERS: </b><ul><i><li>The data used to train this model include The Pile and other sources. These may contain objectionable content, so the model may reproduce that material. Use application and responses at own risk.</i></li>"""
|
544 |
if kwargs['load_8bit']:
|
545 |
-
description += """<i><li> Model is loaded in 8-bit and
|
546 |
-
description += """<i><li>
|
|
|
547 |
|
548 |
if kwargs['verbose']:
|
549 |
task_info_md = f"""
|
@@ -551,14 +556,43 @@ def go_gradio(**kwargs):
|
|
551 |
else:
|
552 |
task_info_md = ''
|
553 |
|
554 |
-
css_code = """footer {visibility: hidden}
|
555 |
-
body{background-
|
|
|
556 |
|
557 |
-
from gradio.themes.utils import colors, fonts, sizes
|
558 |
if kwargs['h2ocolors']:
|
559 |
-
|
560 |
-
|
561 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
562 |
spacing_size=sizes.spacing_md,
|
563 |
radius_size=sizes.radius_md,
|
564 |
text_size=sizes.text_md,
|
@@ -635,7 +669,7 @@ body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en
|
|
635 |
|
636 |
# go button visible if
|
637 |
base_wanted = bool(kwargs['base_model']) and kwargs['login_mode_if_model0']
|
638 |
-
go_btn = gr.Button(value="
|
639 |
normal_block = gr.Row(visible=not base_wanted)
|
640 |
with normal_block:
|
641 |
with gr.Tabs():
|
@@ -770,12 +804,27 @@ body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en
|
|
770 |
add_model_button = gr.Button("Add new model name")
|
771 |
add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora'])
|
772 |
with gr.TabItem("System"):
|
773 |
-
|
|
|
|
|
|
|
774 |
with gr.Column():
|
775 |
system_text = gr.Textbox(label='System Info')
|
776 |
system_btn = gr.Button(value='Get System Info')
|
777 |
|
|
|
|
|
778 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
779 |
inputs_list = get_inputs_list(locals(), kwargs['model_lower'])
|
780 |
from functools import partial
|
781 |
all_kwargs = kwargs.copy()
|
@@ -1094,7 +1143,7 @@ body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en
|
|
1094 |
|
1095 |
|
1096 |
input_args_list = ['model_state']
|
1097 |
-
inputs_kwargs_list = ['debug', 'chat', 'hard_stop_list', 'sanitize_bot_response', 'model_state0']
|
1098 |
|
1099 |
|
1100 |
def get_inputs_list(inputs_dict, model_lower):
|
@@ -1157,6 +1206,7 @@ def evaluate(
|
|
1157 |
src_lang=None,
|
1158 |
tgt_lang=None,
|
1159 |
debug=False,
|
|
|
1160 |
chat=False,
|
1161 |
hard_stop_list=None,
|
1162 |
sanitize_bot_response=True,
|
@@ -1369,6 +1419,8 @@ def evaluate(
|
|
1369 |
raise StopIteration
|
1370 |
yield prompter.get_response(decoded_output, prompt=inputs_decoded,
|
1371 |
sanitize_bot_response=sanitize_bot_response)
|
|
|
|
|
1372 |
return
|
1373 |
else:
|
1374 |
outputs = model.generate(**gen_kwargs)
|
@@ -1585,5 +1637,9 @@ if __name__ == "__main__":
|
|
1585 |
|
1586 |
python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot' --lora_weights='GPT-NeoXT-Chat-Base-20B.merged.json.8_epochs.57b2892c53df5b8cefac45f84d019cace803ef26.28'
|
1587 |
|
|
|
|
|
|
|
|
|
1588 |
""", flush=True)
|
1589 |
fire.Fire(main)
|
|
|
2 |
import inspect
|
3 |
import sys
|
4 |
import os
|
5 |
+
import time
|
6 |
import traceback
|
7 |
import typing
|
8 |
+
import filelock
|
9 |
+
from utils import set_seed, flatten_list, clear_torch_cache, system_info_print, zip_data, save_generate_output
|
10 |
|
11 |
SEED = 1236
|
12 |
set_seed(SEED)
|
|
|
28 |
human, bot, prompt_type_to_model_name, inv_prompt_type_to_model_lower
|
29 |
from stopping import CallbackToGenerator, Stream, StoppingCriteriaSub
|
30 |
|
31 |
+
is_hf = bool(os.getenv("HUGGINGFACE_SPACES"))
|
32 |
+
is_gpth2oai = bool(os.getenv("GPT_H2O_AI"))
|
33 |
is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer
|
34 |
is_low_mem = is_hf # assumes run on 24GB consumer GPU
|
35 |
+
admin_pass = os.getenv("ADMIN_PASS")
|
36 |
|
37 |
|
38 |
def main(
|
|
|
60 |
|
61 |
llama_type: bool = None,
|
62 |
debug: bool = False,
|
63 |
+
save_path: str = None,
|
64 |
share: bool = True,
|
65 |
local_files_only: bool = False,
|
66 |
resume_download: bool = True,
|
|
|
114 |
if is_hf:
|
115 |
# must override share if in spaces
|
116 |
share = False
|
117 |
+
save_path = os.getenv('SAVE_PATH')
|
118 |
|
119 |
# get defaults
|
120 |
model_lower = base_model.lower()
|
|
|
182 |
if not eval_sharegpt_as_output:
|
183 |
model, tokenizer, device = get_model(**locals())
|
184 |
model_state = [model, tokenizer, device, base_model]
|
185 |
+
fun = partial(evaluate, model_state, debug=debug, chat=chat, save_path=save_path)
|
186 |
else:
|
187 |
assert eval_sharegpt_prompts_only > 0
|
188 |
|
|
|
546 |
if is_public:
|
547 |
description += """<p><b> DISCLAIMERS: </b><ul><i><li>The data used to train this model include The Pile and other sources. These may contain objectionable content, so the model may reproduce that material. Use application and responses at own risk.</i></li>"""
|
548 |
if kwargs['load_8bit']:
|
549 |
+
description += """<i><li> Model is loaded in 8-bit, model loading-unloading is disabled, and other limitations exist in order to fit on GPUs with lower amounts of VRAM, so UX can be worse than non-hosted version.</i></li>"""
|
550 |
+
description += """<i><li>Conversations may be used to improve h2oGPT. Do not share sensitive information.</i></li>"""
|
551 |
+
description += """<i><li>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md).</i></li></ul></p>"""
|
552 |
|
553 |
if kwargs['verbose']:
|
554 |
task_info_md = f"""
|
|
|
556 |
else:
|
557 |
task_info_md = ''
|
558 |
|
559 |
+
css_code = """footer {visibility: hidden;}
|
560 |
+
body{background:linear-gradient(#f5f5f5,#e5e5e5);}
|
561 |
+
body.dark{background:linear-gradient(#0d0d0d,#333333);}"""
|
562 |
|
563 |
+
from gradio.themes.utils import Color, colors, fonts, sizes
|
564 |
if kwargs['h2ocolors']:
|
565 |
+
h2o_yellow = Color(
|
566 |
+
name="yellow",
|
567 |
+
c50="#fffef2",
|
568 |
+
c100="#fff9e6",
|
569 |
+
c200="#ffecb3",
|
570 |
+
c300="#ffe28c",
|
571 |
+
c400="#ffd659",
|
572 |
+
c500="#fec925",
|
573 |
+
c600="#e6ac00",
|
574 |
+
c700="#bf8f00",
|
575 |
+
c800="#a67c00",
|
576 |
+
c900="#664d00",
|
577 |
+
c950="#403000",
|
578 |
+
)
|
579 |
+
h2o_gray = Color(
|
580 |
+
name="gray",
|
581 |
+
c50="#f2f2f2",
|
582 |
+
c100="#e5e5e5",
|
583 |
+
c200="#cccccc",
|
584 |
+
c300="#b2b2b2",
|
585 |
+
c400="#999999",
|
586 |
+
c500="#7f7f7f",
|
587 |
+
c600="#666666",
|
588 |
+
c700="#4c4c4c",
|
589 |
+
c800="#333333",
|
590 |
+
c900="#191919",
|
591 |
+
c950="#0d0d0d",
|
592 |
+
)
|
593 |
+
colors_dict = dict(primary_hue=h2o_yellow,
|
594 |
+
secondary_hue=h2o_yellow,
|
595 |
+
neutral_hue=h2o_gray,
|
596 |
spacing_size=sizes.spacing_md,
|
597 |
radius_size=sizes.radius_md,
|
598 |
text_size=sizes.text_md,
|
|
|
669 |
|
670 |
# go button visible if
|
671 |
base_wanted = bool(kwargs['base_model']) and kwargs['login_mode_if_model0']
|
672 |
+
go_btn = gr.Button(value="ENTER", visible=base_wanted, variant="primary")
|
673 |
normal_block = gr.Row(visible=not base_wanted)
|
674 |
with normal_block:
|
675 |
with gr.Tabs():
|
|
|
804 |
add_model_button = gr.Button("Add new model name")
|
805 |
add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora'])
|
806 |
with gr.TabItem("System"):
|
807 |
+
system_row = gr.Row(visible=not is_public)
|
808 |
+
admin_pass_textbox = gr.Textbox(label="Admin Password", type='password', visible=is_public)
|
809 |
+
admin_btn = gr.Button(value="admin", visible=is_public)
|
810 |
+
with system_row:
|
811 |
with gr.Column():
|
812 |
system_text = gr.Textbox(label='System Info')
|
813 |
system_btn = gr.Button(value='Get System Info')
|
814 |
|
815 |
+
zip_btn = gr.Button("Zip")
|
816 |
+
file_output = gr.File()
|
817 |
|
818 |
+
# Get flagged data
|
819 |
+
zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_path']])
|
820 |
+
zip_btn.click(zip_data1, inputs=None, outputs=file_output)
|
821 |
+
|
822 |
+
def check_admin_pass(x):
|
823 |
+
return gr.update(visible=x == admin_pass)
|
824 |
+
|
825 |
+
admin_btn.click(check_admin_pass, inputs=admin_pass_textbox, outputs=system_row)
|
826 |
+
|
827 |
+
# Get inputs to evaluate()
|
828 |
inputs_list = get_inputs_list(locals(), kwargs['model_lower'])
|
829 |
from functools import partial
|
830 |
all_kwargs = kwargs.copy()
|
|
|
1143 |
|
1144 |
|
1145 |
input_args_list = ['model_state']
|
1146 |
+
inputs_kwargs_list = ['debug', 'chat', 'save_path', 'hard_stop_list', 'sanitize_bot_response', 'model_state0']
|
1147 |
|
1148 |
|
1149 |
def get_inputs_list(inputs_dict, model_lower):
|
|
|
1206 |
src_lang=None,
|
1207 |
tgt_lang=None,
|
1208 |
debug=False,
|
1209 |
+
save_path=None,
|
1210 |
chat=False,
|
1211 |
hard_stop_list=None,
|
1212 |
sanitize_bot_response=True,
|
|
|
1419 |
raise StopIteration
|
1420 |
yield prompter.get_response(decoded_output, prompt=inputs_decoded,
|
1421 |
sanitize_bot_response=sanitize_bot_response)
|
1422 |
+
if save_path:
|
1423 |
+
save_generate_output(output=decoded_output, base_model=base_model, json_file_path=save_path)
|
1424 |
return
|
1425 |
else:
|
1426 |
outputs = model.generate(**gen_kwargs)
|
|
|
1637 |
|
1638 |
python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot' --lora_weights='GPT-NeoXT-Chat-Base-20B.merged.json.8_epochs.57b2892c53df5b8cefac45f84d019cace803ef26.28'
|
1639 |
|
1640 |
+
must have 4*48GB GPU and run without 8bit in order for sharding to work with infer_devices=False
|
1641 |
+
can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned
|
1642 |
+
python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --infer_devices=False --prompt_type='human_bot'
|
1643 |
+
|
1644 |
""", flush=True)
|
1645 |
fire.Fire(main)
|
finetune.py
CHANGED
@@ -73,6 +73,7 @@ prompt_type_to_model_name = {
|
|
73 |
'decapoda-research/llama-7b-hf',
|
74 |
'decapoda-research/llama-13b-hf',
|
75 |
'decapoda-research/llama-30b-hf',
|
|
|
76 |
'facebook/mbart-large-50-many-to-many-mmt',
|
77 |
'philschmid/bart-large-cnn-samsum',
|
78 |
'philschmid/flan-t5-base-samsum',
|
|
|
73 |
'decapoda-research/llama-7b-hf',
|
74 |
'decapoda-research/llama-13b-hf',
|
75 |
'decapoda-research/llama-30b-hf',
|
76 |
+
'decapoda-research/llama-65b-hf',
|
77 |
'facebook/mbart-large-50-many-to-many-mmt',
|
78 |
'philschmid/bart-large-cnn-samsum',
|
79 |
'philschmid/flan-t5-base-samsum',
|
utils.py
CHANGED
@@ -1,7 +1,13 @@
|
|
|
|
1 |
import os
|
2 |
import gc
|
3 |
import random
|
|
|
4 |
import time
|
|
|
|
|
|
|
|
|
5 |
import numpy as np
|
6 |
import pandas as pd
|
7 |
import torch
|
@@ -87,3 +93,61 @@ def system_info_print():
|
|
87 |
return df.to_markdown()
|
88 |
except Exception as e:
|
89 |
return "Error: %s" % str(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
import os
|
3 |
import gc
|
4 |
import random
|
5 |
+
import shutil
|
6 |
import time
|
7 |
+
import traceback
|
8 |
+
import zipfile
|
9 |
+
|
10 |
+
import filelock
|
11 |
import numpy as np
|
12 |
import pandas as pd
|
13 |
import torch
|
|
|
93 |
return df.to_markdown()
|
94 |
except Exception as e:
|
95 |
return "Error: %s" % str(e)
|
96 |
+
|
97 |
+
|
98 |
+
def zip_data(root_dirs=None, zip_path='data.zip', base_dir='./'):
|
99 |
+
try:
|
100 |
+
return _zip_data(zip_path=zip_path, base_dir=base_dir, root_dirs=root_dirs)
|
101 |
+
except Exception as e:
|
102 |
+
traceback.print_exc()
|
103 |
+
print('Exception in zipping: %s' % str(e))
|
104 |
+
|
105 |
+
|
106 |
+
def _zip_data(root_dirs=None, zip_path='data.zip', base_dir='./'):
|
107 |
+
assert root_dirs is not None
|
108 |
+
with zipfile.ZipFile(zip_path, "w") as expt_zip:
|
109 |
+
for root_dir in root_dirs:
|
110 |
+
if root_dir is None:
|
111 |
+
continue
|
112 |
+
for root, d, files in os.walk(root_dir):
|
113 |
+
for file in files:
|
114 |
+
file_to_archive = os.path.join(root, file)
|
115 |
+
assert os.path.exists(file_to_archive)
|
116 |
+
path_to_archive = os.path.relpath(file_to_archive, base_dir)
|
117 |
+
expt_zip.write(filename=file_to_archive, arcname=path_to_archive)
|
118 |
+
return "data.zip"
|
119 |
+
|
120 |
+
|
121 |
+
def save_generate_output(output=None, base_model=None, json_file_path=None):
|
122 |
+
try:
|
123 |
+
return _save_generate_output(output=output, base_model=base_model, json_file_path=json_file_path)
|
124 |
+
except Exception as e:
|
125 |
+
traceback.print_exc()
|
126 |
+
print('Exception in saving: %s' % str(e))
|
127 |
+
|
128 |
+
|
129 |
+
def _save_generate_output(output=None, base_model=None, json_file_path=None):
|
130 |
+
"""
|
131 |
+
Save conversation to .json, row by row
|
132 |
+
Appends if file exists
|
133 |
+
"""
|
134 |
+
assert isinstance(json_file_path, str), "must provide save_path"
|
135 |
+
as_file = os.path.normpath(json_file_path)
|
136 |
+
if os.path.isfile(as_file):
|
137 |
+
# protection if had file there before
|
138 |
+
os.remove(as_file)
|
139 |
+
os.makedirs(json_file_path, exist_ok=True)
|
140 |
+
json_file_file = os.path.join(json_file_path, 'save.json')
|
141 |
+
import json
|
142 |
+
if output[-10:] == '\n\n<human>:':
|
143 |
+
# remove trailing <human>:
|
144 |
+
output = output[:-10]
|
145 |
+
with filelock.FileLock("save_path.lock"):
|
146 |
+
# lock logging in case have concurrency
|
147 |
+
with open(json_file_file, "a") as f:
|
148 |
+
# just add [ at start, and ] at end, and have proper JSON dataset
|
149 |
+
f.write(
|
150 |
+
" " + json.dumps(
|
151 |
+
dict(text=output, time=time.ctime(), base_model=base_model)
|
152 |
+
) + ",\n"
|
153 |
+
)
|