SoM / app.py
SkalskiP's picture
Deleted 'gpt4v.py' and moved its functionalities to 'utils.py' and 'app.py'.
017684f
raw
history blame
6.14 kB
import os
from typing import List, Dict, Tuple, Any, Optional
import cv2
import gradio as gr
import numpy as np
import som
import supervision as sv
import torch
from segment_anything import sam_model_registry
from sam_utils import sam_interactive_inference, sam_inference
from utils import postprocess_masks, Visualizer
HOME = os.getenv("HOME")
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
SAM_CHECKPOINT = os.path.join(HOME, "app/weights/sam_vit_h_4b8939.pth")
# SAM_CHECKPOINT = "weights/sam_vit_h_4b8939.pth"
SAM_MODEL_TYPE = "vit_h"
ANNOTATED_IMAGE_KEY = "annotated_image"
DETECTIONS_KEY = "detections"
MARKDOWN = """
<div align='center'>
<h1>
<img
src='https://som-gpt4v.github.io/website/img/som_logo.png'
style='height:50px; display:inline-block'
/>
Set-of-Mark (SoM) Prompting Unleashes Extraordinary Visual Grounding in GPT-4V
</h1>
<br>
[<a href="https://arxiv.org/abs/2109.07529"> arXiv paper </a>]
[<a href="https://som-gpt4v.github.io"> project page </a>]
[<a href="https://github.com/roboflow/set-of-mark"> python package </a>]
[<a href="https://github.com/microsoft/SoM"> code </a>]
</div>
## 🚧 Roadmap
- [ ] Support for alphabetic labels
- [ ] Support for Semantic-SAM (multi-level)
- [ ] Support for mask filtering based on granularity
"""
SAM = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT).to(device=DEVICE)
def inference(
image_and_mask: Dict[str, np.ndarray],
annotation_mode: List[str],
mask_alpha: float
) -> Tuple[Tuple[np.ndarray, List[Tuple[np.ndarray, str]]], Dict[str, Any]]:
image = image_and_mask['image']
mask = cv2.cvtColor(image_and_mask['mask'], cv2.COLOR_RGB2GRAY)
is_interactive = not np.all(mask == 0)
visualizer = Visualizer(mask_opacity=mask_alpha)
if is_interactive:
detections = sam_interactive_inference(
image=image,
mask=mask,
model=SAM)
else:
detections = sam_inference(
image=image,
model=SAM
)
detections = postprocess_masks(
detections=detections)
bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
annotated_image = visualizer.visualize(
image=bgr_image,
detections=detections,
with_box="Box" in annotation_mode,
with_mask="Mask" in annotation_mode,
with_polygon="Polygon" in annotation_mode,
with_label="Mark" in annotation_mode)
annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
state = {
ANNOTATED_IMAGE_KEY: annotated_image,
DETECTIONS_KEY: detections
}
return (annotated_image, []), state
def prompt(
message: str,
history: List[List[str]],
state: Dict[str, Any],
api_key: Optional[str]
) -> str:
if api_key == "":
return "⚠️ Please set your OpenAI API key first"
if state is None or ANNOTATED_IMAGE_KEY not in state:
return "⚠️ Please generate SoM visual prompt first"
return som.prompt_image(
api_key=api_key,
image=cv2.cvtColor(state[ANNOTATED_IMAGE_KEY], cv2.COLOR_BGR2RGB),
prompt=message
)
def on_image_input_clear():
return None, {}
def highlight(
state: Dict[str, Any],
history: List[List[str]]
) -> Optional[Tuple[np.ndarray, List[Tuple[np.ndarray, str]]]]:
if DETECTIONS_KEY not in state or ANNOTATED_IMAGE_KEY not in state:
return None
detections: sv.Detections = state[DETECTIONS_KEY]
annotated_image: np.ndarray = state[ANNOTATED_IMAGE_KEY]
if len(history) == 0:
return None
text = history[-1][-1]
relevant_masks = som.extract_relevant_masks(
text=text,
detections=detections
)
relevant_masks = [
(mask, mark)
for mark, mask
in relevant_masks.items()
]
return annotated_image, relevant_masks
image_input = gr.Image(
label="Input",
type="numpy",
tool="sketch",
interactive=True,
brush_radius=20.0,
brush_color="#FFFFFF",
height=512
)
checkbox_annotation_mode = gr.CheckboxGroup(
choices=["Mark", "Polygon", "Mask", "Box"],
value=['Mark'],
label="Annotation Mode")
slider_mask_alpha = gr.Slider(
minimum=0,
maximum=1,
value=0.05,
label="Mask Alpha")
image_output = gr.AnnotatedImage(
label="SoM Visual Prompt",
color_map={
str(i): sv.ColorPalette.default().by_idx(i).as_hex()
for i in range(64)
},
height=512
)
openai_api_key = gr.Textbox(
show_label=False,
placeholder="Before you start chatting, set your OpenAI API key here",
lines=1,
type="password")
chatbot = gr.Chatbot(
label="GPT-4V + SoM",
height=256)
generate_button = gr.Button("Generate Marks")
highlight_button = gr.Button("Highlight Marks")
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
inference_state = gr.State({})
with gr.Row():
with gr.Column():
image_input.render()
with gr.Accordion(
label="Detailed prompt settings (e.g., mark type)",
open=False):
with gr.Row():
checkbox_annotation_mode.render()
with gr.Row():
slider_mask_alpha.render()
with gr.Column():
image_output.render()
generate_button.render()
highlight_button.render()
with gr.Row():
openai_api_key.render()
with gr.Row():
gr.ChatInterface(
chatbot=chatbot,
fn=prompt,
additional_inputs=[inference_state, openai_api_key])
generate_button.click(
fn=inference,
inputs=[image_input, checkbox_annotation_mode, slider_mask_alpha],
outputs=[image_output, inference_state])
image_input.clear(
fn=on_image_input_clear,
outputs=[image_output, inference_state]
)
highlight_button.click(
fn=highlight,
inputs=[inference_state, chatbot],
outputs=[image_output])
demo.queue().launch(debug=False, show_error=True)