import os import pandas as pd import torch import gradio as gr from model import DistMult from PIL import Image from torchvision import transforms import json from tqdm import tqdm # Default image tensor normalization _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406] _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225] def generate_target_list(data, entity2id): sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']] sub = list(sub['t']) categories = [] for item in tqdm(sub): if entity2id[str(int(float(item)))] not in categories: categories.append(entity2id[str(int(float(item)))]) # print('categories = {}'.format(categories)) # print("No. of target categories = {}".format(len(categories))) return torch.tensor(categories, dtype=torch.long).unsqueeze(-1) # Load necessary data and initialize the model entity2id = json.load(open('entity2id_subtree.json', 'r')) id2entity = {v: k for k, v in entity2id.items()} datacsv = pd.read_csv('dataset_subtree.csv', low_memory=False) num_ent_id = len(entity2id) target_list = generate_target_list(datacsv, entity2id) # Assuming this function is defined elsewhere overall_id_to_name = json.load(open('overall_id_to_name.json')) # Initialize your model here model = DistMult(num_ent_id, target_list, torch.device('cpu')) # Update arguments as necessary model.eval() # Define your evaluation function def evaluate(img): transform_steps = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((448, 448)), transforms.ToTensor(), transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD) ]) h = transform_steps(img) r = torch.tensor([3]) # Assuming `move_to` is a function to move tensors to the desired device h = h.unsqueeze(0) r = r.unsqueeze(0) outputs = model.forward_ce(h, r, triple_type=('image', 'id')) y_pred = outputs.argmax(-1).cpu() pred_label = target_list[y_pred].item() species_label = overall_id_to_name[str(id2entity[pred_label])] return {species_label:1.0} if __name__ == '__main__': # Gradio interface species_model = gr.Interface( evaluate, gr.inputs.Image(shape=(200, 200)), outputs="label", title='Species Classification', description='Species Classification', article='Species Classification' ) species_model.launch(share=True, debug=True)