Leon Sick commited on
Commit
1045df5
1 Parent(s): f7d197a

start of space

Browse files
Files changed (2) hide show
  1. Dockerfile +1 -0
  2. model.py +13 -3
Dockerfile CHANGED
@@ -49,6 +49,7 @@ RUN pip install --no-cache-dir -U \
49
  Pillow==9.5.0 \
50
  colored==1.4.4
51
  RUN pip install --no-cache-dir -U gradio==4.43.0
 
52
 
53
  COPY --chown=1000 . ${HOME}/app
54
  RUN git clone https://github.com/facebookresearch/CutLER.git
 
49
  Pillow==9.5.0 \
50
  colored==1.4.4
51
  RUN pip install --no-cache-dir -U gradio==4.43.0
52
+ RUN pip install -U "huggingface_hub[cli]"
53
 
54
  COPY --chown=1000 . ${HOME}/app
55
  RUN git clone https://github.com/facebookresearch/CutLER.git
model.py CHANGED
@@ -3,6 +3,7 @@
3
 
4
  import argparse
5
  import multiprocessing as mp
 
6
  import pathlib
7
  import shlex
8
  import subprocess
@@ -87,14 +88,23 @@ WEIGHT_URL = 'http://dl.fbaipublicfiles.com/cutler/checkpoints/cutler_cascade_fi
87
 
88
 
89
  def load_model(score_threshold: float) -> VisualizationDemo:
 
 
 
 
90
  model_dir = pathlib.Path('checkpoints')
91
  model_dir.mkdir(exist_ok=True)
92
- weight_path = model_dir / WEIGHT_URL.split('/')[-1]
 
93
  print("***weight_path***", weight_path)
94
  print("torch.cuda.is_available()", torch.cuda.is_available())
95
 
96
- if not weight_path.exists():
97
- subprocess.run(shlex.split(f'wget {WEIGHT_URL} -O {weight_path}'))
 
 
 
 
98
 
99
  arg_list = [
100
  '--config-file',
 
3
 
4
  import argparse
5
  import multiprocessing as mp
6
+ import os
7
  import pathlib
8
  import shlex
9
  import subprocess
 
88
 
89
 
90
  def load_model(score_threshold: float) -> VisualizationDemo:
91
+ # Get secrets
92
+ hf_token = os.getenv('HF_TOKEN')
93
+ model_filename = os.getenv('MODEL_NAME')
94
+
95
  model_dir = pathlib.Path('checkpoints')
96
  model_dir.mkdir(exist_ok=True)
97
+ #weight_path = model_dir / WEIGHT_URL.split('/')[-1]
98
+ weight_path = model_dir / ".cache" / "huggingface" / "download"/ model_filename
99
  print("***weight_path***", weight_path)
100
  print("torch.cuda.is_available()", torch.cuda.is_available())
101
 
102
+
103
+ # Load the model weights file from huggingface hub, use token to authenticate
104
+ subprocess.run(shlex.split(f'huggingface-cli download leonsick/cuts3d_zeroshot {model_filename} --token {hf_token} --local-dir {model_dir}'))
105
+
106
+ #if not weight_path.exists():
107
+ # subprocess.run(shlex.split(f'wget {WEIGHT_URL} -O {weight_path}'))
108
 
109
  arg_list = [
110
  '--config-file',