root
add custom pipeline
09c3d85
raw
history blame contribute delete
No virus
1.8 kB
from typing import Dict, Any
from transformers import pipeline
import holidays
import PIL.Image
import io
import pytesseract
class PreTrainedPipeline():
def __init__(self, model_path="PrimWong/layout_qa_hparam_tuning"):
# Initializing the document-question-answering pipeline with the specified model
self.pipeline = pipeline("document-question-answering", model=model_path)
self.holidays = holidays.US()
def __call__(self, data: Dict[str, Any]) -> str:
"""
Process input data for document question answering with optional holiday checking.
Args:
data (Dict[str, Any]): Input data containing an 'inputs' field with 'image' and 'question',
and optionally a 'date' field.
Returns:
str: The answer to the question or a holiday message if applicable.
"""
inputs = data.get('inputs', {})
date = data.get("date")
# Check if date is provided and if it's a holiday
if date and date in self.holidays:
return "Today is a holiday!"
# Process the image and question for document question answering
image_path = inputs.get("image")
question = inputs.get("question")
# Load and process an image
image = PIL.Image.open(image_path)
image_text = pytesseract.image_to_string(image) # Use OCR to extract text
# Run prediction (Note: this now uses the extracted text, not the image directly)
prediction = self.pipeline(question=question, context=image_text)
return prediction["answer"] # Adjust based on actual output format of the model
# Note: This script assumes the use of pytesseract for OCR to process images. Ensure pytesseract is configured properly.