xiziwang commited on
Commit
2e36228
1 Parent(s): a9c14c0

push files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +61 -0
  2. __pycache__/dataLoader_multiperson.cpython-37.pyc +0 -0
  3. __pycache__/loconet.cpython-37.pyc +0 -0
  4. __pycache__/loss_multi.cpython-37.pyc +0 -0
  5. __pycache__/talkNet_config_multi.cpython-37.pyc +0 -0
  6. builder.py +95 -0
  7. configs/multi.yaml +51 -0
  8. dataLoaderTalkSet.py +182 -0
  9. dataLoader_multiperson.py +402 -0
  10. dlhammer/.gitignore +3 -0
  11. dlhammer/LICENSE +201 -0
  12. dlhammer/README.md +2 -0
  13. dlhammer/dlhammer/.ipynb_checkpoints/argparser-checkpoint.py +110 -0
  14. dlhammer/dlhammer/.ipynb_checkpoints/bootstrap-checkpoint.py +33 -0
  15. dlhammer/dlhammer/__init__.py +1 -0
  16. dlhammer/dlhammer/argparser.py +109 -0
  17. dlhammer/dlhammer/bootstrap.py +33 -0
  18. dlhammer/dlhammer/logger.py +66 -0
  19. dlhammer/dlhammer/test/config.yml +32 -0
  20. dlhammer/dlhammer/test/test_args.py +20 -0
  21. dlhammer/dlhammer/test/test_logger.py +22 -0
  22. dlhammer/dlhammer/utils/__init__.py +0 -0
  23. dlhammer/dlhammer/utils/misc.py +125 -0
  24. dlhammer/dlhammer/utils/system.py +25 -0
  25. environment.yml +298 -0
  26. legacy/talkNet_multi_multicard.py +124 -0
  27. legacy/talkNet_multicard.py +146 -0
  28. legacy/talkNet_orig.py +102 -0
  29. legacy/trainTalkNet_multicard.py +171 -0
  30. legacy/train_multi.py +156 -0
  31. loconet.py +182 -0
  32. loss_multi.py +72 -0
  33. metrics/AverageMeter.py +18 -0
  34. metrics/__pycache__/.nfs000000035f4a8257000000eb +0 -0
  35. metrics/__pycache__/AverageMeter.cpython-36.pyc +0 -0
  36. metrics/__pycache__/AverageMeter.cpython-38.pyc +0 -0
  37. metrics/__pycache__/accuracy.cpython-36.pyc +0 -0
  38. metrics/__pycache__/accuracy.cpython-38.pyc +0 -0
  39. metrics/accuracy.py +20 -0
  40. model/.DS_Store +0 -0
  41. model/__init__.py +5 -0
  42. model/__pycache__/__init__.cpython-36.pyc +0 -0
  43. model/__pycache__/__init__.cpython-37.pyc +0 -0
  44. model/__pycache__/attentionLayer.cpython-37.pyc +0 -0
  45. model/__pycache__/convLayer.cpython-37.pyc +0 -0
  46. model/__pycache__/loconet_encoder.cpython-37.pyc +0 -0
  47. model/__pycache__/position_encoding.cpython-36.pyc +0 -0
  48. model/__pycache__/talkNetModel.cpython-37.pyc +0 -0
  49. model/__pycache__/transformer.cpython-36.pyc +0 -0
  50. model/__pycache__/utils.cpython-36.pyc +0 -0
