hysts HF Staff commited on
Commit
a2ee6bd
·
1 Parent(s): cfde09c
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -76,17 +76,17 @@ def generate_image(model: nn.Module, z: torch.Tensor, truncation_psi: float,
76
 
77
 
78
  @torch.inference_mode()
79
- def generate_interpolated_images(seed0: int, seed1: int, num_intermediate: int,
80
- psi0: float, psi1: float,
81
- randomize_noise: bool, model: nn.Module,
82
- device: torch.device) -> np.ndarray:
83
  seed0 = int(np.clip(seed0, 0, np.iinfo(np.uint32).max))
84
  seed1 = int(np.clip(seed1, 0, np.iinfo(np.uint32).max))
85
 
86
  z0 = generate_z(model.style_dim, seed0, device)
87
  if num_intermediate == -1:
88
  out = generate_image(model, z0, psi0, randomize_noise)
89
- return out
90
 
91
  z1 = generate_z(model.style_dim, seed1, device)
92
  vec = z1 - z0
@@ -98,8 +98,8 @@ def generate_interpolated_images(seed0: int, seed1: int, num_intermediate: int,
98
  for z, psi in zip(zs, psis):
99
  out = generate_image(model, z, psi, randomize_noise)
100
  res.append(out)
101
- res = np.hstack(res)
102
- return res
103
 
104
 
105
  def main():
@@ -129,7 +129,7 @@ def main():
129
  gr.inputs.Number(default=29703, label='Seed 1'),
130
  gr.inputs.Number(default=55376, label='Seed 2'),
131
  gr.inputs.Slider(-1,
132
- 11,
133
  step=1,
134
  default=3,
135
  label='Number of Intermediate Frames'),
@@ -139,7 +139,11 @@ def main():
139
  0, 2, step=0.05, default=0.7, label='Truncation psi 2'),
140
  gr.inputs.Checkbox(default=False, label='Randomize Noise'),
141
  ],
142
- gr.outputs.Image(type='numpy', label='Output'),
 
 
 
 
143
  examples=examples,
144
  title=TITLE,
145
  description=DESCRIPTION,
 
76
 
77
 
78
  @torch.inference_mode()
79
+ def generate_interpolated_images(
80
+ seed0: int, seed1: int, num_intermediate: int, psi0: float,
81
+ psi1: float, randomize_noise: bool, model: nn.Module,
82
+ device: torch.device) -> tuple[list[np.ndarray], np.ndarray]:
83
  seed0 = int(np.clip(seed0, 0, np.iinfo(np.uint32).max))
84
  seed1 = int(np.clip(seed1, 0, np.iinfo(np.uint32).max))
85
 
86
  z0 = generate_z(model.style_dim, seed0, device)
87
  if num_intermediate == -1:
88
  out = generate_image(model, z0, psi0, randomize_noise)
89
+ return [out], None
90
 
91
  z1 = generate_z(model.style_dim, seed1, device)
92
  vec = z1 - z0
 
98
  for z, psi in zip(zs, psis):
99
  out = generate_image(model, z, psi, randomize_noise)
100
  res.append(out)
101
+ concatenated = np.hstack(res)
102
+ return res, concatenated
103
 
104
 
105
  def main():
 
129
  gr.inputs.Number(default=29703, label='Seed 1'),
130
  gr.inputs.Number(default=55376, label='Seed 2'),
131
  gr.inputs.Slider(-1,
132
+ 21,
133
  step=1,
134
  default=3,
135
  label='Number of Intermediate Frames'),
 
139
  0, 2, step=0.05, default=0.7, label='Truncation psi 2'),
140
  gr.inputs.Checkbox(default=False, label='Randomize Noise'),
141
  ],
142
+ [
143
+ gr.outputs.Carousel(gr.outputs.Image(type='numpy'),
144
+ label='Output Images'),
145
+ gr.outputs.Image(type='numpy', label='Concatenated'),
146
+ ],
147
  examples=examples,
148
  title=TITLE,
149
  description=DESCRIPTION,