SpringWang08 commited on
Commit
bfbf130
·
1 Parent(s): 0a2bc32

Load checkpoints from Hugging Face Hub

Browse files
Files changed (1) hide show
  1. web/main.py +79 -7
web/main.py CHANGED
@@ -13,6 +13,7 @@ import torch
13
  from fastapi import FastAPI, File, Form, HTTPException, UploadFile
14
  from fastapi.responses import FileResponse, JSONResponse
15
  from fastapi.staticfiles import StaticFiles
 
16
  from PIL import Image
17
  from peft import PeftModel
18
  from transformers import AutoTokenizer, LlavaForConditionalGeneration, LlavaProcessor
@@ -106,6 +107,17 @@ class VQAServerState:
106
  self.model_b_cfg = CFG.get("model_b", {})
107
  self.eval_cfg = CFG.get("eval", {})
108
  self.models_dir = ROOT_DIR / "checkpoints"
 
 
 
 
 
 
 
 
 
 
 
109
  self.qa_tokenizer = None
110
  self.translator = MedicalTranslator(device="cpu")
111
  self.answer_rewriter = MedicalAnswerRewriter()
@@ -134,6 +146,19 @@ def _artifact_exists(path: Path) -> bool:
134
  return path.exists()
135
 
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  def _as_bool(value: Any) -> bool:
138
  if isinstance(value, bool):
139
  return value
@@ -352,7 +377,20 @@ def _resolve_variant_artifact(variant: str) -> dict[str, Any]:
352
  ckpt_path = ROOT_DIR / "checkpoints" / f"medical_vqa_{variant}_best.pth"
353
  if not ckpt_path.exists():
354
  resume_path = ROOT_DIR / "checkpoints" / f"medical_vqa_{variant}_resume.pth"
355
- ckpt_path = resume_path if resume_path.exists() else ckpt_path
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  return {"type": "direction_a", "path": ckpt_path}
357
 
358
  if variant == "B1":
@@ -360,15 +398,49 @@ def _resolve_variant_artifact(variant: str) -> dict[str, Any]:
360
 
361
  if variant == "B2":
362
  ckpt_dir = _select_best_b2_checkpoint(ROOT_DIR / "checkpoints" / "B2")
 
 
 
 
 
 
 
 
363
  return {"type": "llava_adapter", "path": ckpt_dir}
364
 
365
  if variant == "DPO":
366
  final_adapter = ROOT_DIR / "checkpoints" / "DPO" / "final_adapter"
367
  fallback = ROOT_DIR / "checkpoints" / "DPO" / "checkpoint-25"
368
- return {"type": "llava_adapter", "path": final_adapter if final_adapter.exists() else fallback}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
  if variant == "PPO":
371
  final_adapter = ROOT_DIR / "checkpoints" / "PPO" / "final_adapter"
 
 
 
 
 
 
 
 
 
 
 
 
372
  return {"type": "llava_adapter", "path": final_adapter}
373
 
374
  raise ValueError(f"Unknown variant: {variant}")
@@ -857,12 +929,12 @@ def _variant_availability() -> dict[str, dict[str, Any]]:
857
  b2_checkpoint = _select_best_b2_checkpoint(ROOT_DIR / "checkpoints" / "B2")
858
  cuda_ready = torch.cuda.is_available()
859
  return {
860
- "A1": {"available": (_artifact_exists(ROOT_DIR / "checkpoints" / "medical_vqa_A1_best.pth")), "artifact": "checkpoints/medical_vqa_A1_best.pth"},
861
- "A2": {"available": (_artifact_exists(ROOT_DIR / "checkpoints" / "medical_vqa_A2_best.pth")), "artifact": "checkpoints/medical_vqa_A2_best.pth"},
862
  "B1": {"available": cuda_ready, "artifact": state.llava_model_id},
863
- "B2": {"available": cuda_ready and b2_checkpoint is not None, "artifact": str(b2_checkpoint) if b2_checkpoint else ""},
864
- "DPO": {"available": cuda_ready and (_artifact_exists(ROOT_DIR / "checkpoints" / "DPO" / "final_adapter") or _artifact_exists(ROOT_DIR / "checkpoints" / "DPO" / "checkpoint-25")), "artifact": "checkpoints/DPO/final_adapter"},
865
- "PPO": {"available": cuda_ready and _artifact_exists(ROOT_DIR / "checkpoints" / "PPO" / "final_adapter"), "artifact": "checkpoints/PPO/final_adapter"},
866
  }
