AvianVision / app.py
Vedmani's picture
fixed flagging parameter
8bbed7e
raw
history blame contribute delete
No virus
2.15 kB
from models import EfficientNet
from utils import get_device
import torch
import json
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import json
import timm
from torch import nn
import torch.nn.functional as F
def load_efficientnet_model(model_path: str, device=get_device()):
"""
Load a PyTorch model checkpoint.
Args:
model_path: The path of the checkpoint file.
device: The device to load the model onto.
Returns:
The model loaded onto the specified device.
"""
# Initialize model
model = EfficientNet()
# Load model weights onto the specified device
model.load_state_dict(torch.load(model_path, map_location=device)['model_state_dict'])
# Set model to evaluation mode
model.eval()
return model
with open('idx_to_class.json', 'r') as f:
idx_to_class = json.load(f)
def predict_image(array):
"""
Predict the class of an image.
Args:
array: The image data as an array.
Returns:
The predicted class.
"""
# Convert the image to a PIL Image object
input_image = Image.fromarray(array)
# Load the model
model = load_efficientnet_model('efficientnet_epoch=18_loss=0.0020_val_f1score=0.8993.pth')
# Transform the image
transform = transforms.Compose([
transforms.Resize(size=(150, 150)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
image = transform(input_image).unsqueeze(0)
image.to(get_device())
# Predict the class
with torch.no_grad():
output = model(image)
# Apply softmax to the outputs to convert them into probabilities
probabilities = F.softmax(output, dim=1)
predicted = probabilities.argmax().item()
predicted_class = idx_to_class[str(predicted)] # Make sure your keys in json are string type
return predicted_class
# Create the image classifier
image_classifier = gr.Interface(fn=predict_image, inputs="image", outputs="text", allow_flagging='never')
# Launch the image classifier
image_classifier.launch(share=False)