xinyiW915 commited on
Commit
3879b24
·
verified ·
1 Parent(s): f88996a

Upload 5 files

Browse files
extractor/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # __init__.py
2
+ print("Initializing extractor")
extractor/vf_extract.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import pandas as pd
3
+ import cv2
4
+ import os
5
+
6
+ def extract_frames(video_path, sampled_path, frame_interval, residual=False):
7
+ try:
8
+ video_name = os.path.splitext(os.path.basename(video_path))[0]
9
+ cap = cv2.VideoCapture(video_path)
10
+ frames = []
11
+
12
+ if not cap.isOpened():
13
+ print(f"Error: Could not open video file {video_path}")
14
+ return frames
15
+
16
+ frame_count = 0
17
+ saved_frame_count = 0
18
+ while True:
19
+ ret, frame = cap.read()
20
+ if not ret:
21
+ break
22
+ if (frame_count % frame_interval == 0 and not residual) or (
23
+ (frame_count - 1) % frame_interval == 0 and residual):
24
+ # suffix = '_next' if residual else ''
25
+ # output_filename = os.path.join(sampled_path, f'{video_name}_{saved_frame_count + 1}{suffix}.png')
26
+ # cv2.imwrite(output_filename, frame)
27
+ frames.append(frame)
28
+ saved_frame_count += 1
29
+
30
+ frame_count += 1
31
+ cap.release()
32
+ frame_type = 'next frames' if residual else 'sampled frames'
33
+ print(f'Extraction of {frame_type} for {video_name} completed!')
34
+ return frames
35
+ except Exception as e:
36
+ print(f"An unexpected error occurred: {e}")
37
+
38
+
39
+ def process_video_residual(video_type, video_name, framerate, video_path, sampled_path):
40
+ if not os.path.exists(sampled_path):
41
+ os.makedirs(sampled_path)
42
+ # cap = cv2.VideoCapture(video_path)
43
+ # framerate = cap.get(cv2.CAP_PROP_FPS)
44
+ # print(f'framerate: {framerate}')
45
+ frame_interval = math.ceil(framerate / 2) if framerate < 2 else int(framerate / 2)
46
+ # print(f'Frame interval: {frame_interval}')
47
+
48
+ frames = extract_frames(video_path, sampled_path, frame_interval, residual=False)
49
+ frames_next = extract_frames(video_path, sampled_path, frame_interval, residual=True)
50
+ return frames, frames_next
51
+
52
+
53
+ if __name__ == '__main__':
54
+ video_type = 'test'
55
+
56
+ if video_type == 'test':
57
+ ugcdata = pd.read_csv("../../metadata/test_videos.csv")
58
+
59
+ for i in range(len(ugcdata)):
60
+ video_name = ugcdata['vid'][i]
61
+ framerate = ugcdata['framerate'][i]
62
+ print(f'Processing video: {video_name}, framerate: {framerate}')
63
+
64
+ video_path = f"../../ugc_original_videos/{video_name}.mp4"
65
+ sampled_path = f'../../video_sampled_frame/original_sampled_frame/{video_name}/'
66
+
67
+ if not os.path.exists(sampled_path):
68
+ os.makedirs(sampled_path)
69
+
70
+ print(f'{video_name}')
71
+ frames, frames_next = process_video_residual(video_type, video_name, framerate, video_path, sampled_path)
72
+ print(f'Extracted {len(frames)} frames and {len(frames_next)} residual frames for video: {video_name}')
extractor/visualise_resnet.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore")
3
+ import os
4
+ import glob
5
+ import numpy as np
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+ import torch
9
+ from torchvision import models, transforms
10
+ from thop import profile
11
+ is_flop_cal = False
12
+
13
+ # get the activation
14
+ def get_activation(model, layer, input_img_data):
15
+ model.eval()
16
+ activations = []
17
+ inputs = []
18
+
19
+ def hook(module, input, output):
20
+ activations.append(output)
21
+ inputs.append(input[0])
22
+
23
+ hook_handle = layer.register_forward_hook(hook)
24
+ with torch.no_grad():
25
+ model(input_img_data)
26
+ hook_handle.remove()
27
+ return activations, inputs
28
+
29
+ def get_activation_map(frame, layer_name, resnet50, device):
30
+ # image pre-processing
31
+ transform = transforms.Compose([
32
+ transforms.Resize((224, 224)),
33
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
34
+ ])
35
+
36
+ # Apply the transformations (resize and normalize)
37
+ frame_tensor = transform(frame)
38
+
39
+ # adding index 0 changes the original [C, H, W] shape to [1, C, H, W]
40
+ if frame_tensor.dim() == 3:
41
+ frame_tensor = frame_tensor.unsqueeze(0)
42
+ # print(f'Image dimension: {frame_tensor.shape}')
43
+
44
+ # getting the activation of a given layer
45
+ layer_obj = eval(layer_name)
46
+ activations, inputs = get_activation(resnet50, layer_obj, frame_tensor)
47
+ activated_img = activations[0][0]
48
+ activation_array = activated_img.cpu().numpy()
49
+
50
+ # calculate FLOPs for layer
51
+ if is_flop_cal == True:
52
+ flops, params = profile(layer_obj, inputs=(inputs[0],), verbose=False)
53
+ if params == 0 and isinstance(layer_obj, torch.nn.Conv2d):
54
+ params = layer_obj.in_channels * layer_obj.out_channels * layer_obj.kernel_size[0] * layer_obj.kernel_size[1]
55
+ if layer_obj.bias is not None:
56
+ params += layer_obj.out_channels
57
+ # print(f"FLOPs for {layer_name}: {flops}, Params: {params}")
58
+ else:
59
+ flops, params = None, None
60
+ return activated_img, activation_array, flops, params
61
+
62
+ def process_video_frame(video_name, frame, frame_number, all_layers, resnet50, device):
63
+ # create a dictionary to store activation arrays for each layer
64
+ activations_dict = {}
65
+ total_flops = 0
66
+ total_params = 0
67
+ for layer_name in all_layers:
68
+ fig_name = f"resnet50_feature_map_layer_{layer_name}"
69
+ combined_name = f"resnet50_feature_map"
70
+
71
+ activated_img, activation_array, flops, params = get_activation_map(frame, layer_name, resnet50, device)
72
+ if is_flop_cal == True:
73
+ total_flops += flops
74
+ total_params += params
75
+
76
+ # save activation maps as png
77
+ # png_path = f'../visualisation/resnet50/{video_name}/frame_{frame_number}/'
78
+ # npy_path = f'../features/resnet50/{video_name}/frame_{frame_number}/'
79
+ # os.makedirs(png_path, exist_ok=True)
80
+ # os.makedirs(npy_path, exist_ok=True)
81
+ # get_activation_png(png_path, fig_name, activated_img)
82
+ # save activation features as npy
83
+ # get_activation_npy(npy_path, fig_name, activation_array)
84
+ # save to the dictionary
85
+ activations_dict[layer_name] = activated_img
86
+
87
+ # print(f"total FLOPs for Resnet50 layerstack: {total_flops}, Params: {total_params}")
88
+ frame_npy_path = f'../features/resnet50/{video_name}/frame_{frame_number}_{combined_name}.npy'
89
+ return activations_dict, frame_npy_path, total_flops, total_params
90
+
91
+ def get_activation_png(png_path, fig_name, activated_img, n=8):
92
+ fig = plt.figure(figsize=(10, 10))
93
+
94
+ # visualise activation map for 64 channels
95
+ for i in range(n):
96
+ for j in range(n):
97
+ idx = (n * i) + j
98
+ if idx >= activated_img.shape[0]:
99
+ break
100
+ ax = fig.add_subplot(n, n, idx + 1)
101
+ ax.imshow(activated_img[idx].cpu().numpy(), cmap='viridis')
102
+ ax.axis('off')
103
+
104
+ # save figures
105
+ fig_path = f'{png_path}{fig_name}.png'
106
+ print(fig_path)
107
+ print("----------------" + '\n')
108
+ plt.savefig(fig_path)
109
+ plt.close()
110
+
111
+ def get_activation_npy(npy_path, fig_name, activation_array):
112
+ np.save(f'{npy_path}{fig_name}.npy', activation_array)
113
+
114
+ if __name__ == '__main__':
115
+ device_name = "gpu"
116
+ if device_name == "gpu":
117
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
118
+ else:
119
+ device = torch.device("cpu")
120
+ print(f"Running on {'GPU' if device.type == 'cuda' else 'CPU'}")
121
+ # pre-trained ResNet-50 model to device
122
+ resnet50 = models.resnet50(pretrained=True).to(device)
123
+
124
+ all_layers = ['resnet50.conv1',
125
+ 'resnet50.layer1[0]', 'resnet50.layer1[1]', 'resnet50.layer1[2]',
126
+ 'resnet50.layer2[0]', 'resnet50.layer2[1]', 'resnet50.layer2[2]', 'resnet50.layer2[3]',
127
+ 'resnet50.layer3[0]', 'resnet50.layer3[1]', 'resnet50.layer3[2]', 'resnet50.layer3[3]',
128
+ 'resnet50.layer4[0]', 'resnet50.layer4[1]', 'resnet50.layer4[2]']
129
+
130
+ video_type = 'test'
131
+ # Test
132
+ if video_type == 'test':
133
+ metadata_path = "../../metadata/test_videos.csv"
134
+ # NR:
135
+ elif video_type == 'resolution_ugc':
136
+ resolution = '360P'
137
+ metadata_path = f"../../metadata/YOUTUBE_UGC_{resolution}_metadata.csv"
138
+ else:
139
+ metadata_path = f'../../metadata/{video_type.upper()}_metadata.csv'
140
+
141
+ ugcdata = pd.read_csv(metadata_path)
142
+ for i in range(len(ugcdata)):
143
+ video_name = ugcdata['vid'][i]
144
+ sampled_frame_path = os.path.join('../..', 'video_sampled_frame', 'sampled_frame', f'{video_name}')
145
+
146
+ print(f"Processing video: {video_name}")
147
+ image_paths = glob.glob(os.path.join(sampled_frame_path, f'{video_name}_*.png'))
148
+ frame_number = 0
149
+ for image in image_paths:
150
+ print(f"{image}")
151
+ frame_number += 1
152
+ process_video_frame(video_name, image, frame_number, all_layers, resnet50, device)
153
+
154
+ # # ResNet-50 layers to visualize
155
+ # layers_to_visualize_resnet50 = {
156
+ # 'conv1': 0,
157
+ # 'layer1.0.conv1': 2,
158
+ # 'layer1.0.conv2': 3,
159
+ # 'layer1.1.conv1': 5,
160
+ # 'layer1.1.conv2': 6,
161
+ # 'layer1.2.conv1': 8,
162
+ # 'layer1.2.conv2': 9,
163
+ # 'layer2.0.conv1': 11,
164
+ # 'layer2.0.conv2': 12,
165
+ # 'layer2.1.conv1': 14,
166
+ # 'layer2.1.conv2': 15,
167
+ # 'layer2.2.conv1': 17,
168
+ # 'layer2.2.conv2': 18,
169
+ # 'layer2.3.conv1': 20,
170
+ # 'layer2.3.conv2': 21,
171
+ # 'layer3.0.conv1': 23,
172
+ # 'layer3.0.conv2': 24,
173
+ # 'layer3.0.downsample.0': 25,
174
+ # 'layer3.1.conv1': 27,
175
+ # 'layer3.1.conv2': 28,
176
+ # 'layer3.2.conv1': 30,
177
+ # 'layer3.2.conv2': 31,
178
+ # 'layer3.3.conv1': 33,
179
+ # 'layer3.3.conv2': 34,
180
+ # 'layer4.0.conv1': 36,
181
+ # 'layer4.0.conv2': 37,
182
+ # 'layer4.0.downsample.0': 38,
183
+ # 'layer4.1.conv1': 40,
184
+ # 'layer4.1.conv2': 41,
185
+ # 'layer4.2.conv1': 43,
186
+ # 'layer4.2.conv2': 44,
187
+ # }
extractor/visualise_resnet_layer.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore")
3
+ import os
4
+ import glob
5
+ import numpy as np
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+ import torch
9
+ from torchvision import models, transforms
10
+ from thop import profile
11
+ is_flop_cal = False
12
+
13
+ # get the activation
14
+ def get_activation(model, layer, input_img_data):
15
+ model.eval()
16
+ activations = []
17
+ inputs = []
18
+
19
+ def hook(module, input, output):
20
+ activations.append(output)
21
+ inputs.append(input[0])
22
+
23
+ hook_handle = layer.register_forward_hook(hook)
24
+ with torch.no_grad():
25
+ model(input_img_data)
26
+ hook_handle.remove()
27
+ return activations, inputs
28
+
29
+ def get_activation_map(frame, layer_name, resnet50, device):
30
+ # image pre-processing
31
+ transform = transforms.Compose([
32
+ transforms.Resize((224, 224)),
33
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
34
+ ])
35
+
36
+ # Apply the transformations (resize and normalize)
37
+ frame_tensor = transform(frame)
38
+
39
+ # adding index 0 changes the original [C, H, W] shape to [1, C, H, W]
40
+ if frame_tensor.dim() == 3:
41
+ frame_tensor = frame_tensor.unsqueeze(0)
42
+ # print(f'Image dimension: {frame_tensor.shape}')
43
+
44
+ # getting the activation of a given layer
45
+ conv_idx = layer_name
46
+ layer_obj = eval(conv_idx)
47
+ activations, inputs = get_activation(resnet50, layer_obj, frame_tensor)
48
+ activated_img = activations[0][0]
49
+ activation_array = activated_img.cpu().numpy()
50
+
51
+ # calculate FLOPs for layer
52
+ if is_flop_cal == True:
53
+ flops, params = profile(layer_obj, inputs=(inputs[0],), verbose=False)
54
+ if params == 0 and isinstance(layer_obj, torch.nn.Conv2d):
55
+ params = layer_obj.in_channels * layer_obj.out_channels * layer_obj.kernel_size[0] * layer_obj.kernel_size[1]
56
+ if layer_obj.bias is not None:
57
+ params += layer_obj.out_channels
58
+ # print(f"FLOPs for {layer_name}: {flops}, Params: {params}")
59
+ else:
60
+ flops, params = None, None
61
+ return activated_img, activation_array, flops, params
62
+
63
+ def process_video_frame(video_name, frame, frame_number, layer_name, resnet50, device):
64
+ # create a dictionary to store activation arrays for each layer
65
+ activations_dict = {}
66
+ total_flops = 0
67
+ total_params = 0
68
+ fig_name = f"resnet50_feature_map_layer_{layer_name}"
69
+ combined_name = f"resnet50_feature_map"
70
+
71
+ activated_img, activation_array, flops, params = get_activation_map(frame, layer_name, resnet50, device)
72
+ if is_flop_cal == True:
73
+ total_flops += flops
74
+ total_params += params
75
+
76
+ # save activation maps as png
77
+ # png_path = f'../visualisation/resnet50/{video_name}/frame_{frame_number}/'
78
+ # npy_path = f'../features/resnet50/{video_name}/frame_{frame_number}/'
79
+ # os.makedirs(png_path, exist_ok=True)
80
+ # os.makedirs(npy_path, exist_ok=True)
81
+ # get_activation_png(png_path, fig_name, activated_img)
82
+ # save activation features as pny
83
+ # get_activation_npy(npy_path, fig_name, activation_array)
84
+
85
+ # print(f"total FLOPs for Resnet50 layerstack: {total_flops}, Params: {total_params}")
86
+ frame_npy_path = f'../features/resnet50/{video_name}/frame_{frame_number}_{combined_name}.npy'
87
+ return activated_img, frame_npy_path, total_flops, total_params
88
+
89
+ def get_activation_png(png_path, fig_name, activated_img, n=8):
90
+ fig = plt.figure(figsize=(10, 10))
91
+
92
+ # visualise activation map for 64 channels
93
+ for i in range(n):
94
+ for j in range(n):
95
+ idx = (n * i) + j
96
+ if idx >= activated_img.shape[0]:
97
+ break
98
+ ax = fig.add_subplot(n, n, idx + 1)
99
+ ax.imshow(activated_img[idx].cpu().numpy(), cmap='viridis')
100
+ ax.axis('off')
101
+
102
+ # save figures
103
+ fig_path = f'{png_path}{fig_name}.png'
104
+ print(fig_path)
105
+ print("----------------" + '\n')
106
+ plt.savefig(fig_path)
107
+ plt.close()
108
+
109
+ def get_activation_npy(npy_path, fig_name, activation_array):
110
+ np.save(f'{npy_path}{fig_name}.npy', activation_array)
111
+
112
+ if __name__ == '__main__':
113
+ device_name = "gpu"
114
+ if device_name == "gpu":
115
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
+ else:
117
+ device = torch.device("cpu")
118
+ print(f"Running on {'GPU' if device.type == 'cuda' else 'CPU'}")
119
+ # pre-trained ResNet-50 model to device
120
+ resnet50 = models.resnet50(pretrained=True).to(device)
121
+
122
+ for idx, (name, layer) in enumerate(resnet50.named_children()):
123
+ print(f"Index: {idx}, Layer Name: {name}, Layer Type: {type(layer)}")
124
+ layer_name = 'layer4.2.conv2'
125
+
126
+ video_type = 'test'
127
+ # Test
128
+ if video_type == 'test':
129
+ metadata_path = "../../metadata/test_videos.csv"
130
+ # NR:
131
+ elif video_type == 'resolution_ugc':
132
+ resolution = '360P'
133
+ metadata_path = f"../../metadata/YOUTUBE_UGC_{resolution}_metadata.csv"
134
+ else:
135
+ metadata_path = f'../../metadata/{video_type.upper()}_metadata.csv'
136
+
137
+ ugcdata = pd.read_csv(metadata_path)
138
+ for i in range(len(ugcdata)):
139
+ video_name = ugcdata['vid'][i]
140
+ sampled_frame_path = os.path.join('../..', 'video_sampled_frame', 'sampled_frame', f'{video_name}')
141
+
142
+ print(f"Processing video: {video_name}")
143
+ image_paths = glob.glob(os.path.join(sampled_frame_path, f'{video_name}_*.png'))
144
+ frame_number = 0
145
+ for image in image_paths:
146
+ print(f"{image}")
147
+ frame_number += 1
148
+ process_video_frame(video_name, image, frame_number, layer_name, resnet50, device)
149
+
150
+ # # ResNet-50 layers to visualize
151
+ # layers_to_visualize_resnet50 = {
152
+ # 'conv1': 0,
153
+ # 'layer1.0.conv1': 2,
154
+ # 'layer1.0.conv2': 3,
155
+ # 'layer1.1.conv1': 5,
156
+ # 'layer1.1.conv2': 6,
157
+ # 'layer1.2.conv1': 8,
158
+ # 'layer1.2.conv2': 9,
159
+ # 'layer2.0.conv1': 11,
160
+ # 'layer2.0.conv2': 12,
161
+ # 'layer2.1.conv1': 14,
162
+ # 'layer2.1.conv2': 15,
163
+ # 'layer2.2.conv1': 17,
164
+ # 'layer2.2.conv2': 18,
165
+ # 'layer2.3.conv1': 20,
166
+ # 'layer2.3.conv2': 21,
167
+ # 'layer3.0.conv1': 23,
168
+ # 'layer3.0.conv2': 24,
169
+ # 'layer3.0.downsample.0': 25,
170
+ # 'layer3.1.conv1': 27,
171
+ # 'layer3.1.conv2': 28,
172
+ # 'layer3.2.conv1': 30,
173
+ # 'layer3.2.conv2': 31,
174
+ # 'layer3.3.conv1': 33,
175
+ # 'layer3.3.conv2': 34,
176
+ # 'layer4.0.conv1': 36,
177
+ # 'layer4.0.conv2': 37,
178
+ # 'layer4.0.downsample.0': 38,
179
+ # 'layer4.1.conv1': 40,
180
+ # 'layer4.1.conv2': 41,
181
+ # 'layer4.2.conv1': 43,
182
+ # 'layer4.2.conv2': 44,
183
+ # }
184
+
185
+ # Index: 0, Layer Name: conv1, Layer Type: <class 'torch.nn.modules.conv.Conv2d'>
186
+ # Index: 1, Layer Name: bn1, Layer Type: <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
187
+ # Index: 2, Layer Name: relu, Layer Type: <class 'torch.nn.modules.activation.ReLU'>
188
+ # Index: 3, Layer Name: maxpool, Layer Type: <class 'torch.nn.modules.pooling.MaxPool2d'>
189
+ # Index: 4, Layer Name: layer1, Layer Type: <class 'torch.nn.modules.container.Sequential'>
190
+ # Index: 5, Layer Name: layer2, Layer Type: <class 'torch.nn.modules.container.Sequential'>
191
+ # Index: 6, Layer Name: layer3, Layer Type: <class 'torch.nn.modules.container.Sequential'>
192
+ # Index: 7, Layer Name: layer4, Layer Type: <class 'torch.nn.modules.container.Sequential'>
193
+ # Index: 8, Layer Name: avgpool, Layer Type: <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>
194
+ # Index: 9, Layer Name: fc, Layer Type: <class 'torch.nn.modules.linear.Linear'>
extractor/visualise_vit_layer.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import math
4
+ from functools import partial
5
+ import torch
6
+
7
+ import ipywidgets as widgets
8
+ import io
9
+ from PIL import Image
10
+ from torchvision import transforms
11
+ import matplotlib.pyplot as plt
12
+ import numpy as np
13
+ import pandas as pd
14
+ from torch import nn
15
+ from thop import profile
16
+ is_flop_cal = False
17
+
18
+ import warnings
19
+ warnings.filterwarnings("ignore")
20
+
21
+ # Step 2: Creating a Vision Transformer
22
+ # normalise the torch
23
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
24
+ # type: (Tensor, float, float, float, float) -> Tensor
25
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
26
+
27
+ #用于执行无梯度截断正态分布初始化。这两个函数在模型初始化中使用,确保权重被适当地初始化。
28
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
29
+ def norm_cdf(x):
30
+ # computes standard normal cumulative distribution function
31
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
32
+
33
+ #对输入进行随机丢弃一部分元素,实现随机深度(Stochastic Depth)。
34
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
35
+ if drop_prob == 0. or not training:
36
+ return x
37
+ keep_prob = 1 - drop_prob
38
+ # work with diff dim tensors, not just 2D ConvNets
39
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
40
+ random_tensor = keep_prob + \
41
+ torch.rand(shape, dtype=x.dtype, device=x.device)
42
+ random_tensor.floor_() # binarize
43
+ output = x.div(keep_prob) * random_tensor
44
+ return output
45
+
46
+ #用于在残差块的主路径上应用 drop_path 函数。
47
+ class DropPath(nn.Module):
48
+ """
49
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
50
+ """
51
+
52
+ def __init__(self, drop_prob=None):
53
+ super(DropPath, self).__init__()
54
+ self.drop_prob = drop_prob
55
+
56
+ def forward(self, x):
57
+ return drop_path(x, self.drop_prob, self.training)
58
+
59
+ #一个多层感知机(MLP)类,包含两个线性层和一个激活函数,用于在残差块中对特征进行非线性映射。
60
+ class Mlp(nn.Module):
61
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
62
+ super().__init__()
63
+ out_features = out_features or in_features
64
+ hidden_features = hidden_features or in_features
65
+ self.fc1 = nn.Linear(in_features, hidden_features)
66
+ self.act = act_layer()
67
+ self.fc2 = nn.Linear(hidden_features, out_features)
68
+ self.drop = nn.Dropout(drop)
69
+
70
+ def forward(self, x):
71
+ x = self.fc1(x)
72
+ x = self.act(x)
73
+ x = self.drop(x)
74
+ x = self.fc2(x)
75
+ x = self.drop(x)
76
+ return x
77
+
78
+ # 自注意力机制类,用于在残差块中计算注意力权重并应用它们。
79
+ class Attention(nn.Module):
80
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
81
+ super().__init__()
82
+ self.num_heads = num_heads
83
+ head_dim = dim // num_heads
84
+ self.scale = qk_scale or head_dim ** -0.5
85
+
86
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
87
+ self.attn_drop = nn.Dropout(attn_drop)
88
+ self.proj = nn.Linear(dim, dim)
89
+ self.proj_drop = nn.Dropout(proj_drop)
90
+
91
+ def forward(self, x):
92
+ B, N, C = x.shape
93
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
94
+ self.num_heads).permute(2, 0, 3, 1, 4)
95
+ q, k, v = qkv[0], qkv[1], qkv[2]
96
+
97
+ attn = (q @ k.transpose(-2, -1)) * self.scale
98
+ attn = attn.softmax(dim=-1)
99
+ attn = self.attn_drop(attn)
100
+
101
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
102
+ x = self.proj(x)
103
+ x = self.proj_drop(x)
104
+ return x, attn
105
+
106
+ # 一个残差块类,包含一个自注意力模块和一个MLP模块。
107
+ class Block(nn.Module):
108
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
109
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
110
+ super().__init__()
111
+ self.norm1 = norm_layer(dim)
112
+ self.attn = Attention(
113
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
114
+ self.drop_path = DropPath(
115
+ drop_path) if drop_path > 0. else nn.Identity()
116
+ self.norm2 = norm_layer(dim)
117
+ mlp_hidden_dim = int(dim * mlp_ratio)
118
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
119
+ act_layer=act_layer, drop=drop)
120
+
121
+ def forward(self, x, return_attention=False):
122
+ y, attn = self.attn(self.norm1(x))
123
+ if return_attention:
124
+ return attn
125
+ x = x + self.drop_path(y)
126
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
127
+ return x
128
+
129
+ # 图像到块嵌入类,将输入图像分割成块并将它们映射到嵌入空间
130
+ class PatchEmbed(nn.Module):
131
+ """
132
+ Image to Patch Embedding
133
+ """
134
+
135
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
136
+ super().__init__()
137
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
138
+ self.img_size = img_size
139
+ self.patch_size = patch_size
140
+ self.num_patches = num_patches
141
+ self.proj = nn.Conv2d(in_chans, embed_dim,
142
+ kernel_size=patch_size, stride=patch_size)
143
+
144
+ def forward(self, x):
145
+ B, C, H, W = x.shape
146
+ x = self.proj(x).flatten(2).transpose(1, 2)
147
+ return x
148
+
149
+ # Vision Transformer模型的主要实现。包含多个残差块、嵌入层等。(还需要学里面每一步代码具体在做什么)
150
+ class VisionTransformer(nn.Module):
151
+ """
152
+ Vision Transformer
153
+ """
154
+ def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
155
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
156
+ drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
157
+ super().__init__()
158
+ self.num_features = self.embed_dim = embed_dim
159
+
160
+ self.patch_embed = PatchEmbed(
161
+ img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
162
+ num_patches = self.patch_embed.num_patches
163
+
164
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
165
+ self.pos_embed = nn.Parameter(
166
+ torch.zeros(1, num_patches + 1, embed_dim))
167
+ self.pos_drop = nn.Dropout(p=drop_rate)
168
+
169
+ # stochastic depth decay rule
170
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
171
+ self.blocks = nn.ModuleList([
172
+ Block(
173
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
174
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
175
+ for i in range(depth)])
176
+ self.norm = norm_layer(embed_dim)
177
+
178
+ # classifier head
179
+ self.head = nn.Linear(
180
+ embed_dim, num_classes) if num_classes > 0 else nn.Identity()
181
+
182
+ trunc_normal_(self.pos_embed, std=.02)
183
+ trunc_normal_(self.cls_token, std=.02)
184
+ self.apply(self._init_weights)
185
+
186
+ def _init_weights(self, m):
187
+ if isinstance(m, nn.Linear):
188
+ trunc_normal_(m.weight, std=.02)
189
+ if isinstance(m, nn.Linear) and m.bias is not None:
190
+ nn.init.constant_(m.bias, 0)
191
+ elif isinstance(m, nn.LayerNorm):
192
+ nn.init.constant_(m.bias, 0)
193
+ nn.init.constant_(m.weight, 1.0)
194
+
195
+ def interpolate_pos_encoding(self, x, w, h):
196
+ npatch = x.shape[1] - 1
197
+ N = self.pos_embed.shape[1] - 1
198
+ if npatch == N and w == h:
199
+ return self.pos_embed
200
+ class_pos_embed = self.pos_embed[:, 0]
201
+ patch_pos_embed = self.pos_embed[:, 1:]
202
+ dim = x.shape[-1]
203
+ w0 = w // self.patch_embed.patch_size
204
+ h0 = h // self.patch_embed.patch_size
205
+
206
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
207
+ w0, h0 = w0 + 0.1, h0 + 0.1
208
+ patch_pos_embed = nn.functional.interpolate(
209
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(
210
+ math.sqrt(N)), dim).permute(0, 3, 1, 2),
211
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
212
+ mode='bicubic',
213
+ )
214
+ assert int(
215
+ w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
216
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
217
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
218
+
219
+ def prepare_tokens(self, x):
220
+ B, nc, w, h = x.shape
221
+ x = self.patch_embed(x) # patch linear embedding
222
+
223
+ # add the [CLS] token to the embed patch tokens
224
+ cls_tokens = self.cls_token.expand(B, -1, -1)
225
+ x = torch.cat((cls_tokens, x), dim=1)
226
+
227
+ # add positional encoding to each token
228
+ x = x + self.interpolate_pos_encoding(x, w, h)
229
+
230
+ return self.pos_drop(x)
231
+
232
+ def forward(self, x):
233
+ x = self.prepare_tokens(x)
234
+ for blk in self.blocks:
235
+ x = blk(x)
236
+ x = self.norm(x)
237
+ return x[:, 0], x[:, 1:] # return CLS token and attention_features maps
238
+
239
+ def get_last_selfattention(self, x):
240
+ x = self.prepare_tokens(x)
241
+ for i, blk in enumerate(self.blocks):
242
+ if i < len(self.blocks) - 1:
243
+ x = blk(x)
244
+ else:
245
+ # return attention of the last block
246
+ # print(f"return attention of the last block: {x.shape}")
247
+ # print(blk(x, return_attention=True).shape)
248
+ return blk(x, return_attention=True)
249
+
250
+ def get_intermediate_layers(self, x, n=1):
251
+ x = self.prepare_tokens(x)
252
+
253
+ output = []
254
+ for i, blk in enumerate(self.blocks):
255
+ x = blk(x)
256
+ if len(self.blocks) - i <= n:
257
+ output.append(self.norm(x))
258
+ return output
259
+
260
+ # Vision Transformer 模型的生成器类,用于实例化和配置特定模型。
261
+ class VitGenerator(object):
262
+ def __init__(self, name_model, patch_size, device, evaluate=True, random=False, verbose=False):
263
+ self.name_model = name_model
264
+ self.patch_size = patch_size
265
+ self.evaluate = evaluate
266
+ self.device = device
267
+ self.verbose = verbose
268
+ self.model = self._getModel()
269
+ self._initializeModel()
270
+ if not random:
271
+ self._loadPretrainedWeights()
272
+
273
+ def _getModel(self):
274
+ if self.verbose:
275
+ pass
276
+ # print((f"[INFO] Initializing {self.name_model} with patch size of {self.patch_size}"))
277
+ if self.name_model == 'vit_tiny':
278
+ model = VisionTransformer(patch_size=self.patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
279
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
280
+
281
+ elif self.name_model == 'vit_small':
282
+ model = VisionTransformer(patch_size=self.patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
283
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
284
+
285
+ elif self.name_model == 'vit_base':
286
+ model = VisionTransformer(patch_size=self.patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
287
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
288
+ else:
289
+ raise f"No model found with {self.name_model}"
290
+
291
+ return model
292
+
293
+ def _initializeModel(self):
294
+ if self.evaluate:
295
+ for p in self.model.parameters():
296
+ p.requires_grad = False
297
+
298
+ self.model.eval()
299
+
300
+ self.model.to(self.device)
301
+
302
+ def _loadPretrainedWeights(self):
303
+ if self.verbose:
304
+ pass
305
+ # print(("[INFO] Loading weights"))
306
+ url = None
307
+ if self.name_model == 'vit_small' and self.patch_size == 16:
308
+ url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
309
+
310
+ elif self.name_model == 'vit_small' and self.patch_size == 8:
311
+ url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"
312
+
313
+ elif self.name_model == 'vit_base' and self.patch_size == 16:
314
+ url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
315
+
316
+ elif self.name_model == 'vit_base' and self.patch_size == 8:
317
+ url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
318
+
319
+ if url is None:
320
+ pass
321
+ # print((f"Since no pretrained weights have been found with name {self.name_model} and patch size {self.patch_size}, random weights will be used"))
322
+
323
+ else:
324
+ state_dict = torch.hub.load_state_dict_from_url(
325
+ url="https://dl.fbaipublicfiles.com/dino/" + url)
326
+ self.model.load_state_dict(state_dict, strict=True)
327
+ # print(url)
328
+
329
+ def get_last_selfattention(self, img):
330
+ return self.model.get_last_selfattention(img.to(self.device))
331
+
332
+ def __call__(self, x):
333
+ return self.model(x)
334
+
335
+ # Step 3: Creating Visualization Functions
336
+ def transform(img, img_size):
337
+ img = transforms.Resize(img_size)(img)
338
+ img = transforms.ToTensor()(img)
339
+ return img
340
+
341
+ def visualize_predict(model, img_tensor, patch_size, device, video_name, frame_number, fig_name, combined_name):
342
+ if img_tensor.dim() == 3:
343
+ img_tensor = img_tensor.unsqueeze(0)
344
+ attention = visualize_attention(model, img_tensor, patch_size, device)
345
+ # save activation maps as png
346
+ # png_path = f'../visualisation/resnet50/{video_name}/frame_{frame_number}/'
347
+ # os.makedirs(png_path, exist_ok=True)
348
+ # get_activation_png(img, png_path, fig_name, attention)
349
+ # save activation features as npy
350
+ activations_dict, frame_npy_path = get_activation_npy(video_name, frame_number, fig_name, combined_name, attention)
351
+ return activations_dict, frame_npy_path
352
+
353
+ def visualize_attention(model, img_tensor, patch_size, device):
354
+ # img_tensor: format [1, C, H, W]
355
+ # Adjust the image dimensions to be divisible by the patch size
356
+ w, h = img_tensor.shape[2] - img_tensor.shape[2] % patch_size, img_tensor.shape[3] - img_tensor.shape[3] % patch_size
357
+ img_tensor = img_tensor[:, :, :w, :h]
358
+
359
+ w_featmap = img_tensor.shape[-2] // patch_size
360
+ h_featmap = img_tensor.shape[-1] // patch_size
361
+
362
+ attentions = model.get_last_selfattention(img_tensor.to(device))
363
+ nh = attentions.shape[1] # number of heads
364
+
365
+ # keep only the output patch attention
366
+ attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
367
+ attentions = attentions.reshape(nh, w_featmap, h_featmap)
368
+ attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].cpu().numpy()
369
+
370
+ return attentions
371
+
372
+ def get_activation_png(img, png_path, fig_name, attention):
373
+ n_heads = attention.shape[0]
374
+
375
+ # attention maps
376
+ for i in range(n_heads):
377
+ plt.imshow(attention[i], cmap='viridis') #cmap='viridis', cmap='inferno'
378
+ plt.title(f"Head n: {i + 1}")
379
+ plt.axis('off') # Turn off axis ticks and labels
380
+
381
+ # Save figures
382
+ fig_path = f'{png_path}{fig_name}_head_{i + 1}.png'
383
+ print(fig_path)
384
+ plt.savefig(fig_path)
385
+ plt.close()
386
+
387
+ # head mean map
388
+ plt.figure(figsize=(10, 10))
389
+ image_name = fig_name.replace('vit_feature_map_', '')
390
+ text = [f"{image_name}", "Head Mean"]
391
+ for i, fig in enumerate([img, np.mean(attention, 0)]):
392
+ plt.subplot(1, 2, i+1)
393
+ plt.imshow(fig, cmap='viridis')
394
+ plt.title(text[i])
395
+ plt.axis('off') # Turn off axis ticks and labels
396
+ fig_path1 = f'{png_path}{fig_name}_head_mean.png'
397
+ print(fig_path1)
398
+ print("----------------" + '\n')
399
+ plt.savefig(fig_path1)
400
+ plt.close()
401
+
402
+ # combine
403
+ # plt.figure(figsize=(20, 20))
404
+ # for i in range(n_heads):
405
+ # plt.subplot(n_heads//3, 3, i+1)
406
+ # plt.imshow(attention[i], cmap='inferno')
407
+ # plt.title(f"Head n: {i+1}")
408
+ # plt.tight_layout()
409
+ # fig_path2 = png_path + fig_name + '_heads.png'
410
+ # print(fig_path2 + '\n')
411
+ # plt.savefig(fig_path2)
412
+ # plt.close()
413
+
414
+ def get_activation_npy(video_name, frame_number, fig_name, combined_name, attention):
415
+ # save activation features as pny
416
+ # npy_path = f'../features/vit/{video_name}/frame_{frame_number}/'
417
+ # os.makedirs(npy_path, exist_ok=True)
418
+
419
+ mean_attention = attention.mean(axis=0)
420
+ frame_npy_path = f'../features/vit/{video_name}/frame_{frame_number}_{combined_name}.npy'
421
+
422
+ return mean_attention, frame_npy_path
423
+
424
+
425
+ class Loader(object):
426
+ def __init__(self):
427
+ self.uploader = widgets.FileUpload(accept='image/*', multiple=False)
428
+ self._start()
429
+
430
+ def _start(self):
431
+ display(self.uploader)
432
+
433
+ def getLastImage(self):
434
+ try:
435
+ for uploaded_filename in self.uploader.value:
436
+ uploaded_filename = uploaded_filename
437
+ img = Image.open(io.BytesIO(
438
+ bytes(self.uploader.value[uploaded_filename]['content'])))
439
+
440
+ return img
441
+ except:
442
+ return None
443
+
444
+ def saveImage(self, path):
445
+ with open(path, 'wb') as output_file:
446
+ for uploaded_filename in self.uploader.value:
447
+ content = self.uploader.value[uploaded_filename]['content']
448
+ output_file.write(content)
449
+
450
+ def process_video_frame(video_name, frame, frame_number, model, patch_size, device):
451
+ # resize image
452
+ if frame.dim() == 3:
453
+ frame = frame.unsqueeze(0)
454
+ if frame.shape[2:] != (224, 224):
455
+ frame_tensor = torch.nn.functional.interpolate(frame, size=(224, 224), mode='bicubic', align_corners=False)
456
+ else:
457
+ frame_tensor = frame
458
+
459
+ # Calculate FLOPs and Params
460
+ if is_flop_cal == True:
461
+ total_flops, total_params = profile(model.model, inputs=(frame_tensor,), verbose=False)
462
+ print(f"total FLOPs for ViT layerstack: {total_flops}, Params: {total_params}")
463
+ else:
464
+ total_flops, total_params = None, None
465
+
466
+ fig_name = f"vit_feature_map"
467
+ combined_name = f"vit_feature_map"
468
+
469
+ # activations_dict, frame_npy_path = visualize_predict(model, frame_tensor, patch_size, device, video_name, frame_number, fig_name, combined_name)
470
+ attention_features, frame_feature_npy_path = extract_features(model, frame_tensor, video_name, frame_number, combined_name)
471
+ return attention_features, frame_feature_npy_path, total_flops, total_params
472
+
473
+ def extract_features(model, img_tensor, video_name, frame_number, combined_name):
474
+ if img_tensor.dim() == 3:
475
+ img_tensor = img_tensor.unsqueeze(0)
476
+ cls_token, attention_features = model(img_tensor)
477
+
478
+ attention_features = attention_features.squeeze(0)
479
+ frame_feature_npy_path = f'../features/vit/{video_name}/frame_attention_{frame_number}_{combined_name}.npy'
480
+ return attention_features, frame_feature_npy_path
481
+
482
+ if __name__ == '__main__':
483
+ # Step 4: Visualizing Images
484
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
485
+ if device.type == "cuda":
486
+ torch.cuda.set_device(0)
487
+
488
+ name_model = 'vit_base'
489
+ patch_size = 16
490
+
491
+ model = VitGenerator(name_model, patch_size,
492
+ device, evaluate=True, random=False, verbose=True)
493
+
494
+ video_type = 'test'
495
+ # Test
496
+ if video_type == 'test':
497
+ metadata_path = "../../metadata/test_videos.csv"
498
+ # NR:
499
+ elif video_type == 'resolution_ugc':
500
+ resolution = '360P'
501
+ metadata_path = f"../../metadata/YOUTUBE_UGC_{resolution}_metadata.csv"
502
+ else:
503
+ metadata_path = f'../../metadata/{video_type.upper()}_metadata.csv'
504
+
505
+ ugcdata = pd.read_csv(metadata_path)
506
+ for i in range(len(ugcdata)):
507
+ video_name = ugcdata['vid'][i]
508
+ sampled_frame_path = os.path.join('../..', 'video_sampled_frame', 'sampled_frame', f'{video_name}')
509
+
510
+ print(f"Processing video: {video_name}")
511
+ image_paths = glob.glob(os.path.join(sampled_frame_path, f'{video_name}_*.png'))
512
+ frame_number = 0
513
+ for image in image_paths:
514
+ print(f"{image}")
515
+ frame_number += 1
516
+ process_video_frame(video_name, image, frame_number, model, patch_size, device)