Chao Xu commited on
Commit
6c1250a
1 Parent(s): 0e93edd

empty cache

Browse files
Files changed (2) hide show
  1. sam_utils.py +1 -0
  2. zero123_utils.py +8 -9
sam_utils.py CHANGED
@@ -103,5 +103,6 @@ def sam_out_nosave(predictor, input_image, *bbox_sliders):
103
  out_image_bbox = out_image.copy()
104
  out_image[:, :, 3] = mask.astype(np.uint8) * 255
105
  out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 # np.argmax(scores_bbox)
 
106
  return Image.fromarray(out_image_bbox, mode='RGBA')
107
  cv2.imwrite(save_path, cv2.cvtColor(out_image_bbox, cv2.COLOR_RGBA2BGRA))
 
103
  out_image_bbox = out_image.copy()
104
  out_image[:, :, 3] = mask.astype(np.uint8) * 255
105
  out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 # np.argmax(scores_bbox)
106
+ torch.cuda.empty_cache()
107
  return Image.fromarray(out_image_bbox, mode='RGBA')
108
  cv2.imwrite(save_path, cv2.cvtColor(out_image_bbox, cv2.COLOR_RGBA2BGRA))
zero123_utils.py CHANGED
@@ -61,9 +61,9 @@ def init_model(device, ckpt):
61
  return models
62
 
63
  @torch.no_grad()
64
- def sample_model_batch(model, sampler, input_im, xs, ys, n_samples=4, precision='fp32', ddim_eta=1.0, ddim_steps=75, scale=3.0, h=256, w=256):
65
  precision_scope = autocast if precision == 'autocast' else nullcontext
66
- with precision_scope(model.device):
67
  with model.ema_scope():
68
  c = model.get_learned_conditioning(input_im).tile(n_samples, 1, 1)
69
  T = []
@@ -98,7 +98,9 @@ def sample_model_batch(model, sampler, input_im, xs, ys, n_samples=4, precision=
98
  print(samples_ddim.shape)
99
  # samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False)
100
  x_samples_ddim = model.decode_first_stage(samples_ddim)
101
- return torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu()
 
 
102
 
103
 
104
  def predict_stage1(model, sampler, input_img_path, save_path_8, adjust_set=[], device="cuda"):
@@ -118,7 +120,7 @@ def predict_stage1(model, sampler, input_img_path, save_path_8, adjust_set=[], d
118
  for stage1_idx in range(len(x_samples_ddims_8)):
119
  if adjust_set != [] and stage1_idx not in adjust_set:
120
  continue
121
- x_sample = 255.0 * rearrange(x_samples_ddims_8[stage1_idx].cpu().numpy(), 'c h w -> h w c')
122
  Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(save_path_8, '%d.png'%(stage1_idx)))
123
  del x_samples_ddims_8
124
  del input_im
@@ -148,7 +150,7 @@ def predict_stage1_gradio(model, raw_im, save_path = "", adjust_set=[], device="
148
  for stage1_idx in range(len(delta_x_1_8)):
149
  if adjust_set != [] and stage1_idx not in adjust_set:
150
  continue
151
- x_sample = 255.0 * rearrange(x_samples_ddims_8[sample_idx].cpu().numpy(), 'c h w -> h w c')
152
  out_image = Image.fromarray(x_sample.astype(np.uint8))
153
  ret_imgs.append(out_image)
154
  if save_path:
@@ -177,16 +179,13 @@ def infer_stage_2(model, save_path_stage1, save_path_stage2, delta_x_2, delta_y_
177
  input_im_init = input_im_init / 255.0
178
  input_im = transforms.ToTensor()(input_im_init).unsqueeze(0).to(device)
179
  input_im = input_im * 2 - 1
180
- print("debug input device", input_im.device)
181
- print("debug model device", model.device)
182
  # infer stage 2
183
  sampler = DDIMSampler(model)
184
- print("debug sampler device", sampler.device)
185
  # sampler.to(device)
186
  # stage2_in = x_samples_ddims[stage1_idx][None, ...].to(device) * 2 - 1
187
  x_samples_ddims_stage2 = sample_model_batch(model, sampler, input_im, delta_x_2, delta_y_2, n_samples=len(delta_x_2), ddim_steps=ddim_steps, scale=scale)
188
  for stage2_idx in range(len(delta_x_2)):
189
- x_sample_stage2 = 255.0 * rearrange(x_samples_ddims_stage2[stage2_idx].cpu().numpy(), 'c h w -> h w c')
190
  Image.fromarray(x_sample_stage2.astype(np.uint8)).save(os.path.join(save_path_stage2, '%d_%d.png'%(stage1_idx, stage2_idx)))
191
  del input_im
192
  del sampler
 
61
  return models
62
 
63
  @torch.no_grad()
64
+ def sample_model_batch(model, sampler, input_im, xs, ys, n_samples=4, precision='autocast', ddim_eta=1.0, ddim_steps=75, scale=3.0, h=256, w=256):
65
  precision_scope = autocast if precision == 'autocast' else nullcontext
66
+ with precision_scope("cuda"):
67
  with model.ema_scope():
68
  c = model.get_learned_conditioning(input_im).tile(n_samples, 1, 1)
69
  T = []
 
98
  print(samples_ddim.shape)
99
  # samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False)
100
  x_samples_ddim = model.decode_first_stage(samples_ddim)
101
+ ret_imgs = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu()
102
+ del cond, c, x_samples_ddim, samples_ddim, uc
103
+ return ret_imgs
104
 
105
 
106
  def predict_stage1(model, sampler, input_img_path, save_path_8, adjust_set=[], device="cuda"):
 
120
  for stage1_idx in range(len(x_samples_ddims_8)):
121
  if adjust_set != [] and stage1_idx not in adjust_set:
122
  continue
123
+ x_sample = 255.0 * rearrange(x_samples_ddims_8[stage1_idx].numpy(), 'c h w -> h w c')
124
  Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(save_path_8, '%d.png'%(stage1_idx)))
125
  del x_samples_ddims_8
126
  del input_im
 
150
  for stage1_idx in range(len(delta_x_1_8)):
151
  if adjust_set != [] and stage1_idx not in adjust_set:
152
  continue
153
+ x_sample = 255.0 * rearrange(x_samples_ddims_8[sample_idx].numpy(), 'c h w -> h w c')
154
  out_image = Image.fromarray(x_sample.astype(np.uint8))
155
  ret_imgs.append(out_image)
156
  if save_path:
 
179
  input_im_init = input_im_init / 255.0
180
  input_im = transforms.ToTensor()(input_im_init).unsqueeze(0).to(device)
181
  input_im = input_im * 2 - 1
 
 
182
  # infer stage 2
183
  sampler = DDIMSampler(model)
 
184
  # sampler.to(device)
185
  # stage2_in = x_samples_ddims[stage1_idx][None, ...].to(device) * 2 - 1
186
  x_samples_ddims_stage2 = sample_model_batch(model, sampler, input_im, delta_x_2, delta_y_2, n_samples=len(delta_x_2), ddim_steps=ddim_steps, scale=scale)
187
  for stage2_idx in range(len(delta_x_2)):
188
+ x_sample_stage2 = 255.0 * rearrange(x_samples_ddims_stage2[stage2_idx].numpy(), 'c h w -> h w c')
189
  Image.fromarray(x_sample_stage2.astype(np.uint8)).save(os.path.join(save_path_stage2, '%d_%d.png'%(stage1_idx, stage2_idx)))
190
  del input_im
191
  del sampler