File size: 5,518 Bytes
2f435af
 
 
 
b15e33b
2f435af
e5db4c0
 
 
 
 
2f435af
a5b9764
 
2f435af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2856a0a
4814c8f
2f435af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ce6662
2f435af
4814c8f
2f435af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4814c8f
2f435af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143

import PyPDF2
import faiss
import numpy as np
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# Load the LLM for generation
generation_model_name = 'facebook/bart-large-cnn'
generation_model = AutoModelForSeq2SeqLM.from_pretrained(generation_model_name)
tokenizer = AutoTokenizer.from_pretrained(generation_model_name)

#Specify file paths
file_path1 = './AST-1.pdf'
file_path2 = './AST-2.pdf'

#Step 1 : Load the document files
def read_pdf(file_path):
    with open(file_path, 'rb') as file:
        reader = PyPDF2.PdfReader(file)
        text = ''
        for page_num in range(len(reader.pages)):
            page = reader.pages[page_num]
            text += page.extract_text()
        return text

ast1_text = read_pdf(file_path1)
ast2_text = read_pdf(file_path2)

#Step 2 Split the loaded documents into chunks
# Split by Fixed Number of Words:
def chunk_text(text, chunk_size=200):
    words = text.split()
    chunks = [' '.join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]
    return chunks

ast1_chunks = chunk_text(ast1_text, chunk_size=100)
ast2_chunks = chunk_text(ast2_text, chunk_size=150)

#label the chunks
ast1_chunks = [(chunk, 'AST-1') for chunk in ast1_chunks]
ast2_chunks = [(chunk, 'AST-2') for chunk in ast2_chunks]
all_chunks = ast1_chunks + ast2_chunks

print('Created the chunks')

#Load the Embedding Model and LLM
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# Load the pre-trained model from the MTEB leaderboard
embedding_model  = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

embeddings = embedding_model.encode(all_chunks, convert_to_tensor=True)

# Convert embeddings to numpy arrays
embeddings_np = np.array([embedding.cpu().numpy() for embedding in embeddings])

# Create a FAISS index
dimension = embeddings_np.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings_np)

# Save the index
faiss.write_index(index, 'embeddings_index.faiss')

# Load FAISS index
stored_index = faiss.read_index('./embeddings_index.faiss')

print('Stored embedding in db') 
#Function to retrieve chunks
def retrieve_chunks(query, top_k=10, use_mmr=False, diversity=0.5, target_doc='AST-1'):
    query_embedding = embedding_model.encode(query, convert_to_tensor=True).cpu().numpy()
    distances, indices = stored_index.search(np.array([query_embedding]), top_k)

    if use_mmr:
        # Implement MMR-based retrieval
        from sklearn.metrics.pairwise import cosine_similarity

        selected_indices = []
        selected_distances = []
        candidate_indices = [i for i in indices[0] if all_chunks[i][1] == target_doc]
        candidate_distances = [distances[0][i] for i in range(len(indices[0])) if all_chunks[indices[0][i]][1] == target_doc]

        while len(selected_indices) < top_k and candidate_indices:
            if not selected_indices:
                selected_indices.append(candidate_indices.pop(0))
                selected_distances.append(candidate_distances.pop(0))
            else:
                remaining_candidates = [candidate_indices[i] for i in range(len(candidate_indices))]
                remaining_embeddings = embeddings_np[remaining_candidates]
                selected_embeddings = embeddings_np[selected_indices]

                similarities = cosine_similarity(remaining_embeddings, selected_embeddings)
                mmr_scores = (1 - diversity) * np.array(candidate_distances[:len(remaining_candidates)]) - diversity * np.max(similarities, axis=1)

                next_index = np.argmax(mmr_scores)
                selected_indices.append(candidate_indices.pop(next_index))
                selected_distances.append(candidate_distances.pop(next_index))

        return [all_chunks[i][0] for i in selected_indices]
    else:
      retrieved_chunks = []
      for idx in indices[0]:
          chunk, doc_label = all_chunks[idx]
          if doc_label == target_doc:
              retrieved_chunks.append(chunk)
          if len(retrieved_chunks) >= top_k:
              break
      return retrieved_chunks


# Generate response
def generate_response(query, retrieved_chunks):
    #context = " ".join([chunk for chunk, _ in retrieved_chunks])
    context = " ".join(retrieved_chunks) # for retrieved_chunks as array of strings
    input_text = f"Query: {query}\nContext: {context}"
    inputs = tokenizer(input_text, return_tensors='pt', max_length=1024, truncation=True)
    summary_ids = generation_model.generate(inputs['input_ids'], max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
    return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

def rag_system(query, use_mmr=False):
    retrieved_chunks = retrieve_chunks(query, top_k=3, use_mmr=use_mmr)
    response = generate_response(query, retrieved_chunks)
    print(response)
    return response
    
import gradio as gr

def query_rag_system(query, use_mmr):
    return rag_system(query, use_mmr=use_mmr)

interface = gr.Interface(
    fn=query_rag_system,
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter your query here..."),
        gr.Checkbox(label="Use MMR")
    ],
    outputs="text",
    title="RAG System",
    description="Enter a query to get a response from the RAG system. Optionally, use MMR for better results."
)

interface.launch()