File size: 6,592 Bytes
67a6282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ff294a
67a6282
 
 
 
 
 
 
 
 
b548269
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import gradio as gr
from transformers import PerceiverForOpticalFlow
import torch
import torch.nn.functional as F
import numpy as np
import requests
from PIL import Image
import matplotlib.pyplot as plt
import itertools
import math
import cv2

model = PerceiverForOpticalFlow.from_pretrained("deepmind/optical-flow-perceiver")
TRAIN_SIZE = model.config.train_size
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def normalize(im):
  return im / 255.0 * 2 - 1

# source: https://discuss.pytorch.org/t/tf-extract-image-patches-in-pytorch/43837/9
def extract_image_patches(x, kernel, stride=1, dilation=1):
    # Do TF 'SAME' Padding
    b,c,h,w = x.shape
    h2 = math.ceil(h / stride)
    w2 = math.ceil(w / stride)
    pad_row = (h2 - 1) * stride + (kernel - 1) * dilation + 1 - h
    pad_col = (w2 - 1) * stride + (kernel - 1) * dilation + 1 - w
    x = F.pad(x, (pad_row//2, pad_row - pad_row//2, pad_col//2, pad_col - pad_col//2))
    
    # Extract patches
    patches = x.unfold(2, kernel, stride).unfold(3, kernel, stride)
    patches = patches.permute(0,4,5,1,2,3).contiguous()
    
    return patches.view(b,-1,patches.shape[-2], patches.shape[-1])

def compute_optical_flow(model, img1, img2, grid_indices, FLOW_SCALE_FACTOR = 20):
  """Function to compute optical flow between two images.

  To compute the flow between images of arbitrary sizes, we divide the image
  into patches, compute the flow for each patch, and stitch the flows together.

  Args:
    model: PyTorch Perceiver model 
    img1: first image
    img2: second image
    grid_indices: indices of the upper left corner for each patch.
  """
  img1 = torch.tensor(np.moveaxis(img1, -1, 0))
  img2 = torch.tensor(np.moveaxis(img2, -1, 0))
  imgs = torch.stack([img1, img2], dim=0)[None]
  height = imgs.shape[-2]
  width = imgs.shape[-1]

  patch_size = model.config.train_size
  
  if height < patch_size[0]:
    raise ValueError(
        f"Height of image (shape: {imgs.shape}) must be at least {patch_size[0]}."
        "Please pad or resize your image to the minimum dimension."
    )
  if width < patch_size[1]:
    raise ValueError(
        f"Width of image (shape: {imgs.shape}) must be at least {patch_size[1]}."
        "Please pad or resize your image to the minimum dimension."
    )

  flows = 0
  flow_count = 0

  for y, x in grid_indices:    
    imgs = torch.stack([img1, img2], dim=0)[None]
    inp_piece = imgs[..., y : y + patch_size[0],
                     x : x + patch_size[1]]
    
    batch_size, _, C, H, W = inp_piece.shape
    patches = extract_image_patches(inp_piece.view(batch_size*2,C,H,W), kernel=3)
    _, C, H, W = patches.shape
    patches = patches.view(batch_size, -1, C, H, W).float().to(model.device)
        
    # actual forward pass
    with torch.no_grad():
      output = model(inputs=patches).logits * FLOW_SCALE_FACTOR
    
    # the code below could also be implemented in PyTorch
    flow_piece = output.cpu().detach().numpy()
    
    weights_x, weights_y = np.meshgrid(
        torch.arange(patch_size[1]), torch.arange(patch_size[0]))

    weights_x = np.minimum(weights_x + 1, patch_size[1] - weights_x)
    weights_y = np.minimum(weights_y + 1, patch_size[0] - weights_y)
    weights = np.minimum(weights_x, weights_y)[np.newaxis, :, :,
                                                np.newaxis]
    padding = [(0, 0), (y, height - y - patch_size[0]),
               (x, width - x - patch_size[1]), (0, 0)]
    flows += np.pad(flow_piece * weights, padding)
    flow_count += np.pad(weights, padding)

    # delete activations to avoid OOM
    del output

  flows /= flow_count
  return flows

def compute_grid_indices(image_shape, patch_size=TRAIN_SIZE, min_overlap=20):
  if min_overlap >= TRAIN_SIZE[0] or min_overlap >= TRAIN_SIZE[1]:
    raise ValueError(
        f"Overlap should be less than size of patch (got {min_overlap}"
        f"for patch size {patch_size}).")
  ys = list(range(0, image_shape[0], TRAIN_SIZE[0] - min_overlap))
  xs = list(range(0, image_shape[1], TRAIN_SIZE[1] - min_overlap))
  # Make sure the final patch is flush with the image boundary
  ys[-1] = image_shape[0] - patch_size[0]
  xs[-1] = image_shape[1] - patch_size[1]
  return itertools.product(ys, xs)

def return_flow(flow):
  flow = np.array(flow)
  # Use Hue, Saturation, Value colour model 
  hsv = np.zeros((flow.shape[0], flow.shape[1], 3), dtype=np.uint8)
  hsv[..., 2] = 255

  mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
  hsv[..., 0] = ang / np.pi / 2 * 180
  hsv[..., 1] = np.clip(mag * 255 / 24, 0, 255)
  bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
  return Image.fromarray(bgr)

# load image examples
urls = ["https://storage.googleapis.com/perceiver_io/sintel_frame1.png", "https://storage.googleapis.com/perceiver_io/sintel_frame2.png"]

for idx, url in enumerate(urls):
  image = Image.open(requests.get(url, stream=True).raw)
  image.save(f"image_{idx}.png")

def process_images(image1, image2):
    im1 = np.array(image1)
    im2 = np.array(image2)

    # Divide images into patches, compute flow between corresponding patches
    # of both images, and stitch the flows together
    grid_indices = compute_grid_indices(im1.shape)
    output = compute_optical_flow(model, normalize(im1), normalize(im2), grid_indices)
        
    # return as PIL Image
    predicted_flow = return_flow(output[0])
    return predicted_flow

title = "Interactive demo: Perceiver for optical flow"
description = "Demo for predicting optical flow (i.e. the task of, given 2 images, estimating the 2D displacement for each pixel in the first image) with Perceiver IO. To use it, simply upload 2 images (e.g. 2 subsequent frames) or use the example images below and click 'submit' to let the model predict the flow of the pixels. Results will show up in a few seconds."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2107.14795'>Perceiver IO: A General Architecture for Structured Inputs & Outputs</a> | <a href='https://deepmind.com/blog/article/building-architectures-that-can-handle-the-worlds-data/'>Official blog</a></p>"
examples =[[f"image_{idx}.png" for idx in range(len(urls))]]

iface = gr.Interface(fn=process_images, 
                     inputs=[gr.inputs.Image(type="pil"), gr.inputs.Image(type="pil")], 
                     outputs=gr.outputs.Image(type="pil"),
                     title=title,
                     description=description,
                     article=article,
                     examples=examples,
                     enable_queue=True)
iface.launch(debug=True)