|
import gradio as gr |
|
import os |
|
import random |
|
import datetime |
|
from utils import * |
|
from pathlib import Path |
|
import gdown |
|
|
|
pre_generate = False |
|
|
|
file_url = "https://storage.googleapis.com/derendering_model/derendering_supp.zip" |
|
filename = "derendering_supp.zip" |
|
|
|
|
|
video_cache_dir = Path("./cached_videos") |
|
video_cache_dir.mkdir(exist_ok=True) |
|
|
|
download_file(file_url, filename) |
|
unzip_file(filename) |
|
print("Downloaded and unzipped the inks.") |
|
|
|
diagram = get_svg_content("derendering_supp/derender_diagram.svg") |
|
org = get_svg_content("org/cor.svg") |
|
org_content = f"{org}" |
|
|
|
gif_filenames = [ |
|
"christians.gif", |
|
"good.gif", |
|
"october.gif", |
|
"welcome.gif", |
|
"you.gif", |
|
"letter.gif", |
|
] |
|
captions = [ |
|
"CHRISTIANS", |
|
"Good", |
|
"October", |
|
"WELOME", |
|
"you", |
|
"letter", |
|
] |
|
gif_base64_strings = {caption: get_base64_encoded_gif(f"gifs/{name}") for caption, name in zip(captions, gif_filenames)} |
|
|
|
sketches = [ |
|
"bird.gif", |
|
"cat.gif", |
|
"coffee.gif", |
|
"penguin.gif", |
|
] |
|
sketches_base64_strings = {name: get_base64_encoded_gif(f"sketches/{name}") for name in sketches} |
|
|
|
if not pre_generate: |
|
|
|
if not (video_cache_dir / "gdrive_file.zip").exists(): |
|
print("Downloading pre-generated videos from Google Drive.") |
|
|
|
gdown.download( |
|
"https://drive.google.com/uc?id=1oT6zw1EbWg3lavBMXsL28piULGNmqJzA", |
|
str(video_cache_dir / "gdrive_file.zip"), |
|
quiet=False, |
|
) |
|
|
|
|
|
unzip_file(str(video_cache_dir / "gdrive_file.zip")) |
|
else: |
|
print("File already exists. Skipping download.") |
|
else: |
|
pregenerate_videos(video_cache_dir=video_cache_dir) |
|
print("Videos cached.") |
|
|
|
|
|
def demo(Dataset, Model): |
|
if Model == "Small-i": |
|
inkml_path = f"./derendering_supp/small-i_{Dataset}_inkml" |
|
elif Model == "Small-p": |
|
inkml_path = f"./derendering_supp/small-p_{Dataset}_inkml" |
|
elif Model == "Large-i": |
|
inkml_path = f"./derendering_supp/large-i_{Dataset}_inkml" |
|
|
|
now = datetime.datetime.now() |
|
random.seed(now.timestamp()) |
|
now = now.strftime("%Y-%m-%d %H:%M:%S") |
|
print( |
|
now, |
|
"Taking sample from dataset:", |
|
Dataset, |
|
"and model:", |
|
Model, |
|
) |
|
path = f"./derendering_supp/{Dataset}/images_sample" |
|
samples = os.listdir(path) |
|
|
|
picked_samples = random.sample(samples, min(1, len(samples))) |
|
|
|
query_modes = ["d+t", "r+d", "vanilla"] |
|
plot_title = {"r+d": "Recognized: ", "d+t": "OCR Input: ", "vanilla": ""} |
|
text_outputs = [] |
|
|
|
video_outputs = [] |
|
for name in picked_samples: |
|
img_path = os.path.join(path, name) |
|
img = load_and_pad_img_dir(img_path) |
|
|
|
for mode in query_modes: |
|
example_id = name.strip(".png") |
|
inkml_file = os.path.join(inkml_path, mode, example_id + ".inkml") |
|
text_field = parse_inkml_annotations(inkml_file)["textField"] |
|
output_text = f"{plot_title[mode]}{text_field}" |
|
text_outputs.append(output_text) |
|
ink = inkml_to_ink(inkml_file) |
|
|
|
video_filename = f"{Model}_{Dataset}_{mode}_{example_id}.mp4" |
|
video_filepath = video_cache_dir / video_filename |
|
|
|
if not video_filepath.exists(): |
|
plot_ink_to_video(ink, str(video_filepath), input_image=img) |
|
print("Cached video at:", video_filepath) |
|
video_outputs.append("./" + str(video_filepath)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return ( |
|
img, |
|
text_outputs[0], |
|
|
|
video_outputs[0], |
|
text_outputs[1], |
|
|
|
video_outputs[1], |
|
text_outputs[2], |
|
|
|
video_outputs[2], |
|
) |
|
|
|
|
|
with gr.Blocks() as app: |
|
gr.HTML(org_content) |
|
gr.Markdown("# InkSight: Offline-to-Online Handwriting Conversion by Learning to Read and Write") |
|
gr.HTML( |
|
""" |
|
<div style="display: flex; gap: 10px; justify-content: left;"> |
|
<a href="https://research.google/blog/a-return-to-hand-written-notes-by-learning-to-read-write/"> |
|
<img src="https://img.shields.io/badge/Google_Research_Blog-333333?&logo=google&logoColor=white" alt="Google Research Blog"> |
|
</a> |
|
<a href="https://arxiv.org/abs/2402.05804"> |
|
<img src="https://img.shields.io/badge/Read_the_Paper-4CAF50?&logo=arxiv&logoColor=white" alt="Read the Paper"> |
|
</a> |
|
<a href="https://charlieleee.github.io/publication/inksight/"> |
|
<img src="https://img.shields.io/badge/π_Project_Page-FFA500?&logo=link&logoColor=white" alt="Project Page"> |
|
</a> |
|
<a href="https://huggingface.co/datasets/Derendering/InkSight-Derenderings"> |
|
<img src="https://img.shields.io/badge/Dataset-InkSight-40AF40?&logo=huggingface&logoColor=white" alt="Hugging Face Dataset"> |
|
</a> |
|
<a href="https://github.com/google-research/inksight"> |
|
<img src="https://img.shields.io/badge/GitHub-InkSight-green?&logo=github&logoColor=white" alt="GitHub Repository"> |
|
</a> |
|
</div> |
|
""" |
|
) |
|
gr.HTML(f'<div style="margin: 20px 0; text-align: center;">{diagram}<br><em>InkSight system diagram (<a href="https://github.com/google-research/inksight/blob/main/figures/full_diagram.gif">gif version</a>)</em></div>') |
|
gr.Markdown( |
|
""" |
|
π This demo highlights the capabilities of Small-i, Small-p, and Large-i across three public datasets (word-level, with 100 random samples each).<br> |
|
π We've released the InkSight-Small-p model on Hugging Face! Check it out [here](https://huggingface.co/Derendering/InkSight-Small-p).<br> |
|
π² Select a model variant and dataset (IAM, IMGUR5K, HierText), then hit 'Sample' to view a randomly selected input alongside its corresponding outputs for all three types of inference.<br> |
|
""" |
|
) |
|
with gr.Row(): |
|
dataset = gr.Dropdown(["IAM", "IMGUR5K", "HierText"], label="Dataset", value="IAM") |
|
model = gr.Dropdown( |
|
["Small-i", "Large-i", "Small-p"], |
|
label="InkSight Model Variant", |
|
value="Small-i", |
|
) |
|
im = gr.Image(label="Input Image") |
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
d_t_text = gr.Textbox(label="OCR recognition input to the model", interactive=False) |
|
r_d_text = gr.Textbox(label="Recognition from the model", interactive=False) |
|
vanilla_text = gr.Textbox(label="Vanilla", interactive=False) |
|
with gr.Row(): |
|
d_t_vid = gr.Video(label="Derender with Text (Click to stop/play)", autoplay=True) |
|
r_d_vid = gr.Video(label="Recognize and Derender (Click to stop/play)", autoplay=True) |
|
vanilla_vid = gr.Video(label="Vanilla (Click to stop/play)", autoplay=True) |
|
|
|
with gr.Row(): |
|
btn_sub = gr.Button("Sample") |
|
|
|
btn_sub.click( |
|
fn=demo, |
|
inputs=[dataset, model], |
|
outputs=[ |
|
im, |
|
d_t_text, |
|
|
|
d_t_vid, |
|
r_d_text, |
|
|
|
r_d_vid, |
|
vanilla_text, |
|
|
|
vanilla_vid, |
|
], |
|
) |
|
|
|
gr.Markdown("## More Word-level Samples") |
|
|
|
html_content = """ |
|
<div style="display: flex; justify-content: space-around; flex-wrap: wrap; gap: 0px;"> |
|
""" |
|
|
|
for caption, base64_string in gif_base64_strings.items(): |
|
title = caption |
|
html_content += f""" |
|
<div> |
|
<img src="data:image/gif;base64,{base64_string}" alt="{title}" style="width: 100%; max-width: 200px;"> |
|
<p style="text-align: center;">{title}</p> |
|
</div> |
|
""" |
|
|
|
html_content += "</div>" |
|
|
|
gr.HTML(html_content) |
|
|
|
|
|
gr.Markdown("## Sketch Samples") |
|
|
|
html_content = """ |
|
<div style="display: flex; justify-content: space-around; flex-wrap: wrap; gap: 0px;"> |
|
""" |
|
|
|
for _, base64_string in sketches_base64_strings.items(): |
|
html_content += f""" |
|
<div> |
|
<img src="data:image/gif;base64,{base64_string}" style="width: 100%; max-width: 200px;"> |
|
</div> |
|
""" |
|
|
|
html_content += "</div>" |
|
|
|
gr.HTML(html_content) |
|
|
|
gr.Markdown("## Scale Up to Full Page") |
|
|
|
svg1_content = get_svg_content("full_page/danke.svg") |
|
svg2_content = get_svg_content("full_page/multilingual_demo.svg") |
|
svg3_content = get_svg_content("full_page/unsplash_frame.svg") |
|
|
|
svg_html_template = """ |
|
<div style="display: block;"> |
|
<div> |
|
<div style="margin-bottom: 10px;">{}</div> |
|
<p style="text-align: center;">{}</p> |
|
</div> |
|
<div> |
|
<div style="margin-bottom: 10px;">{}</div> |
|
<p style="text-align: center;">{}</p> |
|
</div> |
|
<div> |
|
<div style="margin-bottom: 10px;">{}</div> |
|
<p style="text-align: center;">{}</p> |
|
</div> |
|
</div> |
|
""" |
|
|
|
full_svg_display = svg_html_template.format( |
|
svg1_content, |
|
'Writings on the beach. <a href="https://unsplash.com/photos/text-rG-PerMFjFA">Credit</a>', |
|
svg2_content, |
|
"Multilingual handwriting.", |
|
svg3_content, |
|
"Handwriting in a frame. <a href='https://unsplash.com/photos/white-wooden-framed-white-board-t7fLWMQl2Lw'>Credit</a>", |
|
) |
|
|
|
gr.HTML(full_svg_display) |
|
|
|
|
|
app.launch() |
|
|