Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import numpy as np | |
os.system("pip install pdfminer.six rank_bm25 torch transformers termcolor") | |
from gradio.mix import Series | |
import re | |
from rank_bm25 import BM25Okapi | |
import string | |
import torch | |
from transformers import pipeline | |
import pdfminer | |
from pdfminer.high_level import extract_text | |
#from termcolor import colored | |
def read_pdf(file): | |
text = extract_text(file.name) | |
# Split text into smaller docs | |
len_doc = 400 | |
overlap = 50 | |
docs = [] | |
i = 0 | |
while i < len(text): | |
docs.append(text[i:i+len_doc]) | |
i = i + len_doc - overlap | |
return docs | |
# We use BM25 as retriver which will do 1st round of candidate filtering based on word based matching | |
def bm25_tokenizer(text): | |
stop_w = ['a', 'the', 'am', 'is' , 'are', 'who', 'how', 'where', 'when', 'why', 'what'] | |
tokenized_doc = [] | |
for token in text.lower().split(): | |
token = token.strip(string.punctuation) | |
if len(token) > 0 and token not in stop_w: | |
tokenized_doc.append(token) | |
return tokenized_doc | |
def retrieval(query, top_k_retriver, docs, bm25_): | |
bm25_scores = bm25_.get_scores(bm25_tokenizer(query)) | |
top_n = np.argsort(bm25_scores)[::-1][:top_k_retriver] | |
bm25_hits = [{'corpus_id': idx, | |
'score': bm25_scores[idx], | |
'docs':docs[idx]} for idx in top_n if bm25_scores[idx] > 0] | |
bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True) | |
return bm25_hits | |
qa_model = pipeline("question-answering", | |
model = "deepset/roberta-base-squad2") | |
def qa_ranker(query, docs_, top_k_ranker): | |
ans = [] | |
for doc in docs_: | |
answer = qa_model(question = query, | |
context = doc) | |
answer['doc'] = doc | |
ans.append(answer) | |
return sorted(ans, key=lambda x: x['score'], reverse=True)[:top_k_ranker] | |
def cstr(s, color='black'): | |
return "<text style=color:{}>{}</text>".format(color, s) | |
def cstr_bold(s, color='black'): | |
return "<text style=color:{}><b>{}</b></text>".format(color, s) | |
def cstr_break(s, color='black'): | |
return "<text style=color:{}>{}<br></text>".format(color, s) | |
def print_colored(text, start_idx, end_idx, confidence): | |
a = cstr_break(' '.join(cstr(' '.join([text[:start_idx], cstr_bold(text[start_idx:end_idx], color='red'), text[end_idx:]]), color='black'), 'Confidence: {}'.format(confidence), color='black) | |
#a = colored(text[:start_idx]) + colored(text[start_idx:end_idx], 'red', 'on_yellow') + colored(text[end_idx:]) | |
return a | |
def final_qa_pipeline(file, query): | |
docs = read_pdf(file) | |
tokenized_corpus = [] | |
for doc in docs: | |
tokenized_corpus.append(bm25_tokenizer(doc)) | |
bm25 = BM25Okapi(tokenized_corpus) | |
top_k_retriver, top_k_ranker = 20,1 | |
lvl1 = retrieval(query, top_k_retriver, docs, bm25) | |
if len(lvl1) > 0: | |
fnl_rank = qa_ranker(query, [l["docs"] for l in lvl1], top_k_ranker) | |
#return (fnl_rank[0]["answer"], str(np.round(100*fnl_rank[0]["score"],2))+"%" , fnl_rank[0]['doc']) | |
#return (print_colored(fnl_rank[0]['doc'], fnl_rank[0]['start'], fnl_rank[0]['end']), str(np.round(100*fnl_rank[0]["score"],2))+"%" | |
return (print_colored(fnl_rank[0]['doc'], fnl_rank[0]['start'], fnl_rank[0]['end'], str(np.round(100*fnl_rank[0]["score"],2))+"%") ) | |
#for fnl_ in fnl_rank: | |
# print("\n") | |
# print_colored(fnl_['doc'], fnl_['start'], fnl_['end']) | |
# print(colored("Confidence score of ") + colored(str(fnl_['score'])[:4], attrs=['bold'])) | |
else: | |
return ("No match", "0") | |
examples = [ | |
[os.path.abspath("dbs-annual-report-2020.pdf"), "how much dividend was paid to shareholders ?"], | |
[os.path.abspath("dbs-annual-report-2020.pdf"), "what are the key risks ?"], | |
[os.path.abspath("dbs-annual-report-2020.pdf"), "what is the sustainability focus ?"], | |
[os.path.abspath("NASDAQ_AAPL_2020.pdf"), "how much are the outstanding shares ?"], | |
[os.path.abspath("NASDAQ_AAPL_2020.pdf"), "How high is shareholders equity ?"], | |
[os.path.abspath("NASDAQ_AAPL_2020.pdf"), "what is competitors strategy ?"], | |
] | |
iface = gr.Interface( | |
fn = final_qa_pipeline, | |
inputs = [gr.inputs.File(label="input pdf file"), gr.inputs.Textbox(label="Question:")], | |
outputs = [gr.outputs.HTML(label="Predicted answer"), gr.outputs.Textbox(label="Confidence") ], | |
examples=examples, | |
title = "Question Answering on company annual reports", | |
description = "Simply upload any annual report pdf you are interested in and ask model a question OR load an example from below." | |
) | |
iface.launch() |