n42 commited on
Commit
15347cd
·
1 Parent(s): d9d145c

add cpu offload option

Browse files
Files changed (2) hide show
  1. app.py +15 -1
  2. config.py +6 -1
app.py CHANGED
@@ -17,6 +17,12 @@ def model_refiner_change(refiner, config):
17
 
18
  return config, str(config), assemble_code(config)
19
 
 
 
 
 
 
 
20
  def models_change(model, scheduler, config):
21
 
22
  config = set_config(config, 'model', model)
@@ -171,6 +177,9 @@ def run_inference(config, config_history, progress=gr.Progress(track_tqdm=True))
171
  num_inference_steps = int(config["inference_steps"]),
172
  guidance_scale = float(config["guidance_scale"])).images
173
 
 
 
 
174
  if config['refiner'] != '':
175
  image = refiner(
176
  prompt = config["prompt"],
@@ -178,6 +187,9 @@ def run_inference(config, config_history, progress=gr.Progress(track_tqdm=True))
178
  image=image,
179
  ).images
180
 
 
 
 
181
  config_history.append(config.copy())
182
 
183
  return image[0], dict_list_to_markdown_table(config_history), config_history
@@ -216,6 +228,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
216
  with gr.Row():
217
  with gr.Column(scale=1):
218
  in_use_safetensors = gr.Radio(label="Use safe tensors:", choices=["True", "False"], interactive=False)
 
219
  in_model_refiner = gr.Dropdown(value="", choices=[""], label="Refiner", allow_custom_value=True, multiselect=False)
220
  with gr.Column(scale=1):
221
  in_safety_checker = gr.Radio(label="Enable safety checker:", value=config.value["safety_checker"], choices=["True", "False"])
@@ -223,7 +236,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
223
 
224
  gr.Markdown("### Scheduler")
225
  with gr.Row():
226
- in_schedulers = gr.Dropdown(choices=list(schedulers.keys()), label="Scheduler/Solver", info="the scheduler controls parameter adaption between each inference step, depending on the right scheduler for your model, it may only take 10 or 20 steps to achieve very good results, see https://huggingface.co/docs/diffusers/using-diffusers/loading#schedulers" )
227
  out_scheduler_description = gr.Textbox(value="", label="Description")
228
 
229
  gr.Markdown("### Adapters")
@@ -258,6 +271,7 @@ with gr.Blocks(analytics_enabled=False) as demo:
258
  in_variant.change(variant_change, inputs=[in_variant, config], outputs=[config, out_config, out_code])
259
  in_models.change(models_change, inputs=[in_models, in_schedulers, config], outputs=[out_model_description, in_model_refiner, in_use_safetensors, in_schedulers, config, out_config, out_code])
260
  in_model_refiner.change(model_refiner_change, inputs=[in_model_refiner, config], outputs=[config, out_config, out_code])
 
261
  in_safety_checker.change(safety_checker_change, inputs=[in_safety_checker, config], outputs=[config, out_config, out_code])
262
  in_requires_safety_checker.change(requires_safety_checker_change, inputs=[in_requires_safety_checker, config], outputs=[config, out_config, out_code])
263
  in_schedulers.change(schedulers_change, inputs=[in_schedulers, config], outputs=[out_scheduler_description, config, out_config, out_code])
 
17
 
18
  return config, str(config), assemble_code(config)
19
 
20
+ def cpu_offload_change(cpu_offload, config):
21
+
22
+ config = set_config(config, 'cpu_offload', cpu_offload)
23
+
24
+ return config, str(config), assemble_code(config)
25
+
26
  def models_change(model, scheduler, config):
27
 
28
  config = set_config(config, 'model', model)
 
177
  num_inference_steps = int(config["inference_steps"]),
178
  guidance_scale = float(config["guidance_scale"])).images
179
 
180
+ if str(config["use_safetensors"]).lower() != 'false':
181
+ pipeline.enable_model_cpu_offload()
182
+
183
  if config['refiner'] != '':
184
  image = refiner(
185
  prompt = config["prompt"],
 
187
  image=image,
188
  ).images
189
 
190
+ if str(config["use_safetensors"]).lower() != 'false':
191
+ refiner.enable_model_cpu_offload()
192
+
193
  config_history.append(config.copy())
194
 
195
  return image[0], dict_list_to_markdown_table(config_history), config_history
 
228
  with gr.Row():
229
  with gr.Column(scale=1):
230
  in_use_safetensors = gr.Radio(label="Use safe tensors:", choices=["True", "False"], interactive=False)
