Charlie Li commited on
Commit
4697797
β€’
1 Parent(s): f44710a
Files changed (6) hide show
  1. .gitignore +7 -0
  2. README.md +2 -2
  3. app.py +101 -0
  4. org/cor.svg +264 -0
  5. requirements.txt +5 -0
  6. utils.py +235 -0
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.mp4
3
+ flagged/
4
+ derendering_supp/
5
+ *.zip
6
+ __MACOSX/
7
+ .DS_Store
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Model Output Playground
3
- emoji: 🐨
4
  colorFrom: purple
5
  colorTo: green
6
  sdk: gradio
@@ -10,4 +10,4 @@ pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Model Output Playground
3
+ emoji: πŸ›
4
  colorFrom: purple
5
  colorTo: green
6
  sdk: gradio
 
10
  license: apache-2.0
11
  ---
12
 
13
+ Paper: https://arxiv.org/abs/2402.05804
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from utils import *
3
+
4
+ file_url = "https://storage.googleapis.com/derendering_model/derendering_supp.zip"
5
+ filename = "derendering_supp.zip"
6
+
7
+ download_file(file_url, filename)
8
+ unzip_file(filename)
9
+ print("Downloaded and unzipped the file.")
10
+
11
+ diagram = get_svg_content("derendering_supp/derender_diagram.svg")
12
+ org = get_svg_content("org/cor.svg")
13
+
14
+ org_content = f"""
15
+ {org}
16
+ """
17
+
18
+
19
+ def demo(Dataset, Model):
20
+ if Model == "Small-i":
21
+ inkml_path = f"./derendering_supp/small-i_{Dataset}_inkml"
22
+ elif Model == "Small-p":
23
+ inkml_path = f"./derendering_supp/small-p_{Dataset}_inkml"
24
+ elif Model == "Large-i":
25
+ inkml_path = f"./derendering_supp/large-i_{Dataset}_inkml"
26
+
27
+ path = f"./derendering_supp/{Dataset}/images_sample"
28
+ samples = os.listdir(path)
29
+ # Randomly pick a sample
30
+ picked_samples = random.sample(samples, min(1, len(samples)))
31
+
32
+ query_modes = ["d+t", "r+d", "vanilla"]
33
+ plot_title = {"r+d": "Recognized: ", "d+t": "OCR Input: ", "vanilla": ""}
34
+ text_outputs = []
35
+
36
+ for name in picked_samples:
37
+ img_path = os.path.join(path, name)
38
+ img = load_and_pad_img_dir(img_path)
39
+
40
+ for mode in query_modes:
41
+ example_id = name.strip(".png")
42
+ inkml_file = os.path.join(inkml_path, mode, example_id + ".inkml")
43
+ text_field = parse_inkml_annotations(inkml_file)["textField"]
44
+ output_text = f"{plot_title[mode]}{text_field}"
45
+ text_outputs.append(output_text) # Append text output for the current mode
46
+ ink = inkml_to_ink(inkml_file)
47
+ plot_ink_to_video(ink, mode + ".mp4", input_image=img)
48
+
49
+ return (
50
+ img,
51
+ text_outputs[0],
52
+ "d+t.mp4",
53
+ text_outputs[1],
54
+ "r+d.mp4",
55
+ text_outputs[2],
56
+ "vanilla.mp4",
57
+ )
58
+
59
+
60
+ with gr.Blocks() as app:
61
+ gr.HTML(org_content)
62
+ gr.Markdown(
63
+ f"""
64
+ # InkSight: Offline-to-Online Handwriting Conversion by Learning to Read and Write<br>
65
+ <div>{diagram}</div>
66
+ πŸ”” This demo showcases the outputs of <b>Small-i</b>, <b>Small-p</b>, and <b>Large-i</b> on three public datasets (100 samples each).<br>
67
+ ℹ️ Choose a model variant and dataset, then click 'Sample' to see an input with its corresponding outputs for all three inference types..<br>
68
+ """
69
+ )
70
+ with gr.Row():
71
+ dataset = gr.Dropdown(
72
+ ["IMGUR5K", "IAM", "HierText"], label="Dataset", value="HierText"
73
+ )
74
+ model = gr.Dropdown(
75
+ ["Small-i", "Large-i", "Small-p"],
76
+ label="InkSight Model Variant",
77
+ value="Small-i",
78
+ )
79
+ im = gr.Image(label="Input Image")
80
+ with gr.Row():
81
+ d_t_text = gr.Textbox(
82
+ label="OCR recognition input to the model", interactive=False
83
+ )
84
+ r_d_text = gr.Textbox(label="Recognition from the model", interactive=False)
85
+ vanilla_text = gr.Textbox(label="Vanilla", interactive=False)
86
+
87
+ with gr.Row():
88
+ d_t = gr.Video(label="Derender with Text", autoplay=True)
89
+ r_d = gr.Video(label="Recognize and Derender", autoplay=True)
90
+ vanilla = gr.Video(label="Vanilla", autoplay=True)
91
+
92
+ with gr.Row():
93
+ btn_sub = gr.Button("Sample")
94
+
95
+ btn_sub.click(
96
+ fn=demo,
97
+ inputs=[dataset, model],
98
+ outputs=[im, d_t_text, d_t, r_d_text, r_d, vanilla_text, vanilla],
99
+ )
100
+
101
+ app.launch()
org/cor.svg ADDED
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ tqdm
2
+ numpy
3
+ matplotlib
4
+ Pillow
5
+ numpy
utils.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from tqdm import tqdm
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import xml.etree.ElementTree as ET
6
+ from xml.dom import minidom
7
+ import os
8
+ from PIL import Image
9
+ import matplotlib.animation as animation
10
+ import copy
11
+ from PIL import ImageEnhance
12
+ import colorsys
13
+ import matplotlib.colors as mcolors
14
+ from matplotlib.collections import LineCollection
15
+ from matplotlib.patheffects import withStroke
16
+ import random
17
+ import warnings
18
+ from matplotlib.figure import Figure
19
+ from io import BytesIO
20
+ from matplotlib.animation import FuncAnimation, FFMpegWriter, PillowWriter
21
+ import requests
22
+ import zipfile
23
+
24
+
25
+ warnings.filterwarnings("ignore")
26
+
27
+
28
+ def get_svg_content(svg_path):
29
+ with open(svg_path, "r") as file:
30
+ return file.read()
31
+
32
+
33
+ def download_file(url, filename):
34
+ response = requests.get(url)
35
+ with open(filename, "wb") as f:
36
+ f.write(response.content)
37
+
38
+
39
+ def unzip_file(filename, extract_to="."):
40
+ with zipfile.ZipFile(filename, "r") as zip_ref:
41
+ zip_ref.extractall(extract_to)
42
+
43
+
44
+ def load_and_pad_img_dir(file_dir):
45
+ image_path = os.path.join(file_dir)
46
+ image = Image.open(image_path)
47
+ width, height = image.size
48
+ ratio = min(224 / width, 224 / height)
49
+ image = image.resize((int(width * ratio), int(height * ratio)))
50
+ width, height = image.size
51
+ if height < 224:
52
+ # If width is shorter than height pad top and bottom.
53
+ top_padding = (224 - height) // 2
54
+ bottom_padding = 224 - height - top_padding
55
+ padded_image = Image.new("RGB", (width, 224), (255, 255, 255))
56
+ padded_image.paste(image, (0, top_padding))
57
+ else:
58
+ # Otherwise pad left and right.
59
+ left_padding = (224 - width) // 2
60
+ right_padding = 224 - width - left_padding
61
+ padded_image = Image.new("RGB", (224, height), (255, 255, 255))
62
+ padded_image.paste(image, (left_padding, 0))
63
+ return padded_image
64
+
65
+
66
+ def plot_ink(ink, ax, lw=1.8, input_image=None, with_path=True, path_color="white"):
67
+ if input_image is not None:
68
+ img = copy.deepcopy(input_image)
69
+ enhancer = ImageEnhance.Brightness(img)
70
+ img = enhancer.enhance(0.45)
71
+ ax.imshow(img)
72
+
73
+ base_colors = plt.cm.get_cmap("rainbow", len(ink.strokes))
74
+
75
+ for i, stroke in enumerate(ink.strokes):
76
+ x, y = np.array(stroke.x), np.array(stroke.y)
77
+
78
+ base_color = base_colors(len(ink.strokes) - 1 - i)
79
+ hsv_color = colorsys.rgb_to_hsv(*base_color[:3])
80
+
81
+ darker_color = colorsys.hsv_to_rgb(
82
+ hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65)
83
+ )
84
+ colors = [
85
+ mcolors.to_rgba(darker_color, alpha=1 - (0.5 * j / len(x)))
86
+ for j in range(len(x))
87
+ ]
88
+
89
+ points = np.array([x, y]).T.reshape(-1, 1, 2)
90
+ segments = np.concatenate([points[:-1], points[1:]], axis=1)
91
+
92
+ lc = LineCollection(segments, colors=colors, linewidth=lw)
93
+ if with_path:
94
+ lc.set_path_effects(
95
+ [withStroke(linewidth=lw * 1.25, foreground=path_color)]
96
+ )
97
+ ax.add_collection(lc)
98
+
99
+ ax.set_xlim(0, 224)
100
+ ax.set_ylim(0, 224)
101
+ ax.invert_yaxis()
102
+
103
+
104
+ def plot_ink_to_video(
105
+ ink, output_name, lw=1.8, input_image=None, path_color="white", fps=30
106
+ ):
107
+ fig, ax = plt.subplots(figsize=(4, 4), dpi=150)
108
+
109
+ if input_image is not None:
110
+ img = copy.deepcopy(input_image)
111
+ enhancer = ImageEnhance.Brightness(img)
112
+ img = enhancer.enhance(0.45)
113
+ ax.imshow(img)
114
+
115
+ ax.set_xlim(0, 224)
116
+ ax.set_ylim(0, 224)
117
+ ax.invert_yaxis()
118
+ ax.axis("off")
119
+
120
+ base_colors = plt.cm.get_cmap("rainbow", len(ink.strokes))
121
+ all_points = sum([len(stroke.x) for stroke in ink.strokes], 0)
122
+
123
+ def update(frame):
124
+ ax.clear()
125
+ if input_image is not None:
126
+ ax.imshow(img)
127
+ ax.set_xlim(0, 224)
128
+ ax.set_ylim(0, 224)
129
+ ax.invert_yaxis()
130
+ ax.axis("off")
131
+
132
+ points_drawn = 0
133
+ for stroke_index, stroke in enumerate(ink.strokes):
134
+ x, y = np.array(stroke.x), np.array(stroke.y)
135
+ points = np.array([x, y]).T.reshape(-1, 1, 2)
136
+ segments = np.concatenate([points[:-1], points[1:]], axis=1)
137
+
138
+ base_color = base_colors(len(ink.strokes) - 1 - stroke_index)
139
+ hsv_color = colorsys.rgb_to_hsv(*base_color[:3])
140
+ darker_color = colorsys.hsv_to_rgb(
141
+ hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65)
142
+ )
143
+ visible_segments = (
144
+ segments[: frame - points_drawn]
145
+ if frame - points_drawn < len(segments)
146
+ else segments
147
+ )
148
+ colors = [
149
+ mcolors.to_rgba(
150
+ darker_color, alpha=1 - (0.5 * j / len(visible_segments))
151
+ )
152
+ for j in range(len(visible_segments))
153
+ ]
154
+
155
+ if len(visible_segments) > 0:
156
+ lc = LineCollection(visible_segments, colors=colors, linewidth=lw)
157
+ lc.set_path_effects(
158
+ [withStroke(linewidth=lw * 1.25, foreground=path_color)]
159
+ )
160
+ ax.add_collection(lc)
161
+
162
+ points_drawn += len(segments)
163
+ if points_drawn >= frame:
164
+ break
165
+
166
+ ani = FuncAnimation(fig, update, frames=all_points + 1, blit=False)
167
+ Writer = FFMpegWriter(fps=fps)
168
+ ani.save(output_name, writer=Writer)
169
+ plt.close(fig)
170
+
171
+
172
+ class Stroke:
173
+ def __init__(self, list_of_coordinates=None) -> None:
174
+ self.x = []
175
+ self.y = []
176
+ if list_of_coordinates:
177
+ for point in list_of_coordinates:
178
+ self.x.append(point[0])
179
+ self.y.append(point[1])
180
+
181
+ def __len__(self):
182
+ return len(self.x)
183
+
184
+ def __getitem__(self, index):
185
+ return (self.x[index], self.y[index])
186
+
187
+
188
+ class Ink:
189
+ def __init__(self, list_of_strokes=None) -> None:
190
+ self.strokes = []
191
+ if list_of_strokes:
192
+ self.strokes = list_of_strokes
193
+
194
+ def __len__(self):
195
+ return len(self.strokes)
196
+
197
+ def __getitem__(self, index):
198
+ return self.strokes[index]
199
+
200
+
201
+ def inkml_to_ink(inkml_file):
202
+ """Convert inkml file to Ink"""
203
+ tree = ET.parse(inkml_file)
204
+ root = tree.getroot()
205
+
206
+ inkml_namespace = {"inkml": "http://www.w3.org/2003/InkML"}
207
+
208
+ strokes = []
209
+
210
+ for trace in root.findall("inkml:trace", inkml_namespace):
211
+ points = trace.text.strip().split()
212
+ stroke_points = []
213
+
214
+ for point in points:
215
+ x, y = point.split(",")
216
+ stroke_points.append((float(x), float(y)))
217
+ strokes.append(Stroke(stroke_points))
218
+ return Ink(strokes)
219
+
220
+
221
+ def parse_inkml_annotations(inkml_file):
222
+ tree = ET.parse(inkml_file)
223
+ root = tree.getroot()
224
+
225
+ annotations = root.findall(".//{http://www.w3.org/2003/InkML}annotation")
226
+
227
+ annotation_dict = {}
228
+
229
+ for annotation in annotations:
230
+ annotation_type = annotation.get("type")
231
+ annotation_text = annotation.text
232
+
233
+ annotation_dict[annotation_type] = annotation_text
234
+
235
+ return annotation_dict