sign-search / pipeline.py
grantpitt's picture
add pipeline
8e78d2d
raw
history blame
1.05 kB
from typing import Dict, List, Any
import numpy as np
from transformers import CLIPTokenizer, CLIPModel
class PreTrainedPipeline():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
# For instance your model, processors, tokenizer that might be needed.
# This function is only called once, so do all the heavy processing I/O here"""
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
def __call__(self, inputs: str) -> List[float]:
"""
Args:
inputs (:obj:`str`):
a string to get the features from.
Return:
A :obj:`list` of floats: The features computed by the model.
"""
token_inputs = self.tokenizer([inputs], padding=True, return_tensors="pt")
query_embed = self.model.get_text_features(**token_inputs)
return query_embed.detach().cpu().numpy()[0].tolist()