ilyassmoummad
commited on
Commit
•
aab5975
1
Parent(s):
2cc9f37
init
Browse files- README.md +75 -0
- config/__init__.py +4 -0
- config/comm.py +132 -0
- config/cvt-13-224x224.yaml +83 -0
- config/default.py +202 -0
- cvt.py +694 -0
- melspectrogram.py +40 -0
- protoclr.pth +3 -0
- requirements.txt +9 -0
README.md
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ProtoCLR
|
2 |
+
|
3 |
+
This repository contains a CvT-13 [Convolutional Vision Transformer](https://arxiv.org/abs/2103.15808) model trained from scratch on the [Xeno-Canto dataset](https://huggingface.co/datasets/ilyassmoummad/Xeno-Canto-6s-16khz), specifically on 6-second audio segments sampled at 16 kHz. The model is trained on Mel spectrograms of bird sounds using ProtoCLR [(Prototypical Contrastive Loss)](https://arxiv.org/abs/2409.08589) for 300 epochs and can be used as a feature extractor for bird audio classification and related tasks.
|
4 |
+
|
5 |
+
## Files
|
6 |
+
|
7 |
+
- `cvt.py`: Defines the CvT-13 model architecture.
|
8 |
+
- `protoclr_300.pth`: Pre-trained model weights for ProtoCLR.
|
9 |
+
- `config/`: Configuration files for CvT-13 setup.
|
10 |
+
- `mel_spectrogram.py`: Contains the `MelSpectrogramProcessor` class, which converts audio waveforms into Mel spectrograms, a format suitable for model input.
|
11 |
+
|
12 |
+
## Setup
|
13 |
+
|
14 |
+
1. **Install dependencies**:
|
15 |
+
Ensure you have the required Python packages, including `torch` and any other dependencies listed in `requirements.txt`.
|
16 |
+
```bash
|
17 |
+
pip install -r requirements.txt
|
18 |
+
```
|
19 |
+
|
20 |
+
2. **Prepare the audio**:
|
21 |
+
- **Sample rate**: Ensure your audio is sampled at 16 kHz.
|
22 |
+
- **Padding**: For audio shorter than 6 seconds, pad with zeros or repeat the audio to reach 6 seconds.
|
23 |
+
- **Chunking**: For audio longer than 6 seconds, split it into 6-second chunks.
|
24 |
+
|
25 |
+
## Usage
|
26 |
+
|
27 |
+
To use the model, process your audio data using the `MelSpectrogramProcessor`, and then pass the processed spectrograms to the CvT-13 model.
|
28 |
+
|
29 |
+
## Example Code
|
30 |
+
|
31 |
+
The following example demonstrates loading, processing, and running inference on an audio file:
|
32 |
+
|
33 |
+
```python
|
34 |
+
import torch
|
35 |
+
from cvt import cvt13 # Import model architecture
|
36 |
+
from melspectrogram import MelSpectrogramProcessor # Import Mel spectrogram processor
|
37 |
+
|
38 |
+
# Initialize the preprocessor and model
|
39 |
+
preprocessor = MelSpectrogramProcessor()
|
40 |
+
model = cvt13()
|
41 |
+
model.load_state_dict(torch.load("protoclr.pth"))
|
42 |
+
model.eval()
|
43 |
+
|
44 |
+
# Load and preprocess a sample audio waveform
|
45 |
+
def load_waveform(file_path):
|
46 |
+
# Replace this with your specific audio loading function
|
47 |
+
# For example, using torchaudio to load and resample
|
48 |
+
pass
|
49 |
+
|
50 |
+
waveform = load_waveform("path/to/audio.wav") # Load your audio file here
|
51 |
+
|
52 |
+
# Ensure waveform is sampled at 16 kHz, then pad/chunk as needed for 6s length
|
53 |
+
input_tensor = preprocessor.process(waveform).unsqueeze(0) # Add batch dimension
|
54 |
+
|
55 |
+
# Run the model on the preprocessed audio
|
56 |
+
with torch.no_grad():
|
57 |
+
output = model(input_tensor)
|
58 |
+
print("Model output shape:", output.shape)
|
59 |
+
```
|
60 |
+
|
61 |
+
## Citation
|
62 |
+
|
63 |
+
If you use our model in your research, please cite the following paper:
|
64 |
+
|
65 |
+
```bibtex
|
66 |
+
@misc{moummad2024dirlbs,
|
67 |
+
title={Domain-Invariant Representation Learning of Bird Sounds},
|
68 |
+
author={Ilyass Moummad and Romain Serizel and Emmanouil Benetos and Nicolas Farrugia},
|
69 |
+
year={2024},
|
70 |
+
eprint={2409.08589},
|
71 |
+
archivePrefix={arXiv},
|
72 |
+
primaryClass={cs.SD},
|
73 |
+
url={https://arxiv.org/abs/2409.08589},
|
74 |
+
}
|
75 |
+
```
|
config/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .default import _C as config
|
2 |
+
from .default import update_config
|
3 |
+
from .default import _update_config_from_file
|
4 |
+
from .default import save_config
|
config/comm.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.distributed as dist
|
5 |
+
|
6 |
+
|
7 |
+
class Comm(object):
|
8 |
+
def __init__(self, local_rank=0):
|
9 |
+
self.local_rank = 0
|
10 |
+
|
11 |
+
@property
|
12 |
+
def world_size(self):
|
13 |
+
if not dist.is_available():
|
14 |
+
return 1
|
15 |
+
if not dist.is_initialized():
|
16 |
+
return 1
|
17 |
+
return dist.get_world_size()
|
18 |
+
|
19 |
+
@property
|
20 |
+
def rank(self):
|
21 |
+
if not dist.is_available():
|
22 |
+
return 0
|
23 |
+
if not dist.is_initialized():
|
24 |
+
return 0
|
25 |
+
return dist.get_rank()
|
26 |
+
|
27 |
+
@property
|
28 |
+
def local_rank(self):
|
29 |
+
if not dist.is_available():
|
30 |
+
return 0
|
31 |
+
if not dist.is_initialized():
|
32 |
+
return 0
|
33 |
+
return self._local_rank
|
34 |
+
|
35 |
+
@local_rank.setter
|
36 |
+
def local_rank(self, value):
|
37 |
+
if not dist.is_available():
|
38 |
+
self._local_rank = 0
|
39 |
+
if not dist.is_initialized():
|
40 |
+
self._local_rank = 0
|
41 |
+
self._local_rank = value
|
42 |
+
|
43 |
+
@property
|
44 |
+
def head(self):
|
45 |
+
return 'Rank[{}/{}]'.format(self.rank, self.world_size)
|
46 |
+
|
47 |
+
def is_main_process(self):
|
48 |
+
return self.rank == 0
|
49 |
+
|
50 |
+
def synchronize(self):
|
51 |
+
"""
|
52 |
+
Helper function to synchronize (barrier) among all processes when
|
53 |
+
using distributed training
|
54 |
+
"""
|
55 |
+
if self.world_size == 1:
|
56 |
+
return
|
57 |
+
dist.barrier()
|
58 |
+
|
59 |
+
|
60 |
+
comm = Comm()
|
61 |
+
|
62 |
+
|
63 |
+
def all_gather(data):
|
64 |
+
"""
|
65 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
66 |
+
Args:
|
67 |
+
data: any picklable object
|
68 |
+
Returns:
|
69 |
+
list[data]: list of data gathered from each rank
|
70 |
+
"""
|
71 |
+
world_size = comm.world_size
|
72 |
+
if world_size == 1:
|
73 |
+
return [data]
|
74 |
+
|
75 |
+
# serialized to a Tensor
|
76 |
+
buffer = pickle.dumps(data)
|
77 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
78 |
+
tensor = torch.ByteTensor(storage).to("cuda")
|
79 |
+
|
80 |
+
# obtain Tensor size of each rank
|
81 |
+
local_size = torch.LongTensor([tensor.numel()]).to("cuda")
|
82 |
+
size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
|
83 |
+
dist.all_gather(size_list, local_size)
|
84 |
+
size_list = [int(size.item()) for size in size_list]
|
85 |
+
max_size = max(size_list)
|
86 |
+
|
87 |
+
# receiving Tensor from all ranks
|
88 |
+
# we pad the tensor because torch all_gather does not support
|
89 |
+
# gathering tensors of different shapes
|
90 |
+
tensor_list = []
|
91 |
+
for _ in size_list:
|
92 |
+
tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
|
93 |
+
if local_size != max_size:
|
94 |
+
padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
|
95 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
96 |
+
dist.all_gather(tensor_list, tensor)
|
97 |
+
|
98 |
+
data_list = []
|
99 |
+
for size, tensor in zip(size_list, tensor_list):
|
100 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
101 |
+
data_list.append(pickle.loads(buffer))
|
102 |
+
|
103 |
+
return data_list
|
104 |
+
|
105 |
+
|
106 |
+
def reduce_dict(input_dict, average=True):
|
107 |
+
"""
|
108 |
+
Args:
|
109 |
+
input_dict (dict): all the values will be reduced
|
110 |
+
average (bool): whether to do average or sum
|
111 |
+
Reduce the values in the dictionary from all processes so that process with rank
|
112 |
+
0 has the averaged results. Returns a dict with the same fields as
|
113 |
+
input_dict, after reduction.
|
114 |
+
"""
|
115 |
+
world_size = comm.world_size
|
116 |
+
if world_size < 2:
|
117 |
+
return input_dict
|
118 |
+
with torch.no_grad():
|
119 |
+
names = []
|
120 |
+
values = []
|
121 |
+
# sort the keys so that they are consistent across processes
|
122 |
+
for k in sorted(input_dict.keys()):
|
123 |
+
names.append(k)
|
124 |
+
values.append(input_dict[k])
|
125 |
+
values = torch.stack(values, dim=0)
|
126 |
+
dist.reduce(values, dst=0)
|
127 |
+
if dist.get_rank() == 0 and average:
|
128 |
+
# only main process gets accumulated, so only divide by
|
129 |
+
# world_size in this case
|
130 |
+
values /= world_size
|
131 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
132 |
+
return reduced_dict
|
config/cvt-13-224x224.yaml
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
OUTPUT_DIR: 'OUTPUT/'
|
2 |
+
WORKERS: 6
|
3 |
+
PRINT_FREQ: 500
|
4 |
+
AMP:
|
5 |
+
ENABLED: true
|
6 |
+
|
7 |
+
MODEL:
|
8 |
+
NAME: cls_cvt
|
9 |
+
SPEC:
|
10 |
+
INIT: 'trunc_norm'
|
11 |
+
NUM_STAGES: 3
|
12 |
+
PATCH_SIZE: [7, 3, 3]
|
13 |
+
PATCH_STRIDE: [4, 2, 2]
|
14 |
+
PATCH_PADDING: [2, 1, 1]
|
15 |
+
DIM_EMBED: [64, 192, 384]
|
16 |
+
NUM_HEADS: [1, 3, 6]
|
17 |
+
DEPTH: [1, 2, 10]
|
18 |
+
MLP_RATIO: [4.0, 4.0, 4.0]
|
19 |
+
ATTN_DROP_RATE: [0.0, 0.0, 0.0]
|
20 |
+
DROP_RATE: [0.0, 0.0, 0.0]
|
21 |
+
DROP_PATH_RATE: [0.0, 0.0, 0.1]
|
22 |
+
QKV_BIAS: [True, True, True]
|
23 |
+
CLS_TOKEN: [False, False, True]
|
24 |
+
POS_EMBED: [False, False, False]
|
25 |
+
QKV_PROJ_METHOD: ['dw_bn', 'dw_bn', 'dw_bn']
|
26 |
+
KERNEL_QKV: [3, 3, 3]
|
27 |
+
PADDING_KV: [1, 1, 1]
|
28 |
+
STRIDE_KV: [2, 2, 2]
|
29 |
+
PADDING_Q: [1, 1, 1]
|
30 |
+
STRIDE_Q: [1, 1, 1]
|
31 |
+
AUG:
|
32 |
+
MIXUP_PROB: 1.0
|
33 |
+
MIXUP: 0.8
|
34 |
+
MIXCUT: 1.0
|
35 |
+
TIMM_AUG:
|
36 |
+
USE_LOADER: true
|
37 |
+
RE_COUNT: 1
|
38 |
+
RE_MODE: pixel
|
39 |
+
RE_SPLIT: false
|
40 |
+
RE_PROB: 0.25
|
41 |
+
AUTO_AUGMENT: rand-m9-mstd0.5-inc1
|
42 |
+
HFLIP: 0.5
|
43 |
+
VFLIP: 0.0
|
44 |
+
COLOR_JITTER: 0.4
|
45 |
+
INTERPOLATION: bicubic
|
46 |
+
LOSS:
|
47 |
+
LABEL_SMOOTHING: 0.1
|
48 |
+
CUDNN:
|
49 |
+
BENCHMARK: true
|
50 |
+
DETERMINISTIC: false
|
51 |
+
ENABLED: true
|
52 |
+
DATASET:
|
53 |
+
DATASET: 'imagenet'
|
54 |
+
DATA_FORMAT: 'jpg'
|
55 |
+
ROOT: 'DATASET/imagenet/'
|
56 |
+
TEST_SET: 'val'
|
57 |
+
TRAIN_SET: 'train'
|
58 |
+
TEST:
|
59 |
+
BATCH_SIZE_PER_GPU: 32
|
60 |
+
IMAGE_SIZE: [224, 224]
|
61 |
+
MODEL_FILE: ''
|
62 |
+
INTERPOLATION: 3
|
63 |
+
TRAIN:
|
64 |
+
BATCH_SIZE_PER_GPU: 256
|
65 |
+
LR: 0.00025
|
66 |
+
IMAGE_SIZE: [224, 224]
|
67 |
+
BEGIN_EPOCH: 0
|
68 |
+
END_EPOCH: 300
|
69 |
+
LR_SCHEDULER:
|
70 |
+
METHOD: 'timm'
|
71 |
+
ARGS:
|
72 |
+
sched: 'cosine'
|
73 |
+
warmup_epochs: 5
|
74 |
+
warmup_lr: 0.000001
|
75 |
+
min_lr: 0.00001
|
76 |
+
cooldown_epochs: 10
|
77 |
+
decay_rate: 0.1
|
78 |
+
OPTIMIZER: adamW
|
79 |
+
WD: 0.05
|
80 |
+
WITHOUT_WD_LIST: ['bn', 'bias', 'ln']
|
81 |
+
SHUFFLE: true
|
82 |
+
DEBUG:
|
83 |
+
DEBUG: false
|
config/default.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import os.path as op
|
6 |
+
import yaml
|
7 |
+
from yacs.config import CfgNode as CN
|
8 |
+
|
9 |
+
from config import comm
|
10 |
+
|
11 |
+
|
12 |
+
_C = CN()
|
13 |
+
|
14 |
+
_C.BASE = ['']
|
15 |
+
_C.NAME = ''
|
16 |
+
_C.DATA_DIR = ''
|
17 |
+
_C.DIST_BACKEND = 'nccl'
|
18 |
+
_C.GPUS = (0,)
|
19 |
+
# _C.LOG_DIR = ''
|
20 |
+
_C.MULTIPROCESSING_DISTRIBUTED = True
|
21 |
+
_C.OUTPUT_DIR = ''
|
22 |
+
_C.PIN_MEMORY = True
|
23 |
+
_C.PRINT_FREQ = 20
|
24 |
+
_C.RANK = 0
|
25 |
+
_C.VERBOSE = True
|
26 |
+
_C.WORKERS = 4
|
27 |
+
_C.MODEL_SUMMARY = False
|
28 |
+
|
29 |
+
_C.AMP = CN()
|
30 |
+
_C.AMP.ENABLED = False
|
31 |
+
_C.AMP.MEMORY_FORMAT = 'nchw'
|
32 |
+
|
33 |
+
# Cudnn related params
|
34 |
+
_C.CUDNN = CN()
|
35 |
+
_C.CUDNN.BENCHMARK = True
|
36 |
+
_C.CUDNN.DETERMINISTIC = False
|
37 |
+
_C.CUDNN.ENABLED = True
|
38 |
+
|
39 |
+
# common params for NETWORK
|
40 |
+
_C.MODEL = CN()
|
41 |
+
_C.MODEL.NAME = 'cls_hrnet'
|
42 |
+
_C.MODEL.INIT_WEIGHTS = True
|
43 |
+
_C.MODEL.PRETRAINED = ''
|
44 |
+
_C.MODEL.PRETRAINED_LAYERS = ['*']
|
45 |
+
_C.MODEL.NUM_CLASSES = 1000
|
46 |
+
_C.MODEL.SPEC = CN(new_allowed=True)
|
47 |
+
|
48 |
+
_C.LOSS = CN(new_allowed=True)
|
49 |
+
_C.LOSS.LABEL_SMOOTHING = 0.0
|
50 |
+
_C.LOSS.LOSS = 'softmax'
|
51 |
+
|
52 |
+
# DATASET related params
|
53 |
+
_C.DATASET = CN()
|
54 |
+
_C.DATASET.ROOT = ''
|
55 |
+
_C.DATASET.DATASET = 'imagenet'
|
56 |
+
_C.DATASET.TRAIN_SET = 'train'
|
57 |
+
_C.DATASET.TEST_SET = 'val'
|
58 |
+
_C.DATASET.DATA_FORMAT = 'jpg'
|
59 |
+
_C.DATASET.LABELMAP = ''
|
60 |
+
_C.DATASET.TRAIN_TSV_LIST = []
|
61 |
+
_C.DATASET.TEST_TSV_LIST = []
|
62 |
+
_C.DATASET.SAMPLER = 'default'
|
63 |
+
|
64 |
+
_C.DATASET.TARGET_SIZE = -1
|
65 |
+
|
66 |
+
# training data augmentation
|
67 |
+
_C.INPUT = CN()
|
68 |
+
_C.INPUT.MEAN = [0.485, 0.456, 0.406]
|
69 |
+
_C.INPUT.STD = [0.229, 0.224, 0.225]
|
70 |
+
|
71 |
+
# data augmentation
|
72 |
+
_C.AUG = CN()
|
73 |
+
_C.AUG.SCALE = (0.08, 1.0)
|
74 |
+
_C.AUG.RATIO = (3.0/4.0, 4.0/3.0)
|
75 |
+
_C.AUG.COLOR_JITTER = [0.4, 0.4, 0.4, 0.1, 0.0]
|
76 |
+
_C.AUG.GRAY_SCALE = 0.0
|
77 |
+
_C.AUG.GAUSSIAN_BLUR = 0.0
|
78 |
+
_C.AUG.DROPBLOCK_LAYERS = [3, 4]
|
79 |
+
_C.AUG.DROPBLOCK_KEEP_PROB = 1.0
|
80 |
+
_C.AUG.DROPBLOCK_BLOCK_SIZE = 7
|
81 |
+
_C.AUG.MIXUP_PROB = 0.0
|
82 |
+
_C.AUG.MIXUP = 0.0
|
83 |
+
_C.AUG.MIXCUT = 0.0
|
84 |
+
_C.AUG.MIXCUT_MINMAX = []
|
85 |
+
_C.AUG.MIXUP_SWITCH_PROB = 0.5
|
86 |
+
_C.AUG.MIXUP_MODE = 'batch'
|
87 |
+
_C.AUG.MIXCUT_AND_MIXUP = False
|
88 |
+
_C.AUG.INTERPOLATION = 2
|
89 |
+
_C.AUG.TIMM_AUG = CN(new_allowed=True)
|
90 |
+
_C.AUG.TIMM_AUG.USE_LOADER = False
|
91 |
+
_C.AUG.TIMM_AUG.USE_TRANSFORM = False
|
92 |
+
|
93 |
+
# train
|
94 |
+
_C.TRAIN = CN()
|
95 |
+
|
96 |
+
_C.TRAIN.AUTO_RESUME = True
|
97 |
+
_C.TRAIN.CHECKPOINT = ''
|
98 |
+
_C.TRAIN.LR_SCHEDULER = CN(new_allowed=True)
|
99 |
+
_C.TRAIN.SCALE_LR = True
|
100 |
+
_C.TRAIN.LR = 0.001
|
101 |
+
|
102 |
+
_C.TRAIN.OPTIMIZER = 'sgd'
|
103 |
+
_C.TRAIN.OPTIMIZER_ARGS = CN(new_allowed=True)
|
104 |
+
_C.TRAIN.MOMENTUM = 0.9
|
105 |
+
_C.TRAIN.WD = 0.0001
|
106 |
+
_C.TRAIN.WITHOUT_WD_LIST = []
|
107 |
+
_C.TRAIN.NESTEROV = True
|
108 |
+
# for adam
|
109 |
+
_C.TRAIN.GAMMA1 = 0.99
|
110 |
+
_C.TRAIN.GAMMA2 = 0.0
|
111 |
+
|
112 |
+
_C.TRAIN.BEGIN_EPOCH = 0
|
113 |
+
_C.TRAIN.END_EPOCH = 100
|
114 |
+
|
115 |
+
_C.TRAIN.IMAGE_SIZE = [224, 224] # width * height, ex: 192 * 256
|
116 |
+
_C.TRAIN.BATCH_SIZE_PER_GPU = 32
|
117 |
+
_C.TRAIN.SHUFFLE = True
|
118 |
+
|
119 |
+
_C.TRAIN.EVAL_BEGIN_EPOCH = 0
|
120 |
+
|
121 |
+
_C.TRAIN.DETECT_ANOMALY = False
|
122 |
+
|
123 |
+
_C.TRAIN.CLIP_GRAD_NORM = 0.0
|
124 |
+
_C.TRAIN.SAVE_ALL_MODELS = False
|
125 |
+
|
126 |
+
# testing
|
127 |
+
_C.TEST = CN()
|
128 |
+
|
129 |
+
# size of images for each device
|
130 |
+
_C.TEST.BATCH_SIZE_PER_GPU = 32
|
131 |
+
_C.TEST.CENTER_CROP = True
|
132 |
+
_C.TEST.IMAGE_SIZE = [224, 224] # width * height, ex: 192 * 256
|
133 |
+
_C.TEST.INTERPOLATION = 2
|
134 |
+
_C.TEST.MODEL_FILE = ''
|
135 |
+
_C.TEST.REAL_LABELS = False
|
136 |
+
_C.TEST.VALID_LABELS = ''
|
137 |
+
|
138 |
+
_C.FINETUNE = CN()
|
139 |
+
_C.FINETUNE.FINETUNE = False
|
140 |
+
_C.FINETUNE.USE_TRAIN_AUG = False
|
141 |
+
_C.FINETUNE.BASE_LR = 0.003
|
142 |
+
_C.FINETUNE.BATCH_SIZE = 512
|
143 |
+
_C.FINETUNE.EVAL_EVERY = 3000
|
144 |
+
_C.FINETUNE.TRAIN_MODE = True
|
145 |
+
# _C.FINETUNE.MODEL_FILE = ''
|
146 |
+
_C.FINETUNE.FROZEN_LAYERS = []
|
147 |
+
_C.FINETUNE.LR_SCHEDULER = CN(new_allowed=True)
|
148 |
+
_C.FINETUNE.LR_SCHEDULER.DECAY_TYPE = 'step'
|
149 |
+
|
150 |
+
# debug
|
151 |
+
_C.DEBUG = CN()
|
152 |
+
_C.DEBUG.DEBUG = False
|
153 |
+
|
154 |
+
|
155 |
+
def _update_config_from_file(config, cfg_file):
|
156 |
+
config.defrost()
|
157 |
+
with open(cfg_file, 'r') as f:
|
158 |
+
yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
|
159 |
+
|
160 |
+
for cfg in yaml_cfg.setdefault('BASE', ['']):
|
161 |
+
if cfg:
|
162 |
+
_update_config_from_file(
|
163 |
+
config, op.join(op.dirname(cfg_file), cfg)
|
164 |
+
)
|
165 |
+
print('=> merge config from {}'.format(cfg_file))
|
166 |
+
config.merge_from_file(cfg_file)
|
167 |
+
config.freeze()
|
168 |
+
|
169 |
+
|
170 |
+
def update_config(config, args):
|
171 |
+
_update_config_from_file(config, args.cfg)
|
172 |
+
|
173 |
+
config.defrost()
|
174 |
+
config.merge_from_list(args.opts)
|
175 |
+
if config.TRAIN.SCALE_LR:
|
176 |
+
config.TRAIN.LR *= comm.world_size
|
177 |
+
file_name, _ = op.splitext(op.basename(args.cfg))
|
178 |
+
config.NAME = file_name + config.NAME
|
179 |
+
config.RANK = comm.rank
|
180 |
+
|
181 |
+
if 'timm' == config.TRAIN.LR_SCHEDULER.METHOD:
|
182 |
+
config.TRAIN.LR_SCHEDULER.ARGS.epochs = config.TRAIN.END_EPOCH
|
183 |
+
|
184 |
+
if 'timm' == config.TRAIN.OPTIMIZER:
|
185 |
+
config.TRAIN.OPTIMIZER_ARGS.lr = config.TRAIN.LR
|
186 |
+
|
187 |
+
aug = config.AUG
|
188 |
+
if aug.MIXUP > 0.0 or aug.MIXCUT > 0.0 or aug.MIXCUT_MINMAX:
|
189 |
+
aug.MIXUP_PROB = 1.0
|
190 |
+
config.freeze()
|
191 |
+
|
192 |
+
|
193 |
+
def save_config(cfg, path):
|
194 |
+
if comm.is_main_process():
|
195 |
+
with open(path, 'w') as f:
|
196 |
+
f.write(cfg.dump())
|
197 |
+
|
198 |
+
|
199 |
+
if __name__ == '__main__':
|
200 |
+
import sys
|
201 |
+
with open(sys.argv[1], 'w') as f:
|
202 |
+
print(_C, file=f)
|
cvt.py
ADDED
@@ -0,0 +1,694 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from itertools import repeat
|
3 |
+
#from torch._six import container_abcs
|
4 |
+
import collections.abc as container_abcs
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import os
|
8 |
+
from collections import OrderedDict
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import scipy
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from einops import rearrange
|
16 |
+
from einops.layers.torch import Rearrange
|
17 |
+
|
18 |
+
from timm.models.layers import DropPath, trunc_normal_
|
19 |
+
|
20 |
+
#from .registry import register_model
|
21 |
+
from config import config
|
22 |
+
|
23 |
+
from torchinfo import summary
|
24 |
+
import yaml
|
25 |
+
|
26 |
+
_model_entrypoints = {}
|
27 |
+
|
28 |
+
|
29 |
+
def register_model(fn):
|
30 |
+
module_name_split = fn.__module__.split('.')
|
31 |
+
model_name = module_name_split[-1]
|
32 |
+
|
33 |
+
_model_entrypoints[model_name] = fn
|
34 |
+
|
35 |
+
return fn
|
36 |
+
|
37 |
+
|
38 |
+
def model_entrypoints(model_name):
|
39 |
+
return _model_entrypoints[model_name]
|
40 |
+
|
41 |
+
|
42 |
+
def is_model(model_name):
|
43 |
+
return model_name in _model_entrypoints
|
44 |
+
|
45 |
+
|
46 |
+
# From PyTorch internals
|
47 |
+
def _ntuple(n):
|
48 |
+
def parse(x):
|
49 |
+
if isinstance(x, container_abcs.Iterable):
|
50 |
+
return x
|
51 |
+
return tuple(repeat(x, n))
|
52 |
+
|
53 |
+
return parse
|
54 |
+
|
55 |
+
|
56 |
+
to_1tuple = _ntuple(1)
|
57 |
+
to_2tuple = _ntuple(2)
|
58 |
+
to_3tuple = _ntuple(3)
|
59 |
+
to_4tuple = _ntuple(4)
|
60 |
+
to_ntuple = _ntuple
|
61 |
+
|
62 |
+
|
63 |
+
class LayerNorm(nn.LayerNorm):
|
64 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
65 |
+
|
66 |
+
def forward(self, x: torch.Tensor):
|
67 |
+
orig_type = x.dtype
|
68 |
+
ret = super().forward(x.type(torch.float32))
|
69 |
+
return ret.type(orig_type)
|
70 |
+
|
71 |
+
|
72 |
+
class QuickGELU(nn.Module):
|
73 |
+
def forward(self, x: torch.Tensor):
|
74 |
+
return x * torch.sigmoid(1.702 * x)
|
75 |
+
|
76 |
+
|
77 |
+
class Mlp(nn.Module):
|
78 |
+
def __init__(self,
|
79 |
+
in_features,
|
80 |
+
hidden_features=None,
|
81 |
+
out_features=None,
|
82 |
+
act_layer=nn.GELU,
|
83 |
+
drop=0.):
|
84 |
+
super().__init__()
|
85 |
+
out_features = out_features or in_features
|
86 |
+
hidden_features = hidden_features or in_features
|
87 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
88 |
+
self.act = act_layer()
|
89 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
90 |
+
self.drop = nn.Dropout(drop)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
x = self.fc1(x)
|
94 |
+
x = self.act(x)
|
95 |
+
x = self.drop(x)
|
96 |
+
x = self.fc2(x)
|
97 |
+
x = self.drop(x)
|
98 |
+
return x
|
99 |
+
|
100 |
+
|
101 |
+
class Attention(nn.Module):
|
102 |
+
def __init__(self,
|
103 |
+
dim_in,
|
104 |
+
dim_out,
|
105 |
+
num_heads,
|
106 |
+
qkv_bias=False,
|
107 |
+
attn_drop=0.,
|
108 |
+
proj_drop=0.,
|
109 |
+
method='dw_bn',
|
110 |
+
kernel_size=3,
|
111 |
+
stride_kv=1,
|
112 |
+
stride_q=1,
|
113 |
+
padding_kv=1,
|
114 |
+
padding_q=1,
|
115 |
+
with_cls_token=True,
|
116 |
+
**kwargs
|
117 |
+
):
|
118 |
+
super().__init__()
|
119 |
+
self.stride_kv = stride_kv
|
120 |
+
self.stride_q = stride_q
|
121 |
+
self.dim = dim_out
|
122 |
+
self.num_heads = num_heads
|
123 |
+
# head_dim = self.qkv_dim // num_heads
|
124 |
+
self.scale = dim_out ** -0.5
|
125 |
+
self.with_cls_token = with_cls_token
|
126 |
+
|
127 |
+
self.conv_proj_q = self._build_projection(
|
128 |
+
dim_in, dim_out, kernel_size, padding_q,
|
129 |
+
stride_q, 'linear' if method == 'avg' else method
|
130 |
+
)
|
131 |
+
self.conv_proj_k = self._build_projection(
|
132 |
+
dim_in, dim_out, kernel_size, padding_kv,
|
133 |
+
stride_kv, method
|
134 |
+
)
|
135 |
+
self.conv_proj_v = self._build_projection(
|
136 |
+
dim_in, dim_out, kernel_size, padding_kv,
|
137 |
+
stride_kv, method
|
138 |
+
)
|
139 |
+
|
140 |
+
self.proj_q = nn.Linear(dim_in, dim_out, bias=qkv_bias)
|
141 |
+
self.proj_k = nn.Linear(dim_in, dim_out, bias=qkv_bias)
|
142 |
+
self.proj_v = nn.Linear(dim_in, dim_out, bias=qkv_bias)
|
143 |
+
|
144 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
145 |
+
self.proj = nn.Linear(dim_out, dim_out)
|
146 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
147 |
+
|
148 |
+
def _build_projection(self,
|
149 |
+
dim_in,
|
150 |
+
dim_out,
|
151 |
+
kernel_size,
|
152 |
+
padding,
|
153 |
+
stride,
|
154 |
+
method):
|
155 |
+
if method == 'dw_bn':
|
156 |
+
proj = nn.Sequential(OrderedDict([
|
157 |
+
('conv', nn.Conv2d(
|
158 |
+
dim_in,
|
159 |
+
dim_in,
|
160 |
+
kernel_size=kernel_size,
|
161 |
+
padding=padding,
|
162 |
+
stride=stride,
|
163 |
+
bias=False,
|
164 |
+
groups=dim_in
|
165 |
+
)),
|
166 |
+
('bn', nn.BatchNorm2d(dim_in)),
|
167 |
+
('rearrage', Rearrange('b c h w -> b (h w) c')),
|
168 |
+
]))
|
169 |
+
elif method == 'avg':
|
170 |
+
proj = nn.Sequential(OrderedDict([
|
171 |
+
('avg', nn.AvgPool2d(
|
172 |
+
kernel_size=kernel_size,
|
173 |
+
padding=padding,
|
174 |
+
stride=stride,
|
175 |
+
ceil_mode=True
|
176 |
+
)),
|
177 |
+
('rearrage', Rearrange('b c h w -> b (h w) c')),
|
178 |
+
]))
|
179 |
+
elif method == 'linear':
|
180 |
+
proj = None
|
181 |
+
else:
|
182 |
+
raise ValueError('Unknown method ({})'.format(method))
|
183 |
+
|
184 |
+
return proj
|
185 |
+
|
186 |
+
def forward_conv(self, x, h, w):
|
187 |
+
if self.with_cls_token:
|
188 |
+
cls_token, x = torch.split(x, [1, h*w], 1)
|
189 |
+
|
190 |
+
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
191 |
+
|
192 |
+
if self.conv_proj_q is not None:
|
193 |
+
q = self.conv_proj_q(x)
|
194 |
+
else:
|
195 |
+
q = rearrange(x, 'b c h w -> b (h w) c')
|
196 |
+
|
197 |
+
if self.conv_proj_k is not None:
|
198 |
+
k = self.conv_proj_k(x)
|
199 |
+
else:
|
200 |
+
k = rearrange(x, 'b c h w -> b (h w) c')
|
201 |
+
|
202 |
+
if self.conv_proj_v is not None:
|
203 |
+
v = self.conv_proj_v(x)
|
204 |
+
else:
|
205 |
+
v = rearrange(x, 'b c h w -> b (h w) c')
|
206 |
+
|
207 |
+
if self.with_cls_token:
|
208 |
+
q = torch.cat((cls_token, q), dim=1)
|
209 |
+
k = torch.cat((cls_token, k), dim=1)
|
210 |
+
v = torch.cat((cls_token, v), dim=1)
|
211 |
+
|
212 |
+
return q, k, v
|
213 |
+
|
214 |
+
def forward(self, x, h, w):
|
215 |
+
if (
|
216 |
+
self.conv_proj_q is not None
|
217 |
+
or self.conv_proj_k is not None
|
218 |
+
or self.conv_proj_v is not None
|
219 |
+
):
|
220 |
+
q, k, v = self.forward_conv(x, h, w)
|
221 |
+
|
222 |
+
q = rearrange(self.proj_q(q), 'b t (h d) -> b h t d', h=self.num_heads)
|
223 |
+
k = rearrange(self.proj_k(k), 'b t (h d) -> b h t d', h=self.num_heads)
|
224 |
+
v = rearrange(self.proj_v(v), 'b t (h d) -> b h t d', h=self.num_heads)
|
225 |
+
|
226 |
+
attn_score = torch.einsum('bhlk,bhtk->bhlt', [q, k]) * self.scale
|
227 |
+
attn = F.softmax(attn_score, dim=-1)
|
228 |
+
attn = self.attn_drop(attn)
|
229 |
+
|
230 |
+
x = torch.einsum('bhlt,bhtv->bhlv', [attn, v])
|
231 |
+
x = rearrange(x, 'b h t d -> b t (h d)')
|
232 |
+
|
233 |
+
x = self.proj(x)
|
234 |
+
x = self.proj_drop(x)
|
235 |
+
|
236 |
+
return x
|
237 |
+
|
238 |
+
@staticmethod
|
239 |
+
def compute_macs(module, input, output):
|
240 |
+
# T: num_token
|
241 |
+
# S: num_token
|
242 |
+
input = input[0]
|
243 |
+
flops = 0
|
244 |
+
|
245 |
+
_, T, C = input.shape
|
246 |
+
H = W = int(np.sqrt(T-1)) if module.with_cls_token else int(np.sqrt(T))
|
247 |
+
|
248 |
+
H_Q = H / module.stride_q
|
249 |
+
W_Q = H / module.stride_q
|
250 |
+
T_Q = H_Q * W_Q + 1 if module.with_cls_token else H_Q * W_Q
|
251 |
+
|
252 |
+
H_KV = H / module.stride_kv
|
253 |
+
W_KV = W / module.stride_kv
|
254 |
+
T_KV = H_KV * W_KV + 1 if module.with_cls_token else H_KV * W_KV
|
255 |
+
|
256 |
+
# C = module.dim
|
257 |
+
# S = T
|
258 |
+
# Scaled-dot-product macs
|
259 |
+
# [B x T x C] x [B x C x T] --> [B x T x S]
|
260 |
+
# multiplication-addition is counted as 1 because operations can be fused
|
261 |
+
flops += T_Q * T_KV * module.dim
|
262 |
+
# [B x T x S] x [B x S x C] --> [B x T x C]
|
263 |
+
flops += T_Q * module.dim * T_KV
|
264 |
+
|
265 |
+
if (
|
266 |
+
hasattr(module, 'conv_proj_q')
|
267 |
+
and hasattr(module.conv_proj_q, 'conv')
|
268 |
+
):
|
269 |
+
params = sum(
|
270 |
+
[
|
271 |
+
p.numel()
|
272 |
+
for p in module.conv_proj_q.conv.parameters()
|
273 |
+
]
|
274 |
+
)
|
275 |
+
flops += params * H_Q * W_Q
|
276 |
+
|
277 |
+
if (
|
278 |
+
hasattr(module, 'conv_proj_k')
|
279 |
+
and hasattr(module.conv_proj_k, 'conv')
|
280 |
+
):
|
281 |
+
params = sum(
|
282 |
+
[
|
283 |
+
p.numel()
|
284 |
+
for p in module.conv_proj_k.conv.parameters()
|
285 |
+
]
|
286 |
+
)
|
287 |
+
flops += params * H_KV * W_KV
|
288 |
+
|
289 |
+
if (
|
290 |
+
hasattr(module, 'conv_proj_v')
|
291 |
+
and hasattr(module.conv_proj_v, 'conv')
|
292 |
+
):
|
293 |
+
params = sum(
|
294 |
+
[
|
295 |
+
p.numel()
|
296 |
+
for p in module.conv_proj_v.conv.parameters()
|
297 |
+
]
|
298 |
+
)
|
299 |
+
flops += params * H_KV * W_KV
|
300 |
+
|
301 |
+
params = sum([p.numel() for p in module.proj_q.parameters()])
|
302 |
+
flops += params * T_Q
|
303 |
+
params = sum([p.numel() for p in module.proj_k.parameters()])
|
304 |
+
flops += params * T_KV
|
305 |
+
params = sum([p.numel() for p in module.proj_v.parameters()])
|
306 |
+
flops += params * T_KV
|
307 |
+
params = sum([p.numel() for p in module.proj.parameters()])
|
308 |
+
flops += params * T
|
309 |
+
|
310 |
+
module.__flops__ += flops
|
311 |
+
|
312 |
+
|
313 |
+
class Block(nn.Module):
|
314 |
+
|
315 |
+
def __init__(self,
|
316 |
+
dim_in,
|
317 |
+
dim_out,
|
318 |
+
num_heads,
|
319 |
+
mlp_ratio=4.,
|
320 |
+
qkv_bias=False,
|
321 |
+
drop=0.,
|
322 |
+
attn_drop=0.,
|
323 |
+
drop_path=0.,
|
324 |
+
act_layer=nn.GELU,
|
325 |
+
norm_layer=nn.LayerNorm,
|
326 |
+
**kwargs):
|
327 |
+
super().__init__()
|
328 |
+
|
329 |
+
self.with_cls_token = kwargs['with_cls_token']
|
330 |
+
|
331 |
+
self.norm1 = norm_layer(dim_in)
|
332 |
+
self.attn = Attention(
|
333 |
+
dim_in, dim_out, num_heads, qkv_bias, attn_drop, drop,
|
334 |
+
**kwargs
|
335 |
+
)
|
336 |
+
|
337 |
+
self.drop_path = DropPath(drop_path) \
|
338 |
+
if drop_path > 0. else nn.Identity()
|
339 |
+
self.norm2 = norm_layer(dim_out)
|
340 |
+
|
341 |
+
dim_mlp_hidden = int(dim_out * mlp_ratio)
|
342 |
+
self.mlp = Mlp(
|
343 |
+
in_features=dim_out,
|
344 |
+
hidden_features=dim_mlp_hidden,
|
345 |
+
act_layer=act_layer,
|
346 |
+
drop=drop
|
347 |
+
)
|
348 |
+
|
349 |
+
def forward(self, x, h, w):
|
350 |
+
res = x
|
351 |
+
|
352 |
+
x = self.norm1(x)
|
353 |
+
attn = self.attn(x, h, w)
|
354 |
+
x = res + self.drop_path(attn)
|
355 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
356 |
+
|
357 |
+
return x
|
358 |
+
|
359 |
+
|
360 |
+
class ConvEmbed(nn.Module):
|
361 |
+
""" Image to Conv Embedding
|
362 |
+
|
363 |
+
"""
|
364 |
+
|
365 |
+
def __init__(self,
|
366 |
+
patch_size=7,
|
367 |
+
in_chans=1, #1 for spectrogram, 3 for rgb image
|
368 |
+
embed_dim=64,
|
369 |
+
stride=4,
|
370 |
+
padding=2,
|
371 |
+
norm_layer=None):
|
372 |
+
super().__init__()
|
373 |
+
patch_size = to_2tuple(patch_size)
|
374 |
+
self.patch_size = patch_size
|
375 |
+
|
376 |
+
self.proj = nn.Conv2d(
|
377 |
+
in_chans, embed_dim,
|
378 |
+
kernel_size=patch_size,
|
379 |
+
stride=stride,
|
380 |
+
padding=padding
|
381 |
+
)
|
382 |
+
self.norm = norm_layer(embed_dim) if norm_layer else None
|
383 |
+
|
384 |
+
def forward(self, x):
|
385 |
+
x = self.proj(x)
|
386 |
+
|
387 |
+
B, C, H, W = x.shape
|
388 |
+
x = rearrange(x, 'b c h w -> b (h w) c')
|
389 |
+
if self.norm:
|
390 |
+
x = self.norm(x)
|
391 |
+
x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
|
392 |
+
|
393 |
+
return x
|
394 |
+
|
395 |
+
|
396 |
+
class VisionTransformer(nn.Module):
|
397 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
398 |
+
"""
|
399 |
+
def __init__(self,
|
400 |
+
patch_size=16,
|
401 |
+
patch_stride=16,
|
402 |
+
patch_padding=0,
|
403 |
+
in_chans=1, #1for spectrogram, 3 for RGB
|
404 |
+
embed_dim=768,
|
405 |
+
depth=12,
|
406 |
+
num_heads=12,
|
407 |
+
mlp_ratio=4.,
|
408 |
+
qkv_bias=False,
|
409 |
+
drop_rate=0.,
|
410 |
+
attn_drop_rate=0.,
|
411 |
+
drop_path_rate=0.,
|
412 |
+
act_layer=nn.GELU,
|
413 |
+
norm_layer=nn.LayerNorm,
|
414 |
+
init='trunc_norm',
|
415 |
+
**kwargs):
|
416 |
+
super().__init__()
|
417 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
418 |
+
|
419 |
+
self.rearrage = None
|
420 |
+
|
421 |
+
self.patch_embed = ConvEmbed(
|
422 |
+
# img_size=img_size,
|
423 |
+
patch_size=patch_size,
|
424 |
+
in_chans=in_chans,
|
425 |
+
stride=patch_stride,
|
426 |
+
padding=patch_padding,
|
427 |
+
embed_dim=embed_dim,
|
428 |
+
norm_layer=norm_layer
|
429 |
+
)
|
430 |
+
|
431 |
+
with_cls_token = kwargs['with_cls_token']
|
432 |
+
if with_cls_token:
|
433 |
+
self.cls_token = nn.Parameter(
|
434 |
+
torch.zeros(1, 1, embed_dim)
|
435 |
+
)
|
436 |
+
else:
|
437 |
+
self.cls_token = None
|
438 |
+
|
439 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
440 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
441 |
+
|
442 |
+
blocks = []
|
443 |
+
for j in range(depth):
|
444 |
+
blocks.append(
|
445 |
+
Block(
|
446 |
+
dim_in=embed_dim,
|
447 |
+
dim_out=embed_dim,
|
448 |
+
num_heads=num_heads,
|
449 |
+
mlp_ratio=mlp_ratio,
|
450 |
+
qkv_bias=qkv_bias,
|
451 |
+
drop=drop_rate,
|
452 |
+
attn_drop=attn_drop_rate,
|
453 |
+
drop_path=dpr[j],
|
454 |
+
act_layer=act_layer,
|
455 |
+
norm_layer=norm_layer,
|
456 |
+
**kwargs
|
457 |
+
)
|
458 |
+
)
|
459 |
+
self.blocks = nn.ModuleList(blocks)
|
460 |
+
|
461 |
+
if self.cls_token is not None:
|
462 |
+
trunc_normal_(self.cls_token, std=.02)
|
463 |
+
|
464 |
+
if init == 'xavier':
|
465 |
+
self.apply(self._init_weights_xavier)
|
466 |
+
else:
|
467 |
+
self.apply(self._init_weights_trunc_normal)
|
468 |
+
|
469 |
+
def _init_weights_trunc_normal(self, m):
|
470 |
+
if isinstance(m, nn.Linear):
|
471 |
+
logging.info('=> init weight of Linear from trunc norm')
|
472 |
+
trunc_normal_(m.weight, std=0.02)
|
473 |
+
if m.bias is not None:
|
474 |
+
logging.info('=> init bias of Linear to zeros')
|
475 |
+
nn.init.constant_(m.bias, 0)
|
476 |
+
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
477 |
+
nn.init.constant_(m.bias, 0)
|
478 |
+
nn.init.constant_(m.weight, 1.0)
|
479 |
+
|
480 |
+
def _init_weights_xavier(self, m):
|
481 |
+
if isinstance(m, nn.Linear):
|
482 |
+
logging.info('=> init weight of Linear from xavier uniform')
|
483 |
+
nn.init.xavier_uniform_(m.weight)
|
484 |
+
if m.bias is not None:
|
485 |
+
logging.info('=> init bias of Linear to zeros')
|
486 |
+
nn.init.constant_(m.bias, 0)
|
487 |
+
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
488 |
+
nn.init.constant_(m.bias, 0)
|
489 |
+
nn.init.constant_(m.weight, 1.0)
|
490 |
+
|
491 |
+
def forward(self, x):
|
492 |
+
x = self.patch_embed(x)
|
493 |
+
B, C, H, W = x.size()
|
494 |
+
|
495 |
+
x = rearrange(x, 'b c h w -> b (h w) c')
|
496 |
+
|
497 |
+
cls_tokens = None
|
498 |
+
if self.cls_token is not None:
|
499 |
+
# stole cls_tokens impl from Phil Wang, thanks
|
500 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
501 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
502 |
+
|
503 |
+
x = self.pos_drop(x)
|
504 |
+
|
505 |
+
for i, blk in enumerate(self.blocks):
|
506 |
+
x = blk(x, H, W)
|
507 |
+
|
508 |
+
if self.cls_token is not None:
|
509 |
+
cls_tokens, x = torch.split(x, [1, H*W], 1)
|
510 |
+
x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
|
511 |
+
|
512 |
+
return x, cls_tokens
|
513 |
+
|
514 |
+
|
515 |
+
class ConvolutionalVisionTransformer(nn.Module):
|
516 |
+
def __init__(self,
|
517 |
+
in_chans=1, #3 for RGB, 1 for Spectrogram
|
518 |
+
num_classes=1000,
|
519 |
+
act_layer=nn.GELU,
|
520 |
+
norm_layer=nn.LayerNorm,
|
521 |
+
init='trunc_norm',
|
522 |
+
spec=None):
|
523 |
+
super().__init__()
|
524 |
+
self.num_classes = num_classes
|
525 |
+
|
526 |
+
self.num_stages = spec['NUM_STAGES']
|
527 |
+
for i in range(self.num_stages):
|
528 |
+
kwargs = {
|
529 |
+
'patch_size': spec['PATCH_SIZE'][i],
|
530 |
+
'patch_stride': spec['PATCH_STRIDE'][i],
|
531 |
+
'patch_padding': spec['PATCH_PADDING'][i],
|
532 |
+
'embed_dim': spec['DIM_EMBED'][i],
|
533 |
+
'depth': spec['DEPTH'][i],
|
534 |
+
'num_heads': spec['NUM_HEADS'][i],
|
535 |
+
'mlp_ratio': spec['MLP_RATIO'][i],
|
536 |
+
'qkv_bias': spec['QKV_BIAS'][i],
|
537 |
+
'drop_rate': spec['DROP_RATE'][i],
|
538 |
+
'attn_drop_rate': spec['ATTN_DROP_RATE'][i],
|
539 |
+
'drop_path_rate': spec['DROP_PATH_RATE'][i],
|
540 |
+
'with_cls_token': spec['CLS_TOKEN'][i],
|
541 |
+
'method': spec['QKV_PROJ_METHOD'][i],
|
542 |
+
'kernel_size': spec['KERNEL_QKV'][i],
|
543 |
+
'padding_q': spec['PADDING_Q'][i],
|
544 |
+
'padding_kv': spec['PADDING_KV'][i],
|
545 |
+
'stride_kv': spec['STRIDE_KV'][i],
|
546 |
+
'stride_q': spec['STRIDE_Q'][i],
|
547 |
+
}
|
548 |
+
|
549 |
+
stage = VisionTransformer(
|
550 |
+
in_chans=in_chans,
|
551 |
+
init=init,
|
552 |
+
act_layer=act_layer,
|
553 |
+
norm_layer=norm_layer,
|
554 |
+
**kwargs
|
555 |
+
)
|
556 |
+
setattr(self, f'stage{i}', stage)
|
557 |
+
|
558 |
+
in_chans = spec['DIM_EMBED'][i]
|
559 |
+
|
560 |
+
dim_embed = spec['DIM_EMBED'][-1]
|
561 |
+
self.norm = norm_layer(dim_embed)
|
562 |
+
self.cls_token = spec['CLS_TOKEN'][-1]
|
563 |
+
|
564 |
+
# Classifier head
|
565 |
+
#self.head = nn.Linear(dim_embed, num_classes) if num_classes > 0 else nn.Identity()
|
566 |
+
#trunc_normal_(self.head.weight, std=0.02)
|
567 |
+
self.head = nn.Identity()
|
568 |
+
|
569 |
+
|
570 |
+
|
571 |
+
def init_weights(self, pretrained='', pretrained_layers=[], verbose=True):
|
572 |
+
if os.path.isfile(pretrained):
|
573 |
+
pretrained_dict = torch.load(pretrained, map_location='cpu')
|
574 |
+
logging.info(f'=> loading pretrained model {pretrained}')
|
575 |
+
model_dict = self.state_dict()
|
576 |
+
pretrained_dict = {
|
577 |
+
k: v for k, v in pretrained_dict.items()
|
578 |
+
if k in model_dict.keys()
|
579 |
+
}
|
580 |
+
need_init_state_dict = {}
|
581 |
+
for k, v in pretrained_dict.items():
|
582 |
+
need_init = (
|
583 |
+
k.split('.')[0] in pretrained_layers
|
584 |
+
#or pretrained_layers[0] is '*'
|
585 |
+
or pretrained_layers[0] == '*'
|
586 |
+
)
|
587 |
+
if need_init:
|
588 |
+
if verbose:
|
589 |
+
logging.info(f'=> init {k} from {pretrained}')
|
590 |
+
if 'pos_embed' in k and v.size() != model_dict[k].size():
|
591 |
+
size_pretrained = v.size()
|
592 |
+
size_new = model_dict[k].size()
|
593 |
+
logging.info(
|
594 |
+
'=> load_pretrained: resized variant: {} to {}'
|
595 |
+
.format(size_pretrained, size_new)
|
596 |
+
)
|
597 |
+
|
598 |
+
ntok_new = size_new[1]
|
599 |
+
ntok_new -= 1
|
600 |
+
|
601 |
+
posemb_tok, posemb_grid = v[:, :1], v[0, 1:]
|
602 |
+
|
603 |
+
gs_old = int(np.sqrt(len(posemb_grid)))
|
604 |
+
gs_new = int(np.sqrt(ntok_new))
|
605 |
+
|
606 |
+
logging.info(
|
607 |
+
'=> load_pretrained: grid-size from {} to {}'
|
608 |
+
.format(gs_old, gs_new)
|
609 |
+
)
|
610 |
+
|
611 |
+
posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
|
612 |
+
zoom = (gs_new / gs_old, gs_new / gs_old, 1)
|
613 |
+
posemb_grid = scipy.ndimage.zoom(
|
614 |
+
posemb_grid, zoom, order=1
|
615 |
+
)
|
616 |
+
posemb_grid = posemb_grid.reshape(1, gs_new ** 2, -1)
|
617 |
+
v = torch.tensor(
|
618 |
+
np.concatenate([posemb_tok, posemb_grid], axis=1)
|
619 |
+
)
|
620 |
+
|
621 |
+
need_init_state_dict[k] = v
|
622 |
+
self.load_state_dict(need_init_state_dict, strict=False)
|
623 |
+
|
624 |
+
@torch.jit.ignore
|
625 |
+
def no_weight_decay(self):
|
626 |
+
layers = set()
|
627 |
+
for i in range(self.num_stages):
|
628 |
+
layers.add(f'stage{i}.pos_embed')
|
629 |
+
layers.add(f'stage{i}.cls_token')
|
630 |
+
|
631 |
+
return layers
|
632 |
+
|
633 |
+
def forward_features(self, x):
|
634 |
+
for i in range(self.num_stages):
|
635 |
+
x, cls_tokens = getattr(self, f'stage{i}')(x)
|
636 |
+
|
637 |
+
if self.cls_token:
|
638 |
+
x = self.norm(cls_tokens)
|
639 |
+
#x = cls_tokens
|
640 |
+
x = torch.squeeze(x)
|
641 |
+
else:
|
642 |
+
x = rearrange(x, 'b c h w -> b (h w) c')
|
643 |
+
x = self.norm(x)
|
644 |
+
x = torch.mean(x, dim=1)
|
645 |
+
|
646 |
+
return x
|
647 |
+
|
648 |
+
def forward(self, x):
|
649 |
+
x = self.forward_features(x)
|
650 |
+
x = self.head(x)
|
651 |
+
|
652 |
+
return x
|
653 |
+
|
654 |
+
|
655 |
+
@register_model
|
656 |
+
def get_cls_model(**kwargs):
|
657 |
+
msvit_spec = config.MODEL.SPEC
|
658 |
+
msvit = ConvolutionalVisionTransformer(
|
659 |
+
in_chans=1, #1 for spectrogram 3 for RGB
|
660 |
+
num_classes=config.MODEL.NUM_CLASSES,
|
661 |
+
act_layer=QuickGELU,
|
662 |
+
norm_layer=partial(LayerNorm, eps=1e-5),
|
663 |
+
init=getattr(msvit_spec, 'INIT', 'trunc_norm'),
|
664 |
+
spec=msvit_spec
|
665 |
+
)
|
666 |
+
|
667 |
+
# if config.MODEL.INIT_WEIGHTS:
|
668 |
+
# msvit.init_weights(
|
669 |
+
# config.MODEL.PRETRAINED,
|
670 |
+
# config.MODEL.PRETRAINED_LAYERS,
|
671 |
+
# config.VERBOSE
|
672 |
+
# )
|
673 |
+
|
674 |
+
return msvit
|
675 |
+
|
676 |
+
def build_model(config, **kwargs):
|
677 |
+
model_name = config.MODEL.NAME
|
678 |
+
if not is_model(model_name):
|
679 |
+
raise ValueError(f'Unkown model: {model_name}')
|
680 |
+
|
681 |
+
return model_entrypoints(model_name)(config, **kwargs)
|
682 |
+
|
683 |
+
def cvt13(**kwargs):
|
684 |
+
f = open('config/cvt-13-224x224.yaml', 'r')
|
685 |
+
config = yaml.safe_load(f)
|
686 |
+
return ConvolutionalVisionTransformer(spec=config['MODEL']['SPEC']) # only loades the config, no pretraining
|
687 |
+
|
688 |
+
if __name__ == '__main__':
|
689 |
+
f = open('config/cvt-13-224x224.yaml', 'r')
|
690 |
+
config = yaml.safe_load(f)
|
691 |
+
model = ConvolutionalVisionTransformer(spec=config['MODEL']['SPEC'])
|
692 |
+
print(summary(model))
|
693 |
+
quit()
|
694 |
+
print(summary(model, input_size=(4, 1, 128, 301)))
|
melspectrogram.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchaudio import transforms as T
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
MEAN, STD = 0.5347, 0.0772 # Xeno-Canto stats
|
6 |
+
SR = 16000
|
7 |
+
NFFT = 1024
|
8 |
+
HOPLEN = 320
|
9 |
+
NMELS = 128
|
10 |
+
FMIN = 50
|
11 |
+
FMAX = 8000
|
12 |
+
|
13 |
+
class Normalization(torch.nn.Module):
|
14 |
+
def __init__(self):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
return (x - x.min()) / (x.max() - x.min())
|
19 |
+
|
20 |
+
class Standardization(torch.nn.Module):
|
21 |
+
def __init__(self, mean, std):
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
self.mean = mean
|
25 |
+
self.std = std
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
return (x - self.mean) / self.std
|
29 |
+
|
30 |
+
class MelSpectrogramProcessor:
|
31 |
+
def __init__(self, sample_rate=SR, n_mels=NMELS, n_fft=NFFT, hop_length=HOPLEN, f_min=FMIN, f_max=FMAX):
|
32 |
+
self.transform = nn.Sequential(
|
33 |
+
T.MelSpectrogram(sample_rate=sample_rate, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length, f_min=f_min, f_max=f_max),
|
34 |
+
T.AmplitudeToDB(),
|
35 |
+
Normalization(),
|
36 |
+
Standardization(mean=MEAN, std=STD),
|
37 |
+
)
|
38 |
+
|
39 |
+
def process(self, waveform):
|
40 |
+
return self.transform(waveform)
|
protoclr.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b7eaf62e2f66084f50cbb9677420a3a93b81334fadddeb4eb9790e734f4a514f
|
3 |
+
size 78717724
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
einops==0.8.0
|
2 |
+
numpy==2.1.3
|
3 |
+
PyYAML==6.0.2
|
4 |
+
scipy==1.14.1
|
5 |
+
timm==1.0.11
|
6 |
+
torch==2.2.2
|
7 |
+
torchaudio==2.2.2
|
8 |
+
torchinfo==1.8.0
|
9 |
+
yacs==0.1.8
|