liuyuan-pal commited on
Commit
36a325d
1 Parent(s): ab287b7
Files changed (4) hide show
  1. app.py +7 -9
  2. ckpt/sam_vit_h_4b8939.pth +3 -0
  3. requirements.txt +1 -0
  4. sam_utils.py +50 -0
app.py CHANGED
@@ -9,6 +9,7 @@ import fire
9
  from omegaconf import OmegaConf
10
 
11
  from ldm.util import add_margin, instantiate_from_config
 
12
 
13
  _TITLE = '''SyncDreamer: Generating Multiview-consistent Images from a Single-view Image'''
14
  _DESCRIPTION = '''
@@ -31,12 +32,6 @@ _USER_GUIDE3 = "Generated multiview images are shown below!"
31
 
32
  deployed = True
33
 
34
- def mask_prediction(mask_predictor, image_in: Image.Image):
35
- if image_in.mode=='RGBA':
36
- return image_in
37
- else:
38
- raise NotImplementedError
39
-
40
  def resize_inputs(image_input, crop_size):
41
  alpha_np = np.asarray(image_input)[:, :, 3]
42
  coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
@@ -58,6 +53,8 @@ def generate(model, batch_view_num, sample_num, cfg_scale, seed, image_input, el
58
  # prepare data
59
  image_input = np.asarray(image_input)
60
  image_input = image_input.astype(np.float32) / 255.0
 
 
61
  image_input = image_input[:, :, :3] * 2.0 - 1.0
62
  image_input = torch.from_numpy(image_input.astype(np.float32))
63
  elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
@@ -103,7 +100,8 @@ def run_demo():
103
  model = None
104
 
105
  # init sam model
106
- mask_predictor = None # sam_init(device_idx)
 
107
 
108
  # with open('instructions_12345.md', 'r') as f:
109
  # article = f.read()
@@ -144,7 +142,7 @@ def run_demo():
144
  fig0 = gr.Image(value=Image.open('assets/crop_size.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False)
145
 
146
  with gr.Column(scale=1):
147
- input_block = gr.Image(type='pil', image_mode='RGB', label="Input to SyncDreamer", height=256, interactive=False)
148
  elevation = gr.Slider(-10, 40, 30, step=5, label='Elevation angle', interactive=True)
149
  cfg_scale = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance', interactive=True)
150
  # sample_num = gr.Slider(1, 2, 2, step=1, label='Sample Num', interactive=True, info='How many instance (16 images per instance)')
@@ -156,7 +154,7 @@ def run_demo():
156
  output_block = gr.Image(type='pil', image_mode='RGB', label="Outputs of SyncDreamer", height=256, interactive=False)
157
 
158
  update_guide = lambda GUIDE_TEXT: gr.update(value=GUIDE_TEXT)
159
- image_block.change(fn=partial(mask_prediction, mask_predictor), inputs=[image_block], outputs=[sam_block], queue=False)\
160
  .success(fn=partial(update_guide, _USER_GUIDE1), outputs=[guide_text], queue=False)
161
 
162
  crop_size_slider.change(fn=resize_inputs, inputs=[sam_block, crop_size_slider], outputs=[input_block], queue=False)\
 
9
  from omegaconf import OmegaConf
10
 
11
  from ldm.util import add_margin, instantiate_from_config
12
+ from sam_utils import sam_init, sam_out_nosave
13
 
14
  _TITLE = '''SyncDreamer: Generating Multiview-consistent Images from a Single-view Image'''
15
  _DESCRIPTION = '''
 
32
 
33
  deployed = True
34
 
 
 
 
 
 
 
35
  def resize_inputs(image_input, crop_size):
36
  alpha_np = np.asarray(image_input)[:, :, 3]
37
  coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
 
53
  # prepare data
54
  image_input = np.asarray(image_input)
55
  image_input = image_input.astype(np.float32) / 255.0
56
+ alpha_values = image_input[:,:, 3:]
57
+ image_input[:, :, :3] = alpha_values * image_input[:,:, :3] + 1 - alpha_values # white background
58
  image_input = image_input[:, :, :3] * 2.0 - 1.0
