File size: 2,598 Bytes
ac42789 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import cv2
import numpy as np
import torch
from matplotlib import pyplot as plt
from ultralytics import YOLO
from ultralytics.engine.results import Masks
class CablePoleSegmentation():
def __init__(self, MODEL_PATH=None, retina_mask=False):
if not MODEL_PATH:
MODEL_PATH = "./weight/yolov8l-seg-pre100.onnx"
self._RetinaMask=retina_mask
self.Model = YOLO(MODEL_PATH) # load a custom model
def RescaleTheMask(self, orijinal_image, masks):
_masks = []
for contour in masks:
b_mask = np.zeros(orijinal_image.shape[:2], np.uint8)
contour = contour.astype(np.int32)
contour = contour.reshape(-1, 1, 2)
mask = cv2.drawContours(b_mask, [contour], -1, (1, 1, 1), cv2.FILLED)
_masks += [mask]
return _masks
def Process(self, image):
with torch.no_grad():
results = self.Model(
image,
save=False,
show_boxes=False,
project="./result/",
conf=0.5,
retina_masks=self._RetinaMask,
stream=True
)
with torch.no_grad():
for result in results:
maskCountours = result.masks.xy
boxes = result.boxes.xyxy.int().cpu().numpy()
classes = result.boxes.cls.cpu().numpy()
rescaledMasks = self.RescaleTheMask(image, maskCountours)
return rescaledMasks, boxes, classes, result.plot()
def PlotResults(self, masks, boxes, classes, original_image, result_image, mask, cable_mask):
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(27,15))
axs[0][0].imshow(original_image)
axs[0][0].set_title("Orijinal Görüntü")
axs[0][1].imshow(mask)
axs[0][1].set_title("Segmentasyon Maskesi")
cv2.imwrite("cable_mask.png", cable_mask)
axs[1][0].imshow(cable_mask)
axs[1][0].set_title("Seçilen")
axs[1][1].imshow(result_image)
axs[1][1].set_title("Sonuç")
plt.show()
if "__main__" == __name__:
test = "data/16_3450.png"
image = cv2.imread(test)
model = CablePoleSegmentation(retina_mask=True)
masks, boxes, classes, result_plot = model.Process(image)
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(27,15))
axs[0][0].imshow(image)
axs[0][0].set_title("Orijinal Görüntü")
axs[1][1].imshow(np.any(masks, axis=0))
axs[1][1].set_title("Sonuç")
plt.show()
# model.PlotResults(*model.Process(image))
|