Spaces:
Runtime error
Runtime error
| from flask import Flask | |
| import gradio as gr | |
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| from config import MODEL_CONFIG | |
| from model import CycleGAN | |
| # Load the CycleGAN models | |
| model_paths = { | |
| "CycleGAN_Cezanne_Unet_300": "/checkpoints/checkpoints/cyclegan_cezanne_unet_300_epochs.ckpt", | |
| "CycleGAN_Monet_Unet_250": "/checkpoints/checkpoints/cyclegan_monet_unet_250_epochs.ckpt", | |
| "CycleGAN_Vangogh_Resnet_70": "/cyclegan_vangogh_resnet_70_epochs.ckpt", | |
| "CycleGAN_Vangogh_Unet_70":"/cyclegan_vangogh_unet_70_epochs.ckpt" | |
| } | |
| models = {name: CycleGAN.load_from_checkpoint(path, **MODEL_CONFIG) for name, path in model_paths.items()} | |
| # Define the image transformation | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| ]) | |
| # Define the image translation function | |
| def translate_image(input_image, style): | |
| model = models[style] | |
| image = transform(input_image).unsqueeze(0) | |
| with torch.no_grad(): | |
| translated_image = model(image) | |
| return transforms.ToPILImage()(translated_image.squeeze(0)) | |
| # Initialize the Gradio interface | |
| iface = gr.Interface( | |
| fn=translate_image, | |
| inputs=[ | |
| gr.Image(type="pil"), | |
| gr.Dropdown(choices=list(models.keys()), label="Select Style") | |
| ], | |
| outputs=gr.Image(type="pil"), | |
| title="CycleGAN Image Translation", | |
| description="Upload an image and select a style to translate it using CycleGAN." | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch(debug=True) |