Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import requests | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from torchvision.models import resnet18 | |
| from torchvision.transforms import functional as F | |
| def main(): | |
| # ใขใใซใฎ้ธๆ | |
| plant_models = { | |
| 'cucumber':{ | |
| 'model_path':'cucumber_resnet18_last_model.pth', | |
| 'labels_list' : ["ๅฅๅ จ","ใใฉใใ็ ","็ฐ่ฒใใณ็ ","็ญ็ฝ็ ","ในใจ็ ","่คๆ็ ","ใคใๆฏ็ ","ๆ็น็ดฐ่็ ","CCYV","ใขใถใคใฏ็ ","MYSV"] | |
| }, | |
| 'eggplant':{ | |
| 'model_path':'eggplant_resnet18_last_model.pth', | |
| 'labels_list' : ["ๅฅๅ จ","ใใฉใใ็ ","็ฐ่ฒใใณ็ ","่ค่ฒๅๆ็ ","ใใใใณ็ ","ๅ่บซ่ๅ็ ","้ๆฏ็ "] | |
| }, | |
| 'strawberry':{ | |
| 'model_path':'strawberry_resnet18_last_model.pth', | |
| 'labels_list' : ["ๅฅๅ จ","ใใฉใใ็ ","็ญ็ฝ็ ","่้ป็ "] | |
| }, | |
| 'tomato':{ | |
| 'model_path':'tomato_resnet18_last_model.pth', | |
| 'labels_list' : ["ๅฅๅ จ","ใใฉใใ็ ","็ฐ่ฒใใณ็ ","ใใใใณ็ ","่ใใณ็ ","็ซ็ ","่ค่ฒ่ผช็ด็ ","้ๆฏ็ ","ใใใใ็ ","้ปๅ่ๅทป็ "] | |
| }, | |
| } | |
| # examples_images=[ | |
| # ['image/231305_20200302150233_01.JPG'], | |
| # ['image/0004_20181120084837_01.jpg'], | |
| # ['image/160001_20170830173740_01.JPG'], | |
| # ['image/152300_20190119175054_01.JPG'], | |
| # ] | |
| # ใขใใซใฎๆบๅใใ้ขๆฐใๅฎ็พฉ | |
| def select_model(plant_name): | |
| model_ft = resnet18(num_classes = len(plant_models[plant_name]['labels_list']),pretrained=False) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| model_ft = model_ft.to(device) | |
| if torch.cuda.is_available(): | |
| model_ft.load_state_dict(torch.load(plant_models[plant_name]['model_path'])) | |
| else: | |
| model_ft.load_state_dict( | |
| torch.load(plant_models[plant_name]['model_path'], map_location=torch.device("cpu")) | |
| ) | |
| model_ft.eval() | |
| return model_ft | |
| # ็ปๅๅ้กใ่กใ้ขๆฐใๅฎ็พฉ | |
| def inference(gr_input,gr_model_type): | |
| img = Image.fromarray(gr_input.astype("uint8"), "RGB") | |
| # ๅๅฆ็ | |
| img = F.resize(img, (224, 224)) | |
| img = F.to_tensor(img) | |
| img = img.unsqueeze(0) | |
| # ใขใใซ้ธๆ | |
| model_ft = select_model(gr_model_type) | |
| # ๆจ่ซ | |
| output = model_ft(img).squeeze(0) | |
| probs = nn.functional.softmax(output, dim=0).numpy() | |
| labels_lenght =len(plant_models[gr_model_type]['labels_list']) | |
| # ใฉใใซใใจใฎ็ขบ็ใdictใจใใฆ่ฟใ | |
| return {plant_models[gr_model_type]['labels_list'][i]: float(probs[i]) for i in range(labels_lenght)} | |
| model_labels = list(plant_models.keys()) | |
| # ๅ ฅๅใฎๅฝขๅผใ็ปๅใจใใ | |
| inputs = gr.inputs.Image() | |
| # ใขใใซใฎ็จฎ้กใ้ธๆใใ | |
| model_type = gr.inputs.Radio(model_labels, type='value', label='BASE MODEL') | |
| # ๅบๅใฏใฉใใซๅฝขๅผใง๏ผtop4ใพใง่กจ็คบใใ | |
| outputs = gr.outputs.Label(num_top_classes=4) | |
| # ใตใผใใผใฎ็ซใกไธใ | |
| interface = gr.Interface(fn=inference, | |
| inputs=[inputs, model_type], | |
| outputs=outputs, | |
| title='Plant Diseases Diagnosis', | |
| ) | |
| interface.launch() | |
| if __name__ == "__main__": | |
| main() |