Spaces:
Runtime error
Runtime error
import sys | |
import os | |
sys.path.append(os.path.abspath('.')) | |
import streamlit as st | |
import time | |
import openai | |
from typing import List, Optional, Tuple, Dict, IO | |
from langchain.chat_models import ChatOpenAI | |
from langchain.schema import HumanMessage, AIMessage, ChatMessage, FunctionMessage | |
from langchain.chains.question_answering import load_qa_chain | |
from langchain.callbacks import get_openai_callback | |
from backend_utils.file_handlers import FileHandlerFactory | |
from backend_utils.text_processor import DefaultTextProcessor | |
MODELS = { | |
'gpt-3.5': 'openai', | |
'gpt-4': 'openai', | |
} | |
openai.api_key = "" | |
os.environ["OPENAI_API_KEY"]="" | |
def set_api_key(api_provider, api_key): | |
""" | |
Set the API key in the respective environment variable | |
""" | |
if api_provider == 'openai': | |
os.environ["OPENAI_API_KEY"] = api_key | |
openai.api_key = os.environ["OPENAI_API_KEY"] | |
else: | |
raise ValueError(f"Unknown API provider: {api_provider}") | |
def load_chain(selected_model): | |
"""Logic for loading the chain you want to use should go here.""" | |
if selected_model=='gpt-4': | |
llm = ChatOpenAI(temperature=0, model="gpt-4") | |
else: | |
llm = ChatOpenAI(temperature=0) | |
return llm | |
def answer_question(knowledge_base, user_question,llm): | |
try: | |
retrived_docs= knowledge_base.similarity_search( | |
user_question, | |
k=10 | |
) | |
except Exception as e: | |
print(f"Error finding relative chunks: {e}") | |
return [] | |
print(retrived_docs) | |
try: | |
chain = load_qa_chain( | |
llm, | |
chain_type="stuff" | |
) | |
with get_openai_callback() as callback: | |
response = chain.run( | |
input_documents=retrived_docs, | |
question=user_question, | |
max_tokens=50 | |
) | |
print(callback) | |
return response | |
except Exception as e: | |
print(f"Error running QA chain: {e}") | |
return "" | |
def read_files(files: List[IO]) -> Optional[str]: | |
""" | |
Reads the files and returns the combined text. | |
""" | |
combined_text = "" | |
if len(files)==1: | |
file=files[0] | |
if file is not None: | |
file_factory=FileHandlerFactory() | |
handler = file_factory.get_file_handler(file.type) | |
text = handler.read_file(file) | |
if not text: | |
print(f"No text could be extracted from {file.name}. Please ensure the file is not encrypted or corrupted.") | |
return None | |
else: | |
combined_text += text | |
else: | |
for file in files: | |
if file is not None: | |
file_factory=FileHandlerFactory() | |
handler = file_factory.get_file_handler(file.type) | |
text = handler.read_file(file) | |
if not text: | |
print(f"No text could be extracted from {file.name}. Please ensure the file is not encrypted or corrupted.") | |
return None | |
else: | |
combined_text += text | |
return combined_text | |
def chunk_text(combined_text: str) -> Optional[List[str]]: | |
processor=DefaultTextProcessor(500,0) | |
chunks = processor.split_text(combined_text) | |
if not chunks: | |
print("Couldn't split the text into chunks. Please try again with different text.") | |
return None | |
return chunks,processor | |
def create_embeddings( chunks: List[str], processor) -> Optional[Dict]: | |
""" | |
Takes chunks and creates embeddings in a knowledge base. | |
""" | |
knowledge_base = processor.create_embeddings(chunks) | |
if not knowledge_base: | |
print("Couldn't create embeddings from the text. Please try again.") | |
return None | |
return knowledge_base | |
def load_documents(files): | |
print(files) | |
combined_text = read_files(files) | |
chunks,processor = chunk_text(combined_text) | |
knowledge_base = create_embeddings(chunks,processor) | |
print("ALL DONE") | |
return knowledge_base | |
def get_text(): | |
input_text = st.text_input("You: ", "Hello, how are you?", key="input") | |
return input_text | |
if __name__ == "__main__": | |
st.set_page_config( | |
page_title="Chat with your documents demo:", | |
page_icon="π", | |
layout="wide", | |
initial_sidebar_state="expanded", ) | |
# Dropdown to select model | |
selected_model = st.sidebar.selectbox("Select a model", list(MODELS.keys())) | |
# Input box to enter API key | |
api_key = st.sidebar.text_input(f"Enter API key for {MODELS[selected_model]}", type="password") | |
# Set the API key for the selected model | |
if api_key: | |
set_api_key(MODELS[selected_model], api_key) | |
llm = load_chain(selected_model) | |
if "loaded" not in st.session_state: | |
st.session_state["loaded"] = False | |
if "knowledge_base" not in st.session_state: | |
st.session_state["knowledge_base"] = None | |
ResumePDF = st.sidebar.file_uploader( | |
"Upload your documents", type=['pdf'], help="Help message goes here", key="uploaded_file", accept_multiple_files=True | |
) | |
if ResumePDF : | |
print("ResumePDF",ResumePDF) | |
if not st.session_state["loaded"]: | |
with st.spinner('Loading files π'): | |
st.session_state["knowledge_base"] = load_documents(ResumePDF) | |
st.session_state["loaded"] = True | |
st.header("π Chat with your documents demo:") | |
if "messages" not in st.session_state: | |
st.session_state["messages"] = [ | |
{"role": "assistant", "content": "How can I help you?"}] | |
# Display chat messages from history on app rerun | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
if user_input := st.chat_input("What is your question?"): | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": user_input}) | |
# Display user message in chat message container | |
with st.chat_message("user"): | |
st.markdown(user_input) | |
with st.chat_message("assistant"): | |
message_placeholder = st.empty() | |
full_response = "" | |
with st.spinner('Thinking ...'): | |
ai_message=answer_question(st.session_state["knowledge_base"],user_input,llm) | |
# ai_message = llm.predict_messages([HumanMessage(content=user_input)]) | |
# Simulate stream of response with milliseconds delay | |
print(ai_message) | |
for chunk in ai_message.split(): | |
full_response += chunk + " " | |
time.sleep(0.05) | |
# Add a blinking cursor to simulate typing | |
message_placeholder.markdown(full_response + "β") | |
message_placeholder.markdown(full_response) | |
st.session_state.messages.append({"role": "assistant", "content": full_response}) | |