File size: 3,930 Bytes
42bd6ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73b09c7
107c9d9
73b09c7
 
 
 
 
 
 
 
 
 
 
 
42bd6ca
 
545f68c
 
 
 
 
 
 
 
 
 
 
 
42bd6ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d96138
42bd6ca
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# import dependencies
from IPython.display import display, Javascript, Image
import numpy as np
import PIL
import io
import html
import time
import torch
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from models.stmodel import STModel
from predictor import Predictor
import argparse
from glob import glob
import os
from ipywidgets import Box, Image
import gradio as gr

def predict_gradio(image):
    img_size = 512
    load_model_path = "./models/st_model_512_80k_12.pth"
    styles_path = "./styles/" 

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_styles = len(glob(os.path.join(styles_path, '*.jpg')))
    st_model = STModel(n_styles)
    if True:
        st_model.load_state_dict(torch.load(load_model_path, map_location=device))
    st_model = st_model.to(device)

    predictor = Predictor(st_model, device, img_size)

    list_gen=[]
    for s in range(n_styles):
        gen = predictor.eval_image(image, s)
        list_gen.append(gen)
    return list_gen

def gradio_pls():
    description="""
Upload a photo and click on submit to see the 12 styles applied to your photo. \n 
Keep in mind that for compatibility reasons your photo is cropped before the neural net applied the different styles.
<center>
<table cellspacing=0>
<tr>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/a_muse_picasso.jpg" class="image_responsive" width=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/britto.jpg" class="image_responsive" width=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/cat.jpg" class="image_responsive" width=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/cubist.jpg" class="image_responsive" width=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/fractal.jpg" class="image_responsive" width=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/horse.jpg" class="image_responsive" width=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/monet.jpg" class="image_responsive" width=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/sketch.jpg" class="image_responsive" width=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/starry_night.jpg" class="image_responsive" width=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/texture.jpg" class="image_responsive" width=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/tsunami.jpg" class="image_responsive" width=300px></td>
<td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/vibrant.jpg" class="image_responsive" width=300px></td>
</tr>
</table>

<style>
.conteneur{
width=100%
}
<\style>

<div class="conteneur">
<img class="image_responsive" src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/vibrant.jpg" width=50%>
<img class="image_responsive" src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/vibrant.jpg" width=50%>
</div>

"""
    iface = gr.Interface(
        predict_gradio,   
        [
            gr.inputs.Image(type="pil", label="Image"),
        ],
        [
            gr.outputs.Carousel("image", label="Style"),
        ],
        layout="unaligned",
        title="Photo Style Transfer",
        description=description,
        theme="grass",
        allow_flagging='never'
    )

    return iface.launch(inbrowser=True, enable_queue=True, height=800, width=800)


if __name__ == '__main__':
    gradio_pls()