n42 commited on
Commit
a6a747f
·
1 Parent(s): a5097b6

update config dict

Browse files
Files changed (2) hide show
  1. app.py +58 -39
  2. config.py +10 -19
app.py CHANGED
@@ -30,30 +30,34 @@ from config import *
30
  # - a list of available models from the config file
31
  # - a list of available schedulers from the config file
32
  # - a dict that contains code to for reproduction
33
- initial_config, devices, models, schedulers, code = get_inital_config()
34
-
35
- device = initial_config["device"]
36
- model = initial_config["model"]
37
- scheduler = initial_config["scheduler"]
38
- variant = initial_config["variant"]
39
- allow_tensorfloat32 = initial_config["allow_tensorfloat32"]
40
- use_safetensors = initial_config["use_safetensors"]
41
- data_type = initial_config["data_type"]
42
- safety_checker = initial_config["safety_checker"]
43
- requires_safety_checker = initial_config["requires_safety_checker"]
44
- manual_seed = initial_config["manual_seed"]
45
- inference_steps = initial_config["inference_steps"]
46
- guidance_scale = initial_config["guidance_scale"]
47
- prompt = initial_config["prompt"]
48
- negative_prompt = initial_config["negative_prompt"]
 
 
 
49
 
50
  config_history = []
51
 
52
  def device_change(device):
53
 
54
  code[code_pos_device] = f'''device = "{device}"'''
 
55
 
56
- return get_sorted_code()
57
 
58
  def models_change(model, scheduler):
59
 
@@ -74,16 +78,19 @@ def models_change(model, scheduler):
74
  use_safetensors=use_safetensors,
75
  torch_dtype=data_type,