README.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## LoCoNet: Long-Short Context Network for Active Speaker Detection
2
+
3
+
4
+
5
+ ### Dependencies
6
+
7
+ Start from building the environment
8
+ ```
9
+ conda env create -f requirements.yml
10
+ conda activate loconet
11
+ ```
12
+ export PYTHONPATH=**project_dir**/dlhammer:$PYTHONPATH
13
+ and replace **project_dir** with your code base location
14
+
15
+
16
+
17
+ ### Data preparation
18
+
19
+ We follow TalkNet's data preparation script to download and prepare the AVA dataset.
20
+
21
+ ```
22
+ python train.py --dataPathAVA AVADataPath --download
23
+ ```
24
+
25
+ `AVADataPath` is the folder you want to save the AVA dataset and its preprocessing outputs, the details can be found in [here](https://github.com/TaoRuijie/TalkNet_ASD/blob/main/utils/tools.py#L34) . Please read them carefully.
26
+
27
+ After AVA dataset is downloaded, please change the DATA.dataPathAVA entry in the config file.
28
+
29
+ #### Training script
30
+ ```
31
+ python -W ignore::UserWarning train.py --cfg configs/multi.yaml OUTPUT_DIR <output directory>
32
+ ```
33
+
34
+
35
+
36
+ #### Pretrained model
37
+
38
+ Please download the LoCoNet trained weights on AVA dataset [here](https://drive.google.com/file/d/1EX-V464jCD6S-wg68yGuAa-UcsMrw8mK/view?usp=sharing).
39
+
40
+ ```
41
+ python -W ignore::UserWarning test_multicard.py --cfg configs/multi.yaml RESUME_PATH {model download path}
42
+ ```
43
+
44
+ ### Citation
45
+
46
+ Please cite the following if our paper or code is helpful to your research.
47
+ ```
48
+ @article{wang2023loconet,
49
+ title={LoCoNet: Long-Short Context Network for Active Speaker Detection},
50
+ author={Wang, Xizi and Cheng, Feng and Bertasius, Gedas and Crandall, David},
51
+ journal={arXiv preprint arXiv:2301.08237},
52
+ year={2023}
53
+ }
54
+ ```
55
+
56
+
57
+ ### Acknowledge
58
+
59
+ The code base of this project is studied from [TalkNet](https://github.com/TaoRuijie/TalkNet-ASD) which is a very easy-to-use ASD pipeline.
60
+
61
+
__pycache__/dataLoader_multiperson.cpython-37.pyc ADDED
Binary file (10.8 kB). View file
 
__pycache__/loconet.cpython-37.pyc ADDED
Binary file (6.26 kB). View file
 
__pycache__/loss_multi.cpython-37.pyc ADDED
Binary file (2.61 kB). View file
 
__pycache__/talkNet_config_multi.cpython-37.pyc ADDED
Binary file (6.59 kB). View file
 
builder.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #================================================================
3
+ # Don't go gently into that good night.
4
+ #
5
+ # author: klaus
6
+ # description:
7
+ #
8
+ #================================================================
9
+
10
+ import warnings
11
+
12
+ from mmcv.cnn import MODELS as MMCV_MODELS
13
+ from mmcv.utils import Registry
14
+
15
+ from mmaction.utils import import_module_error_func
16
+
17
+ MODELS = Registry('models', parent=MMCV_MODELS)
18
+ BACKBONES = MODELS
19
+ NECKS = MODELS
20
+ HEADS = MODELS
21
+ RECOGNIZERS = MODELS
22
+ LOSSES = MODELS
23
+ LOCALIZERS = MODELS
24
+
25
+ try:
26
+ from mmdet.models.builder import DETECTORS, build_detector
27
+ except (ImportError, ModuleNotFoundError):
28
+ # Define an empty registry and building func, so that can import
29
+ DETECTORS = MODELS
30
+
31
+ @import_module_error_func('mmdet')
32
+ def build_detector(cfg, train_cfg, test_cfg):
33
+ pass
34
+
35
+
36
+ def build_backbone(cfg):
37
+ """Build backbone."""
38
+ return BACKBONES.build(cfg)
39
+
40
+
41
+ def build_head(cfg):
42
+ """Build head."""
43
+ return HEADS.build(cfg)
44
+
45
+
46
+ def build_recognizer(cfg, train_cfg=None, test_cfg=None):
47
+ """Build recognizer."""
48
+ if train_cfg is not None or test_cfg is not None:
49
+ warnings.warn(
50
+ 'train_cfg and test_cfg is deprecated, '
51
+ 'please specify them in model. Details see this '
52
+ 'PR: https://github.com/open-mmlab/mmaction2/pull/629', UserWarning)
53
+ assert cfg.get(
54
+ 'train_cfg'
55
+ ) is None or train_cfg is None, 'train_cfg specified in both outer field and model field' # noqa: E501
56
+ assert cfg.get(
57
+ 'test_cfg'
58
+ ) is None or test_cfg is None, 'test_cfg specified in both outer field and model field ' # noqa: E501
59
+ return RECOGNIZERS.build(cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
60
+
61
+
62
+ def build_loss(cfg):
63
+ """Build loss."""
64
+ return LOSSES.build(cfg)
65
+
66
+
67
+ def build_localizer(cfg):
68
+ """Build localizer."""
69
+ return LOCALIZERS.build(cfg)
70
+
71
+
72
+ def build_model(cfg, train_cfg=None, test_cfg=None):
73
+ """Build model."""
74
+ args = cfg.copy()
75
+ obj_type = args.pop('type')
76
+ if obj_type in LOCALIZERS:
77
+ return build_localizer(cfg)
78
+ if obj_type in RECOGNIZERS:
79
+ return build_recognizer(cfg, train_cfg, test_cfg)
80
+ if obj_type in DETECTORS:
81
+ if train_cfg is not None or test_cfg is not None:
82
+ warnings.warn(
83
+ 'train_cfg and test_cfg is deprecated, '
84
+ 'please specify them in model. Details see this '
85
+ 'PR: https://github.com/open-mmlab/mmaction2/pull/629', UserWarning)
86
+ return build_detector(cfg, train_cfg, test_cfg)
87
+ model_in_mmdet = ['FastRCNN']
88
+ if obj_type in model_in_mmdet:
89
+ raise ImportError('Please install mmdet for spatial temporal detection tasks.')
90
+ raise ValueError(f'{obj_type} is not registered in ' 'LOCALIZERS, RECOGNIZERS or DETECTORS')
91
+
92
+
93
+ def build_neck(cfg):
94
+ """Build neck."""
95
+ return NECKS.build(cfg)
configs/multi.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SEED: "20210617"
2
+ NUM_GPUS: 4
3
+ NUM_WORKERS: 6
4
+ LOG_NAME: 'config.txt'
5
+ OUTPUT_DIR: '/nfs/joltik/data/ssd/xiziwang/TalkNet_models/' # savePath
6
+ evalDataType: "val"
7
+ downloadAVA: False
8
+ evaluation: False
9
+ RESUME: False
10
+ RESUME_PATH: ""
11
+ RESUME_EPOCH: 0
12
+
13
+ DATA:
14
+ dataPathAVA: '/nfs/jolteon/data/ssd/xiziwang/AVA_dataset/'
15
+
16
+ DATALOADER:
17
+ nDataLoaderThread: 4
18
+
19
+
20
+ SOLVER:
21
+ OPTIMIZER: "adam"
22
+ BASE_LR: 5e-5
23
+ SCHEDULER:
24
+ NAME: "multistep"
25
+ GAMMA: 0.95
26
+
27
+ MODEL:
28
+ NUM_SPEAKERS: 3
29
+ CLIP_LENGTH: 200
30
+ AV: "speaker_temporal"
31
+ AV_layers: 3
32
+ ADJUST_ATTENTION: 0
33
+
34
+ TRAIN:
35
+ BATCH_SIZE: 1
36
+ MAX_EPOCH: 25
37
+ AUDIO_AUG: 1
38
+ TEST_INTERVAL: 1
39
+ TRAINER_GPU: 4
40
+
41
+
42
+ VAL:
43
+ BATCH_SIZE: 1
44
+
45
+ TEST:
46
+ BATCH_SIZE: 1
47
+ DATASET: 'seen'
48
+ MODEL: 'unseen'
49
+
50
+
51
+
dataLoaderTalkSet.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch, numpy, cv2, imageio, random, python_speech_features
2
+ import matplotlib.pyplot as plt
3
+ from scipy.io import wavfile
4
+ from glob import glob
5
+ from torchvision.transforms import RandomCrop
6
+ from scipy import signal
7
+
8
+ def get_noise_list(musanPath, rirPath):
9
+ augment_files = glob(os.path.join(musanPath, '*/*/*/*.wav'))
10
+ noiselist = {}
11
+ rir = numpy.load(rirPath)
12
+ for file in augment_files:
13
+ if not file.split('/')[-4] in noiselist:
14
+ noiselist[file.split('/')[-4]] = []
15
+ noiselist[file.split('/')[-4]].append(file)
16
+ return rir, noiselist
17
+
18
+ def augment_wav(audio, aug_type, rir, noiselist):
19
+ if aug_type == 'rir':
20
+ rir_gains = numpy.random.uniform(-7,3,1)
21
+ rir_filts = random.choice(rir)
22
+ rir = numpy.multiply(rir_filts, pow(10, 0.1 * rir_gains))
23
+ audio = signal.convolve(audio, rir, mode='full')[:len(audio)]
24
+ else:
25
+ noisecat = aug_type
26
+ noisefile = random.choice(noiselist[noisecat].copy())
27
+ snr = [random.uniform({'noise':[0,15],'music':[5,15]}[noisecat][0], {'noise':[0,15],'music':[5,15]}[noisecat][1])]
28
+ _, noiseaudio = wavfile.read(noisefile)
29
+ if len(noiseaudio) < len(audio):
30
+ shortage = len(audio) - len(noiseaudio)
31
+ noiseaudio = numpy.pad(noiseaudio, (0, shortage), 'wrap')
32
+ else:
33
+ noiseaudio = noiseaudio[:len(audio)]
34
+
35
+ noise_db = 10 * numpy.log10(numpy.mean(abs(noiseaudio ** 2)) + 1e-4)
36
+ clean_db = 10 * numpy.log10(numpy.mean(abs(audio ** 2)) + 1e-4)
37
+ noise = numpy.sqrt(10 ** ((clean_db - noise_db - snr) / 10)) * noiseaudio
38
+ audio = audio + noise
39
+ return audio.astype(numpy.int16)
40
+
41
+ def load_audio(data, data_path, length, start, end, audio_aug, rirlist = None, noiselist = None):
42
+ # Find the path of the audio data
43
+ data_type = data[0]
44
+ id_name = data[1][:8]
45
+ file_name = data[1].split('/')[0] + '_' + data[1].split('/')[1] + '_' + data[1].split('/')[2] + \
46
+ '_' + data[2].split('/')[0] + '_' + data[2].split('/')[1] + '_' + data[2].split('/')[2] + '.wav'
47
+ audio_file_path = os.path.join(data_path, data_type, id_name, file_name)
48
+ # Load audio, compute MFCC, cut it to the required length
49
+ _, audio = wavfile.read(audio_file_path)
50
+
51
+ if audio_aug == True:
52
+ augtype = random.randint(0,3)
53
+ if augtype == 1: # rir
54
+ audio = augment_wav(audio, 'rir', rirlist, noiselist)
55
+ elif augtype == 2:
56
+ audio = augment_wav(audio, 'noise', rirlist, noiselist)
57
+ elif augtype == 3:
58
+ audio = augment_wav(audio, 'music', rirlist, noiselist)
59
+ else:
60
+ audio = audio
61
+
62
+ feature = python_speech_features.mfcc(audio, 16000, numcep = 13, winlen = 0.025, winstep = 0.010)
63
+ length_audio = int(round(length * 100))
64
+ if feature.shape[0] < length_audio:
65
+ shortage = length_audio - feature.shape[0]
66
+ feature = numpy.pad(feature, ((0, shortage), (0,0)), 'wrap')
67
+ feature = feature[int(round(start * 100)):int(round(end * 100)),:]
68
+ return feature
69
+
70
+ def load_video(data, data_path, length, start, end, visual_aug):
71
+ # Find the path of the visual data
72
+ data_type = data[0]
73
+ id_name = data[1][:8]
74
+ file_name = data[1].split('/')[0] + '_' + data[1].split('/')[1] + '_' + data[1].split('/')[2] + \
75
+ '_' + data[2].split('/')[0] + '_' + data[2].split('/')[1] + '_' + data[2].split('/')[2] + '.mp4'
76
+ video_file_path = os.path.join(data_path, data_type, id_name, file_name)
77
+ # Load visual frame-by-frame, cut it to the required length
78
+ length_video = int(round((end - start) * 25))
79
+ video = cv2.VideoCapture(video_file_path)
80
+ faces = []
81
+ augtype = 'orig'
82
+
83
+ if visual_aug == True:
84
+ new = int(112*random.uniform(0.7, 1))
85
+ x, y = numpy.random.randint(0, 112 - new), numpy.random.randint(0, 112 - new)
86
+ M = cv2.getRotationMatrix2D((112/2,112/2), random.uniform(-15, 15), 1)
87
+ augtype = random.choice(['orig', 'flip', 'crop', 'rotate'])
88
+
89
+ num_frame = 0
90
+ while video.isOpened():
91
+ ret, frames = video.read()
92
+ if ret == True:
93
+ num_frame += 1
94
+ if num_frame >= int(round(start * 25)) and num_frame < int(round(end * 25)):
95
+ face = cv2.cvtColor(frames, cv2.COLOR_BGR2GRAY)
96
+ face = cv2.resize(face, (224,224))
97
+ face = face[int(112-(112/2)):int(112+(112/2)), int(112-(112/2)):int(112+(112/2))]
98
+ if augtype == 'orig':
99
+ faces.append(face)
100
+ elif augtype == 'flip':
101
+ faces.append(cv2.flip(face, 1))
102
+ elif augtype == 'crop':
103
+ faces.append(cv2.resize(face[y:y+new, x:x+new] , (112,112)))
104
+ elif augtype == 'rotate':
105
+ faces.append(cv2.warpAffine(face, M, (112,112)))
106
+ else:
107
+ break
108
+ video.release()
109
+ faces = numpy.array(faces)
110
+ if faces.shape[0] < length_video:
111
+ shortage = length_video - faces.shape[0]
112
+ faces = numpy.pad(faces, ((0,shortage), (0,0),(0,0)), 'wrap')
113
+ # faces = numpy.array(faces)[int(round(start * 25)):int(round(end * 25)),:,:]
114
+ return faces
115
+
116
+ def load_label(data, length, start, end):
117
+ labels_all = []
118
+ labels = []
119
+ data_type = data[0]
120
+ start_T, end_T, start_F, end_F = float(data[4]), float(data[5]), float(data[6]), float(data[7])
121
+ for i in range(int(round(length * 100))):
122
+ if data_type == 'TAudio':
123
+ labels_all.append(1)
124
+ elif data_type == 'FAudio' or data_type == 'FSilence':
125
+ labels_all.append(0)
126
+ else:
127
+ if i >= int(round(start_T * 100)) and i <= int(round(end_T * 100)):
128
+ labels_all.append(1)
129
+ else:
130
+ labels_all.append(0)
131
+ for i in range(int(round(length * 25))):
132
+ labels.append(int(round(sum(labels_all[i*4: (i+1)*4]) / 4)))
133
+ return labels[round(start*25): round(end*25)]
134
+
135
+ class loader_TalkSet(object):
136
+ def __init__(self, trial_file_name, data_path, audio_aug, visual_aug, musanPath, rirPath,**kwargs):
137
+ self.data_path = data_path
138
+ self.audio_aug = audio_aug
139
+ self.visual_aug = visual_aug
140
+ self.minibatch = []
141
+ self.rir, self.noiselist = get_noise_list(musanPath, rirPath)
142
+ mix_lst = open(trial_file_name).read().splitlines()
143
+ mix_lst = list(filter(lambda x: float(x.split()[3]) >= 1, mix_lst)) # filter the video less than 1s
144
+ # mix_lst = list(filter(lambda x: x.split()[0] == 'TSilence', mix_lst))
145
+ sorted_mix_lst = sorted(mix_lst, key=lambda data: (float(data.split()[3]), int(data.split()[-1])), reverse=True)
146
+ start = 0
147
+ while True:
148
+ length_total = float(sorted_mix_lst[start].split()[3])
149
+ batch_size = int(250 / length_total)
150
+ end = min(len(sorted_mix_lst), start + batch_size)
151
+ self.minibatch.append(sorted_mix_lst[start:end])
152
+ if end == len(sorted_mix_lst):
153
+ break
154
+ start = end
155
+ # self.minibatch = self.minibatch[0:5]
156
+
157
+ def __getitem__(self, index):
158
+ batch_lst = self.minibatch[index]
159
+ length_total = float(batch_lst[-1].split()[3])
160
+ length_total = (int(round(length_total * 100)) - int(round(length_total * 100)) % 4) / 100
161
+ audio_feature, video_feature, labels = [], [], []
162
+ duration = random.choice([1,2,4,6])
163
+ #duration = 6
164
+ length = min(length_total, duration)
165
+ if length == duration:
166
+ start = int(round(random.randint(0, round(length_total * 25) - round(length * 25)) * 0.04 * 100)) / 100
167
+ end = int(round((start + length) * 100)) / 100
168
+ else:
169
+ start, end = 0, length
170
+
171
+ for line in batch_lst:
172
+ data = line.split()
173
+ audio_feature.append(load_audio(data, self.data_path, length_total, start, end, audio_aug = self.audio_aug, rirlist = self.rir, noiselist = self.noiselist))
174
+ video_feature.append(load_video(data, self.data_path, length_total, start, end, visual_aug = self.visual_aug))
175
+ labels.append(load_label(data, length_total, start, end))
176
+
177
+ return torch.FloatTensor(numpy.array(audio_feature)), \
178
+ torch.FloatTensor(numpy.array(video_feature)), \
179
+ torch.LongTensor(numpy.array(labels))
180
+
181
+ def __len__(self):
182
+ return len(self.minibatch)
dataLoader_multiperson.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch, numpy, cv2, random, glob, python_speech_features, json, math
2
+ from scipy.io import wavfile
3
+ from torchvision.transforms import RandomCrop
4
+ from operator import itemgetter
5
+ from torchvggish import vggish_input, vggish_params, mel_features
6
+
7
+
8
+ def overlap(audio, noiseAudio):
9
+ snr = [random.uniform(-5, 5)]
10
+ if len(noiseAudio) < len(audio):
11
+ shortage = len(audio) - len(noiseAudio)
12
+ noiseAudio = numpy.pad(noiseAudio, (0, shortage), 'wrap')
13
+ else:
14
+ noiseAudio = noiseAudio[:len(audio)]
15
+ noiseDB = 10 * numpy.log10(numpy.mean(abs(noiseAudio**2)) + 1e-4)
16
+ cleanDB = 10 * numpy.log10(numpy.mean(abs(audio**2)) + 1e-4)
17
+ noiseAudio = numpy.sqrt(10**((cleanDB - noiseDB - snr) / 10)) * noiseAudio
18
+ audio = audio + noiseAudio
19
+ return audio.astype(numpy.int16)
20
+
21
+
22
+ def load_audio(data, dataPath, numFrames, audioAug, audioSet=None):
23
+ dataName = data[0]
24
+ fps = float(data[2])
25
+ audio = audioSet[dataName]
26
+ if audioAug == True:
27
+ augType = random.randint(0, 1)
28
+ if augType == 1:
29
+ audio = overlap(dataName, audio, audioSet)
30
+ else:
31
+ audio = audio
32
+ # fps is not always 25, in order to align the visual, we modify the window and step in MFCC extraction process based on fps
33
+ audio = python_speech_features.mfcc(audio,
34
+ 16000,
35
+ numcep=13,
36
+ winlen=0.025 * 25 / fps,
37
+ winstep=0.010 * 25 / fps)
38
+ maxAudio = int(numFrames * 4)
39
+ if audio.shape[0] < maxAudio:
40
+ shortage = maxAudio - audio.shape[0]
41
+ audio = numpy.pad(audio, ((0, shortage), (0, 0)), 'wrap')
42
+ audio = audio[:int(round(numFrames * 4)), :]
43
+ return audio
44
+
45
+
46
+ def load_single_audio(audio, fps, numFrames, audioAug=False):
47
+ audio = python_speech_features.mfcc(audio,
48
+ 16000,
49
+ numcep=13,
50
+ winlen=0.025 * 25 / fps,
51
+ winstep=0.010 * 25 / fps)
52
+ maxAudio = int(numFrames * 4)
53
+ if audio.shape[0] < maxAudio:
54
+ shortage = maxAudio - audio.shape[0]
55
+ audio = numpy.pad(audio, ((0, shortage), (0, 0)), 'wrap')
56
+ audio = audio[:int(round(numFrames * 4)), :]
57
+ return audio
58
+
59
+
60
+ def load_visual(data, dataPath, numFrames, visualAug):
61
+ dataName = data[0]
62
+ videoName = data[0][:11]
63
+ faceFolderPath = os.path.join(dataPath, videoName, dataName)
64
+ faceFiles = glob.glob("%s/*.jpg" % faceFolderPath)
65
+ sortedFaceFiles = sorted(faceFiles,
66
+ key=lambda data: (float(data.split('/')[-1][:-4])),
67
+ reverse=False)
68
+ faces = []
69
+ H = 112
70
+ if visualAug == True:
71
+ new = int(H * random.uniform(0.7, 1))
72
+ x, y = numpy.random.randint(0, H - new), numpy.random.randint(0, H - new)
73
+ M = cv2.getRotationMatrix2D((H / 2, H / 2), random.uniform(-15, 15), 1)
74
+ augType = random.choice(['orig', 'flip', 'crop', 'rotate'])
75
+ else:
76
+ augType = 'orig'
77
+ for faceFile in sortedFaceFiles[:numFrames]:
78
+ face = cv2.imread(faceFile)
79
+
80
+ face = cv2.cvtColor(face, cv2.COLOR_BGR2GRAY)
81
+ face = cv2.resize(face, (H, H))
82
+ if augType == 'orig':
83
+ faces.append(face)
84
+ elif augType == 'flip':
85
+ faces.append(cv2.flip(face, 1))
86
+ elif augType == 'crop':
87
+ faces.append(cv2.resize(face[y:y + new, x:x + new], (H, H)))
88
+ elif augType == 'rotate':
89
+ faces.append(cv2.warpAffine(face, M, (H, H)))
90
+ faces = numpy.array(faces)
91
+ return faces
92
+
93
+
94
+ def load_label(data, numFrames):
95
+ res = []
96
+ labels = data[3].replace('[', '').replace(']', '')
97
+ labels = labels.split(',')
98
+ for label in labels:
99
+ res.append(int(label))
100
+ res = numpy.array(res[:numFrames])
101
+ return res
102
+
103
+
104
+ class train_loader(object):
105
+
106
+ def __init__(self, cfg, trialFileName, audioPath, visualPath, num_speakers):
107
+ self.cfg = cfg
108
+ self.audioPath = audioPath
109
+ self.visualPath = visualPath
110
+ self.candidate_speakers = num_speakers
111
+ self.path = os.path.join(cfg.DATA.dataPathAVA, "csv")
112
+ self.entity_data = json.load(open(os.path.join(self.path, 'train_entity.json')))
113
+ self.ts_to_entity = json.load(open(os.path.join(self.path, 'train_ts.json')))
114
+ self.mixLst = open(trialFileName).read().splitlines()
115
+ self.list_length = len(self.mixLst)
116
+ random.shuffle(self.mixLst)
117
+
118
+ def load_single_audio(self, audio, fps, numFrames, audioAug=False, aug_audio=None):
119
+ if audioAug:
120
+ augType = random.randint(0, 1)
121
+ if augType == 1:
122
+ audio = overlap(audio, aug_audio)
123
+ else:
124
+ audio = audio
125
+
126
+ res = vggish_input.waveform_to_examples(audio, 16000, numFrames, fps, return_tensor=False)
127
+ return res
128
+
129
+ def load_visual_label_mask(self, videoName, entityName, target_ts, context_ts, visualAug=True):
130
+
131
+ faceFolderPath = os.path.join(self.visualPath, videoName, entityName)
132
+
133
+ faces = []
134
+ H = 112
135
+ if visualAug == True:
136
+ new = int(H * random.uniform(0.7, 1))
137
+ x, y = numpy.random.randint(0, H - new), numpy.random.randint(0, H - new)
138
+ M = cv2.getRotationMatrix2D((H / 2, H / 2), random.uniform(-15, 15), 1)
139
+ augType = random.choice(['orig', 'flip', 'crop', 'rotate'])
140
+ else:
141
+ augType = 'orig'
142
+ labels_dict = self.entity_data[videoName][entityName]
143
+ labels = numpy.zeros(len(target_ts))
144
+ mask = numpy.zeros(len(target_ts))
145
+
146
+ for i, time in enumerate(target_ts):
147
+ if time not in context_ts:
148
+ faces.append(numpy.zeros((H, H)))
149
+ else:
150
+ labels[i] = labels_dict[time]
151
+ mask[i] = 1
152
+ time = "%.2f" % float(time)
153
+ faceFile = os.path.join(faceFolderPath, str(time) + '.jpg')
154
+
155
+ face = cv2.imread(faceFile)
156
+
157
+ face = cv2.cvtColor(face, cv2.COLOR_BGR2GRAY)
158
+ face = cv2.resize(face, (H, H))
159
+ if augType == 'orig':
160
+ faces.append(face)
161
+ elif augType == 'flip':
162
+ faces.append(cv2.flip(face, 1))
163
+ elif augType == 'crop':
164
+ faces.append(cv2.resize(face[y:y + new, x:x + new], (H, H)))
165
+ elif augType == 'rotate':
166
+ faces.append(cv2.warpAffine(face, M, (H, H)))
167
+ faces = numpy.array(faces)
168
+ return faces, labels, mask
169
+
170
+ def get_speaker_context(self, videoName, target_entity, all_ts, center_ts):
171
+
172
+ context_speakers = list(self.ts_to_entity[videoName][center_ts])
173
+ context = {}
174
+ chosen_speakers = []
175
+ context[target_entity] = all_ts
176
+ context_speakers.remove(target_entity)
177
+ num_frames = len(all_ts)
178
+ for candidate in context_speakers:
179
+ candidate_ts = self.entity_data[videoName][candidate]
180
+ shared_ts = set(all_ts).intersection(set(candidate_ts))
181
+ if (len(shared_ts) > (num_frames / 2)):
182
+ context[candidate] = shared_ts
183
+ chosen_speakers.append(candidate)
184
+ context_speakers = chosen_speakers
185
+ random.shuffle(context_speakers)
186
+ if not context_speakers:
187
+ context_speakers.insert(0, target_entity) # make sure is at 0
188
+ while len(context_speakers) < self.candidate_speakers:
189
+ context_speakers.append(random.choice(context_speakers))
190
+ elif len(context_speakers) < self.candidate_speakers:
191
+ context_speakers.insert(0, target_entity) # make sure is at 0
192
+ while len(context_speakers) < self.candidate_speakers:
193
+ context_speakers.append(random.choice(context_speakers[1:]))
194
+ else:
195
+ context_speakers.insert(0, target_entity) # make sure is at 0
196
+ context_speakers = context_speakers[:self.candidate_speakers]
197
+
198
+ assert set(context_speakers).issubset(set(list(context.keys()))), target_entity
199
+ assert target_entity in context_speakers, target_entity
200
+
201
+ return context_speakers, context
202
+
203
+ def __getitem__(self, index):
204
+
205
+ target_video = self.mixLst[index]
206
+ data = target_video.split('\t')
207
+ fps = float(data[2])
208
+ videoName = data[0][:11]
209
+ target_entity = data[0]
210
+ all_ts = list(self.entity_data[videoName][target_entity].keys())
211
+ numFrames = int(data[1])
212
+ assert numFrames == len(all_ts)
213
+
214
+ center_ts = all_ts[math.floor(numFrames / 2)]
215
+
216
+ # get context speakers which have more than half time overlapped with target speaker
217
+ context_speakers, context = self.get_speaker_context(videoName, target_entity, all_ts,
218
+ center_ts)
219
+
220
+ if self.cfg.TRAIN.AUDIO_AUG:
221
+ other_indices = list(range(0, index)) + list(range(index + 1, self.list_length))
222
+ augment_entity = self.mixLst[random.choice(other_indices)]
223
+ augment_data = augment_entity.split('\t')
224
+ augment_entity = augment_data[0]
225
+ augment_videoname = augment_data[0][:11]
226
+ aug_sr, aug_audio = wavfile.read(
227
+ os.path.join(self.audioPath, augment_videoname, augment_entity + '.wav'))
228
+ else:
229
+ aug_audio = None
230
+
231
+ audio_path = os.path.join(self.audioPath, videoName, target_entity + '.wav')
232
+ sr, audio = wavfile.read(os.path.join(self.audioPath, videoName, target_entity + '.wav'))
233
+ audio = self.load_single_audio(audio,
234
+ fps,
235
+ numFrames,
236
+ audioAug=self.cfg.TRAIN.AUDIO_AUG,
237
+ aug_audio=aug_audio)
238
+
239
+ visualFeatures, labels, masks = [], [], []
240
+
241
+ # target_label = list(self.entity_data[videoName][target_entity].values())
242
+ visual, target_labels, target_masks = self.load_visual_label_mask(
243
+ videoName, target_entity, all_ts, all_ts)
244
+
245
+ for idx, context_entity in enumerate(context_speakers):
246
+ if context_entity == target_entity:
247
+ label = target_labels
248
+ visualfeat = visual
249
+ mask = target_masks
250
+ else:
251
+ visualfeat, label, mask = self.load_visual_label_mask(videoName, context_entity,
252
+ all_ts,
253
+ context[context_entity])
254
+ visualFeatures.append(visualfeat)
255
+ labels.append(label)
256
+ masks.append(mask)
257
+
258
+ audio = torch.FloatTensor(audio)[None, :, :]
259
+ visualFeatures = torch.FloatTensor(numpy.array(visualFeatures))
260
+ audio_t = audio.shape[1]
261
+ video_t = visualFeatures.shape[1]
262
+ if audio_t != video_t * 4:
263
+ print(visualFeatures.shape, audio.shape, videoName, target_entity, numFrames)
264
+ labels = torch.LongTensor(numpy.array(labels))
265
+ masks = torch.LongTensor(numpy.array(masks))
266
+ print(audio.shape)
267
+ return audio, visualFeatures, labels, masks
268
+
269
+ def __len__(self):
270
+ return len(self.mixLst)
271
+
272
+
273
+ class val_loader(object):
274
+
275
+ def __init__(self, cfg, trialFileName, audioPath, visualPath, num_speakers):
276
+ self.cfg = cfg
277
+ self.audioPath = audioPath
278
+ self.visualPath = visualPath
279
+ self.candidate_speakers = num_speakers
280
+ self.path = os.path.join(cfg.DATA.dataPathAVA, "csv")
281
+ self.entity_data = json.load(open(os.path.join(self.path, 'val_entity.json')))
282
+ self.ts_to_entity = json.load(open(os.path.join(self.path, 'val_ts.json')))
283
+ self.mixLst = open(trialFileName).read().splitlines()
284
+
285
+ def load_single_audio(self, audio, fps, numFrames, audioAug=False, aug_audio=None):
286
+
287
+ res = vggish_input.waveform_to_examples(audio, 16000, numFrames, fps, return_tensor=False)
288
+ return res
289
+
290
+ def load_visual_label_mask(self, videoName, entityName, target_ts, context_ts):
291
+
292
+ faceFolderPath = os.path.join(self.visualPath, videoName, entityName)
293
+
294
+ faces = []
295
+ H = 112
296
+ labels_dict = self.entity_data[videoName][entityName]
297
+ labels = numpy.zeros(len(target_ts))
298
+ mask = numpy.zeros(len(target_ts))
299
+
300
+ for i, time in enumerate(target_ts):
301
+ if time not in context_ts:
302
+ faces.append(numpy.zeros((H, H)))
303
+ else:
304
+ labels[i] = labels_dict[time]
305
+ mask[i] = 1
306
+ time = "%.2f" % float(time)
307
+ faceFile = os.path.join(faceFolderPath, str(time) + '.jpg')
308
+
309
+ face = cv2.imread(faceFile)
310
+ face = cv2.cvtColor(face, cv2.COLOR_BGR2GRAY)
311
+ face = cv2.resize(face, (H, H))
312
+ faces.append(face)
313
+ faces = numpy.array(faces)
314
+ return faces, labels, mask
315
+
316
+ def get_speaker_context(self, videoName, target_entity, all_ts, center_ts):
317
+
318
+ context_speakers = list(self.ts_to_entity[videoName][center_ts])
319
+ context = {}
320
+ chosen_speakers = []
321
+ context[target_entity] = all_ts
322
+ context_speakers.remove(target_entity)
323
+ num_frames = len(all_ts)
324
+ for candidate in context_speakers:
325
+ candidate_ts = self.entity_data[videoName][candidate]
326
+ shared_ts = set(all_ts).intersection(set(candidate_ts))
327
+ context[candidate] = shared_ts
328
+ chosen_speakers.append(candidate)
329
+ # if (len(shared_ts) > (num_frames / 2)):
330
+ # context[candidate] = shared_ts
331
+ # chosen_speakers.append(candidate)
332
+ context_speakers = chosen_speakers
333
+ random.shuffle(context_speakers)
334
+ if not context_speakers:
335
+ context_speakers.insert(0, target_entity) # make sure is at 0
336
+ while len(context_speakers) < self.candidate_speakers:
337
+ context_speakers.append(random.choice(context_speakers))
338
+ elif len(context_speakers) < self.candidate_speakers:
339
+ context_speakers.insert(0, target_entity) # make sure is at 0
340
+ while len(context_speakers) < self.candidate_speakers:
341
+ context_speakers.append(random.choice(context_speakers[1:]))
342
+ else:
343
+ context_speakers.insert(0, target_entity) # make sure is at 0
344
+ context_speakers = context_speakers[:self.candidate_speakers]
345
+
346
+ assert set(context_speakers).issubset(set(list(context.keys()))), target_entity
347
+
348
+ return context_speakers, context
349
+
350
+ def __getitem__(self, index):
351
+
352
+ target_video = self.mixLst[index]
353
+ data = target_video.split('\t')
354
+ fps = float(data[2])
355
+ videoName = data[0][:11]
356
+ target_entity = data[0]
357
+ all_ts = list(self.entity_data[videoName][target_entity].keys())
358
+ numFrames = int(data[1])
359
+ # print(numFrames, len(all_ts))
360
+ assert numFrames == len(all_ts)
361
+
362
+ center_ts = all_ts[math.floor(numFrames / 2)]
363
+
364
+ # get context speakers which have more than half time overlapped with target speaker
365
+ context_speakers, context = self.get_speaker_context(videoName, target_entity, all_ts,
366
+ center_ts)
367
+
368
+ sr, audio = wavfile.read(os.path.join(self.audioPath, videoName, target_entity + '.wav'))
369
+ audio = self.load_single_audio(audio, fps, numFrames, audioAug=False)
370
+
371
+ visualFeatures, labels, masks = [], [], []
372
+
373
+ # target_label = list(self.entity_data[videoName][target_entity].values())
374
+ target_visual, target_labels, target_masks = self.load_visual_label_mask(
375
+ videoName, target_entity, all_ts, all_ts)
376
+
377
+ for idx, context_entity in enumerate(context_speakers):
378
+ if context_entity == target_entity:
379
+ label = target_labels
380
+ visualfeat = target_visual
381
+ mask = target_masks
382
+ else:
383
+ visualfeat, label, mask = self.load_visual_label_mask(videoName, context_entity,
384
+ all_ts,
385
+ context[context_entity])
386
+ visualFeatures.append(visualfeat)
387
+ labels.append(label)
388
+ masks.append(mask)
389
+
390
+ audio = torch.FloatTensor(audio)[None, :, :]
391
+ visualFeatures = torch.FloatTensor(numpy.array(visualFeatures))
392
+ audio_t = audio.shape[1]
393
+ video_t = visualFeatures.shape[1]
394
+ if audio_t != video_t * 4:
395
+ print(visualFeatures.shape, audio.shape, videoName, target_entity, numFrames)
396
+ labels = torch.LongTensor(numpy.array(labels))
397
+ masks = torch.LongTensor(numpy.array(masks))
398
+
399
+ return audio, visualFeatures, labels, masks
400
+
401
+ def __len__(self):
402
+ return len(self.mixLst)
dlhammer/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.log
2
+ .vim-arsync
3
+ __pycache__/
dlhammer/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
dlhammer/README.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # dl-hammer
2
+ tools for deep learning coding.
dlhammer/dlhammer/.ipynb_checkpoints/argparser-checkpoint.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #================================================================
3
+ # Don't go gently into that good night.
4
+ #
5
+ # author: klaus
6
+ # description:
7
+ #
8
+ #================================================================
9
+
10
+ import os
11
+ import argparse
12
+ import datetime
13
+ from functools import partial
14
+ import yaml
15
+ from easydict import EasyDict
16
+
17
+ # from .utils import get_vacant_gpu
18
+ from .logger import bootstrap_logger, logger
19
+ from .utils.system import get_available_gpuids
20
+ from .utils.misc import merge_dict, merge_opts, to_string, eval_dict_leaf
21
+
22
+ CONFIG = EasyDict()
23
+
24
+ BASE_CONFIG = {
25
+ 'OUTPUT_DIR': './workspace',
26
+ 'SESSION': 'base',
27
+ 'NUM_GPUS': 1,
28
+ 'LOG_NAME': 'log.txt'
29
+ }
30
+
31
+
32
+ def bootstrap_args(default_params=None):
33
+ """get the params from yaml file and args. The args will override arguemnts in the yaml file.
34
+ Returns: EasyDict instance.
35
+
36
+ """
37
+ parser = define_default_arg_parser()
38
+ cfg = update_config(parser, default_params)
39
+ create_workspace(cfg) #create workspace
40
+
41
+ CONFIG.update(cfg)
42
+ bootstrap_logger(get_logfile(CONFIG)) # setup logger
43
+ setup_gpu(CONFIG.NUM_GPUS) #setup gpu
44
+
45
+ return cfg
46
+
47
+
48
+ def setup_gpu(ngpu):
49
+ gpuids = get_available_gpuids()
50
+ # os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in gpuids[:ngpu]])
51
+
52
+
53
+ def get_logfile(config):
54
+ return os.path.join(config.WORKSPACE, config.LOG_NAME)
55
+
56
+
57
+ def define_default_arg_parser():
58
+ """Define a default arg_parser.
59
+
60
+ Returns:
61
+ A argparse.ArgumentParser. More arguments can be added.
62
+
63
+ """
64
+ parser = argparse.ArgumentParser()
65
+ parser.add_argument('--cfg', help='load configs from yaml file', default='', type=str)
66
+ parser.add_argument('opts',
67
+ default=None,
68
+ nargs='*',
69
+ help='modify config options using the command-line')
70
+
71
+ return parser
72
+
73
+
74
+ def update_config(arg_parser, default_config=None):
75
+ """ update argparser to args.
76
+
77
+ Args:
78
+ arg_parser: argparse.ArgumentParser.
79
+ """
80
+
81
+ parsed, unknown = arg_parser.parse_known_args()
82
+ if default_config and parsed.cfg == "" and "cfg" in default_config:
83
+ parsed.cfg = default_config["cfg"]
84
+
85
+ config = EasyDict(BASE_CONFIG.copy())
86
+ config['cfg'] = parsed.cfg
87
+ # update default config
88
+ if default_config is not None:
89
+ config.update(default_config)
90
+
91
+ # merge config from yaml
92
+ if os.path.isfile(config.cfg):
93
+ with open(config.cfg, 'r') as f:
94
+ yml_config = yaml.full_load(f)
95
+ config = merge_dict(config, yml_config)
96
+
97
+ # merge opts
98
+ config = merge_opts(config, parsed.opts)
99
+
100
+ # eval values
101
+ config = eval_dict_leaf(config)
102
+
103
+ return config
104
+
105
+
106
+ def create_workspace(cfg):
107
+ cfg_name, ext = os.path.splitext(os.path.basename(cfg.cfg))
108
+ workspace = os.path.join(cfg.OUTPUT_DIR, cfg_name, cfg.SESSION)
109
+ os.makedirs(workspace, exist_ok=True)
110
+ cfg.WORKSPACE = workspace
dlhammer/dlhammer/.ipynb_checkpoints/bootstrap-checkpoint.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #================================================================
3
+ # Don't go gently into that good night.
4
+ #
5
+ # author: klaus
6
+ # description:
7
+ #
8
+ #================================================================
9
+
10
+ import sys
11
+ import logging
12
+
13
+ from .logger import bootstrap_logger, logger
14
+ from .argparser import bootstrap_args, CONFIG
15
+ from .utils.misc import to_string
16
+
17
+ __all__ = ['bootstrap', 'logger', 'CONFIG']
18
+
19
+
20
+ def bootstrap(default_cfg=None, print_cfg=True):
21
+ """TODO: Docstring for bootstrap.
22
+
23
+ Kwargs:
24
+ use_argparser (TODO): TODO
25
+ use_logger (TODO): TODO
26
+
27
+ Returns: TODO
28
+
29
+ """
30
+ config = bootstrap_args(default_cfg)
31
+ if print_cfg:
32
+ logger.info(to_string(config))
33
+ return config
dlhammer/dlhammer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .bootstrap import *
dlhammer/dlhammer/argparser.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #================================================================
3
+ # Don't go gently into that good night.
4
+ #
5
+ # author: klaus
6
+ # description:
7
+ #
8
+ #================================================================
9
+
10
+ import os
11
+ import argparse
12
+ import datetime
13
+ from functools import partial
14
+ import yaml
15
+ from easydict import EasyDict
16
+
17
+ # from .utils import get_vacant_gpu
18
+ from .logger import bootstrap_logger, logger
19
+ from .utils.system import get_available_gpuids
20
+ from .utils.misc import merge_dict, merge_opts, to_string, eval_dict_leaf
21
+
22
+ CONFIG = EasyDict()
23
+
24
+ BASE_CONFIG = {
25
+ 'OUTPUT_DIR': './workspace',
26
+ 'NUM_GPUS': 1,
27
+ 'LOG_NAME': 'log.txt'
28
+ }
29
+
30
+
31
+ def bootstrap_args(default_params=None):
32
+ """get the params from yaml file and args. The args will override arguemnts in the yaml file.
33
+ Returns: EasyDict instance.
34
+
35
+ """
36
+ parser = define_default_arg_parser()
37
+ cfg = update_config(parser, default_params)
38
+ create_workspace(cfg) #create workspace
39
+
40
+ CONFIG.update(cfg)
41
+ bootstrap_logger(get_logfile(CONFIG)) # setup logger
42
+ setup_gpu(CONFIG.NUM_GPUS) #setup gpu
43
+
44
+ return cfg
45
+
46
+
47
+ def setup_gpu(ngpu):
48
+ gpuids = get_available_gpuids()
49
+ # os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in gpuids[:ngpu]])
50
+
51
+
52
+ def get_logfile(config):
53
+ return os.path.join(config.WORKSPACE, config.LOG_NAME)
54
+
55
+
56
+ def define_default_arg_parser():
57
+ """Define a default arg_parser.
58
+
59
+ Returns:
60
+ A argparse.ArgumentParser. More arguments can be added.
61
+
62
+ """
63
+ parser = argparse.ArgumentParser()
64
+ parser.add_argument('--cfg', help='load configs from yaml file', default='', type=str)
65
+ parser.add_argument('opts',
66
+ default=None,
67
+ nargs='*',
68
+ help='modify config options using the command-line')
69
+
70
+ return parser
71
+
72
+
73
+ def update_config(arg_parser, default_config=None):
74
+ """ update argparser to args.
75
+
76
+ Args:
77
+ arg_parser: argparse.ArgumentParser.
78
+ """
79
+
80
+ parsed, unknown = arg_parser.parse_known_args()
81
+ if default_config and parsed.cfg == "" and "cfg" in default_config:
82
+ parsed.cfg = default_config["cfg"]
83
+
84
+ config = EasyDict(BASE_CONFIG.copy())
85
+ config['cfg'] = parsed.cfg
86
+ # update default config
87
+ if default_config is not None:
88
+ config.update(default_config)
89
+
90
+ # merge config from yaml
91
+ if os.path.isfile(config.cfg):
92
+ with open(config.cfg, 'r') as f:
93
+ yml_config = yaml.full_load(f)
94
+ config = merge_dict(config, yml_config)
95
+
96
+ # merge opts
97
+ config = merge_opts(config, parsed.opts)
98
+
99
+ # eval values
100
+ config = eval_dict_leaf(config)
101
+
102
+ return config
103
+
104
+
105
+ def create_workspace(cfg):
106
+ cfg_name, ext = os.path.splitext(os.path.basename(cfg.cfg))
107
+ workspace = os.path.join(cfg.OUTPUT_DIR)
108
+ os.makedirs(workspace, exist_ok=True)
109
+ cfg.WORKSPACE = workspace
dlhammer/dlhammer/bootstrap.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #================================================================
3
+ # Don't go gently into that good night.
4
+ #
5
+ # author: klaus
6
+ # description:
7
+ #
8
+ #================================================================
9
+
10
+ import sys
11
+ import logging
12
+
13
+ from .logger import bootstrap_logger, logger
14
+ from .argparser import bootstrap_args, CONFIG
15
+ from .utils.misc import to_string
16
+
17
+ __all__ = ['bootstrap', 'logger', 'CONFIG']
18
+
19
+
20
+ def bootstrap(default_cfg=None, print_cfg=True):
21
+ """TODO: Docstring for bootstrap.
22
+
23
+ Kwargs:
24
+ use_argparser (TODO): TODO
25
+ use_logger (TODO): TODO
26
+
27
+ Returns: TODO
28
+
29
+ """
30
+ config = bootstrap_args(default_cfg)
31
+ if print_cfg:
32
+ logger.info(to_string(config))
33
+ return config
dlhammer/dlhammer/logger.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #================================================================
3
+ # Don't go gently into that good night.
4
+ #
5
+ # author: klaus
6
+ # description:
7
+ #
8
+ #================================================================
9
+
10
+ import os
11
+ import sys
12
+ import logging
13
+
14
+ logger = logging.getLogger('DLHammer')
15
+
16
+
17
+ def bootstrap_logger(logfile=None, fmt=None):
18
+ """TODO: Docstring for bootstrap_logger.
19
+
20
+ Args:
21
+ logfile (str): file path logging to.
22
+
23
+ Kwargs:
24
+ fmt (TODO): TODO
25
+
26
+ Returns: TODO
27
+
28
+ """
29
+ if fmt is None:
30
+ # fmt = '%(asctime)s - %(levelname)-5s - [%(filename)s:%(lineno)d] %(message)s'
31
+ fmt = '%(message)s'
32
+ logging.basicConfig(level=logging.DEBUG, format=fmt)
33
+
34
+ #log to file
35
+ if logfile is not None:
36
+ formatter = logging.Formatter(fmt)
37
+ fh = logging.FileHandler(logfile)
38
+ fh.setLevel(logging.DEBUG)
39
+ fh.setFormatter(formatter)
40
+ logger.addHandler(fh)
41
+
42
+ # sys.stdout = LoggerWriter(sys.stdout, logger.info)
43
+ # sys.stderr = LoggerWriter(sys.stderr, logger.error)
44
+ return
45
+
46
+
47
+ class LoggerWriter(object):
48
+
49
+ def __init__(self, stream, logfct):
50
+ self.terminal = stream
51
+ self.logfct = logfct
52
+ self.buf = []
53
+
54
+ def write(self, msg):
55
+ if msg.endswith('\n'):
56
+ self.buf.append(msg.rstrip('\n'))
57
+
58
+ message = ''.join(self.buf)
59
+ self.logfct(message)
60
+
61
+ self.buf = []
62
+ else:
63
+ self.buf.append(msg)
64
+
65
+ def flush(self):
66
+ pass
dlhammer/dlhammer/test/config.yml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ a_int: 12
2
+ a_float: 1e-2
3
+ a_list: [0,1,2]
4
+ eval_list: eval(list(range(10)))
5
+ DATA:
6
+ PATH_TO_DATA_DIR: /home/ubuntu/data/kinetics/Mini-Kinetics-200
7
+ PATH_PREFIX: /home/ubuntu/data/kinetics/k400_ver3
8
+ NUM_FRAMES: 16
9
+ SAMPLING_RATE: 8
10
+ TARGET_FPS: 25
11
+ TRAIN_JITTER_SCALES: [256, 320]
12
+ TRAIN_CROP_SIZE: 224
13
+ TEST_CROP_SIZE: 224
14
+ INPUT_CHANNEL_NUM: [3]
15
+ SOLVER:
16
+ BACKBONE:
17
+ OPTIMIZER: sgd
18
+ MOMENTUM: 0.9
19
+ BASE_LR: 1e-3
20
+ SCHEDULER:
21
+ NAME: warmup_multistep
22
+ MILESTONES: [13, 24]
23
+ WARMUP_EPOCHS: 0.5
24
+ GAMMA: 0.1
25
+ TEMPORAL_MODEL:
26
+ OPTIMIZER: sgd
27
+ MOMENTUM: 0.9
28
+ BASE_LR: 1e-3
29
+ SCHEDULER:
30
+ NAME: multistep
31
+ MILESTONES: [13, 24]
32
+ GAMMA: 0.1
dlhammer/dlhammer/test/test_args.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #================================================================
3
+ # Don't go gently into that good night.
4
+ #
5
+ # author: klaus
6
+ # description:
7
+ #
8
+ #================================================================
9
+
10
+ import os
11
+ import sys
12
+
13
+ CURRENT_FILE_DIRECTORY = os.path.abspath(os.path.dirname(__file__))
14
+ sys.path.append(os.path.join(CURRENT_FILE_DIRECTORY, '../..'))
15
+ sys.path.append(os.path.join(CURRENT_FILE_DIRECTORY, '.'))
16
+
17
+ from dlhammer import bootstrap, CONFIG
18
+ from dlhammer import logger
19
+
20
+ config = bootstrap(print_cfg=True)
dlhammer/dlhammer/test/test_logger.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #================================================================
3
+ # Don't go gently into that good night.
4
+ #
5
+ # author: klaus
6
+ # description:
7
+ #
8
+ #================================================================
9
+
10
+ import os
11
+ import sys
12
+
13
+ CURRENT_FILE_DIRECTORY = os.path.abspath(os.path.dirname(__file__))
14
+ sys.path.append(os.path.join(CURRENT_FILE_DIRECTORY, '../..'))
15
+ sys.path.append(os.path.join(CURRENT_FILE_DIRECTORY, '.'))
16
+
17
+ from dlhammer import bootstrap, logger
18
+ bootstrap()
19
+
20
+ logger.info('dummy output')
21
+
22
+ raise Exception('dummy error')
dlhammer/dlhammer/utils/__init__.py ADDED
File without changes
dlhammer/dlhammer/utils/misc.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #================================================================
3
+ # Don't go gently into that good night.
4
+ #
5
+ # author: klaus
6
+ # description:
7
+ #
8
+ #================================================================
9
+
10
+ import ast
11
+
12
+
13
+ def merge_dict(a, b, path=None):
14
+ """merge b into a. The values in b will override values in a.
15
+
16
+ Args:
17
+ a (dict): dict to merge to.
18
+ b (dict): dict to merge from.
19
+
20
+ Returns: dict1 with values merged from b.
21
+
22
+ """
23
+ if path is None: path = []
24
+ for key in b:
25
+ if key in a:
26
+ if isinstance(a[key], dict) and isinstance(b[key], dict):
27
+ merge_dict(a[key], b[key], path + [str(key)])
28
+ else:
29
+ a[key] = b[key]
30
+ else:
31
+ a[key] = b[key]
32
+ return a
33
+
34
+
35
+ def merge_opts(d, opts):
36
+ """merge opts
37
+ Args:
38
+ d (dict): The dict.
39
+ opts (list): The opts to merge. format: [key1, name1, key2, name2,...]
40
+ Returns: d. the input dict `d` with merged opts.
41
+
42
+ """
43
+ assert len(opts) % 2 == 0, f'length of opts must be even. Got: {opts}'
44
+ for i in range(0, len(opts), 2):
45
+ full_k, v = opts[i], opts[i + 1]
46
+ keys = full_k.split('.')
47
+ sub_d = d
48
+ for i, k in enumerate(keys):
49
+ if not hasattr(sub_d, k):
50
+ raise ValueError(f'The key {k} not exist in the dict. Full key:{full_k}')
51
+ if i != len(keys) - 1:
52
+ sub_d = sub_d[k]
53
+ else:
54
+ sub_d[k] = v
55
+ return d
56
+
57
+
58
+ def to_string(params, indent=2):
59
+ """format params to a string
60
+
61
+ Args:
62
+ params (EasyDict): the params.
63
+
64
+ Returns: The string to display.
65
+
66
+ """
67
+ msg = '{\n'
68
+ for i, (k, v) in enumerate(params.items()):
69
+ if isinstance(v, dict):
70
+ v = to_string(v, indent + 4)
71
+ spaces = ' ' * indent
72
+ msg += spaces + '{}: {}'.format(k, v)
73
+ if i == len(params) - 1:
74
+ msg += ' }'
75
+ else:
76
+ msg += '\n'
77
+ return msg
78
+
79
+
80
+ def eval_dict_leaf(d):
81
+ """eval values of dict leaf.
82
+
83
+ Args:
84
+ d (dict): The dict to eval.
85
+
86
+ Returns: dict.
87
+
88
+ """
89
+ for k, v in d.items():
90
+ if not isinstance(v, dict):
91
+ d[k] = eval_string(v)
92
+ else:
93
+ eval_dict_leaf(v)
94
+ return d
95
+
96
+
97
+ def eval_string(string):
98
+ """automatically evaluate string to corresponding types.
99
+
100
+ For example:
101
+ not a string -> return the original input
102
+ '0' -> 0
103
+ '0.2' -> 0.2
104
+ '[0, 1, 2]' -> [0,1,2]
105
+ 'eval(1+2)' -> 3
106
+ 'eval(range(5))' -> [0,1,2,3,4]
107
+
108
+
109
+ Args:
110
+ value : string.
111
+
112
+ Returns: the corresponding type
113
+
114
+ """
115
+ if not isinstance(string, str):
116
+ return string
117
+ if len(string) > 1 and string[0] == '[' and string[-1] == ']':
118
+ return eval(string)
119
+ if string[0:5] == 'eval(':
120
+ return eval(string[5:-1])
121
+ try:
122
+ v = ast.literal_eval(string)
123
+ except:
124
+ v = string
125
+ return v
dlhammer/dlhammer/utils/system.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #================================================================
3
+ # Don't go gently into that good night.
4
+ #
5
+ # author: klaus
6
+ # description:
7
+ #
8
+ #================================================================
9
+
10
+ import os
11
+ import sys
12
+ import subprocess
13
+ import numpy as np
14
+
15
+
16
+ def get_available_gpuids():
17
+ """
18
+ Returns: the gpu ids sorted in descending order w.r.t occupied memory.
19
+ """
20
+ com = "nvidia-smi|sed -n '/%/p'|sed 's/|/\\n/g'|sed -n '/MiB/p'|sed 's/ //g'|sed 's/MiB/\\n/'|sed '/\\//d'"
21
+ gpum = subprocess.check_output(com, shell=True)
22
+ gpum = gpum.decode('utf-8').split('\n')
23
+ gpum = gpum[:-1]
24
+ sorted_gpuid = np.argsort(gpum)
25
+ return sorted_gpuid
environment.yml ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: loconet
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=conda_forge
7
+ - _openmp_mutex=4.5=1_gnu
8
+ - alsa-lib=1.2.3=h516909a_0
9
+ - anyio=3.5.0=py37h89c1867_0
10
+ - argon2-cffi=21.3.0=pyhd8ed1ab_0
11
+ - argon2-cffi-bindings=21.2.0=py37h5e8e339_1
12
+ - aria2=1.36.0=h319415d_2
13
+ - attrs=21.4.0=pyhd8ed1ab_0
14
+ - babel=2.9.1=pyh44b312d_0
15
+ - backcall=0.2.0=pyh9f0ad1d_0
16
+ - backports=1.0=py_2
17
+ - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
18
+ - bleach=4.1.0=pyhd8ed1ab_0
19
+ - bottleneck=1.3.4=py37h6c7ee08_0
20
+ - brotli=1.0.9=h7f98852_6
21
+ - brotli-bin=1.0.9=h7f98852_6
22
+ - brotlipy=0.7.0=py37h5e8e339_1003
23
+ - c-ares=1.18.1=h7f98852_0
24
+ - ca-certificates=2022.5.18.1=ha878542_0
25
+ - cffi=1.14.6=py37hc58025e_0
26
+ - configparser=5.2.0=pyhd8ed1ab_0
27
+ - cryptography=36.0.1=py37hf1a17b8_0
28
+ - cycler=0.11.0=pyhd8ed1ab_0
29
+ - cython=0.29.27=py37hcd2ae1e_0
30
+ - dbus=1.13.6=h48d8840_2
31
+ - debugpy=1.5.1=py37hcd2ae1e_0
32
+ - defusedxml=0.7.1=pyhd8ed1ab_0
33
+ - easydict=1.9=py_0
34
+ - entrypoints=0.4=pyhd8ed1ab_0
35
+ - expat=2.4.6=h27087fc_0
36
+ - flit-core=3.7.0=pyhd8ed1ab_0
37
+ - fontconfig=2.13.96=ha180cfb_0
38
+ - fonttools=4.29.1=py37h5e8e339_0
39
+ - freetype=2.10.4=h0708190_1
40
+ - gettext=0.19.8.1=h0b5b191_1005
41
+ - giflib=5.2.1=h36c2ea0_2
42
+ - glib=2.68.4=h9c3ff4c_0
43
+ - glib-tools=2.68.4=h9c3ff4c_0
44
+ - gst-plugins-base=1.18.5=hf529b03_0
45
+ - gstreamer=1.18.5=h76c114f_0
46
+ - icu=68.2=h9c3ff4c_0
47
+ - idna=3.3=pyhd8ed1ab_0
48
+ - importlib_resources=5.4.0=pyhd8ed1ab_0
49
+ - ipykernel=6.9.1=py37h6531663_0
50
+ - ipython=7.31.1=py37h89c1867_0
51
+ - ipython_genutils=0.2.0=py_1
52
+ - jbig=2.1=h7f98852_2003
53
+ - jedi=0.18.1=py37h89c1867_0
54
+ - jinja2=3.0.3=pyhd8ed1ab_0
55
+ - jpeg=9e=h7f98852_0
56
+ - json5=0.9.5=pyh9f0ad1d_0
57
+ - jsonschema=4.4.0=pyhd8ed1ab_0
58
+ - jupyter_client=7.1.2=pyhd8ed1ab_0
59
+ - jupyter_core=4.9.2=py37h89c1867_0
60
+ - jupyter_server=1.13.5=pyhd8ed1ab_1
61
+ - jupyterlab=3.2.9=pyhd8ed1ab_0
62
+ - jupyterlab_pygments=0.1.2=pyh9f0ad1d_0
63
+ - jupyterlab_server=2.10.3=pyhd8ed1ab_0
64
+ - kiwisolver=1.3.2=py37h2527ec5_1
65
+ - krb5=1.19.2=hcc1bbae_3
66
+ - lcms2=2.12=hddcbb42_0
67
+ - ld_impl_linux-64=2.36.1=hea4e1c9_2
68
+ - lerc=3.0=h9c3ff4c_0
69
+ - libblas=3.9.0=13_linux64_openblas
70
+ - libbrotlicommon=1.0.9=h7f98852_6
71
+ - libbrotlidec=1.0.9=h7f98852_6
72
+ - libbrotlienc=1.0.9=h7f98852_6
73
+ - libcblas=3.9.0=13_linux64_openblas
74
+ - libclang=11.1.0=default_ha53f305_1
75
+ - libdeflate=1.10=h7f98852_0
76
+ - libedit=3.1.20191231=he28a2e2_2
77
+ - libevent=2.1.10=h9b69904_4
78
+ - libffi=3.3=h58526e2_2
79
+ - libgcc-ng=11.2.0=h1d223b6_12
80
+ - libgfortran-ng=11.2.0=h69a702a_12
81
+ - libgfortran5=11.2.0=h5c6108e_12
82
+ - libglib=2.68.4=h3e27bee_0
83
+ - libgomp=11.2.0=h1d223b6_12
84
+ - libiconv=1.16=h516909a_0
85
+ - liblapack=3.9.0=13_linux64_openblas
86
+ - libllvm11=11.1.0=hf817b99_3
87
+ - libogg=1.3.4=h7f98852_1
88
+ - libopenblas=0.3.18=pthreads_h8fe5266_0
89
+ - libopus=1.3.1=h7f98852_1
90
+ - libpng=1.6.37=h21135ba_2
91
+ - libpq=13.5=hd57d9b9_1
92
+ - libsodium=1.0.18=h36c2ea0_1
93
+ - libssh2=1.10.0=ha56f1ee_2
94
+ - libstdcxx-ng=11.2.0=he4da1e4_12
95
+ - libtiff=4.3.0=h542a066_3
96
+ - libuuid=2.32.1=h7f98852_1000
97
+ - libvorbis=1.3.7=h9c3ff4c_0
98
+ - libwebp=1.2.2=h3452ae3_0
99
+ - libwebp-base=1.2.2=h7f98852_1
100
+ - libxcb=1.13=h7f98852_1004
101
+ - libxkbcommon=1.0.3=he3ba5ed_0
102
+ - libxml2=2.9.12=h72842e0_0
103
+ - libzlib=1.2.11=h36c2ea0_1013
104
+ - llvmlite=0.38.0=py37h0761922_1
105
+ - lz4-c=1.9.3=h9c3ff4c_1
106
+ - markupsafe=2.1.0=py37h540881e_0
107
+ - matplotlib=3.5.1=py37h89c1867_0
108
+ - matplotlib-base=3.5.1=py37h1058ff1_0
109
+ - matplotlib-inline=0.1.3=pyhd8ed1ab_0
110
+ - mistune=0.8.4=py37h5e8e339_1005
111
+ - munkres=1.1.4=pyh9f0ad1d_0
112
+ - mysql-common=8.0.28=ha770c72_0
113
+ - mysql-libs=8.0.28=hfa10184_0
114
+ - nbclassic=0.3.5=pyhd8ed1ab_0
115
+ - nbclient=0.5.11=pyhd8ed1ab_0
116
+ - nbconvert=6.4.2=py37h89c1867_0
117
+ - nbformat=5.1.3=pyhd8ed1ab_0
118
+ - ncurses=6.2=h58526e2_4
119
+ - nest-asyncio=1.5.4=pyhd8ed1ab_0
120
+ - nomkl=1.0=h5ca1d4c_0
121
+ - notebook=6.4.8=pyha770c72_0
122
+ - nspr=4.32=h9c3ff4c_1
123
+ - nss=3.74=hb5efdd6_0
124
+ - numba=0.55.1=py37h2d894fd_0
125
+ - numexpr=2.8.0=py37hfe5f03c_101
126
+ - numpy=1.21.5=py37hf2998dd_0
127
+ - openjpeg=2.4.0=hb52868f_1
128
+ - openssl=1.1.1o=h166bdaf_0
129
+ - packaging=21.3=pyhd8ed1ab_0
130
+ - pandas=1.3.5=py37h8c16a72_0
131
+ - pandoc=2.17.1.1=ha770c72_0
132
+ - pandocfilters=1.5.0=pyhd8ed1ab_0
133
+ - parso=0.8.3=pyhd8ed1ab_0
134
+ - patsy=0.5.2=pyhd8ed1ab_0
135
+ - pcre=8.45=h9c3ff4c_0
136
+ - pexpect=4.8.0=pyh9f0ad1d_2
137
+ - pickleshare=0.7.5=py_1003
138
+ - pip=22.0.3=pyhd8ed1ab_0
139
+ - prometheus_client=0.13.1=pyhd8ed1ab_0
140
+ - prompt-toolkit=3.0.27=pyha770c72_0
141
+ - pthread-stubs=0.4=h36c2ea0_1001
142
+ - ptyprocess=0.7.0=pyhd3deb0d_0
143
+ - pycparser=2.21=pyhd8ed1ab_0
144
+ - pygments=2.11.2=pyhd8ed1ab_0
145
+ - pyopenssl=22.0.0=pyhd8ed1ab_0
146
+ - pyparsing=3.0.7=pyhd8ed1ab_0
147
+ - pyqt=5.12.3=py37h89c1867_8
148
+ - pyqt-impl=5.12.3=py37hac37412_8
149
+ - pyqt5-sip=4.19.18=py37hcd2ae1e_8
150
+ - pyqtchart=5.12=py37he336c9b_8
151
+ - pyqtwebengine=5.12.1=py37he336c9b_8
152
+ - pyrsistent=0.18.1=py37h5e8e339_0
153
+ - pysocks=1.7.1=py37h89c1867_4
154
+ - python=3.7.9=hffdb5ce_100_cpython
155
+ - python-dateutil=2.8.2=pyhd8ed1ab_0
156
+ - python_abi=3.7=2_cp37m
157
+ - pytz=2021.3=pyhd8ed1ab_0
158
+ - pyzmq=22.3.0=py37h336d617_1
159
+ - qt=5.12.9=hda022c4_4
160
+ - readline=8.1=h46c0cb4_0
161
+ - resampy=0.2.2=py_0
162
+ - scipy=1.7.3=py37hf2a6cf1_0
163
+ - seaborn=0.11.2=hd8ed1ab_0
164
+ - seaborn-base=0.11.2=pyhd8ed1ab_0
165
+ - send2trash=1.8.0=pyhd8ed1ab_0
166
+ - six=1.16.0=pyh6c4a22f_0
167
+ - sniffio=1.2.0=py37h89c1867_2
168
+ - sqlite=3.37.0=h9cd32fc_0
169
+ - statsmodels=0.13.2=py37hb1e94ed_0
170
+ - terminado=0.13.1=py37h89c1867_0
171
+ - testpath=0.5.0=pyhd8ed1ab_0
172
+ - tk=8.6.12=h27826a3_0
173
+ - tornado=6.1=py37h5e8e339_2
174
+ - traitlets=5.1.1=pyhd8ed1ab_0
175
+ - typing_extensions=4.1.1=pyha770c72_0
176
+ - unicodedata2=14.0.0=py37h5e8e339_0
177
+ - wcwidth=0.2.5=pyh9f0ad1d_2
178
+ - webencodings=0.5.1=py_1
179
+ - websocket-client=1.2.3=pyhd8ed1ab_0
180
+ - wheel=0.37.1=pyhd8ed1ab_0
181
+ - xorg-libxau=1.0.9=h7f98852_0
182
+ - xorg-libxdmcp=1.1.3=h7f98852_0
183
+ - xz=5.2.5=h516909a_1
184
+ - zeromq=4.3.4=h9c3ff4c_1
185
+ - zlib=1.2.11=h36c2ea0_1013
186
+ - zstd=1.5.2=ha95c52a_0
187
+ - pip:
188
+ - absl-py==1.0.0
189
+ - addict==2.4.0
190
+ - aiohttp==3.8.1
191
+ - aiosignal==1.2.0
192
+ - analytics-python==1.4.0
193
+ - appdirs==1.4.4
194
+ - asgiref==3.5.2
195
+ - async-timeout==4.0.2
196
+ - asynctest==0.13.0
197
+ - audioread==2.1.9
198
+ - backoff==1.10.0
199
+ - bcrypt==3.2.2
200
+ - beautifulsoup4==4.10.0
201
+ - cachetools==4.2.4
202
+ - certifi==2021.10.8
203
+ - charset-normalizer==2.0.9
204
+ - click==8.0.3
205
+ - decorator==4.4.2
206
+ - decord==0.6.0
207
+ - einops==0.4.0
208
+ - fastapi==0.78.0
209
+ - ffmpeg==1.4
210
+ - ffmpy==0.3.0
211
+ - filelock==3.4.0
212
+ - frozenlist==1.3.0
213
+ - fsspec==2022.1.0
214
+ - future==0.18.2
215
+ - fvcore==0.1.5.post20221221
216
+ - gdown==4.2.0
217
+ - google-auth==2.3.3
218
+ - google-auth-oauthlib==0.4.6
219
+ - gradio==3.0.2
220
+ - grpcio==1.43.0
221
+ - h11==0.13.0
222
+ - imageio==2.23.0
223
+ - imageio-ffmpeg==0.4.7
224
+ - importlib-metadata==4.10.0
225
+ - iopath==0.1.10
226
+ - ipywidgets==8.0.4
227
+ - joblib==1.1.0
228
+ - jupyterlab-widgets==3.0.5
229
+ - librosa==0.9.1
230
+ - linkify-it-py==1.0.3
231
+ - lmdb==1.4.1
232
+ - markdown==3.3.6
233
+ - markdown-it-py==2.1.0
234
+ - mdit-py-plugins==0.3.0
235
+ - mdurl==0.1.1
236
+ - mmaction2==0.24.1
237
+ - mmcv==1.7.0
238
+ - mmcv-full==1.4.6
239
+ - monotonic==1.6
240
+ - moviepy==1.0.3
241
+ - multidict==5.2.0
242
+ - oauthlib==3.1.1
243
+ - opencv-contrib-python==4.7.0.68
244
+ - opencv-python==4.5.5.62
245
+ - orjson==3.6.8
246
+ - paramiko==2.11.0
247
+ - pillow==8.3.2
248
+ - pooch==1.6.0
249
+ - portalocker==2.7.0
250
+ - proglog==0.1.10
251
+ - protobuf==3.19.3
252
+ - pyasn1==0.4.8
253
+ - pyasn1-modules==0.2.8
254
+ - pycryptodome==3.14.1
255
+ - pydantic==1.9.0
256
+ - pydeprecate==0.3.1
257
+ - pydub==0.25.1
258
+ - pynacl==1.5.0
259
+ - python-box==6.0.2
260
+ - python-multipart==0.0.5
261
+ - python-speech-features==0.6
262
+ - pytorch-lightning==1.5.8
263
+ - pyyaml==6.0
264
+ - requests==2.26.0
265
+ - requests-oauthlib==1.3.0
266
+ - rsa==4.8
267
+ - scenedetect==0.5.6.1
268
+ - scikit-learn==1.0.1
269
+ - setuptools==60.9.3
270
+ - soundfile==0.10.3.post1
271
+ - soupsieve==2.3.1
272
+ - starlette==0.19.1
273
+ - tabulate==0.9.0
274
+ - tensorboard==2.7.0
275
+ - tensorboard-data-server==0.6.1
276
+ - tensorboard-plugin-wit==1.8.1
277
+ - termcolor==2.2.0
278
+ - threadpoolctl==3.0.0
279
+ - timm==0.4.5
280
+ - torch==1.10.1
281
+ - torchaudio==0.10.1
282
+ - torchlibrosa==0.0.9
283
+ - torchmetrics==0.7.0
284
+ - torchvision==0.11.2
285
+ - tqdm==4.62.3
286
+ - typing-extensions==4.0.1
287
+ - uc-micro-py==1.0.1
288
+ - urllib3==1.26.7
289
+ - uvicorn==0.17.6
290
+ - warmup-scheduler-pytorch==0.1.2
291
+ - werkzeug==2.0.2
292
+ - wget==3.2
293
+ - widgetsnbextension==4.0.5
294
+ - yacs==0.1.8
295
+ - yapf==0.32.0
296
+ - yarl==1.7.2
297
+ - youtube-dl==2021.12.17
298
+ - zipp==3.6.0
legacy/talkNet_multi_multicard.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import sys, time, numpy, os, subprocess, pandas, tqdm
6
+
7
+ from loss_multi import lossAV, lossA, lossV
8
+ from model.talkNetModel import talkNetModel
9
+
10
+ import pytorch_lightning as pl
11
+ from torch import distributed as dist
12
+
13
+
14
+ class talkNet(pl.LightningModule):
15
+
16
+ def __init__(self, cfg):
17
+ super(talkNet, self).__init__()
18
+ self.model = talkNetModel().cuda()
19
+ self.cfg = cfg
20
+ self.lossAV = lossAV().cuda()
21
+ self.lossA = lossA().cuda()
22
+ self.lossV = lossV().cuda()
23
+ print(
24
+ time.strftime("%m-%d %H:%M:%S") + " Model para number = %.2f" %
25
+ (sum(param.numel() for param in self.model.parameters()) / 1024 / 1024))
26
+
27
+ def configure_optimizers(self):
28
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.cfg.SOLVER.BASE_LR)
29
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
30
+ step_size=1,
31
+ gamma=self.cfg.SOLVER.SCHEDULER.GAMMA)
32
+ return {"optimizer": optimizer, "lr_scheduler": scheduler}
33
+
34
+ def training_step(self, batch, batch_idx):
35
+ audioFeature, visualFeature, labels, masks = batch
36
+ b, s, t = visualFeature.shape[0], visualFeature.shape[1], visualFeature.shape[2]
37
+ audioFeature = audioFeature.repeat(1, s, 1, 1)
38
+ audioFeature = audioFeature.view(b * s, *audioFeature.shape[2:])
39
+ visualFeature = visualFeature.view(b * s, *visualFeature.shape[2:])
40
+ labels = labels.view(b * s, *labels.shape[2:])
41
+ masks = masks.view(b * s, *masks.shape[2:])
42
+
43
+ audioEmbed = self.model.forward_audio_frontend(audioFeature) # feedForward
44
+ visualEmbed = self.model.forward_visual_frontend(visualFeature)
45
+ audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed)
46
+ outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed)
47
+ outsA = self.model.forward_audio_backend(audioEmbed)
48
+ outsV = self.model.forward_visual_backend(visualEmbed)
49
+ labels = labels.reshape((-1))
50
+ nlossAV, _, _, prec = self.lossAV.forward(outsAV, labels, masks)
51
+ nlossA = self.lossA.forward(outsA, labels, masks)
52
+ nlossV = self.lossV.forward(outsV, labels, masks)
53
+ loss = nlossAV + 0.4 * nlossA + 0.4 * nlossV
54
+ self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
55
+ return loss
56
+
57
+ def training_epoch_end(self, training_step_outputs):
58
+ self.saveParameters(
59
+ os.path.join(self.cfg.WORKSPACE, "model", "{}.pth".format(self.current_epoch)))
60
+
61
+ def evaluate_network(self, loader):
62
+ self.eval()
63
+ predScores = []
64
+ self.model = self.model.cuda()
65
+ self.lossAV = self.lossAV.cuda()
66
+ self.lossA = self.lossA.cuda()
67
+ self.lossV = self.lossV.cuda()
68
+ evalCsvSave = self.cfg.evalCsvSave
69
+ evalOrig = self.cfg.evalOrig
70
+ for audioFeature, visualFeature, labels, masks in tqdm.tqdm(loader):
71
+ with torch.no_grad():
72
+ b, s = visualFeature.shape[0], visualFeature.shape[1]
73
+ t = visualFeature.shape[2]
74
+ audioFeature = audioFeature.repeat(1, s, 1, 1)
75
+ audioFeature = audioFeature.view(b * s, *audioFeature.shape[2:])
76
+ visualFeature = visualFeature.view(b * s, *visualFeature.shape[2:])
77
+ labels = labels.view(b * s, *labels.shape[2:])
78
+ masks = masks.view(b * s, *masks.shape[2:])
79
+ audioEmbed = self.model.forward_audio_frontend(audioFeature.cuda())
80
+ visualEmbed = self.model.forward_visual_frontend(visualFeature.cuda())
81
+ audioEmbed, visualEmbed = self.model.forward_cross_attention(
82
+ audioEmbed, visualEmbed)
83
+ outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed)
84
+ labels = labels.reshape((-1)).cuda()
85
+ outsAV = outsAV.view(b, s, t, -1)[:, 0, :, :].view(b * t, -1)
86
+ labels = labels.view(b, s, t)[:, 0, :].view(b * t)
87
+ masks = masks.view(b, s, t)[:, 0, :].view(b * t)
88
+ _, predScore, _, _ = self.lossAV.forward(outsAV, labels, masks)
89
+ predScore = predScore.detach().cpu().numpy()
90
+ predScores.extend(predScore)
91
+ evalLines = open(evalOrig).read().splitlines()[1:]
92
+ labels = []
93
+ labels = pandas.Series(['SPEAKING_AUDIBLE' for line in evalLines])
94
+ scores = pandas.Series(predScores)
95
+ evalRes = pandas.read_csv(evalOrig)
96
+ evalRes['score'] = scores
97
+ evalRes['label'] = labels
98
+ evalRes.drop(['label_id'], axis=1, inplace=True)
99
+ evalRes.drop(['instance_id'], axis=1, inplace=True)
100
+ evalRes.to_csv(evalCsvSave, index=False)
101
+ cmd = "python -O utils/get_ava_active_speaker_performance.py -g %s -p %s " % (evalOrig,
102
+ evalCsvSave)
103
+ mAP = float(
104
+ str(subprocess.run(cmd, shell=True, capture_output=True).stdout).split(' ')[2][:5])
105
+ return mAP
106
+
107
+ def saveParameters(self, path):
108
+ torch.save(self.state_dict(), path)
109
+
110
+ def loadParameters(self, path):
111
+ selfState = self.state_dict()
112
+ loadedState = torch.load(path)
113
+ for name, param in loadedState.items():
114
+ origName = name
115
+ if name not in selfState:
116
+ name = name.replace("module.", "")
117
+ if name not in selfState:
118
+ print("%s is not in the model." % origName)
119
+ continue
120
+ if selfState[name].size() != loadedState[origName].size():
121
+ sys.stderr.write("Wrong parameter length: %s, model: %s, loaded: %s" %
122
+ (origName, selfState[name].size(), loadedState[origName].size()))
123
+ continue
124
+ selfState[name].copy_(param)
legacy/talkNet_multicard.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import sys, time, numpy, os, subprocess, pandas, tqdm
6
+
7
+ from loss import lossAV, lossA, lossV
8
+ from model.talkNetModel import talkNetModel
9
+
10
+ import pytorch_lightning as pl
11
+ from torch import distributed as dist
12
+
13
+
14
+ class talkNet(pl.LightningModule):
15
+
16
+ def __init__(self, cfg):
17
+ super(talkNet, self).__init__()
18
+ self.cfg = cfg
19
+ self.model = talkNetModel()
20
+ self.lossAV = lossAV()
21
+ self.lossA = lossA()
22
+ self.lossV = lossV()
23
+ print(
24
+ time.strftime("%m-%d %H:%M:%S") + " Model para number = %.2f" %
25
+ (sum(param.numel() for param in self.model.parameters()) / 1024 / 1024))
26
+
27
+ def configure_optimizers(self):
28
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.cfg.SOLVER.BASE_LR)
29
+ scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
30
+ step_size=1,
31
+ gamma=self.cfg.SOLVER.SCHEDULER.GAMMA)
32
+ return {"optimizer": optimizer, "lr_scheduler": scheduler}
33
+
34
+ def training_step(self, batch, batch_idx):
35
+ audioFeature, visualFeature, labels = batch
36
+ audioEmbed = self.model.forward_audio_frontend(audioFeature[0]) # feedForward
37
+ visualEmbed = self.model.forward_visual_frontend(visualFeature[0])
38
+ audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed)
39
+ outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed)
40
+ outsA = self.model.forward_audio_backend(audioEmbed)
41
+ outsV = self.model.forward_visual_backend(visualEmbed)
42
+ labels = labels[0].reshape((-1))
43
+ nlossAV, _, _, prec = self.lossAV.forward(outsAV, labels)
44
+ nlossA = self.lossA.forward(outsA, labels)
45
+ nlossV = self.lossV.forward(outsV, labels)
46
+ loss = nlossAV + 0.4 * nlossA + 0.4 * nlossV
47
+ self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
48
+
49
+ return loss
50
+
51
+ def training_epoch_end(self, training_step_outputs):
52
+ self.saveParameters(
53
+ os.path.join(self.cfg.WORKSPACE, "model", "{}.pth".format(self.current_epoch)))
54
+
55
+ def validation_step(self, batch, batch_idx):
56
+ audioFeature, visualFeature, labels, indices = batch
57
+ audioEmbed = self.model.forward_audio_frontend(audioFeature[0])
58
+ visualEmbed = self.model.forward_visual_frontend(visualFeature[0])
59
+ audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed)
60
+ outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed)
61
+ labels = labels[0].reshape((-1))
62
+ loss, predScore, _, _ = self.lossAV.forward(outsAV, labels)
63
+ predScore = predScore[:, -1:].detach().cpu().numpy()
64
+ # self.log("val_loss", loss)
65
+
66
+ return predScore
67
+
68
+ def validation_epoch_end(self, validation_step_outputs):
69
+ evalCsvSave = self.cfg.evalCsvSave
70
+ evalOrig = self.cfg.evalOrig
71
+ predScores = []
72
+
73
+ for out in validation_step_outputs: # batch size =1
74
+ predScores.extend(out)
75
+
76
+ evalLines = open(evalOrig).read().splitlines()[1:]
77
+ labels = []
78
+ labels = pandas.Series(['SPEAKING_AUDIBLE' for line in evalLines])
79
+ scores = pandas.Series(predScores)
80
+ evalRes = pandas.read_csv(evalOrig)
81
+ print(len(evalRes), len(predScores), len(evalLines))
82
+ evalRes['score'] = scores
83
+ evalRes['label'] = labels
84
+ evalRes.drop(['label_id'], axis=1, inplace=True)
85
+ evalRes.drop(['instance_id'], axis=1, inplace=True)
86
+ evalRes.to_csv(evalCsvSave, index=False)
87
+ cmd = "python -O utils/get_ava_active_speaker_performance.py -g %s -p %s " % (evalOrig,
88
+ evalCsvSave)
89
+ mAP = float(
90
+ str(subprocess.run(cmd, shell=True, capture_output=True).stdout).split(' ')[2][:5])
91
+ print("validation mAP: {}".format(mAP))
92
+
93
+ def saveParameters(self, path):
94
+ torch.save(self.state_dict(), path)
95
+
96
+ def loadParameters(self, path):
97
+ selfState = self.state_dict()
98
+ loadedState = torch.load(path, map_location='cpu')
99
+ for name, param in loadedState.items():
100
+ origName = name
101
+ if name not in selfState:
102
+ name = name.replace("module.", "")
103
+ if name not in selfState:
104
+ print("%s is not in the model." % origName)
105
+ continue
106
+ if selfState[name].size() != loadedState[origName].size():
107
+ sys.stderr.write("Wrong parameter length: %s, model: %s, loaded: %s" %
108
+ (origName, selfState[name].size(), loadedState[origName].size()))
109
+ continue
110
+ selfState[name].copy_(param)
111
+
112
+ def evaluate_network(self, loader):
113
+ self.eval()
114
+ self.model = self.model.cuda()
115
+ self.lossAV = self.lossAV.cuda()
116
+ self.lossA = self.lossA.cuda()
117
+ self.lossV = self.lossV.cuda()
118
+ predScores = []
119
+ evalCsvSave = self.cfg.evalCsvSave
120
+ evalOrig = self.cfg.evalOrig
121
+ for audioFeature, visualFeature, labels in tqdm.tqdm(loader):
122
+ with torch.no_grad():
123
+ audioEmbed = self.model.forward_audio_frontend(audioFeature[0].cuda())
124
+ visualEmbed = self.model.forward_visual_frontend(visualFeature[0].cuda())
125
+ audioEmbed, visualEmbed = self.model.forward_cross_attention(
126
+ audioEmbed, visualEmbed)
127
+ outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed)
128
+ labels = labels[0].reshape((-1)).cuda()
129
+ _, predScore, _, _ = self.lossAV.forward(outsAV, labels)
130
+ predScore = predScore[:, 1].detach().cpu().numpy()
131
+ predScores.extend(predScore)
132
+ evalLines = open(evalOrig).read().splitlines()[1:]
133
+ labels = []
134
+ labels = pandas.Series(['SPEAKING_AUDIBLE' for line in evalLines])
135
+ scores = pandas.Series(predScores)
136
+ evalRes = pandas.read_csv(evalOrig)
137
+ evalRes['score'] = scores
138
+ evalRes['label'] = labels
139
+ evalRes.drop(['label_id'], axis=1, inplace=True)
140
+ evalRes.drop(['instance_id'], axis=1, inplace=True)
141
+ evalRes.to_csv(evalCsvSave, index=False)
142
+ cmd = "python -O utils/get_ava_active_speaker_performance.py -g %s -p %s " % (evalOrig,
143
+ evalCsvSave)
144
+ mAP = float(
145
+ str(subprocess.run(cmd, shell=True, capture_output=True).stdout).split(' ')[2][:5])
146
+ return mAP
legacy/talkNet_orig.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import sys, time, numpy, os, subprocess, pandas, tqdm
6
+
7
+ from loss import lossAV, lossA, lossV
8
+ from model.talkNetModel import talkNetModel
9
+
10
+
11
+ class talkNet(nn.Module):
12
+
13
+ def __init__(self, lr=0.0001, lrDecay=0.95, **kwargs):
14
+ super(talkNet, self).__init__()
15
+ self.model = talkNetModel().cuda()
16
+ self.lossAV = lossAV().cuda()
17
+ self.lossA = lossA().cuda()
18
+ self.lossV = lossV().cuda()
19
+ self.optim = torch.optim.Adam(self.parameters(), lr=lr)
20
+ self.scheduler = torch.optim.lr_scheduler.StepLR(self.optim, step_size=1, gamma=lrDecay)
21
+ print(
22
+ time.strftime("%m-%d %H:%M:%S") + " Model para number = %.2f" %
23
+ (sum(param.numel() for param in self.model.parameters()) / 1024 / 1024))
24
+
25
+ def train_network(self, loader, epoch, **kwargs):
26
+ self.train()
27
+ self.scheduler.step(epoch - 1)
28
+ index, top1, loss = 0, 0, 0
29
+ lr = self.optim.param_groups[0]['lr']
30
+ for num, (audioFeature, visualFeature, labels) in enumerate(loader, start=1):
31
+ self.zero_grad()
32
+ audioEmbed = self.model.forward_audio_frontend(audioFeature[0].cuda()) # feedForward
33
+ visualEmbed = self.model.forward_visual_frontend(visualFeature[0].cuda())
34
+ audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed)
35
+ outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed)
36
+ outsA = self.model.forward_audio_backend(audioEmbed)
37
+ outsV = self.model.forward_visual_backend(visualEmbed)
38
+ labels = labels[0].reshape((-1)).cuda() # Loss
39
+ nlossAV, _, _, prec = self.lossAV.forward(outsAV, labels)
40
+ nlossA = self.lossA.forward(outsA, labels)
41
+ nlossV = self.lossV.forward(outsV, labels)
42
+ nloss = nlossAV + 0.4 * nlossA + 0.4 * nlossV
43
+ loss += nloss.detach().cpu().numpy()
44
+ top1 += prec
45
+ nloss.backward()
46
+ self.optim.step()
47
+ index += len(labels)
48
+ sys.stderr.write(time.strftime("%m-%d %H:%M:%S") + \
49
+ " [%2d] Lr: %5f, Training: %.2f%%, " %(epoch, lr, 100 * (num / loader.__len__())) + \
50
+ " Loss: %.5f, ACC: %2.2f%% \r" %(loss/(num), 100 * (top1/index)))
51
+ sys.stderr.flush()
52
+ sys.stdout.write("\n")
53
+ return loss / num, lr
54
+
55
+ def evaluate_network(self, loader, evalCsvSave, evalOrig, **kwargs):
56
+ self.eval()
57
+ predScores = []
58
+ for audioFeature, visualFeature, labels in tqdm.tqdm(loader):
59
+ with torch.no_grad():
60
+ audioEmbed = self.model.forward_audio_frontend(audioFeature[0].cuda())
61
+ visualEmbed = self.model.forward_visual_frontend(visualFeature[0].cuda())
62
+ audioEmbed, visualEmbed = self.model.forward_cross_attention(
63
+ audioEmbed, visualEmbed)
64
+ outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed)
65
+ labels = labels[0].reshape((-1)).cuda()
66
+ _, predScore, _, _ = self.lossAV.forward(outsAV, labels)
67
+ predScore = predScore[:, 1].detach().cpu().numpy()
68
+ predScores.extend(predScore)
69
+ evalLines = open(evalOrig).read().splitlines()[1:]
70
+ labels = []
71
+ labels = pandas.Series(['SPEAKING_AUDIBLE' for line in evalLines])
72
+ scores = pandas.Series(predScores)
73
+ evalRes = pandas.read_csv(evalOrig)
74
+ evalRes['score'] = scores
75
+ evalRes['label'] = labels
76
+ evalRes.drop(['label_id'], axis=1, inplace=True)
77
+ evalRes.drop(['instance_id'], axis=1, inplace=True)
78
+ evalRes.to_csv(evalCsvSave, index=False)
79
+ cmd = "python -O utils/get_ava_active_speaker_performance.py -g %s -p %s " % (evalOrig,
80
+ evalCsvSave)
81
+ mAP = float(
82
+ str(subprocess.run(cmd, shell=True, capture_output=True).stdout).split(' ')[2][:5])
83
+ return mAP
84
+
85
+ def saveParameters(self, path):
86
+ torch.save(self.state_dict(), path)
87
+
88
+ def loadParameters(self, path):
89
+ selfState = self.state_dict()
90
+ loadedState = torch.load(path)
91
+ for name, param in loadedState.items():
92
+ origName = name
93
+ if name not in selfState:
94
+ name = name.replace("module.", "")
95
+ if name not in selfState:
96
+ print("%s is not in the model." % origName)
97
+ continue
98
+ if selfState[name].size() != loadedState[origName].size():
99
+ sys.stderr.write("Wrong parameter length: %s, model: %s, loaded: %s" %
100
+ (origName, selfState[name].size(), loadedState[origName].size()))
101
+ continue
102
+ selfState[name].copy_(param)
legacy/trainTalkNet_multicard.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time, os, torch, argparse, warnings, glob
2
+
3
+ from utils.tools import *
4
+ from dlhammer import bootstrap
5
+ import pytorch_lightning as pl
6
+ from pytorch_lightning import Trainer, seed_everything
7
+ from pytorch_lightning.callbacks import ModelCheckpoint
8
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
9
+
10
+
11
+ class MyCollator(object):
12
+
13
+ def __init__(self, cfg):
14
+ self.cfg = cfg
15
+
16
+ def __call__(self, data):
17
+ audiofeatures = [item[0] for item in data]
18
+ visualfeatures = [item[1] for item in data]
19
+ labels = [item[2] for item in data]
20
+ masks = [item[3] for item in data]
21
+ cut_limit = self.cfg.MODEL.CLIP_LENGTH
22
+ # pad audio
23
+ lengths = torch.tensor([t.shape[1] for t in audiofeatures])
24
+ max_len = max(lengths)
25
+ padded_audio = torch.stack([
26
+ torch.cat([i, i.new_zeros((i.shape[0], max_len - i.shape[1], i.shape[2]))], 1)
27
+ for i in audiofeatures
28
+ ], 0)
29
+
30
+ if max_len > cut_limit * 4:
31
+ padded_audio = padded_audio[:, :, :cut_limit * 4, ...]
32
+
33
+ # pad video
34
+ lengths = torch.tensor([t.shape[1] for t in visualfeatures])
35
+ max_len = max(lengths)
36
+ padded_video = torch.stack([
37
+ torch.cat(
38
+ [i, i.new_zeros((i.shape[0], max_len - i.shape[1], i.shape[2], i.shape[3]))], 1)
39
+ for i in visualfeatures
40
+ ], 0)
41
+ padded_labels = torch.stack(
42
+ [torch.cat([i, i.new_zeros((i.shape[0], max_len - i.shape[1]))], 1) for i in labels], 0)
43
+ padded_masks = torch.stack(
44
+ [torch.cat([i, i.new_zeros((i.shape[0], max_len - i.shape[1]))], 1) for i in masks], 0)
45
+
46
+ if max_len > cut_limit:
47
+ padded_video = padded_video[:, :, :cut_limit, ...]
48
+ padded_labels = padded_labels[:, :, :cut_limit, ...]
49
+ padded_masks = padded_masks[:, :, :cut_limit, ...]
50
+ return padded_audio, padded_video, padded_labels, padded_masks
51
+
52
+
53
+ class DataPrep(pl.LightningDataModule):
54
+
55
+ def __init__(self, cfg):
56
+ self.cfg = cfg
57
+
58
+ def train_dataloader(self):
59
+ cfg = self.cfg
60
+
61
+ if self.cfg.MODEL.NAME == "baseline":
62
+ from dataLoader import train_loader, val_loader
63
+ loader = train_loader(trialFileName = cfg.trainTrialAVA, \
64
+ audioPath = os.path.join(cfg.audioPathAVA , 'train'), \
65
+ visualPath = os.path.join(cfg.visualPathAVA, 'train'), \
66
+ batchSize=2500
67
+ )
68
+ elif self.cfg.MODEL.NAME == "multi":
69
+ from dataLoader_multiperson import train_loader, val_loader
70
+ loader = train_loader(trialFileName = cfg.trainTrialAVA, \
71
+ audioPath = os.path.join(cfg.audioPathAVA , 'train'), \
72
+ visualPath = os.path.join(cfg.visualPathAVA, 'train'), \
73
+ num_speakers=cfg.MODEL.NUM_SPEAKERS,
74
+ )
75
+ if cfg.MODEL.NAME == "baseline":
76
+ trainLoader = torch.utils.data.DataLoader(
77
+ loader,
78
+ batch_size=1,
79
+ shuffle=True,
80
+ num_workers=4,
81
+ )
82
+ elif cfg.MODEL.NAME == "multi":
83
+ collator = MyCollator(cfg)
84
+ trainLoader = torch.utils.data.DataLoader(loader,
85
+ batch_size=1,
86
+ shuffle=True,
87
+ num_workers=4,
88
+ collate_fn=collator)
89
+
90
+ return trainLoader
91
+
92
+ def val_dataloader(self):
93
+ cfg = self.cfg
94
+ loader = val_loader(trialFileName = cfg.evalTrialAVA, \
95
+ audioPath = os.path.join(cfg.audioPathAVA , cfg.evalDataType), \
96
+ visualPath = os.path.join(cfg.visualPathAVA, cfg.evalDataType), \
97
+ )
98
+ valLoader = torch.utils.data.DataLoader(loader,
99
+ batch_size=cfg.VAL.BATCH_SIZE,
100
+ shuffle=False,
101
+ num_workers=16)
102
+ return valLoader
103
+
104
+
105
+ def main():
106
+ # The structure of this code is learnt from https://github.com/clovaai/voxceleb_trainer
107
+ cfg = bootstrap(print_cfg=False)
108
+ print(cfg)
109
+
110
+ warnings.filterwarnings("ignore")
111
+ seed_everything(42, workers=True)
112
+
113
+ cfg = init_args(cfg)
114
+
115
+ # checkpoint_callback = ModelCheckpoint(dirpath=os.path.join(cfg.WORKSPACE, "model"),
116
+ # save_top_k=-1,
117
+ # filename='{epoch}')
118
+
119
+ data = DataPrep(cfg)
120
+
121
+ trainer = Trainer(
122
+ gpus=int(cfg.TRAIN.TRAINER_GPU),
123
+ precision=32,
124
+ # callbacks=[checkpoint_callback],
125
+ max_epochs=25,
126
+ replace_sampler_ddp=True)
127
+ # val_trainer = Trainer(deterministic=True, num_sanity_val_steps=-1, gpus=1)
128
+ if cfg.downloadAVA == True:
129
+ preprocess_AVA(cfg)
130
+ quit()
131
+
132
+ # if cfg.RESUME:
133
+ # modelfiles = glob.glob('%s/model_0*.model' % cfg.modelSavePath)
134
+ # modelfiles.sort()
135
+ # if len(modelfiles) >= 1:
136
+ # print("Model %s loaded from previous state!" % modelfiles[-1])
137
+ # epoch = int(os.path.splitext(os.path.basename(modelfiles[-1]))[0][6:]) + 1
138
+ # s = talkNet(cfg)
139
+ # s.loadParameters(modelfiles[-1])
140
+ # else:
141
+ # epoch = 1
142
+ # s = talkNet(cfg)
143
+ epoch = 1
144
+ if cfg.MODEL.NAME == "baseline":
145
+ from talkNet_multicard import talkNet
146
+ elif cfg.MODEL.NAME == "multi":
147
+ from talkNet_multi import talkNet
148
+
149
+ s = talkNet(cfg)
150
+
151
+ # scoreFile = open(cfg.scoreSavePath, "a+")
152
+
153
+ trainer.fit(s, train_dataloaders=data.train_dataloader())
154
+
155
+ modelfiles = glob.glob('%s/*.pth' % os.path.join(cfg.WORKSPACE, "model"))
156
+
157
+ modelfiles.sort()
158
+ for path in modelfiles:
159
+ s.loadParameters(path)
160
+ prec = trainer.validate(s, data.val_dataloader())
161
+
162
+ # if epoch % cfg.testInterval == 0:
163
+ # s.saveParameters(cfg.modelSavePath + "/model_%04d.model" % epoch)
164
+ # trainer.validate(dataloaders=valLoader)
165
+ # print(time.strftime("%Y-%m-%d %H:%M:%S"), "%d epoch, mAP %2.2f%%" % (epoch, mAPs[-1]))
166
+ # scoreFile.write("%d epoch, LOSS %f, mAP %2.2f%%\n" % (epoch, loss, mAPs[-1]))
167
+ # scoreFile.flush()
168
+
169
+
170
+ if __name__ == '__main__':
171
+ main()
legacy/train_multi.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time, os, torch, argparse, warnings, glob
2
+
3
+ from dataLoader_multiperson import train_loader, val_loader
4
+ from utils.tools import *
5
+ from talkNet_multi import talkNet
6
+
7
+
8
+ def collate_fn_padding(data):
9
+ audiofeatures = [item[0] for item in data]
10
+ visualfeatures = [item[1] for item in data]
11
+ labels = [item[2] for item in data]
12
+ masks = [item[3] for item in data]
13
+ cut_limit = 200
14
+ # pad audio
15
+ lengths = torch.tensor([t.shape[1] for t in audiofeatures])
16
+ max_len = max(lengths)
17
+ padded_audio = torch.stack([
18
+ torch.cat([i, i.new_zeros((i.shape[0], max_len - i.shape[1], i.shape[2]))], 1)
19
+ for i in audiofeatures
20
+ ], 0)
21
+
22
+ if max_len > cut_limit * 4:
23
+ padded_audio = padded_audio[:, :, :cut_limit * 4, ...]
24
+
25
+ # pad video
26
+ lengths = torch.tensor([t.shape[1] for t in visualfeatures])
27
+ max_len = max(lengths)
28
+ padded_video = torch.stack([
29
+ torch.cat([i, i.new_zeros((i.shape[0], max_len - i.shape[1], i.shape[2], i.shape[3]))], 1)
30
+ for i in visualfeatures
31
+ ], 0)
32
+ padded_labels = torch.stack(
33
+ [torch.cat([i, i.new_zeros((i.shape[0], max_len - i.shape[1]))], 1) for i in labels], 0)
34
+ padded_masks = torch.stack(
35
+ [torch.cat([i, i.new_zeros((i.shape[0], max_len - i.shape[1]))], 1) for i in masks], 0)
36
+
37
+ if max_len > cut_limit:
38
+ padded_video = padded_video[:, :, :cut_limit, ...]
39
+ padded_labels = padded_labels[:, :, :cut_limit, ...]
40
+ padded_masks = padded_masks[:, :, :cut_limit, ...]
41
+ # print(padded_audio.shape, padded_video.shape, padded_labels.shape, padded_masks.shape)
42
+ return padded_audio, padded_video, padded_labels, padded_masks
43
+
44
+
45
+ def main():
46
+ # The structure of this code is learnt from https://github.com/clovaai/voxceleb_trainer
47
+ warnings.filterwarnings("ignore")
48
+
49
+ parser = argparse.ArgumentParser(description="TalkNet Training")
50
+ # Training details
51
+ parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate')
52
+ parser.add_argument('--lrDecay', type=float, default=0.95, help='Learning rate decay rate')
53
+ parser.add_argument('--maxEpoch', type=int, default=25, help='Maximum number of epochs')
54
+ parser.add_argument('--testInterval',
55
+ type=int,
56
+ default=1,
57
+ help='Test and save every [testInterval] epochs')
58
+ parser.add_argument(
59
+ '--batchSize',
60
+ type=int,
61
+ default=2500,
62
+ help=
63
+ 'Dynamic batch size, default is 2500 frames, other batchsize (such as 1500) will not affect the performance'
64
+ )
65
+ parser.add_argument('--batch_size', type=int, default=1, help='batch_size')
66
+ parser.add_argument('--num_speakers', type=int, default=5, help='num_speakers')
67
+ parser.add_argument('--nDataLoaderThread', type=int, default=4, help='Number of loader threads')
68
+ # Data path
69
+ parser.add_argument('--dataPathAVA',
70
+ type=str,
71
+ default="/data08/AVA",
72
+ help='Save path of AVA dataset')
73
+ parser.add_argument('--savePath', type=str, default="exps/exp1")
74
+ # Data selection
75
+ parser.add_argument('--evalDataType',
76
+ type=str,
77
+ default="val",
78
+ help='Only for AVA, to choose the dataset for evaluation, val or test')
79
+ # For download dataset only, for evaluation only
80
+ parser.add_argument('--downloadAVA',
81
+ dest='downloadAVA',
82
+ action='store_true',
83
+ help='Only download AVA dataset and do related preprocess')
84
+ parser.add_argument('--evaluation',
85
+ dest='evaluation',
86
+ action='store_true',
87
+ help='Only do evaluation by using pretrained model [pretrain_AVA.model]')
88
+ args = parser.parse_args()
89
+ # Data loader
90
+ args = init_args(args)
91
+
92
+ if args.downloadAVA == True:
93
+ preprocess_AVA(args)
94
+ quit()
95
+
96
+ loader = train_loader(trialFileName = args.trainTrialAVA, \
97
+ audioPath = os.path.join(args.audioPathAVA , 'train'), \
98
+ visualPath = os.path.join(args.visualPathAVA, 'train'), \
99
+ # num_speakers = args.num_speakers, \
100
+ **vars(args))
101
+ trainLoader = torch.utils.data.DataLoader(loader,
102
+ batch_size=args.batch_size,
103
+ shuffle=True,
104
+ num_workers=args.nDataLoaderThread,
105
+ collate_fn=collate_fn_padding)
106
+
107
+ loader = val_loader(trialFileName = args.evalTrialAVA, \
108
+ audioPath = os.path.join(args.audioPathAVA , args.evalDataType), \
109
+ visualPath = os.path.join(args.visualPathAVA, args.evalDataType), \
110
+ # num_speakers = args.num_speakers, \
111
+ **vars(args))
112
+ valLoader = torch.utils.data.DataLoader(loader, batch_size=1, shuffle=False, num_workers=16)
113
+
114
+ if args.evaluation == True:
115
+ download_pretrain_model_AVA()
116
+ s = talkNet(**vars(args))
117
+ s.loadParameters('pretrain_AVA.model')
118
+ print("Model %s loaded from previous state!" % ('pretrain_AVA.model'))
119
+ mAP = s.evaluate_network(loader=valLoader, **vars(args))
120
+ print("mAP %2.2f%%" % (mAP))
121
+ quit()
122
+
123
+ modelfiles = glob.glob('%s/model_0*.model' % args.modelSavePath)
124
+ modelfiles.sort()
125
+ if len(modelfiles) >= 1:
126
+ print("Model %s loaded from previous state!" % modelfiles[-1])
127
+ epoch = int(os.path.splitext(os.path.basename(modelfiles[-1]))[0][6:]) + 1
128
+ s = talkNet(epoch=epoch, **vars(args))
129
+ s.loadParameters(modelfiles[-1])
130
+ else:
131
+ epoch = 1
132
+ s = talkNet(epoch=epoch, **vars(args))
133
+
134
+ mAPs = []
135
+ scoreFile = open(args.scoreSavePath, "a+")
136
+
137
+ while (1):
138
+ loss, lr = s.train_network(epoch=epoch, loader=trainLoader, **vars(args))
139
+
140
+ if epoch % args.testInterval == 0:
141
+ s.saveParameters(args.modelSavePath + "/model_%04d.model" % epoch)
142
+ mAPs.append(s.evaluate_network(epoch=epoch, loader=valLoader, **vars(args)))
143
+ print(time.strftime("%Y-%m-%d %H:%M:%S"),
144
+ "%d epoch, mAP %2.2f%%, bestmAP %2.2f%%" % (epoch, mAPs[-1], max(mAPs)))
145
+ scoreFile.write("%d epoch, LR %f, LOSS %f, mAP %2.2f%%, bestmAP %2.2f%%\n" %
146
+ (epoch, lr, loss, mAPs[-1], max(mAPs)))
147
+ scoreFile.flush()
148
+
149
+ if epoch >= args.maxEpoch:
150
+ quit()
151
+
152
+ epoch += 1
153
+
154
+
155
+ if __name__ == '__main__':
156
+ main()
loconet.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import sys, time, numpy, os, subprocess, pandas, tqdm
6
+
7
+ from loss_multi import lossAV, lossA, lossV
8
+ from model.loconet_encoder import locoencoder
9
+
10
+ import torch.distributed as dist
11
+ from xxlib.utils.distributed import all_gather, all_reduce
12
+
13
+
14
+ class Loconet(nn.Module):
15
+
16
+ def __init__(self, cfg):
17
+ super(Loconet, self).__init__()
18
+ self.cfg = cfg
19
+ self.model = locoencoder(cfg)
20
+ self.lossAV = lossAV()
21
+ self.lossA = lossA()
22
+ self.lossV = lossV()
23
+
24
+ def forward(self, audioFeature, visualFeature, labels, masks):
25
+ b, s, t = visualFeature.shape[:3]
26
+ visualFeature = visualFeature.view(b * s, *visualFeature.shape[2:])
27
+ labels = labels.view(b * s, *labels.shape[2:])
28
+ masks = masks.view(b * s, *masks.shape[2:])
29
+
30
+ audioEmbed = self.model.forward_audio_frontend(audioFeature) # B, C, T, 4
31
+ visualEmbed = self.model.forward_visual_frontend(visualFeature)
32
+ audioEmbed = audioEmbed.repeat(s, 1, 1)
33
+
34
+ audioEmbed, visualEmbed = self.model.forward_cross_attention(audioEmbed, visualEmbed)
35
+ outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed, b, s)
36
+ outsA = self.model.forward_audio_backend(audioEmbed)
37
+ outsV = self.model.forward_visual_backend(visualEmbed)
38
+
39
+ labels = labels.reshape((-1))
40
+ masks = masks.reshape((-1))
41
+ nlossAV, _, _, prec = self.lossAV.forward(outsAV, labels, masks)
42
+ nlossA = self.lossA.forward(outsA, labels, masks)
43
+ nlossV = self.lossV.forward(outsV, labels, masks)
44
+
45
+ nloss = nlossAV + 0.4 * nlossA + 0.4 * nlossV
46
+
47
+ num_frames = masks.sum()
48
+ return nloss, prec, num_frames
49
+
50
+
51
+ class loconet(nn.Module):
52
+
53
+ def __init__(self, cfg, rank=None, device=None):
54
+ super(loconet, self).__init__()
55
+ self.cfg = cfg
56
+ self.rank = rank
57
+ if rank != None:
58
+ self.rank = rank
59
+ self.device = device
60
+
61
+ self.model = Loconet(cfg).to(device)
62
+ self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
63
+ self.model = nn.parallel.DistributedDataParallel(self.model,
64
+ device_ids=[rank],
65
+ output_device=rank,
66
+ find_unused_parameters=False)
67
+ self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.SOLVER.BASE_LR)
68
+ self.scheduler = torch.optim.lr_scheduler.StepLR(self.optim,
69
+ step_size=1,
70
+ gamma=self.cfg.SOLVER.SCHEDULER.GAMMA)
71
+ else:
72
+ self.model = locoencoder(cfg).cuda()
73
+ self.lossAV = lossAV().cuda()
74
+ self.lossA = lossA().cuda()
75
+ self.lossV = lossV().cuda()
76
+
77
+ print(
78
+ time.strftime("%m-%d %H:%M:%S") + " Model para number = %.2f" %
79
+ (sum(param.numel() for param in self.model.parameters()) / 1024 / 1024))
80
+
81
+ def train_network(self, epoch, loader):
82
+ self.model.train()
83
+ self.scheduler.step(epoch - 1)
84
+ index, top1, loss = 0, 0, 0
85
+ lr = self.optim.param_groups[0]['lr']
86
+ loader.sampler.set_epoch(epoch)
87
+ device = self.device
88
+
89
+ pbar = enumerate(loader, start=1)
90
+ if self.rank == 0:
91
+ pbar = tqdm.tqdm(pbar, total=loader.__len__())
92
+
93
+ for num, (audioFeature, visualFeature, labels, masks) in pbar:
94
+
95
+ audioFeature = audioFeature.to(device)
96
+ visualFeature = visualFeature.to(device)
97
+ labels = labels.to(device)
98
+ masks = masks.to(device)
99
+ nloss, prec, num_frames = self.model(
100
+ audioFeature,
101
+ visualFeature,
102
+ labels,
103
+ masks,
104
+ )
105
+
106
+ self.optim.zero_grad()
107
+ nloss.backward()
108
+ self.optim.step()
109
+
110
+ [nloss, prec, num_frames] = all_reduce([nloss, prec, num_frames], average=False)
111
+ top1 += prec.detach().cpu().numpy()
112
+ loss += nloss.detach().cpu().numpy()
113
+ index += int(num_frames.detach().cpu().item())
114
+ if self.rank == 0:
115
+ pbar.set_postfix(
116
+ dict(epoch=epoch,
117
+ lr=lr,
118
+ loss=loss / (num * self.cfg.NUM_GPUS),
119
+ acc=(top1 / index)))
120
+ dist.barrier()
121
+ return loss / num, lr
122
+
123
+ def evaluate_network(self, epoch, loader):
124
+ self.eval()
125
+ predScores = []
126
+ evalCsvSave = os.path.join(self.cfg.WORKSPACE, "{}_res.csv".format(epoch))
127
+ evalOrig = self.cfg.evalOrig
128
+ for audioFeature, visualFeature, labels, masks in tqdm.tqdm(loader):
129
+ with torch.no_grad():
130
+ audioFeature = audioFeature.cuda()
131
+ visualFeature = visualFeature.cuda()
132
+ labels = labels.cuda()
133
+ masks = masks.cuda()
134
+ b, s, t = visualFeature.shape[0], visualFeature.shape[1], visualFeature.shape[2]
135
+ visualFeature = visualFeature.view(b * s, *visualFeature.shape[2:])
136
+ labels = labels.view(b * s, *labels.shape[2:])
137
+ masks = masks.view(b * s, *masks.shape[2:])
138
+ audioEmbed = self.model.forward_audio_frontend(audioFeature)
139
+ visualEmbed = self.model.forward_visual_frontend(visualFeature)
140
+ audioEmbed = audioEmbed.repeat(s, 1, 1)
141
+ audioEmbed, visualEmbed = self.model.forward_cross_attention(
142
+ audioEmbed, visualEmbed)
143
+ outsAV = self.model.forward_audio_visual_backend(audioEmbed, visualEmbed, b, s)
144
+ labels = labels.reshape((-1))
145
+ masks = masks.reshape((-1))
146
+ outsAV = outsAV.view(b, s, t, -1)[:, 0, :, :].view(b * t, -1)
147
+ labels = labels.view(b, s, t)[:, 0, :].view(b * t).cuda()
148
+ masks = masks.view(b, s, t)[:, 0, :].view(b * t)
149
+ _, predScore, _, _ = self.lossAV.forward(outsAV, labels, masks)
150
+ predScore = predScore[:, 1].detach().cpu().numpy()
151
+ predScores.extend(predScore)
152
+ evalLines = open(evalOrig).read().splitlines()[1:]
153
+ labels = []
154
+ labels = pandas.Series(['SPEAKING_AUDIBLE' for line in evalLines])
155
+ scores = pandas.Series(predScores)
156
+ evalRes = pandas.read_csv(evalOrig)
157
+ evalRes['score'] = scores
158
+ evalRes['label'] = labels
159
+ evalRes.drop(['label_id'], axis=1, inplace=True)
160
+ evalRes.drop(['instance_id'], axis=1, inplace=True)
161
+ evalRes.to_csv(evalCsvSave, index=False)
162
+ cmd = "python -O utils/get_ava_active_speaker_performance.py -g %s -p %s " % (evalOrig,
163
+ evalCsvSave)
164
+ mAP = float(
165
+ str(subprocess.run(cmd, shell=True, capture_output=True).stdout).split(' ')[2][:5])
166
+ return mAP
167
+
168
+ def saveParameters(self, path):
169
+ torch.save(self.state_dict(), path)
170
+
171
+ def loadParameters(self, path):
172
+ selfState = self.state_dict()
173
+ loadedState = torch.load(path, map_location='cpu')
174
+ if self.rank != None:
175
+ info = self.load_state_dict(loadedState)
176
+ else:
177
+ new_state = {}
178
+
179
+ for k, v in loadedState.items():
180
+ new_state[k.replace("model.module.", "")] = v
181
+ info = self.load_state_dict(new_state, strict=False)
182
+ print(info)
loss_multi.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import utils.distributed as du
5
+
6
+
7
+ class lossAV(nn.Module):
8
+
9
+ def __init__(self):
10
+ super(lossAV, self).__init__()
11
+ self.criterion = nn.CrossEntropyLoss(reduction='none')
12
+ self.FC = nn.Linear(256, 2)
13
+
14
+ def forward(self, x, labels=None, masks=None):
15
+ x = x.squeeze(1)
16
+ x = self.FC(x)
17
+ if labels == None:
18
+ predScore = x[:, 1]
19
+ predScore = predScore.t()
20
+ predScore = predScore.view(-1).detach().cpu().numpy()
21
+ return predScore
22
+ else:
23
+ nloss = self.criterion(x, labels) * masks
24
+
25
+ num_valid = masks.sum().float()
26
+ if self.training:
27
+ [num_valid] = du.all_reduce([num_valid],average=True)
28
+ nloss = torch.sum(nloss) / num_valid
29
+
30
+ predScore = F.softmax(x, dim=-1)
31
+ predLabel = torch.round(F.softmax(x, dim=-1))[:, 1]
32
+ correctNum = ((predLabel == labels) * masks).sum().float()
33
+ return nloss, predScore, predLabel, correctNum
34
+
35
+
36
+ class lossA(nn.Module):
37
+
38
+ def __init__(self):
39
+ super(lossA, self).__init__()
40
+ self.criterion = nn.CrossEntropyLoss(reduction='none')
41
+ self.FC = nn.Linear(128, 2)
42
+
43
+ def forward(self, x, labels, masks=None):
44
+ x = x.squeeze(1)
45
+ x = self.FC(x)
46
+ nloss = self.criterion(x, labels) * masks
47
+ num_valid = masks.sum().float()
48
+ if self.training:
49
+ [num_valid] = du.all_reduce([num_valid],average=True)
50
+ nloss = torch.sum(nloss) / num_valid
51
+ #nloss = torch.sum(nloss) / torch.sum(masks)
52
+ return nloss
53
+
54
+
55
+ class lossV(nn.Module):
56
+
57
+ def __init__(self):
58
+ super(lossV, self).__init__()
59
+
60
+ self.criterion = nn.CrossEntropyLoss(reduction='none')
61
+ self.FC = nn.Linear(128, 2)
62
+
63
+ def forward(self, x, labels, masks=None):
64
+ x = x.squeeze(1)
65
+ x = self.FC(x)
66
+ nloss = self.criterion(x, labels) * masks
67
+ # nloss = torch.sum(nloss) / torch.sum(masks)
68
+ num_valid = masks.sum().float()
69
+ if self.training:
70
+ [num_valid] = du.all_reduce([num_valid],average=True)
71
+ nloss = torch.sum(nloss) / num_valid
72
+ return nloss
metrics/AverageMeter.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #taken from pytorch imagenet example
2
+ class AverageMeter(object):
3
+ """Computes and stores the average and current value"""
4
+ def __init__(self):
5
+ self.reset()
6
+
7
+ def reset(self):
8
+ self.val = 0
9
+ self.avg = 0
10
+ self.sum = 0
11
+ self.count = 0
12
+
13
+ def update(self, val, n=1):
14
+ self.val = val
15
+ self.sum += val * n
16
+ self.count += n
17
+ self.avg = self.sum / self.count
18
+
metrics/__pycache__/.nfs000000035f4a8257000000eb ADDED
Binary file (896 Bytes). View file
 
