import cv2 from matplotlib import pyplot as plt import torch import numpy as np from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation from segmentation_mask_overlay import overlay_masks from typing import List import logging class CLIPSEG: def __init__(self,model_name = "CIDAS/clipseg-rd64-refined",threshould=0.60): self.clip_processor = CLIPSegProcessor.from_pretrained(model_name) self.clip_model = CLIPSegForImageSegmentation.from_pretrained(model_name) self.threshould = threshould self.clip_model.to('cpu') @staticmethod def create_rgb_mask(mask,color=None): color = tuple(np.random.choice(range(0,256), size=3)) gray_3_channel = cv2.merge((mask, mask, mask)) gray_3_channel[mask==255] = color return gray_3_channel.astype(np.uint8) def get_segmentation_mask(self,image_path:str,object_prompts:List): image = cv2.cvtColor(cv2.imread(image_path),cv2.COLOR_BGR2RGB) logging.info("objects found out from the image :{}".format(object_prompts)) predicted_masks = [] inputs = self.clip_processor( text=object_prompts, images=[image] * len(object_prompts), padding="max_length", return_tensors="pt", ) with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation outputs = self.clip_model(**inputs) preds = outputs.logits.unsqueeze(1) # detections = outputs.logits[0] # Assuming class index 0 for i in range(preds.shape[0]): predicted_mask = torch.sigmoid(preds[i][0]).detach().cpu().numpy() predicted_mask = np.where(predicted_mask>self.threshould, 255,0) predicted_masks.append(predicted_mask) resize_image = cv2.resize(image,(352,352)) mask_labels = [f"{prompt}_{i}" for i,prompt in enumerate(object_prompts)] cmap = plt.cm.tab20(np.arange(len(mask_labels)))[..., :-1] bool_masks = [predicted_mask.astype('bool') for predicted_mask in predicted_masks] final_mask = overlay_masks(resize_image,np.stack(bool_masks,-1),labels=mask_labels,colors=cmap,alpha=0.5,beta=0.7) try: cv2.imwrite('final_mask.png',final_mask) return 'Segmentation image created : final_mask.png' except Exception as e: logging.error("Error while saving the final mask :",e) return "unable to create a mask image "