import gradio as gr import requests import torch import torch.nn as nn from PIL import Image from torchvision.models import resnet18 from torchvision.transforms import functional as F def main(): # モデルの選択 plant_models = { 'cucumber':{ 'model_path':'cucumber_resnet18_last_model.pth', 'labels_list' : ["健全","うどんこ病","灰色かび病","炭疽病","べと病","褐斑病","つる枯病","斑点細菌病","CCYV","モザイク病","MYSV"] }, 'eggplant':{ 'model_path':'eggplant_resnet18_last_model.pth', 'labels_list' : ["健全","うどんこ病","灰色かび病","褐色円星病","すすかび病","半身萎凋病","青枯病"] }, 'strawberry':{ 'model_path':'strawberry_resnet18_last_model.pth', 'labels_list' : ["健全","うどんこ病","炭疽病","萎黄病"] }, 'tomato':{ 'model_path':'tomato_resnet18_last_model.pth', 'labels_list' : ["健全","うどんこ病","灰色かび病","すすかび病","葉かび病","疫病","褐色輪紋病","青枯病","かいよう病","黄化葉巻病"] }, } # examples_images=[ # ['image/231305_20200302150233_01.JPG'], # ['image/0004_20181120084837_01.jpg'], # ['image/160001_20170830173740_01.JPG'], # ['image/152300_20190119175054_01.JPG'], # ] # モデルの準備する関数を定義 def select_model(plant_name): model_ft = resnet18(num_classes = len(plant_models[plant_name]['labels_list']),pretrained=False) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_ft = model_ft.to(device) if torch.cuda.is_available(): model_ft.load_state_dict(torch.load(plant_models[plant_name]['model_path'])) else: model_ft.load_state_dict( torch.load(plant_models[plant_name]['model_path'], map_location=torch.device("cpu")) ) model_ft.eval() return model_ft # 画像分類を行う関数を定義 @torch.no_grad() def inference(gr_input,gr_model_type): img = Image.fromarray(gr_input.astype("uint8"), "RGB") # 前処理 img = F.resize(img, (224, 224)) img = F.to_tensor(img) img = img.unsqueeze(0) # モデル選択 model_ft = select_model(gr_model_type) # 推論 output = model_ft(img).squeeze(0) probs = nn.functional.softmax(output, dim=0).numpy() labels_lenght =len(plant_models[gr_model_type]['labels_list']) # ラベルごとの確率をdictとして返す return {plant_models[gr_model_type]['labels_list'][i]: float(probs[i]) for i in range(labels_lenght)} model_labels = list(plant_models.keys()) # 入力の形式を画像とする inputs = gr.inputs.Image() # モデルの種類を選択する model_type = gr.inputs.Radio(model_labels, type='value', label='BASE MODEL') # 出力はラベル形式で,top4まで表示する outputs = gr.outputs.Label(num_top_classes=4) # サーバーの立ち上げ interface = gr.Interface(fn=inference, inputs=[inputs, model_type], outputs=outputs, title='Plant Diseases Diagnosis', ) interface.launch() if __name__ == "__main__": main()