Spaces:
Sleeping
Sleeping
| import torch | |
| import math | |
| import random | |
| import numpy as np | |
| from PIL import Image | |
| def random_insert_latent_frame( | |
| image_latent: torch.Tensor, | |
| noisy_model_input: torch.Tensor, | |
| target_latents: torch.Tensor, | |
| input_intervals: torch.Tensor, | |
| output_intervals: torch.Tensor, | |
| special_info | |
| ): | |
| """ | |
| Inserts latent frames into noisy input, pads targets, and builds flattened intervals with flags. | |
| Args: | |
| image_latent: [B, latent_count, C, H, W] | |
| noisy_model_input:[B, F, C, H, W] | |
| target_latents: [B, F, C, H, W] | |
| input_intervals: [B, N, frames_per_latent, L] | |
| output_intervals: [B, M, frames_per_latent, L] | |
| For each sample randomly choose: | |
| Mode A (50%): | |
| - Insert two image_latent frames at start of noisy input and targets. | |
| - Pad target_latents by prepending two zero-frames. | |
| - Pad input_intervals by repeating its last group once. | |
| Mode B (50%): | |
| - Insert one image_latent frame at start and repeat last noisy frame at end. | |
| - Pad target_latents by prepending one one-frame and appending last target frame. | |
| - Pad output_intervals by repeating its last group once. | |
| After padding intervals, flatten each group from [frames_per_latent, L] to [frames_per_latent * L], | |
| then append a 4-element flag (1 for input groups, 0 for output groups). | |
| Returns: | |
| outputs: Tensor [B, F+2, C, H, W] | |
| new_targets: Tensor [B, F+2, C, H, W] | |
| masks: Tensor [B, F+2] bool mask of latent inserts | |
| intervals: Tensor [B, N+M+1, fpl * L + 4] | |
| """ | |
| B, F, C, H, W = noisy_model_input.shape | |
| _, N, fpl, L = input_intervals.shape | |
| _, M, _, _ = output_intervals.shape | |
| device = noisy_model_input.device | |
| new_F = F + 1 if special_info == "just_one" else F + 2 | |
| outputs = torch.empty((B, new_F, C, H, W), device=device) | |
| masks = torch.zeros((B, new_F), dtype=torch.bool, device=device) | |
| combined_groups = N + M #+ 1 | |
| feature_len = fpl * L | |
| # intervals = torch.empty((B, combined_groups, feature_len + 4), device=device, | |
| # dtype=input_intervals.dtype) | |
| intervals = torch.empty((B, combined_groups, feature_len), device=device, | |
| dtype=input_intervals.dtype) | |
| new_targets = torch.empty((B, new_F, C, H, W), device=device, | |
| dtype=target_latents.dtype) | |
| for b in range(B): | |
| latent = image_latent[b, 0] | |
| frames = noisy_model_input[b] | |
| tgt = target_latents[b] | |
| limit = 10 if special_info == "use_a" else 0.5 | |
| if special_info == "just_one": #ALWAYS_MODE_A | |
| # Mode A: two latent inserts, zero-prefixed targets | |
| outputs[b, 0] = latent | |
| masks[b, :1] = True | |
| outputs[b, 1:] = frames | |
| # pad targets: two large-numbers - these should be ignored | |
| large_number = torch.ones_like(tgt[0])*10000 | |
| new_targets[b, 0] = large_number | |
| new_targets[b, 1:] = tgt | |
| # pad intervals: input + replicated last input group | |
| #pad_group = input_intervals[b, -1:].clone() | |
| in_groups = input_intervals[b] #torch.cat([input_intervals[b], pad_group], dim=0) | |
| out_groups = output_intervals[b] | |
| elif random.random() < limit: #ALWAYS_MODE_A | |
| # Mode A: two latent inserts, zero-prefixed targets | |
| outputs[b, 0] = latent | |
| outputs[b, 1] = latent | |
| masks[b, :2] = True | |
| outputs[b, 2:] = frames | |
| # pad targets: two large-numbers - these should be ignored | |
| large_number = torch.ones_like(tgt[0])*10000 | |
| new_targets[b, 0] = large_number | |
| new_targets[b, 1] = large_number | |
| new_targets[b, 2:] = tgt | |
| # pad intervals: input + replicated last input group | |
| pad_group = input_intervals[b, -1:].clone() | |
| in_groups = torch.cat([input_intervals[b], pad_group], dim=0) | |
| out_groups = output_intervals[b] | |
| else: | |
| # Mode B: one latent insert & last-frame repeat, one-prefixed/appended targets | |
| outputs[b, 0] = latent | |
| masks[b, 0] = True | |
| outputs[b, 1:new_F-1] = frames | |
| outputs[b, new_F-1] = frames[-1] | |
| # pad targets: one one-frame then original then last frame | |
| zero = torch.zeros_like(tgt[0]) | |
| new_targets[b, 0] = zero | |
| new_targets[b, 1:new_F-1] = tgt | |
| new_targets[b, new_F-1] = tgt[-1] | |
| # pad intervals: output + replicated last output group | |
| in_groups = input_intervals[b] | |
| pad_group = output_intervals[b, -1:].clone() | |
| out_groups = torch.cat([output_intervals[b], pad_group], dim=0) | |
| # flatten & flag groups | |
| flat_in = in_groups.reshape(-1, feature_len) | |
| proc_in = torch.cat([flat_in], dim=1) | |
| flat_out = out_groups.reshape(-1, feature_len) | |
| proc_out = torch.cat([flat_out], dim=1) | |
| intervals[b] = torch.cat([proc_in, proc_out], dim=0) | |
| return outputs, new_targets, masks, intervals | |
| def transform_intervals( | |
| intervals: torch.Tensor, | |
| frames_per_latent: int = 4, | |
| repeat_first: bool = True | |
| ) -> torch.Tensor: | |
| """ | |
| Pad and reshape intervals into [B, num_latent_frames, frames_per_latent, L]. | |
| Args: | |
| intervals: Tensor of shape [B, N, L] | |
| frames_per_latent: number of frames per latent group (e.g., 4) | |
| repeat_first: if True, pad at the beginning by repeating the first row; otherwise pad at the end by repeating the last row. | |
| Returns: | |
| Tensor of shape [B, num_latent_frames, frames_per_latent, L] | |
| """ | |
| B, N, L = intervals.shape | |
| num_latent = math.ceil(N / frames_per_latent) | |
| target_N = num_latent * frames_per_latent | |
| pad_count = target_N - N | |
| if pad_count > 0: | |
| # choose row to repeat | |
| pad_row = intervals[:, :1, :] if repeat_first else intervals[:, -1:, :] | |
| # replicate pad_row pad_count times | |
| pad = pad_row.repeat(1, pad_count, 1) | |
| # pad at beginning or end | |
| if repeat_first: | |
| expanded = torch.cat([pad, intervals], dim=1) | |
| else: | |
| expanded = torch.cat([intervals, pad], dim=1) | |
| else: | |
| expanded = intervals[:, :target_N, :] | |
| # reshape into latent-frame groups | |
| return expanded.view(B, num_latent, frames_per_latent, L) | |
| import random | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import random | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| def build_blur(frame_paths, gamma=2.2): | |
| """ | |
| Simulate motion blur using inverse-gamma (linear-light) summation: | |
| - Load each image, convert to float32 sRGB [0,255] | |
| - Linearize via inverse gamma: linear = (img/255)^gamma | |
| - Sum linear values, average, then re-encode via gamma: (linear_avg)^(1/gamma)*255 | |
| Returns a uint8 numpy array. | |
| """ | |
| acc_lin = None | |
| for p in frame_paths: | |
| img = np.array(Image.open(p).convert('RGB'), dtype=np.float32) | |
| # normalize to [0,1] then linearize | |
| lin = np.power(img / 255.0, gamma) | |
| acc_lin = lin if acc_lin is None else acc_lin + lin | |
| # average in linear domain | |
| avg_lin = acc_lin / len(frame_paths) | |
| # gamma-encode back to sRGB domain | |
| srgb = np.power(avg_lin, 1.0 / gamma) * 255.0 | |
| return np.clip(srgb, 0, 255).astype(np.uint8) | |
| def generate_1x_sequence(frame_paths, window_max =16, output_len=17, base_rate=1, start = None): | |
| """ | |
| 1× mode at arbitrary base_rate (units of 1/240s): | |
| - Treat each output step as the sum of `base_rate` consecutive raw frames. | |
| - Pick window size W ∈ [1, output_len] | |
| - Randomly choose start index so W*base_rate frames fit | |
| - Group raw frames into W groups of length base_rate | |
| - Build blur image over all W*base_rate frames for input | |
| - For each group, build a blurred output frame by summing its base_rate frames | |
| - Pad sequence of W blurred frames to output_len by repeating last blurred frame | |
| - Input interval always [-0.5, 0.5] | |
| - Output intervals reflect each group’s coverage within [-0.5,0.5] | |
| """ | |
| N = len(frame_paths) | |
| max_w = min(output_len, N // base_rate) | |
| max_w = min(max_w, window_max) | |
| W = random.randint(1, max_w) | |
| if start is not None: | |
| # choose start so that W*base_rate frames fit | |
| assert N >= W * base_rate, f"Not enough frames for base_rate={base_rate}, need {W * base_rate}, got {N}" | |
| else: | |
| start = random.randint(0, N - W * base_rate) | |
| # group start indices | |
| group_starts = [start + i * base_rate for i in range(W)] | |
| # flatten raw frame paths for blur input | |
| blur_paths = [] | |
| for gs in group_starts: | |
| blur_paths.extend(frame_paths[gs:gs + base_rate]) | |
| blur_img = build_blur(blur_paths) | |
| # build blurred output frames per group | |
| seq = [] | |
| for gs in group_starts: | |
| group = frame_paths[gs:gs + base_rate] | |
| seq.append(build_blur(group)) | |
| # pad with last blurred frame | |
| seq += [seq[-1]] * (output_len - len(seq)) | |
| input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float) | |
| # each group covers interval of length 1/W | |
| step = 1.0 / W | |
| intervals = [[-0.5 + i * step, -0.5 + (i + 1) * step] for i in range(W)] | |
| num_frames = len(intervals) | |
| intervals += [intervals[-1]] * (output_len - W) | |
| output_intervals = torch.tensor(intervals, dtype=torch.float) | |
| return blur_img, seq, input_interval, output_intervals, num_frames | |
| def generate_2x_sequence(frame_paths, window_max =16, output_len=17, base_rate=1): | |
| """ | |
| 2× mode: | |
| - Logical window of W output-steps so that 2*W ≤ output_len | |
| - Raw window spans W*base_rate frames | |
| - Build blur only over that raw window (flattened) for input | |
| - before_count = W//2, after_count = W - before_count | |
| - Define groups for before, during, and after each of length base_rate | |
| - Build blurred frames for each group | |
| - Pad sequence of 2*W blurred frames to output_len by repeating last | |
| - Input interval always [-0.5,0.5] | |
| - Output intervals relative to window: each group’s center | |
| """ | |
| N = len(frame_paths) | |
| max_w = min(output_len // 2, N // base_rate) | |
| max_w = min(max_w, window_max) | |
| W = random.randint(1, max_w) | |
| before_count = W // 2 | |
| after_count = W - before_count | |
| # choose start so that before and after stay within bounds | |
| min_start = before_count * base_rate | |
| max_start = N - (W + after_count) * base_rate | |
| # ensure we can pick a valid start, else fail | |
| assert max_start >= min_start, f"Cannot satisfy before/after window for W={W}, base_rate={base_rate}, N={N}" | |
| start = random.randint(min_start, max_start) | |
| # window group starts | |
| window_starts = [start + i * base_rate for i in range(W)] | |
| # flatten for blur input | |
| blur_paths = [] | |
| for gs in window_starts: | |
| blur_paths.extend(frame_paths[gs:gs + base_rate]) | |
| blur_img = build_blur(blur_paths) | |
| # define before/after group starts | |
| before_count = W // 2 | |
| after_count = W - before_count | |
| before_starts = [max(0, start - (i + 1) * base_rate) for i in range(before_count)][::-1] | |
| after_starts = [min(N - base_rate, start + W * base_rate + i * base_rate) for i in range(after_count)] | |
| # all group starts in sequence | |
| group_starts = before_starts + window_starts + after_starts | |
| # build blurred frames per group | |
| seq = [] | |
| for gs in group_starts: | |
| group = frame_paths[gs:gs + base_rate] | |
| seq.append(build_blur(group)) | |
| # pad blurred frames to output_len | |
| seq += [seq[-1]] * (output_len - len(seq)) | |
| input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float) | |
| # each group covers 1/(2W) around its center within [-0.5,0.5] | |
| half = 0.5 / W | |
| centers = [((gs - start) / (W * base_rate)) - 0.5 + half | |
| for gs in group_starts] | |
| intervals = [[c - half, c + half] for c in centers] | |
| num_frames = len(intervals) | |
| intervals += [intervals[-1]] * (output_len - len(intervals)) | |
| output_intervals = torch.tensor(intervals, dtype=torch.float) | |
| return blur_img, seq, input_interval, output_intervals, num_frames | |
| def generate_large_blur_sequence(frame_paths, window_max=16, output_len=17, base_rate=1): | |
| """ | |
| Large blur mode (fixed output_len=25) with instantaneous outputs: | |
| - Raw window spans 25 * base_rate consecutive frames | |
| - Build blur over that full raw window for input | |
| - For output sequence: | |
| • Pick 1 raw frame every `base_rate` (group_starts) | |
| • Each output frame is the instantaneous frame at that raw index | |
| - Input interval always [-0.5, 0.5] | |
| - Output intervals reflect each 1-frame slice’s coverage within the blur window, | |
| leaving gaps between. | |
| """ | |
| N = len(frame_paths) | |
| total_raw = window_max * base_rate | |
| assert N >= total_raw, f"Not enough frames for base_rate={base_rate}, need {total_raw}, got {N}" | |
| start = random.randint(0, N - total_raw) | |
| # build blur input over the full raw block | |
| raw_block = frame_paths[start:start + total_raw] | |
| blur_img = build_blur(raw_block) | |
| # output sequence: instantaneous frames at each group_start | |
| seq = [] | |
| group_starts = [start + i * base_rate for i in range(window_max)] | |
| for gs in group_starts: | |
| img = np.array(Image.open(frame_paths[gs]).convert('RGB'), dtype=np.uint8) | |
| seq.append(img) | |
| # pad blurred frames to output_len | |
| seq += [seq[-1]] * (output_len - len(seq)) | |
| # compute intervals for each instantaneous frame: | |
| # each covers [gs, gs+1) over total_raw, normalized to [-0.5, 0.5] | |
| intervals = [] | |
| for gs in group_starts: | |
| t0 = (gs - start) / total_raw - 0.5 | |
| t1 = (gs + 1 - start) / total_raw - 0.5 | |
| intervals.append([t0, t1]) | |
| num_frames = len(intervals) | |
| intervals += [intervals[-1]] * (output_len - len(intervals)) | |
| output_intervals = torch.tensor(intervals, dtype=torch.float) | |
| # input interval | |
| input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float) | |
| return blur_img, seq, input_interval, output_intervals, num_frames | |
| def generate_test_case(frame_paths, | |
| window_max=16, | |
| output_len=17, | |
| in_start=None, | |
| in_end=None, | |
| out_start=None, | |
| out_end = None, | |
| center=None, | |
| mode="1x", | |
| fps=240): | |
| """ | |
| Generate blurred input + a target sequence + normalized intervals. | |
| Args: | |
| frame_paths: list of all frame filepaths | |
| window_max: number of groups/bins W | |
| output_len: desired length of the output sequence | |
| in_start, in_end: integer indices defining the raw window [in_start, in_end) | |
| mode: one of "1x", "2x", or "lb" | |
| fps: frames-per-second (only used to override mode=="2x" if fps==120) | |
| Returns: | |
| blur_img: np.ndarray of the global blur over the window | |
| seq: list of np.ndarray, length = output_len (blured groups or raw frames) | |
| input_interval: torch.Tensor [[-0.5, 0.5]] | |
| output_intervals: torch.Tensor shape [output_len, 2], normalized in [-0.5,0.5] | |
| """ | |
| # 1) slice and blur | |
| raw_paths = frame_paths[in_start:in_end] | |
| blur_img = build_blur(raw_paths) | |
| # 2) build the sequence | |
| # one target per frame | |
| seq = [ | |
| np.array(Image.open(p).convert("RGB"), dtype=np.uint8) | |
| for p in frame_paths[out_start:out_end] | |
| ] | |
| # 3) compute normalized intervals | |
| input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float) | |
| # 2) define the normalizer | |
| def normalize(x, in_start, in_end): | |
| return (x - in_start) / (in_end - in_start) - 0.5 | |
| base_rate = 240 // fps | |
| # 3) define the raw intervals in absolute frame‐indices | |
| base_rate = 240 // fps | |
| if mode == "1x": | |
| assert in_start == out_start and in_end == out_end | |
| #assert fps == 240, "haven't implemented 120fps in 1x yet" | |
| W = (out_end - out_start) // base_rate | |
| # one frame per window | |
| group_starts = [out_start + i * base_rate for i in range(W)] | |
| group_ends = [out_start + (i + 1) * base_rate for i in range(W)] | |
| elif mode == "2x": | |
| W = (out_end - out_start) // base_rate | |
| # every base_rate frames, starting at out_start | |
| group_starts = [out_start + i * base_rate for i in range(W)] | |
| group_ends = [out_start + (i + 1) * base_rate for i in range(W)] | |
| elif mode == "lb": | |
| W = (out_end - out_start) // base_rate | |
| # sparse “key‐frame” windows from the raw input range | |
| group_starts = [in_start + i * base_rate for i in range(W)] | |
| group_ends = [s + 1 for s in group_starts] | |
| else: | |
| raise ValueError(f"Unsupported mode: {mode}") | |
| # --- after mode‐switch, once you have raw group_starts & group_ends --- | |
| # 4) build a summed video sequence by blurring each interval | |
| summed_seq = [] | |
| for s, e in zip(group_starts, group_ends): | |
| # make sure indices lie in [in_start, in_end) | |
| s_clamped = max(in_start, min(s, in_end-1)) | |
| e_clamped = max(s_clamped+1, min(e, in_end)) | |
| # sum/blur the frames in [s_clamped:e_clamped) | |
| summed = build_blur(frame_paths[s_clamped:e_clamped]) | |
| summed_seq.append(summed) | |
| # pad to output_len | |
| if len(summed_seq) < output_len: | |
| summed_seq += [summed_seq[-1]] * (output_len - len(summed_seq)) | |
| # 5) now normalize your intervals as before | |
| def normalize(x): | |
| return (x - in_start) / (in_end - in_start) - 0.5 | |
| intervals = [[normalize(s), normalize(e)] for s, e in zip(group_starts, group_ends)] | |
| num_frames = len(intervals) | |
| if len(intervals) < output_len: | |
| intervals += [intervals[-1]] * (output_len - len(intervals)) | |
| output_intervals = torch.tensor(intervals, dtype=torch.float) | |
| # final return now also includes summed_seq | |
| return blur_img, summed_seq, input_interval, output_intervals, seq, num_frames | |
| def get_conditioning( | |
| output_len=17, | |
| in_start=None, | |
| in_end=None, | |
| out_start=None, | |
| out_end=None, | |
| mode="1x", | |
| fps=240, | |
| ): | |
| """ | |
| Generate normalized intervals conditioning singals. Just like the above function but without | |
| loading any images (for inference only). | |
| Args: | |
| output_len: desired length of the output sequence | |
| in_start, in_end: integer indices defining the raw window [in_start, in_end) | |
| mode: one of "1x", "2x", or "lb" | |
| fps: frames-per-second (only used to override mode=="2x" if fps==120) | |
| Returns: | |
| input_interval: torch.Tensor [[-0.5, 0.5]] | |
| output_intervals: torch.Tensor shape [output_len, 2], normalized in [-0.5,0.5] | |
| """ | |
| # 3) compute normalized intervals | |
| input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float) | |
| # 2) define the normalizer | |
| def normalize(x, in_start, in_end): | |
| return (x - in_start) / (in_end - in_start) - 0.5 | |
| base_rate = 240 // fps | |
| # 3) define the raw intervals in absolute frame‐indices | |
| base_rate = 240 // fps | |
| if mode == "1x": | |
| assert in_start == out_start and in_end == out_end | |
| #assert fps == 240, "haven't implemented 120fps in 1x yet" | |
| W = (out_end - out_start) // base_rate | |
| # one frame per window | |
| group_starts = [out_start + i * base_rate for i in range(W)] | |
| group_ends = [out_start + (i + 1) * base_rate for i in range(W)] | |
| elif mode == "2x": | |
| W = (out_end - out_start) // base_rate | |
| # every base_rate frames, starting at out_start | |
| group_starts = [out_start + i * base_rate for i in range(W)] | |
| group_ends = [out_start + (i + 1) * base_rate for i in range(W)] | |
| elif mode == "lb": | |
| W = (out_end - out_start) // base_rate | |
| # sparse “key‐frame” windows from the raw input range | |
| group_starts = [in_start + i * base_rate for i in range(W)] | |
| group_ends = [s + 1 for s in group_starts] | |
| else: | |
| raise ValueError(f"Unsupported mode: {mode}") | |
| # 5) now normalize your intervals as before | |
| def normalize(x): | |
| return (x - in_start) / (in_end - in_start) - 0.5 | |
| intervals = [[normalize(s), normalize(e)] for s, e in zip(group_starts, group_ends)] | |
| num_frames = len(intervals) | |
| if len(intervals) < output_len: | |
| intervals += [intervals[-1]] * (output_len - len(intervals)) | |
| output_intervals = torch.tensor(intervals, dtype=torch.float) | |
| return input_interval, output_intervals, num_frames | |