Spaces:
Runtime error
Runtime error
import gradio as gr | |
from PIL import Image | |
from io import BytesIO | |
import base64 | |
import requests | |
from io import BytesIO | |
from collections import Counter | |
from PIL import Image | |
import numpy as np | |
import matplotlib.pyplot as plt | |
def compute_fft_cross_correlation(img1, img2): | |
fft1 = np.fft.fft2(img1) | |
fft2 = np.fft.fft2(np.rot90(img2, 2), s=img1.shape) | |
result = np.fft.ifft2(fft1 * fft2).real | |
return result | |
def compute_offsets(reference, images, window_size): | |
reference_gray = np.array(reference.convert('L')) | |
offsets = [] | |
for img in images: | |
img_gray = np.array(img.convert('L')) | |
correlation = compute_fft_cross_correlation(reference_gray, img_gray) | |
# Roll the correlation by half the width and height | |
height, width = correlation.shape | |
correlation = np.roll(correlation, height // 2, axis=0) | |
correlation = np.roll(correlation, width // 2, axis=1) | |
# Find the peak in the central region of the correlation | |
center_x, center_y = height // 2, width // 2 | |
start_x, start_y = center_x - window_size // 2, center_y - window_size // 2 | |
end_x, end_y = start_x + window_size, start_y + window_size | |
#make sure starts and ends are in the range(0,height) and (0,width) | |
start_x = max(start_x,0) | |
start_y = max(start_y,0) | |
end_x = min(end_x,height-1) | |
end_y = min(end_y,width-1) | |
window_size_x = end_x - start_x | |
window_size_y = end_y - start_y | |
peak_x, peak_y = np.unravel_index(np.argmax(correlation[start_x:end_x, start_y:end_y]), (window_size_x, window_size_y)) | |
''' | |
#plot the correlation | |
fig, axs = plt.subplots(1, 5, figsize=(10, 5)) | |
axs[0].imshow(reference_gray, cmap='gray') | |
axs[0].set_title('Reference') | |
axs[1].imshow(img_gray, cmap='gray') | |
axs[1].set_title('Image') | |
axs[2].imshow(correlation, cmap='hot', interpolation='nearest', extent=[-window_size, window_size, -window_size, window_size]) | |
axs[2].set_title('Correlation') | |
axs[3].imshow(correlation, cmap='hot', interpolation='nearest') | |
axs[3].set_title('Correlation full') | |
axs[4].imshow(correlation[start_x:end_x, start_y:end_y], cmap='hot', interpolation='nearest') | |
axs[4].set_title('Correlation cropped') | |
plt.show() | |
print("what?",np.argmax(correlation[start_x:end_x, start_y:end_y])) | |
print(peak_x, peak_y,start_x,end_x,start_y,end_y,center_x,center_y) | |
''' | |
# Compute the offset in the range [-window_size, window_size] | |
peak_x += start_x - center_x + 1 | |
peak_y += start_y - center_y + 1 | |
#signs are wrong | |
#peak_x = -peak_x | |
#peak_y = -peak_y | |
#print(peak_x, peak_y) | |
# Compute the offset in the range [-window_size, window_size] | |
if peak_x > correlation.shape[0] // 2: | |
peak_x -= correlation.shape[0] | |
if peak_y > correlation.shape[1] // 2: | |
peak_y -= correlation.shape[1] | |
if peak_x >= 0: | |
peak_x = min(peak_x, window_size) | |
else: | |
peak_x = max(peak_x, -window_size) | |
if peak_y >= 0: | |
peak_y = min(peak_y, window_size) | |
else: | |
peak_y = max(peak_y, -window_size) | |
offsets.append((peak_x, peak_y)) | |
return offsets | |
def find_most_common_color(image): | |
pixels = list(image.getdata()) | |
color_counter = Counter(pixels) | |
return color_counter.most_common(1)[0][0] | |
def slice_frames_final(original, centers, frame_width, frame_height, background_color=(255, 255, 0, 255)): | |
sliced_frames = [] | |
original_width, original_height = original.size | |
for center_x, center_y in centers: | |
left = center_x - frame_width // 2 | |
upper = center_y - frame_height // 2 | |
right = left + frame_width | |
lower = upper + frame_height | |
new_frame = Image.new("RGBA", (frame_width, frame_height), background_color) | |
paste_x = max(0, -left) | |
paste_y = max(0, -upper) | |
cropped_frame = original.crop((max(0, left), max(0, upper), min(original_width, right), min(original_height, lower))) | |
new_frame.paste(cropped_frame, (paste_x, paste_y)) | |
sliced_frames.append(new_frame) | |
return sliced_frames | |
def create_aligned_gif(original_image, columns_per_row, window_size=200, duration=100,output_gif_path = 'output.gif'): | |
original_width, original_height = original_image.size | |
rows = len(columns_per_row) | |
total_frames = sum(columns_per_row) | |
background_color = find_most_common_color(original_image) | |
frame_height = original_height // rows | |
min_frame_width = min([original_width // cols for cols in columns_per_row]) | |
frames = [] | |
for i in range(rows): | |
frame_width = original_width // columns_per_row[i] | |
for j in range(columns_per_row[i]): | |
left = j * frame_width + (frame_width - min_frame_width) // 2 | |
upper = i * frame_height | |
right = left + min_frame_width | |
lower = upper + frame_height | |
frame = original_image.crop((left, upper, right, lower)) | |
frames.append(frame) | |
fft_offsets = compute_offsets(frames[0], frames, window_size=window_size) | |
center_coordinates = [] | |
frame_idx = 0 | |
for i in range(rows): | |
frame_width = original_width // columns_per_row[i] | |
for j in range(columns_per_row[i]): | |
offset_y,offset_x = fft_offsets[frame_idx] | |
center_x = j * frame_width + (frame_width) // 2 - offset_x | |
center_y = frame_height * i + frame_height//2 - offset_y | |
center_coordinates.append((center_x, center_y)) | |
frame_idx += 1 | |
sliced_frames = slice_frames_final(original_image, center_coordinates, min_frame_width, frame_height, background_color=background_color) | |
sliced_frames[0].save(output_gif_path, save_all=True, append_images=sliced_frames[1:], loop=0, duration=duration) | |
''' | |
#display frames | |
for frame in sliced_frames: | |
plt.figure() | |
plt.imshow(frame) | |
''' | |
return output_gif_path | |
def wrapper_func(img, columns_per_row_str,duration): | |
#img = Image.open(BytesIO(file)) | |
#img = Image.fromarray(img_arr) | |
columns_per_row = [int(x.strip()) for x in columns_per_row_str.split(',')] | |
output_gif_path = 'output.gif' | |
create_aligned_gif(img, columns_per_row,duration=duration) | |
#with open(output_gif_path, "rb") as f: | |
#return base64.b64encode(f.read()).decode() | |
# Image.open(output_gif_path) | |
return output_gif_path | |
# Example image in the form of a NumPy array | |
#example_image = Image.open("https://raw.githubusercontent.com/nagolinc/notebooks/main/ss5.png") | |
url = "https://raw.githubusercontent.com/nagolinc/notebooks/main/ss5.png" | |
response = requests.get(url) | |
example_image = Image.open(BytesIO(response.content)) | |
# Example for "Columns per Row" as a string | |
example_columns_per_row = "5,5,5" | |
iface = gr.Interface( | |
fn=wrapper_func, | |
inputs=[ | |
gr.components.Image(label="Upload Spritesheet",type='pil'), | |
gr.components.Textbox(label="Columns per Row", value="3,4,3"), | |
gr.components.Slider(minimum=10, maximum=1000, step=10, value=100, label="Duration of each frame (ms)"), | |
], | |
outputs=gr.components.Image(type="filepath", label="Generated GIF"), | |
examples=[[example_image, example_columns_per_row,100]], # Adding examples here | |
) | |
iface.launch() |