devjas1 commited on
Commit
71b3dbd
·
1 Parent(s): 078ed21

(FEAT)[Model Registry]: Expand registry with metadata and multi-model utilities

Browse files

- Added '_MODEL_SPECS' dictionary containing metadata (input length, classes, description, modalities, citation) for each model.
- Added '_FUTURE_MODELS' placeholder for planned architectures

New Functions:
'planned_models': List planned models.
'build_multiple': Instantiate multiple models for comparison.
'register_model': Dynamically add new models.
'get_model_info': Retrieve detailed model info.
'models_for_modality': List models supporting a modality.
'validate_model_list': Filter valid models from input.
- Updated 'spec' and error messages to use new metadata.
Expanded public API via '__all__'

Files changed (1) hide show
  1. models/registry.py +114 -11
models/registry.py CHANGED
@@ -1,35 +1,138 @@
1
  # models/registry.py
2
- from typing import Callable, Dict
3
  from models.figure2_cnn import Figure2CNN
4
  from models.resnet_cnn import ResNet1D
5
- from models.resnet18_vision import ResNet18Vision
6
 
7
  # Internal registry of model builders keyed by short name.
8
  _REGISTRY: Dict[str, Callable[[int], object]] = {
9
  "figure2": lambda L: Figure2CNN(input_length=L),
10
  "resnet": lambda L: ResNet1D(input_length=L),
11
- "resnet18vision": lambda L: ResNet18Vision(input_length=L)
12
  }
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def choices():
15
  """Return the list of available model keys."""
16
  return list(_REGISTRY.keys())
17
 
 
 
 
 
 
 
18
  def build(name: str, input_length: int):
19
  """Instantiate a model by short name with the given input length."""
20
  if name not in _REGISTRY:
21
  raise ValueError(f"Unknown model '{name}'. Choices: {choices()}")
22
  return _REGISTRY[name](input_length)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def spec(name: str):
25
  """Return expected input length and number of classes for a model key."""
26
- if name == "figure2":
27
- return {"input_length": 500, "num_classes": 2}
28
- if name == "resnet":
29
- return {"input_length": 500, "num_classes": 2}
30
- if name == "resnet18vision":
31
- return {"input_length": 500, "num_classes": 2}
32
- raise KeyError(f"Unknown model '{name}'")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
 
35
- __all__ = ["choices", "build"]
 
 
 
 
 
 
 
 
 
 
 
1
  # models/registry.py
2
+ from typing import Callable, Dict, List, Any
3
  from models.figure2_cnn import Figure2CNN
4
  from models.resnet_cnn import ResNet1D
5
+ from models.resnet18_vision import ResNet18Vision
6
 
7
  # Internal registry of model builders keyed by short name.
8
  _REGISTRY: Dict[str, Callable[[int], object]] = {
9
  "figure2": lambda L: Figure2CNN(input_length=L),
10
  "resnet": lambda L: ResNet1D(input_length=L),
11
+ "resnet18vision": lambda L: ResNet18Vision(input_length=L),
12
  }
13
 
14
+ # Model specifications with metadata for enhanced features
15
+ _MODEL_SPECS: Dict[str, Dict[str, Any]] = {
16
+ "figure2": {
17
+ "input_length": 500,
18
+ "num_classes": 2,
19
+ "description": "Figure 2 baseline custom implemetation",
20
+ "modalities": ["raman", "ftir"],
21
+ "citation": "Neo et al., 2023, Resour. Conserv. Recycl., 188, 106718",
22
+ },
23
+ "resnet": {
24
+ "input_length": 500,
25
+ "num_classes": 2,
26
+ "description": "(Residual Network) uses skip connections to train much deeper networks",
27
+ "modalities": ["raman", "ftir"],
28
+ "citation": "Custom ResNet implementation",
29
+ },
30
+ "resnet18vision": {
31
+ "input_length": 500,
32
+ "num_classes": 2,
33
+ "description": "excels at image recognition tasks by using 'residual blocks' to train more efficiently",
34
+ "modalities": ["raman", "ftir"],
35
+ "citation": "ResNet18 Vision adaptation",
36
+ },
37
+ }
38
+
39
+ # Placeholder for future model expansions
40
+ _FUTURE_MODELS = {
41
+ "densenet1d": {
42
+ "description": "DenseNet1D for spectroscopy (placeholder)",
43
+ "status": "planned",
44
+ },
45
+ "ensemble_cnn": {
46
+ "description": "Ensemble of CNN variants (placeholder)",
47
+ "status": "planned",
48
+ },
49
+ }
50
+
51
+
52
  def choices():
53
  """Return the list of available model keys."""
54
  return list(_REGISTRY.keys())
55
 
56
+
57
+ def planned_models():
58
+ """Return the list of planned future model keys."""
59
+ return list(_FUTURE_MODELS.keys())
60
+
61
+
62
  def build(name: str, input_length: int):
63
  """Instantiate a model by short name with the given input length."""
64
  if name not in _REGISTRY:
65
  raise ValueError(f"Unknown model '{name}'. Choices: {choices()}")
66
  return _REGISTRY[name](input_length)
67
 
68
+
69
+ def build_multiple(names: List[str], input_length: int) -> Dict[str, Any]:
70
+ """Nuild multiple models for comparison."""
71
+ models = {}
72
+ for name in names:
73
+ if name in _REGISTRY:
74
+ models[name] = build(name, input_length)
75
+ else:
76
+ raise ValueError(f"Unknown model '{name}'. Available: {choices()}")
77
+ return models
78
+
79
+
80
+ def register_model(
81
+ name: str, builder: Callable[[int], object], spec: Dict[str, Any]
82
+ ) -> None:
83
+ """Dynamically register a new model."""
84
+ if name in _REGISTRY:
85
+ raise ValueError(f"Model '{name}' already registered.")
86
+ if not callable(builder):
87
+ raise TypeError("Builder must be a callable that accepts an integer argument.")
88
+ _REGISTRY[name] = builder
89
+ _MODEL_SPECS[name] = spec
90
+
91
+
92
  def spec(name: str):
93
  """Return expected input length and number of classes for a model key."""
94
+ if name in _MODEL_SPECS:
95
+ return _MODEL_SPECS[name].copy()
96
+ raise KeyError(f"Unknown model '{name}'. Available: {choices()}")
97
+
98
+
99
+ def get_model_info(name: str) -> Dict[str, Any]:
100
+ """Get comprehensive model information including metadata."""
101
+ if name in _MODEL_SPECS:
102
+ return _MODEL_SPECS[name].copy()
103
+ elif name in _FUTURE_MODELS:
104
+ return _FUTURE_MODELS[name].copy()
105
+ else:
106
+ raise KeyError(f"Unknown model '{name}'")
107
+
108
+
109
+ def models_for_modality(modality: str) -> List[str]:
110
+ """Get list of models that support a specific modality."""
111
+ compatible = []
112
+ for name, spec_info in _MODEL_SPECS.items():
113
+ if modality in spec_info.get("modalities", []):
114
+ compatible.append(name)
115
+ return compatible
116
+
117
+
118
+ def validate_model_list(names: List[str]) -> List[str]:
119
+ """Validate and return list of available models from input list."""
120
+ available = choices()
121
+ valid_models = []
122
+ for name in names:
123
+ if name is available:
124
+ valid_models.append(name)
125
+ return valid_models
126
 
127
 
128
+ __all__ = [
129
+ "choices",
130
+ "build",
131
+ "spec",
132
+ "build_multiple",
133
+ "register_model",
134
+ "get_model_info",
135
+ "models_for_modality",
136
+ "validate_model_list",
137
+ "planned_models",
138
+ ]