File size: 1,085 Bytes
328ca69
3987d05
 
 
328ca69
3987d05
328ca69
3987d05
 
4e6206d
328ca69
a3f5fed
45b9f3a
 
3987d05
 
 
 
 
45b9f3a
3987d05
 
 
 
9c1e80b
 
3987d05
45b9f3a
328ca69
45b9f3a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np

# Load SegFormer for hair segmentation
processor = SegformerImageProcessor.from_pretrained("VanNguyen1214/get_face_and_hair")
model     = AutoModelForSemanticSegmentation.from_pretrained("VanNguyen1214/get_face_and_hair")

def extract_hair(image: Image.Image) -> Image.Image:
    """
    Return an RGBA image where hair pixels have alpha=255 and
    all other pixels have alpha=0.
    """
    rgb = image.convert("RGB")
    arr = np.array(rgb)
    h, w = arr.shape[:2]

    # Segment hair
    inputs = processor(images=rgb, return_tensors="pt")
    with torch.no_grad():
        logits = model(**inputs).logits.cpu()
    up = F.interpolate(logits, size=(h, w), mode="bilinear", align_corners=False)
    seg = up.argmax(dim=1)[0].numpy()
    hair_mask = (seg == 2).astype(np.uint8)

    # Build RGBA
    alpha = (hair_mask * 255).astype(np.uint8)
    rgba  = np.dstack([arr, alpha])
    return Image.fromarray(rgba)