import torch from kornia.morphology import dilation, closing import requests from transformers import SamModel, SamProcessor print('Loading SAM...') device = "cuda" if torch.cuda.is_available() else "cpu" model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") print('DONE') def build_mask(image, faces, hairs): # 1. Segmentation input_points = faces # 2D location of the face with torch.no_grad(): inputs = processor(image, input_points=input_points, return_tensors="pt").to(device) outputs = model(**inputs) masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) scores = outputs.iou_scores input_points = hairs # 2D location of the face with torch.no_grad(): inputs = processor(image, input_points=input_points, return_tensors="pt").to(device) outputs = model(**inputs) h_masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) h_scores = outputs.iou_scores # 2. Post-processing mask=masks[0][0].all(0) | h_masks[0][0].all(0) # dilation tensor = mask[None,None,:,:] kernel = torch.ones(3, 3) mask = closing(tensor, kernel)[0,0].bool() return mask def build_mask_multi(image, faces, hairs): all_masks = [] for face,hair in zip(faces,hairs): # 1. Segmentation input_points = [face] # 2D location of the face with torch.no_grad(): inputs = processor(image, input_points=input_points, return_tensors="pt").to(device) outputs = model(**inputs) masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) scores = outputs.iou_scores input_points = [hair] # 2D location of the face with torch.no_grad(): inputs = processor(image, input_points=input_points, return_tensors="pt").to(device) outputs = model(**inputs) h_masks = processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) h_scores = outputs.iou_scores # 2. Post-processing mask=masks[0][0].all(0) | h_masks[0][0].all(0) # dilation mask_T = mask[None,None,:,:] kernel = torch.ones(3, 3) mask = closing(mask_T, kernel)[0,0].bool() all_masks.append(mask) mask = all_masks[0] for next_mask in all_masks[1:]: mask = mask | next_mask return mask