File size: 509 Bytes
ab80e91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
from inference import UniversalImageClassifier
import json
def load_model():
"""Load the model for inference"""
with open("class_names.json", "r") as f:
class_names = json.load(f)
classifier = UniversalImageClassifier(
model_path="pytorch_model.pth",
config_path="config.json",
class_names=class_names
)
return classifier
def predict(image_path):
"""Predict image class"""
classifier = load_model()
return classifier.predict(image_path)
|