metrics/__pycache__/AverageMeter.cpython-36.pyc ADDED
Binary file (897 Bytes). View file
 
metrics/__pycache__/AverageMeter.cpython-38.pyc ADDED
Binary file (908 Bytes). View file
 
metrics/__pycache__/accuracy.cpython-36.pyc ADDED
Binary file (870 Bytes). View file
 
metrics/__pycache__/accuracy.cpython-38.pyc ADDED
Binary file (876 Bytes). View file
 
metrics/accuracy.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ accuracy = lambda output,target : acc_topk(output, target)[0]
4
+
5
+ #taken from pytorch imagenet example
6
+ def acc_topk(output, target, topk=(1,)):
7
+ """Computes the accuracy over the k top predictions for the specified values of k"""
8
+ with torch.no_grad():
9
+ maxk = max(topk)
10
+ batch_size = target.size(0)
11
+
12
+ _, pred = output.topk(maxk, 1, True, True)
13
+ pred = pred.t()
14
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
15
+
16
+ res = []
17
+ for k in topk:
18
+ correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
19
+ res.append(correct_k.mul_(1.0 / batch_size))
20
+ return res
model/.DS_Store ADDED
Binary file (6.15 kB). View file
 
model/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from model.transformer.position_encoding import PositionalEncoding
2
+ from model.transformer.transformer import Transformer
3
+ from model.transformer.transformer import TransformerEncoder, TransformerEncoderLayer
4
+ from model.transformer.transformer import TransformerDecoder, TransformerDecoderLayer
5
+ from model.transformer.utils import layer_norm, generate_square_subsequent_mask, generate_proposal_mask
model/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (561 Bytes). View file
 
model/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (573 Bytes). View file
 
model/__pycache__/attentionLayer.cpython-37.pyc ADDED
Binary file (1.38 kB). View file
 
model/__pycache__/convLayer.cpython-37.pyc ADDED
Binary file (1.32 kB). View file
 
model/__pycache__/loconet_encoder.cpython-37.pyc ADDED
Binary file (3.21 kB). View file
 
model/__pycache__/position_encoding.cpython-36.pyc ADDED
Binary file (1.26 kB). View file
 
model/__pycache__/talkNetModel.cpython-37.pyc ADDED
Binary file (6.33 kB). View file
 
model/__pycache__/transformer.cpython-36.pyc ADDED
Binary file (8.84 kB). View file
 
model/__pycache__/utils.cpython-36.pyc ADDED
Binary file (1.08 kB). View file