Feature Extraction
PyTorch
Bioacoustics
ilyassmoummad commited on
Commit
aab5975
1 Parent(s): 2cc9f37
README.md ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ProtoCLR
2
+
3
+ This repository contains a CvT-13 [Convolutional Vision Transformer](https://arxiv.org/abs/2103.15808) model trained from scratch on the [Xeno-Canto dataset](https://huggingface.co/datasets/ilyassmoummad/Xeno-Canto-6s-16khz), specifically on 6-second audio segments sampled at 16 kHz. The model is trained on Mel spectrograms of bird sounds using ProtoCLR [(Prototypical Contrastive Loss)](https://arxiv.org/abs/2409.08589) for 300 epochs and can be used as a feature extractor for bird audio classification and related tasks.
4
+
5
+ ## Files
6
+
7
+ - `cvt.py`: Defines the CvT-13 model architecture.
8
+ - `protoclr_300.pth`: Pre-trained model weights for ProtoCLR.
9
+ - `config/`: Configuration files for CvT-13 setup.
10
+ - `mel_spectrogram.py`: Contains the `MelSpectrogramProcessor` class, which converts audio waveforms into Mel spectrograms, a format suitable for model input.
11
+
12
+ ## Setup
13
+
14
+ 1. **Install dependencies**:
15
+ Ensure you have the required Python packages, including `torch` and any other dependencies listed in `requirements.txt`.
16
+ ```bash
17
+ pip install -r requirements.txt
18
+ ```
19
+
20
+ 2. **Prepare the audio**:
21
+ - **Sample rate**: Ensure your audio is sampled at 16 kHz.
22
+ - **Padding**: For audio shorter than 6 seconds, pad with zeros or repeat the audio to reach 6 seconds.
23
+ - **Chunking**: For audio longer than 6 seconds, split it into 6-second chunks.
24
+
25
+ ## Usage
26
+
27
+ To use the model, process your audio data using the `MelSpectrogramProcessor`, and then pass the processed spectrograms to the CvT-13 model.
28
+
29
+ ## Example Code
30
+
31
+ The following example demonstrates loading, processing, and running inference on an audio file:
32
+
33
+ ```python
34
+ import torch
35
+ from cvt import cvt13 # Import model architecture
36
+ from melspectrogram import MelSpectrogramProcessor # Import Mel spectrogram processor
37
+
38
+ # Initialize the preprocessor and model
39
+ preprocessor = MelSpectrogramProcessor()
40
+ model = cvt13()
41
+ model.load_state_dict(torch.load("protoclr.pth"))
42
+ model.eval()
43
+
44
+ # Load and preprocess a sample audio waveform
45
+ def load_waveform(file_path):
46
+ # Replace this with your specific audio loading function
47
+ # For example, using torchaudio to load and resample
48
+ pass
49
+
50
+ waveform = load_waveform("path/to/audio.wav") # Load your audio file here
51
+
52
+ # Ensure waveform is sampled at 16 kHz, then pad/chunk as needed for 6s length
53
+ input_tensor = preprocessor.process(waveform).unsqueeze(0) # Add batch dimension
54
+
55
+ # Run the model on the preprocessed audio
56
+ with torch.no_grad():
57
+ output = model(input_tensor)
58
+ print("Model output shape:", output.shape)
59
+ ```
60
+
61
+ ## Citation
62
+
63
+ If you use our model in your research, please cite the following paper:
64
+
65
+ ```bibtex
66
+ @misc{moummad2024dirlbs,
67
+ title={Domain-Invariant Representation Learning of Bird Sounds},
68
+ author={Ilyass Moummad and Romain Serizel and Emmanouil Benetos and Nicolas Farrugia},
69
+ year={2024},
70
+ eprint={2409.08589},
71
+ archivePrefix={arXiv},
72
+ primaryClass={cs.SD},
73
+ url={https://arxiv.org/abs/2409.08589},
74
+ }
75
+ ```
config/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .default import _C as config
2
+ from .default import update_config
3
+ from .default import _update_config_from_file
4
+ from .default import save_config
config/comm.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+
6
+
7
+ class Comm(object):
8
+ def __init__(self, local_rank=0):
9
+ self.local_rank = 0
10
+
11
+ @property
12
+ def world_size(self):
13
+ if not dist.is_available():
14
+ return 1
15
+ if not dist.is_initialized():
16
+ return 1
17
+ return dist.get_world_size()
18
+
19
+ @property
20
+ def rank(self):
21
+ if not dist.is_available():
22
+ return 0
23
+ if not dist.is_initialized():
24
+ return 0
25
+ return dist.get_rank()
26
+
27
+ @property
28
+ def local_rank(self):
29
+ if not dist.is_available():
30
+ return 0
31
+ if not dist.is_initialized():
32
+ return 0
33
+ return self._local_rank
34
+
35
+ @local_rank.setter
36
+ def local_rank(self, value):
37
+ if not dist.is_available():
38
+ self._local_rank = 0
39
+ if not dist.is_initialized():
40
+ self._local_rank = 0
41
+ self._local_rank = value
42
+
43
+ @property
44
+ def head(self):
45
+ return 'Rank[{}/{}]'.format(self.rank, self.world_size)
46
+
47
+ def is_main_process(self):
48
+ return self.rank == 0
49
+
50
+ def synchronize(self):
51
+ """
52
+ Helper function to synchronize (barrier) among all processes when
53
+ using distributed training
54
+ """
55
+ if self.world_size == 1:
56
+ return
57
+ dist.barrier()
58
+
59
+
60
+ comm = Comm()
61
+
62
+
63
+ def all_gather(data):
64
+ """
65
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
66
+ Args:
67
+ data: any picklable object
68
+ Returns:
69
+ list[data]: list of data gathered from each rank
70
+ """
71
+ world_size = comm.world_size
72
+ if world_size == 1:
73
+ return [data]
74
+
75
+ # serialized to a Tensor
76
+ buffer = pickle.dumps(data)
77
+ storage = torch.ByteStorage.from_buffer(buffer)
78
+ tensor = torch.ByteTensor(storage).to("cuda")
79
+
80
+ # obtain Tensor size of each rank
81
+ local_size = torch.LongTensor([tensor.numel()]).to("cuda")
82
+ size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
83
+ dist.all_gather(size_list, local_size)
84
+ size_list = [int(size.item()) for size in size_list]
85
+ max_size = max(size_list)
86
+
87
+ # receiving Tensor from all ranks
88
+ # we pad the tensor because torch all_gather does not support
89
+ # gathering tensors of different shapes
90
+ tensor_list = []
91
+ for _ in size_list:
92
+ tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
93
+ if local_size != max_size:
94
+ padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
95
+ tensor = torch.cat((tensor, padding), dim=0)
96
+ dist.all_gather(tensor_list, tensor)
97
+
98
+ data_list = []
99
+ for size, tensor in zip(size_list, tensor_list):
100
+ buffer = tensor.cpu().numpy().tobytes()[:size]
101
+ data_list.append(pickle.loads(buffer))
102
+
103
+ return data_list
104
+
105
+
106
+ def reduce_dict(input_dict, average=True):
107
+ """
108
+ Args:
109
+ input_dict (dict): all the values will be reduced
110
+ average (bool): whether to do average or sum
111
+ Reduce the values in the dictionary from all processes so that process with rank
112
+ 0 has the averaged results. Returns a dict with the same fields as
113
+ input_dict, after reduction.
114
+ """
115
+ world_size = comm.world_size
116
+ if world_size < 2:
117
+ return input_dict
118
+ with torch.no_grad():
119
+ names = []
120
+ values = []
121
+ # sort the keys so that they are consistent across processes
122
+ for k in sorted(input_dict.keys()):
123
+ names.append(k)
124
+ values.append(input_dict[k])
125
+ values = torch.stack(values, dim=0)
126
+ dist.reduce(values, dst=0)
127
+ if dist.get_rank() == 0 and average:
128
+ # only main process gets accumulated, so only divide by
129
+ # world_size in this case
130
+ values /= world_size
131
+ reduced_dict = {k: v for k, v in zip(names, values)}
132
+ return reduced_dict
config/cvt-13-224x224.yaml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ OUTPUT_DIR: 'OUTPUT/'
2
+ WORKERS: 6
3
+ PRINT_FREQ: 500
4
+ AMP:
5
+ ENABLED: true
6
+
7
+ MODEL:
8
+ NAME: cls_cvt
9
+ SPEC:
10
+ INIT: 'trunc_norm'
11
+ NUM_STAGES: 3
12
+ PATCH_SIZE: [7, 3, 3]
13
+ PATCH_STRIDE: [4, 2, 2]
14
+ PATCH_PADDING: [2, 1, 1]
15
+ DIM_EMBED: [64, 192, 384]
16
+ NUM_HEADS: [1, 3, 6]
17
+ DEPTH: [1, 2, 10]
18
+ MLP_RATIO: [4.0, 4.0, 4.0]
19
+ ATTN_DROP_RATE: [0.0, 0.0, 0.0]
20
+ DROP_RATE: [0.0, 0.0, 0.0]
21
+ DROP_PATH_RATE: [0.0, 0.0, 0.1]
22
+ QKV_BIAS: [True, True, True]
23
+ CLS_TOKEN: [False, False, True]
24
+ POS_EMBED: [False, False, False]
25
+ QKV_PROJ_METHOD: ['dw_bn', 'dw_bn', 'dw_bn']
26
+ KERNEL_QKV: [3, 3, 3]
27
+ PADDING_KV: [1, 1, 1]
28
+ STRIDE_KV: [2, 2, 2]
29
+ PADDING_Q: [1, 1, 1]
30
+ STRIDE_Q: [1, 1, 1]
31
+ AUG:
32
+ MIXUP_PROB: 1.0
33
+ MIXUP: 0.8
34
+ MIXCUT: 1.0
35
+ TIMM_AUG:
36
+ USE_LOADER: true
37
+ RE_COUNT: 1
38
+ RE_MODE: pixel
39
+ RE_SPLIT: false
40
+ RE_PROB: 0.25
41
+ AUTO_AUGMENT: rand-m9-mstd0.5-inc1
42
+ HFLIP: 0.5
43
+ VFLIP: 0.0
44
+ COLOR_JITTER: 0.4
45
+ INTERPOLATION: bicubic
46
+ LOSS:
47
+ LABEL_SMOOTHING: 0.1
48
+ CUDNN:
49
+ BENCHMARK: true
50
+ DETERMINISTIC: false
51
+ ENABLED: true
52
+ DATASET:
53
+ DATASET: 'imagenet'
54
+ DATA_FORMAT: 'jpg'
55
+ ROOT: 'DATASET/imagenet/'
56
+ TEST_SET: 'val'
57
+ TRAIN_SET: 'train'
58
+ TEST:
59
+ BATCH_SIZE_PER_GPU: 32
60
+ IMAGE_SIZE: [224, 224]
61
+ MODEL_FILE: ''
62
+ INTERPOLATION: 3
63
+ TRAIN:
64
+ BATCH_SIZE_PER_GPU: 256
65
+ LR: 0.00025
66
+ IMAGE_SIZE: [224, 224]
67
+ BEGIN_EPOCH: 0
68
+ END_EPOCH: 300
69
+ LR_SCHEDULER:
70
+ METHOD: 'timm'
71
+ ARGS:
72
+ sched: 'cosine'
73
+ warmup_epochs: 5
74
+ warmup_lr: 0.000001
75
+ min_lr: 0.00001
76
+ cooldown_epochs: 10
77
+ decay_rate: 0.1
78
+ OPTIMIZER: adamW
79
+ WD: 0.05
80
+ WITHOUT_WD_LIST: ['bn', 'bias', 'ln']
81
+ SHUFFLE: true
82
+ DEBUG:
83
+ DEBUG: false
config/default.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import os.path as op
6
+ import yaml
7
+ from yacs.config import CfgNode as CN
8
+
9
+ from config import comm
10
+
11
+
12
+ _C = CN()
13
+
14
+ _C.BASE = ['']
15
+ _C.NAME = ''
16
+ _C.DATA_DIR = ''
17
+ _C.DIST_BACKEND = 'nccl'
18
+ _C.GPUS = (0,)
19
+ # _C.LOG_DIR = ''
20
+ _C.MULTIPROCESSING_DISTRIBUTED = True
21
+ _C.OUTPUT_DIR = ''
22
+ _C.PIN_MEMORY = True
23
+ _C.PRINT_FREQ = 20
24
+ _C.RANK = 0
25
+ _C.VERBOSE = True
26
+ _C.WORKERS = 4
27
+ _C.MODEL_SUMMARY = False
28
+
29
+ _C.AMP = CN()
30
+ _C.AMP.ENABLED = False
31
+ _C.AMP.MEMORY_FORMAT = 'nchw'
32
+
33
+ # Cudnn related params
34
+ _C.CUDNN = CN()
35
+ _C.CUDNN.BENCHMARK = True
36
+ _C.CUDNN.DETERMINISTIC = False
37
+ _C.CUDNN.ENABLED = True
38
+
39
+ # common params for NETWORK
40
+ _C.MODEL = CN()
41
+ _C.MODEL.NAME = 'cls_hrnet'
42
+ _C.MODEL.INIT_WEIGHTS = True
43
+ _C.MODEL.PRETRAINED = ''
44
+ _C.MODEL.PRETRAINED_LAYERS = ['*']
45
+ _C.MODEL.NUM_CLASSES = 1000
46
+ _C.MODEL.SPEC = CN(new_allowed=True)
47
+
48
+ _C.LOSS = CN(new_allowed=True)
49
+ _C.LOSS.LABEL_SMOOTHING = 0.0
50
+ _C.LOSS.LOSS = 'softmax'
51
+
52
+ # DATASET related params
53
+ _C.DATASET = CN()
54
+ _C.DATASET.ROOT = ''
55
+ _C.DATASET.DATASET = 'imagenet'
56
+ _C.DATASET.TRAIN_SET = 'train'
57
+ _C.DATASET.TEST_SET = 'val'
58
+ _C.DATASET.DATA_FORMAT = 'jpg'
59
+ _C.DATASET.LABELMAP = ''
60
+ _C.DATASET.TRAIN_TSV_LIST = []
61
+ _C.DATASET.TEST_TSV_LIST = []
62
+ _C.DATASET.SAMPLER = 'default'
63
+
64
+ _C.DATASET.TARGET_SIZE = -1
65
+
66
+ # training data augmentation
67
+ _C.INPUT = CN()
68
+ _C.INPUT.MEAN = [0.485, 0.456, 0.406]
69
+ _C.INPUT.STD = [0.229, 0.224, 0.225]
70
+
71
+ # data augmentation
72
+ _C.AUG = CN()
73
+ _C.AUG.SCALE = (0.08, 1.0)
74
+ _C.AUG.RATIO = (3.0/4.0, 4.0/3.0)
75
+ _C.AUG.COLOR_JITTER = [0.4, 0.4, 0.4, 0.1, 0.0]
76
+ _C.AUG.GRAY_SCALE = 0.0
77
+ _C.AUG.GAUSSIAN_BLUR = 0.0
78
+ _C.AUG.DROPBLOCK_LAYERS = [3, 4]
79
+ _C.AUG.DROPBLOCK_KEEP_PROB = 1.0
80
+ _C.AUG.DROPBLOCK_BLOCK_SIZE = 7
81
+ _C.AUG.MIXUP_PROB = 0.0
82
+ _C.AUG.MIXUP = 0.0
83
+ _C.AUG.MIXCUT = 0.0
84
+ _C.AUG.MIXCUT_MINMAX = []
85
+ _C.AUG.MIXUP_SWITCH_PROB = 0.5
86
+ _C.AUG.MIXUP_MODE = 'batch'
87
+ _C.AUG.MIXCUT_AND_MIXUP = False
88
+ _C.AUG.INTERPOLATION = 2
89
+ _C.AUG.TIMM_AUG = CN(new_allowed=True)
90
+ _C.AUG.TIMM_AUG.USE_LOADER = False
91
+ _C.AUG.TIMM_AUG.USE_TRANSFORM = False
92
+
93
+ # train
94
+ _C.TRAIN = CN()
95
+
96
+ _C.TRAIN.AUTO_RESUME = True
97
+ _C.TRAIN.CHECKPOINT = ''
98
+ _C.TRAIN.LR_SCHEDULER = CN(new_allowed=True)
99
+ _C.TRAIN.SCALE_LR = True
100
+ _C.TRAIN.LR = 0.001
101
+
102
+ _C.TRAIN.OPTIMIZER = 'sgd'
103
+ _C.TRAIN.OPTIMIZER_ARGS = CN(new_allowed=True)
104
+ _C.TRAIN.MOMENTUM = 0.9
105
+ _C.TRAIN.WD = 0.0001
106
+ _C.TRAIN.WITHOUT_WD_LIST = []
107
+ _C.TRAIN.NESTEROV = True
108
+ # for adam
109
+ _C.TRAIN.GAMMA1 = 0.99
110
+ _C.TRAIN.GAMMA2 = 0.0
111
+
112
+ _C.TRAIN.BEGIN_EPOCH = 0
113
+ _C.TRAIN.END_EPOCH = 100
114
+
115
+ _C.TRAIN.IMAGE_SIZE = [224, 224] # width * height, ex: 192 * 256
116
+ _C.TRAIN.BATCH_SIZE_PER_GPU = 32
117
+ _C.TRAIN.SHUFFLE = True
118
+
119
+ _C.TRAIN.EVAL_BEGIN_EPOCH = 0
120
+
121
+ _C.TRAIN.DETECT_ANOMALY = False
122
+
123
+ _C.TRAIN.CLIP_GRAD_NORM = 0.0
124
+ _C.TRAIN.SAVE_ALL_MODELS = False
125
+
126
+ # testing
127
+ _C.TEST = CN()
128
+
129
+ # size of images for each device
130
+ _C.TEST.BATCH_SIZE_PER_GPU = 32
131
+ _C.TEST.CENTER_CROP = True
132
+ _C.TEST.IMAGE_SIZE = [224, 224] # width * height, ex: 192 * 256
133
+ _C.TEST.INTERPOLATION = 2
134
+ _C.TEST.MODEL_FILE = ''
135
+ _C.TEST.REAL_LABELS = False
136
+ _C.TEST.VALID_LABELS = ''
137
+
138
+ _C.FINETUNE = CN()
139
+ _C.FINETUNE.FINETUNE = False
140
+ _C.FINETUNE.USE_TRAIN_AUG = False
141
+ _C.FINETUNE.BASE_LR = 0.003
142
+ _C.FINETUNE.BATCH_SIZE = 512
143
+ _C.FINETUNE.EVAL_EVERY = 3000
144
+ _C.FINETUNE.TRAIN_MODE = True
145
+ # _C.FINETUNE.MODEL_FILE = ''
146
+ _C.FINETUNE.FROZEN_LAYERS = []
147
+ _C.FINETUNE.LR_SCHEDULER = CN(new_allowed=True)
148
+ _C.FINETUNE.LR_SCHEDULER.DECAY_TYPE = 'step'
149
+
150
+ # debug
151
+ _C.DEBUG = CN()
152
+ _C.DEBUG.DEBUG = False
153
+
154
+
155
+ def _update_config_from_file(config, cfg_file):
156
+ config.defrost()
157
+ with open(cfg_file, 'r') as f:
158
+ yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
159
+
160
+ for cfg in yaml_cfg.setdefault('BASE', ['']):
161
+ if cfg:
162
+ _update_config_from_file(
163
+ config, op.join(op.dirname(cfg_file), cfg)
164
+ )
165
+ print('=> merge config from {}'.format(cfg_file))
166
+ config.merge_from_file(cfg_file)
167
+ config.freeze()
168
+
169
+
170
+ def update_config(config, args):
171
+ _update_config_from_file(config, args.cfg)
172
+
173
+ config.defrost()
174
+ config.merge_from_list(args.opts)
175
+ if config.TRAIN.SCALE_LR:
176
+ config.TRAIN.LR *= comm.world_size
177
+ file_name, _ = op.splitext(op.basename(args.cfg))
178
+ config.NAME = file_name + config.NAME
179
+ config.RANK = comm.rank
180
+
181
+ if 'timm' == config.TRAIN.LR_SCHEDULER.METHOD:
182
+ config.TRAIN.LR_SCHEDULER.ARGS.epochs = config.TRAIN.END_EPOCH
183
+
184
+ if 'timm' == config.TRAIN.OPTIMIZER:
185
+ config.TRAIN.OPTIMIZER_ARGS.lr = config.TRAIN.LR
186
+
187
+ aug = config.AUG
188
+ if aug.MIXUP > 0.0 or aug.MIXCUT > 0.0 or aug.MIXCUT_MINMAX:
189
+ aug.MIXUP_PROB = 1.0
190
+ config.freeze()
191
+
192
+
193
+ def save_config(cfg, path):
194
+ if comm.is_main_process():
195
+ with open(path, 'w') as f:
196
+ f.write(cfg.dump())
197
+
198
+
199
+ if __name__ == '__main__':
200
+ import sys
201
+ with open(sys.argv[1], 'w') as f:
202
+ print(_C, file=f)
cvt.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from itertools import repeat
3
+ #from torch._six import container_abcs
4
+ import collections.abc as container_abcs
5
+
6
+ import logging
7
+ import os
8
+ from collections import OrderedDict
9
+
10
+ import numpy as np
11
+ import scipy
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from einops import rearrange
16
+ from einops.layers.torch import Rearrange
17
+
18
+ from timm.models.layers import DropPath, trunc_normal_
19
+
20
+ #from .registry import register_model
21
+ from config import config
22
+
23
+ from torchinfo import summary
24
+ import yaml
25
+
26
+ _model_entrypoints = {}
27
+
28
+
29
+ def register_model(fn):
30
+ module_name_split = fn.__module__.split('.')
31
+ model_name = module_name_split[-1]
32
+
33
+ _model_entrypoints[model_name] = fn
34
+
35
+ return fn
36
+
37
+
38
+ def model_entrypoints(model_name):
39
+ return _model_entrypoints[model_name]
40
+
41
+
42
+ def is_model(model_name):
43
+ return model_name in _model_entrypoints
44
+
45
+
46
+ # From PyTorch internals
47
+ def _ntuple(n):
48
+ def parse(x):
49
+ if isinstance(x, container_abcs.Iterable):
50
+ return x
51
+ return tuple(repeat(x, n))
52
+
53
+ return parse
54
+
55
+
56
+ to_1tuple = _ntuple(1)
57
+ to_2tuple = _ntuple(2)
58
+ to_3tuple = _ntuple(3)
59
+ to_4tuple = _ntuple(4)
60
+ to_ntuple = _ntuple
61
+
62
+
63
+ class LayerNorm(nn.LayerNorm):
64
+ """Subclass torch's LayerNorm to handle fp16."""
65
+
66
+ def forward(self, x: torch.Tensor):
67
+ orig_type = x.dtype
68
+ ret = super().forward(x.type(torch.float32))
69
+ return ret.type(orig_type)
70
+
71
+
72
+ class QuickGELU(nn.Module):
73
+ def forward(self, x: torch.Tensor):
74
+ return x * torch.sigmoid(1.702 * x)
75
+
76
+
77
+ class Mlp(nn.Module):
78
+ def __init__(self,
79
+ in_features,
80
+ hidden_features=None,
81
+ out_features=None,
82
+ act_layer=nn.GELU,
83
+ drop=0.):
84
+ super().__init__()
85
+ out_features = out_features or in_features
86
+ hidden_features = hidden_features or in_features
87
+ self.fc1 = nn.Linear(in_features, hidden_features)
88
+ self.act = act_layer()
89
+ self.fc2 = nn.Linear(hidden_features, out_features)
90
+ self.drop = nn.Dropout(drop)
91
+
92
+ def forward(self, x):
93
+ x = self.fc1(x)
94
+ x = self.act(x)
95
+ x = self.drop(x)
96
+ x = self.fc2(x)
97
+ x = self.drop(x)
98
+ return x
99
+
100
+
101
+ class Attention(nn.Module):
102
+ def __init__(self,
103
+ dim_in,
104
+ dim_out,
105
+ num_heads,
106
+ qkv_bias=False,
107
+ attn_drop=0.,
108
+ proj_drop=0.,
109
+ method='dw_bn',
110
+ kernel_size=3,
111
+ stride_kv=1,
112
+ stride_q=1,
113
+ padding_kv=1,
114
+ padding_q=1,
115
+ with_cls_token=True,
116
+ **kwargs
117
+ ):
118
+ super().__init__()
119
+ self.stride_kv = stride_kv
120
+ self.stride_q = stride_q
121
+ self.dim = dim_out
122
+ self.num_heads = num_heads
123
+ # head_dim = self.qkv_dim // num_heads
124
+ self.scale = dim_out ** -0.5
125
+ self.with_cls_token = with_cls_token
126
+
127
+ self.conv_proj_q = self._build_projection(
128
+ dim_in, dim_out, kernel_size, padding_q,
129
+ stride_q, 'linear' if method == 'avg' else method
130
+ )
131
+ self.conv_proj_k = self._build_projection(
132
+ dim_in, dim_out, kernel_size, padding_kv,
133
+ stride_kv, method
134
+ )
135
+ self.conv_proj_v = self._build_projection(
136
+ dim_in, dim_out, kernel_size, padding_kv,
137
+ stride_kv, method
138
+ )
139
+
140
+ self.proj_q = nn.Linear(dim_in, dim_out, bias=qkv_bias)
141
+ self.proj_k = nn.Linear(dim_in, dim_out, bias=qkv_bias)
142
+ self.proj_v = nn.Linear(dim_in, dim_out, bias=qkv_bias)
143
+
144
+ self.attn_drop = nn.Dropout(attn_drop)
145
+ self.proj = nn.Linear(dim_out, dim_out)
146
+ self.proj_drop = nn.Dropout(proj_drop)
147
+
148
+ def _build_projection(self,
149
+ dim_in,
150
+ dim_out,
151
+ kernel_size,
152
+ padding,
153
+ stride,
154
+ method):
155
+ if method == 'dw_bn':
156
+ proj = nn.Sequential(OrderedDict([
157
+ ('conv', nn.Conv2d(
158
+ dim_in,
159
+ dim_in,
160
+ kernel_size=kernel_size,
161
+ padding=padding,
162
+ stride=stride,
163
+ bias=False,
164
+ groups=dim_in
165
+ )),
166
+ ('bn', nn.BatchNorm2d(dim_in)),
167
+ ('rearrage', Rearrange('b c h w -> b (h w) c')),
168
+ ]))
169
+ elif method == 'avg':
170
+ proj = nn.Sequential(OrderedDict([
171
+ ('avg', nn.AvgPool2d(
172
+ kernel_size=kernel_size,
173
+ padding=padding,
174
+ stride=stride,
175
+ ceil_mode=True
176
+ )),
177
+ ('rearrage', Rearrange('b c h w -> b (h w) c')),
178
+ ]))
179
+ elif method == 'linear':
180
+ proj = None
181
+ else:
182
+ raise ValueError('Unknown method ({})'.format(method))
183
+
184
+ return proj
185
+
186
+ def forward_conv(self, x, h, w):
187
+ if self.with_cls_token:
188
+ cls_token, x = torch.split(x, [1, h*w], 1)
189
+
190
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
191
+
192
+ if self.conv_proj_q is not None:
193
+ q = self.conv_proj_q(x)
194
+ else:
195
+ q = rearrange(x, 'b c h w -> b (h w) c')
196
+
197
+ if self.conv_proj_k is not None:
198
+ k = self.conv_proj_k(x)
199
+ else:
200
+ k = rearrange(x, 'b c h w -> b (h w) c')
201
+
202
+ if self.conv_proj_v is not None:
203
+ v = self.conv_proj_v(x)
204
+ else:
205
+ v = rearrange(x, 'b c h w -> b (h w) c')
206
+
207
+ if self.with_cls_token:
208
+ q = torch.cat((cls_token, q), dim=1)
209
+ k = torch.cat((cls_token, k), dim=1)
210
+ v = torch.cat((cls_token, v), dim=1)
211
+
212
+ return q, k, v
213
+
214
+ def forward(self, x, h, w):
215
+ if (
216
+ self.conv_proj_q is not None
217
+ or self.conv_proj_k is not None
218
+ or self.conv_proj_v is not None
219
+ ):
220
+ q, k, v = self.forward_conv(x, h, w)
221
+
222
+ q = rearrange(self.proj_q(q), 'b t (h d) -> b h t d', h=self.num_heads)
223
+ k = rearrange(self.proj_k(k), 'b t (h d) -> b h t d', h=self.num_heads)
224
+ v = rearrange(self.proj_v(v), 'b t (h d) -> b h t d', h=self.num_heads)
225
+
226
+ attn_score = torch.einsum('bhlk,bhtk->bhlt', [q, k]) * self.scale
227
+ attn = F.softmax(attn_score, dim=-1)
228
+ attn = self.attn_drop(attn)
229
+
230
+ x = torch.einsum('bhlt,bhtv->bhlv', [attn, v])
231
+ x = rearrange(x, 'b h t d -> b t (h d)')
232
+
233
+ x = self.proj(x)
234
+ x = self.proj_drop(x)
235
+
236
+ return x
237
+
238
+ @staticmethod
239
+ def compute_macs(module, input, output):
240
+ # T: num_token
241
+ # S: num_token
242
+ input = input[0]
243
+ flops = 0
244
+
245
+ _, T, C = input.shape
246
+ H = W = int(np.sqrt(T-1)) if module.with_cls_token else int(np.sqrt(T))
247
+
248
+ H_Q = H / module.stride_q
249
+ W_Q = H / module.stride_q
250
+ T_Q = H_Q * W_Q + 1 if module.with_cls_token else H_Q * W_Q
251
+
252
+ H_KV = H / module.stride_kv
253
+ W_KV = W / module.stride_kv
254
+ T_KV = H_KV * W_KV + 1 if module.with_cls_token else H_KV * W_KV
255
+
256
+ # C = module.dim
257
+ # S = T
258
+ # Scaled-dot-product macs
259
+ # [B x T x C] x [B x C x T] --> [B x T x S]
260
+ # multiplication-addition is counted as 1 because operations can be fused
261
+ flops += T_Q * T_KV * module.dim
262
+ # [B x T x S] x [B x S x C] --> [B x T x C]
263
+ flops += T_Q * module.dim * T_KV
264
+
265
+ if (
266
+ hasattr(module, 'conv_proj_q')
267
+ and hasattr(module.conv_proj_q, 'conv')
268
+ ):
269
+ params = sum(
270
+ [
271
+ p.numel()
272
+ for p in module.conv_proj_q.conv.parameters()
273
+ ]
274
+ )
275
+ flops += params * H_Q * W_Q
276
+
277
+ if (
278
+ hasattr(module, 'conv_proj_k')
279
+ and hasattr(module.conv_proj_k, 'conv')
280
+ ):
281
+ params = sum(
282
+ [
283
+ p.numel()
284
+ for p in module.conv_proj_k.conv.parameters()
285
+ ]
286
+ )
287
+ flops += params * H_KV * W_KV
288
+
289
+ if (
290
+ hasattr(module, 'conv_proj_v')
291
+ and hasattr(module.conv_proj_v, 'conv')
292
+ ):
293
+ params = sum(
294
+ [
295
+ p.numel()
296
+ for p in module.conv_proj_v.conv.parameters()
297
+ ]
298
+ )
299
+ flops += params * H_KV * W_KV
300
+
301
+ params = sum([p.numel() for p in module.proj_q.parameters()])
302
+ flops += params * T_Q
303
+ params = sum([p.numel() for p in module.proj_k.parameters()])
304
+ flops += params * T_KV
305
+ params = sum([p.numel() for p in module.proj_v.parameters()])
306
+ flops += params * T_KV
307
+ params = sum([p.numel() for p in module.proj.parameters()])
308
+ flops += params * T
309
+
310
+ module.__flops__ += flops
311
+
312
+
313
+ class Block(nn.Module):
314
+
315
+ def __init__(self,
316
+ dim_in,
317
+ dim_out,
318
+ num_heads,
319
+ mlp_ratio=4.,
320
+ qkv_bias=False,
321
+ drop=0.,
322
+ attn_drop=0.,
323
+ drop_path=0.,
324
+ act_layer=nn.GELU,
325
+ norm_layer=nn.LayerNorm,
326
+ **kwargs):
327
+ super().__init__()
328
+
329
+ self.with_cls_token = kwargs['with_cls_token']
330
+
331
+ self.norm1 = norm_layer(dim_in)
332
+ self.attn = Attention(
333
+ dim_in, dim_out, num_heads, qkv_bias, attn_drop, drop,
334
+ **kwargs
335
+ )
336
+
337
+ self.drop_path = DropPath(drop_path) \
338
+ if drop_path > 0. else nn.Identity()
339
+ self.norm2 = norm_layer(dim_out)
340
+
341
+ dim_mlp_hidden = int(dim_out * mlp_ratio)
342
+ self.mlp = Mlp(
343
+ in_features=dim_out,
344
+ hidden_features=dim_mlp_hidden,
345
+ act_layer=act_layer,
346
+ drop=drop
347
+ )
348
+
349
+ def forward(self, x, h, w):
350
+ res = x
351
+
352
+ x = self.norm1(x)
353
+ attn = self.attn(x, h, w)
354
+ x = res + self.drop_path(attn)
355
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
356
+
357
+ return x
358
+
359
+
360
+ class ConvEmbed(nn.Module):
361
+ """ Image to Conv Embedding
362
+
363
+ """
364
+
365
+ def __init__(self,
366
+ patch_size=7,
367
+ in_chans=1, #1 for spectrogram, 3 for rgb image
368
+ embed_dim=64,
369
+ stride=4,
370
+ padding=2,
371
+ norm_layer=None):
372
+ super().__init__()
373
+ patch_size = to_2tuple(patch_size)
374
+ self.patch_size = patch_size
375
+
376
+ self.proj = nn.Conv2d(
377
+ in_chans, embed_dim,
378
+ kernel_size=patch_size,
379
+ stride=stride,
380
+ padding=padding
381
+ )
382
+ self.norm = norm_layer(embed_dim) if norm_layer else None
383
+
384
+ def forward(self, x):
385
+ x = self.proj(x)
386
+
387
+ B, C, H, W = x.shape
388
+ x = rearrange(x, 'b c h w -> b (h w) c')
389
+ if self.norm:
390
+ x = self.norm(x)
391
+ x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
392
+
393
+ return x
394
+
395
+
396
+ class VisionTransformer(nn.Module):
397
+ """ Vision Transformer with support for patch or hybrid CNN input stage
398
+ """
399
+ def __init__(self,
400
+ patch_size=16,
401
+ patch_stride=16,
402
+ patch_padding=0,
403
+ in_chans=1, #1for spectrogram, 3 for RGB
404
+ embed_dim=768,
405
+ depth=12,
406
+ num_heads=12,
407
+ mlp_ratio=4.,
408
+ qkv_bias=False,
409
+ drop_rate=0.,
410
+ attn_drop_rate=0.,
411
+ drop_path_rate=0.,
412
+ act_layer=nn.GELU,
413
+ norm_layer=nn.LayerNorm,
414
+ init='trunc_norm',
415
+ **kwargs):
416
+ super().__init__()
417
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
418
+
419
+ self.rearrage = None
420
+
421
+ self.patch_embed = ConvEmbed(
422
+ # img_size=img_size,
423
+ patch_size=patch_size,
424
+ in_chans=in_chans,
425
+ stride=patch_stride,
426
+ padding=patch_padding,
427
+ embed_dim=embed_dim,
428
+ norm_layer=norm_layer
429
+ )
430
+
431
+ with_cls_token = kwargs['with_cls_token']
432
+ if with_cls_token:
433
+ self.cls_token = nn.Parameter(
434
+ torch.zeros(1, 1, embed_dim)
435
+ )
436
+ else:
437
+ self.cls_token = None
438
+
439
+ self.pos_drop = nn.Dropout(p=drop_rate)
440
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
441
+
442
+ blocks = []
443
+ for j in range(depth):
444
+ blocks.append(
445
+ Block(
446
+ dim_in=embed_dim,
447
+ dim_out=embed_dim,
448
+ num_heads=num_heads,
449
+ mlp_ratio=mlp_ratio,
450
+ qkv_bias=qkv_bias,
451
+ drop=drop_rate,
452
+ attn_drop=attn_drop_rate,
453
+ drop_path=dpr[j],
454
+ act_layer=act_layer,
455
+ norm_layer=norm_layer,
456
+ **kwargs
457
+ )
458
+ )
459
+ self.blocks = nn.ModuleList(blocks)
460
+
461
+ if self.cls_token is not None:
462
+ trunc_normal_(self.cls_token, std=.02)
463
+
464
+ if init == 'xavier':
465
+ self.apply(self._init_weights_xavier)
466
+ else:
467
+ self.apply(self._init_weights_trunc_normal)
468
+
469
+ def _init_weights_trunc_normal(self, m):
470
+ if isinstance(m, nn.Linear):
471
+ logging.info('=> init weight of Linear from trunc norm')
472
+ trunc_normal_(m.weight, std=0.02)
473
+ if m.bias is not None:
474
+ logging.info('=> init bias of Linear to zeros')
475
+ nn.init.constant_(m.bias, 0)
476
+ elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
477
+ nn.init.constant_(m.bias, 0)
478
+ nn.init.constant_(m.weight, 1.0)
479
+
480
+ def _init_weights_xavier(self, m):
481
+ if isinstance(m, nn.Linear):
482
+ logging.info('=> init weight of Linear from xavier uniform')
483
+ nn.init.xavier_uniform_(m.weight)
484
+ if m.bias is not None:
485
+ logging.info('=> init bias of Linear to zeros')
486
+ nn.init.constant_(m.bias, 0)
487
+ elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
488
+ nn.init.constant_(m.bias, 0)
489
+ nn.init.constant_(m.weight, 1.0)
490
+
491
+ def forward(self, x):
492
+ x = self.patch_embed(x)
493
+ B, C, H, W = x.size()
494
+
495
+ x = rearrange(x, 'b c h w -> b (h w) c')
496
+
497
+ cls_tokens = None
498
+ if self.cls_token is not None:
499
+ # stole cls_tokens impl from Phil Wang, thanks
500
+ cls_tokens = self.cls_token.expand(B, -1, -1)
501
+ x = torch.cat((cls_tokens, x), dim=1)
502
+
503
+ x = self.pos_drop(x)
504
+
505
+ for i, blk in enumerate(self.blocks):
506
+ x = blk(x, H, W)
507
+
508
+ if self.cls_token is not None:
509
+ cls_tokens, x = torch.split(x, [1, H*W], 1)
510
+ x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
511
+
512
+ return x, cls_tokens
513
+
514
+
515
+ class ConvolutionalVisionTransformer(nn.Module):
516
+ def __init__(self,
517
+ in_chans=1, #3 for RGB, 1 for Spectrogram
518
+ num_classes=1000,
519
+ act_layer=nn.GELU,
520
+ norm_layer=nn.LayerNorm,
521
+ init='trunc_norm',
522
+ spec=None):
523
+ super().__init__()
524
+ self.num_classes = num_classes
525
+
526
+ self.num_stages = spec['NUM_STAGES']
527
+ for i in range(self.num_stages):
528
+ kwargs = {
529
+ 'patch_size': spec['PATCH_SIZE'][i],
530
+ 'patch_stride': spec['PATCH_STRIDE'][i],
531
+ 'patch_padding': spec['PATCH_PADDING'][i],
532
+ 'embed_dim': spec['DIM_EMBED'][i],
533
+ 'depth': spec['DEPTH'][i],
534
+ 'num_heads': spec['NUM_HEADS'][i],
535
+ 'mlp_ratio': spec['MLP_RATIO'][i],
536
+ 'qkv_bias': spec['QKV_BIAS'][i],
537
+ 'drop_rate': spec['DROP_RATE'][i],
538
+ 'attn_drop_rate': spec['ATTN_DROP_RATE'][i],
539
+ 'drop_path_rate': spec['DROP_PATH_RATE'][i],
540
+ 'with_cls_token': spec['CLS_TOKEN'][i],
541
+ 'method': spec['QKV_PROJ_METHOD'][i],
542
+ 'kernel_size': spec['KERNEL_QKV'][i],
543
+ 'padding_q': spec['PADDING_Q'][i],
544
+ 'padding_kv': spec['PADDING_KV'][i],
545
+ 'stride_kv': spec['STRIDE_KV'][i],
546
+ 'stride_q': spec['STRIDE_Q'][i],
547
+ }
548
+
549
+ stage = VisionTransformer(
550
+ in_chans=in_chans,
551
+ init=init,
552
+ act_layer=act_layer,
553
+ norm_layer=norm_layer,
554
+ **kwargs
555
+ )
556
+ setattr(self, f'stage{i}', stage)
557
+
558
+ in_chans = spec['DIM_EMBED'][i]
559
+
560
+ dim_embed = spec['DIM_EMBED'][-1]
561
+ self.norm = norm_layer(dim_embed)
562
+ self.cls_token = spec['CLS_TOKEN'][-1]
563
+
564
+ # Classifier head
565
+ #self.head = nn.Linear(dim_embed, num_classes) if num_classes > 0 else nn.Identity()
566
+ #trunc_normal_(self.head.weight, std=0.02)
567
+ self.head = nn.Identity()
568
+
569
+
570
+
571
+ def init_weights(self, pretrained='', pretrained_layers=[], verbose=True):
572
+ if os.path.isfile(pretrained):
573
+ pretrained_dict = torch.load(pretrained, map_location='cpu')
574
+ logging.info(f'=> loading pretrained model {pretrained}')
575
+ model_dict = self.state_dict()
576
+ pretrained_dict = {
577
+ k: v for k, v in pretrained_dict.items()
578
+ if k in model_dict.keys()
579
+ }
580
+ need_init_state_dict = {}
581
+ for k, v in pretrained_dict.items():
582
+ need_init = (
583
+ k.split('.')[0] in pretrained_layers
584
+ #or pretrained_layers[0] is '*'
585
+ or pretrained_layers[0] == '*'
586
+ )
587
+ if need_init:
588
+ if verbose:
589
+ logging.info(f'=> init {k} from {pretrained}')
590
+ if 'pos_embed' in k and v.size() != model_dict[k].size():
591
+ size_pretrained = v.size()
592
+ size_new = model_dict[k].size()
593
+ logging.info(
594
+ '=> load_pretrained: resized variant: {} to {}'
595
+ .format(size_pretrained, size_new)
596
+ )
597
+
598
+ ntok_new = size_new[1]
599
+ ntok_new -= 1
600
+
601
+ posemb_tok, posemb_grid = v[:, :1], v[0, 1:]
602
+
603
+ gs_old = int(np.sqrt(len(posemb_grid)))
604
+ gs_new = int(np.sqrt(ntok_new))
605
+
606
+ logging.info(
607
+ '=> load_pretrained: grid-size from {} to {}'
608
+ .format(gs_old, gs_new)
609
+ )
610
+
611
+ posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
612
+ zoom = (gs_new / gs_old, gs_new / gs_old, 1)
613
+ posemb_grid = scipy.ndimage.zoom(
614
+ posemb_grid, zoom, order=1
615
+ )
616
+ posemb_grid = posemb_grid.reshape(1, gs_new ** 2, -1)
617
+ v = torch.tensor(
618
+ np.concatenate([posemb_tok, posemb_grid], axis=1)
619
+ )
620
+
621
+ need_init_state_dict[k] = v
622
+ self.load_state_dict(need_init_state_dict, strict=False)
623
+
624
+ @torch.jit.ignore
625
+ def no_weight_decay(self):
626
+ layers = set()
627
+ for i in range(self.num_stages):
628
+ layers.add(f'stage{i}.pos_embed')
629
+ layers.add(f'stage{i}.cls_token')
630
+
631
+ return layers
632
+
633
+ def forward_features(self, x):
634
+ for i in range(self.num_stages):
635
+ x, cls_tokens = getattr(self, f'stage{i}')(x)
636
+
637
+ if self.cls_token:
638
+ x = self.norm(cls_tokens)
639
+ #x = cls_tokens
640
+ x = torch.squeeze(x)
641
+ else:
642
+ x = rearrange(x, 'b c h w -> b (h w) c')
643
+ x = self.norm(x)
644
+ x = torch.mean(x, dim=1)
645
+
646
+ return x
647
+
648
+ def forward(self, x):
649
+ x = self.forward_features(x)
650
+ x = self.head(x)
651
+
652
+ return x
653
+
654
+
655
+ @register_model
656
+ def get_cls_model(**kwargs):
657
+ msvit_spec = config.MODEL.SPEC
658
+ msvit = ConvolutionalVisionTransformer(
659
+ in_chans=1, #1 for spectrogram 3 for RGB
660
+ num_classes=config.MODEL.NUM_CLASSES,
661
+ act_layer=QuickGELU,
662
+ norm_layer=partial(LayerNorm, eps=1e-5),
663
+ init=getattr(msvit_spec, 'INIT', 'trunc_norm'),
664
+ spec=msvit_spec
665
+ )
666
+
667
+ # if config.MODEL.INIT_WEIGHTS:
668
+ # msvit.init_weights(
669
+ # config.MODEL.PRETRAINED,
670
+ # config.MODEL.PRETRAINED_LAYERS,
671
+ # config.VERBOSE
672
+ # )
673
+
674
+ return msvit
675
+
676
+ def build_model(config, **kwargs):
677
+ model_name = config.MODEL.NAME
678
+ if not is_model(model_name):
679
+ raise ValueError(f'Unkown model: {model_name}')
680
+
681
+ return model_entrypoints(model_name)(config, **kwargs)
682
+
683
+ def cvt13(**kwargs):
684
+ f = open('config/cvt-13-224x224.yaml', 'r')
685
+ config = yaml.safe_load(f)
686
+ return ConvolutionalVisionTransformer(spec=config['MODEL']['SPEC']) # only loades the config, no pretraining
687
+
688
+ if __name__ == '__main__':
689
+ f = open('config/cvt-13-224x224.yaml', 'r')
690
+ config = yaml.safe_load(f)
691
+ model = ConvolutionalVisionTransformer(spec=config['MODEL']['SPEC'])
692
+ print(summary(model))
693
+ quit()
694
+ print(summary(model, input_size=(4, 1, 128, 301)))
melspectrogram.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchaudio import transforms as T
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ MEAN, STD = 0.5347, 0.0772 # Xeno-Canto stats
6
+ SR = 16000
7
+ NFFT = 1024
8
+ HOPLEN = 320
9
+ NMELS = 128
10
+ FMIN = 50
11
+ FMAX = 8000
12
+
13
+ class Normalization(torch.nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def forward(self, x):
18
+ return (x - x.min()) / (x.max() - x.min())
19
+
20
+ class Standardization(torch.nn.Module):
21
+ def __init__(self, mean, std):
22
+ super().__init__()
23
+
24
+ self.mean = mean
25
+ self.std = std
26
+
27
+ def forward(self, x):
28
+ return (x - self.mean) / self.std
29
+
30
+ class MelSpectrogramProcessor:
31
+ def __init__(self, sample_rate=SR, n_mels=NMELS, n_fft=NFFT, hop_length=HOPLEN, f_min=FMIN, f_max=FMAX):
32
+ self.transform = nn.Sequential(
33
+ T.MelSpectrogram(sample_rate=sample_rate, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length, f_min=f_min, f_max=f_max),
34
+ T.AmplitudeToDB(),
35
+ Normalization(),
36
+ Standardization(mean=MEAN, std=STD),
37
+ )
38
+
39
+ def process(self, waveform):
40
+ return self.transform(waveform)
protoclr.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7eaf62e2f66084f50cbb9677420a3a93b81334fadddeb4eb9790e734f4a514f
3
+ size 78717724
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ einops==0.8.0
2
+ numpy==2.1.3
3
+ PyYAML==6.0.2
4
+ scipy==1.14.1
5
+ timm==1.0.11
6
+ torch==2.2.2
7
+ torchaudio==2.2.2
8
+ torchinfo==1.8.0
9
+ yacs==0.1.8