#!/usr/bin/env python from __future__ import annotations import argparse import os import pathlib import subprocess import tarfile if os.getenv('SYSTEM') == 'spaces': import mim mim.uninstall('mmcv-full', confirm_yes=True) mim.install('mmcv-full==1.5.2', is_yes=True) subprocess.call('pip uninstall -y opencv-python'.split()) subprocess.call('pip uninstall -y opencv-python-headless'.split()) subprocess.call('pip install opencv-python-headless==4.5.5.64'.split()) import cv2 import gradio as gr import numpy as np from mmdet.apis import init_detector, inference_detector from utils import show_result import mmcv from mmcv import Config import os.path as osp DESCRIPTION = '''# OpenPSG This is an official demo for [OpenPSG](https://github.com/Jingkang50/OpenPSG). overview ''' FOOTER = 'visitor badge' def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cpu') parser.add_argument('--theme', type=str) parser.add_argument('--share', action='store_true') parser.add_argument('--port', type=int) parser.add_argument('--disable-queue', dest='enable_queue', action='store_false') return parser.parse_args() def update_input_image(image: np.ndarray) -> dict: if image is None: return gr.Image.update(value=None) scale = 1500 / max(image.shape[:2]) if scale < 1: image = cv2.resize(image, None, fx=scale, fy=scale) return gr.Image.update(value=image) def set_example_image(example: list) -> dict: return gr.Image.update(value=example[0]) def infer(model, input_image, num_rel): result = inference_detector(model, input_image) return show_result(input_image, result, is_one_stage=True, num_rel=num_rel, show=True ) def main(): args = parse_args() model_ckt ='OpenPSG/checkpoints/epoch_60.pth' cfg = Config.fromfile('OpenPSG/configs/psgtr/psgtr_r50_psg_inference.py') model = init_detector(cfg, model_ckt, device=args.device) with gr.Blocks(theme=args.theme, css='style.css') as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): with gr.Row(): input_image = gr.Image(label='Input Image', type='numpy') with gr.Group(): with gr.Row(): num_rel = gr.Slider( 5, 100, step=5, value=20, label='Number of Relations') with gr.Row(): run_button = gr.Button(value='Run') # prediction_results = gr.Variable() with gr.Column(): with gr.Row(): # visualization = gr.Image(label='Result', type='numpy') result = gr.Gallery(label='Result', type='numpy') with gr.Row(): paths = sorted(pathlib.Path('images').rglob('*.jpg')) example_images = gr.Dataset(components=[input_image], samples=[[path.as_posix()] for path in paths]) gr.Markdown(FOOTER) input_image.change(fn=update_input_image, inputs=input_image, outputs=input_image) run_button.click(fn=infer, inputs=[ model, input_image ], outputs=result) example_images.click(fn=set_example_image, inputs=example_images, outputs=input_image) demo.launch( enable_queue=args.enable_queue, server_port=args.port, share=args.share, ) if __name__ == '__main__': main()