import os import pandas as pd import torch import torch.nn.functional as F import gradio as gr from model import DistMult from PIL import Image from torchvision import transforms import json from tqdm import tqdm # Default image tensor normalization _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406] _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225] def generate_target_list(data, entity2id): sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']] sub = list(sub['t']) categories = [] for item in tqdm(sub): if entity2id[str(int(float(item)))] not in categories: categories.append(entity2id[str(int(float(item)))]) # print('categories = {}'.format(categories)) # print("No. of target categories = {}".format(len(categories))) return torch.tensor(categories, dtype=torch.long).unsqueeze(-1) # Load necessary data and initialize the model entity2id = json.load(open('entity2id_subtree.json', 'r')) id2entity = {v: k for k, v in entity2id.items()} datacsv = pd.read_csv('dataset_subtree.csv', low_memory=False) num_ent_id = len(entity2id) target_list = generate_target_list(datacsv, entity2id) # Assuming this function is defined elsewhere overall_id_to_name = json.load(open('overall_id_to_name.json')) # Initialize your model here model = DistMult(num_ent_id, target_list, torch.device('cpu')) # Update arguments as necessary model.eval() ckpt = torch.load('species_class_model.pt', map_location=torch.device('cpu')) model.load_state_dict(ckpt['model'], strict=False) print('ckpt loaded...') # Define your evaluation function def evaluate(img): transform_steps = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((448, 448)), transforms.ToTensor(), transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD) ]) h = transform_steps(img) r = torch.tensor([3]) # Assuming `move_to` is a function to move tensors to the desired device h = h.unsqueeze(0) r = r.unsqueeze(0) outputs = F.softmax(model.forward_ce(h, r, triple_type=('image', 'id')), dim=-1) # print('outputs = {}'.format(outputs.size())) predictions = torch.topk(outputs, k=5, dim=-1).indices.squeeze(0).tolist() # print('predictions', predictions) result = {} for i in predictions: pred_label = target_list[i].item() label = overall_id_to_name[str(id2entity[pred_label])] prob = outputs[0, i].item() result[label] = prob # y_pred = outputs.argmax(-1).cpu() # pred_label = target_list[y_pred].item() # species_label = overall_id_to_name[str(id2entity[pred_label])] # print('pred_label', pred_label) # print('species_label', species_label) # return species_label return result # Gradio interface species_model = gr.Interface( evaluate, gr.inputs.Image(shape=(200, 200)), outputs="label", title='Camera Trap Species Classification demo', # description='Species Classification', # article='Species Classification' ) species_model.launch(server_port=8977,share=True, debug=True)