File size: 3,194 Bytes
3dba732
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import pandas as pd
import torch
import torch.nn.functional as F
import gradio as gr
from model import DistMult
from PIL import Image
from torchvision import transforms
import json
from tqdm import tqdm

# Default image tensor normalization
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225]

def generate_target_list(data, entity2id):
    sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']]
    sub = list(sub['t'])
    categories = []
    for item in tqdm(sub):
        if entity2id[str(int(float(item)))] not in categories:
            categories.append(entity2id[str(int(float(item)))])
    # print('categories = {}'.format(categories))
    # print("No. of target categories = {}".format(len(categories)))
    return torch.tensor(categories, dtype=torch.long).unsqueeze(-1)
    
# Load necessary data and initialize the model
entity2id = json.load(open('entity2id_subtree.json', 'r'))
id2entity = {v: k for k, v in entity2id.items()}
datacsv = pd.read_csv('dataset_subtree.csv', low_memory=False)
num_ent_id = len(entity2id)
target_list = generate_target_list(datacsv, entity2id)  # Assuming this function is defined elsewhere
overall_id_to_name = json.load(open('overall_id_to_name.json'))

# Initialize your model here
model = DistMult(num_ent_id, target_list, torch.device('cpu'))  # Update arguments as necessary
model.eval()

ckpt = torch.load('species_class_model.pt', map_location=torch.device('cpu'))
model.load_state_dict(ckpt['model'], strict=False)
print('ckpt loaded...')

# Define your evaluation function
def evaluate(img):
    transform_steps = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((448, 448)),
        transforms.ToTensor(),
        transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD)
    ])
    h = transform_steps(img)
    r = torch.tensor([3])

    # Assuming `move_to` is a function to move tensors to the desired device
    h = h.unsqueeze(0)
    r = r.unsqueeze(0)

    outputs = F.softmax(model.forward_ce(h, r, triple_type=('image', 'id')), dim=-1)

    # print('outputs = {}'.format(outputs.size()))

    predictions = torch.topk(outputs, k=5, dim=-1).indices.squeeze(0).tolist()

    # print('predictions', predictions)

    result = {}
    for i in predictions:
        pred_label = target_list[i].item()
        label = overall_id_to_name[str(id2entity[pred_label])]
        prob = outputs[0, i].item()
        result[label] = prob

    # y_pred = outputs.argmax(-1).cpu()
    # pred_label = target_list[y_pred].item()
    # species_label = overall_id_to_name[str(id2entity[pred_label])]
    
    # print('pred_label', pred_label)
    # print('species_label', species_label)

    # return species_label
    return result

# Gradio interface
species_model = gr.Interface(
    evaluate,
    gr.inputs.Image(shape=(200, 200)),
    outputs="label",
    title='Camera Trap Species Classification demo',
    # description='Species Classification',
    # article='Species Classification'
)
species_model.launch(server_port=8977,share=True, debug=True)