Spaces:
Runtime error
Runtime error
File size: 3,454 Bytes
a50af6a 026895a 3d1c14a 36743cd 3d1c14a 03abd43 c43bd81 03abd43 f0ac27b b34644c 36743cd f0ac27b 36743cd f0ac27b 36743cd f0ac27b a50af6a f0ac27b a50af6a f0ac27b a50af6a 36743cd a50af6a 36743cd f0ac27b 36743cd 99312fa a50af6a f0ac27b 3d1c14a b34644c f0ac27b a50af6a 96563da 06de145 2f9e95a f9f2226 b34644c ea493f4 a50af6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import gradio as gr
import requests
import torch
import torch.nn as nn
from PIL import Image
from torchvision.models import resnet18
from torchvision.transforms import functional as F
def main():
# ใขใใซใฎ้ธๆ
plant_models = {
'cucumber':{
'model_path':'cucumber_resnet18_last_model.pth',
'labels_list' : ["ๅฅๅ
จ","ใใฉใใ็
","็ฐ่ฒใใณ็
","็ญ็ฝ็
","ในใจ็
","่คๆ็
","ใคใๆฏ็
","ๆ็น็ดฐ่็
","CCYV","ใขใถใคใฏ็
","MYSV"]
},
'eggplant':{
'model_path':'eggplant_resnet18_last_model.pth',
'labels_list' : ["ๅฅๅ
จ","ใใฉใใ็
","็ฐ่ฒใใณ็
","่ค่ฒๅๆ็
","ใใใใณ็
","ๅ่บซ่ๅ็
","้ๆฏ็
"]
},
'strawberry':{
'model_path':'strawberry_resnet18_last_model.pth',
'labels_list' : ["ๅฅๅ
จ","ใใฉใใ็
","็ญ็ฝ็
","่้ป็
"]
},
'tomato':{
'model_path':'tomato_resnet18_last_model.pth',
'labels_list' : ["ๅฅๅ
จ","ใใฉใใ็
","็ฐ่ฒใใณ็
","ใใใใณ็
","่ใใณ็
","็ซ็
","่ค่ฒ่ผช็ด็
","้ๆฏ็
","ใใใใ็
","้ปๅ่ๅทป็
"]
},
}
# examples_images=[
# ['image/231305_20200302150233_01.JPG'],
# ['image/0004_20181120084837_01.jpg'],
# ['image/160001_20170830173740_01.JPG'],
# ['image/152300_20190119175054_01.JPG'],
# ]
# ใขใใซใฎๆบๅใใ้ขๆฐใๅฎ็พฉ
def select_model(plant_name):
model_ft = resnet18(num_classes = len(plant_models[plant_name]['labels_list']),pretrained=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft = model_ft.to(device)
if torch.cuda.is_available():
model_ft.load_state_dict(torch.load(plant_models[plant_name]['model_path']))
else:
model_ft.load_state_dict(
torch.load(plant_models[plant_name]['model_path'], map_location=torch.device("cpu"))
)
model_ft.eval()
return model_ft
# ็ปๅๅ้กใ่กใ้ขๆฐใๅฎ็พฉ
@torch.no_grad()
def inference(gr_input,gr_model_type):
img = Image.fromarray(gr_input.astype("uint8"), "RGB")
# ๅๅฆ็
img = F.resize(img, (224, 224))
img = F.to_tensor(img)
img = img.unsqueeze(0)
# ใขใใซ้ธๆ
model_ft = select_model(gr_model_type)
# ๆจ่ซ
output = model_ft(img).squeeze(0)
probs = nn.functional.softmax(output, dim=0).numpy()
labels_lenght =len(plant_models[gr_model_type]['labels_list'])
# ใฉใใซใใจใฎ็ขบ็ใdictใจใใฆ่ฟใ
return {plant_models[gr_model_type]['labels_list'][i]: float(probs[i]) for i in range(labels_lenght)}
model_labels = list(plant_models.keys())
# ๅ
ฅๅใฎๅฝขๅผใ็ปๅใจใใ
inputs = gr.inputs.Image()
# ใขใใซใฎ็จฎ้กใ้ธๆใใ
model_type = gr.inputs.Radio(model_labels, type='value', label='BASE MODEL')
# ๅบๅใฏใฉใใซๅฝขๅผใง๏ผtop4ใพใง่กจ็คบใใ
outputs = gr.outputs.Label(num_top_classes=4)
# ใตใผใใผใฎ็ซใกไธใ
interface = gr.Interface(fn=inference,
inputs=[inputs, model_type],
outputs=outputs,
title='Plant Diseases Diagnosis',
)
interface.launch()
if __name__ == "__main__":
main() |