Spaces:
Sleeping
Sleeping
File size: 2,908 Bytes
36cd99b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
"""
Created By: ishwor subedi
Date: 2024-07-10
"""
import os.path
import cv2
import numpy as np
import requests
import wget
from PIL import Image, ImageOps
from tqdm import tqdm
from ultralytics import YOLO
from segment_anything import SamPredictor, sam_model_registry
class Segmentation:
def __init__(self):
model_path = "artifacts/segmentation/yolov8x-seg.pt"
self.segmentation_model = YOLO(model=model_path)
def segment_image(self, image_path: str):
results = self.segmentation_model(image_path, show=True)
return results
class SegmentAnything:
def __init__(self, device="cpu"):
self.model_name = "sam_vit_l_0b3195.pth"
self.model_download()
self.sam = sam_model_registry["vit_l"](checkpoint="artifacts/segmentation/sam_vit_l_0b3195.pth").to(device)
self.samPredictor = SamPredictor(self.sam)
def model_download(self):
if os.path.exists(f"artifacts/segmentation/{self.model_name}"):
print(f"{self.model_name} model already exists.")
else:
print(f"Downloading {self.model_name} model...")
url = f"https://dl.fbaipublicfiles.com/segment_anything/{self.model_name}"
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(f"artifacts/segmentation/{self.model_name}", 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
print("ERROR, something went wrong")
def generate_mask(self, image, selected_points, deselected_points):
selected_pixels = []
deselected_pixels = []
selected_pixels.append(selected_points)
deselected_pixels.append(deselected_points)
self.samPredictor.set_image(image)
points = np.array(selected_pixels)
label = np.ones(points.shape[0])
mask, _, _ = self.samPredictor.predict(
point_coords=points,
point_labels=label,
)
mask = Image.fromarray(mask[0, :, :])
mask_img = ImageOps.invert(mask)
return mask_img
if __name__ == '__main__':
segment_anything = SegmentAnything()
image_path = "/home/ishwor/Pictures/01.TEST/alia/5869473_dark_lean.png"
image = cv2.imread(image_path)
mask = segment_anything.generate_mask(image, (20, 20),
(20, 20))
maskimage = np.array(mask)
image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
print(maskimage.shape)
cv2.imshow("image", image)
cv2.waitKey(0)
|