FAU_PR_Detrex / app.py
ThRi's picture
launch Gui with share = true
4f3b9d4 verified
# environment setup
import os
os.system("pip install torch torchvision")
os.system("git clone https://github.com/IDEA-Research/detrex.git")
os.system("python3.10 -m pip install git+https://github.com/facebookresearch/detectron2.git@v0.6#egg=detectron2")
os.system("python3.10 -m pip install git+https://github.com/IDEA-Research/detrex.git@v0.5.0#egg=detrex")
os.system("git submodule sync")
os.system("git submodule update --init")
os.system("pip install Pillow==9.5.0")
os.system("pip install fairscale")
os.system("pip install opencv-python")
os.system("cp -rf '/home/user/app/utils/data' '/usr/local/lib/python3.10/site-packages/detrex/config/configs/common/'")
# import libs
import cv2
import json
import numpy as np
import gradio as gr
import warnings
warnings.filterwarnings("ignore")
# adapt files for cpu usage
with open("/usr/local/lib/python3.10/site-packages/detrex/layers/multi_scale_deform_attn.py", "r") as f:
lines = f.readlines()
lineindex = 1
with open("/usr/local/lib/python3.10/site-packages/detrex/layers/multi_scale_deform_attn.py", "w") as f:
for line in lines:
if lineindex <= 406:
f.write(line)
lineindex += 1
# external lib functions
from detectron2.config import LazyConfig, instantiate
from detectron2.checkpoint import DetectionCheckpointer
from demo.demo import VisualizationDemo
from detectron2.data.detection_utils import read_image
# custom lib functions, data, annotations etc.
config_file = os.getcwd() + '/projects/dino/configs/odor3_fn_l_lrf_384_fl4_5scale_50ep.py'
ckpt_pth = os.getcwd() + '/utils/focaldino_ep18.pth'
# load model/demo
try:
cfg = LazyConfig.load(config_file)
except AssertionError as e:
if str(e).startswith('Dataset '):
pass
else:
raise e
model = instantiate(cfg.model)
model.to(cfg.train.device)
checkpointer = DetectionCheckpointer(model)
checkpointer.load(ckpt_pth)
model.eval()
demo = VisualizationDemo(
model=model,
min_size_test=800,
max_size_test=1333,
img_format='RGB',
metadata_dataset='odor_test')
def read_json_categories(jsonFile):
categories_dict = {}
with open(jsonFile, 'r') as file:
data = json.load(file)
if 'categories' in data:
categories_dict = data['categories']
return categories_dict
def treat_grayscale(img):
if len(img.shape) == 2:
return np.stack((img,)*3, axis=-1)
else:
return img
def get_name_by_id(categories, id):
for cg in categories:
if cg['id'] == id:
return cg['name']
return 'Unknown'
def set_image_resolution(img, percentage):
height, width = img.shape[:2]
new_height = int(height * percentage)
new_width = int(width * percentage)
resized_img = cv2.resize(img, (new_width, new_height))
return resized_img
def predict(link, url, threshold, image_resolution):
categories = read_json_categories(os.getcwd() + '/annotations/instances_train2017.json')
if(link):
img = read_image(link)
else:
img = read_image(url)
img_resized = set_image_resolution(img, image_resolution)
img = treat_grayscale(img_resized)
img = img[:, :, ::-1]
predictions, visualized_output = demo.run_on_image(img, threshold)
instances = predictions["instances"]
pred_boxes = instances.get("pred_boxes")
scores = instances.get("scores")
pred_classes = instances.get("pred_classes")
output_text = ""
for i in range(len(pred_boxes)):
id = pred_classes[i].item()
class_name = get_name_by_id(categories, id)
score = scores[i].item()
output_text += f"{class_name}: {score:.2%}\n"
output_json = []
for i in range(len(pred_boxes)):
id = pred_classes[i].item()
class_name = get_name_by_id(categories, id)
score = scores[i].item()
box_coords = pred_boxes[i].tensor.tolist()
output_json.append({
"class_name": class_name,
"score": score,
"box_coordinates": box_coords
})
output_json = json.dumps(output_json, indent=4)
return visualized_output.get_image(), output_text, output_json
gui = gr.Interface(
predict,
inputs=[
gr.Image(type='filepath', label="Input Image"),
gr.Textbox(type='text', label="Input Image (URL) - not considered if image was uploaded"),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.05, label="Confidence Threshold"),
gr.Slider(minimum=0.3, maximum=1.0, step=0.01, value=1.0, label="Image Size (30-100%)")
],
outputs=[
gr.Image(type='pil', label="Output Image"),
gr.Textbox(type='text', label="Predictions"),
gr.Textbox(type='text', label="Predictions (JSON)")
],
examples=[
["https://puam-loris.aws.princeton.edu/loris/INV33883.jp2/full/full/0/default.jpg", "", 0.05, 1],
["https://explorer.odeuropa.eu/_next/image?url=%2Fimages%2Fodeuropa-homepage%2F15.jpg&w=1920&q=75", "", 0.2, 1],
["https://explorer.odeuropa.eu/_next/image?url=%2Fapi%2Fmedia%3Furl%3Dhttps%253A%252F%252Fcommons.wikimedia.org%252Fwiki%252FSpecial%253AFilePath%252FGrayling%252520Thymallus%252520thymallus.JPG%26width%3D300%26height%3D300&w=384&q=75", "", 0.5, 0.5],
["https://explorer.odeuropa.eu/_next/image?url=%2Fapi%2Fmedia%3Furl%3Dhttps%253A%252F%252Fcommons.wikimedia.org%252Fwiki%252FSpecial%253AFilePath%252FCigarette%252520in%252520white%252520ashtray.jpg%26width%3D300%26height%3D300&w=384&q=75", "", 0.05, 0.3]
],
)
if __name__ == "__main__":
gui.launch(share=True)