tolgacangoz commited on
Commit
5e1ad73
1 Parent(s): 05a9a16

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. scheduler/matryoshka.py +6 -4
scheduler/matryoshka.py CHANGED
@@ -642,16 +642,16 @@ class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin):
642
  if self.config.thresholding:
643
  if len(model_output) > 1:
644
  pred_original_sample = [
645
- self._threshold_sample(p_o_s * scale) / scale
646
- for p_o_s, scale in zip(pred_original_sample, self.scales)
647
  ]
648
  else:
649
  pred_original_sample = self._threshold_sample(pred_original_sample)
650
  elif self.config.clip_sample:
651
  if len(model_output) > 1:
652
  pred_original_sample = [
653
- (p_o_s * scale).clamp(-self.config.clip_sample_range, self.config.clip_sample_range) / scale
654
- for p_o_s, scale in zip(pred_original_sample, self.scales)
655
  ]
656
  else:
657
  pred_original_sample = pred_original_sample.clamp(
@@ -3846,12 +3846,14 @@ class MatryoshkaPipeline(
3846
  ).to(self.device)
3847
  self.config.nesting_level = 1
3848
  self.scheduler.scales = self.unet.nest_ratio + [1]
 
3849
  elif nesting_level == 2:
3850
  self.unet = NestedUNet2DConditionModel.from_pretrained(
3851
  "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_2"
3852
  ).to(self.device)
3853
  self.config.nesting_level = 2
3854
  self.scheduler.scales = self.unet.nest_ratio + [1]
 
3855
  else:
3856
  raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
3857
 
 
642
  if self.config.thresholding:
643
  if len(model_output) > 1:
644
  pred_original_sample = [
645
+ self._threshold_sample(p_o_s)
646
+ for p_o_s in pred_original_sample
647
  ]
648
  else:
649
  pred_original_sample = self._threshold_sample(pred_original_sample)
650
  elif self.config.clip_sample:
651
  if len(model_output) > 1:
652
  pred_original_sample = [
653
+ p_o_s.clamp(-self.config.clip_sample_range, self.config.clip_sample_range)
654
+ for p_o_s in pred_original_sample
655
  ]
656
  else:
657
  pred_original_sample = pred_original_sample.clamp(
 
3846
  ).to(self.device)
3847
  self.config.nesting_level = 1
3848
  self.scheduler.scales = self.unet.nest_ratio + [1]
3849
+ self.scheduler.schedule_shifted_power = 1.0
3850
  elif nesting_level == 2:
3851
  self.unet = NestedUNet2DConditionModel.from_pretrained(
3852
  "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_2"
3853
  ).to(self.device)
3854
  self.config.nesting_level = 2
3855
  self.scheduler.scales = self.unet.nest_ratio + [1]
3856
+ self.scheduler.schedule_shifted_power = 2.0
3857
  else:
3858
  raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
3859