flamehaze1115 commited on
Commit
9938557
1 Parent(s): 0ed298c

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -393
app.py DELETED
@@ -1,393 +0,0 @@
1
- import os
2
- import sys
3
- import numpy
4
- import torch
5
- import rembg
6
- import threading
7
- import urllib.request
8
- from PIL import Image
9
- from typing import Dict, Optional, Tuple, List
10
- from dataclasses import dataclass
11
- import streamlit as st
12
- import huggingface_hub
13
- from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
14
- from mvdiffusion.models.unet_mv2d_condition import UNetMV2DConditionModel
15
- from mvdiffusion.data.single_image_dataset import SingleImageDataset as MVDiffusionDataset
16
- from mvdiffusion.pipelines.pipeline_mvdiffusion_image import MVDiffusionImagePipeline
17
- from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
18
- from einops import rearrange
19
-
20
- @dataclass
21
- class TestConfig:
22
- pretrained_model_name_or_path: str
23
- pretrained_unet_path:str
24
- revision: Optional[str]
25
- validation_dataset: Dict
26
- save_dir: str
27
- seed: Optional[int]
28
- validation_batch_size: int
29
- dataloader_num_workers: int
30
-
31
- local_rank: int
32
-
33
- pipe_kwargs: Dict
34
- pipe_validation_kwargs: Dict
35
- unet_from_pretrained_kwargs: Dict
36
- validation_guidance_scales: List[float]
37
- validation_grid_nrow: int
38
- camera_embedding_lr_mult: float
39
-
40
- num_views: int
41
- camera_embedding_type: str
42
-
43
- pred_type: str # joint, or ablation
44
-
45
- enable_xformers_memory_efficient_attention: bool
46
-
47
- cond_on_normals: bool
48
- cond_on_colors: bool
49
-
50
- img_example_counter = 0
51
- iret_base = 'example_images'
52
- iret = [
53
- dict(rimageinput=os.path.join(iret_base, x), dispi=os.path.join(iret_base, x))
54
- for x in sorted(os.listdir(iret_base))
55
- ]
56
-
57
- def save_image(tensor):
58
- ndarr = tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
59
- # pdb.set_trace()
60
- im = Image.fromarray(ndarr)
61
- return ndarr
62
-
63
- weight_dtype = torch.float16
64
-
65
- class SAMAPI:
66
- predictor = None
67
-
68
- @staticmethod
69
- @st.cache_resource
70
- def get_instance(sam_checkpoint=None):
71
- if SAMAPI.predictor is None:
72
- if sam_checkpoint is None:
73
- sam_checkpoint = "./sam_pt/sam_vit_h_4b8939.pth"
74
- if not os.path.exists(sam_checkpoint):
75
- os.makedirs('sam_pt', exist_ok=True)
76
- urllib.request.urlretrieve(
77
- "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
78
- sam_checkpoint
79
- )
80
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
81
- model_type = "default"
82
-
83
- from segment_anything import sam_model_registry, SamPredictor
84
-
85
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
86
- sam.to(device=device)
87
-
88
- predictor = SamPredictor(sam)
89
- SAMAPI.predictor = predictor
90
- return SAMAPI.predictor
91
-
92
- @staticmethod
93
- def segment_api(rgb, mask=None, bbox=None, sam_checkpoint=None):
94
- """
95
-
96
- Parameters
97
- ----------
98
- rgb : np.ndarray h,w,3 uint8
99
- mask: np.ndarray h,w bool
100
-
101
- Returns
102
- -------
103
-
104
- """
105
- np = numpy
106
- predictor = SAMAPI.get_instance(sam_checkpoint)
107
- predictor.set_image(rgb)
108
- if mask is None and bbox is None:
109
- box_input = None
110
- else:
111
- # mask to bbox
112
- if bbox is None:
113
- y1, y2, x1, x2 = np.nonzero(mask)[0].min(), np.nonzero(mask)[0].max(), np.nonzero(mask)[1].min(), \
114
- np.nonzero(mask)[1].max()
115
- else:
116
- x1, y1, x2, y2 = bbox
117
- box_input = np.array([[x1, y1, x2, y2]])
118
- masks, scores, logits = predictor.predict(
119
- box=box_input,
120
- multimask_output=True,
121
- return_logits=False,
122
- )
123
- mask = masks[-1]
124
- return mask
125
-
126
-
127
- def image_examples(samples, ncols, return_key=None, example_text="Examples"):
128
- global img_example_counter
129
- trigger = False
130
- with st.expander(example_text, True):
131
- for i in range(len(samples) // ncols):
132
- cols = st.columns(ncols)
133
- for j in range(ncols):
134
- idx = i * ncols + j
135
- if idx >= len(samples):
136
- continue
137
- entry = samples[idx]
138
- with cols[j]:
139
- st.image(entry['dispi'])
140
- img_example_counter += 1
141
- with st.columns(5)[2]:
142
- this_trigger = st.button('\+', key='imgexuse%d' % img_example_counter)
143
- trigger = trigger or this_trigger
144
- if this_trigger:
145
- trigger = entry[return_key]
146
- return trigger
147
-
148
-
149
- def segment_img(img: Image):
150
- output = rembg.remove(img)
151
- mask = numpy.array(output)[:, :, 3] > 0
152
- sam_mask = SAMAPI.segment_api(numpy.array(img)[:, :, :3], mask)
153
- segmented_img = Image.new("RGBA", img.size, (0, 0, 0, 0))
154
- segmented_img.paste(img, mask=Image.fromarray(sam_mask))
155
- return segmented_img
156
-
157
-
158
- def segment_6imgs(imgs):
159
- segmented_imgs = []
160
- for i, img in enumerate(imgs):
161
- output = rembg.remove(img)
162
- mask = numpy.array(output)[:, :, 3]
163
- mask = SAMAPI.segment_api(numpy.array(img)[:, :, :3], mask)
164
- data = numpy.array(img)[:,:,:3]
165
- data[mask == 0] = [255, 255, 255]
166
- segmented_imgs.append(data)
167
- result = numpy.concatenate([
168
- numpy.concatenate([segmented_imgs[0], segmented_imgs[1]], axis=1),
169
- numpy.concatenate([segmented_imgs[2], segmented_imgs[3]], axis=1),
170
- numpy.concatenate([segmented_imgs[4], segmented_imgs[5]], axis=1)
171
- ])
172
- return Image.fromarray(result)
173
-
174
- def pack_6imgs(imgs):
175
- import pdb
176
- # pdb.set_trace()
177
- result = numpy.concatenate([
178
- numpy.concatenate([imgs[0], imgs[1]], axis=1),
179
- numpy.concatenate([imgs[2], imgs[3]], axis=1),
180
- numpy.concatenate([imgs[4], imgs[5]], axis=1)
181
- ])
182
- return Image.fromarray(result)
183
-
184
-
185
- def expand2square(pil_img, background_color):
186
- width, height = pil_img.size
187
- if width == height:
188
- return pil_img
189
- elif width > height:
190
- result = Image.new(pil_img.mode, (width, width), background_color)
191
- result.paste(pil_img, (0, (width - height) // 2))
192
- return result
193
- else:
194
- result = Image.new(pil_img.mode, (height, height), background_color)
195
- result.paste(pil_img, ((height - width) // 2, 0))
196
- return result
197
-
198
-
199
- @st.cache_data
200
- def check_dependencies():
201
- reqs = []
202
- try:
203
- import diffusers
204
- except ImportError:
205
- import traceback
206
- traceback.print_exc()
207
- print("Error: `diffusers` not found.", file=sys.stderr)
208
- reqs.append("diffusers==0.20.2")
209
- else:
210
- if not diffusers.__version__.startswith("0.20"):
211
- print(
212
- f"Warning: You are using an unsupported version of diffusers ({diffusers.__version__}), which may lead to performance issues.",
213
- file=sys.stderr
214
- )
215
- print("Recommended version is `diffusers==0.20.2`.", file=sys.stderr)
216
- try:
217
- import transformers
218
- except ImportError:
219
- import traceback
220
- traceback.print_exc()
221
- print("Error: `transformers` not found.", file=sys.stderr)
222
- reqs.append("transformers==4.29.2")
223
- if torch.__version__ < '2.0':
224
- try:
225
- import xformers
226
- except ImportError:
227
- print("Warning: You are using PyTorch 1.x without a working `xformers` installation.", file=sys.stderr)
228
- print("You may see a significant memory overhead when running the model.", file=sys.stderr)
229
- if len(reqs):
230
- print(f"Info: Fix all dependency errors with `pip install {' '.join(reqs)}`.")
231
-
232
-
233
- @st.cache_resource
234
- def load_wonder3d_pipeline():
235
- # Load scheduler, tokenizer and models.
236
- # noise_scheduler = DDPMScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler")
237
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_encoder", revision=cfg.revision)
238
- feature_extractor = CLIPImageProcessor.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="feature_extractor", revision=cfg.revision)
239
- vae = AutoencoderKL.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", revision=cfg.revision)
240
- unet = UNetMV2DConditionModel.from_pretrained_2d(cfg.pretrained_unet_path, subfolder="unet", revision=cfg.revision, **cfg.unet_from_pretrained_kwargs)
241
- unet.enable_xformers_memory_efficient_attention()
242
-
243
- # Move text_encode and vae to gpu and cast to weight_dtype
244
- image_encoder.to(dtype=weight_dtype)
245
- vae.to(dtype=weight_dtype)
246
- unet.to(dtype=weight_dtype)
247
-
248
- pipeline = MVDiffusionImagePipeline(
249
- image_encoder=image_encoder, feature_extractor=feature_extractor, vae=vae, unet=unet, safety_checker=None,
250
- scheduler=DDIMScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler"),
251
- **cfg.pipe_kwargs
252
- )
253
-
254
- if torch.cuda.is_available():
255
- pipeline.to('cuda:0')
256
- sys.main_lock = threading.Lock()
257
- return pipeline
258
-
259
- from mvdiffusion.data.single_image_dataset import SingleImageDataset
260
- def prepare_data(single_image):
261
- dataset = SingleImageDataset(
262
- root_dir = None,
263
- num_views = 6,
264
- img_wh=[256, 256],
265
- bg_color='white',
266
- crop_size=crop_size,
267
- single_image=single_image
268
- )
269
- return dataset[0]
270
-
271
-
272
- def run_pipeline(pipeline, batch, guidance_scale, seed):
273
-
274
- pipeline.set_progress_bar_config(disable=True)
275
-
276
- generator = torch.Generator(device=pipeline.unet.device).manual_seed(seed)
277
-
278
- # repeat (2B, Nv, 3, H, W)
279
- imgs_in = torch.cat([batch['imgs_in']]*2, dim=0).to(weight_dtype)
280
-
281
- # (2B, Nv, Nce)
282
- camera_embeddings = torch.cat([batch['camera_embeddings']]*2, dim=0).to(weight_dtype)
283
-
284
- task_embeddings = torch.cat([batch['normal_task_embeddings'], batch['color_task_embeddings']], dim=0).to(weight_dtype)
285
-
286
- camera_embeddings = torch.cat([camera_embeddings, task_embeddings], dim=-1).to(weight_dtype)
287
-
288
- # (B*Nv, 3, H, W)
289
- imgs_in = rearrange(imgs_in, "Nv C H W -> (Nv) C H W")
290
- # (B*Nv, Nce)
291
- # camera_embeddings = rearrange(camera_embeddings, "B Nv Nce -> (B Nv) Nce")
292
-
293
- out = pipeline(
294
- imgs_in, camera_embeddings, generator=generator, guidance_scale=guidance_scale,
295
- output_type='pt', num_images_per_prompt=1, **cfg.pipe_validation_kwargs
296
- ).images
297
-
298
- bsz = out.shape[0] // 2
299
- normals_pred = out[:bsz]
300
- images_pred = out[bsz:]
301
-
302
- normals_pred = [save_image(normals_pred[i]) for i in range(bsz)]
303
- images_pred = [save_image(images_pred[i]) for i in range(bsz)]
304
-
305
- return normals_pred, images_pred
306
-
307
- from utils.misc import load_config
308
- from omegaconf import OmegaConf
309
- # parse YAML config to OmegaConf
310
- cfg = load_config("./configs/mvdiffusion-joint-ortho-6views.yaml")
311
- # print(cfg)
312
- schema = OmegaConf.structured(TestConfig)
313
- # cfg = OmegaConf.load(args.config)
314
- cfg = OmegaConf.merge(schema, cfg)
315
-
316
- check_dependencies()
317
- pipeline = load_wonder3d_pipeline()
318
- SAMAPI.get_instance()
319
- torch.set_grad_enabled(False)
320
-
321
- st.title("Wonder3D: Single Image to 3D using Cross-Domain Diffusion")
322
- # st.caption("For faster inference without waiting in queue, you may clone the space and run it yourself.")
323
-
324
- pic = st.file_uploader("Upload an Image", key='imageinput', type=['png', 'jpg', 'webp', 'jpeg'])
325
- left, right = st.columns(2)
326
- # with left:
327
- # rem_input_bg = st.checkbox("Remove Input Background")
328
- # with right:
329
- # rem_output_bg = st.checkbox("Remove Output Background")
330
- with left:
331
- num_inference_steps = st.slider("Number of Inference Steps", 15, 100, 50)
332
- # st.caption("Diffusion Steps. For general real or synthetic objects, around 28 is enough. For objects with delicate details such as faces (either realistic or illustration), you may need 75 or more steps.")
333
- with right:
334
- cfg_scale = st.slider("Classifier Free Guidance Scale", 1.0, 10.0, 3.0)
335
- with left:
336
- seed = int(st.text_input("Seed", "42"))
337
- with right:
338
- crop_size = int(st.text_input("crop_size", "192"))
339
- # submit = False
340
- # if st.button("Submit"):
341
- # submit = True
342
- submit = True
343
- prog = st.progress(0.0, "Idle")
344
- results_container = st.container()
345
- sample_got = image_examples(iret, 5, 'rimageinput')
346
- if sample_got:
347
- pic = sample_got
348
- with results_container:
349
- if sample_got or pic is not None:
350
- prog.progress(0.03, "Waiting in Queue...")
351
-
352
- seed = int(seed)
353
- torch.manual_seed(seed)
354
- img = Image.open(pic)
355
-
356
- if max(img.size) > 1280:
357
- w, h = img.size
358
- w = round(1280 / max(img.size) * w)
359
- h = round(1280 / max(img.size) * h)
360
- img = img.resize((w, h))
361
- left, right = st.columns(2)
362
- with left:
363
- st.caption("Input Image")
364
- st.image(img)
365
- prog.progress(0.1, "Preparing Inputs")
366
-
367
- with right:
368
- img = segment_img(img)
369
- st.caption("Input (Background Removed)")
370
- st.image(img)
371
-
372
- img = expand2square(img, (127, 127, 127, 0))
373
- # pipeline.set_progress_bar_config(disable=True)
374
- prog.progress(0.3, "Run cross-domain diffusion model")
375
- data = prepare_data(img)
376
- normals_pred, images_pred = run_pipeline(pipeline, data, cfg_scale, seed)
377
- prog.progress(0.9, "finishing")
378
- left, right = st.columns(2)
379
- with left:
380
- st.caption("Generated Normals")
381
- st.image(pack_6imgs(normals_pred))
382
-
383
- with right:
384
- st.caption("Generated Color Images")
385
- st.image(pack_6imgs(images_pred))
386
- # if rem_output_bg:
387
- # normals_pred = segment_6imgs(normals_pred)
388
- # images_pred = segment_6imgs(images_pred)
389
- # with right:
390
- # st.image(normals_pred)
391
- # st.image(images_pred)
392
- # st.caption("Result (Background Removed)")
393
- prog.progress(1.0, "Idle")