File size: 5,045 Bytes
3beb455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import numpy as np
import json
from tqdm import tqdm
from argparse import ArgumentParser, Namespace

parser = ArgumentParser()
parser.add_argument("--video_folder", default='demo_videos', type=str, help="Folder path of your videos")
parser.add_argument("--diff_feat_folder", default='output_diffusion_features', type=str, help="Folder path of your extracted diffusion features")
parser.add_argument("--merged_feat_folder", default='output_merged_diffusion_features', type=str, help="Folder path of output merged diffusion features")
args = parser.parse_args()

feat_content_root = os.path.join(args.diff_feat_folder, 'feat_content')
feat_degradation_root = os.path.join(args.diff_feat_folder, 'feat_degradation')

all_frames_content = os.listdir(feat_content_root)
all_frames_degradation = os.listdir(feat_degradation_root)

content_save_root = os.path.join(args.merged_feat_folder, 'feat_content')
os.makedirs(content_save_root, exist_ok=True)
degradation_save_root = os.path.join(args.merged_feat_folder, 'feat_degradation')
os.makedirs(degradation_save_root, exist_ok=True)

file_list = os.listdir(args.video_folder)

for file in tqdm(file_list):
    video_name = file.split('.')[0]

    content_feat_list = [f for f in all_frames_content if os.path.isfile(os.path.join(feat_content_root, f)) and f.startswith(video_name)]
    degradation_feat_list = [f for f in all_frames_degradation if os.path.isfile(os.path.join(feat_degradation_root, f)) and f.startswith(video_name)]

    content_feat_list = sorted(content_feat_list)
    degradation_feat_list = sorted(degradation_feat_list)
              
    num_frames = 15 # same as the number set in generate_frame.py
    final_frames_content, final_frames_degradation = [], []
    
    if len(content_feat_list) < num_frames:
        quotient, remainder = divmod(num_frames, len(content_feat_list))            

        for i, item in enumerate(content_feat_list[:remainder]):
            final_frames_content.extend([item] * (quotient+1))
            final_frames_degradation.extend([degradation_feat_list[i]] * (quotient+1))
            
        for i, item in enumerate(content_feat_list[remainder:]):
            final_frames_content.extend([item] * quotient)
            final_frames_degradation.extend([degradation_feat_list[i]] * quotient)
    
    else: 
        step = len(content_feat_list) / num_frames
        final_frames_content = [content_feat_list[int(i * step)] for i in range(num_frames)]
        final_frames_degradation = [degradation_feat_list[int(i * step)] for i in range(num_frames)]
        
        ''' Merge content feat '''
        merged_data = {
            'pred_latent_000': [],
            'input_latent_000': [],
            'input_unet_000_000': [],
            'input_unet_000_001': [],
            'output_unet_000_000': [],
            'output_unet_000_001': []
        }

        # merge feat of different frames
        for file in final_frames_content:
            data = np.load(os.path.join(feat_content_root, file))
            for key in merged_data.keys():
                merged_data[key].append(data[key])

        for key in merged_data.keys():
            merged_data[key] = np.stack(merged_data[key], axis=0)

        save_path = os.path.join(content_save_root, video_name + '.npz')
        np.savez(save_path, 
                pred_latent=merged_data['pred_latent_000'], 
                input_latent=merged_data['input_latent_000'], 
                input_unet_000=merged_data['input_unet_000_000'], 
                input_unet_001=merged_data['input_unet_000_001'], 
                output_unet_000=merged_data['output_unet_000_000'], 
                output_unet_001=merged_data['output_unet_000_001'],
                )    

        ''' Merge degradation feat '''        
        merged_data = {}

        # merge feat of different frames
        for file in final_frames_degradation:
            data = np.load(os.path.join(feat_degradation_root, file))

            for key in data.files:
                if key not in merged_data:
                    merged_data[key] = []
                    
                merged_data[key].append(data[key])

        for key in merged_data.keys():
            merged_data[key] = np.stack(merged_data[key], axis=0)

        save_path = os.path.join(degradation_save_root, video_name + '.npz')
        np.savez(save_path, **{key: merged_data[key] for key in merged_data})
        
        '''
        # Delete original files to save storage if needed
        for file in final_frames_content:
            file_path = os.path.join(feat_content_root, file)
            try:
                os.remove(file_path)
            
            except: 
                print('{} cannot be deleted!'.format(file_path))

        for file in final_frames_degradation:
            file_path = os.path.join(feat_degradation_root, file)
            try:
                os.remove(file_path)
            
            except: 
                print('{} cannot be deleted!'.format(file_path)) 
        '''