TusharGoel's picture
Update README.md
17a9f25
|
raw
history blame
2.07 kB
metadata
license: mit
language:
  - en
library_name: transformers
pipeline_tag: document-question-answering

Fine tuned on DocVQA Dataset 40000 questions

import json
from glob import glob
from transformers import AutoProcessor, AutoModelForDocumentQuestionAnswering

import torch
import numpy as np

model_name = "TusharGoel/LayoutLMv2-finetuned-docvqa"
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForDocumentQuestionAnswering.from_pretrained(model_name)


def pipeline(question, words, boxes, **kwargs):
    
    images = kwargs["images"]
    try:
        encoding = processor(
            images, question, words,boxes = boxes, return_token_type_ids=True, return_tensors="pt", truncation = True
        )
        word_ids = encoding.word_ids(0)

        outputs = model(**encoding)
        
        start_scores = outputs.start_logits
        end_scores = outputs.end_logits
        

        start, end = word_ids[start_scores.argmax(-1)], word_ids[end_scores.argmax(-1)]
        answer = " ".join(words[start : end + 1])

        start_scores, end_scores = start_scores.detach().numpy(), end_scores.detach().numpy()
        undesired_tokens = encoding['attention_mask']
        undesired_tokens_mask = undesired_tokens == 0.0

        start_ = np.where(undesired_tokens_mask, -10000.0, start_scores)
        end_ = np.where(undesired_tokens_mask, -10000.0, end_scores)
        start_ = np.exp(start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True)))
        end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True)))

        outer = np.matmul(np.expand_dims(start_, -1), np.expand_dims(end_, 1))
        max_answer_len = 20
        candidates = np.tril(np.triu(outer), max_answer_len - 1)
        scores_flat = candidates.flatten()

        idx_sort = [np.argmax(scores_flat)]
        start, end = np.unravel_index(idx_sort, candidates.shape)[1:]

        scores = candidates[0, start, end]
        score = scores[0]
    except Exception as e:
        answer, score = "", 0.0
    return answer, score