Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import random | |
import datetime | |
from utils import * | |
from pathlib import Path | |
import gdown | |
pre_generate = False | |
file_url = "https://storage.googleapis.com/derendering_model/derendering_supp.zip" | |
filename = "derendering_supp.zip" | |
# Cache videos to speed up demo | |
video_cache_dir = Path("./cached_videos") | |
video_cache_dir.mkdir(exist_ok=True) | |
download_file(file_url, filename) | |
unzip_file(filename) | |
print("Downloaded and unzipped the inks.") | |
diagram = get_svg_content("derendering_supp/derender_diagram.svg") | |
org = get_svg_content("org/cor.svg") | |
org_content = f"{org}" | |
gif_filenames = [ | |
"christians.gif", | |
"good.gif", | |
"october.gif", | |
"welcome.gif", | |
"you.gif", | |
"letter.gif", | |
] | |
captions = [ | |
"CHRISTIANS", | |
"Good", | |
"October", | |
"WELOME", | |
"you", | |
"letter", | |
] | |
gif_base64_strings = { | |
caption: get_base64_encoded_gif(f"gifs/{name}") | |
for caption, name in zip(captions, gif_filenames) | |
} | |
sketches = [ | |
"bird.gif", | |
"cat.gif", | |
"coffee.gif", | |
"penguin.gif", | |
] | |
sketches_base64_strings = { | |
name: get_base64_encoded_gif(f"sketches/{name}") for name in sketches | |
} | |
if not pre_generate: | |
print("Downloading pre-generated videos from google drive.") | |
# Download from gdown 1oT6zw1EbWg3lavBMXsL28piULGNmqJzA | |
gdown.download( | |
"https://drive.google.com/uc?id=1oT6zw1EbWg3lavBMXsL28piULGNmqJzA", | |
str(video_cache_dir / "gdrive_file.zip"), | |
quiet=False, | |
) | |
# Unzip the file to video_cache_dir | |
unzip_file(str(video_cache_dir / "gdrive_file.zip")) | |
else: | |
pregenerate_videos(video_cache_dir=video_cache_dir) | |
print("Videos cached.") | |
def demo(Dataset, Model): | |
if Model == "Small-i": | |
inkml_path = f"./derendering_supp/small-i_{Dataset}_inkml" | |
elif Model == "Small-p": | |
inkml_path = f"./derendering_supp/small-p_{Dataset}_inkml" | |
elif Model == "Large-i": | |
inkml_path = f"./derendering_supp/large-i_{Dataset}_inkml" | |
now = datetime.datetime.now() | |
random.seed(now.timestamp()) | |
now = now.strftime("%Y-%m-%d %H:%M:%S") | |
print( | |
now, | |
"Taking sample from dataset:", | |
Dataset, | |
"and model:", | |
Model, | |
) | |
path = f"./derendering_supp/{Dataset}/images_sample" | |
samples = os.listdir(path) | |
# Randomly pick a sample | |
picked_samples = random.sample(samples, min(1, len(samples))) | |
query_modes = ["d+t", "r+d", "vanilla"] | |
plot_title = {"r+d": "Recognized: ", "d+t": "OCR Input: ", "vanilla": ""} | |
text_outputs = [] | |
# img_outputs = [] | |
video_outputs = [] | |
for name in picked_samples: | |
img_path = os.path.join(path, name) | |
img = load_and_pad_img_dir(img_path) | |
for mode in query_modes: | |
example_id = name.strip(".png") | |
inkml_file = os.path.join(inkml_path, mode, example_id + ".inkml") | |
text_field = parse_inkml_annotations(inkml_file)["textField"] | |
output_text = f"{plot_title[mode]}{text_field}" | |
text_outputs.append(output_text) | |
ink = inkml_to_ink(inkml_file) | |
video_filename = f"{Model}_{Dataset}_{mode}_{example_id}.mp4" | |
video_filepath = video_cache_dir / video_filename | |
if not video_filepath.exists(): | |
plot_ink_to_video(ink, str(video_filepath), input_image=img) | |
print("Cached video at:", video_filepath) | |
video_outputs.append("./" + str(video_filepath)) | |
# fig, ax = plt.subplots() | |
# ax.axis("off") | |
# plot_ink(ink, ax, input_image=img) | |
# buf = BytesIO() | |
# fig.savefig(buf, format="png", bbox_inches="tight") | |
# plt.close(fig) | |
# buf.seek(0) | |
# res = Image.open(buf) | |
# img_outputs.append(res) | |
return ( | |
img, | |
text_outputs[0], | |
# img_outputs[0], | |
video_outputs[0], | |
text_outputs[1], | |
# img_outputs[1], | |
video_outputs[1], | |
text_outputs[2], | |
# img_outputs[2], | |
video_outputs[2], | |
) | |
with gr.Blocks() as app: | |
gr.HTML(org_content) | |
gr.Markdown( | |
"# InkSight: Offline-to-Online Handwriting Conversion by Learning to Read and Write" | |
) | |
gr.HTML( | |
""" | |
<div style="display: flex; align-items: center; margin-bottom: 20px;"> | |
<a href="https://arxiv.org/pdf/2402.05804.pdf" target="_blank" style="font-size: 16px; background-color: #4CAF50; color: white; padding: 5px 7px; text-decoration: none; border-radius: 2px;"> | |
π Read the Paper | |
</a> | |
</div> | |
""" | |
) | |
gr.HTML(f"<div style='margin: 20px 0;'>{diagram}</div>") | |
gr.Markdown( | |
""" | |
π This demo highlights the capabilities of Small-i, Small-p, and Large-i across three public datasets (word-level, with 100 random samples each).<br> | |
π² Select a model variant and dataset (IAM, IMGUR5K, HierText), then hit 'Sample' to view a randomly selected input alongside its corresponding outputs for all three types of inference.<br> | |
""" | |
) | |
with gr.Row(): | |
dataset = gr.Dropdown( | |
["IAM", "IMGUR5K", "HierText"], label="Dataset", value="IAM" | |
) | |
model = gr.Dropdown( | |
["Small-i", "Large-i", "Small-p"], | |
label="InkSight Model Variant", | |
value="Small-i", | |
) | |
im = gr.Image(label="Input Image") | |
# with gr.Row(): | |
# d_t_img = gr.Image(label="Derender with Text") | |
# r_d_img = gr.Image(label="Recognize and Derender") | |
# vanilla_img = gr.Image(label="Vanilla") | |
with gr.Row(): | |
d_t_text = gr.Textbox( | |
label="OCR recognition input to the model", interactive=False | |
) | |
r_d_text = gr.Textbox(label="Recognition from the model", interactive=False) | |
vanilla_text = gr.Textbox(label="Vanilla", interactive=False) | |
with gr.Row(): | |
d_t_vid = gr.Video( | |
label="Derender with Text (Click to stop/play)", autoplay=True | |
) | |
r_d_vid = gr.Video( | |
label="Recognize and Derender (Click to stop/play)", autoplay=True | |
) | |
vanilla_vid = gr.Video(label="Vanilla (Click to stop/play)", autoplay=True) | |
with gr.Row(): | |
btn_sub = gr.Button("Sample") | |
btn_sub.click( | |
fn=demo, | |
inputs=[dataset, model], | |
outputs=[ | |
im, | |
d_t_text, | |
# d_t_img, | |
d_t_vid, | |
r_d_text, | |
# r_d_img, | |
r_d_vid, | |
vanilla_text, | |
# vanilla_img, | |
vanilla_vid, | |
], | |
) | |
gr.Markdown("## More Word-level Samples") | |
html_content = """ | |
<div style="display: flex; justify-content: space-around; flex-wrap: wrap; gap: 0px;"> | |
""" | |
for caption, base64_string in gif_base64_strings.items(): | |
title = caption | |
html_content += f""" | |
<div> | |
<img src="data:image/gif;base64,{base64_string}" alt="{title}" style="width: 100%; max-width: 200px;"> | |
<p style="text-align: center;">{title}</p> | |
</div> | |
""" | |
html_content += "</div>" | |
gr.HTML(html_content) | |
# Sketches | |
gr.Markdown("## Sketch Samples") | |
html_content = """ | |
<div style="display: flex; justify-content: space-around; flex-wrap: wrap; gap: 0px;"> | |
""" | |
for _, base64_string in sketches_base64_strings.items(): | |
html_content += f""" | |
<div> | |
<img src="data:image/gif;base64,{base64_string}" style="width: 100%; max-width: 200px;"> | |
</div> | |
""" | |
html_content += "</div>" | |
gr.HTML(html_content) | |
gr.Markdown("## Scale Up to Full Page") | |
svg1_content = get_svg_content("full_page/danke.svg") | |
svg2_content = get_svg_content("full_page/multilingual_demo.svg") | |
svg3_content = get_svg_content("full_page/unsplash_frame.svg") | |
svg_html_template = """ | |
<div style="display: block;"> | |
<div> | |
<div style="margin-bottom: 10px;">{}</div> | |
<p style="text-align: center;">{}</p> | |
</div> | |
<div> | |
<div style="margin-bottom: 10px;">{}</div> | |
<p style="text-align: center;">{}</p> | |
</div> | |
<div> | |
<div style="margin-bottom: 10px;">{}</div> | |
<p style="text-align: center;">{}</p> | |
</div> | |
</div> | |
""" | |
full_svg_display = svg_html_template.format( | |
svg1_content, | |
'Writings on the beach. <a href="https://unsplash.com/photos/text-rG-PerMFjFA">Credit</a>', | |
svg2_content, | |
"Multilingual handwriting.", | |
svg3_content, | |
"Handwriting in a frame. <a href='https://unsplash.com/photos/white-wooden-framed-white-board-t7fLWMQl2Lw'>Credit</a>", | |
) | |
gr.HTML(full_svg_display) | |
app.launch() | |