76
  variant=variant).to(device)'''
 
77
 
78
  safety_checker_change(safety_checker)
79
  requires_safety_checker_change(requires_safety_checker)
80
 
81
- return get_sorted_code(), use_safetensors, scheduler
82
 
83
  def data_type_change(selected_data_type):
84
 
 
 
85
  get_data_type(selected_data_type)
86
- return get_sorted_code()
87
 
88
  def get_data_type(selected_data_type):
89
 
@@ -98,9 +105,11 @@ def get_data_type(selected_data_type):
98
 
99
  def tensorfloat32_change(allow_tensorfloat32):
100
 
 
 
101
  get_tensorfloat32(allow_tensorfloat32)
102
 
103
- return get_sorted_code()
104
 
105
  def get_tensorfloat32(allow_tensorfloat32):
106
 
@@ -110,27 +119,33 @@ def get_tensorfloat32(allow_tensorfloat32):
110
 
111
  def variant_change(variant):
112
 
 
 
113
  if str(variant) == 'None':
114
  code[code_pos_variant] = f'variant = {variant}'
115
  else:
116
  code[code_pos_variant] = f'variant = "{variant}"'
117
 
118
- return get_sorted_code()
119
 
120
  def safety_checker_change(safety_checker):
121
 
 
 
122
  if not safety_checker or str(safety_checker).lower == 'false':
123
  code[code_pos_safety_checker] = f'pipeline.safety_checker = None'
124
  else:
125
  code[code_pos_safety_checker] = ''
126
 
127
- return get_sorted_code()
128
 
129
  def requires_safety_checker_change(requires_safety_checker):
130
 
131
  code[code_pos_requires_safety_checker] = f'pipeline.requires_safety_checker = {requires_safety_checker}'
132
 
133
- return get_sorted_code()
 
 
134
 
135
  def schedulers_change(scheduler):
136
 
@@ -138,11 +153,13 @@ def schedulers_change(scheduler):
138
 
139
  code[code_pos_scheduler] = f'pipeline.scheduler = {scheduler}.from_config(pipeline.scheduler.config)'
140
 
141
- return get_sorted_code(), scheduler_configs[scheduler]
 
 
142
 
143
  else:
144
 
145
- return get_sorted_code(), ''
146
 
147
  def get_scheduler(scheduler, config):
148
 
@@ -202,7 +219,7 @@ def run_inference(model,
202
 
203
  pipeline.scheduler = get_scheduler(scheduler, pipeline.scheduler.config)
204
 
205
-
206
  if manual_seed < 0 or manual_seed is None or manual_seed == '':
207
  generator = torch.Generator(device)
208
  else:
@@ -217,8 +234,8 @@ def run_inference(model,
217
  num_inference_steps=int(inference_steps),
218
  guidance_scale=float(guidance_scale)).images[0]
219
 
220
-
221
- return "Done.", image
222
 
223
  else:
224
 
@@ -254,7 +271,7 @@ with gr.Blocks() as demo:
254
  gr.Markdown("### Device specific settings")
255
  with gr.Row():
256
  in_devices = gr.Dropdown(label="Device:", value=device, choices=devices, filterable=True, multiselect=False, allow_custom_value=True)
257
- in_data_type = gr.Radio(label="Data Type:", value=data_type, choices=["bfloat16", "float16"], info="`blfoat16` is not supported on MPS devices right now; Half-precision weights, will save GPU memory, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16")
258
  in_allow_tensorfloat32 = gr.Radio(label="Allow TensorFloat32:", value=allow_tensorfloat32, choices=[True, False], info="is not supported on MPS devices right now; use TensorFloat-32 is faster, but results in slightly less accurate computations, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
259
  in_variant = gr.Radio(label="Variant:", value=variant, choices=["fp16", None], info="Use half-precision weights will save GPU memory, not all models support that, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
260
 
@@ -294,18 +311,18 @@ with gr.Blocks() as demo:
294
  out_image = gr.Image()
295
  out_code = gr.Code(get_sorted_code(), label="Code")
296
  with gr.Row():
297
- out_current_config = gr.Code(value=str(initial_config), label="Current config")
298
  with gr.Row():
299
  out_config_history = gr.Markdown(dict_list_to_markdown_table(config_history))
300
 
301
- in_devices.change(device_change, inputs=[in_devices], outputs=[out_code])
302
- in_data_type.change(data_type_change, inputs=[in_data_type], outputs=[out_code])
303
- in_allow_tensorfloat32.change(tensorfloat32_change, inputs=[in_allow_tensorfloat32], outputs=[out_code])
304
- in_variant.change(variant_change, inputs=[in_variant], outputs=[out_code])
305
- in_models.change(models_change, inputs=[in_models, in_schedulers], outputs=[out_code, in_use_safetensors, in_schedulers])
306
- in_safety_checker.change(safety_checker_change, inputs=[in_safety_checker], outputs=[out_code])
307
- in_requires_safety_checker.change(requires_safety_checker_change, inputs=[in_requires_safety_checker], outputs=[out_code])
308
- in_schedulers.change(schedulers_change, inputs=[in_schedulers], outputs=[out_code, out_scheduler_description])
309
  btn_start_pipeline.click(run_inference, inputs=[
310
  in_models,
311
  in_devices,
@@ -320,7 +337,9 @@ with gr.Blocks() as demo:
320
  in_inference_steps,
321
  in_manual_seed,
322
  in_guidance_scale
323
- ], outputs=[out_image])
 
 
324
 
325
  demo.load(fn=init_config, inputs=out_current_config,
326
  outputs=[
 
30
  # - a list of available models from the config file
31
  # - a list of available schedulers from the config file
32
  # - a dict that contains code to for reproduction
33
+ config, devices, model_configs, scheduler_configs, code = get_inital_config()
34
+
35
+ models = list(model_configs.keys())
36
+ schedulers = list(scheduler_configs.keys())
37
+
38
+ device = config["device"]
39
+ model = config["model"]
40
+ scheduler = config["scheduler"]
41
+ variant = config["variant"]
42
+ allow_tensorfloat32 = config["allow_tensorfloat32"]
43
+ use_safetensors = config["use_safetensors"]
44
+ data_type = config["data_type"]
45
+ safety_checker = config["safety_checker"]
46
+ requires_safety_checker = config["requires_safety_checker"]
47
+ manual_seed = config["manual_seed"]
48
+ inference_steps = config["inference_steps"]
49
+ guidance_scale = config["guidance_scale"]
50
+ prompt = config["prompt"]
51
+ negative_prompt = config["negative_prompt"]
52
 
53
  config_history = []
54
 
55
  def device_change(device):
56
 
57
  code[code_pos_device] = f'''device = "{device}"'''
58
+ config['device'] = device
59
 
60
+ return get_sorted_code(), str(config)
61
 
62
  def models_change(model, scheduler):
63
 
 
78
  use_safetensors=use_safetensors,
79
  torch_dtype=data_type,
80
  variant=variant).to(device)'''
