wyysf commited on
Commit
f5810c5
·
verified ·
1 Parent(s): 891b5b9

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +79 -24
gradio_app.py CHANGED
@@ -14,6 +14,7 @@ from collections import OrderedDict
14
  import trimesh
15
  import gradio as gr
16
  from typing import Any
 
17
 
18
  proj_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
19
  sys.path.append(os.path.join(proj_dir))
@@ -58,6 +59,8 @@ If you have any questions, feel free to open a discussion or contact us at <b>we
58
  """
59
  from apps.third_party.CRM.pipelines import TwoStagePipeline
60
  from apps.third_party.LGM.pipeline_mvdream import MVDreamPipeline
 
 
61
 
62
  import re
63
  import os
@@ -88,22 +91,25 @@ chmod(f"{parent_dir}/apps/third_party/InstantMeshes", "777")
88
 
89
  model = None
90
  cached_dir = None
91
- stage1_config = OmegaConf.load(f"{parent_dir}/apps/third_party/CRM/configs/nf7_v3_SNR_rd_size_stroke.yaml").config
92
- stage1_sampler_config = stage1_config.sampler
93
- stage1_model_config = stage1_config.models
94
- stage1_model_config.resume = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth", repo_type="model")
95
- stage1_model_config.config = f"{parent_dir}/apps/third_party/CRM/" + stage1_model_config.config
96
  crm_pipeline = None
97
 
98
  sys.path.append(f"apps/third_party/LGM")
99
  imgaedream_pipeline = None
100
 
 
 
 
101
  @spaces.GPU
102
  def gen_mvimg(
103
  mvimg_model, image, seed, guidance_scale, step, text, neg_text, elevation, backgroud_color
104
  ):
105
  if seed == 0:
106
  seed = np.random.randint(1, 65535)
 
 
107
 
108
  if mvimg_model == "CRM":
109
  global crm_pipeline
@@ -118,7 +124,7 @@ def gen_mvimg(
118
  return mv_imgs[5], mv_imgs[3], mv_imgs[2], mv_imgs[0]
119
 
120
  elif mvimg_model == "ImageDream":
121
- global imagedream_pipeline, generator
122
  background = Image.new("RGBA", image.size, backgroud_color)
123
  image = Image.alpha_composite(background, image)
124
  image = np.array(image).astype(np.float32) / 255.0
@@ -130,9 +136,36 @@ def gen_mvimg(
130
  guidance_scale=guidance_scale,
131
  num_inference_steps=step,
132
  elevation=elevation,
 
133
  )
134
  return mv_imgs[1], mv_imgs[2], mv_imgs[3], mv_imgs[0]
135
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  @spaces.GPU
138
  def image2mesh(view_front: np.ndarray,
@@ -209,24 +242,46 @@ if __name__=="__main__":
209
  "Auto Remove Background": "Auto Remove Background",
210
  "Original Image": "Original Image",
211
  })
212
- mvimg_model_config_list = ["CRM", "ImageDream"]
213
- crm_pipeline = TwoStagePipeline(
214
- stage1_model_config,
215
- stage1_sampler_config,
216
- device=device,
217
- dtype=torch.float16
218
- )
219
- imagedream_pipeline = MVDreamPipeline.from_pretrained(
220
- "ashawkey/imagedream-ipmv-diffusers", # remote weights
221
- torch_dtype=torch.float16,
222
- trust_remote_code=True,
223
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
  # for 3D latent set diffusion
226
- ckpt_path = "./ckpts/image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/model.ckpt"
227
- config_path = "./ckpts/image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/config.yaml"
228
- # ckpt_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/model.ckpt", repo_type="model")
229
- # config_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/config.yaml", repo_type="model")
230
  # ckpt_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/model-300k.ckpt", repo_type="model")
231
  # config_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/config.yaml", repo_type="model")
232
  scheluder_dict = OrderedDict({
@@ -266,7 +321,7 @@ if __name__=="__main__":
266
  gr.Markdown('''Try a different <b>seed and MV Model</b> for better results. Good Luck :)''')
267
  with gr.Row():
268
  seed = gr.Number(0, label='Seed', show_label=True)
269
- mvimg_model = gr.Dropdown(value="CRM", label="MV Image Model", choices=list(mvimg_model_config_list))
270
  more = gr.CheckboxGroup(["Remesh", "Symmetry(TBD)"], label="More", show_label=False)
271
  with gr.Row():
272
  # input prompt
 
14
  import trimesh
15
  import gradio as gr
16
  from typing import Any
17
+ from einops import rearrange
18
 
19
  proj_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
20
  sys.path.append(os.path.join(proj_dir))
 
59
  """
60
  from apps.third_party.CRM.pipelines import TwoStagePipeline
61
  from apps.third_party.LGM.pipeline_mvdream import MVDreamPipeline
62
+ from apps.third_party.Era3D.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline
63
+ from apps.third_party.Era3D.data.single_image_dataset import SingleImageDataset
64
 
65
  import re
66
  import os
 
91
 
92
  model = None
93
  cached_dir = None
94
+ generator = None
95
+
96
+ sys.path.append(f"apps/third_party/CRM")
 
 
97
  crm_pipeline = None
98
 
99
  sys.path.append(f"apps/third_party/LGM")
100
  imgaedream_pipeline = None
101
 
102
+ sys.path.append(f"apps/third_party/Era3D")
103
+ era3d_pipeline = None
104
+
105
  @spaces.GPU
