File size: 7,467 Bytes
88768cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

import sys
import os

sys.path.append(os.path.abspath('.'))

import streamlit as st
import time
import openai
from typing import List, Optional, Tuple, Dict, IO

from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage, AIMessage, ChatMessage, FunctionMessage
from langchain.chains.question_answering import load_qa_chain
from langchain.callbacks import get_openai_callback
from backend_utils.file_handlers import FileHandlerFactory
from backend_utils.text_processor import DefaultTextProcessor


MODELS = {
    'gpt-3.5': 'openai',
    'gpt-4': 'openai',
    
}
openai.api_key = ""
os.environ["OPENAI_API_KEY"]=""
def set_api_key(api_provider, api_key):
    """
    Set the API key in the respective environment variable
    """
    if api_provider == 'openai':
        os.environ["OPENAI_API_KEY"] = api_key
        openai.api_key = os.environ["OPENAI_API_KEY"]
    else:
        raise ValueError(f"Unknown API provider: {api_provider}")

def load_chain(selected_model):
    """Logic for loading the chain you want to use should go here."""
    if selected_model=='gpt-4':
        llm = ChatOpenAI(temperature=0, model="gpt-4")
    else:
         llm = ChatOpenAI(temperature=0)
    return llm

def answer_question(knowledge_base, user_question,llm):
      
    try:
        retrived_docs= knowledge_base.similarity_search(
                                                user_question, 
                                                k=10
                                                )
    except Exception as e:
        print(f"Error finding relative chunks: {e}")
        return []
    print(retrived_docs)
    try:
        chain = load_qa_chain(
                            llm, 
                            chain_type="stuff"
                            ) 
        with get_openai_callback() as callback:
            
            response = chain.run(
                                input_documents=retrived_docs, 
                                question=user_question,
                                max_tokens=50
                                )

            print(callback)
        return response
    except Exception as e:
        print(f"Error running QA chain: {e}")
        return ""


def read_files(files: List[IO]) -> Optional[str]:
        """
        Reads the files and returns the combined text.
        """
        combined_text = ""
        if len(files)==1:
            file=files[0]
            if file is not None:
                file_factory=FileHandlerFactory()
                handler = file_factory.get_file_handler(file.type)
                text = handler.read_file(file)
                if not text:
                    print(f"No text could be extracted from {file.name}. Please ensure the file is not encrypted or corrupted.")
                    return None
                else:
                    combined_text += text
        else:
            for file in files:
                if file is not None:
                    file_factory=FileHandlerFactory()
                    handler = file_factory.get_file_handler(file.type)
                    text = handler.read_file(file)
                    if not text:
                        print(f"No text could be extracted from {file.name}. Please ensure the file is not encrypted or corrupted.")
                        return None
                    else:
                        combined_text += text
        return combined_text

def chunk_text(combined_text: str) -> Optional[List[str]]:
    processor=DefaultTextProcessor(500,0)
    chunks = processor.split_text(combined_text)
    if not chunks:
        print("Couldn't split the text into chunks. Please try again with different text.")
        return None
    return chunks,processor
def create_embeddings( chunks: List[str], processor) -> Optional[Dict]:
    """
    Takes chunks and creates embeddings in a knowledge base.
    """
    knowledge_base = processor.create_embeddings(chunks)
    if not knowledge_base:
        print("Couldn't create embeddings from the text. Please try again.")
        return None
    return knowledge_base
def load_documents(files):
    print(files)
    combined_text = read_files(files)
    chunks,processor = chunk_text(combined_text) 
    knowledge_base = create_embeddings(chunks,processor)
    
    print("ALL DONE")
    return knowledge_base
def get_text():
    input_text = st.text_input("You: ", "Hello, how are you?", key="input")
    return input_text


if __name__ == "__main__":
    st.set_page_config(
            page_title="Chat with your documents demo:",
            page_icon="πŸ“–",
            layout="wide",
            initial_sidebar_state="expanded", )
     # Dropdown to select model
    selected_model = st.sidebar.selectbox("Select a model", list(MODELS.keys()))

    # Input box to enter API key
    api_key = st.sidebar.text_input(f"Enter API key for {MODELS[selected_model]}", type="password")

    # Set the API key for the selected model
    if api_key:
        set_api_key(MODELS[selected_model], api_key)

    llm = load_chain(selected_model)
    if "loaded" not in st.session_state:
        st.session_state["loaded"] = False
    if "knowledge_base" not in st.session_state:
        st.session_state["knowledge_base"] = None

    ResumePDF = st.sidebar.file_uploader(
        "Upload your documents", type=['pdf'], help="Help message goes here", key="uploaded_file", accept_multiple_files=True
    )
    if ResumePDF :
        
        print("ResumePDF",ResumePDF)
        
        if not st.session_state["loaded"]:
            with st.spinner('Loading files πŸ“–'):
                st.session_state["knowledge_base"] = load_documents(ResumePDF)
                st.session_state["loaded"] = True
       
        st.header("πŸ“– Chat with your documents demo:")

        if "messages" not in st.session_state:
            st.session_state["messages"] = [
                {"role": "assistant", "content": "How can I help you?"}]

        # Display chat messages from history on app rerun
        for message in st.session_state.messages:
            with st.chat_message(message["role"]):
                st.markdown(message["content"])

        if user_input := st.chat_input("What is your question?"):
            # Add user message to chat history
            st.session_state.messages.append({"role": "user", "content": user_input})
            # Display user message in chat message container
            with st.chat_message("user"):
                st.markdown(user_input)

            with st.chat_message("assistant"):
                message_placeholder = st.empty()
                full_response = ""

                with st.spinner('Thinking ...'):
                    ai_message=answer_question(st.session_state["knowledge_base"],user_input,llm)
                    # ai_message = llm.predict_messages([HumanMessage(content=user_input)])
                    # Simulate stream of response with milliseconds delay
                    print(ai_message)
                    for chunk in ai_message.split():
                        full_response += chunk + " "
                        time.sleep(0.05)
                        # Add a blinking cursor to simulate typing
                        message_placeholder.markdown(full_response + "β–Œ")
                    message_placeholder.markdown(full_response)
            st.session_state.messages.append({"role": "assistant", "content": full_response})