from typing import Any import pytorch_lightning as pl from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights import torch from torch import nn from torchvision import transforms import yaml from yaml.loader import SafeLoader import gradio as gr import os class WeedModel(pl.LightningModule): def __init__(self, params): super().__init__() self.params = params model = self.params["model"] if model.lower() == "efficientnet": if self.params["pretrained"]: self.base_model = efficientnet_v2_s( weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1 ) else: self.base_model = efficientnet_v2_s(weights=None) num_ftrs = self.base_model.classifier[-1].in_features self.base_model.classifier[-1] = nn.Linear(num_ftrs, self.params["n_class"]) else: print("not prepared model yet!!") def forward(self, x): embedding = self.base_model(x) return embedding def predict_step( self, batch: Any, batch_idx: int = 0, dataloader_idx: int = 0 ) -> Any: y_hat = self(batch) preds = torch.softmax(y_hat, dim=-1).tolist() # preds = torch.argmax(preds, dim=-1) return preds def predict(image): tensor_image = transform(image) outs = model.predict_step(tensor_image.unsqueeze(0)) labels = {class_names[k]: float(v) for k, v in enumerate(outs[0][:-1])} return labels title = " AISeed AI Application Demo " description = "# A Demo of Deep Learning for Weed Classification" example_list = [["examples/" + example] for example in os.listdir("examples")] with open("class_names.txt", "r", encoding="utf-8") as f: class_names = f.read().splitlines() with gr.Blocks() as demo: demo.title = title gr.Markdown(description) with gr.Tabs(): with gr.TabItem("Images"): with gr.Row(): with gr.Column(): im = gr.Image(type="pil", label="input image", sources=["upload", "webcam"]) with gr.Column(): label_conv = gr.Label(label="Predictions", num_top_classes=4) btn = gr.Button(value="predict") btn.click(predict, inputs=im, outputs=[label_conv]) gr.Examples(examples=example_list, inputs=[im], outputs=[label_conv]) if __name__ == "__main__": with open("config.yaml") as f: PARAMS = yaml.load(f, Loader=SafeLoader) print(PARAMS) model = WeedModel.load_from_checkpoint( "model/epoch=08.ckpt", params=PARAMS, map_location=torch.device("cpu") ) model.eval() transform = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) demo.launch()