Spaces:
Runtime error
Runtime error
File size: 4,547 Bytes
42bd6ca 107c9d9 1d12497 42bd6ca e97c45b ee9cecf e97c45b 53000ed e97c45b 48fd256 e97c45b 48fd256 53000ed e97c45b 53000ed e97c45b 42bd6ca 5d96138 42bd6ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
# import dependencies
from IPython.display import display, Javascript, Image
import numpy as np
import PIL
import io
import html
import time
import torch
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from models.stmodel import STModel
from predictor import Predictor
import argparse
from glob import glob
import os
from ipywidgets import Box, Image
import gradio as gr
def predict_gradio(image):
img_size = 512
load_model_path = "./models/st_model_512_80k_12.pth"
styles_path = "./styles/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_styles = len(glob(os.path.join(styles_path, '*.jpg')))
st_model = STModel(n_styles)
if True:
st_model.load_state_dict(torch.load(load_model_path, map_location=device))
st_model = st_model.to(device)
predictor = Predictor(st_model, device, img_size)
list_gen=[]
for s in range(n_styles):
gen = predictor.eval_image(image, s)
list_gen.append(gen)
return list_gen
def gradio_pls():
description="""
Upload a photo and click on submit to see the 12 styles applied to your photo. \n
Keep in mind that for compatibility reasons your photo is cropped before the neural net applied the different styles.
<center>
<table cellpadding=0 cellspacing=0>
<tr>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/a_muse_picasso.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/britto.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/cat.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/cubist.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/fractal.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/horse.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/monet.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/sketch.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/starry_night.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/texture.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/tsunami.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/vibrant.jpg" class="image_responsive" width=300px height=300px></td>
</tr>
</table>
<style type="text/css">
#lampions
{
width: 100%;
height: auto;
}
.lampion
{
float:left;
margin:0 5px; /* 5px a droite et a gauche de l image */
padding:0;
width: 100%;
height: auto;
}
</style>
<table>
<tr>
<div id="lampions">
<img class="lampion" src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/vibrant.jpg" alt="" />
<img class="lampion" src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/vibrant.jpg" alt="" />
<img class="lampion" src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/vibrant.jpg" alt="" />
<img class="lampion" src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/vibrant.jpg" alt="" />
</div>
</tr>
</table>
</p>
"""
iface = gr.Interface(
predict_gradio,
[
gr.inputs.Image(type="pil", label="Image"),
],
[
gr.outputs.Carousel("image", label="Style"),
],
layout="unaligned",
title="Photo Style Transfer",
description=description,
theme="grass",
allow_flagging='never'
)
return iface.launch(inbrowser=True, enable_queue=True, height=800, width=800)
if __name__ == '__main__':
gradio_pls() |