liuyuan-pal commited on
Commit
8e2f608
1 Parent(s): df916e6
Files changed (2) hide show
  1. .gitignore +1 -1
  2. app.py +117 -104
.gitignore CHANGED
@@ -2,4 +2,4 @@
2
  training_examples
3
  objaverse_examples
4
  ldm/__pycache__/
5
-
 
2
  training_examples
3
  objaverse_examples
4
  ldm/__pycache__/
5
+ __pycache__/
app.py CHANGED
@@ -12,11 +12,6 @@ from ldm.util import add_margin, instantiate_from_config
12
  from sam_utils import sam_init, sam_out_nosave
13
 
14
  import torch
15
- print(f"Is CUDA available: {torch.cuda.is_available()}")
16
- # True
17
- print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
18
- # Tesla T4
19
-
20
  _TITLE = '''SyncDreamer: Generating Multiview-consistent Images from a Single-view Image'''
21
  _DESCRIPTION = '''
22
  <div>
@@ -26,18 +21,24 @@ _DESCRIPTION = '''
26
  </div>
27
  Given a single-view image, SyncDreamer is able to generate multiview-consistent images, which enables direct 3D reconstruction with NeuS or NeRF without SDS loss
28
 
29
- 1. Upload the image.
30
- 2. Predict the mask for the foreground object.
31
- 3. Crop the foreground object.
32
- 4. Generate multiview images.
 
