face-segmenter / model.py
pogzyb's picture
Upload SegformerForSemanticSegmentation
2163be8 verified
import torch
import transformers
from torch import nn
from transformers.modeling_outputs import SemanticSegmenterOutput
class FaceSegmenterConfig(transformers.PretrainedConfig):
model_type = "image-segmentation"
_id2label = {
0: "skin",
1: "l_brow",
2: "r_brow",
3: "l_eye",
4: "r_eye",
5: "eye_g",
6: "l_ear",
7: "r_ear",
8: "ear_r",
9: "nose",
10: "mouth",
11: "u_lip",
12: "l_lip",
13: "neck",
14: "neck_l",
15: "cloth",
16: "hair",
17: "hat",
}
_label2id = {
"skin": 0,
"l_brow": 1,
"r_brow": 2,
"l_eye": 3,
"r_eye": 4,
"eye_g": 5,
"l_ear": 6,
"r_ear": 7,
"ear_r": 8,
"nose": 9,
"mouth": 10,
"u_lip": 11,
"l_lip": 12,
"neck": 13,
"neck_l": 14,
"cloth": 15,
"hair": 16,
"hat": 17,
}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.id2label = kwargs.get("id2label", self._id2label)
# for some reason these are getting convert to strings when used in pipelines
id_keys = list(self.id2label.keys())
for label_id in id_keys:
label_value = self.id2label.pop(label_id)
self.id2label[int(label_id)] = label_value
self.label2id = kwargs.get("label2id", self._label2id)
self.num_classes = kwargs.get("num_classes", len(self.id2label))
def encode_down(c_in: int, c_out: int):
return nn.Sequential(
nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=3, padding=1),
nn.BatchNorm2d(num_features=c_out),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=c_out, out_channels=c_out, kernel_size=3, padding=1),
nn.BatchNorm2d(num_features=c_out),
nn.ReLU(inplace=True),
)
def decode_up(c: int):
return nn.ConvTranspose2d(
in_channels=c,
out_channels=int(c / 2),
kernel_size=2,
stride=2,
)
class FaceUNet(nn.Module):
def __init__(self, num_classes: int):
super().__init__()
self.num_classes = num_classes
# unet
self.down_1 = nn.Conv2d(
in_channels=3,
out_channels=64,
kernel_size=3,
padding=1,
)
self.down_2 = encode_down(64, 128)
self.down_3 = encode_down(128, 256)
self.down_4 = encode_down(256, 512)
self.down_5 = encode_down(512, 1024)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# Below, `in_channels` again becomes 1024 as we are concatinating.
self.up_1 = decode_up(1024)
self.up_c1 = encode_down(1024, 512)
self.up_2 = decode_up(512)
self.up_c2 = encode_down(512, 256)
self.up_3 = decode_up(256)
self.up_c3 = encode_down(256, 128)
self.up_4 = decode_up(128)
self.up_c4 = encode_down(128, 64)
self.segment = nn.Conv2d(
in_channels=64,
out_channels=self.num_classes,
kernel_size=3,
padding=1,
)
def forward(self, x):
d1 = self.down_1(x)
d2 = self.pool(d1)
d3 = self.down_2(d2)
d4 = self.pool(d3)
d5 = self.down_3(d4)
d6 = self.pool(d5)
d7 = self.down_4(d6)
d8 = self.pool(d7)
d9 = self.down_5(d8)
u1 = self.up_1(d9)
x = self.up_c1(torch.cat([d7, u1], 1))
u2 = self.up_2(x)
x = self.up_c2(torch.cat([d5, u2], 1))
u3 = self.up_3(x)
x = self.up_c3(torch.cat([d3, u3], 1))
u4 = self.up_4(x)
x = self.up_c4(torch.cat([d1, u4], 1))
x = self.segment(x)
return x
class Segformer(transformers.PreTrainedModel):
config_class = FaceSegmenterConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.model = FaceUNet(num_classes=config.num_classes)
def forward(self, tensor):
return self.model.forward_features(tensor)
class SegformerForSemanticSegmentation(transformers.PreTrainedModel):
config_class = FaceSegmenterConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.model = FaceUNet(num_classes=config.num_classes)
def forward(self, pixel_values, labels=None):
logits = self.model(pixel_values)
values = {"logits": logits}
if labels is not None:
loss = torch.nn.cross_entropy(logits, labels)
values["loss"] = loss
return SemanticSegmenterOutput(**values)