# import dependencies from IPython.display import display, Javascript, Image import numpy as np import PIL import io import html import time import torch import matplotlib.pyplot as plt import numpy as np from PIL import Image from models.stmodel import STModel from predictor import Predictor import argparse from glob import glob import os from ipywidgets import Box, Image import gradio as gr def predict_gradio(image): img_size = 512 load_model_path = "./models/st_model_512_80k_12.pth" styles_path = "./styles/" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_styles = len(glob(os.path.join(styles_path, '*.jpg'))) st_model = STModel(n_styles) if True: st_model.load_state_dict(torch.load(load_model_path, map_location=device)) st_model = st_model.to(device) predictor = Predictor(st_model, device, img_size) list_gen=[] for s in range(n_styles): gen = predictor.eval_image(image, s) list_gen.append(gen) return list_gen def gradio_pls(): description=""" Upload a photo and click on submit to see the 12 styles applied to your photo. \n Keep in mind that for compatibility reasons your photo is cropped before the neural net applied the different styles.
""" iface = gr.Interface( predict_gradio, [ gr.inputs.Image(type="pil", label="Image"), ], [ gr.outputs.Carousel("image", label="Style"), ], layout="unaligned", title="Photo Style Transfer", description=description, theme="grass", allow_flagging='never' ) return iface.launch(inbrowser=True, enable_queue=True, height=800, width=800) if __name__ == '__main__': gradio_pls()