batmangiaicuuthegioi's picture
Upload 7 files
dca2470 verified
from flask import Flask
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
from config import MODEL_CONFIG
from model import CycleGAN
# Load the CycleGAN models
model_paths = {
"CycleGAN_Cezanne_Unet_300": "/checkpoints/checkpoints/cyclegan_cezanne_unet_300_epochs.ckpt",
"CycleGAN_Monet_Unet_250": "/checkpoints/checkpoints/cyclegan_monet_unet_250_epochs.ckpt",
"CycleGAN_Vangogh_Resnet_70": "/cyclegan_vangogh_resnet_70_epochs.ckpt",
"CycleGAN_Vangogh_Unet_70":"/cyclegan_vangogh_unet_70_epochs.ckpt"
}
models = {name: CycleGAN.load_from_checkpoint(path, **MODEL_CONFIG) for name, path in model_paths.items()}
# Define the image transformation
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
# Define the image translation function
def translate_image(input_image, style):
model = models[style]
image = transform(input_image).unsqueeze(0)
with torch.no_grad():
translated_image = model(image)
return transforms.ToPILImage()(translated_image.squeeze(0))
# Initialize the Gradio interface
iface = gr.Interface(
fn=translate_image,
inputs=[
gr.Image(type="pil"),
gr.Dropdown(choices=list(models.keys()), label="Select Style")
],
outputs=gr.Image(type="pil"),
title="CycleGAN Image Translation",
description="Upload an image and select a style to translate it using CycleGAN."
)
if __name__ == "__main__":
iface.launch(debug=True)