csukuangfj commited on
Commit
cbd589e
β€’
1 Parent(s): 9baf00d

first commit

Browse files
Files changed (6) hide show
  1. README.md +6 -5
  2. __init.py +0 -0
  3. app.py +288 -0
  4. requirements.txt +6 -0
  5. separate.py +198 -0
  6. unet.py +150 -0
README.md CHANGED
@@ -1,10 +1,11 @@
1
  ---
2
- title: Music Source Separation
3
- emoji: πŸ’»
4
- colorFrom: green
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 3.41.0
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
+ title: Music source separation
3
+ emoji: πŸŒ–
4
+ colorFrom: yellow
5
+ colorTo: green
6
  sdk: gradio
7
+ python_version: 3.8.9
8
+ sdk_version: 3.0.26
9
  app_file: app.py
10
  pinned: false
11
  license: apache-2.0
__init.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ #
3
+ # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
4
+ #
5
+ # See LICENSE for clarification regarding multiple authors
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+ # References:
20
+ # https://gradio.app/docs/#dropdown
21
+
22
+ import logging
23
+ import os
24
+ import tempfile
25
+ import time
26
+ from datetime import datetime
27
+
28
+ import gradio as gr
29
+ import torch
30
+ import torchaudio
31
+ import urllib.request
32
+ from separate import load_audio, load_model, separate
33
+
34
+
35
+ def build_html_output(s: str, style: str = "result_item_success"):
36
+ return f"""
37
+ <div class='result'>
38
+ <div class='result_item {style}'>
39
+ {s}
40
+ </div>
41
+ </div>
42
+ """
43
+
44
+
45
+ def process_url(url: str):
46
+ logging.info(f"Processing URL: {url}")
47
+ with tempfile.NamedTemporaryFile() as f:
48
+ try:
49
+ urllib.request.urlretrieve(url, f.name)
50
+ return process(in_filename=f.name)
51
+ except Exception as e:
52
+ logging.info(str(e))
53
+ return "", build_html_output(str(e), "result_item_error")
54
+
55
+
56
+ def process_uploaded_file(in_filename: str):
57
+ if in_filename is None or in_filename == "":
58
+ return "", build_html_output(
59
+ "Please first upload a file and then click "
60
+ 'the button "submit for separation"',
61
+ "result_item_error",
62
+ )
63
+
64
+ logging.info(f"Processing uploaded file: {in_filename}")
65
+ try:
66
+ return process(in_filename=in_filename)
67
+ except Exception as e:
68
+ logging.info(str(e))
69
+ return "", build_html_output(str(e), "result_item_error")
70
+
71
+
72
+ def process_microphone(in_filename: str):
73
+ if in_filename is None or in_filename == "":
74
+ return "", build_html_output(
75
+ "Please first click 'Record from microphone', speak, "
76
+ "click 'Stop recording', and then "
77
+ "click the button 'submit for separation'",
78
+ "result_item_error",
79
+ )
80
+
81
+ logging.info(f"Processing microphone: {in_filename}")
82
+ try:
83
+ return process(in_filename=in_filename)
84
+ except Exception as e:
85
+ logging.info(str(e))
86
+ return "", build_html_output(str(e), "result_item_error")
87
+
88
+
89
+ @torch.no_grad()
90
+ def process(in_filename: str):
91
+ logging.info(f"in_filename: {in_filename}")
92
+
93
+ waveform = load_audio(waveform)
94
+ duration = waveform.shape[0] / 44100 # in seconds
95
+
96
+ vocals = load_model("vocals.pt")
97
+ accompaniment = load_model("accompaniment.pt")
98
+
99
+ now = datetime.now()
100
+ date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
101
+ logging.info(f"Started at {date_time}")
102
+
103
+ start = time.time()
104
+
105
+ vocals_wave, accompaniment_wave = separate(vocals, accompaniment, waveform)
106
+
107
+ date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
108
+ end = time.time()
109
+
110
+ metadata = torchaudio.info(filename)
111
+ duration = metadata.num_frames / sample_rate
112
+ rtf = (end - start) / duration
113
+
114
+ logging.info(f"Finished at {date_time} s. Elapsed: {end - start: .3f} s")
115
+
116
+ info = f"""
117
+ Wave duration : {duration: .3f} s <br/>
118
+ Processing time: {end - start: .3f} s <br/>
119
+ RTF: {end - start: .3f}/{duration: .3f} = {rtf:.3f} <br/>
120
+ """
121
+ if rtf > 1:
122
+ info += (
123
+ "<br/>We are loading the model for the first run. "
124
+ "Please run again to measure the real RTF.<br/>"
125
+ )
126
+
127
+ logging.info(info)
128
+ logging.info(f"\nrepo_id: {repo_id}\nhyp: {text}")
129
+
130
+ return text, build_html_output(info)
131
+
132
+
133
+ title = "# Automatic Speech Recognition with Next-gen Kaldi"
134
+ description = """
135
+ This space shows how to do automatic speech recognition with Next-gen Kaldi.
136
+
137
+ Please visit
138
+ <https://huggingface.co/spaces/k2-fsa/streaming-automatic-speech-recognition>
139
+ for streaming speech recognition with **Next-gen Kaldi**.
140
+
141
+ It is running on CPU within a docker container provided by Hugging Face.
142
+
143
+ See more information by visiting the following links:
144
+
145
+ - <https://github.com/k2-fsa/icefall>
146
+ - <https://github.com/k2-fsa/sherpa>
147
+ - <https://github.com/k2-fsa/k2>
148
+ - <https://github.com/lhotse-speech/lhotse>
149
+
150
+ If you want to deploy it locally, please see
151
+ <https://k2-fsa.github.io/sherpa/>
152
+ """
153
+
154
+ # css style is copied from
155
+ # https://huggingface.co/spaces/alphacep/asr/blob/main/app.py#L113
156
+ css = """
157
+ .result {display:flex;flex-direction:column}
158
+ .result_item {padding:15px;margin-bottom:8px;border-radius:15px;width:100%}
159
+ .result_item_success {background-color:mediumaquamarine;color:white;align-self:start}
160
+ .result_item_error {background-color:#ff7070;color:white;align-self:start}
161
+ """
162
+
163
+
164
+ def update_model_dropdown(language: str):
165
+ if language in language_to_models:
166
+ choices = language_to_models[language]
167
+ return gr.Dropdown.update(choices=choices, value=choices[0])
168
+
169
+ raise ValueError(f"Unsupported language: {language}")
170
+
171
+
172
+ demo = gr.Blocks(css=css)
173
+
174
+
175
+ with demo:
176
+ gr.Markdown(title)
177
+ language_choices = list(language_to_models.keys())
178
+
179
+ language_radio = gr.Radio(
180
+ label="Language",
181
+ choices=language_choices,
182
+ value=language_choices[0],
183
+ )
184
+ model_dropdown = gr.Dropdown(
185
+ choices=language_to_models[language_choices[0]],
186
+ label="Select a model",
187
+ value=language_to_models[language_choices[0]][0],
188
+ )
189
+
190
+ language_radio.change(
191
+ update_model_dropdown,
192
+ inputs=language_radio,
193
+ outputs=model_dropdown,
194
+ )
195
+
196
+ decoding_method_radio = gr.Radio(
197
+ label="Decoding method",
198
+ choices=["greedy_search", "modified_beam_search"],
199
+ value="greedy_search",
200
+ )
201
+
202
+ num_active_paths_slider = gr.Slider(
203
+ minimum=1,
204
+ value=4,
205
+ step=1,
206
+ label="Number of active paths for modified_beam_search",
207
+ )
208
+
209
+ with gr.Tabs():
210
+ with gr.TabItem("Upload from disk"):
211
+ uploaded_file = gr.Audio(
212
+ source="upload", # Choose between "microphone", "upload"
213
+ type="filepath",
214
+ optional=False,
215
+ label="Upload from disk",
216
+ )
217
+ upload_button = gr.Button("Submit for separation")
218
+ uploaded_html_info = gr.HTML(label="Info")
219
+
220
+ gr.Examples(
221
+ examples=examples,
222
+ inputs=[uploaded_file],
223
+ outputs=["audio", "audio", uploaded_html_info],
224
+ fn=process_uploaded_file,
225
+ )
226
+
227
+ with gr.TabItem("Record from microphone"):
228
+ microphone = gr.Audio(
229
+ source="microphone", # Choose between "microphone", "upload"
230
+ type="filepath",
231
+ optional=False,
232
+ label="Record from microphone",
233
+ )
234
+
235
+ record_button = gr.Button("Submit for separation")
236
+ recorded_html_info = gr.HTML(label="Info")
237
+
238
+ gr.Examples(
239
+ examples=examples,
240
+ inputs=[microphone],
241
+ outputs=["audio", "audio", recorded_html_info],
242
+ fn=process_microphone,
243
+ )
244
+
245
+ with gr.TabItem("From URL"):
246
+ url_textbox = gr.Textbox(
247
+ max_lines=1,
248
+ placeholder="URL to an audio file",
249
+ label="URL",
250
+ interactive=True,
251
+ )
252
+
253
+ url_button = gr.Button("Submit for separation")
254
+ url_html_info = gr.HTML(label="Info")
255
+
256
+ upload_button.click(
257
+ process_uploaded_file,
258
+ inputs=[uploaded_file],
259
+ outputs=["audio", "audio", uploaded_html_info],
260
+ )
261
+
262
+ record_button.click(
263
+ process_microphone,
264
+ inputs=[microphone],
265
+ outputs=["audio", "audio", recorded_html_info],
266
+ )
267
+
268
+ url_button.click(
269
+ process_url,
270
+ inputs=[url_textbox],
271
+ outputs=["audio", "audio", url_html_info],
272
+ )
273
+
274
+ gr.Markdown(description)
275
+
276
+ torch.set_num_threads(1)
277
+ torch.set_num_interop_threads(1)
278
+
279
+ torch._C._jit_set_profiling_executor(False)
280
+ torch._C._jit_set_profiling_mode(False)
281
+ torch._C._set_graph_executor_optimize(False)
282
+
283
+ if __name__ == "__main__":
284
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
285
+
286
+ logging.basicConfig(format=formatter, level=logging.INFO)
287
+
288
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ https://download.pytorch.org/whl/cpu/torch-1.13.1%2Bcpu-cp38-cp38-linux_x86_64.whl
2
+ https://download.pytorch.org/whl/cpu/torchaudio-0.13.1%2Bcpu-cp38-cp38-linux_x86_64.whl
3
+
4
+ numpy
5
+
6
+ huggingface_hub
separate.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
3
+
4
+ # Please see ./run.sh for usage
5
+
6
+ from typing import Optional, Tuple
7
+
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ import ffmpeg
11
+ import numpy as np
12
+ import torch
13
+ import soundfile as sf
14
+ import torchaudio
15
+ from functools import lru_cache
16
+ from pydub import AudioSegment
17
+
18
+
19
+ from unet import UNet
20
+
21
+
22
+ def load_audio(filename):
23
+ probe = ffmpeg.probe(filename)
24
+ if "streams" not in probe or len(probe["streams"]) == 0:
25
+ raise ValueError("No stream was found with ffprobe")
26
+
27
+ metadata = next(
28
+ stream for stream in probe["streams"] if stream["codec_type"] == "audio"
29
+ )
30
+ n_channels = metadata["channels"]
31
+
32
+ sample_rate = 44100
33
+
34
+ process = (
35
+ ffmpeg.input(filename)
36
+ .output("pipe:", format="f32le", ar=sample_rate)
37
+ .run_async(pipe_stdout=True, pipe_stderr=True)
38
+ )
39
+ buffer, _ = process.communicate()
40
+ waveform = np.frombuffer(buffer, dtype="<f4").reshape(-1, n_channels)
41
+
42
+ waveform = torch.from_numpy(waveform).to(torch.float32)
43
+ if n_channels == 1:
44
+ waveform = waveform.tile(1, 2)
45
+
46
+ if n_channels > 2:
47
+ waveform = waveform[:, :2]
48
+
49
+ return waveform
50
+
51
+
52
+ def separate(
53
+ vocals: torch.nn.Module,
54
+ accompaniment: torch.nn.Module,
55
+ waveform: torch.Tensor,
56
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
57
+ waveform = torch.nn.functional.pad(waveform, (0, 0, 0, 4096))
58
+
59
+ # torch.stft requires a 2-D input of shape (N, T), so we transpose waveform
60
+ stft = torch.stft(
61
+ waveform.t(),
62
+ n_fft=4096,
63
+ hop_length=1024,
64
+ window=torch.hann_window(4096, periodic=True),
65
+ center=False,
66
+ onesided=True,
67
+ return_complex=True,
68
+ )
69
+ # stft: (2, 2049, 465)
70
+ # stft is a complex tensor
71
+
72
+ y = stft.permute(2, 1, 0)
73
+ # (465, 2049, 2)
74
+
75
+ y = y[:, :1024, :]
76
+ # (465, 1024, 2)
77
+
78
+ tensor_size = y.shape[0] - int(y.shape[0] / 512) * 512
79
+ pad_size = 512 - tensor_size
80
+ y = torch.nn.functional.pad(y, (0, 0, 0, 0, 0, pad_size))
81
+ # (512, 1024, 2)
82
+
83
+ num_splits = int(y.shape[0] / 512)
84
+ y = y.reshape([num_splits, 512] + list(y.shape[1:]))
85
+ # y: (1, 512, 1024, 2)
86
+
87
+ y = y.abs()
88
+ y = y.permute(0, 3, 1, 2)
89
+ # (1, 2, 512, 1024)
90
+
91
+ vocals_spec = vocals(y)
92
+ accompaniment_spec = accompaniment(y)
93
+
94
+ sum_spec = (vocals_spec**2 + accompaniment_spec**2) + 1e-10
95
+
96
+ vocals_spec = (vocals_spec**2 + 1e-10 / 2) / sum_spec
97
+ # (1, 2, 512, 1024)
98
+
99
+ accompaniment_spec = (accompaniment_spec**2 + 1e-10 / 2) / sum_spec
100
+ # (1, 2, 512, 1024)
101
+
102
+ ans = []
103
+ for spec in [vocals_spec, accompaniment_spec]:
104
+ spec = torch.nn.functional.pad(spec, (0, 2049 - 1024, 0, 0, 0, 0, 0, 0))
105
+ # (1, 2, 512, 2049)
106
+
107
+ spec = spec.permute(0, 2, 3, 1)
108
+ # (1, 512, 2049, 2)
109
+
110
+ spec = spec.reshape(-1, spec.shape[2], spec.shape[3])
111
+ # (512, 2049, 2)
112
+
113
+ spec = spec[: stft.shape[2], :, :]
114
+ # (465, 2049, 2)
115
+
116
+ spec = spec.permute(2, 1, 0)
117
+ # (2, 2049, 465)
118
+
119
+ masked_stft = spec * stft
120
+
121
+ wave = torch.istft(
122
+ masked_stft,
123
+ 4096,
124
+ 1024,
125
+ window=torch.hann_window(4096, periodic=True),
126
+ onesided=True,
127
+ ) * (2 / 3)
128
+
129
+ # sf.write(f"{name}.wav", wave.t(), 44100)
130
+
131
+ # wave = (wave.t() * 32768).to(torch.int16)
132
+ # sound = AudioSegment(
133
+ # data=wave.numpy().tobytes(), sample_width=2, frame_rate=44100, channels=2
134
+ # )
135
+ # sound.export(f"{name}.mp3", format="mp3", bitrate="128k")
136
+ ans.append(wave)
137
+
138
+ return ans[0], ans[1]
139
+
140
+
141
+ @lru_cache(maxsize=10)
142
+ def get_nn_model_filename(
143
+ repo_id: str,
144
+ filename: str,
145
+ subfolder: str = "2stems",
146
+ ) -> str:
147
+ nn_model_filename = hf_hub_download(
148
+ repo_id=repo_id,
149
+ filename=filename,
150
+ subfolder=subfolder,
151
+ )
152
+ return nn_model_filename
153
+
154
+
155
+ @lru_cache(maxsize=10)
156
+ def load_model(name: str):
157
+ net = UNet()
158
+ net.eval()
159
+ filename = get_nn_model_filename(
160
+ "csukuangfj/spleeter-torch", name, subfolder="2stems"
161
+ )
162
+
163
+ state_dict = torch.load(filename, map_location="cpu")
164
+ net.load_state_dict(state_dict)
165
+
166
+ return net
167
+
168
+
169
+ @torch.no_grad()
170
+ def main():
171
+ vocals = load_model("vocals.pt")
172
+ accompaniment = load_model("accompaniment.pt")
173
+
174
+ filename = "./yesterday-once-more-carpenters.mp3"
175
+
176
+ waveform = load_audio(filename)
177
+ assert waveform.shape[1] == 2, waveform.shape
178
+
179
+ vocals_wave, accompaniment_wave = separate(vocals, accompaniment, waveform)
180
+ vocals_wave = (vocals_wave.t() * 32768).to(torch.int16)
181
+ accompaniment_wave = (accompaniment_wave.t() * 32768).to(torch.int16)
182
+
183
+ vocals_sound = AudioSegment(
184
+ data=vocals_wave.numpy().tobytes(), sample_width=2, frame_rate=44100, channels=2
185
+ )
186
+ vocals_sound.export(f"vocals.mp3", format="mp3", bitrate="128k")
187
+
188
+ accompaniment_sound = AudioSegment(
189
+ data=accompaniment_wave.numpy().tobytes(),
190
+ sample_width=2,
191
+ frame_rate=44100,
192
+ channels=2,
193
+ )
194
+ accompaniment_sound.export(f"accompaniment.mp3", format="mp3", bitrate="128k")
195
+
196
+
197
+ if __name__ == "__main__":
198
+ main()
unet.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
2
+
3
+ import torch
4
+
5
+
6
+ class UNet(torch.nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.conv = torch.nn.Conv2d(2, 16, kernel_size=5, stride=(2, 2), padding=0)
10
+ self.bn = torch.nn.BatchNorm2d(
11
+ 16, track_running_stats=True, eps=1e-3, momentum=0.01
12
+ )
13
+ #
14
+ self.conv1 = torch.nn.Conv2d(16, 32, kernel_size=5, stride=(2, 2), padding=0)
15
+ self.bn1 = torch.nn.BatchNorm2d(
16
+ 32, track_running_stats=True, eps=1e-3, momentum=0.01
17
+ )
18
+
19
+ self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=5, stride=(2, 2), padding=0)
20
+ self.bn2 = torch.nn.BatchNorm2d(
21
+ 64, track_running_stats=True, eps=1e-3, momentum=0.01
22
+ )
23
+
24
+ self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=5, stride=(2, 2), padding=0)
25
+ self.bn3 = torch.nn.BatchNorm2d(
26
+ 128, track_running_stats=True, eps=1e-3, momentum=0.01
27
+ )
28
+
29
+ self.conv4 = torch.nn.Conv2d(128, 256, kernel_size=5, stride=(2, 2), padding=0)
30
+ self.bn4 = torch.nn.BatchNorm2d(
31
+ 256, track_running_stats=True, eps=1e-3, momentum=0.01
32
+ )
33
+
34
+ self.conv5 = torch.nn.Conv2d(256, 512, kernel_size=5, stride=(2, 2), padding=0)
35
+
36
+ self.up1 = torch.nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2)
37
+ self.bn5 = torch.nn.BatchNorm2d(
38
+ 256, track_running_stats=True, eps=1e-3, momentum=0.01
39
+ )
40
+
41
+ self.up2 = torch.nn.ConvTranspose2d(512, 128, kernel_size=5, stride=2)
42
+ self.bn6 = torch.nn.BatchNorm2d(
43
+ 128, track_running_stats=True, eps=1e-3, momentum=0.01
44
+ )
45
+
46
+ self.up3 = torch.nn.ConvTranspose2d(256, 64, kernel_size=5, stride=2)
47
+ self.bn7 = torch.nn.BatchNorm2d(
48
+ 64, track_running_stats=True, eps=1e-3, momentum=0.01
49
+ )
50
+
51
+ self.up4 = torch.nn.ConvTranspose2d(128, 32, kernel_size=5, stride=2)
52
+ self.bn8 = torch.nn.BatchNorm2d(
53
+ 32, track_running_stats=True, eps=1e-3, momentum=0.01
54
+ )
55
+
56
+ self.up5 = torch.nn.ConvTranspose2d(64, 16, kernel_size=5, stride=2)
57
+ self.bn9 = torch.nn.BatchNorm2d(
58
+ 16, track_running_stats=True, eps=1e-3, momentum=0.01
59
+ )
60
+
61
+ self.up6 = torch.nn.ConvTranspose2d(32, 1, kernel_size=5, stride=2)
62
+ self.bn10 = torch.nn.BatchNorm2d(
63
+ 1, track_running_stats=True, eps=1e-3, momentum=0.01
64
+ )
65
+
66
+ # output logit is False, so we need self.up7
67
+ self.up7 = torch.nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3)
68
+
69
+ def forward(self, x):
70
+ in_x = x
71
+ # in_x is (3, 2, 512, 1024) = (T, 2, 512, 1024)
72
+ x = torch.nn.functional.pad(x, (1, 2, 1, 2), "constant", 0)
73
+ conv1 = self.conv(x)
74
+ batch1 = self.bn(conv1)
75
+ rel1 = torch.nn.functional.leaky_relu(batch1, negative_slope=0.2)
76
+
77
+ x = torch.nn.functional.pad(rel1, (1, 2, 1, 2), "constant", 0)
78
+ conv2 = self.conv1(x) # (3, 32, 128, 256)
79
+ batch2 = self.bn1(conv2)
80
+ rel2 = torch.nn.functional.leaky_relu(
81
+ batch2, negative_slope=0.2
82
+ ) # (3, 32, 128, 256)
83
+
84
+ x = torch.nn.functional.pad(rel2, (1, 2, 1, 2), "constant", 0)
85
+ conv3 = self.conv2(x) # (3, 64, 64, 128)
86
+ batch3 = self.bn2(conv3)
87
+ rel3 = torch.nn.functional.leaky_relu(
88
+ batch3, negative_slope=0.2
89
+ ) # (3, 64, 64, 128)
90
+
91
+ x = torch.nn.functional.pad(rel3, (1, 2, 1, 2), "constant", 0)
92
+ conv4 = self.conv3(x) # (3, 128, 32, 64)
93
+ batch4 = self.bn3(conv4)
94
+ rel4 = torch.nn.functional.leaky_relu(
95
+ batch4, negative_slope=0.2
96
+ ) # (3, 128, 32, 64)
97
+
98
+ x = torch.nn.functional.pad(rel4, (1, 2, 1, 2), "constant", 0)
99
+ conv5 = self.conv4(x) # (3, 256, 16, 32)
100
+ batch5 = self.bn4(conv5)
101
+ rel6 = torch.nn.functional.leaky_relu(
102
+ batch5, negative_slope=0.2
103
+ ) # (3, 256, 16, 32)
104
+
105
+ x = torch.nn.functional.pad(rel6, (1, 2, 1, 2), "constant", 0)
106
+ conv6 = self.conv5(x) # (3, 512, 8, 16)
107
+
108
+ up1 = self.up1(conv6)
109
+ up1 = up1[:, :, 1:-2, 1:-2] # (3, 256, 16, 32)
110
+ up1 = torch.nn.functional.relu(up1)
111
+ batch7 = self.bn5(up1)
112
+ merge1 = torch.cat([conv5, batch7], axis=1) # (3, 512, 16, 32)
113
+
114
+ up2 = self.up2(merge1)
115
+ up2 = up2[:, :, 1:-2, 1:-2]
116
+ up2 = torch.nn.functional.relu(up2)
117
+ batch8 = self.bn6(up2)
118
+
119
+ merge2 = torch.cat([conv4, batch8], axis=1) # (3, 256, 32, 64)
120
+
121
+ up3 = self.up3(merge2)
122
+ up3 = up3[:, :, 1:-2, 1:-2]
123
+ up3 = torch.nn.functional.relu(up3)
124
+ batch9 = self.bn7(up3)
125
+
126
+ merge3 = torch.cat([conv3, batch9], axis=1) # (3, 128, 64, 128)
127
+
128
+ up4 = self.up4(merge3)
129
+ up4 = up4[:, :, 1:-2, 1:-2]
130
+ up4 = torch.nn.functional.relu(up4)
131
+ batch10 = self.bn8(up4)
132
+
133
+ merge4 = torch.cat([conv2, batch10], axis=1) # (3, 64, 128, 256)
134
+
135
+ up5 = self.up5(merge4)
136
+ up5 = up5[:, :, 1:-2, 1:-2]
137
+ up5 = torch.nn.functional.relu(up5)
138
+ batch11 = self.bn9(up5)
139
+
140
+ merge5 = torch.cat([conv1, batch11], axis=1) # (3, 32, 256, 512)
141
+
142
+ up6 = self.up6(merge5)
143
+ up6 = up6[:, :, 1:-2, 1:-2]
144
+ up6 = torch.nn.functional.relu(up6)
145
+ batch12 = self.bn10(up6) # (3, 1, 512, 1024) = (T, 1, 512, 1024)
146
+
147
+ up7 = self.up7(batch12)
148
+ up7 = torch.sigmoid(up7) # (3, 2, 512, 1024)
149
+
150
+ return up7 * in_x