artelabsuper commited on
Commit
0d4ce65
1 Parent(s): 4e2283a
Files changed (1) hide show
  1. utils.py +226 -0
utils.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ke Chen
2
+ # knutchen@ucsd.edu
3
+ # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
4
+ # Some Useful Common Methods
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch import Tensor
10
+ from typing import Optional
11
+ import logging
12
+ import os
13
+ import sys
14
+ import h5py
15
+ import csv
16
+ import time
17
+ import json
18
+ import museval
19
+ import librosa
20
+ from datetime import datetime
21
+ from tqdm import tqdm
22
+ from scipy import stats
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+
26
+
27
+ # import from https://github.com/Alibaba-MIIL/ASL/blob/main/src/loss_functions/losses.py
28
+ class AsymmetricLoss(nn.Module):
29
+ def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True):
30
+ super(AsymmetricLoss, self).__init__()
31
+
32
+ self.gamma_neg = gamma_neg
33
+ self.gamma_pos = gamma_pos
34
+ self.clip = clip
35
+ self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
36
+ self.eps = eps
37
+
38
+ def forward(self, x, y):
39
+ """"
40
+ Parameters
41
+ ----------
42
+ x: input logits
43
+ y: targets (multi-label binarized vector)
44
+ """
45
+
46
+ # Calculating Probabilities
47
+ # x_sigmoid = torch.sigmoid(x)
48
+ x_sigmoid = x # without sigmoid since it has been computed
49
+ xs_pos = x_sigmoid
50
+ xs_neg = 1 - x_sigmoid
51
+
52
+ # Asymmetric Clipping
53
+ if self.clip is not None and self.clip > 0:
54
+ xs_neg = (xs_neg + self.clip).clamp(max=1)
55
+
56
+ # Basic CE calculation
57
+ los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
58
+ los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
59
+ loss = los_pos + los_neg
60
+
61
+ # Asymmetric Focusing
62
+ if self.gamma_neg > 0 or self.gamma_pos > 0:
63
+ if self.disable_torch_grad_focal_loss:
64
+ torch.set_grad_enabled(False)
65
+ pt0 = xs_pos * y
66
+ pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
67
+ pt = pt0 + pt1
68
+ one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
69
+ one_sided_w = torch.pow(1 - pt, one_sided_gamma)
70
+ if self.disable_torch_grad_focal_loss:
71
+ torch.set_grad_enabled(True)
72
+ loss *= one_sided_w
73
+
74
+ return -loss.mean()
75
+
76
+
77
+ def get_mix_lambda(mixup_alpha, batch_size):
78
+ mixup_lambdas = [np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size)]
79
+ return np.array(mixup_lambdas).astype(np.float32)
80
+
81
+ def create_folder(fd):
82
+ if not os.path.exists(fd):
83
+ os.makedirs(fd)
84
+
85
+ def dump_config(config, filename, include_time = False):
86
+ save_time = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
87
+ config_json = {}
88
+ for key in dir(config):
89
+ if not key.startswith("_"):
90
+ config_json[key] = eval("config." + key)
91
+ if include_time:
92
+ filename = filename + "_" + save_time
93
+ with open(filename + ".json", "w") as f:
94
+ json.dump(config_json, f ,indent=4)
95
+
96
+ def int16_to_float32(x):
97
+ return (x / 32767.).astype(np.float32)
98
+
99
+ def float32_to_int16(x):
100
+ x = np.clip(x, a_min = -1., a_max = 1.)
101
+ return (x * 32767.).astype(np.int16)
102
+
103
+
104
+ # index for each class
105
+ def process_idc(index_path, classes_num, filename):
106
+ # load data
107
+ logging.info("Load Data...............")
108
+ idc = [[] for _ in range(classes_num)]
109
+ with h5py.File(index_path, "r") as f:
110
+ for i in tqdm(range(len(f["target"]))):
111
+ t_class = np.where(f["target"][i])[0]
112
+ for t in t_class:
113
+ idc[t].append(i)
114
+ print(idc)
115
+ np.save(filename, idc)
116
+ logging.info("Load Data Succeed...............")
117
+
118
+ def clip_bce(pred, target):
119
+ """Binary crossentropy loss.
120
+ """
121
+ return F.cross_entropy(pred, target)
122
+ # return F.binary_cross_entropy(pred, target)
123
+
124
+
125
+ def clip_ce(pred, target):
126
+ return F.cross_entropy(pred, target)
127
+
128
+ def d_prime(auc):
129
+ d_prime = stats.norm().ppf(auc) * np.sqrt(2.0)
130
+ return d_prime
131
+
132
+
133
+ def get_loss_func(loss_type):
134
+ if loss_type == 'clip_bce':
135
+ return clip_bce
136
+ if loss_type == 'clip_ce':
137
+ return clip_ce
138
+ if loss_type == 'asl_loss':
139
+ loss_func = AsymmetricLoss(gamma_neg=4, gamma_pos=0,clip=0.05)
140
+ return loss_func
141
+
142
+ def do_mixup_label(x):
143
+ out = torch.logical_or(x, torch.flip(x, dims = [0])).float()
144
+ return out
145
+
146
+ def do_mixup(x, mixup_lambda):
147
+ """
148
+ Args:
149
+ x: (batch_size , ...)
150
+ mixup_lambda: (batch_size,)
151
+
152
+ Returns:
153
+ out: (batch_size, ...)
154
+ """
155
+ out = (x.transpose(0,-1) * mixup_lambda + torch.flip(x, dims = [0]).transpose(0,-1) * (1 - mixup_lambda)).transpose(0,-1)
156
+ return out
157
+
158
+ def interpolate(x, ratio):
159
+ """Interpolate data in time domain. This is used to compensate the
160
+ resolution reduction in downsampling of a CNN.
161
+
162
+ Args:
163
+ x: (batch_size, time_steps, classes_num)
164
+ ratio: int, ratio to interpolate
165
+
166
+ Returns:
167
+ upsampled: (batch_size, time_steps * ratio, classes_num)
168
+ """
169
+ (batch_size, time_steps, classes_num) = x.shape
170
+ upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
171
+ upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
172
+ return upsampled
173
+
174
+
175
+ def pad_framewise_output(framewise_output, frames_num):
176
+ """Pad framewise_output to the same length as input frames. The pad value
177
+ is the same as the value of the last frame.
178
+
179
+ Args:
180
+ framewise_output: (batch_size, frames_num, classes_num)
181
+ frames_num: int, number of frames to pad
182
+
183
+ Outputs:
184
+ output: (batch_size, frames_num, classes_num)
185
+ """
186
+ pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1)
187
+ """tensor for padding"""
188
+
189
+ output = torch.cat((framewise_output, pad), dim=1)
190
+ """(batch_size, frames_num, classes_num)"""
191
+
192
+ return output
193
+
194
+ # set the audio into the format that can be fed into the model
195
+ # resample -> convert to mono -> output the audio
196
+ # track [n_sample, n_channel]
197
+ def prepprocess_audio(track, ofs, rfs, mono_type = "mix"):
198
+ if track.shape[-1] > 1:
199
+ # stereo
200
+ if mono_type == "mix":
201
+ track = np.transpose(track, (1,0))
202
+ track = librosa.to_mono(track)
203
+ elif mono_type == "left":
204
+ track = track[:, 0]
205
+ elif mono_type == "right":
206
+ track = track[:, 1]
207
+ else:
208
+ track = track[:, 0]
209
+ # track [n_sample]
210
+ if ofs != rfs:
211
+ track = librosa.resample(track, ofs, rfs)
212
+ return track
213
+
214
+ def init_hier_head(class_map, num_class):
215
+ class_map = np.load(class_map, allow_pickle = True)
216
+
217
+ head_weight = torch.zeros(num_class,num_class).float()
218
+ head_bias = torch.zeros(num_class).float()
219
+
220
+ for i in range(len(class_map)):
221
+ for d in class_map[i][1]:
222
+ head_weight[d][i] = 1.0
223
+ for d in class_map[i][2]:
224
+ head_weight[d][i] = 1.0 / len(class_map[i][2])
225
+ head_weight[i][i] = 1.0
226
+ return head_weight, head_bias