KKYYKK commited on
Commit
e744d68
·
verified ·
1 Parent(s): d662031

Upload config_recog_bern_bypass_frame_linear.py with huggingface_hub

Browse files
config_recog_bern_bypass_frame_linear.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ import os
4
+ import logging
5
+ import pickle
6
+
7
+ def read_pkl_data(pkl_path, img_path):
8
+ logging.info('reading pickle file: '+ pkl_path)
9
+ with open(pkl_path, "rb") as fp:
10
+ data = pickle.load(fp)
11
+ fp.close()
12
+
13
+ root_dir = img_path
14
+ if not os.path.exists(root_dir):
15
+ root_dir = root_dir.replace('train', '').replace('val', '').replace('test', '')
16
+ imgs, phases, steps = [], [], []
17
+ for vid_name in sorted(data.keys()):
18
+ paths = [
19
+ os.path.join(root_dir, vid_name, f"{item['Frame_id']}.jpg")
20
+ for item in data[vid_name]
21
+ ]
22
+ imgs.append(paths)
23
+ phases.append([item['Phase_gt'] for item in data[vid_name]])
24
+ steps.append([item['Step_gt'] for item in data[vid_name]])
25
+
26
+ return imgs, phases, steps
27
+
28
+
29
+ ## Read test pickle files
30
+ #### TRAIN ####
31
+ labels = os.path.join('/gpfswork/rech/okw/ukw13bv/MultiBypass140/labels', 'bern', 'labels_by70_splits/labels', 'train', f'1fps_100_0.pickle')
32
+ images = os.path.join('/gpfsscratch/rech/okw/ukw13bv/bypass/BernBypass70/frames')
33
+ videos_train, phase_labels_train, step_labels_train = read_pkl_data(
34
+ labels, images
35
+ )
36
+
37
+ #### VAL ####
38
+ labels = os.path.join('/gpfswork/rech/okw/ukw13bv/MultiBypass140/labels', 'bern', 'labels_by70_splits/labels', 'val', f'1fps_0.pickle')
39
+ images = os.path.join('/gpfsscratch/rech/okw/ukw13bv/bypass/BernBypass70/frames')
40
+ videos_val, phase_labels_val, step_labels_val = read_pkl_data(
41
+ labels, images
42
+ )
43
+
44
+ #### TEST ####
45
+ labels = os.path.join('/gpfswork/rech/okw/ukw13bv/MultiBypass140/labels', 'bern', 'labels_by70_splits/labels', 'test', f'1fps_0.pickle')
46
+ images = os.path.join('/gpfsscratch/rech/okw/ukw13bv/bypass/BernBypass70/frames')
47
+ videos_test, phase_labels_test, step_labels_test = read_pkl_data(labels, images)
48
+
49
+ _base_ = ['../base.py']
50
+ config = dict(
51
+ train_config=[
52
+ dict(
53
+ type='Recognition_frame_bypass',
54
+ img_list=v,
55
+ label_list=l,
56
+ transforms=transforms.Compose(
57
+ [
58
+ transforms.Resize((360, 640)),
59
+ transforms.CenterCrop(224),
60
+ transforms.ToTensor(),
61
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
62
+ ]
63
+ ),
64
+ ) for v, l in zip(videos_train, phase_labels_train)
65
+ ],
66
+ val_config=[
67
+ dict(
68
+ type='Recognition_frame_bypass',
69
+ img_list=v,
70
+ label_list=l,
71
+ transforms=transforms.Compose(
72
+ [
73
+ transforms.Resize((360, 640)),
74
+ transforms.CenterCrop(224),
75
+ transforms.ToTensor(),
76
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
77
+ ]
78
+ ),
79
+ ) for v, l in zip(videos_val, phase_labels_val)
80
+ ],
81
+ test_config=[
82
+ dict(
83
+ type='Recognition_frame_bypass',
84
+ img_list=v,
85
+ label_list=l,
86
+ transforms=transforms.Compose(
87
+ [
88
+ transforms.Resize((360, 640)),
89
+ transforms.CenterCrop(224),
90
+ transforms.ToTensor(),
91
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
92
+ ]
93
+ ),
94
+ ) for v, l in zip(videos_test, phase_labels_test)
95
+ ],
96
+ model_config = dict(
97
+ type='MVNet_feature_extractor',
98
+ backbone_img = dict(
99
+ type='img_backbones/ImageEncoder_feature_extractor',
100
+ # type='img_backbones/ImageEncoder_CLIPVISUAL',
101
+ num_classes=768,
102
+ pretrained='imagenet', # imagenet/ssl/random
103
+ backbone_name='resnet_50',
104
+ # backbone_name='resnet_50_clip'
105
+ img_norm=False,
106
+ ),
107
+ backbone_text= dict(
108
+ type='text_backbones/BertEncoder',
109
+ text_bert_type='/gpfswork/rech/okw/ukw13bv/mmsl/biobert_pretrain_output_all_notes_150000',
110
+ text_last_n_layers=4,
111
+ text_aggregate_method='sum',
112
+ text_norm=False,
113
+ text_embedding_dim=768,
114
+ text_freeze_bert=False,
115
+ text_agg_tokens=True
116
+ )
117
+ )
118
+ )
119
+