owlv2-omega / app.py
Omega02gdfdd's picture
Update app.py
3b165ce verified
# -*- coding: utf-8 -*-
import torch
import gradio as gr
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import spaces
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import numpy as np
# 设置设备
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
#引入模型和推理器
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device)
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
#载入图像
@spaces.GPU
#输入图像,搜索文本,检测分数
def query_image(img, text_queries, score_threshold,compress):
img = load_image_as_np_array(img, compress)
text_queries = text_queries
#分割搜索文本
text_queries = text_queries.split(",")
#转换为正方行torch矩阵
#(长宽边最大的那个设置为size)
size = max(img.shape[:2])
target_sizes = torch.Tensor([[size, size]])
#创建输入(搜索文本和图像转换为torch张量发送到GPU)
inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device)
#禁用梯度计算,运行推理
with torch.no_grad():
outputs = model(**inputs)
#输出分数和边界框信息
outputs.logits = outputs.logits.cpu()
outputs.pred_boxes = outputs.pred_boxes.cpu()
#导出输出结果
results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes)
#分类存储输出结果的边界框,分数,标签
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
#创建空列表
result_labels = []
# all_result = []
#遍历分类存储的输出结果
for box, score, label in zip(boxes, scores, labels):
#转换为整数
box = [int(i) for i in box.tolist()]
#过滤阈值以下的目标
if score < score_threshold:
continue
result_labels.append((box, text_queries[label.item()]))
text = len(result_labels)
return img, result_labels,img, result_labels
#图像输入:图像压缩
def load_image_as_np_array(img, compress=False):
# 输入图像文件
# with Image.open(image_path) as img:
# 转换为RGB
# img = img.convert("RGB")
#数组-图像
img = Image.fromarray(img)
#图像压缩
if compress:
# 获取图像尺寸
width, height = img.size
# 检查图像分辨率是否大于2048
max_dimension = max(width, height)
if max_dimension > 1024:
# Calculate the new size, maintaining the aspect ratio
scale_factor = 1024 / max_dimension
new_width = int(width * scale_factor)
new_height = int(height * scale_factor)
# 缩放图像
img = img.resize((new_width, new_height))
# 图像-数组
img = np.array(img)
return img
demo = gr.Interface(
query_image,
inputs=[gr.Image(),
gr.Text(value="insect",label="提示词(多个用,分开)"),
gr.Slider(0, 1, value=0.2,label="确信度阈值"),
gr.Checkbox(value=True,label="图像压缩")],
outputs=[gr.Annotatedimage()],
title="Zero-Shot Object Detection with OWLv2",
)
demo.launch()