benibraz commited on
Commit
8920e2d
1 Parent(s): 87b6e69

VAE: Check for timesteps parameter in decoder before calling

Browse files
Files changed (1) hide show
  1. xora/models/autoencoders/vae.py +7 -1
xora/models/autoencoders/vae.py CHANGED
@@ -1,6 +1,7 @@
1
  from typing import Optional, Union
2
 
3
  import torch
 
4
  import math
5
  import torch.nn as nn
6
  from diffusers import ConfigMixin, ModelMixin
@@ -60,6 +61,8 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
60
  self.dims = dims
61
  self.z_sample_size = 1
62
 
 
 
63
  # only relevant if vae tiling is enabled
64
  self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25)
65
 
@@ -257,7 +260,10 @@ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
257
  timesteps: Optional[torch.Tensor] = None,
258
  ) -> Union[DecoderOutput, torch.FloatTensor]:
259
  z = self.post_quant_conv(z)
260
- dec = self.decoder(z, target_shape=target_shape, timesteps=timesteps)
 
 
 
261
  return dec
262
 
263
  def decode(
 
1
  from typing import Optional, Union
2
 
3
  import torch
4
+ import inspect
5
  import math
6
  import torch.nn as nn
7
  from diffusers import ConfigMixin, ModelMixin
 
61
  self.dims = dims
62
  self.z_sample_size = 1
63
 
64
+ self.decoder_params = inspect.signature(self.decoder.forward).parameters
65
+
66
  # only relevant if vae tiling is enabled
67
  self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25)
68
 
 
260
  timesteps: Optional[torch.Tensor] = None,
261
  ) -> Union[DecoderOutput, torch.FloatTensor]:
262
  z = self.post_quant_conv(z)
263
+ if "timesteps" in self.decoder_params:
264
+ dec = self.decoder(z, target_shape=target_shape, timesteps=timesteps)
265
+ else:
266
+ dec = self.decoder(z, target_shape=target_shape)
267
  return dec
268
 
269
  def decode(