crop-health / app.py
adityapathakk's picture
Create app.py
4a401bd
raw
history blame
No virus
1.97 kB
import os
import copy
import torch
import gradio
import gradio as gr
from PIL import Image
import torch.nn as nn
from torchvision import transforms, models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.system("wget https://www.dropbox.com/s/3us120bz5lhoh0t/model_best.pt?dl=0")
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
model.fc = nn.Linear(num_ftrs, 7)
model.load_state_dict(torch.load("./model_best.pt?dl=0", map_location=device))
# img = Image.open(path).convert('RGB')
# from torchvision import transforms
transforms2 = transforms.Compose([
transforms.Resize(256),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# img = transforms(img)
# img = img.unsqueeze(0)
# model.eval()
labels = ["Bacterialblight",
"Blast",
"Brownspot",
"Healthy",
"Hispa",
"LeafBlast",
"Tungro"]
# with torch.no_grad():
# # preds =
# preds = model(img)
# score, indices = torch.max(preds, 1)
def recognize_digit(image):
image = transforms2(image)
image = image.unsqueeze(0)
# image = image.unsqueeze(0)
# image = image.reshape(1, -1)
# with torch.no_grad():
# preds =
# img = image.reshape((-1, 3, 256, 256))
preds = model(image).flatten()
# prediction = model.predict(image).tolist()[0]
# score, indices = torch.max(preds, 1)
# return {str(indices.item())}
return {labels[i]: float(preds[i]) for i in range(7)}
im = gradio.inputs.Image(
shape=(256, 256), image_mode="RGB", type="pil")
iface = gr.Interface(
recognize_digit,
im,
gradio.outputs.Label(num_top_classes=3),
live=True,
interpretation="default",
# examples=[["images/cheetah1.jpg"], ["images/lion.jpg"]],
capture_session=True,
)
iface.test_launch()
iface.launch()