Spaces:
Sleeping
Sleeping
""" | |
Created By: ishwor subedi | |
Date: 2024-07-10 | |
""" | |
import os.path | |
import cv2 | |
import numpy as np | |
import requests | |
import wget | |
from PIL import Image, ImageOps | |
from tqdm import tqdm | |
from ultralytics import YOLO | |
from segment_anything import SamPredictor, sam_model_registry | |
class Segmentation: | |
def __init__(self): | |
model_path = "artifacts/segmentation/yolov8x-seg.pt" | |
self.segmentation_model = YOLO(model=model_path) | |
def segment_image(self, image_path: str): | |
results = self.segmentation_model(image_path, show=True) | |
return results | |
class SegmentAnything: | |
def __init__(self, device="cpu"): | |
self.model_name = "sam_vit_l_0b3195.pth" | |
self.model_download() | |
self.sam = sam_model_registry["vit_l"](checkpoint="artifacts/segmentation/sam_vit_l_0b3195.pth").to(device) | |
self.samPredictor = SamPredictor(self.sam) | |
def model_download(self): | |
if os.path.exists(f"artifacts/segmentation/{self.model_name}"): | |
print(f"{self.model_name} model already exists.") | |
else: | |
print(f"Downloading {self.model_name} model...") | |
url = f"https://dl.fbaipublicfiles.com/segment_anything/{self.model_name}" | |
response = requests.get(url, stream=True) | |
total_size_in_bytes = int(response.headers.get('content-length', 0)) | |
block_size = 1024 | |
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) | |
with open(f"artifacts/segmentation/{self.model_name}", 'wb') as file: | |
for data in response.iter_content(block_size): | |
progress_bar.update(len(data)) | |
file.write(data) | |
progress_bar.close() | |
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: | |
print("ERROR, something went wrong") | |
def generate_mask(self, image, selected_points, deselected_points): | |
selected_pixels = [] | |
deselected_pixels = [] | |
selected_pixels.append(selected_points) | |
deselected_pixels.append(deselected_points) | |
self.samPredictor.set_image(image) | |
points = np.array(selected_pixels) | |
label = np.ones(points.shape[0]) | |
mask, _, _ = self.samPredictor.predict( | |
point_coords=points, | |
point_labels=label, | |
) | |
mask = Image.fromarray(mask[0, :, :]) | |
mask_img = ImageOps.invert(mask) | |
return mask_img | |
if __name__ == '__main__': | |
segment_anything = SegmentAnything() | |
image_path = "/home/ishwor/Pictures/01.TEST/alia/5869473_dark_lean.png" | |
image = cv2.imread(image_path) | |
mask = segment_anything.generate_mask(image, (20, 20), | |
(20, 20)) | |
maskimage = np.array(mask) | |
image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
print(maskimage.shape) | |
cv2.imshow("image", image) | |
cv2.waitKey(0) | |