File size: 893 Bytes
d2ff88f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# ------------------------------------------------------------------------------
import torch
import logging
log = logging.getLogger(__name__)
from mmseg.ops import resize
from mmseg.models import EncoderDecoder

class DinoCLIP_Infrencer(EncoderDecoder):
    def __init__(
        self,
        model,
        num_classes,
        test_cfg=dict(),
        **kwargs,
    ):
        super(EncoderDecoder, self).__init__()
        self.mode = test_cfg['mode']
        self.num_classes = num_classes
        self.model = model
        self.test_cfg = test_cfg
        self.align_corners = False

    @torch.no_grad()
    def encode_decode(self, img, meta_data):
        """
        """
        masks = self.model(img)
        masks = resize(
            input=masks,
            size=img.shape[-2:],
            mode='bilinear',
            align_corners=self.align_corners)
        return masks