Spaces:
Build error
Build error
import timm | |
import torch | |
import torch.nn.functional as nnf | |
import gradio as gr | |
import numpy as np | |
import json | |
class GELU(torch.nn.Module): | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
return torch.nn.functional.gelu(input) | |
torch.nn.modules.activation.GELU = GELU | |
model = torch.load("/home/user/app/trained_model_2e960a3n_model.pkl",map_location=torch.device('cpu')) | |
with open('/home/user/app/val.json', 'r') as handle: | |
parsed = json.load(handle) | |
classes = [] | |
for i in range(len(parsed["categories"])): | |
if parsed["categories"][i]['supercategory'] == 'Plants': | |
classes.append(parsed["categories"][i]['name']) | |
classes = set(classes) | |
classes = list(classes) | |
classes.sort() | |
labels = classes | |
def classify_image(inp): | |
print(inp.shape) | |
inp = inp.astype(np.uint8).reshape((-1, 3, 224, 224)) | |
print(inp.shape) | |
inp = torch.from_numpy(inp).float() | |
#confidences = model(inp) | |
preds = nnf.softmax(model(inp).data[0], dim=0) | |
preds = [pred.cpu() for pred in preds] | |
preds = [pred.detach().numpy() for pred in preds] | |
#confidences_dict = {classes[i]: float(confidences.data[0][i]) for i in range(len(confidences.data[0]))} | |
confidences_dict = {classes[i]: float(preds[i]) for i in range(len(preds))} | |
return confidences_dict | |
gr.Interface(fn=classify_image, | |
inputs=gr.Image(shape=(224, 224)), | |
outputs=gr.Label(num_top_classes=3)).launch(debug = True) |