# environment setup import os os.system("pip install torch torchvision") os.system("git clone https://github.com/IDEA-Research/detrex.git") os.system("python3.10 -m pip install git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2") os.system("python3.10 -m pip install git+https://github.com/IDEA-Research/detrex.git@v0.5.0#egg=detrex") os.system("git submodule sync") os.system("git submodule update --init") os.system("pip install Pillow==9.5.0") os.system("pip install fairscale") os.system("pip install opencv-python") os.system("cp -rf '/home/user/app/utils/data' '/usr/local/lib/python3.10/site-packages/detrex/config/configs/common/'") # import libs import cv2 import json import numpy as np import gradio as gr import warnings warnings.filterwarnings("ignore") # adapt files for cpu usage with open("/usr/local/lib/python3.10/site-packages/detrex/layers/multi_scale_deform_attn.py", "r") as f: lines = f.readlines() lineindex = 1 with open("/usr/local/lib/python3.10/site-packages/detrex/layers/multi_scale_deform_attn.py", "w") as f: for line in lines: if lineindex <= 406: f.write(line) lineindex += 1 # external lib functions from detectron2.config import LazyConfig, instantiate from detectron2.checkpoint import DetectionCheckpointer from demo.demo import VisualizationDemo from detectron2.data.detection_utils import read_image # custom lib functions, data, annotations etc. config_file = os.getcwd() + '/projects/dino/configs/odor3_fn_l_lrf_384_fl4_5scale_50ep.py' ckpt_pth = os.getcwd() + '/utils/focaldino_ep18.pth' # load model/demo try: cfg = LazyConfig.load(config_file) except AssertionError as e: if str(e).startswith('Dataset '): pass else: raise e model = instantiate(cfg.model) model.to(cfg.train.device) checkpointer = DetectionCheckpointer(model) checkpointer.load(ckpt_pth) model.eval() demo = VisualizationDemo( model=model, min_size_test=800, max_size_test=1333, img_format='RGB', metadata_dataset='odor_test') def read_json_categories(jsonFile): categories_dict = {} with open(jsonFile, 'r') as file: data = json.load(file) if 'categories' in data: categories_dict = data['categories'] return categories_dict def treat_grayscale(img): if len(img.shape) == 2: return np.stack((img,)*3, axis=-1) else: return img def get_name_by_id(categories, id): for cg in categories: if cg['id'] == id: return cg['name'] return 'Unknown' def set_image_resolution(img, percentage): height, width = img.shape[:2] new_height = int(height * percentage) new_width = int(width * percentage) resized_img = cv2.resize(img, (new_width, new_height)) return resized_img def predict(link, url, threshold, image_resolution): categories = read_json_categories(os.getcwd() + '/annotations/instances_train2017.json') if(link): img = read_image(link) else: img = read_image(url) img_resized = set_image_resolution(img, image_resolution) img = treat_grayscale(img_resized) img = img[:, :, ::-1] predictions, visualized_output = demo.run_on_image(img, threshold) instances = predictions["instances"] pred_boxes = instances.get("pred_boxes") scores = instances.get("scores") pred_classes = instances.get("pred_classes") output_text = "" for i in range(len(pred_boxes)): id = pred_classes[i].item() class_name = get_name_by_id(categories, id) score = scores[i].item() output_text += f"{class_name}: {score:.2%}\n" output_json = [] for i in range(len(pred_boxes)): id = pred_classes[i].item() class_name = get_name_by_id(categories, id) score = scores[i].item() box_coords = pred_boxes[i].tensor.tolist() output_json.append({ "class_name": class_name, "score": score, "box_coordinates": box_coords }) output_json = json.dumps(output_json, indent=4) return visualized_output.get_image(), output_text, output_json gui = gr.Interface( predict, inputs=[ gr.Image(type='filepath', label="Input Image"), gr.Textbox(type='text', label="Input Image (URL) - not considered if image was uploaded"), gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.05, label="Confidence Threshold"), gr.Slider(minimum=0.3, maximum=1.0, step=0.01, value=1.0, label="Image Size (30-100%)") ], outputs=[ gr.Image(type='pil', label="Output Image"), gr.Textbox(type='text', label="Predictions"), gr.Textbox(type='text', label="Predictions (JSON)") ], examples=[ ["https://puam-loris.aws.princeton.edu/loris/INV33883.jp2/full/full/0/default.jpg", "", 0.05, 1], ["https://explorer.odeuropa.eu/_next/image?url=%2Fimages%2Fodeuropa-homepage%2F15.jpg&w=1920&q=75", "", 0.2, 1], ["https://explorer.odeuropa.eu/_next/image?url=%2Fapi%2Fmedia%3Furl%3Dhttps%253A%252F%252Fcommons.wikimedia.org%252Fwiki%252FSpecial%253AFilePath%252FGrayling%252520Thymallus%252520thymallus.JPG%26width%3D300%26height%3D300&w=384&q=75", "", 0.5, 0.5], ["https://explorer.odeuropa.eu/_next/image?url=%2Fapi%2Fmedia%3Furl%3Dhttps%253A%252F%252Fcommons.wikimedia.org%252Fwiki%252FSpecial%253AFilePath%252FCigarette%252520in%252520white%252520ashtray.jpg%26width%3D300%26height%3D300&w=384&q=75", "", 0.05, 0.3] ], ) if __name__ == "__main__": gui.launch(share=True)