867
 
868
 
 
13
  from fastapi import FastAPI, File, Form, HTTPException, UploadFile
14
  from fastapi.responses import FileResponse, JSONResponse
15
  from fastapi.staticfiles import StaticFiles
16
+ from huggingface_hub import snapshot_download
17
  from PIL import Image
18
  from peft import PeftModel
19
  from transformers import AutoTokenizer, LlavaForConditionalGeneration, LlavaProcessor
 
107
  self.model_b_cfg = CFG.get("model_b", {})
108
  self.eval_cfg = CFG.get("eval", {})
109
  self.models_dir = ROOT_DIR / "checkpoints"
110
+ self.artifact_cache_dir = Path(
111
+ os.getenv("MEDVQA_ARTIFACT_CACHE", str(ROOT_DIR / ".cache" / "hub_artifacts"))
112
+ )
113
+ self.artifact_cache_dir.mkdir(parents=True, exist_ok=True)
114
+ self.hub_model_ids = {
115
+ "A1": os.getenv("MEDVQA_A1_MODEL_ID", "SpringWang08/medical-vqa-a1"),
116
+ "A2": os.getenv("MEDVQA_A2_MODEL_ID", "SpringWang08/medical-vqa-a2"),
117
+ "B2": os.getenv("MEDVQA_B2_MODEL_ID", "SpringWang08/medical-vqa-b2"),
118
+ "DPO": os.getenv("MEDVQA_DPO_MODEL_ID", "SpringWang08/medical-vqa-dpo"),
119
+ "PPO": os.getenv("MEDVQA_PPO_MODEL_ID", "SpringWang08/medical-vqa-ppo"),
120
+ }
121
  self.qa_tokenizer = None
122
  self.translator = MedicalTranslator(device="cpu")
123
  self.answer_rewriter = MedicalAnswerRewriter()
 
146
  return path.exists()
147
 
148
 
149
+ def _download_hub_snapshot(repo_id: str, cache_subdir: str, allow_patterns: Optional[list[str]] = None) -> Path:
150
+ target_dir = state.artifact_cache_dir / cache_subdir
151
+ target_dir.mkdir(parents=True, exist_ok=True)
152
+ snapshot_download(
153
+ repo_id=repo_id,
154
+ repo_type="model",
155
+ local_dir=str(target_dir),
156
+ local_dir_use_symlinks=False,
157
+ allow_patterns=allow_patterns,
158
+ )
159
+ return target_dir
160
+
161
+
162
  def _as_bool(value: Any) -> bool:
163
  if isinstance(value, bool):
164
  return value
 
377
  ckpt_path = ROOT_DIR / "checkpoints" / f"medical_vqa_{variant}_best.pth"
378
  if not ckpt_path.exists():
379
  resume_path = ROOT_DIR / "checkpoints" / f"medical_vqa_{variant}_resume.pth"
380
+ if resume_path.exists():
381
+ ckpt_path = resume_path
382
+ else:
383
+ repo_id = state.hub_model_ids.get(variant, "")
384
+ if repo_id:
385
+ downloaded_dir = _download_hub_snapshot(
386
+ repo_id=repo_id,
387
+ cache_subdir=variant.lower(),
388
+ allow_patterns=["README.md", "*.pth"],
389
+ )
390
+ downloaded_ckpt = downloaded_dir / f"medical_vqa_{variant}_best.pth"
391
+ if not downloaded_ckpt.exists():
392
+ downloaded_ckpt = downloaded_dir / f"medical_vqa_{variant}_resume.pth"
393
+ ckpt_path = downloaded_ckpt
394
  return {"type": "direction_a", "path": ckpt_path}
395
 
396
  if variant == "B1":
 
398
 
399
  if variant == "B2":
400
  ckpt_dir = _select_best_b2_checkpoint(ROOT_DIR / "checkpoints" / "B2")
401
+ if ckpt_dir is None:
402
+ repo_id = state.hub_model_ids.get("B2", "")
403
+ if repo_id:
404
+ ckpt_dir = _download_hub_snapshot(
405
+ repo_id=repo_id,
406
+ cache_subdir="b2",
407
+ allow_patterns=["README.md", "adapter_model.safetensors", "adapter_config.json", "tokenizer.json", "tokenizer_config.json", "processor_config.json", "chat_template.jinja"],
408
+ )
409
  return {"type": "llava_adapter", "path": ckpt_dir}
