|
from typing import Dict, Any |
|
from transformers import pipeline |
|
import holidays |
|
import PIL.Image |
|
import io |
|
import pytesseract |
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, model_path="PrimWong/layout_qa_hparam_tuning"): |
|
|
|
self.pipeline = pipeline("document-question-answering", model=model_path) |
|
self.holidays = holidays.US() |
|
|
|
def __call__(self, data: Dict[str, Any]) -> str: |
|
""" |
|
Process input data for document question answering with optional holiday checking. |
|
|
|
Args: |
|
data (Dict[str, Any]): Input data containing an 'inputs' field with 'image' and 'question', |
|
and optionally a 'date' field. |
|
|
|
Returns: |
|
str: The answer to the question or a holiday message if applicable. |
|
""" |
|
inputs = data.get('inputs', {}) |
|
date = data.get("date") |
|
|
|
|
|
if date and date in self.holidays: |
|
return "Today is a holiday!" |
|
|
|
|
|
image_path = inputs.get("image") |
|
question = inputs.get("question") |
|
|
|
|
|
image = PIL.Image.open(image_path) |
|
image_text = pytesseract.image_to_string(image) |
|
|
|
|
|
prediction = self.pipeline(question=question, context=image_text) |
|
return prediction["answer"] |
|
|
|
|
|
|