Dinov2-Video / utils.py
EduardoPacheco's picture
First commit
ce78b5d
raw
history blame
No virus
4.72 kB
from typing import List
import cv2
import torch
import numpy as np
from tqdm import tqdm
import supervision as sv
import torch.nn.functional as F
from transformers import AutoModel
from sklearn.decomposition import PCA
from torchvision import transforms as T
from sklearn.preprocessing import MinMaxScaler
def load_video_frames(video_path: str) -> List[np.ndarray]:
frames = []
for frame in tqdm(sv.get_video_frames_generator(source_path=video_path), unit=" frames"):
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
return frames
def preprocess(image: np.ndarray, n_patches: int, device: str, patch_size: int = 14) -> torch.Tensor:
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
transform = T.Compose([
T.Resize((n_patches * patch_size, n_patches * patch_size)),
T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
])
img = torch.from_numpy(image).type(torch.float).permute(2, 0, 1) / 255
img_tensor = transform(img).unsqueeze(0).to(device)
return img_tensor
def process_video(
model: AutoModel,
video: str | List[np.ndarray],
is_larger: bool = True,
batch_size: int = 4,
threshold: float = 0.5,
n_patches: int = 40,
interpolate: bool = False,
device: str = "cpu"
) -> List[np.ndarray]:
# NP = N_PATCHES
# P = PATCH_SIZE
if isinstance(video, str):
frames = load_video_frames(video)
else:
frames = video
patch_size = model.config.patch_size
original_height = frames[0].shape[0] # C, H, W
original_width = frames[0].shape[1] # C, H, W
final_frames = []
pca = PCA(n_components=3)
scaler = MinMaxScaler(clip=True)
for i in range(len(frames)//batch_size):
batch = frames[i*batch_size:batch_size*(i+1)]
pixel_values = [
preprocess(f, n_patches, device, patch_size).squeeze(0) for f in batch
]
pixel_values = torch.stack(pixel_values) # B, C, NP * P, NP * P
with torch.no_grad():
out = model(pixel_values=pixel_values)
features = out.last_hidden_state[:, 1:] # B, P * P, HIDDEN_DIM
features = features.cpu().numpy()
features = features.reshape(batch_size * n_patches * n_patches, -1) # B * P * P, HIDDEN_DIM
pca_features = pca.fit_transform(features)
pca_features = scaler.fit_transform(pca_features)
if is_larger:
pca_features_bg = pca_features[:, 0] > threshold
else:
pca_features_bg = pca_features[:, 0] < threshold
pca_features_fg = ~pca_features_bg
pca_features_fg_seg = pca.fit_transform(features[pca_features_fg])
pca_features_fg_seg = scaler.fit_transform(pca_features_fg_seg)
pca_features_rgb = np.zeros((batch_size * n_patches * n_patches, 3))
pca_features_rgb[pca_features_bg] = 0
pca_features_rgb[pca_features_fg] = pca_features_fg_seg
pca_features_rgb = pca_features_rgb.reshape(batch_size, n_patches, n_patches, 3)
if interpolate:
# transformed into torch tensor
pca_features_rgb = torch.from_numpy(pca_features_rgb) # B, P, P, 3
# reshaped to B, C, P, P
pca_features_rgb = pca_features_rgb.permute(0, 3, 1, 2)
# interpolate to B, C, H, W
# reshaped to B, H, W, C
# unbind to a list of len B with np.ndarray of shape H, W, C
pca_features_rgb = F.interpolate(
pca_features_rgb,
size=(original_height, original_width),
mode='bilinear',
align_corners=False
).permute(0, 2, 3, 1).unbind(0)
# Fixing range to np.uint8
else:
pca_features_rgb = [f for f in pca_features_rgb]
# Adding to final_frames list
final_frames.extend(pca_features_rgb)
return final_frames
def create_video_from_frames_rgb(
frame_list: List[np.ndarray],
output_filename: str = "animation.mp4",
fps: int = 15
) -> str:
# Get the shape of the frames to determine video dimensions
frame_height, frame_width, _ = frame_list[0].shape
# Define the codec and create a VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # You can change the codec as needed
out = cv2.VideoWriter(output_filename, fourcc, fps, (frame_width, frame_height))
for frame in frame_list:
# Convert the frame from RGB to BGR
bgr_frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
# Write the frame to the video file
out.write(bgr_frame)
# Release the VideoWriter object
out.release()
return output_filename