hansyan commited on
Commit
55b50fa
β€’
1 Parent(s): c5df60d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -14,6 +14,12 @@ from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orienta
14
  from src.scheduler_perflow import PeRFlowScheduler
15
  from diffusers import StableDiffusionPipeline, UNet2DConditionModel
16
 
 
 
 
 
 
 
17
  def merge_delta_weights_into_unet(pipe, delta_weights, org_alpha = 1.0):
18
  unet_weights = pipe.unet.state_dict()
19
  for key in delta_weights.keys():
@@ -72,7 +78,7 @@ def generate(text, seed):
72
  return image
73
 
74
  setup_seed(int(seed))
75
- prompt_prefix = "high quality, best quality, masterpiece; "
76
  neg_prompt = "EasyNegative, drawn by bad-artist, sketch by bad-artist-anime, (bad_prompt:0.8), (artist name, signature, watermark:1.4), (ugly:1.2), (worst quality, poor details:1.4), bad-hands-5, badhandv4, blurry"
77
  text = prompt_prefix + text
78
  samples = pipe_t2i(
@@ -86,18 +92,20 @@ def generate(text, seed):
86
  guidance_scale = 7.5,
87
  output_type = 'pt',
88
  ).images
89
- samples = torch.nn.functional.interpolate(samples, size=768, mode='bilinear')
90
  samples = samples.squeeze(0).permute(1, 2, 0).cpu().numpy()*255.
91
  samples = samples.astype(np.uint8)
92
  samples = Image.fromarray(samples[:, :, :3])
 
93
 
94
- image = remove_background(samples, rembg_session)
95
- image = resize_foreground(image, 0.85)
96
- image = fill_background(image)
97
- return image
98
 
99
  @spaces.GPU
100
  def render(image, mc_resolution=256, formats=["obj"]):
 
 
 
 
 
 
101
  scene_codes = model(image, device=device)
102
  mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
103
  mesh = to_gradio_3d_orientation(mesh)
 
14
  from src.scheduler_perflow import PeRFlowScheduler
15
  from diffusers import StableDiffusionPipeline, UNet2DConditionModel
16
 
17
+ def fill_background(img):
18
+ img = np.array(img).astype(np.float32) / 255.0
19
+ img = img[:, :, :3] * img[:, :, 3:4] + (1 - img[:, :, 3:4]) * 0.5
20
+ img = Image.fromarray((img * 255.0).astype(np.uint8))
21
+ return img
22
+
23
  def merge_delta_weights_into_unet(pipe, delta_weights, org_alpha = 1.0):
24
  unet_weights = pipe.unet.state_dict()
25
  for key in delta_weights.keys():
 
78
  return image
79
 
80
  setup_seed(int(seed))
81
+ prompt_prefix = "high quality, best quality, highly detailed, masterpiece; "
82
  neg_prompt = "EasyNegative, drawn by bad-artist, sketch by bad-artist-anime, (bad_prompt:0.8), (artist name, signature, watermark:1.4), (ugly:1.2), (worst quality, poor details:1.4), bad-hands-5, badhandv4, blurry"
83
  text = prompt_prefix + text
84
  samples = pipe_t2i(
 
92
  guidance_scale = 7.5,
93
  output_type = 'pt',
94
  ).images
 
95
  samples = samples.squeeze(0).permute(1, 2, 0).cpu().numpy()*255.
96
  samples = samples.astype(np.uint8)
97
  samples = Image.fromarray(samples[:, :, :3])
98
+ return samples
99
 
 
 
 
 
100
 
101
  @spaces.GPU
102
  def render(image, mc_resolution=256, formats=["obj"]):
103
+ image = Image.fromarray(image)
104
+ image = image.resize((768, 768))
105
+ image = remove_background(image, rembg_session)
106
+ image = resize_foreground(image, 0.85)
107
+ image = fill_background(image)
108
+
109
  scene_codes = model(image, device=device)
110
  mesh = model.extract_mesh(scene_codes, resolution=mc_resolution)[0]
111
  mesh = to_gradio_3d_orientation(mesh)