kanji_lookup / encode.py
etrotta's picture
First Release
be12cc9
raw
history blame contribute delete
No virus
911 Bytes
from PIL import Image
import torch
from transformers import (
VisionEncoderDecoderModel,
ViTImageProcessor, # Load extractor
ViTModel, # Load ViT encoder
)
MODEL = "kha-white/manga-ocr-base"
print("Loading models")
feature_extractor: ViTImageProcessor = ViTImageProcessor.from_pretrained(MODEL, requires_grad=False)
encoder: ViTModel = VisionEncoderDecoderModel.from_pretrained(MODEL).encoder
if torch.cuda.is_available():
print('Using CUDA')
encoder.cuda()
else:
print('Using CPU')
def get_embeddings(images: list[Image.Image]) -> torch.Tensor:
"""Processes the images and returns their Embeddings"""
images_rgb = [image.convert("RGB") for image in images]
with torch.inference_mode():
pixel_values: torch.Tensor = feature_extractor(images_rgb, return_tensors="pt")["pixel_values"]
return encoder(pixel_values.to(encoder.device))["pooler_output"].cpu()