desco / app.py
zdou0830's picture
clean
971be17
raw
history blame
2.59 kB
# Reference: https://huggingface.co/spaces/haotiz/glip-zeroshot-demo/blob/main/app.py
import requests
import os
from io import BytesIO
from PIL import Image
import numpy as np
from pathlib import Path
import gradio as gr
import warnings
warnings.filterwarnings("ignore")
os.system("python setup.py build develop --user")
from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.engine.predictor_glip import GLIPDemo
config_file = "configs/pretrain_new/desco_glip.yaml"
weight_file = "MODEL/desco_glip_tiny.pth"
# update the config options with the config file
# manual override some options
cfg.local_rank = 0
cfg.num_gpus = 1
cfg.merge_from_file(config_file)
cfg.merge_from_list(["MODEL.WEIGHT", weight_file])
cfg.merge_from_list(["MODEL.DEVICE", "cuda"])
glip_demo = GLIPDemo(
cfg,
min_image_size=800,
confidence_threshold=0.7,
show_mask_heatmaps=False
)
config_file = "configs/pretrain_new/desco_fiber.yaml"
weight_file = "MODEL/desco_fiber_base.pth"
from copy import deepcopy
cfg = deepcopy(cfg)
cfg.merge_from_file(config_file)
cfg.merge_from_list(["MODEL.WEIGHT", weight_file])
cfg.merge_from_list(["MODEL.DEVICE", "cuda"])
fiber_demo = GLIPDemo(
cfg,
min_image_size=800,
confidence_threshold=0.7,
show_mask_heatmaps=False
)
def predict(image, text, ground_tokens=""):
ground_tokens = None if ground_tokens.strip() == "" else ground_tokens.strip().split(";")
result, _ = glip_demo.run_on_web_image(deepcopy(image[:, :, [2, 1, 0]]), text, 0.5, specified_tokens)
fiber_result, _ = fiber_demo.run_on_web_image(deepcopy(image[:, :, [2, 1, 0]]), text, 0.5, specified_tokens)
return result[:, :, [2, 1, 0]], fiber_result[:, :, [2, 1, 0]]
image = gr.inputs.Image()
gr.Interface(
description="Object Recognition with DesCo (https://github.com/liunian-harold-li/DesCo)",
fn=predict,
inputs=["image", "text", "text"],
outputs=[
gr.outputs.Image(
type="pil",
label="DesCo-GLIP"
),
gr.outputs.Image(
type="pil",
label="DesCo-FIBER"
),
],
examples=[
["./coco_000000281759.jpg", "A green umbrella. A pink striped umbrella. A plain white umbrella.", ""],
["./coco_000000281759.jpg", "a flowery top. A blue dress. An orange shirt .", ""],
["./coco_000000281759.jpg", "a car . An electricity box .", ""],
["./1.jpg", "a train besides sidewalk", "train;sidewalk"],
],
article=Path("docs/intro.md").read_text()
).launch()
# ).launch(server_name="0.0.0.0", server_port=7000, share=True)