object-detection-safari / classifier.py
Fangrui Liu
init repo
3f1124e
raw
history blame
No virus
4.39 kB
import torch
def extract_text_feature(prompt, model, processor, device='cpu'):
"""Extract text features
Args:
prompt: a single text query
model: OwlViT model
processor: OwlViT processor
device (str, optional): device to run. Defaults to 'cpu'.
"""
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
with torch.no_grad():
input_ids = torch.as_tensor(processor(text=prompt)[
'input_ids']).to(device)
print(input_ids.device)
text_outputs = model.owlvit.text_model(
input_ids=input_ids,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
)
text_embeds = text_outputs[1]
text_embeds = model.owlvit.text_projection(text_embeds)
text_embeds /= text_embeds.norm(p=2, dim=-1, keepdim=True) + 1e-6
query_embeds = text_embeds
return input_ids, query_embeds
def prompt2vec(prompt: str, model, processor):
""" Convert prompt into a computational vector
Args:
prompt (str): Text to be tokenized
Returns:
xq: vector from the tokenizer, representing the original prompt
"""
# inputs = tokenizer(prompt, return_tensors='pt')
# out = clip.get_text_features(**inputs)
input_ids, xq = extract_text_feature(prompt, model, processor)
input_ids = input_ids.detach().cpu().numpy()
xq = xq.detach().cpu().numpy()
return input_ids, xq
def tune(clf, X, y, iters=2):
""" Train the Zero-shot Classifier
Args:
X (numpy.ndarray): Input vectors (retreived vectors)
y (list of floats or numpy.ndarray): Scores given by user
iters (int, optional): iterations of updates to be run
"""
assert len(X) == len(y)
# train the classifier
clf.fit(X, y, iters=iters)
# extract new vector
return clf.get_weights()
class Classifier:
"""Multi-Class Zero-shot Classifier
This Classifier provides proxy regarding to the user's reaction to the probed images.
The proxy will replace the original query vector generated by prompted vector and finally
give the user a satisfying retrieval result.
This can be commonly seen in a recommendation system. The classifier will recommend more
precise result as it accumulating user's activity.
This is a multiclass classifier. For N queries it will set the all queries to the first-N classes
and the last one takes the negative one.
"""
def __init__(self, xq: list):
init_weight = torch.Tensor(xq)
self.num_class = xq.shape[0]
DIMS = xq.shape[1]
# note that the bias is ignored, as we only focus on the inner product result
self.model = torch.nn.Linear(DIMS, self.num_class, bias=False)
# convert initial query `xq` to tensor parameter to init weights
self.model.weight = torch.nn.Parameter(init_weight)
# init loss and optimizer
self.loss = torch.nn.BCEWithLogitsLoss()
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
def fit(self, X: list, y: list, iters: int = 5):
# convert X and y to tensor
X = torch.Tensor(X)
X /= torch.norm(X, p=2, dim=-1, keepdim=True)
y = torch.Tensor(y).long()
# Generate labels for binary classification and ignore outbound labels
non_ind = y > self.num_class
y = torch.nn.functional.one_hot(y % self.num_class, num_classes=self.num_class).float()
y[non_ind] = 0
for i in range(iters):
# zero gradients
self.optimizer.zero_grad()
# Normalize the weight before inference
# This will constrain the gradient or you will have an explosion on query vector
self.model.weight.data /= torch.norm(self.model.weight.data, p=2, dim=-1, keepdim=True)
# forward pass
out = self.model(X)
# compute loss
loss = self.loss(out, y)
# backward pass
loss.backward()
# update weights
self.optimizer.step()
def get_weights(self):
xq = self.model.weight.detach().numpy()
return xq
class SplitLayer(torch.nn.Module):
def forward(self, x):
return torch.split(x, 1, dim=-1)