gajavegs commited on
Commit
7c50ef1
·
verified ·
1 Parent(s): cd73891

Update model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +14 -21
model_loader.py CHANGED
@@ -2,7 +2,6 @@ import torch
2
  import torch.nn as nn
3
  from torchvision import models, transforms
4
  from PIL import Image
5
- import os
6
 
7
  def build_alexnet(num_classes=2):
8
  model = models.alexnet(pretrained=False)
@@ -10,26 +9,20 @@ def build_alexnet(num_classes=2):
10
  model.classifier[6] = nn.Linear(in_features, num_classes)
11
  return model
12
 
13
- def load_alexnet_model(model_path):
14
- checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
15
- model = build_alexnet(len(checkpoint['classes']))
16
- model.load_state_dict(checkpoint['model_state'])
 
 
 
17
  model.eval()
18
- return model, checkpoint['classes']
19
 
20
- DEFAULT_IMG_SIZE = int(os.getenv("IMG_SIZE", "224")) # changed default to match training (224x224)
21
-
22
- def preprocess_image(image: Image.Image, normalize: bool = True) -> torch.Tensor:
23
- transform_list = [
24
- transforms.Resize((DEFAULT_IMG_SIZE, DEFAULT_IMG_SIZE), Image.Resampling.LANCZOS),
25
  transforms.ToTensor(),
26
- ]
27
-
28
- # default to ImageNet stats used in training; can override with IMG_MEAN / IMG_STD env vars
29
- IMAGENET_MEAN = list(map(float, os.getenv("IMG_MEAN", "0.485,0.456,0.406").split(",")))
30
- IMAGENET_STD = list(map(float, os.getenv("IMG_STD", "0.229,0.224,0.225").split(",")))
31
-
32
- if normalize:
33
- transform_list.append(transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD))
34
- transform = transforms.Compose(transform_list)
35
- return transform(image).unsqueeze(0)
 
2
  import torch.nn as nn
3
  from torchvision import models, transforms
4
  from PIL import Image
 
5
 
6
  def build_alexnet(num_classes=2):
7
  model = models.alexnet(pretrained=False)
 
9
  model.classifier[6] = nn.Linear(in_features, num_classes)
10
  return model
11
 
12
+ def load_alexnet_model(model_path, device=None):
13
+ # Load weights on CPU first (safer with CUDA init)
14
+ checkpoint = torch.load(model_path, map_location="cpu")
15
+ model = build_alexnet(len(checkpoint["classes"]))
16
+ model.load_state_dict(checkpoint["model_state"])
17
+ if device is not None:
18
+ model.to(device)
19
  model.eval()
20
+ return model, checkpoint["classes"]
21
 
22
+ def preprocess_image(image: Image.Image) -> torch.Tensor:
23
+ transform = transforms.Compose([
24
+ transforms.Resize((224, 224)),
 
 
25
  transforms.ToTensor(),
26
+ transforms.Normalize([0.4914,0.4822,0.4465], [0.2470,0.2435,0.2616]), # CIFAR MEAN and STD
27
+ ])
28
+ return transform(image).unsqueeze(0)