| import torch |
| import re |
| import numpy as np |
| import torch |
| import cv2 |
| import os |
| import math |
| from copy import deepcopy |
| from typing import Tuple |
| import pandas as pd |
| import io |
| from pydub import AudioSegment |
| import librosa |
| from PIL import Image |
|
|
|
|
| def preprocess_image_tensor(image_path, device, target_dtype, h_w_multiple_of=32, resize_total_area=720*720): |
| """Preprocess video data into standardized tensor format and (optionally) resize area.""" |
| def _parse_area(val): |
| if val is None: |
| return None |
| if isinstance(val, (int, float)): |
| return int(val) |
| if isinstance(val, (tuple, list)) and len(val) == 2: |
| return int(val[0]) * int(val[1]) |
| if isinstance(val, str): |
| m = re.match(r"\s*(\d+)\s*[x\*\s]\s*(\d+)\s*$", val, flags=re.IGNORECASE) |
| if m: |
| return int(m.group(1)) * int(m.group(2)) |
| if val.strip().isdigit(): |
| return int(val.strip()) |
| raise ValueError(f"resize_total_area={val!r} is not a valid area or WxH.") |
|
|
| def _best_hw_for_area(h, w, area_target, multiple): |
| if area_target <= 0: |
| return h, w |
| ratio_wh = w / float(h) |
| area_unit = multiple * multiple |
| tgt_units = max(1, area_target // area_unit) |
| p0 = max(1, int(round(np.sqrt(tgt_units / max(ratio_wh, 1e-8))))) |
| candidates = [] |
| for dp in range(-3, 4): |
| p = max(1, p0 + dp) |
| q = max(1, int(round(p * ratio_wh))) |
| H = p * multiple |
| W = q * multiple |
| candidates.append((H, W)) |
| scale = np.sqrt(area_target / (h * float(w))) |
| H_sc = max(multiple, int(round(h * scale / multiple)) * multiple) |
| W_sc = max(multiple, int(round(w * scale / multiple)) * multiple) |
| candidates.append((H_sc, W_sc)) |
| def score(HW): |
| H, W = HW |
| area = H * W |
| return (abs(area - area_target), abs((W / max(H, 1e-8)) - ratio_wh)) |
| H_best, W_best = min(candidates, key=score) |
| return H_best, W_best |
|
|
| if isinstance(image_path, str): |
| image = cv2.imread(image_path) |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| elif isinstance(image_path, np.ndarray): |
| image = image_path |
| else: |
| assert isinstance(image_path, Image.Image) |
| if image_path.mode != "RGB": |
| image_path = image_path.convert("RGB") |
| image = np.array(image_path) |
|
|
| image = image.transpose(2, 0, 1) |
| image = image.astype(np.float32) / 255.0 |
|
|
| image_tensor = torch.from_numpy(image).float().to(device, dtype=target_dtype).unsqueeze(0) |
| image_tensor = image_tensor * 2.0 - 1.0 |
|
|
| _, c, h, w = image_tensor.shape |
| area_target = _parse_area(resize_total_area) |
| if area_target is not None: |
| target_h, target_w = _best_hw_for_area(h, w, area_target, h_w_multiple_of) |
| else: |
| target_h = (h // h_w_multiple_of) * h_w_multiple_of |
| target_w = (w // h_w_multiple_of) * h_w_multiple_of |
|
|
| target_h = max(h_w_multiple_of, int(target_h)) |
| target_w = max(h_w_multiple_of, int(target_w)) |
|
|
| if (h != target_h) or (w != target_w): |
| image_tensor = torch.nn.functional.interpolate( |
| image_tensor, |
| size=(target_h, target_w), |
| mode='bicubic', |
| align_corners=False |
| ) |
|
|
| return image_tensor |
|
|
|
|
| def preprocess_audio_tensor(audio_path, device, target_dtype, sr=16000, clip_len=None, normalize=True): |
| """Preprocess audio data into standardized tensor format.""" |
| wave_data, sample_rate = librosa.load(audio_path, sr=sr, mono=True) |
| assert (sample_rate == sr) |
| if normalize: |
| wave_data = wave_data / (np.max(np.abs(wave_data)) + 1e-6) * 0.95 |
| if clip_len is not None: |
| wave_data = deepcopy(wave_data[:clip_len]) |
| audio_tensor = torch.from_numpy(wave_data).float().squeeze().unsqueeze(0).to(device) |
| |
| return audio_tensor |
|
|
|
|
| def calc_dims_from_area( |
| aspect_ratio: str, |
| total_area: int = 720*720, |
| divisible_by: int = 32 |
| ) -> Tuple[int, int]: |
| """ |
| Calculate width and height given an aspect ratio (h:w), total area, |
| and divisibility constraint. |
| |
| Args: |
| aspect_ratio (str): Aspect ratio string in format "h:w" (e.g., "9:16"). |
| total_area (int): Target maximum area (width * height ≤ total_area). |
| divisible_by (int): Force width and height to be divisible by this value. |
| |
| Returns: |
| (width, height): Tuple of integers that satisfy constraints. |
| """ |
| |
| h_ratio, w_ratio = map(int, aspect_ratio.split(":")) |
|
|
| |
| gcd = math.gcd(h_ratio, w_ratio) |
| h_ratio //= gcd |
| w_ratio //= gcd |
|
|
| |
| k = math.sqrt(total_area / (h_ratio * w_ratio)) |
|
|
| |
| height = (int(k * h_ratio) // divisible_by) * divisible_by |
| width = (int(k * w_ratio) // divisible_by) * divisible_by |
|
|
| |
| height = max(height, divisible_by) |
| width = max(width, divisible_by) |
|
|
| return height, width |
|
|
|
|
| def snap_hw_to_multiple_of_32(h: int, w: int, area = 720 * 720) -> tuple[int, int]: |
| """ |
| Scale (h, w) to match a target area if provided, then snap both |
| dimensions to the nearest multiple of 32 (min 32). |
| |
| Args: |
| h (int): original height |
| w (int): original width |
| area (int, optional): target area to scale to. If None, no scaling is applied. |
| |
| Returns: |
| (new_h, new_w): dimensions adjusted |
| """ |
| if h <= 0 or w <= 0: |
| raise ValueError(f"h and w must be positive, got {(h, w)}") |
|
|
| |
| if area is not None and area > 0: |
| current_area = h * w |
| scale = math.sqrt(area / float(current_area)) |
| h = int(round(h * scale)) |
| w = int(round(w * scale)) |
|
|
| |
| def _n32(x: int) -> int: |
| return max(32, int(round(x / 32)) * 32) |
|
|
| return _n32(h), _n32(w) |
|
|
|
|
| def scale_hw_to_area_divisible(h, w, area=1024*1024, n=16): |
| """ |
| Scale (h, w) so that area ≈ A, while keeping aspect ratio, |
| and then round so both are divisible by n. |
| |
| Args: |
| h (int): original height |
| w (int): original width |
| A (int or float): target area |
| n (int): divisibility requirement |
| |
| Returns: |
| (new_h, new_w): scaled and adjusted dimensions |
| """ |
| |
| current_area = h * w |
|
|
| if current_area == 0: |
| raise ValueError("Height and width must be positive") |
|
|
| |
| scale = math.sqrt(area / current_area) |
|
|
| |
| new_h = h * scale |
| new_w = w * scale |
|
|
| |
| new_h = int(round(new_h / n) * n) |
| new_w = int(round(new_w / n) * n) |
|
|
| |
| new_h = max(new_h, n) |
| new_w = max(new_w, n) |
|
|
| return new_h, new_w |
|
|
|
|
| def validate_and_process_user_prompt(text_prompt: str, image_path: str = None, ip_image_path: str = None, ip_audio_path: str = None, mode: str = "id2v") -> str: |
| if not isinstance(text_prompt, str): |
| raise ValueError("User input must be a string") |
|
|
| |
| text_prompt = text_prompt.strip() |
|
|
| |
| if os.path.isfile(text_prompt): |
| _, ext = os.path.splitext(text_prompt.lower()) |
| |
| if ext == ".csv": |
| df = pd.read_csv(text_prompt) |
| df = df.fillna("") |
| elif ext == ".tsv": |
| df = pd.read_csv(text_prompt, sep="\t") |
| df = df.fillna("") |
| else: |
| raise ValueError(f"Unsupported file type: {ext}. Only .csv and .tsv are allowed.") |
|
|
| assert "text_prompt" in df.keys(), f"Missing required columns in TSV file." |
| text_prompts = list(df["text_prompt"]) |
| if mode == "i2v" and 'image_path' in df.keys(): |
| image_paths = list(df["image_path"]) |
| assert all(p is None or len(p) == 0 or os.path.isfile(p) for p in image_paths), "One or more image paths in the TSV file do not exist." |
| else: |
| print("Warning: image_path was not found, assuming t2v or t2i2v mode...") |
| image_paths = [None] * len(text_prompts) |
| |
| if mode == "id2v" and 'ip_image_path' in df.keys(): |
| ip_image_paths = list(df["ip_image_path"]) |
| print(f"ip images: {ip_image_paths}") |
| assert all(p is None or len(p) == 0 or os.path.isfile(p) for p in ip_image_paths), "One or more ip image paths in the TSV file do not exist." |
| else: |
| print("Warning: ip_image_path was not found, assuming i2v or t2v or t2i2v mode...") |
| ip_image_paths = [None] * len(text_prompts) |
| |
| if mode == "id2v" and 'ip_audio_path' in df.keys(): |
| ip_audio_paths = list(df["ip_audio_path"]) |
| assert all(p is None or len(p) == 0 or os.path.isfile(p) for p in ip_audio_paths), "One or more ip audio paths in the TSV file do not exist." |
| else: |
| print("Warning: ip_audio_path was not found, assuming i2v or t2v or t2i2v mode...") |
| ip_audio_paths = [None] * len(text_prompts) |
| else: |
| assert image_path is None or os.path.isfile(image_path), f"Image path is not None but {image_path} does not exist." |
| assert ip_image_path is None or os.path.isfile(ip_image_path), f"IP image path is not None but {ip_image_path} does not exist." |
| assert ip_audio_path is None or os.path.isfile(ip_audio_path), f"IP audio path is not None but {ip_audio_path} does not exist." |
| text_prompts = [text_prompt] |
| image_paths = [image_path] |
| ip_image_paths = [ip_image_path] |
| ip_audio_paths = [ip_audio_path] |
| |
| return text_prompts, image_paths, ip_image_paths, ip_audio_paths |
|
|
|
|
| def format_prompt_for_filename(text: str) -> str: |
| |
| no_tags = re.sub(r"<.*?>", "", text) |
| |
| safe = no_tags.replace(" ", "_").replace("/", "_") |
| |
| return safe[:50] |
|
|
|
|
| def audio_bytes_to_tensor(audio_bytes, target_sr=16000): |
| """ |
| Convert audio bytes to a 16kHz mono torch tensor in [-1, 1]. |
| |
| Args: |
| audio_bytes (bytes): Raw audio bytes |
| target_sr (int): Target sample rate |
| |
| Returns: |
| torch.Tensor: shape (num_samples,) |
| int: sample rate |
| """ |
| |
| audio = AudioSegment.from_file(io.BytesIO(audio_bytes), format="wav") |
|
|
| |
| if audio.channels != 1: |
| audio = audio.set_channels(1) |
|
|
| |
| if audio.frame_rate != target_sr: |
| audio = audio.set_frame_rate(target_sr) |
|
|
| |
| samples = np.array(audio.get_array_of_samples()) |
| samples = samples.astype(np.float32) / np.iinfo(samples.dtype).max |
|
|
| |
| tensor = torch.from_numpy(samples) |
|
|
| return tensor, target_sr |
|
|
|
|
| def audio_path_to_tensor(path, target_sr=16000): |
| with open(path, "rb") as f: |
| audio_bytes = f.read() |
| return audio_bytes_to_tensor(audio_bytes, target_sr=target_sr) |
|
|
|
|
| def clean_text(text: str) -> str: |
| """ |
| Remove all text between <S>...</E> and <AUDCAP>...</ENDAUDCAP> tags, |
| including the tags themselves. |
| """ |
| |
| text = re.sub(r"<S>.*?<E>", "", text, flags=re.DOTALL) |
|
|
| |
| text = re.sub(r"<AUDCAP>.*?<ENDAUDCAP>", "", text, flags=re.DOTALL) |
|
|
| |
| return text.strip() |