taroii's picture
Update app.py
81d6d36
raw
history blame
1.47 kB
import gradio as gr
import torch
from transformers import DetrForObjectDetection, DetrImageProcessor, AutoModel
import supervision as sv
from supervision.detection.annotate import BoxAnnotator
from supervision.utils.notebook import plot_image
import cv2
og_model = 'facebook/detr-resnet-50'
image_processor = DetrImageProcessor.from_pretrained(og_model)
model = AutoModel.from_pretrained("taroii/notfinetuned-detr-50")
def predict(image_path):
image = cv2.imread(image_path)
with torch.no_grad():
# load image and predict
inputs = image_processor(images=image, return_tensors='pt')
outputs = model(**inputs)
# post-process
target_sizes = torch.tensor([image.shape[:2]])
results = image_processor.post_process_object_detection(
outputs=outputs,
threshold=CONFIDENCE_TRESHOLD,
target_sizes=target_sizes
)[0]
# annotate
detections = sv.Detections.from_transformers(transformers_results=results).with_nms(threshold=0.5)
labels = [f"{id2label[class_id]} {confidence:.2f}" for _, confidence, class_id, _ in detections]
frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)
#return plot_image(frame, (16, 16))
return frame
gr.Interface(
predict,
inputs=gr.inputs.Image(label="Upload hot dog candidate", type="filepath"),
outputs=gr.Image(),
title="Non-Fine-Tuned Model"
).launch()