Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from multiprocessing import Pool | |
import numpy as np | |
import random | |
from PIL import Image | |
import re | |
import cv2 | |
import glob | |
from natsort import natsorted | |
class MultiProcessImageSaver(object): | |
def __init__(self, n_workers=1): | |
self.pool = Pool(n_workers) | |
def __call__(self, images, output_files, resizes=None): | |
if resizes is None: | |
resizes = [None for _ in range(len(images))] | |
return self.pool.imap( | |
self.save_image, | |
zip(images, output_files, resizes), | |
) | |
def close(self): | |
self.pool.close() | |
self.pool.join() | |
def save_image(args): | |
image, filename, resize = args | |
image = Image.fromarray(image) | |
if resize is not None: | |
image = image.resize(tuple(resize)) | |
image.save(filename) | |
def list_dir_with_full_path(path): | |
return [os.path.join(path, f) for f in os.listdir(path)] | |
def find_all_files_in_dir(path): | |
files = [] | |
for root, _, files in os.walk(path): | |
for file in files: | |
files.append(os.path.join(root, file)) | |
return files | |
def is_image(path): | |
return ( | |
path.endswith('.jpg') | |
or path.endswith('.png') | |
or path.endswith('.jpeg') | |
or path.endswith('.JPG') | |
or path.endswith('.PNG') | |
or path.endswith('.JPEG') | |
) | |
def is_video(path): | |
return ( | |
path.endswith('.mp4') | |
or path.endswith('.avi') | |
or path.endswith('.MP4') | |
or path.endswith('.AVI') | |
or path.endswith('.webm') | |
or path.endswith('.WEBM') | |
or path.endswith('.mkv') | |
or path.endswith('.MVK') | |
) | |
def random_square_crop(img, random_generator=None): | |
# If no random generator is provided, use numpy's default | |
if random_generator is None: | |
random_generator = np.random.default_rng() | |
# Get the width and height of the image | |
width, height = img.size | |
# Determine the shorter side | |
min_size = min(width, height) | |
# Randomly determine the starting x and y coordinates for the crop | |
if width > height: | |
left = random_generator.integers(0, width - min_size) | |
upper = 0 | |
else: | |
left = 0 | |
upper = random_generator.integers(0, height - min_size) | |
# Calculate the ending x and y coordinates for the crop | |
right = left + min_size | |
lower = upper + min_size | |
# Crop the image | |
return img.crop((left, upper, right, lower)) | |
def read_image_to_tensor(path, center_crop=1.0): | |
pil_im = Image.open(path).convert('RGB') | |
if center_crop < 1.0: | |
width, height = pil_im.size | |
pil_im = pil_im.crop(( | |
int((1 - center_crop) * height / 2), int((1 + center_crop) * height / 2), | |
int((1 - center_crop) * width / 2), int((1 + center_crop) * width / 2), | |
)) | |
input_img = pil_im.resize((256, 256)) | |
input_img = np.array(input_img) / 255.0 | |
input_img = input_img.astype(np.float32) | |
return input_img | |
def match_mulitple_path(root_dir, regex): | |
videos = [] | |
for root, _, files in os.walk(root_dir): | |
for file in files: | |
videos.append(os.path.join(root, file)) | |
videos = [v for v in videos if not v.split('/')[-1].startswith('.')] | |
grouped_path = {} | |
for r in regex: | |
r = re.compile(r) | |
for v in videos: | |
matched = r.findall(v) | |
if len(matched) > 0: | |
groups = matched[0] | |
if groups not in grouped_path: | |
grouped_path[groups] = [] | |
grouped_path[groups].append(v) | |
grouped_path = { | |
k: tuple(v) for k, v in grouped_path.items() | |
if len(v) == len(regex) | |
} | |
return list(grouped_path.values()) | |
def randomly_subsample_frame_indices(length, n_frames, max_stride=30, random_start=True): | |
assert length >= n_frames | |
max_stride = min( | |
(length - 1) // (n_frames - 1), | |
max_stride | |
) | |
stride = np.random.randint(1, max_stride + 1) | |
if random_start: | |
start = np.random.randint(0, length - (n_frames - 1) * stride) | |
else: | |
start = 0 | |
return np.arange(n_frames) * stride + start | |
def read_frames_from_dir(dir_path, n_frames, stride, random_start=True, center_crop=1.0): | |
files = [os.path.join(dir_path, x) for x in os.listdir(dir_path)] | |
files = natsorted([x for x in files if is_image(x)]) | |
total_frames = len(files) | |
if total_frames < n_frames: | |
return None | |
max_stride = (total_frames - 1) // (n_frames - 1) | |
stride = min(max_stride, stride) | |
if random_start: | |
start = np.random.randint(0, total_frames - (n_frames - 1) * stride) | |
else: | |
start = 0 | |
frame_indices = np.arange(n_frames) * stride + start | |
frames = [] | |
for frame_index in sorted(frame_indices): | |
# Check if the frame_index is valid | |
frames.append(read_image_to_tensor(files[frame_index], center_crop=center_crop)) | |
if len(frames) < n_frames: | |
return None | |
frames = np.stack(frames, axis=0) | |
return frames | |
def read_frames_from_video(video_path, n_frames, stride, random_start=True, center_crop=1.0): | |
frames = [] | |
cap = cv2.VideoCapture(video_path) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
if total_frames < n_frames: | |
cap.release() | |
return None | |
max_stride = (total_frames - 1) // (n_frames - 1) | |
stride = min(max_stride, stride) | |
if random_start: | |
start = np.random.randint(0, total_frames - (n_frames - 1) * stride) | |
else: | |
start = 0 | |
frame_indices = np.arange(n_frames) * stride + start | |
for frame_index in sorted(frame_indices): | |
# Check if the frame_index is valid | |
if 0 <= frame_index < total_frames: | |
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) | |
ret, frame = cap.read() | |
if ret: | |
if center_crop < 1.0: | |
height, width, _ = frame.shape | |
frame = frame[ | |
int((1 - center_crop) * height / 2):int((1 + center_crop) * height / 2), | |
int((1 - center_crop) * width / 2):int((1 + center_crop) * width / 2), | |
: | |
] | |
frame = cv2.resize(frame, (256, 256)) | |
frames.append(frame) | |
else: | |
print(f"Frame index {frame_index} is out of bounds. Skipping...") | |
cap.release() | |
if len(frames) < n_frames: | |
return None | |
frames = np.stack(frames, axis=0).astype(np.float32) / 255.0 | |
# From BGR to RGB | |
return np.stack( | |
[frames[..., 2], frames[..., 1], frames[..., 0]], axis=-1 | |
) | |
def read_all_frames_from_video(video_path, center_crop=1.0): | |
frames = [] | |
cap = cv2.VideoCapture(video_path) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
for frame_index in range(total_frames): | |
# Check if the frame_index is valid | |
if 0 <= frame_index < total_frames: | |
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) | |
ret, frame = cap.read() | |
if ret: | |
if center_crop < 1.0: | |
height, width, _ = frame.shape | |
frame = frame[ | |
int((1 - center_crop) * height / 2):int((1 + center_crop) * height / 2), | |
int((1 - center_crop) * width / 2):int((1 + center_crop) * width / 2), | |
: | |
] | |
frames.append(cv2.resize(frame, (256, 256))) | |
else: | |
print(f"Frame index {frame_index} is out of bounds. Skipping...") | |
cap.release() | |
if len(frames) == 0: | |
return None | |
frames = np.stack(frames, axis=0).astype(np.float32) / 255.0 | |
# From BGR to RGB | |
return np.stack( | |
[frames[..., 2], frames[..., 1], frames[..., 0]], axis=-1 | |
) | |
def read_max_span_frames_from_video(video_path, n_frames): | |
frames = [] | |
cap = cv2.VideoCapture(video_path) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
if total_frames < n_frames: | |
cap.release() | |
return None | |
stride = (total_frames - 1) // (n_frames - 1) | |
frame_indices = np.arange(n_frames) * stride | |
frames = [] | |
for frame_index in frame_indices: | |
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) | |
ret, frame = cap.read() | |
if ret: | |
frames.append(cv2.resize(frame, (256, 256))) | |
cap.release() | |
if len(frames) < n_frames: | |
return None | |
frames = np.stack(frames, axis=0).astype(np.float32) / 255.0 | |
# From BGR to RGB | |
return np.stack( | |
[frames[..., 2], frames[..., 1], frames[..., 0]], axis=-1 | |
) | |