Spaces:
Running
Running
File size: 4,144 Bytes
65947b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
from typing import List, Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch
from decord import VideoReader, cpu
from einops import rearrange
from PIL import Image
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision import transforms
from torchvision.transforms import ToPILImage
def get_frames(
path: str, transform: transforms.Compose, num_frames: int = 16
) -> Tuple[torch.Tensor, List[int]]:
vr = VideoReader(path, ctx=cpu(0))
tmp = np.arange(0, num_frames * 2, 2) + 60
frame_id_list = tmp.tolist()
video_data = vr.get_batch(frame_id_list).asnumpy()
frames, _ = transform(
(
[
Image.fromarray(video_data[vid, :, :, :]).convert("RGB")
for vid, _ in enumerate(frame_id_list)
],
None,
)
)
frames = frames.view((num_frames, 3) + frames.size()[-2:]).transpose(0, 1)
return frames, frame_id_list
def prepare_frames_masks(
frames: torch.Tensor, masks: torch.Tensor, device: "torch.device"
) -> Tuple[torch.Tensor, torch.Tensor]:
frames = frames.unsqueeze(0)
masks = masks.unsqueeze(0)
frames = frames.to(device, non_blocking=True)
masks = masks.to(device, non_blocking=True).flatten(1).to(torch.bool)
return frames, masks
def get_videomae_outputs(
frames: torch.Tensor,
masks: torch.Tensor,
outputs: torch.Tensor,
ids: List[int],
patch_size: Tuple[int, ...],
device: "torch.device",
):
visualisations = []
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None, None]
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None, None]
ori_img = frames * std + mean # in [0, 1]
original_images = [
ToPILImage()(ori_img[0, :, vid, :, :].cpu()) for vid, _ in enumerate(ids)
]
img_squeeze = rearrange(
ori_img,
"b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2) c",
p0=2,
p1=patch_size[0],
p2=patch_size[0],
)
img_norm = (img_squeeze - img_squeeze.mean(dim=-2, keepdim=True)) / (
img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6
)
img_patch = rearrange(img_norm, "b n p c -> b n (p c)")
img_patch[masks] = outputs
# make mask
mask = torch.ones_like(img_patch)
mask[masks] = 0
mask = rearrange(mask, "b n (p c) -> b n p c", c=3)
mask = rearrange(
mask,
"b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2) ",
p0=2,
p1=patch_size[0],
p2=patch_size[1],
h=14,
w=14,
)
# save reconstruction video
rec_img = rearrange(img_patch, "b n (p c) -> b n p c", c=3)
rec_img = rec_img * (
img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6
) + img_squeeze.mean(dim=-2, keepdim=True)
rec_img = rearrange(
rec_img,
"b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)",
p0=2,
p1=patch_size[0],
p2=patch_size[1],
h=14,
w=14,
)
reconstructed_images = [
ToPILImage()(rec_img[0, :, vid, :, :].cpu().clamp(0, 0.996))
for vid, _ in enumerate(ids)
]
# save masked video
img_mask = rec_img * mask
masked_images = [
ToPILImage()(img_mask[0, :, vid, :, :].cpu()) for vid, _ in enumerate(ids)
]
assert len(original_images) == len(reconstructed_images) == len(masked_images)
for i in range(len(original_images)):
visualisations.append(
[original_images[i], masked_images[i], reconstructed_images[i]]
)
return visualisations
def create_plot(images):
num_cols = 3
num_rows = 16
column_names = ["Original Patch", "Masked Patch", "Reconstructed Patch"]
fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 48))
for i in range(num_rows):
for j in range(num_cols):
axes[i, j].imshow(images[i][j])
axes[i, j].axis("off")
if i == 0:
axes[i, j].set_title(column_names[j], fontsize=16)
plt.tight_layout()
return fig
|