Spaces:
Runtime error
Runtime error
import gradio as gr | |
import requests | |
import torch | |
import torch.nn as nn | |
from PIL import Image | |
from torchvision.models import resnet18 | |
from torchvision.transforms import functional as F | |
def main(): | |
# ใขใใซใฎ้ธๆ | |
plant_models = { | |
'cucumber':{ | |
'model_path':'cucumber_resnet18_last_model.pth', | |
'labels_list' : ["ๅฅๅ จ","ใใฉใใ็ ","็ฐ่ฒใใณ็ ","็ญ็ฝ็ ","ในใจ็ ","่คๆ็ ","ใคใๆฏ็ ","ๆ็น็ดฐ่็ ","CCYV","ใขใถใคใฏ็ ","MYSV"] | |
}, | |
'eggplant':{ | |
'model_path':'eggplant_resnet18_last_model.pth', | |
'labels_list' : ["ๅฅๅ จ","ใใฉใใ็ ","็ฐ่ฒใใณ็ ","่ค่ฒๅๆ็ ","ใใใใณ็ ","ๅ่บซ่ๅ็ ","้ๆฏ็ "] | |
}, | |
'strawberry':{ | |
'model_path':'strawberry_resnet18_last_model.pth', | |
'labels_list' : ["ๅฅๅ จ","ใใฉใใ็ ","็ญ็ฝ็ ","่้ป็ "] | |
}, | |
'tomato':{ | |
'model_path':'tomato_resnet18_last_model.pth', | |
'labels_list' : ["ๅฅๅ จ","ใใฉใใ็ ","็ฐ่ฒใใณ็ ","ใใใใณ็ ","่ใใณ็ ","็ซ็ ","่ค่ฒ่ผช็ด็ ","้ๆฏ็ ","ใใใใ็ ","้ปๅ่ๅทป็ "] | |
}, | |
} | |
# examples_images=[ | |
# ['image/231305_20200302150233_01.JPG'], | |
# ['image/0004_20181120084837_01.jpg'], | |
# ['image/160001_20170830173740_01.JPG'], | |
# ['image/152300_20190119175054_01.JPG'], | |
# ] | |
# ใขใใซใฎๆบๅใใ้ขๆฐใๅฎ็พฉ | |
def select_model(plant_name): | |
model_ft = resnet18(num_classes = len(plant_models[plant_name]['labels_list']),pretrained=False) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model_ft = model_ft.to(device) | |
if torch.cuda.is_available(): | |
model_ft.load_state_dict(torch.load(plant_models[plant_name]['model_path'])) | |
else: | |
model_ft.load_state_dict( | |
torch.load(plant_models[plant_name]['model_path'], map_location=torch.device("cpu")) | |
) | |
model_ft.eval() | |
return model_ft | |
# ็ปๅๅ้กใ่กใ้ขๆฐใๅฎ็พฉ | |
def inference(gr_input,gr_model_type): | |
img = Image.fromarray(gr_input.astype("uint8"), "RGB") | |
# ๅๅฆ็ | |
img = F.resize(img, (224, 224)) | |
img = F.to_tensor(img) | |
img = img.unsqueeze(0) | |
# ใขใใซ้ธๆ | |
model_ft = select_model(gr_model_type) | |
# ๆจ่ซ | |
output = model_ft(img).squeeze(0) | |
probs = nn.functional.softmax(output, dim=0).numpy() | |
labels_lenght =len(plant_models[gr_model_type]['labels_list']) | |
# ใฉใใซใใจใฎ็ขบ็ใdictใจใใฆ่ฟใ | |
return {plant_models[gr_model_type]['labels_list'][i]: float(probs[i]) for i in range(labels_lenght)} | |
model_labels = list(plant_models.keys()) | |
# ๅ ฅๅใฎๅฝขๅผใ็ปๅใจใใ | |
inputs = gr.inputs.Image() | |
# ใขใใซใฎ็จฎ้กใ้ธๆใใ | |
model_type = gr.inputs.Radio(model_labels, type='value', label='BASE MODEL') | |
# ๅบๅใฏใฉใใซๅฝขๅผใง๏ผtop4ใพใง่กจ็คบใใ | |
outputs = gr.outputs.Label(num_top_classes=4) | |
# ใตใผใใผใฎ็ซใกไธใ | |
interface = gr.Interface(fn=inference, | |
inputs=[inputs, model_type], | |
outputs=outputs, | |
title='Plant Diseases Diagnosis', | |
) | |
interface.launch() | |
if __name__ == "__main__": | |
main() |