callum-canavan's picture
Add helpers, change to hot dog example
954caab
import numpy as np
from PIL import Image
import torch
from einops import einsum, rearrange
from .permutations import make_jigsaw_perm, get_inv_perm
from .view_permute import PermuteView
from .jigsaw_helpers import get_jigsaw_pieces
class JigsawView(PermuteView):
'''
Implements a 4x4 jigsaw puzzle view...
'''
def __init__(self, seed=11):
'''
'''
# Get pixel permutations, corresponding to jigsaw permutations
self.perm_64, _ = make_jigsaw_perm(64, seed=seed)
self.perm_256, (jigsaw_perm) = make_jigsaw_perm(256, seed=seed)
# keep track of jigsaw permutation as well
self.piece_perms, self.edge_swaps = jigsaw_perm
# Init parent PermuteView, with above pixel perms
super().__init__(self.perm_64, self.perm_256)
def extract_pieces(self, im):
'''
Given an image, extract jigsaw puzzle pieces from it
im (PIL.Image) :
PIL Image of the jigsaw illusion
'''
im = np.array(im)
size = im.shape[0]
pieces = []
# Get jigsaw pieces
piece_masks = get_jigsaw_pieces(size)
# Save pieces
for piece_mask in piece_masks:
# Add mask as alpha mask to image
im_piece = np.concatenate([im, piece_mask[:,:,None] * 255], axis=2)
# Get extents of piece, and crop
x_min = np.nonzero(im_piece[:,:,-1].sum(0))[0].min()
x_max = np.nonzero(im_piece[:,:,-1].sum(0))[0].max()
y_min = np.nonzero(im_piece[:,:,-1].sum(1))[0].min()
y_max = np.nonzero(im_piece[:,:,-1].sum(1))[0].max()
im_piece = im_piece[y_min:y_max+1, x_min:x_max+1]
pieces.append(Image.fromarray(im_piece))
return pieces
def paste_piece(self, piece, x, y, theta, xc, yc, canvas_size=384):
'''
Given a PIL Image of a piece, place it so that it's center is at
(x,y) and it's rotate about that center at theta degrees
x (float) : x coordinate to place piece at
y (float) : y coordinate to place piece at
theta (float) : degrees to rotate piece about center
xc (float) : x coordinate of center of piece
yc (float) : y coordinate of center of piece
'''
# Make canvas
canvas = Image.new("RGBA",
(canvas_size, canvas_size),
(255, 255, 255, 0))
# Past piece so center is at (x, y)
canvas.paste(piece, (x-xc,y-yc), piece)
# Rotate about (x, y)
canvas = canvas.rotate(theta, resample=Image.BILINEAR, center=(x, y))
return canvas
def make_frame(self, im, t, canvas_size=384, knot_seed=0):
'''
This function returns a PIL image of a frame animating a jigsaw
permutation. Pieces move and rotate from the identity view
(t = 0) to the rearranged view (t = 1) along splines.
The approach is as follows:
1. Extract all 16 pieces
2. Figure out start locations for each of these pieces (t=0)
3. Figure out how these pieces permute
4. Using these permutations, figure out end locations (t=1)
5. Make knots for splines, randomly offset normally from the
midpoint of the start and end locations
6. Paste pieces into correct locations, determined by
spline interpolation
im (PIL.Image) :
PIL image representing the jigsaw illusion
t (float) :
Interpolation parameter in [0,1] indicating what frame of the
animation to generate
canvas_size (int) :
Side length of the frame
knot_seed (int) :
Seed for random offsets for the knots
'''
im_size = im.size[0]
# Extract 16 jigsaw pieces
pieces = self.extract_pieces(im)
# Rotate all pieces to "base" piece orientation
pieces = [p.rotate(90 * (i % 4),
resample=Image.BILINEAR,
expand=1) for i, p in enumerate(pieces)]
# Get (hardcoded) start locations for each base piece, on a
# 4x4 grid centered on the origin.
corner_start_loc = np.array([-1.5, -1.5])
inner_start_loc = np.array([-0.5, -0.5])
edge_e_start_loc = np.array([-1.5, -0.5])
edge_f_start_loc = np.array([-1.5, 0.5])
base_start_locs = np.stack([corner_start_loc,
inner_start_loc,
edge_e_start_loc,
edge_f_start_loc])
# Construct all start locations by rotating around (0,0)
# by 90 degrees, 4 times, and concatenating the results
rot_mats = []
for theta in -np.arange(4) * 90 / 180 * np.pi:
rot_mat = np.array([[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]])
rot_mats.append(rot_mat)
rot_mats = np.stack(rot_mats)
start_locs = einsum(base_start_locs, rot_mats,
'start i, rot j i -> start rot j')
start_locs = rearrange(start_locs,
'start rot j -> (start rot) j')
# Add rotation information to start locations
thetas = np.tile(np.arange(4) * -90, 4)[:, None]
start_locs = np.concatenate([start_locs, thetas], axis=1)
# Get explicit permutation of pieces from permutation metadata
perm = self.piece_perms + np.repeat(np.arange(4), 4) * 4
for edge_idx, to_swap in enumerate(self.edge_swaps):
if to_swap:
# Make swap permutation array
swap_perm = np.arange(16)
swap_perm[8 + edge_idx], swap_perm[12 + edge_idx] = \
swap_perm[12 + edge_idx], swap_perm[8 + edge_idx]
# Apply swap permutation after perm
perm = np.array([swap_perm[perm[i]] for i in range(16)])
# Get inverse perm (the actual permutation needed)...
perm_inv = get_inv_perm(torch.tensor(perm))
# ...and use it to get the final locations of pieces
end_locs = start_locs[perm_inv]
# Convert start and end locations to pixel coordinate system
start_locs[:,:2] = (start_locs[:,:2] + 2) * 64
end_locs[:,:2] = (end_locs[:,:2] + 2) * 64
# Add offset so pieces are centered on canvas
start_locs[:,:2] = start_locs[:,:2] + (canvas_size - im_size) // 2
end_locs[:,:2] = end_locs[:,:2] + (canvas_size - im_size) // 2
# Get random offsets from middle for spline knot (so path is pretty)
# Wrapped in a set seed
original_state = np.random.get_state()
np.random.seed(knot_seed)
rand_offsets = np.random.rand(16, 1) * 2 - 1
rand_offsets = rand_offsets * 2
eps = np.random.randn(16, 2) # Add epsilon for divide by zero
np.random.set_state(original_state)
# Make spline knots by taking average of start and end,
# and offsetting by some amount normal from the line
avg_locs = (start_locs[:, :2] + end_locs[:, :2]) / 2.
norm = (end_locs[:, :2] - start_locs[:, :2])
norm = norm + eps
norm = norm / np.linalg.norm(norm, axis=1, keepdims=True)
rot_mat = np.array([[0,1], [-1,0]])
norm = norm @ rot_mat
rand_offsets = rand_offsets * (im_size / 4)
knot_locs = avg_locs + norm * rand_offsets
# Paste pieces on to a canvas
canvas = Image.new("RGBA", (canvas_size, canvas_size), (255,255,255,255))
for i in range(16):
# Get start and end coords
y_0, x_0, theta_0 = start_locs[i]
y_1, x_1, theta_1 = end_locs[i]
y_k, x_k = knot_locs[i]
# Take spline interpolation for x and y
x_int_0 = x_0 * (1-t) + x_k * t
y_int_0 = y_0 * (1-t) + y_k * t
x_int_1 = x_k * (1-t) + x_1 * t
y_int_1 = y_k * (1-t) + y_1 * t
x = int(np.round(x_int_0 * (1-t) + x_int_1 * t))
y = int(np.round(y_int_0 * (1-t) + y_int_1 * t))
# Just take normal interpolation for theta
theta = int(np.round(theta_0 * (1-t) + theta_1 * t))
# Get piece in location and rotation
xc = yc = im_size // 4 // 2
pasted_piece = self.paste_piece(pieces[i], x, y, theta, xc, yc)
canvas.paste(pasted_piece, (0,0), pasted_piece)
return canvas