106
  def gen_mvimg(
107
  mvimg_model, image, seed, guidance_scale, step, text, neg_text, elevation, backgroud_color
108
  ):
109
  if seed == 0:
110
  seed = np.random.randint(1, 65535)
111
+ global generator
112
+ generator.manual_seed(seed)
113
 
114
  if mvimg_model == "CRM":
115
  global crm_pipeline
 
124
  return mv_imgs[5], mv_imgs[3], mv_imgs[2], mv_imgs[0]
125
 
126
  elif mvimg_model == "ImageDream":
127
+ global imagedream_pipeline
128
  background = Image.new("RGBA", image.size, backgroud_color)
129
  image = Image.alpha_composite(background, image)
130
  image = np.array(image).astype(np.float32) / 255.0
 
136
  guidance_scale=guidance_scale,
137
  num_inference_steps=step,
138
  elevation=elevation,
139
+ generator=generator,
140
  )
141
  return mv_imgs[1], mv_imgs[2], mv_imgs[3], mv_imgs[0]
142
+
143
+ elif mvimg_model == "Era3D":
144
+ global era3d_pipeline
145
+ crop_size = 420
146
+ batch = SingleImageDataset(root_dir='', num_views=6, img_wh=[512, 512], bg_color='white',
147
+ crop_size=crop_size, single_image=image, prompt_embeds_path='apps/third_party/Era3D/data/fixed_prompt_embeds_6view')[0]
148
+ imgs_in = torch.cat([batch['imgs_in']]*2, dim=0)
149
+ imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W")# (B*Nv, 3, H, W)
150
+
151
+ normal_prompt_embeddings, clr_prompt_embeddings = batch['normal_prompt_embeddings'], batch['color_prompt_embeddings']
152
+ prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0)
153
+ prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C")
154
+
155
+ imgs_in = imgs_in.to(device=device, dtype=torch.float16)
156
+ prompt_embeddings = prompt_embeddings.to(device=device, dtype=torch.float16)
157
+
158
+ mv_imgs = era3d_pipeline(
159
+ imgs_in,
160
+ None,
161
+ prompt_embeds=prompt_embeddings,
162
+ generator=generator,
163
+ guidance_scale=guidance_scale,
164
+ num_inference_steps=step,
165
+ num_images_per_prompt=1,
166
+ **{'eta': 1.0}
167
+ ).images
168
+ return mv_imgs[6], mv_imgs[8], mv_imgs[9], mv_imgs[10]
169
 
170
  @spaces.GPU
171
  def image2mesh(view_front: np.ndarray,
 
242
  "Auto Remove Background": "Auto Remove Background",
243
  "Original Image": "Original Image",
244
  })
245
+ mvimg_model_config_list = [
246
+ "Era3D",
247
+ # "CRM",
248
+ # "ImageDream"
249
+ ]
250
+ if "Era3D" in mvimg_model_config_list:
251
+ # cfg = load_config("apps/third_party/Era3D/configs/test_unclip-512-6view.yaml")
252
+ # schema = OmegaConf.structured(TestConfig)
253
+ # cfg = OmegaConf.merge(schema, cfg)
254
+ era3d_pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(
255
+ 'pengHTYX/MacLab-Era3D-512-6view',
256
+ torch_dtype=torch.float16
257
+ )
258
+ # enable xformers
259
+ era3d_pipeline.unet.enable_xformers_memory_efficient_attention()
260
+ era3d_pipeline.to(device)
261
+ elif "CRM" in mvimg_model_config_list:
262
+ stage1_config = OmegaConf.load(f"apps/third_party/CRM/configs/nf7_v3_SNR_rd_size_stroke.yaml").config
263
+ stage1_sampler_config = stage1_config.sampler
264
+ stage1_model_config = stage1_config.models
265
+ stage1_model_config.resume = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth", repo_type="model")
266
+ stage1_model_config.config = f"apps/third_party/CRM/" + stage1_model_config.config
267
+ crm_pipeline = TwoStagePipeline(
268
+ stage1_model_config,
269
+ stage1_sampler_config,
270
+ device=device,
271
+ dtype=torch.float16
272
+ )
273
+ elif "ImageDream" in mvimg_model_config_list:
274
+ imagedream_pipeline = MVDreamPipeline.from_pretrained(
275
+ "ashawkey/imagedream-ipmv-diffusers", # remote weights
276
+ torch_dtype=torch.float16,
277
+ trust_remote_code=True,
278
+ )
279
+
280
+ generator = torch.Generator(device)
281
 
282
  # for 3D latent set diffusion
283
+ ckpt_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6-aligned-vae/model.ckpt", repo_type="model")
284
+ config_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6-aligned-vae/config.yaml", repo_type="model")
 
 
285
  # ckpt_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/model-300k.ckpt", repo_type="model")
286
  # config_path = hf_hub_download(repo_id="wyysf/CraftsMan", filename="image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6/config.yaml", repo_type="model")
287
  scheluder_dict = OrderedDict({
 
321
  gr.Markdown('''Try a different <b>seed and MV Model</b> for better results. Good Luck :)''')
322
  with gr.Row():
323
  seed = gr.Number(0, label='Seed', show_label=True)
324
+ mvimg_model = gr.Dropdown(value="Era3D", label="MV Image Model", choices=list(mvimg_model_config_list))
325
  more = gr.CheckboxGroup(["Remesh", "Symmetry(TBD)"], label="More", show_label=False)
326
  with gr.Row():
327
  # input prompt