SpringWang08 commited on
Commit
9c71261
·
1 Parent(s): d9a0039

Lock Space to B2-only demo

Browse files
Files changed (4) hide show
  1. Dockerfile +1 -0
  2. web/README.md +12 -3
  3. web/main.py +17 -9
  4. web/static/index.html +29 -11
Dockerfile CHANGED
@@ -7,6 +7,7 @@ ENV DEBIAN_FRONTEND=noninteractive \
7
  HF_HOME=/data/.huggingface \
8
  HUGGINGFACE_HUB_CACHE=/data/.huggingface/hub \
9
  TRANSFORMERS_CACHE=/data/.huggingface/transformers \
 
10
  WEB_PRELOAD_MODELS=0 \
11
  ANSWER_REWRITE_ENABLED=0
12
 
 
7
  HF_HOME=/data/.huggingface \
8
  HUGGINGFACE_HUB_CACHE=/data/.huggingface/hub \
9
  TRANSFORMERS_CACHE=/data/.huggingface/transformers \
10
+ MEDVQA_ACTIVE_VARIANTS=B2 \
11
  WEB_PRELOAD_MODELS=0 \
12
  ANSWER_REWRITE_ENABLED=0
13
 
web/README.md CHANGED
@@ -5,7 +5,8 @@ Thư mục này chứa FastAPI + web UI để:
5
  - upload ảnh
6
  - nhập câu hỏi VQA
7
  - chạy dự đoán
8
- - so sánh 6 model: `A1`, `A2`, `B1`, `B2`, `DPO`, `PPO`
 
9
 
10
  ### Chạy server
11
 
@@ -23,6 +24,14 @@ WEB_PRELOAD_MODELS=1 uvicorn web.main:app --host 0.0.0.0 --port 8000
23
 
24
  Mặc định hiện tại là `WEB_PRELOAD_MODELS=0` để Space khởi động nhẹ hơn. Chỉ bật `1` khi GPU đủ mạnh và bạn muốn preload trước.
25
 
 
 
 
 
 
 
 
 
26
  Khi chạy trên GPU, nên để `--workers 1` để tránh mỗi worker nạp một bản model riêng.
27
 
28
  ### Chạy bằng Docker
@@ -78,8 +87,8 @@ http://localhost:8000
78
  - form-data:
79
  - `question`: câu hỏi VQA
80
  - `image`: ảnh đầu vào
81
- - `model_name` hoặc `model_names`:
82
- - nếu bỏ trống thì chạy toàn bộ 6 model
83
  - `model_names` nhận chuỗi JSON list hoặc chuỗi phân tách bằng dấu phẩy
84
 
85
  ### Artifact cần có
 
5
  - upload ảnh
6
  - nhập câu hỏi VQA
7
  - chạy dự đoán
8
+ - chạy mặc định model `B2` trên Hugging Face Space
9
+ - nếu cần, vẫn có thể bật lại các model khác bằng biến môi trường
10
 
11
  ### Chạy server
12
 
 
24
 
25
  Mặc định hiện tại là `WEB_PRELOAD_MODELS=0` để Space khởi động nhẹ hơn. Chỉ bật `1` khi GPU đủ mạnh và bạn muốn preload trước.
26
 
27
+ Mặc định Space chỉ mở chế độ `B2` để giảm RAM/VRAM:
28
+
29
+ ```bash
30
+ MEDVQA_ACTIVE_VARIANTS=B2
31
+ ```
32
+
33
+ Nếu muốn chạy nhiều model hơn, đặt `MEDVQA_ACTIVE_VARIANTS` thành danh sách ngăn cách bởi dấu phẩy, ví dụ `A1,A2,B2`.
34
+
35
  Khi chạy trên GPU, nên để `--workers 1` để tránh mỗi worker nạp một bản model riêng.
36
 
37
  ### Chạy bằng Docker
 
87
  - form-data:
88
  - `question`: câu hỏi VQA
89
  - `image`: ảnh đầu vào
90
+ - `model_name` hoặc `model_names`:
91
+ - nếu bỏ trống thì chạy các model đang bật trong `MEDVQA_ACTIVE_VARIANTS`
92
  - `model_names` nhận chuỗi JSON list hoặc chuỗi phân tách bằng dấu phẩy
93
 
94
  ### Artifact cần có
web/main.py CHANGED
@@ -133,6 +133,12 @@ class VQAServerState:
133
  self.preload_models = os.getenv("WEB_PRELOAD_MODELS", "0") == "1"
