COSMO / app.py
vardaan123's picture
Upload folder using huggingface_hub
3dba732 verified
raw history blame
No virus
3.19 kB
import os
import pandas as pd
import torch
import torch.nn.functional as F
import gradio as gr
from model import DistMult
from PIL import Image
from torchvision import transforms
import json
from tqdm import tqdm
# Default image tensor normalization
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225]
def generate_target_list(data, entity2id):
sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']]
sub = list(sub['t'])
categories = []
for item in tqdm(sub):
if entity2id[str(int(float(item)))] not in categories:
categories.append(entity2id[str(int(float(item)))])
# print('categories = {}'.format(categories))
# print("No. of target categories = {}".format(len(categories)))
return torch.tensor(categories, dtype=torch.long).unsqueeze(-1)
# Load necessary data and initialize the model
entity2id = json.load(open('entity2id_subtree.json', 'r'))
id2entity = {v: k for k, v in entity2id.items()}
datacsv = pd.read_csv('dataset_subtree.csv', low_memory=False)
num_ent_id = len(entity2id)
target_list = generate_target_list(datacsv, entity2id) # Assuming this function is defined elsewhere
overall_id_to_name = json.load(open('overall_id_to_name.json'))
# Initialize your model here
model = DistMult(num_ent_id, target_list, torch.device('cpu')) # Update arguments as necessary
model.eval()
ckpt = torch.load('species_class_model.pt', map_location=torch.device('cpu'))
model.load_state_dict(ckpt['model'], strict=False)
print('ckpt loaded...')
# Define your evaluation function
def evaluate(img):
transform_steps = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((448, 448)),
transforms.ToTensor(),
transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD)
])
h = transform_steps(img)
r = torch.tensor([3])
# Assuming `move_to` is a function to move tensors to the desired device
h = h.unsqueeze(0)
r = r.unsqueeze(0)
outputs = F.softmax(model.forward_ce(h, r, triple_type=('image', 'id')), dim=-1)
# print('outputs = {}'.format(outputs.size()))
predictions = torch.topk(outputs, k=5, dim=-1).indices.squeeze(0).tolist()
# print('predictions', predictions)
result = {}
for i in predictions:
pred_label = target_list[i].item()
label = overall_id_to_name[str(id2entity[pred_label])]
prob = outputs[0, i].item()
result[label] = prob
# y_pred = outputs.argmax(-1).cpu()
# pred_label = target_list[y_pred].item()
# species_label = overall_id_to_name[str(id2entity[pred_label])]
# print('pred_label', pred_label)
# print('species_label', species_label)
# return species_label
return result
# Gradio interface
species_model = gr.Interface(
evaluate,
gr.inputs.Image(shape=(200, 200)),
outputs="label",
title='Camera Trap Species Classification demo',
# description='Species Classification',
# article='Species Classification'
)
species_model.launch(server_port=8977,share=True, debug=True)