plant-id-3 / app.py
for876543's picture
Create new file
5e0099e
raw
history blame
1.41 kB
import timm
import torch
import torch.nn.functional as nnf
import gradio as gr
import numpy as np
import json
class GELU(torch.nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.gelu(input)
torch.nn.modules.activation.GELU = GELU
model = torch.load("/home/user/app/run45.pkl",map_location=torch.device('cpu'))
with open('/home/user/app/val.json', 'r') as handle:
parsed = json.load(handle)
classes = []
for i in range(len(parsed["categories"])):
if parsed["categories"][i]['supercategory'] == 'Plants':
classes.append(parsed["categories"][i]['name'])
classes = set(classes)
classes = list(classes)
classes.sort()
labels = classes
def classify_image(inp):
print(inp.shape)
inp = inp.astype(np.uint8).reshape((-1, 3, 224, 224))
print(inp.shape)
inp = torch.from_numpy(inp).float()
#confidences = model(inp)
preds = nnf.softmax(model(inp).data[0], dim=0)
preds = [pred.cpu() for pred in preds]
preds = [pred.detach().numpy() for pred in preds]
#confidences_dict = {classes[i]: float(confidences.data[0][i]) for i in range(len(confidences.data[0]))}
confidences_dict = {classes[i]: float(preds[i]) for i in range(len(preds))}
return confidences_dict
gr.Interface(fn=classify_image,
inputs=gr.Image(shape=(224, 224)),
outputs=gr.Label(num_top_classes=3)).launch(debug = True)