Spaces:
Runtime error
Runtime error
File size: 7,369 Bytes
39d5658 6c22a09 96f2ad7 39d5658 6a77b6d 39d5658 a5d0544 eb5f8c6 39d5658 eb5f8c6 6a77b6d a5d0544 6a77b6d 39d5658 8b7f367 c88defd 8b7f367 39d5658 6a77b6d 8b7f367 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import sys
sys.path.insert(0, './')
import decord
import numpy as np
import torch
import os
from lavila.data.video_transforms import Permute
from lavila.data.datasets import get_frame_ids, video_loader_by_frames
from lavila.models.models import VCLM_OPENAI_TIMESFORMER_BASE_GPT2
from lavila.models.tokenizer import MyGPT2Tokenizer
from collections import OrderedDict
import torch
import torchvision.transforms as transforms
import torchvision.transforms._transforms_video as transforms_video
import gradio as gr
def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True):
seg_size = float(end_frame - start_frame - 1) / num_segments
seq = []
for i in range(num_segments):
start = int(np.round(seg_size * i) + start_frame)
end = int(np.round(seg_size * (i + 1)) + start_frame)
end = min(end, end_frame)
if jitter:
frame_id = np.random.randint(low=start, high=(end + 1))
else:
frame_id = (start + end) // 2
seq.append(frame_id)
return seq
def video_loader_by_frames(root, vid, frame_ids):
vr = decord.VideoReader(os.path.join(root, vid))
try:
frames = vr.get_batch(frame_ids).asnumpy()
frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames]
except (IndexError, decord.DECORDError) as error:
print(error)
print("Erroneous video: ", vid)
frames = [torch.zeros((240, 320, 3)) for _ in range(len(frame_ids))]
return torch.stack(frames, dim=0)
def iter_clips(video_path, num_segments=4, stride_size=16):
# The video is represented by `num_seg=4` frames
vr = decord.VideoReader(video_path)
frame_sample_size = num_segments * stride_size
max_start_frame = len(vr) - frame_sample_size
curr_frame = 0
fps = vr.get_avg_fps()
while curr_frame == 0 or curr_frame < max_start_frame:
stop_frame = min(curr_frame + frame_sample_size, len(vr))
curr_sec, stop_sec = curr_frame / fps, stop_frame / fps
frame_ids = get_frame_ids(curr_frame, stop_frame, num_segments=num_segments, jitter=False)
frames = video_loader_by_frames('./', video_path, frame_ids)
yield curr_sec, stop_sec, frames
curr_frame += frame_sample_size
class Pipeline:
def __init__(self, path=""):
ckpt_path = os.path.join(path, 'vclm_openai_timesformer_base_gpt2_base.pt_ego4d.jobid_319630.ep_0002.md5sum_68a71f.pth')
ckpt = torch.load(ckpt_path, map_location='cpu')
state_dict = OrderedDict()
for k, v in ckpt['state_dict'].items():
state_dict[k.replace('module.', '')] = v
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = VCLM_OPENAI_TIMESFORMER_BASE_GPT2(
text_use_cls_token=False,
project_embed_dim=256,
gated_xattn=True,
timesformer_gated_xattn=False,
freeze_lm_vclm=False,
freeze_visual_vclm=False,
freeze_visual_vclm_temporal=False,
num_frames=4,
drop_path_rate=0.
)
self.model.load_state_dict(state_dict, strict=True)
self.model.to(self.device)
self.model.eval()
self.tokenizer = MyGPT2Tokenizer('gpt2', add_bos=True)
crop_size = 224
self.val_transform = transforms.Compose([
Permute([3, 0, 1, 2]),
transforms.Resize(crop_size),
transforms.CenterCrop(crop_size),
transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])
])
def decode_one(self, generated_ids, tokenizer):
# get the index of <EOS>
if tokenizer.eos_token_id == tokenizer.bos_token_id:
if tokenizer.eos_token_id in generated_ids[1:].tolist():
eos_id = generated_ids[1:].tolist().index(tokenizer.eos_token_id) + 1
else:
eos_id = len(generated_ids.tolist()) - 1
elif tokenizer.eos_token_id in generated_ids.tolist():
eos_id = generated_ids.tolist().index(tokenizer.eos_token_id)
else:
eos_id = len(generated_ids.tolist()) - 1
generated_text_str = tokenizer.tokenizer.decode(generated_ids[1:eos_id].tolist())
return generated_text_str
def __call__(self, video_path, temperature=0.7, top_p=0.95, max_text_length=77, num_return_sequences=10):
text = ""
MAX_ITERATIONS = 5
with torch.autocast(self.device):
for clip_idx, (start, stop, frames) in enumerate(iter_clips(video_path)):
text_to_add = f"{'-'*30} Predictions From: {start:2.3f}-{stop:2.3f} seconds {'-'*30}\n"
print(text_to_add)
text += text_to_add
frames = self.val_transform(frames).unsqueeze(0)
if self.device == 'cuda':
frames = frames.to(self.device).half()
with torch.no_grad():
image_features = self.model.encode_image(frames)
generated_text_ids, ppls = self.model.generate(
image_features,
self.tokenizer,
target=None, # free-form generation
max_text_length=max_text_length,
top_k=None,
top_p=top_p, # nucleus sampling
num_return_sequences=num_return_sequences, # number of candidates: 10
temperature=temperature,
early_stopping=True,
)
for i in range(num_return_sequences):
generated_text_str = self.decode_one(generated_text_ids[i], self.tokenizer)
text_to_add = '\t{}: {}\n'.format(i, generated_text_str)
print(text_to_add)
text += text_to_add
if (clip_idx+1) >= MAX_ITERATIONS:
return text
return text
title = "LaViLa"
description = """LaViLa (**L**anguage **a**ugmented **Vi**deo **La**nguage Pretraining) is a new approach to learning video representations from Large Language Models (LLMs). We repurpose LLMs to be visually conditioned "Narrators", and use them to automatically generate video-language paired data. We use this data to then learn a video-langauge representation, outperforming prior work by large margins. \nGradio Demo for LaVila. To use it, simply upload your video, or click one of the examples to load them. Read more at the links below."""
article = "<p style='text-align: center'><a href='https://github.com/facebookresearch/LaViLa' target='_blank'>Github Repo</a> | <a href='https://arxiv.org/abs/2212.04501' target='_blank'>Paper on arxiv</a></p><center><img src='https://visitor-badge.glitch.me/badge?page_id=nateraw_lavila' alt='visitor badge'></center></p>"
interface = gr.Interface(
Pipeline(),
inputs=[
gr.Video(label='video_path'),
gr.Slider(0.0, 1.0, 0.7, label='temperature'),
gr.Slider(0.0, 1.0, 0.95, label='top_p'),
],
outputs='text',
examples=[['eating_spaghetti.mp4', 0.7, 0.95], ['assets/3c0dffd0-e38e-4643-bc48-d513943dc20b_012_014.mp4', 0.7, 0.95]],
title=title,
description=description,
article=article,
).launch(debug=True) |