File size: 5,240 Bytes
abc928c
19da527
9a5a690
3ba9cd3
 
 
3bd289d
 
e2bc8f2
abc928c
744d2f2
3bd289d
616d967
 
 
 
 
 
 
 
 
abc928c
 
4944874
 
 
3bd289d
4944874
 
 
abc928c
08b6600
4944874
08b6600
 
 
 
e2bc8f2
08b6600
4944874
 
 
3ba9cd3
1b4d7ee
4944874
3bd289d
 
 
744d2f2
60d0552
 
 
 
 
 
 
 
 
 
 
 
744d2f2
e2bc8f2
60d0552
 
19da527
1b4d7ee
6d001a9
60d0552
 
 
 
 
616d967
 
 
 
 
c6ecfc8
616d967
 
 
 
c6ecfc8
616d967
c6ecfc8
616d967
 
 
 
 
19da527
616d967
e2bc8f2
19da527
4efabce
e2bc8f2
 
c6ecfc8
4efabce
c6ecfc8
 
3bd289d
4efabce
 
 
616d967
e2bc8f2
60d0552
e2bc8f2
3bd289d
4944874
 
abc928c
616d967
97be9e5
e2bc8f2
4efabce
c6ecfc8
 
4efabce
c6ecfc8
 
616d967
e2bc8f2
abc928c
 
6ca00a9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
from pdfminer.high_level import extract_text
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import nltk
from nltk.tokenize import sent_tokenize
from rank_bm25 import BM25Okapi

nltk.download('punkt')

# QA model
qa_model_name = "deepset/roberta-large-squad2"
qa_model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name)
qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
qa_pipeline = pipeline("question-answering", model=qa_model, tokenizer=qa_tokenizer)

# Summarization model
summarization_model_name = "facebook/bart-large-cnn"
summarizer = pipeline("summarization", model=summarization_model_name)

def read_pdf(file):
    try:
        text = extract_text(file)
        if not text:
            raise ValueError("PDF extraction failed.")
        return text
    except Exception as e:
        return str(e)

def retrieve_relevant_text_bm25(question, sentences, top_n=3):
    try:
        tokenized_corpus = [sent.split() for sent in sentences]
        bm25 = BM25Okapi(tokenized_corpus)
        tokenized_query = question.split()
        doc_scores = bm25.get_scores(tokenized_query)
        top_n_indices = np.argsort(doc_scores)[::-1][:top_n]
        relevant_texts = [sentences[i] for i in top_n_indices]
        return " ".join(relevant_texts)
    except Exception as e:
        return str(e)

def answer_question(pdf, question, num_words):
    try:
        text = read_pdf(pdf)
        if isinstance(text, str):
            return text

        if "summarize" in question.lower():
            try:
                summarized_text = summarizer(text, max_length=num_words, min_length=1)
                return summarized_text[0]['summary_text'].strip()
            except RuntimeError as e:
                if "Input length of input_ids is" in str(e) and "but `max_length` is set to" in str(e):
                    return "PDF is too long for summarization. Please provide a shorter PDF or ask a more specific question."
                else:
                    return f"Summarization Error: {e}"
            except Exception as e:
                return f"Summarization Error: {e}"

        sentences = sent_tokenize(text)
        relevant_text = retrieve_relevant_text_bm25(question, sentences)

        if not relevant_text:
            return "Could not find relevant information in the PDF."

        response = qa_pipeline(question=question, context=relevant_text)
        answer = response.get('answer')

        if not answer:
            return "Could not find an answer in the relevant text."

        answer = answer.strip()
        answer = " ".join(answer.split())

        if len(answer.split()) > num_words:
            try:
                summarized_answer = summarizer(answer, max_length=num_words + 10, min_length=1)
                answer = summarized_answer[0]['summary_text']
                answer = answer.strip()
                answer = " ".join(answer.split())
                if len(answer.split()) > num_words:
                    answer = " ".join(answer.split()[:num_words])
            except RuntimeError as e:
                if "Input length of input_ids is" in str(e) and "but `max_length` is set to" in str(e):
                    answer = " ".join(answer.split()[:num_words])
                else:
                    return f"Summarization Error: {e}"
            except Exception as e:
                return f"Summarization Error: {e}"
        elif len(answer.split()) < num_words and relevant_text:
            remaining_words = num_words - len(answer.split())
            added_words = 0
            added_sentences = []
            for sentence in sent_tokenize(relevant_text):
                sentence_words = sentence.split()
                words_to_add = min(remaining_words - added_words, len(sentence_words))
                if words_to_add > 0:
                    added_sentences.append(" ".join(sentence_words[:words_to_add]))
                    added_words += words_to_add
                if added_words >= remaining_words:
                    break
            answer += " " + " ".join(added_sentences)
            answer = answer.strip()
            answer = " ".join(answer.split())
            if len(answer.split()) > num_words:
                answer = " ".join(answer.split()[:num_words])

        return answer.strip()

    except Exception as e:
        return str(e)

with gr.Blocks() as iface:
    gr.Markdown("PDF Q&A with RoBERTa | made by NP")
    with gr.Row():
        with gr.Column(scale=2):
            question_input = gr.Textbox(lines=2, placeholder="Ask a question", label="Question")
            btn = gr.Button("Submit")
        with gr.Column(scale=1):
            pdf_input = gr.File(type="filepath", label="Upload PDF")
            num_words_slider = gr.Slider(minimum=1, maximum=500, value=100, step=1, label="Number of Words")
    answer_output = gr.Textbox(label="Answer", lines=5)
    btn.click(fn=answer_question, inputs=[pdf_input, question_input, num_words_slider], outputs=answer_output)

if __name__ == "__main__":
    iface.launch()