artelabsuper commited on
Commit
daf1ccd
1 Parent(s): a2b88d1

copy work from repo

Browse files
.gitattributes CHANGED
@@ -25,3 +25,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ saved_training/HTSAT_ESC_exp=1_fold=1_acc=0.985.ckpt filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from scipy.io.wavfile import write
4
+ import config
5
+
6
+ import torch
7
+ from model.htsat import HTSAT_Swin_Transformer
8
+ from sed_model import SEDWrapper
9
+ import librosa
10
+ import numpy as np
11
+
12
+ example_path = './examples_audio'
13
+
14
+ class_mapping = ['dog', 'rooster', 'pig', 'cow', 'frog', 'cat', 'hen', 'insects', 'sheep', 'crow', 'rain', 'sea_waves', 'crackling_fire', 'crickets', 'chirping_birds', 'water_drops', 'wind', 'pouring_water', 'toilet_flush', 'thunderstorm', 'crying_baby', 'sneezing', 'clapping', 'breathing', 'coughing', 'footsteps', 'laughing',
15
+ 'brushing_teeth', 'snoring', 'drinking_sipping', 'door_wood_knock', 'mouse_click', 'keyboard_typing', 'door_wood_creaks', 'can_opening', 'washing_machine', 'vacuum_cleaner', 'clock_alarm', 'clock_tick', 'glass_breaking', 'helicopter', 'chainsaw', 'siren', 'car_horn', 'engine', 'train', 'church_bells', 'airplane', 'fireworks', 'hand_saw']
16
+
17
+ sed_model = HTSAT_Swin_Transformer(
18
+ spec_size=config.htsat_spec_size,
19
+ patch_size=config.htsat_patch_size,
20
+ in_chans=1,
21
+ num_classes=config.classes_num,
22
+ window_size=config.htsat_window_size,
23
+ config=config,
24
+ depths=config.htsat_depth,
25
+ embed_dim=config.htsat_dim,
26
+ patch_stride=config.htsat_stride,
27
+ num_heads=config.htsat_num_head
28
+ )
29
+
30
+ model = SEDWrapper(
31
+ sed_model=sed_model,
32
+ config=config,
33
+ dataset=None
34
+ )
35
+
36
+ ckpt = torch.load(config.resume_checkpoint, map_location="cpu")
37
+ model.load_state_dict(ckpt["state_dict"], strict=False)
38
+
39
+ def inference(audio):
40
+ sr, y = audio
41
+ y = y/32767.0 # scipy vs librosa
42
+ if len(y.shape) != 1: # to mono
43
+ y = y[:,0]
44
+ y = librosa.resample(y, orig_sr=sr, target_sr=32000)
45
+ in_val = np.array([y])
46
+ result = model.inference(in_val)
47
+ pred = result['clipwise_output'][0]
48
+ # pred = np.exp(pred)/np.sum(np.exp(pred)) # softmax
49
+ return {class_mapping[i]: float(p) for i, p in enumerate(pred)}
50
+ # win_classes = np.argmax(result['clipwise_output'], axis=1)
51
+ # win_class_index = win_classes[0]
52
+ # win_class_name = class_mapping[win_class_index]
53
+ # return str({win_class_name: result['clipwise_output'][0][win_class_index]})
54
+
55
+
56
+ title = "HTS-Audio-Transformer"
57
+ description = "Audio classificatio with ESC-50."
58
+ # article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1911.13254' target='_blank'>Music Source Separation in the Waveform Domain</a> | <a href='https://github.com/facebookresearch/demucs' target='_blank'>Github Repo</a></p>"
59
+
60
+ examples = [['test.mp3']]
61
+ gr.Interface(
62
+ inference,
63
+ gr.inputs.Audio(type="numpy", label="Input"),
64
+ # gr.outputs.Textbox(),
65
+ gr.outputs.Label(),
66
+ title=title,
67
+ description=description,
68
+ # article=article,
69
+ examples=[[os.path.join(example_path, f)]
70
+ for f in os.listdir(example_path)]
71
+ ).launch(enable_queue=True)
config.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
+ # knutchen@ucsd.edu
3
+ # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
4
+ # The configuration for training the model
5
+
6
+ exp_name = "exp_htsat_pretrain" # the saved ckpt prefix name of the model
7
+ workspace = "/home/super/nic/HTS-Audio-Transformer" # the folder of your code
8
+ dataset_path = "/home/super/datasets-nas/ESC-50/" # the dataset path
9
+ desed_folder = "/home/super/nic/HTS-Audio-Transformer/DESED" # the desed file
10
+
11
+ dataset_type = "esc-50" # "audioset" "esc-50" "scv2"
12
+ index_type = "full_train" # only works for audioset
13
+ balanced_data = True # only works for audioset
14
+
15
+ loss_type = "clip_bce" #
16
+ # AudioSet & SCV2: "clip_bce" | ESC-50: "clip_ce"
17
+
18
+ # trained from a checkpoint, or evaluate a single model
19
+ # resume_checkpoint = None
20
+ # resume_checkpoint = "/home/super/nic/HTS-Audio-Transformer/saved_training/HTSAT_ESC_exp=1_fold=1_acc=0.985.ckpt"
21
+ resume_checkpoint = "/home/super/nic/HTS-Audio-Transformer/saved_training/HTSAT_ESC_exp=1_fold=1_acc=0.985.ckpt"
22
+ # resume_checkpoint = "/home/super/nic/HTS-Audio-Transformer/results/exp_htsat_pretrain/checkpoint_1/lightning_logs/version_9/checkpoints/l-epoch=99-acc=0.490.ckpt"
23
+
24
+ # "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt"
25
+
26
+ esc_fold = 0 # just for esc dataset, select the fold you need for evaluation and (+1) validation
27
+
28
+
29
+ debug = False
30
+
31
+ random_seed = 970131 # 19970318 970131 12412 127777 1009 34047
32
+ batch_size = 32 * 1 # batch size per GPU x GPU number , default is 32 x 4 = 128
33
+ learning_rate = 1e-3 # 1e-4 also workable
34
+ max_epoch = 100
35
+ num_workers = 3
36
+
37
+ lr_scheduler_epoch = [10,20,30]
38
+ lr_rate = [0.02, 0.05, 0.1]
39
+
40
+ # these data preparation optimizations do not bring many improvements, so deprecated
41
+ enable_token_label = False # token label
42
+ class_map_path = "esc-50-data.npy"
43
+ class_filter = None
44
+ retrieval_index = [15382, 9202, 130, 17618, 17157, 17516, 16356, 6165, 13992, 9238, 5550, 5733, 1914, 1600, 3450, 13735, 11108, 3762,
45
+ 9840, 11318, 8131, 4429, 16748, 4992, 16783, 12691, 4945, 8779, 2805, 9418, 2797, 14357, 5603, 212, 3852, 12666, 1338, 10269, 2388, 8260, 4293, 14454, 7677, 11253, 5060, 14938, 8840, 4542, 2627, 16336, 8992, 15496, 11140, 446, 6126, 10691, 8624, 10127, 9068, 16710, 10155, 14358, 7567, 5695, 2354, 8057, 17635, 133, 16183, 14535, 7248, 4560, 14429, 2463, 10773, 113, 2462, 9223, 4929, 14274, 4716, 17307, 4617, 2132, 11083, 1039, 1403, 9621, 13936, 2229, 2875, 17840, 9359, 13311, 9790, 13288, 4750, 17052, 8260, 14900]
46
+ token_label_range = [0.2,0.6]
47
+ enable_time_shift = False # shift time
48
+ enable_label_enhance = False # enhance hierarchical label
49
+ enable_repeat_mode = False # repeat the spectrogram / reshape the spectrogram
50
+
51
+
52
+
53
+ # for model's design
54
+ enable_tscam = True # enbale the token-semantic layer
55
+
56
+ # for signal processing
57
+ sample_rate = 32000 # 16000 for scv2, 32000 for audioset and esc-50
58
+ clip_samples = sample_rate * 10 # audio_set 10-sec clip
59
+ window_size = 1024
60
+ hop_size = 320 # 160 for scv2, 320 for audioset and esc-50
61
+ mel_bins = 64
62
+ fmin = 50
63
+ fmax = 14000
64
+ shift_max = int(clip_samples * 0.5)
65
+
66
+ # for data collection
67
+ classes_num = 50 # esc: 50 | audioset: 527 | scv2: 35
68
+ patch_size = (25, 4) # deprecated
69
+ crop_size = None # int(clip_samples * 0.5) deprecated
70
+
71
+ # for htsat hyperparamater
72
+ htsat_window_size = 8
73
+ htsat_spec_size = 256
74
+ htsat_patch_size = 4
75
+ htsat_stride = (4, 4)
76
+ htsat_num_head = [4,8,16,32]
77
+ htsat_dim = 96
78
+ htsat_depth = [2,2,6,2]
79
+
80
+ swin_pretrain_path = None
81
+ # "/home/Research/model_backup/pretrain/swin_tiny_c24_patch4_window8_256.pth"
82
+
83
+ # Some Deprecated Optimization in the model design, check the model code for details
84
+ htsat_attn_heatmap = False
85
+ htsat_hier_output = False
86
+ htsat_use_max = False
87
+
88
+
89
+ # for ensemble test
90
+
91
+ ensemble_checkpoints = []
92
+ ensemble_strides = []
93
+
94
+
95
+ # weight average folder
96
+ wa_folder = "/home/super/nic/HTS-Audio-Transformer/checkpoints/"
97
+ # weight average output filename
98
+ wa_model_path = "HTSAT_AudioSet_Saved_x.ckpt"
99
+
100
+ esm_model_pathes = [
101
+ "/home/super/nic/HTS-Audio-Transformer/HTSAT_AudioSet_Saved_1.ckpt",
102
+ "/home/super/nic/HTS-Audio-Transformer/HTSAT_AudioSet_Saved_2.ckpt",
103
+ "/home/super/nic/HTS-Audio-Transformer/HTSAT_AudioSet_Saved_3.ckpt",
104
+ "/home/super/nic/HTS-Audio-Transformer/HTSAT_AudioSet_Saved_4.ckpt",
105
+ "/home/super/nic/HTS-Audio-Transformer/HTSAT_AudioSet_Saved_5.ckpt",
106
+ "/home/super/nic/HTS-Audio-Transformer/HTSAT_AudioSet_Saved_6.ckpt"
107
+ ]
108
+
109
+ # for framewise localization
110
+ heatmap_dir = "//home/super/nic/HTS-Audio-Transformer/heatmap_output"
111
+ test_file = "htsat-test-ensemble"
112
+ fl_local = False # indicate if we need to use this dataset for the framewise detection
113
+ fl_dataset = "/home/Research/desed/desed_eval.npy"
114
+ fl_class_num = [
115
+ "Speech", "Frying", "Dishes", "Running_water",
116
+ "Blender", "Electric_shaver_toothbrush", "Alarm_bell_ringing",
117
+ "Cat", "Dog", "Vacuum_cleaner"
118
+ ]
119
+
120
+ # map 527 classes into 10 classes
121
+ fl_audioset_mapping = [
122
+ [0,1,2,3,4,5,6,7],
123
+ [366, 367, 368],
124
+ [364],
125
+ [288, 289, 290, 291, 292, 293, 294, 295, 296, 297],
126
+ [369],
127
+ [382],
128
+ [310, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402],
129
+ [81, 82, 83, 84, 85],
130
+ [74, 75, 76, 77, 78, 79],
131
+ [377]
132
+ ]
examples_audio/2-82367-A-10.wav ADDED
Binary file (320 kB). View file
 
