import torch import torch.nn as nn from torchvision import transforms from PIL import Image, ImageFilter import gradio as gr import numpy as np import os import uuid from model import model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) resize_transform = transforms.Resize((512, 512)) def load_image(image): image = Image.fromarray(image).convert('RGB') image = transform(image) return image.unsqueeze(0).to(device) def interpolate_vectors(v1, v2, num_steps): return [v1 * (1 - alpha) + v2 * alpha for alpha in np.linspace(0, 1, num_steps)] def infer_and_interpolate(image1, image2, num_interpolations=24): image1 = load_image(image1) image2 = load_image(image2) with torch.no_grad(): mu1, logvar1 = model.encode(image1) mu2, logvar2 = model.encode(image2) interpolated_vectors = interpolate_vectors(mu1, mu2, num_interpolations) decoded_images = [model.decode(vec).squeeze(0) for vec in interpolated_vectors] return decoded_images def create_gif(decoded_images, duration=200, apply_blur=False): reversed_images = decoded_images[::-1] all_images = decoded_images + reversed_images pil_images = [] for img in all_images: img = (img - img.min()) / (img.max() - img.min()) img = (img * 255).byte() pil_img = transforms.ToPILImage()(img.cpu()).convert("RGB") pil_img = resize_transform(pil_img) if apply_blur: pil_img = pil_img.filter(ImageFilter.GaussianBlur(radius=1)) pil_images.append(pil_img) gif_filename = f"/tmp/morphing_{uuid.uuid4().hex}.gif" pil_images[0].save(gif_filename, save_all=True, append_images=pil_images[1:], duration=duration, loop=0) return gif_filename def create_morphing_gif(image1, image2, num_interpolations=24, duration=200): decoded_images = infer_and_interpolate(image1, image2, num_interpolations) gif_path = create_gif(decoded_images, duration) return gif_path examples = [ ["example_images/image1.jpg", "example_images/image2.png", 24, 200], ["example_images/image3.jpg", "example_images/image4.jpg", 30, 150], ] with gr.Blocks() as morphing: with gr.Column(): with gr.Column(): num_interpolations = gr.Slider(minimum=2, maximum=50, value=24, step=1, label="Number of interpolations") duration = gr.Slider(minimum=100, maximum=1000, value=200, step=50, label="Duration per frame (ms)") generate_button = gr.Button("Generate Morphing GIF") output_gif = gr.Image(label="Morphing GIF") with gr.Row(): image1 = gr.Image(label="Upload first image", type="numpy") image2 = gr.Image(label="Upload second image", type="numpy") generate_button.click(fn=create_morphing_gif, inputs=[image1, image2, num_interpolations, duration], outputs=output_gif) gr.Examples(examples=examples, inputs=[image1, image2, num_interpolations, duration])