File size: 1,395 Bytes
711211a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32


import torch
from transformers import OwlViTProcessor, OwlViTForObjectDetection

from .model import OwlViTForClassification

def load_xclip(device: str = "cuda:0", 
               n_classes: int = 183, 
               use_teacher_logits: bool = False, 
               custom_box_head: bool = False,
               model_path: str = 'data/models/peeb_pretrain.pt',
               ):
    
    owlvit_det_processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
    owlvit_det_model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to(device)

    # BirdSoup mean std
    mean = [0.48168647, 0.49244233, 0.42851609]
    std = [0.18656386, 0.18614962, 0.19659419]
    owlvit_det_processor.image_processor.image_mean = mean
    owlvit_det_processor.image_processor.image_std = std
        
    # load finetuned owl-vit model
    weight_dict = {"loss_ce": 0, "loss_bbox": 0, "loss_giou": 0,
                    "loss_sym_box_label": 0, "loss_xclip": 0}
    model = OwlViTForClassification(owlvit_det_model=owlvit_det_model, num_classes=n_classes, device=device, weight_dict=weight_dict, logits_from_teacher=use_teacher_logits, custom_box_head=custom_box_head)
    if model_path is not None:
        ckpt = torch.load(model_path, map_location='cpu')
        model.load_state_dict(ckpt, strict=False)
    model.to(device)
    return model, owlvit_det_processor