celery22's picture
Update app.py
a45d25e
import gradio as gr
import requests
import torch
import torch.nn as nn
from PIL import Image
from torchvision.models import resnet18
from torchvision.transforms import functional as F
def main():
# model_path
model_path = 'cucumber_resnet18_last_model.pth'
# class_name
class_name = ["健全","うどんこ病","灰色かび病","炭疽病","べと病","褐斑病","つる枯病","斑点細菌病","CCYV","モザイク病","MYSV"]
# example
# example = [
# 'image/healthy.jpg',
# 'image/powdery.JPG',
# 'image/graymold.JPG',
# 'image/anthracnose.JPG',
# 'image/downy.JPG',
# 'image/cornespora.JPG',
# 'image/gummy.JPG',
# 'image/bacterial.JPG',
# 'image/ccyv.jpg',
# 'image/mosaic.jpg',
# 'image/mysv.jpg']
# model定義
model_ft = resnet18(num_classes = len(class_name),pretrained=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft = model_ft.to(device)
if torch.cuda.is_available():
model_ft.load_state_dict(torch.load(model_path))
else:
model_ft.load_state_dict(
torch.load(model_path, map_location=torch.device("cpu"))
)
model_ft.eval()
# 画像分類を行う関数を定義
@torch.no_grad()
def inference(gr_input):
img = Image.fromarray(gr_input.astype("uint8"), "RGB")
# 前処理
img = F.resize(img, (224, 224))
img = F.to_tensor(img)
img = img.unsqueeze(0)
# 推論
output = model_ft(img).squeeze(0)
probs = nn.functional.softmax(output, dim=0).numpy()
labels_lenght =len(class_name)
# ラベルごとの確率をdictとして返す
return {class_name[i]: float(probs[i]) for i in range(labels_lenght)}
# # 入力の形式を画像とする
# inputs = gr.inputs.Image()
# # 出力はラベル形式で,top5まで表示する
# outputs = gr.outputs.Label(num_top_classes=5)
# # サーバーの立ち上げ
# interface = gr.Interface(fn=inference,
# inputs=[inputs],
# outputs=outputs,
# examples=example,
# title=title,
# description=description)
with gr.Blocks(title="Cucumber Diseases Diagnosis",
css=".gradio-container {background:white;}"
) as demo:
gr.HTML("""<div style="font-family:'Arial', 'Serif'; font-size:18pt; text-align:center; color:black;">Cucumber Diseases Diagnosis</div>""")
with gr.Row():
input_image = gr.inputs.Image()
output_label= gr.outputs.Label(num_top_classes=4)
send_btn = gr.Button("識別")
send_btn.click(fn=inference, inputs=input_image, outputs=output_label)
with gr.Row():
# gr.Examples(['image/healthy.jpg'], label='cucumber', inputs=input_image)
gr.Examples(['image/healthy.jpg'], label='健全', inputs=input_image)
gr.Examples(['image/powdery.JPG'], label='うどんこ病', inputs=input_image)
gr.Examples(['image/graymold.JPG'], label='灰色かび病', inputs=input_image)
gr.Examples(['image/anthracnose.JPG'], label='炭疽病', inputs=input_image)
gr.Examples(['image/downy.JPG'], label='べと病', inputs=input_image)
gr.Examples(['image/cornespora.JPG'], label='褐斑病', inputs=input_image)
gr.Examples(['image/gummy.JPG'], label='つる枯病', inputs=input_image)
gr.Examples(['image/bacterial.JPG'], label='斑点細菌病', inputs=input_image)
gr.Examples(['image/ccyv.jpg'], label='CCYV', inputs=input_image)
gr.Examples(['image/mosaic.jpg'], label='モザイク病', inputs=input_image)
gr.Examples(['image/mysv.jpg'], label='MYSV', inputs=input_image)
demo.launch()
if __name__ == "__main__":
main()