sachit-menon commited on
Commit
b207f23
1 Parent(s): 8fed5b5

Update model/sd_model.py

Browse files
Files changed (1) hide show
  1. model/sd_model.py +8 -5
model/sd_model.py CHANGED
@@ -59,14 +59,14 @@ class SDModelConfig(BaseModelConfig):
59
  pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5"
60
  conditioning_dropout_prob: float = 0.05
61
  use_ema: bool = True
62
- concat_all_steps: bool = II("dataset.concat_all_steps")
63
  positional_encoding_type: Optional[str] = "sinusoidal"
64
  positional_encoding_length: Optional[int] = None
65
  image_positional_encoding_type: Optional[str] = None #"sinusoidal"
66
  image_positional_encoding_length: Optional[int] = None
67
  broadcast_positional_encoding: bool = True
68
- sequence_length: Optional[int] = II("dataset.sequence_length") # TODO consider changing interp on next line to this +1?
69
- text_sequence_length: Optional[int] = II("dataset.text_sequence_length")
70
  use_lora: bool = False
71
  # lora_cfg: Any = LoraConfig()
72
  zero_snr: bool = True
@@ -76,9 +76,12 @@ class SDModelConfig(BaseModelConfig):
76
 
77
 
78
  class SDModel(ModelMixin, ConfigMixin, PushToHubMixin):
79
- def __init__(self, cfg: SDModelConfig) -> None:
80
  super().__init__()
81
- self.cfg = cfg
 
 
 
82
  self.noise_scheduler = DDPMScheduler.from_pretrained(
83
  self.cfg.pretrained_model_name_or_path,
84
  subfolder="scheduler",
 
59
  pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5"
60
  conditioning_dropout_prob: float = 0.05
61
  use_ema: bool = True
62
+ concat_all_steps: bool = False
63
  positional_encoding_type: Optional[str] = "sinusoidal"
64
  positional_encoding_length: Optional[int] = None
65
  image_positional_encoding_type: Optional[str] = None #"sinusoidal"
66
  image_positional_encoding_length: Optional[int] = None
67
  broadcast_positional_encoding: bool = True
68
+ sequence_length: Optional[int] = 6
69
+ text_sequence_length: Optional[int] = 7
70
  use_lora: bool = False
71
  # lora_cfg: Any = LoraConfig()
72
  zero_snr: bool = True
 
76
 
77
 
78
  class SDModel(ModelMixin, ConfigMixin, PushToHubMixin):
79
+ def __init__(self, cfg: SDModelConfig = None) -> None:
80
  super().__init__()
81
+ if cfg is None: # workaround for default
82
+ cfg = SDModelConfig()
83
+ else:
84
+ self.cfg = cfg
85
  self.noise_scheduler = DDPMScheduler.from_pretrained(
86
  self.cfg.pretrained_model_name_or_path,
87
  subfolder="scheduler",