flamehaze1115 commited on
Commit
e8ee7ff
β€’
1 Parent(s): 28dbcfe

Upload 33 files

Browse files
README.md CHANGED
@@ -3,7 +3,7 @@ title: Wonder3D
3
  emoji: πŸš€
4
  colorFrom: indigo
5
  colorTo: pink
6
- sdk: gradio
7
  sdk_version: 3.43.2
8
  app_file: app.py
9
  pinned: false
 
3
  emoji: πŸš€
4
  colorFrom: indigo
5
  colorTo: pink
6
+ sdk: docker
7
  sdk_version: 3.43.2
8
  app_file: app.py
9
  pinned: false
app.py CHANGED
@@ -15,6 +15,7 @@ from mvdiffusion.models.unet_mv2d_condition import UNetMV2DConditionModel
15
  from mvdiffusion.data.single_image_dataset import SingleImageDataset as MVDiffusionDataset
16
  from mvdiffusion.pipelines.pipeline_mvdiffusion_image import MVDiffusionImagePipeline
17
  from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
 
18
 
19
  @dataclass
20
  class TestConfig:
@@ -53,6 +54,13 @@ iret = [
53
  for x in sorted(os.listdir(iret_base))
54
  ]
55
 
 
 
 
 
 
 
 
56
 
57
  class SAMAPI:
58
  predictor = None
@@ -62,7 +70,7 @@ class SAMAPI:
62
  def get_instance(sam_checkpoint=None):
63
  if SAMAPI.predictor is None:
64
  if sam_checkpoint is None:
65
- sam_checkpoint = "tmp/sam_vit_h_4b8939.pth"
66
  if not os.path.exists(sam_checkpoint):
67
  os.makedirs('tmp', exist_ok=True)
