blitzkrieg0000's picture
Upload 17 files
ac42789 verified
raw
history blame
2.6 kB
import cv2
import numpy as np
import torch
from matplotlib import pyplot as plt
from ultralytics import YOLO
from ultralytics.engine.results import Masks
class CablePoleSegmentation():
def __init__(self, MODEL_PATH=None, retina_mask=False):
if not MODEL_PATH:
MODEL_PATH = "./weight/yolov8l-seg-pre100.onnx"
self._RetinaMask=retina_mask
self.Model = YOLO(MODEL_PATH) # load a custom model
def RescaleTheMask(self, orijinal_image, masks):
_masks = []
for contour in masks:
b_mask = np.zeros(orijinal_image.shape[:2], np.uint8)
contour = contour.astype(np.int32)
contour = contour.reshape(-1, 1, 2)
mask = cv2.drawContours(b_mask, [contour], -1, (1, 1, 1), cv2.FILLED)
_masks += [mask]
return _masks
def Process(self, image):
with torch.no_grad():
results = self.Model(
image,
save=False,
show_boxes=False,
project="./result/",
conf=0.5,
retina_masks=self._RetinaMask,
stream=True
)
with torch.no_grad():
for result in results:
maskCountours = result.masks.xy
boxes = result.boxes.xyxy.int().cpu().numpy()
classes = result.boxes.cls.cpu().numpy()
rescaledMasks = self.RescaleTheMask(image, maskCountours)
return rescaledMasks, boxes, classes, result.plot()
def PlotResults(self, masks, boxes, classes, original_image, result_image, mask, cable_mask):
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(27,15))
axs[0][0].imshow(original_image)
axs[0][0].set_title("Orijinal Görüntü")
axs[0][1].imshow(mask)
axs[0][1].set_title("Segmentasyon Maskesi")
cv2.imwrite("cable_mask.png", cable_mask)
axs[1][0].imshow(cable_mask)
axs[1][0].set_title("Seçilen")
axs[1][1].imshow(result_image)
axs[1][1].set_title("Sonuç")
plt.show()
if "__main__" == __name__:
test = "data/16_3450.png"
image = cv2.imread(test)
model = CablePoleSegmentation(retina_mask=True)
masks, boxes, classes, result_plot = model.Process(image)
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(27,15))
axs[0][0].imshow(image)
axs[0][0].set_title("Orijinal Görüntü")
axs[1][1].imshow(np.any(masks, axis=0))
axs[1][1].set_title("Sonuç")
plt.show()
# model.PlotResults(*model.Process(image))