File size: 1,417 Bytes
62dd301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
import torch
import numpy as np
from PIL import Image

class SegmentAnything:
	def __init__(self):
		sam_checkpoint = 'checkpoint/sam_vit_h_4b8939.pth'
		model_type = 'vit_h'
		sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
		if torch.cuda.is_available():
			sam.to(device='cuda')
		self.sam = sam

	def predict(self, image, point_coords, point_labels, box=None):
		predictor = SamPredictor(self.sam)
		predictor.set_image(np.array(image, dtype=np.uint8))
		return predictor.predict(point_coords=point_coords, point_labels=point_labels, box=box)

	def generate(self, image):
		mask_generator = SamAutomaticMaskGenerator(self.sam)
		return mask_generator.generate(np.array(image, dtype=np.uint8))

	@staticmethod
	def makeMaskImage(mask, color):
		image = Image.new('RGBA', mask.shape)
		width, height = image.size
		for x in range(width):
			for y in range(height):
				if mask[x, y]:
					image.putpixel((x, y), color)
		return image

	@staticmethod
	def makeNewImage(image, maskImage):
		newImage = Image.new('RGBA', image.size)
		timage = maskImage.copy()
		width, height = timage.size
		for x in range(width):
			for y in range(height):
				_, _, _, a = timage.getpixel((x, y))
				timage.putpixel((x, y), (0, 0, 0, 255) if a > 0 else (0, 0, 0, 0))
		newImage.paste(image, (0, 0), timage)
		return newImage