''' Usage: from misc_utils.flow_utils import RAFTFlow, load_image_as_tensor, warp_image, MyRandomPerspective, generate_sample image = load_image_as_tensor('hamburger_pic.jpeg', image_size) flow_estimator = RAFTFlow() res = generate_sample( image, flow_estimator, distortion_scale=distortion_scale, ) f1 = res['input'][None] f2 = res['target'][None] flow = res['flow'][None] f1_warp = warp_image(f1, flow) show_image(f1_warp[0]) show_image(f2[0]) ''' import torch import torch.nn.functional as F import torchvision.transforms.functional as TF from torchvision.models.optical_flow import raft_large, Raft_Large_Weights import numpy as np def warp_image(image, flow, mode='bilinear'): """ Warp an image using optical flow. Args: image (torch.Tensor): Input image tensor with shape (N, C, H, W). flow (torch.Tensor): Optical flow tensor with shape (N, 2, H, W). Returns: warped_image (torch.Tensor): Warped image tensor with shape (N, C, H, W). """ # check shape if len(image.shape) == 3: image = image.unsqueeze(0) if len(flow.shape) == 3: flow = flow.unsqueeze(0) if image.device != flow.device: flow = flow.to(image.device) assert image.shape[0] == flow.shape[0], f'Batch size of image and flow must be the same. Got {image.shape[0]} and {flow.shape[0]}.' assert image.shape[2:] == flow.shape[2:], f'Height and width of image and flow must be the same. Got {image.shape[2:]} and {flow.shape[2:]}.' # Generate a grid of sampling points grid = torch.tensor( np.array(np.meshgrid(range(image.shape[3]), range(image.shape[2]), indexing='xy')), dtype=torch.float32, device=image.device )[None] grid = grid.permute(0, 2, 3, 1).repeat(image.shape[0], 1, 1, 1) # (N, H, W, 2) grid += flow.permute(0, 2, 3, 1) # add optical flow to grid # Normalize grid to [-1, 1] grid[:, :, :, 0] = 2 * (grid[:, :, :, 0] / (image.shape[3] - 1) - 0.5) grid[:, :, :, 1] = 2 * (grid[:, :, :, 1] / (image.shape[2] - 1) - 0.5) # Sample input image using the grid warped_image = F.grid_sample(image, grid, mode=mode, align_corners=True) return warped_image def resize_flow(flow, size): """ Resize optical flow tensor to a new size. Args: flow (torch.Tensor): Optical flow tensor with shape (B, 2, H, W). size (tuple[int, int]): Target size as a tuple (H, W). Returns: flow_resized (torch.Tensor): Resized optical flow tensor with shape (B, 2, H, W). """ # Unpack the target size H, W = size # Compute the scaling factors h, w = flow.shape[2:] scale_x = W / w scale_y = H / h # Scale the optical flow by the resizing factors flow_scaled = flow.clone() flow_scaled[:, 0] *= scale_x flow_scaled[:, 1] *= scale_y # Resize the optical flow to the new size (H, W) flow_resized = F.interpolate(flow_scaled, size=(H, W), mode='bilinear', align_corners=False) return flow_resized def check_consistency(flow1: torch.Tensor, flow2: torch.Tensor) -> torch.Tensor: """ Check the consistency of two optical flows. flow1: (B, 2, H, W) flow2: (B, 2, H, W) if want the output to be forward flow, then flow1 is the forward flow and flow2 is the backward flow return: (H, W) """ device = flow1.device height, width = flow1.shape[2:] kernel_x = torch.tensor([[0.5, 0, -0.5]]).unsqueeze(0).unsqueeze(0).to(device) kernel_y = torch.tensor([[0.5], [0], [-0.5]]).unsqueeze(0).unsqueeze(0).to(device) grad_x = torch.nn.functional.conv2d(flow1[:, :1], kernel_x, padding=(0, 1)) grad_y = torch.nn.functional.conv2d(flow1[:, 1:], kernel_y, padding=(1, 0)) motion_edge = (grad_x * grad_x + grad_y * grad_y).sum(dim=1).squeeze(0) ax, ay = torch.meshgrid(torch.arange(width, device=device), torch.arange(height, device=device), indexing='xy') bx, by = ax + flow1[:, 0], ay + flow1[:, 1] x1, y1 = torch.floor(bx).long(), torch.floor(by).long() x2, y2 = x1 + 1, y1 + 1 x1 = torch.clamp(x1, 0, width - 1) x2 = torch.clamp(x2, 0, width - 1) y1 = torch.clamp(y1, 0, height - 1) y2 = torch.clamp(y2, 0, height - 1) alpha_x, alpha_y = bx - x1.float(), by - y1.float() a = (1.0 - alpha_x) * flow2[:, 0, y1, x1] + alpha_x * flow2[:, 0, y1, x2] b = (1.0 - alpha_x) * flow2[:, 0, y2, x1] + alpha_x * flow2[:, 0, y2, x2] u = (1.0 - alpha_y) * a + alpha_y * b a = (1.0 - alpha_x) * flow2[:, 1, y1, x1] + alpha_x * flow2[:, 1, y1, x2] b = (1.0 - alpha_x) * flow2[:, 1, y2, x1] + alpha_x * flow2[:, 1, y2, x2] v = (1.0 - alpha_y) * a + alpha_y * b cx, cy = bx + u, by + v u2, v2 = flow1[:, 0], flow1[:, 1] reliable = ((((cx - ax) ** 2 + (cy - ay) ** 2) < (0.01 * (u2 ** 2 + v2 ** 2 + u ** 2 + v ** 2) + 0.5)) & (motion_edge <= 0.01 * (u2 ** 2 + v2 ** 2) + 0.002)).float() return reliable # (B, 1, H, W) class RAFTFlow(torch.nn.Module): ''' # Instantiate the RAFTFlow class raft_flow = RAFTFlow(device='cuda') # Load a pair of image frames as PyTorch tensors img1 = torch.tensor(np.random.rand(3, 720, 1280), dtype=torch.float32) img2 = torch.tensor(np.random.rand(3, 720, 1280), dtype=torch.float32) # Compute optical flow between the two frames (optional) image_size = (256, 256) or None flow = raft_flow.compute_flow(img1, img2, image_size) # flow will be computed at the original image size if image_size is None # this flow can be used to warp the second image to the first image # Warp the second image using the flow warped_img = warp_image(img2, flow) ''' def __init__(self, *args): """ Args: device (str): Device to run the model on ("cpu" or "cuda"). """ super().__init__(*args) weights = Raft_Large_Weights.DEFAULT self.model = raft_large(weights=weights, progress=False) self.model_transform = weights.transforms() def forward(self, img1, img2, img_size=None): """ Compute optical flow between two frames using RAFT model. Args: img1 (torch.Tensor): First frame tensor with shape (B, C, H, W). img2 (torch.Tensor): Second frame tensor with shape (B, C, H, W). img_size (tuple): Size of the input images to be processed. Returns: flow (torch.Tensor): Optical flow tensor with shape (B, 2, H, W). """ original_size = img1.shape[2:] # Preprocess the input frames if img_size is not None: img1 = TF.resize(img1, size=img_size, antialias=False) img2 = TF.resize(img2, size=img_size, antialias=False) img1, img2 = self.model_transform(img1, img2) # Compute the optical flow using the RAFT model with torch.no_grad(): list_of_flows = self.model(img1, img2) flow = list_of_flows[-1] if img_size is not None: flow = resize_flow(flow, original_size) return flow