object-to-depth / app.py
tphung's picture
Use dpt-large
8511d74
import gradio as gr
from transformers import pipeline
from PIL import ImageDraw
import torch
detector = pipeline("zero-shot-object-detection", model="google/owlvit-base-patch32")
depth_estimator = pipeline("depth-estimation", model="Intel/dpt-large")
def visualize_preds(image, predictions):
new_image = image.copy()
draw = ImageDraw.Draw(new_image)
for prediction in predictions:
box = prediction["box"]
label = prediction["label"]
score = prediction["score"]
xmin, ymin, xmax, ymax = box.values()
draw.rectangle((xmin, ymin, xmax, ymax), outline="red", width=1)
draw.text((xmin, ymin), f"{label}: {round(score,2)}", fill="white")
return new_image
def compute_depth(image, preds):
output = depth_estimator(image)
prediction = torch.nn.functional.interpolate(
output["predicted_depth"].unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False,
).squeeze().cpu().numpy()
output = []
for pred in preds:
x = (pred["box"]["xmax"] + pred["box"]["xmin"]) // 2
y = (pred["box"]["ymax"] + pred["box"]["ymin"]) // 2
output.append({
"class": pred["label"],
"distance": float(prediction[y][x])
})
return output
def process(image, text):
items = text.split(".")
preds = detector(image, candidate_labels=items)
return [visualize_preds(image, preds), compute_depth(image, preds)]
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=1):
image = gr.Image(type="pil")
name = gr.Textbox(label="Name")
greet_btn = gr.Button("Greet")
with gr.Column(scale=1):
output_detection = gr.Image(type="pil")
output_distance = gr.JSON(label="Distance")
greet_btn.click(fn=process, inputs=[image, name], outputs=[output_detection, output_distance], api_name="process")
demo.launch()