59
  image_input = torch.from_numpy(image_input.astype(np.float32))
60
  elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
 
100
  model = None
101
 
102
  # init sam model
103
+ mask_predictor = sam_init()
104
+ mask_predict_fn = lambda x: sam_out_nosave(mask_predictor, x)
105
 
106
  # with open('instructions_12345.md', 'r') as f:
107
  # article = f.read()
 
142
  fig0 = gr.Image(value=Image.open('assets/crop_size.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False)
143
 
144
  with gr.Column(scale=1):
145
+ input_block = gr.Image(type='pil', image_mode='RGBA', label="Input to SyncDreamer", height=256, interactive=False)
146
  elevation = gr.Slider(-10, 40, 30, step=5, label='Elevation angle', interactive=True)
147
  cfg_scale = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance', interactive=True)
148
  # sample_num = gr.Slider(1, 2, 2, step=1, label='Sample Num', interactive=True, info='How many instance (16 images per instance)')
 
154
  output_block = gr.Image(type='pil', image_mode='RGB', label="Outputs of SyncDreamer", height=256, interactive=False)
155
 
156
  update_guide = lambda GUIDE_TEXT: gr.update(value=GUIDE_TEXT)
157
+ image_block.change(fn=mask_predict_fn, inputs=[image_block], outputs=[sam_block], queue=False)\
158
  .success(fn=partial(update_guide, _USER_GUIDE1), outputs=[guide_text], queue=False)
159
 
160
  crop_size_slider.change(fn=resize_inputs, inputs=[sam_block, crop_size_slider], outputs=[input_block], queue=False)\
ckpt/sam_vit_h_4b8939.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
+ size 2564550879
requirements.txt CHANGED
@@ -20,4 +20,5 @@ easydict
20
  nerfacc
21
  imageio-ffmpeg==0.4.7
22
  fire
 
23
  git+https://github.com/openai/CLIP.git
 
20
  nerfacc
21
  imageio-ffmpeg==0.4.7
22
  fire
23
+ segment_anything
24
  git+https://github.com/openai/CLIP.git
sam_utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+ import time
6
+
7
+ from segment_anything import sam_model_registry, SamPredictor
8
+
9
+ def sam_init(device_id=0):
10
+ sam_checkpoint = os.path.join(os.path.dirname(__file__), "ckpt/sam_vit_h_4b8939.pth")
11
+ model_type = "vit_h"
12
+
13
+ device = "cuda:{}".format(device_id) if torch.cuda.is_available() else "cpu"
14
+
15
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device)
16
+ predictor = SamPredictor(sam)
17
+ return predictor
18
+
19
+ def sam_out_nosave(predictor, input_image, bbox_sliders=(0,0,255,255)):
20
+ bbox = np.array(bbox_sliders)
21
+ image = np.asarray(input_image)
22
+
23
+ start_time = time.time()
24
+ predictor.set_image(image)
25
+
26
+ h, w, _ = image.shape
27
+ input_point = np.array([[h//2, w//2]])
28
+ input_label = np.array([1])
29
+
30
+ masks, scores, logits = predictor.predict(
31
+ point_coords=input_point,
32
+ point_labels=input_label,
33
+ multimask_output=True,
34
+ )
35
+
36
+ masks_bbox, scores_bbox, logits_bbox = predictor.predict(
37
+ box=bbox,
38
+ multimask_output=True
39
+ )
40
+
41
+ print(f"SAM Time: {time.time() - start_time:.3f}s")
42
+ opt_idx = np.argmax(scores)
43
+ mask = masks[opt_idx]
44
+ out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
45
+ out_image[:, :, :3] = image
46
+ out_image_bbox = out_image.copy()
47
+ out_image[:, :, 3] = mask.astype(np.uint8) * 255
48
+ out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 # np.argmax(scores_bbox)
49
+ torch.cuda.empty_cache()
50
+ return Image.fromarray(out_image_bbox, mode='RGBA')