ghep_image / segmentation.py
VanNguyen1214's picture
Update segmentation.py
45b9f3a verified
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
# Load SegFormer for hair segmentation
processor = SegformerImageProcessor.from_pretrained("VanNguyen1214/get_face_and_hair")
model = AutoModelForSemanticSegmentation.from_pretrained("VanNguyen1214/get_face_and_hair")
def extract_hair(image: Image.Image) -> Image.Image:
"""
Return an RGBA image where hair pixels have alpha=255 and
all other pixels have alpha=0.
"""
rgb = image.convert("RGB")
arr = np.array(rgb)
h, w = arr.shape[:2]
# Segment hair
inputs = processor(images=rgb, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits.cpu()
up = F.interpolate(logits, size=(h, w), mode="bilinear", align_corners=False)
seg = up.argmax(dim=1)[0].numpy()
hair_mask = (seg == 2).astype(np.uint8)
# Build RGBA
alpha = (hair_mask * 255).astype(np.uint8)
rgba = np.dstack([arr, alpha])
return Image.fromarray(rgba)