|
import sys |
|
|
|
sys.path.append(".") |
|
sys.path.append("./scripts/notebooks") |
|
|
|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
import requests |
|
import numpy as np |
|
import time |
|
import matplotlib.pyplot as plt |
|
import io |
|
|
|
|
|
import logging |
|
import os |
|
|
|
import hydra |
|
from omegaconf import DictConfig, OmegaConf |
|
from src.arguments import ( |
|
global_setup, |
|
) |
|
|
|
from transformers import set_seed |
|
from src.train import prepare_model, prepare_processor |
|
from amcg import ScaAutomaticMaskCaptionGenerator |
|
import dotenv |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
model = None |
|
processor = None |
|
|
|
|
|
@hydra.main(version_base="1.3", config_path="../../src/conf", config_name="conf") |
|
def main(args: DictConfig) -> None: |
|
global model, processor |
|
|
|
|
|
logger.info(OmegaConf.to_yaml(args)) |
|
args, training_args, model_args = global_setup(args) |
|
|
|
|
|
set_seed(args.training.seed) |
|
|
|
|
|
logger.info(f"Try to load sas_key from .env file: {dotenv.load_dotenv('.env')}.") |
|
use_auth_token = os.getenv("USE_AUTH_TOKEN", False) |
|
|
|
processor = prepare_processor(model_args, use_auth_token) |
|
|
|
model = prepare_model(model_args, use_auth_token) |
|
return model, processor |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
cache_dir = ".model.cache" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
|
clip_model = "openai/clip-vit-base-patch32" |
|
|
|
img_url = "https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg" |
|
raw_image = Image.open(requests.get(img_url, stream=True).raw) |
|
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dtype = model.dtype |
|
|
|
|
|
NUM_OUTPUT_HEADS = 3 |
|
LIBRARIES = ["multimask_output"] |
|
DEFAULT_LIBRARIES = ["multimask_output"] |
|
|
|
auto_mask_caption_generator = ScaAutomaticMaskCaptionGenerator(model, processor) |
|
|
|
|
|
def click_and_assign(args, visual_prompt_mode, input_point_text, input_boxes_text, evt: gr.SelectData): |
|
x, y = evt.index |
|
if visual_prompt_mode == "point": |
|
input_point_text = f"{x},{y}" |
|
elif visual_prompt_mode == "box": |
|
if len(input_boxes_text.split(",")) == 2: |
|
input_boxes_text = f"{input_boxes_text},{x},{y}" |
|
else: |
|
input_boxes_text = f"{x},{y}" |
|
return input_point_text, input_boxes_text |
|
|
|
|
|
def box_and_run(input_image, args, input_boxes_text): |
|
x, y, x2, y2 = list(map(int, input_boxes_text.split(","))) |
|
input_boxes = [[[x, y, x2, y2]]] |
|
return run(args, input_image, input_boxes=input_boxes) |
|
|
|
|
|
def point_and_run(input_image, args, input_point_text): |
|
x, y = list(map(int, input_point_text.split(","))) |
|
input_points = [[[[x, y]]]] |
|
return run(args, input_image, input_points=input_points) |
|
|
|
|
|
def show_anns(anns): |
|
if len(anns) == 0: |
|
return |
|
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True) |
|
ax = plt.gca() |
|
ax.set_autoscale_on(False) |
|
|
|
img = np.ones((sorted_anns[0]["segmentation"].shape[0], sorted_anns[0]["segmentation"].shape[1], 4)) |
|
img[:, :, 3] = 0 |
|
for ann in sorted_anns: |
|
m = ann["segmentation"] |
|
color_mask = np.concatenate([np.random.random(3), [0.35]]) |
|
img[m] = color_mask |
|
if "caption" in ann: |
|
captions: str = ann["caption"] |
|
|
|
y, x = np.where(m) |
|
random_index = np.random.choice(range(len(x))) |
|
random_position = (x[random_index], y[random_index]) |
|
|
|
ax.text(*random_position, captions, color="white", fontsize=12, ha="center", va="center") |
|
ax.imshow(img) |
|
|
|
|
|
def auto_mode(input_image): |
|
np_input_image = np.array(input_image) |
|
outputs = auto_mask_caption_generator.generate(np_input_image) |
|
|
|
dpi = 80 |
|
height, width, _ = np_input_image.shape |
|
figsize = width / float(dpi), height / float(dpi) |
|
|
|
plt.figure(figsize=figsize) |
|
plt.imshow(input_image) |
|
show_anns(outputs) |
|
plt.axis("off") |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format="png") |
|
buf.seek(0) |
|
img = Image.open(buf) |
|
|
|
return img |
|
|
|
|
|
def run(args, input_image, input_points=None, input_boxes=None): |
|
if input_points is None and input_boxes is None: |
|
raise ValueError("input_points and input_boxes cannot be both None") |
|
|
|
multimask_output = "multimask_output" in args |
|
return_patches = "return_patches" in args |
|
|
|
inputs = processor(input_image, input_points=input_points, input_boxes=input_boxes, return_tensors="pt") |
|
for k, v in inputs.items(): |
|
if isinstance(v, torch.Tensor): |
|
|
|
inputs[k] = v.to(device, dtype if v.dtype == torch.float32 else v.dtype) |
|
tic = time.perf_counter() |
|
with torch.inference_mode(): |
|
model_outputs = model.generate(**inputs, multimask_output=multimask_output, num_beams=3) |
|
toc = time.perf_counter() |
|
print(f"Time taken: {(toc - tic)*1000:0.4f} ms") |
|
|
|
batch_size, num_masks, num_text_heads, num_tokens = model_outputs.sequences.shape |
|
batch_size, num_masks, num_mask_heads, *_ = model_outputs.pred_masks.shape |
|
if batch_size != 1 or num_masks != 1: |
|
raise ValueError("batch_size and num_masks must be 1") |
|
|
|
captions = processor.tokenizer.batch_decode( |
|
model_outputs.sequences.reshape(-1, num_tokens), skip_special_tokens=True |
|
) |
|
|
|
if num_text_heads < num_mask_heads: |
|
captions += [captions[-1]] * (num_mask_heads - num_text_heads) |
|
|
|
masks = processor.post_process_masks( |
|
model_outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"] |
|
) |
|
iou_scores = model_outputs.iou_scores |
|
|
|
patches = None |
|
|
|
outputs = [] |
|
|
|
|
|
iou_scores = iou_scores[0][0] |
|
for i in range(num_mask_heads): |
|
output = [input_image, [[masks[0][:, i].cpu().numpy(), f"{captions[i]}|iou:{iou_scores[i]:.4f}"]]] |
|
outputs.append(output) |
|
for i in range(num_mask_heads, NUM_OUTPUT_HEADS): |
|
output = [np.ones((1, 1)), []] |
|
outputs.append(output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(NUM_OUTPUT_HEADS): |
|
output = [np.ones((1, 1)), []] |
|
outputs.append(output) |
|
return outputs |
|
|
|
|
|
def fake_click_and_run(input_image, args, evt: gr.SelectData): |
|
outputs = [] |
|
|
|
num_heads = 1 |
|
for i in range(num_heads): |
|
output = [input_image, []] |
|
outputs.append(output) |
|
for i in range(num_heads, NUM_OUTPUT_HEADS): |
|
output = [input_image, []] |
|
outputs.append(output) |
|
return outputs |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("Welcome to the SCA Demo! We have two modes: **Prompt Mode** and **Anything Mode**.") |
|
|
|
input_image = gr.Image(value=raw_image, label="Input Image", interactive=True, type="pil", height=500) |
|
|
|
with gr.Tab("Prompt Mode"): |
|
visual_prompt_mode = gr.Radio(choices=["point", "box"], value="point", label="Visual Prompt Mode") |
|
args = gr.CheckboxGroup(choices=LIBRARIES, value=DEFAULT_LIBRARIES, label="SCA Arguments") |
|
input_point_text = gr.Textbox(lines=1, label="Input Points (x,y)", value="0,0") |
|
input_point_button = gr.Button(value="Run with Input Points") |
|
input_boxes_text = gr.Textbox(lines=1, label="Input Boxes (x,y,x2,y2)", value="0,0,100,100") |
|
input_boxes_button = gr.Button(value="Run with Input Boxes") |
|
|
|
output_images = [] |
|
with gr.Row(): |
|
for i in range(NUM_OUTPUT_HEADS): |
|
output_images.append(gr.AnnotatedImage(label=f"Output Image {i}", height=500)) |
|
with gr.Row(): |
|
for i in range(NUM_OUTPUT_HEADS): |
|
output_images.append(gr.AnnotatedImage(label=f"Output Image {i}", height=500)) |
|
with gr.Tab("Anything Mode"): |
|
run_anything_mode_button = gr.Button(value="Run Anything Mode") |
|
output_image_for_anything_mode = gr.Image( |
|
value=raw_image, label="Output Image", interactive=False, type="pil", height=500 |
|
) |
|
|
|
input_image.select( |
|
click_and_assign, |
|
[args, visual_prompt_mode, input_point_text, input_boxes_text], |
|
[input_point_text, input_boxes_text], |
|
) |
|
input_point_button.click(point_and_run, [input_image, args, input_point_text], [*output_images]) |
|
input_boxes_button.click(box_and_run, [input_image, args, input_boxes_text], [*output_images]) |
|
|
|
run_anything_mode_button.click(auto_mode, [input_image], [output_image_for_anything_mode]) |
|
|
|
demo.launch() |
|
|