erwann commited on
Commit
f2e0f04
1 Parent(s): f9d605e

fix memory issue:

Browse files
Files changed (2) hide show
  1. ImageState.py +3 -14
  2. backend.py +1 -0
ImageState.py CHANGED
@@ -1,4 +1,5 @@
1
  # from align import align_from_path
 
2
  import imageio
3
  import glob
4
  import uuid
@@ -21,8 +22,6 @@ from edit import blend_paths
21
  from img_processing import *
22
  from img_processing import custom_to_pil
23
  from loaders import load_default
24
- # from app import vqgan
25
- global vqgan
26
  num = 0
27
  class PromptTransformHistory():
28
  def __init__(self, iterations) -> None:
@@ -42,7 +41,6 @@ class ImageState:
42
  self.attn_mask = None
43
  self.prompt_optim = prompt_optimizer
44
  self.state_id = None
45
- # print("NEW INSTANCE")
46
  print(self.state_id)
47
  self._load_vectors()
48
  self.init_transforms()
@@ -65,8 +63,6 @@ class ImageState:
65
  if file_name.endswith('.png'):
66
  file_path = os.path.join(folder, file_name)
67
  images.append(imageio.imread(file_path))
68
- # images[0] = images[0].set_meta_data({'duration': 1})
69
- # images[-1] = images[-1].set_meta_data({'duration': 1})
70
  imageio.mimsave(gif_name, images, duration=durations)
71
  return gif_name
72
  def init_transforms(self):
@@ -85,12 +81,8 @@ class ImageState:
85
  new_latent = torch.lerp(src, src + vector, 1)
86
  return new_latent
87
  def _decode_latent_to_pil(self, latent):
88
- # global vqgan
89
  current_im = self.vqgan.decode(latent.to(self.device))[0]
90
  return custom_to_pil(current_im)
91
- # def _get_current_vector_transforms(self):
92
- # current_vector_transforms = (self.blue_eyes, self.lip_size, self.hair_gp, self.asian_transform, sum(self.current_prompt_transforms))
93
- # return (self.blend_latent, current_vector_transforms)
94
  def _get_mask(self, img, mask=None):
95
  if img and "mask" in img and img["mask"] is not None:
96
  attn_mask = torchvision.transforms.ToTensor()(img["mask"])
@@ -180,11 +172,6 @@ class ImageState:
180
  print(latent_index)
181
  self.current_prompt_transforms[-1] = prompt_transform.transforms[latent_index].to(self.device)
182
  return self._render_all_transformations()
183
- # def rescale_mask(self, mask):
184
- # rep = mask.clone()
185
- # rep[mask < 0.03] = -1000000
186
- # rep[mask >= 0.03] = 1
187
- # return rep
188
  def apply_prompts(self, positive_prompts, negative_prompts, lr, iterations, lpips_weight, reconstruction_steps):
189
  if self.state_id is None:
190
  self.state_id = "./" + str(uuid.uuid4())
@@ -217,6 +204,8 @@ class ImageState:
217
  wandb.finish()
218
  self.attn_mask = None
219
  self.transform_history.append(transform_log)
 
 
220
  # transform = self.prompt_optim.optimize(self.blend_latent,
221
  # positive_prompts,
222
  # negative_prompts)
 
1
  # from align import align_from_path
2
+ import gc
3
  import imageio
4
  import glob
5
  import uuid
 
22
  from img_processing import *
23
  from img_processing import custom_to_pil
24
  from loaders import load_default
 
 
25
  num = 0
26
  class PromptTransformHistory():
27
  def __init__(self, iterations) -> None:
 
41
  self.attn_mask = None
42
  self.prompt_optim = prompt_optimizer
43
  self.state_id = None
 
44
  print(self.state_id)
45
  self._load_vectors()
46
  self.init_transforms()
 
63
  if file_name.endswith('.png'):
64
  file_path = os.path.join(folder, file_name)
65
  images.append(imageio.imread(file_path))
 
 
66
  imageio.mimsave(gif_name, images, duration=durations)
67
  return gif_name
68
  def init_transforms(self):
 
81
  new_latent = torch.lerp(src, src + vector, 1)
82
  return new_latent
83
  def _decode_latent_to_pil(self, latent):
 
84
  current_im = self.vqgan.decode(latent.to(self.device))[0]
85
  return custom_to_pil(current_im)
 
 
 
86
  def _get_mask(self, img, mask=None):
87
  if img and "mask" in img and img["mask"] is not None:
88
  attn_mask = torchvision.transforms.ToTensor()(img["mask"])
 
172
  print(latent_index)
173
  self.current_prompt_transforms[-1] = prompt_transform.transforms[latent_index].to(self.device)
174
  return self._render_all_transformations()
 
 
 
 
 
175
  def apply_prompts(self, positive_prompts, negative_prompts, lr, iterations, lpips_weight, reconstruction_steps):
176
  if self.state_id is None:
177
  self.state_id = "./" + str(uuid.uuid4())
 
204
  wandb.finish()
205
  self.attn_mask = None
206
  self.transform_history.append(transform_log)
207
+ gc.collect()
208
+ torch.cuda.empty_cache()
209
  # transform = self.prompt_optim.optimize(self.blend_latent,
210
  # positive_prompts,
211
  # negative_prompts)
backend.py CHANGED
@@ -17,6 +17,7 @@ from img_processing import *
17
  from img_processing import custom_to_pil
18
  from loaders import load_default
19
  import glob
 
20
 
21
  global log
22
  log=False
 
17
  from img_processing import custom_to_pil
18
  from loaders import load_default
19
  import glob
20
+ import gc
21
 
22
  global log
23
  log=False