Jsonwu's picture
Update app.py
f2fd017 verified
import torch
import gradio as gr
import json
from torchvision import transforms
from torchvision.ops import nms
from PIL import Image, ImageDraw, ImageFont
TORCHSCRIPT_PATH = "res/screenrecognition-web350k-vins.torchscript"
LABELS_PATH = "res/class_map_vins_manual.json"
model = torch.jit.load(TORCHSCRIPT_PATH)
with open(LABELS_PATH, "r") as f:
idx2Label = json.load(f)["idx2Label"]
img_transforms = transforms.ToTensor()
def inter_class_nms(boxes, scores, labels, iou_threshold=0.5):
# Perform non-maximum suppression
keep = nms(boxes, scores, iou_threshold)
# Filter boxes and scores
new_boxes = boxes[keep]
new_scores = scores[keep]
new_labels = labels[keep]
# Return the result in a dictionary
return {'boxes': new_boxes, 'scores': new_scores, 'labels': new_labels}
def predict(img, conf_thresh=0.4):
img_input = [img_transforms(img)]
_, pred = model(img_input)
pred = [inter_class_nms(pred[0]['boxes'], pred[0]['scores'], pred[0]['labels'])]
out_img = img.copy()
draw = ImageDraw.Draw(out_img)
font = ImageFont.truetype("res/Tuffy_Bold.ttf", 25)
for i in range(len(pred[0]['boxes'])):
conf_score = pred[0]['scores'][i]
if conf_score > conf_thresh:
x1, y1, x2, y2 = pred[0]['boxes'][i]
x1 = int(x1)
y1 = int(y1)
x2 = int(x2)
y2 = int(y2)
draw.rectangle([x1, y1, x2, y2], outline='red', width=3)
text = idx2Label[str(int(pred[0]['labels'][i]))] + " {:.2f}".format(float(conf_score))
bbox = draw.textbbox((x1, y1), text, font=font)
draw.rectangle(bbox, fill="red")
draw.text((x1, y1), text, font=font, fill="black")
return out_img
example_imgs = [
["res/example.jpg", 0.4],
["res/screenlane-snapchat-profile.jpg", 0.4],
["res/screenlane-snapchat-settings.jpg", 0.4],
["res/example_pair1.jpg", 0.4],
["res/example_pair2.jpg", 0.4],
]
interface = gr.Interface(fn=predict, inputs=[gr.Image(type="pil", label="Screenshot"), gr.Slider(0.0, 1.0, step=0.1, value=0.4)], outputs=gr.Image(type="pil", label="Annotated Screenshot").style(height=600), examples=example_imgs)
interface.launch()