|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
from einops import einsum, rearrange |
|
|
|
from .permutations import make_jigsaw_perm, get_inv_perm |
|
from .view_permute import PermuteView |
|
from .jigsaw_helpers import get_jigsaw_pieces |
|
|
|
class JigsawView(PermuteView): |
|
''' |
|
Implements a 4x4 jigsaw puzzle view... |
|
''' |
|
def __init__(self, seed=11): |
|
''' |
|
''' |
|
|
|
self.perm_64, _ = make_jigsaw_perm(64, seed=seed) |
|
self.perm_256, (jigsaw_perm) = make_jigsaw_perm(256, seed=seed) |
|
|
|
|
|
self.piece_perms, self.edge_swaps = jigsaw_perm |
|
|
|
|
|
super().__init__(self.perm_64, self.perm_256) |
|
|
|
def extract_pieces(self, im): |
|
''' |
|
Given an image, extract jigsaw puzzle pieces from it |
|
|
|
im (PIL.Image) : |
|
PIL Image of the jigsaw illusion |
|
''' |
|
im = np.array(im) |
|
size = im.shape[0] |
|
pieces = [] |
|
|
|
|
|
piece_masks = get_jigsaw_pieces(size) |
|
|
|
|
|
for piece_mask in piece_masks: |
|
|
|
im_piece = np.concatenate([im, piece_mask[:,:,None] * 255], axis=2) |
|
|
|
|
|
x_min = np.nonzero(im_piece[:,:,-1].sum(0))[0].min() |
|
x_max = np.nonzero(im_piece[:,:,-1].sum(0))[0].max() |
|
y_min = np.nonzero(im_piece[:,:,-1].sum(1))[0].min() |
|
y_max = np.nonzero(im_piece[:,:,-1].sum(1))[0].max() |
|
im_piece = im_piece[y_min:y_max+1, x_min:x_max+1] |
|
|
|
pieces.append(Image.fromarray(im_piece)) |
|
|
|
return pieces |
|
|
|
|
|
def paste_piece(self, piece, x, y, theta, xc, yc, canvas_size=384): |
|
''' |
|
Given a PIL Image of a piece, place it so that it's center is at |
|
(x,y) and it's rotate about that center at theta degrees |
|
|
|
x (float) : x coordinate to place piece at |
|
y (float) : y coordinate to place piece at |
|
theta (float) : degrees to rotate piece about center |
|
xc (float) : x coordinate of center of piece |
|
yc (float) : y coordinate of center of piece |
|
''' |
|
|
|
|
|
canvas = Image.new("RGBA", |
|
(canvas_size, canvas_size), |
|
(255, 255, 255, 0)) |
|
|
|
|
|
canvas.paste(piece, (x-xc,y-yc), piece) |
|
|
|
|
|
canvas = canvas.rotate(theta, resample=Image.BILINEAR, center=(x, y)) |
|
return canvas |
|
|
|
|
|
def make_frame(self, im, t, canvas_size=384, knot_seed=0): |
|
''' |
|
This function returns a PIL image of a frame animating a jigsaw |
|
permutation. Pieces move and rotate from the identity view |
|
(t = 0) to the rearranged view (t = 1) along splines. |
|
|
|
The approach is as follows: |
|
|
|
1. Extract all 16 pieces |
|
2. Figure out start locations for each of these pieces (t=0) |
|
3. Figure out how these pieces permute |
|
4. Using these permutations, figure out end locations (t=1) |
|
5. Make knots for splines, randomly offset normally from the |
|
midpoint of the start and end locations |
|
6. Paste pieces into correct locations, determined by |
|
spline interpolation |
|
|
|
im (PIL.Image) : |
|
PIL image representing the jigsaw illusion |
|
|
|
t (float) : |
|
Interpolation parameter in [0,1] indicating what frame of the |
|
animation to generate |
|
|
|
canvas_size (int) : |
|
Side length of the frame |
|
|
|
knot_seed (int) : |
|
Seed for random offsets for the knots |
|
''' |
|
im_size = im.size[0] |
|
|
|
|
|
pieces = self.extract_pieces(im) |
|
|
|
|
|
pieces = [p.rotate(90 * (i % 4), |
|
resample=Image.BILINEAR, |
|
expand=1) for i, p in enumerate(pieces)] |
|
|
|
|
|
|
|
corner_start_loc = np.array([-1.5, -1.5]) |
|
inner_start_loc = np.array([-0.5, -0.5]) |
|
edge_e_start_loc = np.array([-1.5, -0.5]) |
|
edge_f_start_loc = np.array([-1.5, 0.5]) |
|
base_start_locs = np.stack([corner_start_loc, |
|
inner_start_loc, |
|
edge_e_start_loc, |
|
edge_f_start_loc]) |
|
|
|
|
|
|
|
rot_mats = [] |
|
for theta in -np.arange(4) * 90 / 180 * np.pi: |
|
rot_mat = np.array([[np.cos(theta), -np.sin(theta)], |
|
[np.sin(theta), np.cos(theta)]]) |
|
rot_mats.append(rot_mat) |
|
rot_mats = np.stack(rot_mats) |
|
start_locs = einsum(base_start_locs, rot_mats, |
|
'start i, rot j i -> start rot j') |
|
start_locs = rearrange(start_locs, |
|
'start rot j -> (start rot) j') |
|
|
|
|
|
thetas = np.tile(np.arange(4) * -90, 4)[:, None] |
|
start_locs = np.concatenate([start_locs, thetas], axis=1) |
|
|
|
|
|
perm = self.piece_perms + np.repeat(np.arange(4), 4) * 4 |
|
for edge_idx, to_swap in enumerate(self.edge_swaps): |
|
if to_swap: |
|
|
|
swap_perm = np.arange(16) |
|
swap_perm[8 + edge_idx], swap_perm[12 + edge_idx] = \ |
|
swap_perm[12 + edge_idx], swap_perm[8 + edge_idx] |
|
|
|
|
|
perm = np.array([swap_perm[perm[i]] for i in range(16)]) |
|
|
|
|
|
perm_inv = get_inv_perm(torch.tensor(perm)) |
|
|
|
|
|
end_locs = start_locs[perm_inv] |
|
|
|
|
|
start_locs[:,:2] = (start_locs[:,:2] + 2) * 64 |
|
end_locs[:,:2] = (end_locs[:,:2] + 2) * 64 |
|
|
|
|
|
start_locs[:,:2] = start_locs[:,:2] + (canvas_size - im_size) // 2 |
|
end_locs[:,:2] = end_locs[:,:2] + (canvas_size - im_size) // 2 |
|
|
|
|
|
|
|
original_state = np.random.get_state() |
|
np.random.seed(knot_seed) |
|
rand_offsets = np.random.rand(16, 1) * 2 - 1 |
|
rand_offsets = rand_offsets * 2 |
|
eps = np.random.randn(16, 2) |
|
np.random.set_state(original_state) |
|
|
|
|
|
|
|
avg_locs = (start_locs[:, :2] + end_locs[:, :2]) / 2. |
|
norm = (end_locs[:, :2] - start_locs[:, :2]) |
|
norm = norm + eps |
|
norm = norm / np.linalg.norm(norm, axis=1, keepdims=True) |
|
rot_mat = np.array([[0,1], [-1,0]]) |
|
norm = norm @ rot_mat |
|
rand_offsets = rand_offsets * (im_size / 4) |
|
knot_locs = avg_locs + norm * rand_offsets |
|
|
|
|
|
canvas = Image.new("RGBA", (canvas_size, canvas_size), (255,255,255,255)) |
|
for i in range(16): |
|
|
|
y_0, x_0, theta_0 = start_locs[i] |
|
y_1, x_1, theta_1 = end_locs[i] |
|
y_k, x_k = knot_locs[i] |
|
|
|
|
|
x_int_0 = x_0 * (1-t) + x_k * t |
|
y_int_0 = y_0 * (1-t) + y_k * t |
|
x_int_1 = x_k * (1-t) + x_1 * t |
|
y_int_1 = y_k * (1-t) + y_1 * t |
|
x = int(np.round(x_int_0 * (1-t) + x_int_1 * t)) |
|
y = int(np.round(y_int_0 * (1-t) + y_int_1 * t)) |
|
|
|
|
|
theta = int(np.round(theta_0 * (1-t) + theta_1 * t)) |
|
|
|
|
|
xc = yc = im_size // 4 // 2 |
|
pasted_piece = self.paste_piece(pieces[i], x, y, theta, xc, yc) |
|
|
|
canvas.paste(pasted_piece, (0,0), pasted_piece) |
|
|
|
return canvas |
|
|