File size: 509 Bytes
ab80e91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

from inference import UniversalImageClassifier
import json

def load_model():
    """Load the model for inference"""
    with open("class_names.json", "r") as f:
        class_names = json.load(f)
    
    classifier = UniversalImageClassifier(
        model_path="pytorch_model.pth",
        config_path="config.json",
        class_names=class_names
    )
    return classifier

def predict(image_path):
    """Predict image class"""
    classifier = load_model()
    return classifier.predict(image_path)