JadenFK commited on
Commit
f71eb42
1 Parent(s): 0166058

Gpu mem stuff

Browse files
Files changed (1) hide show
  1. app.py +29 -7
app.py CHANGED
@@ -6,6 +6,7 @@ from convertModels import convert_ldm_unet_checkpoint, create_unet_diffusers_con
6
  from omegaconf import OmegaConf
7
  from StableDiffuser import StableDiffuser
8
  from diffusers import UNet2DConditionModel
 
9
 
10
  ckpt_path = "stable_diffusion/models/ldm/sd-v1-4-full-ema.ckpt"
11
  config_path = "stable_diffusion/configs/stable-diffusion/v1-inference.yaml"
@@ -18,10 +19,16 @@ class Demo:
18
 
19
  self.training = False
20
  self.generating = False
 
 
 
 
 
 
21
 
22
  with gr.Blocks() as demo:
23
  self.layout()
24
- demo.queue(concurrency_count=10).launch()
25
 
26
  def disable(self):
27
  return [gr.update(interactive=False), gr.update(interactive=False)]
@@ -131,6 +138,8 @@ class Demo:
131
  else:
132
  self.training = True
133
 
 
 
134
  model_orig, model_edited = train_esd(prompt,
135
  train_method,
136
  3,
@@ -146,8 +155,16 @@ class Demo:
146
  original_config = OmegaConf.load(config_path)
147
  original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = 4
148
  unet_config = create_unet_diffusers_config(original_config, image_size=512)
149
- model_edited_sd = convert_ldm_unet_checkpoint(model_edited.state_dict(), unet_config)
150
- model_orig_sd = convert_ldm_unet_checkpoint(model_orig.state_dict(), unet_config)
 
 
 
 
 
 
 
 
151
 
152
  self.init_inference(model_edited_sd, model_orig_sd, unet_config)
153
 
@@ -155,16 +172,17 @@ class Demo:
155
 
156
  def init_inference(self, model_edited_sd, model_orig_sd, unet_config):
157
 
 
 
 
 
158
  self.model_edited_sd = model_edited_sd
159
  self.model_orig_sd = model_orig_sd
160
 
161
- self.diffuser = StableDiffuser(42)
162
-
163
- self.diffuser.unet = UNet2DConditionModel(**unet_config)
164
  self.diffuser.to('cuda')
165
 
166
  self.training = False
167
-
168
 
169
  def inference(self, prompt, seed, pbar = gr.Progress(track_tqdm=True)):
170
 
@@ -185,6 +203,8 @@ class Demo:
185
 
186
  orig_image = images[0][0]
187
 
 
 
188
  self.diffuser.unet.load_state_dict(self.model_edited_sd)
189
 
190
  images = self.diffuser(
@@ -197,6 +217,8 @@ class Demo:
197
 
198
  self.generating = False
199
 
 
 
200
  return edited_image, orig_image
201
 
202
 
 
6
  from omegaconf import OmegaConf
7
  from StableDiffuser import StableDiffuser
8
  from diffusers import UNet2DConditionModel
9
+ import torch
10
 
11
  ckpt_path = "stable_diffusion/models/ldm/sd-v1-4-full-ema.ckpt"
12
  config_path = "stable_diffusion/configs/stable-diffusion/v1-inference.yaml"
 
19
 
20
  self.training = False
21
  self.generating = False
22
+ self.model_edited_sd = None
23
+ self.model_orig_sd = None
24
+
25
+ self.diffuser = StableDiffuser(42)
26
+ self.diffuser.to('cpu')
27
+ self.diffuser = self.diffuser.half()
28
 
29
  with gr.Blocks() as demo:
30
  self.layout()
31
+ demo.queue(concurrency_count=1).launch()
32
 
33
  def disable(self):
34
  return [gr.update(interactive=False), gr.update(interactive=False)]
 
138
  else:
139
  self.training = True
140
 
141
+ self.diffuser.to('cpu')
142
+
143
  model_orig, model_edited = train_esd(prompt,
144
  train_method,
145
  3,
 
155
  original_config = OmegaConf.load(config_path)
156
  original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = 4
157
  unet_config = create_unet_diffusers_config(original_config, image_size=512)
158
+ _model_edited_sd = convert_ldm_unet_checkpoint(model_edited.state_dict(), unet_config)
159
+ _model_orig_sd = convert_ldm_unet_checkpoint(model_orig.state_dict(), unet_config)
160
+
161
+ model_edited_sd = {key: value.cpu() for key, value in _model_edited_sd.items()}
162
+ model_orig_sd = {key: value.cpu() for key, value in _model_orig_sd.items()}
163
+
164
+ del model_orig, _model_orig_sd
165
+ del model_edited, _model_edited_sd
166
+
167
+ torch.cuda.empty_cache()
168
 
169
  self.init_inference(model_edited_sd, model_orig_sd, unet_config)
170
 
 
172
 
173
  def init_inference(self, model_edited_sd, model_orig_sd, unet_config):
174
 
175
+ del self.model_edited_sd, self.model_orig_sd
176
+
177
+ torch.cuda.empty_cache()
178
+
179
  self.model_edited_sd = model_edited_sd
180
  self.model_orig_sd = model_orig_sd
181
 
 
 
 
182
  self.diffuser.to('cuda')
183
 
184
  self.training = False
185
+
186
 
187
  def inference(self, prompt, seed, pbar = gr.Progress(track_tqdm=True)):
188
 
 
203
 
204
  orig_image = images[0][0]
205
 
206
+ torch.cuda.empty_cache()
207
+
208
  self.diffuser.unet.load_state_dict(self.model_edited_sd)
209
 
210
  images = self.diffuser(
 
217
 
218
  self.generating = False
219
 
220
+ torch.cuda.empty_cache()
221
+
222
  return edited_image, orig_image
223
 
224