OCTO / app.py
adityas's picture
add demo app code
9798f42
import os
if not os.path.isdir("weights"):
os.mkdir("weights")
os.system("python -m pip install --upgrade pip")
os.system(
"wget https://raw.githubusercontent.com/asharma381/cs291I/main/backend/original_images/000749.png"
)
os.system(
"wget -q -O weights/sam_vit_h_4b8939.pth https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
)
os.system(
"wget -q -O weights/ram_plus_swin_large_14m.pth https://huggingface.co/xinyu1205/recognize-anything-plus-model/resolve/main/ram_plus_swin_large_14m.pth"
)
os.system(
"wget -q -O weights/groundingdino_swint_ogc.pth https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"
)
os.system("pip install git+https://github.com/xinyu1205/recognize-anything.git")
os.system("pip install git+https://github.com/IDEA-Research/GroundingDINO.git")
os.system("pip install git+https://github.com/facebookresearch/segment-anything.git")
os.system("pip install openai==0.27.4")
os.system("pip install tenacity")
from typing import List, Tuple
import cv2
import gradio as gr
import groundingdino.config.GroundingDINO_SwinT_OGC
import numpy as np
import openai
import torch
from groundingdino.util.inference import Model
from PIL import Image, ImageDraw
from ram import get_transform
from ram import inference_ram as inference
from ram.models import ram_plus
from scipy.spatial.distance import cdist
from segment_anything import SamPredictor, sam_model_registry
from supervision import Detections
from tenacity import retry, wait_fixed
device = "cuda" if torch.cuda.is_available() else "cpu"
ram_model = None
ram_threshold_multiplier = 1
gdino_model = None
sam_model = None
sam_predictor = None
print("CUDA Available:", torch.cuda.is_available())
def get_tags_ram(
image: Image.Image, threshold_multiplier=0.8, weights_folder="weights"
) -> List[str]:
global ram_model, ram_threshold_multiplier
if ram_model is None:
print("Loading RAM++ Model...")
ram_model = ram_plus(
pretrained=f"{weights_folder}/ram_plus_swin_large_14m.pth",
vit="swin_l",
image_size=384,
)
ram_model.eval()
ram_model = ram_model.to(device)
ram_model.class_threshold *= threshold_multiplier / ram_threshold_multiplier
ram_threshold_multiplier = threshold_multiplier
transform = get_transform()
image = transform(image).unsqueeze(0).to(device)
res = inference(image, ram_model)
return [s.strip() for s in res[0].split("|")]
def get_gdino_result(
image: Image.Image,
classes: List[str],
box_threshold: float = 0.25,
weights_folder="weights",
) -> Tuple[Detections, List[str]]:
global gdino_model
if gdino_model is None:
print("Loading GroundingDINO Model...")
config_path = groundingdino.config.GroundingDINO_SwinT_OGC.__file__
gdino_model = Model(
model_config_path=config_path,
model_checkpoint_path=f"{weights_folder}/groundingdino_swint_ogc.pth",
device=device,
)
detections, phrases = gdino_model.predict_with_caption(
image=np.array(image),
caption=", ".join(classes),
box_threshold=box_threshold,
text_threshold=0.25,
)
return detections, phrases
def get_sam_model(weights_folder="weights"):
global sam_model
if sam_model is None:
sam_checkpoint = f"{weights_folder}/sam_vit_h_4b8939.pth"
sam_model = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
sam_model.to(device=device)
return sam_model
def filter_tags_gdino(image: Image.Image, tags: List[str]) -> List[str]:
detections, phrases = get_gdino_result(image, tags)
filtered_tags = []
for tag in tags:
for (
phrase,
area,
) in zip(phrases, detections.area):
if area < 0.9 * image.size[0] * image.size[1] and tag in phrase:
filtered_tags.append(tag)
break
return filtered_tags
def read_file_to_string(file_path: str) -> str:
content = ""
try:
with open(file_path, "r", encoding="utf8") as file:
content = file.read()
except FileNotFoundError:
print(f"The file {file_path} was not found.")
except Exception as e:
print(f"An error occurred while reading {file_path}: {e}")
return content
@retry(wait=wait_fixed(2))
def completion_with_backoff(**kwargs):
return openai.ChatCompletion.create(**kwargs)
def gpt4(
usr_prompt: str, sys_prompt: str = "", api_key: str = "", model: str = "gpt-4"
) -> str:
openai.api_key = api_key
message = [
{"role": "system", "content": sys_prompt},
{"role": "user", "content": usr_prompt},
]
response = completion_with_backoff(
model=model,
messages=message,
temperature=0.2,
max_tokens=1000,
frequency_penalty=0.0,
)
return response["choices"][0]["message"]["content"]
def select_best_tag(
filtered_tags: List[str], object_to_place: str, api_key: str = ""
) -> str:
user_template = read_file_to_string("user_template.txt").format(object=object_to_place)
user_prompt = user_template + "\n".join(filtered_tags)
system_prompt = read_file_to_string("system_template.txt")
return gpt4(user_prompt, system_prompt, api_key=api_key)
def get_location_gsam(
image: Image.Image, prompt: str, weights_folder="weights"
) -> Tuple[int, int]:
global sam_predictor
BOX_TRESHOLD = 0.25
RESIZE_RATIO = 3
detections, phrases = get_gdino_result(
image=image,
classes=[prompt],
box_threshold=BOX_TRESHOLD,
)
while len(detections.xyxy) == 0:
BOX_TRESHOLD -= 0.02
detections, phrases = get_gdino_result(
image=image,
classes=[prompt],
box_threshold=BOX_TRESHOLD,
)
sam_model = get_sam_model(weights_folder)
if sam_predictor is None:
print("Loading SAM Model...")
sam_predictor = SamPredictor(sam_model)
sam_predictor.set_image(np.array(image))
result_masks = []
for box in detections.xyxy:
masks, scores, logits = sam_predictor.predict(box=box, multimask_output=True)
index = np.argmax(scores)
result_masks.append(masks[index])
detections.mask = np.array(result_masks)
combined_mask = detections.mask[0]
for mask in detections.mask[1:]:
combined_mask += mask
combined_mask[combined_mask > 1] = 1
mask = cv2.resize(
combined_mask.astype("uint8"),
(
combined_mask.shape[1] // RESIZE_RATIO,
combined_mask.shape[0] // RESIZE_RATIO,
),
)
mask_2_pad = np.pad(mask, pad_width=2, mode="constant", constant_values=0)
mask_1_pad = np.pad(mask, pad_width=1, mode="constant", constant_values=0)
windows = np.lib.stride_tricks.sliding_window_view(mask_2_pad, (3, 3))
windows_all_zero = (windows == 0).all(axis=(2, 3))
result = np.where(windows_all_zero, 2, mask_1_pad)
mask_0_coordinates = np.argwhere(result == 0)
mask_1_coordinates = np.argwhere(result == 1)
distances = cdist(mask_1_coordinates, mask_0_coordinates, "euclidean")
max_min_distance_index = np.argmax(np.min(distances, axis=1))
y, x = mask_1_coordinates[max_min_distance_index]
return int(x) * RESIZE_RATIO, int(y) * RESIZE_RATIO
def run_octo_pipeline(input_image, object, api_key):
print("Inside run_octo_pipeline with input_image=", input_image, "object=", object)
print("Loading Image...")
image = input_image.convert("RGB")
print("Stage 1...")
tags = get_tags_ram(image, threshold_multiplier=0.8)
print("RAM++ Tags", tags)
filtered_tags = filter_tags_gdino(image, tags)
print("Filtered Tags", filtered_tags)
print("Stage 2...")
selected_tag = select_best_tag(filtered_tags, object, api_key=api_key)
print("GPT-4 Selected Tag", selected_tag)
print("Stage 3...")
x, y = get_location_gsam(image, selected_tag)
print("G-SAM Location", "(" + str(x) + "," + str(y) + ")")
draw = ImageDraw.Draw(image)
radius = 10
bbox = (x - radius, y - radius, x + radius, y + radius)
draw.ellipse(bbox, fill="red")
return [image]
block = gr.Blocks()
with block:
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", value="000749.png")
object = gr.Textbox(label="Object", placeholder="Enter an object")
api_key = gr.Textbox(label="OpenAI API Key", placeholder="Enter OpenAI API Key")
with gr.Column():
gallery = gr.Gallery(
label="Output",
show_label=False,
elem_id="gallery",
preview=True,
object_fit="scale-down",
)
iface = gr.Interface(
fn=run_octo_pipeline, inputs=[input_image, object, api_key], outputs=gallery
)
iface.launch()