tolgacangoz commited on
Commit
a8cd280
1 Parent(s): f7a2a5a

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. unet/matryoshka.py +17 -5
unet/matryoshka.py CHANGED
@@ -660,12 +660,24 @@ class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin):
660
  )
661
 
662
  if variance_noise is None:
663
- variance_noise = randn_tensor(
664
- model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
665
- )
666
- variance = std_dev_t * variance_noise
 
 
 
 
 
 
 
 
 
 
 
 
667
 
668
- prev_sample = prev_sample + variance
669
 
670
  if not return_dict:
671
  return (prev_sample,)
 
660
  )
661
 
662
  if variance_noise is None:
663
+ if len(model_output) > 1:
664
+ variance_noise = []
665
+ for m_o in model_output:
666
+ variance_noise.append(
667
+ randn_tensor(
668
+ m_o.shape, generator=generator, device=m_o.device, dtype=m_o.dtype
669
+ )
670
+ )
671
+ else:
672
+ variance_noise = randn_tensor(
673
+ model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
674
+ )
675
+ if len(model_output) > 1:
676
+ prev_sample = [p_s + std_dev_t * v_n for v_n, p_s in zip(variance_noise, prev_sample)]
677
+ else:
678
+ variance = std_dev_t * variance_noise
679
 
680
+ prev_sample = prev_sample + variance
681
 
682
  if not return_dict:
683
  return (prev_sample,)