Spaces:
Runtime error
Runtime error
import torch | |
from kornia.morphology import dilation, closing | |
import requests | |
from transformers import SamModel, SamProcessor | |
print('Loading SAM...') | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) | |
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") | |
print('DONE') | |
def build_mask(image, faces, hairs): | |
# 1. Segmentation | |
input_points = faces # 2D location of the face | |
with torch.no_grad(): | |
inputs = processor(image, input_points=input_points, return_tensors="pt").to(device) | |
outputs = model(**inputs) | |
masks = processor.image_processor.post_process_masks( | |
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() | |
) | |
scores = outputs.iou_scores | |
input_points = hairs # 2D location of the face | |
with torch.no_grad(): | |
inputs = processor(image, input_points=input_points, return_tensors="pt").to(device) | |
outputs = model(**inputs) | |
h_masks = processor.image_processor.post_process_masks( | |
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() | |
) | |
h_scores = outputs.iou_scores | |
# 2. Post-processing | |
mask=masks[0][0].all(0) | h_masks[0][0].all(0) | |
# dilation | |
tensor = mask[None,None,:,:] | |
kernel = torch.ones(3, 3) | |
mask = closing(tensor, kernel)[0,0].bool() | |
return mask | |
def build_mask_multi(image, faces, hairs): | |
all_masks = [] | |
for face,hair in zip(faces,hairs): | |
# 1. Segmentation | |
input_points = [face] # 2D location of the face | |
with torch.no_grad(): | |
inputs = processor(image, input_points=input_points, return_tensors="pt").to(device) | |
outputs = model(**inputs) | |
masks = processor.image_processor.post_process_masks( | |
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() | |
) | |
scores = outputs.iou_scores | |
input_points = [hair] # 2D location of the face | |
with torch.no_grad(): | |
inputs = processor(image, input_points=input_points, return_tensors="pt").to(device) | |
outputs = model(**inputs) | |
h_masks = processor.image_processor.post_process_masks( | |
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() | |
) | |
h_scores = outputs.iou_scores | |
# 2. Post-processing | |
mask=masks[0][0].all(0) | h_masks[0][0].all(0) | |
# dilation | |
mask_T = mask[None,None,:,:] | |
kernel = torch.ones(3, 3) | |
mask = closing(mask_T, kernel)[0,0].bool() | |
all_masks.append(mask) | |
mask = all_masks[0] | |
for next_mask in all_masks[1:]: | |
mask = mask | next_mask | |
return mask |