MTO-TCP / src /components /object_segmentation.py
ishworrsubedii's picture
Updated the latest changes
36cd99b
"""
Created By: ishwor subedi
Date: 2024-07-10
"""
import os.path
import cv2
import numpy as np
import requests
import wget
from PIL import Image, ImageOps
from tqdm import tqdm
from ultralytics import YOLO
from segment_anything import SamPredictor, sam_model_registry
class Segmentation:
def __init__(self):
model_path = "artifacts/segmentation/yolov8x-seg.pt"
self.segmentation_model = YOLO(model=model_path)
def segment_image(self, image_path: str):
results = self.segmentation_model(image_path, show=True)
return results
class SegmentAnything:
def __init__(self, device="cpu"):
self.model_name = "sam_vit_l_0b3195.pth"
self.model_download()
self.sam = sam_model_registry["vit_l"](checkpoint="artifacts/segmentation/sam_vit_l_0b3195.pth").to(device)
self.samPredictor = SamPredictor(self.sam)
def model_download(self):
if os.path.exists(f"artifacts/segmentation/{self.model_name}"):
print(f"{self.model_name} model already exists.")
else:
print(f"Downloading {self.model_name} model...")
url = f"https://dl.fbaipublicfiles.com/segment_anything/{self.model_name}"
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(f"artifacts/segmentation/{self.model_name}", 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
print("ERROR, something went wrong")
def generate_mask(self, image, selected_points, deselected_points):
selected_pixels = []
deselected_pixels = []
selected_pixels.append(selected_points)
deselected_pixels.append(deselected_points)
self.samPredictor.set_image(image)
points = np.array(selected_pixels)
label = np.ones(points.shape[0])
mask, _, _ = self.samPredictor.predict(
point_coords=points,
point_labels=label,
)
mask = Image.fromarray(mask[0, :, :])
mask_img = ImageOps.invert(mask)
return mask_img
if __name__ == '__main__':
segment_anything = SegmentAnything()
image_path = "/home/ishwor/Pictures/01.TEST/alia/5869473_dark_lean.png"
image = cv2.imread(image_path)
mask = segment_anything.generate_mask(image, (20, 20),
(20, 20))
maskimage = np.array(mask)
image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
print(maskimage.shape)
cv2.imshow("image", image)
cv2.waitKey(0)