File size: 6,005 Bytes
39d5658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import sys
sys.path.insert(0, './')

import decord
import numpy as np
import torch
import os

from lavila.data.video_transforms import Permute
from lavila.data.datasets import get_frame_ids, video_loader_by_frames
from lavila.models.models import VCLM_OPENAI_TIMESFORMER_BASE_GPT2
from lavila.models.tokenizer import MyGPT2Tokenizer
from collections import OrderedDict
import torch
import torchvision.transforms as transforms
import torchvision.transforms._transforms_video as transforms_video
import gradio as gr

def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True):
    seg_size = float(end_frame - start_frame - 1) / num_segments
    seq = []
    for i in range(num_segments):
        start = int(np.round(seg_size * i) + start_frame)
        end = int(np.round(seg_size * (i + 1)) + start_frame)
        end = min(end, end_frame)
        if jitter:
            frame_id = np.random.randint(low=start, high=(end + 1))
        else:
            frame_id = (start + end) // 2
        seq.append(frame_id)
    return seq

def video_loader_by_frames(root, vid, frame_ids):
    vr = decord.VideoReader(os.path.join(root, vid))
    try:
        frames = vr.get_batch(frame_ids).asnumpy()
        frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames]
    except (IndexError, decord.DECORDError) as error:
        print(error)
        print("Erroneous video: ", vid)
        frames = [torch.zeros((240, 320, 3)) for _ in range(len(frame_ids))]
    return torch.stack(frames, dim=0)

def iter_clips(video_path, num_segments=4, stride_size=16):
    # The video is represented by `num_seg=4` frames
    vr = decord.VideoReader(video_path)
    frame_sample_size = num_segments * stride_size
    max_start_frame = len(vr) - frame_sample_size
    curr_frame = 0
    fps = vr.get_avg_fps()
    while curr_frame < max_start_frame:
        stop_frame = min(frame_sample_size, len(vr))
        curr_sec, stop_sec = curr_frame / fps, stop_frame / fps
        frame_ids = get_frame_ids(curr_frame, stop_frame, num_segments=num_segments, jitter=False)
        frames = video_loader_by_frames('./', video_path, frame_ids)
        yield curr_sec, stop_sec, frames
        curr_frame += frame_sample_size


class Pipeline:
    def __init__(self, path=""):
        ckpt_path = os.path.join(path, 'vclm_openai_timesformer_base_gpt2_base.pt_ego4d.jobid_319630.ep_0002.md5sum_68a71f.pth')
        ckpt = torch.load(ckpt_path, map_location='cpu')
        state_dict = OrderedDict()
        for k, v in ckpt['state_dict'].items():
          state_dict[k.replace('module.', '')] = v
        
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = VCLM_OPENAI_TIMESFORMER_BASE_GPT2(
            text_use_cls_token=False,
            project_embed_dim=256,
            gated_xattn=True,
            timesformer_gated_xattn=False,
            freeze_lm_vclm=False,
            freeze_visual_vclm=False,
            freeze_visual_vclm_temporal=False,
            num_frames=4,
            drop_path_rate=0.
        )
        self.model.load_state_dict(state_dict, strict=True)
        self.model.to(self.device)
        self.model.eval()

        self.tokenizer = MyGPT2Tokenizer('gpt2', add_bos=True)

        crop_size = 224
        self.val_transform = transforms.Compose([
            Permute([3, 0, 1, 2]),
            transforms.Resize(crop_size),
            transforms.CenterCrop(crop_size),
            transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])
        ])

    def decode_one(self, generated_ids, tokenizer):
        # get the index of <EOS>
        if tokenizer.eos_token_id == tokenizer.bos_token_id:
            if tokenizer.eos_token_id in generated_ids[1:].tolist():
                eos_id = generated_ids[1:].tolist().index(tokenizer.eos_token_id) + 1
            else:
                eos_id = len(generated_ids.tolist()) - 1
        elif tokenizer.eos_token_id in generated_ids.tolist():
            eos_id = generated_ids.tolist().index(tokenizer.eos_token_id)
        else:
            eos_id = len(generated_ids.tolist()) - 1
        generated_text_str = tokenizer.tokenizer.decode(generated_ids[1:eos_id].tolist())
        return generated_text_str

    def __call__(self, video_path, temperature=0.7, top_p=0.95, max_text_length=77, num_return_sequences=10):
        text = ""
        with torch.autocast(self.device):
            for start, stop, frames in iter_clips(video_path):
                text += f"{'-'*30} Predictions From: {start:10.3f}-{stop:10.3f} seconds {'-'*30}\n"
                frames = self.val_transform(frames).unsqueeze(0)
                if self.device == 'cuda':
                    frames = frames.to(self.device).half()
    
                with torch.no_grad():
                    image_features = self.model.encode_image(frames)
                    generated_text_ids, ppls = self.model.generate(
                        image_features,
                        self.tokenizer,
                        target=None, # free-form generation
                        max_text_length=max_text_length,
                        top_k=None,
                        top_p=top_p,  # nucleus sampling
                        num_return_sequences=num_return_sequences, # number of candidates: 10
                        temperature=temperature,
                        early_stopping=True,
                    )
                for i in range(num_return_sequences):
                    generated_text_str = self.decode_one(generated_text_ids[i], self.tokenizer)
                    text += '\t{}: {}\n'.format(i, generated_text_str)
        return text

interface = gr.Interface(
    Pipeline(),
    inputs=[
        gr.Video(label='video_path'),
        gr.Slider(0.0, 1.0, 0.7, label='temperature'),
        gr.Slider(0.0, 1.0, 0.95, label='top_p'),
    ],
    outputs='text'
)

if __name__ == '__main__':
    interface.launch(debug=True)