81
+ config['model'] = model
82
 
83
  safety_checker_change(safety_checker)
84
  requires_safety_checker_change(requires_safety_checker)
85
 
86
+ return get_sorted_code(), use_safetensors, scheduler, str(config)
87
 
88
  def data_type_change(selected_data_type):
89
 
90
+ config['data_type'] = data_type
91
+
92
  get_data_type(selected_data_type)
93
+ return get_sorted_code(), str(config)
94
 
95
  def get_data_type(selected_data_type):
96
 
 
105
 
106
  def tensorfloat32_change(allow_tensorfloat32):
107
 
108
+ config['allow_tensorfloat32'] = allow_tensorfloat32
109
+
110
  get_tensorfloat32(allow_tensorfloat32)
111
 
112
+ return get_sorted_code(), str(config)
113
 
114
  def get_tensorfloat32(allow_tensorfloat32):
115
 
 
119
 
120
  def variant_change(variant):
121
 
122
+ config['variant'] = variant
123
+
124
  if str(variant) == 'None':
125
  code[code_pos_variant] = f'variant = {variant}'
126
  else:
127
  code[code_pos_variant] = f'variant = "{variant}"'
128
 
129
+ return get_sorted_code(), str(config)
130
 
131
  def safety_checker_change(safety_checker):
132
 
133
+ config['safety_checker'] = safety_checker
134
+
135
  if not safety_checker or str(safety_checker).lower == 'false':
136
  code[code_pos_safety_checker] = f'pipeline.safety_checker = None'
137
  else:
138
  code[code_pos_safety_checker] = ''
139
 
140
+ return get_sorted_code(), str(config)
141
 
142
  def requires_safety_checker_change(requires_safety_checker):
143
 
144
  code[code_pos_requires_safety_checker] = f'pipeline.requires_safety_checker = {requires_safety_checker}'
145
 
146
+ config['requires_safety_checker'] = requires_safety_checker
147
+
148
+ return get_sorted_code(), str(config)
149
 
150
  def schedulers_change(scheduler):
151
 
 
153
 
154
  code[code_pos_scheduler] = f'pipeline.scheduler = {scheduler}.from_config(pipeline.scheduler.config)'
155
 
156
+ config['scheduler'] = scheduler
157
+
158
+ return get_sorted_code(), scheduler_configs[scheduler], str(config)
159
 
160
  else:
161
 
162
+ return get_sorted_code(), '', str(config)
163
 
164
  def get_scheduler(scheduler, config):
165
 
 
219
 
220
  pipeline.scheduler = get_scheduler(scheduler, pipeline.scheduler.config)
221
 
222
+ manual_seed = int(manual_seed)
223
  if manual_seed < 0 or manual_seed is None or manual_seed == '':
224
  generator = torch.Generator(device)
225
  else:
 
234
  num_inference_steps=int(inference_steps),
235
  guidance_scale=float(guidance_scale)).images[0]
236
 
237
+ config_history.append(config)
238
+ return image, dict_list_to_markdown_table(config_history)
239
 
240
  else:
241
 
 
271
  gr.Markdown("### Device specific settings")
272
  with gr.Row():
273
  in_devices = gr.Dropdown(label="Device:", value=device, choices=devices, filterable=True, multiselect=False, allow_custom_value=True)
