gbach1lg's picture
Update app.py
ee9cecf
raw
history blame
4.55 kB
# import dependencies
from IPython.display import display, Javascript, Image
import numpy as np
import PIL
import io
import html
import time
import torch
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from models.stmodel import STModel
from predictor import Predictor
import argparse
from glob import glob
import os
from ipywidgets import Box, Image
import gradio as gr
def predict_gradio(image):
img_size = 512
load_model_path = "./models/st_model_512_80k_12.pth"
styles_path = "./styles/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_styles = len(glob(os.path.join(styles_path, '*.jpg')))
st_model = STModel(n_styles)
if True:
st_model.load_state_dict(torch.load(load_model_path, map_location=device))
st_model = st_model.to(device)
predictor = Predictor(st_model, device, img_size)
list_gen=[]
for s in range(n_styles):
gen = predictor.eval_image(image, s)
list_gen.append(gen)
return list_gen
def gradio_pls():
description="""
Upload a photo and click on submit to see the 12 styles applied to your photo. \n
Keep in mind that for compatibility reasons your photo is cropped before the neural net applied the different styles.
<center>
<table cellpadding=0 cellspacing=0>
<tr>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/a_muse_picasso.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/britto.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/cat.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/cubist.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/fractal.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/horse.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/monet.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/sketch.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/starry_night.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/texture.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/tsunami.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/vibrant.jpg" class="image_responsive" width=300px height=300px></td>
</tr>
</table>
<style type="text/css">
#lampions
{
width: 100%;
height: auto;
}
.lampion
{
float:left;
margin:0 5px; /* 5px a droite et a gauche de l image */
padding:0;
width: 100%;
height: auto;
}
</style>
<table>
<tr>
<div id="lampions">
<img class="lampion" src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/vibrant.jpg" alt="" />
<img class="lampion" src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/vibrant.jpg" alt="" />
<img class="lampion" src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/vibrant.jpg" alt="" />
<img class="lampion" src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/vibrant.jpg" alt="" />
</div>
</tr>
</table>
</p>
"""
iface = gr.Interface(
predict_gradio,
[
gr.inputs.Image(type="pil", label="Image"),
],
[
gr.outputs.Carousel("image", label="Style"),
],
layout="unaligned",
title="Photo Style Transfer",
description=description,
theme="grass",
allow_flagging='never'
)
return iface.launch(inbrowser=True, enable_queue=True, height=800, width=800)
if __name__ == '__main__':
gradio_pls()