File size: 1,289 Bytes
218c86b
 
 
 
ba24c6a
218c86b
 
 
 
 
ba24c6a
218c86b
 
 
 
 
 
 
 
 
 
 
 
6373c5a
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# models/registry.py
from typing import Callable, Dict
from models.figure2_cnn import Figure2CNN
from models.resnet_cnn import ResNet1D
from models.resnet18_vision import ResNet18Vision 

# Internal registry of model builders keyed by short name.
_REGISTRY: Dict[str, Callable[[int], object]] = {
    "figure2": lambda L: Figure2CNN(input_length=L),
    "resnet": lambda L: ResNet1D(input_length=L),
    "resnet18vision": lambda L: ResNet18Vision(input_length=L)
}

def choices():
    """Return the list of available model keys."""
    return list(_REGISTRY.keys())

def build(name: str, input_length: int):
    """Instantiate a model by short name with the given input length."""
    if name not in _REGISTRY:
        raise ValueError(f"Unknown model '{name}'. Choices: {choices()}")
    return _REGISTRY[name](input_length)

def spec(name: str):
    """Return expected input length and number of classes for a model key."""
    if name == "figure2":
        return {"input_length": 500, "num_classes": 2}
    if name == "resnet":
        return {"input_length": 500, "num_classes": 2}
    if name == "resnet18vision":
        return {"input_length": 500, "num_classes": 2}
    raise KeyError(f"Unknown model '{name}'")


__all__ = ["choices", "build"]