plant-id-3 / app.py
for876543's picture
Update app.py
7acc8fb
raw
history blame
1.51 kB
import timm
import torch
import torch.nn.functional as nnf
import gradio as gr
import numpy as np
import pandas as pd
import json
model = torch.load("/home/user/app/model_30c4tc4y_scripted.pkl",map_location=torch.device('cpu'))
model.eval()
with open('/home/user/app/val.json', 'r') as handle:
parsed = json.load(handle)
classes = []
for i in range(len(parsed["categories"])):
if parsed["categories"][i]['supercategory'] == 'Plants':
classes.append(parsed["categories"][i]['name'])
classes = set(classes)
classes = list(classes)
classes.sort()
labels = classes
def classify_image(inp):
#print(inp.shape)
inp = inp.astype(np.uint8).reshape((-1, 3, 300, 300))
#print(inp.shape)
inp = torch.from_numpy(inp).float()
#confidences = model(inp)
preds = model(inp).data[0]
means = preds.mean(dim=0, keepdim=True)
stds = preds.std(dim=0, keepdim=True)
preds = 4 * (preds - means) / stds
#preds = nnf.normalize(model(inp).data[0], dim=0)
preds = nnf.softmax(preds, dim=0)
preds = [pred.cpu() for pred in preds]
preds = [float(pred.detach()) for pred in preds]
print(pd.Series(preds).describe())
#confidences_dict = {classes[i]: float(confidences.data[0][i]) for i in range(len(confidences.data[0]))}
confidences_dict = {classes[i]: float(preds[i]) for i in range(len(preds))}
return confidences_dict
gr.Interface(fn=classify_image,
inputs=gr.Image(shape=(300, 300)),
outputs=gr.Label(num_top_classes=3)).launch(debug = True)