ironjr commited on
Commit
3913e77
1 Parent(s): 96fa1ab

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +41 -14
model.py CHANGED
@@ -295,8 +295,12 @@ class StreamMultiDiffusion(nn.Module):
295
  def reset_latent(self) -> None:
296
  # initialize x_t_latent (it can be any random tensor)
297
  b = (self.denoising_steps_num - 1) * self.frame_bff_size
298
- self.x_t_latent_buffer = torch.zeros(
299
- (b, 4, self.latent_height, self.latent_width), dtype=self.dtype, device=self.device)
 
 
 
 
300
 
301
  def reset_state(self) -> None:
302
  # TODO Reset states for context switch between multiple users.
@@ -305,24 +309,35 @@ class StreamMultiDiffusion(nn.Module):
305
  def prepare(self) -> None:
306
  # make sub timesteps list based on the indices in the t_list list and the values in the timesteps list
307
  self.timesteps = self.scheduler.timesteps.to(self.device)
308
- self.sub_timesteps = []
309
  for t in self.t_list:
310
- self.sub_timesteps.append(self.timesteps[t])
311
- sub_timesteps_tensor = torch.tensor(self.sub_timesteps, dtype=torch.long, device=self.device)
312
- self.sub_timesteps_tensor = sub_timesteps_tensor.repeat_interleave(self.frame_bff_size, dim=0)
 
 
 
313
 
314
  c_skip_list = []
315
  c_out_list = []
316
- for timestep in self.sub_timesteps:
317
  c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete(timestep)
318
  c_skip_list.append(c_skip)
319
  c_out_list.append(c_out)
320
- self.c_skip = torch.stack(c_skip_list).view(len(self.t_list), 1, 1, 1).to(dtype=self.dtype, device=self.device)
321
- self.c_out = torch.stack(c_out_list).view(len(self.t_list), 1, 1, 1).to(dtype=self.dtype, device=self.device)
 
 
 
 
 
 
 
 
322
 
323
  alpha_prod_t_sqrt_list = []
324
  beta_prod_t_sqrt_list = []
325
- for timestep in self.sub_timesteps:
326
  alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt()
327
  beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt()
328
  alpha_prod_t_sqrt_list.append(alpha_prod_t_sqrt)
@@ -331,12 +346,24 @@ class StreamMultiDiffusion(nn.Module):
331
  .to(dtype=self.dtype, device=self.device))
332
  beta_prod_t_sqrt = (torch.stack(beta_prod_t_sqrt_list).view(len(self.t_list), 1, 1, 1)
333
  .to(dtype=self.dtype, device=self.device))
334
- self.alpha_prod_t_sqrt = alpha_prod_t_sqrt.repeat_interleave(self.frame_bff_size, dim=0)
335
- self.beta_prod_t_sqrt = beta_prod_t_sqrt.repeat_interleave(self.frame_bff_size, dim=0)
 
 
 
 
 
 
336
 
337
  noise_lvs = ((1 - self.scheduler.alphas_cumprod.to(self.device)[self.sub_timesteps_tensor]) ** 0.5)
338
- self.noise_lvs = noise_lvs[None, :, None, None, None]
339
- self.next_noise_lvs = torch.cat([noise_lvs[1:], noise_lvs.new_zeros(1)])[None, :, None, None, None]
 
 
 
 
 
 
340
 
341
  @torch.no_grad()
342
  def get_text_prompts(self, image: Image.Image) -> str:
 
295
  def reset_latent(self) -> None:
296
  # initialize x_t_latent (it can be any random tensor)
297
  b = (self.denoising_steps_num - 1) * self.frame_bff_size
298
+ if not hasattr(self, 'x_t_latent_buffer'):
299
+ self.register_buffer('x_t_latent_buffer', torch.zeros(
300
+ (b, 4, self.latent_height, self.latent_width), dtype=self.dtype, device=self.device))
301
+ else:
302
+ self.x_t_latent_buffer = torch.zeros(
303
+ (b, 4, self.latent_height, self.latent_width), dtype=self.dtype, device=self.device))
304
 
305
  def reset_state(self) -> None:
306
  # TODO Reset states for context switch between multiple users.
 
309
  def prepare(self) -> None:
