yamildiego commited on
Commit
15a8194
1 Parent(s): 471adc0
Files changed (1) hide show
  1. handler.py +10 -21
handler.py CHANGED
@@ -31,29 +31,18 @@ class EndpointHandler():
31
  #self.pipe.enable_vae_tiling()
32
  self.generator = torch.Generator(device=device.type).manual_seed(3)
33
 
34
-
35
- from typing import Optional
36
- from torch import Tensor
37
- from torch.nn import functional as F
38
- from torch.nn import Conv2d
39
- from torch.nn.modules.utils import _pair
40
-
41
- def asymmetricConv2DConvForward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
42
- self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
43
- self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
44
- working = F.pad(input, self.paddingX, mode='circular')
45
- working = F.pad(working, self.paddingY, mode='constant')
46
- return F.conv2d(working, weight, bias, self.stride, _pair(0), self.dilation, self.groups)
47
-
48
- targets = [pipe.vae, pipe.text_encoder, pipe.unet,]
49
- conv_layers = []
50
  for target in targets:
51
  for module in target.modules():
52
- if isinstance(module, torch.nn.Conv2d):
53
- conv_layers.append(module)
54
-
55
- for cl in conv_layers:
56
- cl._conv_forward = asymmetricConv2DConvForward.__get__(cl, Conv2d)
57
 
58
 
59
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
 
31
  #self.pipe.enable_vae_tiling()
32
  self.generator = torch.Generator(device=device.type).manual_seed(3)
33
 
34
+ targets = [
35
+ self.pipe.vae,
36
+ self.pipe.text_encoder,
37
+ self.pipe.unet,
38
+ ]
39
+ self.conv_layers = []
40
+ self.conv_layers_original_paddings = []
 
 
 
 
 
 
 
 
 
41
  for target in targets:
42
  for module in target.modules():
43
+ if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.ConvTranspose2d):
44
+ self.conv_layers.append(module)
45
+ self.conv_layers_original_paddings.append(module.padding_mode)
 
 
46
 
47
 
48
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]: