update config dict
Browse files
app.py
CHANGED
@@ -30,30 +30,34 @@ from config import *
|
|
30 |
# - a list of available models from the config file
|
31 |
# - a list of available schedulers from the config file
|
32 |
# - a dict that contains code to for reproduction
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
49 |
|
50 |
config_history = []
|
51 |
|
52 |
def device_change(device):
|
53 |
|
54 |
code[code_pos_device] = f'''device = "{device}"'''
|
|
|
55 |
|
56 |
-
return get_sorted_code()
|
57 |
|
58 |
def models_change(model, scheduler):
|
59 |
|
@@ -74,16 +78,19 @@ def models_change(model, scheduler):
|
|
74 |
use_safetensors=use_safetensors,
|
75 |
torch_dtype=data_type,
|
76 |
variant=variant).to(device)'''
|
|
|
77 |
|
78 |
safety_checker_change(safety_checker)
|
79 |
requires_safety_checker_change(requires_safety_checker)
|
80 |
|
81 |
-
return get_sorted_code(), use_safetensors, scheduler
|
82 |
|
83 |
def data_type_change(selected_data_type):
|
84 |
|
|
|
|
|
85 |
get_data_type(selected_data_type)
|
86 |
-
return get_sorted_code()
|
87 |
|
88 |
def get_data_type(selected_data_type):
|
89 |
|
@@ -98,9 +105,11 @@ def get_data_type(selected_data_type):
|
|
98 |
|
99 |
def tensorfloat32_change(allow_tensorfloat32):
|
100 |
|
|
|
|
|
101 |
get_tensorfloat32(allow_tensorfloat32)
|
102 |
|
103 |
-
return get_sorted_code()
|
104 |
|
105 |
def get_tensorfloat32(allow_tensorfloat32):
|
106 |
|
@@ -110,27 +119,33 @@ def get_tensorfloat32(allow_tensorfloat32):
|
|
110 |
|
111 |
def variant_change(variant):
|
112 |
|
|
|
|
|
113 |
if str(variant) == 'None':
|
114 |
code[code_pos_variant] = f'variant = {variant}'
|
115 |
else:
|
116 |
code[code_pos_variant] = f'variant = "{variant}"'
|
117 |
|
118 |
-
return get_sorted_code()
|
119 |
|
120 |
def safety_checker_change(safety_checker):
|
121 |
|
|
|
|
|
122 |
if not safety_checker or str(safety_checker).lower == 'false':
|
123 |
code[code_pos_safety_checker] = f'pipeline.safety_checker = None'
|
124 |
else:
|
125 |
code[code_pos_safety_checker] = ''
|
126 |
|
127 |
-
return get_sorted_code()
|
128 |
|
129 |
def requires_safety_checker_change(requires_safety_checker):
|
130 |
|
131 |
code[code_pos_requires_safety_checker] = f'pipeline.requires_safety_checker = {requires_safety_checker}'
|
132 |
|
133 |
-
|
|
|
|
|
134 |
|
135 |
def schedulers_change(scheduler):
|
136 |
|
@@ -138,11 +153,13 @@ def schedulers_change(scheduler):
|
|
138 |
|
139 |
code[code_pos_scheduler] = f'pipeline.scheduler = {scheduler}.from_config(pipeline.scheduler.config)'
|
140 |
|
141 |
-
|
|
|
|
|
142 |
|
143 |
else:
|
144 |
|
145 |
-
return get_sorted_code(), ''
|
146 |
|
147 |
def get_scheduler(scheduler, config):
|
148 |
|
@@ -202,7 +219,7 @@ def run_inference(model,
|
|
202 |
|
203 |
pipeline.scheduler = get_scheduler(scheduler, pipeline.scheduler.config)
|
204 |
|
205 |
-
|
206 |
if manual_seed < 0 or manual_seed is None or manual_seed == '':
|
207 |
generator = torch.Generator(device)
|
208 |
else:
|
@@ -217,8 +234,8 @@ def run_inference(model,
|
|
217 |
num_inference_steps=int(inference_steps),
|
218 |
guidance_scale=float(guidance_scale)).images[0]
|
219 |
|
220 |
-
|
221 |
-
return
|
222 |
|
223 |
else:
|
224 |
|
@@ -254,7 +271,7 @@ with gr.Blocks() as demo:
|
|
254 |
gr.Markdown("### Device specific settings")
|
255 |
with gr.Row():
|
256 |
in_devices = gr.Dropdown(label="Device:", value=device, choices=devices, filterable=True, multiselect=False, allow_custom_value=True)
|
257 |
-
in_data_type = gr.Radio(label="Data Type:", value=data_type, choices=["bfloat16", "float16"], info="`
|
258 |
in_allow_tensorfloat32 = gr.Radio(label="Allow TensorFloat32:", value=allow_tensorfloat32, choices=[True, False], info="is not supported on MPS devices right now; use TensorFloat-32 is faster, but results in slightly less accurate computations, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
|
259 |
in_variant = gr.Radio(label="Variant:", value=variant, choices=["fp16", None], info="Use half-precision weights will save GPU memory, not all models support that, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
|
260 |
|
@@ -294,18 +311,18 @@ with gr.Blocks() as demo:
|
|
294 |
out_image = gr.Image()
|
295 |
out_code = gr.Code(get_sorted_code(), label="Code")
|
296 |
with gr.Row():
|
297 |
-
out_current_config = gr.Code(value=str(
|
298 |
with gr.Row():
|
299 |
out_config_history = gr.Markdown(dict_list_to_markdown_table(config_history))
|
300 |
|
301 |
-
in_devices.change(device_change, inputs=[in_devices], outputs=[out_code])
|
302 |
-
in_data_type.change(data_type_change, inputs=[in_data_type], outputs=[out_code])
|
303 |
-
in_allow_tensorfloat32.change(tensorfloat32_change, inputs=[in_allow_tensorfloat32], outputs=[out_code])
|
304 |
-
in_variant.change(variant_change, inputs=[in_variant], outputs=[out_code])
|
305 |
-
in_models.change(models_change, inputs=[in_models, in_schedulers], outputs=[out_code, in_use_safetensors, in_schedulers])
|
306 |
-
in_safety_checker.change(safety_checker_change, inputs=[in_safety_checker], outputs=[out_code])
|
307 |
-
in_requires_safety_checker.change(requires_safety_checker_change, inputs=[in_requires_safety_checker], outputs=[out_code])
|
308 |
-
in_schedulers.change(schedulers_change, inputs=[in_schedulers], outputs=[out_code, out_scheduler_description])
|
309 |
btn_start_pipeline.click(run_inference, inputs=[
|
310 |
in_models,
|
311 |
in_devices,
|
@@ -320,7 +337,9 @@ with gr.Blocks() as demo:
|
|
320 |
in_inference_steps,
|
321 |
in_manual_seed,
|
322 |
in_guidance_scale
|
323 |
-
], outputs=[
|
|
|
|
|
324 |
|
325 |
demo.load(fn=init_config, inputs=out_current_config,
|
326 |
outputs=[
|
|
|
30 |
# - a list of available models from the config file
|
31 |
# - a list of available schedulers from the config file
|
32 |
# - a dict that contains code to for reproduction
|
33 |
+
config, devices, model_configs, scheduler_configs, code = get_inital_config()
|
34 |
+
|
35 |
+
models = list(model_configs.keys())
|
36 |
+
schedulers = list(scheduler_configs.keys())
|
37 |
+
|
38 |
+
device = config["device"]
|
39 |
+
model = config["model"]
|
40 |
+
scheduler = config["scheduler"]
|
41 |
+
variant = config["variant"]
|
42 |
+
allow_tensorfloat32 = config["allow_tensorfloat32"]
|
43 |
+
use_safetensors = config["use_safetensors"]
|
44 |
+
data_type = config["data_type"]
|
45 |
+
safety_checker = config["safety_checker"]
|
46 |
+
requires_safety_checker = config["requires_safety_checker"]
|
47 |
+
manual_seed = config["manual_seed"]
|
48 |
+
inference_steps = config["inference_steps"]
|
49 |
+
guidance_scale = config["guidance_scale"]
|
50 |
+
prompt = config["prompt"]
|
51 |
+
negative_prompt = config["negative_prompt"]
|
52 |
|
53 |
config_history = []
|
54 |
|
55 |
def device_change(device):
|
56 |
|
57 |
code[code_pos_device] = f'''device = "{device}"'''
|
58 |
+
config['device'] = device
|
59 |
|
60 |
+
return get_sorted_code(), str(config)
|
61 |
|
62 |
def models_change(model, scheduler):
|
63 |
|
|
|
78 |
use_safetensors=use_safetensors,
|
79 |
torch_dtype=data_type,
|
80 |
variant=variant).to(device)'''
|
81 |
+
config['model'] = model
|
82 |
|
83 |
safety_checker_change(safety_checker)
|
84 |
requires_safety_checker_change(requires_safety_checker)
|
85 |
|
86 |
+
return get_sorted_code(), use_safetensors, scheduler, str(config)
|
87 |
|
88 |
def data_type_change(selected_data_type):
|
89 |
|
90 |
+
config['data_type'] = data_type
|
91 |
+
|
92 |
get_data_type(selected_data_type)
|
93 |
+
return get_sorted_code(), str(config)
|
94 |
|
95 |
def get_data_type(selected_data_type):
|
96 |
|
|
|
105 |
|
106 |
def tensorfloat32_change(allow_tensorfloat32):
|
107 |
|
108 |
+
config['allow_tensorfloat32'] = allow_tensorfloat32
|
109 |
+
|
110 |
get_tensorfloat32(allow_tensorfloat32)
|
111 |
|
112 |
+
return get_sorted_code(), str(config)
|
113 |
|
114 |
def get_tensorfloat32(allow_tensorfloat32):
|
115 |
|
|
|
119 |
|
120 |
def variant_change(variant):
|
121 |
|
122 |
+
config['variant'] = variant
|
123 |
+
|
124 |
if str(variant) == 'None':
|
125 |
code[code_pos_variant] = f'variant = {variant}'
|
126 |
else:
|
127 |
code[code_pos_variant] = f'variant = "{variant}"'
|
128 |
|
129 |
+
return get_sorted_code(), str(config)
|
130 |
|
131 |
def safety_checker_change(safety_checker):
|
132 |
|
133 |
+
config['safety_checker'] = safety_checker
|
134 |
+
|
135 |
if not safety_checker or str(safety_checker).lower == 'false':
|
136 |
code[code_pos_safety_checker] = f'pipeline.safety_checker = None'
|
137 |
else:
|
138 |
code[code_pos_safety_checker] = ''
|
139 |
|
140 |
+
return get_sorted_code(), str(config)
|
141 |
|
142 |
def requires_safety_checker_change(requires_safety_checker):
|
143 |
|
144 |
code[code_pos_requires_safety_checker] = f'pipeline.requires_safety_checker = {requires_safety_checker}'
|
145 |
|
146 |
+
config['requires_safety_checker'] = requires_safety_checker
|
147 |
+
|
148 |
+
return get_sorted_code(), str(config)
|
149 |
|
150 |
def schedulers_change(scheduler):
|
151 |
|
|
|
153 |
|
154 |
code[code_pos_scheduler] = f'pipeline.scheduler = {scheduler}.from_config(pipeline.scheduler.config)'
|
155 |
|
156 |
+
config['scheduler'] = scheduler
|
157 |
+
|
158 |
+
return get_sorted_code(), scheduler_configs[scheduler], str(config)
|
159 |
|
160 |
else:
|
161 |
|
162 |
+
return get_sorted_code(), '', str(config)
|
163 |
|
164 |
def get_scheduler(scheduler, config):
|
165 |
|
|
|
219 |
|
220 |
pipeline.scheduler = get_scheduler(scheduler, pipeline.scheduler.config)
|
221 |
|
222 |
+
manual_seed = int(manual_seed)
|
223 |
if manual_seed < 0 or manual_seed is None or manual_seed == '':
|
224 |
generator = torch.Generator(device)
|
225 |
else:
|
|
|
234 |
num_inference_steps=int(inference_steps),
|
235 |
guidance_scale=float(guidance_scale)).images[0]
|
236 |
|
237 |
+
config_history.append(config)
|
238 |
+
return image, dict_list_to_markdown_table(config_history)
|
239 |
|
240 |
else:
|
241 |
|
|
|
271 |
gr.Markdown("### Device specific settings")
|
272 |
with gr.Row():
|
273 |
in_devices = gr.Dropdown(label="Device:", value=device, choices=devices, filterable=True, multiselect=False, allow_custom_value=True)
|
274 |
+
in_data_type = gr.Radio(label="Data Type:", value=data_type, choices=["bfloat16", "float16"], info="`bfloat16` is not supported on MPS devices right now; Half-precision weights, will save GPU memory, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16")
|
275 |
in_allow_tensorfloat32 = gr.Radio(label="Allow TensorFloat32:", value=allow_tensorfloat32, choices=[True, False], info="is not supported on MPS devices right now; use TensorFloat-32 is faster, but results in slightly less accurate computations, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
|
276 |
in_variant = gr.Radio(label="Variant:", value=variant, choices=["fp16", None], info="Use half-precision weights will save GPU memory, not all models support that, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
|
277 |
|
|
|
311 |
out_image = gr.Image()
|
312 |
out_code = gr.Code(get_sorted_code(), label="Code")
|
313 |
with gr.Row():
|
314 |
+
out_current_config = gr.Code(value=str(config), label="Current config")
|
315 |
with gr.Row():
|
316 |
out_config_history = gr.Markdown(dict_list_to_markdown_table(config_history))
|
317 |
|
318 |
+
in_devices.change(device_change, inputs=[in_devices], outputs=[out_code, out_current_config])
|
319 |
+
in_data_type.change(data_type_change, inputs=[in_data_type], outputs=[out_code, out_current_config])
|
320 |
+
in_allow_tensorfloat32.change(tensorfloat32_change, inputs=[in_allow_tensorfloat32], outputs=[out_code, out_current_config])
|
321 |
+
in_variant.change(variant_change, inputs=[in_variant], outputs=[out_code, out_current_config])
|
322 |
+
in_models.change(models_change, inputs=[in_models, in_schedulers], outputs=[out_code, in_use_safetensors, in_schedulers, out_current_config])
|
323 |
+
in_safety_checker.change(safety_checker_change, inputs=[in_safety_checker], outputs=[out_code, out_current_config])
|
324 |
+
in_requires_safety_checker.change(requires_safety_checker_change, inputs=[in_requires_safety_checker], outputs=[out_code, out_current_config])
|
325 |
+
in_schedulers.change(schedulers_change, inputs=[in_schedulers], outputs=[out_code, out_scheduler_description, out_current_config])
|
326 |
btn_start_pipeline.click(run_inference, inputs=[
|
327 |
in_models,
|
328 |
in_devices,
|
|
|
337 |
in_inference_steps,
|
338 |
in_manual_seed,
|
339 |
in_guidance_scale
|
340 |
+
], outputs=[
|
341 |
+
out_image,
|
342 |
+
out_config_history])
|
343 |
|
344 |
demo.load(fn=init_config, inputs=out_current_config,
|
345 |
outputs=[
|
config.py
CHANGED
@@ -3,6 +3,7 @@ import base64
|
|
3 |
import json
|
4 |
import torch
|
5 |
|
|
|
6 |
code_pos_device = '001_code'
|
7 |
code_pos_data_type = '002_data_type'
|
8 |
code_pos_tf32 = '003_tf32'
|
@@ -35,16 +36,8 @@ def get_inital_config():
|
|
35 |
|
36 |
appConfig = load_app_config()
|
37 |
|
38 |
-
# default model is None
|
39 |
model_configs = appConfig.get("models", {})
|
40 |
-
# default model is None
|
41 |
-
models = list(model_configs.keys())
|
42 |
-
model = None
|
43 |
-
|
44 |
-
# default scheduler is None
|
45 |
scheduler_configs = appConfig.get("schedulers", {})
|
46 |
-
schedulers = list(scheduler_configs.keys())
|
47 |
-
scheduler = None
|
48 |
|
49 |
# default device
|
50 |
devices = appConfig.get("devices", [])
|
@@ -65,7 +58,7 @@ def get_inital_config():
|
|
65 |
"model": None,
|
66 |
"scheduler": None,
|
67 |
"variant": None,
|
68 |
-
"
|
69 |
"use_safetensors": False,
|
70 |
"data_type": data_type,
|
71 |
"safety_checker": False,
|
@@ -78,21 +71,19 @@ def get_inital_config():
|
|
78 |
}
|
79 |
|
80 |
# code output order
|
81 |
-
code = {}
|
82 |
-
|
83 |
code[code_pos_device] = f'device = "{device}"'
|
84 |
-
code[code_pos_variant] = f'variant = {initial_config[
|
85 |
-
code[code_pos_tf32] = f'torch.backends.cuda.matmul.allow_tf32 = {initial_config[
|
86 |
code[code_pos_data_type] = 'data_type = torch.bfloat16'
|
87 |
code[code_pos_init_pipeline] = 'sys.exit("No model selected!")'
|
88 |
code[code_pos_safety_checker] = 'pipeline.safety_checker = None'
|
89 |
-
code[code_pos_requires_safety_checker] = f'pipeline.requires_safety_checker = {initial_config[
|
90 |
code[code_pos_scheduler] = 'sys.exit("No scheduler selected!")'
|
91 |
code[code_pos_generator] = f'generator = torch.Generator("{device}")'
|
92 |
-
code[code_pos_prompt] = f'prompt = "{initial_config[
|
93 |
-
code[code_pos_negative_prompt] = f'negative_prompt = "{initial_config[
|
94 |
-
code[code_pos_inference_steps] = f'inference_steps = {initial_config[
|
95 |
-
code[code_pos_guidance_scale] = f'guidance_scale = {initial_config[
|
96 |
code[code_pos_run_inference] = f'''image = pipeline(
|
97 |
prompt=prompt,
|
98 |
negative_prompt=negative_prompt,
|
@@ -100,7 +91,7 @@ def get_inital_config():
|
|
100 |
num_inference_steps=inference_steps,
|
101 |
guidance_scale=guidance_scale).images[0]'''
|
102 |
|
103 |
-
return initial_config, devices,
|
104 |
|
105 |
def init_config(request: gr.Request, inital_config):
|
106 |
|
|
|
3 |
import json
|
4 |
import torch
|
5 |
|
6 |
+
code = {}
|
7 |
code_pos_device = '001_code'
|
8 |
code_pos_data_type = '002_data_type'
|
9 |
code_pos_tf32 = '003_tf32'
|
|
|
36 |
|
37 |
appConfig = load_app_config()
|
38 |
|
|
|
39 |
model_configs = appConfig.get("models", {})
|
|
|
|
|
|
|
|
|
|
|
40 |
scheduler_configs = appConfig.get("schedulers", {})
|
|
|
|
|
41 |
|
42 |
# default device
|
43 |
devices = appConfig.get("devices", [])
|
|
|
58 |
"model": None,
|
59 |
"scheduler": None,
|
60 |
"variant": None,
|
61 |
+
"allow_tensorfloat32": allow_tensorfloat32,
|
62 |
"use_safetensors": False,
|
63 |
"data_type": data_type,
|
64 |
"safety_checker": False,
|
|
|
71 |
}
|
72 |
|
73 |
# code output order
|
|
|
|
|
74 |
code[code_pos_device] = f'device = "{device}"'
|
75 |
+
code[code_pos_variant] = f'variant = {initial_config["variant"]}'
|
76 |
+
code[code_pos_tf32] = f'torch.backends.cuda.matmul.allow_tf32 = {initial_config["allow_tensorfloat32"]}'
|
77 |
code[code_pos_data_type] = 'data_type = torch.bfloat16'
|
78 |
code[code_pos_init_pipeline] = 'sys.exit("No model selected!")'
|
79 |
code[code_pos_safety_checker] = 'pipeline.safety_checker = None'
|
80 |
+
code[code_pos_requires_safety_checker] = f'pipeline.requires_safety_checker = {initial_config["requires_safety_checker"]}'
|
81 |
code[code_pos_scheduler] = 'sys.exit("No scheduler selected!")'
|
82 |
code[code_pos_generator] = f'generator = torch.Generator("{device}")'
|
83 |
+
code[code_pos_prompt] = f'prompt = "{initial_config["prompt"]}"'
|
84 |
+
code[code_pos_negative_prompt] = f'negative_prompt = "{initial_config["negative_prompt"]}"'
|
85 |
+
code[code_pos_inference_steps] = f'inference_steps = {initial_config["inference_steps"]}'
|
86 |
+
code[code_pos_guidance_scale] = f'guidance_scale = {initial_config["guidance_scale"]}'
|
87 |
code[code_pos_run_inference] = f'''image = pipeline(
|
88 |
prompt=prompt,
|
89 |
negative_prompt=negative_prompt,
|
|
|
91 |
num_inference_steps=inference_steps,
|
92 |
guidance_scale=guidance_scale).images[0]'''
|
93 |
|
94 |
+
return initial_config, devices, model_configs, scheduler_configs, code
|
95 |
|
96 |
def init_config(request: gr.Request, inital_config):
|
97 |
|