File size: 2,522 Bytes
bc38547
 
 
3f6e9c6
bc38547
 
 
93d2b0f
 
bc38547
93d2b0f
bc38547
 
 
e75805e
 
 
 
 
 
 
 
 
 
 
93d2b0f
 
 
 
 
 
7218171
bc38547
93d2b0f
2840071
93d2b0f
bc38547
93d2b0f
 
 
7218171
93d2b0f
 
 
 
bc38547
 
 
93d2b0f
 
 
bc38547
 
93d2b0f
bc38547
 
3f6e9c6
e6a3d86
bc38547
9abfd25
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import pandas as pd
import torch
import gradio as gr
from model import DistMult
from PIL import Image
from torchvision import transforms
import json
from tqdm import tqdm

# Default image tensor normalization
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225]

def generate_target_list(data, entity2id):
    sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']]
    sub = list(sub['t'])
    categories = []
    for item in tqdm(sub):
        if entity2id[str(int(float(item)))] not in categories:
            categories.append(entity2id[str(int(float(item)))])
    # print('categories = {}'.format(categories))
    # print("No. of target categories = {}".format(len(categories)))
    return torch.tensor(categories, dtype=torch.long).unsqueeze(-1)
    
# Load necessary data and initialize the model
entity2id = json.load(open('entity2id_subtree.json', 'r'))
id2entity = {v: k for k, v in entity2id.items()}
datacsv = pd.read_csv('dataset_subtree.csv', low_memory=False)
num_ent_id = len(entity2id)
target_list = generate_target_list(datacsv, entity2id)  # Assuming this function is defined elsewhere
overall_id_to_name = json.load(open('overall_id_to_name.json'))

# Initialize your model here
model = DistMult(num_ent_id, target_list, torch.device('cpu'))  # Update arguments as necessary
model.eval()

# Define your evaluation function
def evaluate(img):
    transform_steps = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((448, 448)),
        transforms.ToTensor(),
        transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD)
    ])
    h = transform_steps(img)
    r = torch.tensor([3])

    # Assuming `move_to` is a function to move tensors to the desired device
    h = h.unsqueeze(0)
    r = r.unsqueeze(0)

    outputs = model.forward_ce(h, r, triple_type=('image', 'id'))
    y_pred = outputs.argmax(-1).cpu()
    pred_label = target_list[y_pred].item()
    species_label = overall_id_to_name[str(id2entity[pred_label])]

    return {species_label:1.0}

if __name__ == '__main__':
    # Gradio interface
    species_model = gr.Interface(
        evaluate,
        gr.inputs.Image(shape=(200, 200)),
        outputs="label",
        title='Species Classification',
        description='Species Classification',
        article='Species Classification'
    )
    species_model.launch(share=True, debug=True)