| | import os |
| | import os.path as osp |
| | import math |
| | import cv2 |
| | from PIL import Image |
| | import torch |
| | from torchvision import transforms |
| | from plyfile import PlyData, PlyElement |
| | import numpy as np |
| |
|
| | def load_images_as_tensor(path='data/truck', interval=1, PIXEL_LIMIT=255000): |
| | """ |
| | Loads images from a directory or video, resizes them to a uniform size, |
| | then converts and stacks them into a single [N, 3, H, W] PyTorch tensor. |
| | """ |
| | sources = [] |
| | |
| | |
| | if osp.isdir(path): |
| | print(f"Loading images from directory: {path}") |
| | filenames = sorted([x for x in os.listdir(path) if x.lower().endswith(('.png', '.jpg', '.jpeg'))]) |
| | for i in range(0, len(filenames), interval): |
| | img_path = osp.join(path, filenames[i]) |
| | try: |
| | sources.append(Image.open(img_path).convert('RGB')) |
| | except Exception as e: |
| | print(f"Could not load image {filenames[i]}: {e}") |
| | elif path.lower().endswith('.mp4'): |
| | print(f"Loading frames from video: {path}") |
| | cap = cv2.VideoCapture(path) |
| | if not cap.isOpened(): raise IOError(f"Cannot open video file: {path}") |
| | frame_idx = 0 |
| | while True: |
| | ret, frame = cap.read() |
| | if not ret: break |
| | if frame_idx % interval == 0: |
| | rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | sources.append(Image.fromarray(rgb_frame)) |
| | frame_idx += 1 |
| | cap.release() |
| | else: |
| | raise ValueError(f"Unsupported path. Must be a directory or a .mp4 file: {path}") |
| |
|
| | if not sources: |
| | print("No images found or loaded.") |
| | return torch.empty(0) |
| |
|
| | print(f"Found {len(sources)} images/frames. Processing...") |
| |
|
| | |
| | |
| | first_img = sources[0] |
| | W_orig, H_orig = first_img.size |
| | scale = math.sqrt(PIXEL_LIMIT / (W_orig * H_orig)) if W_orig * H_orig > 0 else 1 |
| | W_target, H_target = W_orig * scale, H_orig * scale |
| | k, m = round(W_target / 14), round(H_target / 14) |
| | while (k * 14) * (m * 14) > PIXEL_LIMIT: |
| | if k / m > W_target / H_target: k -= 1 |
| | else: m -= 1 |
| | TARGET_W, TARGET_H = max(1, k) * 14, max(1, m) * 14 |
| | print(f"All images will be resized to a uniform size: ({TARGET_W}, {TARGET_H})") |
| |
|
| | |
| | tensor_list = [] |
| | |
| | to_tensor_transform = transforms.ToTensor() |
| | |
| | for img_pil in sources: |
| | try: |
| | |
| | resized_img = img_pil.resize((TARGET_W, TARGET_H), Image.Resampling.LANCZOS) |
| | |
| | img_tensor = to_tensor_transform(resized_img) |
| | tensor_list.append(img_tensor) |
| | except Exception as e: |
| | print(f"Error processing an image: {e}") |
| |
|
| | if not tensor_list: |
| | print("No images were successfully processed.") |
| | return torch.empty(0) |
| |
|
| | |
| | return torch.stack(tensor_list, dim=0) |
| |
|
| |
|
| | def tensor_to_pil(tensor): |
| | """ |
| | Converts a PyTorch tensor to a PIL image. Automatically moves the channel dimension |
| | (if it has size 3) to the last axis before converting. |
| | |
| | Args: |
| | tensor (torch.Tensor): Input tensor. Expected shape can be [C, H, W], [H, W, C], or [H, W]. |
| | |
| | Returns: |
| | PIL.Image: The converted PIL image. |
| | """ |
| | if torch.is_tensor(tensor): |
| | array = tensor.detach().cpu().numpy() |
| | else: |
| | array = tensor |
| |
|
| | return array_to_pil(array) |
| |
|
| |
|
| | def array_to_pil(array): |
| | """ |
| | Converts a NumPy array to a PIL image. Automatically: |
| | - Squeezes dimensions of size 1. |
| | - Moves the channel dimension (if it has size 3) to the last axis. |
| | |
| | Args: |
| | array (np.ndarray): Input array. Expected shape can be [C, H, W], [H, W, C], or [H, W]. |
| | |
| | Returns: |
| | PIL.Image: The converted PIL image. |
| | """ |
| | |
| | array = np.squeeze(array) |
| | |
| | |
| | if array.ndim == 3 and array.shape[0] == 3: |
| | array = np.transpose(array, (1, 2, 0)) |
| | |
| | |
| | if array.ndim == 2: |
| | return Image.fromarray((array * 255).astype(np.uint8), mode="L") |
| | elif array.ndim == 3 and array.shape[2] == 3: |
| | return Image.fromarray((array * 255).astype(np.uint8), mode="RGB") |
| | else: |
| | raise ValueError(f"Unsupported array shape for PIL conversion: {array.shape}") |
| |
|
| |
|
| | def rotate_target_dim_to_last_axis(x, target_dim=3): |
| | shape = x.shape |
| | axis_to_move = -1 |
| | |
| | |
| | for i in range(len(shape) - 1, -1, -1): |
| | if shape[i] == target_dim: |
| | axis_to_move = i |
| | break |
| |
|
| | |
| | if axis_to_move != -1 and axis_to_move != len(shape) - 1: |
| | |
| | dims_order = list(range(len(shape))) |
| | dims_order.pop(axis_to_move) |
| | dims_order.append(axis_to_move) |
| | |
| | |
| | ret = x.transpose(*dims_order) |
| | else: |
| | ret = x |
| |
|
| | return ret |
| |
|
| |
|
| | def write_ply( |
| | xyz, |
| | rgb=None, |
| | path='output.ply', |
| | ) -> None: |
| | if torch.is_tensor(xyz): |
| | xyz = xyz.detach().cpu().numpy() |
| |
|
| | if torch.is_tensor(rgb): |
| | rgb = rgb.detach().cpu().numpy() |
| |
|
| | if rgb is not None and rgb.max() > 1: |
| | rgb = rgb / 255. |
| |
|
| | xyz = rotate_target_dim_to_last_axis(xyz, 3) |
| | xyz = xyz.reshape(-1, 3) |
| |
|
| | if rgb is not None: |
| | rgb = rotate_target_dim_to_last_axis(rgb, 3) |
| | rgb = rgb.reshape(-1, 3) |
| | |
| | if rgb is None: |
| | min_coord = np.min(xyz, axis=0) |
| | max_coord = np.max(xyz, axis=0) |
| | normalized_coord = (xyz - min_coord) / (max_coord - min_coord + 1e-8) |
| | |
| | hue = 0.7 * normalized_coord[:,0] + 0.2 * normalized_coord[:,1] + 0.1 * normalized_coord[:,2] |
| | hsv = np.stack([hue, 0.9*np.ones_like(hue), 0.8*np.ones_like(hue)], axis=1) |
| |
|
| | c = hsv[:,2:] * hsv[:,1:2] |
| | x = c * (1 - np.abs( (hsv[:,0:1]*6) % 2 - 1 )) |
| | m = hsv[:,2:] - c |
| | |
| | rgb = np.zeros_like(hsv) |
| | cond = (0 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 1) |
| | rgb[cond] = np.hstack([c[cond], x[cond], np.zeros_like(x[cond])]) |
| | cond = (1 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 2) |
| | rgb[cond] = np.hstack([x[cond], c[cond], np.zeros_like(x[cond])]) |
| | cond = (2 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 3) |
| | rgb[cond] = np.hstack([np.zeros_like(x[cond]), c[cond], x[cond]]) |
| | cond = (3 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 4) |
| | rgb[cond] = np.hstack([np.zeros_like(x[cond]), x[cond], c[cond]]) |
| | cond = (4 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 5) |
| | rgb[cond] = np.hstack([x[cond], np.zeros_like(x[cond]), c[cond]]) |
| | cond = (5 <= hsv[:,0]*6%6) & (hsv[:,0]*6%6 < 6) |
| | rgb[cond] = np.hstack([c[cond], np.zeros_like(x[cond]), x[cond]]) |
| | rgb = (rgb + m) |
| |
|
| | dtype = [ |
| | ("x", "f4"), |
| | ("y", "f4"), |
| | ("z", "f4"), |
| | ("nx", "f4"), |
| | ("ny", "f4"), |
| | ("nz", "f4"), |
| | ("red", "u1"), |
| | ("green", "u1"), |
| | ("blue", "u1"), |
| | ] |
| | normals = np.zeros_like(xyz) |
| | elements = np.empty(xyz.shape[0], dtype=dtype) |
| | attributes = np.concatenate((xyz, normals, rgb * 255), axis=1) |
| | elements[:] = list(map(tuple, attributes)) |
| | vertex_element = PlyElement.describe(elements, "vertex") |
| | ply_data = PlyData([vertex_element]) |
| | ply_data.write(path) |