274
+ in_data_type = gr.Radio(label="Data Type:", value=data_type, choices=["bfloat16", "float16"], info="`bfloat16` is not supported on MPS devices right now; Half-precision weights, will save GPU memory, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16")
275
  in_allow_tensorfloat32 = gr.Radio(label="Allow TensorFloat32:", value=allow_tensorfloat32, choices=[True, False], info="is not supported on MPS devices right now; use TensorFloat-32 is faster, but results in slightly less accurate computations, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
276
  in_variant = gr.Radio(label="Variant:", value=variant, choices=["fp16", None], info="Use half-precision weights will save GPU memory, not all models support that, see https://huggingface.co/docs/diffusers/main/en/optimization/fp16 ")
277
 
 
311
  out_image = gr.Image()
312
  out_code = gr.Code(get_sorted_code(), label="Code")
313
  with gr.Row():
314
+ out_current_config = gr.Code(value=str(config), label="Current config")
315
  with gr.Row():
316
  out_config_history = gr.Markdown(dict_list_to_markdown_table(config_history))
317
 
318
+ in_devices.change(device_change, inputs=[in_devices], outputs=[out_code, out_current_config])
319
+ in_data_type.change(data_type_change, inputs=[in_data_type], outputs=[out_code, out_current_config])
320
+ in_allow_tensorfloat32.change(tensorfloat32_change, inputs=[in_allow_tensorfloat32], outputs=[out_code, out_current_config])
321
+ in_variant.change(variant_change, inputs=[in_variant], outputs=[out_code, out_current_config])
322
+ in_models.change(models_change, inputs=[in_models, in_schedulers], outputs=[out_code, in_use_safetensors, in_schedulers, out_current_config])
323
+ in_safety_checker.change(safety_checker_change, inputs=[in_safety_checker], outputs=[out_code, out_current_config])
324
+ in_requires_safety_checker.change(requires_safety_checker_change, inputs=[in_requires_safety_checker], outputs=[out_code, out_current_config])
325
+ in_schedulers.change(schedulers_change, inputs=[in_schedulers], outputs=[out_code, out_scheduler_description, out_current_config])
326
  btn_start_pipeline.click(run_inference, inputs=[
327
  in_models,
328
  in_devices,
 
337
  in_inference_steps,
338
  in_manual_seed,
339
  in_guidance_scale
340
+ ], outputs=[
341
+ out_image,
342
+ out_config_history])
343
 
344
  demo.load(fn=init_config, inputs=out_current_config,
345
  outputs=[
config.py CHANGED
@@ -3,6 +3,7 @@ import base64
3
  import json
4
  import torch
5
 
 
6
  code_pos_device = '001_code'
7
  code_pos_data_type = '002_data_type'
8
  code_pos_tf32 = '003_tf32'
@@ -35,16 +36,8 @@ def get_inital_config():
35
 
36
  appConfig = load_app_config()
37
 
38
- # default model is None
39
  model_configs = appConfig.get("models", {})
40
- # default model is None
41
- models = list(model_configs.keys())
42
- model = None
43
-
44
- # default scheduler is None
45
  scheduler_configs = appConfig.get("schedulers", {})
46
- schedulers = list(scheduler_configs.keys())
47
- scheduler = None
48
 
49
  # default device
50
  devices = appConfig.get("devices", [])
@@ -65,7 +58,7 @@ def get_inital_config():
65
  "model": None,
66
  "scheduler": None,
67
  "variant": None,
68
- "allow_tensorflow32": allow_tensorfloat32,
69
  "use_safetensors": False,
70
  "data_type": data_type,
71
  "safety_checker": False,
@@ -78,21 +71,19 @@ def get_inital_config():
78
  }
79
 
80
  # code output order
81
- code = {}
82
-
83
  code[code_pos_device] = f'device = "{device}"'
84
- code[code_pos_variant] = f'variant = {initial_config['variant']}'
85
- code[code_pos_tf32] = f'torch.backends.cuda.matmul.allow_tf32 = {initial_config['allow_tensorfloat32']}'
86
  code[code_pos_data_type] = 'data_type = torch.bfloat16'
87
  code[code_pos_init_pipeline] = 'sys.exit("No model selected!")'
88
  code[code_pos_safety_checker] = 'pipeline.safety_checker = None'
89
- code[code_pos_requires_safety_checker] = f'pipeline.requires_safety_checker = {initial_config['requires_safety_checker']}'
90
  code[code_pos_scheduler] = 'sys.exit("No scheduler selected!")'
91
  code[code_pos_generator] = f'generator = torch.Generator("{device}")'
92
- code[code_pos_prompt] = f'prompt = "{initial_config['prompt']}"'
93
- code[code_pos_negative_prompt] = f'negative_prompt = "{initial_config['negative_prompt']}"'
94
- code[code_pos_inference_steps] = f'inference_steps = {initial_config['inference_steps']}'
95
- code[code_pos_guidance_scale] = f'guidance_scale = {initial_config['guidance_scale']}'
96
  code[code_pos_run_inference] = f'''image = pipeline(
97
  prompt=prompt,
98
  negative_prompt=negative_prompt,
@@ -100,7 +91,7 @@ def get_inital_config():
100
  num_inference_steps=inference_steps,
101
  guidance_scale=guidance_scale).images[0]'''
102
 
103
- return initial_config, devices, models, schedulers, code
104
 
105
  def init_config(request: gr.Request, inital_config):
106
 
 
3
  import json
4
  import torch
5
 
6
+ code = {}
7
  code_pos_device = '001_code'
8
  code_pos_data_type = '002_data_type'
9
  code_pos_tf32 = '003_tf32'
 
36
 
37
  appConfig = load_app_config()
38
 
 
39
  model_configs = appConfig.get("models", {})
 
 
 
 
 
40
  scheduler_configs = appConfig.get("schedulers", {})
 
 
41
 
42
  # default device
43
  devices = appConfig.get("devices", [])
 
58
  "model": None,
59
  "scheduler": None,
60
  "variant": None,
61
+ "allow_tensorfloat32": allow_tensorfloat32,
62
  "use_safetensors": False,
63
  "data_type": data_type,
64
  "safety_checker": False,
 
71
  }
72
 
73
  # code output order
 
 
74
  code[code_pos_device] = f'device = "{device}"'
75
+ code[code_pos_variant] = f'variant = {initial_config["variant"]}'
76
+ code[code_pos_tf32] = f'torch.backends.cuda.matmul.allow_tf32 = {initial_config["allow_tensorfloat32"]}'
77
  code[code_pos_data_type] = 'data_type = torch.bfloat16'
78
  code[code_pos_init_pipeline] = 'sys.exit("No model selected!")'
79
  code[code_pos_safety_checker] = 'pipeline.safety_checker = None'
80
+ code[code_pos_requires_safety_checker] = f'pipeline.requires_safety_checker = {initial_config["requires_safety_checker"]}'
81
  code[code_pos_scheduler] = 'sys.exit("No scheduler selected!")'
82
  code[code_pos_generator] = f'generator = torch.Generator("{device}")'
83
+ code[code_pos_prompt] = f'prompt = "{initial_config["prompt"]}"'
84
+ code[code_pos_negative_prompt] = f'negative_prompt = "{initial_config["negative_prompt"]}"'
85
+ code[code_pos_inference_steps] = f'inference_steps = {initial_config["inference_steps"]}'
86
+ code[code_pos_guidance_scale] = f'guidance_scale = {initial_config["guidance_scale"]}'
87
  code[code_pos_run_inference] = f'''image = pipeline(
88
  prompt=prompt,
89
  negative_prompt=negative_prompt,
 
91
  num_inference_steps=inference_steps,
92
  guidance_scale=guidance_scale).images[0]'''
93
 
94
+ return initial_config, devices, model_configs, scheduler_configs, code
95
 
96
  def init_config(request: gr.Request, inital_config):
97