Kaori1707's picture
fix app
91df7cc
raw
history blame contribute delete
No virus
2.98 kB
from typing import Any
import pytorch_lightning as pl
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
import torch
from torch import nn
from torchvision import transforms
import yaml
from yaml.loader import SafeLoader
import gradio as gr
import os
class WeedModel(pl.LightningModule):
def __init__(self, params):
super().__init__()
self.params = params
model = self.params["model"]
if model.lower() == "efficientnet":
if self.params["pretrained"]:
self.base_model = efficientnet_v2_s(
weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1
)
else:
self.base_model = efficientnet_v2_s(weights=None)
num_ftrs = self.base_model.classifier[-1].in_features
self.base_model.classifier[-1] = nn.Linear(num_ftrs, self.params["n_class"])
else:
print("not prepared model yet!!")
def forward(self, x):
embedding = self.base_model(x)
return embedding
def predict_step(
self, batch: Any, batch_idx: int = 0, dataloader_idx: int = 0
) -> Any:
y_hat = self(batch)
preds = torch.softmax(y_hat, dim=-1).tolist()
# preds = torch.argmax(preds, dim=-1)
return preds
def predict(image):
tensor_image = transform(image)
outs = model.predict_step(tensor_image.unsqueeze(0))
labels = {class_names[k]: float(v) for k, v in enumerate(outs[0][:-1])}
return labels
title = " AISeed AI Application Demo "
description = "# A Demo of Deep Learning for Weed Classification"
example_list = [["examples/" + example] for example in os.listdir("examples")]
with open("class_names.txt", "r", encoding="utf-8") as f:
class_names = f.read().splitlines()
with gr.Blocks() as demo:
demo.title = title
gr.Markdown(description)
with gr.Tabs():
with gr.TabItem("Images"):
with gr.Row():
with gr.Column():
im = gr.Image(type="pil", label="input image", sources=["upload", "webcam"])
with gr.Column():
label_conv = gr.Label(label="Predictions", num_top_classes=4)
btn = gr.Button(value="predict")
btn.click(predict, inputs=im, outputs=[label_conv])
gr.Examples(examples=example_list, inputs=[im], outputs=[label_conv])
if __name__ == "__main__":
with open("config.yaml") as f:
PARAMS = yaml.load(f, Loader=SafeLoader)
print(PARAMS)
model = WeedModel.load_from_checkpoint(
"model/epoch=08.ckpt", params=PARAMS, map_location=torch.device("cpu")
)
model.eval()
transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
demo.launch()