Charlie Li commited on
Commit
0c738cb
1 Parent(s): ec8676d

add cache video to speed up

Browse files
Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -3,10 +3,15 @@ import os
3
  import random
4
  import datetime
5
  from utils import *
 
6
 
7
  file_url = "https://storage.googleapis.com/derendering_model/derendering_supp.zip"
8
  filename = "derendering_supp.zip"
9
 
 
 
 
 
10
  download_file(file_url, filename)
11
  unzip_file(filename)
12
  print("Downloaded and unzipped the file.")
@@ -86,17 +91,17 @@ def demo(Dataset, Model, Output_Format):
86
  inkml_file = os.path.join(inkml_path, mode, example_id + ".inkml")
87
  text_field = parse_inkml_annotations(inkml_file)["textField"]
88
  output_text = f"{plot_title[mode]}{text_field}"
89
- # Text output for three modes
90
- # d+t: OCR recognition input to the model
91
- # r+d: Recognition from the model
92
- # vanilla: None
93
  text_outputs.append(output_text)
94
  ink = inkml_to_ink(inkml_file)
95
 
 
 
 
96
  if Output_Format == "Image+Video":
97
- video_filename = mode + ".mp4"
98
- plot_ink_to_video(ink, video_filename, input_image=img)
99
- video_outputs.append(video_filename)
 
100
  else:
101
  video_outputs.append(None)
102
 
 
3
  import random
4
  import datetime
5
  from utils import *
6
+ from pathlib import Path
7
 
8
  file_url = "https://storage.googleapis.com/derendering_model/derendering_supp.zip"
9
  filename = "derendering_supp.zip"
10
 
11
+ # Cache videos to speed up demo
12
+ video_cache_dir = Path("./cached_videos")
13
+ video_cache_dir.mkdir(exist_ok=True)
14
+
15
  download_file(file_url, filename)
16
  unzip_file(filename)
17
  print("Downloaded and unzipped the file.")
 
91
  inkml_file = os.path.join(inkml_path, mode, example_id + ".inkml")
92
  text_field = parse_inkml_annotations(inkml_file)["textField"]
93
  output_text = f"{plot_title[mode]}{text_field}"
 
 
 
 
94
  text_outputs.append(output_text)
95
  ink = inkml_to_ink(inkml_file)
96
 
97
+ video_filename = f"{Model}_{Dataset}_{mode}_{example_id}.mp4"
98
+ video_filepath = video_cache_dir / video_filename
99
+
100
  if Output_Format == "Image+Video":
101
+ if not video_filepath.exists():
102
+ plot_ink_to_video(ink, str(video_filepath), input_image=img)
103
+ print("Cached video at:", video_filepath)
104
+ video_outputs.append("./" + str(video_filepath))
105
  else:
106
  video_outputs.append(None)
107