asammoud
Re-add large CSVs using Git LFS
b265364
import os
import torch
from rfdetr import RFDETRBase
import supervision as sv
from PIL import Image
import numpy as np
if not torch.distributed.is_initialized():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_RANK"] = "0"
model = RFDETRBase()
model.train(dataset_dir="P&ID-Symbols-3/P&ID-Symbols-3", resume="output/checkpoint0009.pth", epochs=0)
ds = sv.DetectionDataset.from_coco(
images_directory_path="P&ID-Symbols-3/P&ID-Symbols-3/test",
annotations_path="P&ID-Symbols-3/P&ID-Symbols-3/test/_annotations.coco.json",
)
import streamlit as st
import io
def detect_symbols_and_lines(image):
# Convert to PIL.Image if needed
if not isinstance(image, Image.Image):
if hasattr(image, "read"):
image = Image.open(image)
# === Improve resolution ===
upscale_factor = 2
new_size = (int(image.width * upscale_factor), int(image.height * upscale_factor))
# image = image.resize(new_size, resample=Image.BICUBIC)
# === Run model prediction ===
detections = model.predict(image, threshold=0.5)
# === Find matching dataset entry ===
matching_index = None
for idx in range(len(ds)):
img_path, _, _ = ds[idx]
if os.path.basename(img_path) == getattr(image, "filename", None):
matching_index = idx
break
if matching_index is None:
st.warning("No matching ground truth annotations found for this image.")
annotations = sv.Detections.empty()
annotations_labels = []
else:
_, _, annotations = ds[matching_index]
annotations_labels = [f"{ds.classes[class_id]}" for class_id in annotations.class_id]
detections_labels = [
f"{ds.classes[class_id]} {confidence:.2f}"
for class_id, confidence in zip(detections.class_id, detections.confidence)
]
text_scale = 0.9
thickness = sv.calculate_optimal_line_thickness(resolution_wh=image.size)
bbox_annotator = sv.BoxAnnotator(thickness=thickness)
label_annotator = sv.LabelAnnotator(
text_color=sv.Color.BLACK,
text_scale=text_scale,
text_thickness=thickness,
smart_position=True
)
annotation_image = image.copy()
annotation_image = bbox_annotator.annotate(annotation_image, annotations)
annotation_image = label_annotator.annotate(annotation_image, annotations, annotations_labels)
detections_image = image.copy()
detections_image = bbox_annotator.annotate(detections_image, detections)
detections_image = label_annotator.annotate(detections_image, detections, detections_labels)
# === Display side-by-side in Streamlit ===
col1, col2 = st.columns(2)
with col1:
st.image(annotation_image, caption="Ground Truth Annotations", use_column_width=True)
with col2:
st.image(detections_image, caption="Model Predictions", use_column_width=True)
return detections, annotations, ds.classes