|
|
import torch, json |
|
|
import torchvision |
|
|
from torchvision import transforms, models |
|
|
from PIL import Image |
|
|
|
|
|
def build_model(arch, dropout, width, freeze_backbone, num_classes=2): |
|
|
import torch.nn as nn |
|
|
if arch == "smallcnn": |
|
|
class SmallCNN(nn.Module): |
|
|
def __init__(self, num_classes=2, dropout=0.2, width=32): |
|
|
super().__init__() |
|
|
c = width |
|
|
self.features = nn.Sequential( |
|
|
nn.Conv2d(3, c, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), |
|
|
nn.Conv2d(c, 2*c, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), |
|
|
nn.Conv2d(2*c, 4*c, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1), |
|
|
) |
|
|
self.head = nn.Sequential(nn.Flatten(), nn.Dropout(dropout), nn.Linear(4*c, num_classes)) |
|
|
def forward(self, x): return self.head(self.features(x)) |
|
|
return SmallCNN(num_classes=num_classes, dropout=dropout, width=width) |
|
|
elif arch == "resnet18": |
|
|
m = models.resnet18(weights=None) |
|
|
in_features = m.fc.in_features |
|
|
import torch.nn as nn |
|
|
m.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_features, num_classes)) |
|
|
return m |
|
|
elif arch == "mobilenet_v3_small": |
|
|
m = models.mobilenet_v3_small(weights=None) |
|
|
in_features = m.classifier[-1].in_features |
|
|
import torch.nn as nn |
|
|
m.classifier[-1] = nn.Linear(in_features, num_classes) |
|
|
return m |
|
|
else: |
|
|
raise ValueError("Unknown arch") |
|
|
|
|
|
def load_model(model_path="model_state.pt", config_path="config.json", device="cpu"): |
|
|
with open(config_path) as f: |
|
|
cfg = json.load(f) |
|
|
model = build_model(cfg["arch"], cfg["dropout"], cfg["width"], cfg["freeze_backbone"], cfg["num_classes"]) |
|
|
state = torch.load(model_path, map_location=device) |
|
|
model.load_state_dict(state, strict=True) |
|
|
model.to(device).eval() |
|
|
tfm = transforms.Compose([ |
|
|
transforms.Resize(int(cfg["img_size"]*1.14)), |
|
|
transforms.CenterCrop(cfg["img_size"]), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=cfg["mean"], std=cfg["std"]), |
|
|
]) |
|
|
return model, tfm, cfg |
|
|
|
|
|
def predict_image(image_path, model, tfm, device="cpu"): |
|
|
img = Image.open(image_path).convert("RGB") |
|
|
x = tfm(img).unsqueeze(0).to(device) |
|
|
with torch.no_grad(): |
|
|
logits = model(x) |
|
|
probs = torch.softmax(logits, dim=1).cpu().numpy().ravel().tolist() |
|
|
pred = int(logits.argmax(dim=1).item()) |
|
|
return pred, probs |
|
|
|