HanningChen commited on
Commit
e992b8d
·
1 Parent(s): a9de562
Files changed (2) hide show
  1. webui/app.py +1 -1
  2. webui/weights.py +9 -2
webui/app.py CHANGED
@@ -24,7 +24,7 @@ app.mount("/static", StaticFiles(directory=str(WEBUI_DIR / "static")), name="sta
24
  app.mount("/results", StaticFiles(directory=str(RESULT_DIR)), name="results")
25
 
26
  # ---- weights repo ----
27
- WEIGHTS_REPO = os.getenv("TASKCLIP_WEIGHTS_REPO", "BiasLab2025/YOUR-WEIGHTS-REPO") # <-- change default
28
  WEIGHTS_DIR = get_weights_dir(WEIGHTS_REPO)
29
 
30
  CKPT_DIR = WEIGHTS_DIR / "checkpoints"
 
24
  app.mount("/results", StaticFiles(directory=str(RESULT_DIR)), name="results")
25
 
26
  # ---- weights repo ----
27
+ WEIGHTS_REPO = os.getenv("TASKCLIP_WEIGHTS_REPO", "BiasLab2025/taskclip-weights")
28
  WEIGHTS_DIR = get_weights_dir(WEIGHTS_REPO)
29
 
30
  CKPT_DIR = WEIGHTS_DIR / "checkpoints"
webui/weights.py CHANGED
@@ -1,12 +1,19 @@
1
- # webui/weights.py
2
  import os
3
  from pathlib import Path
4
  from huggingface_hub import snapshot_download
5
 
6
  def get_weights_dir(repo_id: str) -> Path:
7
- token = os.getenv("HF_TOKEN") # only needed if repo is private
 
 
 
 
 
 
 
8
  p = snapshot_download(
9
  repo_id=repo_id,
 
10
  local_dir="weights_cache",
11
  local_dir_use_symlinks=False,
12
  token=token,
 
 
1
  import os
2
  from pathlib import Path
3
  from huggingface_hub import snapshot_download
4
 
5
  def get_weights_dir(repo_id: str) -> Path:
6
+ # repo_id must be like "BiasLab2025/taskclip-weights" (NOT a URL)
7
+ repo_id = repo_id.strip()
8
+ if repo_id.startswith("http"):
9
+ # allow passing a full URL by accident
10
+ repo_id = repo_id.rstrip("/").split("huggingface.co/")[-1]
11
+
12
+ token = os.getenv("HF_TOKEN") # only needed if the repo is private
13
+
14
  p = snapshot_download(
15
  repo_id=repo_id,
16
+ repo_type="model", # IMPORTANT for your weights repo
17
  local_dir="weights_cache",
18
  local_dir_use_symlinks=False,
19
  token=token,