File size: 5,745 Bytes
aa4fa52
 
 
 
 
 
 
 
 
 
 
 
 
 
88ee5e1
 
4d89474
 
aa4fa52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d50066
aa4fa52
 
 
 
0d50066
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import streamlit as st
import torch
import numpy as np
import faiss
import PyPDF2
import os

from transformers import DPRContextEncoder, DPRContextEncoderTokenizer, DPRQuestionEncoder, DPRQuestionEncoderTokenizer, BartForQuestionAnswering
from transformers import BartForConditionalGeneration, BartTokenizer, AutoTokenizer

from langchain import text_splitter
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader

from streamlit import runtime

runtime.exists()

device = torch.device("cpu")
if torch.cuda.is_available():
   print("Training on GPU")
   device = torch.device("cuda:0")

file_url = "https://arxiv.org/pdf/1706.03762.pdf"
file_path = "assets/attention.pdf"

if not os.path.exists('assets'):
    os.mkdir('assets')

if not os.path.isfile(file_path):
    os.system(f'curl -o {file_path} {file_url}')
else:
    print("File already exists!")

class Retriever:

  def __init__(self, file_path, device, context_model_name, question_model_name):
    self.file_path = file_path
    self.device = device

    self.context_tokenizer = DPRContextEncoderTokenizer.from_pretrained(context_model_name)
    self.context_model = DPRContextEncoder.from_pretrained(context_model_name).to(device)

    self.question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(question_model_name)
    self.question_model = DPRQuestionEncoder.from_pretrained(question_model_name).to(device)

  def token_len(self, text):
    tokens = self.context_tokenizer.encode(text)
    return len(tokens)

  def extract_text_from_pdf(self, file_path):
    with open(file_path, 'rb') as file:
        reader = PyPDF2.PdfReader(file)
        text = ''
        for page in reader.pages:
            text += page.extract_text()
    return text

  def get_text(self):
    with open(self.file_path, 'rb') as file:
        reader = PyPDF2.PdfReader(file)
        text = ''
        for page in reader.pages:
            text += page.extract_text()
    return text

  def load_chunks(self):
    self.text = self.extract_text_from_pdf(self.file_path)
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=300,
        chunk_overlap=20,
        length_function=self.token_len,
        separators=["\n\n", " ", ".", ""]
    )

    self.chunks = text_splitter.split_text(self.text)

  def load_context_embeddings(self):
    encoded_input = self.context_tokenizer(self.chunks, return_tensors='pt', padding=True, truncation=True, max_length=100).to(device)

    with torch.no_grad():
      model_output = self.context_model(**encoded_input)
      self.token_embeddings = model_output.pooler_output.cpu().detach().numpy()

    self.index = faiss.IndexFlatL2(self.token_embeddings.shape[1])
    self.index.add(self.token_embeddings)

  def retrieve_top_k(self, query_prompt, k=10):
    encoded_query = self.question_tokenizer(query_prompt, return_tensors="pt", truncation=True, padding=True).to(device)

    with torch.no_grad():
      model_output = self.question_model(**encoded_query)
      query_vector = model_output.pooler_output

    query_vector_np = query_vector.cpu().numpy()
    D, I = self.index.search(query_vector_np, k)

    retrieved_texts = [self.chunks[i] for i in I[0]]

    scores = [d for d in D[0]]

    # print("Top 5 retrieved texts and their associated scores:")
    # for idx, (text, score) in enumerate(zip(retrieved_texts, scores)):
    #     print(f"{idx + 1}. Text: {text} \n   Score: {score:.4f}\n")

    return retrieved_texts

class RAG:
    def __init__(self,
                 file_path,
                 device,
                 context_model_name="facebook/dpr-ctx_encoder-multiset-base",
                 question_model_name="facebook/dpr-question_encoder-multiset-base",
                 generator_name="facebook/bart-large"):

      # generator_name = "valhalla/bart-large-finetuned-squadv1"
      # generator_name = "'vblagoje/bart_lfqa'"
      generator_name = "a-ware/bart-squadv2"
      
      self.generator_tokenizer = BartTokenizer.from_pretrained(generator_name)
      self.generator_model = BartForConditionalGeneration.from_pretrained(generator_name).to(device)

      self.retriever = Retriever(file_path, device, context_model_name, question_model_name)
      self.retriever.load_chunks()
      self.retriever.load_context_embeddings()

    def get_answer(self, question, context):
      input_text = "context: %s <question for context: %s </s>" % (context,question)
      features = self.generator_tokenizer([input_text], return_tensors='pt')
      out = self.generator_model.generate(input_ids=features['input_ids'].to(device), attention_mask=features['attention_mask'].to(device))
      return self.generator_tokenizer.decode(out[0])

    def query(self, question):
      context = self.retriever.retrieve_top_k(question, k=5)
      # input_text = question + " " + " ".join(context)

      input_text = "answer: " + " ".join(context) + " " + question

      print(input_text)

      inputs = self.generator_tokenizer.encode(input_text, return_tensors='pt', max_length=1024, truncation=True).to(device)
      outputs = self.generator_model.generate(inputs, max_length=150, min_length=2, length_penalty=2.0, num_beams=4, early_stopping=True)

      answer = self.generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
      return answer


context_model_name="facebook/dpr-ctx_encoder-single-nq-base"
context_model_name="facebook/dpr-ctx_encoder-multiset-base"
question_model_name="facebook/dpr-question_encoder-multiset-base"

rag = RAG(file_path, device)

st.title("RAG Model Query Interface")

query = st.text_area("Enter your question:")

# If a query is given, get the answer
if query:
    answer = rag.query(query)
    st.json(answer)