Chao Xu commited on
Commit
c0c3e1b
1 Parent(s): 6c1250a
Files changed (2) hide show
  1. sam_utils.py +3 -57
  2. zero123_utils.py +4 -4
sam_utils.py CHANGED
@@ -1,14 +1,10 @@
1
  import os
2
  import numpy as np
3
  import torch
4
- # import matplotlib.pyplot as plt
5
- import cv2
6
  from PIL import Image
7
- # from PIL import Image
8
  import time
9
- from utils import find_image_file
10
 
11
- from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
12
 
13
  def sam_init(device_id=0):
14
  import inspect
@@ -22,60 +18,11 @@ def sam_init(device_id=0):
22
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
23
  sam.to(device=device)
24
  predictor = SamPredictor(sam)
25
- # mask_generator = SamAutomaticMaskGenerator(sam)
26
  return predictor
27
 
28
- def sam_out(predictor, shape_dir):
29
- image_path = os.path.join(shape_dir, find_image_file(shape_dir))
30
- save_path = os.path.join(shape_dir, "image_sam.png")
31
- bbox_path = os.path.join(shape_dir, "bbox.txt")
32
- bbox = np.loadtxt(bbox_path, delimiter=',')
33
- image = cv2.imread(image_path)
34
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
35
-
36
- start_time = time.time()
37
- predictor.set_image(image)
38
-
39
- h, w, _ = image.shape
40
- input_point = np.array([[h//2, w//2]])
41
- input_label = np.array([1])
42
-
43
- masks, scores, logits = predictor.predict(
44
- point_coords=input_point,
45
- point_labels=input_label,
46
- multimask_output=True,
47
- )
48
-
49
- masks_bbox, scores_bbox, logits_bbox = predictor.predict(
50
- box=bbox,
51
- multimask_output=True
52
- )
53
-
54
- print(f"SAM Time: {time.time() - start_time:.3f}s")
55
- opt_idx = np.argmax(scores)
56
- mask = masks[opt_idx]
57
- out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
58
- out_image[:, :, :3] = image
59
- out_image_bbox = out_image.copy()
60
- out_image[:, :, 3] = mask.astype(np.uint8) * 255
61
- out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 # np.argmax(scores_bbox)
62
- cv2.imwrite(save_path, cv2.cvtColor(out_image_bbox, cv2.COLOR_RGBA2BGRA))
63
-
64
-
65
- def convert_from_cv2_to_image(img: np.ndarray) -> Image:
66
- return Image.fromarray(img)
67
- # return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA))
68
-
69
- def convert_from_image_to_cv2(img: Image) -> np.ndarray:
70
- return np.asarray(img)
71
- # return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
72
-
73
  def sam_out_nosave(predictor, input_image, *bbox_sliders):
74
- # save_path = os.path.join(shape_dir, "image_sam.png")
75
- # bbox_path = os.path.join(shape_dir, "bbox.txt")
76
- # bbox = np.loadtxt(bbox_path, delimiter=',')
77
  bbox = np.array(bbox_sliders)
78
- image = convert_from_image_to_cv2(input_image)
79
 
80
  start_time = time.time()
81
  predictor.set_image(image)
@@ -104,5 +51,4 @@ def sam_out_nosave(predictor, input_image, *bbox_sliders):
104
  out_image[:, :, 3] = mask.astype(np.uint8) * 255
105
  out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 # np.argmax(scores_bbox)
106
  torch.cuda.empty_cache()
107
- return Image.fromarray(out_image_bbox, mode='RGBA')
108
- cv2.imwrite(save_path, cv2.cvtColor(out_image_bbox, cv2.COLOR_RGBA2BGRA))
 
1
  import os
2
  import numpy as np
3
  import torch
 
 
4
  from PIL import Image
 
5
  import time
 
6
 
7
+ from segment_anything import sam_model_registry, SamPredictor
8
 
9
  def sam_init(device_id=0):
10
  import inspect
 
18
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
19
  sam.to(device=device)
20
  predictor = SamPredictor(sam)
 
21
  return predictor
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def sam_out_nosave(predictor, input_image, *bbox_sliders):
 
 
 
24
  bbox = np.array(bbox_sliders)
25
+ image = np.asarray(input_image)
26
 
27
  start_time = time.time()
28
  predictor.set_image(image)
 
51
  out_image[:, :, 3] = mask.astype(np.uint8) * 255
52
  out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 # np.argmax(scores_bbox)
53
  torch.cuda.empty_cache()
54
+ return Image.fromarray(out_image_bbox, mode='RGBA')
 
zero123_utils.py CHANGED
@@ -76,7 +76,7 @@ def sample_model_batch(model, sampler, input_im, xs, ys, n_samples=4, precision=
76
  cond = {}
77
  cond['c_crossattn'] = [c]
78
  # c_concat = model.encode_first_stage((input_im.to(c.device))).mode().detach()
79
- cond['c_concat'] = [model.encode_first_stage((input_im.to(c.device))).mode().detach()
80
  .repeat(n_samples, 1, 1, 1)]
81
  if scale != 1.0:
82
  uc = {}
@@ -99,7 +99,8 @@ def sample_model_batch(model, sampler, input_im, xs, ys, n_samples=4, precision=
99
  # samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False)
100
  x_samples_ddim = model.decode_first_stage(samples_ddim)
101
  ret_imgs = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu()
102
- del cond, c, x_samples_ddim, samples_ddim, uc
 
103
  return ret_imgs
104
 
105
 
@@ -126,6 +127,7 @@ def predict_stage1(model, sampler, input_img_path, save_path_8, adjust_set=[], d
126
  del input_im
127
  torch.cuda.empty_cache()
128
 
 
129
  def predict_stage1_gradio(model, raw_im, save_path = "", adjust_set=[], device="cuda", ddim_steps=75, scale=3.0):
130
  # raw_im = raw_im.resize([256, 256], Image.LANCZOS)
131
  # input_im_init = preprocess_image(models, raw_im, preprocess=False)
@@ -157,7 +159,6 @@ def predict_stage1_gradio(model, raw_im, save_path = "", adjust_set=[], device="
157
  out_image.save(os.path.join(save_path, '%d.png'%(stage1_idx)))
158
  sample_idx += 1
159
  del x_samples_ddims_8
160
- del input_im
161
  del sampler
162
  torch.cuda.empty_cache()
163
  return ret_imgs
@@ -188,7 +189,6 @@ def infer_stage_2(model, save_path_stage1, save_path_stage2, delta_x_2, delta_y_
188
  x_sample_stage2 = 255.0 * rearrange(x_samples_ddims_stage2[stage2_idx].numpy(), 'c h w -> h w c')
189
  Image.fromarray(x_sample_stage2.astype(np.uint8)).save(os.path.join(save_path_stage2, '%d_%d.png'%(stage1_idx, stage2_idx)))
190
  del input_im
191
- del sampler
192
  del x_samples_ddims_stage2
193
  torch.cuda.empty_cache()
194
 
 
76
  cond = {}
77
  cond['c_crossattn'] = [c]
78
  # c_concat = model.encode_first_stage((input_im.to(c.device))).mode().detach()
79
+ cond['c_concat'] = [model.encode_first_stage(input_im).mode().detach()
80
  .repeat(n_samples, 1, 1, 1)]
81
  if scale != 1.0:
82
  uc = {}
 
99
  # samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False)
100
  x_samples_ddim = model.decode_first_stage(samples_ddim)
101
  ret_imgs = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu()
102
+ del cond, c, x_samples_ddim, samples_ddim, uc, input_im
103
+ torch.cuda.empty_cache()
104
  return ret_imgs
105
 
106
 
 
127
  del input_im
128
  torch.cuda.empty_cache()
129
 
130
+ @torch.no_grad()
131
  def predict_stage1_gradio(model, raw_im, save_path = "", adjust_set=[], device="cuda", ddim_steps=75, scale=3.0):
132
  # raw_im = raw_im.resize([256, 256], Image.LANCZOS)
133
  # input_im_init = preprocess_image(models, raw_im, preprocess=False)
 
159
  out_image.save(os.path.join(save_path, '%d.png'%(stage1_idx)))
160
  sample_idx += 1
161
  del x_samples_ddims_8
 
162
  del sampler
163
  torch.cuda.empty_cache()
164
  return ret_imgs
 
189
  x_sample_stage2 = 255.0 * rearrange(x_samples_ddims_stage2[stage2_idx].numpy(), 'c h w -> h w c')
190
  Image.fromarray(x_sample_stage2.astype(np.uint8)).save(os.path.join(save_path_stage2, '%d_%d.png'%(stage1_idx, stage2_idx)))
191
  del input_im
 
192
  del x_samples_ddims_stage2
193
  torch.cuda.empty_cache()
194