deepkyu commited on
Commit
e8ce90b
1 Parent(s): c3014da

add checkpoint downloading code

Browse files
Files changed (1) hide show
  1. app.py +21 -1
app.py CHANGED
@@ -1,7 +1,27 @@
1
- import gradio as gr
 
2
  from pathlib import Path
 
 
 
3
  from demo import SdmCompressionDemo
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  if __name__ == "__main__":
6
  servicer = SdmCompressionDemo()
7
  example_list = servicer.get_example_list()
 
1
+ import os
2
+ import subprocess
3
  from pathlib import Path
4
+
5
+ import gradio as gr
6
+
7
  from demo import SdmCompressionDemo
8
 
9
+ dest_path_config = Path('checkpoints/BK-SDM-Small_iter50000/unet/config.json')
10
+ dest_path_torch_ckpt = Path('checkpoints/BK-SDM-Small_iter50000/unet/diffusion_pytorch_model.bin')
11
+ BK_SDM_CONFIG_URL: str = os.getenv('CHECKPOINT_CONFIG', None)
12
+ BK_SDM_TORCH_CKPT_URL: str = os.getenv('CHECKPOINT_PYTORCH_BIN', None)
13
+ assert BK_SDM_CONFIG_URL is not None
14
+ assert BK_SDM_TORCH_CKPT_URL is not None
15
+
16
+ subprocess.call(
17
+ f"wget --no-check-certificate -O {dest_path_config} {BK_SDM_CONFIG_URL}",
18
+ shell=True
19
+ )
20
+ subprocess.call(
21
+ f"wget --no-check-certificate -O {dest_path_torch_ckpt} {BK_SDM_TORCH_CKPT_URL}",
22
+ shell=True
23
+ )
24
+
25
  if __name__ == "__main__":
26
  servicer = SdmCompressionDemo()
27
  example_list = servicer.get_example_list()