Charlie Li commited on
Commit
17f8269
β€’
1 Parent(s): 396546f

update page

Browse files
Files changed (2) hide show
  1. app.py +32 -34
  2. utils.py +12 -33
app.py CHANGED
@@ -39,10 +39,7 @@ captions = [
39
  "you",
40
  "letter",
41
  ]
42
- gif_base64_strings = {
43
- caption: get_base64_encoded_gif(f"gifs/{name}")
44
- for caption, name in zip(captions, gif_filenames)
45
- }
46
 
47
  sketches = [
48
  "bird.gif",
@@ -50,21 +47,23 @@ sketches = [
50
  "coffee.gif",
51
  "penguin.gif",
52
  ]
53
- sketches_base64_strings = {
54
- name: get_base64_encoded_gif(f"sketches/{name}") for name in sketches
55
- }
56
 
57
  if not pre_generate:
58
- print("Downloading pre-generated videos from google drive.")
59
- # Download from gdown 1oT6zw1EbWg3lavBMXsL28piULGNmqJzA
60
- gdown.download(
61
- "https://drive.google.com/uc?id=1oT6zw1EbWg3lavBMXsL28piULGNmqJzA",
62
- str(video_cache_dir / "gdrive_file.zip"),
63
- quiet=False,
64
- )
 
 
65
 
66
- # Unzip the file to video_cache_dir
67
- unzip_file(str(video_cache_dir / "gdrive_file.zip"))
 
 
68
  else:
69
  pregenerate_videos(video_cache_dir=video_cache_dir)
70
  print("Videos cached.")
@@ -143,14 +142,21 @@ def demo(Dataset, Model):
143
 
144
  with gr.Blocks() as app:
145
  gr.HTML(org_content)
146
- gr.Markdown(
147
- "# InkSight: Offline-to-Online Handwriting Conversion by Learning to Read and Write"
148
- )
149
  gr.HTML(
150
  """
151
- <div style="display: flex; align-items: center; margin-bottom: 20px;">
152
- <a href="https://arxiv.org/pdf/2402.05804.pdf" target="_blank" style="font-size: 16px; background-color: #4CAF50; color: white; padding: 5px 7px; text-decoration: none; border-radius: 2px;">
153
- πŸ“„ Read the Paper
 
 
 
 
 
 
 
 
 
154
  </a>
155
  </div>
156
  """
@@ -163,9 +169,7 @@ with gr.Blocks() as app:
163
  """
164
  )
165
  with gr.Row():
166
- dataset = gr.Dropdown(
167
- ["IAM", "IMGUR5K", "HierText"], label="Dataset", value="IAM"
168
- )
169
  model = gr.Dropdown(
170
  ["Small-i", "Large-i", "Small-p"],
171
  label="InkSight Model Variant",
@@ -179,18 +183,12 @@ with gr.Blocks() as app:
179
  # vanilla_img = gr.Image(label="Vanilla")
180
 
181
  with gr.Row():
182
- d_t_text = gr.Textbox(
183
- label="OCR recognition input to the model", interactive=False
184
- )
185
  r_d_text = gr.Textbox(label="Recognition from the model", interactive=False)
186
  vanilla_text = gr.Textbox(label="Vanilla", interactive=False)
187
  with gr.Row():
188
- d_t_vid = gr.Video(
189
- label="Derender with Text (Click to stop/play)", autoplay=True
190
- )
191
- r_d_vid = gr.Video(
192
- label="Recognize and Derender (Click to stop/play)", autoplay=True
193
- )
194
  vanilla_vid = gr.Video(label="Vanilla (Click to stop/play)", autoplay=True)
195
 
196
  with gr.Row():
 
39
  "you",
40
  "letter",
41
  ]
42
+ gif_base64_strings = {caption: get_base64_encoded_gif(f"gifs/{name}") for caption, name in zip(captions, gif_filenames)}
 
 
 
43
 
44
  sketches = [
45
  "bird.gif",
 
47
  "coffee.gif",
48
  "penguin.gif",
49
  ]
50
+ sketches_base64_strings = {name: get_base64_encoded_gif(f"sketches/{name}") for name in sketches}
 
 
51
 
52
  if not pre_generate:
53
+ # Check if the file already exists
54
+ if not (video_cache_dir / "gdrive_file.zip").exists():
55
+ print("Downloading pre-generated videos from Google Drive.")
56
+ # Download from Google Drive using gdown
57
+ gdown.download(
58
+ "https://drive.google.com/uc?id=1oT6zw1EbWg3lavBMXsL28piULGNmqJzA",
59
+ str(video_cache_dir / "gdrive_file.zip"),
60
+ quiet=False,
61
+ )
62
 
63
+ # Unzip the file to video_cache_dir
64
+ unzip_file(str(video_cache_dir / "gdrive_file.zip"))
65
+ else:
66
+ print("File already exists. Skipping download.")
67
  else:
68
  pregenerate_videos(video_cache_dir=video_cache_dir)
69
  print("Videos cached.")
 
142
 
143
  with gr.Blocks() as app:
144
  gr.HTML(org_content)
145
+ gr.Markdown("# InkSight: Offline-to-Online Handwriting Conversion by Learning to Read and Write")
 
 
146
  gr.HTML(
147
  """
148
+ <div style="display: flex; gap: 10px; justify-content: left;">
149
+ <a href="https://arxiv.org/abs/2402.05804">
150
+ <img src="https://img.shields.io/badge/πŸ“„_Read_the_Paper-4CAF50?style=for-the-badge&logo=arxiv&logoColor=white" alt="Read the Paper">
151
+ </a>
152
+ <a href="https://github.com/google-research/inksight">
153
+ <img src="https://img.shields.io/badge/View_on_GitHub-181717?style=for-the-badge&logo=github&logoColor=white" alt="View on GitHub">
154
+ </a>
155
+ <a href="https://research.google/blog/a-return-to-hand-written-notes-by-learning-to-read-write/">
156
+ <img src="https://img.shields.io/badge/🌐_Google_Research_Blog-333333?style=for-the-badge&logo=google&logoColor=white" alt="Google Research Blog">
157
+ </a>
158
+ <a href="https://charlieleee.github.io/publication/inksight/">
159
+ <img src="https://img.shields.io/badge/ℹ️_Info-FFA500?style=for-the-badge&logo=info&logoColor=white" alt="Info">
160
  </a>
161
  </div>
162
  """
 
169
  """
170
  )
171
  with gr.Row():
172
+ dataset = gr.Dropdown(["IAM", "IMGUR5K", "HierText"], label="Dataset", value="IAM")
 
 
173
  model = gr.Dropdown(
174
  ["Small-i", "Large-i", "Small-p"],
175
  label="InkSight Model Variant",
 
183
  # vanilla_img = gr.Image(label="Vanilla")
184
 
185
  with gr.Row():
186
+ d_t_text = gr.Textbox(label="OCR recognition input to the model", interactive=False)
 
 
187
  r_d_text = gr.Textbox(label="Recognition from the model", interactive=False)
188
  vanilla_text = gr.Textbox(label="Vanilla", interactive=False)
189
  with gr.Row():
190
+ d_t_vid = gr.Video(label="Derender with Text (Click to stop/play)", autoplay=True)
191
+ r_d_vid = gr.Video(label="Recognize and Derender (Click to stop/play)", autoplay=True)
 
 
 
 
192
  vanilla_vid = gr.Video(label="Vanilla (Click to stop/play)", autoplay=True)
193
 
194
  with gr.Row():
utils.py CHANGED
@@ -32,6 +32,8 @@ def get_svg_content(svg_path):
32
 
33
 
34
  def download_file(url, filename):
 
 
35
  response = requests.get(url)
36
  with open(filename, "wb") as f:
37
  f.write(response.content)
@@ -84,22 +86,15 @@ def plot_ink(ink, ax, lw=1.8, input_image=None, with_path=True, path_color="whit
84
  base_color = base_colors(len(ink.strokes) - 1 - i)
85
  hsv_color = colorsys.rgb_to_hsv(*base_color[:3])
86
 
87
- darker_color = colorsys.hsv_to_rgb(
88
- hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65)
89
- )
90
- colors = [
91
- mcolors.to_rgba(darker_color, alpha=1 - (0.5 * j / len(x)))
92
- for j in range(len(x))
93
- ]
94
 
95
  points = np.array([x, y]).T.reshape(-1, 1, 2)
96
  segments = np.concatenate([points[:-1], points[1:]], axis=1)
97
 
98
  lc = LineCollection(segments, colors=colors, linewidth=lw)
99
  if with_path:
100
- lc.set_path_effects(
101
- [withStroke(linewidth=lw * 1.25, foreground=path_color)]
102
- )
103
  ax.add_collection(lc)
104
 
105
  ax.set_xlim(0, 224)
@@ -107,9 +102,7 @@ def plot_ink(ink, ax, lw=1.8, input_image=None, with_path=True, path_color="whit
107
  ax.invert_yaxis()
108
 
109
 
110
- def plot_ink_to_video(
111
- ink, output_name, lw=1.8, input_image=None, path_color="white", fps=30
112
- ):
113
  fig, ax = plt.subplots(figsize=(4, 4), dpi=150)
114
 
115
  if input_image is not None:
@@ -143,26 +136,16 @@ def plot_ink_to_video(
143
 
144
  base_color = base_colors(len(ink.strokes) - 1 - stroke_index)
145
  hsv_color = colorsys.rgb_to_hsv(*base_color[:3])
146
- darker_color = colorsys.hsv_to_rgb(
147
- hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65)
148
- )
149
- visible_segments = (
150
- segments[: frame - points_drawn]
151
- if frame - points_drawn < len(segments)
152
- else segments
153
- )
154
  colors = [
155
- mcolors.to_rgba(
156
- darker_color, alpha=1 - (0.5 * j / len(visible_segments))
157
- )
158
  for j in range(len(visible_segments))
159
  ]
160
 
161
  if len(visible_segments) > 0:
162
  lc = LineCollection(visible_segments, colors=colors, linewidth=lw)
163
- lc.set_path_effects(
164
- [withStroke(linewidth=lw * 1.25, foreground=path_color)]
165
- )
166
  ax.add_collection(lc)
167
 
168
  points_drawn += len(segments)
@@ -254,13 +237,9 @@ def pregenerate_videos(video_cache_dir):
254
  if not os.path.exists(path):
255
  continue
256
  samples = os.listdir(path)
257
- for name in tqdm(
258
- samples, desc=f"Generating {Model}-{Dataset}-{mode} videos"
259
- ):
260
  example_id = name.strip(".png")
261
- inkml_file = os.path.join(
262
- inkml_path_base, mode, f"{example_id}.inkml"
263
- )
264
  if not os.path.exists(inkml_file):
265
  continue
266
  video_filename = f"{Model}_{Dataset}_{mode}_{example_id}.mp4"
 
32
 
33
 
34
  def download_file(url, filename):
35
+ if os.path.exists(filename):
36
+ return
37
  response = requests.get(url)
38
  with open(filename, "wb") as f:
39
  f.write(response.content)
 
86
  base_color = base_colors(len(ink.strokes) - 1 - i)
87
  hsv_color = colorsys.rgb_to_hsv(*base_color[:3])
88
 
89
+ darker_color = colorsys.hsv_to_rgb(hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65))
90
+ colors = [mcolors.to_rgba(darker_color, alpha=1 - (0.5 * j / len(x))) for j in range(len(x))]
 
 
 
 
 
91
 
92
  points = np.array([x, y]).T.reshape(-1, 1, 2)
93
  segments = np.concatenate([points[:-1], points[1:]], axis=1)
94
 
95
  lc = LineCollection(segments, colors=colors, linewidth=lw)
96
  if with_path:
97
+ lc.set_path_effects([withStroke(linewidth=lw * 1.25, foreground=path_color)])
 
 
98
  ax.add_collection(lc)
99
 
100
  ax.set_xlim(0, 224)
 
102
  ax.invert_yaxis()
103
 
104
 
105
+ def plot_ink_to_video(ink, output_name, lw=1.8, input_image=None, path_color="white", fps=30):
 
 
106
  fig, ax = plt.subplots(figsize=(4, 4), dpi=150)
107
 
108
  if input_image is not None:
 
136
 
137
  base_color = base_colors(len(ink.strokes) - 1 - stroke_index)
138
  hsv_color = colorsys.rgb_to_hsv(*base_color[:3])
139
+ darker_color = colorsys.hsv_to_rgb(hsv_color[0], hsv_color[1], max(0, hsv_color[2] * 0.65))
140
+ visible_segments = segments[: frame - points_drawn] if frame - points_drawn < len(segments) else segments
 
 
 
 
 
 
141
  colors = [
142
+ mcolors.to_rgba(darker_color, alpha=1 - (0.5 * j / len(visible_segments)))
 
 
143
  for j in range(len(visible_segments))
144
  ]
145
 
146
  if len(visible_segments) > 0:
147
  lc = LineCollection(visible_segments, colors=colors, linewidth=lw)
148
+ lc.set_path_effects([withStroke(linewidth=lw * 1.25, foreground=path_color)])
 
 
149
  ax.add_collection(lc)
150
 
151
  points_drawn += len(segments)
 
237
  if not os.path.exists(path):
238
  continue
239
  samples = os.listdir(path)
240
+ for name in tqdm(samples, desc=f"Generating {Model}-{Dataset}-{mode} videos"):
 
 
241
  example_id = name.strip(".png")
242
+ inkml_file = os.path.join(inkml_path_base, mode, f"{example_id}.inkml")
 
 
243
  if not os.path.exists(inkml_file):
244
  continue
245
  video_filename = f"{Model}_{Dataset}_{mode}_{example_id}.mp4"