Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
'''模型使用老师提供的示例代码,仅修改了三处版本改动''' | |
def weights_init(m): | |
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): | |
nn.init.xavier_uniform_(m.weight.data) | |
nn.init.constant_(m.bias, 0.1) | |
class LossFn: | |
def __init__(self, cls_factor=1, box_factor=1, landmark_factor=1): | |
# loss function | |
self.cls_factor = cls_factor | |
self.box_factor = box_factor | |
self.land_factor = landmark_factor | |
self.loss_cls = nn.BCELoss() # binary cross entropy | |
self.loss_box = nn.MSELoss() # mean square error | |
self.loss_landmark = nn.MSELoss() | |
def cls_loss(self,gt_label,pred_label): | |
pred_label = torch.squeeze(pred_label) | |
gt_label = torch.squeeze(gt_label) | |
# get the mask element which >= 0, only 0 and 1 can effect the detection loss | |
mask = torch.ge(gt_label,0) | |
valid_gt_label = torch.masked_select(gt_label,mask) | |
valid_pred_label = torch.masked_select(pred_label,mask) | |
return self.loss_cls(valid_pred_label,valid_gt_label)*self.cls_factor | |
def box_loss(self,gt_label,gt_offset,pred_offset): | |
pred_offset = torch.squeeze(pred_offset) | |
gt_offset = torch.squeeze(gt_offset) | |
gt_label = torch.squeeze(gt_label) | |
#get the mask element which != 0 | |
unmask = torch.eq(gt_label,0) | |
mask = torch.eq(unmask,0) | |
#convert mask to dim index | |
chose_index = torch.nonzero(mask.data) | |
chose_index = torch.squeeze(chose_index) | |
#only valid element can effect the loss | |
valid_gt_offset = gt_offset[chose_index,:] | |
valid_pred_offset = pred_offset[chose_index,:] | |
return self.loss_box(valid_pred_offset,valid_gt_offset)*self.box_factor | |
def landmark_loss(self,gt_label,gt_landmark,pred_landmark): | |
pred_landmark = torch.squeeze(pred_landmark) | |
gt_landmark = torch.squeeze(gt_landmark) | |
gt_label = torch.squeeze(gt_label) | |
mask = torch.eq(gt_label,-2) | |
chose_index = torch.nonzero(mask.data) | |
chose_index = torch.squeeze(chose_index) | |
valid_gt_landmark = gt_landmark[chose_index, :] | |
valid_pred_landmark = pred_landmark[chose_index, :] | |
return self.loss_landmark(valid_pred_landmark,valid_gt_landmark)*self.land_factor | |
class PNet(nn.Module): | |
''' PNet ''' | |
def __init__(self, is_train=False, use_cuda=True): | |
super(PNet, self).__init__() | |
self.is_train = is_train | |
self.use_cuda = use_cuda | |
# backend | |
self.pre_layer = nn.Sequential( | |
nn.Conv2d(3, 10, kernel_size=3, stride=1), # conv1 | |
nn.PReLU(), # PReLU1 | |
nn.MaxPool2d(kernel_size=2, stride=2), # pool1 | |
nn.Conv2d(10, 16, kernel_size=3, stride=1), # conv2 | |
nn.PReLU(), # PReLU2 | |
nn.Conv2d(16, 32, kernel_size=3, stride=1), # conv3 | |
nn.PReLU() # PReLU3 | |
) | |
# detection | |
self.conv4_1 = nn.Conv2d(32, 1, kernel_size=1, stride=1) | |
# bounding box regresion | |
self.conv4_2 = nn.Conv2d(32, 4, kernel_size=1, stride=1) | |
# landmark localization | |
self.conv4_3 = nn.Conv2d(32, 10, kernel_size=1, stride=1) | |
# weight initiation with xavier | |
self.apply(weights_init) | |
def forward(self, x): | |
x = self.pre_layer(x) | |
label = torch.sigmoid(self.conv4_1(x)) | |
offset = self.conv4_2(x) | |
# landmark = self.conv4_3(x) | |
if self.is_train is True: | |
# label_loss = LossUtil.label_loss(self.gt_label,torch.squeeze(label)) | |
# bbox_loss = LossUtil.bbox_loss(self.gt_bbox,torch.squeeze(offset)) | |
return label,offset | |
#landmark = self.conv4_3(x) | |
return label, offset | |
class RNet(nn.Module): | |
''' RNet ''' | |
def __init__(self,is_train=False, use_cuda=True): | |
super(RNet, self).__init__() | |
self.is_train = is_train | |
self.use_cuda = use_cuda | |
# backend | |
self.pre_layer = nn.Sequential( | |
nn.Conv2d(3, 28, kernel_size=3, stride=1), # conv1 | |
nn.PReLU(), # prelu1 | |
nn.MaxPool2d(kernel_size=3, stride=2), # pool1 | |
nn.Conv2d(28, 48, kernel_size=3, stride=1), # conv2 | |
nn.PReLU(), # prelu2 | |
nn.MaxPool2d(kernel_size=3, stride=2), # pool2 | |
nn.Conv2d(48, 64, kernel_size=2, stride=1), # conv3 | |
nn.PReLU() # prelu3 | |
) | |
self.conv4 = nn.Linear(64*2*2, 128) # conv4 | |
self.prelu4 = nn.PReLU() # prelu4 | |
# detection | |
self.conv5_1 = nn.Linear(128, 1) | |
# bounding box regression | |
self.conv5_2 = nn.Linear(128, 4) | |
# lanbmark localization | |
self.conv5_3 = nn.Linear(128, 10) | |
# weight initiation weih xavier | |
self.apply(weights_init) | |
def forward(self, x): | |
# backend | |
x = self.pre_layer(x) | |
x = x.view(x.size(0), -1) | |
x = self.conv4(x) | |
x = self.prelu4(x) | |
# detection | |
det = torch.sigmoid(self.conv5_1(x)) | |
box = self.conv5_2(x) | |
# landmark = self.conv5_3(x) | |
if self.is_train is True: | |
return det, box | |
#landmard = self.conv5_3(x) | |
return det, box | |
class ONet(nn.Module): | |
''' RNet ''' | |
def __init__(self,is_train=False, use_cuda=True): | |
super(ONet, self).__init__() | |
self.is_train = is_train | |
self.use_cuda = use_cuda | |
# backend | |
self.pre_layer = nn.Sequential( | |
nn.Conv2d(3, 32, kernel_size=3, stride=1), # conv1 | |
nn.PReLU(), # prelu1 | |
nn.MaxPool2d(kernel_size=3, stride=2), # pool1 | |
nn.Conv2d(32, 64, kernel_size=3, stride=1), # conv2 | |
nn.PReLU(), # prelu2 | |
nn.MaxPool2d(kernel_size=3, stride=2), # pool2 | |
nn.Conv2d(64, 64, kernel_size=3, stride=1), # conv3 | |
nn.PReLU(), # prelu3 | |
nn.MaxPool2d(kernel_size=2,stride=2), # pool3 | |
nn.Conv2d(64,128,kernel_size=2,stride=1), # conv4 | |
nn.PReLU() # prelu4 | |
) | |
self.conv5 = nn.Linear(128*2*2, 256) # conv5 | |
self.prelu5 = nn.PReLU() # prelu5 | |
# detection | |
self.conv6_1 = nn.Linear(256, 1) | |
# bounding box regression | |
self.conv6_2 = nn.Linear(256, 4) | |
# lanbmark localization | |
self.conv6_3 = nn.Linear(256, 10) | |
# weight initiation weih xavier | |
self.apply(weights_init) | |
def forward(self, x): | |
# backend | |
x = self.pre_layer(x) | |
x = x.view(x.size(0), -1) | |
x = self.conv5(x) | |
x = self.prelu5(x) | |
# detection | |
det = torch.sigmoid(self.conv6_1(x)) | |
box = self.conv6_2(x) | |
landmark = self.conv6_3(x) | |
if self.is_train is True: | |
return det, box, landmark | |
#landmard = self.conv5_3(x) | |
return det, box, landmark | |