studyfar commited on
Commit
41f97d1
1 Parent(s): 1559c7d
bvh2ts.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from imusim.all import *
2
+ import imusim
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ import multiprocessing
6
+ import os
7
+
8
+ with open('./bvh/000000.bvh', 'r') as file:
9
+ lines = file.readlines()
10
+ line_109 = lines[108]
11
+ frame_time = line_109.split(': ')[1].strip()
12
+ frame_time_value = float(frame_time)
13
+ print(frame_time_value)
14
+
15
+ def process_file(f):
16
+
17
+ imu_file_path = './output/%s.npy' % f
18
+ if not os.path.exists(imu_file_path):
19
+
20
+ samplingPeriod = frame_time_value
21
+ imu = Orient3IMU()
22
+ env = Environment()
23
+
24
+ samples = 1000
25
+ rotationalVelocity = 20
26
+ calibrator = ScaleAndOffsetCalibrator(env, samples, samplingPeriod, rotationalVelocity)
27
+ calibration = calibrator.calibrate(imu)
28
+
29
+ try:
30
+ model = loadBVHFile('./bvh/%s.bvh' % f)
31
+ splinedModel = SplinedBodyModel(model)
32
+
33
+ imu_list = []
34
+ for i in range(22):
35
+ sim = Simulation(environment=env)
36
+ imu.simulation = sim
37
+
38
+ if i not in [4,8,13,17,21]:
39
+ imu.trajectory = splinedModel.getJoint('joint_%s' % str(i))
40
+ else:
41
+ imu.trajectory = splinedModel.getPoint('joint_%s_end' % str(i-1))
42
+
43
+ sim.time = splinedModel.startTime
44
+ BasicIMUBehaviour(imu, samplingPeriod, calibration, initialTime=sim.time)
45
+ sim.run(splinedModel.endTime, printProgress=False)
46
+
47
+ acc = imu.accelerometer.calibratedMeasurements.values
48
+ gyro = imu.gyroscope.calibratedMeasurements.values
49
+
50
+ imu_npy = np.concatenate((acc, gyro), axis=0)
51
+ imu_list.append(imu_npy)
52
+
53
+ imu_npy = np.stack(imu_list, axis=1).transpose(2,1,0)
54
+ np.save('./output/%s' % f, imu_npy)
55
+
56
+ except (imusim.maths.splines.Spline.InsufficientPointsError, AttributeError, IndexError) as e:
57
+ print(f"Error processing file {f}: {e}. Skipping.")
58
+ with open('log.txt', 'a') as log_file:
59
+ log_file.write(f + '\n')
60
+
61
+ source_dir = './bvh'
62
+ npy_files = [file[:-4] for file in os.listdir(source_dir) if file.endswith('.bvh')]
63
+
64
+ # Process files in parallel
65
+ pool = multiprocessing.Pool(processes=8)
66
+ for _ in tqdm(pool.imap_unordered(process_file, npy_files), total=len(npy_files)):
67
+ pass
68
+ pool.close()
69
+ pool.join()
checkpoint/UniMTS.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9858c0084d936655240407e30ff9db9adeded6a67dc5650e3f667578e93b220
3
+ size 274583082
contrastive.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import clip
4
+ from model import ST_GCN_18
5
+
6
+ class ContrastiveModule(nn.Module):
7
+
8
+ def __init__(self, args):
9
+ super(ContrastiveModule, self).__init__()
10
+
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ model, preprocess = clip.load("ViT-B/32", device=device)
13
+ del model.visual
14
+ self.model = model
15
+
16
+ base_channel = 3
17
+ base_channel = base_channel * 2 if args.gyro else base_channel
18
+ base_channel = base_channel * 2 if args.stft else base_channel
19
+ self.model.acc = ST_GCN_18(in_channels=base_channel)
20
+
21
+ self.model = self.model.float()
22
+
23
+ if args.stage == 'finetune':
24
+ self.fc = nn.Linear(512, args.num_class)
25
+
26
+ def encode_image(self, image):
27
+ return self.model.acc(image.float()).squeeze(-1).squeeze(-1)
28
+
29
+ def encode_text(self, text):
30
+ x = self.model.token_embedding(text).float() # b,t,512
31
+ x = x + self.model.positional_embedding.float()
32
+ x = x.permute(1, 0, 2) # b,t,512 -> t,b,512
33
+ x = self.model.transformer(x)
34
+ x = x.permute(1, 0, 2) # t,b,512 -> b,t,512
35
+ x = self.model.ln_final(x).float() # b,t,512
36
+
37
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
38
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection # b,512
39
+
40
+ return x
41
+
42
+ def classifier(self, image):
43
+ # for fine-tuning
44
+ imu_features = self.model.acc(image.float()).squeeze(-1).squeeze(-1)
45
+ out = self.fc(imu_features)
46
+ return out
47
+
48
+ def forward(self, inputs_imu, inputs_text):
49
+
50
+ imu_features = self.encode_image(inputs_imu)
51
+ text_features = self.encode_text(inputs_text)
52
+
53
+ # normalized features
54
+ imu_features = imu_features / imu_features.norm(dim=-1, keepdim=True)
55
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
56
+
57
+ # logits
58
+ logit_scale = self.model.logit_scale.exp()
59
+ logits_per_image = logit_scale * imu_features @ text_features.t()
60
+ logits_per_text = logits_per_image.t()
61
+
62
+ return logits_per_image, logits_per_text
data.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import random
4
+ import os
5
+ import json
6
+ from scipy.signal import resample
7
+ import clip
8
+ from torch.utils.data import Dataset
9
+
10
+ class CLIPDataset(Dataset):
11
+
12
+ def __init__(self, args):
13
+
14
+ imu_dirs = [
15
+ f'{args.data_path}/sim/',
16
+ ]
17
+ text_dirs = [
18
+ f'{args.data_path}/aug_texts/',
19
+ ]
20
+ self.paths = []
21
+ for imu_dir, text_dir in zip(imu_dirs, text_dirs):
22
+ imu_files = [f.split('.')[0] for f in os.listdir(imu_dir) if os.path.isfile(os.path.join(imu_dir, f))]
23
+ text_files = [f.split('.')[0] for f in os.listdir(text_dir) if os.path.isfile(os.path.join(text_dir, f))]
24
+ common_files = [f for f in imu_files if f in text_files]
25
+ for f in common_files:
26
+ self.paths.append((os.path.join(imu_dir, f + '.npy'), os.path.join(text_dir, f + '.txt')))
27
+
28
+ self.args = args
29
+ if args.sample < 1:
30
+ self.paths = random.sample(self.paths, int(len(self.paths) * args.sample))
31
+
32
+ def __len__(self):
33
+ return len(self.paths)
34
+
35
+ def __getitem__(self, idx):
36
+
37
+ # load imu
38
+ imu_path, text_path = self.paths[idx]
39
+ imu = np.load(imu_path)
40
+ imu[np.isnan(imu)] = 0
41
+
42
+ # padding
43
+ if len(imu) < self.args.padding_size:
44
+ imu = np.pad(imu, ((0, self.args.padding_size - len(imu)), (0, 0), (0, 0)), mode='wrap')
45
+ imu = imu[:self.args.padding_size]
46
+
47
+ # random masking
48
+ mask = np.zeros_like(imu)
49
+ k = np.random.randint(1, 6) # randomly select k joints
50
+ selected_joints = np.random.choice(22, k, replace=False)
51
+ mask[:,selected_joints] = 1
52
+ imu = imu.reshape(len(imu), -1)
53
+ mask = mask.reshape(len(mask), -1)
54
+
55
+ # load text
56
+ with open(text_path, 'r') as file:
57
+ lines = file.readlines()
58
+
59
+ text = random.choice(lines).split('#')[0].strip() # remove the comment starting from "#"
60
+
61
+ batch = {}
62
+ batch['imu'] = imu
63
+ batch['text'] = text
64
+ batch['mask'] = mask
65
+
66
+ return batch
67
+
68
+ def select_samples(data, masks, labels, k, name, data_path):
69
+ unique_labels = torch.unique(labels)
70
+ selected_data = []
71
+ selected_masks = []
72
+ selected_labels = []
73
+ all_indices = torch.load(f'{data_path}/few_shot_data_2/{name}_k={k}.pth')
74
+
75
+ for i, label in enumerate(unique_labels):
76
+ selected_indices = all_indices[i]
77
+ selected_data.append(data[selected_indices])
78
+ selected_masks.append(masks[selected_indices])
79
+ selected_labels.append(labels[selected_indices])
80
+
81
+ selected_data = torch.cat(selected_data, dim=0)
82
+ selected_masks = torch.cat(selected_masks, dim=0)
83
+ selected_labels = torch.cat(selected_labels, dim=0)
84
+
85
+ return selected_data, selected_masks, selected_labels
86
+
87
+ def load(dataset, padding_size, data_path, split='test', k=None):
88
+
89
+ print(dataset)
90
+
91
+ X = np.load(f'{data_path}/{dataset}/X_{split}.npy')
92
+ real_labels = torch.from_numpy(np.load(f'{data_path}/{dataset}/y_{split}.npy'))
93
+ with open(f'{data_path}/{dataset}/{dataset}.json', 'r') as file:
94
+ data = json.load(file)
95
+ all_X = np.zeros((X.shape[0], X.shape[1], 22, 6))
96
+
97
+ if dataset == 'PAMAP':
98
+ all_X[:,:,21] = np.concatenate((X[:,:,0:3], X[:,:,3:6]), axis=-1)
99
+ all_X[:,:,11] = np.concatenate((X[:,:,18:21], X[:,:,21:24]), axis=-1)
100
+ all_X[:,:,7] = np.concatenate((X[:,:,9:12], X[:,:,12:15]), axis=-1)
101
+ original_sampling_rate = 100
102
+ num_classes = 12
103
+
104
+ elif dataset == 'USCHAD':
105
+ all_X[:,:,5] = np.concatenate((X[:,:,0:3] * 9.80665, X[:,:,3:6] / 180 * np.pi), axis=-1)
106
+ original_sampling_rate = 100
107
+ num_classes = 12
108
+
109
+ elif dataset == 'UCIHAR':
110
+ all_X[:,:,9] = np.concatenate((X[:,:,6:9] * 9.80665, X[:,:,3:6]), axis=-1) # linear accel, gyro, total accel
111
+ original_sampling_rate = 50
112
+ num_classes = 6
113
+
114
+ elif dataset == 'Opp_g':
115
+ all_X[:,:,10] = np.concatenate((X[:,:,0:3] / 1000 * 9.8, X[:,:,3:6] / 1000), axis=-1) # convert unit from milli g to m/s^2
116
+ all_X[:,:,19] = np.concatenate((X[:,:,9:12] / 1000 * 9.8, X[:,:,12:15] / 1000), axis=-1)
117
+ all_X[:,:,20] = np.concatenate((X[:,:,18:21] / 1000 * 9.8, X[:,:,21:24] / 1000), axis=-1)
118
+ all_X[:,:,15] = np.concatenate((X[:,:,27:30] / 1000 * 9.8, X[:,:,30:33] / 1000), axis=-1)
119
+ all_X[:,:,16] = np.concatenate((X[:,:,36:39] / 1000 * 9.8, X[:,:,39:42] / 1000), axis=-1)
120
+ original_sampling_rate = 30
121
+ num_classes = 4 # locomotion
122
+
123
+ elif dataset == 'WISDM':
124
+ all_X[:,:,21] = np.concatenate((X[:,:,0:3], X[:,:,3:6]), axis=-1)
125
+ original_sampling_rate = 20
126
+ num_classes = 18
127
+
128
+ elif dataset == 'DSADS':
129
+ all_X[:,:,11] = np.concatenate((X[:,:,0:3], X[:,:,3:6]), axis=-1)
130
+ all_X[:,:,21] = np.concatenate((X[:,:,9:12], X[:,:,12:15]), axis=-1)
131
+ all_X[:,:,17] = np.concatenate((X[:,:,18:21], X[:,:,21:24]), axis=-1)
132
+ all_X[:,:,6] = np.concatenate((X[:,:,27:30], X[:,:,30:33]), axis=-1)
133
+ all_X[:,:,2] = np.concatenate((X[:,:,36:39], X[:,:,39:42]), axis=-1)
134
+ original_sampling_rate = 25
135
+ num_classes = 19
136
+
137
+ elif dataset == 'Harth':
138
+ all_X[:,:,9,:3] = X[:,:,:3] * 9.80665
139
+ all_X[:,:,6,:3] = X[:,:,3:6] * 9.80665
140
+ original_sampling_rate = 50
141
+ num_classes = 12
142
+
143
+ elif dataset == 'Wharf':
144
+ X = -14.709 + X / 63 * (2 * 14.709)
145
+ all_X[:,:,21,:3] = X
146
+ original_sampling_rate = 32
147
+ num_classes = 14
148
+
149
+ elif dataset == 'Mhealth':
150
+ all_X[:,:,11,:3] = X[:,:,0:3]
151
+ all_X[:,:,3] = np.concatenate((X[:,:,6:9], X[:,:,9:12] / 180 * np.pi), axis=-1)
152
+ all_X[:,:,21] = np.concatenate((X[:,:,15:18], X[:,:,18:21] / 180 * np.pi), axis=-1)
153
+ original_sampling_rate = 50
154
+ num_classes = 12
155
+
156
+ elif dataset == 'UTD-MHAD':
157
+ all_X[real_labels < 21,:,21,:] = np.concatenate((X[real_labels < 21,:,0:3] * 9.80665, X[real_labels < 21,:,3:6] / 180 * np.pi), axis=-1)
158
+ all_X[real_labels >= 21,:,5,:] = np.concatenate((X[real_labels >= 21,:,0:3] * 9.80665, X[real_labels >= 21,:,3:6] / 180 * np.pi), axis=-1)
159
+ original_sampling_rate = 50
160
+ num_classes = 27
161
+
162
+ elif dataset == 'MotionSense':
163
+ all_X[:,:,5] = np.concatenate((X[:,:,:3] * 9.80665, X[:,:,3:6]), axis=-1)
164
+ all_X[:,:,1] = np.concatenate((X[:,:,:3] * 9.80665, X[:,:,3:6]), axis=-1)
165
+ original_sampling_rate = 50
166
+ num_classes = 6
167
+
168
+ elif dataset == 'w-HAR':
169
+ all_X[:,:,7] = np.concatenate((X[:,:,:3] * 9.80665, X[:,:,3:6] / 180 * np.pi), axis=-1)
170
+ original_sampling_rate = 250
171
+ num_classes = 7
172
+
173
+ elif dataset == 'Shoaib':
174
+ all_X[:,:,1] = X[:,:,:6]
175
+ all_X[:,:,5] = X[:,:,6:12]
176
+ all_X[:,:,21] = X[:,:,12:18]
177
+ all_X[:,:,20] = X[:,:,18:24]
178
+ all_X[:,:,0] = X[:,:,24:30]
179
+ original_sampling_rate = 50
180
+ num_classes = 7
181
+
182
+ elif dataset == 'har70plus':
183
+ all_X[:,:,0,:3] = X[:,:,:3] * 9.80665
184
+ all_X[:,:,5,:3] = X[:,:,3:6] * 9.80665
185
+ original_sampling_rate = 50
186
+ num_classes = 7
187
+
188
+ elif dataset == 'MMAct':
189
+ all_X[:,:,5] = np.concatenate((X[:,:,:3], X[:,:,3:6]), axis=-1)
190
+ all_X[:,:,21,:3] = X[:,:,6:9]
191
+ original_sampling_rate = 50
192
+ num_classes = 35
193
+
194
+ elif dataset == 'realworld':
195
+ all_X[:,:,14] = np.concatenate((X[:,:,:3], X[:,:,3:6]), axis=-1)
196
+ all_X[:,:,16] = np.concatenate((X[:,:,6:9], X[:,:,9:12]), axis=-1)
197
+ all_X[:,:,13] = np.concatenate((X[:,:,12:15], X[:,:,15:18]), axis=-1)
198
+ all_X[:,:,3] = np.concatenate((X[:,:,18:21], X[:,:,21:24]), axis=-1)
199
+ all_X[:,:,1] = np.concatenate((X[:,:,24:27], X[:,:,27:30]), axis=-1)
200
+ all_X[:,:,15] = np.concatenate((X[:,:,30:33], X[:,:,33:36]), axis=-1)
201
+ all_X[:,:,9] = np.concatenate((X[:,:,36:39], X[:,:,39:42]), axis=-1)
202
+ original_sampling_rate = 50
203
+ num_classes = 8
204
+
205
+ elif dataset == 'TNDA-HAR':
206
+ all_X[:,:,20] = np.concatenate((X[:,:,:3], X[:,:,3:6]), axis=-1)
207
+ all_X[:,:,2] = np.concatenate((X[:,:,6:9], X[:,:,9:12]), axis=-1)
208
+ all_X[:,:,21] = np.concatenate((X[:,:,12:15], X[:,:,15:18]), axis=-1)
209
+ all_X[:,:,3] = np.concatenate((X[:,:,18:21], X[:,:,21:24]), axis=-1)
210
+ all_X[:,:,11] = np.concatenate((X[:,:,24:27], X[:,:,27:30]), axis=-1)
211
+ original_sampling_rate = 50
212
+ num_classes = 8
213
+
214
+ elif dataset == 'ut-complex':
215
+ all_X[:,:,5] = np.concatenate((X[:,:,:3], X[:,:,3:6]), axis=-1)
216
+ all_X[:,:,21] = np.concatenate((X[:,:,6:9], X[:,:,9:12]), axis=-1)
217
+ original_sampling_rate = 50
218
+ num_classes = 13
219
+
220
+ all_X = all_X.reshape(all_X.shape[0], all_X.shape[1], 22 * 6)
221
+
222
+ # resample real data to 20 Hz
223
+ new_sampling_rate = 20
224
+ new_length = int((all_X.shape[1] / original_sampling_rate) * new_sampling_rate)
225
+ resampled_data = np.array([resample(sequence, new_length) for sequence in all_X])
226
+
227
+ # pad real data to args.padding_size
228
+ masks = np.ones_like(resampled_data)
229
+ if resampled_data.shape[1] < padding_size:
230
+ resampled_data = np.pad(resampled_data, ((0, 0), (0, padding_size - resampled_data.shape[1]), (0, 0)), 'wrap') # N, 200, 6
231
+ masks = np.pad(masks, ((0, 0), (0, padding_size - masks.shape[1]), (0, 0)), 'constant') # N, 200, 6
232
+ real_inputs = torch.from_numpy(resampled_data[:,:padding_size,:]).float()
233
+ real_masks = torch.from_numpy(masks[:,:padding_size,:]).float()
234
+
235
+ if split == 'train' and k and k < len(real_inputs):
236
+ real_inputs, real_masks, real_labels = select_samples(real_inputs, real_masks, real_labels, k, dataset, data_path)
237
+ print(real_inputs.shape, real_labels.shape)
238
+
239
+ # load text
240
+ label_dictionary = data['label_dictionary']
241
+ label_list = [' '.join(labels) for labels in label_dictionary.values()]
242
+ all_text = clip.tokenize(label_list).cuda()
243
+
244
+ return real_inputs, real_masks, real_labels, label_list, all_text, num_classes
245
+
246
+ def load_multiple(dataset_list, padding_size, data_path, split='test', k=None):
247
+
248
+ real_inputs_list, real_masks_list, real_labels_list, label_list_list, all_text_list, num_classes_list = [], [], [], [], [], []
249
+ for dataset in dataset_list:
250
+ real_inputs, real_masks, real_labels, label_list, all_text, num_classes = load(dataset, padding_size, data_path, split, k)
251
+ real_inputs_list.append(real_inputs)
252
+ real_masks_list.append(real_masks)
253
+ real_labels_list.append(real_labels)
254
+ label_list_list.append(label_list)
255
+ all_text_list.append(all_text)
256
+ num_classes_list.append(num_classes)
257
+
258
+ return real_inputs_list, real_masks_list, real_labels_list, label_list_list, all_text_list, num_classes_list
259
+
260
+ def load_custom_data(X_path, y_path, config_path, joint_list, original_sampling_rate, padding_size=200, split='test', k=None, few_shot_path=None):
261
+
262
+ X = np.load(X_path)
263
+ real_labels = torch.from_numpy(np.load(y_path))
264
+ with open(config_path, 'r') as file:
265
+ data = json.load(file)
266
+ all_X = np.zeros((X.shape[0], X.shape[1], 22, 6))
267
+
268
+ for i, joint in enumerate(joint_list):
269
+ all_X[:,:,joint] = np.concatenate((X[:,:,6*i:6*i+3], X[:,:,6*i+3:6*i+6]), axis=-1)
270
+
271
+ all_X = all_X.reshape(all_X.shape[0], all_X.shape[1], 22 * 6)
272
+
273
+ # resample real data to 20 Hz
274
+ new_sampling_rate = 20
275
+ new_length = int((all_X.shape[1] / original_sampling_rate) * new_sampling_rate)
276
+ resampled_data = np.array([resample(sequence, new_length) for sequence in all_X])
277
+
278
+ # pad real data to args.padding_size
279
+ masks = np.ones_like(resampled_data)
280
+ if resampled_data.shape[1] < padding_size:
281
+ resampled_data = np.pad(resampled_data, ((0, 0), (0, padding_size - resampled_data.shape[1]), (0, 0)), 'wrap') # N, 200, 6
282
+ masks = np.pad(masks, ((0, 0), (0, padding_size - masks.shape[1]), (0, 0)), 'constant') # N, 200, 6
283
+ real_inputs = torch.from_numpy(resampled_data[:,:padding_size,:]).float()
284
+ real_masks = torch.from_numpy(masks[:,:padding_size,:]).float()
285
+
286
+ if split == 'train' and k and k < len(real_inputs):
287
+
288
+ unique_labels = torch.unique(real_labels)
289
+
290
+ if few_shot_path is None:
291
+ print('Generating few shot indices ...')
292
+ all_indices = []
293
+ for i, label in enumerate(unique_labels):
294
+ indices = torch.where(real_labels == label)[0]
295
+ selected_indices = indices[torch.randperm(len(indices))[:k]]
296
+ all_indices.append(selected_indices)
297
+ else:
298
+ print('Loading existing few shot indices ...')
299
+ all_indices = torch.load(few_shot_path)
300
+
301
+ selected_data = []
302
+ selected_masks = []
303
+ selected_labels = []
304
+ for i, label in enumerate(unique_labels):
305
+ selected_indices = all_indices[i]
306
+ selected_data.append(real_inputs[selected_indices])
307
+ selected_masks.append(real_masks[selected_indices])
308
+ selected_labels.append(real_labels[selected_indices])
309
+ selected_data = torch.cat(selected_data, dim=0)
310
+ selected_masks = torch.cat(selected_masks, dim=0)
311
+ selected_labels = torch.cat(selected_labels, dim=0)
312
+ real_inputs, real_masks, real_labels = selected_data, selected_masks, selected_labels
313
+
314
+ print(real_inputs.shape, real_labels.shape)
315
+
316
+ # load text
317
+ label_dictionary = data['label_dictionary']
318
+ label_list = [' '.join(labels) for labels in label_dictionary.values()]
319
+ all_text = clip.tokenize(label_list).cuda()
320
+
321
+ return real_inputs, real_masks, real_labels, label_list, all_text
evaluate.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import argparse
4
+ import os
5
+ import numpy as np
6
+ from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
7
+ import wandb
8
+ import datetime
9
+ from torch.utils.data import DataLoader, TensorDataset
10
+
11
+ from data import load, load_multiple
12
+ from utils import compute_metrics_np
13
+ from contrastive import ContrastiveModule
14
+
15
+ def main(args):
16
+ # load real data
17
+ dataset_list = ['Opp_g','UCIHAR','MotionSense','w-HAR','Shoaib','har70plus','realworld','TNDA-HAR','PAMAP',\
18
+ 'USCHAD','Mhealth','Harth','ut-complex','Wharf','WISDM','DSADS','UTD-MHAD','MMAct']
19
+ real_inputs_list, real_masks_list, real_labels_list, label_list_list, all_text_list, _ = load_multiple(dataset_list, args.padding_size, args.data_path)
20
+ test_real_dataloader_list = []
21
+ for real_inputs, real_masks, real_labels in zip(real_inputs_list, real_masks_list, real_labels_list):
22
+ real_dataset = TensorDataset(real_inputs, real_masks, real_labels)
23
+ test_real_dataloader_list.append(DataLoader(real_dataset, batch_size=args.batch_size, shuffle=False))
24
+
25
+ date = datetime.datetime.now().strftime("%d-%m-%y_%H:%M")
26
+ wandb.init(
27
+ project='UniMTS',
28
+ name=f"{args.run_tag}_{args.stage}_" + f"{date}"
29
+ )
30
+
31
+ model = ContrastiveModule(args).cuda()
32
+
33
+ model.model.load_state_dict(torch.load(f'{args.checkpoint}'))
34
+
35
+ model.eval()
36
+ with torch.no_grad():
37
+ for ds, real_labels, test_real_dataloader, label_list, all_text in zip(dataset_list, real_labels_list, test_real_dataloader_list, label_list_list, all_text_list):
38
+ pred_whole, logits_whole = [], []
39
+ for input, mask, label in test_real_dataloader:
40
+
41
+ input = input.cuda()
42
+ mask = mask.cuda()
43
+ label = label.cuda()
44
+
45
+ if not args.gyro:
46
+ b, t, c = input.shape
47
+ indices = np.array([range(i, i+3) for i in range(0, c, 6)]).flatten()
48
+ input = input[:,:,indices]
49
+
50
+ b, t, c = input.shape
51
+ if args.stft:
52
+ input_stft = input.permute(0,2,1).reshape(b * c,t)
53
+ input_stft = torch.abs(torch.stft(input_stft, n_fft = 25, hop_length = 28, onesided = False, center = True, return_complex = True))
54
+ input_stft = input_stft.reshape(b, c, input_stft.shape[-2], input_stft.shape[-1]).reshape(b, c, t).permute(0,2,1)
55
+ input = torch.cat((input, input_stft), dim=-1)
56
+
57
+ input = input.reshape(b, t, 22, -1).permute(0, 3, 1, 2).unsqueeze(-1)
58
+
59
+ logits_per_imu, logits_per_text = model(input, all_text)
60
+ logits_whole.append(logits_per_imu)
61
+
62
+ pred = torch.argmax(logits_per_imu, dim=-1).detach().cpu().numpy()
63
+ pred_whole.append(pred)
64
+
65
+ pred = np.concatenate(pred_whole)
66
+ acc = accuracy_score(real_labels, pred)
67
+ prec = precision_score(real_labels, pred, average='macro')
68
+ rec = recall_score(real_labels, pred, average='macro')
69
+ f1 = f1_score(real_labels, pred, average='macro')
70
+
71
+ print(f"{ds} acc: {acc}, {ds} prec: {prec}, {ds} rec: {rec}, {ds} f1: {f1}")
72
+ wandb.log({f"{ds} acc": acc, f"{ds} prec": prec, f"{ds} rec": rec, f"{ds} f1": f1})
73
+
74
+ logits_whole = torch.cat(logits_whole)
75
+ r_at_1, r_at_2, r_at_3, r_at_4, r_at_5, mrr_score = compute_metrics_np(logits_whole.detach().cpu().numpy(), real_labels.numpy())
76
+
77
+ print(f"{ds} R@1: {r_at_1}, R@2: {r_at_2}, R@3: {r_at_3}, R@4: {r_at_4}, R@5: {r_at_5}, MRR: {mrr_score}")
78
+ wandb.log({f"{ds} R@1": r_at_1, f"{ds} R@2": r_at_2, f"{ds} R@3": r_at_3, f"{ds} R@4": r_at_4, f"{ds} R@5": r_at_5, f"{ds} MRR": mrr_score})
79
+
80
+ if __name__ == "__main__":
81
+
82
+ parser = argparse.ArgumentParser(description='Unified Pre-trained Motion Time Series Model')
83
+
84
+ # data
85
+ parser.add_argument('--padding_size', type=int, default='200', help='padding size (default: 200)')
86
+ parser.add_argument('--data_path', type=str, default='./data/', help='/path/to/data/')
87
+
88
+ # training
89
+ parser.add_argument('--run_tag', type=str, default='exp0', help='logging tag')
90
+ parser.add_argument('--stage', type=str, default='evaluation', help='training or evaluation stage')
91
+ parser.add_argument('--gyro', type=int, default=0, help='using gyro or not')
92
+ parser.add_argument('--stft', type=int, default=0, help='using stft or not')
93
+ parser.add_argument('--batch_size', type=int, default=64, help='batch size')
94
+
95
+ parser.add_argument('--checkpoint', type=str, default='./checkpoint/', help='/path/to/checkpoint/')
96
+
97
+ args = parser.parse_args()
98
+
99
+ main(args)
evaluate_custom.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import argparse
4
+ import os
5
+ import numpy as np
6
+ from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
7
+ import wandb
8
+ import datetime
9
+ from torch.utils.data import DataLoader, TensorDataset
10
+
11
+ from data import load, load_multiple, load_custom_data
12
+ from utils import compute_metrics_np
13
+ from contrastive import ContrastiveModule
14
+
15
+ def main(args):
16
+ # load real data
17
+
18
+ real_inputs, real_masks, real_labels, label_list, all_text = load_custom_data(
19
+ args.X_path, args.y_path, args.config_path, args.joint_list, args.original_sampling_rate, padding_size=args.padding_size, split='test'
20
+ )
21
+ real_dataset = TensorDataset(real_inputs, real_masks, real_labels)
22
+ test_real_dataloader = DataLoader(real_dataset, batch_size=args.batch_size, shuffle=False)
23
+
24
+ date = datetime.datetime.now().strftime("%d-%m-%y_%H:%M")
25
+ wandb.init(
26
+ project='UniMTS',
27
+ name=f"{args.run_tag}_{args.stage}_" + f"{date}"
28
+ )
29
+
30
+ model = ContrastiveModule(args).cuda()
31
+
32
+ model.model.load_state_dict(torch.load(f'{args.checkpoint}'))
33
+
34
+ model.eval()
35
+ with torch.no_grad():
36
+ pred_whole, logits_whole = [], []
37
+ for input, mask, label in test_real_dataloader:
38
+
39
+ input = input.cuda()
40
+ mask = mask.cuda()
41
+ label = label.cuda()
42
+
43
+ if not args.gyro:
44
+ b, t, c = input.shape
45
+ indices = np.array([range(i, i+3) for i in range(0, c, 6)]).flatten()
46
+ input = input[:,:,indices]
47
+
48
+ b, t, c = input.shape
49
+ if args.stft:
50
+ input_stft = input.permute(0,2,1).reshape(b * c,t)
51
+ input_stft = torch.abs(torch.stft(input_stft, n_fft = 25, hop_length = 28, onesided = False, center = True, return_complex = True))
52
+ input_stft = input_stft.reshape(b, c, input_stft.shape[-2], input_stft.shape[-1]).reshape(b, c, t).permute(0,2,1)
53
+ input = torch.cat((input, input_stft), dim=-1)
54
+
55
+ input = input.reshape(b, t, 22, -1).permute(0, 3, 1, 2).unsqueeze(-1)
56
+
57
+ logits_per_imu, logits_per_text = model(input, all_text)
58
+ logits_whole.append(logits_per_imu)
59
+
60
+ pred = torch.argmax(logits_per_imu, dim=-1).detach().cpu().numpy()
61
+ pred_whole.append(pred)
62
+
63
+ pred = np.concatenate(pred_whole)
64
+ acc = accuracy_score(real_labels, pred)
65
+ prec = precision_score(real_labels, pred, average='macro')
66
+ rec = recall_score(real_labels, pred, average='macro')
67
+ f1 = f1_score(real_labels, pred, average='macro')
68
+
69
+ print(f"acc: {acc}, prec: {prec}, rec: {rec}, f1: {f1}")
70
+ wandb.log({f"acc": acc, f"prec": prec, f"rec": rec, f"f1": f1})
71
+
72
+ logits_whole = torch.cat(logits_whole)
73
+ r_at_1, r_at_2, r_at_3, r_at_4, r_at_5, mrr_score = compute_metrics_np(logits_whole.detach().cpu().numpy(), real_labels.numpy())
74
+
75
+ print(f"R@1: {r_at_1}, R@2: {r_at_2}, R@3: {r_at_3}, R@4: {r_at_4}, R@5: {r_at_5}, MRR: {mrr_score}")
76
+ wandb.log({f"R@1": r_at_1, f"R@2": r_at_2, f"R@3": r_at_3, f"R@4": r_at_4, f"R@5": r_at_5, f"MRR": mrr_score})
77
+
78
+ if __name__ == "__main__":
79
+
80
+ parser = argparse.ArgumentParser(description='Unified Pre-trained Motion Time Series Model')
81
+
82
+ # data
83
+ parser.add_argument('--padding_size', type=int, default='200', help='padding size (default: 200)')
84
+ parser.add_argument('--X_path', type=str, required=True, help='/path/to/data/')
85
+ parser.add_argument('--y_path', type=str, required=True, help='/path/to/label/')
86
+ parser.add_argument('--config_path', type=str, required=True, help='/path/to/config/')
87
+ parser.add_argument('--joint_list', nargs='+', type=int, required=True, help='List of joint indices')
88
+ parser.add_argument('--original_sampling_rate', type=int, required=True, help='original sampling rate')
89
+
90
+ # training
91
+ parser.add_argument('--run_tag', type=str, default='exp0', help='logging tag')
92
+ parser.add_argument('--stage', type=str, default='evaluation', help='training or evaluation stage')
93
+ parser.add_argument('--gyro', type=int, default=0, help='using gyro or not')
94
+ parser.add_argument('--stft', type=int, default=0, help='using stft or not')
95
+ parser.add_argument('--batch_size', type=int, default=64, help='batch size')
96
+
97
+ parser.add_argument('--checkpoint', type=str, default='./checkpoint/', help='/path/to/checkpoint/')
98
+
99
+ args = parser.parse_args()
100
+
101
+ main(args)
finetune.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ import argparse
6
+ import os
7
+ import numpy as np
8
+ from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
9
+ import wandb
10
+ import datetime
11
+ from torch.utils.data import DataLoader, TensorDataset
12
+ import torch.optim as optim
13
+
14
+ from data import load_multiple
15
+ from utils import compute_metrics_np
16
+ from contrastive import ContrastiveModule
17
+
18
+ def main(args):
19
+
20
+ # load real data
21
+ dataset_list = ['Opp_g','UCIHAR','MotionSense','w-HAR','Shoaib','har70plus','realworld','TNDA-HAR','PAMAP',\
22
+ 'USCHAD','Mhealth','Harth','ut-complex','Wharf','WISDM','DSADS','UTD-MHAD','MMAct']
23
+ train_inputs_list, train_masks_list, train_labels_list, label_list_list, all_text_list, num_classes_list = load_multiple(dataset_list, args.padding_size, args.data_path, split='train', k=args.k)
24
+ test_inputs_list, test_masks_list, test_labels_list, label_list_list, all_text_list, _ = load_multiple(dataset_list, args.padding_size, args.data_path, split='test')
25
+ train_dataloader_list, test_dataloader_list = [], []
26
+ for real_inputs, real_masks, real_labels in zip(train_inputs_list, train_masks_list, train_labels_list):
27
+ train_dataset = TensorDataset(real_inputs, real_masks, real_labels)
28
+ train_dataloader_list.append(DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True))
29
+ for real_inputs, real_masks, real_labels in zip(test_inputs_list, test_masks_list, test_labels_list):
30
+ test_dataset = TensorDataset(real_inputs, real_masks, real_labels)
31
+ test_dataloader_list.append(DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False))
32
+
33
+ date = datetime.datetime.now().strftime("%d-%m-%y_%H:%M")
34
+ wandb.init(
35
+ project='UniMTS',
36
+ name=f"{args.run_tag}_{args.stage}_{args.mode}_k={args.k}_" + f"{date}"
37
+ )
38
+
39
+ save_path = './checkpoint/%s/' % args.run_tag
40
+
41
+ for ds, train_dataloader, test_dataloader, test_labels, label_list, all_text, num_class in \
42
+ zip(dataset_list, train_dataloader_list, test_dataloader_list, test_labels_list, label_list_list, all_text_list, num_classes_list):
43
+
44
+ args.num_class = num_class
45
+ model = ContrastiveModule(args).cuda()
46
+ optimizer = optim.Adam(model.parameters(), lr=1e-4)
47
+
48
+ if args.mode == 'full' or args.mode == 'probe':
49
+ model.model.load_state_dict(torch.load(f'{args.checkpoint}'))
50
+ if args.mode == 'probe':
51
+ for name, param in model.model.named_parameters():
52
+ param.requires_grad = False
53
+
54
+ best_loss = None
55
+ for epoch in range(args.num_epochs):
56
+
57
+ tol_loss = 0
58
+
59
+ model.train()
60
+ for i, (input, mask, label) in enumerate(train_dataloader):
61
+
62
+ input = input.cuda()
63
+ labels = label.cuda()
64
+
65
+ if not args.gyro:
66
+ b, t, c = input.shape
67
+ indices = np.array([range(i, i+3) for i in range(0, c, 6)]).flatten()
68
+ input = input[:,:,indices]
69
+
70
+ b, t, c = input.shape
71
+ if args.stft:
72
+ input_stft = input.permute(0,2,1).reshape(b * c,t)
73
+ input_stft = torch.abs(torch.stft(input_stft, n_fft = 25, hop_length = 28, onesided = False, center = True, return_complex = True))
74
+ input_stft = input_stft.reshape(b, c, input_stft.shape[-2], input_stft.shape[-1]).reshape(b, c, t).permute(0,2,1)
75
+ input = torch.cat((input, input_stft), dim=-1)
76
+
77
+ input = input.reshape(b, t, 22, -1).permute(0, 3, 1, 2).unsqueeze(-1)
78
+
79
+ output = model.classifier(input)
80
+
81
+ loss = F.cross_entropy(output.float(), labels.long(), reduction="mean")
82
+
83
+ optimizer.zero_grad()
84
+ loss.backward()
85
+ optimizer.step()
86
+
87
+ tol_loss += len(input) * loss.item()
88
+
89
+ # print(epoch, i, loss.item())
90
+
91
+ print(f'Epoch [{epoch+1}/{args.num_epochs}], Loss: {tol_loss / len(train_dataset):.4f}')
92
+ wandb.log({'{ds} loss': tol_loss / len(train_dataset)})
93
+
94
+ if best_loss is None or tol_loss < best_loss:
95
+ best_loss = tol_loss
96
+ torch.save(model.state_dict(), os.path.join(save_path, f'{ds}_k={args.k}_best_loss.pth'))
97
+
98
+ # evaluation
99
+ model.load_state_dict(torch.load(os.path.join(save_path, f'{ds}_k={args.k}_best_loss.pth')))
100
+ model.eval()
101
+ with torch.no_grad():
102
+
103
+ pred_whole, logits_whole = [], []
104
+ for input, mask, label in test_dataloader:
105
+
106
+ input = input.cuda()
107
+ label = label.cuda()
108
+
109
+ if not args.gyro:
110
+ b, t, c = input.shape
111
+ indices = np.array([range(i, i+3) for i in range(0, c, 6)]).flatten()
112
+ input = input[:,:,indices]
113
+
114
+ b, t, c = input.shape
115
+ if args.stft:
116
+ input_stft = input.permute(0,2,1).reshape(b * c,t)
117
+ input_stft = torch.abs(torch.stft(input_stft, n_fft = 25, hop_length = 28, onesided = False, center = True, return_complex = True))
118
+ input_stft = input_stft.reshape(b, c, input_stft.shape[-2], input_stft.shape[-1]).reshape(b, c, t).permute(0,2,1)
119
+ input = torch.cat((input, input_stft), dim=-1)
120
+
121
+ input = input.reshape(b, t, 22, -1).permute(0, 3, 1, 2).unsqueeze(-1)
122
+
123
+ logits_per_imu = model.classifier(input)
124
+ logits_whole.append(logits_per_imu)
125
+
126
+ pred = torch.argmax(logits_per_imu, dim=-1).detach().cpu().numpy()
127
+ pred_whole.append(pred)
128
+
129
+ pred = np.concatenate(pred_whole)
130
+ acc = accuracy_score(test_labels, pred)
131
+ prec = precision_score(test_labels, pred, average='macro')
132
+ rec = recall_score(test_labels, pred, average='macro')
133
+ f1 = f1_score(test_labels, pred, average='macro')
134
+
135
+ print(f"{ds} acc: {acc}, {ds} prec: {prec}, {ds} rec: {rec}, {ds} f1: {f1}")
136
+ wandb.log({f"{ds} acc": acc, f"{ds} prec": prec, f"{ds} rec": rec, f"{ds} f1": f1})
137
+
138
+ logits_whole = torch.cat(logits_whole)
139
+ r_at_1, r_at_2, r_at_3, r_at_4, r_at_5, mrr_score = compute_metrics_np(logits_whole.detach().cpu().numpy(), test_labels.numpy())
140
+
141
+ print(f"{ds} R@1: {r_at_1}, R@2: {r_at_2}, R@3: {r_at_3}, R@4: {r_at_4}, R@5: {r_at_5}, MRR: {mrr_score}")
142
+ wandb.log({f"{ds} R@1": r_at_1, f"{ds} R@2": r_at_2, f"{ds} R@3": r_at_3, f"{ds} R@4": r_at_4, f"{ds} R@5": r_at_5, f"{ds} MRR": mrr_score})
143
+
144
+
145
+ if __name__ == "__main__":
146
+
147
+ parser = argparse.ArgumentParser(description='Unified Pre-trained Motion Time Series Model')
148
+
149
+ # model
150
+ parser.add_argument('--mode', type=str, default='full', choices=['random','probe','full'], help='full fine-tuning, linear probe, random init')
151
+
152
+ # data
153
+ parser.add_argument('--padding_size', type=int, default='200', help='padding size (default: 200)')
154
+ parser.add_argument('--k', type=int, help='few shot samples per class (default: None)')
155
+ parser.add_argument('--data_path', type=str, default='./data/', help='/path/to/data/')
156
+
157
+ # training
158
+ parser.add_argument('--stage', type=str, default='finetune', help='training stage')
159
+ parser.add_argument('--num_epochs', type=int, default=200, help='number of fine-tuning epochs (default: 200)')
160
+ parser.add_argument('--run_tag', type=str, default='exp0', help='logging tag')
161
+ parser.add_argument('--gyro', type=int, default=0, help='using gyro or not')
162
+ parser.add_argument('--stft', type=int, default=0, help='using stft or not')
163
+ parser.add_argument('--batch_size', type=int, default=64, help='batch size')
164
+
165
+ parser.add_argument('--checkpoint', type=str, default='./checkpoint/', help='/path/to/checkpoint/')
166
+
167
+ args = parser.parse_args()
168
+
169
+ main(args)
finetune_custom.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ import argparse
6
+ import os
7
+ import numpy as np
8
+ from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
9
+ import wandb
10
+ import datetime
11
+ from torch.utils.data import DataLoader, TensorDataset
12
+ import torch.optim as optim
13
+
14
+ from data import load_multiple, load_custom_data
15
+ from utils import compute_metrics_np
16
+ from contrastive import ContrastiveModule
17
+
18
+ def main(args):
19
+
20
+ train_inputs, train_masks, train_labels, _, _ = load_custom_data(
21
+ args.X_train_path, args.y_train_path, args.config_path, args.joint_list, args.original_sampling_rate, padding_size=args.padding_size, split='train', k=args.k, few_shot_path=None
22
+ )
23
+ train_dataset = TensorDataset(train_inputs, train_masks, train_labels)
24
+ train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
25
+
26
+ test_inputs, test_masks, test_labels, _, _ = load_custom_data(
27
+ args.X_test_path, args.y_test_path, args.config_path, args.joint_list, args.original_sampling_rate, padding_size=args.padding_size, split='test'
28
+ )
29
+ test_dataset = TensorDataset(test_inputs, test_masks, test_labels)
30
+ test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
31
+
32
+ date = datetime.datetime.now().strftime("%d-%m-%y_%H:%M")
33
+ wandb.init(
34
+ project='UniMTS',
35
+ name=f"{args.run_tag}_{args.stage}_{args.mode}_k={args.k}_" + f"{date}"
36
+ )
37
+
38
+ save_path = './checkpoint/%s/' % args.run_tag
39
+
40
+ model = ContrastiveModule(args).cuda()
41
+ optimizer = optim.Adam(model.parameters(), lr=1e-4)
42
+
43
+ if args.mode == 'full' or args.mode == 'probe':
44
+ model.model.load_state_dict(torch.load(f'{args.checkpoint}'))
45
+ if args.mode == 'probe':
46
+ for name, param in model.model.named_parameters():
47
+ param.requires_grad = False
48
+
49
+ best_loss = None
50
+ for epoch in range(args.num_epochs):
51
+
52
+ tol_loss = 0
53
+
54
+ model.train()
55
+ for i, (input, mask, label) in enumerate(train_dataloader):
56
+
57
+ input = input.cuda()
58
+ labels = label.cuda()
59
+
60
+ if not args.gyro:
61
+ b, t, c = input.shape
62
+ indices = np.array([range(i, i+3) for i in range(0, c, 6)]).flatten()
63
+ input = input[:,:,indices]
64
+
65
+ b, t, c = input.shape
66
+ if args.stft:
67
+ input_stft = input.permute(0,2,1).reshape(b * c,t)
68
+ input_stft = torch.abs(torch.stft(input_stft, n_fft = 25, hop_length = 28, onesided = False, center = True, return_complex = True))
69
+ input_stft = input_stft.reshape(b, c, input_stft.shape[-2], input_stft.shape[-1]).reshape(b, c, t).permute(0,2,1)
70
+ input = torch.cat((input, input_stft), dim=-1)
71
+
72
+ input = input.reshape(b, t, 22, -1).permute(0, 3, 1, 2).unsqueeze(-1)
73
+
74
+ output = model.classifier(input)
75
+
76
+ loss = F.cross_entropy(output.float(), labels.long(), reduction="mean")
77
+
78
+ optimizer.zero_grad()
79
+ loss.backward()
80
+ optimizer.step()
81
+
82
+ tol_loss += len(input) * loss.item()
83
+
84
+ # print(epoch, i, loss.item())
85
+
86
+ print(f'Epoch [{epoch+1}/{args.num_epochs}], Loss: {tol_loss / len(train_dataset):.4f}')
87
+ wandb.log({' loss': tol_loss / len(train_dataset)})
88
+
89
+ if best_loss is None or tol_loss < best_loss:
90
+ best_loss = tol_loss
91
+ torch.save(model.state_dict(), os.path.join(save_path, f'k={args.k}_best_loss.pth'))
92
+
93
+ # evaluation
94
+ model.load_state_dict(torch.load(os.path.join(save_path, f'k={args.k}_best_loss.pth')))
95
+ model.eval()
96
+ with torch.no_grad():
97
+
98
+ pred_whole, logits_whole = [], []
99
+ for input, mask, label in test_dataloader:
100
+
101
+ input = input.cuda()
102
+ label = label.cuda()
103
+
104
+ if not args.gyro:
105
+ b, t, c = input.shape
106
+ indices = np.array([range(i, i+3) for i in range(0, c, 6)]).flatten()
107
+ input = input[:,:,indices]
108
+
109
+ b, t, c = input.shape
110
+ if args.stft:
111
+ input_stft = input.permute(0,2,1).reshape(b * c,t)
112
+ input_stft = torch.abs(torch.stft(input_stft, n_fft = 25, hop_length = 28, onesided = False, center = True, return_complex = True))
113
+ input_stft = input_stft.reshape(b, c, input_stft.shape[-2], input_stft.shape[-1]).reshape(b, c, t).permute(0,2,1)
114
+ input = torch.cat((input, input_stft), dim=-1)
115
+
116
+ input = input.reshape(b, t, 22, -1).permute(0, 3, 1, 2).unsqueeze(-1)
117
+
118
+ logits_per_imu = model.classifier(input)
119
+ logits_whole.append(logits_per_imu)
120
+
121
+ pred = torch.argmax(logits_per_imu, dim=-1).detach().cpu().numpy()
122
+ pred_whole.append(pred)
123
+
124
+ pred = np.concatenate(pred_whole)
125
+ acc = accuracy_score(test_labels, pred)
126
+ prec = precision_score(test_labels, pred, average='macro')
127
+ rec = recall_score(test_labels, pred, average='macro')
128
+ f1 = f1_score(test_labels, pred, average='macro')
129
+
130
+ print(f"acc: {acc}, prec: {prec}, rec: {rec}, f1: {f1}")
131
+ wandb.log({f"acc": acc, f"prec": prec, f"rec": rec, f"f1": f1})
132
+
133
+ logits_whole = torch.cat(logits_whole)
134
+ r_at_1, r_at_2, r_at_3, r_at_4, r_at_5, mrr_score = compute_metrics_np(logits_whole.detach().cpu().numpy(), test_labels.numpy())
135
+
136
+ print(f"R@1: {r_at_1}, R@2: {r_at_2}, R@3: {r_at_3}, R@4: {r_at_4}, R@5: {r_at_5}, MRR: {mrr_score}")
137
+ wandb.log({f"R@1": r_at_1, f"R@2": r_at_2, f"R@3": r_at_3, f"R@4": r_at_4, f"R@5": r_at_5, f"MRR": mrr_score})
138
+
139
+
140
+ if __name__ == "__main__":
141
+
142
+ parser = argparse.ArgumentParser(description='Unified Pre-trained Motion Time Series Model')
143
+
144
+ # model
145
+ parser.add_argument('--mode', type=str, default='full', choices=['random','probe','full'], help='full fine-tuning, linear probe, random init')
146
+
147
+ # data
148
+ parser.add_argument('--padding_size', type=int, default='200', help='padding size (default: 200)')
149
+ parser.add_argument('--k', type=int, help='few shot samples per class (default: None)')
150
+ parser.add_argument('--X_train_path', type=str, required=True, help='/path/to/train/data/')
151
+ parser.add_argument('--X_test_path', type=str, required=True, help='/path/to/test/data/')
152
+ parser.add_argument('--y_train_path', type=str, required=True, help='/path/to/train/label/')
153
+ parser.add_argument('--y_test_path', type=str, required=True, help='/path/to/test/label/')
154
+ parser.add_argument('--config_path', type=str, required=True, help='/path/to/config/')
155
+ parser.add_argument('--few_shot_path', type=str, help='/path/to/few/shot/indices/')
156
+ parser.add_argument('--joint_list', nargs='+', type=int, required=True, help='List of joint indices')
157
+ parser.add_argument('--original_sampling_rate', type=int, required=True, help='original sampling rate')
158
+ parser.add_argument('--num_class', type=int, required=True, help='number of classes')
159
+
160
+ # training
161
+ parser.add_argument('--stage', type=str, default='finetune', help='training stage')
162
+ parser.add_argument('--num_epochs', type=int, default=200, help='number of fine-tuning epochs (default: 200)')
163
+ parser.add_argument('--run_tag', type=str, default='exp0', help='logging tag')
164
+ parser.add_argument('--gyro', type=int, default=0, help='using gyro or not')
165
+ parser.add_argument('--stft', type=int, default=0, help='using stft or not')
166
+ parser.add_argument('--batch_size', type=int, default=64, help='batch size')
167
+
168
+ parser.add_argument('--checkpoint', type=str, default='./checkpoint/', help='/path/to/checkpoint/')
169
+
170
+ args = parser.parse_args()
171
+
172
+ main(args)
model.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ class Graph():
7
+ """ The Graph to model the skeletons
8
+
9
+ Args:
10
+ strategy (string): must be one of the follow candidates
11
+ - uniform: Uniform Labeling
12
+ - distance: Distance Partitioning
13
+ - spatial: Spatial Configuration
14
+ max_hop (int): the maximal distance between two connected nodes
15
+ dilation (int): controls the spacing between the kernel points
16
+
17
+ """
18
+ def __init__(self,
19
+ strategy='spatial',
20
+ max_hop=1,
21
+ dilation=1):
22
+ self.max_hop = max_hop
23
+ self.dilation = dilation
24
+
25
+ self.get_edge()
26
+ self.hop_dis = get_hop_distance(self.num_node,
27
+ self.edge,
28
+ max_hop=max_hop)
29
+ self.get_adjacency(strategy)
30
+
31
+ def __str__(self):
32
+ return self.A
33
+
34
+ def get_edge(self):
35
+ # edge is a list of [child, parent] paris
36
+ self.num_node = 22
37
+ self_link = [(i, i) for i in range(self.num_node)]
38
+ neighbor_link = [(1,0), (2,1), (3,2), (4,3), (5,0), (6,5), (7,6), (8,7), (9,0), (10,9), (11,10), (12,11), \
39
+ (13,12), (14,11), (15,14), (16,15), (17,16), (18,11), (19,18), (20,19), (21,20)]
40
+ self.edge = self_link + neighbor_link
41
+ self.center = 0
42
+
43
+ def get_adjacency(self, strategy):
44
+ valid_hop = range(0, self.max_hop + 1, self.dilation)
45
+ adjacency = np.zeros((self.num_node, self.num_node))
46
+ for hop in valid_hop:
47
+ adjacency[self.hop_dis == hop] = 1
48
+ normalize_adjacency = normalize_digraph(adjacency)
49
+
50
+ if strategy == 'uniform':
51
+ A = np.zeros((1, self.num_node, self.num_node))
52
+ A[0] = normalize_adjacency
53
+ self.A = A
54
+ elif strategy == 'distance':
55
+ A = np.zeros((len(valid_hop), self.num_node, self.num_node))
56
+ for i, hop in enumerate(valid_hop):
57
+ A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis ==
58
+ hop]
59
+ self.A = A
60
+ elif strategy == 'spatial':
61
+ A = []
62
+ for hop in valid_hop:
63
+ a_root = np.zeros((self.num_node, self.num_node))
64
+ a_close = np.zeros((self.num_node, self.num_node))
65
+ a_further = np.zeros((self.num_node, self.num_node))
66
+ for i in range(self.num_node):
67
+ for j in range(self.num_node):
68
+ if self.hop_dis[j, i] == hop:
69
+ if self.hop_dis[j, self.center] == self.hop_dis[
70
+ i, self.center]:
71
+ a_root[j, i] = normalize_adjacency[j, i]
72
+ elif self.hop_dis[j, self.center] > self.hop_dis[
73
+ i, self.center]:
74
+ a_close[j, i] = normalize_adjacency[j, i]
75
+ else:
76
+ a_further[j, i] = normalize_adjacency[j, i]
77
+ if hop == 0:
78
+ A.append(a_root)
79
+ else:
80
+ A.append(a_root + a_close)
81
+ A.append(a_further)
82
+ A = np.stack(A)
83
+ self.A = A
84
+ else:
85
+ raise ValueError("Do Not Exist This Strategy")
86
+
87
+ def get_hop_distance(num_node, edge, max_hop=1):
88
+ A = np.zeros((num_node, num_node))
89
+ for i, j in edge:
90
+ A[j, i] = 1
91
+ A[i, j] = 1
92
+
93
+ # compute hop steps
94
+ hop_dis = np.zeros((num_node, num_node)) + np.inf
95
+ transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)]
96
+ arrive_mat = (np.stack(transfer_mat) > 0)
97
+ for d in range(max_hop, -1, -1):
98
+ hop_dis[arrive_mat[d]] = d
99
+ return hop_dis
100
+
101
+ def normalize_digraph(A):
102
+ Dl = np.sum(A, 0)
103
+ num_node = A.shape[0]
104
+ Dn = np.zeros((num_node, num_node))
105
+ for i in range(num_node):
106
+ if Dl[i] > 0:
107
+ Dn[i, i] = Dl[i]**(-1)
108
+ AD = np.dot(A, Dn)
109
+ return AD
110
+
111
+ def normalize_undigraph(A):
112
+ Dl = np.sum(A, 0)
113
+ num_node = A.shape[0]
114
+ Dn = np.zeros((num_node, num_node))
115
+ for i in range(num_node):
116
+ if Dl[i] > 0:
117
+ Dn[i, i] = Dl[i]**(-0.5)
118
+ DAD = np.dot(np.dot(Dn, A), Dn)
119
+ return DAD
120
+
121
+ def zero(x):
122
+ return 0
123
+
124
+ def iden(x):
125
+ return x
126
+
127
+ class ConvTemporalGraphical(nn.Module):
128
+ r"""The basic module for applying a graph convolution.
129
+
130
+ Args:
131
+ in_channels (int): Number of channels in the input sequence data
132
+ out_channels (int): Number of channels produced by the convolution
133
+ kernel_size (int): Size of the graph convolving kernel
134
+ t_kernel_size (int): Size of the temporal convolving kernel
135
+ t_stride (int, optional): Stride of the temporal convolution. Default: 1
136
+ t_padding (int, optional): Temporal zero-padding added to both sides of
137
+ the input. Default: 0
138
+ t_dilation (int, optional): Spacing between temporal kernel elements.
139
+ Default: 1
140
+ bias (bool, optional): If ``True``, adds a learnable bias to the output.
141
+ Default: ``True``
142
+
143
+ Shape:
144
+ - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
145
+ - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
146
+ - Output[0]: Output graph sequence in :math:`(N, out_channels, T_{out}, V)` format
147
+ - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
148
+
149
+ where
150
+ :math:`N` is a batch size,
151
+ :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
152
+ :math:`T_{in}/T_{out}` is a length of input/output sequence,
153
+ :math:`V` is the number of graph nodes.
154
+ """
155
+ def __init__(self,
156
+ in_channels,
157
+ out_channels,
158
+ kernel_size,
159
+ t_kernel_size=1,
160
+ t_stride=1,
161
+ t_padding=0,
162
+ t_dilation=1,
163
+ bias=True):
164
+ super().__init__()
165
+
166
+ self.kernel_size = kernel_size
167
+ self.conv = nn.Conv2d(in_channels,
168
+ out_channels * kernel_size,
169
+ kernel_size=(t_kernel_size, 1),
170
+ padding=(t_padding, 0),
171
+ stride=(t_stride, 1),
172
+ dilation=(t_dilation, 1),
173
+ bias=bias)
174
+
175
+ def forward(self, x, A):
176
+ assert A.size(0) == self.kernel_size
177
+
178
+ x = self.conv(x)
179
+
180
+ n, kc, t, v = x.size()
181
+ x = x.view(n, self.kernel_size, kc // self.kernel_size, t, v)
182
+ x = torch.einsum('nkctv,kvw->nctw', (x, A))
183
+
184
+ return x.contiguous(), A
185
+
186
+ class st_gcn_block(nn.Module):
187
+ r"""Applies a spatial temporal graph convolution over an input graph sequence.
188
+
189
+ Args:
190
+ in_channels (int): Number of channels in the input sequence data
191
+ out_channels (int): Number of channels produced by the convolution
192
+ kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel
193
+ stride (int, optional): Stride of the temporal convolution. Default: 1
194
+ dropout (int, optional): Dropout rate of the final output. Default: 0
195
+ residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True``
196
+
197
+ Shape:
198
+ - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
199
+ - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
200
+ - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format
201
+ - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
202
+
203
+ where
204
+ :math:`N` is a batch size,
205
+ :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
206
+ :math:`T_{in}/T_{out}` is a length of input/output sequence,
207
+ :math:`V` is the number of graph nodes.
208
+
209
+ """
210
+ def __init__(self,
211
+ in_channels,
212
+ out_channels,
213
+ kernel_size,
214
+ stride=1,
215
+ dropout=0,
216
+ residual=True):
217
+ super().__init__()
218
+
219
+ assert len(kernel_size) == 2
220
+ assert kernel_size[0] % 2 == 1
221
+ padding = ((kernel_size[0] - 1) // 2, 0)
222
+
223
+ self.gcn = ConvTemporalGraphical(in_channels, out_channels,
224
+ kernel_size[1])
225
+
226
+ self.tcn = nn.Sequential(
227
+ nn.BatchNorm2d(out_channels),
228
+ nn.ReLU(inplace=True),
229
+ nn.Conv2d(
230
+ out_channels,
231
+ out_channels,
232
+ (kernel_size[0], 1),
233
+ (stride, 1),
234
+ padding,
235
+ ),
236
+ nn.BatchNorm2d(out_channels),
237
+ nn.Dropout(dropout, inplace=True),
238
+ )
239
+
240
+ if not residual:
241
+ self.residual = zero
242
+
243
+ elif (in_channels == out_channels) and (stride == 1):
244
+ self.residual = iden
245
+
246
+ else:
247
+ self.residual = nn.Sequential(
248
+ nn.Conv2d(in_channels,
249
+ out_channels,
250
+ kernel_size=1,
251
+ stride=(stride, 1)),
252
+ nn.BatchNorm2d(out_channels),
253
+ )
254
+
255
+ self.relu = nn.ReLU(inplace=True)
256
+
257
+ def forward(self, x, A):
258
+
259
+ res = self.residual(x)
260
+ x, A = self.gcn(x, A)
261
+ x = self.tcn(x) + res
262
+
263
+ return self.relu(x), A
264
+
265
+ class ST_GCN_18(nn.Module):
266
+ r"""Spatial temporal graph convolutional networks.
267
+
268
+ Args:
269
+ in_channels (int): Number of channels in the input data
270
+ num_class (int): Number of classes for the classification task
271
+ graph_cfg (dict): The arguments for building the graph
272
+ edge_importance_weighting (bool): If ``True``, adds a learnable
273
+ importance weighting to the edges of the graph
274
+ **kwargs (optional): Other parameters for graph convolution units
275
+
276
+ Shape:
277
+ - Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})`
278
+ - Output: :math:`(N, num_class)` where
279
+ :math:`N` is a batch size,
280
+ :math:`T_{in}` is a length of input sequence,
281
+ :math:`V_{in}` is the number of graph nodes,
282
+ :math:`M_{in}` is the number of instance in a frame.
283
+ """
284
+ def __init__(self,
285
+ in_channels,
286
+ edge_importance_weighting=True,
287
+ data_bn=True,
288
+ **kwargs):
289
+ super().__init__()
290
+
291
+ # load graph
292
+ self.graph = Graph()
293
+ A = torch.tensor(self.graph.A,
294
+ dtype=torch.float32,
295
+ requires_grad=False)
296
+ self.register_buffer('A', A)
297
+
298
+ # build networks
299
+ spatial_kernel_size = A.size(0)
300
+ temporal_kernel_size = 9
301
+ kernel_size = (temporal_kernel_size, spatial_kernel_size)
302
+ self.data_bn = nn.BatchNorm1d(in_channels *
303
+ A.size(1)) if data_bn else iden
304
+ kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'}
305
+ self.st_gcn_networks = nn.ModuleList((
306
+ st_gcn_block(in_channels,
307
+ 64,
308
+ kernel_size,
309
+ 1,
310
+ residual=False,
311
+ **kwargs0),
312
+ st_gcn_block(64, 64, kernel_size, 1, **kwargs),
313
+ st_gcn_block(64, 64, kernel_size, 1, **kwargs),
314
+ st_gcn_block(64, 64, kernel_size, 1, **kwargs),
315
+ st_gcn_block(64, 128, kernel_size, 2, **kwargs),
316
+ st_gcn_block(128, 128, kernel_size, 1, **kwargs),
317
+ st_gcn_block(128, 128, kernel_size, 1, **kwargs),
318
+ st_gcn_block(128, 256, kernel_size, 2, **kwargs),
319
+ st_gcn_block(256, 256, kernel_size, 1, **kwargs),
320
+ st_gcn_block(256, 512, kernel_size, 1, **kwargs),
321
+ ))
322
+
323
+ # initialize parameters for edge importance weighting
324
+ if edge_importance_weighting:
325
+ self.edge_importance = nn.ParameterList([
326
+ nn.Parameter(torch.ones(self.A.size()))
327
+ for i in self.st_gcn_networks
328
+ ])
329
+ else:
330
+ self.edge_importance = [1] * len(self.st_gcn_networks)
331
+
332
+ def forward(self, x):
333
+ # data normalization
334
+ N, C, T, V, M = x.size()
335
+ x = x.permute(0, 4, 3, 1, 2).contiguous()
336
+ x = x.view(N * M, V * C, T)
337
+ x = self.data_bn(x)
338
+ x = x.view(N, M, V, C, T)
339
+ x = x.permute(0, 1, 3, 4, 2).contiguous()
340
+ x = x.view(N * M, C, T, V)
341
+
342
+ # forward
343
+ for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
344
+ x, _ = gcn(x, self.A * importance)
345
+
346
+ # global pooling
347
+ x = F.avg_pool2d(x, x.size()[2:]) # (b, 512, t, joint)
348
+ x = x.view(N, M, -1, 1, 1).mean(dim=1)
349
+
350
+ return x
pos2bvh.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from Quaternions import Quaternions
3
+ from scipy_motion import myBVH
4
+ import BVH
5
+ from scipy_motion import myAnimation
6
+ import Animation
7
+ from scipy_motion import myInverseKinematics as myIK
8
+ import InverseKinematics as IK
9
+ from tqdm import tqdm
10
+ import multiprocessing
11
+ import os
12
+ import os.path as osp
13
+ from scipy.spatial.transform import Rotation as R
14
+
15
+
16
+ parents = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19]
17
+ # names = ['root','leftleg1','leftleg2','leftleg3','leftleg4','rightleg1','rightleg2','rightleg3','rightleg4',\
18
+ # 'spline1','spline2','spline3','spline4','spline5','rightarm1','rightarm2','rightarm3','rightarm4',\
19
+ # 'leftarm1','lertarm2','leftarm3','leftarm4']
20
+
21
+ def process_file(f):
22
+
23
+ fk_positions = np.load('/path/to/joint/pos/%s.npy' % (f))
24
+
25
+ frametime = 1 / 20
26
+
27
+ anim_ik, _, _, save_file = IK.animation_from_positions(fk_positions, parents=parents)
28
+
29
+ if save_file:
30
+ BVH.save('bvh/%s.bvh' % f, anim_ik, frametime=frametime)
31
+
32
+ source_dir = '/path/to/joint/pos'
33
+ error_file = ['M005836.npy', 'M000990.npy', '000990.npy', '005836.npy']
34
+ npy_files = [file[:-4] for file in os.listdir(source_dir) if file.endswith('.npy') and file not in error_file]
35
+
36
+ # Process files in parallel
37
+ pool = multiprocessing.Pool(processes=8)
38
+ for _ in tqdm(pool.imap_unordered(process_file, npy_files), total=len(npy_files)):
39
+ pass
40
+ pool.close()
41
+ pool.join()
pretrain.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ import argparse
7
+ import os
8
+ import numpy as np
9
+ import clip
10
+ import wandb
11
+ import datetime
12
+ import torch.optim as optim
13
+
14
+ from data import CLIPDataset
15
+ from utils import augment_data
16
+ from contrastive import ContrastiveModule
17
+
18
+ def main(args):
19
+
20
+ train_dataset = CLIPDataset(args)
21
+ train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=True)
22
+
23
+ date = datetime.datetime.now().strftime("%d-%m-%y_%H:%M")
24
+ wandb.init(
25
+ project='UniMTS',
26
+ name=f"{args.run_tag}_{args.stage}_" + f"{date}"
27
+ )
28
+
29
+ model = ContrastiveModule(args).cuda()
30
+ optimizer = optim.Adam(model.parameters(), lr=1e-4)
31
+
32
+ save_path = './checkpoint/%s/' % args.run_tag
33
+ if not os.path.exists(save_path):
34
+ os.makedirs(save_path)
35
+
36
+ for epoch in range(args.num_epochs):
37
+
38
+ tol_loss = 0
39
+
40
+ model.train()
41
+ for i, batch in enumerate(train_loader):
42
+
43
+ inputs_imu = batch['imu'].float().cuda()
44
+ inputs_text = clip.tokenize(batch['text'], truncate=True).cuda()
45
+ mask = batch['mask'].float().cuda()
46
+
47
+ input = inputs_imu * mask
48
+
49
+ # rotation invariant
50
+ if args.aug:
51
+ input = augment_data(input)
52
+
53
+ if not args.gyro:
54
+ b, t, c = input.shape
55
+ indices = np.array([range(i, i+3) for i in range(0, c, 6)]).flatten()
56
+ input = input[:,:,indices]
57
+
58
+ b, t, c = input.shape
59
+ if args.stft:
60
+ input_stft = input.permute(0,2,1).reshape(b * c,t)
61
+ input_stft = torch.abs(torch.stft(input_stft, n_fft = 25, hop_length = 28, onesided = False, center = True, return_complex = True))
62
+ input_stft = input_stft.reshape(b, c, input_stft.shape[-2], input_stft.shape[-1]).reshape(b, c, t).permute(0,2,1)
63
+ input = torch.cat((input, input_stft), dim=-1)
64
+
65
+ input = input.reshape(b, t, 22, -1).permute(0, 3, 1, 2).unsqueeze(-1)
66
+
67
+ # IMU and text representations
68
+ logits_per_imu, logits_per_text = model(input, inputs_text)
69
+
70
+ # positive keys are the entries on the diagonal
71
+ labels = torch.arange(len(batch['imu'])).cuda()
72
+
73
+ loss = F.cross_entropy(logits_per_imu / args.temperature, labels, reduction="mean")
74
+
75
+ optimizer.zero_grad()
76
+ loss.backward()
77
+ optimizer.step()
78
+
79
+ tol_loss += len(inputs_imu) * loss.item()
80
+
81
+ # print(epoch, i, loss.item())
82
+
83
+ print(f'Epoch [{epoch+1}/{args.num_epochs}], Loss: {tol_loss / len(train_dataset):.4f}')
84
+ wandb.log({'loss': tol_loss / len(train_dataset)})
85
+
86
+ if epoch > 0 and epoch % args.log == 0:
87
+ torch.save(model.model.state_dict(), os.path.join(save_path, f'epoch_{epoch}.pth'))
88
+
89
+ if __name__ == "__main__":
90
+
91
+ parser = argparse.ArgumentParser(description='Unified Pre-trained Motion Time Series Model')
92
+
93
+ # data
94
+ parser.add_argument('--padding_size', type=int, default='200', help='padding size (default: 200)')
95
+ parser.add_argument('--sample', type=float, default='1', help='pre-training down-sample ratio (default: 1)')
96
+ parser.add_argument('--data_path', type=str, default='./data/', help='/path/to/data/')
97
+
98
+ # training
99
+ parser.add_argument('--run_tag', type=str, default='exp0', help='logging tag')
100
+ parser.add_argument('--stage', type=str, default='pretrain', help='training stage')
101
+ parser.add_argument('--num_epochs', type=int, default=100, help='number of pre-training epochs')
102
+ parser.add_argument('--gyro', type=int, default=0, help='using gyro or not')
103
+ parser.add_argument('--stft', type=int, default=0, help='using stft or not')
104
+ parser.add_argument('--aug', type=int, default=1, help='using augmentation or not')
105
+ parser.add_argument('--batch_size', type=int, default=64, help='batch size')
106
+ parser.add_argument('--temperature', type=float, default=0.1, help='temperature')
107
+ parser.add_argument('--log', type=int, default=10, help='logging step')
108
+
109
+ args = parser.parse_args()
110
+
111
+ main(args)
run_evaluation.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ python evaluate.py \
2
+ --batch_size 64 \
3
+ --checkpoint './checkpoint/UniMTS.pth' \
4
+ --data_path 'UniMTS_data'
run_evaluation_custom.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ python evaluate_custom.py \
2
+ --batch_size 64 \
3
+ --checkpoint './checkpoint/UniMTS.pth' \
4
+ --X_path 'UniMTS_data/TNDA-HAR/X_test.npy' \
5
+ --y_path 'UniMTS_data/TNDA-HAR/y_test.npy' \
6
+ --config_path 'UniMTS_data/TNDA-HAR/TNDA-HAR.json' \
7
+ --joint_list 20 2 21 3 11 \
8
+ --original_sampling_rate 50
run_finetune.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ for k in 1 2 3 5 10
2
+ do
3
+
4
+ python finetune.py \
5
+ --mode full \
6
+ --k $k \
7
+ --batch_size 64 \
8
+ --num_epochs 200 \
9
+ --checkpoint './checkpoint/UniMTS.pth' \
10
+ --data_path 'UniMTS_data'
11
+
12
+ done
13
+
14
+ python finetune.py \
15
+ --mode full \
16
+ --batch_size 64 \
17
+ --num_epochs 200 \
18
+ --checkpoint './checkpoint/UniMTS.pth' \
19
+ --data_path 'UniMTS_data'
run_finetune_custom.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ for k in 1 2 3 5 10
2
+ do
3
+
4
+ python finetune_custom.py \
5
+ --mode full \
6
+ --k $k \
7
+ --batch_size 64 \
8
+ --num_epochs 200 \
9
+ --checkpoint './checkpoint/UniMTS.pth' \
10
+ --X_train_path 'UniMTS_data/TNDA-HAR/X_train.npy' \
11
+ --y_train_path 'UniMTS_data/TNDA-HAR/y_train.npy' \
12
+ --X_test_path 'UniMTS_data/TNDA-HAR/X_test.npy' \
13
+ --y_test_path 'UniMTS_data/TNDA-HAR/y_test.npy' \
14
+ --config_path 'UniMTS_data/TNDA-HAR/TNDA-HAR.json' \
15
+ --joint_list 20 2 21 3 11 \
16
+ --original_sampling_rate 50 \
17
+ --num_class 8
18
+
19
+ done
20
+
21
+ python finetune_custom.py \
22
+ --mode full \
23
+ --batch_size 64 \
24
+ --num_epochs 200 \
25
+ --checkpoint './checkpoint/UniMTS.pth' \
26
+ --X_train_path 'UniMTS_data/TNDA-HAR/X_train.npy' \
27
+ --y_train_path 'UniMTS_data/TNDA-HAR/y_train.npy' \
28
+ --X_test_path 'UniMTS_data/TNDA-HAR/X_test.npy' \
29
+ --y_test_path 'UniMTS_data/TNDA-HAR/y_test.npy' \
30
+ --config_path 'UniMTS_data/TNDA-HAR/TNDA-HAR.json' \
31
+ --joint_list 20 2 21 3 11 \
32
+ --original_sampling_rate 50 \
33
+ --num_class 8
run_pretrain.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ python pretrain.py \
2
+ --aug 1 \
3
+ --batch_size 64 \
4
+ --data_path 'UniMTS_data'
text_aug.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ import glob
3
+ import os
4
+ import shutil
5
+ from tqdm import tqdm
6
+
7
+ def load_api_key(file_path='api_key.txt'):
8
+ with open(file_path, 'r') as f:
9
+ for line in f:
10
+ if line.startswith('api_key='):
11
+ return line.strip().split('=', 1)[1]
12
+ return None
13
+
14
+ openai.api_key = load_api_key()
15
+
16
+ if openai.api_key is None:
17
+ print("Error: API key not found.")
18
+ exit()
19
+
20
+ files = glob.glob('/path/to/txt')
21
+ aug_dir = '/path/to/output'
22
+
23
+ for f in tqdm(files):
24
+
25
+ file_id = f.split('/')[-1]
26
+ if not os.path.exists(aug_dir + file_id):
27
+
28
+ with open(f, 'r') as file:
29
+ lines = file.readlines()
30
+
31
+ text = []
32
+ for i, l in enumerate(lines):
33
+ text.append(str(i) + ': ')
34
+ text.append((l).split('#')[0].strip())
35
+ if text[-1][-1] != '.':
36
+ text.append('. ')
37
+ else:
38
+ text.append(' ')
39
+ text = ''.join(text)
40
+
41
+ prompt = 'The following one or multiple descriptions are describing the same human activities: '
42
+ prompt += text
43
+ prompt += 'Generate 3 paraphrases to describe the same activities. One in a line in a plain text format ending with \n, without numbering or - at the beginning. Do not generate any other analysis except from the paraphrased descriptions.'
44
+
45
+ response = openai.ChatCompletion.create(
46
+ model="gpt-3.5-turbo",
47
+ messages=[
48
+ {"role": "user", "content": prompt}
49
+ ]
50
+ )
51
+ pred = response.choices[0]['message']['content']
52
+ # res = pred.split('\n')
53
+
54
+ shutil.copy(f, aug_dir)
55
+ with open(aug_dir + file_id, 'a') as log_file:
56
+ log_file.write(pred)
57
+
58
+ files = glob.glob('/path/to/output')
59
+ for f in tqdm(files):
60
+ with open(f, 'r') as file:
61
+ lines = file.readlines()
62
+
63
+ lines = [line.lstrip("- ") for line in lines if line.strip()]
64
+
65
+ with open(f, 'w') as file:
66
+ file.writelines(lines)
utils.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import torch
4
+ import imageio
5
+ import io
6
+
7
+ def random_rotation_matrix():
8
+ # Random quaternion
9
+ q = torch.randn(4)
10
+ q = q / torch.norm(q)
11
+
12
+ # Quaternion to rotation matrix
13
+ R = torch.tensor([
14
+ [1 - 2*q[2]**2 - 2*q[3]**2, 2*q[1]*q[2] - 2*q[3]*q[0], 2*q[1]*q[3] + 2*q[2]*q[0]],
15
+ [2*q[1]*q[2] + 2*q[3]*q[0], 1 - 2*q[1]**2 - 2*q[3]**2, 2*q[2]*q[3] - 2*q[1]*q[0]],
16
+ [2*q[1]*q[3] - 2*q[2]*q[0], 2*q[2]*q[3] + 2*q[1]*q[0], 1 - 2*q[1]**2 - 2*q[2]**2]
17
+ ])
18
+ return R
19
+
20
+ def augment_data(data):
21
+ B, T, M = data.shape
22
+ augmented_data = torch.zeros_like(data)
23
+
24
+ for i in range(B):
25
+ for c in range(0, M, 6):
26
+ R = random_rotation_matrix().cuda()
27
+ acc = data[i, :, c:c+3].transpose(0, 1) # Shape (3, T)
28
+ gyro = data[i, :, c+3:c+6].transpose(0, 1) # Shape (3, T)
29
+
30
+ # Apply rotation
31
+ rotated_acc = torch.matmul(R, acc)
32
+ rotated_gyro = torch.matmul(R, gyro)
33
+
34
+ # Concatenate and assign to augmented_data
35
+ augmented_data[i, :, c:c+3] = rotated_acc.transpose(0, 1)
36
+ augmented_data[i, :, c+3:c+6] = rotated_gyro.transpose(0, 1)
37
+
38
+ return augmented_data
39
+
40
+ def update_limits(data):
41
+ # Get global min and max for each axis
42
+ min_x, max_x = np.min(data[:, :, 0]), np.max(data[:, :, 0])
43
+ min_y, max_y = np.min(data[:, :, 2]), np.max(data[:, :, 2])
44
+ min_z, max_z = np.min(data[:, :, 1]), np.max(data[:, :, 1])
45
+
46
+ # Add some padding to ensure the skeleton doesn't touch the plot edges
47
+ padding = 0.1
48
+ x_range = max_x - min_x
49
+ y_range = max_y - min_y
50
+ z_range = max_z - min_z
51
+
52
+ return (min_x - padding * x_range, max_x + padding * x_range), \
53
+ (min_y - padding * y_range, max_y + padding * y_range), \
54
+ (min_z - padding * z_range, max_z + padding * z_range)
55
+
56
+ def plot_skeleton(frame_data, xlims, ylims, zlims, dataset):
57
+ """
58
+ Plot a single frame of skeleton data.
59
+ """
60
+ fig = plt.figure()
61
+ ax = fig.add_subplot(111, projection='3d')
62
+ ax.scatter(frame_data[:, 0], frame_data[:, 2], frame_data[:, 1])
63
+
64
+ # Add code here to connect the joints as per your skeleton structure
65
+ if dataset == 't2m':
66
+ connections = [
67
+ [0, 2, 5, 8, 11],
68
+ [0, 1, 4, 7, 10],
69
+ [0, 3, 6, 9, 12, 15],
70
+ [9, 14, 17, 19, 21],
71
+ [9, 13, 16, 18, 20]
72
+ ]
73
+
74
+ if dataset == 'kit':
75
+ connections = [
76
+ [0, 11, 12, 13, 14, 15],
77
+ [0, 16, 17, 18, 19, 20],
78
+ [0, 1, 2, 3, 4],
79
+ [3, 5, 6, 7],
80
+ [3, 8, 9, 10]
81
+ ]
82
+
83
+ if dataset == 'ntu':
84
+ connections = [
85
+ [0, 12, 13, 14, 15],
86
+ [0, 16, 17, 18, 19],
87
+ [0, 1, 20, 2, 3],
88
+ [20, 4, 5, 6, 7, 21],
89
+ [7, 22],
90
+ [20, 8, 9, 10, 11, 23],
91
+ [11, 24],
92
+ ]
93
+
94
+ # Plot the lines for each sequence
95
+ for connection in connections:
96
+ for i in range(len(connection)-1):
97
+ start_joint = connection[i]
98
+ end_joint = connection[i+1]
99
+ ax.plot([frame_data[start_joint, 0], frame_data[end_joint, 0]],
100
+ [frame_data[start_joint, 2], frame_data[end_joint, 2]],
101
+ [frame_data[start_joint, 1], frame_data[end_joint, 1]])
102
+
103
+ ax.view_init(elev=10, azim=90)
104
+ ax.set_box_aspect((np.ptp(xlims), np.ptp(ylims), np.ptp(zlims)))
105
+
106
+ ax.set_xlim(xlims)
107
+ ax.set_ylim(ylims)
108
+ ax.set_zlim(zlims)
109
+ ax.set_xlabel('X')
110
+ ax.set_ylabel('Z')
111
+ ax.set_zlabel('Y')
112
+
113
+ # Save the plot to a buffer
114
+ buf = io.BytesIO()
115
+ plt.savefig(buf, format='png')
116
+ buf.seek(0)
117
+ img = imageio.imread(buf)
118
+ buf.close()
119
+
120
+ plt.close(fig) # Close the figure to prevent display
121
+ return img
122
+
123
+ def plot_skeleton_gif(data, dataset):
124
+ xlims, ylims, zlims = update_limits(data)
125
+ images = [plot_skeleton(frame, xlims, ylims, zlims, dataset) for frame in data]
126
+ imageio.mimsave('./skeleton_animation.gif', images, fps=20)
127
+ return
128
+
129
+ def plot_single_skeleton(data, dataset, frame=0):
130
+
131
+ xlims, ylims, zlims = update_limits(data)
132
+ frame_data = data[frame]
133
+
134
+ fig = plt.figure()
135
+ ax = fig.add_subplot(111, projection='3d')
136
+ ax.scatter(frame_data[:, 0], frame_data[:, 2], frame_data[:, 1])
137
+
138
+ # Add code here to connect the joints as per your skeleton structure
139
+ if dataset == 't2m':
140
+ connections = [
141
+ [0, 2, 5, 8, 11],
142
+ [0, 1, 4, 7, 10],
143
+ [0, 3, 6, 9, 12, 15],
144
+ [9, 14, 17, 19, 21],
145
+ [9, 13, 16, 18, 20]
146
+ ]
147
+
148
+ if dataset == 'kit':
149
+ connections = [
150
+ [0, 11, 12, 13, 14, 15],
151
+ [0, 16, 17, 18, 19, 20],
152
+ [0, 1, 2, 3, 4],
153
+ [3, 5, 6, 7],
154
+ [3, 8, 9, 10]
155
+ ]
156
+
157
+ if dataset == 'ntu':
158
+ connections = [
159
+ [0, 12, 13, 14, 15],
160
+ [0, 16, 17, 18, 19],
161
+ [0, 1, 20, 2, 3],
162
+ [20, 4, 5, 6, 7, 21],
163
+ [7, 22],
164
+ [20, 8, 9, 10, 11, 23],
165
+ [11, 24],
166
+ ]
167
+
168
+ # Plot the lines for each sequence
169
+ for connection in connections:
170
+ for i in range(len(connection)-1):
171
+ start_joint = connection[i]
172
+ end_joint = connection[i+1]
173
+ ax.plot([frame_data[start_joint, 0], frame_data[end_joint, 0]],
174
+ [frame_data[start_joint, 2], frame_data[end_joint, 2]],
175
+ [frame_data[start_joint, 1], frame_data[end_joint, 1]])
176
+
177
+ #ax.view_init(elev=10, azim=90)
178
+ ax.set_box_aspect((np.ptp(xlims), np.ptp(ylims), np.ptp(zlims)))
179
+
180
+ ax.set_xlim(xlims)
181
+ ax.set_ylim(ylims)
182
+ ax.set_zlim(zlims)
183
+
184
+ ax.set_xlabel('X')
185
+ ax.set_ylabel('Z')
186
+ ax.set_zlabel('Y')
187
+
188
+ plt.savefig('skeleton.pdf', bbox_inches='tight')
189
+
190
+ def compute_height(joints, head_index, l_foot_index, r_foot_index):
191
+ joints = torch.from_numpy(joints)
192
+ left = (joints[:,head_index,1] - joints[:,l_foot_index,1])[0]
193
+ right = (joints[:,head_index,1] - joints[:,r_foot_index,1])[0]
194
+ height = (left + right) / 2
195
+ return height
196
+
197
+ def compute_metrics_np(similarity_matrix, correct_labels):
198
+
199
+ B, _ = similarity_matrix.shape
200
+
201
+ ranked_indices = np.argsort(-similarity_matrix, axis=1)
202
+
203
+ correct_label_ranks = np.array([np.where(ranked_indices[i] == correct_labels[i])[0][0] for i in range(B)]) + 1
204
+
205
+ # Compute R@K
206
+ R_at_1 = np.mean(correct_label_ranks <= 1)
207
+ R_at_2 = np.mean(correct_label_ranks <= 2)
208
+ R_at_3 = np.mean(correct_label_ranks <= 3)
209
+ R_at_4 = np.mean(correct_label_ranks <= 4)
210
+ R_at_5 = np.mean(correct_label_ranks <= 5)
211
+
212
+ # Compute MRR
213
+ MRR = np.mean(1.0 / correct_label_ranks)
214
+
215
+ return R_at_1, R_at_2, R_at_3, R_at_4, R_at_5, MRR