gbach1lg commited on
Commit
42bd6ca
1 Parent(s): d338d04

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import dependencies
2
+ from IPython.display import display, Javascript, Image
3
+ from google.colab.output import eval_js
4
+ from google.colab.patches import cv2_imshow
5
+ from base64 import b64decode, b64encode
6
+ import cv2
7
+ import numpy as np
8
+ import PIL
9
+ import io
10
+ import html
11
+ import time
12
+ import torch
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ from PIL import Image
16
+ from models.stmodel import STModel
17
+ from predictor import Predictor
18
+ import argparse
19
+ from glob import glob
20
+ import os
21
+ from ipywidgets import Box, Image
22
+ import gradio as gr
23
+
24
+ def predict_gradio(image):
25
+ img_size = 512
26
+ load_model_path = "./models/st_model_512_80k_12.pth"
27
+ styles_path = "./styles/"
28
+
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ n_styles = len(glob(os.path.join(styles_path, '*.jpg')))
31
+ st_model = STModel(n_styles)
32
+ if True:
33
+ st_model.load_state_dict(torch.load(load_model_path, map_location=device))
34
+ st_model = st_model.to(device)
35
+
36
+ predictor = Predictor(st_model, device, img_size)
37
+
38
+ list_gen=[]
39
+ for s in range(n_styles):
40
+ gen = predictor.eval_image(image, s)
41
+ list_gen.append(gen)
42
+ return list_gen
43
+
44
+ def gradio_pls():
45
+ description="""
46
+ Upload a photo and click on submit to see the 12 styles applied to your photo. \n
47
+ Keep in mind that for compatibility reasons your photo is cropped before the neural net applied the different styles.
48
+ <center>
49
+ <table><tr>
50
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/a_muse_picasso.jpg" width=100px></td>
51
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/britto.jpg" width=100px></td>
52
+
53
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/cat.jpg" width=100px></td>
54
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/cubist.jpg" width=100px></td>
55
+
56
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/fractal.jpg" width=100px></td>
57
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/horse.jpg" width=100px></td>
58
+
59
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/monet.jpg" width=100px></td>
60
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/sketch.jpg" width=100px></td>
61
+
62
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/starry_night.jpg" width=100px></td>
63
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/texture.jpg" width=100px></td>
64
+
65
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/tsunami.jpg" width=100px></td>
66
+ <td><img src="https://raw.githubusercontent.com/dabidou025/Live-Style-Transfer/main/styles/vibrant.jpg" width=100px></td>
67
+
68
+ </tr>
69
+ </table>
70
+ </center>
71
+ """
72
+ iface = gr.Interface(
73
+ predict_gradio,
74
+ [
75
+ gr.inputs.Image(type="pil", label="Image"),
76
+ ],
77
+ [
78
+ gr.outputs.Carousel("image", label="Style"),
79
+ ],
80
+ layout="unaligned",
81
+ title="Photo Style Transfer",
82
+ description=description,
83
+ theme="grass",
84
+ allow_flagging='never'
85
+ )
86
+
87
+ return iface.launch(inbrowser=True, height=800, width=800)
88
+
89
+
90
+ if __name__ == '__main__':
91
+ gradio_pls()