File size: 911 Bytes
be12cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from PIL import Image
import torch
from transformers import (
    VisionEncoderDecoderModel,
    ViTImageProcessor,  # Load extractor
    ViTModel,  # Load ViT encoder
)
MODEL = "kha-white/manga-ocr-base"

print("Loading models")
feature_extractor: ViTImageProcessor = ViTImageProcessor.from_pretrained(MODEL, requires_grad=False)
encoder: ViTModel = VisionEncoderDecoderModel.from_pretrained(MODEL).encoder

if torch.cuda.is_available():
    print('Using CUDA')
    encoder.cuda()
else:
    print('Using CPU')

def get_embeddings(images: list[Image.Image]) -> torch.Tensor:
    """Processes the images and returns their Embeddings"""
    images_rgb = [image.convert("RGB") for image in images]
    with torch.inference_mode():
        pixel_values: torch.Tensor = feature_extractor(images_rgb, return_tensors="pt")["pixel_values"]
        return encoder(pixel_values.to(encoder.device))["pooler_output"].cpu()