| from collections import OrderedDict |
|
|
| import torch |
| import torch.nn as nn |
| from torchvision.models import ( |
| ResNet50_Weights, |
| VGG16_BN_Weights, |
| VGG16_Weights, |
| resnet50, |
| vgg16, |
| vgg16_bn, |
| ) |
|
|
| from engine.BiRefNet.config import Config |
| from engine.BiRefNet.models.backbones.pvt_v2 import ( |
| pvt_v2_b0, |
| pvt_v2_b1, |
| pvt_v2_b2, |
| pvt_v2_b5, |
| ) |
| from engine.BiRefNet.models.backbones.swin_v1 import ( |
| swin_v1_b, |
| swin_v1_l, |
| swin_v1_s, |
| swin_v1_t, |
| ) |
|
|
| config = Config() |
|
|
|
|
| def build_backbone(bb_name, pretrained=True, params_settings=""): |
| if bb_name == "vgg16": |
| bb_net = list( |
| vgg16(pretrained=VGG16_Weights.DEFAULT if pretrained else None).children() |
| )[0] |
| bb = nn.Sequential( |
| OrderedDict( |
| { |
| "conv1": bb_net[:4], |
| "conv2": bb_net[4:9], |
| "conv3": bb_net[9:16], |
| "conv4": bb_net[16:23], |
| } |
| ) |
| ) |
| elif bb_name == "vgg16bn": |
| bb_net = list( |
| vgg16_bn( |
| pretrained=VGG16_BN_Weights.DEFAULT if pretrained else None |
| ).children() |
| )[0] |
| bb = nn.Sequential( |
| OrderedDict( |
| { |
| "conv1": bb_net[:6], |
| "conv2": bb_net[6:13], |
| "conv3": bb_net[13:23], |
| "conv4": bb_net[23:33], |
| } |
| ) |
| ) |
| elif bb_name == "resnet50": |
| bb_net = list( |
| resnet50( |
| pretrained=ResNet50_Weights.DEFAULT if pretrained else None |
| ).children() |
| ) |
| bb = nn.Sequential( |
| OrderedDict( |
| { |
| "conv1": nn.Sequential(*bb_net[0:3]), |
| "conv2": bb_net[4], |
| "conv3": bb_net[5], |
| "conv4": bb_net[6], |
| } |
| ) |
| ) |
| else: |
| bb = eval("{}({})".format(bb_name, params_settings)) |
| if pretrained: |
| bb = load_weights(bb, bb_name) |
| return bb |
|
|
|
|
| def load_weights(model, model_name): |
| save_model = torch.load( |
| config.weights[model_name], map_location="cpu", weights_only=True |
| ) |
| model_dict = model.state_dict() |
| state_dict = { |
| k: v if v.size() == model_dict[k].size() else model_dict[k] |
| for k, v in save_model.items() |
| if k in model_dict.keys() |
| } |
| |
| if not state_dict: |
| save_model_keys = list(save_model.keys()) |
| sub_item = save_model_keys[0] if len(save_model_keys) == 1 else None |
| state_dict = { |
| k: v if v.size() == model_dict[k].size() else model_dict[k] |
| for k, v in save_model[sub_item].items() |
| if k in model_dict.keys() |
| } |
| if not state_dict or not sub_item: |
| print( |
| "Weights are not successully loaded. Check the state dict of weights file." |
| ) |
| return None |
| else: |
| print( |
| 'Found correct weights in the "{}" item of loaded state_dict.'.format( |
| sub_item |
| ) |
| ) |
| model_dict.update(state_dict) |
| model.load_state_dict(model_dict) |
| return model |
|
|