File size: 3,648 Bytes
9d0a4ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import pdb
import torch as th
import math
import numpy as np
import torch
from video_loader import VideoLoader
from torch.utils.data import DataLoader
import argparse
from preprocessing import Preprocessing
import torch.nn.functional as F
from tqdm import tqdm
import os
import sys
from feature_extractor import clip
import argparse

#################################
model_version = "ViT-B/32"
output_feat_size = 512
clip_len = 2
overwrite = True
num_decoding_thread = 4
half_precision = False

@torch.no_grad()
def extractor(vid_path, text, output_file):
  dataset = VideoLoader(
      vid_path,
      framerate=1/clip_len,
      size=224,
      centercrop=True,
      overwrite=overwrite,
      model_version=model_version
  )
  n_dataset = len(dataset)
  loader = DataLoader(
      dataset,
      batch_size=1,
      shuffle=False,
      num_workers=num_decoding_thread,
      sampler=sampler if n_dataset > 10 else None,
  )
  preprocess = Preprocessing()
  model, _ = clip.load(model_version, device="cuda", jit=False)

  encoded_texts = clip.tokenize(text).to('cuda')
  text_feature = model.encode_text(encoded_texts)['last_hidden_state']
  valid_lengths = (encoded_texts != 0).sum(1).tolist()[0]
  text_feature = text_feature[0, :valid_lengths].cpu().numpy()
  np.savez(os.path.join(output_file, 'txt.npz'), features=text_feature)

  totatl_num_frames = 0
  with th.no_grad():
      for k, data in enumerate(tqdm(loader)):
          input_file = data['input'][0]
          if os.path.isfile(output_file):
              # print(f'Video {input_file} already processed.')
              continue
          elif not os.path.isfile(input_file):
              print(f'{input_file}, does not exist.\n')
          elif len(data['video'].shape) > 4:
              video = data['video'].squeeze(0)
              if len(video.shape) == 4:
                  video = preprocess(video)
                  n_chunk = len(video)
                  vid_features = th.cuda.FloatTensor(
                      n_chunk, output_feat_size).fill_(0)
                  n_iter = int(math.ceil(n_chunk))
                  for i in range(n_iter):
                      min_ind = i
                      max_ind = (i + 1)
                      video_batch = video[min_ind:max_ind].cuda()
                      batch_features = model.encode_image(video_batch)
                      vid_features[min_ind:max_ind] = batch_features
                  vid_features = vid_features.cpu().numpy()
                  if half_precision:
                      vid_features = vid_features.astype('float16')
                  totatl_num_frames += vid_features.shape[0]
                  # safeguard output path before saving
                  dirname = os.path.dirname(output_file)
                  if not os.path.exists(dirname):
                      print(f"Output directory {dirname} does not exists, creating...")
                      os.makedirs(dirname)
                  np.savez(os.path.join(output_file, 'vid.npz'), features=vid_features)
          else:
              print(f'{input_file}, failed at ffprobe.\n')

  print(f"Total number of frames: {totatl_num_frames}")

if __name__ == "__main__":
  parser = argparse.ArgumentParser(description='')
  parser.add_argument('--vid_path', type=str, default='/data/home/qinghonglin/dataset/charades/videos/Charades_v1_480/0A8CF.mp4')
  parser.add_argument('--text', nargs='+', type=str, default='a boy is drinking.')
  parser.add_argument('--save_dir', type=str, default='./tmp')
  args = parser.parse_args()

  query = ' '.join(args.text)

  print(args.vid_path)
  print(query)
  extractor(args.vid_path, [query], args.save_dir)