Spaces:
Runtime error
Runtime error
import torch | |
def extract_text_feature(prompt, model, processor, device='cpu'): | |
"""Extract text features | |
Args: | |
prompt: a single text query | |
model: OwlViT model | |
processor: OwlViT processor | |
device (str, optional): device to run. Defaults to 'cpu'. | |
""" | |
device = 'cpu' | |
if torch.cuda.is_available(): | |
device = 'cuda' | |
with torch.no_grad(): | |
input_ids = torch.as_tensor(processor(text=prompt)[ | |
'input_ids']).to(device) | |
print(input_ids.device) | |
text_outputs = model.owlvit.text_model( | |
input_ids=input_ids, | |
attention_mask=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
) | |
text_embeds = text_outputs[1] | |
text_embeds = model.owlvit.text_projection(text_embeds) | |
text_embeds /= text_embeds.norm(p=2, dim=-1, keepdim=True) + 1e-6 | |
query_embeds = text_embeds | |
return input_ids, query_embeds | |
def prompt2vec(prompt: str, model, processor): | |
""" Convert prompt into a computational vector | |
Args: | |
prompt (str): Text to be tokenized | |
Returns: | |
xq: vector from the tokenizer, representing the original prompt | |
""" | |
# inputs = tokenizer(prompt, return_tensors='pt') | |
# out = clip.get_text_features(**inputs) | |
input_ids, xq = extract_text_feature(prompt, model, processor) | |
input_ids = input_ids.detach().cpu().numpy() | |
xq = xq.detach().cpu().numpy() | |
return input_ids, xq | |
def tune(clf, X, y, iters=2): | |
""" Train the Zero-shot Classifier | |
Args: | |
X (numpy.ndarray): Input vectors (retreived vectors) | |
y (list of floats or numpy.ndarray): Scores given by user | |
iters (int, optional): iterations of updates to be run | |
""" | |
assert len(X) == len(y) | |
# train the classifier | |
clf.fit(X, y, iters=iters) | |
# extract new vector | |
return clf.get_weights() | |
class Classifier: | |
"""Multi-Class Zero-shot Classifier | |
This Classifier provides proxy regarding to the user's reaction to the probed images. | |
The proxy will replace the original query vector generated by prompted vector and finally | |
give the user a satisfying retrieval result. | |
This can be commonly seen in a recommendation system. The classifier will recommend more | |
precise result as it accumulating user's activity. | |
This is a multiclass classifier. For N queries it will set the all queries to the first-N classes | |
and the last one takes the negative one. | |
""" | |
def __init__(self, xq: list): | |
init_weight = torch.Tensor(xq) | |
self.num_class = xq.shape[0] | |
DIMS = xq.shape[1] | |
# note that the bias is ignored, as we only focus on the inner product result | |
self.model = torch.nn.Linear(DIMS, self.num_class, bias=False) | |
# convert initial query `xq` to tensor parameter to init weights | |
self.model.weight = torch.nn.Parameter(init_weight) | |
# init loss and optimizer | |
self.loss = torch.nn.BCEWithLogitsLoss() | |
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1) | |
def fit(self, X: list, y: list, iters: int = 5): | |
# convert X and y to tensor | |
X = torch.Tensor(X) | |
X /= torch.norm(X, p=2, dim=-1, keepdim=True) | |
y = torch.Tensor(y).long() | |
# Generate labels for binary classification and ignore outbound labels | |
non_ind = y > self.num_class | |
y = torch.nn.functional.one_hot(y % self.num_class, num_classes=self.num_class).float() | |
y[non_ind] = 0 | |
for i in range(iters): | |
# zero gradients | |
self.optimizer.zero_grad() | |
# Normalize the weight before inference | |
# This will constrain the gradient or you will have an explosion on query vector | |
self.model.weight.data /= torch.norm(self.model.weight.data, p=2, dim=-1, keepdim=True) | |
# forward pass | |
out = self.model(X) | |
# compute loss | |
loss = self.loss(out, y) | |
# backward pass | |
loss.backward() | |
# update weights | |
self.optimizer.step() | |
def get_weights(self): | |
xq = self.model.weight.detach().numpy() | |
return xq | |
class SplitLayer(torch.nn.Module): | |
def forward(self, x): | |
return torch.split(x, 1, dim=-1) | |