Jsonwu's picture
Update app.py
289bee5 verified
raw
history blame
No virus
4.07 kB
import torch
import gradio as gr
import json
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFont
TORCHSCRIPT_PATH = "res/screenrecognition-web350k-vins.torchscript"
LABELS_PATH = "res/class_map_vins_manual.json"
model = torch.jit.load(TORCHSCRIPT_PATH)
with open(LABELS_PATH, "r") as f:
idx2Label = json.load(f)["idx2Label"]
img_transforms = transforms.ToTensor()
# inter_class_nms and iou functions implemented by GPT
def inter_class_nms(boxes, scores, iou_threshold=0.5):
# Convert boxes and scores to torch tensors if they are not already
boxes = torch.as_tensor(boxes)
scores, class_indices = scores.max(dim=1)
# Keep track of final boxes and scores
final_boxes = []
final_scores = []
final_class_indices = []
for class_index in range(scores.shape[1]):
# Filter boxes and scores for the current class
class_scores = scores[:, class_index]
class_boxes = boxes
# Indices of boxes sorted by score (highest first)
sorted_indices = torch.argsort(class_scores, descending=True)
while len(sorted_indices) > 0:
# Take the box with the highest score
highest_index = sorted_indices[0]
highest_box = class_boxes[highest_index]
# Add the highest box and score to the final list
final_boxes.append(highest_box)
final_scores.append(class_scores[highest_index])
final_class_indices.append(class_index)
# Remove the highest box from the list
sorted_indices = sorted_indices[1:]
# Compute IoU of the highest box with the rest
ious = iou(class_boxes[sorted_indices], highest_box)
# Keep only boxes with IoU less than the threshold
sorted_indices = sorted_indices[ious < iou_threshold]
return {'boxes': final_boxes, 'scores': final_scores}
def iou(boxes1, boxes2):
"""
Compute the Intersection over Union (IoU) of two sets of boxes.
Args:
- boxes1 (Tensor[N, 4]): ground truth boxes
- boxes2 (Tensor[M, 4]): predicted boxes
Returns:
- iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
"""
area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
iou = inter / (area1[:, None] + area2 - inter)
return iou
def predict(img, conf_thresh=0.4):
img_input = [img_transforms(img)]
_, pred = model(img_input)
pred = inter_class_nms(pred['boxes'], pred['scores'])
out_img = img.copy()
draw = ImageDraw.Draw(out_img)
font = ImageFont.truetype("res/Tuffy_Bold.ttf", 25)
for i in range(len(pred[0]['boxes'])):
conf_score = pred[0]['scores'][i]
if conf_score > conf_thresh:
x1, y1, x2, y2 = pred[0]['boxes'][i]
x1 = int(x1)
y1 = int(y1)
x2 = int(x2)
y2 = int(y2)
draw.rectangle([x1, y1, x2, y2], outline='red', width=3)
text = idx2Label[str(int(pred[0]['labels'][i]))] + " {:.2f}".format(float(conf_score))
bbox = draw.textbbox((x1, y1), text, font=font)
draw.rectangle(bbox, fill="red")
draw.text((x1, y1), text, font=font, fill="black")
return out_img
example_imgs = [
["res/example.jpg", 0.4],
["res/screenlane-snapchat-profile.jpg", 0.4],
["res/screenlane-snapchat-settings.jpg", 0.4],
["res/example_pair1.jpg", 0.4],
["res/example_pair2.jpg", 0.4],
]
interface = gr.Interface(fn=predict, inputs=[gr.Image(type="pil", label="Screenshot"), gr.Slider(0.0, 1.0, step=0.1, value=0.4)], outputs=gr.Image(type="pil", label="Annotated Screenshot").style(height=600), examples=example_imgs)
interface.launch()