diff --git a/README.md b/README.md
index cd07d8dfdc6d66bbcfc1677ea9492238f41aaca4..905d4166401ee22a117c7b11fb6daae1acfeb064 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,55 @@
---
-title: MOFA-Video Traj
-emoji: 📚
-colorFrom: red
-colorTo: gray
-sdk: gradio
-sdk_version: 4.36.1
-app_file: app.py
-pinned: false
license: apache-2.0
+sdk_version: 4.5.0
---
+## Updates 🔥🔥🔥
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+We have released the Gradio demo for **Hybrid (Trajectory + Landmark)** Controls [HERE](https://huggingface.co/MyNiuuu/MOFA-Video-Hybrid)!
+
+
+## Introduction
+
+This repo provides the inference Gradio demo for Trajectory Control of MOFA-Video.
+
+## Environment Setup
+
+`pip install -r requirements.txt`
+
+## Download checkpoints
+
+1. Download the pretrained checkpoints of [SVD_xt](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1) from huggingface to `./ckpts`.
+
+2. Download the checkpint of [MOFA-Adapter](https://huggingface.co/MyNiuuu/MOFA-Video-Traj) from huggingface to `./ckpts`.
+
+The final structure of checkpoints should be:
+
+
+```text
+./ckpts/
+|-- controlnet
+| |-- config.json
+| `-- diffusion_pytorch_model.safetensors
+|-- stable-video-diffusion-img2vid-xt-1-1
+| |-- feature_extractor
+| |-- ...
+| |-- image_encoder
+| |-- ...
+| |-- scheduler
+| |-- ...
+| |-- unet
+| |-- ...
+| |-- vae
+| |-- ...
+| |-- svd_xt_1_1.safetensors
+| `-- model_index.json
+```
+
+## Run Gradio Demo
+
+`python run_gradio.py`
+
+Please refer to the instructions on the gradio interface during the inference process.
+
+## Paper
+
+arxiv.org/abs/2405.20222
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..1855c619c5c275b4e7394bff94e52c88bbc491d6
--- /dev/null
+++ b/app.py
@@ -0,0 +1,838 @@
+import gradio as gr
+import numpy as np
+import cv2
+import os
+from PIL import Image, ImageFilter
+import uuid
+from scipy.interpolate import interp1d, PchipInterpolator
+import torchvision
+# from utils import *
+import time
+from tqdm import tqdm
+import imageio
+
+import torch
+import torch.nn.functional as F
+import torchvision
+import torchvision.transforms as transforms
+from einops import rearrange, repeat
+
+from packaging import version
+
+from accelerate.utils import set_seed
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+
+from diffusers import AutoencoderKLTemporalDecoder, EulerDiscreteScheduler
+from diffusers.utils import check_min_version
+from diffusers.utils.import_utils import is_xformers_available
+
+from utils.flow_viz import flow_to_image
+from utils.utils import split_filename, image2arr, image2pil, ensure_dirname
+
+
+output_dir_video = "./outputs/videos"
+output_dir_frame = "./outputs/frames"
+
+
+ensure_dirname(output_dir_video)
+ensure_dirname(output_dir_frame)
+
+
+def divide_points_afterinterpolate(resized_all_points, motion_brush_mask):
+ k = resized_all_points.shape[0]
+ starts = resized_all_points[:, 0] # [K, 2]
+
+ in_masks = []
+ out_masks = []
+
+ for i in range(k):
+ x, y = int(starts[i][1]), int(starts[i][0])
+ if motion_brush_mask[x][y] == 255:
+ in_masks.append(resized_all_points[i])
+ else:
+ out_masks.append(resized_all_points[i])
+
+ in_masks = np.array(in_masks)
+ out_masks = np.array(out_masks)
+
+ return in_masks, out_masks
+
+
+def get_sparseflow_and_mask_forward(
+ resized_all_points,
+ n_steps, H, W,
+ is_backward_flow=False
+ ):
+
+ K = resized_all_points.shape[0]
+
+ starts = resized_all_points[:, 0] # [K, 2]
+
+ interpolated_ends = resized_all_points[:, 1:]
+
+ s_flow = np.zeros((K, n_steps, H, W, 2))
+ mask = np.zeros((K, n_steps, H, W))
+
+ for k in range(K):
+ for i in range(n_steps):
+ start, end = starts[k], interpolated_ends[k][i]
+ flow = np.int64(end - start) * (-1 if is_backward_flow is True else 1)
+ s_flow[k][i][int(start[1]), int(start[0])] = flow
+ mask[k][i][int(start[1]), int(start[0])] = 1
+
+ s_flow = np.sum(s_flow, axis=0)
+ mask = np.sum(mask, axis=0)
+
+ return s_flow, mask
+
+
+
+def init_models(pretrained_model_name_or_path, resume_from_checkpoint, weight_dtype, device='cuda', enable_xformers_memory_efficient_attention=False, allow_tf32=False):
+
+ from models.unet_spatio_temporal_condition_controlnet import UNetSpatioTemporalConditionControlNetModel
+ from pipeline.pipeline import FlowControlNetPipeline
+ from models.svdxt_featureflow_forward_controlnet_s2d_fixcmp_norefine import FlowControlNet, CMP_demo
+
+ print('start loading models...')
+ # Load scheduler, tokenizer and models.
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ pretrained_model_name_or_path, subfolder="image_encoder", revision=None, variant="fp16"
+ )
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(
+ pretrained_model_name_or_path, subfolder="vae", revision=None, variant="fp16")
+ unet = UNetSpatioTemporalConditionControlNetModel.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="unet",
+ low_cpu_mem_usage=True,
+ variant="fp16",
+ )
+
+ controlnet = FlowControlNet.from_pretrained(resume_from_checkpoint)
+
+ cmp = CMP_demo(
+ './models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/config.yaml',
+ 42000
+ ).to(device)
+ cmp.requires_grad_(False)
+
+ # Freeze vae and image_encoder
+ vae.requires_grad_(False)
+ image_encoder.requires_grad_(False)
+ unet.requires_grad_(False)
+ controlnet.requires_grad_(False)
+
+ # Move image_encoder and vae to gpu and cast to weight_dtype
+ image_encoder.to(device, dtype=weight_dtype)
+ vae.to(device, dtype=weight_dtype)
+ unet.to(device, dtype=weight_dtype)
+ controlnet.to(device, dtype=weight_dtype)
+
+ if enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ print(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError(
+ "xformers is not available. Make sure it is installed correctly")
+
+ if allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ pipeline = FlowControlNetPipeline.from_pretrained(
+ pretrained_model_name_or_path,
+ unet=unet,
+ controlnet=controlnet,
+ image_encoder=image_encoder,
+ vae=vae,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(device)
+
+ print('models loaded.')
+
+ return pipeline, cmp
+
+
+def interpolate_trajectory(points, n_points):
+ x = [point[0] for point in points]
+ y = [point[1] for point in points]
+
+ t = np.linspace(0, 1, len(points))
+
+ fx = PchipInterpolator(t, x)
+ fy = PchipInterpolator(t, y)
+
+ new_t = np.linspace(0, 1, n_points)
+
+ new_x = fx(new_t)
+ new_y = fy(new_t)
+ new_points = list(zip(new_x, new_y))
+
+ return new_points
+
+
+def visualize_drag_v2(background_image_path, splited_tracks, width, height):
+ trajectory_maps = []
+
+ background_image = Image.open(background_image_path).convert('RGBA')
+ background_image = background_image.resize((width, height))
+ w, h = background_image.size
+ transparent_background = np.array(background_image)
+ transparent_background[:, :, -1] = 128
+ transparent_background = Image.fromarray(transparent_background)
+
+ # Create a transparent layer with the same size as the background image
+ transparent_layer = np.zeros((h, w, 4))
+ for splited_track in splited_tracks:
+ if len(splited_track) > 1:
+ splited_track = interpolate_trajectory(splited_track, 16)
+ splited_track = splited_track[:16]
+ for i in range(len(splited_track)-1):
+ start_point = (int(splited_track[i][0]), int(splited_track[i][1]))
+ end_point = (int(splited_track[i+1][0]), int(splited_track[i+1][1]))
+ vx = end_point[0] - start_point[0]
+ vy = end_point[1] - start_point[1]
+ arrow_length = np.sqrt(vx**2 + vy**2)
+ if i == len(splited_track)-2:
+ cv2.arrowedLine(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2, tipLength=8 / arrow_length)
+ else:
+ cv2.line(transparent_layer, start_point, end_point, (255, 0, 0, 192), 2)
+ else:
+ cv2.circle(transparent_layer, (int(splited_track[0][0]), int(splited_track[0][1])), 2, (255, 0, 0, 192), -1)
+
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
+ trajectory_maps.append(trajectory_map)
+ return trajectory_maps, transparent_layer
+
+
+class Drag:
+ def __init__(self, device, height, width, model_length):
+ self.device = device
+
+ svd_ckpt = "ckpts/stable-video-diffusion-img2vid-xt-1-1"
+ mofa_ckpt = "ckpts/controlnet"
+
+ self.device = 'cuda'
+ self.weight_dtype = torch.float16
+
+ self.pipeline, self.cmp = init_models(
+ svd_ckpt,
+ mofa_ckpt,
+ weight_dtype=self.weight_dtype,
+ device=self.device
+ )
+
+ self.height = height
+ self.width = width
+ self.model_length = model_length
+
+ def get_cmp_flow(self, frames, sparse_optical_flow, mask, brush_mask=None):
+
+ '''
+ frames: [b, 13, 3, 384, 384] (0, 1) tensor
+ sparse_optical_flow: [b, 13, 2, 384, 384] (-384, 384) tensor
+ mask: [b, 13, 2, 384, 384] {0, 1} tensor
+ '''
+
+ b, t, c, h, w = frames.shape
+ assert h == 384 and w == 384
+ frames = frames.flatten(0, 1) # [b*13, 3, 256, 256]
+ sparse_optical_flow = sparse_optical_flow.flatten(0, 1) # [b*13, 2, 256, 256]
+ mask = mask.flatten(0, 1) # [b*13, 2, 256, 256]
+ cmp_flow = self.cmp.run(frames, sparse_optical_flow, mask) # [b*13, 2, 256, 256]
+
+ if brush_mask is not None:
+ brush_mask = torch.from_numpy(brush_mask) / 255.
+ brush_mask = brush_mask.to(cmp_flow.device, dtype=cmp_flow.dtype)
+ brush_mask = brush_mask.unsqueeze(0).unsqueeze(0)
+ cmp_flow = cmp_flow * brush_mask
+
+ cmp_flow = cmp_flow.reshape(b, t, 2, h, w)
+ return cmp_flow
+
+
+ def get_flow(self, pixel_values_384, sparse_optical_flow_384, mask_384, motion_brush_mask=None):
+
+ fb, fl, fc, _, _ = pixel_values_384.shape
+
+ controlnet_flow = self.get_cmp_flow(
+ pixel_values_384[:, 0:1, :, :, :].repeat(1, fl, 1, 1, 1),
+ sparse_optical_flow_384,
+ mask_384, motion_brush_mask
+ )
+
+ if self.height != 384 or self.width != 384:
+ scales = [self.height / 384, self.width / 384]
+ controlnet_flow = F.interpolate(controlnet_flow.flatten(0, 1), (self.height, self.width), mode='nearest').reshape(fb, fl, 2, self.height, self.width)
+ controlnet_flow[:, :, 0] *= scales[1]
+ controlnet_flow[:, :, 1] *= scales[0]
+
+ return controlnet_flow
+
+
+ @torch.no_grad()
+ def forward_sample(self, input_drag_384_inmask, input_drag_384_outmask, input_first_frame, input_mask_384_inmask, input_mask_384_outmask, in_mask_flag, out_mask_flag, motion_brush_mask=None, ctrl_scale=1., outputs=dict()):
+ '''
+ input_drag: [1, 13, 320, 576, 2]
+ input_drag_384: [1, 13, 384, 384, 2]
+ input_first_frame: [1, 3, 320, 576]
+ '''
+
+ seed = 42
+ num_frames = self.model_length
+
+ set_seed(seed)
+
+ input_first_frame_384 = F.interpolate(input_first_frame, (384, 384))
+ input_first_frame_384 = input_first_frame_384.repeat(num_frames - 1, 1, 1, 1).unsqueeze(0)
+ input_first_frame_pil = Image.fromarray(np.uint8(input_first_frame[0].cpu().permute(1, 2, 0)*255))
+ height, width = input_first_frame.shape[-2:]
+
+ input_drag_384_inmask = input_drag_384_inmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384]
+ mask_384_inmask = input_mask_384_inmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384]
+ input_drag_384_outmask = input_drag_384_outmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384]
+ mask_384_outmask = input_mask_384_outmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384]
+
+ print('start diffusion process...')
+
+ input_drag_384_inmask = input_drag_384_inmask.to(self.device, dtype=self.weight_dtype)
+ mask_384_inmask = mask_384_inmask.to(self.device, dtype=self.weight_dtype)
+ input_drag_384_outmask = input_drag_384_outmask.to(self.device, dtype=self.weight_dtype)
+ mask_384_outmask = mask_384_outmask.to(self.device, dtype=self.weight_dtype)
+
+ input_first_frame_384 = input_first_frame_384.to(self.device, dtype=self.weight_dtype)
+
+ if in_mask_flag:
+ flow_inmask = self.get_flow(
+ input_first_frame_384,
+ input_drag_384_inmask, mask_384_inmask, motion_brush_mask
+ )
+ else:
+ fb, fl = mask_384_inmask.shape[:2]
+ flow_inmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype)
+
+ if out_mask_flag:
+ flow_outmask = self.get_flow(
+ input_first_frame_384,
+ input_drag_384_outmask, mask_384_outmask
+ )
+ else:
+ fb, fl = mask_384_outmask.shape[:2]
+ flow_outmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype)
+
+ inmask_no_zero = (flow_inmask != 0).all(dim=2)
+ inmask_no_zero = inmask_no_zero.unsqueeze(2).expand_as(flow_inmask)
+
+ controlnet_flow = torch.where(inmask_no_zero, flow_inmask, flow_outmask)
+
+ val_output = self.pipeline(
+ input_first_frame_pil,
+ input_first_frame_pil,
+ controlnet_flow,
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ decode_chunk_size=8,
+ motion_bucket_id=127,
+ fps=7,
+ noise_aug_strength=0.02,
+ controlnet_cond_scale=ctrl_scale,
+ )
+
+ video_frames, estimated_flow = val_output.frames[0], val_output.controlnet_flow
+
+ for i in range(num_frames):
+ img = video_frames[i]
+ video_frames[i] = np.array(img)
+ video_frames = torch.from_numpy(np.array(video_frames)).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255.
+
+ print(video_frames.shape)
+
+ viz_esti_flows = []
+ for i in range(estimated_flow.shape[1]):
+ temp_flow = estimated_flow[0][i].permute(1, 2, 0)
+ viz_esti_flows.append(flow_to_image(temp_flow))
+ viz_esti_flows = [np.uint8(np.ones_like(viz_esti_flows[-1]) * 255)] + viz_esti_flows
+ viz_esti_flows = np.stack(viz_esti_flows) # [t-1, h, w, c]
+
+ total_nps = viz_esti_flows
+
+ outputs['logits_imgs'] = video_frames
+ outputs['flows'] = torch.from_numpy(total_nps).cuda().permute(0, 3, 1, 2).unsqueeze(0) / 255.
+
+ return outputs
+
+ @torch.no_grad()
+ def get_cmp_flow_from_tracking_points(self, tracking_points, motion_brush_mask, first_frame_path):
+
+ original_width, original_height = self.width, self.height
+
+ input_all_points = tracking_points.constructor_args['value']
+
+ if len(input_all_points) == 0 or len(input_all_points[-1]) == 1:
+ return np.uint8(np.ones((original_width, original_height, 3))*255)
+
+ resized_all_points = [tuple([tuple([int(e1[0]*self.width/original_width), int(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
+ resized_all_points_384 = [tuple([tuple([int(e1[0]*384/original_width), int(e1[1]*384/original_height)]) for e1 in e]) for e in input_all_points]
+
+ new_resized_all_points = []
+ new_resized_all_points_384 = []
+ for tnum in range(len(resized_all_points)):
+ new_resized_all_points.append(interpolate_trajectory(input_all_points[tnum], self.model_length))
+ new_resized_all_points_384.append(interpolate_trajectory(resized_all_points_384[tnum], self.model_length))
+
+ resized_all_points = np.array(new_resized_all_points)
+ resized_all_points_384 = np.array(new_resized_all_points_384)
+
+ motion_brush_mask_384 = cv2.resize(motion_brush_mask, (384, 384), cv2.INTER_NEAREST)
+
+ resized_all_points_384_inmask, resized_all_points_384_outmask = \
+ divide_points_afterinterpolate(resized_all_points_384, motion_brush_mask_384)
+
+ in_mask_flag = False
+ out_mask_flag = False
+
+ if resized_all_points_384_inmask.shape[0] != 0:
+ in_mask_flag = True
+ input_drag_384_inmask, input_mask_384_inmask = \
+ get_sparseflow_and_mask_forward(
+ resized_all_points_384_inmask,
+ self.model_length - 1, 384, 384
+ )
+ else:
+ input_drag_384_inmask, input_mask_384_inmask = \
+ np.zeros((self.model_length - 1, 384, 384, 2)), \
+ np.zeros((self.model_length - 1, 384, 384))
+
+ if resized_all_points_384_outmask.shape[0] != 0:
+ out_mask_flag = True
+ input_drag_384_outmask, input_mask_384_outmask = \
+ get_sparseflow_and_mask_forward(
+ resized_all_points_384_outmask,
+ self.model_length - 1, 384, 384
+ )
+ else:
+ input_drag_384_outmask, input_mask_384_outmask = \
+ np.zeros((self.model_length - 1, 384, 384, 2)), \
+ np.zeros((self.model_length - 1, 384, 384))
+
+ input_drag_384_inmask = torch.from_numpy(input_drag_384_inmask).unsqueeze(0).to(self.device) # [1, 13, h, w, 2]
+ input_mask_384_inmask = torch.from_numpy(input_mask_384_inmask).unsqueeze(0).to(self.device) # [1, 13, h, w]
+ input_drag_384_outmask = torch.from_numpy(input_drag_384_outmask).unsqueeze(0).to(self.device) # [1, 13, h, w, 2]
+ input_mask_384_outmask = torch.from_numpy(input_mask_384_outmask).unsqueeze(0).to(self.device) # [1, 13, h, w]
+
+ first_frames_transform = transforms.Compose([
+ lambda x: Image.fromarray(x),
+ transforms.ToTensor(),
+ ])
+
+ input_first_frame = image2arr(first_frame_path)
+ input_first_frame = repeat(first_frames_transform(input_first_frame), 'c h w -> b c h w', b=1).to(self.device)
+
+ seed = 42
+ num_frames = self.model_length
+
+ set_seed(seed)
+
+ input_first_frame_384 = F.interpolate(input_first_frame, (384, 384))
+ input_first_frame_384 = input_first_frame_384.repeat(num_frames - 1, 1, 1, 1).unsqueeze(0)
+
+ input_drag_384_inmask = input_drag_384_inmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384]
+ mask_384_inmask = input_mask_384_inmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384]
+ input_drag_384_outmask = input_drag_384_outmask.permute(0, 1, 4, 2, 3) # [1, 13, 2, 384, 384]
+ mask_384_outmask = input_mask_384_outmask.unsqueeze(2).repeat(1, 1, 2, 1, 1) # [1, 13, 2, 384, 384]
+
+ input_drag_384_inmask = input_drag_384_inmask.to(self.device, dtype=self.weight_dtype)
+ mask_384_inmask = mask_384_inmask.to(self.device, dtype=self.weight_dtype)
+ input_drag_384_outmask = input_drag_384_outmask.to(self.device, dtype=self.weight_dtype)
+ mask_384_outmask = mask_384_outmask.to(self.device, dtype=self.weight_dtype)
+
+ input_first_frame_384 = input_first_frame_384.to(self.device, dtype=self.weight_dtype)
+
+ if in_mask_flag:
+ flow_inmask = self.get_flow(
+ input_first_frame_384,
+ input_drag_384_inmask, mask_384_inmask, motion_brush_mask_384
+ )
+ else:
+ fb, fl = mask_384_inmask.shape[:2]
+ flow_inmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype)
+
+ if out_mask_flag:
+ flow_outmask = self.get_flow(
+ input_first_frame_384,
+ input_drag_384_outmask, mask_384_outmask
+ )
+ else:
+ fb, fl = mask_384_outmask.shape[:2]
+ flow_outmask = torch.zeros(fb, fl, 2, self.height, self.width).to(self.device, dtype=self.weight_dtype)
+
+ inmask_no_zero = (flow_inmask != 0).all(dim=2)
+ inmask_no_zero = inmask_no_zero.unsqueeze(2).expand_as(flow_inmask)
+
+ controlnet_flow = torch.where(inmask_no_zero, flow_inmask, flow_outmask)
+
+ controlnet_flow = controlnet_flow[0, -1].permute(1, 2, 0)
+ viz_esti_flows = flow_to_image(controlnet_flow) # [h, w, c]
+
+ return viz_esti_flows
+
+ def run(self, first_frame_path, tracking_points, inference_batch_size, motion_brush_mask, motion_brush_viz, ctrl_scale):
+
+ original_width, original_height = self.width, self.height
+
+ input_all_points = tracking_points.constructor_args['value']
+ resized_all_points = [tuple([tuple([int(e1[0]*self.width/original_width), int(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
+ resized_all_points_384 = [tuple([tuple([int(e1[0]*384/original_width), int(e1[1]*384/original_height)]) for e1 in e]) for e in input_all_points]
+
+ new_resized_all_points = []
+ new_resized_all_points_384 = []
+ for tnum in range(len(resized_all_points)):
+ new_resized_all_points.append(interpolate_trajectory(input_all_points[tnum], self.model_length))
+ new_resized_all_points_384.append(interpolate_trajectory(resized_all_points_384[tnum], self.model_length))
+
+ resized_all_points = np.array(new_resized_all_points)
+ resized_all_points_384 = np.array(new_resized_all_points_384)
+
+ motion_brush_mask_384 = cv2.resize(motion_brush_mask, (384, 384), cv2.INTER_NEAREST)
+
+ resized_all_points_384_inmask, resized_all_points_384_outmask = \
+ divide_points_afterinterpolate(resized_all_points_384, motion_brush_mask_384)
+
+ in_mask_flag = False
+ out_mask_flag = False
+
+ if resized_all_points_384_inmask.shape[0] != 0:
+ in_mask_flag = True
+ input_drag_384_inmask, input_mask_384_inmask = \
+ get_sparseflow_and_mask_forward(
+ resized_all_points_384_inmask,
+ self.model_length - 1, 384, 384
+ )
+ else:
+ input_drag_384_inmask, input_mask_384_inmask = \
+ np.zeros((self.model_length - 1, 384, 384, 2)), \
+ np.zeros((self.model_length - 1, 384, 384))
+
+ if resized_all_points_384_outmask.shape[0] != 0:
+ out_mask_flag = True
+ input_drag_384_outmask, input_mask_384_outmask = \
+ get_sparseflow_and_mask_forward(
+ resized_all_points_384_outmask,
+ self.model_length - 1, 384, 384
+ )
+ else:
+ input_drag_384_outmask, input_mask_384_outmask = \
+ np.zeros((self.model_length - 1, 384, 384, 2)), \
+ np.zeros((self.model_length - 1, 384, 384))
+
+ input_drag_384_inmask = torch.from_numpy(input_drag_384_inmask).unsqueeze(0) # [1, 13, h, w, 2]
+ input_mask_384_inmask = torch.from_numpy(input_mask_384_inmask).unsqueeze(0) # [1, 13, h, w]
+ input_drag_384_outmask = torch.from_numpy(input_drag_384_outmask).unsqueeze(0) # [1, 13, h, w, 2]
+ input_mask_384_outmask = torch.from_numpy(input_mask_384_outmask).unsqueeze(0) # [1, 13, h, w]
+
+ dir, base, ext = split_filename(first_frame_path)
+ id = base.split('_')[0]
+
+ image_pil = image2pil(first_frame_path)
+ image_pil = image_pil.resize((self.width, self.height), Image.BILINEAR).convert('RGB')
+
+ visualized_drag, _ = visualize_drag_v2(first_frame_path, resized_all_points, self.width, self.height)
+
+ motion_brush_viz_pil = Image.fromarray(motion_brush_viz.astype(np.uint8)).convert('RGBA')
+ visualized_drag = visualized_drag[0].convert('RGBA')
+ visualized_drag_brush = Image.alpha_composite(motion_brush_viz_pil, visualized_drag)
+
+ first_frames_transform = transforms.Compose([
+ lambda x: Image.fromarray(x),
+ transforms.ToTensor(),
+ ])
+
+ outputs = None
+ ouput_video_list = []
+ ouput_flow_list = []
+ num_inference = 1
+ for i in tqdm(range(num_inference)):
+ if not outputs:
+ first_frames = image2arr(first_frame_path)
+ first_frames = repeat(first_frames_transform(first_frames), 'c h w -> b c h w', b=inference_batch_size).to(self.device)
+ else:
+ first_frames = outputs['logits_imgs'][:, -1]
+
+
+ outputs = self.forward_sample(
+ input_drag_384_inmask.to(self.device),
+ input_drag_384_outmask.to(self.device),
+ first_frames.to(self.device),
+ input_mask_384_inmask.to(self.device),
+ input_mask_384_outmask.to(self.device),
+ in_mask_flag,
+ out_mask_flag,
+ motion_brush_mask_384,
+ ctrl_scale)
+
+ ouput_video_list.append(outputs['logits_imgs'])
+ ouput_flow_list.append(outputs['flows'])
+
+ hint_path = os.path.join(output_dir_video, str(id), f'{id}_hint.png')
+ visualized_drag_brush.save(hint_path)
+
+ for i in range(inference_batch_size):
+ output_tensor = [ouput_video_list[0][i]]
+ flow_tensor = [ouput_flow_list[0][i]]
+ output_tensor = torch.cat(output_tensor, dim=0)
+ flow_tensor = torch.cat(flow_tensor, dim=0)
+
+ outputs_path = os.path.join(output_dir_video, str(id), f's{ctrl_scale}', f'{id}_output.gif')
+ flows_path = os.path.join(output_dir_video, str(id), f's{ctrl_scale}', f'{id}_flow.gif')
+
+ outputs_mp4_path = os.path.join(output_dir_video, str(id), f's{ctrl_scale}', f'{id}_output.mp4')
+ flows_mp4_path = os.path.join(output_dir_video, str(id), f's{ctrl_scale}', f'{id}_flow.mp4')
+
+ outputs_frames_path = os.path.join(output_dir_frame, str(id), f's{ctrl_scale}', f'{id}_output')
+ flows_frames_path = os.path.join(output_dir_frame, str(id), f's{ctrl_scale}', f'{id}_flow')
+
+ os.makedirs(os.path.join(output_dir_video, str(id), f's{ctrl_scale}'), exist_ok=True)
+ os.makedirs(os.path.join(outputs_frames_path), exist_ok=True)
+ os.makedirs(os.path.join(flows_frames_path), exist_ok=True)
+
+ print(output_tensor.shape)
+
+ output_RGB = output_tensor.permute(0, 2, 3, 1).mul(255).cpu().numpy()
+ flow_RGB = flow_tensor.permute(0, 2, 3, 1).mul(255).cpu().numpy()
+
+ torchvision.io.write_video(
+ outputs_mp4_path,
+ output_RGB,
+ fps=20, video_codec='h264', options={'crf': '10'}
+ )
+
+ torchvision.io.write_video(
+ flows_mp4_path,
+ flow_RGB,
+ fps=20, video_codec='h264', options={'crf': '10'}
+ )
+
+ imageio.mimsave(outputs_path, np.uint8(output_RGB), fps=20, loop=0)
+
+ imageio.mimsave(flows_path, np.uint8(flow_RGB), fps=20, loop=0)
+
+ for f in range(output_RGB.shape[0]):
+ Image.fromarray(np.uint8(output_RGB[f])).save(os.path.join(outputs_frames_path, f'{str(f).zfill(3)}.png'))
+ Image.fromarray(np.uint8(flow_RGB[f])).save(os.path.join(flows_frames_path, f'{str(f).zfill(3)}.png'))
+
+ return hint_path, outputs_path, flows_path, outputs_mp4_path, flows_mp4_path
+
+
+with gr.Blocks() as demo:
+ gr.Markdown("""
MOFA-Video
""")
+
+ gr.Markdown("""Official Gradio Demo for MOFA-Video: Controllable Image Animation via Generative Motion Field Adaptions in Frozen Image-to-Video Diffusion Model.
""")
+
+ gr.Markdown(
+ """
+ During the inference, kindly follow these instructions:
+
+ 1. Use the "Upload Image" button to upload an image. Avoid dragging the image directly into the window.
+ 2. Proceed to draw trajectories:
+ 2.1. Click "Add Trajectory" first, then select points on the "Add Trajectory Here" image. The first click sets the starting point. Click multiple points to create a non-linear trajectory. To add a new trajectory, click "Add Trajectory" again and select points on the image. Avoid clicking the "Add Trajectory" button multiple times without clicking points in the image to add the trajectory, as this can lead to errors.
+ 2.2. After adding each trajectory, an optical flow image will be displayed automatically. Use it as a reference to adjust the trajectory for desired effects (e.g., area, intensity).
+ 2.3. To delete the latest trajectory, click "Delete Last Trajectory."
+ 2.4. Choose the Control Scale in the bar. This determines the control intensity. Setting it to 0 means no control (pure generation result of SVD itself), while setting it to 1 results in the strongest control (which will not lead to good results in most cases because of twisting artifacts). A preset value of 0.6 is recommended for most cases.
+ 2.5. To use the motion brush for restraining the control area of the trajectory, click to add masks on the "Add Motion Brush Here" image. The motion brush restricts the optical flow area derived from the trajectory whose starting point is within the motion brush. The displayed optical flow image will change correspondingly. Adjust the motion brush radius using the "Motion Brush Radius" bar.
+ 3. Click the "Run" button to animate the image according to the path.
+ """
+ )
+
+ target_size = 512
+ DragNUWA_net = Drag("cuda:0", target_size, target_size, 25)
+ first_frame_path = gr.State()
+ tracking_points = gr.State([])
+ motion_brush_points = gr.State([])
+ motion_brush_mask = gr.State()
+ motion_brush_viz = gr.State()
+ inference_batch_size = gr.State(1)
+
+ def preprocess_image(image):
+
+ image_pil = image2pil(image.name)
+ raw_w, raw_h = image_pil.size
+
+ max_edge = min(raw_w, raw_h)
+ resize_ratio = target_size / max_edge
+
+ image_pil = image_pil.resize((round(raw_w * resize_ratio), round(raw_h * resize_ratio)), Image.BILINEAR)
+
+ new_w, new_h = image_pil.size
+ crop_w = new_w - (new_w % 64)
+ crop_h = new_h - (new_h % 64)
+
+ image_pil = transforms.CenterCrop((crop_h, crop_w))(image_pil.convert('RGB'))
+
+ DragNUWA_net.width = crop_w
+ DragNUWA_net.height = crop_h
+
+ id = str(time.time()).split('.')[0]
+ os.makedirs(os.path.join(output_dir_video, str(id)), exist_ok=True)
+ os.makedirs(os.path.join(output_dir_frame, str(id)), exist_ok=True)
+
+ first_frame_path = os.path.join(output_dir_video, str(id), f"{id}_input.png")
+ image_pil.save(first_frame_path)
+
+ return first_frame_path, first_frame_path, first_frame_path, gr.State([]), gr.State([]), np.zeros((crop_h, crop_w)), np.zeros((crop_h, crop_w, 4))
+
+ def add_drag(tracking_points):
+ if len(tracking_points.constructor_args['value']) != 0 and tracking_points.constructor_args['value'][-1] == []:
+ return tracking_points
+ tracking_points.constructor_args['value'].append([])
+ return tracking_points
+
+ def add_mask(motion_brush_points):
+ motion_brush_points.constructor_args['value'].append([])
+ return motion_brush_points
+
+ def delete_last_drag(tracking_points, first_frame_path, motion_brush_mask):
+ if len(tracking_points.constructor_args['value']) > 0:
+ tracking_points.constructor_args['value'].pop()
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
+ w, h = transparent_background.size
+ transparent_layer = np.zeros((h, w, 4))
+ for track in tracking_points.constructor_args['value']:
+ if len(track) > 1:
+ for i in range(len(track)-1):
+ start_point = track[i]
+ end_point = track[i+1]
+ vx = end_point[0] - start_point[0]
+ vy = end_point[1] - start_point[1]
+ arrow_length = np.sqrt(vx**2 + vy**2)
+ if i == len(track)-2:
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
+ else:
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
+ else:
+ cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
+
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
+
+ viz_flow = DragNUWA_net.get_cmp_flow_from_tracking_points(tracking_points, motion_brush_mask, first_frame_path)
+
+ return tracking_points, trajectory_map, viz_flow
+
+ def add_motion_brushes(motion_brush_points, motion_brush_mask, transparent_layer, first_frame_path, radius, tracking_points, evt: gr.SelectData):
+
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
+ w, h = transparent_background.size
+
+ motion_points = motion_brush_points.constructor_args['value']
+ motion_points.append(evt.index)
+
+ x, y = evt.index
+
+ cv2.circle(motion_brush_mask, (x, y), radius, 255, -1)
+ cv2.circle(transparent_layer, (x, y), radius, (0, 0, 255, 255), -1)
+
+ transparent_layer_pil = Image.fromarray(transparent_layer.astype(np.uint8))
+ motion_map = Image.alpha_composite(transparent_background, transparent_layer_pil)
+
+ viz_flow = DragNUWA_net.get_cmp_flow_from_tracking_points(tracking_points, motion_brush_mask, first_frame_path)
+
+ return motion_brush_mask, transparent_layer, motion_map, viz_flow
+
+ def add_tracking_points(tracking_points, first_frame_path, motion_brush_mask, evt: gr.SelectData):
+
+ print(f"You selected {evt.value} at {evt.index} from {evt.target}")
+
+ if len(tracking_points.constructor_args['value']) == 0:
+ tracking_points.constructor_args['value'].append([])
+
+ tracking_points.constructor_args['value'][-1].append(evt.index)
+
+ # print(tracking_points.constructor_args['value'])
+
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
+ w, h = transparent_background.size
+ transparent_layer = np.zeros((h, w, 4))
+ for track in tracking_points.constructor_args['value']:
+ if len(track) > 1:
+ for i in range(len(track)-1):
+ start_point = track[i]
+ end_point = track[i+1]
+ vx = end_point[0] - start_point[0]
+ vy = end_point[1] - start_point[1]
+ arrow_length = np.sqrt(vx**2 + vy**2)
+ if i == len(track)-2:
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
+ else:
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
+ else:
+ cv2.circle(transparent_layer, tuple(track[0]), 3, (255, 0, 0, 255), -1)
+
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
+
+ viz_flow = DragNUWA_net.get_cmp_flow_from_tracking_points(tracking_points, motion_brush_mask, first_frame_path)
+
+ return tracking_points, trajectory_map, viz_flow
+
+ with gr.Row():
+ with gr.Column(scale=2):
+ image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"])
+ add_drag_button = gr.Button(value="Add Trajectory")
+ run_button = gr.Button(value="Run")
+ delete_last_drag_button = gr.Button(value="Delete Last Trajectory")
+ brush_radius = gr.Slider(label='Motion Brush Radius',
+ minimum=1,
+ maximum=100,
+ step=1,
+ value=10)
+ ctrl_scale = gr.Slider(label='Control Scale',
+ minimum=0,
+ maximum=1.,
+ step=0.01,
+ value=0.6)
+
+ with gr.Column(scale=5):
+ input_image = gr.Image(label="Add Trajectory Here",
+ interactive=True)
+ with gr.Column(scale=5):
+ input_image_mask = gr.Image(label="Add Motion Brush Here",
+ interactive=True)
+
+ with gr.Row():
+ with gr.Column(scale=6):
+ viz_flow = gr.Image(label="Visualized Flow")
+ with gr.Column(scale=6):
+ hint_image = gr.Image(label="Visualized Hint Image")
+ with gr.Row():
+ with gr.Column(scale=6):
+ output_video = gr.Image(label="Output Video")
+ with gr.Column(scale=6):
+ output_flow = gr.Image(label="Output Flow")
+
+ with gr.Row():
+ with gr.Column(scale=6):
+ output_video_mp4 = gr.Video(label="Output Video mp4")
+ with gr.Column(scale=6):
+ output_flow_mp4 = gr.Video(label="Output Flow mp4")
+
+ image_upload_button.upload(preprocess_image, image_upload_button, [input_image, input_image_mask, first_frame_path, tracking_points, motion_brush_points, motion_brush_mask, motion_brush_viz])
+
+ add_drag_button.click(add_drag, tracking_points, tracking_points)
+
+ delete_last_drag_button.click(delete_last_drag, [tracking_points, first_frame_path, motion_brush_mask], [tracking_points, input_image, viz_flow])
+
+ input_image.select(add_tracking_points, [tracking_points, first_frame_path, motion_brush_mask], [tracking_points, input_image, viz_flow])
+
+ input_image_mask.select(add_motion_brushes, [motion_brush_points, motion_brush_mask, motion_brush_viz, first_frame_path, brush_radius, tracking_points], [motion_brush_mask, motion_brush_viz, input_image_mask, viz_flow])
+
+ run_button.click(DragNUWA_net.run, [first_frame_path, tracking_points, inference_batch_size, motion_brush_mask, motion_brush_viz, ctrl_scale], [hint_image, output_video, output_flow, output_video_mp4, output_flow_mp4])
+
+ demo.launch(server_name="127.0.0.1", debug=True, server_port=9080)
diff --git a/ckpts/controlnet/config.json b/ckpts/controlnet/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..a264a50622deb20fb4a4082cb823b134af7db7bf
--- /dev/null
+++ b/ckpts/controlnet/config.json
@@ -0,0 +1,45 @@
+{
+ "_class_name": "FlowControlNet",
+ "_diffusers_version": "0.25.1",
+ "_name_or_path": "/apdcephfs_cq10/share_1290939/myniu/svd_controlnet/svdxt11_featureflow_forward_avg_256256_stride4/unimatch_512384/checkpoint-100000/controlnet",
+ "addition_time_embed_dim": 256,
+ "block_out_channels": [
+ 320,
+ 640,
+ 1280,
+ 1280
+ ],
+ "conditioning_channels": 3,
+ "conditioning_embedding_out_channels": [
+ 16,
+ 32,
+ 96,
+ 256
+ ],
+ "cross_attention_dim": 1024,
+ "down_block_types": [
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "DownBlockSpatioTemporal"
+ ],
+ "in_channels": 8,
+ "layers_per_block": 2,
+ "num_attention_heads": [
+ 5,
+ 10,
+ 10,
+ 20
+ ],
+ "num_frames": 25,
+ "out_channels": 4,
+ "projection_class_embeddings_input_dim": 768,
+ "sample_size": null,
+ "transformer_layers_per_block": 1,
+ "up_block_types": [
+ "UpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal"
+ ]
+}
diff --git a/ckpts/controlnet/diffusion_pytorch_model.safetensors b/ckpts/controlnet/diffusion_pytorch_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..3b575b15b495831af7a1970337ea33cf883f6d58
--- /dev/null
+++ b/ckpts/controlnet/diffusion_pytorch_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1303192a1e72d071e15e7eb37fd1ea15f6424aaf2cd6b6b1e1bb3b1e9e75d37e
+size 2777345452
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/config.yaml b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2944c4056bc0683c9a94bc20017c1056352356e1
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/config.yaml
@@ -0,0 +1,59 @@
+model:
+ arch: CMP
+ total_iter: 140000
+ lr_steps: [80000, 120000]
+ lr_mults: [0.1, 0.1]
+ lr: 0.1
+ optim: SGD
+ warmup_lr: []
+ warmup_steps: []
+ module:
+ arch: CMP
+ image_encoder: alexnet_fcn_32x
+ sparse_encoder: shallownet32x
+ flow_decoder: MotionDecoderPlain
+ skip_layer: False
+ img_enc_dim: 256
+ sparse_enc_dim: 16
+ output_dim: 198
+ decoder_combo: [1,2,4]
+ pretrained_image_encoder: False
+ flow_criterion: "DiscreteLoss"
+ nbins: 99
+ fmax: 50
+data:
+ workers: 2
+ batch_size: 12
+ batch_size_test: 1
+ data_mean: [123.675, 116.28, 103.53] # RGB
+ data_div: [58.395, 57.12, 57.375]
+ short_size: 416
+ crop_size: [384, 384]
+ sample_strategy: ['grid', 'watershed']
+ sample_bg_ratio: 0.000025
+ nms_ks: 81
+ max_num_guide: 150
+
+ flow_file_type: "jpg"
+ image_flow_aug:
+ flip: False
+ flow_aug:
+ reverse: False
+ scale: False
+ rotate: False
+ train_source:
+ - data/yfcc/lists/train.txt
+ - data/youtube9000/lists/train.txt
+ val_source:
+ - data/yfcc/lists/val.txt
+ memcached: False
+trainer:
+ initial_val: True
+ print_freq: 100
+ val_freq: 10000
+ save_freq: 10000
+ val_iter: -1
+ val_disp_start_iter: 0
+ val_disp_end_iter: 16
+ loss_record: ['loss_flow']
+ tensorboard: False
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/resume.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/resume.sh
new file mode 100644
index 0000000000000000000000000000000000000000..06bd63a2c51db22a687f347635759f3a41ea30b2
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/resume.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes=2 --node_rank=$1 \
+ --master_addr="192.168.1.1" main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/resume_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/resume_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..644276733e346ef31fa9d3aaa4110b0b360cff3f
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/resume_slurm.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/train.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5f2b03a431e84f04599c76865ec14cd499ff3063
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/train.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes=2 --node_rank=$1 \
+ --master_addr="192.168.1.1" main.py \
+ --config $work_path/config.yaml --launcher pytorch
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/train_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/train_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e9c1a9f27ef9e639802ecf29247297ff7eb022d1
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/train_slurm.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/validate.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/validate.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6d7abca3cc02d020fce66d22d880f5c9e03ce34c
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/validate.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/validate_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/validate_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9bfe2eec2f1a52089b86f7d8a2550f12251a269e
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc+youtube_voc_16gpu_140k/validate_slurm.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py --config $work_path/config.yaml --launcher slurm \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/config.yaml b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9296127836bea77efc5d6d28ccf363c6e8adbf91
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/config.yaml
@@ -0,0 +1,58 @@
+model:
+ arch: CMP
+ total_iter: 70000
+ lr_steps: [40000, 60000]
+ lr_mults: [0.1, 0.1]
+ lr: 0.1
+ optim: SGD
+ warmup_lr: []
+ warmup_steps: []
+ module:
+ arch: CMP
+ image_encoder: alexnet_fcn_32x
+ sparse_encoder: shallownet32x
+ flow_decoder: MotionDecoderPlain
+ skip_layer: False
+ img_enc_dim: 256
+ sparse_enc_dim: 16
+ output_dim: 198
+ decoder_combo: [1,2,4]
+ pretrained_image_encoder: False
+ flow_criterion: "DiscreteLoss"
+ nbins: 99
+ fmax: 50
+data:
+ workers: 2
+ batch_size: 12
+ batch_size_test: 1
+ data_mean: [123.675, 116.28, 103.53] # RGB
+ data_div: [58.395, 57.12, 57.375]
+ short_size: 416
+ crop_size: [384, 384]
+ sample_strategy: ['grid', 'watershed']
+ sample_bg_ratio: 0.00015625
+ nms_ks: 41
+ max_num_guide: 150
+
+ flow_file_type: "jpg"
+ image_flow_aug:
+ flip: False
+ flow_aug:
+ reverse: False
+ scale: False
+ rotate: False
+ train_source:
+ - data/yfcc/lists/train.txt
+ val_source:
+ - data/yfcc/lists/val.txt
+ memcached: False
+trainer:
+ initial_val: True
+ print_freq: 100
+ val_freq: 10000
+ save_freq: 10000
+ val_iter: -1
+ val_disp_start_iter: 0
+ val_disp_end_iter: 16
+ loss_record: ['loss_flow']
+ tensorboard: False
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/resume.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/resume.sh
new file mode 100644
index 0000000000000000000000000000000000000000..06bd63a2c51db22a687f347635759f3a41ea30b2
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/resume.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes=2 --node_rank=$1 \
+ --master_addr="192.168.1.1" main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/resume_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/resume_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..644276733e346ef31fa9d3aaa4110b0b360cff3f
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/resume_slurm.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/train.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5f2b03a431e84f04599c76865ec14cd499ff3063
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/train.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes=2 --node_rank=$1 \
+ --master_addr="192.168.1.1" main.py \
+ --config $work_path/config.yaml --launcher pytorch
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/train_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/train_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e9c1a9f27ef9e639802ecf29247297ff7eb022d1
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/train_slurm.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/validate.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/validate.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6d7abca3cc02d020fce66d22d880f5c9e03ce34c
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/validate.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/validate_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/validate_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9bfe2eec2f1a52089b86f7d8a2550f12251a269e
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_16gpu_70k/validate_slurm.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py --config $work_path/config.yaml --launcher slurm \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/config.yaml b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6e8751ff794627d37449771734bb2fe1521f527a
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/config.yaml
@@ -0,0 +1,58 @@
+model:
+ arch: CMP
+ total_iter: 140000
+ lr_steps: [80000, 120000]
+ lr_mults: [0.1, 0.1]
+ lr: 0.1
+ optim: SGD
+ warmup_lr: []
+ warmup_steps: []
+ module:
+ arch: CMP
+ image_encoder: alexnet_fcn_32x
+ sparse_encoder: shallownet32x
+ flow_decoder: MotionDecoderPlain
+ skip_layer: False
+ img_enc_dim: 256
+ sparse_enc_dim: 16
+ output_dim: 198
+ decoder_combo: [1,2,4]
+ pretrained_image_encoder: False
+ flow_criterion: "DiscreteLoss"
+ nbins: 99
+ fmax: 50
+data:
+ workers: 2
+ batch_size: 12
+ batch_size_test: 1
+ data_mean: [123.675, 116.28, 103.53] # RGB
+ data_div: [58.395, 57.12, 57.375]
+ short_size: 416
+ crop_size: [384, 384]
+ sample_strategy: ['grid', 'watershed']
+ sample_bg_ratio: 0.00015625
+ nms_ks: 41
+ max_num_guide: 150
+
+ flow_file_type: "jpg"
+ image_flow_aug:
+ flip: False
+ flow_aug:
+ reverse: False
+ scale: False
+ rotate: False
+ train_source:
+ - data/yfcc/lists/train.txt
+ val_source:
+ - data/yfcc/lists/val.txt
+ memcached: False
+trainer:
+ initial_val: True
+ print_freq: 100
+ val_freq: 10000
+ save_freq: 10000
+ val_iter: -1
+ val_disp_start_iter: 0
+ val_disp_end_iter: 16
+ loss_record: ['loss_flow']
+ tensorboard: False
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/resume.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/resume.sh
new file mode 100644
index 0000000000000000000000000000000000000000..17cec90cef2555c6b7dd5acfe3b938c9be451346
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/resume.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/resume_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/resume_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..94b7cccac61566afb3eef7924a6d8b56027b2d13
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/resume_slurm.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/train.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..330d14c459f8549ea81f956c3497e13ddf68aed0
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/train.sh
@@ -0,0 +1,4 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+ --config $work_path/config.yaml --launcher pytorch
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/train_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/train_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..140bfae1f0543e2b186f06f3dfc7a934c0aeccf1
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/train_slurm.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/validate.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/validate.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6d7abca3cc02d020fce66d22d880f5c9e03ce34c
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/validate.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/validate_slurm.sh b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/validate_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9bfe2eec2f1a52089b86f7d8a2550f12251a269e
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/alexnet_yfcc_voc_8gpu_140k/validate_slurm.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py --config $work_path/config.yaml --launcher slurm \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/config.yaml b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5dd44cf7642837242711326eb413b950c384dd26
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/config.yaml
@@ -0,0 +1,61 @@
+model:
+ arch: CMP
+ total_iter: 70000
+ lr_steps: [40000, 60000]
+ lr_mults: [0.1, 0.1]
+ lr: 0.1
+ optim: SGD
+ warmup_lr: []
+ warmup_steps: []
+ module:
+ arch: CMP
+ image_encoder: resnet50
+ sparse_encoder: shallownet8x
+ flow_decoder: MotionDecoderPlain
+ skip_layer: False
+ img_enc_dim: 256
+ sparse_enc_dim: 16
+ output_dim: 198
+ decoder_combo: [1,2,4]
+ pretrained_image_encoder: False
+ flow_criterion: "DiscreteLoss"
+ nbins: 99
+ fmax: 50
+data:
+ workers: 2
+ batch_size: 10
+ batch_size_test: 1
+ data_mean: [123.675, 116.28, 103.53] # RGB
+ data_div: [58.395, 57.12, 57.375]
+ short_size: 416
+ crop_size: [320, 320]
+ sample_strategy: ['grid', 'watershed']
+ sample_bg_ratio: 0.00015625
+ nms_ks: 15
+ max_num_guide: -1
+
+ flow_file_type: "jpg"
+ image_flow_aug:
+ flip: False
+ flow_aug:
+ reverse: False
+ scale: False
+ rotate: False
+ train_source:
+ - data/yfcc/lists/train.txt
+ - data/youtube9000/lists/train.txt
+ - data/VIP/lists/train.txt
+ - data/MPII/lists/train.txt
+ val_source:
+ - data/yfcc/lists/val.txt
+ memcached: False
+trainer:
+ initial_val: True
+ print_freq: 100
+ val_freq: 10000
+ save_freq: 10000
+ val_iter: -1
+ val_disp_start_iter: 0
+ val_disp_end_iter: 16
+ loss_record: ['loss_flow']
+ tensorboard: False
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/resume.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/resume.sh
new file mode 100644
index 0000000000000000000000000000000000000000..06bd63a2c51db22a687f347635759f3a41ea30b2
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/resume.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes=2 --node_rank=$1 \
+ --master_addr="192.168.1.1" main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/resume_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/resume_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..644276733e346ef31fa9d3aaa4110b0b360cff3f
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/resume_slurm.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/train.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5f2b03a431e84f04599c76865ec14cd499ff3063
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/train.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes=2 --node_rank=$1 \
+ --master_addr="192.168.1.1" main.py \
+ --config $work_path/config.yaml --launcher pytorch
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/train_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/train_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e9c1a9f27ef9e639802ecf29247297ff7eb022d1
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/train_slurm.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/validate.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/validate.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6d7abca3cc02d020fce66d22d880f5c9e03ce34c
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/validate.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/validate_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/validate_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9bfe2eec2f1a52089b86f7d8a2550f12251a269e
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc+youtube+vip+mpii_lip_16gpu_70k/validate_slurm.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py --config $work_path/config.yaml --launcher slurm \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/config.yaml b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1a453c27947f570320609a61fde9c862819842bc
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/config.yaml
@@ -0,0 +1,58 @@
+model:
+ arch: CMP
+ total_iter: 42000
+ lr_steps: [24000, 36000]
+ lr_mults: [0.1, 0.1]
+ lr: 0.1
+ optim: SGD
+ warmup_lr: []
+ warmup_steps: []
+ module:
+ arch: CMP
+ image_encoder: resnet50
+ sparse_encoder: shallownet8x
+ flow_decoder: MotionDecoderPlain
+ skip_layer: False
+ img_enc_dim: 256
+ sparse_enc_dim: 16
+ output_dim: 198
+ decoder_combo: [1,2,4]
+ pretrained_image_encoder: False
+ flow_criterion: "DiscreteLoss"
+ nbins: 99
+ fmax: 50
+data:
+ workers: 2
+ batch_size: 16
+ batch_size_test: 1
+ data_mean: [123.675, 116.28, 103.53] # RGB
+ data_div: [58.395, 57.12, 57.375]
+ short_size: 333
+ crop_size: [256, 256]
+ sample_strategy: ['grid', 'watershed']
+ sample_bg_ratio: 0.00005632
+ nms_ks: 49
+ max_num_guide: -1
+
+ flow_file_type: "jpg"
+ image_flow_aug:
+ flip: False
+ flow_aug:
+ reverse: False
+ scale: False
+ rotate: False
+ train_source:
+ - data/yfcc/lists/train.txt
+ val_source:
+ - data/yfcc/lists/val.txt
+ memcached: False
+trainer:
+ initial_val: True
+ print_freq: 100
+ val_freq: 10000
+ save_freq: 10000
+ val_iter: -1
+ val_disp_start_iter: 0
+ val_disp_end_iter: 16
+ loss_record: ['loss_flow']
+ tensorboard: False
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/resume.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/resume.sh
new file mode 100644
index 0000000000000000000000000000000000000000..06bd63a2c51db22a687f347635759f3a41ea30b2
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/resume.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes=2 --node_rank=$1 \
+ --master_addr="192.168.1.1" main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/resume_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/resume_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..644276733e346ef31fa9d3aaa4110b0b360cff3f
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/resume_slurm.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/train.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5f2b03a431e84f04599c76865ec14cd499ff3063
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/train.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes=2 --node_rank=$1 \
+ --master_addr="192.168.1.1" main.py \
+ --config $work_path/config.yaml --launcher pytorch
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/train_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/train_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e9c1a9f27ef9e639802ecf29247297ff7eb022d1
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/train_slurm.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/validate.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/validate.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6d7abca3cc02d020fce66d22d880f5c9e03ce34c
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/validate.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/validate_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/validate_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9bfe2eec2f1a52089b86f7d8a2550f12251a269e
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_coco_16gpu_42k/validate_slurm.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py --config $work_path/config.yaml --launcher slurm \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/config.yaml b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..47ba5c8c0d6f63247b7fcf6ac18554f4cddb0eac
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/config.yaml
@@ -0,0 +1,58 @@
+model:
+ arch: CMP
+ total_iter: 42000
+ lr_steps: [24000, 36000]
+ lr_mults: [0.1, 0.1]
+ lr: 0.1
+ optim: SGD
+ warmup_lr: []
+ warmup_steps: []
+ module:
+ arch: CMP
+ image_encoder: resnet50
+ sparse_encoder: shallownet8x
+ flow_decoder: MotionDecoderPlain
+ skip_layer: False
+ img_enc_dim: 256
+ sparse_enc_dim: 16
+ output_dim: 198
+ decoder_combo: [1,2,4]
+ pretrained_image_encoder: False
+ flow_criterion: "DiscreteLoss"
+ nbins: 99
+ fmax: 50
+data:
+ workers: 2
+ batch_size: 10
+ batch_size_test: 1
+ data_mean: [123.675, 116.28, 103.53] # RGB
+ data_div: [58.395, 57.12, 57.375]
+ short_size: 416
+ crop_size: [320, 320]
+ sample_strategy: ['grid', 'watershed']
+ sample_bg_ratio: 0.00003629
+ nms_ks: 67
+ max_num_guide: -1
+
+ flow_file_type: "jpg"
+ image_flow_aug:
+ flip: False
+ flow_aug:
+ reverse: False
+ scale: False
+ rotate: False
+ train_source:
+ - data/yfcc/lists/train.txt
+ val_source:
+ - data/yfcc/lists/val.txt
+ memcached: False
+trainer:
+ initial_val: True
+ print_freq: 100
+ val_freq: 10000
+ save_freq: 10000
+ val_iter: -1
+ val_disp_start_iter: 0
+ val_disp_end_iter: 16
+ loss_record: ['loss_flow']
+ tensorboard: False
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/resume.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/resume.sh
new file mode 100644
index 0000000000000000000000000000000000000000..06bd63a2c51db22a687f347635759f3a41ea30b2
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/resume.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes=2 --node_rank=$1 \
+ --master_addr="192.168.1.1" main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/resume_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/resume_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..644276733e346ef31fa9d3aaa4110b0b360cff3f
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/resume_slurm.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/train.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5f2b03a431e84f04599c76865ec14cd499ff3063
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/train.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 \
+ --nnodes=2 --node_rank=$1 \
+ --master_addr="192.168.1.1" main.py \
+ --config $work_path/config.yaml --launcher pytorch
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/train_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/train_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e9c1a9f27ef9e639802ecf29247297ff7eb022d1
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/train_slurm.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n16 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/validate.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/validate.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6d7abca3cc02d020fce66d22d880f5c9e03ce34c
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/validate.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/validate_slurm.sh b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/validate_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9bfe2eec2f1a52089b86f7d8a2550f12251a269e
--- /dev/null
+++ b/models/cmp/experiments/rep_learning/resnet50_yfcc_voc_16gpu_42k/validate_slurm.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py --config $work_path/config.yaml --launcher slurm \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar
new file mode 100644
index 0000000000000000000000000000000000000000..a15fde53bc352803ac906bb48f7ec6f08f55f817
--- /dev/null
+++ b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/checkpoints/ckpt_iter_42000.pth.tar
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cd3a385e227c29f89b5c7c6f4c89d356f6022fa7fcfc71ab1bd40e9833048dd6
+size 228465722
diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/config.yaml b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fc56f53ce2088872c5f6987a0f1a44dabaf76f9d
--- /dev/null
+++ b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/config.yaml
@@ -0,0 +1,59 @@
+model:
+ arch: CMP
+ total_iter: 42000
+ lr_steps: [24000, 36000]
+ lr_mults: [0.1, 0.1]
+ lr: 0.1
+ optim: SGD
+ warmup_lr: []
+ warmup_steps: []
+ module:
+ arch: CMP
+ image_encoder: resnet50
+ sparse_encoder: shallownet8x
+ flow_decoder: MotionDecoderSkipLayer
+ skip_layer: True
+ img_enc_dim: 256
+ sparse_enc_dim: 16
+ output_dim: 198
+ decoder_combo: [1,2,4]
+ pretrained_image_encoder: False
+ flow_criterion: "DiscreteLoss"
+ nbins: 99
+ fmax: 50
+data:
+ workers: 2
+ batch_size: 8
+ batch_size_test: 1
+ data_mean: [123.675, 116.28, 103.53] # RGB
+ data_div: [58.395, 57.12, 57.375]
+ short_size: 416
+ crop_size: [384, 384]
+ sample_strategy: ['grid', 'watershed']
+ sample_bg_ratio: 5.74e-5
+ nms_ks: 41
+ max_num_guide: -1
+
+ flow_file_type: "jpg"
+ image_flow_aug:
+ flip: False
+ flow_aug:
+ reverse: False
+ scale: False
+ rotate: False
+ train_source:
+ - data/VIP/lists/train.txt
+ - data/MPII/lists/train.txt
+ val_source:
+ - data/VIP/lists/randval.txt
+ memcached: False
+trainer:
+ initial_val: True
+ print_freq: 100
+ val_freq: 5000
+ save_freq: 5000
+ val_iter: -1
+ val_disp_start_iter: 0
+ val_disp_end_iter: 16
+ loss_record: ['loss_flow']
+ tensorboard: True
diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/resume.sh b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/resume.sh
new file mode 100644
index 0000000000000000000000000000000000000000..17cec90cef2555c6b7dd5acfe3b938c9be451346
--- /dev/null
+++ b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/resume.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/resume_slurm.sh b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/resume_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..94b7cccac61566afb3eef7924a6d8b56027b2d13
--- /dev/null
+++ b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/resume_slurm.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm \
+ --load-iter 10000 \
+ --resume
diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/train.sh b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..330d14c459f8549ea81f956c3497e13ddf68aed0
--- /dev/null
+++ b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/train.sh
@@ -0,0 +1,4 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+ --config $work_path/config.yaml --launcher pytorch
diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/train_slurm.sh b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/train_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..140bfae1f0543e2b186f06f3dfc7a934c0aeccf1
--- /dev/null
+++ b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/train_slurm.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition -n8 \
+ --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py \
+ --config $work_path/config.yaml --launcher slurm
diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/validate.sh b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/validate.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6d7abca3cc02d020fce66d22d880f5c9e03ce34c
--- /dev/null
+++ b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/validate.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+work_path=$(dirname $0)
+python -m torch.distributed.launch --nproc_per_node=8 main.py \
+ --config $work_path/config.yaml --launcher pytorch \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/validate_slurm.sh b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/validate_slurm.sh
new file mode 100644
index 0000000000000000000000000000000000000000..aef377a6e02a61de710eb8a72769ede93ce897e7
--- /dev/null
+++ b/models/cmp/experiments/semiauto_annot/resnet50_vip+mpii_liteflow/validate_slurm.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+work_path=$(dirname $0)
+partition=$1
+GLOG_vmodule=MemcachedClient=-1 srun --mpi=pmi2 -p $partition \
+ -n8 --gres=gpu:8 --ntasks-per-node=8 \
+ python -u main.py --config $work_path/config.yaml --launcher slurm \
+ --load-iter 70000 \
+ --validate
diff --git a/models/cmp/losses.py b/models/cmp/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..b562ff841da3e8508ddf5e1264de382fb510376d
--- /dev/null
+++ b/models/cmp/losses.py
@@ -0,0 +1,536 @@
+import torch
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+import random
+import math
+
+def MultiChannelSoftBinaryCrossEntropy(input, target, reduction='mean'):
+ '''
+ input: N x 38 x H x W --> 19N x 2 x H x W
+ target: N x 19 x H x W --> 19N x 1 x H x W
+ '''
+ input = input.view(-1, 2, input.size(2), input.size(3))
+ target = target.view(-1, 1, input.size(2), input.size(3))
+
+ logsoftmax = nn.LogSoftmax(dim=1)
+ if reduction == 'mean':
+ return torch.mean(torch.sum(-target * logsoftmax(input), dim=1))
+ else:
+ return torch.sum(torch.sum(-target * logsoftmax(input), dim=1))
+
+class EdgeAwareLoss():
+ def __init__(self, nc=2, loss_type="L1", reduction='mean'):
+ assert loss_type in ['L1', 'BCE'], "Undefined loss type: {}".format(loss_type)
+ self.nc = nc
+ self.loss_type = loss_type
+ self.kernelx = Variable(torch.Tensor([[1,0,-1],[2,0,-2],[1,0,-1]]).cuda())
+ self.kernelx = self.kernelx.repeat(nc,1,1,1)
+ self.kernely = Variable(torch.Tensor([[1,2,1],[0,0,0],[-1,-2,-1]]).cuda())
+ self.kernely = self.kernely.repeat(nc,1,1,1)
+ self.bias = Variable(torch.zeros(nc).cuda())
+ self.reduction = reduction
+ if loss_type == 'L1':
+ self.loss = nn.SmoothL1Loss(reduction=reduction)
+ elif loss_type == 'BCE':
+ self.loss = self.bce2d
+
+ def bce2d(self, input, target):
+ assert not target.requires_grad
+ beta = 1 - torch.mean(target)
+ weights = 1 - beta + (2 * beta - 1) * target
+ loss = nn.functional.binary_cross_entropy(input, target, weights, reduction=self.reduction)
+ return loss
+
+ def get_edge(self, var):
+ assert var.size(1) == self.nc, \
+ "input size at dim 1 should be consistent with nc, {} vs {}".format(var.size(1), self.nc)
+ outputx = nn.functional.conv2d(var, self.kernelx, bias=self.bias, padding=1, groups=self.nc)
+ outputy = nn.functional.conv2d(var, self.kernely, bias=self.bias, padding=1, groups=self.nc)
+ eps=1e-05
+ return torch.sqrt(outputx.pow(2) + outputy.pow(2) + eps).mean(dim=1, keepdim=True)
+
+ def __call__(self, input, target):
+ size = target.shape[2:4]
+ input = nn.functional.interpolate(input, size=size, mode="bilinear", align_corners=True)
+ target_edge = self.get_edge(target)
+ if self.loss_type == 'L1':
+ return self.loss(self.get_edge(input), target_edge)
+ elif self.loss_type == 'BCE':
+ raise NotImplemented
+ #target_edge = torch.sign(target_edge - 0.1)
+ #pred = self.get_edge(nn.functional.sigmoid(input))
+ #return self.loss(pred, target_edge)
+
+def KLD(mean, logvar):
+ return -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
+
+class DiscreteLoss(nn.Module):
+ def __init__(self, nbins, fmax):
+ super().__init__()
+ self.loss = nn.CrossEntropyLoss()
+ assert nbins % 2 == 1, "nbins should be odd"
+ self.nbins = nbins
+ self.fmax = fmax
+ self.step = 2 * fmax / float(nbins)
+
+ def tobin(self, target):
+ target = torch.clamp(target, -self.fmax + 1e-3, self.fmax - 1e-3)
+ quantized_target = torch.floor((target + self.fmax) / self.step)
+ return quantized_target.type(torch.cuda.LongTensor)
+
+ def __call__(self, input, target):
+ size = target.shape[2:4]
+ if input.shape[2] != size[0] or input.shape[3] != size[1]:
+ input = nn.functional.interpolate(input, size=size, mode="bilinear", align_corners=True)
+ target = self.tobin(target)
+ assert input.size(1) == self.nbins * 2
+ # print(target.shape)
+ # print(input.shape)
+ # print(torch.max(target))
+ target[target>=99]=98 # odd bugs of the training loss. We have [0 ~ 99] in GT flow, but nbins = 99
+ return self.loss(input[:,:self.nbins,...], target[:,0,...]) + self.loss(input[:,self.nbins:,...], target[:,1,...])
+
+class MultiDiscreteLoss():
+ def __init__(self, nbins=19, fmax=47.5, reduction='mean', xy_weight=(1., 1.), quantize_strategy='linear'):
+ self.loss = nn.CrossEntropyLoss(reduction=reduction)
+ assert nbins % 2 == 1, "nbins should be odd"
+ self.nbins = nbins
+ self.fmax = fmax
+ self.step = 2 * fmax / float(nbins)
+ self.x_weight, self.y_weight = xy_weight
+ self.quantize_strategy = quantize_strategy
+
+ def tobin(self, target):
+ target = torch.clamp(target, -self.fmax + 1e-3, self.fmax - 1e-3)
+ if self.quantize_strategy == "linear":
+ quantized_target = torch.floor((target + self.fmax) / self.step)
+ elif self.quantize_strategy == "quadratic":
+ ind = target.data > 0
+ quantized_target = target.clone()
+ quantized_target[ind] = torch.floor(self.nbins * torch.sqrt(target[ind] / (4 * self.fmax)) + self.nbins / 2.)
+ quantized_target[~ind] = torch.floor(-self.nbins * torch.sqrt(-target[~ind] / (4 * self.fmax)) + self.nbins / 2.)
+ return quantized_target.type(torch.cuda.LongTensor)
+
+ def __call__(self, input, target):
+ size = target.shape[2:4]
+ target = self.tobin(target)
+ if isinstance(input, list):
+ input = [nn.functional.interpolate(ip, size=size, mode="bilinear", align_corners=True) for ip in input]
+ return sum([self.x_weight * self.loss(input[k][:,:self.nbins,...], target[:,0,...]) + self.y_weight * self.loss(input[k][:,self.nbins:,...], target[:,1,...]) for k in range(len(input))]) / float(len(input))
+ else:
+ input = nn.functional.interpolate(input, size=size, mode="bilinear", align_corners=True)
+ return self.x_weight * self.loss(input[:,:self.nbins,...], target[:,0,...]) + self.y_weight * self.loss(input[:,self.nbins:,...], target[:,1,...])
+
+class MultiL1Loss():
+ def __init__(self, reduction='mean'):
+ self.loss = nn.SmoothL1Loss(reduction=reduction)
+
+ def __call__(self, input, target):
+ size = target.shape[2:4]
+ if isinstance(input, list):
+ input = [nn.functional.interpolate(ip, size=size, mode="bilinear", align_corners=True) for ip in input]
+ return sum([self.loss(input[k], target) for k in range(len(input))]) / float(len(input))
+ else:
+ input = nn.functional.interpolate(input, size=size, mode="bilinear", align_corners=True)
+ return self.loss(input, target)
+
+class MultiMSELoss():
+ def __init__(self):
+ self.loss = nn.MSELoss()
+
+ def __call__(self, predicts, targets):
+ loss = 0
+ for predict, target in zip(predicts, targets):
+ loss += self.loss(predict, target)
+ return loss
+
+class JointDiscreteLoss():
+ def __init__(self, nbins=19, fmax=47.5, reduction='mean', quantize_strategy='linear'):
+ self.loss = nn.CrossEntropyLoss(reduction=reduction)
+ assert nbins % 2 == 1, "nbins should be odd"
+ self.nbins = nbins
+ self.fmax = fmax
+ self.step = 2 * fmax / float(nbins)
+ self.quantize_strategy = quantize_strategy
+
+ def tobin(self, target):
+ target = torch.clamp(target, -self.fmax + 1e-3, self.fmax - 1e-3)
+ if self.quantize_strategy == "linear":
+ quantized_target = torch.floor((target + self.fmax) / self.step)
+ elif self.quantize_strategy == "quadratic":
+ ind = target.data > 0
+ quantized_target = target.clone()
+ quantized_target[ind] = torch.floor(self.nbins * torch.sqrt(target[ind] / (4 * self.fmax)) + self.nbins / 2.)
+ quantized_target[~ind] = torch.floor(-self.nbins * torch.sqrt(-target[~ind] / (4 * self.fmax)) + self.nbins / 2.)
+ else:
+ raise Exception("No such quantize strategy: {}".format(self.quantize_strategy))
+ joint_target = quantized_target[:,0,:,:] * self.nbins + quantized_target[:,1,:,:]
+ return joint_target.type(torch.cuda.LongTensor)
+
+ def __call__(self, input, target):
+ target = self.tobin(target)
+ assert input.size(1) == self.nbins ** 2
+ return self.loss(input, target)
+
+class PolarDiscreteLoss():
+ def __init__(self, abins=30, rbins=20, fmax=50., reduction='mean', ar_weight=(1., 1.), quantize_strategy='linear'):
+ self.loss = nn.CrossEntropyLoss(reduction=reduction)
+ self.fmax = fmax
+ self.rbins = rbins
+ self.abins = abins
+ self.a_weight, self.r_weight = ar_weight
+ self.quantize_strategy = quantize_strategy
+
+ def tobin(self, target):
+ indxneg = target.data[:,0,:,:] < 0
+ eps = torch.zeros(target.data[:,0,:,:].size()).cuda()
+ epsind = target.data[:,0,:,:] == 0
+ eps[epsind] += 1e-5
+ angle = torch.atan(target.data[:,1,:,:] / (target.data[:,0,:,:] + eps))
+ angle[indxneg] += np.pi
+ angle += np.pi / 2 # 0 to 2pi
+ angle = torch.clamp(angle, 0, 2 * np.pi - 1e-3)
+ radius = torch.sqrt(target.data[:,0,:,:] ** 2 + target.data[:,1,:,:] ** 2)
+ radius = torch.clamp(radius, 0, self.fmax - 1e-3)
+ quantized_angle = torch.floor(self.abins * angle / (2 * np.pi))
+ if self.quantize_strategy == 'linear':
+ quantized_radius = torch.floor(self.rbins * radius / self.fmax)
+ elif self.quantize_strategy == 'quadratic':
+ quantized_radius = torch.floor(self.rbins * torch.sqrt(radius / self.fmax))
+ else:
+ raise Exception("No such quantize strategy: {}".format(self.quantize_strategy))
+ quantized_target = torch.autograd.Variable(torch.cat([torch.unsqueeze(quantized_angle, 1), torch.unsqueeze(quantized_radius, 1)], dim=1))
+ return quantized_target.type(torch.cuda.LongTensor)
+
+ def __call__(self, input, target):
+ target = self.tobin(target)
+ assert (target >= 0).all() and (target[:,0,:,:] < self.abins).all() and (target[:,1,:,:] < self.rbins).all()
+ return self.a_weight * self.loss(input[:,:self.abins,...], target[:,0,...]) + self.r_weight * self.loss(input[:,self.abins:,...], target[:,1,...])
+
+class WeightedDiscreteLoss():
+ def __init__(self, nbins=19, fmax=47.5, reduction='mean'):
+ self.loss = CrossEntropy2d(reduction=reduction)
+ assert nbins % 2 == 1, "nbins should be odd"
+ self.nbins = nbins
+ self.fmax = fmax
+ self.step = 2 * fmax / float(nbins)
+ self.weight = np.ones((nbins), dtype=np.float32)
+ self.weight[int(self.fmax / self.step)] = 0.01
+ self.weight = torch.from_numpy(self.weight).cuda()
+
+ def tobin(self, target):
+ target = torch.clamp(target, -self.fmax + 1e-3, self.fmax - 1e-3)
+ return torch.floor((target + self.fmax) / self.step).type(torch.cuda.LongTensor)
+
+ def __call__(self, input, target):
+ target = self.tobin(target)
+ assert (target >= 0).all() and (target < self.nbins).all()
+ return self.loss(input[:,:self.nbins,...], target[:,0,...]) + self.loss(input[:,self.nbins:,...], target[:,1,...], self.weight)
+
+
+class CrossEntropy2d(nn.Module):
+ def __init__(self, reduction='mean', ignore_label=-1):
+ super(CrossEntropy2d, self).__init__()
+ self.ignore_label = ignore_label
+ self.reduction = reduction
+
+ def forward(self, predict, target, weight=None):
+ """
+ Args:
+ predict:(n, c, h, w)
+ target:(n, h, w)
+ weight (Tensor, optional): a manual rescaling weight given to each class.
+ If given, has to be a Tensor of size "nclasses"
+ """
+ assert not target.requires_grad
+ assert predict.dim() == 4
+ assert target.dim() == 3
+ assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0))
+ assert predict.size(2) == target.size(1), "{0} vs {1} ".format(predict.size(2), target.size(1))
+ assert predict.size(3) == target.size(2), "{0} vs {1} ".format(predict.size(3), target.size(3))
+ n, c, h, w = predict.size()
+ target_mask = (target >= 0) * (target != self.ignore_label)
+ target = target[target_mask]
+ predict = predict.transpose(1, 2).transpose(2, 3).contiguous()
+ predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c)
+ loss = F.cross_entropy(predict, target, weight=weight, reduction=self.reduction)
+ return loss
+
+#class CrossPixelSimilarityLoss():
+# '''
+# Modified from: https://github.com/lppllppl920/Challenge2018/blob/master/loss.py
+# '''
+# def __init__(self, sigma=0.0036, sampling_size=512):
+# self.sigma = sigma
+# self.sampling_size = sampling_size
+# self.epsilon = 1.0e-15
+# self.embed_norm = True # loss does not decrease no matter it is true or false.
+#
+# def __call__(self, embeddings, flows):
+# '''
+# embedding: Variable Nx256xHxW (not hyper-column)
+# flows: Variable Nx2xHxW
+# '''
+# assert flows.size(1) == 2
+#
+# # flow normalization
+# positive_mask = (flows > 0)
+# flows = -torch.clamp(torch.log(torch.abs(flows) + 1) / math.log(50. + 1), max=1.)
+# flows[positive_mask] = -flows[positive_mask]
+#
+# # embedding normalization
+# if self.embed_norm:
+# embeddings /= torch.norm(embeddings, p=2, dim=1, keepdim=True)
+#
+# # Spatially random sampling (512 samples)
+# flows_flatten = flows.view(flows.shape[0], 2, -1)
+# random_locations = Variable(torch.from_numpy(np.array(random.sample(range(flows_flatten.shape[2]), self.sampling_size))).long().cuda())
+# flows_sample = torch.index_select(flows_flatten, 2, random_locations)
+#
+# # K_f
+# k_f = self.epsilon + torch.norm(torch.unsqueeze(flows_sample, dim=-1).permute(0, 3, 2, 1) -
+# torch.unsqueeze(flows_sample, dim=-1).permute(0, 2, 3, 1), p=2, dim=3,
+# keepdim=False) ** 2
+# exp_k_f = torch.exp(-k_f / 2. / self.sigma)
+#
+#
+# # mask
+# eye = Variable(torch.unsqueeze(torch.eye(k_f.shape[1]), dim=0).cuda())
+# mask = torch.ones_like(exp_k_f) - eye
+#
+# # S_f
+# masked_exp_k_f = torch.mul(mask, exp_k_f) + eye
+# s_f = masked_exp_k_f / torch.sum(masked_exp_k_f, dim=1, keepdim=True)
+#
+# # K_theta
+# embeddings_flatten = embeddings.view(embeddings.shape[0], embeddings.shape[1], -1)
+# embeddings_sample = torch.index_select(embeddings_flatten, 2, random_locations)
+# embeddings_sample_norm = torch.norm(embeddings_sample, p=2, dim=1, keepdim=True)
+# k_theta = 0.25 * (torch.matmul(embeddings_sample.permute(0, 2, 1), embeddings_sample)) / (self.epsilon + torch.matmul(embeddings_sample_norm.permute(0, 2, 1), embeddings_sample_norm))
+# exp_k_theta = torch.exp(k_theta)
+#
+# # S_theta
+# masked_exp_k_theta = torch.mul(mask, exp_k_theta) + math.exp(-0.75) * eye
+# s_theta = masked_exp_k_theta / torch.sum(masked_exp_k_theta, dim=1, keepdim=True)
+#
+# # loss
+# loss = -torch.mean(torch.mul(s_f, torch.log(s_theta)))
+#
+# return loss
+
+class CrossPixelSimilarityLoss():
+ '''
+ Modified from: https://github.com/lppllppl920/Challenge2018/blob/master/loss.py
+ '''
+ def __init__(self, sigma=0.01, sampling_size=512):
+ self.sigma = sigma
+ self.sampling_size = sampling_size
+ self.epsilon = 1.0e-15
+ self.embed_norm = True # loss does not decrease no matter it is true or false.
+
+ def __call__(self, embeddings, flows):
+ '''
+ embedding: Variable Nx256xHxW (not hyper-column)
+ flows: Variable Nx2xHxW
+ '''
+ assert flows.size(1) == 2
+
+ # flow normalization
+ positive_mask = (flows > 0)
+ flows = -torch.clamp(torch.log(torch.abs(flows) + 1) / math.log(50. + 1), max=1.)
+ flows[positive_mask] = -flows[positive_mask]
+
+ # embedding normalization
+ if self.embed_norm:
+ embeddings /= torch.norm(embeddings, p=2, dim=1, keepdim=True)
+
+ # Spatially random sampling (512 samples)
+ flows_flatten = flows.view(flows.shape[0], 2, -1)
+ random_locations = Variable(torch.from_numpy(np.array(random.sample(range(flows_flatten.shape[2]), self.sampling_size))).long().cuda())
+ flows_sample = torch.index_select(flows_flatten, 2, random_locations)
+
+ # K_f
+ k_f = self.epsilon + torch.norm(torch.unsqueeze(flows_sample, dim=-1).permute(0, 3, 2, 1) -
+ torch.unsqueeze(flows_sample, dim=-1).permute(0, 2, 3, 1), p=2, dim=3,
+ keepdim=False) ** 2
+ exp_k_f = torch.exp(-k_f / 2. / self.sigma)
+
+
+ # mask
+ eye = Variable(torch.unsqueeze(torch.eye(k_f.shape[1]), dim=0).cuda())
+ mask = torch.ones_like(exp_k_f) - eye
+
+ # S_f
+ masked_exp_k_f = torch.mul(mask, exp_k_f) + eye
+ s_f = masked_exp_k_f / torch.sum(masked_exp_k_f, dim=1, keepdim=True)
+
+ # K_theta
+ embeddings_flatten = embeddings.view(embeddings.shape[0], embeddings.shape[1], -1)
+ embeddings_sample = torch.index_select(embeddings_flatten, 2, random_locations)
+ embeddings_sample_norm = torch.norm(embeddings_sample, p=2, dim=1, keepdim=True)
+ k_theta = 0.25 * (torch.matmul(embeddings_sample.permute(0, 2, 1), embeddings_sample)) / (self.epsilon + torch.matmul(embeddings_sample_norm.permute(0, 2, 1), embeddings_sample_norm))
+ exp_k_theta = torch.exp(k_theta)
+
+ # S_theta
+ masked_exp_k_theta = torch.mul(mask, exp_k_theta) + eye
+ s_theta = masked_exp_k_theta / torch.sum(masked_exp_k_theta, dim=1, keepdim=True)
+
+ # loss
+ loss = -torch.mean(torch.mul(s_f, torch.log(s_theta)))
+
+ return loss
+
+
+class CrossPixelSimilarityFullLoss():
+ '''
+ Modified from: https://github.com/lppllppl920/Challenge2018/blob/master/loss.py
+ '''
+ def __init__(self, sigma=0.01):
+ self.sigma = sigma
+ self.epsilon = 1.0e-15
+ self.embed_norm = True # loss does not decrease no matter it is true or false.
+
+ def __call__(self, embeddings, flows):
+ '''
+ embedding: Variable Nx256xHxW (not hyper-column)
+ flows: Variable Nx2xHxW
+ '''
+ assert flows.size(1) == 2
+
+ # downsample flow
+ factor = flows.shape[2] // embeddings.shape[2]
+ flows = nn.functional.avg_pool2d(flows, factor, factor)
+ assert flows.shape[2] == embeddings.shape[2]
+
+ # flow normalization
+ positive_mask = (flows > 0)
+ flows = -torch.clamp(torch.log(torch.abs(flows) + 1) / math.log(50. + 1), max=1.)
+ flows[positive_mask] = -flows[positive_mask]
+
+ # embedding normalization
+ if self.embed_norm:
+ embeddings /= torch.norm(embeddings, p=2, dim=1, keepdim=True)
+
+ # Spatially random sampling (512 samples)
+ flows_flatten = flows.view(flows.shape[0], 2, -1)
+ #random_locations = Variable(torch.from_numpy(np.array(random.sample(range(flows_flatten.shape[2]), self.sampling_size))).long().cuda())
+ #flows_sample = torch.index_select(flows_flatten, 2, random_locations)
+
+ # K_f
+ k_f = self.epsilon + torch.norm(torch.unsqueeze(flows_flatten, dim=-1).permute(0, 3, 2, 1) -
+ torch.unsqueeze(flows_flatten, dim=-1).permute(0, 2, 3, 1), p=2, dim=3,
+ keepdim=False) ** 2
+ exp_k_f = torch.exp(-k_f / 2. / self.sigma)
+
+
+ # mask
+ eye = Variable(torch.unsqueeze(torch.eye(k_f.shape[1]), dim=0).cuda())
+ mask = torch.ones_like(exp_k_f) - eye
+
+ # S_f
+ masked_exp_k_f = torch.mul(mask, exp_k_f) + eye
+ s_f = masked_exp_k_f / torch.sum(masked_exp_k_f, dim=1, keepdim=True)
+
+ # K_theta
+ embeddings_flatten = embeddings.view(embeddings.shape[0], embeddings.shape[1], -1)
+ #embeddings_sample = torch.index_select(embeddings_flatten, 2, random_locations)
+ embeddings_flatten_norm = torch.norm(embeddings_flatten, p=2, dim=1, keepdim=True)
+ k_theta = 0.25 * (torch.matmul(embeddings_flatten.permute(0, 2, 1), embeddings_flatten)) / (self.epsilon + torch.matmul(embeddings_flatten_norm.permute(0, 2, 1), embeddings_flatten_norm))
+ exp_k_theta = torch.exp(k_theta)
+
+ # S_theta
+ masked_exp_k_theta = torch.mul(mask, exp_k_theta) + eye
+ s_theta = masked_exp_k_theta / torch.sum(masked_exp_k_theta, dim=1, keepdim=True)
+
+ # loss
+ loss = -torch.mean(torch.mul(s_f, torch.log(s_theta)))
+
+ return loss
+
+
+def get_column(embeddings, index, full_size):
+ col = []
+ for embd in embeddings:
+ ind = (index.float() / full_size * embd.size(2)).long()
+ col.append(torch.index_select(embd.view(embd.shape[0], embd.shape[1], -1), 2, ind))
+ return torch.cat(col, dim=1) # N x coldim x sparsenum
+
+class CrossPixelSimilarityColumnLoss(nn.Module):
+ '''
+ Modified from: https://github.com/lppllppl920/Challenge2018/blob/master/loss.py
+ '''
+ def __init__(self, sigma=0.0036, sampling_size=512):
+ super(CrossPixelSimilarityColumnLoss, self).__init__()
+ self.sigma = sigma
+ self.sampling_size = sampling_size
+ self.epsilon = 1.0e-15
+ self.embed_norm = True # loss does not decrease no matter it is true or false.
+ self.mlp = nn.Sequential(
+ nn.Linear(96 + 96 + 384 + 256 + 4096, 256),
+ nn.ReLU(inplace=True),
+ nn.Linear(256, 16))
+
+ def forward(self, feats, flows):
+ '''
+ embedding: Variable Nx256xHxW (not hyper-column)
+ flows: Variable Nx2xHxW
+ '''
+ assert flows.size(1) == 2
+
+ # flow normalization
+ positive_mask = (flows > 0)
+ flows = -torch.clamp(torch.log(torch.abs(flows) + 1) / math.log(50. + 1), max=1.)
+ flows[positive_mask] = -flows[positive_mask]
+
+ # Spatially random sampling (512 samples)
+ flows_flatten = flows.view(flows.shape[0], 2, -1)
+ random_locations = Variable(torch.from_numpy(np.array(random.sample(range(flows_flatten.shape[2]), self.sampling_size))).long().cuda())
+ flows_sample = torch.index_select(flows_flatten, 2, random_locations)
+
+ # K_f
+ k_f = self.epsilon + torch.norm(torch.unsqueeze(flows_sample, dim=-1).permute(0, 3, 2, 1) -
+ torch.unsqueeze(flows_sample, dim=-1).permute(0, 2, 3, 1), p=2, dim=3,
+ keepdim=False) ** 2
+ exp_k_f = torch.exp(-k_f / 2. / self.sigma)
+
+
+ # mask
+ eye = Variable(torch.unsqueeze(torch.eye(k_f.shape[1]), dim=0).cuda())
+ mask = torch.ones_like(exp_k_f) - eye
+
+ # S_f
+ masked_exp_k_f = torch.mul(mask, exp_k_f) + eye
+ s_f = masked_exp_k_f / torch.sum(masked_exp_k_f, dim=1, keepdim=True)
+
+
+ # column
+ column = get_column(feats, random_locations, flows.shape[2])
+ embedding = self.mlp(column)
+ # K_theta
+ embedding_norm = torch.norm(embedding, p=2, dim=1, keepdim=True)
+ k_theta = 0.25 * (torch.matmul(embedding.permute(0, 2, 1), embedding)) / (self.epsilon + torch.matmul(embedding_norm.permute(0, 2, 1), embedding_norm))
+ exp_k_theta = torch.exp(k_theta)
+
+ # S_theta
+ masked_exp_k_theta = torch.mul(mask, exp_k_theta) + math.exp(-0.75) * eye
+ s_theta = masked_exp_k_theta / torch.sum(masked_exp_k_theta, dim=1, keepdim=True)
+
+ # loss
+ loss = -torch.mean(torch.mul(s_f, torch.log(s_theta)))
+
+ return loss
+
+
+def print_info(name, var):
+ print(name, var.size(), torch.max(var).data.cpu()[0], torch.min(var).data.cpu()[0], torch.mean(var).data.cpu()[0])
+
+
+def MaskL1Loss(input, target, mask):
+ input_size = input.size()
+ res = torch.sum(torch.abs(input * mask - target * mask))
+ total = torch.sum(mask).item()
+ if total > 0:
+ res = res / (total * input_size[1])
+ return res
diff --git a/models/cmp/models/__init__.py b/models/cmp/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..168e1aae54937d9c23f4f40eae871b7fd73dc5c8
--- /dev/null
+++ b/models/cmp/models/__init__.py
@@ -0,0 +1,4 @@
+from .single_stage_model import *
+from .cmp import *
+from . import modules
+from . import backbone
diff --git a/models/cmp/models/backbone/__init__.py b/models/cmp/models/backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..eea305c40902faaf491f9ed7ca70a56c0b9ae7fb
--- /dev/null
+++ b/models/cmp/models/backbone/__init__.py
@@ -0,0 +1,2 @@
+from .resnet import *
+from .alexnet import *
diff --git a/models/cmp/models/backbone/alexnet.py b/models/cmp/models/backbone/alexnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ac39a8d8af096e9363854a2e7d720623ecd73e
--- /dev/null
+++ b/models/cmp/models/backbone/alexnet.py
@@ -0,0 +1,83 @@
+import torch.nn as nn
+import math
+
+class AlexNetBN_FCN(nn.Module):
+
+ def __init__(self, output_dim=256, stride=[4, 2, 2, 2], dilation=[1, 1], padding=[1, 1]):
+ super(AlexNetBN_FCN, self).__init__()
+ BN = nn.BatchNorm2d
+
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(3, 96, kernel_size=11, stride=stride[0], padding=5),
+ BN(96),
+ nn.ReLU(inplace=True))
+ self.pool1 = nn.MaxPool2d(kernel_size=3, stride=stride[1], padding=1)
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(96, 256, kernel_size=5, padding=2),
+ BN(256),
+ nn.ReLU(inplace=True))
+ self.pool2 = nn.MaxPool2d(kernel_size=3, stride=stride[2], padding=1)
+ self.conv3 = nn.Sequential(
+ nn.Conv2d(256, 384, kernel_size=3, padding=1),
+ BN(384),
+ nn.ReLU(inplace=True))
+ self.conv4 = nn.Sequential(
+ nn.Conv2d(384, 384, kernel_size=3, padding=padding[0], dilation=dilation[0]),
+ BN(384),
+ nn.ReLU(inplace=True))
+ self.conv5 = nn.Sequential(
+ nn.Conv2d(384, 256, kernel_size=3, padding=padding[1], dilation=dilation[1]),
+ BN(256),
+ nn.ReLU(inplace=True))
+ self.pool5 = nn.MaxPool2d(kernel_size=3, stride=stride[3], padding=1)
+
+ self.fc6 = nn.Sequential(
+ nn.Conv2d(256, 4096, kernel_size=3, stride=1, padding=1),
+ BN(4096),
+ nn.ReLU(inplace=True))
+ self.drop6 = nn.Dropout(0.5)
+ self.fc7 = nn.Sequential(
+ nn.Conv2d(4096, 4096, kernel_size=1, stride=1, padding=0),
+ BN(4096),
+ nn.ReLU(inplace=True))
+ self.drop7 = nn.Dropout(0.5)
+ self.conv8 = nn.Conv2d(4096, output_dim, kernel_size=1)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ fan_in = m.out_channels * m.kernel_size[0] * m.kernel_size[1]
+ scale = math.sqrt(2. / fan_in)
+ m.weight.data.uniform_(-scale, scale)
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def forward(self, x, ret_feat=False):
+ if ret_feat:
+ raise NotImplemented
+ x = self.conv1(x)
+ x = self.pool1(x)
+ x = self.conv2(x)
+ x = self.pool2(x)
+ x = self.conv3(x)
+ x = self.conv4(x)
+ x = self.conv5(x)
+ x = self.pool5(x)
+ x = self.fc6(x)
+ x = self.drop6(x)
+ x = self.fc7(x)
+ x = self.drop7(x)
+ x = self.conv8(x)
+ return x
+
+def alexnet_fcn_32x(output_dim, pretrained=False, **kwargs):
+ assert pretrained == False
+ model = AlexNetBN_FCN(output_dim=output_dim, **kwargs)
+ return model
+
+def alexnet_fcn_8x(output_dim, use_ppm=False, pretrained=False, **kwargs):
+ assert pretrained == False
+ model = AlexNetBN_FCN(output_dim=output_dim, stride=[2, 2, 2, 1], **kwargs)
+ return model
diff --git a/models/cmp/models/backbone/resnet.py b/models/cmp/models/backbone/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..126ef386b7abba0ff9d09b2f051494ada0cfab30
--- /dev/null
+++ b/models/cmp/models/backbone/resnet.py
@@ -0,0 +1,201 @@
+import torch.nn as nn
+import math
+import torch.utils.model_zoo as model_zoo
+
+BN = None
+
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+}
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ "3x3 convolution with padding"
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = BN(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = BN(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = BN(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+ self.bn2 = BN(planes)
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+ self.bn3 = BN(planes * 4)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(self, output_dim, block, layers):
+
+ global BN
+
+ BN = nn.BatchNorm2d
+
+ self.inplanes = 64
+ super(ResNet, self).__init__()
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = BN(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+
+ self.conv5 = nn.Conv2d(2048, output_dim, kernel_size=1)
+
+ ## dilation
+ for n, m in self.layer3.named_modules():
+ if 'conv2' in n:
+ m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
+ elif 'downsample.0' in n:
+ m.stride = (1, 1)
+ for n, m in self.layer4.named_modules():
+ if 'conv2' in n:
+ m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
+ elif 'downsample.0' in n:
+ m.stride = (1, 1)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ BN(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, img, ret_feat=False):
+ x = self.conv1(img) # 1/2
+ x = self.bn1(x)
+ conv1 = self.relu(x) # 1/2
+ pool1 = self.maxpool(conv1) # 1/4
+
+ layer1 = self.layer1(pool1) # 1/4
+ layer2 = self.layer2(layer1) # 1/8
+ layer3 = self.layer3(layer2) # 1/8
+ layer4 = self.layer4(layer3) # 1/8
+ out = self.conv5(layer4)
+
+ if ret_feat:
+ return out, [img, conv1, layer1] # 3, 64, 256
+ else:
+ return out
+
+def resnet18(output_dim, pretrained=False):
+ model = ResNet(output_dim, BasicBlock, [2, 2, 2, 2])
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
+ return model
+
+
+def resnet34(output_dim, pretrained=False):
+ model = ResNet(output_dim, BasicBlock, [3, 4, 6, 3])
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
+ return model
+
+
+def resnet50(output_dim, pretrained=False):
+ model = ResNet(output_dim, Bottleneck, [3, 4, 6, 3])
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)
+ return model
+
+def resnet101(output_dim, pretrained=False):
+ model = ResNet(output_dim, Bottleneck, [3, 4, 23, 3])
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet101']), strict=False)
+ return model
+
+
+def resnet152(output_dim, pretrained=False):
+ model = ResNet(output_dim, Bottleneck, [3, 8, 36, 3])
+ if pretrained:
+ model.load_state_dict(model_zoo.load_url(model_urls['resnet152']), strict=False)
+ return model
diff --git a/models/cmp/models/cmp.py b/models/cmp/models/cmp.py
new file mode 100644
index 0000000000000000000000000000000000000000..11987b4b7c14a2a2e7a2ad01b34e84f7f77bc03f
--- /dev/null
+++ b/models/cmp/models/cmp.py
@@ -0,0 +1,64 @@
+import torch
+import torch.nn as nn
+
+import models.cmp.losses as losses
+import models.cmp.utils as utils
+
+from . import SingleStageModel
+
+class CMP(SingleStageModel):
+
+ def __init__(self, params, dist_model=False):
+ super(CMP, self).__init__(params, dist_model)
+ model_params = params['module']
+
+ # define loss
+ if model_params['flow_criterion'] == 'L1':
+ self.flow_criterion = nn.SmoothL1Loss()
+ elif model_params['flow_criterion'] == 'L2':
+ self.flow_criterion = nn.MSELoss()
+ elif model_params['flow_criterion'] == 'DiscreteLoss':
+ self.flow_criterion = losses.DiscreteLoss(
+ nbins=model_params['nbins'], fmax=model_params['fmax'])
+ else:
+ raise Exception("No such flow loss: {}".format(model_params['flow_criterion']))
+
+ self.fuser = utils.Fuser(nbins=model_params['nbins'],
+ fmax=model_params['fmax'])
+ self.model_params = model_params
+
+ def eval(self, ret_loss=True):
+ with torch.no_grad():
+ cmp_output = self.model(self.image_input, self.sparse_input)
+ if self.model_params['flow_criterion'] == "DiscreteLoss":
+ self.flow = self.fuser.convert_flow(cmp_output)
+ else:
+ self.flow = cmp_output
+ if self.flow.shape[2] != self.image_input.shape[2]:
+ self.flow = nn.functional.interpolate(
+ self.flow, size=self.image_input.shape[2:4],
+ mode="bilinear", align_corners=True)
+
+ ret_tensors = {
+ 'flow_tensors': [self.flow, self.flow_target],
+ 'common_tensors': [],
+ 'rgb_tensors': []} # except for image_input
+
+ if ret_loss:
+ if cmp_output.shape[2] != self.flow_target.shape[2]:
+ cmp_output = nn.functional.interpolate(
+ cmp_output, size=self.flow_target.shape[2:4],
+ mode="bilinear", align_corners=True)
+ loss_flow = self.flow_criterion(cmp_output, self.flow_target) / self.world_size
+ return ret_tensors, {'loss_flow': loss_flow}
+ else:
+ return ret_tensors
+
+ def step(self):
+ cmp_output = self.model(self.image_input, self.sparse_input)
+ loss_flow = self.flow_criterion(cmp_output, self.flow_target) / self.world_size
+ self.optim.zero_grad()
+ loss_flow.backward()
+ utils.average_gradients(self.model)
+ self.optim.step()
+ return {'loss_flow': loss_flow}
diff --git a/models/cmp/models/modules/__init__.py b/models/cmp/models/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..eff11cb76475f299be4ce9641182686866e00f99
--- /dev/null
+++ b/models/cmp/models/modules/__init__.py
@@ -0,0 +1,6 @@
+from .warp import *
+from .others import *
+from .shallownet import *
+from .decoder import *
+from .cmp import *
+
diff --git a/models/cmp/models/modules/cmp.py b/models/cmp/models/modules/cmp.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c5130353000c6f971425a37c2588e45d8710664
--- /dev/null
+++ b/models/cmp/models/modules/cmp.py
@@ -0,0 +1,37 @@
+import torch
+import torch.nn as nn
+import models.cmp.models as models
+
+
+class CMP(nn.Module):
+
+ def __init__(self, params):
+ super(CMP, self).__init__()
+ img_enc_dim = params['img_enc_dim']
+ sparse_enc_dim = params['sparse_enc_dim']
+ output_dim = params['output_dim']
+ pretrained = params['pretrained_image_encoder']
+ decoder_combo = params['decoder_combo']
+ self.skip_layer = params['skip_layer']
+ if self.skip_layer:
+ assert params['flow_decoder'] == "MotionDecoderSkipLayer"
+
+ self.image_encoder = models.backbone.__dict__[params['image_encoder']](
+ img_enc_dim, pretrained)
+ self.flow_encoder = models.modules.__dict__[params['sparse_encoder']](
+ sparse_enc_dim)
+ self.flow_decoder = models.modules.__dict__[params['flow_decoder']](
+ input_dim=img_enc_dim+sparse_enc_dim,
+ output_dim=output_dim, combo=decoder_combo)
+
+ def forward(self, image, sparse):
+ sparse_enc = self.flow_encoder(sparse)
+ if self.skip_layer:
+ img_enc, skip_feat = self.image_encoder(image, ret_feat=True)
+ flow_dec = self.flow_decoder(torch.cat((img_enc, sparse_enc), dim=1), skip_feat)
+ else:
+ img_enc = self.image_encoder(image)
+ flow_dec = self.flow_decoder(torch.cat((img_enc, sparse_enc), dim=1))
+ return flow_dec
+
+
diff --git a/models/cmp/models/modules/decoder.py b/models/cmp/models/modules/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f1c0e395f55f4e348410c98d2d37a13441d7139
--- /dev/null
+++ b/models/cmp/models/modules/decoder.py
@@ -0,0 +1,358 @@
+import torch
+import torch.nn as nn
+import math
+
+class MotionDecoderPlain(nn.Module):
+
+ def __init__(self, input_dim=512, output_dim=2, combo=[1,2,4]):
+ super(MotionDecoderPlain, self).__init__()
+ BN = nn.BatchNorm2d
+
+ self.combo = combo
+ for c in combo:
+ assert c in [1,2,4,8], "invalid combo: {}".format(combo)
+
+ if 1 in combo:
+ self.decoder1 = nn.Sequential(
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ if 2 in combo:
+ self.decoder2 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ if 4 in combo:
+ self.decoder4 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=4, stride=4),
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ if 8 in combo:
+ self.decoder8 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=8, stride=8),
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ self.head = nn.Conv2d(128 * len(self.combo), output_dim, kernel_size=1, padding=0)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ fan_in = m.out_channels * m.kernel_size[0] * m.kernel_size[1]
+ scale = math.sqrt(2. / fan_in)
+ m.weight.data.uniform_(-scale, scale)
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ if not m.weight is None:
+ m.weight.data.fill_(1)
+ if not m.bias is None:
+ m.bias.data.zero_()
+
+ def forward(self, x):
+
+ cat_list = []
+ if 1 in self.combo:
+ x1 = self.decoder1(x)
+ cat_list.append(x1)
+ if 2 in self.combo:
+ x2 = nn.functional.interpolate(
+ self.decoder2(x), size=(x.size(2), x.size(3)),
+ mode="bilinear", align_corners=True)
+ cat_list.append(x2)
+ if 4 in self.combo:
+ x4 = nn.functional.interpolate(
+ self.decoder4(x), size=(x.size(2), x.size(3)),
+ mode="bilinear", align_corners=True)
+ cat_list.append(x4)
+ if 8 in self.combo:
+ x8 = nn.functional.interpolate(
+ self.decoder8(x), size=(x.size(2), x.size(3)),
+ mode="bilinear", align_corners=True)
+ cat_list.append(x8)
+
+ cat = torch.cat(cat_list, dim=1)
+ flow = self.head(cat)
+ return flow
+
+
+class MotionDecoderSkipLayer(nn.Module):
+
+ def __init__(self, input_dim=512, output_dim=2, combo=[1,2,4,8]):
+ super(MotionDecoderSkipLayer, self).__init__()
+
+ BN = nn.BatchNorm2d
+
+ self.decoder1 = nn.Sequential(
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ self.decoder2 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ self.decoder4 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=4, stride=4),
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ self.decoder8 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=8, stride=8),
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ self.fusion8 = nn.Sequential(
+ nn.Conv2d(512, 256, kernel_size=3, padding=1),
+ BN(256),
+ nn.ReLU(inplace=True))
+
+ self.skipconv4 = nn.Sequential(
+ nn.Conv2d(256, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+ self.fusion4 = nn.Sequential(
+ nn.Conv2d(256 + 128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ self.skipconv2 = nn.Sequential(
+ nn.Conv2d(64, 32, kernel_size=3, padding=1),
+ BN(32),
+ nn.ReLU(inplace=True))
+ self.fusion2 = nn.Sequential(
+ nn.Conv2d(128 + 32, 64, kernel_size=3, padding=1),
+ BN(64),
+ nn.ReLU(inplace=True))
+
+ self.head = nn.Conv2d(64, output_dim, kernel_size=1, padding=0)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ fan_in = m.out_channels * m.kernel_size[0] * m.kernel_size[1]
+ scale = math.sqrt(2. / fan_in)
+ m.weight.data.uniform_(-scale, scale)
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ if not m.weight is None:
+ m.weight.data.fill_(1)
+ if not m.bias is None:
+ m.bias.data.zero_()
+
+ def forward(self, x, skip_feat):
+ layer1, layer2, layer4 = skip_feat
+
+ x1 = self.decoder1(x)
+ x2 = nn.functional.interpolate(
+ self.decoder2(x), size=(x1.size(2), x1.size(3)),
+ mode="bilinear", align_corners=True)
+ x4 = nn.functional.interpolate(
+ self.decoder4(x), size=(x1.size(2), x1.size(3)),
+ mode="bilinear", align_corners=True)
+ x8 = nn.functional.interpolate(
+ self.decoder8(x), size=(x1.size(2), x1.size(3)),
+ mode="bilinear", align_corners=True)
+ cat = torch.cat([x1, x2, x4, x8], dim=1)
+ f8 = self.fusion8(cat)
+
+ f8_up = nn.functional.interpolate(
+ f8, size=(layer4.size(2), layer4.size(3)),
+ mode="bilinear", align_corners=True)
+ f4 = self.fusion4(torch.cat([f8_up, self.skipconv4(layer4)], dim=1))
+
+ f4_up = nn.functional.interpolate(
+ f4, size=(layer2.size(2), layer2.size(3)),
+ mode="bilinear", align_corners=True)
+ f2 = self.fusion2(torch.cat([f4_up, self.skipconv2(layer2)], dim=1))
+
+ flow = self.head(f2)
+ return flow
+
+
+class MotionDecoderFlowNet(nn.Module):
+
+ def __init__(self, input_dim=512, output_dim=2, combo=[1,2,4,8]):
+ super(MotionDecoderFlowNet, self).__init__()
+ global BN
+
+ BN = nn.BatchNorm2d
+
+ self.decoder1 = nn.Sequential(
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ self.decoder2 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=2, stride=2),
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ self.decoder4 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=4, stride=4),
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ self.decoder8 = nn.Sequential(
+ nn.MaxPool2d(kernel_size=8, stride=8),
+ nn.Conv2d(input_dim, 128, kernel_size=3, padding=1, stride=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
+ BN(128),
+ nn.ReLU(inplace=True))
+
+ self.fusion8 = nn.Sequential(
+ nn.Conv2d(512, 256, kernel_size=3, padding=1),
+ BN(256),
+ nn.ReLU(inplace=True))
+
+ # flownet head
+ self.predict_flow8 = predict_flow(256, output_dim)
+ self.predict_flow4 = predict_flow(384 + output_dim, output_dim)
+ self.predict_flow2 = predict_flow(192 + output_dim, output_dim)
+ self.predict_flow1 = predict_flow(67 + output_dim, output_dim)
+
+ self.upsampled_flow8_to_4 = nn.ConvTranspose2d(
+ output_dim, output_dim, 4, 2, 1, bias=False)
+ self.upsampled_flow4_to_2 = nn.ConvTranspose2d(
+ output_dim, output_dim, 4, 2, 1, bias=False)
+ self.upsampled_flow2_to_1 = nn.ConvTranspose2d(
+ output_dim, output_dim, 4, 2, 1, bias=False)
+
+ self.deconv8 = deconv(256, 128)
+ self.deconv4 = deconv(384 + output_dim, 128)
+ self.deconv2 = deconv(192 + output_dim, 64)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ fan_in = m.out_channels * m.kernel_size[0] * m.kernel_size[1]
+ scale = math.sqrt(2. / fan_in)
+ m.weight.data.uniform_(-scale, scale)
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ if not m.weight is None:
+ m.weight.data.fill_(1)
+ if not m.bias is None:
+ m.bias.data.zero_()
+
+ def forward(self, x, skip_feat):
+ layer1, layer2, layer4 = skip_feat # 3, 64, 256
+
+ # propagation nets
+ x1 = self.decoder1(x)
+ x2 = nn.functional.interpolate(
+ self.decoder2(x), size=(x1.size(2), x1.size(3)),
+ mode="bilinear", align_corners=True)
+ x4 = nn.functional.interpolate(
+ self.decoder4(x), size=(x1.size(2), x1.size(3)),
+ mode="bilinear", align_corners=True)
+ x8 = nn.functional.interpolate(
+ self.decoder8(x), size=(x1.size(2), x1.size(3)),
+ mode="bilinear", align_corners=True)
+ cat = torch.cat([x1, x2, x4, x8], dim=1)
+ feat8 = self.fusion8(cat) # 256
+
+ # flownet head
+ flow8 = self.predict_flow8(feat8)
+ flow8_up = self.upsampled_flow8_to_4(flow8)
+ out_deconv8 = self.deconv8(feat8) # 128
+
+ concat4 = torch.cat((layer4, out_deconv8, flow8_up), dim=1) # 394 + out
+ flow4 = self.predict_flow4(concat4)
+ flow4_up = self.upsampled_flow4_to_2(flow4)
+ out_deconv4 = self.deconv4(concat4) # 128
+
+ concat2 = torch.cat((layer2, out_deconv4, flow4_up), dim=1) # 192 + out
+ flow2 = self.predict_flow2(concat2)
+ flow2_up = self.upsampled_flow2_to_1(flow2)
+ out_deconv2 = self.deconv2(concat2) # 64
+
+ concat1 = torch.cat((layer1, out_deconv2, flow2_up), dim=1) # 67 + out
+ flow1 = self.predict_flow1(concat1)
+
+ return [flow1, flow2, flow4, flow8]
+
+
+def predict_flow(in_planes, out_planes):
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3,
+ stride=1, padding=1, bias=True)
+
+
+def deconv(in_planes, out_planes):
+ return nn.Sequential(
+ nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4,
+ stride=2, padding=1, bias=True),
+ nn.LeakyReLU(0.1, inplace=True)
+ )
+
+
diff --git a/models/cmp/models/modules/others.py b/models/cmp/models/modules/others.py
new file mode 100644
index 0000000000000000000000000000000000000000..591ce94f7d10db49fb3209d4d74a4a973f9a6cf5
--- /dev/null
+++ b/models/cmp/models/modules/others.py
@@ -0,0 +1,11 @@
+import torch.nn as nn
+
+class FixModule(nn.Module):
+
+ def __init__(self, m):
+ super(FixModule, self).__init__()
+ self.module = m
+
+ def forward(self, *args, **kwargs):
+ return self.module(*args, **kwargs)
+
diff --git a/models/cmp/models/modules/shallownet.py b/models/cmp/models/modules/shallownet.py
new file mode 100644
index 0000000000000000000000000000000000000000..b37fedd26b5096e34c0e6303f69e54b3d58c39b4
--- /dev/null
+++ b/models/cmp/models/modules/shallownet.py
@@ -0,0 +1,49 @@
+import torch.nn as nn
+import math
+
+class ShallowNet(nn.Module):
+
+ def __init__(self, input_dim=4, output_dim=16, stride=[2, 2, 2]):
+ super(ShallowNet, self).__init__()
+ global BN
+
+ BN = nn.BatchNorm2d
+
+ self.features = nn.Sequential(
+ nn.Conv2d(input_dim, 16, kernel_size=5, stride=stride[0], padding=2),
+ nn.BatchNorm2d(16),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=stride[1], stride=stride[1]),
+ nn.Conv2d(16, output_dim, kernel_size=3, padding=1),
+ nn.BatchNorm2d(output_dim),
+ nn.ReLU(inplace=True),
+ nn.AvgPool2d(kernel_size=stride[2], stride=stride[2]),
+ )
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ fan_in = m.out_channels * m.kernel_size[0] * m.kernel_size[1]
+ scale = math.sqrt(2. / fan_in)
+ m.weight.data.uniform_(-scale, scale)
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ if not m.weight is None:
+ m.weight.data.fill_(1)
+ if not m.bias is None:
+ m.bias.data.zero_()
+
+ def forward(self, x):
+ x = self.features(x)
+ return x
+
+
+def shallownet8x(output_dim):
+ model = ShallowNet(output_dim=output_dim, stride=[2,2,2])
+ return model
+
+def shallownet32x(output_dim, **kwargs):
+ model = ShallowNet(output_dim=output_dim, stride=[2,2,8])
+ return model
+
+
+
diff --git a/models/cmp/models/modules/warp.py b/models/cmp/models/modules/warp.py
new file mode 100644
index 0000000000000000000000000000000000000000..d32dc5db787345c9d2622fa6f65d463dd78ef8ba
--- /dev/null
+++ b/models/cmp/models/modules/warp.py
@@ -0,0 +1,68 @@
+import torch
+import torch.nn as nn
+
+class WarpingLayerBWFlow(nn.Module):
+
+ def __init__(self):
+ super(WarpingLayerBWFlow, self).__init__()
+
+ def forward(self, image, flow):
+ flow_for_grip = torch.zeros_like(flow)
+ flow_for_grip[:,0,:,:] = flow[:,0,:,:] / ((flow.size(3) - 1.0) / 2.0)
+ flow_for_grip[:,1,:,:] = flow[:,1,:,:] / ((flow.size(2) - 1.0) / 2.0)
+
+ torchHorizontal = torch.linspace(
+ -1.0, 1.0, image.size(3)).view(
+ 1, 1, 1, image.size(3)).expand(
+ image.size(0), 1, image.size(2), image.size(3))
+ torchVertical = torch.linspace(
+ -1.0, 1.0, image.size(2)).view(
+ 1, 1, image.size(2), 1).expand(
+ image.size(0), 1, image.size(2), image.size(3))
+ grid = torch.cat([torchHorizontal, torchVertical], 1).cuda()
+
+ grid = (grid + flow_for_grip).permute(0, 2, 3, 1)
+ return torch.nn.functional.grid_sample(image, grid)
+
+
+class WarpingLayerFWFlow(nn.Module):
+
+ def __init__(self):
+ super(WarpingLayerFWFlow, self).__init__()
+ self.initialized = False
+
+ def forward(self, image, flow, ret_mask = False):
+ n, h, w = image.size(0), image.size(2), image.size(3)
+
+ if not self.initialized or n != self.meshx.shape[0] or h * w != self.meshx.shape[1]:
+ self.meshx = torch.arange(w).view(1, 1, w).expand(
+ n, h, w).contiguous().view(n, -1).cuda()
+ self.meshy = torch.arange(h).view(1, h, 1).expand(
+ n, h, w).contiguous().view(n, -1).cuda()
+ self.warped_image = torch.zeros((n, 3, h, w), dtype=torch.float32).cuda()
+ if ret_mask:
+ self.hole_mask = torch.ones((n, 1, h, w), dtype=torch.float32).cuda()
+ self.initialized = True
+
+ v = (flow[:,0,:,:] ** 2 + flow[:,1,:,:] ** 2).view(n, -1)
+ _, sortidx = torch.sort(v, dim=1)
+
+ warped_meshx = self.meshx + flow[:,0,:,:].long().view(n, -1)
+ warped_meshy = self.meshy + flow[:,1,:,:].long().view(n, -1)
+
+ warped_meshx = torch.clamp(warped_meshx, 0, w - 1)
+ warped_meshy = torch.clamp(warped_meshy, 0, h - 1)
+
+ self.warped_image.zero_()
+ if ret_mask:
+ self.hole_mask.fill_(1.)
+ for i in range(n):
+ for c in range(3):
+ ind = sortidx[i]
+ self.warped_image[i,c,warped_meshy[i][ind],warped_meshx[i][ind]] = image[i,c,self.meshy[i][ind],self.meshx[i][ind]]
+ if ret_mask:
+ self.hole_mask[i,0,warped_meshy[i],warped_meshx[i]] = 0.
+ if ret_mask:
+ return self.warped_image, self.hole_mask
+ else:
+ return self.warped_image
diff --git a/models/cmp/models/single_stage_model.py b/models/cmp/models/single_stage_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..79f5d4ab7ccffba99f72612ad6c77bf4dc3f2521
--- /dev/null
+++ b/models/cmp/models/single_stage_model.py
@@ -0,0 +1,72 @@
+import os
+import torch
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+
+import models.cmp.models as models
+import models.cmp.utils as utils
+
+
+class SingleStageModel(object):
+
+ def __init__(self, params, dist_model=False):
+ model_params = params['module']
+ self.model = models.modules.__dict__[params['module']['arch']](model_params)
+ utils.init_weights(self.model, init_type='xavier')
+ self.model.cuda()
+ if dist_model:
+ self.model = utils.DistModule(self.model)
+ self.world_size = dist.get_world_size()
+ else:
+ self.model = models.modules.FixModule(self.model)
+ self.world_size = 1
+
+ if params['optim'] == 'SGD':
+ self.optim = torch.optim.SGD(
+ self.model.parameters(), lr=params['lr'],
+ momentum=0.9, weight_decay=0.0001)
+ elif params['optim'] == 'Adam':
+ self.optim = torch.optim.Adam(
+ self.model.parameters(), lr=params['lr'],
+ betas=(params['beta1'], 0.999))
+ else:
+ raise Exception("No such optimizer: {}".format(params['optim']))
+
+ cudnn.benchmark = True
+
+ def set_input(self, image_input, sparse_input, flow_target=None, rgb_target=None):
+ self.image_input = image_input
+ self.sparse_input = sparse_input
+ self.flow_target = flow_target
+ self.rgb_target = rgb_target
+
+ def eval(self, ret_loss=True):
+ pass
+
+ def step(self):
+ pass
+
+ def load_state(self, path, Iter, resume=False):
+ path = os.path.join(path, "ckpt_iter_{}.pth.tar".format(Iter))
+
+ if resume:
+ utils.load_state(path, self.model, self.optim)
+ else:
+ utils.load_state(path, self.model)
+
+ def load_pretrain(self, load_path):
+ utils.load_state(load_path, self.model)
+
+ def save_state(self, path, Iter):
+ path = os.path.join(path, "ckpt_iter_{}.pth.tar".format(Iter))
+
+ torch.save({
+ 'step': Iter,
+ 'state_dict': self.model.state_dict(),
+ 'optimizer': self.optim.state_dict()}, path)
+
+ def switch_to(self, phase):
+ if phase == 'train':
+ self.model.train()
+ else:
+ self.model.eval()
diff --git a/models/cmp/utils/__init__.py b/models/cmp/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..29be9c14049e0540b324db3fc65eedf1b492358e
--- /dev/null
+++ b/models/cmp/utils/__init__.py
@@ -0,0 +1,6 @@
+from .common_utils import *
+from .data_utils import *
+from .distributed_utils import *
+from .visualize_utils import *
+from .scheduler import *
+from . import flowlib
diff --git a/models/cmp/utils/__pycache__/__init__.cpython-310.pyc b/models/cmp/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5872018627bb22fd3215842559fbaa93afd8ce4e
Binary files /dev/null and b/models/cmp/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/models/cmp/utils/__pycache__/__init__.cpython-38.pyc b/models/cmp/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de102f03a7a6867f25b01248be7b89f945b4c566
Binary files /dev/null and b/models/cmp/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/models/cmp/utils/__pycache__/__init__.cpython-39.pyc b/models/cmp/utils/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..99e5558aa4b180da35343f1239ddab5f38a77077
Binary files /dev/null and b/models/cmp/utils/__pycache__/__init__.cpython-39.pyc differ
diff --git a/models/cmp/utils/__pycache__/common_utils.cpython-310.pyc b/models/cmp/utils/__pycache__/common_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2522ee55cc0c2dceb9c8b51b6d5e718095c19bfa
Binary files /dev/null and b/models/cmp/utils/__pycache__/common_utils.cpython-310.pyc differ
diff --git a/models/cmp/utils/__pycache__/common_utils.cpython-38.pyc b/models/cmp/utils/__pycache__/common_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9eee12ec621d4270a048b28d91284d43e84ad044
Binary files /dev/null and b/models/cmp/utils/__pycache__/common_utils.cpython-38.pyc differ
diff --git a/models/cmp/utils/__pycache__/common_utils.cpython-39.pyc b/models/cmp/utils/__pycache__/common_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f1f12bfb246ddd2a7bc0e5ed5806cca9ab739c54
Binary files /dev/null and b/models/cmp/utils/__pycache__/common_utils.cpython-39.pyc differ
diff --git a/models/cmp/utils/__pycache__/data_utils.cpython-310.pyc b/models/cmp/utils/__pycache__/data_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6e57d83cd45f4a05a1933e19641f9ccc300c6007
Binary files /dev/null and b/models/cmp/utils/__pycache__/data_utils.cpython-310.pyc differ
diff --git a/models/cmp/utils/__pycache__/data_utils.cpython-38.pyc b/models/cmp/utils/__pycache__/data_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f54d673f0e39f5d705294c526d27c20ae9669d68
Binary files /dev/null and b/models/cmp/utils/__pycache__/data_utils.cpython-38.pyc differ
diff --git a/models/cmp/utils/__pycache__/data_utils.cpython-39.pyc b/models/cmp/utils/__pycache__/data_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..465e8f7c943f32423fa98fbb5181302f7d0690a3
Binary files /dev/null and b/models/cmp/utils/__pycache__/data_utils.cpython-39.pyc differ
diff --git a/models/cmp/utils/__pycache__/distributed_utils.cpython-310.pyc b/models/cmp/utils/__pycache__/distributed_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2ab6a077f6d359e4d441d5068ce82ae1f5686515
Binary files /dev/null and b/models/cmp/utils/__pycache__/distributed_utils.cpython-310.pyc differ
diff --git a/models/cmp/utils/__pycache__/distributed_utils.cpython-38.pyc b/models/cmp/utils/__pycache__/distributed_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0a568804f27744cc22698a76f44b9dc27ef56d20
Binary files /dev/null and b/models/cmp/utils/__pycache__/distributed_utils.cpython-38.pyc differ
diff --git a/models/cmp/utils/__pycache__/distributed_utils.cpython-39.pyc b/models/cmp/utils/__pycache__/distributed_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2fb989ee19f9cb4b9535e877cac3615124aa0385
Binary files /dev/null and b/models/cmp/utils/__pycache__/distributed_utils.cpython-39.pyc differ
diff --git a/models/cmp/utils/__pycache__/flowlib.cpython-310.pyc b/models/cmp/utils/__pycache__/flowlib.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d25af9b518c4e61bb70e3866c8a87bebab9d4608
Binary files /dev/null and b/models/cmp/utils/__pycache__/flowlib.cpython-310.pyc differ
diff --git a/models/cmp/utils/__pycache__/flowlib.cpython-38.pyc b/models/cmp/utils/__pycache__/flowlib.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..04bc3affab11b62af84636b4c27c6fe3a2cde805
Binary files /dev/null and b/models/cmp/utils/__pycache__/flowlib.cpython-38.pyc differ
diff --git a/models/cmp/utils/__pycache__/flowlib.cpython-39.pyc b/models/cmp/utils/__pycache__/flowlib.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dbbd666cb4db9c0ba6db201f7039833197905dae
Binary files /dev/null and b/models/cmp/utils/__pycache__/flowlib.cpython-39.pyc differ
diff --git a/models/cmp/utils/__pycache__/scheduler.cpython-310.pyc b/models/cmp/utils/__pycache__/scheduler.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1bd1b42a3ec0306854bf3b25cdb19a8fb8d2e2f7
Binary files /dev/null and b/models/cmp/utils/__pycache__/scheduler.cpython-310.pyc differ
diff --git a/models/cmp/utils/__pycache__/scheduler.cpython-38.pyc b/models/cmp/utils/__pycache__/scheduler.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3307fd8ec223b1f045791e37695ffd0fc53c9fa
Binary files /dev/null and b/models/cmp/utils/__pycache__/scheduler.cpython-38.pyc differ
diff --git a/models/cmp/utils/__pycache__/scheduler.cpython-39.pyc b/models/cmp/utils/__pycache__/scheduler.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8d7268c2b21201cfb84d5b93eff29d01a3df7608
Binary files /dev/null and b/models/cmp/utils/__pycache__/scheduler.cpython-39.pyc differ
diff --git a/models/cmp/utils/__pycache__/visualize_utils.cpython-310.pyc b/models/cmp/utils/__pycache__/visualize_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee676a2714e864c62795db6b2854d6a95c8d1aca
Binary files /dev/null and b/models/cmp/utils/__pycache__/visualize_utils.cpython-310.pyc differ
diff --git a/models/cmp/utils/__pycache__/visualize_utils.cpython-38.pyc b/models/cmp/utils/__pycache__/visualize_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5cbc7cfb8617bf57cfbea9bf2c7d6fcd4da7f873
Binary files /dev/null and b/models/cmp/utils/__pycache__/visualize_utils.cpython-38.pyc differ
diff --git a/models/cmp/utils/__pycache__/visualize_utils.cpython-39.pyc b/models/cmp/utils/__pycache__/visualize_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..50c1e38846ef99069b905ed9f2932dc600ba62f7
Binary files /dev/null and b/models/cmp/utils/__pycache__/visualize_utils.cpython-39.pyc differ
diff --git a/models/cmp/utils/common_utils.py b/models/cmp/utils/common_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a3862068c32b5094b7ee3caa045d250c1d63264
--- /dev/null
+++ b/models/cmp/utils/common_utils.py
@@ -0,0 +1,118 @@
+import os
+import logging
+import numpy as np
+
+import torch
+from torch.nn import init
+
+def init_weights(net, init_type='normal', init_gain=0.02):
+ """Initialize network weights.
+ Parameters:
+ net (network) -- network to be initialized
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
+ We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
+ work better for some applications. Feel free to try yourself.
+ """
+ def init_func(m): # define the initialization function
+ classname = m.__class__.__name__
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
+ if init_type == 'normal':
+ init.normal_(m.weight.data, 0.0, init_gain)
+ elif init_type == 'xavier':
+ init.xavier_normal_(m.weight.data, gain=init_gain)
+ elif init_type == 'kaiming':
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+ elif init_type == 'orthogonal':
+ init.orthogonal_(m.weight.data, gain=init_gain)
+ else:
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
+ if hasattr(m, 'bias') and m.bias is not None:
+ init.constant_(m.bias.data, 0.0)
+ elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
+ init.normal_(m.weight.data, 1.0, init_gain)
+ init.constant_(m.bias.data, 0.0)
+
+ net.apply(init_func) # apply the initialization function
+
+def create_logger(name, log_file, level=logging.INFO):
+ l = logging.getLogger(name)
+ formatter = logging.Formatter('[%(asctime)s] %(message)s')
+ fh = logging.FileHandler(log_file)
+ fh.setFormatter(formatter)
+ sh = logging.StreamHandler()
+ sh.setFormatter(formatter)
+ l.setLevel(level)
+ l.addHandler(fh)
+ l.addHandler(sh)
+ return l
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+ def __init__(self, length=0):
+ self.length = length
+ self.reset()
+
+ def reset(self):
+ if self.length > 0:
+ self.history = []
+ else:
+ self.count = 0
+ self.sum = 0.0
+ self.val = 0.0
+ self.avg = 0.0
+
+ def update(self, val):
+ if self.length > 0:
+ self.history.append(val)
+ if len(self.history) > self.length:
+ del self.history[0]
+
+ self.val = self.history[-1]
+ self.avg = np.mean(self.history)
+ else:
+ self.val = val
+ self.sum += val
+ self.count += 1
+ self.avg = self.sum / self.count
+
+def accuracy(output, target, topk=(1,)):
+ """Computes the precision@k for the specified values of k"""
+ maxk = max(topk)
+ batch_size = target.size(0)
+
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+
+ res = []
+ for k in topk:
+ correct_k = correct[:k].view(-1).float().sum(0, keepdims=True)
+ res.append(correct_k.mul_(100.0 / batch_size))
+ return res
+
+def load_state(path, model, optimizer=None):
+ def map_func(storage, location):
+ return storage.cuda()
+ if os.path.isfile(path):
+ print("=> loading checkpoint '{}'".format(path))
+ checkpoint = torch.load(path, map_location=map_func)
+ model.load_state_dict(checkpoint['state_dict'], strict=False)
+ ckpt_keys = set(checkpoint['state_dict'].keys())
+ own_keys = set(model.state_dict().keys())
+ missing_keys = own_keys - ckpt_keys
+ # print(ckpt_keys)
+ # print(own_keys)
+ for k in missing_keys:
+ print('caution: missing keys from checkpoint {}: {}'.format(path, k))
+
+ last_iter = checkpoint['step']
+ if optimizer != None:
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ print("=> also loaded optimizer from checkpoint '{}' (iter {})"
+ .format(path, last_iter))
+ return last_iter
+ else:
+ print("=> no checkpoint found at '{}'".format(path))
+
+
diff --git a/models/cmp/utils/data_utils.py b/models/cmp/utils/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0651fc5d9fefa638c76f86ebdb9db3139fecf991
--- /dev/null
+++ b/models/cmp/utils/data_utils.py
@@ -0,0 +1,280 @@
+from PIL import Image, ImageOps
+import scipy.ndimage as ndimage
+import cv2
+import random
+import numpy as np
+from scipy.ndimage.filters import maximum_filter
+from scipy import signal
+cv2.ocl.setUseOpenCL(False)
+
+def get_edge(data, blur=False):
+ if blur:
+ data = cv2.GaussianBlur(data, (3, 3), 1.)
+ sobel = np.array([[1,0,-1],[2,0,-2],[1,0,-1]]).astype(np.float32)
+ ch_edges = []
+ for k in range(data.shape[2]):
+ edgex = signal.convolve2d(data[:,:,k], sobel, boundary='symm', mode='same')
+ edgey = signal.convolve2d(data[:,:,k], sobel.T, boundary='symm', mode='same')
+ ch_edges.append(np.sqrt(edgex**2 + edgey**2))
+ return sum(ch_edges)
+
+def get_max(score, bbox):
+ u = max(0, bbox[0])
+ d = min(score.shape[0], bbox[1])
+ l = max(0, bbox[2])
+ r = min(score.shape[1], bbox[3])
+ return score[u:d,l:r].max()
+
+def nms(score, ks):
+ assert ks % 2 == 1
+ ret_score = score.copy()
+ maxpool = maximum_filter(score, footprint=np.ones((ks, ks)))
+ ret_score[score < maxpool] = 0.
+ return ret_score
+
+def image_flow_crop(img1, img2, flow, crop_size, phase):
+ assert len(crop_size) == 2
+ pad_h = max(crop_size[0] - img1.height, 0)
+ pad_w = max(crop_size[1] - img1.width, 0)
+ pad_h_half = int(pad_h / 2)
+ pad_w_half = int(pad_w / 2)
+ if pad_h > 0 or pad_w > 0:
+ flow_expand = np.zeros((img1.height + pad_h, img1.width + pad_w, 2), dtype=np.float32)
+ flow_expand[pad_h_half:pad_h_half+img1.height, pad_w_half:pad_w_half+img1.width, :] = flow
+ flow = flow_expand
+ border = (pad_w_half, pad_h_half, pad_w - pad_w_half, pad_h - pad_h_half)
+ img1 = ImageOps.expand(img1, border=border, fill=(0,0,0))
+ img2 = ImageOps.expand(img2, border=border, fill=(0,0,0))
+ if phase == 'train':
+ hoff = int(np.random.rand() * (img1.height - crop_size[0]))
+ woff = int(np.random.rand() * (img1.width - crop_size[1]))
+ else:
+ hoff = (img1.height - crop_size[0]) // 2
+ woff = (img1.width - crop_size[1]) // 2
+
+ img1 = img1.crop((woff, hoff, woff+crop_size[1], hoff+crop_size[0]))
+ img2 = img2.crop((woff, hoff, woff+crop_size[1], hoff+crop_size[0]))
+ flow = flow[hoff:hoff+crop_size[0], woff:woff+crop_size[1], :]
+ offset = (hoff, woff)
+ return img1, img2, flow, offset
+
+def image_crop(img, crop_size):
+ pad_h = max(crop_size[0] - img.height, 0)
+ pad_w = max(crop_size[1] - img.width, 0)
+ pad_h_half = int(pad_h / 2)
+ pad_w_half = int(pad_w / 2)
+ if pad_h > 0 or pad_w > 0:
+ border = (pad_w_half, pad_h_half, pad_w - pad_w_half, pad_h - pad_h_half)
+ img = ImageOps.expand(img, border=border, fill=(0,0,0))
+ hoff = (img.height - crop_size[0]) // 2
+ woff = (img.width - crop_size[1]) // 2
+ return img.crop((woff, hoff, woff+crop_size[1], hoff+crop_size[0])), (pad_w_half, pad_h_half)
+
+def image_flow_resize(img1, img2, flow, short_size=None, long_size=None):
+ assert (short_size is None) ^ (long_size is None)
+ w, h = img1.width, img1.height
+ if short_size is not None:
+ if w < h:
+ neww = short_size
+ newh = int(short_size / float(w) * h)
+ else:
+ neww = int(short_size / float(h) * w)
+ newh = short_size
+ else:
+ if w < h:
+ neww = int(long_size / float(h) * w)
+ newh = long_size
+ else:
+ neww = long_size
+ newh = int(long_size / float(w) * h)
+ img1 = img1.resize((neww, newh), Image.BICUBIC)
+ img2 = img2.resize((neww, newh), Image.BICUBIC)
+ ratio = float(newh) / h
+ flow = cv2.resize(flow.copy(), (neww, newh), interpolation=cv2.INTER_LINEAR) * ratio
+ return img1, img2, flow, ratio
+
+def image_resize(img, short_size=None, long_size=None):
+ assert (short_size is None) ^ (long_size is None)
+ w, h = img.width, img.height
+ if short_size is not None:
+ if w < h:
+ neww = short_size
+ newh = int(short_size / float(w) * h)
+ else:
+ neww = int(short_size / float(h) * w)
+ newh = short_size
+ else:
+ if w < h:
+ neww = int(long_size / float(h) * w)
+ newh = long_size
+ else:
+ neww = long_size
+ newh = int(long_size / float(w) * h)
+ img = img.resize((neww, newh), Image.BICUBIC)
+ return img, [w, h]
+
+
+def image_pose_crop(img, posemap, crop_size, scale):
+ assert len(crop_size) == 2
+ assert crop_size[0] <= img.height
+ assert crop_size[1] <= img.width
+ hoff = (img.height - crop_size[0]) // 2
+ woff = (img.width - crop_size[1]) // 2
+ img = img.crop((woff, hoff, woff+crop_size[1], hoff+crop_size[0]))
+ posemap = posemap[hoff//scale:hoff//scale+crop_size[0]//scale, woff//scale:woff//scale+crop_size[1]//scale,:]
+ return img, posemap
+
+def neighbor_elim(ph, pw, d):
+ valid = np.ones((len(ph))).astype(np.int)
+ h_dist = np.fabs(np.tile(ph[:,np.newaxis], [1,len(ph)]) - np.tile(ph.T[np.newaxis,:], [len(ph),1]))
+ w_dist = np.fabs(np.tile(pw[:,np.newaxis], [1,len(pw)]) - np.tile(pw.T[np.newaxis,:], [len(pw),1]))
+ idx1, idx2 = np.where((h_dist < d) & (w_dist < d))
+ for i,j in zip(idx1, idx2):
+ if valid[i] and valid[j] and i != j:
+ if np.random.rand() > 0.5:
+ valid[i] = 0
+ else:
+ valid[j] = 0
+ valid_idx = np.where(valid==1)
+ return ph[valid_idx], pw[valid_idx]
+
+def remove_border(mask):
+ mask[0,:] = 0
+ mask[:,0] = 0
+ mask[mask.shape[0]-1,:] = 0
+ mask[:,mask.shape[1]-1] = 0
+
+def flow_sampler(flow, strategy=['grid'], bg_ratio=1./6400, nms_ks=15, max_num_guide=-1, guidepoint=None):
+ assert bg_ratio >= 0 and bg_ratio <= 1, "sampling ratio must be in (0, 1]"
+ for s in strategy:
+ assert s in ['grid', 'uniform', 'gradnms', 'watershed', 'single', 'full', 'specified'], "No such strategy: {}".format(s)
+ h = flow.shape[0]
+ w = flow.shape[1]
+ ds = max(1, max(h, w) // 400) # reduce computation
+
+ if 'full' in strategy:
+ sparse = flow.copy()
+ mask = np.ones(flow.shape, dtype=np.int)
+ return sparse, mask
+
+ pts_h = []
+ pts_w = []
+ if 'grid' in strategy:
+ stride = int(np.sqrt(1./bg_ratio))
+ mesh_start_h = int((h - h // stride * stride) / 2)
+ mesh_start_w = int((w - w // stride * stride) / 2)
+ mesh = np.meshgrid(np.arange(mesh_start_h, h, stride), np.arange(mesh_start_w, w, stride))
+ pts_h.append(mesh[0].flat)
+ pts_w.append(mesh[1].flat)
+ if 'uniform' in strategy:
+ pts_h.append(np.random.randint(0, h, int(bg_ratio * h * w)))
+ pts_w.append(np.random.randint(0, w, int(bg_ratio * h * w)))
+ if "gradnms" in strategy:
+ ks = w // ds // 20
+ edge = get_edge(flow[::ds,::ds,:])
+ kernel = np.ones((ks, ks), dtype=np.float32) / (ks * ks)
+ subkernel = np.ones((ks//2, ks//2), dtype=np.float32) / (ks//2 * ks//2)
+ score = signal.convolve2d(edge, kernel, boundary='symm', mode='same')
+ subscore = signal.convolve2d(edge, subkernel, boundary='symm', mode='same')
+ score = score / score.max() - subscore / subscore.max()
+ nms_res = nms(score, nms_ks)
+ pth, ptw = np.where(nms_res > 0.1)
+ pts_h.append(pth * ds)
+ pts_w.append(ptw * ds)
+ if "watershed" in strategy:
+ edge = get_edge(flow[::ds,::ds,:])
+ edge /= max(edge.max(), 0.01)
+ edge = (edge > 0.1).astype(np.float32)
+ watershed = ndimage.distance_transform_edt(1-edge)
+ nms_res = nms(watershed, nms_ks)
+ remove_border(nms_res)
+ pth, ptw = np.where(nms_res > 0)
+ pth, ptw = neighbor_elim(pth, ptw, (nms_ks-1)/2)
+ pts_h.append(pth * ds)
+ pts_w.append(ptw * ds)
+ if "single" in strategy:
+ pth, ptw = np.where((flow[:,:,0] != 0) | (flow[:,:,1] != 0))
+ randidx = np.random.randint(len(pth))
+ pts_h.append(pth[randidx:randidx+1])
+ pts_w.append(ptw[randidx:randidx+1])
+ if 'specified' in strategy:
+ assert guidepoint is not None, "if using \"specified\", switch \"with_info\" on."
+ pts_h.append(guidepoint[:,1])
+ pts_w.append(guidepoint[:,0])
+
+ pts_h = np.concatenate(pts_h)
+ pts_w = np.concatenate(pts_w)
+
+ if max_num_guide == -1:
+ max_num_guide = np.inf
+
+ randsel = np.random.permutation(len(pts_h))[:len(pts_h)]
+ selidx = randsel[np.arange(min(max_num_guide, len(randsel)))]
+ pts_h = pts_h[selidx]
+ pts_w = pts_w[selidx]
+
+ sparse = np.zeros(flow.shape, dtype=flow.dtype)
+ mask = np.zeros(flow.shape, dtype=np.int)
+
+ sparse[:, :, 0][(pts_h, pts_w)] = flow[:, :, 0][(pts_h, pts_w)]
+ sparse[:, :, 1][(pts_h, pts_w)] = flow[:, :, 1][(pts_h, pts_w)]
+
+ mask[:,:,0][(pts_h, pts_w)] = 1
+ mask[:,:,1][(pts_h, pts_w)] = 1
+ return sparse, mask
+
+def image_flow_aug(img1, img2, flow, flip_horizon=True):
+ if flip_horizon:
+ if random.random() < 0.5:
+ img1 = img1.transpose(Image.FLIP_LEFT_RIGHT)
+ img2 = img2.transpose(Image.FLIP_LEFT_RIGHT)
+ flow = flow[:,::-1,:].copy()
+ flow[:,:,0] = -flow[:,:,0]
+ return img1, img2, flow
+
+def flow_aug(flow, reverse=True, scale=True, rotate=True):
+ if reverse:
+ if random.random() < 0.5:
+ flow = -flow
+ if scale:
+ rand_scale = random.uniform(0.5, 2.0)
+ flow = flow * rand_scale
+ if rotate and random.random() < 0.5:
+ lengh = np.sqrt(np.square(flow[:,:,0]) + np.square(flow[:,:,1]))
+ alpha = np.arctan(flow[:,:,1] / flow[:,:,0])
+ theta = random.uniform(0, np.pi*2)
+ flow[:,:,0] = lengh * np.cos(alpha + theta)
+ flow[:,:,1] = lengh * np.sin(alpha + theta)
+ return flow
+
+def draw_gaussian(img, pt, sigma, type='Gaussian'):
+ # Check that any part of the gaussian is in-bounds
+ ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
+ br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]
+ if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or
+ br[0] < 0 or br[1] < 0):
+ # If not, just return the image as is
+ return img
+
+ # Generate gaussian
+ size = 6 * sigma + 1
+ x = np.arange(0, size, 1, float)
+ y = x[:, np.newaxis]
+ x0 = y0 = size // 2
+ # The gaussian is not normalized, we want the center value to equal 1
+ if type == 'Gaussian':
+ g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
+ elif type == 'Cauchy':
+ g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5)
+
+ # Usable gaussian range
+ g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
+ g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
+ # Image range
+ img_x = max(0, ul[0]), min(br[0], img.shape[1])
+ img_y = max(0, ul[1]), min(br[1], img.shape[0])
+
+ img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
+ return img
+
+
diff --git a/models/cmp/utils/distributed_utils.py b/models/cmp/utils/distributed_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..97056fc313c198ea11ec96a6ad7575db5de6b302
--- /dev/null
+++ b/models/cmp/utils/distributed_utils.py
@@ -0,0 +1,229 @@
+import os
+import subprocess
+import numpy as np
+import multiprocessing as mp
+import math
+
+import torch
+import torch.distributed as dist
+from torch.utils.data.sampler import Sampler
+from torch.nn import Module
+
+class DistModule(Module):
+ def __init__(self, module):
+ super(DistModule, self).__init__()
+ self.module = module
+ broadcast_params(self.module)
+ def forward(self, *inputs, **kwargs):
+ return self.module(*inputs, **kwargs)
+ def train(self, mode=True):
+ super(DistModule, self).train(mode)
+ self.module.train(mode)
+
+def average_gradients(model):
+ """ average gradients """
+ for param in model.parameters():
+ if param.requires_grad:
+ dist.all_reduce(param.grad.data)
+
+def broadcast_params(model):
+ """ broadcast model parameters """
+ for p in model.state_dict().values():
+ dist.broadcast(p, 0)
+
+def dist_init(launcher, backend='nccl', **kwargs):
+ if mp.get_start_method(allow_none=True) is None:
+ mp.set_start_method('spawn')
+ if launcher == 'pytorch':
+ _init_dist_pytorch(backend, **kwargs)
+ elif launcher == 'mpi':
+ _init_dist_mpi(backend, **kwargs)
+ elif launcher == 'slurm':
+ _init_dist_slurm(backend, **kwargs)
+ else:
+ raise ValueError('Invalid launcher type: {}'.format(launcher))
+
+def _init_dist_pytorch(backend, **kwargs):
+ rank = int(os.environ['RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+def _init_dist_mpi(backend, **kwargs):
+ raise NotImplementedError
+
+def _init_dist_slurm(backend, port=10086, **kwargs):
+ proc_id = int(os.environ['SLURM_PROCID'])
+ ntasks = int(os.environ['SLURM_NTASKS'])
+ node_list = os.environ['SLURM_NODELIST']
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(proc_id % num_gpus)
+ addr = subprocess.getoutput(
+ 'scontrol show hostname {} | head -n1'.format(node_list))
+ os.environ['MASTER_PORT'] = str(port)
+ os.environ['MASTER_ADDR'] = addr
+ os.environ['WORLD_SIZE'] = str(ntasks)
+ os.environ['RANK'] = str(proc_id)
+ dist.init_process_group(backend=backend)
+
+def gather_tensors(input_array):
+ world_size = dist.get_world_size()
+ ## gather shapes first
+ myshape = input_array.shape
+ mycount = input_array.size
+ shape_tensor = torch.Tensor(np.array(myshape)).cuda()
+ all_shape = [torch.Tensor(np.array(myshape)).cuda() for i in range(world_size)]
+ dist.all_gather(all_shape, shape_tensor)
+ ## compute largest shapes
+ all_shape = [x.cpu().numpy() for x in all_shape]
+ all_count = [int(x.prod()) for x in all_shape]
+ all_shape = [list(map(int, x)) for x in all_shape]
+ max_count = max(all_count)
+ ## padding tensors and gather them
+ output_tensors = [torch.Tensor(max_count).cuda() for i in range(world_size)]
+ padded_input_array = np.zeros(max_count)
+ padded_input_array[:mycount] = input_array.reshape(-1)
+ input_tensor = torch.Tensor(padded_input_array).cuda()
+ dist.all_gather(output_tensors, input_tensor)
+ ## unpadding gathered tensors
+ padded_output = [x.cpu().numpy() for x in output_tensors]
+ output = [x[:all_count[i]].reshape(all_shape[i]) for i,x in enumerate(padded_output)]
+ return output
+
+def gather_tensors_batch(input_array, part_size=10):
+ # gather
+ rank = dist.get_rank()
+ all_features = []
+ part_num = input_array.shape[0] // part_size + 1 if input_array.shape[0] % part_size != 0 else input_array.shape[0] // part_size
+ for i in range(part_num):
+ part_feat = input_array[i * part_size:min((i+1)*part_size, input_array.shape[0]),...]
+ assert part_feat.shape[0] > 0, "rank: {}, length of part features should > 0".format(rank)
+ print("rank: {}, gather part: {}/{}, length: {}".format(rank, i, part_num, len(part_feat)))
+ gather_part_feat = gather_tensors(part_feat)
+ all_features.append(gather_part_feat)
+ print("rank: {}, gather done.".format(rank))
+ all_features = np.concatenate([np.concatenate([all_features[i][j] for i in range(part_num)], axis=0) for j in range(len(all_features[0]))], axis=0)
+ return all_features
+
+def reduce_tensors(tensor):
+ reduced_tensor = tensor.clone()
+ dist.all_reduce(reduced_tensor)
+ return reduced_tensor
+
+class DistributedSequentialSampler(Sampler):
+ def __init__(self, dataset, world_size=None, rank=None):
+ if world_size == None:
+ world_size = dist.get_world_size()
+ if rank == None:
+ rank = dist.get_rank()
+ self.dataset = dataset
+ self.world_size = world_size
+ self.rank = rank
+ assert len(self.dataset) >= self.world_size, '{} vs {}'.format(len(self.dataset), self.world_size)
+ sub_num = int(math.ceil(len(self.dataset) * 1.0 / self.world_size))
+ self.beg = sub_num * self.rank
+ #self.end = min(self.beg+sub_num, len(self.dataset))
+ self.end = self.beg + sub_num
+ self.padded_ind = list(range(len(self.dataset))) + list(range(sub_num * self.world_size - len(self.dataset)))
+
+ def __iter__(self):
+ indices = [self.padded_ind[i] for i in range(self.beg, self.end)]
+ return iter(indices)
+
+ def __len__(self):
+ return self.end - self.beg
+
+class GivenIterationSampler(Sampler):
+ def __init__(self, dataset, total_iter, batch_size, last_iter=-1):
+ self.dataset = dataset
+ self.total_iter = total_iter
+ self.batch_size = batch_size
+ self.last_iter = last_iter
+
+ self.total_size = self.total_iter * self.batch_size
+ self.indices = self.gen_new_list()
+ self.call = 0
+
+ def __iter__(self):
+ if self.call == 0:
+ self.call = 1
+ return iter(self.indices[(self.last_iter + 1) * self.batch_size:])
+ else:
+ raise RuntimeError("this sampler is not designed to be called more than once!!")
+
+ def gen_new_list(self):
+
+ # each process shuffle all list with same seed, and pick one piece according to rank
+ np.random.seed(0)
+
+ all_size = self.total_size
+ indices = np.arange(len(self.dataset))
+ indices = indices[:all_size]
+ num_repeat = (all_size-1) // indices.shape[0] + 1
+ indices = np.tile(indices, num_repeat)
+ indices = indices[:all_size]
+
+ np.random.shuffle(indices)
+
+ assert len(indices) == self.total_size
+
+ return indices
+
+ def __len__(self):
+ return self.total_size
+
+
+class DistributedGivenIterationSampler(Sampler):
+ def __init__(self, dataset, total_iter, batch_size, world_size=None, rank=None, last_iter=-1):
+ if world_size is None:
+ world_size = dist.get_world_size()
+ if rank is None:
+ rank = dist.get_rank()
+ assert rank < world_size
+ self.dataset = dataset
+ self.total_iter = total_iter
+ self.batch_size = batch_size
+ self.world_size = world_size
+ self.rank = rank
+ self.last_iter = last_iter
+
+ self.total_size = self.total_iter*self.batch_size
+
+ self.indices = self.gen_new_list()
+ self.call = 0
+
+ def __iter__(self):
+ if self.call == 0:
+ self.call = 1
+ return iter(self.indices[(self.last_iter+1)*self.batch_size:])
+ else:
+ raise RuntimeError("this sampler is not designed to be called more than once!!")
+
+ def gen_new_list(self):
+
+ # each process shuffle all list with same seed, and pick one piece according to rank
+ np.random.seed(0)
+
+ all_size = self.total_size * self.world_size
+ indices = np.arange(len(self.dataset))
+ indices = indices[:all_size]
+ num_repeat = (all_size-1) // indices.shape[0] + 1
+ indices = np.tile(indices, num_repeat)
+ indices = indices[:all_size]
+
+ np.random.shuffle(indices)
+ beg = self.total_size * self.rank
+ indices = indices[beg:beg+self.total_size]
+
+ assert len(indices) == self.total_size
+
+ return indices
+
+ def __len__(self):
+ # note here we do not take last iter into consideration, since __len__
+ # should only be used for displaying, the correct remaining size is
+ # handled by dataloader
+ #return self.total_size - (self.last_iter+1)*self.batch_size
+ return self.total_size
+
+
diff --git a/models/cmp/utils/flowlib.py b/models/cmp/utils/flowlib.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a0ab1a8bf3cbe05b55c50449319a55d4ae8d1ee
--- /dev/null
+++ b/models/cmp/utils/flowlib.py
@@ -0,0 +1,308 @@
+#!/usr/bin/python
+"""
+# ==============================
+# flowlib.py
+# library for optical flow processing
+# Author: Ruoteng Li
+# Date: 6th Aug 2016
+# ==============================
+"""
+#import png
+import numpy as np
+from PIL import Image
+import io
+
+UNKNOWN_FLOW_THRESH = 1e7
+SMALLFLOW = 0.0
+LARGEFLOW = 1e8
+
+"""
+=============
+Flow Section
+=============
+"""
+
+def write_flow(flow, filename):
+ """
+ write optical flow in Middlebury .flo format
+ :param flow: optical flow map
+ :param filename: optical flow file path to be saved
+ :return: None
+ """
+ f = open(filename, 'wb')
+ magic = np.array([202021.25], dtype=np.float32)
+ (height, width) = flow.shape[0:2]
+ w = np.array([width], dtype=np.int32)
+ h = np.array([height], dtype=np.int32)
+ magic.tofile(f)
+ w.tofile(f)
+ h.tofile(f)
+ flow.tofile(f)
+ f.close()
+
+
+def save_flow_image(flow, image_file):
+ """
+ save flow visualization into image file
+ :param flow: optical flow data
+ :param flow_fil
+ :return: None
+ """
+ flow_img = flow_to_image(flow)
+ img_out = Image.fromarray(flow_img)
+ img_out.save(image_file)
+
+def segment_flow(flow):
+ h = flow.shape[0]
+ w = flow.shape[1]
+ u = flow[:, :, 0]
+ v = flow[:, :, 1]
+
+ idx = ((abs(u) > LARGEFLOW) | (abs(v) > LARGEFLOW))
+ idx2 = (abs(u) == SMALLFLOW)
+ class0 = (v == 0) & (u == 0)
+ u[idx2] = 0.00001
+ tan_value = v / u
+
+ class1 = (tan_value < 1) & (tan_value >= 0) & (u > 0) & (v >= 0)
+ class2 = (tan_value >= 1) & (u >= 0) & (v >= 0)
+ class3 = (tan_value < -1) & (u <= 0) & (v >= 0)
+ class4 = (tan_value < 0) & (tan_value >= -1) & (u < 0) & (v >= 0)
+ class8 = (tan_value >= -1) & (tan_value < 0) & (u > 0) & (v <= 0)
+ class7 = (tan_value < -1) & (u >= 0) & (v <= 0)
+ class6 = (tan_value >= 1) & (u <= 0) & (v <= 0)
+ class5 = (tan_value >= 0) & (tan_value < 1) & (u < 0) & (v <= 0)
+
+ seg = np.zeros((h, w))
+
+ seg[class1] = 1
+ seg[class2] = 2
+ seg[class3] = 3
+ seg[class4] = 4
+ seg[class5] = 5
+ seg[class6] = 6
+ seg[class7] = 7
+ seg[class8] = 8
+ seg[class0] = 0
+ seg[idx] = 0
+
+ return seg
+
+def flow_to_image(flow):
+ """
+ Convert flow into middlebury color code image
+ :param flow: optical flow map
+ :return: optical flow image in middlebury color
+ """
+ u = flow[:, :, 0]
+ v = flow[:, :, 1]
+
+ maxu = -999.
+ maxv = -999.
+ minu = 999.
+ minv = 999.
+
+ idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
+ u[idxUnknow] = 0
+ v[idxUnknow] = 0
+
+ maxu = max(maxu, np.max(u))
+ minu = min(minu, np.min(u))
+
+ maxv = max(maxv, np.max(v))
+ minv = min(minv, np.min(v))
+
+ rad = np.sqrt(u ** 2 + v ** 2)
+ maxrad = max(5, np.max(rad))
+ #maxrad = max(-1, 99)
+
+ u = u/(maxrad + np.finfo(float).eps)
+ v = v/(maxrad + np.finfo(float).eps)
+
+ img = compute_color(u, v)
+
+ idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
+ img[idx] = 0
+
+ return np.uint8(img)
+
+def disp_to_flowfile(disp, filename):
+ """
+ Read KITTI disparity file in png format
+ :param disp: disparity matrix
+ :param filename: the flow file name to save
+ :return: None
+ """
+ f = open(filename, 'wb')
+ magic = np.array([202021.25], dtype=np.float32)
+ (height, width) = disp.shape[0:2]
+ w = np.array([width], dtype=np.int32)
+ h = np.array([height], dtype=np.int32)
+ empty_map = np.zeros((height, width), dtype=np.float32)
+ data = np.dstack((disp, empty_map))
+ magic.tofile(f)
+ w.tofile(f)
+ h.tofile(f)
+ data.tofile(f)
+ f.close()
+
+def compute_color(u, v):
+ """
+ compute optical flow color map
+ :param u: optical flow horizontal map
+ :param v: optical flow vertical map
+ :return: optical flow in color code
+ """
+ [h, w] = u.shape
+ img = np.zeros([h, w, 3])
+ nanIdx = np.isnan(u) | np.isnan(v)
+ u[nanIdx] = 0
+ v[nanIdx] = 0
+
+ colorwheel = make_color_wheel()
+ ncols = np.size(colorwheel, 0)
+
+ rad = np.sqrt(u**2+v**2)
+
+ a = np.arctan2(-v, -u) / np.pi
+
+ fk = (a+1) / 2 * (ncols - 1) + 1
+
+ k0 = np.floor(fk).astype(int)
+
+ k1 = k0 + 1
+ k1[k1 == ncols+1] = 1
+ f = fk - k0
+
+ for i in range(0, np.size(colorwheel,1)):
+ tmp = colorwheel[:, i]
+ col0 = tmp[k0-1] / 255
+ col1 = tmp[k1-1] / 255
+ col = (1-f) * col0 + f * col1
+
+ idx = rad <= 1
+ col[idx] = 1-rad[idx]*(1-col[idx])
+ notidx = np.logical_not(idx)
+
+ col[notidx] *= 0.75
+ img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
+
+ return img
+
+
+def make_color_wheel():
+ """
+ Generate color wheel according Middlebury color code
+ :return: Color wheel
+ """
+ RY = 15
+ YG = 6
+ GC = 4
+ CB = 11
+ BM = 13
+ MR = 6
+
+ ncols = RY + YG + GC + CB + BM + MR
+
+ colorwheel = np.zeros([ncols, 3])
+
+ col = 0
+
+ # RY
+ colorwheel[0:RY, 0] = 255
+ colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
+ col += RY
+
+ # YG
+ colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
+ colorwheel[col:col+YG, 1] = 255
+ col += YG
+
+ # GC
+ colorwheel[col:col+GC, 1] = 255
+ colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
+ col += GC
+
+ # CB
+ colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
+ colorwheel[col:col+CB, 2] = 255
+ col += CB
+
+ # BM
+ colorwheel[col:col+BM, 2] = 255
+ colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
+ col += + BM
+
+ # MR
+ colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
+ colorwheel[col:col+MR, 0] = 255
+
+ return colorwheel
+
+
+def read_flo_file(filename, memcached=False):
+ """
+ Read from Middlebury .flo file
+ :param flow_file: name of the flow file
+ :return: optical flow data in matrix
+ """
+ if memcached:
+ filename = io.BytesIO(filename)
+ f = open(filename, 'rb')
+ magic = np.fromfile(f, np.float32, count=1)[0]
+ data2d = None
+
+ if 202021.25 != magic:
+ print('Magic number incorrect. Invalid .flo file')
+ else:
+ w = np.fromfile(f, np.int32, count=1)[0]
+ h = np.fromfile(f, np.int32, count=1)[0]
+ data2d = np.fromfile(f, np.float32, count=2 * w * h)
+ # reshape data into 3D array (columns, rows, channels)
+ data2d = np.resize(data2d, (h, w, 2))
+ f.close()
+ return data2d
+
+
+# fast resample layer
+def resample(img, sz):
+ """
+ img: flow map to be resampled
+ sz: new flow map size. Must be [height,weight]
+ """
+ original_image_size = img.shape
+ in_height = img.shape[0]
+ in_width = img.shape[1]
+ out_height = sz[0]
+ out_width = sz[1]
+ out_flow = np.zeros((out_height, out_width, 2))
+ # find scale
+ height_scale = float(in_height) / float(out_height)
+ width_scale = float(in_width) / float(out_width)
+
+ [x,y] = np.meshgrid(range(out_width), range(out_height))
+ xx = x * width_scale
+ yy = y * height_scale
+ x0 = np.floor(xx).astype(np.int32)
+ x1 = x0 + 1
+ y0 = np.floor(yy).astype(np.int32)
+ y1 = y0 + 1
+
+ x0 = np.clip(x0,0,in_width-1)
+ x1 = np.clip(x1,0,in_width-1)
+ y0 = np.clip(y0,0,in_height-1)
+ y1 = np.clip(y1,0,in_height-1)
+
+ Ia = img[y0,x0,:]
+ Ib = img[y1,x0,:]
+ Ic = img[y0,x1,:]
+ Id = img[y1,x1,:]
+
+ wa = (y1-yy) * (x1-xx)
+ wb = (yy-y0) * (x1-xx)
+ wc = (y1-yy) * (xx-x0)
+ wd = (yy-y0) * (xx-x0)
+ out_flow[:,:,0] = (Ia[:,:,0]*wa + Ib[:,:,0]*wb + Ic[:,:,0]*wc + Id[:,:,0]*wd) * out_width / in_width
+ out_flow[:,:,1] = (Ia[:,:,1]*wa + Ib[:,:,1]*wb + Ic[:,:,1]*wc + Id[:,:,1]*wd) * out_height / in_height
+
+ return out_flow
diff --git a/models/cmp/utils/scheduler.py b/models/cmp/utils/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f34f6321b0b1f567c656e57ee0e32c6baafe9ce
--- /dev/null
+++ b/models/cmp/utils/scheduler.py
@@ -0,0 +1,102 @@
+import torch
+from bisect import bisect_right
+
+class _LRScheduler(object):
+ def __init__(self, optimizer, last_iter=-1):
+ if not isinstance(optimizer, torch.optim.Optimizer):
+ raise TypeError('{} is not an Optimizer'.format(
+ type(optimizer).__name__))
+ self.optimizer = optimizer
+ if last_iter == -1:
+ for group in optimizer.param_groups:
+ group.setdefault('initial_lr', group['lr'])
+ else:
+ for i, group in enumerate(optimizer.param_groups):
+ if 'initial_lr' not in group:
+ raise KeyError("param 'initial_lr' is not specified "
+ "in param_groups[{}] when resuming an optimizer".format(i))
+ self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
+ self.last_iter = last_iter
+
+ def _get_new_lr(self):
+ raise NotImplementedError
+
+ def get_lr(self):
+ return list(map(lambda group: group['lr'], self.optimizer.param_groups))
+
+ def step(self, this_iter=None):
+ if this_iter is None:
+ this_iter = self.last_iter + 1
+ self.last_iter = this_iter
+ for param_group, lr in zip(self.optimizer.param_groups, self._get_new_lr()):
+ param_group['lr'] = lr
+
+class _WarmUpLRSchedulerOld(_LRScheduler):
+
+ def __init__(self, optimizer, base_lr, warmup_lr, warmup_steps, last_iter=-1):
+ self.base_lr = base_lr
+ self.warmup_steps = warmup_steps
+ if warmup_steps == 0:
+ self.warmup_lr = base_lr
+ else:
+ self.warmup_lr = warmup_lr
+ super(_WarmUpLRSchedulerOld, self).__init__(optimizer, last_iter)
+
+ def _get_warmup_lr(self):
+ if self.warmup_steps > 0 and self.last_iter < self.warmup_steps:
+ # first compute relative scale for self.base_lr, then multiply to base_lr
+ scale = ((self.last_iter/self.warmup_steps)*(self.warmup_lr - self.base_lr) + self.base_lr)/self.base_lr
+ #print('last_iter: {}, warmup_lr: {}, base_lr: {}, scale: {}'.format(self.last_iter, self.warmup_lr, self.base_lr, scale))
+ return [scale * base_lr for base_lr in self.base_lrs]
+ else:
+ return None
+
+class _WarmUpLRScheduler(_LRScheduler):
+
+ def __init__(self, optimizer, base_lr, warmup_lr, warmup_steps, last_iter=-1):
+ self.base_lr = base_lr
+ self.warmup_lr = warmup_lr
+ self.warmup_steps = warmup_steps
+ assert isinstance(warmup_lr, list)
+ assert isinstance(warmup_steps, list)
+ assert len(warmup_lr) == len(warmup_steps)
+ super(_WarmUpLRScheduler, self).__init__(optimizer, last_iter)
+
+ def _get_warmup_lr(self):
+ pos = bisect_right(self.warmup_steps, self.last_iter)
+ if pos >= len(self.warmup_steps):
+ return None
+ else:
+ if pos == 0:
+ curr_lr = self.base_lr + self.last_iter * (self.warmup_lr[pos] - self.base_lr) / self.warmup_steps[pos]
+ else:
+ curr_lr = self.warmup_lr[pos - 1] + (self.last_iter - self.warmup_steps[pos - 1]) * (self.warmup_lr[pos] - self.warmup_lr[pos - 1]) / (self.warmup_steps[pos] - self.warmup_steps[pos - 1])
+ scale = curr_lr / self.base_lr
+ return [scale * base_lr for base_lr in self.base_lrs]
+
+class StepLRScheduler(_WarmUpLRScheduler):
+ def __init__(self, optimizer, milestones, lr_mults, base_lr, warmup_lr, warmup_steps, last_iter=-1):
+ super(StepLRScheduler, self).__init__(optimizer, base_lr, warmup_lr, warmup_steps, last_iter)
+
+ assert len(milestones) == len(lr_mults), "{} vs {}".format(milestones, lr_mults)
+ for x in milestones:
+ assert isinstance(x, int)
+ if not list(milestones) == sorted(milestones):
+ raise ValueError('Milestones should be a list of'
+ ' increasing integers. Got {}', milestones)
+ self.milestones = milestones
+ self.lr_mults = [1.0]
+ for x in lr_mults:
+ self.lr_mults.append(self.lr_mults[-1]*x)
+
+ def _get_new_lr(self):
+ warmup_lrs = self._get_warmup_lr()
+ if warmup_lrs is not None:
+ return warmup_lrs
+
+ pos = bisect_right(self.milestones, self.last_iter)
+ if len(self.warmup_lr) == 0:
+ scale = self.lr_mults[pos]
+ else:
+ scale = self.warmup_lr[-1] * self.lr_mults[pos] / self.base_lr
+ return [base_lr * scale for base_lr in self.base_lrs]
diff --git a/models/cmp/utils/visualize_utils.py b/models/cmp/utils/visualize_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfb4796a980156e9a9e23f0cf86604ba24dfbc4e
--- /dev/null
+++ b/models/cmp/utils/visualize_utils.py
@@ -0,0 +1,109 @@
+import numpy as np
+
+import torch
+from . import flowlib
+
+class Fuser(object):
+ def __init__(self, nbins, fmax):
+ self.nbins = nbins
+ self.fmax = fmax
+ self.step = 2 * fmax / float(nbins)
+ self.mesh = torch.arange(nbins).view(1,-1,1,1).float().cuda() * self.step - fmax + self.step / 2
+
+ def convert_flow(self, flow_prob):
+ flow_probx = torch.nn.functional.softmax(flow_prob[:, :self.nbins, :, :], dim=1)
+ flow_proby = torch.nn.functional.softmax(flow_prob[:, self.nbins:, :, :], dim=1)
+ flow_probx = flow_probx * self.mesh
+ flow_proby = flow_proby * self.mesh
+ flow = torch.cat([flow_probx.sum(dim=1, keepdim=True), flow_proby.sum(dim=1, keepdim=True)], dim=1)
+ return flow
+
+def visualize_tensor_old(image, mask, flow_pred, flow_target, warped, rgb_gen, image_target, image_mean, image_div):
+ together = [
+ draw_cross(unormalize(image.cpu(), mean=image_mean, div=image_div), mask.cpu(), radius=int(image.size(3) / 50.)),
+ flow_to_image(flow_pred.detach().cpu()),
+ flow_to_image(flow_target.detach().cpu())]
+ if warped is not None:
+ together.append(torch.clamp(unormalize(warped.detach().cpu(), mean=image_mean, div=image_div), 0, 255))
+ if rgb_gen is not None:
+ together.append(torch.clamp(unormalize(rgb_gen.detach().cpu(), mean=image_mean, div=image_div), 0, 255))
+ if image_target is not None:
+ together.append(torch.clamp(unormalize(image_target.cpu(), mean=image_mean, div=image_div), 0, 255))
+ together = torch.cat(together, dim=3)
+ return together
+
+def visualize_tensor(image, mask, flow_tensors, common_tensors, rgb_tensors, image_mean, image_div):
+ together = [
+ draw_cross(unormalize(image.cpu(), mean=image_mean, div=image_div), mask.cpu(), radius=int(image.size(3) / 50.))]
+ for ft in flow_tensors:
+ together.append(flow_to_image(ft.cpu()))
+ for ct in common_tensors:
+ together.append(torch.clamp(ct.cpu(), 0, 255))
+ for rt in rgb_tensors:
+ together.append(torch.clamp(unormalize(rt.cpu(), mean=image_mean, div=image_div), 0, 255))
+ together = torch.cat(together, dim=3)
+ return together
+
+
+def unormalize(tensor, mean, div):
+ for c, (m, d) in enumerate(zip(mean, div)):
+ tensor[:,c,:,:].mul_(d).add_(m)
+ return tensor
+
+
+def flow_to_image(flow):
+ flow = flow.numpy()
+ flow_img = np.array([flowlib.flow_to_image(fl.transpose((1,2,0))).transpose((2,0,1)) for fl in flow]).astype(np.float32)
+ return torch.from_numpy(flow_img)
+
+def shift_tensor(input, offh, offw):
+ new = torch.zeros(input.size())
+ h = input.size(2)
+ w = input.size(3)
+ new[:,:,max(0,offh):min(h,h+offh),max(0,offw):min(w,w+offw)] = input[:,:,max(0,-offh):min(h,h-offh),max(0,-offw):min(w,w-offw)]
+ return new
+
+def draw_block(mask, radius=5):
+ '''
+ input: tensor (NxCxHxW)
+ output: block_mask (Nx1xHxW)
+ '''
+ all_mask = []
+ mask = mask[:,0:1,:,:]
+ for offh in range(-radius, radius+1):
+ for offw in range(-radius, radius+1):
+ all_mask.append(shift_tensor(mask, offh, offw))
+ block_mask = sum(all_mask)
+ block_mask[block_mask > 0] = 1
+ return block_mask
+
+def expand_block(sparse, radius=5):
+ '''
+ input: sparse (NxCxHxW)
+ output: block_sparse (NxCxHxW)
+ '''
+ all_sparse = []
+ for offh in range(-radius, radius+1):
+ for offw in range(-radius, radius+1):
+ all_sparse.append(shift_tensor(sparse, offh, offw))
+ block_sparse = sum(all_sparse)
+ return block_sparse
+
+def draw_cross(tensor, mask, radius=5, thickness=2):
+ '''
+ input: tensor (NxCxHxW)
+ mask (NxXxHxW)
+ output: new_tensor (NxCxHxW)
+ '''
+ all_mask = []
+ mask = mask[:,0:1,:,:]
+ for off in range(-radius, radius+1):
+ for t in range(-thickness, thickness+1):
+ all_mask.append(shift_tensor(mask, off, t))
+ all_mask.append(shift_tensor(mask, t, off))
+ cross_mask = sum(all_mask)
+ new_tensor = tensor.clone()
+ new_tensor[:,0:1,:,:][cross_mask > 0] = 255.0
+ new_tensor[:,1:2,:,:][cross_mask > 0] = 0.0
+ new_tensor[:,2:3,:,:][cross_mask > 0] = 0.0
+ return new_tensor
diff --git a/models/controlnet_sdv.py b/models/controlnet_sdv.py
new file mode 100644
index 0000000000000000000000000000000000000000..d45f1597955446b5e8e6e92ac0346a94a56828f4
--- /dev/null
+++ b/models/controlnet_sdv.py
@@ -0,0 +1,782 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import FromOriginalControlnetMixin
+from diffusers.utils import BaseOutput, logging
+from diffusers.models.attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.unet_3d_blocks import (
+ get_down_block, get_up_block,UNetMidBlockSpatioTemporal,
+)
+from diffusers.models import UNetSpatioTemporalConditionModel
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class ControlNetOutput(BaseOutput):
+ """
+ The output of [`ControlNetModel`].
+
+ Args:
+ down_block_res_samples (`tuple[torch.Tensor]`):
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
+ used to condition the original UNet's downsampling activations.
+ mid_down_block_re_sample (`torch.Tensor`):
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
+ Output can be used to condition the original UNet's middle block activation.
+ """
+
+ down_block_res_samples: Tuple[torch.Tensor]
+ mid_block_res_sample: torch.Tensor
+
+
+class ControlNetConditioningEmbeddingSVD(nn.Module):
+ """
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
+ model) to encode image-space conditions ... into feature maps ..."
+ """
+
+ def __init__(
+ self,
+ conditioning_embedding_channels: int,
+ conditioning_channels: int = 3,
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
+ ):
+ super().__init__()
+
+
+
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
+
+ self.blocks = nn.ModuleList([])
+
+ for i in range(len(block_out_channels) - 1):
+ channel_in = block_out_channels[i]
+ channel_out = block_out_channels[i + 1]
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
+
+ self.conv_out = zero_module(
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
+ )
+
+ def forward(self, conditioning):
+ #this seeems appropriate? idk if i should be applying a more complex setup to handle the frames
+ #combine batch and frames dimensions
+ batch_size, frames, channels, height, width = conditioning.size()
+ conditioning = conditioning.view(batch_size * frames, channels, height, width)
+
+ embedding = self.conv_in(conditioning)
+ embedding = F.silu(embedding)
+
+ for block in self.blocks:
+ embedding = block(embedding)
+ embedding = F.silu(embedding)
+
+ embedding = self.conv_out(embedding)
+
+ #split them apart again
+ #actually not needed
+ #new_channels, new_height, new_width = embedding.shape[1], embedding.shape[2], embedding.shape[3]
+ #embedding = embedding.view(batch_size, frames, new_channels, new_height, new_width)
+
+
+ return embedding
+
+
+class ControlNetSDVModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
+ r"""
+ A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample
+ shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
+ The tuple of downsample blocks to use.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
+ The tuple of upsample blocks to use.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ addition_time_embed_dim: (`int`, defaults to 256):
+ Dimension to to encode the additional time ids.
+ projection_class_embeddings_input_dim (`int`, defaults to 768):
+ The dimension of the projection of encoded `added_time_ids`.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
+ [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
+ num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
+ The number of attention heads.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 8,
+ out_channels: int = 4,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "DownBlockSpatioTemporal",
+ ),
+ up_block_types: Tuple[str] = (
+ "UpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ ),
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ addition_time_embed_dim: int = 256,
+ projection_class_embeddings_input_dim: int = 768,
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
+ num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
+ num_frames: int = 25,
+ conditioning_channels: int = 3,
+ conditioning_embedding_out_channels : Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ ):
+ super().__init__()
+ self.sample_size = sample_size
+
+ print("layers per block is", layers_per_block)
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+
+ # input
+ self.conv_in = nn.Conv2d(
+ in_channels,
+ block_out_channels[0],
+ kernel_size=3,
+ padding=1,
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+
+ self.down_blocks = nn.ModuleList([])
+ self.controlnet_down_blocks = nn.ModuleList([])
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ blocks_time_embed_dim = time_embed_dim
+ self.controlnet_cond_embedding = ControlNetConditioningEmbeddingSVD(
+ conditioning_embedding_channels=block_out_channels[0],
+ block_out_channels=conditioning_embedding_out_channels,
+ conditioning_channels=conditioning_channels,
+ )
+
+ # down
+ output_channel = block_out_channels[0]
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+
+
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=1e-5,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ resnet_act_fn="silu",
+ )
+ self.down_blocks.append(down_block)
+
+ for _ in range(layers_per_block[i]):
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ if not is_final_block:
+ controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+
+ # mid
+ mid_block_channel = block_out_channels[-1]
+ controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_mid_block = controlnet_block
+
+
+ self.mid_block = UNetMidBlockSpatioTemporal(
+ block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ cross_attention_dim=cross_attention_dim[-1],
+ num_attention_heads=num_attention_heads[-1],
+ )
+
+
+
+
+ # out
+ #self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
+ #self.conv_act = nn.SiLU()
+
+ #self.conv_out = nn.Conv2d(
+ # block_out_channels[0],
+ # out_channels,
+ # kernel_size=3,
+ # padding=1,
+ #)
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(
+ name: str,
+ module: torch.nn.Module,
+ processors: Dict[str, AttentionProcessor],
+ ):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
+ """
+ Sets the attention processor to use [feed forward
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
+
+ Parameters:
+ chunk_size (`int`, *optional*):
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
+ over each tensor of dim=`dim`.
+ dim (`int`, *optional*, defaults to `0`):
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
+ or dim=1 (sequence length).
+ """
+ if dim not in [0, 1]:
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
+
+ # By default chunk size is 1
+ chunk_size = chunk_size or 1
+
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, chunk_size, dim)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ added_time_ids: torch.Tensor,
+ controlnet_cond: torch.FloatTensor = None,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ guess_mode: bool = False,
+ conditioning_scale: float = 1.0,
+
+
+ ) -> Union[ControlNetOutput, Tuple]:
+ r"""
+ The [`UNetSpatioTemporalConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
+ added_time_ids: (`torch.FloatTensor`):
+ The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
+ embeddings and added to the time embeddings.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
+ tuple.
+ Returns:
+ [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
+ a `tuple` is returned where the first element is the sample tensor.
+ """
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ batch_size, num_frames = sample.shape[:2]
+ timesteps = timesteps.expand(batch_size)
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ # print(t_emb.dtype)
+
+ emb = self.time_embedding(t_emb)
+
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
+ time_embeds = time_embeds.reshape((batch_size, -1))
+ time_embeds = time_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(time_embeds)
+ emb = emb + aug_emb
+
+ # Flatten the batch and frames dimensions
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
+ sample = sample.flatten(0, 1)
+ # Repeat the embeddings num_video_frames times
+ # emb: [batch, channels] -> [batch * frames, channels]
+ emb = emb.repeat_interleave(num_frames, dim=0)
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ #controlnet cond
+ if controlnet_cond != None:
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
+ sample = sample + controlnet_cond
+
+
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+ else:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ image_only_indicator=image_only_indicator,
+ )
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ sample = self.mid_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+
+ controlnet_down_block_res_samples = ()
+
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
+ down_block_res_sample = controlnet_block(down_block_res_sample)
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = controlnet_down_block_res_samples
+
+ mid_block_res_sample = self.controlnet_mid_block(sample)
+
+ # 6. scaling
+
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
+
+ if not return_dict:
+ return (down_block_res_samples, mid_block_res_sample)
+
+ return ControlNetOutput(
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
+ )
+
+
+ @classmethod
+ def from_unet(
+ cls,
+ unet: UNetSpatioTemporalConditionModel,
+ controlnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ load_weights_from_unet: bool = True,
+ conditioning_channels: int = 3,
+ ):
+ r"""
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
+
+ Parameters:
+ unet (`UNet2DConditionModel`):
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
+ where applicable.
+ """
+
+ transformer_layers_per_block = (
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
+ )
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
+ addition_time_embed_dim = (
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
+ )
+ print(unet.config)
+ controlnet = cls(
+ in_channels=unet.config.in_channels,
+ down_block_types=unet.config.down_block_types,
+ block_out_channels=unet.config.block_out_channels,
+ addition_time_embed_dim=unet.config.addition_time_embed_dim,
+ transformer_layers_per_block=unet.config.transformer_layers_per_block,
+ cross_attention_dim=unet.config.cross_attention_dim,
+ num_attention_heads=unet.config.num_attention_heads,
+ num_frames=unet.config.num_frames,
+ sample_size=unet.config.sample_size, # Added based on the dict
+ layers_per_block=unet.config.layers_per_block,
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
+ conditioning_channels = conditioning_channels,
+ conditioning_embedding_out_channels = conditioning_embedding_out_channels,
+ )
+ #controlnet rgb channel order ignored, set to not makea difference by default
+
+ if load_weights_from_unet:
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
+
+ # if controlnet.class_embedding:
+ # controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
+
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
+
+ return controlnet
+
+ @property
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
+ ):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor, _remove_lora=_remove_lora)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor, _remove_lora=True)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ # def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
+ # if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
+ # module.gradient_checkpointing = value
+
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
diff --git a/models/softsplat.py b/models/softsplat.py
new file mode 100644
index 0000000000000000000000000000000000000000..f35ccc21604479940c2c86580c287e73f3dc327d
--- /dev/null
+++ b/models/softsplat.py
@@ -0,0 +1,529 @@
+#!/usr/bin/env python
+
+import collections
+import cupy
+import os
+import re
+import torch
+import typing
+
+
+##########################################################
+
+
+objCudacache = {}
+
+
+def cuda_int32(intIn:int):
+ return cupy.int32(intIn)
+# end
+
+
+def cuda_float32(fltIn:float):
+ return cupy.float32(fltIn)
+# end
+
+
+def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict):
+ if 'device' not in objCudacache:
+ objCudacache['device'] = torch.cuda.get_device_name()
+ # end
+
+ strKey = strFunction
+
+ for strVariable in objVariables:
+ objValue = objVariables[strVariable]
+
+ strKey += strVariable
+
+ if objValue is None:
+ continue
+
+ elif type(objValue) == int:
+ strKey += str(objValue)
+
+ elif type(objValue) == float:
+ strKey += str(objValue)
+
+ elif type(objValue) == bool:
+ strKey += str(objValue)
+
+ elif type(objValue) == str:
+ strKey += objValue
+
+ elif type(objValue) == torch.Tensor:
+ strKey += str(objValue.dtype)
+ strKey += str(objValue.shape)
+ strKey += str(objValue.stride())
+
+ elif True:
+ print(strVariable, type(objValue))
+ assert(False)
+
+ # end
+ # end
+
+ strKey += objCudacache['device']
+
+ if strKey not in objCudacache:
+ for strVariable in objVariables:
+ objValue = objVariables[strVariable]
+
+ if objValue is None:
+ continue
+
+ elif type(objValue) == int:
+ strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
+
+ elif type(objValue) == float:
+ strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
+
+ elif type(objValue) == bool:
+ strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue))
+
+ elif type(objValue) == str:
+ strKernel = strKernel.replace('{{' + strVariable + '}}', objValue)
+
+ elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8:
+ strKernel = strKernel.replace('{{type}}', 'unsigned char')
+
+ elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16:
+ strKernel = strKernel.replace('{{type}}', 'half')
+
+ elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32:
+ strKernel = strKernel.replace('{{type}}', 'float')
+
+ elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64:
+ strKernel = strKernel.replace('{{type}}', 'double')
+
+ elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32:
+ strKernel = strKernel.replace('{{type}}', 'int')
+
+ elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64:
+ strKernel = strKernel.replace('{{type}}', 'long')
+
+ elif type(objValue) == torch.Tensor:
+ print(strVariable, objValue.dtype)
+ assert(False)
+
+ elif True:
+ print(strVariable, type(objValue))
+ assert(False)
+
+ # end
+ # end
+
+ while True:
+ objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
+
+ if objMatch is None:
+ break
+ # end
+
+ intArg = int(objMatch.group(2))
+
+ strTensor = objMatch.group(4)
+ intSizes = objVariables[strTensor].size()
+
+ strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
+ # end
+
+ while True:
+ objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel)
+
+ if objMatch is None:
+ break
+ # end
+
+ intStart = objMatch.span()[1]
+ intStop = objMatch.span()[1]
+ intParentheses = 1
+
+ while True:
+ intParentheses += 1 if strKernel[intStop] == '(' else 0
+ intParentheses -= 1 if strKernel[intStop] == ')' else 0
+
+ if intParentheses == 0:
+ break
+ # end
+
+ intStop += 1
+ # end
+
+ intArgs = int(objMatch.group(2))
+ strArgs = strKernel[intStart:intStop].split(',')
+
+ assert(intArgs == len(strArgs) - 1)
+
+ strTensor = strArgs[0]
+ intStrides = objVariables[strTensor].stride()
+
+ strIndex = []
+
+ for intArg in range(intArgs):
+ strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
+ # end
+
+ strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')')
+ # end
+
+ while True:
+ objMatch = re.search('(VALUE_)([0-4])(\()', strKernel)
+
+ if objMatch is None:
+ break
+ # end
+
+ intStart = objMatch.span()[1]
+ intStop = objMatch.span()[1]
+ intParentheses = 1
+
+ while True:
+ intParentheses += 1 if strKernel[intStop] == '(' else 0
+ intParentheses -= 1 if strKernel[intStop] == ')' else 0
+
+ if intParentheses == 0:
+ break
+ # end
+
+ intStop += 1
+ # end
+
+ intArgs = int(objMatch.group(2))
+ strArgs = strKernel[intStart:intStop].split(',')
+
+ assert(intArgs == len(strArgs) - 1)
+
+ strTensor = strArgs[0]
+ intStrides = objVariables[strTensor].stride()
+
+ strIndex = []
+
+ for intArg in range(intArgs):
+ strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')')
+ # end
+
+ strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']')
+ # end
+
+ objCudacache[strKey] = {
+ 'strFunction': strFunction,
+ 'strKernel': strKernel
+ }
+ # end
+
+ return strKey
+# end
+
+
+@cupy.memoize(for_each_device=True)
+def cuda_launch(strKey:str):
+ if 'CUDA_HOME' not in os.environ:
+ os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()
+ # end
+
+ return cupy.cuda.compile_with_cache(objCudacache[strKey]['strKernel'], tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include'])).get_function(objCudacache[strKey]['strFunction'])
+# end
+
+
+##########################################################
+
+
+def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor, tenMetric:torch.Tensor, strMode:str):
+ assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft'])
+
+ if strMode == 'sum': assert(tenMetric is None)
+ if strMode == 'avg': assert(tenMetric is None)
+ if strMode.split('-')[0] == 'linear': assert(tenMetric is not None)
+ if strMode.split('-')[0] == 'soft': assert(tenMetric is not None)
+
+ if strMode == 'avg':
+ tenIn = torch.cat([tenIn, tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]])], 1)
+
+ elif strMode.split('-')[0] == 'linear':
+ tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1)
+
+ elif strMode.split('-')[0] == 'soft':
+ tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1)
+
+ # end
+
+ tenOut = softsplat_func.apply(tenIn, tenFlow)
+
+ if strMode.split('-')[0] in ['avg', 'linear', 'soft']:
+ tenNormalize = tenOut[:, -1:, :, :]
+
+ if len(strMode.split('-')) == 1:
+ tenNormalize = tenNormalize + 0.0000001
+
+ elif strMode.split('-')[1] == 'addeps':
+ tenNormalize = tenNormalize + 0.0000001
+
+ elif strMode.split('-')[1] == 'zeroeps':
+ tenNormalize[tenNormalize == 0.0] = 1.0
+
+ elif strMode.split('-')[1] == 'clipeps':
+ tenNormalize = tenNormalize.clip(0.0000001, None)
+
+ # end
+
+ tenOut = tenOut[:, :-1, :, :] / tenNormalize
+ # end
+
+ return tenOut
+# end
+
+
+class softsplat_func(torch.autograd.Function):
+ @staticmethod
+ @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
+ def forward(self, tenIn, tenFlow):
+ tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]])
+
+ if tenIn.is_cuda == True:
+ cuda_launch(cuda_kernel('softsplat_out', '''
+ extern "C" __global__ void __launch_bounds__(512) softsplat_out(
+ const int n,
+ const {{type}}* __restrict__ tenIn,
+ const {{type}}* __restrict__ tenFlow,
+ {{type}}* __restrict__ tenOut
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
+ const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut);
+ const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut);
+ const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut);
+ const int intX = ( intIndex ) % SIZE_3(tenOut);
+
+ assert(SIZE_1(tenFlow) == 2);
+
+ {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
+ {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
+
+ if (isfinite(fltX) == false) { return; }
+ if (isfinite(fltY) == false) { return; }
+
+ {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX);
+
+ int intNorthwestX = (int) (floor(fltX));
+ int intNorthwestY = (int) (floor(fltY));
+ int intNortheastX = intNorthwestX + 1;
+ int intNortheastY = intNorthwestY;
+ int intSouthwestX = intNorthwestX;
+ int intSouthwestY = intNorthwestY + 1;
+ int intSoutheastX = intNorthwestX + 1;
+ int intSoutheastY = intNorthwestY + 1;
+
+ {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
+ {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
+ {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
+ {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
+
+ if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) {
+ atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest);
+ }
+
+ if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) {
+ atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast);
+ }
+
+ if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) {
+ atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest);
+ }
+
+ if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) {
+ atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast);
+ }
+ } }
+ ''', {
+ 'tenIn': tenIn,
+ 'tenFlow': tenFlow,
+ 'tenOut': tenOut
+ }))(
+ grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]),
+ block=tuple([512, 1, 1]),
+ args=[cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()],
+ stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
+ )
+
+ elif tenIn.is_cuda != True:
+ assert(False)
+
+ # end
+
+ self.save_for_backward(tenIn, tenFlow)
+
+ return tenOut
+ # end
+
+ @staticmethod
+ @torch.cuda.amp.custom_bwd
+ def backward(self, tenOutgrad):
+ tenIn, tenFlow = self.saved_tensors
+
+ tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True)
+
+ tenIngrad = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) if self.needs_input_grad[0] == True else None
+ tenFlowgrad = tenFlow.new_zeros([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) if self.needs_input_grad[1] == True else None
+
+ if tenIngrad is not None:
+ cuda_launch(cuda_kernel('softsplat_ingrad', '''
+ extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad(
+ const int n,
+ const {{type}}* __restrict__ tenIn,
+ const {{type}}* __restrict__ tenFlow,
+ const {{type}}* __restrict__ tenOutgrad,
+ {{type}}* __restrict__ tenIngrad,
+ {{type}}* __restrict__ tenFlowgrad
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
+ const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad);
+ const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad);
+ const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad);
+ const int intX = ( intIndex ) % SIZE_3(tenIngrad);
+
+ assert(SIZE_1(tenFlow) == 2);
+
+ {{type}} fltIngrad = 0.0f;
+
+ {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
+ {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
+
+ if (isfinite(fltX) == false) { return; }
+ if (isfinite(fltY) == false) { return; }
+
+ int intNorthwestX = (int) (floor(fltX));
+ int intNorthwestY = (int) (floor(fltY));
+ int intNortheastX = intNorthwestX + 1;
+ int intNortheastY = intNorthwestY;
+ int intSouthwestX = intNorthwestX;
+ int intSouthwestY = intNorthwestY + 1;
+ int intSoutheastX = intNorthwestX + 1;
+ int intSoutheastY = intNorthwestY + 1;
+
+ {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY);
+ {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY);
+ {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY));
+ {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY));
+
+ if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
+ fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest;
+ }
+
+ if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
+ fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast;
+ }
+
+ if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
+ fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest;
+ }
+
+ if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
+ fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast;
+ }
+
+ tenIngrad[intIndex] = fltIngrad;
+ } }
+ ''', {
+ 'tenIn': tenIn,
+ 'tenFlow': tenFlow,
+ 'tenOutgrad': tenOutgrad,
+ 'tenIngrad': tenIngrad,
+ 'tenFlowgrad': tenFlowgrad
+ }))(
+ grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]),
+ block=tuple([512, 1, 1]),
+ args=[cuda_int32(tenIngrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), tenIngrad.data_ptr(), None],
+ stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
+ )
+ # end
+
+ if tenFlowgrad is not None:
+ cuda_launch(cuda_kernel('softsplat_flowgrad', '''
+ extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad(
+ const int n,
+ const {{type}}* __restrict__ tenIn,
+ const {{type}}* __restrict__ tenFlow,
+ const {{type}}* __restrict__ tenOutgrad,
+ {{type}}* __restrict__ tenIngrad,
+ {{type}}* __restrict__ tenFlowgrad
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
+ const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad);
+ const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad);
+ const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad);
+ const int intX = ( intIndex ) % SIZE_3(tenFlowgrad);
+
+ assert(SIZE_1(tenFlow) == 2);
+
+ {{type}} fltFlowgrad = 0.0f;
+
+ {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
+ {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);
+
+ if (isfinite(fltX) == false) { return; }
+ if (isfinite(fltY) == false) { return; }
+
+ int intNorthwestX = (int) (floor(fltX));
+ int intNorthwestY = (int) (floor(fltY));
+ int intNortheastX = intNorthwestX + 1;
+ int intNortheastY = intNorthwestY;
+ int intSouthwestX = intNorthwestX;
+ int intSouthwestY = intNorthwestY + 1;
+ int intSoutheastX = intNorthwestX + 1;
+ int intSoutheastY = intNorthwestY + 1;
+
+ {{type}} fltNorthwest = 0.0f;
+ {{type}} fltNortheast = 0.0f;
+ {{type}} fltSouthwest = 0.0f;
+ {{type}} fltSoutheast = 0.0f;
+
+ if (intC == 0) {
+ fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY);
+ fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY);
+ fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY));
+ fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY));
+
+ } else if (intC == 1) {
+ fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f));
+ fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f));
+ fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f));
+ fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f));
+
+ }
+
+ for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) {
+ {{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX);
+
+ if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) {
+ fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest;
+ }
+
+ if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) {
+ fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast;
+ }
+
+ if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) {
+ fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest;
+ }
+
+ if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) {
+ fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast;
+ }
+ }
+
+ tenFlowgrad[intIndex] = fltFlowgrad;
+ } }
+ ''', {
+ 'tenIn': tenIn,
+ 'tenFlow': tenFlow,
+ 'tenOutgrad': tenOutgrad,
+ 'tenIngrad': tenIngrad,
+ 'tenFlowgrad': tenFlowgrad
+ }))(
+ grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]),
+ block=tuple([512, 1, 1]),
+ args=[cuda_int32(tenFlowgrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), None, tenFlowgrad.data_ptr()],
+ stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
+ )
+ # end
+
+ return tenIngrad, tenFlowgrad
+ # end
+# end
diff --git a/models/svdxt_featureflow_forward_controlnet_s2d_fixcmp_norefine.py b/models/svdxt_featureflow_forward_controlnet_s2d_fixcmp_norefine.py
new file mode 100644
index 0000000000000000000000000000000000000000..16ff51bf58e92c8fcc033102fecbd96f66379f4a
--- /dev/null
+++ b/models/svdxt_featureflow_forward_controlnet_s2d_fixcmp_norefine.py
@@ -0,0 +1,384 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from diffusers.configuration_utils import register_to_config
+from diffusers.utils import BaseOutput
+
+from models.controlnet_sdv import ControlNetSDVModel, zero_module
+from models.softsplat import softsplat
+import models.cmp.models as cmp_models
+import models.cmp.utils as cmp_utils
+
+import yaml
+import os
+import torchvision.transforms as transforms
+
+
+class ArgObj(object):
+ def __init__(self):
+ pass
+
+
+class CMP_demo(nn.Module):
+ def __init__(self, configfn, load_iter):
+ super().__init__()
+ args = ArgObj()
+ with open(configfn) as f:
+ config = yaml.full_load(f)
+ for k, v in config.items():
+ setattr(args, k, v)
+ setattr(args, 'load_iter', load_iter)
+ setattr(args, 'exp_path', os.path.dirname(configfn))
+
+ self.model = cmp_models.__dict__[args.model['arch']](args.model, dist_model=False)
+ self.model.load_state("{}/checkpoints".format(args.exp_path), args.load_iter, False)
+ self.model.switch_to('eval')
+
+ self.data_mean = args.data['data_mean']
+ self.data_div = args.data['data_div']
+
+ self.img_transform = transforms.Compose([
+ transforms.Normalize(self.data_mean, self.data_div)])
+
+ self.args = args
+ self.fuser = cmp_utils.Fuser(args.model['module']['nbins'], args.model['module']['fmax'])
+ torch.cuda.synchronize()
+
+ def run(self, image, sparse, mask):
+ dtype = image.dtype
+ image = image * 2 - 1
+ self.model.set_input(image.float(), torch.cat([sparse, mask], dim=1).float(), None)
+ cmp_output = self.model.model(self.model.image_input, self.model.sparse_input)
+ flow = self.fuser.convert_flow(cmp_output)
+ if flow.shape[2] != self.model.image_input.shape[2]:
+ flow = nn.functional.interpolate(
+ flow, size=self.model.image_input.shape[2:4],
+ mode="bilinear", align_corners=True)
+
+ return flow.to(dtype) # [b, 2, h, w]
+
+
+
+class FlowControlNetConditioningEmbeddingSVD(nn.Module):
+
+ def __init__(
+ self,
+ conditioning_embedding_channels: int,
+ conditioning_channels: int = 3,
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
+ ):
+ super().__init__()
+
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
+
+ self.blocks = nn.ModuleList([])
+
+ for i in range(len(block_out_channels) - 1):
+ channel_in = block_out_channels[i]
+ channel_out = block_out_channels[i + 1]
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
+
+ self.conv_out = zero_module(
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
+ )
+
+ def forward(self, conditioning):
+
+ embedding = self.conv_in(conditioning)
+ embedding = F.silu(embedding)
+
+ for block in self.blocks:
+ embedding = block(embedding)
+ embedding = F.silu(embedding)
+
+ embedding = self.conv_out(embedding)
+
+ return embedding
+
+
+
+
+class FlowControlNetFirstFrameEncoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ c_in,
+ c_out,
+ is_downsample=False
+ ):
+ super().__init__()
+
+ self.conv_in = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, stride=2 if is_downsample else 1)
+
+ def forward(self, feature):
+ '''
+ feature: [b, c, h, w]
+ '''
+
+ embedding = self.conv_in(feature)
+ embedding = F.silu(embedding)
+
+ return embedding
+
+
+
+class FlowControlNetFirstFrameEncoder(nn.Module):
+ def __init__(
+ self,
+ c_in=320,
+ channels=[320, 640, 1280],
+ downsamples=[True, True, True],
+ use_zeroconv=True
+ ):
+ super().__init__()
+
+ self.encoders = nn.ModuleList([])
+ self.zeroconvs = nn.ModuleList([])
+
+ for channel, downsample in zip(channels, downsamples):
+ self.encoders.append(FlowControlNetFirstFrameEncoderLayer(c_in, channel, is_downsample=downsample))
+ self.zeroconvs.append(zero_module(nn.Conv2d(channel, channel, kernel_size=1)) if use_zeroconv else nn.Identity())
+ c_in = channel
+
+ def forward(self, first_frame):
+ feature = first_frame
+ deep_features = []
+ for encoder, zeroconv in zip(self.encoders, self.zeroconvs):
+ feature = encoder(feature)
+ # print(feature.shape)
+ deep_features.append(zeroconv(feature))
+ return deep_features
+
+
+@dataclass
+class FlowControlNetOutput(BaseOutput):
+ """
+ The output of [`FlowControlNetOutput`].
+
+ Args:
+ down_block_res_samples (`tuple[torch.Tensor]`):
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
+ used to condition the original UNet's downsampling activations.
+ mid_down_block_re_sample (`torch.Tensor`):
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
+ Output can be used to condition the original UNet's middle block activation.
+ """
+
+ down_block_res_samples: Tuple[torch.Tensor]
+ mid_block_res_sample: torch.Tensor
+ controlnet_flow: torch.Tensor
+ cmp_output: torch.Tensor
+
+
+class FlowControlNet(ControlNetSDVModel):
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 8,
+ out_channels: int = 4,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "DownBlockSpatioTemporal",
+ ),
+ up_block_types: Tuple[str] = (
+ "UpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ ),
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ addition_time_embed_dim: int = 256,
+ projection_class_embeddings_input_dim: int = 768,
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
+ num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
+ num_frames: int = 25,
+ conditioning_channels: int = 3,
+ conditioning_embedding_out_channels : Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ ):
+ super().__init__()
+
+ self.flow_encoder = FlowControlNetFirstFrameEncoder()
+
+ self.controlnet_cond_embedding = FlowControlNetConditioningEmbeddingSVD(
+ conditioning_embedding_channels=block_out_channels[0],
+ block_out_channels=conditioning_embedding_out_channels,
+ conditioning_channels=conditioning_channels,
+ )
+
+ def get_warped_frames(self, first_frame, flows):
+ '''
+ video_frame: [b, c, w, h]
+ flows: [b, t-1, c, w, h]
+ '''
+ dtype = first_frame.dtype
+ warped_frames = []
+ for i in range(flows.shape[1]):
+ warped_frame = softsplat(tenIn=first_frame.float(), tenFlow=flows[:, i].float(), tenMetric=None, strMode='avg').to(dtype) # [b, c, w, h]
+ warped_frames.append(warped_frame.unsqueeze(1)) # [b, 1, c, w, h]
+ warped_frames = torch.cat(warped_frames, dim=1) # [b, t-1, c, w, h]
+ return warped_frames
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ added_time_ids: torch.Tensor,
+ controlnet_cond: torch.FloatTensor = None, # [b, 3, h, w]
+ controlnet_flow: torch.FloatTensor = None, # [b, 13, 2, h, w]
+ image_only_indicator: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ guess_mode: bool = False,
+ conditioning_scale: float = 1.0,
+ ) -> Union[FlowControlNetOutput, Tuple]:
+
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ batch_size, num_frames = sample.shape[:2]
+ timesteps = timesteps.expand(batch_size)
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb)
+
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
+ time_embeds = time_embeds.reshape((batch_size, -1))
+ time_embeds = time_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(time_embeds)
+ emb = emb + aug_emb
+
+ # Flatten the batch and frames dimensions
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
+ sample = sample.flatten(0, 1)
+ # Repeat the embeddings num_video_frames times
+ # emb: [batch, channels] -> [batch * frames, channels]
+ emb = emb.repeat_interleave(num_frames, dim=0)
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
+
+ # 2. pre-process
+ sample = self.conv_in(sample) # [b*l, 320, h//8, w//8]
+
+ # controlnet cond
+ if controlnet_cond != None:
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) # [b, 320, h//8, w//8]
+
+ controlnet_cond_features = [controlnet_cond] + self.flow_encoder(controlnet_cond) # [4]
+
+ scales = [8, 16, 32, 64]
+ scale_flows = {}
+ fb, fl, fc, fh, fw = controlnet_flow.shape
+ # print(controlnet_flow.shape)
+ for scale in scales:
+ scaled_flow = F.interpolate(controlnet_flow.reshape(-1, fc, fh, fw), scale_factor=1/scale)
+ scaled_flow = scaled_flow.reshape(fb, fl, fc, fh // scale, fw // scale) / scale
+ scale_flows[scale] = scaled_flow
+
+ warped_cond_features = []
+ for cond_feature in controlnet_cond_features:
+ cb, cc, ch, cw = cond_feature.shape
+ # print(cond_feature.shape)
+ warped_cond_feature = self.get_warped_frames(cond_feature, scale_flows[fh // ch])
+ warped_cond_feature = torch.cat([cond_feature.unsqueeze(1), warped_cond_feature], dim=1) # [b, c, h, w]
+ wb, wl, wc, wh, ww = warped_cond_feature.shape
+ # print(warped_cond_feature.shape)
+ warped_cond_features.append(warped_cond_feature.reshape(wb * wl, wc, wh, ww))
+
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
+
+
+ count = 0
+ length = len(warped_cond_features)
+
+ # add the warped feature in the first scale
+ sample = sample + warped_cond_features[count]
+ count += 1
+
+ down_block_res_samples = (sample,)
+
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+ else:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ image_only_indicator=image_only_indicator,
+ )
+
+ sample = sample + warped_cond_features[min(count, length - 1)]
+ count += 1
+
+ down_block_res_samples += res_samples
+
+ # add the warped feature in the last scale
+ sample = sample + warped_cond_features[-1]
+
+ # 4. mid
+ sample = self.mid_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+
+ controlnet_down_block_res_samples = ()
+
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
+ down_block_res_sample = controlnet_block(down_block_res_sample)
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = controlnet_down_block_res_samples
+
+ mid_block_res_sample = self.controlnet_mid_block(sample)
+
+ # 6. scaling
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
+
+ if not return_dict:
+ return (down_block_res_samples, mid_block_res_sample, controlnet_flow, None)
+
+ return FlowControlNetOutput(
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample, controlnet_flow=controlnet_flow, cmp_output=None
+ )
+
diff --git a/models/unet_spatio_temporal_condition_controlnet.py b/models/unet_spatio_temporal_condition_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1361eeb83ab634ed298c05d3dbddda2b56376c8b
--- /dev/null
+++ b/models/unet_spatio_temporal_condition_controlnet.py
@@ -0,0 +1,504 @@
+from dataclasses import dataclass
+from typing import Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import UNet2DConditionLoadersMixin
+from diffusers.utils import BaseOutput, logging
+from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class UNetSpatioTemporalConditionOutput(BaseOutput):
+ """
+ The output of [`UNetSpatioTemporalConditionModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor = None
+
+
+class UNetSpatioTemporalConditionControlNetModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+ r"""
+ A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample
+ shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
+ The tuple of downsample blocks to use.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
+ The tuple of upsample blocks to use.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ addition_time_embed_dim: (`int`, defaults to 256):
+ Dimension to to encode the additional time ids.
+ projection_class_embeddings_input_dim (`int`, defaults to 768):
+ The dimension of the projection of encoded `added_time_ids`.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
+ [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
+ num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
+ The number of attention heads.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 8,
+ out_channels: int = 4,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+ "DownBlockSpatioTemporal",
+ ),
+ up_block_types: Tuple[str] = (
+ "UpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+ ),
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ addition_time_embed_dim: int = 256,
+ projection_class_embeddings_input_dim: int = 768,
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
+ num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20),
+ num_frames: int = 25,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+
+ # input
+ self.conv_in = nn.Conv2d(
+ in_channels,
+ block_out_channels[0],
+ kernel_size=3,
+ padding=1,
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ blocks_time_embed_dim = time_embed_dim
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=1e-5,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ resnet_act_fn="silu",
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlockSpatioTemporal(
+ block_out_channels[-1],
+ temb_channels=blocks_time_embed_dim,
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ cross_attention_dim=cross_attention_dim[-1],
+ num_attention_heads=num_attention_heads[-1],
+ )
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_layers_per_block = list(reversed(layers_per_block))
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=reversed_layers_per_block[i] + 1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=1e-5,
+ resolution_idx=i,
+ cross_attention_dim=reversed_cross_attention_dim[i],
+ num_attention_heads=reversed_num_attention_heads[i],
+ resnet_act_fn="silu",
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
+ self.conv_act = nn.SiLU()
+
+ self.conv_out = nn.Conv2d(
+ block_out_channels[0],
+ out_channels,
+ kernel_size=3,
+ padding=1,
+ )
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(
+ name: str,
+ module: torch.nn.Module,
+ processors: Dict[str, AttentionProcessor],
+ ):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
+ """
+ Sets the attention processor to use [feed forward
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
+
+ Parameters:
+ chunk_size (`int`, *optional*):
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
+ over each tensor of dim=`dim`.
+ dim (`int`, *optional*, defaults to `0`):
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
+ or dim=1 (sequence length).
+ """
+ if dim not in [0, 1]:
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
+
+ # By default chunk size is 1
+ chunk_size = chunk_size or 1
+
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, chunk_size, dim)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ added_time_ids: torch.Tensor=None,
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
+ r"""
+ The [`UNetSpatioTemporalConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
+ added_time_ids: (`torch.FloatTensor`):
+ The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
+ embeddings and added to the time embeddings.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
+ tuple.
+ Returns:
+ [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
+ a `tuple` is returned where the first element is the sample tensor.
+ """
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ batch_size, num_frames = sample.shape[:2]
+ timesteps = timesteps.expand(batch_size)
+
+ t_emb = self.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb)
+
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
+ time_embeds = time_embeds.reshape((batch_size, -1))
+ time_embeds = time_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(time_embeds)
+ emb = emb + aug_emb
+
+ # Flatten the batch and frames dimensions
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
+ sample = sample.flatten(0, 1)
+ # Repeat the embeddings num_video_frames times
+ # emb: [batch, channels] -> [batch * frames, channels]
+ emb = emb.repeat_interleave(num_frames, dim=0)
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+ else:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ image_only_indicator=image_only_indicator,
+ )
+
+ down_block_res_samples += res_samples
+
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+
+ # 4. mid
+ sample = self.mid_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+ sample = sample + mid_block_additional_residual
+
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ image_only_indicator=image_only_indicator,
+ )
+
+ # 6. post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ # 7. Reshape back to original shape
+ sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
+
+ if not return_dict:
+ return (sample,)
+
+ return UNetSpatioTemporalConditionOutput(sample=sample)
diff --git a/pipeline/pipeline.py b/pipeline/pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5b67c869732c8bf7a9dfd3384835f8e5308a13d
--- /dev/null
+++ b/pipeline/pipeline.py
@@ -0,0 +1,640 @@
+import inspect
+from dataclasses import dataclass
+from typing import Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
+from models.svdxt_featureflow_forward_controlnet_s2d_fixcmp_norefine import FlowControlNet
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.models import AutoencoderKLTemporalDecoder
+from diffusers.utils import BaseOutput, logging
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from utils.scheduling_euler_discrete_karras_fix import EulerDiscreteScheduler
+
+from models.unet_spatio_temporal_condition_controlnet import UNetSpatioTemporalConditionControlNetModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def _get_add_time_ids(
+ noise_aug_strength,
+ dtype,
+ batch_size,
+ fps=4,
+ motion_bucket_id=128,
+ unet=None,
+ ):
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
+
+ passed_add_embed_dim = unet.config.addition_time_embed_dim * len(add_time_ids)
+ expected_add_embed_dim = unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ # add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
+
+
+ return add_time_ids
+
+
+def _append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def tensor2vid(video: torch.Tensor, processor, output_type="np"):
+ # Based on:
+ # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
+
+ batch_size, channels, num_frames, height, width = video.shape
+ outputs = []
+ for batch_idx in range(batch_size):
+ batch_vid = video[batch_idx].permute(1, 0, 2, 3)
+ batch_output = processor.postprocess(batch_vid, output_type)
+
+ outputs.append(batch_output)
+
+ return outputs
+
+
+@dataclass
+class FlowControlNetPipelineOutput(BaseOutput):
+ r"""
+ Output class for zero-shot text-to-video pipeline.
+
+ Args:
+ frames (`[List[PIL.Image.Image]`, `np.ndarray`]):
+ List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
+ num_channels)`.
+ """
+
+ frames: Union[List[PIL.Image.Image], np.ndarray]
+ controlnet_flow: torch.Tensor
+
+
+class FlowControlNetPipeline(DiffusionPipeline):
+ model_cpu_offload_seq = "image_encoder->unet->vae"
+ _callback_tensor_inputs = ["latents"]
+ def __init__(
+ self,
+ vae: AutoencoderKLTemporalDecoder,
+ image_encoder: CLIPVisionModelWithProjection,
+ unet: UNetSpatioTemporalConditionControlNetModel,
+ controlnet: FlowControlNet,
+ scheduler: EulerDiscreteScheduler,
+ feature_extractor: CLIPImageProcessor,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ image_encoder=image_encoder,
+ controlnet=controlnet,
+ unet=unet,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ )
+
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+
+ def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.image_processor.pil_to_numpy(image)
+ image = self.image_processor.numpy_to_pt(image)
+
+ #image = image.unsqueeze(0)
+ image = _resize_with_antialiasing(image, (224, 224))
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeddings = self.image_encoder(image).image_embeds
+ image_embeddings = image_embeddings.unsqueeze(1)
+
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
+ bs_embed, seq_len, _ = image_embeddings.shape
+ image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
+ image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ negative_image_embeddings = torch.zeros_like(image_embeddings)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
+
+ return image_embeddings
+
+ def _encode_vae_image(
+ self,
+ image: torch.Tensor,
+ device,
+ num_videos_per_prompt,
+ do_classifier_free_guidance,
+ ):
+ image = image.to(device=device)
+ image_latents = self.vae.encode(image).latent_dist.mode()
+
+ if do_classifier_free_guidance:
+ negative_image_latents = torch.zeros_like(image_latents)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ image_latents = torch.cat([negative_image_latents, image_latents])
+
+ # duplicate image_latents for each generation per prompt, using mps friendly method
+ image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
+
+ return image_latents
+
+ def _get_add_time_ids(
+ self,
+ fps,
+ motion_bucket_id,
+ noise_aug_strength,
+ dtype,
+ batch_size,
+ num_videos_per_prompt,
+ do_classifier_free_guidance,
+ ):
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
+
+ passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
+
+ if do_classifier_free_guidance:
+ add_time_ids = torch.cat([add_time_ids, add_time_ids])
+
+ return add_time_ids
+
+ def decode_latents(self, latents, num_frames, decode_chunk_size=14):
+ # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
+ latents = latents.flatten(0, 1)
+
+ latents = 1 / self.vae.config.scaling_factor * latents
+
+ accepts_num_frames = "num_frames" in set(inspect.signature(self.vae.forward).parameters.keys())
+
+ # decode decode_chunk_size frames at a time to avoid OOM
+ frames = []
+ for i in range(0, latents.shape[0], decode_chunk_size):
+ num_frames_in = latents[i : i + decode_chunk_size].shape[0]
+ decode_kwargs = {}
+ if accepts_num_frames:
+ # we only pass num_frames_in if it's expected
+ decode_kwargs["num_frames"] = num_frames_in
+
+ frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
+ frames.append(frame)
+ frames = torch.cat(frames, dim=0)
+
+ # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
+ frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ frames = frames.float()
+ return frames
+
+ def check_inputs(self, image, height, width):
+ if (
+ not isinstance(image, torch.Tensor)
+ and not isinstance(image, PIL.Image.Image)
+ and not isinstance(image, list)
+ ):
+ raise ValueError(
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
+ f" {type(image)}"
+ )
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_frames,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ shape = (
+ batch_size,
+ num_frames,
+ num_channels_latents // 2,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], # PIL
+ controlnet_condition: [torch.FloatTensor] = None, # PIL
+ controlnet_flow: [torch.FloatTensor] = None, # [1, 13, 2, h, w]
+ # controlnet_mask: [torch.FloatTensor] = None, # [1, 13, 2, h, w]
+ # val_pixel_values_384: [torch.FloatTensor] = None,
+ # val_sparse_optical_flow_384: [torch.FloatTensor] = None,
+ # val_mask_384: [torch.FloatTensor] = None,
+ height: int = 576,
+ width: int = 1024,
+ num_frames: Optional[int] = None,
+ num_inference_steps: int = 25,
+ min_guidance_scale: float = 1.0,
+ max_guidance_scale: float = 3.0,
+ fps: int = 7,
+ motion_bucket_id: int = 127,
+ noise_aug_strength: int = 0.02,
+ decode_chunk_size: Optional[int] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ return_dict: bool = True,
+ controlnet_cond_scale=1.0,
+ batch_size=1,
+ ):
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(image, height, width)
+
+ # 2. Define call parameters
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = max_guidance_scale > 1.0
+
+ # 3. Encode input image
+ image_embeddings = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
+
+ # NOTE: Stable Diffusion Video was conditioned on fps - 1, which
+ # is why it is reduced here.
+ # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
+ fps = fps - 1
+
+ # 4. Encode input image using VAE
+ image = self.image_processor.preprocess(image, height=height, width=width)
+ noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype)
+ image = image + noise_aug_strength * noise
+
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float32)
+
+ image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
+ image_latents = image_latents.to(image_embeddings.dtype)
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+
+ # Repeat the image latents for each frame so we can concatenate them with the noise
+ # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
+ image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
+ #image_latents = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
+
+ # 5. Get Added Time IDs
+ added_time_ids = self._get_add_time_ids(
+ fps,
+ motion_bucket_id,
+ noise_aug_strength,
+ image_embeddings.dtype,
+ batch_size,
+ num_videos_per_prompt,
+ do_classifier_free_guidance,
+ )
+ added_time_ids = added_time_ids.to(device)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_frames,
+ num_channels_latents,
+ height,
+ width,
+ image_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ #prepare controlnet condition
+ controlnet_condition = self.image_processor.preprocess(controlnet_condition, height=height, width=width)
+ # controlnet_condition = controlnet_condition.unsqueeze(0)
+ controlnet_condition = torch.cat([controlnet_condition] * 2) if do_classifier_free_guidance else latents
+ controlnet_condition = controlnet_condition.to(device, latents.dtype)
+
+ controlnet_flow = torch.cat([controlnet_flow] * 2) if do_classifier_free_guidance else latents
+ controlnet_flow = controlnet_flow.to(device, latents.dtype)
+
+ # print(height, width)
+ # print(controlnet_condition.shape)
+ # print(controlnet_flow.shape)
+ # print(image.shape)
+
+ # assert False
+
+
+ # controlnet_mask = torch.cat([controlnet_mask] * 2) if do_classifier_free_guidance else latents
+ # controlnet_mask = controlnet_mask.to(device, latents.dtype)
+
+ # controlnet_init_flow = torch.cat([controlnet_init_flow] * 2) if do_classifier_free_guidance else latents
+ # controlnet_init_flow = controlnet_init_flow.to(device, latents.dtype)
+
+ # val_pixel_values_384 = torch.cat([val_pixel_values_384] * 2) if do_classifier_free_guidance else latents
+ # val_pixel_values_384 = val_pixel_values_384.to(device, latents.dtype)
+
+ # val_sparse_optical_flow_384 = torch.cat([val_sparse_optical_flow_384] * 2) if do_classifier_free_guidance else latents
+ # val_sparse_optical_flow_384 = val_sparse_optical_flow_384.to(device, latents.dtype)
+
+ # val_mask_384 = torch.cat([val_mask_384] * 2) if do_classifier_free_guidance else latents
+ # val_mask_384 = val_mask_384.to(device, latents.dtype)
+
+ # 7. Prepare guidance scale
+ guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
+ guidance_scale = guidance_scale.to(device, latents.dtype)
+ guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
+ guidance_scale = _append_dims(guidance_scale, latents.ndim)
+
+ self._guidance_scale = guidance_scale
+
+ noise_aug_strength = 0.02 #"¯\_(ツ)_/¯
+ added_time_ids = _get_add_time_ids(
+ noise_aug_strength,
+ image_embeddings.dtype,
+ batch_size,
+ 6,
+ 128,
+ unet=self.unet,
+ )
+ added_time_ids = torch.cat([added_time_ids] * 2)
+ added_time_ids = added_time_ids.to(latents.device)
+
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # Concatenate image_latents over channels dimention
+
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
+
+ # print(latent_model_input.shape)
+ # print(controlnet_flow.shape)
+
+ # assert False
+
+ # controlnet_flow = None
+
+ down_block_res_samples, mid_block_res_sample, controlnet_flow, _ = self.controlnet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=image_embeddings,
+ controlnet_cond=controlnet_condition,
+ controlnet_flow=controlnet_flow,
+ # controlnet_mask=controlnet_mask,
+ # pixel_values_384=val_pixel_values_384,
+ # sparse_optical_flow_384=val_sparse_optical_flow_384,
+ # mask_384=val_mask_384,
+ added_time_ids=added_time_ids,
+ conditioning_scale=controlnet_cond_scale,
+ guess_mode=False,
+ return_dict=False,
+ )
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=image_embeddings,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ added_time_ids=added_time_ids,
+ return_dict=False,
+ )[0]
+
+
+ # assert False
+
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if not output_type == "latent":
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ frames = self.decode_latents(latents.to(self.vae.dtype), num_frames, decode_chunk_size)
+ frames = tensor2vid(frames, self.image_processor, output_type=output_type)
+ else:
+ frames = latents
+
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return frames, controlnet_flow
+
+ return FlowControlNetPipelineOutput(frames=frames, controlnet_flow=controlnet_flow)
+
+
+# resizing utils
+# TODO: clean up later
+def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
+
+ if input.ndim == 3:
+ input = input.unsqueeze(0) # Add a batch dimension
+
+ h, w = input.shape[-2:]
+ factors = (h / size[0], w / size[1])
+
+ # First, we have to determine sigma
+ # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
+ sigmas = (
+ max((factors[0] - 1.0) / 2.0, 0.001),
+ max((factors[1] - 1.0) / 2.0, 0.001),
+ )
+
+ # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
+ # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
+ # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
+ ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
+
+ # Make sure it is odd
+ if (ks[0] % 2) == 0:
+ ks = ks[0] + 1, ks[1]
+
+ if (ks[1] % 2) == 0:
+ ks = ks[0], ks[1] + 1
+
+ input = _gaussian_blur2d(input, ks, sigmas)
+
+ output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
+ return output
+
+
+def _compute_padding(kernel_size):
+ """Compute padding tuple."""
+ # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
+ # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
+ if len(kernel_size) < 2:
+ raise AssertionError(kernel_size)
+ computed = [k - 1 for k in kernel_size]
+
+ # for even kernels we need to do asymmetric padding :(
+ out_padding = 2 * len(kernel_size) * [0]
+
+ for i in range(len(kernel_size)):
+ computed_tmp = computed[-(i + 1)]
+
+ pad_front = computed_tmp // 2
+ pad_rear = computed_tmp - pad_front
+
+ out_padding[2 * i + 0] = pad_front
+ out_padding[2 * i + 1] = pad_rear
+
+ return out_padding
+
+
+def _filter2d(input, kernel):
+ # prepare kernel
+ b, c, h, w = input.shape
+ tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
+
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
+
+ height, width = tmp_kernel.shape[-2:]
+
+ padding_shape: list[int] = _compute_padding([height, width])
+ input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
+
+ # kernel and input tensor reshape to align element-wise or batch-wise params
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
+ input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
+
+ # convolve the tensor with the kernel.
+ output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
+
+ out = output.view(b, c, h, w)
+ return out
+
+
+def _gaussian(window_size: int, sigma):
+ if isinstance(sigma, float):
+ sigma = torch.tensor([[sigma]])
+
+ batch_size = sigma.shape[0]
+
+ x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
+
+ if window_size % 2 == 0:
+ x = x + 0.5
+
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
+
+ return gauss / gauss.sum(-1, keepdim=True)
+
+
+def _gaussian_blur2d(input, kernel_size, sigma):
+ if isinstance(sigma, tuple):
+ sigma = torch.tensor([sigma], dtype=input.dtype)
+ else:
+ sigma = sigma.to(dtype=input.dtype)
+
+ ky, kx = int(kernel_size[0]), int(kernel_size[1])
+ bs = sigma.shape[0]
+ kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
+ kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
+ out_x = _filter2d(input, kernel_x[..., None, :])
+ out = _filter2d(out_x, kernel_y[..., None])
+
+ return out
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..37d2754f001133d113e83c9e705c5cf1b2d77529
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,20 @@
+diffusers==0.24.0
+gradio==4.5.0
+opencv-python
+opencv-python-headless
+scikit-image
+torch==2.0.1
+torchvision==0.15.2
+einops
+accelerate==0.30.1
+transformers==4.41.1
+colorlog
+cupy-cuda117
+av
+gpustat
+
+
+# xformers==0.0.22
+# decord
+# bitsandbytes
+
diff --git a/utils/flow_viz.py b/utils/flow_viz.py
new file mode 100755
index 0000000000000000000000000000000000000000..73c0a357d91e785127b2b9513b2a6951f4ceaf1e
--- /dev/null
+++ b/utils/flow_viz.py
@@ -0,0 +1,291 @@
+# MIT License
+#
+# Copyright (c) 2018 Tom Runia
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to conditions.
+#
+# Author: Tom Runia
+# Date Created: 2018-08-03
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from PIL import Image
+import torch
+
+
+def make_colorwheel():
+ '''
+ Generates a color wheel for optical flow visualization as presented in:
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
+ According to the C++ source code of Daniel Scharstein
+ According to the Matlab source code of Deqing Sun
+ '''
+
+ RY = 15
+ YG = 6
+ GC = 4
+ CB = 11
+ BM = 13
+ MR = 6
+
+ ncols = RY + YG + GC + CB + BM + MR
+ colorwheel = np.zeros((ncols, 3))
+ col = 0
+
+ # RY
+ colorwheel[0:RY, 0] = 255
+ colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
+ col = col + RY
+ # YG
+ colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
+ colorwheel[col:col + YG, 1] = 255
+ col = col + YG
+ # GC
+ colorwheel[col:col + GC, 1] = 255
+ colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
+ col = col + GC
+ # CB
+ colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
+ colorwheel[col:col + CB, 2] = 255
+ col = col + CB
+ # BM
+ colorwheel[col:col + BM, 2] = 255
+ colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
+ col = col + BM
+ # MR
+ colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
+ colorwheel[col:col + MR, 0] = 255
+ return colorwheel
+
+
+def flow_compute_color(u, v, convert_to_bgr=False):
+ '''
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
+ According to the C++ source code of Daniel Scharstein
+ According to the Matlab source code of Deqing Sun
+ :param u: np.ndarray, input horizontal flow
+ :param v: np.ndarray, input vertical flow
+ :param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB
+ :return:
+ '''
+
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
+
+ colorwheel = make_colorwheel() # shape [55x3]
+ ncols = colorwheel.shape[0]
+
+ rad = np.sqrt(np.square(u) + np.square(v))
+ a = np.arctan2(-v, -u) / np.pi
+
+ fk = (a + 1) / 2 * (ncols - 1) + 1
+ k0 = np.floor(fk).astype(np.int32)
+ k1 = k0 + 1
+ k1[k1 == ncols] = 1
+ f = fk - k0
+
+ for i in range(colorwheel.shape[1]):
+ tmp = colorwheel[:, i]
+ col0 = tmp[k0] / 255.0
+ col1 = tmp[k1] / 255.0
+ col = (1 - f) * col0 + f * col1
+
+ idx = (rad <= 1)
+ col[idx] = 1 - rad[idx] * (1 - col[idx])
+ col[~idx] = col[~idx] * 0.75 # out of range?
+
+ # Note the 2-i => BGR instead of RGB
+ ch_idx = 2 - i if convert_to_bgr else i
+ flow_image[:, :, ch_idx] = np.floor(255 * col)
+
+ return flow_image
+
+
+def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False):
+ '''
+ Expects a two dimensional flow image of shape [H,W,2]
+ According to the C++ source code of Daniel Scharstein
+ According to the Matlab source code of Deqing Sun
+ :param flow_uv: np.ndarray of shape [H,W,2]
+ :param clip_flow: float, maximum clipping value for flow
+ :return:
+ '''
+
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
+
+ if clip_flow is not None:
+ flow_uv = np.clip(flow_uv, 0, clip_flow)
+
+ u = flow_uv[:, :, 0]
+ v = flow_uv[:, :, 1]
+
+ rad = np.sqrt(np.square(u) + np.square(v))
+ rad_max = np.max(rad)
+
+ epsilon = 1e-5
+ u = u / (rad_max + epsilon)
+ v = v / (rad_max + epsilon)
+
+ return flow_compute_color(u, v, convert_to_bgr)
+
+
+UNKNOWN_FLOW_THRESH = 1e7
+SMALLFLOW = 0.0
+LARGEFLOW = 1e8
+
+
+def make_color_wheel():
+ """
+ Generate color wheel according Middlebury color code
+ :return: Color wheel
+ """
+ RY = 15
+ YG = 6
+ GC = 4
+ CB = 11
+ BM = 13
+ MR = 6
+
+ ncols = RY + YG + GC + CB + BM + MR
+
+ colorwheel = np.zeros([ncols, 3])
+
+ col = 0
+
+ # RY
+ colorwheel[0:RY, 0] = 255
+ colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY))
+ col += RY
+
+ # YG
+ colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG))
+ colorwheel[col:col + YG, 1] = 255
+ col += YG
+
+ # GC
+ colorwheel[col:col + GC, 1] = 255
+ colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC))
+ col += GC
+
+ # CB
+ colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB))
+ colorwheel[col:col + CB, 2] = 255
+ col += CB
+
+ # BM
+ colorwheel[col:col + BM, 2] = 255
+ colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM))
+ col += + BM
+
+ # MR
+ colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
+ colorwheel[col:col + MR, 0] = 255
+
+ return colorwheel
+
+
+def compute_color(u, v):
+ """
+ compute optical flow color map
+ :param u: optical flow horizontal map
+ :param v: optical flow vertical map
+ :return: optical flow in color code
+ """
+ [h, w] = u.shape
+ img = np.zeros([h, w, 3])
+ nanIdx = np.isnan(u) | np.isnan(v)
+ u[nanIdx] = 0
+ v[nanIdx] = 0
+
+ colorwheel = make_color_wheel()
+ ncols = np.size(colorwheel, 0)
+
+ rad = np.sqrt(u ** 2 + v ** 2)
+
+ a = np.arctan2(-v, -u) / np.pi
+
+ fk = (a + 1) / 2 * (ncols - 1) + 1
+
+ k0 = np.floor(fk).astype(int)
+
+ k1 = k0 + 1
+ k1[k1 == ncols + 1] = 1
+ f = fk - k0
+
+ for i in range(0, np.size(colorwheel, 1)):
+ tmp = colorwheel[:, i]
+ col0 = tmp[k0 - 1] / 255
+ col1 = tmp[k1 - 1] / 255
+ col = (1 - f) * col0 + f * col1
+
+ idx = rad <= 1
+ col[idx] = 1 - rad[idx] * (1 - col[idx])
+ notidx = np.logical_not(idx)
+
+ col[notidx] *= 0.75
+ img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx)))
+
+ return img
+
+
+# from https://github.com/gengshan-y/VCN
+def flow_to_image(flow):
+ """
+ Convert flow into middlebury color code image
+ :param flow: optical flow map
+ :return: optical flow image in middlebury color
+ """
+ u = flow[:, :, 0]
+ v = flow[:, :, 1]
+
+ # maxu = -999.
+ # maxv = -999.
+ # minu = 999.
+ # minv = 999.
+
+ idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
+ u[idxUnknow] = 0
+ v[idxUnknow] = 0
+
+ # maxu = max(maxu, np.max(u))
+ # minu = min(minu, np.min(u))
+
+ # maxv = max(maxv, np.max(v))
+ # minv = min(minv, np.min(v))
+
+ rad = torch.sqrt(u ** 2 + v ** 2)
+ maxrad = max(-1, torch.max(rad).cpu().numpy())
+
+ u = u / (maxrad + np.finfo(float).eps)
+ v = v / (maxrad + np.finfo(float).eps)
+
+ img = compute_color(u.cpu().numpy(), v.cpu().numpy())
+
+ idx = np.repeat(idxUnknow[:, :, np.newaxis].cpu().numpy(), 3, axis=2)
+ img[idx] = 0
+
+ return np.uint8(img)
+
+
+def save_vis_flow_tofile(flow, output_path):
+ vis_flow = flow_to_image(flow)
+ Image.fromarray(vis_flow).save(output_path)
+
+
+def flow_tensor_to_image(flow):
+ """Used for tensorboard visualization"""
+ flow = flow.permute(1, 2, 0) # [H, W, 2]
+ flow = flow.detach().cpu().numpy()
+ flow = flow_to_image(flow) # [H, W, 3]
+ flow = np.transpose(flow, (2, 0, 1)) # [3, H, W]
+
+ return flow
diff --git a/utils/scheduling_euler_discrete_karras_fix.py b/utils/scheduling_euler_discrete_karras_fix.py
new file mode 100644
index 0000000000000000000000000000000000000000..2de68461afb061e2bc5efb3efeb8e54c81b09ca6
--- /dev/null
+++ b/utils/scheduling_euler_discrete_karras_fix.py
@@ -0,0 +1,556 @@
+# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput, logging
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
+import torch.nn.functional as F
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete
+class EulerDiscreteSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
+def rescale_zero_terminal_snr(betas):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+
+
+ Args:
+ betas (`torch.FloatTensor`):
+ the betas that the scheduler is being initialized with.
+
+ Returns:
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
+ """
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = torch.cat([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+
+class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Euler scheduler.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ beta_start (`float`, defaults to 0.0001):
+ The starting `beta` value of inference.
+ beta_end (`float`, defaults to 0.02):
+ The final `beta` value.
+ beta_schedule (`str`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear` or `scaled_linear`.
+ trained_betas (`np.ndarray`, *optional*):
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper).
+ interpolation_type(`str`, defaults to `"linear"`, *optional*):
+ The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be on of
+ `"linear"` or `"log_linear"`.
+ use_karras_sigmas (`bool`, *optional*, defaults to `False`):
+ Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
+ the sigmas are determined according to a sequence of noise levels {σi}.
+ timestep_spacing (`str`, defaults to `"linspace"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ steps_offset (`int`, defaults to 0):
+ An offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
+ Diffusion.
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ prediction_type: str = "epsilon",
+ interpolation_type: str = "linear",
+ use_karras_sigmas: Optional[bool] = False,
+ sigma_min: Optional[float] = None,
+ sigma_max: Optional[float] = None,
+ timestep_spacing: str = "linspace",
+ timestep_type: str = "discrete", # can be "discrete" or "continuous"
+ steps_offset: int = 0,
+ rescale_betas_zero_snr: bool = False,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ if rescale_betas_zero_snr:
+ self.betas = rescale_zero_terminal_snr(self.betas)
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ if rescale_betas_zero_snr:
+ # Close to 0 without being 0 so first sigma is not inf
+ # FP16 smallest positive subnormal works well here
+ self.alphas_cumprod[-1] = 2**-24
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
+
+ sigmas = sigmas[::-1].copy()
+
+ if self.use_karras_sigmas:
+ log_sigmas = np.log(sigmas)
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_train_timesteps)
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
+
+ # setable values
+ self.num_inference_steps = None
+
+ # TODO: Support the full EDM scalings for all prediction types and timestep types
+ if timestep_type == "continuous" and prediction_type == "v_prediction":
+ self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas])
+ else:
+ self.timesteps = torch.from_numpy(timesteps.astype(np.float32))
+
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
+
+ self.is_scale_input_called = False
+ self.use_karras_sigmas = use_karras_sigmas
+
+ self._step_index = None
+
+ @property
+ def init_noise_sigma(self):
+ # standard deviation of the initial noise distribution
+ max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max()
+ if self.config.timestep_spacing in ["linspace", "trailing"]:
+ return max_sigma
+
+ return (max_sigma**2 + 1) ** 0.5
+
+ @property
+ def step_index(self):
+ """
+ The index counter for current timestep. It will increae 1 after each scheduler step.
+ """
+ return self._step_index
+
+ def scale_model_input(
+ self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
+ ) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+
+ Returns:
+ `torch.FloatTensor`:
+ A scaled input sample.
+ """
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ sigma = self.sigmas[self.step_index]
+ sample = sample / ((sigma**2 + 1) ** 0.5)
+
+ self.is_scale_input_called = True
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ """
+ self.num_inference_steps = num_inference_steps
+
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if self.config.timestep_spacing == "linspace":
+ timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
+ ::-1
+ ].copy()
+ elif self.config.timestep_spacing == "leading":
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == "trailing":
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
+ )
+
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ log_sigmas = np.log(sigmas)
+
+ if self.config.interpolation_type == "linear":
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
+ elif self.config.interpolation_type == "log_linear":
+ sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy()
+ else:
+ raise ValueError(
+ f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
+ " 'linear' or 'log_linear'"
+ )
+
+ if self.use_karras_sigmas:
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
+
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
+
+ # TODO: Support the full EDM scalings for all prediction types and timestep types
+ if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
+ self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]).to(device=device)
+ else:
+ self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
+
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
+ self._step_index = None
+
+ def _sigma_to_t(self, sigma, log_sigmas):
+ # get log sigma
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
+
+ # get distribution
+ dists = log_sigma - log_sigmas[:, np.newaxis]
+
+ # get sigmas range
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
+ high_idx = low_idx + 1
+
+ low = log_sigmas[low_idx]
+ high = log_sigmas[high_idx]
+
+ # interpolate sigmas
+ w = (low - log_sigma) / (low - high)
+ w = np.clip(w, 0, 1)
+
+ # transform interpolation to time range
+ t = (1 - w) * low_idx + w * high_idx
+ t = t.reshape(sigma.shape)
+ return t
+
+ # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
+ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
+ """Constructs the noise schedule of Karras et al. (2022)."""
+
+ # Hack to make sure that other schedulers which copy this function don't break
+ # TODO: Add this logic to the other schedulers
+ if hasattr(self.config, "sigma_min"):
+ sigma_min = self.config.sigma_min
+ else:
+ sigma_min = None
+
+ if hasattr(self.config, "sigma_max"):
+ sigma_max = self.config.sigma_max
+ else:
+ sigma_max = None
+
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
+
+ rho = 7.0 # 7.0 is the value used in the paper
+ ramp = np.linspace(0, 1, num_inference_steps)
+ min_inv_rho = sigma_min ** (1 / rho)
+ max_inv_rho = sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return sigmas
+
+ def _init_step_index(self, timestep):
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+
+ index_candidates = (self.timesteps == timestep).nonzero()
+
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ if len(index_candidates) > 1:
+ step_index = index_candidates[1]
+ else:
+ step_index = index_candidates[0]
+
+ self._step_index = step_index.item()
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ sample: torch.FloatTensor,
+ s_churn: float = 0.0,
+ s_tmin: float = 0.0,
+ s_tmax: float = float("inf"),
+ s_noise: float = 1.0,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ A current instance of a sample created by the diffusion process.
+ s_churn (`float`):
+ s_tmin (`float`):
+ s_tmax (`float`):
+ s_noise (`float`, defaults to 1.0):
+ Scaling factor for noise added to the sample.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ return_dict (`bool`):
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
+ tuple.
+
+ Returns:
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
+ """
+
+ if (
+ isinstance(timestep, int)
+ or isinstance(timestep, torch.IntTensor)
+ or isinstance(timestep, torch.LongTensor)
+ ):
+ raise ValueError(
+ (
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
+ " one of the `scheduler.timesteps` as a timestep."
+ ),
+ )
+
+ if not self.is_scale_input_called:
+ logger.warning(
+ "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
+ "See `StableDiffusionPipeline` for a usage example."
+ )
+
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ # Upcast to avoid precision issues when computing prev_sample
+ sample = sample.to(torch.float32)
+
+ sigma = self.sigmas[self.step_index]
+
+ gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
+
+ noise = randn_tensor(
+ model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
+ )
+
+ eps = noise * s_noise
+ sigma_hat = sigma * (gamma + 1)
+
+ if gamma > 0:
+ sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
+
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ # NOTE: "original_sample" should not be an expected prediction_type but is left in for
+ # backwards compatibility
+ if self.config.prediction_type == "original_sample" or self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ elif self.config.prediction_type == "epsilon":
+ pred_original_sample = sample - sigma_hat * model_output
+ elif self.config.prediction_type == "v_prediction":
+ # denoised = model_output * c_out + input * c_skip
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
+ )
+
+ # 2. Convert to an ODE derivative
+ derivative = (sample - pred_original_sample) / sigma_hat
+
+ dt = self.sigmas[self.step_index + 1] - sigma_hat
+
+ prev_sample = sample + derivative * dt
+
+ # Cast sample back to model compatible dtype
+ prev_sample = prev_sample.to(model_output.dtype)
+
+ # upon completion increase step index by one
+ self._step_index += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
+ if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
+ # mps does not support float64
+ schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
+ timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
+ else:
+ schedule_timesteps = self.timesteps.to(original_samples.device)
+ timesteps = timesteps.to(original_samples.device)
+
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(original_samples.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ noisy_samples = original_samples + noise * sigma
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/utils/utils.py b/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b296648f28598cd5f8d0fd9b0613b9173e1b9aad
--- /dev/null
+++ b/utils/utils.py
@@ -0,0 +1,269 @@
+# -*- coding:utf-8 -*-
+import os
+import sys
+import shutil
+import logging
+import colorlog
+from tqdm import tqdm
+import time
+import yaml
+import random
+import importlib
+from PIL import Image
+from warnings import simplefilter
+import imageio
+import math
+import collections
+import json
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.optim import Adam
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from torch.utils.data import DataLoader, Dataset
+from einops import rearrange, repeat
+import torch.distributed as dist
+from torchvision import datasets, transforms, utils
+
+logging.getLogger().setLevel(logging.WARNING)
+simplefilter(action='ignore', category=FutureWarning)
+
+def get_logger(filename=None):
+ """
+ examples:
+ logger = get_logger('try_logging.txt')
+
+ logger.debug("Do something.")
+ logger.info("Start print log.")
+ logger.warning("Something maybe fail.")
+ try:
+ raise ValueError()
+ except ValueError:
+ logger.error("Error", exc_info=True)
+
+ tips:
+ DO NOT logger.inf(some big tensors since color may not helpful.)
+ """
+ logger = logging.getLogger('utils')
+ level = logging.DEBUG
+ logger.setLevel(level=level)
+ # Use propagate to avoid multiple loggings.
+ logger.propagate = False
+ # Remove %(levelname)s since we have colorlog to represent levelname.
+ format_str = '[%(asctime)s <%(filename)s:%(lineno)d> %(funcName)s] %(message)s'
+
+ streamHandler = logging.StreamHandler()
+ streamHandler.setLevel(level)
+ coloredFormatter = colorlog.ColoredFormatter(
+ '%(log_color)s' + format_str,
+ datefmt='%Y-%m-%d %H:%M:%S',
+ reset=True,
+ log_colors={
+ 'DEBUG': 'cyan',
+ # 'INFO': 'white',
+ 'WARNING': 'yellow',
+ 'ERROR': 'red',
+ 'CRITICAL': 'reg,bg_white',
+ }
+ )
+
+ streamHandler.setFormatter(coloredFormatter)
+ logger.addHandler(streamHandler)
+
+ if filename:
+ fileHandler = logging.FileHandler(filename)
+ fileHandler.setLevel(level)
+ formatter = logging.Formatter(format_str)
+ fileHandler.setFormatter(formatter)
+ logger.addHandler(fileHandler)
+
+ # Fix multiple logging for torch.distributed
+ try:
+ class UniqueLogger:
+ def __init__(self, logger):
+ self.logger = logger
+ self.local_rank = torch.distributed.get_rank()
+
+ def info(self, msg, *args, **kwargs):
+ if self.local_rank == 0:
+ return self.logger.info(msg, *args, **kwargs)
+
+ def warning(self, msg, *args, **kwargs):
+ if self.local_rank == 0:
+ return self.logger.warning(msg, *args, **kwargs)
+
+ logger = UniqueLogger(logger)
+ # AssertionError for gpu with no distributed
+ # AttributeError for no gpu.
+ except Exception:
+ pass
+ return logger
+
+
+logger = get_logger()
+
+def split_filename(filename):
+ absname = os.path.abspath(filename)
+ dirname, basename = os.path.split(absname)
+ split_tmp = basename.rsplit('.', maxsplit=1)
+ if len(split_tmp) == 2:
+ rootname, extname = split_tmp
+ elif len(split_tmp) == 1:
+ rootname = split_tmp[0]
+ extname = None
+ else:
+ raise ValueError("programming error!")
+ return dirname, rootname, extname
+
+def data2file(data, filename, type=None, override=False, printable=False, **kwargs):
+ dirname, rootname, extname = split_filename(filename)
+ print_did_not_save_flag = True
+ if type:
+ extname = type
+ if not os.path.exists(dirname):
+ os.makedirs(dirname, exist_ok=True)
+
+ if not os.path.exists(filename) or override:
+ if extname in ['jpg', 'png', 'jpeg']:
+ utils.save_image(data, filename, **kwargs)
+ elif extname == 'gif':
+ imageio.mimsave(filename, data, format='GIF', duration=kwargs.get('duration'), loop=0)
+ elif extname == 'txt':
+ if kwargs is None:
+ kwargs = {}
+ max_step = kwargs.get('max_step')
+ if max_step is None:
+ max_step = np.Infinity
+
+ with open(filename, 'w', encoding='utf-8') as f:
+ for i, e in enumerate(data):
+ if i < max_step:
+ f.write(str(e) + '\n')
+ else:
+ break
+ else:
+ raise ValueError('Do not support this type')
+ if printable: logger.info('Saved data to %s' % os.path.abspath(filename))
+ else:
+ if print_did_not_save_flag: logger.info(
+ 'Did not save data to %s because file exists and override is False' % os.path.abspath(
+ filename))
+
+
+def file2data(filename, type=None, printable=True, **kwargs):
+ dirname, rootname, extname = split_filename(filename)
+ print_load_flag = True
+ if type:
+ extname = type
+
+ if extname in ['pth', 'ckpt']:
+ data = torch.load(filename, map_location=kwargs.get('map_location'))
+ elif extname == 'txt':
+ top = kwargs.get('top', None)
+ with open(filename, encoding='utf-8') as f:
+ if top:
+ data = [f.readline() for _ in range(top)]
+ else:
+ data = [e for e in f.read().split('\n') if e]
+ elif extname == 'yaml':
+ with open(filename, 'r') as f:
+ data = yaml.load(f)
+ else:
+ raise ValueError('type can only support h5, npy, json, txt')
+ if printable:
+ if print_load_flag:
+ logger.info('Loaded data from %s' % os.path.abspath(filename))
+ return data
+
+
+def ensure_dirname(dirname, override=False):
+ if os.path.exists(dirname) and override:
+ logger.info('Removing dirname: %s' % os.path.abspath(dirname))
+ try:
+ shutil.rmtree(dirname)
+ except OSError as e:
+ raise ValueError('Failed to delete %s because %s' % (dirname, e))
+
+ if not os.path.exists(dirname):
+ logger.info('Making dirname: %s' % os.path.abspath(dirname))
+ os.makedirs(dirname, exist_ok=True)
+
+
+def import_filename(filename):
+ spec = importlib.util.spec_from_file_location("mymodule", filename)
+ module = importlib.util.module_from_spec(spec)
+ sys.modules[spec.name] = module
+ spec.loader.exec_module(module)
+ return module
+
+
+def adaptively_load_state_dict(target, state_dict):
+ target_dict = target.state_dict()
+
+ try:
+ common_dict = {k: v for k, v in state_dict.items() if k in target_dict and v.size() == target_dict[k].size()}
+ except Exception as e:
+ logger.warning('load error %s', e)
+ common_dict = {k: v for k, v in state_dict.items() if k in target_dict}
+
+ if 'param_groups' in common_dict and common_dict['param_groups'][0]['params'] != \
+ target.state_dict()['param_groups'][0]['params']:
+ logger.warning('Detected mismatch params, auto adapte state_dict to current')
+ common_dict['param_groups'][0]['params'] = target.state_dict()['param_groups'][0]['params']
+ target_dict.update(common_dict)
+ target.load_state_dict(target_dict)
+
+ missing_keys = [k for k in target_dict.keys() if k not in common_dict]
+ unexpected_keys = [k for k in state_dict.keys() if k not in common_dict]
+
+ if len(unexpected_keys) != 0:
+ logger.warning(
+ f"Some weights of state_dict were not used in target: {unexpected_keys}"
+ )
+ if len(missing_keys) != 0:
+ logger.warning(
+ f"Some weights of state_dict are missing used in target {missing_keys}"
+ )
+ if len(unexpected_keys) == 0 and len(missing_keys) == 0:
+ logger.warning("Strictly Loaded state_dict.")
+
+def set_seed(seed=42):
+ random.seed(seed)
+ os.environ['PYHTONHASHSEED'] = str(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.backends.cudnn.deterministic = True
+
+def image2pil(filename):
+ return Image.open(filename)
+
+
+def image2arr(filename):
+ pil = image2pil(filename)
+ return pil2arr(pil)
+
+
+# 格式转换
+def pil2arr(pil):
+ if isinstance(pil, list):
+ arr = np.array(
+ [np.array(e.convert('RGB').getdata(), dtype=np.uint8).reshape(e.size[1], e.size[0], 3) for e in pil])
+ else:
+ arr = np.array(pil)
+ return arr
+
+
+def arr2pil(arr):
+ if arr.ndim == 3:
+ return Image.fromarray(arr.astype('uint8'), 'RGB')
+ elif arr.ndim == 4:
+ return [Image.fromarray(e.astype('uint8'), 'RGB') for e in list(arr)]
+ else:
+ raise ValueError('arr must has ndim of 3 or 4, but got %s' % arr.ndim)
+
+def notebook_show(*images):
+ from IPython.display import Image
+ from IPython.display import display
+ display(*[Image(e) for e in images])
\ No newline at end of file