|
import streamlit as st |
|
from PIL import Image, ImageDraw |
|
import torchvision.transforms as T |
|
from transformers import DetrImageProcessor, DetrForObjectDetection |
|
import torch |
|
from torch.utils.data import DataLoader |
|
from transformers import DetrImageProcessor |
|
import torchvision |
|
import numpy as np |
|
|
|
|
|
LABELS = {0: 'ballon'} |
|
|
|
|
|
def detect(image_tensor, model_path, device="cuda"): |
|
|
|
|
|
model = DetrForObjectDetection.from_pretrained(model_path, |
|
revision="no_timm", |
|
num_labels=len(LABELS), |
|
ignore_mismatched_sizes=True) |
|
model.to(device) |
|
processor = DetrImageProcessor.from_pretrained(model_path) |
|
image = processor(image_tensor, return_tensors="pt", padding=True, do_rescale=False)['pixel_values'] |
|
|
|
box, label, score = [], [], [] |
|
with torch.no_grad(): |
|
outputs = model(image.to(device)) |
|
|
|
d, z, width, height = image.size() |
|
postprocessed_outputs = processor.post_process_object_detection(outputs, |
|
target_sizes=[(height, width)], |
|
threshold=0.8) |
|
results = postprocessed_outputs[0] |
|
print(f"scores: {results['scores']}, labels: {results['labels']}, boxes: {results['boxes']}") |
|
|
|
boxes = results["boxes"].detach().numpy() |
|
labels = results["labels"].detach().numpy() |
|
scores = results["scores"].detach().numpy() |
|
|
|
for b, l, s in zip(boxes, labels, scores): |
|
x0, y0, x1, y1 = b |
|
box.append((x0, y0, x1, y1)) |
|
label.append(LABELS[l]) |
|
score.append(round(s, 2)) |
|
return box, label, score |
|
|
|
|
|
st.title("Table Detection") |
|
file = st.file_uploader("Upload Image", type=["png", "jpeg", "jpg"]) |
|
|
|
if file is not None: |
|
image = Image.open(file).convert("RGB") |
|
image_transform = T.Compose([T.Resize(900), T.ToTensor()]) |
|
image_tensor = image_transform(image).unsqueeze(0) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model_path = 'facebook/detr-resnet-50' |
|
box, label, score = detect(image_tensor, model_path, device) |
|
print(f"Detected Objects:") |
|
for b, l, s in zip(box, label, score): |
|
print(f" {l} {s:.2f} at {b}") |
|
|
|
|
|
|
|
|
|
draw_image = Image.fromarray(np.asarray(image).copy()) |
|
for b, l, s in zip(box, label, score): |
|
draw = ImageDraw.Draw(draw_image) |
|
|
|
draw.rectangle(b, outline="red") |
|
draw.text((b[0], b[1]), f"{l} {s:.2f}", fill="red") |
|
st.image(draw_image, caption="Detected Objects", use_column_width=True) |
|
|
|
|
|
|