dino-vitb16 / pipeline.py
agiera's picture
Update pipeline.py
3e3cd94
raw
history blame
1.65 kB
from typing import Dict, List, Any
import PIL
import torch
import base64
import os
import io
from transformers import ViTImageProcessor, ViTModel
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class PreTrainedPipeline():
def __init__(self, path=""):
self.model = ViTModel.from_pretrained(
pretrained_model_name_or_path=path,
config=os.path.join(path, 'config.json')
)
self.model.eval()
self.model = self.model.to(device)
self.processor = ViTImageProcessor.from_pretrained(
pretrained_model_name_or_path=os.path.join(
path, 'preprocessor_config.json')
)
def __call__(self, data: Any) -> Dict[str, List[float]]:
"""
Args:
data (:dict | str:):
Includes the input data and the parameters for the inference.
Inputs should be an image encoded in base 64.
Return:
A :obj:`dict`:. The object returned should be a dict like
{"feature_vector": [0.6331314444541931,...,-0.7866355180740356,]} containing :
- "feature_vector": A list of floats corresponding to the image embedding.
"""
inputs = data.pop("inputs", data)
# decode base64 image to PIL
image = PIL.Image.open(io.BytesIO(base64.b64decode(inputs['image'])))
inputs = self.processor(images=image, return_tensors="pt")
outputs = self.model(**inputs)
feature_vector = outputs.last_hidden_state[0, 0].tolist()
# postprocess the prediction
return {"feature_vector": feature_vector}