celery22's picture
Update app.py
06de145
import gradio as gr
import requests
import torch
import torch.nn as nn
from PIL import Image
from torchvision.models import resnet18
from torchvision.transforms import functional as F
def main():
# ใƒขใƒ‡ใƒซใฎ้ธๆŠž
plant_models = {
'cucumber':{
'model_path':'cucumber_resnet18_last_model.pth',
'labels_list' : ["ๅฅๅ…จ","ใ†ใฉใ‚“ใ“็—…","็ฐ่‰ฒใ‹ใณ็—…","็‚ญ็–ฝ็—…","ในใจ็—…","่คๆ–‘็—…","ใคใ‚‹ๆžฏ็—…","ๆ–‘็‚น็ดฐ่Œ็—…","CCYV","ใƒขใ‚ถใ‚คใ‚ฏ็—…","MYSV"]
},
'eggplant':{
'model_path':'eggplant_resnet18_last_model.pth',
'labels_list' : ["ๅฅๅ…จ","ใ†ใฉใ‚“ใ“็—…","็ฐ่‰ฒใ‹ใณ็—…","่ค่‰ฒๅ††ๆ˜Ÿ็—…","ใ™ใ™ใ‹ใณ็—…","ๅŠ่บซ่Žๅ‡‹็—…","้’ๆžฏ็—…"]
},
'strawberry':{
'model_path':'strawberry_resnet18_last_model.pth',
'labels_list' : ["ๅฅๅ…จ","ใ†ใฉใ‚“ใ“็—…","็‚ญ็–ฝ็—…","่Ž้ป„็—…"]
},
'tomato':{
'model_path':'tomato_resnet18_last_model.pth',
'labels_list' : ["ๅฅๅ…จ","ใ†ใฉใ‚“ใ“็—…","็ฐ่‰ฒใ‹ใณ็—…","ใ™ใ™ใ‹ใณ็—…","่‘‰ใ‹ใณ็—…","็–ซ็—…","่ค่‰ฒ่ผช็ด‹็—…","้’ๆžฏ็—…","ใ‹ใ„ใ‚ˆใ†็—…","้ป„ๅŒ–่‘‰ๅทป็—…"]
},
}
# examples_images=[
# ['image/231305_20200302150233_01.JPG'],
# ['image/0004_20181120084837_01.jpg'],
# ['image/160001_20170830173740_01.JPG'],
# ['image/152300_20190119175054_01.JPG'],
# ]
# ใƒขใƒ‡ใƒซใฎๆบ–ๅ‚™ใ™ใ‚‹้–ขๆ•ฐใ‚’ๅฎš็พฉ
def select_model(plant_name):
model_ft = resnet18(num_classes = len(plant_models[plant_name]['labels_list']),pretrained=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft = model_ft.to(device)
if torch.cuda.is_available():
model_ft.load_state_dict(torch.load(plant_models[plant_name]['model_path']))
else:
model_ft.load_state_dict(
torch.load(plant_models[plant_name]['model_path'], map_location=torch.device("cpu"))
)
model_ft.eval()
return model_ft
# ็”ปๅƒๅˆ†้กžใ‚’่กŒใ†้–ขๆ•ฐใ‚’ๅฎš็พฉ
@torch.no_grad()
def inference(gr_input,gr_model_type):
img = Image.fromarray(gr_input.astype("uint8"), "RGB")
# ๅ‰ๅ‡ฆ็†
img = F.resize(img, (224, 224))
img = F.to_tensor(img)
img = img.unsqueeze(0)
# ใƒขใƒ‡ใƒซ้ธๆŠž
model_ft = select_model(gr_model_type)
# ๆŽจ่ซ–
output = model_ft(img).squeeze(0)
probs = nn.functional.softmax(output, dim=0).numpy()
labels_lenght =len(plant_models[gr_model_type]['labels_list'])
# ใƒฉใƒ™ใƒซใ”ใจใฎ็ขบ็Ž‡ใ‚’dictใจใ—ใฆ่ฟ”ใ™
return {plant_models[gr_model_type]['labels_list'][i]: float(probs[i]) for i in range(labels_lenght)}
model_labels = list(plant_models.keys())
# ๅ…ฅๅŠ›ใฎๅฝขๅผใ‚’็”ปๅƒใจใ™ใ‚‹
inputs = gr.inputs.Image()
# ใƒขใƒ‡ใƒซใฎ็จฎ้กžใ‚’้ธๆŠžใ™ใ‚‹
model_type = gr.inputs.Radio(model_labels, type='value', label='BASE MODEL')
# ๅ‡บๅŠ›ใฏใƒฉใƒ™ใƒซๅฝขๅผใง๏ผŒtop4ใพใง่กจ็คบใ™ใ‚‹
outputs = gr.outputs.Label(num_top_classes=4)
# ใ‚ตใƒผใƒใƒผใฎ็ซ‹ใกไธŠใ’
interface = gr.Interface(fn=inference,
inputs=[inputs, model_type],
outputs=outputs,
title='Plant Diseases Diagnosis',
)
interface.launch()
if __name__ == "__main__":
main()