File size: 2,019 Bytes
b4880ac
3c3d624
f43b2e5
 
 
 
 
3c3d624
b4880ac
f43b2e5
b4880ac
f43b2e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
995dcec
 
 
f43b2e5
0136d74
f43b2e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc9455d
f43b2e5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import gradio as gr
import numpy as np
from PIL import Image
import torch
from torchvision.transforms.functional import to_tensor, normalize
from transformers import SegformerForSemanticSegmentation

hf_token = os.environ.get("HF_TOKEN", None)
device = torch.device("cpu")

label2id = {"background": 0, "skin": 1, "hair": 2, "clothes": 3, "accessories": 4}
id2label = {v: k for k, v in label2id.items()}
colors = {
    "background": (40, 40, 40),
    "skin": (255, 178, 127),
    "hair": (139, 69, 19),
    "clothes": (100, 149, 237),
    "accessories": (255, 215, 0),
}

model = SegformerForSemanticSegmentation.from_pretrained(
    "neuratech-ai/person_segmentation_v3",
    token=hf_token,
    ignore_mismatched_sizes=True,
    num_labels=len(label2id),
    id2label=id2label,
    label2id=label2id,
)
model.eval()
model.to(device)


def preds_to_rgb(preds):
    preds_rgb = np.zeros((preds.shape[0], preds.shape[1], 3), dtype=np.uint8)

    for class_name, class_id in label2id.items():
        preds_rgb[preds == class_id] = colors[class_name]

    return preds_rgb


def query_image(img):
    if img is None:
        return None

    img = Image.fromarray(img)
    scale = 1024 / max(img.size)
    img = img.resize(
        (int(img.size[0] * scale), int(img.size[1] * scale)), Image.LANCZOS
    )

    img = normalize(
        to_tensor(img),
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
    )

    with torch.no_grad():
        outputs = model(img.unsqueeze(0))

    preds = outputs.logits.cpu()
    w, h = preds.shape[-2:]
    preds = torch.nn.functional.interpolate(
        preds, size=(w * 4, h * 4), mode="bilinear", align_corners=False
    )
    results = torch.argmax(preds, dim=1).numpy()[0]
    results = preds_to_rgb(results)
    return Image.fromarray(results)


demo = gr.Interface(
    query_image,
    inputs=[gr.Image()],
    outputs="image",
    title="neuratech-ai person segmentation v3",
    examples=[["example1.jpg"], ["example2.jpg"]],
)

demo.launch()