clip-spanish / test_on_image.py
edugp's picture
Format with black and isort and lint with flake8
4463ade
raw
history blame contribute delete
No virus
1.47 kB
import os
import jax
import torch
from torchvision.io import ImageReadMode, read_image
from transformers import AutoTokenizer
from modeling_hybrid_clip import FlaxHybridCLIP
from run_hybrid_clip import Transform
def prepare_image(image_path, model):
image = read_image(image_path, mode=ImageReadMode.RGB)
preprocess = Transform(model.config.vision_config.image_size)
preprocess = torch.jit.script(preprocess)
preprocessed_image = preprocess(image)
pixel_values = torch.stack([preprocessed_image]).permute(0, 2, 3, 1).numpy()
return pixel_values
def prepare_text(text, tokenizer):
return tokenizer(text, return_tensors="np")
def run_inference(image_path, text, model, tokenizer):
pixel_values = prepare_image(image_path, model)
input_text = prepare_text(text, tokenizer)
model_output = model(
input_text["input_ids"],
pixel_values,
attention_mask=input_text["attention_mask"],
train=False,
return_dict=True,
)
logits = model_output["logits_per_image"]
score = jax.nn.sigmoid(logits)[0][0]
return score
if __name__ == "__main__":
model = FlaxHybridCLIP.from_pretrained("./")
tokenizer = AutoTokenizer.from_pretrained(
"bertin-project/bertin-roberta-base-spanish"
)
image_path = f"/home/{os.environ['USER']}/data/wit_scale_converted/Santuar.jpg"
text = "Fachada del Santuario"
print(run_inference(image_path, text, model, tokenizer))