|
import os |
|
import yaml |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchvision import models, transforms |
|
import gradio as gr |
|
|
|
|
|
cfg = yaml.safe_load(open("config/prod.yaml")) |
|
with open("class_names.txt") as f: |
|
class_names = [line.strip() for line in f] |
|
|
|
|
|
|
|
def build_model(num_classes: int) -> nn.Module: |
|
model = models.efficientnet_b2(weights=models.EfficientNet_B2_Weights.IMAGENET1K_V1) |
|
in_features = model.classifier[1].in_features |
|
model.classifier[1] = nn.Linear(in_features, num_classes) |
|
return model |
|
|
|
|
|
def load_model(path: str, num_classes: int) -> nn.Module: |
|
model = build_model(num_classes) |
|
state = torch.load(path, map_location="cpu") |
|
state = {k.replace("_orig_mod.", ""): v for k, v in state.items()} |
|
model.load_state_dict(state) |
|
model.eval() |
|
return model |
|
|
|
model = load_model("output/model.pth", len(class_names)) |
|
|
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(cfg["estimator"]["hyperparameters"]["img-size"]), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|
]) |
|
|
|
|
|
def predict(image): |
|
image = preprocess(image).unsqueeze(0) |
|
with torch.no_grad(): |
|
outputs = model(image) |
|
probs = F.softmax(outputs, dim=1)[0] |
|
return {class_names[i]: float(probs[i]) for i in probs.argsort(descending=True)} |
|
|
|
|
|
example_dir = "examples" |
|
if os.path.isdir(example_dir): |
|
examples = [ |
|
[os.path.join(example_dir, f)] |
|
for f in os.listdir(example_dir) |
|
if f.lower().endswith((".png", ".jpg", ".jpeg")) |
|
] |
|
else: |
|
examples = None |
|
|
|
|
|
gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Label(num_top_classes=5, label="Top Classes"), |
|
title="Food101 Classifier", |
|
examples=examples, |
|
).launch() |
|
|