68
  urllib.request.urlretrieve(
@@ -164,6 +172,8 @@ def segment_6imgs(imgs):
164
  return Image.fromarray(result)
165
 
166
  def pack_6imgs(imgs):
 
 
167
  result = numpy.concatenate([
168
  numpy.concatenate([imgs[0], imgs[1]], axis=1),
169
  numpy.concatenate([imgs[2], imgs[3]], axis=1),
@@ -221,15 +231,15 @@ def check_dependencies():
221
 
222
 
223
  @st.cache_resource
224
- def load_wonder3d_pipeline(cfg):
225
  # Load scheduler, tokenizer and models.
226
  # noise_scheduler = DDPMScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler")
227
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_encoder", revision=cfg.revision)
228
  feature_extractor = CLIPImageProcessor.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="feature_extractor", revision=cfg.revision)
229
  vae = AutoencoderKL.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", revision=cfg.revision)
230
  unet = UNetMV2DConditionModel.from_pretrained_2d(cfg.pretrained_unet_path, subfolder="unet", revision=cfg.revision, **cfg.unet_from_pretrained_kwargs)
 
231
 
232
- weight_dtype = torch.float16
233
  # Move text_encode and vae to gpu and cast to weight_dtype
234
  image_encoder.to(dtype=weight_dtype)
235
  vae.to(dtype=weight_dtype)
@@ -246,6 +256,53 @@ def load_wonder3d_pipeline(cfg):
246
  sys.main_lock = threading.Lock()
247
  return pipeline
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
  from utils.misc import load_config
251
  from omegaconf import OmegaConf
@@ -257,26 +314,33 @@ schema = OmegaConf.structured(TestConfig)
257
  cfg = OmegaConf.merge(schema, cfg)
258
 
259
  check_dependencies()
260
- pipeline = load_wonder3d_pipeline(cfg)
261
  SAMAPI.get_instance()
262
  torch.set_grad_enabled(False)
263
 
264
- st.title("Wonder3D Demo")
265
  # st.caption("For faster inference without waiting in queue, you may clone the space and run it yourself.")
266
- prog = st.progress(0.0, "Idle")
267
  pic = st.file_uploader("Upload an Image", key='imageinput', type=['png', 'jpg', 'webp'])
268
  left, right = st.columns(2)
 
 
 
 
 
 
 
 
 
269
  with left:
270
- rem_input_bg = st.checkbox("Remove Input Background")
271
  with right:
272
- rem_output_bg = st.checkbox("Remove Output Background")
273
- num_inference_steps = st.slider("Number of Inference Steps", 15, 100, 75)
274
- st.caption("Diffusion Steps. For general real or synthetic objects, around 28 is enough. For objects with delicate details such as faces (either realistic or illustration), you may need 75 or more steps.")
275
- cfg_scale = st.slider("Classifier Free Guidance Scale", 1.0, 10.0, 4.0)
276
- seed = st.text_input("Seed", "42")
277
- submit = False
278
- if st.button("Submit"):
279
- submit = True
280
  results_container = st.container()
281
  sample_got = image_examples(iret, 4, 'rimageinput')
282
  if sample_got:
@@ -284,48 +348,47 @@ if sample_got:
284
  with results_container:
285
  if sample_got or submit:
286
  prog.progress(0.03, "Waiting in Queue...")
287
- with sys.main_lock:
288
- seed = int(seed)
289
- torch.manual_seed(seed)
290
- img = Image.open(pic)
291
- if max(img.size) > 1280:
292
- w, h = img.size
293
- w = round(1280 / max(img.size) * w)
294
- h = round(1280 / max(img.size) * h)
295
- img = img.resize((w, h))
296
- left, right = st.columns(2)
297
- with left:
298
- st.image(img)
299
- st.caption("Input Image")
300
- prog.progress(0.1, "Preparing Inputs")
301
- if rem_input_bg:
302
- with right:
303
- img = segment_img(img)
304
- st.image(img)
305
- st.caption("Input (Background Removed)")
306
- img = expand2square(img, (127, 127, 127, 0))
307
- pipeline.set_progress_bar_config(disable=True)
308
- result = pipeline(
309
- img,
310
- num_inference_steps=num_inference_steps,
311
- guidance_scale=cfg_scale,
312
- generator=torch.Generator(pipeline.device).manual_seed(seed),
313
- callback=lambda i, t, latents: prog.progress(0.1 + 0.8 * i / num_inference_steps, "Diffusion Step %d" % i)
314
- ).images
315
- bsz = result.shape[0] // 2
316
- normals_pred = result[:bsz]
317
- images_pred = result[bsz:]
318
- prog.progress(0.9, "Post Processing")
319
- left, right = st.columns(2)
320
- with left:
321
- st.image(pack_6imgs(normals_pred))
322
- st.image(pack_6imgs(images_pred))
323
- st.caption("Result")
324
- if rem_output_bg:
325
- normals_pred = segment_6imgs(normals_pred)
326
- images_pred = segment_6imgs(images_pred)
327
- with right:
328
- st.image(normals_pred)
329
- st.image(images_pred)
330
- st.caption("Result (Background Removed)")
331
- prog.progress(1.0, "Idle")
 
15
  from mvdiffusion.data.single_image_dataset import SingleImageDataset as MVDiffusionDataset
16
  from mvdiffusion.pipelines.pipeline_mvdiffusion_image import MVDiffusionImagePipeline
17
  from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
18
+ from einops import rearrange
19
 
20
  @dataclass
21
  class TestConfig:
 
54
  for x in sorted(os.listdir(iret_base))
55
  ]
56
 
57
+ def save_image(tensor):
58
+ ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
59
+ # pdb.set_trace()
60
+ im = Image.fromarray(ndarr)
61
+ return ndarr
62
+
63
+ weight_dtype = torch.float16
64
 
65
  class SAMAPI:
66
  predictor = None
 
70
  def get_instance(sam_checkpoint=None):
71
  if SAMAPI.predictor is None:
72
  if sam_checkpoint is None:
73
+ sam_checkpoint = "sam_pt/sam_vit_h_4b8939.pth"
74
  if not os.path.exists(sam_checkpoint):
75
  os.makedirs('tmp', exist_ok=True)
76
  urllib.request.urlretrieve(
 
172
  return Image.fromarray(result)
173
 
174
  def pack_6imgs(imgs):
175
+ import pdb
176
+ # pdb.set_trace()
177
  result = numpy.concatenate([
178
  numpy.concatenate([imgs[0], imgs[1]], axis=1),
179
  numpy.concatenate([imgs[2], imgs[3]], axis=1),
 
231
 
232
 
233
  @st.cache_resource
234
+ def load_wonder3d_pipeline():
235
  # Load scheduler, tokenizer and models.
236
  # noise_scheduler = DDPMScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler")
237
  image_encoder = CLIPVisionModelWithProjection.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_encoder", revision=cfg.revision)
238
  feature_extractor = CLIPImageProcessor.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="feature_extractor", revision=cfg.revision)
239
  vae = AutoencoderKL.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", revision=cfg.revision)
240
  unet = UNetMV2DConditionModel.from_pretrained_2d(cfg.pretrained_unet_path, subfolder="unet", revision=cfg.revision, **cfg.unet_from_pretrained_kwargs)
241
+ unet.enable_xformers_memory_efficient_attention()
242
 
 
243
  # Move text_encode and vae to gpu and cast to weight_dtype
244
  image_encoder.to(dtype=weight_dtype)
245
  vae.to(dtype=weight_dtype)
 
256
  sys.main_lock = threading.Lock()
257
  return pipeline
258
 
259
+ from mvdiffusion.data.single_image_dataset import SingleImageDataset
260
+ def prepare_data(single_image):
261
+ dataset = SingleImageDataset(
262
+ root_dir = None,
263
+ num_views = 6,
264
+ img_wh=[256, 256],
265
+ bg_color='white',
266
+ crop_size=crop_size,
267
+ single_image=single_image
268
+ )
269
+ return dataset[0]
270
+
271
+
272
+ def run_pipeline(pipeline, batch, guidance_scale, seed):
273
+
274
+ pipeline.set_progress_bar_config(disable=True)
275
+
276
+ generator = torch.Generator(device=pipeline.unet.device).manual_seed(seed)
277
+
278
+ # repeat (2B, Nv, 3, H, W)
279
+ imgs_in = torch.cat([batch['imgs_in']]*2, dim=0).to(weight_dtype)
280
+
281
+ # (2B, Nv, Nce)
282
+ camera_embeddings = torch.cat([batch['camera_embeddings']]*2, dim=0).to(weight_dtype)
283
+
284
+ task_embeddings = torch.cat([batch['normal_task_embeddings'], batch['color_task_embeddings']], dim=0).to(weight_dtype)
285
+
286
+ camera_embeddings = torch.cat([camera_embeddings, task_embeddings], dim=-1).to(weight_dtype)
287
+
288
+ # (B*Nv, 3, H, W)
289
+ imgs_in = rearrange(imgs_in, "Nv C H W -> (Nv) C H W")
290
+ # (B*Nv, Nce)
291
+ # camera_embeddings = rearrange(camera_embeddings, "B Nv Nce -> (B Nv) Nce")
292
+
293
+ out = pipeline(
294
+ imgs_in, camera_embeddings, generator=generator, guidance_scale=guidance_scale,
295
+ output_type='pt', num_images_per_prompt=1, **cfg.pipe_validation_kwargs
296
+ ).images
297
+
298
+ bsz = out.shape[0] // 2
299
+ normals_pred = out[:bsz]
300
+ images_pred = out[bsz:]
301
+
302
+ normals_pred = [save_image(normals_pred[i]) for i in range(bsz)]
303
+ images_pred = [save_image(images_pred[i]) for i in range(bsz)]
304
+
305
+ return normals_pred, images_pred
306
 
307
  from utils.misc import load_config
308
  from omegaconf import OmegaConf
 
314
  cfg = OmegaConf.merge(schema, cfg)
315
 
316
  check_dependencies()
317
+ pipeline = load_wonder3d_pipeline()
318
  SAMAPI.get_instance()
319
  torch.set_grad_enabled(False)
320
 
321
+ st.title("Wonder3D: Single Image to 3D using Cross-Domain Diffusion")
322
  # st.caption("For faster inference without waiting in queue, you may clone the space and run it yourself.")
323
+
324
  pic = st.file_uploader("Upload an Image", key='imageinput', type=['png', 'jpg', 'webp'])
325
  left, right = st.columns(2)
326
+ # with left:
327
+ # rem_input_bg = st.checkbox("Remove Input Background")
328
+ # with right:
329
+ # rem_output_bg = st.checkbox("Remove Output Background")
330
+ with left:
331
+ num_inference_steps = st.slider("Number of Inference Steps", 15, 100, 50)
332
+ # st.caption("Diffusion Steps. For general real or synthetic objects, around 28 is enough. For objects with delicate details such as faces (either realistic or illustration), you may need 75 or more steps.")
333
+ with right:
334
+ cfg_scale = st.slider("Classifier Free Guidance Scale", 1.0, 10.0, 3.0)
335
  with left:
336
+ seed = int(st.text_input("Seed", "42"))
337
  with right:
338
+ crop_size = int(st.text_input("crop_size", "192"))
339
+ # submit = False
340
+ # if st.button("Submit"):
341
+ # submit = True
342
+ submit = True
343
+ prog = st.progress(0.0, "Idle")
 
 
344
  results_container = st.container()
345
  sample_got = image_examples(iret, 4, 'rimageinput')
346
  if sample_got:
 
348
  with results_container:
349
  if sample_got or submit:
350
  prog.progress(0.03, "Waiting in Queue...")
351
+
352
+ seed = int(seed)
353
+ torch.manual_seed(seed)
354
+ img = Image.open(pic)
355
+
356
+ data = prepare_data(img)
357
+
358
+ if max(img.size) > 1280:
359
+ w, h = img.size
360
+ w = round(1280 / max(img.size) * w)
361
+ h = round(1280 / max(img.size) * h)
362
+ img = img.resize((w, h))
363
+ left, right = st.columns(2)
364
+ with left:
365
+ st.caption("Input Image")
366
+ st.image(img)
367
+ prog.progress(0.1, "Preparing Inputs")
368
+
369
+ with right:
370
+ img = segment_img(img)
371
+ st.caption("Input (Background Removed)")
372
+ st.image(img)
373
+
374
+ img = expand2square(img, (127, 127, 127, 0))
375
+ # pipeline.set_progress_bar_config(disable=True)
376
+ prog.progress(0.3, "Run cross-domain diffusion model")
377
+ normals_pred, images_pred = run_pipeline(pipeline, data, cfg_scale, seed)
378
+ prog.progress(0.9, "finishing")
379
+ left, right = st.columns(2)
380
+ with left:
381
+ st.caption("Generated Normals")
382
+ st.image(pack_6imgs(normals_pred))
383
+
384
+ with right:
385
+ st.caption("Generated Color Images")
386
+ st.image(pack_6imgs(images_pred))
387
+ # if rem_output_bg:
388
+ # normals_pred = segment_6imgs(normals_pred)
389
+ # images_pred = segment_6imgs(images_pred)
390
+ # with right:
391
+ # st.image(normals_pred)
392
+ # st.image(images_pred)
393
+ # st.caption("Result (Background Removed)")
394
+ prog.progress(1.0, "Idle")
 
mvdiffusion/data/__pycache__/normal_utils.cpython-39.pyc ADDED
Binary file (1.52 kB). View file
 
mvdiffusion/data/__pycache__/single_image_dataset.cpython-39.pyc ADDED
Binary file (8.06 kB). View file
 
mvdiffusion/data/single_image_dataset.py CHANGED
@@ -84,6 +84,7 @@ class SingleImageDataset(Dataset):
84
  img_wh: Tuple[int, int],
85
  bg_color: str,
86
  crop_size: int = 224,
 
87
  num_validation_samples: Optional[int] = None,
88
  filepaths: Optional[list] = None,
89
  cond_type: Optional[str] = None
@@ -92,7 +93,7 @@ class SingleImageDataset(Dataset):
92
  If you pass in a root directory it will be searched for images
93
  ending in ext (ext can be a list)
94
  """
95
- self.root_dir = Path(root_dir)
96
  self.num_views = num_views
97
  self.img_wh = img_wh
98
  self.crop_size = crop_size
@@ -110,32 +111,37 @@ class SingleImageDataset(Dataset):
110
 
111
  self.fix_cam_poses = self.load_fixed_poses() # world2cam matrix
112
 
113
- if filepaths is None:
114
- # Get a list of all files in the directory
115
- file_list = os.listdir(self.root_dir)
116
- else:
117
- file_list = filepaths
118
-
119
- if self.cond_type == None:
120
- # Filter the files that end with .png or .jpg
121
- self.file_list = [file for file in file_list if file.endswith(('.png', '.jpg'))]
122
- self.cond_dirs = None
123
- else:
124
- self.file_list = []
125
- self.cond_dirs = []
126
- for scene in file_list:
127
- self.file_list.append(os.path.join(scene, f"{scene}.png"))
128
- if self.cond_type == 'normals':
129
- self.cond_dirs.append(os.path.join(self.root_dir, scene, 'outs'))
130
- else:
131
- self.cond_dirs.append(os.path.join(self.root_dir, scene))
132
 
133
  # load all images
134
  self.all_images = []
135
  self.all_alphas = []
136
  bg_color = self.get_bg_color()
137
- for file in self.file_list:
138
- image, alpha = self.load_image(os.path.join(self.root_dir, file), bg_color, return_type='pt')
 
 
 
 
 
139
  self.all_images.append(image)
140
  self.all_alphas.append(alpha)
141
 
@@ -196,9 +202,12 @@ class SingleImageDataset(Dataset):
196
  return bg_color
197
 
198
 
199
- def load_image(self, img_path, bg_color, return_type='np'):
200
  # pil always returns uint8
201
- image_input = Image.open(img_path)
 
 
 
202
  image_size = self.img_wh[0]
203
 
204
  if self.crop_size!=-1:
@@ -210,11 +219,11 @@ class SingleImageDataset(Dataset):
210
  h, w = ref_img_.height, ref_img_.width
211
  scale = self.crop_size / max(h, w)
212
  h_, w_ = int(scale * h), int(scale * w)
213
- ref_img_ = ref_img_.resize((w_, h_), resample=Image.BICUBIC)
214
  image_input = add_margin(ref_img_, size=image_size)
215
  else:
216
  image_input = add_margin(image_input, size=max(image_input.height, image_input.width))
217
- image_input = image_input.resize((image_size, image_size), resample=Image.BICUBIC)
218
 
219
  # img = scale_and_place_object(img, self.scale_ratio)
220
  img = np.array(image_input)
@@ -256,7 +265,7 @@ class SingleImageDataset(Dataset):
256
 
257
  image = self.all_images[index%len(self.all_images)]
258
  alpha = self.all_alphas[index%len(self.all_images)]
259
- filename = self.file_list[index%len(self.all_images)].replace(".png", "")
260
 
261
  if self.cond_type != None:
262
  conds = self.load_conds(self.cond_dirs[index%len(self.all_images)])
@@ -310,7 +319,7 @@ class SingleImageDataset(Dataset):
310
  'camera_embeddings': camera_embeddings,
311
  'normal_task_embeddings': normal_task_embeddings,
312
  'color_task_embeddings': color_task_embeddings,
313
- 'filename': filename,
314
  }
315
 
316
  if conds is not None:
 
84
  img_wh: Tuple[int, int],
85
  bg_color: str,
86
  crop_size: int = 224,
87
+ single_image: Optional[PIL.Image.Image] = None,
88
  num_validation_samples: Optional[int] = None,
89
  filepaths: Optional[list] = None,
90
  cond_type: Optional[str] = None
 
93
  If you pass in a root directory it will be searched for images
94
  ending in ext (ext can be a list)
95
  """
96
+ # self.root_dir = Path(root_dir)
97
  self.num_views = num_views
98
  self.img_wh = img_wh
99
  self.crop_size = crop_size
 
111
 
112
  self.fix_cam_poses = self.load_fixed_poses() # world2cam matrix
113
 
114
+ # if filepaths is None:
115
+ # # Get a list of all files in the directory
116
+ # file_list = os.listdir(self.root_dir)
117
+ # else:
118
+ # file_list = filepaths
119
+
120
+ # if self.cond_type == None:
121
+ # # Filter the files that end with .png or .jpg
122
+ # self.file_list = [file for file in file_list if file.endswith(('.png', '.jpg'))]
123
+ # self.cond_dirs = None
124
+ # else:
125
+ # self.file_list = []
126
+ # self.cond_dirs = []
127
+ # for scene in file_list:
128
+ # self.file_list.append(os.path.join(scene, f"{scene}.png"))
129
+ # if self.cond_type == 'normals':
130
+ # self.cond_dirs.append(os.path.join(self.root_dir, scene, 'outs'))
131
+ # else:
132
+ # self.cond_dirs.append(os.path.join(self.root_dir, scene))
133
 
134
  # load all images
135
  self.all_images = []
136
  self.all_alphas = []
137
  bg_color = self.get_bg_color()
138
+ # for file in self.file_list:
139
+ # image, alpha = self.load_image(os.path.join(self.root_dir, file), bg_color, return_type='pt')
140
+ # self.all_images.append(image)
141
+ # self.all_alphas.append(alpha)
142
+
143
+ if single_image is not None:
144
+ image, alpha = self.load_image(None, bg_color, return_type='pt', Image=single_image)
145
  self.all_images.append(image)
146
  self.all_alphas.append(alpha)
147
 
 
202
  return bg_color
203
 
204
 
205
+ def load_image(self, img_path, bg_color, return_type='np', Image=None):
206
  # pil always returns uint8
207
+ if Image is None:
208
+ image_input = Image.open(img_path)
209
+ else:
210
+ image_input = Image
211
  image_size = self.img_wh[0]
212
 
213
  if self.crop_size!=-1:
 
219
  h, w = ref_img_.height, ref_img_.width
220
  scale = self.crop_size / max(h, w)
221
  h_, w_ = int(scale * h), int(scale * w)
222
+ ref_img_ = ref_img_.resize((w_, h_))
223
  image_input = add_margin(ref_img_, size=image_size)
224
  else:
225
  image_input = add_margin(image_input, size=max(image_input.height, image_input.width))
226
+ image_input = image_input.resize((image_size, image_size))
227
 
228
  # img = scale_and_place_object(img, self.scale_ratio)
229
  img = np.array(image_input)
 
265
 
266
  image = self.all_images[index%len(self.all_images)]
267
  alpha = self.all_alphas[index%len(self.all_images)]
268
+ # filename = self.file_list[index%len(self.all_images)].replace(".png", "")
269
 
270
  if self.cond_type != None:
271
  conds = self.load_conds(self.cond_dirs[index%len(self.all_images)])
 
319
  'camera_embeddings': camera_embeddings,
320
  'normal_task_embeddings': normal_task_embeddings,
321
  'color_task_embeddings': color_task_embeddings,
322
+ # 'filename': filename,
323
  }
324
 
325
  if conds is not None:
mvdiffusion/models/__pycache__/transformer_mv2d.cpython-39.pyc ADDED
Binary file (22.7 kB). View file
 
mvdiffusion/models/__pycache__/unet_mv2d_blocks.cpython-39.pyc ADDED
Binary file (14.1 kB). View file
 
mvdiffusion/models/__pycache__/unet_mv2d_condition.cpython-39.pyc ADDED
Binary file (44.5 kB). View file
 
mvdiffusion/pipelines/__pycache__/pipeline_mvdiffusion_image.cpython-39.pyc ADDED
Binary file (17.6 kB). View file
 
mvdiffusion/pipelines/pipeline_mvdiffusion_image.py CHANGED
@@ -155,7 +155,7 @@ class MVDiffusionImagePipeline(DiffusionPipeline):
155
  # to avoid doing two forward passes
156
  image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
157
 
158
- image_pt = torch.stack([TF.to_tensor(img) for img in image_pil], dim=0).to(device)
159
  image_pt = image_pt * 2.0 - 1.0
160
  image_latents = self.vae.encode(image_pt).latent_dist.mode() * self.vae.config.scaling_factor
161
  # Note: repeat differently from official pipelines
 
155
  # to avoid doing two forward passes
156
  image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
157
 
158
+ image_pt = torch.stack([TF.to_tensor(img) for img in image_pil], dim=0).to(device).to(dtype)
159
  image_pt = image_pt * 2.0 - 1.0
160
  image_latents = self.vae.encode(image_pt).latent_dist.mode() * self.vae.config.scaling_factor
161
  # Note: repeat differently from official pipelines
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- --extra-index-url https://download.pytorch.org/whl/cu113
2
  torch==1.13.1
3
  torchvision
4
  diffusers[torch]==0.19.3
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu117
2
  torch==1.13.1
3
  torchvision
4
  diffusers[torch]==0.19.3
utils/__pycache__/misc.cpython-39.pyc ADDED
Binary file (1.62 kB). View file
 
utils/misc.py CHANGED
@@ -4,13 +4,13 @@ from packaging import version
4
 
5
 
6
  # ============ Register OmegaConf Recolvers ============= #
7
- OmegaConf.register_new_resolver('calc_exp_lr_decay_rate', lambda factor, n: factor**(1./n))
8
- OmegaConf.register_new_resolver('add', lambda a, b: a + b)
9
- OmegaConf.register_new_resolver('sub', lambda a, b: a - b)
10
- OmegaConf.register_new_resolver('mul', lambda a, b: a * b)
11
- OmegaConf.register_new_resolver('div', lambda a, b: a / b)
12
- OmegaConf.register_new_resolver('idiv', lambda a, b: a // b)
13
- OmegaConf.register_new_resolver('basename', lambda p: os.path.basename(p))
14
  # ======================================================= #
15
 
16
 
 
4
 
5
 
6
  # ============ Register OmegaConf Recolvers ============= #
7
+ # OmegaConf.register_new_resolver('calc_exp_lr_decay_rate', lambda factor, n: factor**(1./n))
8
+ # OmegaConf.register_new_resolver('add', lambda a, b: a + b)
9
+ # OmegaConf.register_new_resolver('sub', lambda a, b: a - b)
10
+ # OmegaConf.register_new_resolver('mul', lambda a, b: a * b)
11
+ # OmegaConf.register_new_resolver('div', lambda a, b: a / b)
12
+ # OmegaConf.register_new_resolver('idiv', lambda a, b: a // b)
13
+ # OmegaConf.register_new_resolver('basename', lambda p: os.path.basename(p))
14
  # ======================================================= #
15
 
16