Spaces:
Runtime error
Runtime error
import gradio as gr | |
import requests | |
import torch | |
import torch.nn as nn | |
from PIL import Image | |
from torchvision.models import resnet18 | |
from torchvision.transforms import functional as F | |
def main(): | |
# model_path | |
model_path = 'cucumber_resnet18_last_model.pth' | |
# class_name | |
class_name = ["健全","うどんこ病","灰色かび病","炭疽病","べと病","褐斑病","つる枯病","斑点細菌病","CCYV","モザイク病","MYSV"] | |
# example | |
# example = [ | |
# 'image/healthy.jpg', | |
# 'image/powdery.JPG', | |
# 'image/graymold.JPG', | |
# 'image/anthracnose.JPG', | |
# 'image/downy.JPG', | |
# 'image/cornespora.JPG', | |
# 'image/gummy.JPG', | |
# 'image/bacterial.JPG', | |
# 'image/ccyv.jpg', | |
# 'image/mosaic.jpg', | |
# 'image/mysv.jpg'] | |
# model定義 | |
model_ft = resnet18(num_classes = len(class_name),pretrained=False) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model_ft = model_ft.to(device) | |
if torch.cuda.is_available(): | |
model_ft.load_state_dict(torch.load(model_path)) | |
else: | |
model_ft.load_state_dict( | |
torch.load(model_path, map_location=torch.device("cpu")) | |
) | |
model_ft.eval() | |
# 画像分類を行う関数を定義 | |
def inference(gr_input): | |
img = Image.fromarray(gr_input.astype("uint8"), "RGB") | |
# 前処理 | |
img = F.resize(img, (224, 224)) | |
img = F.to_tensor(img) | |
img = img.unsqueeze(0) | |
# 推論 | |
output = model_ft(img).squeeze(0) | |
probs = nn.functional.softmax(output, dim=0).numpy() | |
labels_lenght =len(class_name) | |
# ラベルごとの確率をdictとして返す | |
return {class_name[i]: float(probs[i]) for i in range(labels_lenght)} | |
# # 入力の形式を画像とする | |
# inputs = gr.inputs.Image() | |
# # 出力はラベル形式で,top5まで表示する | |
# outputs = gr.outputs.Label(num_top_classes=5) | |
# # サーバーの立ち上げ | |
# interface = gr.Interface(fn=inference, | |
# inputs=[inputs], | |
# outputs=outputs, | |
# examples=example, | |
# title=title, | |
# description=description) | |
with gr.Blocks(title="Cucumber Diseases Diagnosis", | |
css=".gradio-container {background:white;}" | |
) as demo: | |
gr.HTML("""<div style="font-family:'Arial', 'Serif'; font-size:18pt; text-align:center; color:black;">Cucumber Diseases Diagnosis</div>""") | |
with gr.Row(): | |
input_image = gr.inputs.Image() | |
output_label= gr.outputs.Label(num_top_classes=4) | |
send_btn = gr.Button("識別") | |
send_btn.click(fn=inference, inputs=input_image, outputs=output_label) | |
with gr.Row(): | |
# gr.Examples(['image/healthy.jpg'], label='cucumber', inputs=input_image) | |
gr.Examples(['image/healthy.jpg'], label='健全', inputs=input_image) | |
gr.Examples(['image/powdery.JPG'], label='うどんこ病', inputs=input_image) | |
gr.Examples(['image/graymold.JPG'], label='灰色かび病', inputs=input_image) | |
gr.Examples(['image/anthracnose.JPG'], label='炭疽病', inputs=input_image) | |
gr.Examples(['image/downy.JPG'], label='べと病', inputs=input_image) | |
gr.Examples(['image/cornespora.JPG'], label='褐斑病', inputs=input_image) | |
gr.Examples(['image/gummy.JPG'], label='つる枯病', inputs=input_image) | |
gr.Examples(['image/bacterial.JPG'], label='斑点細菌病', inputs=input_image) | |
gr.Examples(['image/ccyv.jpg'], label='CCYV', inputs=input_image) | |
gr.Examples(['image/mosaic.jpg'], label='モザイク病', inputs=input_image) | |
gr.Examples(['image/mysv.jpg'], label='MYSV', inputs=input_image) | |
demo.launch() | |
if __name__ == "__main__": | |
main() |