NeoPy commited on
Commit
30f8290
·
verified ·
1 Parent(s): 75e8179
infer/lib/predictors/DJCM/DJCM.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+ import numpy as np
6
+
7
+ from scipy.signal import medfilt
8
+
9
+ sys.path.append(os.getcwd())
10
+
11
+ from main.library.predictors.DJCM.spec import Spectrogram
12
+
13
+ SAMPLE_RATE, WINDOW_LENGTH, N_CLASS = 16000, 1024, 360
14
+
15
+ class DJCM:
16
+ def __init__(
17
+ self,
18
+ model_path,
19
+ device = "cpu",
20
+ is_half = False,
21
+ onnx = False,
22
+ svs = False,
23
+ providers = ["CPUExecutionProvider"],
24
+ batch_size = 1,
25
+ segment_len = 5.12,
26
+ kernel_size = 3
27
+ ):
28
+ super(DJCM, self).__init__()
29
+ if svs: WINDOW_LENGTH = 2048
30
+ self.onnx = onnx
31
+
32
+ if self.onnx:
33
+ import onnxruntime as ort
34
+
35
+ sess_options = ort.SessionOptions()
36
+ sess_options.log_severity_level = 3
37
+ self.model = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
38
+ else:
39
+ from main.library.predictors.DJCM.model import DJCMM
40
+
41
+ model = DJCMM(1, 1, 1, svs=svs, window_length=WINDOW_LENGTH, n_class=N_CLASS)
42
+ model.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=True))
43
+ model.eval()
44
+ if is_half: model = model.half()
45
+ self.model = model.to(device)
46
+
47
+ self.batch_size = batch_size
48
+ self.seg_len = int(segment_len * SAMPLE_RATE)
49
+ self.seg_frames = int(self.seg_len // int(SAMPLE_RATE // 100))
50
+
51
+ self.device = device
52
+ self.is_half = is_half
53
+ self.kernel_size = kernel_size
54
+
55
+ self.spec_extractor = Spectrogram(int(SAMPLE_RATE // 100), WINDOW_LENGTH).to(device)
56
+ cents_mapping = 20 * np.arange(N_CLASS) + 1997.3794084376191
57
+ self.cents_mapping = np.pad(cents_mapping, (4, 4))
58
+
59
+ def spec2hidden(self, spec):
60
+ if self.onnx:
61
+ spec = spec.cpu().numpy().astype(np.float32)
62
+
63
+ hidden = torch.as_tensor(
64
+ self.model.run(
65
+ [self.model.get_outputs()[0].name],
66
+ {self.model.get_inputs()[0].name: spec}
67
+ )[0],
68
+ device=self.device
69
+ )
70
+ else:
71
+ if self.is_half: spec = spec.half()
72
+ hidden = self.model(spec)
73
+
74
+ return hidden
75
+
76
+ def infer_from_audio(self, audio, thred=0.03):
77
+ if torch.is_tensor(audio): audio = audio.cpu().numpy()
78
+ if audio.ndim > 1: audio = audio.squeeze()
79
+
80
+ with torch.no_grad():
81
+ padded_audio = self.pad_audio(audio)
82
+ hidden = self.inference(padded_audio)[:(audio.shape[-1] // int(SAMPLE_RATE // 100) + 1)]
83
+
84
+ f0 = self.decode(hidden.squeeze(0).cpu().numpy(), thred)
85
+ if self.kernel_size is not None: f0 = medfilt(f0, kernel_size=self.kernel_size)
86
+
87
+ return f0
88
+
89
+ def infer_from_audio_with_pitch(self, audio, thred=0.03, f0_min=50, f0_max=1100):
90
+ f0 = self.infer_from_audio(audio, thred)
91
+ f0[(f0 < f0_min) | (f0 > f0_max)] = 0
92
+
93
+ return f0
94
+
95
+ def to_local_average_cents(self, salience, thred=0.05):
96
+ center = np.argmax(salience, axis=1)
97
+ salience = np.pad(salience, ((0, 0), (4, 4)))
98
+ center += 4
99
+ todo_salience, todo_cents_mapping = [], []
100
+ starts = center - 4
101
+ ends = center + 5
102
+
103
+ for idx in range(salience.shape[0]):
104
+ todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
105
+ todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
106
+
107
+ todo_salience = np.array(todo_salience)
108
+ devided = np.sum(todo_salience * np.array(todo_cents_mapping), 1) / np.sum(todo_salience, 1)
109
+ devided[np.max(salience, axis=1) <= thred] = 0
110
+
111
+ return devided
112
+
113
+ def decode(self, hidden, thred=0.03):
114
+ f0 = 10 * (2 ** (self.to_local_average_cents(hidden, thred=thred) / 1200))
115
+ f0[f0 == 10] = 0
116
+ return f0
117
+
118
+ def pad_audio(self, audio):
119
+ audio_len = audio.shape[-1]
120
+
121
+ seg_nums = int(np.ceil(audio_len / self.seg_len)) + 1
122
+ pad_len = int(seg_nums * self.seg_len - audio_len + self.seg_len // 2)
123
+
124
+ left_pad = np.zeros(int(self.seg_len // 4), dtype=np.float32)
125
+ right_pad = np.zeros(int(pad_len - self.seg_len // 4), dtype=np.float32)
126
+ padded_audio = np.concatenate([left_pad, audio, right_pad], axis=-1)
127
+
128
+ segments = [
129
+ padded_audio[start: start + int(self.seg_len)]
130
+ for start in range(
131
+ 0,
132
+ len(padded_audio) - int(self.seg_len) + 1,
133
+ int(self.seg_len // 2)
134
+ )
135
+ ]
136
+
137
+ segments = np.stack(segments, axis=0)
138
+ segments = torch.from_numpy(segments).unsqueeze(1).to(self.device)
139
+
140
+ return segments
141
+
142
+ def inference(self, segments):
143
+ hidden_segments = torch.cat([
144
+ self.spec2hidden(self.spec_extractor(segments[i:i + self.batch_size].float()))
145
+ for i in range(0, len(segments), self.batch_size)
146
+ ], dim=0)
147
+
148
+ hidden = torch.cat([
149
+ seg[self.seg_frames // 4: int(self.seg_frames * 0.75)]
150
+ for seg in hidden_segments
151
+ ], dim=0)
152
+
153
+ return hidden
infer/lib/predictors/DJCM/decoder.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ sys.path.append(os.getcwd())
9
+
10
+ from main.library.predictors.DJCM.encoder import ResEncoderBlock
11
+ from main.library.predictors.DJCM.utils import ResConvBlock, BiGRU, init_bn, init_layer
12
+
13
+ class ResDecoderBlock(nn.Module):
14
+ def __init__(self, in_channels, out_channels, n_blocks, stride):
15
+ super(ResDecoderBlock, self).__init__()
16
+ self.conv1 = nn.ConvTranspose2d(in_channels, out_channels, stride, stride, (0, 0), bias=False)
17
+ self.bn1 = nn.BatchNorm2d(in_channels, momentum=0.01)
18
+ self.conv = nn.ModuleList([ResConvBlock(out_channels * 2, out_channels)])
19
+
20
+ for _ in range(n_blocks - 1):
21
+ self.conv.append(ResConvBlock(out_channels, out_channels))
22
+
23
+ self.init_weights()
24
+
25
+ def init_weights(self):
26
+ init_bn(self.bn1)
27
+ init_layer(self.conv1)
28
+
29
+ def forward(self, x, concat):
30
+ x = self.conv1(F.relu_(self.bn1(x)))
31
+ x = torch.cat((x, concat), dim=1)
32
+
33
+ for each_layer in self.conv:
34
+ x = each_layer(x)
35
+
36
+ return x
37
+
38
+ class Decoder(nn.Module):
39
+ def __init__(self, n_blocks):
40
+ super(Decoder, self).__init__()
41
+ self.de_blocks = nn.ModuleList([
42
+ ResDecoderBlock(384, 384, n_blocks, (1, 2)),
43
+ ResDecoderBlock(384, 384, n_blocks, (1, 2)),
44
+ ResDecoderBlock(384, 256, n_blocks, (1, 2)),
45
+ ResDecoderBlock(256, 128, n_blocks, (1, 2)),
46
+ ResDecoderBlock(128, 64, n_blocks, (1, 2)),
47
+ ResDecoderBlock(64, 32, n_blocks, (1, 2))
48
+ ])
49
+
50
+ def forward(self, x, concat_tensors):
51
+ for i, layer in enumerate(self.de_blocks):
52
+ x = layer(x, concat_tensors[-1 - i])
53
+
54
+ return x
55
+
56
+ class PE_Decoder(nn.Module):
57
+ def __init__(self, n_blocks, seq_layers=1, window_length = 1024, n_class = 360):
58
+ super(PE_Decoder, self).__init__()
59
+ self.de_blocks = Decoder(n_blocks)
60
+ self.after_conv1 = ResEncoderBlock(32, 32, n_blocks, None)
61
+ self.after_conv2 = nn.Conv2d(32, 1, (1, 1))
62
+ self.fc = nn.Sequential(
63
+ BiGRU(
64
+ (1, window_length // 2),
65
+ 1,
66
+ seq_layers
67
+ ),
68
+ nn.Linear(
69
+ window_length // 2,
70
+ n_class
71
+ ),
72
+ nn.Sigmoid()
73
+ )
74
+ init_layer(self.after_conv2)
75
+
76
+ def forward(self, x, concat_tensors):
77
+ return self.fc(self.after_conv2(self.after_conv1(self.de_blocks(x, concat_tensors)))).squeeze(1)
78
+
79
+ class SVS_Decoder(nn.Module):
80
+ def __init__(self, in_channels, n_blocks):
81
+ super(SVS_Decoder, self).__init__()
82
+ self.de_blocks = Decoder(n_blocks)
83
+ self.after_conv1 = ResEncoderBlock(32, 32, n_blocks, None)
84
+ self.after_conv2 = nn.Conv2d(32, in_channels * 4, (1, 1))
85
+ self.init_weights()
86
+
87
+ def init_weights(self):
88
+ init_layer(self.after_conv2)
89
+
90
+ def forward(self, x, concat_tensors):
91
+ return self.after_conv2(self.after_conv1(self.de_blocks(x, concat_tensors)))
infer/lib/predictors/DJCM/encoder.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import torch.nn as nn
5
+
6
+ sys.path.append(os.getcwd())
7
+
8
+ from main.library.predictors.DJCM.utils import ResConvBlock
9
+
10
+ class ResEncoderBlock(nn.Module):
11
+ def __init__(
12
+ self,
13
+ in_channels,
14
+ out_channels,
15
+ n_blocks,
16
+ kernel_size
17
+ ):
18
+ super(ResEncoderBlock, self).__init__()
19
+ self.conv = nn.ModuleList([
20
+ ResConvBlock(
21
+ in_channels,
22
+ out_channels
23
+ )
24
+ ])
25
+
26
+ for _ in range(n_blocks - 1):
27
+ self.conv.append(
28
+ ResConvBlock(
29
+ out_channels,
30
+ out_channels
31
+ )
32
+ )
33
+
34
+ self.pool = nn.MaxPool2d(kernel_size) if kernel_size is not None else None
35
+
36
+ def forward(self, x):
37
+ for each_layer in self.conv:
38
+ x = each_layer(x)
39
+
40
+ if self.pool is not None: return x, self.pool(x)
41
+ return x
42
+
43
+ class Encoder(nn.Module):
44
+ def __init__(
45
+ self,
46
+ in_channels,
47
+ n_blocks
48
+ ):
49
+ super(Encoder, self).__init__()
50
+ self.en_blocks = nn.ModuleList([
51
+ ResEncoderBlock(
52
+ in_channels,
53
+ 32,
54
+ n_blocks,
55
+ (1, 2)
56
+ ),
57
+ ResEncoderBlock(
58
+ 32,
59
+ 64,
60
+ n_blocks,
61
+ (1, 2)
62
+ ),
63
+ ResEncoderBlock(
64
+ 64,
65
+ 128,
66
+ n_blocks,
67
+ (1, 2)
68
+ ),
69
+ ResEncoderBlock(
70
+ 128,
71
+ 256,
72
+ n_blocks,
73
+ (1, 2)
74
+ ),
75
+ ResEncoderBlock(
76
+ 256,
77
+ 384,
78
+ n_blocks,
79
+ (1, 2)
80
+ ),
81
+ ResEncoderBlock(
82
+ 384,
83
+ 384,
84
+ n_blocks,
85
+ (1, 2)
86
+ )
87
+ ])
88
+
89
+ def forward(self, x):
90
+ concat_tensors = []
91
+
92
+ for layer in self.en_blocks:
93
+ _, x = layer(x)
94
+ concat_tensors.append(_)
95
+
96
+ return x, concat_tensors
infer/lib/predictors/DJCM/model.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import torch.nn as nn
5
+
6
+ sys.path.append(os.getcwd())
7
+
8
+ from main.library.predictors.DJCM.utils import init_bn
9
+ from main.library.predictors.DJCM.decoder import PE_Decoder, SVS_Decoder
10
+ from main.library.predictors.DJCM.encoder import ResEncoderBlock, Encoder
11
+
12
+ class LatentBlocks(nn.Module):
13
+ def __init__(
14
+ self,
15
+ n_blocks,
16
+ latent_layers
17
+ ):
18
+ super(LatentBlocks, self).__init__()
19
+ self.latent_blocks = nn.ModuleList([
20
+ ResEncoderBlock(
21
+ 384,
22
+ 384,
23
+ n_blocks,
24
+ None
25
+ )
26
+ for _ in range(latent_layers)
27
+ ])
28
+
29
+ def forward(self, x):
30
+ for layer in self.latent_blocks:
31
+ x = layer(x)
32
+
33
+ return x
34
+
35
+ class DJCMM(nn.Module):
36
+ def __init__(
37
+ self,
38
+ in_channels,
39
+ n_blocks,
40
+ latent_layers,
41
+ svs=False,
42
+ window_length=1024,
43
+ n_class=360
44
+ ):
45
+ super(DJCMM, self).__init__()
46
+ self.bn = nn.BatchNorm2d(
47
+ window_length // 2 + 1,
48
+ momentum=0.01
49
+ )
50
+ self.pe_encoder = Encoder(
51
+ in_channels,
52
+ n_blocks
53
+ )
54
+ self.pe_latent = LatentBlocks(
55
+ n_blocks,
56
+ latent_layers
57
+ )
58
+ self.pe_decoder = PE_Decoder(
59
+ n_blocks,
60
+ window_length=window_length,
61
+ n_class=n_class
62
+ )
63
+
64
+ self.svs = svs
65
+
66
+ if svs:
67
+ self.svs_encoder = Encoder(
68
+ in_channels,
69
+ n_blocks
70
+ )
71
+ self.svs_latent = LatentBlocks(
72
+ n_blocks,
73
+ latent_layers
74
+ )
75
+ self.svs_decoder = SVS_Decoder(
76
+ in_channels,
77
+ n_blocks
78
+ )
79
+
80
+ init_bn(self.bn)
81
+
82
+ def spec(self, x, spec_m):
83
+ bs, c, time_steps, freqs_steps = x.shape
84
+ x = x.reshape(bs, c // 4, 4, time_steps, freqs_steps)
85
+
86
+ mask_spec = x[:, :, 0, :, :].sigmoid()
87
+ linear_spec = x[:, :, 3, :, :]
88
+
89
+ out_spec = (
90
+ spec_m.detach() * mask_spec + linear_spec
91
+ ).relu()
92
+
93
+ return out_spec
94
+
95
+ def forward(self, spec):
96
+ x = self.bn(
97
+ spec.transpose(1, 3)
98
+ ).transpose(1, 3)[..., :-1]
99
+
100
+ if self.svs:
101
+ x, concat_tensors = self.svs_encoder(x)
102
+
103
+ x = self.svs_decoder(
104
+ self.svs_latent(x),
105
+ concat_tensors
106
+ )
107
+
108
+ x = self.spec(
109
+ nn.functional.pad(x, pad=(0, 1)),
110
+ spec
111
+ )[..., :-1]
112
+
113
+ x, concat_tensors = self.pe_encoder(x)
114
+
115
+ pe_out = self.pe_decoder(
116
+ self.pe_latent(x),
117
+ concat_tensors
118
+ )
119
+
120
+ return pe_out
infer/lib/predictors/DJCM/spec.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+ import numpy as np
6
+ import torch.nn as nn
7
+
8
+ sys.path.append(os.getcwd())
9
+
10
+ class Spectrogram(nn.Module):
11
+ def __init__(
12
+ self,
13
+ hop_length,
14
+ win_length,
15
+ n_fft=None,
16
+ clamp=1e-10
17
+ ):
18
+ super(Spectrogram, self).__init__()
19
+ self.n_fft = win_length if n_fft is None else n_fft
20
+ self.hop_length = hop_length
21
+ self.win_length = win_length
22
+ self.clamp = clamp
23
+ self.register_buffer("window", torch.hann_window(win_length), persistent=False)
24
+
25
+ def forward(self, audio, center=True):
26
+ bs, c, segment_samples = audio.shape
27
+ audio = audio.reshape(bs * c, segment_samples)
28
+
29
+ if str(audio.device).startswith(("ocl", "privateuseone")):
30
+ if not hasattr(self, "stft"):
31
+ from main.library.backends.utils import STFT
32
+
33
+ self.stft = STFT(
34
+ filter_length=self.n_fft,
35
+ hop_length=self.hop_length,
36
+ win_length=self.win_length
37
+ ).to(audio.device)
38
+
39
+ magnitude = self.stft.transform(audio, 1e-9)
40
+ else:
41
+ fft = torch.stft(
42
+ audio,
43
+ n_fft=self.n_fft,
44
+ hop_length=self.hop_length,
45
+ win_length=self.win_length,
46
+ window=self.window,
47
+ center=center,
48
+ pad_mode="reflect",
49
+ return_complex=True
50
+ )
51
+
52
+ magnitude = (fft.real.pow(2) + fft.imag.pow(2)).sqrt()
53
+
54
+ mag = magnitude.transpose(1, 2).clamp(self.clamp, np.inf)
55
+ mag = mag.reshape(bs, c, mag.shape[1], mag.shape[2])
56
+
57
+ return mag
infer/lib/predictors/DJCM/utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch import nn
4
+ from einops.layers.torch import Rearrange
5
+
6
+ def init_layer(layer):
7
+ nn.init.xavier_uniform_(layer.weight)
8
+
9
+ if hasattr(layer, "bias") and layer.bias is not None:
10
+ layer.bias.data.fill_(0.0)
11
+
12
+ def init_bn(bn):
13
+ bn.bias.data.fill_(0.0)
14
+ bn.weight.data.fill_(1.0)
15
+ bn.running_mean.data.fill_(0.0)
16
+ bn.running_var.data.fill_(1.0)
17
+
18
+ class BiGRU(nn.Module):
19
+ def __init__(
20
+ self,
21
+ patch_size,
22
+ channels,
23
+ depth
24
+ ):
25
+ super(BiGRU, self).__init__()
26
+ patch_width, patch_height = patch_size
27
+ patch_dim = channels * patch_height * patch_width
28
+
29
+ self.to_patch_embedding = nn.Sequential(
30
+ Rearrange(
31
+ 'b c (w p1) (h p2) -> b (w h) (p1 p2 c)',
32
+ p1=patch_width,
33
+ p2=patch_height
34
+ )
35
+ )
36
+
37
+ self.gru = nn.GRU(
38
+ patch_dim,
39
+ patch_dim // 2,
40
+ num_layers=depth,
41
+ batch_first=True,
42
+ bidirectional=True
43
+ )
44
+
45
+ def forward(self, x):
46
+ x = self.to_patch_embedding(x)
47
+
48
+ try:
49
+ return self.gru(x)[0]
50
+ except:
51
+ torch.backends.cudnn.enabled = False
52
+ return self.gru(x)[0]
53
+
54
+ class ResConvBlock(nn.Module):
55
+ def __init__(
56
+ self,
57
+ in_planes,
58
+ out_planes
59
+ ):
60
+ super(ResConvBlock, self).__init__()
61
+ self.bn1 = nn.BatchNorm2d(
62
+ in_planes,
63
+ momentum=0.01
64
+ )
65
+ self.bn2 = nn.BatchNorm2d(
66
+ out_planes,
67
+ momentum=0.01
68
+ )
69
+ self.act1 = nn.PReLU()
70
+ self.act2 = nn.PReLU()
71
+ self.conv1 = nn.Conv2d(
72
+ in_planes,
73
+ out_planes,
74
+ (3, 3),
75
+ padding=(1, 1),
76
+ bias=False
77
+ )
78
+ self.conv2 = nn.Conv2d(
79
+ out_planes,
80
+ out_planes,
81
+ (3, 3),
82
+ padding=(1, 1),
83
+ bias=False
84
+ )
85
+ self.is_shortcut = False
86
+
87
+ if in_planes != out_planes:
88
+ self.shortcut = nn.Conv2d(
89
+ in_planes,
90
+ out_planes,
91
+ (1, 1)
92
+ )
93
+ self.is_shortcut = True
94
+
95
+ self.init_weights()
96
+
97
+ def init_weights(self):
98
+ init_bn(self.bn1)
99
+ init_bn(self.bn2)
100
+
101
+ init_layer(self.conv1)
102
+ init_layer(self.conv2)
103
+
104
+ if self.is_shortcut: init_layer(self.shortcut)
105
+
106
+ def forward(self, x):
107
+ out = self.conv1(
108
+ self.act1(self.bn1(x))
109
+ )
110
+ out = self.conv2(
111
+ self.act2(self.bn2(out))
112
+ )
113
+
114
+ if self.is_shortcut: return self.shortcut(x) + out
115
+ else: return out + x