Jsonwu's picture
Update app.py
f2fd017 verified
raw
history blame contribute delete
No virus
2.22 kB
import torch
import gradio as gr
import json
from torchvision import transforms
from torchvision.ops import nms
from PIL import Image, ImageDraw, ImageFont
TORCHSCRIPT_PATH = "res/screenrecognition-web350k-vins.torchscript"
LABELS_PATH = "res/class_map_vins_manual.json"
model = torch.jit.load(TORCHSCRIPT_PATH)
with open(LABELS_PATH, "r") as f:
idx2Label = json.load(f)["idx2Label"]
img_transforms = transforms.ToTensor()
def inter_class_nms(boxes, scores, labels, iou_threshold=0.5):
# Perform non-maximum suppression
keep = nms(boxes, scores, iou_threshold)
# Filter boxes and scores
new_boxes = boxes[keep]
new_scores = scores[keep]
new_labels = labels[keep]
# Return the result in a dictionary
return {'boxes': new_boxes, 'scores': new_scores, 'labels': new_labels}
def predict(img, conf_thresh=0.4):
img_input = [img_transforms(img)]
_, pred = model(img_input)
pred = [inter_class_nms(pred[0]['boxes'], pred[0]['scores'], pred[0]['labels'])]
out_img = img.copy()
draw = ImageDraw.Draw(out_img)
font = ImageFont.truetype("res/Tuffy_Bold.ttf", 25)
for i in range(len(pred[0]['boxes'])):
conf_score = pred[0]['scores'][i]
if conf_score > conf_thresh:
x1, y1, x2, y2 = pred[0]['boxes'][i]
x1 = int(x1)
y1 = int(y1)
x2 = int(x2)
y2 = int(y2)
draw.rectangle([x1, y1, x2, y2], outline='red', width=3)
text = idx2Label[str(int(pred[0]['labels'][i]))] + " {:.2f}".format(float(conf_score))
bbox = draw.textbbox((x1, y1), text, font=font)
draw.rectangle(bbox, fill="red")
draw.text((x1, y1), text, font=font, fill="black")
return out_img
example_imgs = [
["res/example.jpg", 0.4],
["res/screenlane-snapchat-profile.jpg", 0.4],
["res/screenlane-snapchat-settings.jpg", 0.4],
["res/example_pair1.jpg", 0.4],
["res/example_pair2.jpg", 0.4],
]
interface = gr.Interface(fn=predict, inputs=[gr.Image(type="pil", label="Screenshot"), gr.Slider(0.0, 1.0, step=0.1, value=0.4)], outputs=gr.Image(type="pil", label="Annotated Screenshot").style(height=600), examples=example_imgs)
interface.launch()