Spaces:
Runtime error
Runtime error
initial commit
Browse files- app.py +202 -0
- backend_utils/file_handlers.py +51 -0
- backend_utils/text_processor.py +45 -0
- requirements.txt +8 -0
app.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
+
|
5 |
+
sys.path.append(os.path.abspath('.'))
|
6 |
+
|
7 |
+
import streamlit as st
|
8 |
+
import time
|
9 |
+
import openai
|
10 |
+
from typing import List, Optional, Tuple, Dict, IO
|
11 |
+
|
12 |
+
from langchain.chat_models import ChatOpenAI
|
13 |
+
from langchain.schema import HumanMessage, AIMessage, ChatMessage, FunctionMessage
|
14 |
+
from langchain.chains.question_answering import load_qa_chain
|
15 |
+
from langchain.callbacks import get_openai_callback
|
16 |
+
from backend_utils.file_handlers import FileHandlerFactory
|
17 |
+
from backend_utils.text_processor import DefaultTextProcessor
|
18 |
+
|
19 |
+
|
20 |
+
MODELS = {
|
21 |
+
'gpt-3.5': 'openai',
|
22 |
+
'gpt-4': 'openai',
|
23 |
+
|
24 |
+
}
|
25 |
+
openai.api_key = ""
|
26 |
+
os.environ["OPENAI_API_KEY"]=""
|
27 |
+
def set_api_key(api_provider, api_key):
|
28 |
+
"""
|
29 |
+
Set the API key in the respective environment variable
|
30 |
+
"""
|
31 |
+
if api_provider == 'openai':
|
32 |
+
os.environ["OPENAI_API_KEY"] = api_key
|
33 |
+
openai.api_key = os.environ["OPENAI_API_KEY"]
|
34 |
+
else:
|
35 |
+
raise ValueError(f"Unknown API provider: {api_provider}")
|
36 |
+
|
37 |
+
def load_chain(selected_model):
|
38 |
+
"""Logic for loading the chain you want to use should go here."""
|
39 |
+
if selected_model=='gpt-4':
|
40 |
+
llm = ChatOpenAI(temperature=0, model="gpt-4")
|
41 |
+
else:
|
42 |
+
llm = ChatOpenAI(temperature=0)
|
43 |
+
return llm
|
44 |
+
|
45 |
+
def answer_question(knowledge_base, user_question,llm):
|
46 |
+
|
47 |
+
try:
|
48 |
+
retrived_docs= knowledge_base.similarity_search(
|
49 |
+
user_question,
|
50 |
+
k=10
|
51 |
+
)
|
52 |
+
except Exception as e:
|
53 |
+
print(f"Error finding relative chunks: {e}")
|
54 |
+
return []
|
55 |
+
print(retrived_docs)
|
56 |
+
try:
|
57 |
+
chain = load_qa_chain(
|
58 |
+
llm,
|
59 |
+
chain_type="stuff"
|
60 |
+
)
|
61 |
+
with get_openai_callback() as callback:
|
62 |
+
|
63 |
+
response = chain.run(
|
64 |
+
input_documents=retrived_docs,
|
65 |
+
question=user_question,
|
66 |
+
max_tokens=50
|
67 |
+
)
|
68 |
+
|
69 |
+
print(callback)
|
70 |
+
return response
|
71 |
+
except Exception as e:
|
72 |
+
print(f"Error running QA chain: {e}")
|
73 |
+
return ""
|
74 |
+
|
75 |
+
|
76 |
+
def read_files(files: List[IO]) -> Optional[str]:
|
77 |
+
"""
|
78 |
+
Reads the files and returns the combined text.
|
79 |
+
"""
|
80 |
+
combined_text = ""
|
81 |
+
if len(files)==1:
|
82 |
+
file=files[0]
|
83 |
+
if file is not None:
|
84 |
+
file_factory=FileHandlerFactory()
|
85 |
+
handler = file_factory.get_file_handler(file.type)
|
86 |
+
text = handler.read_file(file)
|
87 |
+
if not text:
|
88 |
+
print(f"No text could be extracted from {file.name}. Please ensure the file is not encrypted or corrupted.")
|
89 |
+
return None
|
90 |
+
else:
|
91 |
+
combined_text += text
|
92 |
+
else:
|
93 |
+
for file in files:
|
94 |
+
if file is not None:
|
95 |
+
file_factory=FileHandlerFactory()
|
96 |
+
handler = file_factory.get_file_handler(file.type)
|
97 |
+
text = handler.read_file(file)
|
98 |
+
if not text:
|
99 |
+
print(f"No text could be extracted from {file.name}. Please ensure the file is not encrypted or corrupted.")
|
100 |
+
return None
|
101 |
+
else:
|
102 |
+
combined_text += text
|
103 |
+
return combined_text
|
104 |
+
|
105 |
+
def chunk_text(combined_text: str) -> Optional[List[str]]:
|
106 |
+
processor=DefaultTextProcessor(500,0)
|
107 |
+
chunks = processor.split_text(combined_text)
|
108 |
+
if not chunks:
|
109 |
+
print("Couldn't split the text into chunks. Please try again with different text.")
|
110 |
+
return None
|
111 |
+
return chunks,processor
|
112 |
+
def create_embeddings( chunks: List[str], processor) -> Optional[Dict]:
|
113 |
+
"""
|
114 |
+
Takes chunks and creates embeddings in a knowledge base.
|
115 |
+
"""
|
116 |
+
knowledge_base = processor.create_embeddings(chunks)
|
117 |
+
if not knowledge_base:
|
118 |
+
print("Couldn't create embeddings from the text. Please try again.")
|
119 |
+
return None
|
120 |
+
return knowledge_base
|
121 |
+
def load_documents(files):
|
122 |
+
print(files)
|
123 |
+
combined_text = read_files(files)
|
124 |
+
chunks,processor = chunk_text(combined_text)
|
125 |
+
knowledge_base = create_embeddings(chunks,processor)
|
126 |
+
|
127 |
+
print("ALL DONE")
|
128 |
+
return knowledge_base
|
129 |
+
def get_text():
|
130 |
+
input_text = st.text_input("You: ", "Hello, how are you?", key="input")
|
131 |
+
return input_text
|
132 |
+
|
133 |
+
|
134 |
+
if __name__ == "__main__":
|
135 |
+
st.set_page_config(
|
136 |
+
page_title="Chat with your documents demo:",
|
137 |
+
page_icon="📖",
|
138 |
+
layout="wide",
|
139 |
+
initial_sidebar_state="expanded", )
|
140 |
+
# Dropdown to select model
|
141 |
+
selected_model = st.sidebar.selectbox("Select a model", list(MODELS.keys()))
|
142 |
+
|
143 |
+
# Input box to enter API key
|
144 |
+
api_key = st.sidebar.text_input(f"Enter API key for {MODELS[selected_model]}", type="password")
|
145 |
+
|
146 |
+
# Set the API key for the selected model
|
147 |
+
if api_key:
|
148 |
+
set_api_key(MODELS[selected_model], api_key)
|
149 |
+
|
150 |
+
llm = load_chain(selected_model)
|
151 |
+
if "loaded" not in st.session_state:
|
152 |
+
st.session_state["loaded"] = False
|
153 |
+
if "knowledge_base" not in st.session_state:
|
154 |
+
st.session_state["knowledge_base"] = None
|
155 |
+
|
156 |
+
ResumePDF = st.sidebar.file_uploader(
|
157 |
+
"Upload your documents", type=['pdf'], help="Help message goes here", key="uploaded_file", accept_multiple_files=True
|
158 |
+
)
|
159 |
+
if ResumePDF :
|
160 |
+
|
161 |
+
print("ResumePDF",ResumePDF)
|
162 |
+
|
163 |
+
if not st.session_state["loaded"]:
|
164 |
+
with st.spinner('Loading files 📖'):
|
165 |
+
st.session_state["knowledge_base"] = load_documents(ResumePDF)
|
166 |
+
st.session_state["loaded"] = True
|
167 |
+
|
168 |
+
st.header("📖 Chat with your documents demo:")
|
169 |
+
|
170 |
+
if "messages" not in st.session_state:
|
171 |
+
st.session_state["messages"] = [
|
172 |
+
{"role": "assistant", "content": "How can I help you?"}]
|
173 |
+
|
174 |
+
# Display chat messages from history on app rerun
|
175 |
+
for message in st.session_state.messages:
|
176 |
+
with st.chat_message(message["role"]):
|
177 |
+
st.markdown(message["content"])
|
178 |
+
|
179 |
+
if user_input := st.chat_input("What is your question?"):
|
180 |
+
# Add user message to chat history
|
181 |
+
st.session_state.messages.append({"role": "user", "content": user_input})
|
182 |
+
# Display user message in chat message container
|
183 |
+
with st.chat_message("user"):
|
184 |
+
st.markdown(user_input)
|
185 |
+
|
186 |
+
with st.chat_message("assistant"):
|
187 |
+
message_placeholder = st.empty()
|
188 |
+
full_response = ""
|
189 |
+
|
190 |
+
with st.spinner('Thinking ...'):
|
191 |
+
ai_message=answer_question(st.session_state["knowledge_base"],user_input,llm)
|
192 |
+
# ai_message = llm.predict_messages([HumanMessage(content=user_input)])
|
193 |
+
# Simulate stream of response with milliseconds delay
|
194 |
+
print(ai_message)
|
195 |
+
for chunk in ai_message.split():
|
196 |
+
full_response += chunk + " "
|
197 |
+
time.sleep(0.05)
|
198 |
+
# Add a blinking cursor to simulate typing
|
199 |
+
message_placeholder.markdown(full_response + "▌")
|
200 |
+
message_placeholder.markdown(full_response)
|
201 |
+
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
202 |
+
|
backend_utils/file_handlers.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from PyPDF2 import PdfReader
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
class FileHandler(ABC):
|
7 |
+
"""Abstract base class for file handlers."""
|
8 |
+
|
9 |
+
@abstractmethod
|
10 |
+
def read_file(self, file):
|
11 |
+
"""Read the file and extract the text.
|
12 |
+
|
13 |
+
Parameters:
|
14 |
+
file (UploadedFile): The file to read.
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
str: The extracted text.
|
18 |
+
|
19 |
+
"""
|
20 |
+
pass
|
21 |
+
|
22 |
+
class PDFHandler(FileHandler):
|
23 |
+
|
24 |
+
|
25 |
+
def read_file(self, file):
|
26 |
+
|
27 |
+
try:
|
28 |
+
pdf_reader = PdfReader(file)
|
29 |
+
text = ""
|
30 |
+
for page in pdf_reader.pages:
|
31 |
+
page_text = page.extract_text()
|
32 |
+
if page_text:
|
33 |
+
text += page_text
|
34 |
+
return text
|
35 |
+
except Exception as e:
|
36 |
+
print(f"Error reading file: {e}")
|
37 |
+
return "" # return an empty string if an error occurs
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
class FileHandlerFactory:
|
43 |
+
|
44 |
+
|
45 |
+
@staticmethod
|
46 |
+
def get_file_handler(file_type):
|
47 |
+
|
48 |
+
if file_type == "application/pdf":
|
49 |
+
return PDFHandler()
|
50 |
+
else:
|
51 |
+
raise ValueError("Invalid file type")
|
backend_utils/text_processor.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
3 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
4 |
+
from langchain.vectorstores import FAISS
|
5 |
+
|
6 |
+
class TextProcessor(ABC):
|
7 |
+
|
8 |
+
@abstractmethod
|
9 |
+
def split_text(self, text):
|
10 |
+
pass
|
11 |
+
|
12 |
+
@abstractmethod
|
13 |
+
def create_embeddings(self, chunks):
|
14 |
+
pass
|
15 |
+
|
16 |
+
|
17 |
+
class DefaultTextProcessor(TextProcessor):
|
18 |
+
|
19 |
+
|
20 |
+
def __init__(self,chunk_size,chunk_overlap):
|
21 |
+
self.chunk_overlap = chunk_overlap
|
22 |
+
self.chunk_size = chunk_size
|
23 |
+
|
24 |
+
def split_text(self, text):
|
25 |
+
|
26 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
27 |
+
chunk_size=self.chunk_size ,
|
28 |
+
chunk_overlap=self.chunk_overlap,
|
29 |
+
separators=[" ", ",", "\n"],
|
30 |
+
length_function=len
|
31 |
+
)
|
32 |
+
|
33 |
+
chunks = text_splitter.split_text(text)
|
34 |
+
return chunks
|
35 |
+
|
36 |
+
def create_embeddings(self, chunks):
|
37 |
+
|
38 |
+
if not chunks:
|
39 |
+
return None
|
40 |
+
embeddings = OpenAIEmbeddings()
|
41 |
+
try:
|
42 |
+
return FAISS.from_texts(chunks, embeddings)
|
43 |
+
except Exception as e:
|
44 |
+
print(f"Error creating embeddings: {e}")
|
45 |
+
return None
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
openai==0.27.8
|
2 |
+
transformers==4.30.2
|
3 |
+
langchain==0.0.250
|
4 |
+
streamlit==1.25.0
|
5 |
+
torch==2.0.1
|
6 |
+
faiss-cpu==1.7.4
|
7 |
+
tiktoken==0.4.0
|
8 |
+
PyPDF2
|