Spaces:
Runtime error
Runtime error
File size: 5,578 Bytes
bc7eeae a4273c0 7a3a9e6 4a302c4 39a2576 1250ac1 a4273c0 a9d3066 39a2576 a9d3066 39a2576 a9d3066 bc7eeae 1250ac1 bc7eeae 1250ac1 bc7eeae a4f6a62 5ef2f9f 1250ac1 bc7eeae 1250ac1 ac2aecb 1250ac1 ac2aecb ea66648 4e43b2f bc7eeae 9f0ad39 bc7eeae 9f0ad39 8df5b71 bc7eeae 9f0ad39 c1ce821 8df5b71 192cd9b 8df5b71 9c4307a 9b234af 9c4307a 0550f3f ac2aecb 0550f3f bc7eeae 0550f3f 8f2c4bb ac2aecb 8f2c4bb 1250ac1 f254fdd 9b234af bd9dfee bc7eeae 9c4307a ea66648 8f2c4bb 21a7f51 9b234af f254fdd 1c3a97b 3cd452d f254fdd f7d0b4b 4f3b9d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
# environment setup
import os
os.system("pip install torch torchvision")
os.system("git clone https://github.com/IDEA-Research/detrex.git")
os.system("python3.10 -m pip install git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2")
os.system("python3.10 -m pip install git+https://github.com/IDEA-Research/detrex.git@v0.5.0#egg=detrex")
os.system("git submodule sync")
os.system("git submodule update --init")
os.system("pip install Pillow==9.5.0")
os.system("pip install fairscale")
os.system("pip install opencv-python")
os.system("cp -rf '/home/user/app/utils/data' '/usr/local/lib/python3.10/site-packages/detrex/config/configs/common/'")
# import libs
import cv2
import json
import numpy as np
import gradio as gr
import warnings
warnings.filterwarnings("ignore")
# adapt files for cpu usage
with open("/usr/local/lib/python3.10/site-packages/detrex/layers/multi_scale_deform_attn.py", "r") as f:
lines = f.readlines()
lineindex = 1
with open("/usr/local/lib/python3.10/site-packages/detrex/layers/multi_scale_deform_attn.py", "w") as f:
for line in lines:
if lineindex <= 406:
f.write(line)
lineindex += 1
# external lib functions
from detectron2.config import LazyConfig, instantiate
from detectron2.checkpoint import DetectionCheckpointer
from demo.demo import VisualizationDemo
from detectron2.data.detection_utils import read_image
# custom lib functions, data, annotations etc.
config_file = os.getcwd() + '/projects/dino/configs/odor3_fn_l_lrf_384_fl4_5scale_50ep.py'
ckpt_pth = os.getcwd() + '/utils/focaldino_ep18.pth'
# load model/demo
try:
cfg = LazyConfig.load(config_file)
except AssertionError as e:
if str(e).startswith('Dataset '):
pass
else:
raise e
model = instantiate(cfg.model)
model.to(cfg.train.device)
checkpointer = DetectionCheckpointer(model)
checkpointer.load(ckpt_pth)
model.eval()
demo = VisualizationDemo(
model=model,
min_size_test=800,
max_size_test=1333,
img_format='RGB',
metadata_dataset='odor_test')
def read_json_categories(jsonFile):
categories_dict = {}
with open(jsonFile, 'r') as file:
data = json.load(file)
if 'categories' in data:
categories_dict = data['categories']
return categories_dict
def treat_grayscale(img):
if len(img.shape) == 2:
return np.stack((img,)*3, axis=-1)
else:
return img
def get_name_by_id(categories, id):
for cg in categories:
if cg['id'] == id:
return cg['name']
return 'Unknown'
def set_image_resolution(img, percentage):
height, width = img.shape[:2]
new_height = int(height * percentage)
new_width = int(width * percentage)
resized_img = cv2.resize(img, (new_width, new_height))
return resized_img
def predict(link, url, threshold, image_resolution):
categories = read_json_categories(os.getcwd() + '/annotations/instances_train2017.json')
if(link):
img = read_image(link)
else:
img = read_image(url)
img_resized = set_image_resolution(img, image_resolution)
img = treat_grayscale(img_resized)
img = img[:, :, ::-1]
predictions, visualized_output = demo.run_on_image(img, threshold)
instances = predictions["instances"]
pred_boxes = instances.get("pred_boxes")
scores = instances.get("scores")
pred_classes = instances.get("pred_classes")
output_text = ""
for i in range(len(pred_boxes)):
id = pred_classes[i].item()
class_name = get_name_by_id(categories, id)
score = scores[i].item()
output_text += f"{class_name}: {score:.2%}\n"
output_json = []
for i in range(len(pred_boxes)):
id = pred_classes[i].item()
class_name = get_name_by_id(categories, id)
score = scores[i].item()
box_coords = pred_boxes[i].tensor.tolist()
output_json.append({
"class_name": class_name,
"score": score,
"box_coordinates": box_coords
})
output_json = json.dumps(output_json, indent=4)
return visualized_output.get_image(), output_text, output_json
gui = gr.Interface(
predict,
inputs=[
gr.Image(type='filepath', label="Input Image"),
gr.Textbox(type='text', label="Input Image (URL) - not considered if image was uploaded"),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.05, label="Confidence Threshold"),
gr.Slider(minimum=0.3, maximum=1.0, step=0.01, value=1.0, label="Image Size (30-100%)")
],
outputs=[
gr.Image(type='pil', label="Output Image"),
gr.Textbox(type='text', label="Predictions"),
gr.Textbox(type='text', label="Predictions (JSON)")
],
examples=[
["https://puam-loris.aws.princeton.edu/loris/INV33883.jp2/full/full/0/default.jpg", "", 0.05, 1],
["https://explorer.odeuropa.eu/_next/image?url=%2Fimages%2Fodeuropa-homepage%2F15.jpg&w=1920&q=75", "", 0.2, 1],
["https://explorer.odeuropa.eu/_next/image?url=%2Fapi%2Fmedia%3Furl%3Dhttps%253A%252F%252Fcommons.wikimedia.org%252Fwiki%252FSpecial%253AFilePath%252FGrayling%252520Thymallus%252520thymallus.JPG%26width%3D300%26height%3D300&w=384&q=75", "", 0.5, 0.5],
["https://explorer.odeuropa.eu/_next/image?url=%2Fapi%2Fmedia%3Furl%3Dhttps%253A%252F%252Fcommons.wikimedia.org%252Fwiki%252FSpecial%253AFilePath%252FCigarette%252520in%252520white%252520ashtray.jpg%26width%3D300%26height%3D300&w=384&q=75", "", 0.05, 0.3]
],
)
if __name__ == "__main__":
gui.launch(share=True)
|