Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import cv2 | |
| import argparse | |
| import warnings | |
| try: | |
| import torch as th | |
| from transformers import AutoImageProcessor ,Mask2FormerModel,Mask2FormerForUniversalSegmentation | |
| except ImportError as error: | |
| raise ('Try installing torch and Transfomers module using pip.') | |
| warnings.filterwarnings("ignore") | |
| class MASK2FORMER: | |
| def __init__(self,model_name="facebook/mask2former-swin-small-ade-semantic",class_id =6): ## use large | |
| self.image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-ade-semantic") | |
| self.maskformer_processor = Mask2FormerModel.from_pretrained(model_name) | |
| self.maskformer_model = Mask2FormerForUniversalSegmentation.from_pretrained(model_name) | |
| self.DEVICE = "cuda" if th.cuda.is_available() else 'cpu' | |
| self.segment_id = class_id | |
| self.maskformer_model.to(self.DEVICE) | |
| def create_rgb_mask(self,mask,value=255): | |
| gray_3_channel = cv2.merge((mask, mask, mask)) | |
| gray_3_channel[mask==value] = (255,255,255) | |
| return gray_3_channel.astype(np.uint8) | |
| def get_mask(self,segmentation): | |
| """ | |
| Mask out the segment of the class from the provided segment_id | |
| args : segmentation -> torch.obj - segmentation ouput from the maskformer model | |
| segment_id -> class id of the object to be extracted | |
| return : ndarray -> 2D Mask of the image | |
| """ | |
| if self.segment_id == "vehicle": | |
| mask = (segmentation.cpu().numpy().copy()==2) | (segmentation.cpu().numpy().copy()==5) | (segmentation.cpu().numpy().copy()== 7) | |
| else: | |
| mask = (segmentation.cpu().numpy() == 6) | |
| visual_mask = (mask * 255).astype(np.uint8) | |
| return visual_mask #np.asarray(visual_mask) | |
| def generate_road_mask(self,img): | |
| """ | |
| Extract semantic road mask from raw image | |
| args : img -> np.array - input_image | |
| return : ndarray -> masked out road . | |
| """ | |
| inputs = self.image_processor(img, return_tensors="pt") | |
| inputs = inputs.to(self.DEVICE) | |
| with th.no_grad(): | |
| outputs = self.maskformer_model(**inputs) | |
| segmentation = self.image_processor.post_process_semantic_segmentation(outputs,target_sizes=[(img.shape[0],img.shape[1])])[0] | |
| segmented_mask = self.get_mask(segmentation=segmentation) | |
| return segmented_mask | |
| def get_rgb_mask(self,img,segmented_mask): | |
| """ | |
| Extract RGB road image and removing the background . | |
| args: img -> ndarray - raw image | |
| segmented_mask - binary mask from the semantic segmentation | |
| return : ndarray -> RGB road image with background pixels as 0. | |
| """ | |
| predicted_rgb_mask = self.create_rgb_mask(segmented_mask) | |
| rgb_mask_img = cv2.bitwise_and(img,predicted_rgb_mask ) | |
| return rgb_mask_img | |
| def run_inference(self,image_name): | |
| """ | |
| Function used to create a segmentation mask for specific segment_id provided. The function uses | |
| "facebook/maskformer-swin-small-coco" maskformer model to extract segmentation mask for the provided image | |
| args: image_name -> str/numpy_array- image path read and processed by maskformer . | |
| out_path -> str - output path save the masked output | |
| skip_read -> bool- If provided image is nd_array skip_read == True else False | |
| segment_id -> int- id value to extract maks Default value is 100 for road | |
| """ | |
| input_image = cv2.cvtColor( cv2.imread(image_name),cv2.COLOR_BGR2RGB) | |
| road_mask = self.generate_road_mask(input_image) | |
| road_image = self.get_rgb_mask(input_image,road_mask) | |
| obj_prop = round((np.count_nonzero(road_image) / np.size(road_image)) * 100, 1) | |
| ## empty gou cache | |
| with th.no_grad(): | |
| th.cuda.empty_cache() | |
| return obj_prop | |
| def main(args): | |
| mask2former = ROADMASK_WITH_MASK2FORMER() | |
| input_image = cv2.cvtColor( cv2.imread(args.image_path),cv2.COLOR_BGR2RGB) | |
| road_mask = mask2former.generate_road_mask(input_image) | |
| road_image = mask2former.get_rgb_mask(input_image,road_mask) | |
| obj_prop = round(np.count_nonzero(road_image) / np.size(road_image) * 100, 1) | |
| return road_mask , road_image , obj_prop | |
| if __name__=="__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-image_path',help='raw_image_path', required=True) | |
| args = parser.parse_args() | |
| main(args) |