File size: 4,547 Bytes
42bd6ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107c9d9
 
1d12497
 
 
 
 
 
 
 
 
 
 
 
42bd6ca
 
e97c45b
 
ee9cecf
 
 
 
 
e97c45b
 
 
53000ed
e97c45b
48fd256
 
e97c45b
 
48fd256
53000ed
e97c45b
 
 
 
 
 
53000ed
 
e97c45b
42bd6ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d96138
42bd6ca
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# import dependencies
from IPython.display import display, Javascript, Image
import numpy as np
import PIL
import io
import html
import time
import torch
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from models.stmodel import STModel
from predictor import Predictor
import argparse
from glob import glob
import os
from ipywidgets import Box, Image
import gradio as gr

def predict_gradio(image):
    img_size = 512
    load_model_path = "./models/st_model_512_80k_12.pth"
    styles_path = "./styles/" 

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_styles = len(glob(os.path.join(styles_path, '*.jpg')))
    st_model = STModel(n_styles)
    if True:
        st_model.load_state_dict(torch.load(load_model_path, map_location=device))
    st_model = st_model.to(device)

    predictor = Predictor(st_model, device, img_size)

    list_gen=[]
    for s in range(n_styles):
        gen = predictor.eval_image(image, s)
        list_gen.append(gen)
    return list_gen

def gradio_pls():
    description="""
Upload a photo and click on submit to see the 12 styles applied to your photo. \n 
Keep in mind that for compatibility reasons your photo is cropped before the neural net applied the different styles.
<center>
<table cellpadding=0 cellspacing=0>
<tr>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/a_muse_picasso.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/britto.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/cat.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/cubist.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/fractal.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/horse.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/monet.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/sketch.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/starry_night.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/texture.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/tsunami.jpg" class="image_responsive" width=300px height=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/vibrant.jpg" class="image_responsive" width=300px height=300px></td>
</tr>
</table>

<style type="text/css">
#lampions
	{
	   width: 100%;
	   height: auto;
	}
	.lampion
	{
	   float:left;
	   margin:0 5px; /* 5px a droite et a gauche de l image */
	   padding:0;
	   width: 100%;
	   height: auto;
	}
</style>
<table>
<tr>
<div id="lampions">
	<img class="lampion" src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/vibrant.jpg" alt="" />
	<img class="lampion" src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/vibrant.jpg" alt="" />
	<img class="lampion" src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/vibrant.jpg" alt="" />
	<img class="lampion" src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/vibrant.jpg" alt="" />
</div>
</tr>
</table>
</p>
"""
    iface = gr.Interface(
        predict_gradio,   
        [
            gr.inputs.Image(type="pil", label="Image"),
        ],
        [
            gr.outputs.Carousel("image", label="Style"),
        ],
        layout="unaligned",
        title="Photo Style Transfer",
        description=description,
        theme="grass",
        allow_flagging='never'
    )

    return iface.launch(inbrowser=True, enable_queue=True, height=800, width=800)


if __name__ == '__main__':
    gradio_pls()