310
  # make sub timesteps list based on the indices in the t_list list and the values in the timesteps list
311
  self.timesteps = self.scheduler.timesteps.to(self.device)
312
+ sub_timesteps = []
313
  for t in self.t_list:
314
+ sub_timesteps.append(self.timesteps[t])
315
+ sub_timesteps_tensor = torch.tensor(sub_timesteps, dtype=torch.long, device=self.device)
316
+ if not hasattr(self, 'sub_timesteps_tensor'):
317
+ self.register_buffer('sub_timesteps_tensor', sub_timesteps_tensor.repeat_interleave(self.frame_bff_size, dim=0)
318
+ else:
319
+ self.sub_timesteps_tensor = sub_timesteps_tensor.repeat_interleave(self.frame_bff_size, dim=0)
320
 
321
  c_skip_list = []
322
  c_out_list = []
323
+ for timestep in sub_timesteps:
324
  c_skip, c_out = self.scheduler.get_scalings_for_boundary_condition_discrete(timestep)
325
  c_skip_list.append(c_skip)
326
  c_out_list.append(c_out)
327
+ c_skip = torch.stack(c_skip_list).view(len(self.t_list), 1, 1, 1).to(dtype=self.dtype, device=self.device)
328
+ c_out = torch.stack(c_out_list).view(len(self.t_list), 1, 1, 1).to(dtype=self.dtype, device=self.device)
329
+ if not hasattr(self, 'c_skip'):
330
+ self.register_buffer('c_skip', c_skip)
331
+ else:
332
+ self.c_skip = c_skip
333
+ if not hasattr(self, 'c_out'):
334
+ self.register_buffer('c_out', c_out)
335
+ else:
336
+ self.c_out = c_out
337
 
338
  alpha_prod_t_sqrt_list = []
339
  beta_prod_t_sqrt_list = []
340
+ for timestep in sub_timesteps:
341
  alpha_prod_t_sqrt = self.scheduler.alphas_cumprod[timestep].sqrt()
342
  beta_prod_t_sqrt = (1 - self.scheduler.alphas_cumprod[timestep]).sqrt()
343
  alpha_prod_t_sqrt_list.append(alpha_prod_t_sqrt)
 
346
  .to(dtype=self.dtype, device=self.device))
347
  beta_prod_t_sqrt = (torch.stack(beta_prod_t_sqrt_list).view(len(self.t_list), 1, 1, 1)
348
  .to(dtype=self.dtype, device=self.device))
349
+ if not hasattr(self, 'alpha_prod_t_sqrt'):
350
+ self.register_buffer('alpha_prod_t_sqrt', alpha_prod_t_sqrt.repeat_interleave(self.frame_bff_size, dim=0))
351
+ else:
352
+ self.alpha_prod_t_sqrt = alpha_prod_t_sqrt.repeat_interleave(self.frame_bff_size, dim=0)
353
+ if not hasattr(self, 'beta_prod_t_sqrt'):
354
+ self.register_buffer('beta_prod_t_sqrt', beta_prod_t_sqrt.repeat_interleave(self.frame_bff_size, dim=0))
355
+ else:
356
+ self.beta_prod_t_sqrt = beta_prod_t_sqrt.repeat_interleave(self.frame_bff_size, dim=0)
357
 
358
  noise_lvs = ((1 - self.scheduler.alphas_cumprod.to(self.device)[self.sub_timesteps_tensor]) ** 0.5)
359
+ if not hasattr(self, 'noise_lvs'):
360
+ self.register_buffer('noise_lvs', noise_lvs[None, :, None, None, None])
361
+ else:
362
+ self.noise_lvs = noise_lvs[None, :, None, None, None]
363
+ if not hasattr(self, 'next_noise_lvs'):
364
+ self.register_buffer('next_noise_lvs', torch.cat([noise_lvs[1:], noise_lvs.new_zeros(1)])[None, :, None, None, None])
365
+ else:
366
+ self.next_noise_lvs = torch.cat([noise_lvs[1:], noise_lvs.new_zeros(1)])[None, :, None, None, None]
367
 
368
  @torch.no_grad()
369
  def get_text_prompts(self, image: Image.Image) -> str: