Charlie Li
commited on
Commit
β’
17f8269
1
Parent(s):
396546f
update page
Browse files
app.py
CHANGED
@@ -39,10 +39,7 @@ captions = [
|
|
39 |
"you",
|
40 |
"letter",
|
41 |
]
|
42 |
-
gif_base64_strings = {
|
43 |
-
caption: get_base64_encoded_gif(f"gifs/{name}")
|
44 |
-
for caption, name in zip(captions, gif_filenames)
|
45 |
-
}
|
46 |
|
47 |
sketches = [
|
48 |
"bird.gif",
|
@@ -50,21 +47,23 @@ sketches = [
|
|
50 |
"coffee.gif",
|
51 |
"penguin.gif",
|
52 |
]
|
53 |
-
sketches_base64_strings = {
|
54 |
-
name: get_base64_encoded_gif(f"sketches/{name}") for name in sketches
|
55 |
-
}
|
56 |
|
57 |
if not pre_generate:
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
|
|
|
|
|
68 |
else:
|
69 |
pregenerate_videos(video_cache_dir=video_cache_dir)
|
70 |
print("Videos cached.")
|
@@ -143,14 +142,21 @@ def demo(Dataset, Model):
|
|
143 |
|
144 |
with gr.Blocks() as app:
|
145 |
gr.HTML(org_content)
|
146 |
-
gr.Markdown(
|
147 |
-
"# InkSight: Offline-to-Online Handwriting Conversion by Learning to Read and Write"
|
148 |
-
)
|
149 |
gr.HTML(
|
150 |
"""
|
151 |
-
<div style="display: flex;
|
152 |
-
<a href="https://arxiv.org/
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
</a>
|
155 |
</div>
|
156 |
"""
|
@@ -163,9 +169,7 @@ with gr.Blocks() as app:
|
|
163 |
"""
|
164 |
)
|
165 |
with gr.Row():
|
166 |
-
dataset = gr.Dropdown(
|
167 |
-
["IAM", "IMGUR5K", "HierText"], label="Dataset", value="IAM"
|
168 |
-
)
|
169 |
model = gr.Dropdown(
|
170 |
["Small-i", "Large-i", "Small-p"],
|
171 |
label="InkSight Model Variant",
|
@@ -179,18 +183,12 @@ with gr.Blocks() as app:
|
|
179 |
# vanilla_img = gr.Image(label="Vanilla")
|
180 |
|
181 |
with gr.Row():
|
182 |
-
d_t_text = gr.Textbox(
|
183 |
-
label="OCR recognition input to the model", interactive=False
|
184 |
-
)
|
185 |
r_d_text = gr.Textbox(label="Recognition from the model", interactive=False)
|
186 |
vanilla_text = gr.Textbox(label="Vanilla", interactive=False)
|
187 |
with gr.Row():
|
188 |
-
d_t_vid = gr.Video(
|
189 |
-
|
190 |
-
)
|
191 |
-
r_d_vid = gr.Video(
|
192 |
-
label="Recognize and Derender (Click to stop/play)", autoplay=True
|
193 |
-
)
|
194 |
vanilla_vid = gr.Video(label="Vanilla (Click to stop/play)", autoplay=True)
|
195 |
|
196 |
with gr.Row():
|
|
|
39 |
"you",
|
40 |
"letter",
|
41 |
]
|
42 |
+
gif_base64_strings = {caption: get_base64_encoded_gif(f"gifs/{name}") for caption, name in zip(captions, gif_filenames)}
|
|
|
|
|
|
|
43 |
|
44 |
sketches = [
|
45 |
"bird.gif",
|
|
|
47 |
"coffee.gif",
|
48 |
"penguin.gif",
|
49 |
]
|
50 |
+
sketches_base64_strings = {name: get_base64_encoded_gif(f"sketches/{name}") for name in sketches}
|
|
|
|
|
51 |
|
52 |
if not pre_generate:
|
53 |
+
# Check if the file already exists
|
54 |
+
if not (video_cache_dir / "gdrive_file.zip").exists():
|
55 |
+
print("Downloading pre-generated videos from Google Drive.")
|
56 |
+
# Download from Google Drive using gdown
|
57 |
+
gdown.download(
|
58 |
+
"https://drive.google.com/uc?id=1oT6zw1EbWg3lavBMXsL28piULGNmqJzA",
|
59 |
+
str(video_cache_dir / "gdrive_file.zip"),
|
60 |
+
quiet=False,
|
61 |
+
)
|
62 |
|
63 |
+
# Unzip the file to video_cache_dir
|
64 |
+
unzip_file(str(video_cache_dir / "gdrive_file.zip"))
|
65 |
+
else:
|
66 |
+
print("File already exists. Skipping download.")
|
67 |
else:
|
68 |
pregenerate_videos(video_cache_dir=video_cache_dir)
|
69 |
print("Videos cached.")
|
|
|
142 |
|
143 |
with gr.Blocks() as app:
|
144 |
gr.HTML(org_content)
|
145 |
+
gr.Markdown("# InkSight: Offline-to-Online Handwriting Conversion by Learning to Read and Write")
|
|
|
|
|
146 |
gr.HTML(
|
147 |
"""
|
148 |
+
<div style="display: flex; gap: 10px; justify-content: left;">
|
149 |
+
<a href="https://arxiv.org/abs/2402.05804">
|
150 |
+
<img src="https://img.shields.io/badge/π_Read_the_Paper-4CAF50?style=for-the-badge&logo=arxiv&logoColor=white" alt="Read the Paper">
|
151 |
+
</a>
|
152 |
+
<a href="https://github.com/google-research/inksight">
|
153 |
+
<img src="https://img.shields.io/badge/View_on_GitHub-181717?style=for-the-badge&logo=github&logoColor=white" alt="View on GitHub">
|
154 |
+
</a>
|
155 |
+
<a href="https://research.google/blog/a-return-to-hand-written-notes-by-learning-to-read-write/">
|
156 |
+
<img src="https://img.shields.io/badge/π_Google_Research_Blog-333333?style=for-the-badge&logo=google&logoColor=white" alt="Google Research Blog">
|
157 |
+
</a>
|
158 |
+
<a href="https://charlieleee.github.io/publication/inksight/">
|
159 |
+
<img src="https://img.shields.io/badge/βΉοΈ_Info-FFA500?style=for-the-badge&logo=info&logoColor=white" alt="Info">
|
160 |
</a>
|
161 |
</div>
|
162 |
"""
|
|
|
169 |
"""
|
170 |
)
|
171 |
with gr.Row():
|
172 |
+
dataset = gr.Dropdown(["IAM", "IMGUR5K", "HierText"], label="Dataset", value="IAM")
|
|
|
|
|
173 |
model = gr.Dropdown(
|
174 |
["Small-i", "Large-i", "Small-p"],
|
175 |
label="InkSight Model Variant",
|
|
|
183 |
# vanilla_img = gr.Image(label="Vanilla")
|
184 |
|
185 |
with gr.Row():
|
186 |
+
d_t_text = gr.Textbox(label="OCR recognition input to the model", interactive=False)
|
|
|
|
|
187 |
r_d_text = gr.Textbox(label="Recognition from the model", interactive=False)
|
188 |
vanilla_text = gr.Textbox(label="Vanilla", interactive=False)
|
189 |
with gr.Row():
|
190 |
+
d_t_vid = gr.Video(label="Derender with Text (Click to stop/play)", autoplay=True)
|
191 |
+
r_d_vid = gr.Video(label="Recognize and Derender (Click to stop/play)", autoplay=True)
|
|
|
|
|
|
|
|
|
192 |
vanilla_vid = gr.Video(label="Vanilla (Click to stop/play)", autoplay=True)
|
193 |
|
194 |
with gr.Row():
|
utils.py
CHANGED
@@ -32,6 +32,8 @@ def get_svg_content(svg_path):
|
|
32 |
|
33 |
|
34 |
def download_file(url, filename):
|
|
|
|
|
35 |
response = requests.get(url)
|
36 |
with open(filename, "wb") as f:
|
37 |
f.write(response.content)
|
@@ -84,22 +86,15 @@ def plot_ink(ink, ax, lw=1.8, input_image=None, with_path=True, path_color="whit
|
|
84 |
base_color = base_colors(len(ink.strokes) - 1 - i)
|
85 |
hsv_color = colorsys.rgb_to_hsv(*base_color[:3])
|
86 |
|
87 |
-
darker_color = colorsys.hsv_to_rgb(
|
88 |
-
|
89 |
-
)
|
90 |
-
colors = [
|
91 |
-
mcolors.to_rgba(darker_color, alpha=1 - (0.5 * j / len(x)))
|
92 |
-
for j in range(len(x))
|
93 |
-
]
|
94 |
|
95 |
points = np.array([x, y]).T.reshape(-1, 1, 2)
|
96 |
segments = np.concatenate([points[:-1], points[1:]], axis=1)
|
97 |
|
98 |
lc = LineCollection(segments, colors=colors, linewidth=lw)
|
99 |
if with_path:
|
100 |
-
lc.set_path_effects(
|
101 |
-
[withStroke(linewidth=lw * 1.25, foreground=path_color)]
|
102 |
-
)
|
103 |
ax.add_collection(lc)
|
104 |
|
105 |
ax.set_xlim(0, 224)
|
@@ -107,9 +102,7 @@ def plot_ink(ink, ax, lw=1.8, input_image=None, with_path=True, path_color="whit
|
|
107 |
ax.invert_yaxis()
|
108 |
|
109 |
|
110 |
-
def plot_ink_to_video(
|
111 |
-
ink, output_name, lw=1.8, input_image=None, path_color="white", fps=30
|
112 |
-
):
|
113 |
fig, ax = plt.subplots(figsize=(4, 4), dpi=150)
|
114 |
|
115 |
if input_image is not None:
|
@@ -143,26 +136,16 @@ def plot_ink_to_video(
|
|
143 |
|
144 |
base_color = base_colors(len(ink.strokes) - 1 - stroke_index)
|
145 |
hsv_color = colorsys.rgb_to_hsv(*base_color[:3])
|
146 |
-
darker_color = colorsys.hsv_to_rgb(
|
147 |
-
|
148 |
-
)
|
149 |
-
visible_segments = (
|
150 |
-
segments[: frame - points_drawn]
|
151 |
-
if frame - points_drawn < len(segments)
|
152 |
-
else segments
|
153 |
-
)
|
154 |
colors = [
|
155 |
-
mcolors.to_rgba(
|
156 |
-
darker_color, alpha=1 - (0.5 * j / len(visible_segments))
|
157 |
-
)
|
158 |
for j in range(len(visible_segments))
|
159 |
]
|
160 |
|
161 |
if len(visible_segments) > 0:
|
162 |
lc = LineCollection(visible_segments, colors=colors, linewidth=lw)
|
163 |
-
lc.set_path_effects(
|
164 |
-
[withStroke(linewidth=lw * 1.25, foreground=path_color)]
|
165 |
-
)
|
166 |
ax.add_collection(lc)
|
167 |
|
168 |
points_drawn += len(segments)
|
@@ -254,13 +237,9 @@ def pregenerate_videos(video_cache_dir):
|
|
254 |
if not os.path.exists(path):
|
255 |
continue
|
256 |
samples = os.listdir(path)
|
257 |
-
for name in tqdm(
|
258 |
-
samples, desc=f"Generating {Model}-{Dataset}-{mode} videos"
|
259 |
-
):
|
260 |
example_id = name.strip(".png")
|
261 |
-
inkml_file = os.path.join(
|
262 |
-
inkml_path_base, mode, f"{example_id}.inkml"
|
263 |
-
)
|
264 |
if not os.path.exists(inkml_file):
|
265 |
continue
|
266 |
video_filename = f"{Model}_{Dataset}_{mode}_{example_id}.mp4"
|
|
|
32 |
|
33 |
|
34 |
def download_file(url, filename):
|
35 |
+
if os.path.exists(filename):
|
36 |
+
return
|
37 |
response = requests.get(url)
|
38 |
with open(filename, "wb") as f:
|
39 |
f.write(response.content)
|
|
|
86 |
base_color = base_colors(len(ink.strokes) - 1 - i)
|
87 |
hsv_color = colorsys.rgb_to_hsv(*base_color[:3])
|
88 |
|
89 |
+
darker_color = colorsys.hsv_to_rgb(hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65))
|
90 |
+
colors = [mcolors.to_rgba(darker_color, alpha=1 - (0.5 * j / len(x))) for j in range(len(x))]
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
points = np.array([x, y]).T.reshape(-1, 1, 2)
|
93 |
segments = np.concatenate([points[:-1], points[1:]], axis=1)
|
94 |
|
95 |
lc = LineCollection(segments, colors=colors, linewidth=lw)
|
96 |
if with_path:
|
97 |
+
lc.set_path_effects([withStroke(linewidth=lw * 1.25, foreground=path_color)])
|
|
|
|
|
98 |
ax.add_collection(lc)
|
99 |
|
100 |
ax.set_xlim(0, 224)
|
|
|
102 |
ax.invert_yaxis()
|
103 |
|
104 |
|
105 |
+
def plot_ink_to_video(ink, output_name, lw=1.8, input_image=None, path_color="white", fps=30):
|
|
|
|
|
106 |
fig, ax = plt.subplots(figsize=(4, 4), dpi=150)
|
107 |
|
108 |
if input_image is not None:
|
|
|
136 |
|
137 |
base_color = base_colors(len(ink.strokes) - 1 - stroke_index)
|
138 |
hsv_color = colorsys.rgb_to_hsv(*base_color[:3])
|
139 |
+
darker_color = colorsys.hsv_to_rgb(hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65))
|
140 |
+
visible_segments = segments[: frame - points_drawn] if frame - points_drawn < len(segments) else segments
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
colors = [
|
142 |
+
mcolors.to_rgba(darker_color, alpha=1 - (0.5 * j / len(visible_segments)))
|
|
|
|
|
143 |
for j in range(len(visible_segments))
|
144 |
]
|
145 |
|
146 |
if len(visible_segments) > 0:
|
147 |
lc = LineCollection(visible_segments, colors=colors, linewidth=lw)
|
148 |
+
lc.set_path_effects([withStroke(linewidth=lw * 1.25, foreground=path_color)])
|
|
|
|
|
149 |
ax.add_collection(lc)
|
150 |
|
151 |
points_drawn += len(segments)
|
|
|
237 |
if not os.path.exists(path):
|
238 |
continue
|
239 |
samples = os.listdir(path)
|
240 |
+
for name in tqdm(samples, desc=f"Generating {Model}-{Dataset}-{mode} videos"):
|
|
|
|
|
241 |
example_id = name.strip(".png")
|
242 |
+
inkml_file = os.path.join(inkml_path_base, mode, f"{example_id}.inkml")
|
|
|
|
|
243 |
if not os.path.exists(inkml_file):
|
244 |
continue
|
245 |
video_filename = f"{Model}_{Dataset}_{mode}_{example_id}.mp4"
|