231
+ in_cpu_offload = gr.Radio(label="CPU Offload:", choices=["True", "False"], interactive=False, info="This may increase performance, as it offloads computations from the GPU to the CPU. But this can also lead to slower executions and lower effectiveness. Compare running time and outputs before making sure, that this setting will help you")
232
  in_model_refiner = gr.Dropdown(value="", choices=[""], label="Refiner", allow_custom_value=True, multiselect=False)
233
  with gr.Column(scale=1):
234
  in_safety_checker = gr.Radio(label="Enable safety checker:", value=config.value["safety_checker"], choices=["True", "False"])
 
236
 
237
  gr.Markdown("### Scheduler")
238
  with gr.Row():
239
+ in_schedulers = gr.Dropdown(choices=list(schedulers.keys()), label="Scheduler/Solver", info="schedulers employ various strategies for noise control, the scheduler controls parameter adaption between each inference step, depending on the right scheduler for your model, it may only take 10 or 20 steps to achieve very good results, see https://huggingface.co/docs/diffusers/using-diffusers/loading#schedulers" )
240
  out_scheduler_description = gr.Textbox(value="", label="Description")
241
 
242
  gr.Markdown("### Adapters")
 
271
  in_variant.change(variant_change, inputs=[in_variant, config], outputs=[config, out_config, out_code])
272
  in_models.change(models_change, inputs=[in_models, in_schedulers, config], outputs=[out_model_description, in_model_refiner, in_use_safetensors, in_schedulers, config, out_config, out_code])
273
  in_model_refiner.change(model_refiner_change, inputs=[in_model_refiner, config], outputs=[config, out_config, out_code])
274
+ in_cpu_offload.change(cpu_offload_change, inputs=[in_cpu_offload, config], outputs=[config, out_config, out_code])
275
  in_safety_checker.change(safety_checker_change, inputs=[in_safety_checker, config], outputs=[config, out_config, out_code])
276
  in_requires_safety_checker.change(requires_safety_checker_change, inputs=[in_requires_safety_checker, config], outputs=[config, out_config, out_code])
277
  in_schedulers.change(schedulers_change, inputs=[in_schedulers, config], outputs=[out_scheduler_description, config, out_config, out_code])
config.py CHANGED
@@ -37,6 +37,7 @@ def get_initial_config():
37
  config = {
38
  "device": device,
39
  "model": None,
 
40
  "scheduler": None,
41
  "variant": None,
42
  "allow_tensorfloat32": allow_tensorfloat32,
@@ -141,8 +142,10 @@ def assemble_code(str_config):
141
  torch_dtype=data_type,
142
  variant=variant).to(device)'''
143
 
 
 
144
  if config['refiner'] != '':
145
- code['051_refiner'] = f'''refiner = DiffusionPipeline.from_pretrained(
146
  "{config['refiner']}",
147
  text_encoder_2 = base.text_encoder_2,
148
  vae = base.vae,
@@ -151,6 +154,8 @@ def assemble_code(str_config):
151
  variant=variant,
152
  ).to(device)'''
153
 
 
 
154
  code['054_requires_safety_checker'] = f'pipeline.requires_safety_checker = {config["requires_safety_checker"]}'
155
 
156
  if str(config["safety_checker"]).lower() == 'false':
 
37
  config = {
38
  "device": device,
39
  "model": None,
40
+ "cpu_offload": "False",
41
  "scheduler": None,
42
  "variant": None,
43
  "allow_tensorfloat32": allow_tensorfloat32,
 
142
  torch_dtype=data_type,
143
  variant=variant).to(device)'''
144
 
145
+ if str(config["use_safetensors"]).lower() == 'false': code['051_cpu_offload'] = "pipeline.enable_model_cpu_offload()"
146
+
147
  if config['refiner'] != '':
148
+ code['052_refiner'] = f'''refiner = DiffusionPipeline.from_pretrained(
149
  "{config['refiner']}",
150
  text_encoder_2 = base.text_encoder_2,
151
  vae = base.vae,
 
154
  variant=variant,
155
  ).to(device)'''
156
 
157
+ if str(config["use_safetensors"]).lower() == 'false': code['053_cpu_offload'] = "refiner.enable_model_cpu_offload()"
158
+
159
  code['054_requires_safety_checker'] = f'pipeline.requires_safety_checker = {config["requires_safety_checker"]}'
160
 
161
  if str(config["safety_checker"]).lower() == 'false':