import gradio as gr import requests import torch import torch.nn as nn from PIL import Image from torchvision.models import resnet18 from torchvision.transforms import functional as F def main(): # model_path model_path = 'cucumber_resnet18_last_model.pth' # class_name class_name = ["健全","うどんこ病","灰色かび病","炭疽病","べと病","褐斑病","つる枯病","斑点細菌病","CCYV","モザイク病","MYSV"] # example # example = [ # 'image/healthy.jpg', # 'image/powdery.JPG', # 'image/graymold.JPG', # 'image/anthracnose.JPG', # 'image/downy.JPG', # 'image/cornespora.JPG', # 'image/gummy.JPG', # 'image/bacterial.JPG', # 'image/ccyv.jpg', # 'image/mosaic.jpg', # 'image/mysv.jpg'] # model定義 model_ft = resnet18(num_classes = len(class_name),pretrained=False) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_ft = model_ft.to(device) if torch.cuda.is_available(): model_ft.load_state_dict(torch.load(model_path)) else: model_ft.load_state_dict( torch.load(model_path, map_location=torch.device("cpu")) ) model_ft.eval() # 画像分類を行う関数を定義 @torch.no_grad() def inference(gr_input): img = Image.fromarray(gr_input.astype("uint8"), "RGB") # 前処理 img = F.resize(img, (224, 224)) img = F.to_tensor(img) img = img.unsqueeze(0) # 推論 output = model_ft(img).squeeze(0) probs = nn.functional.softmax(output, dim=0).numpy() labels_lenght =len(class_name) # ラベルごとの確率をdictとして返す return {class_name[i]: float(probs[i]) for i in range(labels_lenght)} # # 入力の形式を画像とする # inputs = gr.inputs.Image() # # 出力はラベル形式で,top5まで表示する # outputs = gr.outputs.Label(num_top_classes=5) # # サーバーの立ち上げ # interface = gr.Interface(fn=inference, # inputs=[inputs], # outputs=outputs, # examples=example, # title=title, # description=description) with gr.Blocks(title="Cucumber Diseases Diagnosis", css=".gradio-container {background:white;}" ) as demo: gr.HTML("""
Cucumber Diseases Diagnosis
""") with gr.Row(): input_image = gr.inputs.Image() output_label= gr.outputs.Label(num_top_classes=4) send_btn = gr.Button("識別") send_btn.click(fn=inference, inputs=input_image, outputs=output_label) with gr.Row(): # gr.Examples(['image/healthy.jpg'], label='cucumber', inputs=input_image) gr.Examples(['image/healthy.jpg'], label='健全', inputs=input_image) gr.Examples(['image/powdery.JPG'], label='うどんこ病', inputs=input_image) gr.Examples(['image/graymold.JPG'], label='灰色かび病', inputs=input_image) gr.Examples(['image/anthracnose.JPG'], label='炭疽病', inputs=input_image) gr.Examples(['image/downy.JPG'], label='べと病', inputs=input_image) gr.Examples(['image/cornespora.JPG'], label='褐斑病', inputs=input_image) gr.Examples(['image/gummy.JPG'], label='つる枯病', inputs=input_image) gr.Examples(['image/bacterial.JPG'], label='斑点細菌病', inputs=input_image) gr.Examples(['image/ccyv.jpg'], label='CCYV', inputs=input_image) gr.Examples(['image/mosaic.jpg'], label='モザイク病', inputs=input_image) gr.Examples(['image/mysv.jpg'], label='MYSV', inputs=input_image) demo.launch() if __name__ == "__main__": main()