33
  '''
34
- _USER_GUIDE0 = "Step0: Please upload an image in the block above (or choose an example above). We use alpha values as object masks if given."
35
- _USER_GUIDE1 = "Step1: Please select a crop size using the glider."
36
- _USER_GUIDE2 = "Step2: Please choose a suitable elevation angle and then click the Generate button."
37
  _USER_GUIDE3 = "Generated multiview images are shown below!"
38
 
39
  deployed = True
40
 
 
 
 
 
 
41
  class BackgroundRemoval:
42
  def __init__(self, device='cuda'):
43
  from carvekit.api.high import HiInterface
@@ -74,73 +75,74 @@ def resize_inputs(image_input, crop_size):
74
  return results
75
 
76
  def generate(model, batch_view_num, sample_num, cfg_scale, seed, image_input, elevation_input):
77
- seed=int(seed)
78
- torch.random.manual_seed(seed)
79
- np.random.seed(seed)
80
-
81
- # prepare data
82
- image_input = np.asarray(image_input)
83
- image_input = image_input.astype(np.float32) / 255.0
84
- alpha_values = image_input[:,:, 3:]
85
- image_input[:, :, :3] = alpha_values * image_input[:,:, :3] + 1 - alpha_values # white background
86
- image_input = image_input[:, :, :3] * 2.0 - 1.0
87
- image_input = torch.from_numpy(image_input.astype(np.float32))
88
- elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
89
- data = {"input_image": image_input, "input_elevation": elevation_input}
90
- for k, v in data.items():
 
 
 
 
 
 
 
91
  if deployed:
92
- data[k] = v.unsqueeze(0).cuda()
93
  else:
94
- data[k] = v.unsqueeze(0)
95
- data[k] = torch.repeat_interleave(data[k], sample_num, dim=0)
96
-
97
- if deployed:
98
- x_sample = model.sample(data, cfg_scale, batch_view_num)
 
 
 
 
 
 
 
99
  else:
100
- x_sample = torch.zeros(sample_num, 16, 3, 256, 256)
101
-
102
- B, N, _, H, W = x_sample.shape
103
- x_sample = (torch.clamp(x_sample,max=1.0,min=-1.0) + 1) * 0.5
104
- x_sample = x_sample.permute(0,1,3,4,2).cpu().numpy() * 255
105
- x_sample = x_sample.astype(np.uint8)
106
-
107
- results = []
108
- for bi in range(B):
109
- results.append(np.concatenate([x_sample[bi,ni] for ni in range(N)], 1))
110
- results = np.concatenate(results, 0)
111
- return Image.fromarray(results)
112
 
113
- def white_background(img):
114
- img = np.asarray(img,np.float32)/255
115
- rgb = img[:,:,3:] * img[:,:,:3] + 1 - img[:,:,3:]
116
- rgb = (rgb*255).astype(np.uint8)
117
- return Image.fromarray(rgb)
118
 
119
  def sam_predict(predictor, removal, raw_im):
120
- raw_im.thumbnail([512, 512], Image.Resampling.LANCZOS)
121
- image_nobg = removal(raw_im.convert('RGB'))
122
- arr = np.asarray(image_nobg)[:, :, -1]
123
- x_nonzero = np.nonzero(arr.sum(axis=0))
124
- y_nonzero = np.nonzero(arr.sum(axis=1))
125
- x_min = int(x_nonzero[0].min())
126
- y_min = int(y_nonzero[0].min())
127
- x_max = int(x_nonzero[0].max())
128
- y_max = int(y_nonzero[0].max())
129
- # image_nobg.save('./nobg.png')
130
-
131
- image_nobg.thumbnail([512, 512], Image.Resampling.LANCZOS)
132
- image_sam = sam_out_nosave(predictor, image_nobg.convert("RGB"), (x_min, y_min, x_max, y_max))
133
-
134
- # imsave('./mask.png', np.asarray(image_sam)[:,:,3]*255)
135
- image_sam = np.asarray(image_sam, np.float32) / 255
136
- out_mask = image_sam[:, :, 3:]
137
- out_rgb = image_sam[:, :, :3] * out_mask + 1 - out_mask
138
- out_img = (np.concatenate([out_rgb, out_mask], 2) * 255).astype(np.uint8)
139
-
140
- image_sam = Image.fromarray(out_img, mode='RGBA')
141
- # image_sam.save('./output.png')
142
- torch.cuda.empty_cache()
143
- return image_sam
 
 
 
144
 
145
  def run_demo():
146
  # device = f"cuda:0" if torch.cuda.is_available() else "cpu"
@@ -156,21 +158,28 @@ def run_demo():
156
  model.load_state_dict(ckpt['state_dict'], strict=True)
157
  model = model.cuda().eval()
158
  del ckpt
 
 
159
  else:
160
  model = None
161
-
162
- # init sam model
163
- mask_predictor = sam_init()
164
- removal = BackgroundRemoval()
165
-
166
- # with open('instructions_12345.md', 'r') as f:
167
- # article = f.read()
168
 
169
  # NOTE: Examples must match inputs
170
- example_folder = os.path.join(os.path.dirname(__file__), 'hf_demo', 'examples')
171
- example_fns = os.listdir(example_folder)
172
- example_fns.sort()
173
- examples_full = [os.path.join(example_folder, x) for x in example_fns if x.endswith('.png')]
 
 
 
 
 
 
 
 
 
 
174
 
175
  # Compose demo layout & data flow.
176
  with gr.Blocks(title=_TITLE, css="hf_demo/style.css") as demo:
@@ -182,34 +191,38 @@ def run_demo():
182
  gr.Markdown(_DESCRIPTION)
183
 
184
  with gr.Row(variant='panel'):
185
- with gr.Column(scale=1):
186
- image_block = gr.Image(type='pil', image_mode='RGBA', height=256, label='Input image', tool=None, interactive=True)
187
- guide_text = gr.Markdown(_USER_GUIDE0, visible=True)
188
  gr.Examples(
189
  examples=examples_full, # NOTE: elements must match inputs list!
190
- inputs=[image_block],
191
- outputs=[image_block],
192
  cache_examples=False,
193
  label='Examples (click one of the images below to start)',
194
- examples_per_page=40
195
  )
196
 
 
 
 
 
197
 
198
- with gr.Column(scale=1):
 
199
  sam_block = gr.Image(type='pil', image_mode='RGBA', label="SAM output", height=256, interactive=False)
200
- crop_size_slider = gr.Slider(120, 240, 200, step=10, label='Crop size', interactive=True)
201
- crop_btn = gr.Button('Crop the image', variant='primary', interactive=True)
202
- fig0 = gr.Image(value=Image.open('assets/crop_size.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False)
203
 
204
- with gr.Column(scale=1):
205
  input_block = gr.Image(type='pil', image_mode='RGBA', label="Input to SyncDreamer", height=256, interactive=False)
206
- elevation = gr.Slider(-10, 40, 30, step=5, label='Elevation angle', interactive=True)
207
- cfg_scale = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance', interactive=True)
208
- sample_num = gr.Slider(1, 2, 1, step=1, label='Sample num', interactive=True, info='How many instance (16 images per instance)')
209
- batch_view_num = gr.Slider(1, 16, 16, step=1, label='Batch num', interactive=True)
210
- seed = gr.Number(6033, label='Random seed', interactive=True)
211
- run_btn = gr.Button('Run Generation', variant='primary', interactive=True)
212
- fig1 = gr.Image(value=Image.open('assets/elevation.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False)
 
213
 
214
  output_block = gr.Image(type='pil', image_mode='RGB', label="Outputs of SyncDreamer", height=256, interactive=False)
215
 
@@ -217,9 +230,9 @@ def run_demo():
217
  image_block.change(fn=partial(sam_predict, mask_predictor, removal), inputs=[image_block], outputs=[sam_block], queue=False)\
218
  .success(fn=partial(update_guide, _USER_GUIDE1), outputs=[guide_text], queue=False)
219
 
220
- crop_size_slider.change(fn=resize_inputs, inputs=[sam_block, crop_size_slider], outputs=[input_block], queue=False)\
221
  .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
222
- crop_btn.click(fn=resize_inputs, inputs=[sam_block, crop_size_slider], outputs=[input_block], queue=False)\
223
  .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
224
 
225
  run_btn.click(partial(generate, model), inputs=[batch_view_num, sample_num, cfg_scale, seed, input_block, elevation], outputs=[output_block], queue=False)\
 
12
  from sam_utils import sam_init, sam_out_nosave
13
 
14
  import torch
 
 
 
 
 
15
  _TITLE = '''SyncDreamer: Generating Multiview-consistent Images from a Single-view Image'''
16
  _DESCRIPTION = '''
