Vivien commited on
Commit
8ca63da
1 Parent(s): 7328a87

Initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
1
+ result.jpg
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Depth Aware Caption
3
- emoji: 🐢
4
- colorFrom: red
5
- colorTo: indigo
6
  sdk: streamlit
7
  sdk_version: 1.2.0
8
  app_file: app.py
1
  ---
2
+ title: Depth-aware text addition
3
+ emoji: ✍️
4
+ colorFrom: green
5
+ colorTo: blue
6
  sdk: streamlit
7
  sdk_version: 1.2.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import PIL
3
+ import torch
4
+ import streamlit as st
5
+ import cv2
6
+
7
+ DEBUG = False
8
+ if DEBUG:
9
+ device = torch.device("cpu")
10
+ model_name = "MiDaS_small"
11
+ else:
12
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
13
+ model_name = "DPT_Large"
14
+
15
+ FONTS = [
16
+ "Font: Serif - EBGaramond",
17
+ "Font: Serif - Cinzel",
18
+ "Font: Sans - Roboto",
19
+ "Font: Sans - Lato",
20
+ "Font: Display - Lobster",
21
+ "Font: Display - LilitaOne",
22
+ "Font: Handwriting - GreatVibes",
23
+ "Font: Handwriting - Pacifico",
24
+ "Font: Mono - Inconsolata",
25
+ "Font: Mono - Cutive",
26
+ ]
27
+
28
+ CACHE_KWARGS = {
29
+ "show_spinner": False,
30
+ "hash_funcs": {torch.nn.parameter.Parameter: lambda _: None},
31
+ "allow_output_mutation": True,
32
+ "ttl": 900,
33
+ "max_entries": 20,
34
+ }
35
+
36
+
37
+ def hex_to_rgb(hex):
38
+ rgb = []
39
+ for i in (0, 2, 4):
40
+ decimal = int(hex[i : i + 2], 16)
41
+ rgb.append(decimal)
42
+ return tuple(rgb)
43
+
44
+
45
+ @st.cache(
46
+ show_spinner=True,
47
+ hash_funcs={torch.nn.parameter.Parameter: lambda _: None},
48
+ allow_output_mutation=True,
49
+ )
50
+ def load(model_type):
51
+ midas = torch.hub.load("intel-isl/MiDaS", model_type)
52
+ midas.to(device)
53
+ _ = midas.eval()
54
+ midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
55
+ if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
56
+ transform = midas_transforms.dpt_transform
57
+ else:
58
+ transform = midas_transforms.small_transform
59
+ return midas, transform
60
+
61
+
62
+ midas, transform = load(model_name)
63
+
64
+
65
+ @st.cache(**CACHE_KWARGS)
66
+ def compute_depth(img):
67
+ with torch.no_grad():
68
+ prediction = midas(transform(img).to(device))
69
+
70
+ prediction = torch.nn.functional.interpolate(
71
+ prediction.unsqueeze(1),
72
+ size=img.shape[:2],
73
+ mode="bicubic",
74
+ align_corners=False,
75
+ ).squeeze()
76
+ return prediction.cpu().numpy()
77
+
78
+
79
+ @st.cache(**CACHE_KWARGS)
80
+ def get_mask1(shape, caption, font=None, font_size=0.08, color=(0, 0, 0), alpha=0.8):
81
+ img_text = PIL.Image.new("RGBA", (shape[1], shape[0]), (0, 0, 0, 0))
82
+ draw = PIL.ImageDraw.Draw(img_text)
83
+ font = PIL.ImageFont.truetype(font, int(font_size * img.shape[1]))
84
+ draw.text(
85
+ (x * img.shape[1], (1 - y) * img.shape[0]),
86
+ caption,
87
+ fill=(*color, int(max(min(1, alpha), 0) * 255)),
88
+ font=font,
89
+ )
90
+ text = np.array(img_text)
91
+ mask1 = np.dot(np.expand_dims(text[:, :, -1] / 255, -1), np.ones((1, 3)))
92
+ text = text[:, :, :-1]
93
+ return text, mask1
94
+
95
+
96
+ @st.cache(**CACHE_KWARGS)
97
+ def get_mask2(depth_map, depth):
98
+ m = np.expand_dims(
99
+ (depth_map[:, :] < depth * np.min(depth_map) + (1 - depth) * np.max(depth_map)),
100
+ -1,
101
+ )
102
+ return np.dot(m, np.ones((1, 3)))
103
+
104
+
105
+ @st.cache(**CACHE_KWARGS)
106
+ def add_caption(
107
+ img,
108
+ caption,
109
+ depth_map=None,
110
+ x=0.5,
111
+ y=0.5,
112
+ depth=0.5,
113
+ font_size=50,
114
+ color=(255, 255, 255),
115
+ font="",
116
+ alpha=1,
117
+ ):
118
+ if depth_map is None:
119
+ depth_map = compute_depth(img)
120
+ text, mask1 = get_mask1(
121
+ img.shape, caption, font=font, font_size=font_size, color=color, alpha=alpha
122
+ )
123
+ mask2 = get_mask2(depth_map, depth)
124
+ mask = mask1 * mask2
125
+
126
+ return ((1 - mask) * img + mask * text).astype(np.uint8)
127
+
128
+
129
+ st.markdown(
130
+ """
131
+ <style>
132
+ label{
133
+ height: 0px !important;
134
+ min-height: 0px !important;
135
+ margin-bottom: 0px !important;
136
+ }
137
+ </style>
138
+ """,
139
+ unsafe_allow_html=True,
140
+ )
141
+
142
+ st.sidebar.markdown(
143
+ """
144
+ # Depth-aware text addition
145
+
146
+ Add text ***inside*** an image!
147
+
148
+ Upload an image, enter some text and adjust the ***depth*** where you want the text to be displayed. You can also define its location and appearance (font, color, transparency and size).
149
+
150
+ Built with [PyTorch](https://pytorch.org/), Intel's [MiDaS model](https://pytorch.org/hub/intelisl_midas_v2/), [Streamlit](https://streamlit.io/), [pillow](https://python-pillow.org/) and inspired by the official [video](https://youtu.be/eTa1jHk1Lxc) of *Jenny of Oldstones* by Florence + the Machine
151
+ """
152
+ )
153
+
154
+ uploaded_file = st.file_uploader("", type=["jpg", "jpeg"])
155
+
156
+
157
+ @st.cache(**CACHE_KWARGS)
158
+ def load_img(uploaded_file):
159
+ if uploaded_file is None:
160
+ img = np.array(PIL.Image.open("pulp.jpg"))
161
+ default = True
162
+ else:
163
+ img = np.array(PIL.Image.open(uploaded_file))
164
+ if img.shape[0] > 800 or img.shape[1] > 800:
165
+ if img.shape[0] < img.shape[1]:
166
+ new_size = (800, int(800 * img.shape[0] / img.shape[1]))
167
+ else:
168
+ new_size = (int(800 * img.shape[1] / img.shape[0]), 800)
169
+ img = cv2.resize(img, dsize=new_size, interpolation=cv2.INTER_CUBIC)
170
+ default = False
171
+ depth_map = compute_depth(img)
172
+ return img, depth_map, default
173
+
174
+
175
+ img, depth_map, default = load_img(uploaded_file)
176
+
177
+ if default:
178
+ x0, y0, alpha0, font_size0, depth0, font0 = 0.02, 0.68, 0.99, 0.07, 0.23, 4
179
+ text0 = "Pulp Fiction"
180
+ else:
181
+ x0, y0, alpha0, font_size0, depth0, font0 = 0.1, 0.9, 0.8, 0.08, 0.5, 0
182
+ text0 = "Enter your text here"
183
+
184
+ colA, colB, colC = st.columns((13, 1, 1))
185
+
186
+ with colA:
187
+ text = st.text_input("", text0)
188
+
189
+ with colB:
190
+ st.markdown("Color:")
191
+
192
+ with colC:
193
+ color = st.color_picker("", value="#FFFFFF")
194
+
195
+
196
+ col1, _, col2 = st.columns((4, 1, 4))
197
+
198
+ with col1:
199
+ depth = st.select_slider(
200
+ "",
201
+ options=[i / 100 for i in range(101)],
202
+ value=depth0,
203
+ format_func=lambda x: "Foreground"
204
+ if x == 0.0
205
+ else "Background"
206
+ if x == 1.0
207
+ else "",
208
+ )
209
+ x = st.select_slider(
210
+ "",
211
+ options=[i / 100 for i in range(101)],
212
+ value=x0,
213
+ format_func=lambda x: "Left" if x == 0.0 else "Right" if x == 1.0 else "",
214
+ )
215
+ y = st.select_slider(
216
+ "",
217
+ options=[i / 100 for i in range(101)],
218
+ value=y0,
219
+ format_func=lambda x: "Bottom" if x == 0.0 else "Top" if x == 1.0 else "",
220
+ )
221
+
222
+ with col2:
223
+ font_size = st.select_slider(
224
+ "",
225
+ options=[0.04 + i / 100 for i in range(0, 17)],
226
+ value=font_size0,
227
+ format_func=lambda x: "Small font"
228
+ if x == 0.04
229
+ else "Large font"
230
+ if x == 0.2
231
+ else "",
232
+ )
233
+ alpha = st.select_slider(
234
+ "",
235
+ options=[i / 100 for i in range(101)],
236
+ value=alpha0,
237
+ format_func=lambda x: "Transparent"
238
+ if x == 0.0
239
+ else "Opaque"
240
+ if x == 1.0
241
+ else "",
242
+ )
243
+ font = st.selectbox("", FONTS, index=font0)
244
+
245
+ font = f"fonts/{font[6:]}.ttf"
246
+
247
+ captioned = add_caption(
248
+ img,
249
+ text,
250
+ depth_map=depth_map,
251
+ x=x,
252
+ y=y,
253
+ depth=depth,
254
+ font=font,
255
+ font_size=font_size,
256
+ alpha=alpha,
257
+ color=hex_to_rgb(color[1:]),
258
+ )
259
+
260
+ st.image(captioned)
261
+
262
+ PIL.Image.fromarray(captioned).save("result.jpg")
263
+ with open("result.jpg", "rb") as file:
264
+ btn = st.download_button(
265
+ label="Download image", data=file, file_name="result.jpg", mime="image/jpeg"
266
+ )
fonts/Display - LilitaOne.ttf ADDED
Binary file (26.8 kB). View file
fonts/Display - Lobster.ttf ADDED
Binary file (397 kB). View file
fonts/Handwriting - GreatVibes.ttf ADDED
Binary file (154 kB). View file
fonts/Handwriting - Pacifico.ttf ADDED
Binary file (315 kB). View file
fonts/Mono - Cutive.ttf ADDED
Binary file (77.4 kB). View file
fonts/Mono - Inconsolata.ttf ADDED
Binary file (339 kB). View file
fonts/Sans - Lato.ttf ADDED
Binary file (75.2 kB). View file
fonts/Sans - Roboto.ttf ADDED
Binary file (168 kB). View file
fonts/Serif - Cinzel.ttf ADDED
Binary file (125 kB). View file
fonts/Serif - EBGaramond.ttf ADDED
Binary file (929 kB). View file
packages.txt ADDED
@@ -0,0 +1 @@
 
1
+ libgl1
pulp.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ numpy
2
+ torch
3
+ timm
4
+ pillow
5
+ opencv-python