File size: 5,578 Bytes
bc7eeae
a4273c0
7a3a9e6
 
 
 
 
 
 
 
4a302c4
39a2576
1250ac1
a4273c0
 
 
 
 
 
 
 
a9d3066
39a2576
a9d3066
 
39a2576
a9d3066
 
 
 
 
bc7eeae
1250ac1
 
bc7eeae
1250ac1
 
bc7eeae
a4f6a62
5ef2f9f
1250ac1
bc7eeae
1250ac1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac2aecb
 
 
 
 
 
 
 
1250ac1
 
 
 
 
 
ac2aecb
 
 
 
 
ea66648
4e43b2f
bc7eeae
 
9f0ad39
bc7eeae
 
 
 
 
9f0ad39
8df5b71
bc7eeae
 
9f0ad39
c1ce821
 
 
8df5b71
 
192cd9b
8df5b71
9c4307a
9b234af
9c4307a
0550f3f
 
 
 
 
 
 
 
ac2aecb
 
0550f3f
bc7eeae
0550f3f
8f2c4bb
 
ac2aecb
 
8f2c4bb
 
 
 
 
 
 
 
 
 
1250ac1
f254fdd
 
9b234af
 
bd9dfee
bc7eeae
9c4307a
ea66648
 
 
8f2c4bb
21a7f51
9b234af
f254fdd
1c3a97b
 
3cd452d
 
f254fdd
 
f7d0b4b
 
4f3b9d4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# environment setup
import os
os.system("pip install torch torchvision")
os.system("git clone https://github.com/IDEA-Research/detrex.git")
os.system("python3.10 -m pip install git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2")
os.system("python3.10 -m pip install git+https://github.com/IDEA-Research/detrex.git@v0.5.0#egg=detrex")
os.system("git submodule sync")
os.system("git submodule update --init")
os.system("pip install Pillow==9.5.0")
os.system("pip install fairscale")
os.system("pip install opencv-python")
os.system("cp -rf '/home/user/app/utils/data' '/usr/local/lib/python3.10/site-packages/detrex/config/configs/common/'")

# import libs
import cv2
import json
import numpy as np
import gradio as gr
import warnings
warnings.filterwarnings("ignore")

# adapt files for cpu usage
with open("/usr/local/lib/python3.10/site-packages/detrex/layers/multi_scale_deform_attn.py", "r") as f:
    lines = f.readlines()
    lineindex = 1
with open("/usr/local/lib/python3.10/site-packages/detrex/layers/multi_scale_deform_attn.py", "w") as f:
    for line in lines:
        if lineindex <= 406:
            f.write(line)
        lineindex += 1

# external lib functions
from detectron2.config import LazyConfig, instantiate
from detectron2.checkpoint import DetectionCheckpointer
from demo.demo import VisualizationDemo 
from detectron2.data.detection_utils import read_image

# custom lib functions, data, annotations etc.
config_file = os.getcwd() + '/projects/dino/configs/odor3_fn_l_lrf_384_fl4_5scale_50ep.py'
ckpt_pth = os.getcwd() + '/utils/focaldino_ep18.pth'

# load model/demo
try:
    cfg = LazyConfig.load(config_file)
except AssertionError as e:
    if str(e).startswith('Dataset '):
        pass
    else:
        raise e
model = instantiate(cfg.model)
model.to(cfg.train.device)
checkpointer = DetectionCheckpointer(model)
checkpointer.load(ckpt_pth)
model.eval()
demo = VisualizationDemo(
    model=model,
    min_size_test=800,
    max_size_test=1333,
    img_format='RGB',
    metadata_dataset='odor_test')

def read_json_categories(jsonFile):
    categories_dict = {}
    with open(jsonFile, 'r') as file:
        data = json.load(file)
        if 'categories' in data:
            categories_dict = data['categories']
    return categories_dict

def treat_grayscale(img):
    if len(img.shape) == 2:
        return np.stack((img,)*3, axis=-1)
    else:
        return img

def get_name_by_id(categories, id):
    for cg in categories:
        if cg['id'] == id:
            return cg['name']
    return 'Unknown'

def set_image_resolution(img, percentage):
    
    height, width = img.shape[:2]

    new_height = int(height * percentage)
    new_width = int(width * percentage)

    resized_img = cv2.resize(img, (new_width, new_height))
    return resized_img

def predict(link, url, threshold, image_resolution):

    categories = read_json_categories(os.getcwd() + '/annotations/instances_train2017.json')

    if(link):
        img = read_image(link)
    else:
        img = read_image(url)

    img_resized = set_image_resolution(img, image_resolution)
    img = treat_grayscale(img_resized)
    img = img[:, :, ::-1]

    predictions, visualized_output = demo.run_on_image(img, threshold)

    instances = predictions["instances"]
    pred_boxes = instances.get("pred_boxes")
    scores = instances.get("scores")
    pred_classes = instances.get("pred_classes")

    output_text = ""
    for i in range(len(pred_boxes)):
        id = pred_classes[i].item()
        class_name = get_name_by_id(categories, id)
        score = scores[i].item()
        output_text += f"{class_name}: {score:.2%}\n"

    output_json = []
    for i in range(len(pred_boxes)):
        id = pred_classes[i].item()
        class_name = get_name_by_id(categories, id)
        score = scores[i].item()
        box_coords = pred_boxes[i].tensor.tolist()
        output_json.append({
            "class_name": class_name,
            "score": score,
            "box_coordinates": box_coords
        })
    output_json = json.dumps(output_json, indent=4)
    
    return visualized_output.get_image(), output_text, output_json

gui = gr.Interface(
    predict,
    inputs=[
        gr.Image(type='filepath', label="Input Image"),
        gr.Textbox(type='text', label="Input Image (URL) - not considered if image was uploaded"),
        gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.05, label="Confidence Threshold"),
        gr.Slider(minimum=0.3, maximum=1.0, step=0.01, value=1.0, label="Image Size (30-100%)")
    ],
    outputs=[
        gr.Image(type='pil', label="Output Image"),
        gr.Textbox(type='text', label="Predictions"),
        gr.Textbox(type='text', label="Predictions (JSON)")
    ],
    examples=[
        ["https://puam-loris.aws.princeton.edu/loris/INV33883.jp2/full/full/0/default.jpg", "", 0.05, 1],
        ["https://explorer.odeuropa.eu/_next/image?url=%2Fimages%2Fodeuropa-homepage%2F15.jpg&w=1920&q=75", "", 0.2, 1],
        ["https://explorer.odeuropa.eu/_next/image?url=%2Fapi%2Fmedia%3Furl%3Dhttps%253A%252F%252Fcommons.wikimedia.org%252Fwiki%252FSpecial%253AFilePath%252FGrayling%252520Thymallus%252520thymallus.JPG%26width%3D300%26height%3D300&w=384&q=75", "", 0.5, 0.5],
        ["https://explorer.odeuropa.eu/_next/image?url=%2Fapi%2Fmedia%3Furl%3Dhttps%253A%252F%252Fcommons.wikimedia.org%252Fwiki%252FSpecial%253AFilePath%252FCigarette%252520in%252520white%252520ashtray.jpg%26width%3D300%26height%3D300&w=384&q=75", "", 0.05, 0.3]
    ],
)

if __name__ == "__main__":
    gui.launch(share=True)