17
  <div>
 
21
  </div>
22
  Given a single-view image, SyncDreamer is able to generate multiview-consistent images, which enables direct 3D reconstruction with NeuS or NeRF without SDS loss
23
 
24
+ Procedure:
25
+ **Step 0**. Upload an image or select an example. ==> The foreground is masked out by SAM.
26
+ **Step 1**. Select "Crop size" and click "Crop it". ==> The foreground object is centered and resized.
27
+ **Step 2**. Select "Elevation angle "and click "Run generation". ==> Generate multiview images. (This costs about 2 min.)
28
+ To reconstruct a NeRF or a 3D mesh from the generated images, please refer to our [github repository](https://github.com/liuyuan-pal/SyncDreamer).
29
  '''
30
+ _USER_GUIDE0 = "Step0: Please upload an image in the block above (or choose an example shown in the left)."
31
+ _USER_GUIDE1 = "Step1: Please select a **Crop size** and click **Crop it**."
32
+ _USER_GUIDE2 = "Step2: Please choose a **Elevation angle** and click **Run Generate**. This costs about 2 min."
33
  _USER_GUIDE3 = "Generated multiview images are shown below!"
34
 
35
  deployed = True
36
 
37
+ if deployed:
38
+ print(f"Is CUDA available: {torch.cuda.is_available()}")
39
+ print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
40
+
41
+
42
  class BackgroundRemoval:
43
  def __init__(self, device='cuda'):
44
  from carvekit.api.high import HiInterface
 
75
  return results
76
 
77
  def generate(model, batch_view_num, sample_num, cfg_scale, seed, image_input, elevation_input):
78
+ if deployed:
79
+ seed=int(seed)
80
+ torch.random.manual_seed(seed)
81
+ np.random.seed(seed)
82
+
83
+ # prepare data
84
+ image_input = np.asarray(image_input)
85
+ image_input = image_input.astype(np.float32) / 255.0
86
+ alpha_values = image_input[:,:, 3:]
87
+ image_input[:, :, :3] = alpha_values * image_input[:,:, :3] + 1 - alpha_values # white background
88
+ image_input = image_input[:, :, :3] * 2.0 - 1.0
89
+ image_input = torch.from_numpy(image_input.astype(np.float32))
90
+ elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
91
+ data = {"input_image": image_input, "input_elevation": elevation_input}
92
+ for k, v in data.items():
93
+ if deployed:
94
+ data[k] = v.unsqueeze(0).cuda()
95
+ else:
96
+ data[k] = v.unsqueeze(0)
97
+ data[k] = torch.repeat_interleave(data[k], sample_num, dim=0)
98
+
99
  if deployed:
100
+ x_sample = model.sample(data, cfg_scale, batch_view_num)
101
  else:
102
+ x_sample = torch.zeros(sample_num, 16, 3, 256, 256)
103
+
104
+ B, N, _, H, W = x_sample.shape
105
+ x_sample = (torch.clamp(x_sample,max=1.0,min=-1.0) + 1) * 0.5
106
+ x_sample = x_sample.permute(0,1,3,4,2).cpu().numpy() * 255
107
+ x_sample = x_sample.astype(np.uint8)
108
+
109
+ results = []
110
+ for bi in range(B):
111
+ results.append(np.concatenate([x_sample[bi,ni] for ni in range(N)], 1))
112
+ results = np.concatenate(results, 0)
113
+ return Image.fromarray(results)
114
  else:
115
+ return Image.fromarray(np.zeros([sample_num*256,16*256,3],np.uint8))
 
 
 
 
 
 
 
 
 
 
 
116
 
 
 
 
 
 
117
 
118
  def sam_predict(predictor, removal, raw_im):
119
+ if deployed:
120
+ raw_im.thumbnail([512, 512], Image.Resampling.LANCZOS)
121
+ image_nobg = removal(raw_im.convert('RGB'))
122
+ arr = np.asarray(image_nobg)[:, :, -1]
123
+ x_nonzero = np.nonzero(arr.sum(axis=0))
124
+ y_nonzero = np.nonzero(arr.sum(axis=1))
125
+ x_min = int(x_nonzero[0].min())
126
+ y_min = int(y_nonzero[0].min())
127
+ x_max = int(x_nonzero[0].max())
128
+ y_max = int(y_nonzero[0].max())
129
+ # image_nobg.save('./nobg.png')
130
+
131
+ image_nobg.thumbnail([512, 512], Image.Resampling.LANCZOS)
132
+ image_sam = sam_out_nosave(predictor, image_nobg.convert("RGB"), (x_min, y_min, x_max, y_max))
133
+
134
+ # imsave('./mask.png', np.asarray(image_sam)[:,:,3]*255)
135
+ image_sam = np.asarray(image_sam, np.float32) / 255
136
+ out_mask = image_sam[:, :, 3:]
137
+ out_rgb = image_sam[:, :, :3] * out_mask + 1 - out_mask
138
+ out_img = (np.concatenate([out_rgb, out_mask], 2) * 255).astype(np.uint8)
139
+
140
+ image_sam = Image.fromarray(out_img, mode='RGBA')
141
+ # image_sam.save('./output.png')
142
+ torch.cuda.empty_cache()
143
+ return image_sam
144
+ else:
145
+ return raw_im
146
 
147
  def run_demo():
148
  # device = f"cuda:0" if torch.cuda.is_available() else "cpu"
 
158
  model.load_state_dict(ckpt['state_dict'], strict=True)
159
  model = model.cuda().eval()
160
  del ckpt
161
+ mask_predictor = sam_init()
162
+ removal = BackgroundRemoval()
163
  else:
164
  model = None
165
+ mask_predictor = None
166
+ removal = None
 
 
 
 
 
167
 
168
  # NOTE: Examples must match inputs
169
+ examples_full = [
170
+ ['hf_demo/examples/basket.png',30,200],
171
+ ['hf_demo/examples/cat.png',30,200],
172
+ ['hf_demo/examples/crab.png',30,200],
173
+ ['hf_demo/examples/elephant.png',30,200],
174
+ ['hf_demo/examples/flower.png',0,200],
175
+ ['hf_demo/examples/forest.png',30,200],
176
+ ['hf_demo/examples/monkey.png',30,200],
177
+ ['hf_demo/examples/teapot.png',0,200],
178
+ ]
179
+
180
+ image_block = gr.Image(type='pil', image_mode='RGBA', height=256, label='Input image', tool=None, interactive=True)
181
+ elevation = gr.Slider(-10, 40, 30, step=5, label='Elevation angle', interactive=True)
182
+ crop_size = gr.Slider(120, 240, 200, step=10, label='Crop size', interactive=True)
183
 
184
  # Compose demo layout & data flow.
185
  with gr.Blocks(title=_TITLE, css="hf_demo/style.css") as demo:
 
191
  gr.Markdown(_DESCRIPTION)
192
 
193
  with gr.Row(variant='panel'):
194
+ with gr.Column(scale=1.2):
 
 
195
  gr.Examples(
196
  examples=examples_full, # NOTE: elements must match inputs list!
197
+ inputs=[image_block, elevation, crop_size],
198
+ outputs=[image_block, elevation, crop_size],
199
  cache_examples=False,
200
  label='Examples (click one of the images below to start)',
201
+ examples_per_page=5,
202
  )
203
 
204
+ with gr.Column(scale=0.8):
205
+ image_block.render()
206
+ guide_text = gr.Markdown(_USER_GUIDE0, visible=True)
207
+ fig0 = gr.Image(value=Image.open('assets/crop_size.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False)
208
 
209
+
210
+ with gr.Column(scale=0.8):
211
  sam_block = gr.Image(type='pil', image_mode='RGBA', label="SAM output", height=256, interactive=False)
212
+ crop_size.render()
213
+ crop_btn = gr.Button('Crop it', variant='primary', interactive=True)
214
+ fig1 = gr.Image(value=Image.open('assets/elevation.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False)
215
 
216
+ with gr.Column(scale=0.8):
217
  input_block = gr.Image(type='pil', image_mode='RGBA', label="Input to SyncDreamer", height=256, interactive=False)
218
+ elevation.render()
219
+ with gr.Accordion('Advanced options', open=False):
220
+ cfg_scale = gr.Slider(1.0, 5.0, 2.0, step=0.1, label='Classifier free guidance', interactive=True)
221
+ sample_num = gr.Slider(1, 2, 1, step=1, label='Sample num', interactive=True, info='How many instance (16 images per instance)')
222
+ batch_view_num = gr.Slider(1, 16, 16, step=1, label='Batch num', interactive=True)
223
+ seed = gr.Number(6033, label='Random seed', interactive=True)
224
+ run_btn = gr.Button('Run generation', variant='primary', interactive=True)
225
+
226
 
227
  output_block = gr.Image(type='pil', image_mode='RGB', label="Outputs of SyncDreamer", height=256, interactive=False)
228
 
 
230
  image_block.change(fn=partial(sam_predict, mask_predictor, removal), inputs=[image_block], outputs=[sam_block], queue=False)\
231
  .success(fn=partial(update_guide, _USER_GUIDE1), outputs=[guide_text], queue=False)
232
 
233
+ crop_size.change(fn=resize_inputs, inputs=[sam_block, crop_size], outputs=[input_block], queue=False)\
234
  .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
235
+ crop_btn.click(fn=resize_inputs, inputs=[sam_block, crop_size], outputs=[input_block], queue=False)\
236
  .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False)
237
 
238
  run_btn.click(partial(generate, model), inputs=[batch_view_num, sample_num, cfg_scale, seed, input_block, elevation], outputs=[output_block], queue=False)\