File size: 2,974 Bytes
211120c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import streamlit as st
from PIL import Image, ImageDraw
import torchvision.transforms as T
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from torch.utils.data import DataLoader
from transformers import DetrImageProcessor
import torchvision
import numpy as np
# LABELS = {0: 'table', 1: 'column', 2: 'row', 3: 'cell'}
LABELS = {0: 'ballon'}
def detect(image_tensor, model_path, device="cuda"):
# processor = DetrImageProcessor.from_pretrained("microsoft/table-transformer-detection")
model = DetrForObjectDetection.from_pretrained(model_path, #facebook/detr-resnet-50
revision="no_timm",
num_labels=len(LABELS),
ignore_mismatched_sizes=True)
model.to(device)
processor = DetrImageProcessor.from_pretrained(model_path)
image = processor(image_tensor, return_tensors="pt", padding=True, do_rescale=False)['pixel_values']
box, label, score = [], [], []
with torch.no_grad():
outputs = model(image.to(device))
d, z, width, height = image.size()
postprocessed_outputs = processor.post_process_object_detection(outputs,
target_sizes=[(height, width)],
threshold=0.8)
results = postprocessed_outputs[0]
print(f"scores: {results['scores']}, labels: {results['labels']}, boxes: {results['boxes']}")
boxes = results["boxes"].detach().numpy()
labels = results["labels"].detach().numpy()
scores = results["scores"].detach().numpy()
for b, l, s in zip(boxes, labels, scores):
x0, y0, x1, y1 = b
box.append((x0, y0, x1, y1))
label.append(LABELS[l])
score.append(round(s, 2))
return box, label, score
st.title("Table Detection")
file = st.file_uploader("Upload Image", type=["png", "jpeg", "jpg"])
if file is not None:
image = Image.open(file).convert("RGB")
image_transform = T.Compose([T.Resize(900), T.ToTensor()])
image_tensor = image_transform(image).unsqueeze(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model_path = "microsoft/table-transformer-detection"
model_path = 'facebook/detr-resnet-50'
box, label, score = detect(image_tensor, model_path, device)
print(f"Detected Objects:")
for b, l, s in zip(box, label, score):
print(f" {l} {s:.2f} at {b}")
draw_image = Image.fromarray(np.asarray(image).copy())
for b, l, s in zip(box, label, score):
draw = ImageDraw.Draw(draw_image)
draw.rectangle(b, outline="red")
draw.text((b[0], b[1]), f"{l} {s:.2f}", fill="red")
st.image(draw_image, caption="Detected Objects", use_column_width=True)
|