Jofthomas HF staff commited on
Commit
88768cb
1 Parent(s): 198090e

initial commit

Browse files
app.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import sys
3
+ import os
4
+
5
+ sys.path.append(os.path.abspath('.'))
6
+
7
+ import streamlit as st
8
+ import time
9
+ import openai
10
+ from typing import List, Optional, Tuple, Dict, IO
11
+
12
+ from langchain.chat_models import ChatOpenAI
13
+ from langchain.schema import HumanMessage, AIMessage, ChatMessage, FunctionMessage
14
+ from langchain.chains.question_answering import load_qa_chain
15
+ from langchain.callbacks import get_openai_callback
16
+ from backend_utils.file_handlers import FileHandlerFactory
17
+ from backend_utils.text_processor import DefaultTextProcessor
18
+
19
+
20
+ MODELS = {
21
+ 'gpt-3.5': 'openai',
22
+ 'gpt-4': 'openai',
23
+
24
+ }
25
+ openai.api_key = ""
26
+ os.environ["OPENAI_API_KEY"]=""
27
+ def set_api_key(api_provider, api_key):
28
+ """
29
+ Set the API key in the respective environment variable
30
+ """
31
+ if api_provider == 'openai':
32
+ os.environ["OPENAI_API_KEY"] = api_key
33
+ openai.api_key = os.environ["OPENAI_API_KEY"]
34
+ else:
35
+ raise ValueError(f"Unknown API provider: {api_provider}")
36
+
37
+ def load_chain(selected_model):
38
+ """Logic for loading the chain you want to use should go here."""
39
+ if selected_model=='gpt-4':
40
+ llm = ChatOpenAI(temperature=0, model="gpt-4")
41
+ else:
42
+ llm = ChatOpenAI(temperature=0)
43
+ return llm
44
+
45
+ def answer_question(knowledge_base, user_question,llm):
46
+
47
+ try:
48
+ retrived_docs= knowledge_base.similarity_search(
49
+ user_question,
50
+ k=10
51
+ )
52
+ except Exception as e:
53
+ print(f"Error finding relative chunks: {e}")
54
+ return []
55
+ print(retrived_docs)
56
+ try:
57
+ chain = load_qa_chain(
58
+ llm,
59
+ chain_type="stuff"
60
+ )
61
+ with get_openai_callback() as callback:
62
+
63
+ response = chain.run(
64
+ input_documents=retrived_docs,
65
+ question=user_question,
66
+ max_tokens=50
67
+ )
68
+
69
+ print(callback)
70
+ return response
71
+ except Exception as e:
72
+ print(f"Error running QA chain: {e}")
73
+ return ""
74
+
75
+
76
+ def read_files(files: List[IO]) -> Optional[str]:
77
+ """
78
+ Reads the files and returns the combined text.
79
+ """
80
+ combined_text = ""
81
+ if len(files)==1:
82
+ file=files[0]
83
+ if file is not None:
84
+ file_factory=FileHandlerFactory()
85
+ handler = file_factory.get_file_handler(file.type)
86
+ text = handler.read_file(file)
87
+ if not text:
88
+ print(f"No text could be extracted from {file.name}. Please ensure the file is not encrypted or corrupted.")
89
+ return None
90
+ else:
91
+ combined_text += text
92
+ else:
93
+ for file in files:
94
+ if file is not None:
95
+ file_factory=FileHandlerFactory()
96
+ handler = file_factory.get_file_handler(file.type)
97
+ text = handler.read_file(file)
98
+ if not text:
99
+ print(f"No text could be extracted from {file.name}. Please ensure the file is not encrypted or corrupted.")
100
+ return None
101
+ else:
102
+ combined_text += text
103
+ return combined_text
104
+
105
+ def chunk_text(combined_text: str) -> Optional[List[str]]:
106
+ processor=DefaultTextProcessor(500,0)
107
+ chunks = processor.split_text(combined_text)
108
+ if not chunks:
109
+ print("Couldn't split the text into chunks. Please try again with different text.")
110
+ return None
111
+ return chunks,processor
112
+ def create_embeddings( chunks: List[str], processor) -> Optional[Dict]:
113
+ """
114
+ Takes chunks and creates embeddings in a knowledge base.
115
+ """
116
+ knowledge_base = processor.create_embeddings(chunks)
117
+ if not knowledge_base:
118
+ print("Couldn't create embeddings from the text. Please try again.")
119
+ return None
120
+ return knowledge_base
121
+ def load_documents(files):
122
+ print(files)
123
+ combined_text = read_files(files)
124
+ chunks,processor = chunk_text(combined_text)
125
+ knowledge_base = create_embeddings(chunks,processor)
126
+
127
+ print("ALL DONE")
128
+ return knowledge_base
129
+ def get_text():
130
+ input_text = st.text_input("You: ", "Hello, how are you?", key="input")
131
+ return input_text
132
+
133
+
134
+ if __name__ == "__main__":
135
+ st.set_page_config(
136
+ page_title="Chat with your documents demo:",
137
+ page_icon="📖",
138
+ layout="wide",
139
+ initial_sidebar_state="expanded", )
140
+ # Dropdown to select model
141
+ selected_model = st.sidebar.selectbox("Select a model", list(MODELS.keys()))
142
+
143
+ # Input box to enter API key
144
+ api_key = st.sidebar.text_input(f"Enter API key for {MODELS[selected_model]}", type="password")
145
+
146
+ # Set the API key for the selected model
147
+ if api_key:
148
+ set_api_key(MODELS[selected_model], api_key)
149
+
150
+ llm = load_chain(selected_model)
151
+ if "loaded" not in st.session_state:
152
+ st.session_state["loaded"] = False
153
+ if "knowledge_base" not in st.session_state:
154
+ st.session_state["knowledge_base"] = None
155
+
156
+ ResumePDF = st.sidebar.file_uploader(
157
+ "Upload your documents", type=['pdf'], help="Help message goes here", key="uploaded_file", accept_multiple_files=True
158
+ )
159
+ if ResumePDF :
160
+
161
+ print("ResumePDF",ResumePDF)
162
+
163
+ if not st.session_state["loaded"]:
164
+ with st.spinner('Loading files 📖'):
165
+ st.session_state["knowledge_base"] = load_documents(ResumePDF)
166
+ st.session_state["loaded"] = True
167
+
168
+ st.header("📖 Chat with your documents demo:")
169
+
170
+ if "messages" not in st.session_state:
171
+ st.session_state["messages"] = [
172
+ {"role": "assistant", "content": "How can I help you?"}]
173
+
174
+ # Display chat messages from history on app rerun
175
+ for message in st.session_state.messages:
176
+ with st.chat_message(message["role"]):
177
+ st.markdown(message["content"])
178
+
179
+ if user_input := st.chat_input("What is your question?"):
180
+ # Add user message to chat history
181
+ st.session_state.messages.append({"role": "user", "content": user_input})
182
+ # Display user message in chat message container
183
+ with st.chat_message("user"):
184
+ st.markdown(user_input)
185
+
186
+ with st.chat_message("assistant"):
187
+ message_placeholder = st.empty()
188
+ full_response = ""
189
+
190
+ with st.spinner('Thinking ...'):
191
+ ai_message=answer_question(st.session_state["knowledge_base"],user_input,llm)
192
+ # ai_message = llm.predict_messages([HumanMessage(content=user_input)])
193
+ # Simulate stream of response with milliseconds delay
194
+ print(ai_message)
195
+ for chunk in ai_message.split():
196
+ full_response += chunk + " "
197
+ time.sleep(0.05)
198
+ # Add a blinking cursor to simulate typing
199
+ message_placeholder.markdown(full_response + "▌")
200
+ message_placeholder.markdown(full_response)
201
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
202
+
backend_utils/file_handlers.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from PyPDF2 import PdfReader
3
+
4
+
5
+
6
+ class FileHandler(ABC):
7
+ """Abstract base class for file handlers."""
8
+
9
+ @abstractmethod
10
+ def read_file(self, file):
11
+ """Read the file and extract the text.
12
+
13
+ Parameters:
14
+ file (UploadedFile): The file to read.
15
+
16
+ Returns:
17
+ str: The extracted text.
18
+
19
+ """
20
+ pass
21
+
22
+ class PDFHandler(FileHandler):
23
+
24
+
25
+ def read_file(self, file):
26
+
27
+ try:
28
+ pdf_reader = PdfReader(file)
29
+ text = ""
30
+ for page in pdf_reader.pages:
31
+ page_text = page.extract_text()
32
+ if page_text:
33
+ text += page_text
34
+ return text
35
+ except Exception as e:
36
+ print(f"Error reading file: {e}")
37
+ return "" # return an empty string if an error occurs
38
+
39
+
40
+
41
+
42
+ class FileHandlerFactory:
43
+
44
+
45
+ @staticmethod
46
+ def get_file_handler(file_type):
47
+
48
+ if file_type == "application/pdf":
49
+ return PDFHandler()
50
+ else:
51
+ raise ValueError("Invalid file type")
backend_utils/text_processor.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from langchain.embeddings.openai import OpenAIEmbeddings
4
+ from langchain.vectorstores import FAISS
5
+
6
+ class TextProcessor(ABC):
7
+
8
+ @abstractmethod
9
+ def split_text(self, text):
10
+ pass
11
+
12
+ @abstractmethod
13
+ def create_embeddings(self, chunks):
14
+ pass
15
+
16
+
17
+ class DefaultTextProcessor(TextProcessor):
18
+
19
+
20
+ def __init__(self,chunk_size,chunk_overlap):
21
+ self.chunk_overlap = chunk_overlap
22
+ self.chunk_size = chunk_size
23
+
24
+ def split_text(self, text):
25
+
26
+ text_splitter = RecursiveCharacterTextSplitter(
27
+ chunk_size=self.chunk_size ,
28
+ chunk_overlap=self.chunk_overlap,
29
+ separators=[" ", ",", "\n"],
30
+ length_function=len
31
+ )
32
+
33
+ chunks = text_splitter.split_text(text)
34
+ return chunks
35
+
36
+ def create_embeddings(self, chunks):
37
+
38
+ if not chunks:
39
+ return None
40
+ embeddings = OpenAIEmbeddings()
41
+ try:
42
+ return FAISS.from_texts(chunks, embeddings)
43
+ except Exception as e:
44
+ print(f"Error creating embeddings: {e}")
45
+ return None
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ openai==0.27.8
2
+ transformers==4.30.2
3
+ langchain==0.0.250
4
+ streamlit==1.25.0
5
+ torch==2.0.1
6
+ faiss-cpu==1.7.4
7
+ tiktoken==0.4.0
8
+ PyPDF2