samarthagarwal23's picture
Update app.py
8de99fc
import gradio as gr
import os
import numpy as np
os.system("pip install pdfminer.six rank_bm25 torch transformers")
from gradio.mix import Series
#import re
from rank_bm25 import BM25Okapi
import string
import torch
from transformers import pipeline
import pdfminer
from pdfminer.high_level import extract_text
len_doc = 500
overlap = 15
param_top_k_retriver = 15
param_top_k_ranker = 3
def read_pdf(file):
text = extract_text(file.name)
# Split text into smaller docs
docs = []
i = 0
while i < len(text):
docs.append(text[i:i+len_doc])
i = i + len_doc - overlap
return docs
# We use BM25 as retriver which will do 1st round of candidate filtering based on word based matching
def bm25_tokenizer(text):
stop_w = ['a', 'the', 'am', 'is' , 'are', 'who', 'how', 'where', 'when', 'why', 'what']
tokenized_doc = []
for token in text.lower().split():
token = token.strip(string.punctuation)
if len(token) > 0 and token not in stop_w:
tokenized_doc.append(token)
return tokenized_doc
def retrieval(query, top_k_retriver, docs, bm25_):
bm25_scores = bm25_.get_scores(bm25_tokenizer(query))
top_n = np.argsort(bm25_scores)[::-1][:top_k_retriver]
bm25_hits = [{'corpus_id': idx,
'score': bm25_scores[idx],
'docs':docs[idx]} for idx in top_n if bm25_scores[idx] > 0]
bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
return bm25_hits
def qa_ranker(query, docs_, top_k_ranker, qa_model):
ans = []
for doc in docs_:
answer = qa_model(question = query,
context = doc)
answer['doc'] = doc
ans.append(answer)
return sorted(ans, key=lambda x: x['score'], reverse=True)[:top_k_ranker]
def cstr(s, color='black'):
return "<text style=color:{}>{}</text>".format(color, s)
def cstr_bold(s, color='black'):
return "<text style=color:{}><b>{}</b></text>".format(color, s)
def cstr_break(s, color='black'):
return "<text style=color:{}><br>{}</text>".format(color, s)
def print_colored(text, start_idx, end_idx, confidence):
conf_str = '- Confidence: ' + confidence
a = cstr(' '.join([text[:start_idx], \
cstr_bold(text[start_idx:end_idx], color='blue'), \
text[end_idx:], \
cstr_break(conf_str, color='grey')]), color='black')
return a
def final_qa_pipeline(file, query, model_nm):
docs = read_pdf(file)
tokenized_corpus = []
for doc in docs:
tokenized_corpus.append(bm25_tokenizer(doc))
bm25 = BM25Okapi(tokenized_corpus)
top_k_retriver, top_k_ranker = param_top_k_retriver, param_top_k_ranker
lvl1 = retrieval(query, top_k_retriver, docs, bm25)
qa_model = pipeline("question-answering",
#model = "deepset/minilm-uncased-squad2")
model = "deepset/"+ str(model_nm))
if len(lvl1) > 0:
fnl_rank = qa_ranker(query, [l["docs"] for l in lvl1], top_k_ranker,qa_model)
top1 = print_colored(fnl_rank[0]['doc'], fnl_rank[0]['start'], fnl_rank[0]['end'], str(np.round(100*fnl_rank[0]["score"],1))+"%")
if len(lvl1)>1:
top2 = print_colored(fnl_rank[1]['doc'], fnl_rank[1]['start'], fnl_rank[1]['end'], str(np.round(100*fnl_rank[1]["score"],1))+"%")
else:
top2 = "None"
return (top1, top2)
else:
return ("No match","No match")
examples = [
[os.path.abspath("dbs-annual-report-2020.pdf"), "how many times has DBS won Best bank in the world ?","minilm-uncased-squad2"],
[os.path.abspath("dbs-annual-report-2020.pdf"), "how much dividend was paid to shareholders ?","minilm-uncased-squad2"],
[os.path.abspath("dbs-annual-report-2020.pdf"), "what is the sustainability focus ?","minilm-uncased-squad2"],
[os.path.abspath("NASDAQ_AAPL_2020.pdf"), "how much are the outstanding shares ?","minilm-uncased-squad2"],
[os.path.abspath("NASDAQ_AAPL_2020.pdf"), "what is competitors strategy ?","minilm-uncased-squad2"],
[os.path.abspath("NASDAQ_AAPL_2020.pdf"), "who is the chief executive officer ?","minilm-uncased-squad2"],
[os.path.abspath("NASDAQ_MSFT_2020.pdf"), "How much is the guided revenue for next quarter?","minilm-uncased-squad2"],
]
iface = gr.Interface(
fn = final_qa_pipeline,
inputs = [gr.inputs.File(label="input pdf file"), gr.inputs.Textbox(label="Question:"), gr.inputs.Dropdown(choices=["minilm-uncased-squad2","roberta-base-squad2"],label="Model")],
outputs = [gr.outputs.HTML(label="Top 1 answer"), gr.outputs.HTML(label="Top 2 answer")],
examples=examples,
theme = "grass",
title = "Question Answering on annual reports",
description = "Navigate long annual reports by using Machine learning to answer your questions. \nSimply upload any annual report pdf you are interested in and ask model a question OR load an example from below."
)
iface.launch(enable_queue = True)