Shingome commited on
Commit
286a8b1
1 Parent(s): 21823e0

Initial commit

Browse files
Files changed (3) hide show
  1. app.py +28 -0
  2. neural_style_transfer.py +107 -0
  3. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # Load compressed models from tensorflow_hub
4
+ os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
5
+
6
+ import matplotlib as mpl
7
+
8
+ mpl.rcParams['figure.figsize'] = (12, 12)
9
+ mpl.rcParams['axes.grid'] = False
10
+
11
+ import tensorflow_hub as hub
12
+
13
+ from neural_style_transfer import StyleStealer
14
+
15
+ import gradio as gr
16
+
17
+ if __name__ == "__main__":
18
+ hub_model = hub.load('https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2')
19
+
20
+ worker = StyleStealer(hub_model)
21
+
22
+ demo = gr.Interface(
23
+ fn=worker.steal,
24
+ inputs=["image", "image"],
25
+ outputs=["image"]
26
+ )
27
+
28
+ demo.launch(share=True)
neural_style_transfer.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import tensorflow as tf
4
+
5
+ os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
6
+
7
+ import matplotlib as mpl
8
+
9
+ mpl.rcParams['figure.figsize'] = (12, 12)
10
+ mpl.rcParams['axes.grid'] = False
11
+
12
+ import numpy as np
13
+ import PIL.Image
14
+
15
+ import tensorflow_hub as hub
16
+
17
+
18
+ def tensor_to_image(tensor):
19
+ tensor = tensor * 255
20
+ tensor = np.array(tensor, dtype=np.uint8)
21
+ if np.ndim(tensor) > 3:
22
+ assert tensor.shape[0] == 1
23
+ tensor = tensor[0]
24
+ return PIL.Image.fromarray(tensor)
25
+
26
+
27
+ def load_img(path_to_img):
28
+ max_dim = 512
29
+ img = tf.io.read_file(path_to_img)
30
+ img = tf.image.decode_image(img, channels=3)
31
+ img = tf.image.convert_image_dtype(img, tf.float32)
32
+
33
+ shape = tf.cast(tf.shape(img)[:-1], tf.float32)
34
+ long_dim = max(shape)
35
+ scale = max_dim / long_dim
36
+
37
+ new_shape = tf.cast(shape * scale, tf.int32)
38
+
39
+ img = tf.image.resize(img, new_shape)
40
+ img = img[tf.newaxis, :]
41
+ return img
42
+
43
+
44
+ def convert_img(pil_image):
45
+ max_dim = 512
46
+
47
+ # Конвертируем PIL изображение в тензор
48
+ img = tf.convert_to_tensor(np.array(pil_image))
49
+
50
+ # Убедимся, что изображение имеет 3 канала (RGB)
51
+ if img.shape[-1] != 3:
52
+ img = tf.stack([img, img, img], axis=-1)
53
+
54
+ img = tf.image.convert_image_dtype(img, tf.float32)
55
+
56
+ shape = tf.cast(tf.shape(img)[:-1], tf.float32)
57
+ long_dim = max(shape)
58
+ scale = max_dim / long_dim
59
+
60
+ new_shape = tf.cast(shape * scale, tf.int32)
61
+
62
+ img = tf.image.resize(img, new_shape)
63
+ img = img[tf.newaxis, :]
64
+ return img
65
+
66
+ def convert_img_break(pil_image):
67
+ max_dim = 512
68
+ # Преобразуем PIL изображение в тензор
69
+ img = tf.keras.preprocessing.image.img_to_array(pil_image)
70
+ img = tf.image.convert_image_dtype(img, tf.float32)
71
+
72
+ shape = tf.cast(tf.shape(img)[:-1], tf.float32)
73
+ long_dim = max(shape)
74
+ scale = max_dim / long_dim
75
+
76
+ new_shape = tf.cast(shape * scale, tf.int32)
77
+
78
+ img = tf.image.resize(img, new_shape)
79
+ img = img[tf.newaxis, :]
80
+ return img
81
+
82
+
83
+ class StyleStealer:
84
+ def __init__(self, model):
85
+ self.model = model
86
+
87
+ def steal_break(self, content_img, style_img):
88
+ return tensor_to_image(self.model(tf.constant(convert_img_break(content_img)),
89
+ tf.constant(convert_img_break(style_img)))[0])
90
+
91
+ def steal(self, content_img, style_img):
92
+ return tensor_to_image(self.model(tf.constant(convert_img(content_img)),
93
+ tf.constant(convert_img(style_img)))[0])
94
+
95
+
96
+ if __name__ == "__main__":
97
+ content_path = "images/input_4.jpg"
98
+ style_path = "images/style_6.png"
99
+
100
+ content_image = load_img(content_path)
101
+ style_image = load_img(style_path)
102
+
103
+ print("Loading model...")
104
+ hub_model = hub.load('https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2')
105
+ print("Model loaded")
106
+ stylized_image = hub_model(tf.constant(content_image), tf.constant(style_image))[0]
107
+ tensor_to_image(stylized_image).save("output.jpg")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==4.28.3
2
+ matplotlib==3.8.4
3
+ numpy==1.26.4
4
+ Pillow==10.3.0
5
+ tensorflow==2.16.1
6
+ tensorflow_hub==0.16.1