COSMO-demo / app.py
vardaan123's picture
Update app.py
9abfd25
import os
import pandas as pd
import torch
import gradio as gr
from model import DistMult
from PIL import Image
from torchvision import transforms
import json
from tqdm import tqdm
# Default image tensor normalization
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225]
def generate_target_list(data, entity2id):
sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']]
sub = list(sub['t'])
categories = []
for item in tqdm(sub):
if entity2id[str(int(float(item)))] not in categories:
categories.append(entity2id[str(int(float(item)))])
# print('categories = {}'.format(categories))
# print("No. of target categories = {}".format(len(categories)))
return torch.tensor(categories, dtype=torch.long).unsqueeze(-1)
# Load necessary data and initialize the model
entity2id = json.load(open('entity2id_subtree.json', 'r'))
id2entity = {v: k for k, v in entity2id.items()}
datacsv = pd.read_csv('dataset_subtree.csv', low_memory=False)
num_ent_id = len(entity2id)
target_list = generate_target_list(datacsv, entity2id) # Assuming this function is defined elsewhere
overall_id_to_name = json.load(open('overall_id_to_name.json'))
# Initialize your model here
model = DistMult(num_ent_id, target_list, torch.device('cpu')) # Update arguments as necessary
model.eval()
# Define your evaluation function
def evaluate(img):
transform_steps = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((448, 448)),
transforms.ToTensor(),
transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD)
])
h = transform_steps(img)
r = torch.tensor([3])
# Assuming `move_to` is a function to move tensors to the desired device
h = h.unsqueeze(0)
r = r.unsqueeze(0)
outputs = model.forward_ce(h, r, triple_type=('image', 'id'))
y_pred = outputs.argmax(-1).cpu()
pred_label = target_list[y_pred].item()
species_label = overall_id_to_name[str(id2entity[pred_label])]
return {species_label:1.0}
if __name__ == '__main__':
# Gradio interface
species_model = gr.Interface(
evaluate,
gr.inputs.Image(shape=(200, 200)),
outputs="label",
title='Species Classification',
description='Species Classification',
article='Species Classification'
)
species_model.launch(share=True, debug=True)