Spaces:
Runtime error
Runtime error
# environment setup | |
import os | |
os.system("pip install torch torchvision") | |
os.system("git clone https://github.com/IDEA-Research/detrex.git") | |
os.system("python3.10 -m pip install git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2") | |
os.system("python3.10 -m pip install git+https://github.com/IDEA-Research/detrex.git@v0.5.0#egg=detrex") | |
os.system("git submodule sync") | |
os.system("git submodule update --init") | |
os.system("pip install Pillow==9.5.0") | |
os.system("pip install fairscale") | |
os.system("pip install opencv-python") | |
os.system("cp -rf '/home/user/app/utils/data' '/usr/local/lib/python3.10/site-packages/detrex/config/configs/common/'") | |
# import libs | |
import cv2 | |
import json | |
import numpy as np | |
import gradio as gr | |
import warnings | |
warnings.filterwarnings("ignore") | |
# adapt files for cpu usage | |
with open("/usr/local/lib/python3.10/site-packages/detrex/layers/multi_scale_deform_attn.py", "r") as f: | |
lines = f.readlines() | |
lineindex = 1 | |
with open("/usr/local/lib/python3.10/site-packages/detrex/layers/multi_scale_deform_attn.py", "w") as f: | |
for line in lines: | |
if lineindex <= 406: | |
f.write(line) | |
lineindex += 1 | |
# external lib functions | |
from detectron2.config import LazyConfig, instantiate | |
from detectron2.checkpoint import DetectionCheckpointer | |
from demo.demo import VisualizationDemo | |
from detectron2.data.detection_utils import read_image | |
# custom lib functions, data, annotations etc. | |
config_file = os.getcwd() + '/projects/dino/configs/odor3_fn_l_lrf_384_fl4_5scale_50ep.py' | |
ckpt_pth = os.getcwd() + '/utils/focaldino_ep18.pth' | |
# load model/demo | |
try: | |
cfg = LazyConfig.load(config_file) | |
except AssertionError as e: | |
if str(e).startswith('Dataset '): | |
pass | |
else: | |
raise e | |
model = instantiate(cfg.model) | |
model.to(cfg.train.device) | |
checkpointer = DetectionCheckpointer(model) | |
checkpointer.load(ckpt_pth) | |
model.eval() | |
demo = VisualizationDemo( | |
model=model, | |
min_size_test=800, | |
max_size_test=1333, | |
img_format='RGB', | |
metadata_dataset='odor_test') | |
def read_json_categories(jsonFile): | |
categories_dict = {} | |
with open(jsonFile, 'r') as file: | |
data = json.load(file) | |
if 'categories' in data: | |
categories_dict = data['categories'] | |
return categories_dict | |
def treat_grayscale(img): | |
if len(img.shape) == 2: | |
return np.stack((img,)*3, axis=-1) | |
else: | |
return img | |
def get_name_by_id(categories, id): | |
for cg in categories: | |
if cg['id'] == id: | |
return cg['name'] | |
return 'Unknown' | |
def set_image_resolution(img, percentage): | |
height, width = img.shape[:2] | |
new_height = int(height * percentage) | |
new_width = int(width * percentage) | |
resized_img = cv2.resize(img, (new_width, new_height)) | |
return resized_img | |
def predict(link, url, threshold, image_resolution): | |
categories = read_json_categories(os.getcwd() + '/annotations/instances_train2017.json') | |
if(link): | |
img = read_image(link) | |
else: | |
img = read_image(url) | |
img_resized = set_image_resolution(img, image_resolution) | |
img = treat_grayscale(img_resized) | |
img = img[:, :, ::-1] | |
predictions, visualized_output = demo.run_on_image(img, threshold) | |
instances = predictions["instances"] | |
pred_boxes = instances.get("pred_boxes") | |
scores = instances.get("scores") | |
pred_classes = instances.get("pred_classes") | |
output_text = "" | |
for i in range(len(pred_boxes)): | |
id = pred_classes[i].item() | |
class_name = get_name_by_id(categories, id) | |
score = scores[i].item() | |
output_text += f"{class_name}: {score:.2%}\n" | |
output_json = [] | |
for i in range(len(pred_boxes)): | |
id = pred_classes[i].item() | |
class_name = get_name_by_id(categories, id) | |
score = scores[i].item() | |
box_coords = pred_boxes[i].tensor.tolist() | |
output_json.append({ | |
"class_name": class_name, | |
"score": score, | |
"box_coordinates": box_coords | |
}) | |
output_json = json.dumps(output_json, indent=4) | |
return visualized_output.get_image(), output_text, output_json | |
gui = gr.Interface( | |
predict, | |
inputs=[ | |
gr.Image(type='filepath', label="Input Image"), | |
gr.Textbox(type='text', label="Input Image (URL) - not considered if image was uploaded"), | |
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.05, label="Confidence Threshold"), | |
gr.Slider(minimum=0.3, maximum=1.0, step=0.01, value=1.0, label="Image Size (30-100%)") | |
], | |
outputs=[ | |
gr.Image(type='pil', label="Output Image"), | |
gr.Textbox(type='text', label="Predictions"), | |
gr.Textbox(type='text', label="Predictions (JSON)") | |
], | |
examples=[ | |
["https://puam-loris.aws.princeton.edu/loris/INV33883.jp2/full/full/0/default.jpg", "", 0.05, 1], | |
["https://explorer.odeuropa.eu/_next/image?url=%2Fimages%2Fodeuropa-homepage%2F15.jpg&w=1920&q=75", "", 0.2, 1], | |
["https://explorer.odeuropa.eu/_next/image?url=%2Fapi%2Fmedia%3Furl%3Dhttps%253A%252F%252Fcommons.wikimedia.org%252Fwiki%252FSpecial%253AFilePath%252FGrayling%252520Thymallus%252520thymallus.JPG%26width%3D300%26height%3D300&w=384&q=75", "", 0.5, 0.5], | |
["https://explorer.odeuropa.eu/_next/image?url=%2Fapi%2Fmedia%3Furl%3Dhttps%253A%252F%252Fcommons.wikimedia.org%252Fwiki%252FSpecial%253AFilePath%252FCigarette%252520in%252520white%252520ashtray.jpg%26width%3D300%26height%3D300&w=384&q=75", "", 0.05, 0.3] | |
], | |
) | |
if __name__ == "__main__": | |
gui.launch(share=True) | |