| import os |
| import torch |
| import torchvision |
|
|
| def build_LemonFM(nclasses: int = 2, pretrained: bool = True, pretrained_weights = None): |
|
|
|
|
| |
| net = torchvision.models.convnext_large() |
| input_emdim = net.classifier[2].in_features |
| net.classifier[2] = nn.Identity() |
| |
| if os.path.isfile(pretrained_weights): |
| state_dict = torch.load(pretrained_weights, map_location="cpu") |
| state_dict = state_dict['teacher'] |
|
|
| |
| state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items() if k.startswith('backbone.')} |
| msg = net.load_state_dict(state_dict, strict=False) |
| print(msg, input_emdim) |
|
|
| net.cuda() |
|
|
| return net |
|
|