|
import torch |
|
from PIL import Image |
|
from .preprocess import preprocess_image |
|
from .utils import load_model |
|
|
|
|
|
def predict_with_model(model, inputs): |
|
"""Runs inference and returns the predicted class.""" |
|
model.eval() |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
predicted_class = logits.argmax(dim=-1).item() |
|
return predicted_class |
|
|
|
|
|
def predict(image_path): |
|
"""Loads an image, preprocesses it, runs the model, and returns the prediction.""" |
|
image = Image.open(image_path).convert("RGB") |
|
inputs = preprocess_image(image) |
|
|
|
|
|
model = load_model() |
|
|
|
|
|
device = model.device |
|
inputs = {key: tensor.to(device) for key, tensor in inputs.items()} |
|
|
|
return predict_with_model(model, inputs) |
|
|