File size: 5,569 Bytes
d4b77ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# coding=utf-8
import os
import sys

sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..')))

import argparse
import os
import cv2
import glob
import copy
import numpy as np
import torch
from PIL import Image
import scipy.ndimage
import torchvision.transforms.functional as F
import torch.nn.functional as F2
from RAFT import utils
from RAFT import RAFT

import utils.region_fill as rf
from torchvision.transforms import ToTensor
import time


def to_tensor(img):
    img = Image.fromarray(img)
    img_t = F.to_tensor(img).float()
    return img_t


def gradient_mask(mask):  # 产生梯度的mask

    gradient_mask = np.logical_or.reduce((mask,
                                          np.concatenate((mask[1:, :], np.zeros((1, mask.shape[1]), dtype=np.bool)),
                                                         axis=0),
                                          np.concatenate((mask[:, 1:], np.zeros((mask.shape[0], 1), dtype=np.bool)),
                                                         axis=1)))

    return gradient_mask


def create_dir(dir):
    """Creates a directory if not exist.
    """
    if not os.path.exists(dir):
        os.makedirs(dir)


def initialize_RAFT(args):
    """Initializes the RAFT model.
    """
    model = torch.nn.DataParallel(RAFT(args))
    model.load_state_dict(torch.load(args.model))

    model = model.module
    model.to('cuda')
    model.eval()

    return model


def calculate_flow(args, model, vid, video, mode):
    """Calculates optical flow.
    """
    if mode not in ['forward', 'backward']:
        raise NotImplementedError

    nFrame, _, imgH, imgW = video.shape
    Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32)

    create_dir(os.path.join(args.outroot, vid, mode + '_flo'))
    # create_dir(os.path.join(args.outroot, vid, 'flow', mode + '_png'))

    with torch.no_grad():
        for i in range(video.shape[0] - 1):
            print("Calculating {0} flow {1:2d} <---> {2:2d}".format(mode, i, i + 1), '\r', end='')
            if mode == 'forward':
                # Flow i -> i + 1
                image1 = video[i, None]
                image2 = video[i + 1, None]
            elif mode == 'backward':
                # Flow i + 1 -> i
                image1 = video[i + 1, None]
                image2 = video[i, None]
            else:
                raise NotImplementedError

            _, flow = model(image1, image2, iters=20, test_mode=True)
            flow = flow[0].permute(1, 2, 0).cpu().numpy()
            # Flow = np.concatenate((Flow, flow[..., None]), axis=-1)

            # Flow visualization.
            # flow_img = utils.flow_viz.flow_to_image(flow)
            # flow_img = Image.fromarray(flow_img)

            # Saves the flow and flow_img.
            # flow_img.save(os.path.join(args.outroot, vid, 'flow', mode + '_png', '%05d.png'%i))
            utils.frame_utils.writeFlow(os.path.join(args.outroot, vid, mode + '_flo', '%05d.flo' % i), flow)


def main(args):
    # Flow model.
    RAFT_model = initialize_RAFT(args)

    videos = os.listdir(args.path)
    videoLen = len(videos)
    try:
        exceptList = os.listdir(args.expdir)
    except:
        exceptList = []
    v = 0
    for vid in videos:
        v += 1
        print('[{}]/[{}] Video {} is being processed'.format(v, len(videos), vid))
        if vid in exceptList:
            print('Video: {} skipped'.format(vid))
            continue
        # Loads frames.
        filename_list = glob.glob(os.path.join(args.path, vid, '*.png')) + \
                        glob.glob(os.path.join(args.path, vid, '*.jpg'))

        # Obtains imgH, imgW and nFrame.
        imgH, imgW = np.array(Image.open(filename_list[0])).shape[:2]
        nFrame = len(filename_list)
        print('images are loaded')

        # Loads video.
        video = []
        for filename in sorted(filename_list):
            print(filename)
            img = np.array(Image.open(filename))
            if args.width != 0 and args.height != 0:
                img = cv2.resize(img, (args.width, args.height), cv2.INTER_LINEAR)
            video.append(torch.from_numpy(img.astype(np.uint8)).permute(2, 0, 1).float())

        video = torch.stack(video, dim=0)
        video = video.to('cuda')

        # Calcutes the corrupted flow.
        start = time.time()
        calculate_flow(args, RAFT_model, vid, video, 'forward')
        calculate_flow(args, RAFT_model, vid, video, 'backward')
        end = time.time()
        sumTime = end - start
        print('{}/{}, video {} is finished. {} frames takes {}s, {}s/frame.'.format(v, videoLen, vid, nFrame, sumTime,
                                                                                    sumTime / (2 * nFrame)))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # flow basic setting
    parser.add_argument('--path', required=True, type=str)
    parser.add_argument('--expdir', type=str)
    parser.add_argument('--outroot', required=True, type=str)
    parser.add_argument('--width', type=int, default=432)
    parser.add_argument('--height', type=int, default=256)

    # RAFT
    parser.add_argument('--model', default='../weight/raft-things.pth', help="restore checkpoint")
    parser.add_argument('--small', action='store_true', help='use small model')
    parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
    parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')

    args = parser.parse_args()

    main(args)