Spaces:
Runtime error
Runtime error
# video_processor.py | |
from io import BytesIO | |
import av | |
import base64 | |
from PIL import Image | |
from typing import List | |
from dataclasses import dataclass | |
def sample(N, K): | |
array = list(range(N)) | |
length = len(array) | |
if K >= length or K<2: | |
return array | |
k = length // K | |
sampled_points = [array[i] for i in range(0, length, k)][:K-1] | |
sampled_points.append(array[-1]) | |
return sampled_points | |
def grid_sample(array, N, K): | |
group_size, remainder = len(array) // K, len(array) % K | |
sampled_groups = [] | |
for i in range(K): | |
s = i * group_size + min(i, remainder) | |
e = s + group_size + (1 if i < remainder else 0) | |
group = array[s:e] | |
if N >= len(group): | |
sampled_groups.append(group) | |
else: | |
interval = len(group) // N | |
sampled_groups.append([group[j * interval] for j in range(N)]) | |
return sampled_groups | |
class VideoProcessor: | |
frame_format: str = "JPEG" | |
frame_limit: int = 10 | |
def _decode(self, video_path: str) -> List[Image.Image]: | |
frames = [] | |
with av.open(video_path) as container: | |
src = container.streams.video[0] | |
time_base = src.time_base | |
framerate = src.average_rate | |
for i in sample(src.frames, self.frame_limit): | |
n = round((i / framerate) / time_base) | |
container.seek(n, backward=True, stream=src) | |
frame = next(container.decode(video=0)) | |
im = frame.to_image() | |
frames.append(im) | |
return frames | |
def decode(self, video_path: str) -> List[Image.Image]: | |
frames = [] | |
container = av.open(video_path) | |
for i, frame in enumerate(container.decode(video=0)): | |
if i % self.frame_skip: | |
continue | |
im = frame.to_image() | |
frames.append(im) | |
return frames | |
def concatenate(self, frames: List[Image.Image], direction: str = "horizontal") -> Image.Image: | |
widths, heights = zip(*(frame.size for frame in frames)) | |
if direction == "horizontal": | |
total_width = sum(widths) | |
max_height = max(heights) | |
concatenated_image = Image.new('RGB', (total_width, max_height)) | |
x_offset = 0 | |
for frame in frames: | |
concatenated_image.paste(frame, (x_offset, 0)) | |
x_offset += frame.width | |
else: | |
max_width = max(widths) | |
total_height = sum(heights) | |
concatenated_image = Image.new('RGB', (max_width, total_height)) | |
y_offset = 0 | |
for frame in frames: | |
concatenated_image.paste(frame, (0, y_offset)) | |
y_offset += frame.height | |
return concatenated_image | |
def grid_concatenate(self, frames: List[Image.Image], group_size, limit=10) -> List[Image.Image]: | |
sampled_groups = grid_sample(frames, group_size, limit) | |
return [self.concatenate(group) for group in sampled_groups] | |
def to_base64_list(self, images: List[Image.Image]) -> List[str]: | |
base64_list = [] | |
for image in images: | |
buffered = BytesIO() | |
image.save(buffered, format=self.frame_format) | |
base64_list.append(base64.b64encode(buffered.getvalue()).decode('utf-8')) | |
return base64_list | |