initial
Browse files- bvh2ts.py +69 -0
- checkpoint/UniMTS.pth +3 -0
- contrastive.py +62 -0
- data.py +321 -0
- evaluate.py +99 -0
- evaluate_custom.py +101 -0
- finetune.py +169 -0
- finetune_custom.py +172 -0
- model.py +350 -0
- pos2bvh.py +41 -0
- pretrain.py +111 -0
- run_evaluation.sh +4 -0
- run_evaluation_custom.sh +8 -0
- run_finetune.sh +19 -0
- run_finetune_custom.sh +33 -0
- run_pretrain.sh +4 -0
- text_aug.py +66 -0
- utils.py +215 -0
bvh2ts.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from imusim.all import *
|
2 |
+
import imusim
|
3 |
+
import numpy as np
|
4 |
+
from tqdm import tqdm
|
5 |
+
import multiprocessing
|
6 |
+
import os
|
7 |
+
|
8 |
+
with open('./bvh/000000.bvh', 'r') as file:
|
9 |
+
lines = file.readlines()
|
10 |
+
line_109 = lines[108]
|
11 |
+
frame_time = line_109.split(': ')[1].strip()
|
12 |
+
frame_time_value = float(frame_time)
|
13 |
+
print(frame_time_value)
|
14 |
+
|
15 |
+
def process_file(f):
|
16 |
+
|
17 |
+
imu_file_path = './output/%s.npy' % f
|
18 |
+
if not os.path.exists(imu_file_path):
|
19 |
+
|
20 |
+
samplingPeriod = frame_time_value
|
21 |
+
imu = Orient3IMU()
|
22 |
+
env = Environment()
|
23 |
+
|
24 |
+
samples = 1000
|
25 |
+
rotationalVelocity = 20
|
26 |
+
calibrator = ScaleAndOffsetCalibrator(env, samples, samplingPeriod, rotationalVelocity)
|
27 |
+
calibration = calibrator.calibrate(imu)
|
28 |
+
|
29 |
+
try:
|
30 |
+
model = loadBVHFile('./bvh/%s.bvh' % f)
|
31 |
+
splinedModel = SplinedBodyModel(model)
|
32 |
+
|
33 |
+
imu_list = []
|
34 |
+
for i in range(22):
|
35 |
+
sim = Simulation(environment=env)
|
36 |
+
imu.simulation = sim
|
37 |
+
|
38 |
+
if i not in [4,8,13,17,21]:
|
39 |
+
imu.trajectory = splinedModel.getJoint('joint_%s' % str(i))
|
40 |
+
else:
|
41 |
+
imu.trajectory = splinedModel.getPoint('joint_%s_end' % str(i-1))
|
42 |
+
|
43 |
+
sim.time = splinedModel.startTime
|
44 |
+
BasicIMUBehaviour(imu, samplingPeriod, calibration, initialTime=sim.time)
|
45 |
+
sim.run(splinedModel.endTime, printProgress=False)
|
46 |
+
|
47 |
+
acc = imu.accelerometer.calibratedMeasurements.values
|
48 |
+
gyro = imu.gyroscope.calibratedMeasurements.values
|
49 |
+
|
50 |
+
imu_npy = np.concatenate((acc, gyro), axis=0)
|
51 |
+
imu_list.append(imu_npy)
|
52 |
+
|
53 |
+
imu_npy = np.stack(imu_list, axis=1).transpose(2,1,0)
|
54 |
+
np.save('./output/%s' % f, imu_npy)
|
55 |
+
|
56 |
+
except (imusim.maths.splines.Spline.InsufficientPointsError, AttributeError, IndexError) as e:
|
57 |
+
print(f"Error processing file {f}: {e}. Skipping.")
|
58 |
+
with open('log.txt', 'a') as log_file:
|
59 |
+
log_file.write(f + '\n')
|
60 |
+
|
61 |
+
source_dir = './bvh'
|
62 |
+
npy_files = [file[:-4] for file in os.listdir(source_dir) if file.endswith('.bvh')]
|
63 |
+
|
64 |
+
# Process files in parallel
|
65 |
+
pool = multiprocessing.Pool(processes=8)
|
66 |
+
for _ in tqdm(pool.imap_unordered(process_file, npy_files), total=len(npy_files)):
|
67 |
+
pass
|
68 |
+
pool.close()
|
69 |
+
pool.join()
|
checkpoint/UniMTS.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a9858c0084d936655240407e30ff9db9adeded6a67dc5650e3f667578e93b220
|
3 |
+
size 274583082
|
contrastive.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import clip
|
4 |
+
from model import ST_GCN_18
|
5 |
+
|
6 |
+
class ContrastiveModule(nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, args):
|
9 |
+
super(ContrastiveModule, self).__init__()
|
10 |
+
|
11 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
+
model, preprocess = clip.load("ViT-B/32", device=device)
|
13 |
+
del model.visual
|
14 |
+
self.model = model
|
15 |
+
|
16 |
+
base_channel = 3
|
17 |
+
base_channel = base_channel * 2 if args.gyro else base_channel
|
18 |
+
base_channel = base_channel * 2 if args.stft else base_channel
|
19 |
+
self.model.acc = ST_GCN_18(in_channels=base_channel)
|
20 |
+
|
21 |
+
self.model = self.model.float()
|
22 |
+
|
23 |
+
if args.stage == 'finetune':
|
24 |
+
self.fc = nn.Linear(512, args.num_class)
|
25 |
+
|
26 |
+
def encode_image(self, image):
|
27 |
+
return self.model.acc(image.float()).squeeze(-1).squeeze(-1)
|
28 |
+
|
29 |
+
def encode_text(self, text):
|
30 |
+
x = self.model.token_embedding(text).float() # b,t,512
|
31 |
+
x = x + self.model.positional_embedding.float()
|
32 |
+
x = x.permute(1, 0, 2) # b,t,512 -> t,b,512
|
33 |
+
x = self.model.transformer(x)
|
34 |
+
x = x.permute(1, 0, 2) # t,b,512 -> b,t,512
|
35 |
+
x = self.model.ln_final(x).float() # b,t,512
|
36 |
+
|
37 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
38 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.model.text_projection # b,512
|
39 |
+
|
40 |
+
return x
|
41 |
+
|
42 |
+
def classifier(self, image):
|
43 |
+
# for fine-tuning
|
44 |
+
imu_features = self.model.acc(image.float()).squeeze(-1).squeeze(-1)
|
45 |
+
out = self.fc(imu_features)
|
46 |
+
return out
|
47 |
+
|
48 |
+
def forward(self, inputs_imu, inputs_text):
|
49 |
+
|
50 |
+
imu_features = self.encode_image(inputs_imu)
|
51 |
+
text_features = self.encode_text(inputs_text)
|
52 |
+
|
53 |
+
# normalized features
|
54 |
+
imu_features = imu_features / imu_features.norm(dim=-1, keepdim=True)
|
55 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
56 |
+
|
57 |
+
# logits
|
58 |
+
logit_scale = self.model.logit_scale.exp()
|
59 |
+
logits_per_image = logit_scale * imu_features @ text_features.t()
|
60 |
+
logits_per_text = logits_per_image.t()
|
61 |
+
|
62 |
+
return logits_per_image, logits_per_text
|
data.py
ADDED
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
from scipy.signal import resample
|
7 |
+
import clip
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
|
10 |
+
class CLIPDataset(Dataset):
|
11 |
+
|
12 |
+
def __init__(self, args):
|
13 |
+
|
14 |
+
imu_dirs = [
|
15 |
+
f'{args.data_path}/sim/',
|
16 |
+
]
|
17 |
+
text_dirs = [
|
18 |
+
f'{args.data_path}/aug_texts/',
|
19 |
+
]
|
20 |
+
self.paths = []
|
21 |
+
for imu_dir, text_dir in zip(imu_dirs, text_dirs):
|
22 |
+
imu_files = [f.split('.')[0] for f in os.listdir(imu_dir) if os.path.isfile(os.path.join(imu_dir, f))]
|
23 |
+
text_files = [f.split('.')[0] for f in os.listdir(text_dir) if os.path.isfile(os.path.join(text_dir, f))]
|
24 |
+
common_files = [f for f in imu_files if f in text_files]
|
25 |
+
for f in common_files:
|
26 |
+
self.paths.append((os.path.join(imu_dir, f + '.npy'), os.path.join(text_dir, f + '.txt')))
|
27 |
+
|
28 |
+
self.args = args
|
29 |
+
if args.sample < 1:
|
30 |
+
self.paths = random.sample(self.paths, int(len(self.paths) * args.sample))
|
31 |
+
|
32 |
+
def __len__(self):
|
33 |
+
return len(self.paths)
|
34 |
+
|
35 |
+
def __getitem__(self, idx):
|
36 |
+
|
37 |
+
# load imu
|
38 |
+
imu_path, text_path = self.paths[idx]
|
39 |
+
imu = np.load(imu_path)
|
40 |
+
imu[np.isnan(imu)] = 0
|
41 |
+
|
42 |
+
# padding
|
43 |
+
if len(imu) < self.args.padding_size:
|
44 |
+
imu = np.pad(imu, ((0, self.args.padding_size - len(imu)), (0, 0), (0, 0)), mode='wrap')
|
45 |
+
imu = imu[:self.args.padding_size]
|
46 |
+
|
47 |
+
# random masking
|
48 |
+
mask = np.zeros_like(imu)
|
49 |
+
k = np.random.randint(1, 6) # randomly select k joints
|
50 |
+
selected_joints = np.random.choice(22, k, replace=False)
|
51 |
+
mask[:,selected_joints] = 1
|
52 |
+
imu = imu.reshape(len(imu), -1)
|
53 |
+
mask = mask.reshape(len(mask), -1)
|
54 |
+
|
55 |
+
# load text
|
56 |
+
with open(text_path, 'r') as file:
|
57 |
+
lines = file.readlines()
|
58 |
+
|
59 |
+
text = random.choice(lines).split('#')[0].strip() # remove the comment starting from "#"
|
60 |
+
|
61 |
+
batch = {}
|
62 |
+
batch['imu'] = imu
|
63 |
+
batch['text'] = text
|
64 |
+
batch['mask'] = mask
|
65 |
+
|
66 |
+
return batch
|
67 |
+
|
68 |
+
def select_samples(data, masks, labels, k, name, data_path):
|
69 |
+
unique_labels = torch.unique(labels)
|
70 |
+
selected_data = []
|
71 |
+
selected_masks = []
|
72 |
+
selected_labels = []
|
73 |
+
all_indices = torch.load(f'{data_path}/few_shot_data_2/{name}_k={k}.pth')
|
74 |
+
|
75 |
+
for i, label in enumerate(unique_labels):
|
76 |
+
selected_indices = all_indices[i]
|
77 |
+
selected_data.append(data[selected_indices])
|
78 |
+
selected_masks.append(masks[selected_indices])
|
79 |
+
selected_labels.append(labels[selected_indices])
|
80 |
+
|
81 |
+
selected_data = torch.cat(selected_data, dim=0)
|
82 |
+
selected_masks = torch.cat(selected_masks, dim=0)
|
83 |
+
selected_labels = torch.cat(selected_labels, dim=0)
|
84 |
+
|
85 |
+
return selected_data, selected_masks, selected_labels
|
86 |
+
|
87 |
+
def load(dataset, padding_size, data_path, split='test', k=None):
|
88 |
+
|
89 |
+
print(dataset)
|
90 |
+
|
91 |
+
X = np.load(f'{data_path}/{dataset}/X_{split}.npy')
|
92 |
+
real_labels = torch.from_numpy(np.load(f'{data_path}/{dataset}/y_{split}.npy'))
|
93 |
+
with open(f'{data_path}/{dataset}/{dataset}.json', 'r') as file:
|
94 |
+
data = json.load(file)
|
95 |
+
all_X = np.zeros((X.shape[0], X.shape[1], 22, 6))
|
96 |
+
|
97 |
+
if dataset == 'PAMAP':
|
98 |
+
all_X[:,:,21] = np.concatenate((X[:,:,0:3], X[:,:,3:6]), axis=-1)
|
99 |
+
all_X[:,:,11] = np.concatenate((X[:,:,18:21], X[:,:,21:24]), axis=-1)
|
100 |
+
all_X[:,:,7] = np.concatenate((X[:,:,9:12], X[:,:,12:15]), axis=-1)
|
101 |
+
original_sampling_rate = 100
|
102 |
+
num_classes = 12
|
103 |
+
|
104 |
+
elif dataset == 'USCHAD':
|
105 |
+
all_X[:,:,5] = np.concatenate((X[:,:,0:3] * 9.80665, X[:,:,3:6] / 180 * np.pi), axis=-1)
|
106 |
+
original_sampling_rate = 100
|
107 |
+
num_classes = 12
|
108 |
+
|
109 |
+
elif dataset == 'UCIHAR':
|
110 |
+
all_X[:,:,9] = np.concatenate((X[:,:,6:9] * 9.80665, X[:,:,3:6]), axis=-1) # linear accel, gyro, total accel
|
111 |
+
original_sampling_rate = 50
|
112 |
+
num_classes = 6
|
113 |
+
|
114 |
+
elif dataset == 'Opp_g':
|
115 |
+
all_X[:,:,10] = np.concatenate((X[:,:,0:3] / 1000 * 9.8, X[:,:,3:6] / 1000), axis=-1) # convert unit from milli g to m/s^2
|
116 |
+
all_X[:,:,19] = np.concatenate((X[:,:,9:12] / 1000 * 9.8, X[:,:,12:15] / 1000), axis=-1)
|
117 |
+
all_X[:,:,20] = np.concatenate((X[:,:,18:21] / 1000 * 9.8, X[:,:,21:24] / 1000), axis=-1)
|
118 |
+
all_X[:,:,15] = np.concatenate((X[:,:,27:30] / 1000 * 9.8, X[:,:,30:33] / 1000), axis=-1)
|
119 |
+
all_X[:,:,16] = np.concatenate((X[:,:,36:39] / 1000 * 9.8, X[:,:,39:42] / 1000), axis=-1)
|
120 |
+
original_sampling_rate = 30
|
121 |
+
num_classes = 4 # locomotion
|
122 |
+
|
123 |
+
elif dataset == 'WISDM':
|
124 |
+
all_X[:,:,21] = np.concatenate((X[:,:,0:3], X[:,:,3:6]), axis=-1)
|
125 |
+
original_sampling_rate = 20
|
126 |
+
num_classes = 18
|
127 |
+
|
128 |
+
elif dataset == 'DSADS':
|
129 |
+
all_X[:,:,11] = np.concatenate((X[:,:,0:3], X[:,:,3:6]), axis=-1)
|
130 |
+
all_X[:,:,21] = np.concatenate((X[:,:,9:12], X[:,:,12:15]), axis=-1)
|
131 |
+
all_X[:,:,17] = np.concatenate((X[:,:,18:21], X[:,:,21:24]), axis=-1)
|
132 |
+
all_X[:,:,6] = np.concatenate((X[:,:,27:30], X[:,:,30:33]), axis=-1)
|
133 |
+
all_X[:,:,2] = np.concatenate((X[:,:,36:39], X[:,:,39:42]), axis=-1)
|
134 |
+
original_sampling_rate = 25
|
135 |
+
num_classes = 19
|
136 |
+
|
137 |
+
elif dataset == 'Harth':
|
138 |
+
all_X[:,:,9,:3] = X[:,:,:3] * 9.80665
|
139 |
+
all_X[:,:,6,:3] = X[:,:,3:6] * 9.80665
|
140 |
+
original_sampling_rate = 50
|
141 |
+
num_classes = 12
|
142 |
+
|
143 |
+
elif dataset == 'Wharf':
|
144 |
+
X = -14.709 + X / 63 * (2 * 14.709)
|
145 |
+
all_X[:,:,21,:3] = X
|
146 |
+
original_sampling_rate = 32
|
147 |
+
num_classes = 14
|
148 |
+
|
149 |
+
elif dataset == 'Mhealth':
|
150 |
+
all_X[:,:,11,:3] = X[:,:,0:3]
|
151 |
+
all_X[:,:,3] = np.concatenate((X[:,:,6:9], X[:,:,9:12] / 180 * np.pi), axis=-1)
|
152 |
+
all_X[:,:,21] = np.concatenate((X[:,:,15:18], X[:,:,18:21] / 180 * np.pi), axis=-1)
|
153 |
+
original_sampling_rate = 50
|
154 |
+
num_classes = 12
|
155 |
+
|
156 |
+
elif dataset == 'UTD-MHAD':
|
157 |
+
all_X[real_labels < 21,:,21,:] = np.concatenate((X[real_labels < 21,:,0:3] * 9.80665, X[real_labels < 21,:,3:6] / 180 * np.pi), axis=-1)
|
158 |
+
all_X[real_labels >= 21,:,5,:] = np.concatenate((X[real_labels >= 21,:,0:3] * 9.80665, X[real_labels >= 21,:,3:6] / 180 * np.pi), axis=-1)
|
159 |
+
original_sampling_rate = 50
|
160 |
+
num_classes = 27
|
161 |
+
|
162 |
+
elif dataset == 'MotionSense':
|
163 |
+
all_X[:,:,5] = np.concatenate((X[:,:,:3] * 9.80665, X[:,:,3:6]), axis=-1)
|
164 |
+
all_X[:,:,1] = np.concatenate((X[:,:,:3] * 9.80665, X[:,:,3:6]), axis=-1)
|
165 |
+
original_sampling_rate = 50
|
166 |
+
num_classes = 6
|
167 |
+
|
168 |
+
elif dataset == 'w-HAR':
|
169 |
+
all_X[:,:,7] = np.concatenate((X[:,:,:3] * 9.80665, X[:,:,3:6] / 180 * np.pi), axis=-1)
|
170 |
+
original_sampling_rate = 250
|
171 |
+
num_classes = 7
|
172 |
+
|
173 |
+
elif dataset == 'Shoaib':
|
174 |
+
all_X[:,:,1] = X[:,:,:6]
|
175 |
+
all_X[:,:,5] = X[:,:,6:12]
|
176 |
+
all_X[:,:,21] = X[:,:,12:18]
|
177 |
+
all_X[:,:,20] = X[:,:,18:24]
|
178 |
+
all_X[:,:,0] = X[:,:,24:30]
|
179 |
+
original_sampling_rate = 50
|
180 |
+
num_classes = 7
|
181 |
+
|
182 |
+
elif dataset == 'har70plus':
|
183 |
+
all_X[:,:,0,:3] = X[:,:,:3] * 9.80665
|
184 |
+
all_X[:,:,5,:3] = X[:,:,3:6] * 9.80665
|
185 |
+
original_sampling_rate = 50
|
186 |
+
num_classes = 7
|
187 |
+
|
188 |
+
elif dataset == 'MMAct':
|
189 |
+
all_X[:,:,5] = np.concatenate((X[:,:,:3], X[:,:,3:6]), axis=-1)
|
190 |
+
all_X[:,:,21,:3] = X[:,:,6:9]
|
191 |
+
original_sampling_rate = 50
|
192 |
+
num_classes = 35
|
193 |
+
|
194 |
+
elif dataset == 'realworld':
|
195 |
+
all_X[:,:,14] = np.concatenate((X[:,:,:3], X[:,:,3:6]), axis=-1)
|
196 |
+
all_X[:,:,16] = np.concatenate((X[:,:,6:9], X[:,:,9:12]), axis=-1)
|
197 |
+
all_X[:,:,13] = np.concatenate((X[:,:,12:15], X[:,:,15:18]), axis=-1)
|
198 |
+
all_X[:,:,3] = np.concatenate((X[:,:,18:21], X[:,:,21:24]), axis=-1)
|
199 |
+
all_X[:,:,1] = np.concatenate((X[:,:,24:27], X[:,:,27:30]), axis=-1)
|
200 |
+
all_X[:,:,15] = np.concatenate((X[:,:,30:33], X[:,:,33:36]), axis=-1)
|
201 |
+
all_X[:,:,9] = np.concatenate((X[:,:,36:39], X[:,:,39:42]), axis=-1)
|
202 |
+
original_sampling_rate = 50
|
203 |
+
num_classes = 8
|
204 |
+
|
205 |
+
elif dataset == 'TNDA-HAR':
|
206 |
+
all_X[:,:,20] = np.concatenate((X[:,:,:3], X[:,:,3:6]), axis=-1)
|
207 |
+
all_X[:,:,2] = np.concatenate((X[:,:,6:9], X[:,:,9:12]), axis=-1)
|
208 |
+
all_X[:,:,21] = np.concatenate((X[:,:,12:15], X[:,:,15:18]), axis=-1)
|
209 |
+
all_X[:,:,3] = np.concatenate((X[:,:,18:21], X[:,:,21:24]), axis=-1)
|
210 |
+
all_X[:,:,11] = np.concatenate((X[:,:,24:27], X[:,:,27:30]), axis=-1)
|
211 |
+
original_sampling_rate = 50
|
212 |
+
num_classes = 8
|
213 |
+
|
214 |
+
elif dataset == 'ut-complex':
|
215 |
+
all_X[:,:,5] = np.concatenate((X[:,:,:3], X[:,:,3:6]), axis=-1)
|
216 |
+
all_X[:,:,21] = np.concatenate((X[:,:,6:9], X[:,:,9:12]), axis=-1)
|
217 |
+
original_sampling_rate = 50
|
218 |
+
num_classes = 13
|
219 |
+
|
220 |
+
all_X = all_X.reshape(all_X.shape[0], all_X.shape[1], 22 * 6)
|
221 |
+
|
222 |
+
# resample real data to 20 Hz
|
223 |
+
new_sampling_rate = 20
|
224 |
+
new_length = int((all_X.shape[1] / original_sampling_rate) * new_sampling_rate)
|
225 |
+
resampled_data = np.array([resample(sequence, new_length) for sequence in all_X])
|
226 |
+
|
227 |
+
# pad real data to args.padding_size
|
228 |
+
masks = np.ones_like(resampled_data)
|
229 |
+
if resampled_data.shape[1] < padding_size:
|
230 |
+
resampled_data = np.pad(resampled_data, ((0, 0), (0, padding_size - resampled_data.shape[1]), (0, 0)), 'wrap') # N, 200, 6
|
231 |
+
masks = np.pad(masks, ((0, 0), (0, padding_size - masks.shape[1]), (0, 0)), 'constant') # N, 200, 6
|
232 |
+
real_inputs = torch.from_numpy(resampled_data[:,:padding_size,:]).float()
|
233 |
+
real_masks = torch.from_numpy(masks[:,:padding_size,:]).float()
|
234 |
+
|
235 |
+
if split == 'train' and k and k < len(real_inputs):
|
236 |
+
real_inputs, real_masks, real_labels = select_samples(real_inputs, real_masks, real_labels, k, dataset, data_path)
|
237 |
+
print(real_inputs.shape, real_labels.shape)
|
238 |
+
|
239 |
+
# load text
|
240 |
+
label_dictionary = data['label_dictionary']
|
241 |
+
label_list = [' '.join(labels) for labels in label_dictionary.values()]
|
242 |
+
all_text = clip.tokenize(label_list).cuda()
|
243 |
+
|
244 |
+
return real_inputs, real_masks, real_labels, label_list, all_text, num_classes
|
245 |
+
|
246 |
+
def load_multiple(dataset_list, padding_size, data_path, split='test', k=None):
|
247 |
+
|
248 |
+
real_inputs_list, real_masks_list, real_labels_list, label_list_list, all_text_list, num_classes_list = [], [], [], [], [], []
|
249 |
+
for dataset in dataset_list:
|
250 |
+
real_inputs, real_masks, real_labels, label_list, all_text, num_classes = load(dataset, padding_size, data_path, split, k)
|
251 |
+
real_inputs_list.append(real_inputs)
|
252 |
+
real_masks_list.append(real_masks)
|
253 |
+
real_labels_list.append(real_labels)
|
254 |
+
label_list_list.append(label_list)
|
255 |
+
all_text_list.append(all_text)
|
256 |
+
num_classes_list.append(num_classes)
|
257 |
+
|
258 |
+
return real_inputs_list, real_masks_list, real_labels_list, label_list_list, all_text_list, num_classes_list
|
259 |
+
|
260 |
+
def load_custom_data(X_path, y_path, config_path, joint_list, original_sampling_rate, padding_size=200, split='test', k=None, few_shot_path=None):
|
261 |
+
|
262 |
+
X = np.load(X_path)
|
263 |
+
real_labels = torch.from_numpy(np.load(y_path))
|
264 |
+
with open(config_path, 'r') as file:
|
265 |
+
data = json.load(file)
|
266 |
+
all_X = np.zeros((X.shape[0], X.shape[1], 22, 6))
|
267 |
+
|
268 |
+
for i, joint in enumerate(joint_list):
|
269 |
+
all_X[:,:,joint] = np.concatenate((X[:,:,6*i:6*i+3], X[:,:,6*i+3:6*i+6]), axis=-1)
|
270 |
+
|
271 |
+
all_X = all_X.reshape(all_X.shape[0], all_X.shape[1], 22 * 6)
|
272 |
+
|
273 |
+
# resample real data to 20 Hz
|
274 |
+
new_sampling_rate = 20
|
275 |
+
new_length = int((all_X.shape[1] / original_sampling_rate) * new_sampling_rate)
|
276 |
+
resampled_data = np.array([resample(sequence, new_length) for sequence in all_X])
|
277 |
+
|
278 |
+
# pad real data to args.padding_size
|
279 |
+
masks = np.ones_like(resampled_data)
|
280 |
+
if resampled_data.shape[1] < padding_size:
|
281 |
+
resampled_data = np.pad(resampled_data, ((0, 0), (0, padding_size - resampled_data.shape[1]), (0, 0)), 'wrap') # N, 200, 6
|
282 |
+
masks = np.pad(masks, ((0, 0), (0, padding_size - masks.shape[1]), (0, 0)), 'constant') # N, 200, 6
|
283 |
+
real_inputs = torch.from_numpy(resampled_data[:,:padding_size,:]).float()
|
284 |
+
real_masks = torch.from_numpy(masks[:,:padding_size,:]).float()
|
285 |
+
|
286 |
+
if split == 'train' and k and k < len(real_inputs):
|
287 |
+
|
288 |
+
unique_labels = torch.unique(real_labels)
|
289 |
+
|
290 |
+
if few_shot_path is None:
|
291 |
+
print('Generating few shot indices ...')
|
292 |
+
all_indices = []
|
293 |
+
for i, label in enumerate(unique_labels):
|
294 |
+
indices = torch.where(real_labels == label)[0]
|
295 |
+
selected_indices = indices[torch.randperm(len(indices))[:k]]
|
296 |
+
all_indices.append(selected_indices)
|
297 |
+
else:
|
298 |
+
print('Loading existing few shot indices ...')
|
299 |
+
all_indices = torch.load(few_shot_path)
|
300 |
+
|
301 |
+
selected_data = []
|
302 |
+
selected_masks = []
|
303 |
+
selected_labels = []
|
304 |
+
for i, label in enumerate(unique_labels):
|
305 |
+
selected_indices = all_indices[i]
|
306 |
+
selected_data.append(real_inputs[selected_indices])
|
307 |
+
selected_masks.append(real_masks[selected_indices])
|
308 |
+
selected_labels.append(real_labels[selected_indices])
|
309 |
+
selected_data = torch.cat(selected_data, dim=0)
|
310 |
+
selected_masks = torch.cat(selected_masks, dim=0)
|
311 |
+
selected_labels = torch.cat(selected_labels, dim=0)
|
312 |
+
real_inputs, real_masks, real_labels = selected_data, selected_masks, selected_labels
|
313 |
+
|
314 |
+
print(real_inputs.shape, real_labels.shape)
|
315 |
+
|
316 |
+
# load text
|
317 |
+
label_dictionary = data['label_dictionary']
|
318 |
+
label_list = [' '.join(labels) for labels in label_dictionary.values()]
|
319 |
+
all_text = clip.tokenize(label_list).cuda()
|
320 |
+
|
321 |
+
return real_inputs, real_masks, real_labels, label_list, all_text
|
evaluate.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
|
7 |
+
import wandb
|
8 |
+
import datetime
|
9 |
+
from torch.utils.data import DataLoader, TensorDataset
|
10 |
+
|
11 |
+
from data import load, load_multiple
|
12 |
+
from utils import compute_metrics_np
|
13 |
+
from contrastive import ContrastiveModule
|
14 |
+
|
15 |
+
def main(args):
|
16 |
+
# load real data
|
17 |
+
dataset_list = ['Opp_g','UCIHAR','MotionSense','w-HAR','Shoaib','har70plus','realworld','TNDA-HAR','PAMAP',\
|
18 |
+
'USCHAD','Mhealth','Harth','ut-complex','Wharf','WISDM','DSADS','UTD-MHAD','MMAct']
|
19 |
+
real_inputs_list, real_masks_list, real_labels_list, label_list_list, all_text_list, _ = load_multiple(dataset_list, args.padding_size, args.data_path)
|
20 |
+
test_real_dataloader_list = []
|
21 |
+
for real_inputs, real_masks, real_labels in zip(real_inputs_list, real_masks_list, real_labels_list):
|
22 |
+
real_dataset = TensorDataset(real_inputs, real_masks, real_labels)
|
23 |
+
test_real_dataloader_list.append(DataLoader(real_dataset, batch_size=args.batch_size, shuffle=False))
|
24 |
+
|
25 |
+
date = datetime.datetime.now().strftime("%d-%m-%y_%H:%M")
|
26 |
+
wandb.init(
|
27 |
+
project='UniMTS',
|
28 |
+
name=f"{args.run_tag}_{args.stage}_" + f"{date}"
|
29 |
+
)
|
30 |
+
|
31 |
+
model = ContrastiveModule(args).cuda()
|
32 |
+
|
33 |
+
model.model.load_state_dict(torch.load(f'{args.checkpoint}'))
|
34 |
+
|
35 |
+
model.eval()
|
36 |
+
with torch.no_grad():
|
37 |
+
for ds, real_labels, test_real_dataloader, label_list, all_text in zip(dataset_list, real_labels_list, test_real_dataloader_list, label_list_list, all_text_list):
|
38 |
+
pred_whole, logits_whole = [], []
|
39 |
+
for input, mask, label in test_real_dataloader:
|
40 |
+
|
41 |
+
input = input.cuda()
|
42 |
+
mask = mask.cuda()
|
43 |
+
label = label.cuda()
|
44 |
+
|
45 |
+
if not args.gyro:
|
46 |
+
b, t, c = input.shape
|
47 |
+
indices = np.array([range(i, i+3) for i in range(0, c, 6)]).flatten()
|
48 |
+
input = input[:,:,indices]
|
49 |
+
|
50 |
+
b, t, c = input.shape
|
51 |
+
if args.stft:
|
52 |
+
input_stft = input.permute(0,2,1).reshape(b * c,t)
|
53 |
+
input_stft = torch.abs(torch.stft(input_stft, n_fft = 25, hop_length = 28, onesided = False, center = True, return_complex = True))
|
54 |
+
input_stft = input_stft.reshape(b, c, input_stft.shape[-2], input_stft.shape[-1]).reshape(b, c, t).permute(0,2,1)
|
55 |
+
input = torch.cat((input, input_stft), dim=-1)
|
56 |
+
|
57 |
+
input = input.reshape(b, t, 22, -1).permute(0, 3, 1, 2).unsqueeze(-1)
|
58 |
+
|
59 |
+
logits_per_imu, logits_per_text = model(input, all_text)
|
60 |
+
logits_whole.append(logits_per_imu)
|
61 |
+
|
62 |
+
pred = torch.argmax(logits_per_imu, dim=-1).detach().cpu().numpy()
|
63 |
+
pred_whole.append(pred)
|
64 |
+
|
65 |
+
pred = np.concatenate(pred_whole)
|
66 |
+
acc = accuracy_score(real_labels, pred)
|
67 |
+
prec = precision_score(real_labels, pred, average='macro')
|
68 |
+
rec = recall_score(real_labels, pred, average='macro')
|
69 |
+
f1 = f1_score(real_labels, pred, average='macro')
|
70 |
+
|
71 |
+
print(f"{ds} acc: {acc}, {ds} prec: {prec}, {ds} rec: {rec}, {ds} f1: {f1}")
|
72 |
+
wandb.log({f"{ds} acc": acc, f"{ds} prec": prec, f"{ds} rec": rec, f"{ds} f1": f1})
|
73 |
+
|
74 |
+
logits_whole = torch.cat(logits_whole)
|
75 |
+
r_at_1, r_at_2, r_at_3, r_at_4, r_at_5, mrr_score = compute_metrics_np(logits_whole.detach().cpu().numpy(), real_labels.numpy())
|
76 |
+
|
77 |
+
print(f"{ds} R@1: {r_at_1}, R@2: {r_at_2}, R@3: {r_at_3}, R@4: {r_at_4}, R@5: {r_at_5}, MRR: {mrr_score}")
|
78 |
+
wandb.log({f"{ds} R@1": r_at_1, f"{ds} R@2": r_at_2, f"{ds} R@3": r_at_3, f"{ds} R@4": r_at_4, f"{ds} R@5": r_at_5, f"{ds} MRR": mrr_score})
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
|
82 |
+
parser = argparse.ArgumentParser(description='Unified Pre-trained Motion Time Series Model')
|
83 |
+
|
84 |
+
# data
|
85 |
+
parser.add_argument('--padding_size', type=int, default='200', help='padding size (default: 200)')
|
86 |
+
parser.add_argument('--data_path', type=str, default='./data/', help='/path/to/data/')
|
87 |
+
|
88 |
+
# training
|
89 |
+
parser.add_argument('--run_tag', type=str, default='exp0', help='logging tag')
|
90 |
+
parser.add_argument('--stage', type=str, default='evaluation', help='training or evaluation stage')
|
91 |
+
parser.add_argument('--gyro', type=int, default=0, help='using gyro or not')
|
92 |
+
parser.add_argument('--stft', type=int, default=0, help='using stft or not')
|
93 |
+
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
|
94 |
+
|
95 |
+
parser.add_argument('--checkpoint', type=str, default='./checkpoint/', help='/path/to/checkpoint/')
|
96 |
+
|
97 |
+
args = parser.parse_args()
|
98 |
+
|
99 |
+
main(args)
|
evaluate_custom.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
|
7 |
+
import wandb
|
8 |
+
import datetime
|
9 |
+
from torch.utils.data import DataLoader, TensorDataset
|
10 |
+
|
11 |
+
from data import load, load_multiple, load_custom_data
|
12 |
+
from utils import compute_metrics_np
|
13 |
+
from contrastive import ContrastiveModule
|
14 |
+
|
15 |
+
def main(args):
|
16 |
+
# load real data
|
17 |
+
|
18 |
+
real_inputs, real_masks, real_labels, label_list, all_text = load_custom_data(
|
19 |
+
args.X_path, args.y_path, args.config_path, args.joint_list, args.original_sampling_rate, padding_size=args.padding_size, split='test'
|
20 |
+
)
|
21 |
+
real_dataset = TensorDataset(real_inputs, real_masks, real_labels)
|
22 |
+
test_real_dataloader = DataLoader(real_dataset, batch_size=args.batch_size, shuffle=False)
|
23 |
+
|
24 |
+
date = datetime.datetime.now().strftime("%d-%m-%y_%H:%M")
|
25 |
+
wandb.init(
|
26 |
+
project='UniMTS',
|
27 |
+
name=f"{args.run_tag}_{args.stage}_" + f"{date}"
|
28 |
+
)
|
29 |
+
|
30 |
+
model = ContrastiveModule(args).cuda()
|
31 |
+
|
32 |
+
model.model.load_state_dict(torch.load(f'{args.checkpoint}'))
|
33 |
+
|
34 |
+
model.eval()
|
35 |
+
with torch.no_grad():
|
36 |
+
pred_whole, logits_whole = [], []
|
37 |
+
for input, mask, label in test_real_dataloader:
|
38 |
+
|
39 |
+
input = input.cuda()
|
40 |
+
mask = mask.cuda()
|
41 |
+
label = label.cuda()
|
42 |
+
|
43 |
+
if not args.gyro:
|
44 |
+
b, t, c = input.shape
|
45 |
+
indices = np.array([range(i, i+3) for i in range(0, c, 6)]).flatten()
|
46 |
+
input = input[:,:,indices]
|
47 |
+
|
48 |
+
b, t, c = input.shape
|
49 |
+
if args.stft:
|
50 |
+
input_stft = input.permute(0,2,1).reshape(b * c,t)
|
51 |
+
input_stft = torch.abs(torch.stft(input_stft, n_fft = 25, hop_length = 28, onesided = False, center = True, return_complex = True))
|
52 |
+
input_stft = input_stft.reshape(b, c, input_stft.shape[-2], input_stft.shape[-1]).reshape(b, c, t).permute(0,2,1)
|
53 |
+
input = torch.cat((input, input_stft), dim=-1)
|
54 |
+
|
55 |
+
input = input.reshape(b, t, 22, -1).permute(0, 3, 1, 2).unsqueeze(-1)
|
56 |
+
|
57 |
+
logits_per_imu, logits_per_text = model(input, all_text)
|
58 |
+
logits_whole.append(logits_per_imu)
|
59 |
+
|
60 |
+
pred = torch.argmax(logits_per_imu, dim=-1).detach().cpu().numpy()
|
61 |
+
pred_whole.append(pred)
|
62 |
+
|
63 |
+
pred = np.concatenate(pred_whole)
|
64 |
+
acc = accuracy_score(real_labels, pred)
|
65 |
+
prec = precision_score(real_labels, pred, average='macro')
|
66 |
+
rec = recall_score(real_labels, pred, average='macro')
|
67 |
+
f1 = f1_score(real_labels, pred, average='macro')
|
68 |
+
|
69 |
+
print(f"acc: {acc}, prec: {prec}, rec: {rec}, f1: {f1}")
|
70 |
+
wandb.log({f"acc": acc, f"prec": prec, f"rec": rec, f"f1": f1})
|
71 |
+
|
72 |
+
logits_whole = torch.cat(logits_whole)
|
73 |
+
r_at_1, r_at_2, r_at_3, r_at_4, r_at_5, mrr_score = compute_metrics_np(logits_whole.detach().cpu().numpy(), real_labels.numpy())
|
74 |
+
|
75 |
+
print(f"R@1: {r_at_1}, R@2: {r_at_2}, R@3: {r_at_3}, R@4: {r_at_4}, R@5: {r_at_5}, MRR: {mrr_score}")
|
76 |
+
wandb.log({f"R@1": r_at_1, f"R@2": r_at_2, f"R@3": r_at_3, f"R@4": r_at_4, f"R@5": r_at_5, f"MRR": mrr_score})
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
|
80 |
+
parser = argparse.ArgumentParser(description='Unified Pre-trained Motion Time Series Model')
|
81 |
+
|
82 |
+
# data
|
83 |
+
parser.add_argument('--padding_size', type=int, default='200', help='padding size (default: 200)')
|
84 |
+
parser.add_argument('--X_path', type=str, required=True, help='/path/to/data/')
|
85 |
+
parser.add_argument('--y_path', type=str, required=True, help='/path/to/label/')
|
86 |
+
parser.add_argument('--config_path', type=str, required=True, help='/path/to/config/')
|
87 |
+
parser.add_argument('--joint_list', nargs='+', type=int, required=True, help='List of joint indices')
|
88 |
+
parser.add_argument('--original_sampling_rate', type=int, required=True, help='original sampling rate')
|
89 |
+
|
90 |
+
# training
|
91 |
+
parser.add_argument('--run_tag', type=str, default='exp0', help='logging tag')
|
92 |
+
parser.add_argument('--stage', type=str, default='evaluation', help='training or evaluation stage')
|
93 |
+
parser.add_argument('--gyro', type=int, default=0, help='using gyro or not')
|
94 |
+
parser.add_argument('--stft', type=int, default=0, help='using stft or not')
|
95 |
+
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
|
96 |
+
|
97 |
+
parser.add_argument('--checkpoint', type=str, default='./checkpoint/', help='/path/to/checkpoint/')
|
98 |
+
|
99 |
+
args = parser.parse_args()
|
100 |
+
|
101 |
+
main(args)
|
finetune.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
|
9 |
+
import wandb
|
10 |
+
import datetime
|
11 |
+
from torch.utils.data import DataLoader, TensorDataset
|
12 |
+
import torch.optim as optim
|
13 |
+
|
14 |
+
from data import load_multiple
|
15 |
+
from utils import compute_metrics_np
|
16 |
+
from contrastive import ContrastiveModule
|
17 |
+
|
18 |
+
def main(args):
|
19 |
+
|
20 |
+
# load real data
|
21 |
+
dataset_list = ['Opp_g','UCIHAR','MotionSense','w-HAR','Shoaib','har70plus','realworld','TNDA-HAR','PAMAP',\
|
22 |
+
'USCHAD','Mhealth','Harth','ut-complex','Wharf','WISDM','DSADS','UTD-MHAD','MMAct']
|
23 |
+
train_inputs_list, train_masks_list, train_labels_list, label_list_list, all_text_list, num_classes_list = load_multiple(dataset_list, args.padding_size, args.data_path, split='train', k=args.k)
|
24 |
+
test_inputs_list, test_masks_list, test_labels_list, label_list_list, all_text_list, _ = load_multiple(dataset_list, args.padding_size, args.data_path, split='test')
|
25 |
+
train_dataloader_list, test_dataloader_list = [], []
|
26 |
+
for real_inputs, real_masks, real_labels in zip(train_inputs_list, train_masks_list, train_labels_list):
|
27 |
+
train_dataset = TensorDataset(real_inputs, real_masks, real_labels)
|
28 |
+
train_dataloader_list.append(DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True))
|
29 |
+
for real_inputs, real_masks, real_labels in zip(test_inputs_list, test_masks_list, test_labels_list):
|
30 |
+
test_dataset = TensorDataset(real_inputs, real_masks, real_labels)
|
31 |
+
test_dataloader_list.append(DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False))
|
32 |
+
|
33 |
+
date = datetime.datetime.now().strftime("%d-%m-%y_%H:%M")
|
34 |
+
wandb.init(
|
35 |
+
project='UniMTS',
|
36 |
+
name=f"{args.run_tag}_{args.stage}_{args.mode}_k={args.k}_" + f"{date}"
|
37 |
+
)
|
38 |
+
|
39 |
+
save_path = './checkpoint/%s/' % args.run_tag
|
40 |
+
|
41 |
+
for ds, train_dataloader, test_dataloader, test_labels, label_list, all_text, num_class in \
|
42 |
+
zip(dataset_list, train_dataloader_list, test_dataloader_list, test_labels_list, label_list_list, all_text_list, num_classes_list):
|
43 |
+
|
44 |
+
args.num_class = num_class
|
45 |
+
model = ContrastiveModule(args).cuda()
|
46 |
+
optimizer = optim.Adam(model.parameters(), lr=1e-4)
|
47 |
+
|
48 |
+
if args.mode == 'full' or args.mode == 'probe':
|
49 |
+
model.model.load_state_dict(torch.load(f'{args.checkpoint}'))
|
50 |
+
if args.mode == 'probe':
|
51 |
+
for name, param in model.model.named_parameters():
|
52 |
+
param.requires_grad = False
|
53 |
+
|
54 |
+
best_loss = None
|
55 |
+
for epoch in range(args.num_epochs):
|
56 |
+
|
57 |
+
tol_loss = 0
|
58 |
+
|
59 |
+
model.train()
|
60 |
+
for i, (input, mask, label) in enumerate(train_dataloader):
|
61 |
+
|
62 |
+
input = input.cuda()
|
63 |
+
labels = label.cuda()
|
64 |
+
|
65 |
+
if not args.gyro:
|
66 |
+
b, t, c = input.shape
|
67 |
+
indices = np.array([range(i, i+3) for i in range(0, c, 6)]).flatten()
|
68 |
+
input = input[:,:,indices]
|
69 |
+
|
70 |
+
b, t, c = input.shape
|
71 |
+
if args.stft:
|
72 |
+
input_stft = input.permute(0,2,1).reshape(b * c,t)
|
73 |
+
input_stft = torch.abs(torch.stft(input_stft, n_fft = 25, hop_length = 28, onesided = False, center = True, return_complex = True))
|
74 |
+
input_stft = input_stft.reshape(b, c, input_stft.shape[-2], input_stft.shape[-1]).reshape(b, c, t).permute(0,2,1)
|
75 |
+
input = torch.cat((input, input_stft), dim=-1)
|
76 |
+
|
77 |
+
input = input.reshape(b, t, 22, -1).permute(0, 3, 1, 2).unsqueeze(-1)
|
78 |
+
|
79 |
+
output = model.classifier(input)
|
80 |
+
|
81 |
+
loss = F.cross_entropy(output.float(), labels.long(), reduction="mean")
|
82 |
+
|
83 |
+
optimizer.zero_grad()
|
84 |
+
loss.backward()
|
85 |
+
optimizer.step()
|
86 |
+
|
87 |
+
tol_loss += len(input) * loss.item()
|
88 |
+
|
89 |
+
# print(epoch, i, loss.item())
|
90 |
+
|
91 |
+
print(f'Epoch [{epoch+1}/{args.num_epochs}], Loss: {tol_loss / len(train_dataset):.4f}')
|
92 |
+
wandb.log({'{ds} loss': tol_loss / len(train_dataset)})
|
93 |
+
|
94 |
+
if best_loss is None or tol_loss < best_loss:
|
95 |
+
best_loss = tol_loss
|
96 |
+
torch.save(model.state_dict(), os.path.join(save_path, f'{ds}_k={args.k}_best_loss.pth'))
|
97 |
+
|
98 |
+
# evaluation
|
99 |
+
model.load_state_dict(torch.load(os.path.join(save_path, f'{ds}_k={args.k}_best_loss.pth')))
|
100 |
+
model.eval()
|
101 |
+
with torch.no_grad():
|
102 |
+
|
103 |
+
pred_whole, logits_whole = [], []
|
104 |
+
for input, mask, label in test_dataloader:
|
105 |
+
|
106 |
+
input = input.cuda()
|
107 |
+
label = label.cuda()
|
108 |
+
|
109 |
+
if not args.gyro:
|
110 |
+
b, t, c = input.shape
|
111 |
+
indices = np.array([range(i, i+3) for i in range(0, c, 6)]).flatten()
|
112 |
+
input = input[:,:,indices]
|
113 |
+
|
114 |
+
b, t, c = input.shape
|
115 |
+
if args.stft:
|
116 |
+
input_stft = input.permute(0,2,1).reshape(b * c,t)
|
117 |
+
input_stft = torch.abs(torch.stft(input_stft, n_fft = 25, hop_length = 28, onesided = False, center = True, return_complex = True))
|
118 |
+
input_stft = input_stft.reshape(b, c, input_stft.shape[-2], input_stft.shape[-1]).reshape(b, c, t).permute(0,2,1)
|
119 |
+
input = torch.cat((input, input_stft), dim=-1)
|
120 |
+
|
121 |
+
input = input.reshape(b, t, 22, -1).permute(0, 3, 1, 2).unsqueeze(-1)
|
122 |
+
|
123 |
+
logits_per_imu = model.classifier(input)
|
124 |
+
logits_whole.append(logits_per_imu)
|
125 |
+
|
126 |
+
pred = torch.argmax(logits_per_imu, dim=-1).detach().cpu().numpy()
|
127 |
+
pred_whole.append(pred)
|
128 |
+
|
129 |
+
pred = np.concatenate(pred_whole)
|
130 |
+
acc = accuracy_score(test_labels, pred)
|
131 |
+
prec = precision_score(test_labels, pred, average='macro')
|
132 |
+
rec = recall_score(test_labels, pred, average='macro')
|
133 |
+
f1 = f1_score(test_labels, pred, average='macro')
|
134 |
+
|
135 |
+
print(f"{ds} acc: {acc}, {ds} prec: {prec}, {ds} rec: {rec}, {ds} f1: {f1}")
|
136 |
+
wandb.log({f"{ds} acc": acc, f"{ds} prec": prec, f"{ds} rec": rec, f"{ds} f1": f1})
|
137 |
+
|
138 |
+
logits_whole = torch.cat(logits_whole)
|
139 |
+
r_at_1, r_at_2, r_at_3, r_at_4, r_at_5, mrr_score = compute_metrics_np(logits_whole.detach().cpu().numpy(), test_labels.numpy())
|
140 |
+
|
141 |
+
print(f"{ds} R@1: {r_at_1}, R@2: {r_at_2}, R@3: {r_at_3}, R@4: {r_at_4}, R@5: {r_at_5}, MRR: {mrr_score}")
|
142 |
+
wandb.log({f"{ds} R@1": r_at_1, f"{ds} R@2": r_at_2, f"{ds} R@3": r_at_3, f"{ds} R@4": r_at_4, f"{ds} R@5": r_at_5, f"{ds} MRR": mrr_score})
|
143 |
+
|
144 |
+
|
145 |
+
if __name__ == "__main__":
|
146 |
+
|
147 |
+
parser = argparse.ArgumentParser(description='Unified Pre-trained Motion Time Series Model')
|
148 |
+
|
149 |
+
# model
|
150 |
+
parser.add_argument('--mode', type=str, default='full', choices=['random','probe','full'], help='full fine-tuning, linear probe, random init')
|
151 |
+
|
152 |
+
# data
|
153 |
+
parser.add_argument('--padding_size', type=int, default='200', help='padding size (default: 200)')
|
154 |
+
parser.add_argument('--k', type=int, help='few shot samples per class (default: None)')
|
155 |
+
parser.add_argument('--data_path', type=str, default='./data/', help='/path/to/data/')
|
156 |
+
|
157 |
+
# training
|
158 |
+
parser.add_argument('--stage', type=str, default='finetune', help='training stage')
|
159 |
+
parser.add_argument('--num_epochs', type=int, default=200, help='number of fine-tuning epochs (default: 200)')
|
160 |
+
parser.add_argument('--run_tag', type=str, default='exp0', help='logging tag')
|
161 |
+
parser.add_argument('--gyro', type=int, default=0, help='using gyro or not')
|
162 |
+
parser.add_argument('--stft', type=int, default=0, help='using stft or not')
|
163 |
+
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
|
164 |
+
|
165 |
+
parser.add_argument('--checkpoint', type=str, default='./checkpoint/', help='/path/to/checkpoint/')
|
166 |
+
|
167 |
+
args = parser.parse_args()
|
168 |
+
|
169 |
+
main(args)
|
finetune_custom.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
|
9 |
+
import wandb
|
10 |
+
import datetime
|
11 |
+
from torch.utils.data import DataLoader, TensorDataset
|
12 |
+
import torch.optim as optim
|
13 |
+
|
14 |
+
from data import load_multiple, load_custom_data
|
15 |
+
from utils import compute_metrics_np
|
16 |
+
from contrastive import ContrastiveModule
|
17 |
+
|
18 |
+
def main(args):
|
19 |
+
|
20 |
+
train_inputs, train_masks, train_labels, _, _ = load_custom_data(
|
21 |
+
args.X_train_path, args.y_train_path, args.config_path, args.joint_list, args.original_sampling_rate, padding_size=args.padding_size, split='train', k=args.k, few_shot_path=None
|
22 |
+
)
|
23 |
+
train_dataset = TensorDataset(train_inputs, train_masks, train_labels)
|
24 |
+
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
|
25 |
+
|
26 |
+
test_inputs, test_masks, test_labels, _, _ = load_custom_data(
|
27 |
+
args.X_test_path, args.y_test_path, args.config_path, args.joint_list, args.original_sampling_rate, padding_size=args.padding_size, split='test'
|
28 |
+
)
|
29 |
+
test_dataset = TensorDataset(test_inputs, test_masks, test_labels)
|
30 |
+
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
|
31 |
+
|
32 |
+
date = datetime.datetime.now().strftime("%d-%m-%y_%H:%M")
|
33 |
+
wandb.init(
|
34 |
+
project='UniMTS',
|
35 |
+
name=f"{args.run_tag}_{args.stage}_{args.mode}_k={args.k}_" + f"{date}"
|
36 |
+
)
|
37 |
+
|
38 |
+
save_path = './checkpoint/%s/' % args.run_tag
|
39 |
+
|
40 |
+
model = ContrastiveModule(args).cuda()
|
41 |
+
optimizer = optim.Adam(model.parameters(), lr=1e-4)
|
42 |
+
|
43 |
+
if args.mode == 'full' or args.mode == 'probe':
|
44 |
+
model.model.load_state_dict(torch.load(f'{args.checkpoint}'))
|
45 |
+
if args.mode == 'probe':
|
46 |
+
for name, param in model.model.named_parameters():
|
47 |
+
param.requires_grad = False
|
48 |
+
|
49 |
+
best_loss = None
|
50 |
+
for epoch in range(args.num_epochs):
|
51 |
+
|
52 |
+
tol_loss = 0
|
53 |
+
|
54 |
+
model.train()
|
55 |
+
for i, (input, mask, label) in enumerate(train_dataloader):
|
56 |
+
|
57 |
+
input = input.cuda()
|
58 |
+
labels = label.cuda()
|
59 |
+
|
60 |
+
if not args.gyro:
|
61 |
+
b, t, c = input.shape
|
62 |
+
indices = np.array([range(i, i+3) for i in range(0, c, 6)]).flatten()
|
63 |
+
input = input[:,:,indices]
|
64 |
+
|
65 |
+
b, t, c = input.shape
|
66 |
+
if args.stft:
|
67 |
+
input_stft = input.permute(0,2,1).reshape(b * c,t)
|
68 |
+
input_stft = torch.abs(torch.stft(input_stft, n_fft = 25, hop_length = 28, onesided = False, center = True, return_complex = True))
|
69 |
+
input_stft = input_stft.reshape(b, c, input_stft.shape[-2], input_stft.shape[-1]).reshape(b, c, t).permute(0,2,1)
|
70 |
+
input = torch.cat((input, input_stft), dim=-1)
|
71 |
+
|
72 |
+
input = input.reshape(b, t, 22, -1).permute(0, 3, 1, 2).unsqueeze(-1)
|
73 |
+
|
74 |
+
output = model.classifier(input)
|
75 |
+
|
76 |
+
loss = F.cross_entropy(output.float(), labels.long(), reduction="mean")
|
77 |
+
|
78 |
+
optimizer.zero_grad()
|
79 |
+
loss.backward()
|
80 |
+
optimizer.step()
|
81 |
+
|
82 |
+
tol_loss += len(input) * loss.item()
|
83 |
+
|
84 |
+
# print(epoch, i, loss.item())
|
85 |
+
|
86 |
+
print(f'Epoch [{epoch+1}/{args.num_epochs}], Loss: {tol_loss / len(train_dataset):.4f}')
|
87 |
+
wandb.log({' loss': tol_loss / len(train_dataset)})
|
88 |
+
|
89 |
+
if best_loss is None or tol_loss < best_loss:
|
90 |
+
best_loss = tol_loss
|
91 |
+
torch.save(model.state_dict(), os.path.join(save_path, f'k={args.k}_best_loss.pth'))
|
92 |
+
|
93 |
+
# evaluation
|
94 |
+
model.load_state_dict(torch.load(os.path.join(save_path, f'k={args.k}_best_loss.pth')))
|
95 |
+
model.eval()
|
96 |
+
with torch.no_grad():
|
97 |
+
|
98 |
+
pred_whole, logits_whole = [], []
|
99 |
+
for input, mask, label in test_dataloader:
|
100 |
+
|
101 |
+
input = input.cuda()
|
102 |
+
label = label.cuda()
|
103 |
+
|
104 |
+
if not args.gyro:
|
105 |
+
b, t, c = input.shape
|
106 |
+
indices = np.array([range(i, i+3) for i in range(0, c, 6)]).flatten()
|
107 |
+
input = input[:,:,indices]
|
108 |
+
|
109 |
+
b, t, c = input.shape
|
110 |
+
if args.stft:
|
111 |
+
input_stft = input.permute(0,2,1).reshape(b * c,t)
|
112 |
+
input_stft = torch.abs(torch.stft(input_stft, n_fft = 25, hop_length = 28, onesided = False, center = True, return_complex = True))
|
113 |
+
input_stft = input_stft.reshape(b, c, input_stft.shape[-2], input_stft.shape[-1]).reshape(b, c, t).permute(0,2,1)
|
114 |
+
input = torch.cat((input, input_stft), dim=-1)
|
115 |
+
|
116 |
+
input = input.reshape(b, t, 22, -1).permute(0, 3, 1, 2).unsqueeze(-1)
|
117 |
+
|
118 |
+
logits_per_imu = model.classifier(input)
|
119 |
+
logits_whole.append(logits_per_imu)
|
120 |
+
|
121 |
+
pred = torch.argmax(logits_per_imu, dim=-1).detach().cpu().numpy()
|
122 |
+
pred_whole.append(pred)
|
123 |
+
|
124 |
+
pred = np.concatenate(pred_whole)
|
125 |
+
acc = accuracy_score(test_labels, pred)
|
126 |
+
prec = precision_score(test_labels, pred, average='macro')
|
127 |
+
rec = recall_score(test_labels, pred, average='macro')
|
128 |
+
f1 = f1_score(test_labels, pred, average='macro')
|
129 |
+
|
130 |
+
print(f"acc: {acc}, prec: {prec}, rec: {rec}, f1: {f1}")
|
131 |
+
wandb.log({f"acc": acc, f"prec": prec, f"rec": rec, f"f1": f1})
|
132 |
+
|
133 |
+
logits_whole = torch.cat(logits_whole)
|
134 |
+
r_at_1, r_at_2, r_at_3, r_at_4, r_at_5, mrr_score = compute_metrics_np(logits_whole.detach().cpu().numpy(), test_labels.numpy())
|
135 |
+
|
136 |
+
print(f"R@1: {r_at_1}, R@2: {r_at_2}, R@3: {r_at_3}, R@4: {r_at_4}, R@5: {r_at_5}, MRR: {mrr_score}")
|
137 |
+
wandb.log({f"R@1": r_at_1, f"R@2": r_at_2, f"R@3": r_at_3, f"R@4": r_at_4, f"R@5": r_at_5, f"MRR": mrr_score})
|
138 |
+
|
139 |
+
|
140 |
+
if __name__ == "__main__":
|
141 |
+
|
142 |
+
parser = argparse.ArgumentParser(description='Unified Pre-trained Motion Time Series Model')
|
143 |
+
|
144 |
+
# model
|
145 |
+
parser.add_argument('--mode', type=str, default='full', choices=['random','probe','full'], help='full fine-tuning, linear probe, random init')
|
146 |
+
|
147 |
+
# data
|
148 |
+
parser.add_argument('--padding_size', type=int, default='200', help='padding size (default: 200)')
|
149 |
+
parser.add_argument('--k', type=int, help='few shot samples per class (default: None)')
|
150 |
+
parser.add_argument('--X_train_path', type=str, required=True, help='/path/to/train/data/')
|
151 |
+
parser.add_argument('--X_test_path', type=str, required=True, help='/path/to/test/data/')
|
152 |
+
parser.add_argument('--y_train_path', type=str, required=True, help='/path/to/train/label/')
|
153 |
+
parser.add_argument('--y_test_path', type=str, required=True, help='/path/to/test/label/')
|
154 |
+
parser.add_argument('--config_path', type=str, required=True, help='/path/to/config/')
|
155 |
+
parser.add_argument('--few_shot_path', type=str, help='/path/to/few/shot/indices/')
|
156 |
+
parser.add_argument('--joint_list', nargs='+', type=int, required=True, help='List of joint indices')
|
157 |
+
parser.add_argument('--original_sampling_rate', type=int, required=True, help='original sampling rate')
|
158 |
+
parser.add_argument('--num_class', type=int, required=True, help='number of classes')
|
159 |
+
|
160 |
+
# training
|
161 |
+
parser.add_argument('--stage', type=str, default='finetune', help='training stage')
|
162 |
+
parser.add_argument('--num_epochs', type=int, default=200, help='number of fine-tuning epochs (default: 200)')
|
163 |
+
parser.add_argument('--run_tag', type=str, default='exp0', help='logging tag')
|
164 |
+
parser.add_argument('--gyro', type=int, default=0, help='using gyro or not')
|
165 |
+
parser.add_argument('--stft', type=int, default=0, help='using stft or not')
|
166 |
+
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
|
167 |
+
|
168 |
+
parser.add_argument('--checkpoint', type=str, default='./checkpoint/', help='/path/to/checkpoint/')
|
169 |
+
|
170 |
+
args = parser.parse_args()
|
171 |
+
|
172 |
+
main(args)
|
model.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
class Graph():
|
7 |
+
""" The Graph to model the skeletons
|
8 |
+
|
9 |
+
Args:
|
10 |
+
strategy (string): must be one of the follow candidates
|
11 |
+
- uniform: Uniform Labeling
|
12 |
+
- distance: Distance Partitioning
|
13 |
+
- spatial: Spatial Configuration
|
14 |
+
max_hop (int): the maximal distance between two connected nodes
|
15 |
+
dilation (int): controls the spacing between the kernel points
|
16 |
+
|
17 |
+
"""
|
18 |
+
def __init__(self,
|
19 |
+
strategy='spatial',
|
20 |
+
max_hop=1,
|
21 |
+
dilation=1):
|
22 |
+
self.max_hop = max_hop
|
23 |
+
self.dilation = dilation
|
24 |
+
|
25 |
+
self.get_edge()
|
26 |
+
self.hop_dis = get_hop_distance(self.num_node,
|
27 |
+
self.edge,
|
28 |
+
max_hop=max_hop)
|
29 |
+
self.get_adjacency(strategy)
|
30 |
+
|
31 |
+
def __str__(self):
|
32 |
+
return self.A
|
33 |
+
|
34 |
+
def get_edge(self):
|
35 |
+
# edge is a list of [child, parent] paris
|
36 |
+
self.num_node = 22
|
37 |
+
self_link = [(i, i) for i in range(self.num_node)]
|
38 |
+
neighbor_link = [(1,0), (2,1), (3,2), (4,3), (5,0), (6,5), (7,6), (8,7), (9,0), (10,9), (11,10), (12,11), \
|
39 |
+
(13,12), (14,11), (15,14), (16,15), (17,16), (18,11), (19,18), (20,19), (21,20)]
|
40 |
+
self.edge = self_link + neighbor_link
|
41 |
+
self.center = 0
|
42 |
+
|
43 |
+
def get_adjacency(self, strategy):
|
44 |
+
valid_hop = range(0, self.max_hop + 1, self.dilation)
|
45 |
+
adjacency = np.zeros((self.num_node, self.num_node))
|
46 |
+
for hop in valid_hop:
|
47 |
+
adjacency[self.hop_dis == hop] = 1
|
48 |
+
normalize_adjacency = normalize_digraph(adjacency)
|
49 |
+
|
50 |
+
if strategy == 'uniform':
|
51 |
+
A = np.zeros((1, self.num_node, self.num_node))
|
52 |
+
A[0] = normalize_adjacency
|
53 |
+
self.A = A
|
54 |
+
elif strategy == 'distance':
|
55 |
+
A = np.zeros((len(valid_hop), self.num_node, self.num_node))
|
56 |
+
for i, hop in enumerate(valid_hop):
|
57 |
+
A[i][self.hop_dis == hop] = normalize_adjacency[self.hop_dis ==
|
58 |
+
hop]
|
59 |
+
self.A = A
|
60 |
+
elif strategy == 'spatial':
|
61 |
+
A = []
|
62 |
+
for hop in valid_hop:
|
63 |
+
a_root = np.zeros((self.num_node, self.num_node))
|
64 |
+
a_close = np.zeros((self.num_node, self.num_node))
|
65 |
+
a_further = np.zeros((self.num_node, self.num_node))
|
66 |
+
for i in range(self.num_node):
|
67 |
+
for j in range(self.num_node):
|
68 |
+
if self.hop_dis[j, i] == hop:
|
69 |
+
if self.hop_dis[j, self.center] == self.hop_dis[
|
70 |
+
i, self.center]:
|
71 |
+
a_root[j, i] = normalize_adjacency[j, i]
|
72 |
+
elif self.hop_dis[j, self.center] > self.hop_dis[
|
73 |
+
i, self.center]:
|
74 |
+
a_close[j, i] = normalize_adjacency[j, i]
|
75 |
+
else:
|
76 |
+
a_further[j, i] = normalize_adjacency[j, i]
|
77 |
+
if hop == 0:
|
78 |
+
A.append(a_root)
|
79 |
+
else:
|
80 |
+
A.append(a_root + a_close)
|
81 |
+
A.append(a_further)
|
82 |
+
A = np.stack(A)
|
83 |
+
self.A = A
|
84 |
+
else:
|
85 |
+
raise ValueError("Do Not Exist This Strategy")
|
86 |
+
|
87 |
+
def get_hop_distance(num_node, edge, max_hop=1):
|
88 |
+
A = np.zeros((num_node, num_node))
|
89 |
+
for i, j in edge:
|
90 |
+
A[j, i] = 1
|
91 |
+
A[i, j] = 1
|
92 |
+
|
93 |
+
# compute hop steps
|
94 |
+
hop_dis = np.zeros((num_node, num_node)) + np.inf
|
95 |
+
transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)]
|
96 |
+
arrive_mat = (np.stack(transfer_mat) > 0)
|
97 |
+
for d in range(max_hop, -1, -1):
|
98 |
+
hop_dis[arrive_mat[d]] = d
|
99 |
+
return hop_dis
|
100 |
+
|
101 |
+
def normalize_digraph(A):
|
102 |
+
Dl = np.sum(A, 0)
|
103 |
+
num_node = A.shape[0]
|
104 |
+
Dn = np.zeros((num_node, num_node))
|
105 |
+
for i in range(num_node):
|
106 |
+
if Dl[i] > 0:
|
107 |
+
Dn[i, i] = Dl[i]**(-1)
|
108 |
+
AD = np.dot(A, Dn)
|
109 |
+
return AD
|
110 |
+
|
111 |
+
def normalize_undigraph(A):
|
112 |
+
Dl = np.sum(A, 0)
|
113 |
+
num_node = A.shape[0]
|
114 |
+
Dn = np.zeros((num_node, num_node))
|
115 |
+
for i in range(num_node):
|
116 |
+
if Dl[i] > 0:
|
117 |
+
Dn[i, i] = Dl[i]**(-0.5)
|
118 |
+
DAD = np.dot(np.dot(Dn, A), Dn)
|
119 |
+
return DAD
|
120 |
+
|
121 |
+
def zero(x):
|
122 |
+
return 0
|
123 |
+
|
124 |
+
def iden(x):
|
125 |
+
return x
|
126 |
+
|
127 |
+
class ConvTemporalGraphical(nn.Module):
|
128 |
+
r"""The basic module for applying a graph convolution.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
in_channels (int): Number of channels in the input sequence data
|
132 |
+
out_channels (int): Number of channels produced by the convolution
|
133 |
+
kernel_size (int): Size of the graph convolving kernel
|
134 |
+
t_kernel_size (int): Size of the temporal convolving kernel
|
135 |
+
t_stride (int, optional): Stride of the temporal convolution. Default: 1
|
136 |
+
t_padding (int, optional): Temporal zero-padding added to both sides of
|
137 |
+
the input. Default: 0
|
138 |
+
t_dilation (int, optional): Spacing between temporal kernel elements.
|
139 |
+
Default: 1
|
140 |
+
bias (bool, optional): If ``True``, adds a learnable bias to the output.
|
141 |
+
Default: ``True``
|
142 |
+
|
143 |
+
Shape:
|
144 |
+
- Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
|
145 |
+
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
|
146 |
+
- Output[0]: Output graph sequence in :math:`(N, out_channels, T_{out}, V)` format
|
147 |
+
- Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
|
148 |
+
|
149 |
+
where
|
150 |
+
:math:`N` is a batch size,
|
151 |
+
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
|
152 |
+
:math:`T_{in}/T_{out}` is a length of input/output sequence,
|
153 |
+
:math:`V` is the number of graph nodes.
|
154 |
+
"""
|
155 |
+
def __init__(self,
|
156 |
+
in_channels,
|
157 |
+
out_channels,
|
158 |
+
kernel_size,
|
159 |
+
t_kernel_size=1,
|
160 |
+
t_stride=1,
|
161 |
+
t_padding=0,
|
162 |
+
t_dilation=1,
|
163 |
+
bias=True):
|
164 |
+
super().__init__()
|
165 |
+
|
166 |
+
self.kernel_size = kernel_size
|
167 |
+
self.conv = nn.Conv2d(in_channels,
|
168 |
+
out_channels * kernel_size,
|
169 |
+
kernel_size=(t_kernel_size, 1),
|
170 |
+
padding=(t_padding, 0),
|
171 |
+
stride=(t_stride, 1),
|
172 |
+
dilation=(t_dilation, 1),
|
173 |
+
bias=bias)
|
174 |
+
|
175 |
+
def forward(self, x, A):
|
176 |
+
assert A.size(0) == self.kernel_size
|
177 |
+
|
178 |
+
x = self.conv(x)
|
179 |
+
|
180 |
+
n, kc, t, v = x.size()
|
181 |
+
x = x.view(n, self.kernel_size, kc // self.kernel_size, t, v)
|
182 |
+
x = torch.einsum('nkctv,kvw->nctw', (x, A))
|
183 |
+
|
184 |
+
return x.contiguous(), A
|
185 |
+
|
186 |
+
class st_gcn_block(nn.Module):
|
187 |
+
r"""Applies a spatial temporal graph convolution over an input graph sequence.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
in_channels (int): Number of channels in the input sequence data
|
191 |
+
out_channels (int): Number of channels produced by the convolution
|
192 |
+
kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel
|
193 |
+
stride (int, optional): Stride of the temporal convolution. Default: 1
|
194 |
+
dropout (int, optional): Dropout rate of the final output. Default: 0
|
195 |
+
residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True``
|
196 |
+
|
197 |
+
Shape:
|
198 |
+
- Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
|
199 |
+
- Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
|
200 |
+
- Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format
|
201 |
+
- Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
|
202 |
+
|
203 |
+
where
|
204 |
+
:math:`N` is a batch size,
|
205 |
+
:math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
|
206 |
+
:math:`T_{in}/T_{out}` is a length of input/output sequence,
|
207 |
+
:math:`V` is the number of graph nodes.
|
208 |
+
|
209 |
+
"""
|
210 |
+
def __init__(self,
|
211 |
+
in_channels,
|
212 |
+
out_channels,
|
213 |
+
kernel_size,
|
214 |
+
stride=1,
|
215 |
+
dropout=0,
|
216 |
+
residual=True):
|
217 |
+
super().__init__()
|
218 |
+
|
219 |
+
assert len(kernel_size) == 2
|
220 |
+
assert kernel_size[0] % 2 == 1
|
221 |
+
padding = ((kernel_size[0] - 1) // 2, 0)
|
222 |
+
|
223 |
+
self.gcn = ConvTemporalGraphical(in_channels, out_channels,
|
224 |
+
kernel_size[1])
|
225 |
+
|
226 |
+
self.tcn = nn.Sequential(
|
227 |
+
nn.BatchNorm2d(out_channels),
|
228 |
+
nn.ReLU(inplace=True),
|
229 |
+
nn.Conv2d(
|
230 |
+
out_channels,
|
231 |
+
out_channels,
|
232 |
+
(kernel_size[0], 1),
|
233 |
+
(stride, 1),
|
234 |
+
padding,
|
235 |
+
),
|
236 |
+
nn.BatchNorm2d(out_channels),
|
237 |
+
nn.Dropout(dropout, inplace=True),
|
238 |
+
)
|
239 |
+
|
240 |
+
if not residual:
|
241 |
+
self.residual = zero
|
242 |
+
|
243 |
+
elif (in_channels == out_channels) and (stride == 1):
|
244 |
+
self.residual = iden
|
245 |
+
|
246 |
+
else:
|
247 |
+
self.residual = nn.Sequential(
|
248 |
+
nn.Conv2d(in_channels,
|
249 |
+
out_channels,
|
250 |
+
kernel_size=1,
|
251 |
+
stride=(stride, 1)),
|
252 |
+
nn.BatchNorm2d(out_channels),
|
253 |
+
)
|
254 |
+
|
255 |
+
self.relu = nn.ReLU(inplace=True)
|
256 |
+
|
257 |
+
def forward(self, x, A):
|
258 |
+
|
259 |
+
res = self.residual(x)
|
260 |
+
x, A = self.gcn(x, A)
|
261 |
+
x = self.tcn(x) + res
|
262 |
+
|
263 |
+
return self.relu(x), A
|
264 |
+
|
265 |
+
class ST_GCN_18(nn.Module):
|
266 |
+
r"""Spatial temporal graph convolutional networks.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
in_channels (int): Number of channels in the input data
|
270 |
+
num_class (int): Number of classes for the classification task
|
271 |
+
graph_cfg (dict): The arguments for building the graph
|
272 |
+
edge_importance_weighting (bool): If ``True``, adds a learnable
|
273 |
+
importance weighting to the edges of the graph
|
274 |
+
**kwargs (optional): Other parameters for graph convolution units
|
275 |
+
|
276 |
+
Shape:
|
277 |
+
- Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})`
|
278 |
+
- Output: :math:`(N, num_class)` where
|
279 |
+
:math:`N` is a batch size,
|
280 |
+
:math:`T_{in}` is a length of input sequence,
|
281 |
+
:math:`V_{in}` is the number of graph nodes,
|
282 |
+
:math:`M_{in}` is the number of instance in a frame.
|
283 |
+
"""
|
284 |
+
def __init__(self,
|
285 |
+
in_channels,
|
286 |
+
edge_importance_weighting=True,
|
287 |
+
data_bn=True,
|
288 |
+
**kwargs):
|
289 |
+
super().__init__()
|
290 |
+
|
291 |
+
# load graph
|
292 |
+
self.graph = Graph()
|
293 |
+
A = torch.tensor(self.graph.A,
|
294 |
+
dtype=torch.float32,
|
295 |
+
requires_grad=False)
|
296 |
+
self.register_buffer('A', A)
|
297 |
+
|
298 |
+
# build networks
|
299 |
+
spatial_kernel_size = A.size(0)
|
300 |
+
temporal_kernel_size = 9
|
301 |
+
kernel_size = (temporal_kernel_size, spatial_kernel_size)
|
302 |
+
self.data_bn = nn.BatchNorm1d(in_channels *
|
303 |
+
A.size(1)) if data_bn else iden
|
304 |
+
kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'}
|
305 |
+
self.st_gcn_networks = nn.ModuleList((
|
306 |
+
st_gcn_block(in_channels,
|
307 |
+
64,
|
308 |
+
kernel_size,
|
309 |
+
1,
|
310 |
+
residual=False,
|
311 |
+
**kwargs0),
|
312 |
+
st_gcn_block(64, 64, kernel_size, 1, **kwargs),
|
313 |
+
st_gcn_block(64, 64, kernel_size, 1, **kwargs),
|
314 |
+
st_gcn_block(64, 64, kernel_size, 1, **kwargs),
|
315 |
+
st_gcn_block(64, 128, kernel_size, 2, **kwargs),
|
316 |
+
st_gcn_block(128, 128, kernel_size, 1, **kwargs),
|
317 |
+
st_gcn_block(128, 128, kernel_size, 1, **kwargs),
|
318 |
+
st_gcn_block(128, 256, kernel_size, 2, **kwargs),
|
319 |
+
st_gcn_block(256, 256, kernel_size, 1, **kwargs),
|
320 |
+
st_gcn_block(256, 512, kernel_size, 1, **kwargs),
|
321 |
+
))
|
322 |
+
|
323 |
+
# initialize parameters for edge importance weighting
|
324 |
+
if edge_importance_weighting:
|
325 |
+
self.edge_importance = nn.ParameterList([
|
326 |
+
nn.Parameter(torch.ones(self.A.size()))
|
327 |
+
for i in self.st_gcn_networks
|
328 |
+
])
|
329 |
+
else:
|
330 |
+
self.edge_importance = [1] * len(self.st_gcn_networks)
|
331 |
+
|
332 |
+
def forward(self, x):
|
333 |
+
# data normalization
|
334 |
+
N, C, T, V, M = x.size()
|
335 |
+
x = x.permute(0, 4, 3, 1, 2).contiguous()
|
336 |
+
x = x.view(N * M, V * C, T)
|
337 |
+
x = self.data_bn(x)
|
338 |
+
x = x.view(N, M, V, C, T)
|
339 |
+
x = x.permute(0, 1, 3, 4, 2).contiguous()
|
340 |
+
x = x.view(N * M, C, T, V)
|
341 |
+
|
342 |
+
# forward
|
343 |
+
for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
|
344 |
+
x, _ = gcn(x, self.A * importance)
|
345 |
+
|
346 |
+
# global pooling
|
347 |
+
x = F.avg_pool2d(x, x.size()[2:]) # (b, 512, t, joint)
|
348 |
+
x = x.view(N, M, -1, 1, 1).mean(dim=1)
|
349 |
+
|
350 |
+
return x
|
pos2bvh.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from Quaternions import Quaternions
|
3 |
+
from scipy_motion import myBVH
|
4 |
+
import BVH
|
5 |
+
from scipy_motion import myAnimation
|
6 |
+
import Animation
|
7 |
+
from scipy_motion import myInverseKinematics as myIK
|
8 |
+
import InverseKinematics as IK
|
9 |
+
from tqdm import tqdm
|
10 |
+
import multiprocessing
|
11 |
+
import os
|
12 |
+
import os.path as osp
|
13 |
+
from scipy.spatial.transform import Rotation as R
|
14 |
+
|
15 |
+
|
16 |
+
parents = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19]
|
17 |
+
# names = ['root','leftleg1','leftleg2','leftleg3','leftleg4','rightleg1','rightleg2','rightleg3','rightleg4',\
|
18 |
+
# 'spline1','spline2','spline3','spline4','spline5','rightarm1','rightarm2','rightarm3','rightarm4',\
|
19 |
+
# 'leftarm1','lertarm2','leftarm3','leftarm4']
|
20 |
+
|
21 |
+
def process_file(f):
|
22 |
+
|
23 |
+
fk_positions = np.load('/path/to/joint/pos/%s.npy' % (f))
|
24 |
+
|
25 |
+
frametime = 1 / 20
|
26 |
+
|
27 |
+
anim_ik, _, _, save_file = IK.animation_from_positions(fk_positions, parents=parents)
|
28 |
+
|
29 |
+
if save_file:
|
30 |
+
BVH.save('bvh/%s.bvh' % f, anim_ik, frametime=frametime)
|
31 |
+
|
32 |
+
source_dir = '/path/to/joint/pos'
|
33 |
+
error_file = ['M005836.npy', 'M000990.npy', '000990.npy', '005836.npy']
|
34 |
+
npy_files = [file[:-4] for file in os.listdir(source_dir) if file.endswith('.npy') and file not in error_file]
|
35 |
+
|
36 |
+
# Process files in parallel
|
37 |
+
pool = multiprocessing.Pool(processes=8)
|
38 |
+
for _ in tqdm(pool.imap_unordered(process_file, npy_files), total=len(npy_files)):
|
39 |
+
pass
|
40 |
+
pool.close()
|
41 |
+
pool.join()
|
pretrain.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import os
|
8 |
+
import numpy as np
|
9 |
+
import clip
|
10 |
+
import wandb
|
11 |
+
import datetime
|
12 |
+
import torch.optim as optim
|
13 |
+
|
14 |
+
from data import CLIPDataset
|
15 |
+
from utils import augment_data
|
16 |
+
from contrastive import ContrastiveModule
|
17 |
+
|
18 |
+
def main(args):
|
19 |
+
|
20 |
+
train_dataset = CLIPDataset(args)
|
21 |
+
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=True)
|
22 |
+
|
23 |
+
date = datetime.datetime.now().strftime("%d-%m-%y_%H:%M")
|
24 |
+
wandb.init(
|
25 |
+
project='UniMTS',
|
26 |
+
name=f"{args.run_tag}_{args.stage}_" + f"{date}"
|
27 |
+
)
|
28 |
+
|
29 |
+
model = ContrastiveModule(args).cuda()
|
30 |
+
optimizer = optim.Adam(model.parameters(), lr=1e-4)
|
31 |
+
|
32 |
+
save_path = './checkpoint/%s/' % args.run_tag
|
33 |
+
if not os.path.exists(save_path):
|
34 |
+
os.makedirs(save_path)
|
35 |
+
|
36 |
+
for epoch in range(args.num_epochs):
|
37 |
+
|
38 |
+
tol_loss = 0
|
39 |
+
|
40 |
+
model.train()
|
41 |
+
for i, batch in enumerate(train_loader):
|
42 |
+
|
43 |
+
inputs_imu = batch['imu'].float().cuda()
|
44 |
+
inputs_text = clip.tokenize(batch['text'], truncate=True).cuda()
|
45 |
+
mask = batch['mask'].float().cuda()
|
46 |
+
|
47 |
+
input = inputs_imu * mask
|
48 |
+
|
49 |
+
# rotation invariant
|
50 |
+
if args.aug:
|
51 |
+
input = augment_data(input)
|
52 |
+
|
53 |
+
if not args.gyro:
|
54 |
+
b, t, c = input.shape
|
55 |
+
indices = np.array([range(i, i+3) for i in range(0, c, 6)]).flatten()
|
56 |
+
input = input[:,:,indices]
|
57 |
+
|
58 |
+
b, t, c = input.shape
|
59 |
+
if args.stft:
|
60 |
+
input_stft = input.permute(0,2,1).reshape(b * c,t)
|
61 |
+
input_stft = torch.abs(torch.stft(input_stft, n_fft = 25, hop_length = 28, onesided = False, center = True, return_complex = True))
|
62 |
+
input_stft = input_stft.reshape(b, c, input_stft.shape[-2], input_stft.shape[-1]).reshape(b, c, t).permute(0,2,1)
|
63 |
+
input = torch.cat((input, input_stft), dim=-1)
|
64 |
+
|
65 |
+
input = input.reshape(b, t, 22, -1).permute(0, 3, 1, 2).unsqueeze(-1)
|
66 |
+
|
67 |
+
# IMU and text representations
|
68 |
+
logits_per_imu, logits_per_text = model(input, inputs_text)
|
69 |
+
|
70 |
+
# positive keys are the entries on the diagonal
|
71 |
+
labels = torch.arange(len(batch['imu'])).cuda()
|
72 |
+
|
73 |
+
loss = F.cross_entropy(logits_per_imu / args.temperature, labels, reduction="mean")
|
74 |
+
|
75 |
+
optimizer.zero_grad()
|
76 |
+
loss.backward()
|
77 |
+
optimizer.step()
|
78 |
+
|
79 |
+
tol_loss += len(inputs_imu) * loss.item()
|
80 |
+
|
81 |
+
# print(epoch, i, loss.item())
|
82 |
+
|
83 |
+
print(f'Epoch [{epoch+1}/{args.num_epochs}], Loss: {tol_loss / len(train_dataset):.4f}')
|
84 |
+
wandb.log({'loss': tol_loss / len(train_dataset)})
|
85 |
+
|
86 |
+
if epoch > 0 and epoch % args.log == 0:
|
87 |
+
torch.save(model.model.state_dict(), os.path.join(save_path, f'epoch_{epoch}.pth'))
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
|
91 |
+
parser = argparse.ArgumentParser(description='Unified Pre-trained Motion Time Series Model')
|
92 |
+
|
93 |
+
# data
|
94 |
+
parser.add_argument('--padding_size', type=int, default='200', help='padding size (default: 200)')
|
95 |
+
parser.add_argument('--sample', type=float, default='1', help='pre-training down-sample ratio (default: 1)')
|
96 |
+
parser.add_argument('--data_path', type=str, default='./data/', help='/path/to/data/')
|
97 |
+
|
98 |
+
# training
|
99 |
+
parser.add_argument('--run_tag', type=str, default='exp0', help='logging tag')
|
100 |
+
parser.add_argument('--stage', type=str, default='pretrain', help='training stage')
|
101 |
+
parser.add_argument('--num_epochs', type=int, default=100, help='number of pre-training epochs')
|
102 |
+
parser.add_argument('--gyro', type=int, default=0, help='using gyro or not')
|
103 |
+
parser.add_argument('--stft', type=int, default=0, help='using stft or not')
|
104 |
+
parser.add_argument('--aug', type=int, default=1, help='using augmentation or not')
|
105 |
+
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
|
106 |
+
parser.add_argument('--temperature', type=float, default=0.1, help='temperature')
|
107 |
+
parser.add_argument('--log', type=int, default=10, help='logging step')
|
108 |
+
|
109 |
+
args = parser.parse_args()
|
110 |
+
|
111 |
+
main(args)
|
run_evaluation.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python evaluate.py \
|
2 |
+
--batch_size 64 \
|
3 |
+
--checkpoint './checkpoint/UniMTS.pth' \
|
4 |
+
--data_path 'UniMTS_data'
|
run_evaluation_custom.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python evaluate_custom.py \
|
2 |
+
--batch_size 64 \
|
3 |
+
--checkpoint './checkpoint/UniMTS.pth' \
|
4 |
+
--X_path 'UniMTS_data/TNDA-HAR/X_test.npy' \
|
5 |
+
--y_path 'UniMTS_data/TNDA-HAR/y_test.npy' \
|
6 |
+
--config_path 'UniMTS_data/TNDA-HAR/TNDA-HAR.json' \
|
7 |
+
--joint_list 20 2 21 3 11 \
|
8 |
+
--original_sampling_rate 50
|
run_finetune.sh
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
for k in 1 2 3 5 10
|
2 |
+
do
|
3 |
+
|
4 |
+
python finetune.py \
|
5 |
+
--mode full \
|
6 |
+
--k $k \
|
7 |
+
--batch_size 64 \
|
8 |
+
--num_epochs 200 \
|
9 |
+
--checkpoint './checkpoint/UniMTS.pth' \
|
10 |
+
--data_path 'UniMTS_data'
|
11 |
+
|
12 |
+
done
|
13 |
+
|
14 |
+
python finetune.py \
|
15 |
+
--mode full \
|
16 |
+
--batch_size 64 \
|
17 |
+
--num_epochs 200 \
|
18 |
+
--checkpoint './checkpoint/UniMTS.pth' \
|
19 |
+
--data_path 'UniMTS_data'
|
run_finetune_custom.sh
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
for k in 1 2 3 5 10
|
2 |
+
do
|
3 |
+
|
4 |
+
python finetune_custom.py \
|
5 |
+
--mode full \
|
6 |
+
--k $k \
|
7 |
+
--batch_size 64 \
|
8 |
+
--num_epochs 200 \
|
9 |
+
--checkpoint './checkpoint/UniMTS.pth' \
|
10 |
+
--X_train_path 'UniMTS_data/TNDA-HAR/X_train.npy' \
|
11 |
+
--y_train_path 'UniMTS_data/TNDA-HAR/y_train.npy' \
|
12 |
+
--X_test_path 'UniMTS_data/TNDA-HAR/X_test.npy' \
|
13 |
+
--y_test_path 'UniMTS_data/TNDA-HAR/y_test.npy' \
|
14 |
+
--config_path 'UniMTS_data/TNDA-HAR/TNDA-HAR.json' \
|
15 |
+
--joint_list 20 2 21 3 11 \
|
16 |
+
--original_sampling_rate 50 \
|
17 |
+
--num_class 8
|
18 |
+
|
19 |
+
done
|
20 |
+
|
21 |
+
python finetune_custom.py \
|
22 |
+
--mode full \
|
23 |
+
--batch_size 64 \
|
24 |
+
--num_epochs 200 \
|
25 |
+
--checkpoint './checkpoint/UniMTS.pth' \
|
26 |
+
--X_train_path 'UniMTS_data/TNDA-HAR/X_train.npy' \
|
27 |
+
--y_train_path 'UniMTS_data/TNDA-HAR/y_train.npy' \
|
28 |
+
--X_test_path 'UniMTS_data/TNDA-HAR/X_test.npy' \
|
29 |
+
--y_test_path 'UniMTS_data/TNDA-HAR/y_test.npy' \
|
30 |
+
--config_path 'UniMTS_data/TNDA-HAR/TNDA-HAR.json' \
|
31 |
+
--joint_list 20 2 21 3 11 \
|
32 |
+
--original_sampling_rate 50 \
|
33 |
+
--num_class 8
|
run_pretrain.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python pretrain.py \
|
2 |
+
--aug 1 \
|
3 |
+
--batch_size 64 \
|
4 |
+
--data_path 'UniMTS_data'
|
text_aug.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
def load_api_key(file_path='api_key.txt'):
|
8 |
+
with open(file_path, 'r') as f:
|
9 |
+
for line in f:
|
10 |
+
if line.startswith('api_key='):
|
11 |
+
return line.strip().split('=', 1)[1]
|
12 |
+
return None
|
13 |
+
|
14 |
+
openai.api_key = load_api_key()
|
15 |
+
|
16 |
+
if openai.api_key is None:
|
17 |
+
print("Error: API key not found.")
|
18 |
+
exit()
|
19 |
+
|
20 |
+
files = glob.glob('/path/to/txt')
|
21 |
+
aug_dir = '/path/to/output'
|
22 |
+
|
23 |
+
for f in tqdm(files):
|
24 |
+
|
25 |
+
file_id = f.split('/')[-1]
|
26 |
+
if not os.path.exists(aug_dir + file_id):
|
27 |
+
|
28 |
+
with open(f, 'r') as file:
|
29 |
+
lines = file.readlines()
|
30 |
+
|
31 |
+
text = []
|
32 |
+
for i, l in enumerate(lines):
|
33 |
+
text.append(str(i) + ': ')
|
34 |
+
text.append((l).split('#')[0].strip())
|
35 |
+
if text[-1][-1] != '.':
|
36 |
+
text.append('. ')
|
37 |
+
else:
|
38 |
+
text.append(' ')
|
39 |
+
text = ''.join(text)
|
40 |
+
|
41 |
+
prompt = 'The following one or multiple descriptions are describing the same human activities: '
|
42 |
+
prompt += text
|
43 |
+
prompt += 'Generate 3 paraphrases to describe the same activities. One in a line in a plain text format ending with \n, without numbering or - at the beginning. Do not generate any other analysis except from the paraphrased descriptions.'
|
44 |
+
|
45 |
+
response = openai.ChatCompletion.create(
|
46 |
+
model="gpt-3.5-turbo",
|
47 |
+
messages=[
|
48 |
+
{"role": "user", "content": prompt}
|
49 |
+
]
|
50 |
+
)
|
51 |
+
pred = response.choices[0]['message']['content']
|
52 |
+
# res = pred.split('\n')
|
53 |
+
|
54 |
+
shutil.copy(f, aug_dir)
|
55 |
+
with open(aug_dir + file_id, 'a') as log_file:
|
56 |
+
log_file.write(pred)
|
57 |
+
|
58 |
+
files = glob.glob('/path/to/output')
|
59 |
+
for f in tqdm(files):
|
60 |
+
with open(f, 'r') as file:
|
61 |
+
lines = file.readlines()
|
62 |
+
|
63 |
+
lines = [line.lstrip("- ") for line in lines if line.strip()]
|
64 |
+
|
65 |
+
with open(f, 'w') as file:
|
66 |
+
file.writelines(lines)
|
utils.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import imageio
|
5 |
+
import io
|
6 |
+
|
7 |
+
def random_rotation_matrix():
|
8 |
+
# Random quaternion
|
9 |
+
q = torch.randn(4)
|
10 |
+
q = q / torch.norm(q)
|
11 |
+
|
12 |
+
# Quaternion to rotation matrix
|
13 |
+
R = torch.tensor([
|
14 |
+
[1 - 2*q[2]**2 - 2*q[3]**2, 2*q[1]*q[2] - 2*q[3]*q[0], 2*q[1]*q[3] + 2*q[2]*q[0]],
|
15 |
+
[2*q[1]*q[2] + 2*q[3]*q[0], 1 - 2*q[1]**2 - 2*q[3]**2, 2*q[2]*q[3] - 2*q[1]*q[0]],
|
16 |
+
[2*q[1]*q[3] - 2*q[2]*q[0], 2*q[2]*q[3] + 2*q[1]*q[0], 1 - 2*q[1]**2 - 2*q[2]**2]
|
17 |
+
])
|
18 |
+
return R
|
19 |
+
|
20 |
+
def augment_data(data):
|
21 |
+
B, T, M = data.shape
|
22 |
+
augmented_data = torch.zeros_like(data)
|
23 |
+
|
24 |
+
for i in range(B):
|
25 |
+
for c in range(0, M, 6):
|
26 |
+
R = random_rotation_matrix().cuda()
|
27 |
+
acc = data[i, :, c:c+3].transpose(0, 1) # Shape (3, T)
|
28 |
+
gyro = data[i, :, c+3:c+6].transpose(0, 1) # Shape (3, T)
|
29 |
+
|
30 |
+
# Apply rotation
|
31 |
+
rotated_acc = torch.matmul(R, acc)
|
32 |
+
rotated_gyro = torch.matmul(R, gyro)
|
33 |
+
|
34 |
+
# Concatenate and assign to augmented_data
|
35 |
+
augmented_data[i, :, c:c+3] = rotated_acc.transpose(0, 1)
|
36 |
+
augmented_data[i, :, c+3:c+6] = rotated_gyro.transpose(0, 1)
|
37 |
+
|
38 |
+
return augmented_data
|
39 |
+
|
40 |
+
def update_limits(data):
|
41 |
+
# Get global min and max for each axis
|
42 |
+
min_x, max_x = np.min(data[:, :, 0]), np.max(data[:, :, 0])
|
43 |
+
min_y, max_y = np.min(data[:, :, 2]), np.max(data[:, :, 2])
|
44 |
+
min_z, max_z = np.min(data[:, :, 1]), np.max(data[:, :, 1])
|
45 |
+
|
46 |
+
# Add some padding to ensure the skeleton doesn't touch the plot edges
|
47 |
+
padding = 0.1
|
48 |
+
x_range = max_x - min_x
|
49 |
+
y_range = max_y - min_y
|
50 |
+
z_range = max_z - min_z
|
51 |
+
|
52 |
+
return (min_x - padding * x_range, max_x + padding * x_range), \
|
53 |
+
(min_y - padding * y_range, max_y + padding * y_range), \
|
54 |
+
(min_z - padding * z_range, max_z + padding * z_range)
|
55 |
+
|
56 |
+
def plot_skeleton(frame_data, xlims, ylims, zlims, dataset):
|
57 |
+
"""
|
58 |
+
Plot a single frame of skeleton data.
|
59 |
+
"""
|
60 |
+
fig = plt.figure()
|
61 |
+
ax = fig.add_subplot(111, projection='3d')
|
62 |
+
ax.scatter(frame_data[:, 0], frame_data[:, 2], frame_data[:, 1])
|
63 |
+
|
64 |
+
# Add code here to connect the joints as per your skeleton structure
|
65 |
+
if dataset == 't2m':
|
66 |
+
connections = [
|
67 |
+
[0, 2, 5, 8, 11],
|
68 |
+
[0, 1, 4, 7, 10],
|
69 |
+
[0, 3, 6, 9, 12, 15],
|
70 |
+
[9, 14, 17, 19, 21],
|
71 |
+
[9, 13, 16, 18, 20]
|
72 |
+
]
|
73 |
+
|
74 |
+
if dataset == 'kit':
|
75 |
+
connections = [
|
76 |
+
[0, 11, 12, 13, 14, 15],
|
77 |
+
[0, 16, 17, 18, 19, 20],
|
78 |
+
[0, 1, 2, 3, 4],
|
79 |
+
[3, 5, 6, 7],
|
80 |
+
[3, 8, 9, 10]
|
81 |
+
]
|
82 |
+
|
83 |
+
if dataset == 'ntu':
|
84 |
+
connections = [
|
85 |
+
[0, 12, 13, 14, 15],
|
86 |
+
[0, 16, 17, 18, 19],
|
87 |
+
[0, 1, 20, 2, 3],
|
88 |
+
[20, 4, 5, 6, 7, 21],
|
89 |
+
[7, 22],
|
90 |
+
[20, 8, 9, 10, 11, 23],
|
91 |
+
[11, 24],
|
92 |
+
]
|
93 |
+
|
94 |
+
# Plot the lines for each sequence
|
95 |
+
for connection in connections:
|
96 |
+
for i in range(len(connection)-1):
|
97 |
+
start_joint = connection[i]
|
98 |
+
end_joint = connection[i+1]
|
99 |
+
ax.plot([frame_data[start_joint, 0], frame_data[end_joint, 0]],
|
100 |
+
[frame_data[start_joint, 2], frame_data[end_joint, 2]],
|
101 |
+
[frame_data[start_joint, 1], frame_data[end_joint, 1]])
|
102 |
+
|
103 |
+
ax.view_init(elev=10, azim=90)
|
104 |
+
ax.set_box_aspect((np.ptp(xlims), np.ptp(ylims), np.ptp(zlims)))
|
105 |
+
|
106 |
+
ax.set_xlim(xlims)
|
107 |
+
ax.set_ylim(ylims)
|
108 |
+
ax.set_zlim(zlims)
|
109 |
+
ax.set_xlabel('X')
|
110 |
+
ax.set_ylabel('Z')
|
111 |
+
ax.set_zlabel('Y')
|
112 |
+
|
113 |
+
# Save the plot to a buffer
|
114 |
+
buf = io.BytesIO()
|
115 |
+
plt.savefig(buf, format='png')
|
116 |
+
buf.seek(0)
|
117 |
+
img = imageio.imread(buf)
|
118 |
+
buf.close()
|
119 |
+
|
120 |
+
plt.close(fig) # Close the figure to prevent display
|
121 |
+
return img
|
122 |
+
|
123 |
+
def plot_skeleton_gif(data, dataset):
|
124 |
+
xlims, ylims, zlims = update_limits(data)
|
125 |
+
images = [plot_skeleton(frame, xlims, ylims, zlims, dataset) for frame in data]
|
126 |
+
imageio.mimsave('./skeleton_animation.gif', images, fps=20)
|
127 |
+
return
|
128 |
+
|
129 |
+
def plot_single_skeleton(data, dataset, frame=0):
|
130 |
+
|
131 |
+
xlims, ylims, zlims = update_limits(data)
|
132 |
+
frame_data = data[frame]
|
133 |
+
|
134 |
+
fig = plt.figure()
|
135 |
+
ax = fig.add_subplot(111, projection='3d')
|
136 |
+
ax.scatter(frame_data[:, 0], frame_data[:, 2], frame_data[:, 1])
|
137 |
+
|
138 |
+
# Add code here to connect the joints as per your skeleton structure
|
139 |
+
if dataset == 't2m':
|
140 |
+
connections = [
|
141 |
+
[0, 2, 5, 8, 11],
|
142 |
+
[0, 1, 4, 7, 10],
|
143 |
+
[0, 3, 6, 9, 12, 15],
|
144 |
+
[9, 14, 17, 19, 21],
|
145 |
+
[9, 13, 16, 18, 20]
|
146 |
+
]
|
147 |
+
|
148 |
+
if dataset == 'kit':
|
149 |
+
connections = [
|
150 |
+
[0, 11, 12, 13, 14, 15],
|
151 |
+
[0, 16, 17, 18, 19, 20],
|
152 |
+
[0, 1, 2, 3, 4],
|
153 |
+
[3, 5, 6, 7],
|
154 |
+
[3, 8, 9, 10]
|
155 |
+
]
|
156 |
+
|
157 |
+
if dataset == 'ntu':
|
158 |
+
connections = [
|
159 |
+
[0, 12, 13, 14, 15],
|
160 |
+
[0, 16, 17, 18, 19],
|
161 |
+
[0, 1, 20, 2, 3],
|
162 |
+
[20, 4, 5, 6, 7, 21],
|
163 |
+
[7, 22],
|
164 |
+
[20, 8, 9, 10, 11, 23],
|
165 |
+
[11, 24],
|
166 |
+
]
|
167 |
+
|
168 |
+
# Plot the lines for each sequence
|
169 |
+
for connection in connections:
|
170 |
+
for i in range(len(connection)-1):
|
171 |
+
start_joint = connection[i]
|
172 |
+
end_joint = connection[i+1]
|
173 |
+
ax.plot([frame_data[start_joint, 0], frame_data[end_joint, 0]],
|
174 |
+
[frame_data[start_joint, 2], frame_data[end_joint, 2]],
|
175 |
+
[frame_data[start_joint, 1], frame_data[end_joint, 1]])
|
176 |
+
|
177 |
+
#ax.view_init(elev=10, azim=90)
|
178 |
+
ax.set_box_aspect((np.ptp(xlims), np.ptp(ylims), np.ptp(zlims)))
|
179 |
+
|
180 |
+
ax.set_xlim(xlims)
|
181 |
+
ax.set_ylim(ylims)
|
182 |
+
ax.set_zlim(zlims)
|
183 |
+
|
184 |
+
ax.set_xlabel('X')
|
185 |
+
ax.set_ylabel('Z')
|
186 |
+
ax.set_zlabel('Y')
|
187 |
+
|
188 |
+
plt.savefig('skeleton.pdf', bbox_inches='tight')
|
189 |
+
|
190 |
+
def compute_height(joints, head_index, l_foot_index, r_foot_index):
|
191 |
+
joints = torch.from_numpy(joints)
|
192 |
+
left = (joints[:,head_index,1] - joints[:,l_foot_index,1])[0]
|
193 |
+
right = (joints[:,head_index,1] - joints[:,r_foot_index,1])[0]
|
194 |
+
height = (left + right) / 2
|
195 |
+
return height
|
196 |
+
|
197 |
+
def compute_metrics_np(similarity_matrix, correct_labels):
|
198 |
+
|
199 |
+
B, _ = similarity_matrix.shape
|
200 |
+
|
201 |
+
ranked_indices = np.argsort(-similarity_matrix, axis=1)
|
202 |
+
|
203 |
+
correct_label_ranks = np.array([np.where(ranked_indices[i] == correct_labels[i])[0][0] for i in range(B)]) + 1
|
204 |
+
|
205 |
+
# Compute R@K
|
206 |
+
R_at_1 = np.mean(correct_label_ranks <= 1)
|
207 |
+
R_at_2 = np.mean(correct_label_ranks <= 2)
|
208 |
+
R_at_3 = np.mean(correct_label_ranks <= 3)
|
209 |
+
R_at_4 = np.mean(correct_label_ranks <= 4)
|
210 |
+
R_at_5 = np.mean(correct_label_ranks <= 5)
|
211 |
+
|
212 |
+
# Compute MRR
|
213 |
+
MRR = np.mean(1.0 / correct_label_ranks)
|
214 |
+
|
215 |
+
return R_at_1, R_at_2, R_at_3, R_at_4, R_at_5, MRR
|