HuggingDocsLLM / app.py
Jofthomas's picture
Jofthomas HF staff
initial commit
88768cb
raw history blame
No virus
7.47 kB
import sys
import os
sys.path.append(os.path.abspath('.'))
import streamlit as st
import time
import openai
from typing import List, Optional, Tuple, Dict, IO
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage, AIMessage, ChatMessage, FunctionMessage
from langchain.chains.question_answering import load_qa_chain
from langchain.callbacks import get_openai_callback
from backend_utils.file_handlers import FileHandlerFactory
from backend_utils.text_processor import DefaultTextProcessor
MODELS = {
'gpt-3.5': 'openai',
'gpt-4': 'openai',
}
openai.api_key = ""
os.environ["OPENAI_API_KEY"]=""
def set_api_key(api_provider, api_key):
"""
Set the API key in the respective environment variable
"""
if api_provider == 'openai':
os.environ["OPENAI_API_KEY"] = api_key
openai.api_key = os.environ["OPENAI_API_KEY"]
else:
raise ValueError(f"Unknown API provider: {api_provider}")
def load_chain(selected_model):
"""Logic for loading the chain you want to use should go here."""
if selected_model=='gpt-4':
llm = ChatOpenAI(temperature=0, model="gpt-4")
else:
llm = ChatOpenAI(temperature=0)
return llm
def answer_question(knowledge_base, user_question,llm):
try:
retrived_docs= knowledge_base.similarity_search(
user_question,
k=10
)
except Exception as e:
print(f"Error finding relative chunks: {e}")
return []
print(retrived_docs)
try:
chain = load_qa_chain(
llm,
chain_type="stuff"
)
with get_openai_callback() as callback:
response = chain.run(
input_documents=retrived_docs,
question=user_question,
max_tokens=50
)
print(callback)
return response
except Exception as e:
print(f"Error running QA chain: {e}")
return ""
def read_files(files: List[IO]) -> Optional[str]:
"""
Reads the files and returns the combined text.
"""
combined_text = ""
if len(files)==1:
file=files[0]
if file is not None:
file_factory=FileHandlerFactory()
handler = file_factory.get_file_handler(file.type)
text = handler.read_file(file)
if not text:
print(f"No text could be extracted from {file.name}. Please ensure the file is not encrypted or corrupted.")
return None
else:
combined_text += text
else:
for file in files:
if file is not None:
file_factory=FileHandlerFactory()
handler = file_factory.get_file_handler(file.type)
text = handler.read_file(file)
if not text:
print(f"No text could be extracted from {file.name}. Please ensure the file is not encrypted or corrupted.")
return None
else:
combined_text += text
return combined_text
def chunk_text(combined_text: str) -> Optional[List[str]]:
processor=DefaultTextProcessor(500,0)
chunks = processor.split_text(combined_text)
if not chunks:
print("Couldn't split the text into chunks. Please try again with different text.")
return None
return chunks,processor
def create_embeddings( chunks: List[str], processor) -> Optional[Dict]:
"""
Takes chunks and creates embeddings in a knowledge base.
"""
knowledge_base = processor.create_embeddings(chunks)
if not knowledge_base:
print("Couldn't create embeddings from the text. Please try again.")
return None
return knowledge_base
def load_documents(files):
print(files)
combined_text = read_files(files)
chunks,processor = chunk_text(combined_text)
knowledge_base = create_embeddings(chunks,processor)
print("ALL DONE")
return knowledge_base
def get_text():
input_text = st.text_input("You: ", "Hello, how are you?", key="input")
return input_text
if __name__ == "__main__":
st.set_page_config(
page_title="Chat with your documents demo:",
page_icon="πŸ“–",
layout="wide",
initial_sidebar_state="expanded", )
# Dropdown to select model
selected_model = st.sidebar.selectbox("Select a model", list(MODELS.keys()))
# Input box to enter API key
api_key = st.sidebar.text_input(f"Enter API key for {MODELS[selected_model]}", type="password")
# Set the API key for the selected model
if api_key:
set_api_key(MODELS[selected_model], api_key)
llm = load_chain(selected_model)
if "loaded" not in st.session_state:
st.session_state["loaded"] = False
if "knowledge_base" not in st.session_state:
st.session_state["knowledge_base"] = None
ResumePDF = st.sidebar.file_uploader(
"Upload your documents", type=['pdf'], help="Help message goes here", key="uploaded_file", accept_multiple_files=True
)
if ResumePDF :
print("ResumePDF",ResumePDF)
if not st.session_state["loaded"]:
with st.spinner('Loading files πŸ“–'):
st.session_state["knowledge_base"] = load_documents(ResumePDF)
st.session_state["loaded"] = True
st.header("πŸ“– Chat with your documents demo:")
if "messages" not in st.session_state:
st.session_state["messages"] = [
{"role": "assistant", "content": "How can I help you?"}]
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if user_input := st.chat_input("What is your question?"):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": user_input})
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(user_input)
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
with st.spinner('Thinking ...'):
ai_message=answer_question(st.session_state["knowledge_base"],user_input,llm)
# ai_message = llm.predict_messages([HumanMessage(content=user_input)])
# Simulate stream of response with milliseconds delay
print(ai_message)
for chunk in ai_message.split():
full_response += chunk + " "
time.sleep(0.05)
# Add a blinking cursor to simulate typing
message_placeholder.markdown(full_response + "β–Œ")
message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})