AmrYassinIsFree commited on
Commit
d6ca6d1
·
1 Parent(s): 2daebaf

add custom models

Browse files
Files changed (4) hide show
  1. .gitignore +3 -0
  2. app.py +44 -1
  3. bench.py +32 -4
  4. models.py +34 -1
.gitignore CHANGED
@@ -205,3 +205,6 @@ cython_debug/
205
  marimo/_static/
206
  marimo/_lsp/
207
  __marimo__/
 
 
 
 
205
  marimo/_static/
206
  marimo/_lsp/
207
  __marimo__/
208
+
209
+ # Embedding Bench custom models
210
+ custom_models.json
app.py CHANGED
@@ -14,9 +14,18 @@ from corpus import build_corpus
14
  from dataset_config import DATASET_PRESETS, DatasetConfig
15
  from evals.quality import evaluate_quality
16
  from evals.speed import evaluate_speed
17
- from models import REGISTRY, ModelConfig
 
 
 
 
 
 
 
18
  from wrapper import load_model
19
 
 
 
20
  # ---------------------------------------------------------------------------
21
  # Page config & custom CSS
22
  # ---------------------------------------------------------------------------
@@ -114,6 +123,40 @@ selected_models = st.sidebar.multiselect(
114
  label_visibility="collapsed",
115
  )
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  st.sidebar.markdown("**Datasets**")
118
  available_datasets = list(DATASET_PRESETS.keys())
119
  selected_datasets = st.sidebar.multiselect(
 
14
  from dataset_config import DATASET_PRESETS, DatasetConfig
15
  from evals.quality import evaluate_quality
16
  from evals.speed import evaluate_speed
17
+ from models import (
18
+ REGISTRY,
19
+ VALID_BACKENDS,
20
+ ModelConfig,
21
+ load_custom_models_from_file,
22
+ register_model,
23
+ save_custom_model_to_file,
24
+ )
25
  from wrapper import load_model
26
 
27
+ load_custom_models_from_file()
28
+
29
  # ---------------------------------------------------------------------------
30
  # Page config & custom CSS
31
  # ---------------------------------------------------------------------------
 
123
  label_visibility="collapsed",
124
  )
125
 
126
+ with st.sidebar.expander("➕ Add Custom Model"):
127
+ with st.form("add_model_form", clear_on_submit=True):
128
+ new_key = st.text_input("Registry key", placeholder="my-model")
129
+ new_name = st.text_input("Display name", placeholder="My Custom Model")
130
+ new_model_id = st.text_input("HuggingFace model ID", placeholder="org/model-name")
131
+ new_backend = st.selectbox("Backend", sorted(VALID_BACKENDS))
132
+ new_gguf_file = st.text_input(
133
+ "GGUF filename (gguf backend only)", value="", placeholder="model.gguf"
134
+ )
135
+ new_is_baseline = st.checkbox("Mark as baseline", value=False)
136
+ new_persist = st.checkbox("Save to disk", value=False,
137
+ help="Persist to custom_models.json so it loads next session")
138
+ submitted = st.form_submit_button("Add Model", use_container_width=True)
139
+ if submitted:
140
+ if not new_key or not new_name or not new_model_id:
141
+ st.sidebar.error("Key, name, and model ID are required.")
142
+ elif new_backend == "gguf" and not new_gguf_file:
143
+ st.sidebar.error("GGUF filename is required for gguf backend.")
144
+ else:
145
+ cfg = ModelConfig(
146
+ name=new_name,
147
+ model_id=new_model_id,
148
+ is_baseline=new_is_baseline,
149
+ backend=new_backend,
150
+ gguf_file=new_gguf_file or None,
151
+ )
152
+ try:
153
+ register_model(new_key, cfg)
154
+ if new_persist:
155
+ save_custom_model_to_file(new_key, cfg)
156
+ st.rerun()
157
+ except ValueError as e:
158
+ st.sidebar.error(str(e))
159
+
160
  st.sidebar.markdown("**Datasets**")
161
  available_datasets = list(DATASET_PRESETS.keys())
