|
import os |
|
import torch |
|
from rfdetr import RFDETRBase |
|
import supervision as sv |
|
from PIL import Image |
|
import numpy as np |
|
|
|
if not torch.distributed.is_initialized(): |
|
os.environ["MASTER_ADDR"] = "localhost" |
|
os.environ["MASTER_PORT"] = "29500" |
|
os.environ["RANK"] = "0" |
|
os.environ["WORLD_SIZE"] = "1" |
|
os.environ["LOCAL_RANK"] = "0" |
|
|
|
model = RFDETRBase() |
|
model.train(dataset_dir="P&ID-Symbols-3/P&ID-Symbols-3", resume="output/checkpoint0009.pth", epochs=0) |
|
|
|
ds = sv.DetectionDataset.from_coco( |
|
images_directory_path="P&ID-Symbols-3/P&ID-Symbols-3/test", |
|
annotations_path="P&ID-Symbols-3/P&ID-Symbols-3/test/_annotations.coco.json", |
|
) |
|
|
|
import streamlit as st |
|
import io |
|
|
|
def detect_symbols_and_lines(image): |
|
|
|
if not isinstance(image, Image.Image): |
|
if hasattr(image, "read"): |
|
image = Image.open(image) |
|
|
|
|
|
upscale_factor = 2 |
|
new_size = (int(image.width * upscale_factor), int(image.height * upscale_factor)) |
|
|
|
|
|
|
|
detections = model.predict(image, threshold=0.5) |
|
|
|
|
|
matching_index = None |
|
for idx in range(len(ds)): |
|
img_path, _, _ = ds[idx] |
|
if os.path.basename(img_path) == getattr(image, "filename", None): |
|
matching_index = idx |
|
break |
|
|
|
if matching_index is None: |
|
st.warning("No matching ground truth annotations found for this image.") |
|
annotations = sv.Detections.empty() |
|
annotations_labels = [] |
|
else: |
|
_, _, annotations = ds[matching_index] |
|
annotations_labels = [f"{ds.classes[class_id]}" for class_id in annotations.class_id] |
|
|
|
detections_labels = [ |
|
f"{ds.classes[class_id]} {confidence:.2f}" |
|
for class_id, confidence in zip(detections.class_id, detections.confidence) |
|
] |
|
|
|
text_scale = 0.9 |
|
thickness = sv.calculate_optimal_line_thickness(resolution_wh=image.size) |
|
|
|
bbox_annotator = sv.BoxAnnotator(thickness=thickness) |
|
label_annotator = sv.LabelAnnotator( |
|
text_color=sv.Color.BLACK, |
|
text_scale=text_scale, |
|
text_thickness=thickness, |
|
smart_position=True |
|
) |
|
|
|
annotation_image = image.copy() |
|
annotation_image = bbox_annotator.annotate(annotation_image, annotations) |
|
annotation_image = label_annotator.annotate(annotation_image, annotations, annotations_labels) |
|
|
|
detections_image = image.copy() |
|
detections_image = bbox_annotator.annotate(detections_image, detections) |
|
detections_image = label_annotator.annotate(detections_image, detections, detections_labels) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.image(annotation_image, caption="Ground Truth Annotations", use_column_width=True) |
|
with col2: |
|
st.image(detections_image, caption="Model Predictions", use_column_width=True) |
|
|
|
return detections, annotations, ds.classes |
|
|