410
 
411
  if variant == "DPO":
412
  final_adapter = ROOT_DIR / "checkpoints" / "DPO" / "final_adapter"
413
  fallback = ROOT_DIR / "checkpoints" / "DPO" / "checkpoint-25"
414
+ if final_adapter.exists():
415
+ return {"type": "llava_adapter", "path": final_adapter}
416
+ if fallback.exists():
417
+ return {"type": "llava_adapter", "path": fallback}
418
+ repo_id = state.hub_model_ids.get("DPO", "")
419
+ if repo_id:
420
+ return {
421
+ "type": "llava_adapter",
422
+ "path": _download_hub_snapshot(
423
+ repo_id=repo_id,
424
+ cache_subdir="dpo",
425
+ allow_patterns=["README.md", "adapter_model.safetensors", "adapter_config.json", "tokenizer.json", "tokenizer_config.json", "processor_config.json", "chat_template.jinja"],
426
+ ),
427
+ }
428
+ return {"type": "llava_adapter", "path": final_adapter}
429
 
430
  if variant == "PPO":
431
  final_adapter = ROOT_DIR / "checkpoints" / "PPO" / "final_adapter"
432
+ if final_adapter.exists():
433
+ return {"type": "llava_adapter", "path": final_adapter}
434
+ repo_id = state.hub_model_ids.get("PPO", "")
435
+ if repo_id:
436
+ return {
437
+ "type": "llava_adapter",
438
+ "path": _download_hub_snapshot(
439
+ repo_id=repo_id,
440
+ cache_subdir="ppo",
441
+ allow_patterns=["README.md", "adapter_model.safetensors", "adapter_config.json", "tokenizer.json", "tokenizer_config.json", "processor_config.json", "chat_template.jinja"],
442
+ ),
443
+ }
444
  return {"type": "llava_adapter", "path": final_adapter}
445
 
446
  raise ValueError(f"Unknown variant: {variant}")
 
929
  b2_checkpoint = _select_best_b2_checkpoint(ROOT_DIR / "checkpoints" / "B2")
930
  cuda_ready = torch.cuda.is_available()
931
  return {
932
+ "A1": {"available": (_artifact_exists(ROOT_DIR / "checkpoints" / "medical_vqa_A1_best.pth") or bool(state.hub_model_ids.get("A1"))), "artifact": str(ROOT_DIR / "checkpoints" / "medical_vqa_A1_best.pth") if _artifact_exists(ROOT_DIR / "checkpoints" / "medical_vqa_A1_best.pth") else state.hub_model_ids.get("A1", "")},
933
+ "A2": {"available": (_artifact_exists(ROOT_DIR / "checkpoints" / "medical_vqa_A2_best.pth") or bool(state.hub_model_ids.get("A2"))), "artifact": str(ROOT_DIR / "checkpoints" / "medical_vqa_A2_best.pth") if _artifact_exists(ROOT_DIR / "checkpoints" / "medical_vqa_A2_best.pth") else state.hub_model_ids.get("A2", "")},
934
  "B1": {"available": cuda_ready, "artifact": state.llava_model_id},
935
+ "B2": {"available": cuda_ready and (b2_checkpoint is not None or bool(state.hub_model_ids.get("B2"))), "artifact": str(b2_checkpoint) if b2_checkpoint else state.hub_model_ids.get("B2", "")},
936
+ "DPO": {"available": cuda_ready and (_artifact_exists(ROOT_DIR / "checkpoints" / "DPO" / "final_adapter") or _artifact_exists(ROOT_DIR / "checkpoints" / "DPO" / "checkpoint-25") or bool(state.hub_model_ids.get("DPO"))), "artifact": "checkpoints/DPO/final_adapter" if _artifact_exists(ROOT_DIR / "checkpoints" / "DPO" / "final_adapter") else state.hub_model_ids.get("DPO", "")},
937
+ "PPO": {"available": cuda_ready and (_artifact_exists(ROOT_DIR / "checkpoints" / "PPO" / "final_adapter") or bool(state.hub_model_ids.get("PPO"))), "artifact": "checkpoints/PPO/final_adapter" if _artifact_exists(ROOT_DIR / "checkpoints" / "PPO" / "final_adapter") else state.hub_model_ids.get("PPO", "")},
938
  }
939
 
940