File size: 4,386 Bytes
3f1124e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch


def extract_text_feature(prompt, model, processor, device='cpu'):
    """Extract text features

    Args:
        prompt: a single text query
        model: OwlViT model
        processor: OwlViT processor
        device (str, optional): device to run. Defaults to 'cpu'.
    """
    device = 'cpu'
    if torch.cuda.is_available():
        device = 'cuda'
    with torch.no_grad():
        input_ids = torch.as_tensor(processor(text=prompt)[
                                    'input_ids']).to(device)
        print(input_ids.device)
        text_outputs = model.owlvit.text_model(
            input_ids=input_ids,
            attention_mask=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
        )
        text_embeds = text_outputs[1]
        text_embeds = model.owlvit.text_projection(text_embeds)
        text_embeds /= text_embeds.norm(p=2, dim=-1, keepdim=True) + 1e-6
        query_embeds = text_embeds
    return input_ids, query_embeds


def prompt2vec(prompt: str, model, processor):
    """ Convert prompt into a computational vector

    Args:
        prompt (str): Text to be tokenized

    Returns:
        xq: vector from the tokenizer, representing the original prompt
    """
    # inputs = tokenizer(prompt, return_tensors='pt')
    # out = clip.get_text_features(**inputs)
    input_ids, xq = extract_text_feature(prompt, model, processor)
    input_ids = input_ids.detach().cpu().numpy()
    xq = xq.detach().cpu().numpy()
    return input_ids, xq


def tune(clf, X, y, iters=2):
    """ Train the Zero-shot Classifier

    Args:
        X (numpy.ndarray): Input vectors (retreived vectors)
        y (list of floats or numpy.ndarray): Scores given by user
        iters (int, optional): iterations of updates to be run
    """
    assert len(X) == len(y)
    # train the classifier
    clf.fit(X, y, iters=iters)
    # extract new vector
    return clf.get_weights()


class Classifier:
    """Multi-Class Zero-shot Classifier
    This Classifier provides proxy regarding to the user's reaction to the probed images.
    The proxy will replace the original query vector generated by prompted vector and finally
    give the user a satisfying retrieval result.

    This can be commonly seen in a recommendation system. The classifier will recommend more 
    precise result as it accumulating user's activity.
    
    This is a multiclass classifier. For N queries it will set the all queries to the first-N classes
    and the last one takes the negative one.
    """

    def __init__(self, xq: list):
        init_weight = torch.Tensor(xq)
        self.num_class = xq.shape[0]
        DIMS = xq.shape[1]
        # note that the bias is ignored, as we only focus on the inner product result
        self.model = torch.nn.Linear(DIMS, self.num_class, bias=False)
        # convert initial query `xq` to tensor parameter to init weights
        self.model.weight = torch.nn.Parameter(init_weight)
        # init loss and optimizer
        self.loss = torch.nn.BCEWithLogitsLoss()
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)

    def fit(self, X: list, y: list, iters: int = 5):
        # convert X and y to tensor
        X = torch.Tensor(X)
        X /= torch.norm(X, p=2, dim=-1, keepdim=True)
        y = torch.Tensor(y).long()
        # Generate labels for binary classification and ignore outbound labels
        non_ind = y > self.num_class
        y = torch.nn.functional.one_hot(y % self.num_class, num_classes=self.num_class).float()
        y[non_ind] = 0
        for i in range(iters):
            # zero gradients
            self.optimizer.zero_grad()
            # Normalize the weight before inference
            # This will constrain the gradient or you will have an explosion on query vector
            self.model.weight.data /= torch.norm(self.model.weight.data, p=2, dim=-1, keepdim=True)
            # forward pass
            out = self.model(X)
            # compute loss
            loss = self.loss(out, y)
            # backward pass
            loss.backward()
            # update weights
            self.optimizer.step()

    def get_weights(self):
        xq = self.model.weight.detach().numpy()
        return xq
    
class SplitLayer(torch.nn.Module):
    def forward(self, x):
        return torch.split(x, 1, dim=-1)