Spaces:
Paused
Paused
Add SOUP model option
Browse files
app.py
CHANGED
|
@@ -35,7 +35,7 @@ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
|
| 35 |
|
| 36 |
ROOT_DIR = Path(__file__).resolve().parent
|
| 37 |
CONFIG_PATH = ROOT_DIR / "configs" / "medical_vqa.yaml"
|
| 38 |
-
VARIANT_ORDER = ["A1", "A2", "B1", "B2", "DPO", "PPO"]
|
| 39 |
MODEL_DISPLAY_NAMES = {
|
| 40 |
"A1": "A1 LSTM",
|
| 41 |
"A2": "A2 Transformer",
|
|
@@ -43,6 +43,7 @@ MODEL_DISPLAY_NAMES = {
|
|
| 43 |
"B2": "B2 Fine-tuned",
|
| 44 |
"DPO": "DPO Alignment",
|
| 45 |
"PPO": "PPO RL refinement",
|
|
|
|
| 46 |
}
|
| 47 |
HF_MODEL_REPOS = {
|
| 48 |
"A1": "SpringWang08/medical-vqa-a1",
|
|
@@ -51,6 +52,7 @@ HF_MODEL_REPOS = {
|
|
| 51 |
"B2": "SpringWang08/medical-vqa-b2",
|
| 52 |
"DPO": "SpringWang08/medical-vqa-dpo",
|
| 53 |
"PPO": "SpringWang08/medical-vqa-ppo",
|
|
|
|
| 54 |
}
|
| 55 |
|
| 56 |
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
|
|
@@ -185,7 +187,7 @@ def _ensure_llava_bundle() -> dict[str, Any]:
|
|
| 185 |
return llava_bundle
|
| 186 |
|
| 187 |
wrapper, processor, base_model = _build_llava_base_and_processor()
|
| 188 |
-
adapter_variants = ["B2", "DPO", "PPO"]
|
| 189 |
first_variant = adapter_variants[0]
|
| 190 |
model = PeftModel.from_pretrained(
|
| 191 |
base_model,
|
|
|
|
| 35 |
|
| 36 |
ROOT_DIR = Path(__file__).resolve().parent
|
| 37 |
CONFIG_PATH = ROOT_DIR / "configs" / "medical_vqa.yaml"
|
| 38 |
+
VARIANT_ORDER = ["A1", "A2", "B1", "B2", "DPO", "PPO", "SOUP"]
|
| 39 |
MODEL_DISPLAY_NAMES = {
|
| 40 |
"A1": "A1 LSTM",
|
| 41 |
"A2": "A2 Transformer",
|
|
|
|
| 43 |
"B2": "B2 Fine-tuned",
|
| 44 |
"DPO": "DPO Alignment",
|
| 45 |
"PPO": "PPO RL refinement",
|
| 46 |
+
"SOUP": "SOUP Model Soup",
|
| 47 |
}
|
| 48 |
HF_MODEL_REPOS = {
|
| 49 |
"A1": "SpringWang08/medical-vqa-a1",
|
|
|
|
| 52 |
"B2": "SpringWang08/medical-vqa-b2",
|
| 53 |
"DPO": "SpringWang08/medical-vqa-dpo",
|
| 54 |
"PPO": "SpringWang08/medical-vqa-ppo",
|
| 55 |
+
"SOUP": "SpringWang08/medical-vqa-soup",
|
| 56 |
}
|
| 57 |
|
| 58 |
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
|
|
|
|
| 187 |
return llava_bundle
|
| 188 |
|
| 189 |
wrapper, processor, base_model = _build_llava_base_and_processor()
|
| 190 |
+
adapter_variants = ["B2", "DPO", "PPO", "SOUP"]
|
| 191 |
first_variant = adapter_variants[0]
|
| 192 |
model = PeftModel.from_pretrained(
|
| 193 |
base_model,
|