examples_audio/4-255371-A-47.wav ADDED
Binary file (320 kB). View file
 
examples_audio/urban_sound_98223-7-10-0.wav ADDED
Binary file (175 kB). View file
 
model/htsat.py ADDED
@@ -0,0 +1,836 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
+ # knutchen@ucsd.edu
3
+ # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
4
+ # Model Core
5
+ # below codes are based and referred from https://github.com/microsoft/Swin-Transformer
6
+ # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
7
+
8
+
9
+ import logging
10
+ import pdb
11
+ import math
12
+ import random
13
+ from numpy.core.fromnumeric import clip, reshape
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.utils.checkpoint as checkpoint
17
+
18
+ from torchlibrosa.stft import Spectrogram, LogmelFilterBank
19
+ from torchlibrosa.augmentation import SpecAugmentation
20
+
21
+ from itertools import repeat
22
+ from typing import List
23
+ from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, to_2tuple
24
+ from utils import do_mixup, interpolate
25
+
26
+
27
+ # below codes are based and referred from https://github.com/microsoft/Swin-Transformer
28
+ # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
29
+
30
+ def window_partition(x, window_size):
31
+ """
32
+ Args:
33
+ x: (B, H, W, C)
34
+ window_size (int): window size
35
+ Returns:
36
+ windows: (num_windows*B, window_size, window_size, C)
37
+ """
38
+ B, H, W, C = x.shape
39
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
40
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
41
+ return windows
42
+
43
+
44
+ def window_reverse(windows, window_size, H, W):
45
+ """
46
+ Args:
47
+ windows: (num_windows*B, window_size, window_size, C)
48
+ window_size (int): Window size
49
+ H (int): Height of image
50
+ W (int): Width of image
51
+ Returns:
52
+ x: (B, H, W, C)
53
+ """
54
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
55
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
56
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
57
+ return x
58
+
59
+
60
+ class WindowAttention(nn.Module):
61
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
62
+ It supports both of shifted and non-shifted window.
63
+ Args:
64
+ dim (int): Number of input channels.
65
+ window_size (tuple[int]): The height and width of the window.
66
+ num_heads (int): Number of attention heads.
67
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
68
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
69
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
70
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
71
+ """
72
+
73
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
74
+
75
+ super().__init__()
76
+ self.dim = dim
77
+ self.window_size = window_size # Wh, Ww
78
+ self.num_heads = num_heads
79
+ head_dim = dim // num_heads
80
+ self.scale = qk_scale or head_dim ** -0.5
81
+
82
+ # define a parameter table of relative position bias
83
+ self.relative_position_bias_table = nn.Parameter(
84
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
85
+
86
+ # get pair-wise relative position index for each token inside the window
87
+ coords_h = torch.arange(self.window_size[0])
88
+ coords_w = torch.arange(self.window_size[1])
89
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
90
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
91
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
92
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
93
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
94
+ relative_coords[:, :, 1] += self.window_size[1] - 1
95
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
96
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
97
+ self.register_buffer("relative_position_index", relative_position_index)
98
+
99
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
100
+ self.attn_drop = nn.Dropout(attn_drop)
101
+ self.proj = nn.Linear(dim, dim)
102
+ self.proj_drop = nn.Dropout(proj_drop)
103
+
104
+ trunc_normal_(self.relative_position_bias_table, std=.02)
105
+ self.softmax = nn.Softmax(dim=-1)
106
+
107
+ def forward(self, x, mask=None):
108
+ """
109
+ Args:
110
+ x: input features with shape of (num_windows*B, N, C)
111
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
112
+ """
113
+ B_, N, C = x.shape
114
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
115
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
116
+
117
+ q = q * self.scale
118
+ attn = (q @ k.transpose(-2, -1))
119
+
120
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
121
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
122
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
123
+ attn = attn + relative_position_bias.unsqueeze(0)
124
+
125
+ if mask is not None:
126
+ nW = mask.shape[0]
127
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
128
+ attn = attn.view(-1, self.num_heads, N, N)
129
+ attn = self.softmax(attn)
130
+ else:
131
+ attn = self.softmax(attn)
132
+
133
+ attn = self.attn_drop(attn)
134
+
135
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
136
+ x = self.proj(x)
137
+ x = self.proj_drop(x)
138
+ return x, attn
139
+
140
+ def extra_repr(self):
141
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
142
+
143
+
144
+ # We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
145
+ class SwinTransformerBlock(nn.Module):
146
+ r""" Swin Transformer Block.
147
+ Args:
148
+ dim (int): Number of input channels.
149
+ input_resolution (tuple[int]): Input resulotion.
150
+ num_heads (int): Number of attention heads.
151
+ window_size (int): Window size.
152
+ shift_size (int): Shift size for SW-MSA.
153
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
154
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
155
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
156
+ drop (float, optional): Dropout rate. Default: 0.0
157
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
158
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
159
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
160
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
161
+ """
162
+
163
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
164
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
165
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'):
166
+ super().__init__()
167
+ self.dim = dim
168
+ self.input_resolution = input_resolution
169
+ self.num_heads = num_heads
170
+ self.window_size = window_size
171
+ self.shift_size = shift_size
172
+ self.mlp_ratio = mlp_ratio
173
+ self.norm_before_mlp = norm_before_mlp
174
+ if min(self.input_resolution) <= self.window_size:
175
+ # if window size is larger than input resolution, we don't partition windows
176
+ self.shift_size = 0
177
+ self.window_size = min(self.input_resolution)
178
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
179
+
180
+ self.norm1 = norm_layer(dim)
181
+ self.attn = WindowAttention(
182
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
183
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
184
+
185
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
186
+ if self.norm_before_mlp == 'ln':
187
+ self.norm2 = nn.LayerNorm(dim)
188
+ elif self.norm_before_mlp == 'bn':
189
+ self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2)
190
+ else:
191
+ raise NotImplementedError
192
+ mlp_hidden_dim = int(dim * mlp_ratio)
193
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
194
+
195
+ if self.shift_size > 0:
196
+ # calculate attention mask for SW-MSA
197
+ H, W = self.input_resolution
198
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
199
+ h_slices = (slice(0, -self.window_size),
200
+ slice(-self.window_size, -self.shift_size),
201
+ slice(-self.shift_size, None))
202
+ w_slices = (slice(0, -self.window_size),
203
+ slice(-self.window_size, -self.shift_size),
204
+ slice(-self.shift_size, None))
205
+ cnt = 0
206
+ for h in h_slices:
207
+ for w in w_slices:
208
+ img_mask[:, h, w, :] = cnt
209
+ cnt += 1
210
+
211
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
212
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
213
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
214
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
215
+ else:
216
+ attn_mask = None
217
+
218
+ self.register_buffer("attn_mask", attn_mask)
219
+
220
+ def forward(self, x):
221
+ # pdb.set_trace()
222
+ H, W = self.input_resolution
223
+ # print("H: ", H)
224
+ # print("W: ", W)
225
+ # pdb.set_trace()
226
+ B, L, C = x.shape
227
+ # assert L == H * W, "input feature has wrong size"
228
+
229
+ shortcut = x
230
+ x = self.norm1(x)
231
+ x = x.view(B, H, W, C)
232
+
233
+ # cyclic shift
234
+ if self.shift_size > 0:
235
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
236
+ else:
237
+ shifted_x = x
238
+
239
+ # partition windows
240
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
241
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
242
+
243
+ # W-MSA/SW-MSA
244
+ attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
245
+
246
+ # merge windows
247
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
248
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
249
+
250
+ # reverse cyclic shift
251
+ if self.shift_size > 0:
252
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
253
+ else:
254
+ x = shifted_x
255
+ x = x.view(B, H * W, C)
256
+
257
+ # FFN
258
+ x = shortcut + self.drop_path(x)
259
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
260
+
261
+ return x, attn
262
+
263
+ def extra_repr(self):
264
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
265
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
266
+
267
+
268
+
269
+ class PatchMerging(nn.Module):
270
+ r""" Patch Merging Layer.
271
+ Args:
272
+ input_resolution (tuple[int]): Resolution of input feature.
273
+ dim (int): Number of input channels.
274
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
275
+ """
276
+
277
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
278
+ super().__init__()
279
+ self.input_resolution = input_resolution
280
+ self.dim = dim
281
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
282
+ self.norm = norm_layer(4 * dim)
283
+
284
+ def forward(self, x):
285
+ """
286
+ x: B, H*W, C
287
+ """
288
+ H, W = self.input_resolution
289
+ B, L, C = x.shape
290
+ assert L == H * W, "input feature has wrong size"
291
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
292
+
293
+ x = x.view(B, H, W, C)
294
+
295
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
296
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
297
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
298
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
299
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
300
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
301
+
302
+ x = self.norm(x)
303
+ x = self.reduction(x)
304
+
305
+ return x
306
+
307
+ def extra_repr(self):
308
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
309
+
310
+
311
+ class BasicLayer(nn.Module):
312
+ """ A basic Swin Transformer layer for one stage.
313
+ Args:
314
+ dim (int): Number of input channels.
315
+ input_resolution (tuple[int]): Input resolution.
316
+ depth (int): Number of blocks.
317
+ num_heads (int): Number of attention heads.
318
+ window_size (int): Local window size.
319
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
320
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
321
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
322
+ drop (float, optional): Dropout rate. Default: 0.0
323
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
324
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
325
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
326
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
327
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
328
+ """
329
+
330
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
331
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
332
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
333
+ norm_before_mlp='ln'):
334
+
335
+ super().__init__()
336
+ self.dim = dim
337
+ self.input_resolution = input_resolution
338
+ self.depth = depth
339
+ self.use_checkpoint = use_checkpoint
340
+
341
+ # build blocks
342
+ self.blocks = nn.ModuleList([
343
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
344
+ num_heads=num_heads, window_size=window_size,
345
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
346
+ mlp_ratio=mlp_ratio,
347
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
348
+ drop=drop, attn_drop=attn_drop,
349
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
350
+ norm_layer=norm_layer, norm_before_mlp=norm_before_mlp)
351
+ for i in range(depth)])
352
+
353
+ # patch merging layer
354
+ if downsample is not None:
355
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
356
+ else:
357
+ self.downsample = None
358
+
359
+ def forward(self, x):
360
+ attns = []
361
+ for blk in self.blocks:
362
+ if self.use_checkpoint:
363
+ x = checkpoint.checkpoint(blk, x)
364
+ else:
365
+ x, attn = blk(x)
366
+ if not self.training:
367
+ attns.append(attn.unsqueeze(0))
368
+ if self.downsample is not None:
369
+ x = self.downsample(x)
370
+ if not self.training:
371
+ attn = torch.cat(attns, dim = 0)
372
+ attn = torch.mean(attn, dim = 0)
373
+ return x, attn
374
+
375
+ def extra_repr(self):
376
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
377
+
378
+
379
+ # The Core of HTSAT
380
+ class HTSAT_Swin_Transformer(nn.Module):
381
+ r"""HTSAT based on the Swin Transformer
382
+ Args:
383
+ spec_size (int | tuple(int)): Input Spectrogram size. Default 256
384
+ patch_size (int | tuple(int)): Patch size. Default: 4
385
+ path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
386
+ in_chans (int): Number of input image channels. Default: 1 (mono)
387
+ num_classes (int): Number of classes for classification head. Default: 527
388
+ embed_dim (int): Patch embedding dimension. Default: 96
389
+ depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
390
+ num_heads (tuple(int)): Number of attention heads in different layers.
391
+ window_size (int): Window size. Default: 8
392
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
393
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
394
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
395
+ drop_rate (float): Dropout rate. Default: 0
396
+ attn_drop_rate (float): Attention dropout rate. Default: 0
397
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
398
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
399
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
400
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
401
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
402
+ config (module): The configuration Module from config.py
403
+ """
404
+
405
+ def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4),
406
+ in_chans=1, num_classes=527,
407
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32],
408
+ window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
409
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
410
+ norm_layer=nn.LayerNorm,
411
+ ape=False, patch_norm=True,
412
+ use_checkpoint=False, norm_before_mlp='ln', config = None, **kwargs):
413
+ super(HTSAT_Swin_Transformer, self).__init__()
414
+
415
+ self.config = config
416
+ self.spec_size = spec_size
417
+ self.patch_stride = patch_stride
418
+ self.patch_size = patch_size
419
+ self.window_size = window_size
420
+ self.embed_dim = embed_dim
421
+ self.depths = depths
422
+ self.ape = ape
423
+ self.in_chans = in_chans
424
+ self.num_classes = num_classes
425
+ self.num_heads = num_heads
426
+ self.num_layers = len(self.depths)
427
+ self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
428
+
429
+ self.drop_rate = drop_rate
430
+ self.attn_drop_rate = attn_drop_rate
431
+ self.drop_path_rate = drop_path_rate
432
+
433
+ self.qkv_bias = qkv_bias
434
+ self.qk_scale = None
435
+
436
+ self.patch_norm = patch_norm
437
+ self.norm_layer = norm_layer if self.patch_norm else None
438
+ self.norm_before_mlp = norm_before_mlp
439
+ self.mlp_ratio = mlp_ratio
440
+
441
+ self.use_checkpoint = use_checkpoint
442
+
443
+ # process mel-spec ; used only once
444
+ self.freq_ratio = self.spec_size // self.config.mel_bins
445
+ window = 'hann'
446
+ center = True
447
+ pad_mode = 'reflect'
448
+ ref = 1.0
449
+ amin = 1e-10
450
+ top_db = None
451
+ self.interpolate_ratio = 32 # Downsampled ratio
452
+ # Spectrogram extractor
453
+ self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size,
454
+ win_length=config.window_size, window=window, center=center, pad_mode=pad_mode,
455
+ freeze_parameters=True)
456
+ # Logmel feature extractor
457
+ self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size,
458
+ n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db,
459
+ freeze_parameters=True)
460
+ # Spec augmenter
461
+ self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2,
462
+ freq_drop_width=8, freq_stripes_num=2) # 2 2
463
+ self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
464
+
465
+
466
+ # split spctrogram into non-overlapping patches
467
+ self.patch_embed = PatchEmbed(
468
+ img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans,
469
+ embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride)
470
+
471
+ num_patches = self.patch_embed.num_patches
472
+ patches_resolution = self.patch_embed.grid_size
473
+ self.patches_resolution = patches_resolution
474
+
475
+ # absolute position embedding
476
+ if self.ape:
477
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))
478
+ trunc_normal_(self.absolute_pos_embed, std=.02)
479
+
480
+ self.pos_drop = nn.Dropout(p=self.drop_rate)
481
+
482
+ # stochastic depth
483
+ dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule
484
+
485
+ # build layers
486
+ self.layers = nn.ModuleList()
487
+ for i_layer in range(self.num_layers):
488
+ layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer),
489
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
490
+ patches_resolution[1] // (2 ** i_layer)),
491
+ depth=self.depths[i_layer],
492
+ num_heads=self.num_heads[i_layer],
493
+ window_size=self.window_size,
494
+ mlp_ratio=self.mlp_ratio,
495
+ qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
496
+ drop=self.drop_rate, attn_drop=self.attn_drop_rate,
497
+ drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],
498
+ norm_layer=self.norm_layer,
499
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
500
+ use_checkpoint=use_checkpoint,
501
+ norm_before_mlp=self.norm_before_mlp)
502
+ self.layers.append(layer)
503
+
504
+ # A deprecated optimization for using a hierarchical output from different blocks
505
+ # if self.config.htsat_hier_output:
506
+ # self.norm = nn.ModuleList(
507
+ # [self.norm_layer(
508
+ # min(
509
+ # self.embed_dim * (2 ** (len(self.depths) - 1)),
510
+ # self.embed_dim * (2 ** (i + 1))
511
+ # )
512
+ # ) for i in range(len(self.depths))]
513
+ # )
514
+ # else:
515
+
516
+ self.norm = self.norm_layer(self.num_features)
517
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
518
+ self.maxpool = nn.AdaptiveMaxPool1d(1)
519
+
520
+ # A deprecated optimization for using the max value instead of average value
521
+ # if self.config.htsat_use_max:
522
+ # self.a_avgpool = nn.AvgPool1d(kernel_size=3, stride=1, padding=1)
523
+ # self.a_maxpool = nn.MaxPool1d(kernel_size=3, stride=1, padding=1)
524
+
525
+ if self.config.enable_tscam:
526
+ # if self.config.htsat_hier_output:
527
+ # self.tscam_conv = nn.ModuleList()
528
+ # for i in range(len(self.depths)):
529
+ # zoom_ratio = 2 ** min(len(self.depths) - 1, i + 1)
530
+ # zoom_dim = min(
531
+ # self.embed_dim * (2 ** (len(self.depths) - 1)),
532
+ # self.embed_dim * (2 ** (i + 1))
533
+ # )
534
+ # SF = self.spec_size // zoom_ratio // self.patch_stride[0] // self.freq_ratio
535
+ # self.tscam_conv.append(
536
+ # nn.Conv2d(
537
+ # in_channels = zoom_dim,
538
+ # out_channels = self.num_classes,
539
+ # kernel_size = (SF, 3),
540
+ # padding = (0,1)
541
+ # )
542
+ # )
543
+ # self.head = nn.Linear(num_classes * len(self.depths), num_classes)
544
+ # else:
545
+
546
+ SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio
547
+ self.tscam_conv = nn.Conv2d(
548
+ in_channels = self.num_features,
549
+ out_channels = self.num_classes,
550
+ kernel_size = (SF,3),
551
+ padding = (0,1)
552
+ )
553
+ self.head = nn.Linear(num_classes, num_classes)
554
+ else:
555
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
556
+
557
+ self.apply(self._init_weights)
558
+
559
+ def _init_weights(self, m):
560
+ if isinstance(m, nn.Linear):
561
+ trunc_normal_(m.weight, std=.02)
562
+ if isinstance(m, nn.Linear) and m.bias is not None:
563
+ nn.init.constant_(m.bias, 0)
564
+ elif isinstance(m, nn.LayerNorm):
565
+ nn.init.constant_(m.bias, 0)
566
+ nn.init.constant_(m.weight, 1.0)
567
+
568
+ @torch.jit.ignore
569
+ def no_weight_decay(self):
570
+ return {'absolute_pos_embed'}
571
+
572
+ @torch.jit.ignore
573
+ def no_weight_decay_keywords(self):
574
+ return {'relative_position_bias_table'}
575
+
576
+
577
+ def forward_features(self, x):
578
+ # A deprecated optimization for using a hierarchical output from different blocks
579
+ # if self.config.htsat_hier_output:
580
+ # hier_x = []
581
+ # hier_attn = []
582
+
583
+ frames_num = x.shape[2]
584
+ x = self.patch_embed(x)
585
+ if self.ape:
586
+ x = x + self.absolute_pos_embed
587
+ x = self.pos_drop(x)
588
+ for i, layer in enumerate(self.layers):
589
+ x, attn = layer(x)
590
+ # A deprecated optimization for using a hierarchical output from different blocks
591
+ # if self.config.htsat_hier_output:
592
+ # hier_x.append(x)
593
+ # if i == len(self.layers) - 1:
594
+ # hier_attn.append(attn)
595
+
596
+ # A deprecated optimization for using a hierarchical output from different blocks
597
+ # if self.config.htsat_hier_output:
598
+ # hxs = []
599
+ # fphxs = []
600
+ # for i in range(len(hier_x)):
601
+ # hx = hier_x[i]
602
+ # hx = self.norm[i](hx)
603
+ # B, N, C = hx.shape
604
+ # zoom_ratio = 2 ** min(len(self.depths) - 1, i + 1)
605
+ # SF = frames_num // zoom_ratio // self.patch_stride[0]
606
+ # ST = frames_num // zoom_ratio // self.patch_stride[1]
607
+ # hx = hx.permute(0,2,1).contiguous().reshape(B, C, SF, ST)
608
+ # B, C, F, T = hx.shape
609
+ # c_freq_bin = F // self.freq_ratio
610
+ # hx = hx.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
611
+ # hx = hx.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
612
+
613
+ # hx = self.tscam_conv[i](hx)
614
+ # hx = torch.flatten(hx, 2)
615
+ # fphx = interpolate(hx.permute(0,2,1).contiguous(), self.spec_size * self.freq_ratio // hx.shape[2])
616
+
617
+ # hx = self.avgpool(hx)
618
+ # hx = torch.flatten(hx, 1)
619
+ # hxs.append(hx)
620
+ # fphxs.append(fphx)
621
+ # hxs = torch.cat(hxs, dim=1)
622
+ # fphxs = torch.cat(fphxs, dim = 2)
623
+ # hxs = self.head(hxs)
624
+ # fphxs = self.head(fphxs)
625
+ # output_dict = {'framewise_output': torch.sigmoid(fphxs),
626
+ # 'clipwise_output': torch.sigmoid(hxs)}
627
+ # return output_dict
628
+
629
+ if self.config.enable_tscam:
630
+ # for x
631
+ x = self.norm(x)
632
+ B, N, C = x.shape
633
+ SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
634
+ ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
635
+ x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST)
636
+ B, C, F, T = x.shape
637
+ # group 2D CNN
638
+ c_freq_bin = F // self.freq_ratio
639
+ x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
640
+ x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
641
+
642
+ # get latent_output
643
+ latent_output = self.avgpool(torch.flatten(x,2))
644
+ latent_output = torch.flatten(latent_output, 1)
645
+
646
+ # display the attention map, if needed
647
+ if self.config.htsat_attn_heatmap:
648
+ # for attn
649
+ attn = torch.mean(attn, dim = 1)
650
+ attn = torch.mean(attn, dim = 1)
651
+ attn = attn.reshape(B, SF, ST)
652
+ c_freq_bin = SF // self.freq_ratio
653
+ attn = attn.reshape(B, SF // c_freq_bin, c_freq_bin, ST)
654
+ attn = attn.permute(0,2,1,3).contiguous().reshape(B, c_freq_bin, -1)
655
+ attn = attn.mean(dim = 1)
656
+ attn_max = torch.max(attn, dim = 1, keepdim = True)[0]
657
+ attn_min = torch.min(attn, dim = 1, keepdim = True)[0]
658
+ attn = ((attn * 0.15) + (attn_max * 0.85 - attn_min)) / (attn_max - attn_min)
659
+ attn = attn.unsqueeze(dim = 2)
660
+
661
+ x = self.tscam_conv(x)
662
+ x = torch.flatten(x, 2) # B, C, T
663
+
664
+ # A deprecated optimization for using the max value instead of average value
665
+ # if self.config.htsat_use_max:
666
+ # x1 = self.a_maxpool(x)
667
+ # x2 = self.a_avgpool(x)
668
+ # x = x1 + x2
669
+
670
+ if self.config.htsat_attn_heatmap:
671
+ fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous() * attn, 8 * self.patch_stride[1])
672
+ else:
673
+ fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
674
+
675
+ # A deprecated optimization for using the max value instead of average value
676
+ # if self.config.htsat_use_max:
677
+ # x1 = self.avgpool(x)
678
+ # x2 = self.maxpool(x)
679
+ # x = x1 + x2
680
+ # else:
681
+ x = self.avgpool(x)
682
+ x = torch.flatten(x, 1)
683
+
684
+ if self.config.loss_type == "clip_ce":
685
+ output_dict = {
686
+ 'framewise_output': fpx, # already sigmoided
687
+ 'clipwise_output': x,
688
+ 'latent_output': latent_output
689
+ }
690
+ else:
691
+ output_dict = {
692
+ 'framewise_output': fpx, # already sigmoided
693
+ 'clipwise_output': torch.sigmoid(x),
694
+ 'latent_output': latent_output
695
+ }
696
+
697
+ else:
698
+ x = self.norm(x) # B N C
699
+ B, N, C = x.shape
700
+
701
+ fpx = x.permute(0,2,1).contiguous().reshape(B, C, frames_num // (2 ** (len(self.depths) + 1)), frames_num // (2 ** (len(self.depths) + 1)) )
702
+ B, C, F, T = fpx.shape
703
+ c_freq_bin = F // self.freq_ratio
704
+ fpx = fpx.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
705
+ fpx = fpx.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)
706
+ fpx = torch.sum(fpx, dim = 2)
707
+ fpx = interpolate(fpx.permute(0,2,1).contiguous(), 8 * self.patch_stride[1])
708
+ x = self.avgpool(x.transpose(1, 2)) # B C 1
709
+ x = torch.flatten(x, 1)
710
+ if self.num_classes > 0:
711
+ x = self.head(x)
712
+ fpx = self.head(fpx)
713
+ output_dict = {'framewise_output': torch.sigmoid(fpx),
714
+ 'clipwise_output': torch.sigmoid(x)}
715
+ return output_dict
716
+
717
+ def crop_wav(self, x, crop_size, spe_pos = None):
718
+ time_steps = x.shape[2]
719
+ tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
720
+ for i in range(len(x)):
721
+ if spe_pos is None:
722
+ crop_pos = random.randint(0, time_steps - crop_size - 1)
723
+ else:
724
+ crop_pos = spe_pos
725
+ tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:]
726
+ return tx
727
+
728
+ # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
729
+ def reshape_wav2img(self, x):
730
+ B, C, T, F = x.shape
731
+ target_T = int(self.spec_size * self.freq_ratio)
732
+ target_F = self.spec_size // self.freq_ratio
733
+ assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
734
+ # to avoid bicubic zero error
735
+ if T < target_T:
736
+ x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
737
+ if F < target_F:
738
+ x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
739
+ x = x.permute(0,1,3,2).contiguous()
740
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio)
741
+ # print(x.shape)
742
+ x = x.permute(0,1,3,2,4).contiguous()
743
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
744
+ return x
745
+
746
+ # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
747
+ def repeat_wat2img(self, x, cur_pos):
748
+ B, C, T, F = x.shape
749
+ target_T = int(self.spec_size * self.freq_ratio)
750
+ target_F = self.spec_size // self.freq_ratio
751
+ assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size"
752
+ # to avoid bicubic zero error
753
+ if T < target_T:
754
+ x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True)
755
+ if F < target_F:
756
+ x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
757
+ x = x.permute(0,1,3,2).contiguous() # B C F T
758
+ x = x[:,:,:,cur_pos:cur_pos + self.spec_size]
759
+ x = x.repeat(repeats = (1,1,4,1))
760
+ return x
761
+
762
+ def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False):# out_feat_keys: List[str] = None):
763
+ x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
764
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
765
+
766
+
767
+ x = x.transpose(1, 3)
768
+ x = self.bn0(x)
769
+ x = x.transpose(1, 3)
770
+ if self.training:
771
+ x = self.spec_augmenter(x)
772
+ if self.training and mixup_lambda is not None:
773
+ x = do_mixup(x, mixup_lambda)
774
+
775
+ if infer_mode:
776
+ # in infer mode. we need to handle different length audio input
777
+ frame_num = x.shape[2]
778
+ target_T = int(self.spec_size * self.freq_ratio)
779
+ repeat_ratio = math.floor(target_T / frame_num)
780
+ x = x.repeat(repeats=(1,1,repeat_ratio,1))
781
+ x = self.reshape_wav2img(x)
782
+ output_dict = self.forward_features(x)
783
+ elif self.config.enable_repeat_mode:
784
+ if self.training:
785
+ cur_pos = random.randint(0, (self.freq_ratio - 1) * self.spec_size - 1)
786
+ x = self.repeat_wat2img(x, cur_pos)
787
+ output_dict = self.forward_features(x)
788
+ else:
789
+ output_dicts = []
790
+ for cur_pos in range(0, (self.freq_ratio - 1) * self.spec_size + 1, self.spec_size):
791
+ tx = x.clone()
792
+ tx = self.repeat_wat2img(tx, cur_pos)
793
+ output_dicts.append(self.forward_features(tx))
794
+ clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
795
+ framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
796
+ for d in output_dicts:
797
+ clipwise_output += d["clipwise_output"]
798
+ framewise_output += d["framewise_output"]
799
+ clipwise_output = clipwise_output / len(output_dicts)
800
+ framewise_output = framewise_output / len(output_dicts)
801
+
802
+ output_dict = {
803
+ 'framewise_output': framewise_output,
804
+ 'clipwise_output': clipwise_output
805
+ }
806
+ else:
807
+ if x.shape[2] > self.freq_ratio * self.spec_size:
808
+ if self.training:
809
+ x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
810
+ x = self.reshape_wav2img(x)
811
+ output_dict = self.forward_features(x)
812
+ else:
813
+ # Change: Hard code here
814
+ overlap_size = (x.shape[2] - 1) // 4
815
+ output_dicts = []
816
+ crop_size = (x.shape[2] - 1) // 2
817
+ for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
818
+ tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
819
+ tx = self.reshape_wav2img(tx)
820
+ output_dicts.append(self.forward_features(tx))
821
+ clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
822
+ framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
823
+ for d in output_dicts:
824
+ clipwise_output += d["clipwise_output"]
825
+ framewise_output += d["framewise_output"]
826
+ clipwise_output = clipwise_output / len(output_dicts)
827
+ framewise_output = framewise_output / len(output_dicts)
828
+ output_dict = {
829
+ 'framewise_output': framewise_output,
830
+ 'clipwise_output': clipwise_output
831
+ }
832
+ else: # this part is typically used, and most easy one
833
+ x = self.reshape_wav2img(x)
834
+ output_dict = self.forward_features(x)
835
+ # x = self.head(x)
836
+ return output_dict
model/layers.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
+ # knutchen@ucsd.edu
3
+ # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
4
+ # Some layers designed on the model
5
+ # below codes are based and referred from https://github.com/microsoft/Swin-Transformer
6
+ # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from itertools import repeat
12
+ import collections.abc
13
+ import math
14
+ import warnings
15
+
16
+ from torch.nn.init import _calculate_fan_in_and_fan_out
17
+
18
+
19
+ # from PyTorch internals
20
+ def _ntuple(n):
21
+ def parse(x):
22
+ if isinstance(x, collections.abc.Iterable):
23
+ return x
24
+ return tuple(repeat(x, n))
25
+ return parse
26
+
27
+ to_1tuple = _ntuple(1)
28
+ to_2tuple = _ntuple(2)
29
+ to_3tuple = _ntuple(3)
30
+ to_4tuple = _ntuple(4)
31
+ to_ntuple = _ntuple
32
+
33
+
34
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
35
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
36
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
37
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
38
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
39
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
40
+ 'survival rate' as the argument.
41
+ """
42
+ if drop_prob == 0. or not training:
43
+ return x
44
+ keep_prob = 1 - drop_prob
45
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
46
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
47
+ random_tensor.floor_() # binarize
48
+ output = x.div(keep_prob) * random_tensor
49
+ return output
50
+
51
+
52
+ class DropPath(nn.Module):
53
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
54
+ """
55
+ def __init__(self, drop_prob=None):
56
+ super(DropPath, self).__init__()
57
+ self.drop_prob = drop_prob
58
+
59
+ def forward(self, x):
60
+ return drop_path(x, self.drop_prob, self.training)
61
+
62
+ class PatchEmbed(nn.Module):
63
+ """ 2D Image to Patch Embedding
64
+ """
65
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16):
66
+ super().__init__()
67
+ img_size = to_2tuple(img_size)
68
+ patch_size = to_2tuple(patch_size)
69
+ patch_stride = to_2tuple(patch_stride)
70
+ self.img_size = img_size
71
+ self.patch_size = patch_size
72
+ self.patch_stride = patch_stride
73
+ self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1])
74
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
75
+ self.flatten = flatten
76
+ self.in_chans = in_chans
77
+ self.embed_dim = embed_dim
78
+
79
+ padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2)
80
+
81
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)
82
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
83
+
84
+ def forward(self, x):
85
+ B, C, H, W = x.shape
86
+ assert H == self.img_size[0] and W == self.img_size[1], \
87
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
88
+ x = self.proj(x)
89
+ if self.flatten:
90
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
91
+ x = self.norm(x)
92
+ return x
93
+
94
+ class Mlp(nn.Module):
95
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
96
+ """
97
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
98
+ super().__init__()
99
+ out_features = out_features or in_features
100
+ hidden_features = hidden_features or in_features
101
+ self.fc1 = nn.Linear(in_features, hidden_features)
102
+ self.act = act_layer()
103
+ self.fc2 = nn.Linear(hidden_features, out_features)
104
+ self.drop = nn.Dropout(drop)
105
+
106
+ def forward(self, x):
107
+ x = self.fc1(x)
108
+ x = self.act(x)
109
+ x = self.drop(x)
110
+ x = self.fc2(x)
111
+ x = self.drop(x)
112
+ return x
113
+
114
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
115
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
116
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
117
+ def norm_cdf(x):
118
+ # Computes standard normal cumulative distribution function
119
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
120
+
121
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
122
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
123
+ "The distribution of values may be incorrect.",
124
+ stacklevel=2)
125
+
126
+ with torch.no_grad():
127
+ # Values are generated by using a truncated uniform distribution and
128
+ # then using the inverse CDF for the normal distribution.
129
+ # Get upper and lower cdf values
130
+ l = norm_cdf((a - mean) / std)
131
+ u = norm_cdf((b - mean) / std)
132
+
133
+ # Uniformly fill tensor with values from [l, u], then translate to
134
+ # [2l-1, 2u-1].
135
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
136
+
137
+ # Use inverse cdf transform for normal distribution to get truncated
138
+ # standard normal
139
+ tensor.erfinv_()
140
+
141
+ # Transform to proper mean, std
142
+ tensor.mul_(std * math.sqrt(2.))
143
+ tensor.add_(mean)
144
+
145
+ # Clamp to ensure it's in the proper range
146
+ tensor.clamp_(min=a, max=b)
147
+ return tensor
148
+
149
+
150
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
151
+ # type: (Tensor, float, float, float, float) -> Tensor
152
+ r"""Fills the input Tensor with values drawn from a truncated
153
+ normal distribution. The values are effectively drawn from the
154
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
155
+ with values outside :math:`[a, b]` redrawn until they are within
156
+ the bounds. The method used for generating the random values works
157
+ best when :math:`a \leq \text{mean} \leq b`.
158
+ Args:
159
+ tensor: an n-dimensional `torch.Tensor`
160
+ mean: the mean of the normal distribution
161
+ std: the standard deviation of the normal distribution
162
+ a: the minimum cutoff value
163
+ b: the maximum cutoff value
164
+ Examples:
165
+ >>> w = torch.empty(3, 5)
166
+ >>> nn.init.trunc_normal_(w)
167
+ """
168
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
169
+
170
+
171
+ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
172
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
173
+ if mode == 'fan_in':
174
+ denom = fan_in
175
+ elif mode == 'fan_out':
176
+ denom = fan_out
177
+ elif mode == 'fan_avg':
178
+ denom = (fan_in + fan_out) / 2
179
+
180
+ variance = scale / denom
181
+
182
+ if distribution == "truncated_normal":
183
+ # constant is stddev of standard normal truncated to (-2, 2)
184
+ trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
185
+ elif distribution == "normal":
186
+ tensor.normal_(std=math.sqrt(variance))
187
+ elif distribution == "uniform":
188
+ bound = math.sqrt(3 * variance)
189
+ tensor.uniform_(-bound, bound)
190
+ else:
191
+ raise ValueError(f"invalid distribution {distribution}")
192
+
193
+
194
+ def lecun_normal_(tensor):
195
+ variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h5py==3.6.0
2
+ librosa==0.8.1
3
+ matplotlib==3.5.1
4
+ museval==0.4.0
5
+ numpy==1.22.0
6
+ pandas==1.4.0
7
+ pytorch_lightning==1.5.9
8
+ scikit_learn==1.0.2
9
+ scipy==1.7.3
10
+ soundfile==0.10.3.post1
11
+ tensorboard==2.8.0
12
+ torch==1.10.2
13
+ torchaudio==0.10.2
14
+ torchcontrib==0.0.2
15
+ torchlibrosa==0.0.9
16
+ tqdm==4.62.3
17
+
18
+ gradio
saved_training/HTSAT_ESC_exp=1_fold=1_acc=0.985.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02478f008fad5c0fa6cff856c729f1f22deb73c8e254c652785371355c27ce0f
3
+ size 339619927
sed_model.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
+ # knutchen@ucsd.edu
3
+ # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
4
+ # The Model Training Wrapper
5
+ import numpy as np
6
+ import librosa
7
+ import os
8
+ import bisect
9
+ from numpy.lib.function_base import average
10
+
11
+ from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
12
+
13
+ from utils import get_loss_func, get_mix_lambda, d_prime
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import torch.utils.checkpoint as cp
18
+ import torch.optim as optim
19
+ from torch.nn.parameter import Parameter
20
+ import torch.distributed as dist
21
+ import pytorch_lightning as pl
22
+ from utils import do_mixup, get_mix_lambda, do_mixup_label
23
+
24
+
25
+ class SEDWrapper(pl.LightningModule):
26
+ def __init__(self, sed_model, config, dataset):
27
+ super().__init__()
28
+ self.sed_model = sed_model
29
+ self.config = config
30
+ self.dataset = dataset
31
+ self.loss_func = get_loss_func(config.loss_type)
32
+
33
+ def evaluate_metric(self, pred, ans):
34
+ ap = []
35
+ if self.config.dataset_type == "audioset":
36
+ mAP = np.mean(average_precision_score(ans, pred, average = None))
37
+ mAUC = np.mean(roc_auc_score(ans, pred, average = None))
38
+ dprime = d_prime(mAUC)
39
+ return {"mAP": mAP, "mAUC": mAUC, "dprime": dprime}
40
+ else:
41
+ acc = accuracy_score(ans, np.argmax(pred, 1))
42
+ return {"acc": acc}
43
+ def forward(self, x, mix_lambda = None):
44
+ output_dict = self.sed_model(x, mix_lambda)
45
+ return output_dict["clipwise_output"], output_dict["framewise_output"]
46
+
47
+ def inference(self, x):
48
+ self.device_type = next(self.parameters()).device
49
+ self.eval()
50
+ x = torch.from_numpy(x).float().to(self.device_type)
51
+ print(x.shape)
52
+ output_dict = self.sed_model(x, None, True)
53
+ for key in output_dict.keys():
54
+ output_dict[key] = output_dict[key].detach().cpu().numpy()
55
+ return output_dict
56
+
57
+ def training_step(self, batch, batch_idx):
58
+ self.device_type = next(self.parameters()).device
59
+ mix_lambda = torch.from_numpy(get_mix_lambda(0.5, len(batch["waveform"]))).to(self.device_type)
60
+ # Another Choice: also mixup the target, but AudioSet is not a perfect data
61
+ # so "adding noise" might be better than purly "mix"
62
+ # batch["target"] = do_mixup_label(batch["target"])
63
+ # batch["target"] = do_mixup(batch["target"], mix_lambda)
64
+ pred, _ = self(batch["waveform"], mix_lambda)
65
+ loss = self.loss_func(pred, batch["target"])
66
+ self.log("loss", loss, on_epoch= True, prog_bar=True)
67
+ return loss
68
+ def training_epoch_end(self, outputs):
69
+ # Change: SWA, deprecated
70
+ # for opt in self.trainer.optimizers:
71
+ # if not type(opt) is SWA:
72
+ # continue
73
+ # opt.swap_swa_sgd()
74
+ self.dataset.generate_queue()
75
+
76
+
77
+ def validation_step(self, batch, batch_idx):
78
+ pred, _ = self(batch["waveform"])
79
+ return [pred.detach(), batch["target"].detach()]
80
+
81
+ def validation_epoch_end(self, validation_step_outputs):
82
+ self.device_type = next(self.parameters()).device
83
+ pred = torch.cat([d[0] for d in validation_step_outputs], dim = 0)
84
+ target = torch.cat([d[1] for d in validation_step_outputs], dim = 0)
85
+ gather_pred = [torch.zeros_like(pred) for _ in range(dist.get_world_size())]
86
+ gather_target = [torch.zeros_like(target) for _ in range(dist.get_world_size())]
87
+ dist.barrier()
88
+ if self.config.dataset_type == "audioset":
89
+ metric_dict = {
90
+ "mAP": 0.,
91
+ "mAUC": 0.,
92
+ "dprime": 0.
93
+ }
94
+ else:
95
+ metric_dict = {
96
+ "acc":0.
97
+ }
98
+ dist.all_gather(gather_pred, pred)
99
+ dist.all_gather(gather_target, target)
100
+ if dist.get_rank() == 0:
101
+ gather_pred = torch.cat(gather_pred, dim = 0).cpu().numpy()
102
+ gather_target = torch.cat(gather_target, dim = 0).cpu().numpy()
103
+ if self.config.dataset_type == "scv2":
104
+ gather_target = np.argmax(gather_target, 1)
105
+ metric_dict = self.evaluate_metric(gather_pred, gather_target)
106
+ print(self.device_type, dist.get_world_size(), metric_dict, flush = True)
107
+
108
+ if self.config.dataset_type == "audioset":
109
+ self.log("mAP", metric_dict["mAP"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
110
+ self.log("mAUC", metric_dict["mAUC"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
111
+ self.log("dprime", metric_dict["dprime"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
112
+ else:
113
+ self.log("acc", metric_dict["acc"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
114
+ dist.barrier()
115
+
116
+ def time_shifting(self, x, shift_len):
117
+ shift_len = int(shift_len)
118
+ new_sample = torch.cat([x[:, shift_len:], x[:, :shift_len]], axis = 1)
119
+ return new_sample
120
+
121
+ def test_step(self, batch, batch_idx):
122
+ print(batch['waveform'].shape)
123
+ exit()
124
+ self.device_type = next(self.parameters()).device
125
+ preds = []
126
+ # time shifting optimization
127
+ if self.config.fl_local or self.config.dataset_type != "audioset":
128
+ shift_num = 1 # framewise localization cannot allow the time shifting
129
+ else:
130
+ shift_num = 10
131
+ for i in range(shift_num):
132
+ pred, pred_map = self(batch["waveform"])
133
+ preds.append(pred.unsqueeze(0))
134
+ batch["waveform"] = self.time_shifting(batch["waveform"], shift_len = 100 * (i + 1))
135
+ preds = torch.cat(preds, dim=0)
136
+ pred = preds.mean(dim = 0)
137
+ if self.config.fl_local:
138
+ return [
139
+ pred.detach().cpu().numpy(),
140
+ pred_map.detach().cpu().numpy(),
141
+ batch["audio_name"],
142
+ batch["real_len"].cpu().numpy()
143
+ ]
144
+ else:
145
+ return [pred.detach(), batch["target"].detach()]
146
+
147
+ def test_epoch_end(self, test_step_outputs):
148
+ self.device_type = next(self.parameters()).device
149
+ if self.config.fl_local:
150
+ pred = np.concatenate([d[0] for d in test_step_outputs], axis = 0)
151
+ pred_map = np.concatenate([d[1] for d in test_step_outputs], axis = 0)
152
+ audio_name = np.concatenate([d[2] for d in test_step_outputs], axis = 0)
153
+ real_len = np.concatenate([d[3] for d in test_step_outputs], axis = 0)
154
+ heatmap_file = os.path.join(self.config.heatmap_dir, self.config.test_file + "_" + str(self.device_type) + ".npy")
155
+ save_npy = [
156
+ {
157
+ "audio_name": audio_name[i],
158
+ "heatmap": pred_map[i],
159
+ "pred": pred[i],
160
+ "real_len":real_len[i]
161
+ }
162
+ for i in range(len(pred))
163
+ ]
164
+ np.save(heatmap_file, save_npy)
165
+ else:
166
+ self.device_type = next(self.parameters()).device
167
+ pred = torch.cat([d[0] for d in test_step_outputs], dim = 0)
168
+ target = torch.cat([d[1] for d in test_step_outputs], dim = 0)
169
+ gather_pred = [torch.zeros_like(pred) for _ in range(dist.get_world_size())]
170
+ gather_target = [torch.zeros_like(target) for _ in range(dist.get_world_size())]
171
+ dist.barrier()
172
+ if self.config.dataset_type == "audioset":
173
+ metric_dict = {
174
+ "mAP": 0.,
175
+ "mAUC": 0.,
176
+ "dprime": 0.
177
+ }
178
+ else:
179
+ metric_dict = {
180
+ "acc":0.
181
+ }
182
+ dist.all_gather(gather_pred, pred)
183
+ dist.all_gather(gather_target, target)
184
+ if dist.get_rank() == 0:
185
+ gather_pred = torch.cat(gather_pred, dim = 0).cpu().numpy()
186
+ gather_target = torch.cat(gather_target, dim = 0).cpu().numpy()
187
+ if self.config.dataset_type == "scv2":
188
+ gather_target = np.argmax(gather_target, 1)
189
+ metric_dict = self.evaluate_metric(gather_pred, gather_target)
190
+ print(self.device_type, dist.get_world_size(), metric_dict, flush = True)
191
+ if self.config.dataset_type == "audioset":
192
+ self.log("mAP", metric_dict["mAP"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
193
+ self.log("mAUC", metric_dict["mAUC"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
194
+ self.log("dprime", metric_dict["dprime"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
195
+ else:
196
+ self.log("acc", metric_dict["acc"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
197
+ dist.barrier()
198
+
199
+
200
+ def configure_optimizers(self):
201
+ optimizer = optim.AdamW(
202
+ filter(lambda p: p.requires_grad, self.parameters()),
203
+ lr = self.config.learning_rate,
204
+ betas = (0.9, 0.999), eps = 1e-08, weight_decay = 0.05,
205
+ )
206
+ # Change: SWA, deprecated
207
+ # optimizer = SWA(optimizer, swa_start=10, swa_freq=5)
208
+ def lr_foo(epoch):
209
+ if epoch < 3:
210
+ # warm up lr
211
+ lr_scale = self.config.lr_rate[epoch]
212
+ else:
213
+ # warmup schedule
214
+ lr_pos = int(-1 - bisect.bisect_left(self.config.lr_scheduler_epoch, epoch))
215
+ if lr_pos < -3:
216
+ lr_scale = max(self.config.lr_rate[0] * (0.98 ** epoch), 0.03 )
217
+ else:
218
+ lr_scale = self.config.lr_rate[lr_pos]
219
+ return lr_scale
220
+ scheduler = optim.lr_scheduler.LambdaLR(
221
+ optimizer,
222
+ lr_lambda=lr_foo
223
+ )
224
+
225
+ return [optimizer], [scheduler]
226
+
227
+
228
+
229
+ class Ensemble_SEDWrapper(pl.LightningModule):
230
+ def __init__(self, sed_models, config, dataset):
231
+ super().__init__()
232
+
233
+ self.sed_models = nn.ModuleList(sed_models)
234
+ self.config = config
235
+ self.dataset = dataset
236
+
237
+ def evaluate_metric(self, pred, ans):
238
+ if self.config.dataset_type == "audioset":
239
+ mAP = np.mean(average_precision_score(ans, pred, average = None))
240
+ mAUC = np.mean(roc_auc_score(ans, pred, average = None))
241
+ dprime = d_prime(mAUC)
242
+ return {"mAP": mAP, "mAUC": mAUC, "dprime": dprime}
243
+ else:
244
+ acc = accuracy_score(ans, np.argmax(pred, 1))
245
+ return {"acc": acc}
246
+
247
+ def forward(self, x, sed_index, mix_lambda = None):
248
+ self.sed_models[sed_index].eval()
249
+ preds = []
250
+ pred_maps = []
251
+ # time shifting optimization
252
+ if self.config.fl_local or self.config.dataset_type != "audioset":
253
+ shift_num = 1 # framewise localization cannot allow the time shifting
254
+ else:
255
+ shift_num = 10
256
+ for i in range(shift_num):
257
+ pred, pred_map = self.sed_models[sed_index](x)
258
+ pred_maps.append(pred_map.unsqueeze(0))
259
+ preds.append(pred.unsqueeze(0))
260
+ x = self.time_shifting(x, shift_len = 100 * (i + 1))
261
+ preds = torch.cat(preds, dim=0)
262
+ pred_maps = torch.cat(pred_maps, dim = 0)
263
+ pred = preds.mean(dim = 0)
264
+ pred_map = pred_maps.mean(dim = 0)
265
+ return pred, pred_map
266
+
267
+
268
+ def time_shifting(self, x, shift_len):
269
+ shift_len = int(shift_len)
270
+ new_sample = torch.cat([x[:, shift_len:], x[:, :shift_len]], axis = 1)
271
+ return new_sample
272
+
273
+ def test_step(self, batch, batch_idx):
274
+ self.device_type = next(self.parameters()).device
275
+ if self.config.fl_local:
276
+ pred = torch.zeros(len(batch["waveform"]), self.config.classes_num).float().to(self.device_type)
277
+ pred_map = torch.zeros(len(batch["waveform"]), 1024, self.config.classes_num).float().to(self.device_type)
278
+ for j in range(len(self.sed_models)):
279
+ temp_pred, temp_pred_map = self(batch["waveform"], j)
280
+ pred = pred + temp_pred
281
+ pred_map = pred_map + temp_pred_map
282
+ pred = pred / len(self.sed_models)
283
+ pred_map = pred_map / len(self.sed_models)
284
+ return [
285
+ pred.detach().cpu().numpy(),
286
+ pred_map.detach().cpu().numpy(),
287
+ batch["audio_name"],
288
+ batch["real_len"].cpu().numpy()
289
+ ]
290
+ else:
291
+ pred = torch.zeros(len(batch["waveform"]), self.config.classes_num).float().to(self.device_type)
292
+ for j in range(len(self.sed_models)):
293
+ temp_pred, _ = self(batch["waveform"], j)
294
+ pred = pred + temp_pred
295
+ pred = pred / len(self.sed_models)
296
+ return [
297
+ pred.detach(),
298
+ batch["target"].detach(),
299
+ ]
300
+
301
+ def test_epoch_end(self, test_step_outputs):
302
+ self.device_type = next(self.parameters()).device
303
+ if self.config.fl_local:
304
+ pred = np.concatenate([d[0] for d in test_step_outputs], axis = 0)
305
+ pred_map = np.concatenate([d[1] for d in test_step_outputs], axis = 0)
306
+ audio_name = np.concatenate([d[2] for d in test_step_outputs], axis = 0)
307
+ real_len = np.concatenate([d[3] for d in test_step_outputs], axis = 0)
308
+ heatmap_file = os.path.join(self.config.heatmap_dir, self.config.test_file + "_" + str(self.device_type) + ".npy")
309
+ print(pred.shape)
310
+ print(pred_map.shape)
311
+ print(real_len.shape)
312
+ save_npy = [
313
+ {
314
+ "audio_name": audio_name[i],
315
+ "heatmap": pred_map[i],
316
+ "pred": pred[i],
317
+ "real_len":real_len[i]
318
+ }
319
+ for i in range(len(pred))
320
+ ]
321
+ np.save(heatmap_file, save_npy)
322
+ else:
323
+ pred = torch.cat([d[0] for d in test_step_outputs], dim = 0)
324
+ target = torch.cat([d[1] for d in test_step_outputs], dim = 0)
325
+ gather_pred = [torch.zeros_like(pred) for _ in range(dist.get_world_size())]
326
+ gather_target = [torch.zeros_like(target) for _ in range(dist.get_world_size())]
327
+
328
+ dist.barrier()
329
+ if self.config.dataset_type == "audioset":
330
+ metric_dict = {
331
+ "mAP": 0.,
332
+ "mAUC": 0.,
333
+ "dprime": 0.
334
+ }
335
+ else:
336
+ metric_dict = {
337
+ "acc":0.
338
+ }
339
+ dist.all_gather(gather_pred, pred)
340
+ dist.all_gather(gather_target, target)
341
+ if dist.get_rank() == 0:
342
+ gather_pred = torch.cat(gather_pred, dim = 0).cpu().numpy()
343
+ gather_target = torch.cat(gather_target, dim = 0).cpu().numpy()
344
+ if self.config.dataset_type == "scv2":
345
+ gather_target = np.argmax(gather_target, 1)
346
+ metric_dict = self.evaluate_metric(gather_pred, gather_target)
347
+ print(self.device_type, dist.get_world_size(), metric_dict, flush = True)
348
+ if self.config.dataset_type == "audioset":
349
+ self.log("mAP", metric_dict["mAP"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
350
+ self.log("mAUC", metric_dict["mAUC"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
351
+ self.log("dprime", metric_dict["dprime"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
352
+ else:
353
+ self.log("acc", metric_dict["acc"] * float(dist.get_world_size()), on_epoch = True, prog_bar=True, sync_dist=True)
354
+ dist.barrier()
355
+
356
+
357
+