File size: 2,908 Bytes
36cd99b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""
Created By: ishwor subedi
Date: 2024-07-10
"""
import os.path

import cv2
import numpy as np
import requests
import wget
from PIL import Image, ImageOps
from tqdm import tqdm
from ultralytics import YOLO
from segment_anything import SamPredictor, sam_model_registry


class Segmentation:
    def __init__(self):
        model_path = "artifacts/segmentation/yolov8x-seg.pt"
        self.segmentation_model = YOLO(model=model_path)

    def segment_image(self, image_path: str):
        results = self.segmentation_model(image_path, show=True)
        return results


class SegmentAnything:
    def __init__(self, device="cpu"):
        self.model_name = "sam_vit_l_0b3195.pth"
        self.model_download()

        self.sam = sam_model_registry["vit_l"](checkpoint="artifacts/segmentation/sam_vit_l_0b3195.pth").to(device)
        self.samPredictor = SamPredictor(self.sam)

    def model_download(self):
        if os.path.exists(f"artifacts/segmentation/{self.model_name}"):
            print(f"{self.model_name} model already exists.")

        else:
            print(f"Downloading {self.model_name} model...")
            url = f"https://dl.fbaipublicfiles.com/segment_anything/{self.model_name}"
            response = requests.get(url, stream=True)
            total_size_in_bytes = int(response.headers.get('content-length', 0))
            block_size = 1024
            progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
            with open(f"artifacts/segmentation/{self.model_name}", 'wb') as file:
                for data in response.iter_content(block_size):
                    progress_bar.update(len(data))
                    file.write(data)
            progress_bar.close()
            if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
                print("ERROR, something went wrong")

    def generate_mask(self, image, selected_points, deselected_points):
        selected_pixels = []
        deselected_pixels = []

        selected_pixels.append(selected_points)
        deselected_pixels.append(deselected_points)
        self.samPredictor.set_image(image)
        points = np.array(selected_pixels)
        label = np.ones(points.shape[0])
        mask, _, _ = self.samPredictor.predict(
            point_coords=points,

            point_labels=label,
        )
        mask = Image.fromarray(mask[0, :, :])

        mask_img = ImageOps.invert(mask)

        return mask_img


if __name__ == '__main__':
    segment_anything = SegmentAnything()
    image_path = "/home/ishwor/Pictures/01.TEST/alia/5869473_dark_lean.png"
    image = cv2.imread(image_path)
    mask = segment_anything.generate_mask(image, (20, 20),
                                          (20, 20))
    maskimage = np.array(mask)
    image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    print(maskimage.shape)
    cv2.imshow("image", image)
    cv2.waitKey(0)