nagolinc's picture
Update app.py
ff294d6 verified
import gradio as gr
from PIL import Image
from io import BytesIO
import base64
import requests
from io import BytesIO
from collections import Counter
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
def compute_fft_cross_correlation(img1, img2):
fft1 = np.fft.fft2(img1)
fft2 = np.fft.fft2(np.rot90(img2, 2), s=img1.shape)
result = np.fft.ifft2(fft1 * fft2).real
return result
def compute_offsets(reference, images, window_size):
reference_gray = np.array(reference.convert('L'))
offsets = []
for img in images:
img_gray = np.array(img.convert('L'))
correlation = compute_fft_cross_correlation(reference_gray, img_gray)
# Roll the correlation by half the width and height
height, width = correlation.shape
correlation = np.roll(correlation, height // 2, axis=0)
correlation = np.roll(correlation, width // 2, axis=1)
# Find the peak in the central region of the correlation
center_x, center_y = height // 2, width // 2
start_x, start_y = center_x - window_size // 2, center_y - window_size // 2
end_x, end_y = start_x + window_size, start_y + window_size
#make sure starts and ends are in the range(0,height) and (0,width)
start_x = max(start_x,0)
start_y = max(start_y,0)
end_x = min(end_x,height-1)
end_y = min(end_y,width-1)
window_size_x = end_x - start_x
window_size_y = end_y - start_y
peak_x, peak_y = np.unravel_index(np.argmax(correlation[start_x:end_x, start_y:end_y]), (window_size_x, window_size_y))
'''
#plot the correlation
fig, axs = plt.subplots(1, 5, figsize=(10, 5))
axs[0].imshow(reference_gray, cmap='gray')
axs[0].set_title('Reference')
axs[1].imshow(img_gray, cmap='gray')
axs[1].set_title('Image')
axs[2].imshow(correlation, cmap='hot', interpolation='nearest', extent=[-window_size, window_size, -window_size, window_size])
axs[2].set_title('Correlation')
axs[3].imshow(correlation, cmap='hot', interpolation='nearest')
axs[3].set_title('Correlation full')
axs[4].imshow(correlation[start_x:end_x, start_y:end_y], cmap='hot', interpolation='nearest')
axs[4].set_title('Correlation cropped')
plt.show()
print("what?",np.argmax(correlation[start_x:end_x, start_y:end_y]))
print(peak_x, peak_y,start_x,end_x,start_y,end_y,center_x,center_y)
'''
# Compute the offset in the range [-window_size, window_size]
peak_x += start_x - center_x + 1
peak_y += start_y - center_y + 1
#signs are wrong
#peak_x = -peak_x
#peak_y = -peak_y
#print(peak_x, peak_y)
# Compute the offset in the range [-window_size, window_size]
if peak_x > correlation.shape[0] // 2:
peak_x -= correlation.shape[0]
if peak_y > correlation.shape[1] // 2:
peak_y -= correlation.shape[1]
if peak_x >= 0:
peak_x = min(peak_x, window_size)
else:
peak_x = max(peak_x, -window_size)
if peak_y >= 0:
peak_y = min(peak_y, window_size)
else:
peak_y = max(peak_y, -window_size)
offsets.append((peak_x, peak_y))
return offsets
def find_most_common_color(image):
pixels = list(image.getdata())
color_counter = Counter(pixels)
return color_counter.most_common(1)[0][0]
def slice_frames_final(original, centers, frame_width, frame_height, background_color=(255, 255, 0, 255)):
sliced_frames = []
original_width, original_height = original.size
for center_x, center_y in centers:
left = center_x - frame_width // 2
upper = center_y - frame_height // 2
right = left + frame_width
lower = upper + frame_height
new_frame = Image.new("RGBA", (frame_width, frame_height), background_color)
paste_x = max(0, -left)
paste_y = max(0, -upper)
cropped_frame = original.crop((max(0, left), max(0, upper), min(original_width, right), min(original_height, lower)))
new_frame.paste(cropped_frame, (paste_x, paste_y))
sliced_frames.append(new_frame)
return sliced_frames
def create_aligned_gif(original_image, columns_per_row, window_size=200, duration=100,output_gif_path = 'output.gif'):
original_width, original_height = original_image.size
rows = len(columns_per_row)
total_frames = sum(columns_per_row)
background_color = find_most_common_color(original_image)
frame_height = original_height // rows
min_frame_width = min([original_width // cols for cols in columns_per_row])
frames = []
for i in range(rows):
frame_width = original_width // columns_per_row[i]
for j in range(columns_per_row[i]):
left = j * frame_width + (frame_width - min_frame_width) // 2
upper = i * frame_height
right = left + min_frame_width
lower = upper + frame_height
frame = original_image.crop((left, upper, right, lower))
frames.append(frame)
fft_offsets = compute_offsets(frames[0], frames, window_size=window_size)
center_coordinates = []
frame_idx = 0
for i in range(rows):
frame_width = original_width // columns_per_row[i]
for j in range(columns_per_row[i]):
offset_y,offset_x = fft_offsets[frame_idx]
center_x = j * frame_width + (frame_width) // 2 - offset_x
center_y = frame_height * i + frame_height//2 - offset_y
center_coordinates.append((center_x, center_y))
frame_idx += 1
sliced_frames = slice_frames_final(original_image, center_coordinates, min_frame_width, frame_height, background_color=background_color)
sliced_frames[0].save(output_gif_path, save_all=True, append_images=sliced_frames[1:], loop=0, duration=duration)
'''
#display frames
for frame in sliced_frames:
plt.figure()
plt.imshow(frame)
'''
return output_gif_path
def wrapper_func(img, columns_per_row_str,duration):
#img = Image.open(BytesIO(file))
#img = Image.fromarray(img_arr)
columns_per_row = [int(x.strip()) for x in columns_per_row_str.split(',')]
output_gif_path = 'output.gif'
create_aligned_gif(img, columns_per_row,duration=duration)
#with open(output_gif_path, "rb") as f:
#return base64.b64encode(f.read()).decode()
# Image.open(output_gif_path)
return output_gif_path
# Example image in the form of a NumPy array
#example_image = Image.open("https://raw.githubusercontent.com/nagolinc/notebooks/main/ss5.png")
url = "https://raw.githubusercontent.com/nagolinc/notebooks/main/ss5.png"
response = requests.get(url)
example_image = Image.open(BytesIO(response.content))
# Example for "Columns per Row" as a string
example_columns_per_row = "5,5,5"
iface = gr.Interface(
fn=wrapper_func,
inputs=[
gr.components.Image(label="Upload Spritesheet",type='pil'),
gr.components.Textbox(label="Columns per Row", value="3,4,3"),
gr.components.Slider(minimum=10, maximum=1000, step=10, value=100, label="Duration of each frame (ms)"),
],
outputs=gr.components.Image(type="filepath", label="Generated GIF"),
examples=[[example_image, example_columns_per_row,100]], # Adding examples here
)
iface.launch()