levihsu commited on
Commit
04cc394
·
verified ·
1 Parent(s): 81d335d

Update preprocess/humanparsing/run_parsing.py

Browse files
preprocess/humanparsing/run_parsing.py CHANGED
@@ -10,6 +10,8 @@ import torch
10
 
11
  from huggingface_hub import hf_hub_download
12
 
 
 
13
  class Parsing:
14
  def __init__(self, gpu_id: int):
15
  self.gpu_id = gpu_id
@@ -19,14 +21,13 @@ class Parsing:
19
  session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
20
  session_options.add_session_config_entry('gpu_id', str(gpu_id))
21
 
22
- print('start download')
23
- hf_hub_download(repo_id="levihsu/OOTDiffusion",
24
- filename="checkpoints/humanparsing/parsing_atr.onnx",
25
- local_dir=os.path.join(Path(__file__).absolute().parents[2].absolute(), 'checkpoints/humanparsing'))
26
- hf_hub_download(repo_id="levihsu/OOTDiffusion",
27
- filename="checkpoints/humanparsing/parsing_lip.onnx",
28
- local_dir=os.path.join(Path(__file__).absolute().parents[2].absolute(), 'checkpoints/humanparsing'))
29
- print('finish download')
30
 
31
  self.session = ort.InferenceSession(os.path.join(Path(__file__).absolute().parents[2].absolute(), 'checkpoints/humanparsing/parsing_atr.onnx'),
32
  sess_options=session_options, providers=['CPUExecutionProvider'])
 
10
 
11
  from huggingface_hub import hf_hub_download
12
 
13
+
14
+
15
  class Parsing:
16
  def __init__(self, gpu_id: int):
17
  self.gpu_id = gpu_id
 
21
  session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
22
  session_options.add_session_config_entry('gpu_id', str(gpu_id))
23
 
24
+ parsing_ckpt_path = os.path.join(Path(__file__).absolute().parents[2].absolute(), 'checkpoints/humanparsing')
25
+ atr_model_path = 'https://huggingface.co/levihsu/OOTDiffusion/blob/main/checkpoints/humanparsing/parsing_atr.onnx'
26
+ lip_model_path = 'https://huggingface.co/levihsu/OOTDiffusion/blob/main/checkpoints/humanparsing/parsing_lip.onnx'
27
+
28
+ from basicsr.utils.download_util import load_file_from_url
29
+ load_file_from_url(atr_model_path, model_dir=parsing_ckpt_path)
30
+ load_file_from_url(lip_model_path, model_dir=parsing_ckpt_path)
 
31
 
32
  self.session = ort.InferenceSession(os.path.join(Path(__file__).absolute().parents[2].absolute(), 'checkpoints/humanparsing/parsing_atr.onnx'),
33
  sess_options=session_options, providers=['CPUExecutionProvider'])