FruitQuality / app.py
lopesdri's picture
Update app.py
33d6349
raw
history blame contribute delete
925 Bytes
import gradio as gr
import torch
from torchvision.transforms import ToTensor
from torchvision.models import resnet50
from PIL import Image
import torch.nn as nn
# Load your PyTorch model
model = resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')))
classes = ['bom', 'ruim']
# Define the function for image classification
def classify_image(image):
image_tensor = ToTensor()(image).unsqueeze(0)
# Perform inference using your PyTorch model
with torch.no_grad():
model.eval()
outputs = model(image_tensor)
_, predicted = torch.max(outputs.data, 1)
return classes[predicted.item()]
# Define the Gradio interface
inputs = gr.Image()
outputs = gr.Label(num_top_classes=1)
interface = gr.Interface(fn=classify_image, inputs=inputs, outputs=outputs)
interface.launch(debug=True)