162
  selected_datasets = st.sidebar.multiselect(
bench.py CHANGED
@@ -5,7 +5,7 @@ import argparse
5
  from corpus import build_corpus
6
  from dataset_config import DATASET_PRESETS, DatasetConfig
7
  from evals import evaluate_memory, evaluate_quality, evaluate_speed
8
- from models import REGISTRY
9
  from report import print_report
10
  from wrapper import load_model
11
 
@@ -18,9 +18,15 @@ def main(argv: list[str] | None = None) -> None:
18
  parser.add_argument(
19
  "--models",
20
  nargs="+",
21
- default=list(REGISTRY.keys()),
22
- choices=list(REGISTRY.keys()),
23
- help="Models to benchmark (default: all)",
 
 
 
 
 
 
24
  )
25
  parser.add_argument("--corpus-size", type=int, default=1000)
26
  parser.add_argument("--batch-size", type=int, default=64)
@@ -61,6 +67,28 @@ def main(argv: list[str] | None = None) -> None:
61
 
62
  args = parser.parse_args(argv)
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  # Build list of dataset configs
65
  if args.dataset:
66
  # Custom dataset overrides presets
 
5
  from corpus import build_corpus
6
  from dataset_config import DATASET_PRESETS, DatasetConfig
7
  from evals import evaluate_memory, evaluate_quality, evaluate_speed
8
+ from models import REGISTRY, ModelConfig, load_custom_models_from_file, register_model
9
  from report import print_report
10
  from wrapper import load_model
11
 
 
18
  parser.add_argument(
19
  "--models",
20
  nargs="+",
21
+ default=None,
22
+ help="Models to benchmark (default: all registered)",
23
+ )
24
+ parser.add_argument(
25
+ "--add-model",
26
+ action="append",
27
+ default=[],
28
+ metavar="KEY:NAME:MODEL_ID:BACKEND[:GGUF_FILE]",
29
+ help="Register a custom model. Can be repeated.",
30
  )
31
  parser.add_argument("--corpus-size", type=int, default=1000)
32
  parser.add_argument("--batch-size", type=int, default=64)
 
67
 
68
  args = parser.parse_args(argv)
69
 
70
+ # Load persisted custom models and register any --add-model entries
71
+ load_custom_models_from_file()
72
+ for spec in args.add_model:
73
+ parts = spec.split(":")
74
+ if len(parts) < 4:
75
+ parser.error(f"--add-model requires KEY:NAME:MODEL_ID:BACKEND, got: {spec}")
76
+ key, name, model_id, backend = parts[0], parts[1], parts[2], parts[3]
77
+ gguf_file = parts[4] if len(parts) > 4 else None
78
+ try:
79
+ register_model(key, ModelConfig(
80
+ name=name, model_id=model_id, backend=backend, gguf_file=gguf_file,
81
+ ))
82
+ except ValueError as e:
83
+ parser.error(str(e))
84
+
85
+ if args.models is None:
86
+ args.models = list(REGISTRY.keys())
87
+ else:
88
+ for k in args.models:
89
+ if k not in REGISTRY:
90
+ parser.error(f"Unknown model key: '{k}'. Available: {list(REGISTRY.keys())}")
91
+
92
  # Build list of dataset configs
93
  if args.dataset:
94
  # Custom dataset overrides presets
models.py CHANGED
@@ -1,6 +1,8 @@
1
  from __future__ import annotations
2
 
3
- from dataclasses import dataclass
 
 
4
 
5
 
6
  @dataclass
@@ -43,3 +45,34 @@ REGISTRY: dict[str, ModelConfig] = {
43
  # backend="libembedding",
44
  # ),
45
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ import json
4
+ from dataclasses import asdict, dataclass
5
+ from pathlib import Path
6
 
7
 
8
  @dataclass
 
45
  # backend="libembedding",
46
  # ),
47
  }
48
+
49
+ VALID_BACKENDS = {"sbert", "fastembed", "libembedding", "gguf"}
50
+ CUSTOM_MODELS_PATH = Path(__file__).parent / "custom_models.json"
51
+
52
+
53
+ def register_model(key: str, config: ModelConfig) -> None:
54
+ if key in REGISTRY:
55
+ raise ValueError(f"Model key '{key}' already exists in registry")
56
+ if config.backend not in VALID_BACKENDS:
57
+ raise ValueError(f"Invalid backend '{config.backend}'. Must be one of: {VALID_BACKENDS}")
58
+ REGISTRY[key] = config
59
+
60
+
61
+ def load_custom_models_from_file(path: Path = CUSTOM_MODELS_PATH) -> None:
62
+ if not path.exists():
63
+ return
64
+ with open(path) as f:
65
+ entries = json.load(f)
66
+ for key, fields in entries.items():
67
+ if key not in REGISTRY:
68
+ REGISTRY[key] = ModelConfig(**fields)
69
+
70
+
71
+ def save_custom_model_to_file(key: str, config: ModelConfig, path: Path = CUSTOM_MODELS_PATH) -> None:
72
+ existing: dict = {}
73
+ if path.exists():
74
+ with open(path) as f:
75
+ existing = json.load(f)
76
+ existing[key] = asdict(config)
77
+ with open(path, "w") as f:
78
+ json.dump(existing, f, indent=2)