motiv-receipt-donut / handler.py
samson-s's picture
Add gpu specification
da85024
from transformers import pipeline, AutoFeatureExtractor
from PIL import Image
import torch
class EndpointHandler:
def __init__(self, path=""):
self.pipe = pipeline(
"image-to-text",
model=path,
feature_extractor=AutoFeatureExtractor,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
def __call__(self, data) -> str:
return self.pipe(data.pop("inputs"))