Spaces:
Sleeping
Sleeping
devjas1
commited on
Commit
·
71b3dbd
1
Parent(s):
078ed21
(FEAT)[Model Registry]: Expand registry with metadata and multi-model utilities
Browse files- Added '_MODEL_SPECS' dictionary containing metadata (input length, classes, description, modalities, citation) for each model.
- Added '_FUTURE_MODELS' placeholder for planned architectures
New Functions:
'planned_models': List planned models.
'build_multiple': Instantiate multiple models for comparison.
'register_model': Dynamically add new models.
'get_model_info': Retrieve detailed model info.
'models_for_modality': List models supporting a modality.
'validate_model_list': Filter valid models from input.
- Updated 'spec' and error messages to use new metadata.
Expanded public API via '__all__'
- models/registry.py +114 -11
models/registry.py
CHANGED
@@ -1,35 +1,138 @@
|
|
1 |
# models/registry.py
|
2 |
-
from typing import Callable, Dict
|
3 |
from models.figure2_cnn import Figure2CNN
|
4 |
from models.resnet_cnn import ResNet1D
|
5 |
-
from models.resnet18_vision import ResNet18Vision
|
6 |
|
7 |
# Internal registry of model builders keyed by short name.
|
8 |
_REGISTRY: Dict[str, Callable[[int], object]] = {
|
9 |
"figure2": lambda L: Figure2CNN(input_length=L),
|
10 |
"resnet": lambda L: ResNet1D(input_length=L),
|
11 |
-
"resnet18vision": lambda L: ResNet18Vision(input_length=L)
|
12 |
}
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
def choices():
|
15 |
"""Return the list of available model keys."""
|
16 |
return list(_REGISTRY.keys())
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def build(name: str, input_length: int):
|
19 |
"""Instantiate a model by short name with the given input length."""
|
20 |
if name not in _REGISTRY:
|
21 |
raise ValueError(f"Unknown model '{name}'. Choices: {choices()}")
|
22 |
return _REGISTRY[name](input_length)
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def spec(name: str):
|
25 |
"""Return expected input length and number of classes for a model key."""
|
26 |
-
if name
|
27 |
-
return
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
|
35 |
-
__all__ = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# models/registry.py
|
2 |
+
from typing import Callable, Dict, List, Any
|
3 |
from models.figure2_cnn import Figure2CNN
|
4 |
from models.resnet_cnn import ResNet1D
|
5 |
+
from models.resnet18_vision import ResNet18Vision
|
6 |
|
7 |
# Internal registry of model builders keyed by short name.
|
8 |
_REGISTRY: Dict[str, Callable[[int], object]] = {
|
9 |
"figure2": lambda L: Figure2CNN(input_length=L),
|
10 |
"resnet": lambda L: ResNet1D(input_length=L),
|
11 |
+
"resnet18vision": lambda L: ResNet18Vision(input_length=L),
|
12 |
}
|
13 |
|
14 |
+
# Model specifications with metadata for enhanced features
|
15 |
+
_MODEL_SPECS: Dict[str, Dict[str, Any]] = {
|
16 |
+
"figure2": {
|
17 |
+
"input_length": 500,
|
18 |
+
"num_classes": 2,
|
19 |
+
"description": "Figure 2 baseline custom implemetation",
|
20 |
+
"modalities": ["raman", "ftir"],
|
21 |
+
"citation": "Neo et al., 2023, Resour. Conserv. Recycl., 188, 106718",
|
22 |
+
},
|
23 |
+
"resnet": {
|
24 |
+
"input_length": 500,
|
25 |
+
"num_classes": 2,
|
26 |
+
"description": "(Residual Network) uses skip connections to train much deeper networks",
|
27 |
+
"modalities": ["raman", "ftir"],
|
28 |
+
"citation": "Custom ResNet implementation",
|
29 |
+
},
|
30 |
+
"resnet18vision": {
|
31 |
+
"input_length": 500,
|
32 |
+
"num_classes": 2,
|
33 |
+
"description": "excels at image recognition tasks by using 'residual blocks' to train more efficiently",
|
34 |
+
"modalities": ["raman", "ftir"],
|
35 |
+
"citation": "ResNet18 Vision adaptation",
|
36 |
+
},
|
37 |
+
}
|
38 |
+
|
39 |
+
# Placeholder for future model expansions
|
40 |
+
_FUTURE_MODELS = {
|
41 |
+
"densenet1d": {
|
42 |
+
"description": "DenseNet1D for spectroscopy (placeholder)",
|
43 |
+
"status": "planned",
|
44 |
+
},
|
45 |
+
"ensemble_cnn": {
|
46 |
+
"description": "Ensemble of CNN variants (placeholder)",
|
47 |
+
"status": "planned",
|
48 |
+
},
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
def choices():
|
53 |
"""Return the list of available model keys."""
|
54 |
return list(_REGISTRY.keys())
|
55 |
|
56 |
+
|
57 |
+
def planned_models():
|
58 |
+
"""Return the list of planned future model keys."""
|
59 |
+
return list(_FUTURE_MODELS.keys())
|
60 |
+
|
61 |
+
|
62 |
def build(name: str, input_length: int):
|
63 |
"""Instantiate a model by short name with the given input length."""
|
64 |
if name not in _REGISTRY:
|
65 |
raise ValueError(f"Unknown model '{name}'. Choices: {choices()}")
|
66 |
return _REGISTRY[name](input_length)
|
67 |
|
68 |
+
|
69 |
+
def build_multiple(names: List[str], input_length: int) -> Dict[str, Any]:
|
70 |
+
"""Nuild multiple models for comparison."""
|
71 |
+
models = {}
|
72 |
+
for name in names:
|
73 |
+
if name in _REGISTRY:
|
74 |
+
models[name] = build(name, input_length)
|
75 |
+
else:
|
76 |
+
raise ValueError(f"Unknown model '{name}'. Available: {choices()}")
|
77 |
+
return models
|
78 |
+
|
79 |
+
|
80 |
+
def register_model(
|
81 |
+
name: str, builder: Callable[[int], object], spec: Dict[str, Any]
|
82 |
+
) -> None:
|
83 |
+
"""Dynamically register a new model."""
|
84 |
+
if name in _REGISTRY:
|
85 |
+
raise ValueError(f"Model '{name}' already registered.")
|
86 |
+
if not callable(builder):
|
87 |
+
raise TypeError("Builder must be a callable that accepts an integer argument.")
|
88 |
+
_REGISTRY[name] = builder
|
89 |
+
_MODEL_SPECS[name] = spec
|
90 |
+
|
91 |
+
|
92 |
def spec(name: str):
|
93 |
"""Return expected input length and number of classes for a model key."""
|
94 |
+
if name in _MODEL_SPECS:
|
95 |
+
return _MODEL_SPECS[name].copy()
|
96 |
+
raise KeyError(f"Unknown model '{name}'. Available: {choices()}")
|
97 |
+
|
98 |
+
|
99 |
+
def get_model_info(name: str) -> Dict[str, Any]:
|
100 |
+
"""Get comprehensive model information including metadata."""
|
101 |
+
if name in _MODEL_SPECS:
|
102 |
+
return _MODEL_SPECS[name].copy()
|
103 |
+
elif name in _FUTURE_MODELS:
|
104 |
+
return _FUTURE_MODELS[name].copy()
|
105 |
+
else:
|
106 |
+
raise KeyError(f"Unknown model '{name}'")
|
107 |
+
|
108 |
+
|
109 |
+
def models_for_modality(modality: str) -> List[str]:
|
110 |
+
"""Get list of models that support a specific modality."""
|
111 |
+
compatible = []
|
112 |
+
for name, spec_info in _MODEL_SPECS.items():
|
113 |
+
if modality in spec_info.get("modalities", []):
|
114 |
+
compatible.append(name)
|
115 |
+
return compatible
|
116 |
+
|
117 |
+
|
118 |
+
def validate_model_list(names: List[str]) -> List[str]:
|
119 |
+
"""Validate and return list of available models from input list."""
|
120 |
+
available = choices()
|
121 |
+
valid_models = []
|
122 |
+
for name in names:
|
123 |
+
if name is available:
|
124 |
+
valid_models.append(name)
|
125 |
+
return valid_models
|
126 |
|
127 |
|
128 |
+
__all__ = [
|
129 |
+
"choices",
|
130 |
+
"build",
|
131 |
+
"spec",
|
132 |
+
"build_multiple",
|
133 |
+
"register_model",
|
134 |
+
"get_model_info",
|
135 |
+
"models_for_modality",
|
136 |
+
"validate_model_list",
|
137 |
+
"planned_models",
|
138 |
+
]
|