devjas1 commited on
Commit
ba24c6a
·
1 Parent(s): 4b9a18f

(feat): add ResNet18Vision (1D); register; inference --arch supports it

Browse files
models/registry.py CHANGED
@@ -2,13 +2,13 @@
2
  from typing import Callable, Dict
3
  from models.figure2_cnn import Figure2CNN
4
  from models.resnet_cnn import ResNet1D
5
- # from models.resnet18_vision import ResNet18Vision # (Step 2)
6
 
7
  # Internal registry of model builders keyed by short name.
8
  _REGISTRY: Dict[str, Callable[[int], object]] = {
9
  "figure2": lambda L: Figure2CNN(input_length=L),
10
  "resnet": lambda L: ResNet1D(input_length=L),
11
- # "resnet18vision": lambda L: ResNet18Vision(input_length=L)
12
  }
13
 
14
  def choices():
 
2
  from typing import Callable, Dict
3
  from models.figure2_cnn import Figure2CNN
4
  from models.resnet_cnn import ResNet1D
5
+ from models.resnet18_vision import ResNet18Vision
6
 
7
  # Internal registry of model builders keyed by short name.
8
  _REGISTRY: Dict[str, Callable[[int], object]] = {
9
  "figure2": lambda L: Figure2CNN(input_length=L),
10
  "resnet": lambda L: ResNet1D(input_length=L),
11
+ "resnet18vision": lambda L: ResNet18Vision(input_length=L)
12
  }
13
 
14
  def choices():
models/resnet18_vision.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/resnet18_vision.py
2
+ # 1D ResNet-18 style model for spectra: input (B, 1, L)
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Callable, List
6
+
7
+ class BasicBlock1D(nn.Module):
8
+ expansion = 1
9
+ def __init__(self, in_planes: int, planes: int, stride: int = 1, downsample: nn.Module | None = None):
10
+ super().__init__()
11
+ self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
12
+ self.bn1 = nn.BatchNorm1d(planes)
13
+ self.relu = nn.ReLU(inplace=True)
14
+ self.conv2 = nn.Conv1d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
15
+ self.bn2 = nn.BatchNorm1d(planes)
16
+ self.downsample = downsample
17
+
18
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
19
+ identity = x
20
+ out = self.relu(self.bn1(self.conv1(x)))
21
+ out = self.bn2(self.conv2(out))
22
+ if self.downsample is not None:
23
+ identity = self.downsample(x)
24
+ out += identity
25
+ out = self.relu(out)
26
+ return out
27
+
28
+ def _make_layer(block: Callable[..., nn.Module], in_planes: int, planes: int, blocks: int, stride: int) -> nn.Sequential:
29
+ downsample = None
30
+ if stride != 1 or in_planes != planes * block.expansion:
31
+ downsample = nn.Sequential(
32
+ nn.Conv1d(in_planes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
33
+ nn.BatchNorm1d(planes * block.expansion),
34
+ )
35
+ layers: List[nn.Module] = [block(in_planes, planes, stride, downsample)]
36
+ in_planes = planes * block.expansion
37
+ for _ in range(1, blocks):
38
+ layers.append(block(in_planes, planes))
39
+ return nn.Sequential(*layers)
40
+
41
+ class ResNet18Vision(nn.Module):
42
+ def __init__(self, input_length: int = 500, num_classes: int = 2):
43
+ super().__init__()
44
+ # 1D stem
45
+ self.conv1 = nn.Conv1d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
46
+ self.bn1 = nn.BatchNorm1d(64)
47
+ self.relu = nn.ReLU(inplace=True)
48
+ self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
49
+
50
+ # ResNet-18: 2 blocks per layer
51
+ self.layer1 = _make_layer(BasicBlock1D, 64, 64, blocks=2, stride=1)
52
+ self.layer2 = _make_layer(BasicBlock1D, 64, 128, blocks=2, stride=2)
53
+ self.layer3 = _make_layer(BasicBlock1D, 128, 256, blocks=2, stride=2)
54
+ self.layer4 = _make_layer(BasicBlock1D, 256, 512, blocks=2, stride=2)
55
+
56
+ # Global pooling + classifier
57
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
58
+ self.fc = nn.Linear(512 * BasicBlock1D.expansion, num_classes)
59
+
60
+ # Kaiming init
61
+ for m in self.modules():
62
+ if isinstance(m, nn.Conv1d):
63
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
64
+ elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)):
65
+ nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
66
+
67
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
68
+ # x: (B, 1, L)
69
+ x = self.relu(self.bn1(self.conv1(x)))
70
+ x = self.maxpool(x)
71
+ x = self.layer1(x)
72
+ x = self.layer2(x)
73
+ x = self.layer3(x)
74
+ x = self.layer4(x)
75
+ x = self.avgpool(x) # (B, C, 1)
76
+ x = torch.flatten(x, 1) # (B, C)
77
+ x = self.fc(x) # (B, num_classes)
78
+ return x
scripts/run_inference.py CHANGED
@@ -9,9 +9,11 @@ import logging
9
 
10
  import numpy as np
11
  import torch
 
12
 
13
- from models.figure2_cnn import Figure2CNN
14
  from scripts.preprocess_dataset import resample_spectrum, label_file
 
 
15
 
16
 
17
  # =============================================
@@ -49,6 +51,8 @@ if __name__ == "__main__":
49
  parser = argparse.ArgumentParser(
50
  description="Run inference on a single Raman spectrum (.txt file)."
51
  )
 
 
52
  parser.add_argument(
53
  "--target-len", type=int, required=True,
54
  help="Target length to match model input"
@@ -96,18 +100,17 @@ if __name__ == "__main__":
96
 
97
  data = resample_spectrum(x_raw, y_raw, target_len=args.target_len)
98
  # Shape = (1, 1, target_len) — valid input for Raman inference
99
- input_tensor = torch.tensor(data, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
 
100
 
101
- # 2. Load Model
102
- model = Figure2CNN(
103
- input_length=args.target_len,
104
- input_channels=1
105
- )
106
  if args.model != "random":
107
- model.load_state_dict(
108
- torch.load(args.model, map_location="cpu", weights_only=True)
109
- )
110
  model.eval()
 
 
111
 
112
  # 3. Inference
113
  with torch.no_grad():
 
9
 
10
  import numpy as np
11
  import torch
12
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
 
14
  from scripts.preprocess_dataset import resample_spectrum, label_file
15
+ from models.registry import choices as model_choices, build as build_model
16
+
17
 
18
 
19
  # =============================================
 
51
  parser = argparse.ArgumentParser(
52
  description="Run inference on a single Raman spectrum (.txt file)."
53
  )
54
+ parser.add_argument("--arch", type=str, default="figure2", choices=model_choices(),
55
+ help="Model architecture (must match the provided weights).") # NEW
56
  parser.add_argument(
57
  "--target-len", type=int, required=True,
58
  help="Target length to match model input"
 
100
 
101
  data = resample_spectrum(x_raw, y_raw, target_len=args.target_len)
102
  # Shape = (1, 1, target_len) — valid input for Raman inference
103
+ input_tensor = torch.tensor(data, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(DEVICE)
104
+
105
 
106
+ # 2. Load Model (via shared model registry)
107
+ model = build_model(args.arch, args.target_len).to(DEVICE)
 
 
 
108
  if args.model != "random":
109
+ state = torch.load(args.model, map_location="cpu") # broad compatibility
110
+ model.load_state_dict(state)
 
111
  model.eval()
112
+
113
+
114
 
115
  # 3. Inference
116
  with torch.no_grad():