Spaces:
Runtime error
Runtime error
| # import dependencies | |
| from IPython.display import display, Javascript, Image | |
| from google.colab.output import eval_js | |
| from google.colab.patches import cv2_imshow | |
| from base64 import b64decode, b64encode | |
| import cv2 | |
| import numpy as np | |
| import PIL | |
| import io | |
| import html | |
| import time | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from PIL import Image | |
| from models.stmodel import STModel | |
| from predictor import Predictor | |
| import argparse | |
| from glob import glob | |
| import os | |
| from ipywidgets import Box, Image | |
| import gradio as gr | |
| def predict_gradio(image): | |
| img_size = 512 | |
| load_model_path = "./models/st_model_512_80k_12.pth" | |
| styles_path = "./styles/" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| n_styles = len(glob(os.path.join(styles_path, '*.jpg'))) | |
| st_model = STModel(n_styles) | |
| if True: | |
| st_model.load_state_dict(torch.load(load_model_path, map_location=device)) | |
| st_model = st_model.to(device) | |
| predictor = Predictor(st_model, device, img_size) | |
| list_gen=[] | |
| for s in range(n_styles): | |
| gen = predictor.eval_image(image, s) | |
| list_gen.append(gen) | |
| return list_gen | |
| def gradio_pls(): | |
| description=""" | |
| Upload a photo and click on submit to see the 12 styles applied to your photo. \n | |
| Keep in mind that for compatibility reasons your photo is cropped before the neural net applied the different styles. | |
| <center> | |
| <table><tr> | |
| <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/a_muse_picasso.jpg" width=100px></td> | |
| <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/britto.jpg" width=100px></td> | |
| <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/cat.jpg" width=100px></td> | |
| <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/cubist.jpg" width=100px></td> | |
| <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/fractal.jpg" width=100px></td> | |
| <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/horse.jpg" width=100px></td> | |
| <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/monet.jpg" width=100px></td> | |
| <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/sketch.jpg" width=100px></td> | |
| <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/starry_night.jpg" width=100px></td> | |
| <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/texture.jpg" width=100px></td> | |
| <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/tsunami.jpg" width=100px></td> | |
| <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/vibrant.jpg" width=100px></td> | |
| </tr> | |
| </table> | |
| </center> | |
| """ | |
| iface = gr.Interface( | |
| predict_gradio, | |
| [ | |
| gr.inputs.Image(type="pil", label="Image"), | |
| ], | |
| [ | |
| gr.outputs.Carousel("image", label="Style"), | |
| ], | |
| layout="unaligned", | |
| title="Photo Style Transfer", | |
| description=description, | |
| theme="grass", | |
| allow_flagging='never' | |
| ) | |
| return iface.launch(inline=True, height=800, width=800) |