Omega02gdfdd commited on
Commit
f26688e
1 Parent(s): d004489

Upload app.py

Browse files

new app.py,Add compress oversized images to 1024 size.

Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import gradio as gr
5
+ from transformers import Owlv2Processor, Owlv2ForObjectDetection
6
+ import spaces
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib.patches as patches
9
+ from PIL import Image
10
+ import numpy as np
11
+ # 设置设备
12
+ if torch.cuda.is_available():
13
+ device = torch.device("cuda")
14
+ else:
15
+ device = torch.device("cpu")
16
+ #引入模型和推理器
17
+ model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device)
18
+ processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
19
+ #载入图像
20
+ @spaces.GPU
21
+ #输入图像,搜索文本,检测分数
22
+ def query_image(img, text_queries, score_threshold,compress):
23
+ img = load_image_as_np_array(img, compress)
24
+ text_queries = text_queries
25
+ #分割搜索文本
26
+ text_queries = text_queries.split(",")
27
+ #转换为正方行torch矩阵
28
+ #(长宽边最大的那个设置为size)
29
+ size = max(img.shape[:2])
30
+ target_sizes = torch.Tensor([[size, size]])
31
+ #创建输入(搜索文本和图像转换为torch张量发送到GPU)
32
+ inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device)
33
+ #禁用梯度计算,运行推理
34
+ with torch.no_grad():
35
+ outputs = model(**inputs)
36
+ #输出分数和边界框信息
37
+ outputs.logits = outputs.logits.cpu()
38
+ outputs.pred_boxes = outputs.pred_boxes.cpu()
39
+ #导出输出结果
40
+ results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes)
41
+ #分类存储输出结果的边界框,分数,标签
42
+ boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
43
+
44
+ #创建空列表
45
+ result_labels = []
46
+ # all_result = []
47
+ #遍历分类存储的输出结果
48
+ for box, score, label in zip(boxes, scores, labels):
49
+ #转换为整数
50
+ box = [int(i) for i in box.tolist()]
51
+ #过滤阈值以下的目标
52
+ if score < score_threshold:
53
+ continue
54
+ result_labels.append((box, text_queries[label.item()]))
55
+ text = len(result_labels)
56
+ return img, result_labels,img, result_labels
57
+ p
58
+
59
+ #图像输入:图像压缩
60
+ def load_image_as_np_array(img, compress=False):
61
+ # 输入图像文件
62
+ # with Image.open(image_path) as img:
63
+ # 转换为RGB
64
+ # img = img.convert("RGB")
65
+ #数组-图像
66
+ img = Image.fromarray(img)
67
+ #图像压缩
68
+ if compress:
69
+ # 获取图像尺寸
70
+ width, height = img.size
71
+
72
+ # 检查图像分辨率是否大于2048
73
+ max_dimension = max(width, height)
74
+ if max_dimension > 1024:
75
+ # Calculate the new size, maintaining the aspect ratio
76
+ scale_factor = 1024 / max_dimension
77
+ new_width = int(width * scale_factor)
78
+ new_height = int(height * scale_factor)
79
+
80
+ # 缩放图像
81
+ img = img.resize((new_width, new_height))
82
+
83
+ # 图像-数组
84
+ img = np.array(img)
85
+ return img
86
+
87
+ demo1 = gr.Interface(
88
+ query_image,
89
+ inputs=[gr.Image(),
90
+ gr.Text(value="insect",label="提示词(多个用,分开)"),
91
+ gr.Slider(0, 1, value=0.2,label="确信度阈值"),
92
+ gr.Checkbox(value=True,label="图像压缩")],
93
+ outputs=[gr.Annotatedimage()],
94
+ title="Zero-Shot Object Detection with OWLv2",
95
+ )
96
+ demo1.launch()