NHZ commited on
Commit
fdf7122
·
verified ·
1 Parent(s): 7efdb22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -60
app.py CHANGED
@@ -1,14 +1,18 @@
 
1
  import requests
2
  import numpy as np
3
  import faiss
4
  from PyPDF2 import PdfReader
5
  from transformers import AutoTokenizer, AutoModel
 
 
 
 
 
6
  from groq import Groq
7
  import streamlit as st
8
- import torch
9
- import os
10
 
11
- # Initialize Groq client using secret API key
12
  client = Groq(api_key=os.getenv("GROQ_API_KEY"))
13
 
14
  # Function to download and extract content from a public Google Drive PDF link
@@ -31,28 +35,8 @@ def extract_pdf_content(drive_url):
31
  text += page.extract_text()
32
  return text
33
 
34
- # Function to chunk and tokenize text
35
- def chunk_and_tokenize(text, tokenizer, chunk_size=512):
36
- tokens = tokenizer.encode(text, add_special_tokens=False)
37
- chunks = [tokens[i:i + chunk_size] for i in range(0, len(tokens), chunk_size)]
38
- return chunks
39
-
40
- # Function to compute embeddings and build FAISS index
41
- def build_faiss_index(chunks, model):
42
- embeddings = []
43
- for chunk in chunks:
44
- input_ids = torch.tensor([chunk])
45
- with torch.no_grad():
46
- embedding = model(input_ids).last_hidden_state.mean(dim=1).detach().numpy()
47
- embeddings.append(embedding)
48
- embeddings = np.vstack(embeddings)
49
-
50
- index = faiss.IndexFlatL2(embeddings.shape[1])
51
- index.add(embeddings)
52
- return index
53
-
54
  # Streamlit app
55
- st.title("RAG-based Application with Groq API")
56
 
57
  # Predefined Google Drive link
58
  drive_url = "https://drive.google.com/file/d/1XvqA1OIssRs2gbmOtKFKj-02yQ5X2yg0/view?usp=sharing"
@@ -63,46 +47,40 @@ text = extract_pdf_content(drive_url)
63
  if text:
64
  st.write("Document extracted successfully!")
65
 
66
- # Initialize tokenizer and model
67
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
68
- model = AutoModel.from_pretrained("bert-base-uncased")
69
-
70
- st.write("Chunking and tokenizing content...")
71
- chunks = chunk_and_tokenize(text, tokenizer)
72
-
73
- st.write("Building FAISS index...")
74
- index = build_faiss_index(chunks, model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  # Query input
77
  query = st.text_input("Enter your query:")
78
  if query:
79
- st.write("Searching for the most relevant chunk...")
80
- query_tokens = tokenizer.encode(query, add_special_tokens=False)
81
- query_embedding = (
82
- model(torch.tensor([query_tokens]))
83
- .last_hidden_state.mean(dim=1)
84
- .detach().numpy()
85
- )
86
- _, indices = index.search(query_embedding, k=1)
87
-
88
- # Retrieve the most relevant chunk
89
- relevant_chunk = chunks[indices[0][0]]
90
- relevant_text = tokenizer.decode(relevant_chunk)
91
- st.write("Relevant chunk found:", relevant_text)
92
-
93
- # Interact with Groq API
94
- st.write("Querying the Groq API...")
95
- chat_completion = client.chat.completions.create(
96
- messages=[
97
- {
98
- "role": "user",
99
- "content": relevant_text,
100
- }
101
- ],
102
- model="llama-3.3-70b-versatile",
103
- )
104
- st.write("Model Response:", chat_completion.choices[0].message.content)
105
  else:
106
  st.error("Failed to extract content from the document.")
107
 
108
-
 
1
+ import os
2
  import requests
3
  import numpy as np
4
  import faiss
5
  from PyPDF2 import PdfReader
6
  from transformers import AutoTokenizer, AutoModel
7
+ from langchain.vectorstores import FAISS
8
+ from langchain.embeddings import HuggingFaceEmbeddings
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.prompts import PromptTemplate
11
+ from langchain.chat_models import ChatOpenAI
12
  from groq import Groq
13
  import streamlit as st
 
 
14
 
15
+ # Initialize Groq client
16
  client = Groq(api_key=os.getenv("GROQ_API_KEY"))
17
 
18
  # Function to download and extract content from a public Google Drive PDF link
 
35
  text += page.extract_text()
36
  return text
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  # Streamlit app
39
+ st.title("Enhanced RAG with LangChain and Groq API")
40
 
41
  # Predefined Google Drive link
42
  drive_url = "https://drive.google.com/file/d/1XvqA1OIssRs2gbmOtKFKj-02yQ5X2yg0/view?usp=sharing"
 
47
  if text:
48
  st.write("Document extracted successfully!")
49
 
50
+ # LangChain embeddings and FAISS index setup
51
+ st.write("Building embeddings and FAISS index...")
52
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
53
+ faiss_index = FAISS.from_texts([text], embeddings)
54
+
55
+ # LangChain retriever
56
+ retriever = faiss_index.as_retriever(search_kwargs={"k": 3})
57
+
58
+ # LangChain QA chain
59
+ prompt_template = """
60
+ Use the following document excerpts to answer the user's question.
61
+ If the answer is not directly found in the document, say "The answer is not in the provided document.".
62
+
63
+ Document Excerpts:
64
+ {context}
65
+
66
+ Question:
67
+ {question}
68
+
69
+ Answer:
70
+ """
71
+ PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
72
+ qa_chain = RetrievalQA.from_chain_type(
73
+ llm=ChatOpenAI(model_name="gpt-3.5-turbo"),
74
+ retriever=retriever,
75
+ chain_type_kwargs={"prompt": PROMPT},
76
+ )
77
 
78
  # Query input
79
  query = st.text_input("Enter your query:")
80
  if query:
81
+ st.write("Searching the document and generating a response...")
82
+ result = qa_chain.run(query)
83
+ st.write("Response:", result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  else:
85
  st.error("Failed to extract content from the document.")
86