SpringWang08 commited on
Commit
45b44ff
·
verified ·
1 Parent(s): cb6aa4c

Add SOUP model option

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -35,7 +35,7 @@ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
35
 
36
  ROOT_DIR = Path(__file__).resolve().parent
37
  CONFIG_PATH = ROOT_DIR / "configs" / "medical_vqa.yaml"
38
- VARIANT_ORDER = ["A1", "A2", "B1", "B2", "DPO", "PPO"]
39
  MODEL_DISPLAY_NAMES = {
40
  "A1": "A1 LSTM",
41
  "A2": "A2 Transformer",
@@ -43,6 +43,7 @@ MODEL_DISPLAY_NAMES = {
43
  "B2": "B2 Fine-tuned",
44
  "DPO": "DPO Alignment",
45
  "PPO": "PPO RL refinement",
 
46
  }
47
  HF_MODEL_REPOS = {
48
  "A1": "SpringWang08/medical-vqa-a1",
@@ -51,6 +52,7 @@ HF_MODEL_REPOS = {
51
  "B2": "SpringWang08/medical-vqa-b2",
52
  "DPO": "SpringWang08/medical-vqa-dpo",
53
  "PPO": "SpringWang08/medical-vqa-ppo",
 
54
  }
55
 
56
  with open(CONFIG_PATH, "r", encoding="utf-8") as f:
@@ -185,7 +187,7 @@ def _ensure_llava_bundle() -> dict[str, Any]:
185
  return llava_bundle
186
 
187
  wrapper, processor, base_model = _build_llava_base_and_processor()
188
- adapter_variants = ["B2", "DPO", "PPO"]
189
  first_variant = adapter_variants[0]
190
  model = PeftModel.from_pretrained(
191
  base_model,
 
35
 
36
  ROOT_DIR = Path(__file__).resolve().parent
37
  CONFIG_PATH = ROOT_DIR / "configs" / "medical_vqa.yaml"
38
+ VARIANT_ORDER = ["A1", "A2", "B1", "B2", "DPO", "PPO", "SOUP"]
39
  MODEL_DISPLAY_NAMES = {
40
  "A1": "A1 LSTM",
41
  "A2": "A2 Transformer",
 
43
  "B2": "B2 Fine-tuned",
44
  "DPO": "DPO Alignment",
45
  "PPO": "PPO RL refinement",
46
+ "SOUP": "SOUP Model Soup",
47
  }
48
  HF_MODEL_REPOS = {
49
  "A1": "SpringWang08/medical-vqa-a1",
 
52
  "B2": "SpringWang08/medical-vqa-b2",
53
  "DPO": "SpringWang08/medical-vqa-dpo",
54
  "PPO": "SpringWang08/medical-vqa-ppo",
55
+ "SOUP": "SpringWang08/medical-vqa-soup",
56
  }
57
 
58
  with open(CONFIG_PATH, "r", encoding="utf-8") as f:
 
187
  return llava_bundle
188
 
189
  wrapper, processor, base_model = _build_llava_base_and_processor()
190
+ adapter_variants = ["B2", "DPO", "PPO", "SOUP"]
191
  first_variant = adapter_variants[0]
192
  model = PeftModel.from_pretrained(
193
  base_model,