134
  # Chạy lần lượt và giải phóng model sau mỗi lượt để giảm đỉnh RAM/VRAM.
135
  self.release_after_predict = os.getenv("WEB_RELEASE_AFTER_PREDICT", "1") == "1"
 
 
 
 
 
 
136
  self.progress_state: dict[str, Any] = {
137
  "job_id": "",
138
  "active": False,
@@ -507,6 +513,8 @@ def _resolve_variant_artifact(variant: str) -> dict[str, Any]:
507
  def _llava_adapter_specs() -> list[tuple[str, Path]]:
508
  specs: list[tuple[str, Path]] = []
509
  for variant in ("B2", "DPO", "PPO"):
 
 
510
  artifact = _resolve_variant_artifact(variant)["path"]
511
  if isinstance(artifact, Path) and artifact.exists():
512
  specs.append((variant, artifact))
@@ -1051,26 +1059,26 @@ def _parse_model_selection(raw_model_name: Optional[str], raw_model_names: Optio
1051
  parsed = [part.strip() for part in raw_model_names.split(",") if part.strip()]
1052
  if isinstance(parsed, str):
1053
  parsed = [parsed]
1054
- selected = [name for name in parsed if name in VARIANT_ORDER]
1055
  if selected:
1056
  return selected
1057
 
1058
- if raw_model_name and raw_model_name in VARIANT_ORDER:
1059
  return [raw_model_name]
1060
 
1061
- return VARIANT_ORDER[:]
1062
 
1063
 
1064
  def _variant_availability() -> dict[str, dict[str, Any]]:
1065
  b2_checkpoint = _select_best_b2_checkpoint(ROOT_DIR / "checkpoints" / "B2")
1066
  cuda_ready = torch.cuda.is_available()
1067
  return {
1068
- "A1": {"available": (_artifact_exists(ROOT_DIR / "checkpoints" / "medical_vqa_A1_best.pth") or bool(state.hub_model_ids.get("A1"))), "artifact": str(ROOT_DIR / "checkpoints" / "medical_vqa_A1_best.pth") if _artifact_exists(ROOT_DIR / "checkpoints" / "medical_vqa_A1_best.pth") else state.hub_model_ids.get("A1", "")},
1069
- "A2": {"available": (_artifact_exists(ROOT_DIR / "checkpoints" / "medical_vqa_A2_best.pth") or bool(state.hub_model_ids.get("A2"))), "artifact": str(ROOT_DIR / "checkpoints" / "medical_vqa_A2_best.pth") if _artifact_exists(ROOT_DIR / "checkpoints" / "medical_vqa_A2_best.pth") else state.hub_model_ids.get("A2", "")},
1070
- "B1": {"available": cuda_ready, "artifact": state.llava_model_id},
1071
- "B2": {"available": cuda_ready and (b2_checkpoint is not None or bool(state.hub_model_ids.get("B2"))), "artifact": str(b2_checkpoint) if b2_checkpoint else state.hub_model_ids.get("B2", "")},
1072
- "DPO": {"available": cuda_ready and (_artifact_exists(ROOT_DIR / "checkpoints" / "DPO" / "final_adapter") or _artifact_exists(ROOT_DIR / "checkpoints" / "DPO" / "checkpoint-25") or bool(state.hub_model_ids.get("DPO"))), "artifact": "checkpoints/DPO/final_adapter" if _artifact_exists(ROOT_DIR / "checkpoints" / "DPO" / "final_adapter") else state.hub_model_ids.get("DPO", "")},
1073
- "PPO": {"available": cuda_ready and (_artifact_exists(ROOT_DIR / "checkpoints" / "PPO" / "final_adapter") or bool(state.hub_model_ids.get("PPO"))), "artifact": "checkpoints/PPO/final_adapter" if _artifact_exists(ROOT_DIR / "checkpoints" / "PPO" / "final_adapter") else state.hub_model_ids.get("PPO", "")},
1074
  }
1075
 
1076
 
 
133
  self.preload_models = os.getenv("WEB_PRELOAD_MODELS", "0") == "1"
134
  # Chạy lần lượt và giải phóng model sau mỗi lượt để giảm đỉnh RAM/VRAM.
135
  self.release_after_predict = os.getenv("WEB_RELEASE_AFTER_PREDICT", "1") == "1"
136
+ raw_active_variants = os.getenv("MEDVQA_ACTIVE_VARIANTS", "B2")
137
+ self.active_variants = {
138
+ variant.strip()
139
+ for variant in raw_active_variants.split(",")
140
+ if variant.strip() in VARIANT_ORDER
141
+ } or {"B2"}
142
  self.progress_state: dict[str, Any] = {
143
  "job_id": "",
144
  "active": False,
 
513
  def _llava_adapter_specs() -> list[tuple[str, Path]]:
514
  specs: list[tuple[str, Path]] = []
515
  for variant in ("B2", "DPO", "PPO"):
516
+ if variant not in state.active_variants:
517
+ continue
518
  artifact = _resolve_variant_artifact(variant)["path"]
519
  if isinstance(artifact, Path) and artifact.exists():
520
  specs.append((variant, artifact))
 
1059
  parsed = [part.strip() for part in raw_model_names.split(",") if part.strip()]
1060
  if isinstance(parsed, str):
1061
  parsed = [parsed]
1062
+ selected = [name for name in parsed if name in VARIANT_ORDER and name in state.active_variants]
1063
  if selected:
1064
  return selected
1065
 
1066
+ if raw_model_name and raw_model_name in VARIANT_ORDER and raw_model_name in state.active_variants:
1067
  return [raw_model_name]
1068
 
1069
+ return [variant for variant in VARIANT_ORDER if variant in state.active_variants]
1070
 
1071
 
1072
  def _variant_availability() -> dict[str, dict[str, Any]]:
1073
  b2_checkpoint = _select_best_b2_checkpoint(ROOT_DIR / "checkpoints" / "B2")
1074
  cuda_ready = torch.cuda.is_available()
1075
  return {
1076
+ "A1": {"available": ("A1" in state.active_variants) and (_artifact_exists(ROOT_DIR / "checkpoints" / "medical_vqa_A1_best.pth") or bool(state.hub_model_ids.get("A1"))), "artifact": str(ROOT_DIR / "checkpoints" / "medical_vqa_A1_best.pth") if _artifact_exists(ROOT_DIR / "checkpoints" / "medical_vqa_A1_best.pth") else state.hub_model_ids.get("A1", "")},
1077
+ "A2": {"available": ("A2" in state.active_variants) and (_artifact_exists(ROOT_DIR / "checkpoints" / "medical_vqa_A2_best.pth") or bool(state.hub_model_ids.get("A2"))), "artifact": str(ROOT_DIR / "checkpoints" / "medical_vqa_A2_best.pth") if _artifact_exists(ROOT_DIR / "checkpoints" / "medical_vqa_A2_best.pth") else state.hub_model_ids.get("A2", "")},
1078
+ "B1": {"available": ("B1" in state.active_variants) and cuda_ready, "artifact": state.llava_model_id},
1079
+ "B2": {"available": ("B2" in state.active_variants) and cuda_ready and (b2_checkpoint is not None or bool(state.hub_model_ids.get("B2"))), "artifact": str(b2_checkpoint) if b2_checkpoint else state.hub_model_ids.get("B2", "")},
1080
+ "DPO": {"available": ("DPO" in state.active_variants) and cuda_ready and (_artifact_exists(ROOT_DIR / "checkpoints" / "DPO" / "final_adapter") or _artifact_exists(ROOT_DIR / "checkpoints" / "DPO" / "checkpoint-25") or bool(state.hub_model_ids.get("DPO"))), "artifact": "checkpoints/DPO/final_adapter" if _artifact_exists(ROOT_DIR / "checkpoints" / "DPO" / "final_adapter") else state.hub_model_ids.get("DPO", "")},
1081
+ "PPO": {"available": ("PPO" in state.active_variants) and cuda_ready and (_artifact_exists(ROOT_DIR / "checkpoints" / "PPO" / "final_adapter") or bool(state.hub_model_ids.get("PPO"))), "artifact": "checkpoints/PPO/final_adapter" if _artifact_exists(ROOT_DIR / "checkpoints" / "PPO" / "final_adapter") else state.hub_model_ids.get("PPO", "")},
1082
  }
1083
 
1084
 
web/static/index.html CHANGED
@@ -177,7 +177,7 @@ X2 Vision
177
  <div class="flex flex-col items-center text-center max-w-4xl mx-auto mb-14">
178
  <div class="mb-4 flex items-center gap-2">
179
  <div class="h-[1px] w-12 bg-china-gold"></div>
180
- <span class="text-china-gold font-display text-sm tracking-[0.2em] uppercase">6-model comparison</span>
181
  <div class="h-[1px] w-12 bg-china-gold"></div>
182
  </div>
183
  <h1 class="text-imperial-red text-[42px] md:text-[64px] font-display font-bold leading-[1.1] tracking-tight mb-6 drop-shadow-sm">
@@ -298,7 +298,7 @@ Reset
298
  <span class="material-symbols-outlined absolute right-6 text-[28px] opacity-20 group-hover:opacity-40 transition-opacity text-gold-light">chess_knight</span>
299
  </button>
300
 
301
- <div class="text-center text-sm font-serif italic text-ink-black/60" id="status-text">Select an image, enter a question, then run all six models.</div>
302
  </div>
303
  </div>
304
  </div>
@@ -359,7 +359,7 @@ Alignment and RL variants now have equal room in the grid, making the comparison
359
  <span class="font-display font-bold text-lg tracking-wider">VQA RESEARCH</span>
360
  </div>
361
  <div class="text-[13px] text-paper-white/60 font-serif">
362
- Medical VQA web demo for six-model comparison.
363
  </div>
364
  </div>
365
  <div class="flex gap-8 text-[13px] text-paper-white/80 font-display tracking-widest uppercase">
@@ -399,9 +399,10 @@ Medical VQA web demo for six-model comparison.
399
  };
400
 
401
  let currentImageFile = null;
402
- let selectedModels = new Set(MODEL_ORDER);
403
  let questionSuggestions = [];
404
  let progressTimer = null;
 
405
 
406
  function escapeHtml(value) {
407
  return String(value ?? "")
@@ -597,10 +598,18 @@ Medical VQA web demo for six-model comparison.
597
  function updateModelChips() {
598
  document.querySelectorAll(".model-chip").forEach((chip) => {
599
  const variant = chip.dataset.model;
 
600
  const active = selectedModels.has(variant);
 
 
 
601
  chip.style.background = active ? "#A8181B" : "#fff";
602
  chip.style.color = active ? "#FDFBF7" : "#1A1A1A";
603
  chip.style.borderColor = active ? "#A8181B" : "rgba(212,175,55,0.35)";
 
 
 
 
604
  });
605
  }
606
 
@@ -626,8 +635,14 @@ Medical VQA web demo for six-model comparison.
626
  try {
627
  const res = await fetch("/v1/models");
628
  const data = await res.json();
 
 
 
 
 
 
629
  updateModelChips();
630
- setStatus("Ready. Upload an image and run all six models.");
631
  } catch (err) {
632
  setStatus(`Failed to load model metadata: ${err.message}`);
633
  }
@@ -666,17 +681,20 @@ Medical VQA web demo for six-model comparison.
666
  document.querySelectorAll(".model-chip").forEach((chip) => {
667
  chip.addEventListener("click", () => {
668
  const variant = chip.dataset.model;
 
 
 
669
  if (selectedModels.has(variant)) selectedModels.delete(variant);
670
- else selectedModels.add(variant);
671
  if (selectedModels.size === 0) {
672
- selectedModels = new Set(MODEL_ORDER);
673
  }
674
  updateModelChips();
675
  });
676
  });
677
 
678
  el.resetBtn.addEventListener("click", () => {
679
- selectedModels = new Set(MODEL_ORDER);
680
  el.question.value = "";
681
  el.imageInput.value = "";
682
  setPreview(null);
@@ -696,13 +714,13 @@ Medical VQA web demo for six-model comparison.
696
  return;
697
  }
698
  if (selectedModels.size === 0) {
699
- setStatus("Please select at least one model.");
700
  return;
701
  }
702
 
703
  el.runBtn.disabled = true;
704
  el.runBtn.querySelector("span").textContent = "Running...";
705
- setStatus("Running all selected models...");
706
  renderRunningModelGrid();
707
  applyTiltEffect(".tilt-card", 5);
708
  startProgressPolling();
@@ -730,7 +748,7 @@ Medical VQA web demo for six-model comparison.
730
 
731
  renderModelGrid(resultData?.payload?.results || []);
732
  applyTiltEffect(".tilt-card", 5);
733
- setStatus(`Done. ${resultData?.payload?.summary?.success_count ?? 0} models succeeded.`);
734
  } catch (err) {
735
  setStatus(err.message || "Prediction failed");
736
  } finally {
 
177
  <div class="flex flex-col items-center text-center max-w-4xl mx-auto mb-14">
178
  <div class="mb-4 flex items-center gap-2">
179
  <div class="h-[1px] w-12 bg-china-gold"></div>
180
+ <span class="text-china-gold font-display text-sm tracking-[0.2em] uppercase">B2-only comparison</span>
181
  <div class="h-[1px] w-12 bg-china-gold"></div>
182
  </div>
183
  <h1 class="text-imperial-red text-[42px] md:text-[64px] font-display font-bold leading-[1.1] tracking-tight mb-6 drop-shadow-sm">
 
298
  <span class="material-symbols-outlined absolute right-6 text-[28px] opacity-20 group-hover:opacity-40 transition-opacity text-gold-light">chess_knight</span>
299
  </button>
300
 
301
+ <div class="text-center text-sm font-serif italic text-ink-black/60" id="status-text">Select an image, enter a question, then run B2.</div>
302
  </div>
303
  </div>
304
  </div>
 
359
  <span class="font-display font-bold text-lg tracking-wider">VQA RESEARCH</span>
360
  </div>
361
  <div class="text-[13px] text-paper-white/60 font-serif">
362
+ Medical VQA web demo for B2-only inference.
363
  </div>
364
  </div>
365
  <div class="flex gap-8 text-[13px] text-paper-white/80 font-display tracking-widest uppercase">
 
399
  };
400
 
401
  let currentImageFile = null;
402
+ let selectedModels = new Set(["B2"]);
403
  let questionSuggestions = [];
404
  let progressTimer = null;
405
+ let modelAvailability = {};
406
 
407
  function escapeHtml(value) {
408
  return String(value ?? "")
 
598
  function updateModelChips() {
599
  document.querySelectorAll(".model-chip").forEach((chip) => {
600
  const variant = chip.dataset.model;
601
+ const available = modelAvailability[variant] !== false;
602
  const active = selectedModels.has(variant);
603
+ chip.disabled = !available;
604
+ chip.style.opacity = available ? "1" : "0.35";
605
+ chip.style.cursor = available ? "pointer" : "not-allowed";
606
  chip.style.background = active ? "#A8181B" : "#fff";
607
  chip.style.color = active ? "#FDFBF7" : "#1A1A1A";
608
  chip.style.borderColor = active ? "#A8181B" : "rgba(212,175,55,0.35)";
609
+ if (!available && !active) {
610
+ chip.style.background = "#faf7f0";
611
+ chip.style.color = "rgba(26,26,26,0.45)";
612
+ }
613
  });
614
  }
615
 
 
635
  try {
636
  const res = await fetch("/v1/models");
637
  const data = await res.json();
638
+ modelAvailability = Object.fromEntries((data.models || []).map((item) => [item.name, Boolean(item.available)]));
639
+ if (!modelAvailability.B2) {
640
+ selectedModels = new Set();
641
+ } else if (!selectedModels.has("B2")) {
642
+ selectedModels = new Set(["B2"]);
643
+ }
644
  updateModelChips();
645
+ setStatus("Ready. Upload an image and run B2.");
646
  } catch (err) {
647
  setStatus(`Failed to load model metadata: ${err.message}`);
648
  }
 
681
  document.querySelectorAll(".model-chip").forEach((chip) => {
682
  chip.addEventListener("click", () => {
683
  const variant = chip.dataset.model;
684
+ if (modelAvailability[variant] === false) {
685
+ return;
686
+ }
687
  if (selectedModels.has(variant)) selectedModels.delete(variant);
688
+ else selectedModels = new Set([variant]);
689
  if (selectedModels.size === 0) {
690
+ selectedModels = new Set(["B2"]);
691
  }
692
  updateModelChips();
693
  });
694
  });
695
 
696
  el.resetBtn.addEventListener("click", () => {
697
+ selectedModels = new Set(["B2"]);
698
  el.question.value = "";
699
  el.imageInput.value = "";
700
  setPreview(null);
 
714
  return;
715
  }
716
  if (selectedModels.size === 0) {
717
+ setStatus("Please select B2.");
718
  return;
719
  }
720
 
721
  el.runBtn.disabled = true;
722
  el.runBtn.querySelector("span").textContent = "Running...";
723
+ setStatus("Running B2...");
724
  renderRunningModelGrid();
725
  applyTiltEffect(".tilt-card", 5);
726
  startProgressPolling();
 
748
 
749
  renderModelGrid(resultData?.payload?.results || []);
750
  applyTiltEffect(".tilt-card", 5);
751
+ setStatus(`Done. B2 succeeded.`);
752
  } catch (err) {
753
  setStatus(err.message || "Prediction failed");
754
  } finally {