File size: 3,010 Bytes
85ce65e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch.nn as nn
from torch.nn.utils import remove_weight_norm, weight_norm


class Depthwise_Separable_Conv1D(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride = 1,
        padding = 0,
        dilation = 1,
        bias = True,
        padding_mode = 'zeros',  # TODO: refine this type
        device=None,
        dtype=None
    ):
      super().__init__()
      self.depth_conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype)
      self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype)
    
    def forward(self, input):
      return self.point_conv(self.depth_conv(input))

    def weight_norm(self):
      self.depth_conv = weight_norm(self.depth_conv, name = 'weight')
      self.point_conv = weight_norm(self.point_conv, name = 'weight')

    def remove_weight_norm(self):
      self.depth_conv = remove_weight_norm(self.depth_conv, name = 'weight')
      self.point_conv = remove_weight_norm(self.point_conv, name = 'weight')

class Depthwise_Separable_TransposeConv1D(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride = 1,
        padding = 0, 
        output_padding = 0,
        bias = True,
        dilation = 1,
        padding_mode = 'zeros',  # TODO: refine this type
        device=None,
        dtype=None
    ):
      super().__init__()
      self.depth_conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,output_padding=output_padding,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype)
      self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype)
    
    def forward(self, input):
      return self.point_conv(self.depth_conv(input))

    def weight_norm(self):
      self.depth_conv = weight_norm(self.depth_conv, name = 'weight')
      self.point_conv = weight_norm(self.point_conv, name = 'weight')

    def remove_weight_norm(self):
      remove_weight_norm(self.depth_conv, name = 'weight')
      remove_weight_norm(self.point_conv, name = 'weight')


def weight_norm_modules(module, name = 'weight', dim = 0):
    if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D):
      module.weight_norm()
      return module
    else:
      return weight_norm(module,name,dim)

def remove_weight_norm_modules(module, name = 'weight'):
    if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D):
      module.remove_weight_norm()
    else:
      remove_weight_norm(module,name)