Spaces:
Build error
Build error
from dataclasses import dataclass | |
from typing import Literal | |
import streamlit as st | |
from langchain import OpenAI | |
from langchain.callbacks import get_openai_callback | |
from langchain.chains import ConversationChain | |
from langchain.chains.conversation.memory import ConversationSummaryMemory | |
import streamlit.components.v1 as components | |
import streamlit as st | |
from langchain.chat_models import ChatOpenAI | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.embeddings.cohere import CohereEmbeddings | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.vectorstores.elastic_vector_search import ElasticVectorSearch | |
from langchain.vectorstores import Chroma | |
from langchain.chat_models import ChatOpenAI | |
from langchain.chains import RetrievalQAWithSourcesChain | |
from dotenv import load_dotenv | |
import os | |
import openai | |
import time | |
load_dotenv() | |
os.environ["OPENAI_API_KEY"] = os.environ['my_secret'] | |
class Message: | |
"""Class for keeping track of a chat message.""" | |
origin: Literal["human", "ai"] | |
message: str | |
def load_css(): | |
with open("styles.css", "r") as f: | |
css = f"<style>{f.read()}</style>" | |
st.markdown(css, unsafe_allow_html=True) | |
def initialize_session_state(): | |
if "history" not in st.session_state: | |
st.session_state.history = [Message("ai", "์๋ ํ์ธ์, ์ด๋ป๊ฒ ๋์๋๋ฆด๊น์?")] | |
if "token_count" not in st.session_state: | |
st.session_state.token_count = 0 | |
if 'chain' not in st.session_state: | |
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0) | |
loader = PyPDFLoader("/home/user/app/docs.pdf") | |
documents = loader.load() | |
text_splitter = CharacterTextSplitter(chunk_size=800, chunk_overlap=0) | |
texts = text_splitter.split_documents(documents) | |
embeddings = OpenAIEmbeddings() | |
vector_store = Chroma.from_documents(texts, embeddings) | |
retriever = vector_store.as_retriever(search_kwargs={"k": 3}) | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
SystemMessagePromptTemplate, | |
HumanMessagePromptTemplate, | |
) | |
system_template="""You act like a successful teacher. Talk to students about the career path of a teacher. | |
Never reveal that it is an AI. Use the following pieces of context to answer the users question shortly. | |
Given the following summaries of a long document and a question, create a final answer. | |
If someone talks about something that isn't job-related, just say "I don't know" and don't try to make up an answer. | |
If you don't know the answer, just say that "I don't know", don't try to make up an answer. | |
---------------- | |
{summaries} | |
You MUST answer in Korean and in Markdown format""" | |
messages = [ | |
SystemMessagePromptTemplate.from_template(system_template), | |
HumanMessagePromptTemplate.from_template("{question}") | |
] | |
prompt = ChatPromptTemplate.from_messages(messages) | |
chain_type_kwargs = {"prompt": prompt} | |
st.session_state['chain'] = RetrievalQAWithSourcesChain.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=True, | |
chain_type_kwargs=chain_type_kwargs, | |
reduce_k_below_max_tokens=True, | |
verbose=True, | |
) | |
def generate_response(user_input): | |
result = st.session_state['chain'](user_input) | |
bot_message = result['answer'] | |
return bot_message | |
def on_click_callback(): | |
with get_openai_callback() as cb: | |
human_prompt = st.session_state.human_prompt | |
llm_response = generate_response(human_prompt) | |
st.session_state.history.append( | |
Message("human", human_prompt) | |
) | |
st.session_state.history.append( | |
Message("ai", llm_response) | |
) | |
st.session_state.token_count += cb.total_tokens | |
load_css() | |
initialize_session_state() | |
st.title("๊ต์ฌ์ ์ง๋ก์๋ด์ ํด๋ณด์ธ์, \n ์ค์ ์ธํฐ๋ทฐ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํฉ๋๋ค. ๐ค") | |
chat_placeholder = st.container() | |
prompt_placeholder = st.form("chat-form") | |
credit_card_placeholder = st.empty() | |
with chat_placeholder: | |
for chat in st.session_state.history[:-1]: | |
div = f""" | |
<div class="chat-row | |
{'' if chat.origin == 'ai' else 'row-reverse'}"> | |
<img class="chat-icon" src="https://cdn-icons-png.flaticon.com/{ | |
'/512/3058/3058838.png' if chat.origin == 'ai' | |
else '512/1177/1177568.png'}" | |
width=32 height=32> | |
<div class="chat-bubble | |
{'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}"> | |
​{chat.message} | |
</div> | |
</div> | |
""" | |
st.markdown(div, unsafe_allow_html=True) | |
if st.session_state.history: | |
last_chat = st.session_state.history[-1] | |
div_start = f""" | |
<div class="chat-row | |
{'' if last_chat.origin == 'ai' else 'row-reverse'}"> | |
<img class="chat-icon" src="https://cdn-icons-png.flaticon.com/{ | |
'/512/3058/3058838.png' if last_chat.origin == 'ai' | |
else '512/1177/1177568.png'}" | |
width=32 height=32> | |
<div class="chat-bubble | |
{'ai-bubble' if last_chat.origin == 'ai' else 'human-bubble'}"> | |
​""" | |
div_end = """ | |
</div> | |
</div> | |
""" | |
new_placeholder = st.empty() | |
for j in range(len(last_chat.message)): | |
new_placeholder.markdown(div_start + last_chat.message[:j+1] + div_end, unsafe_allow_html=True) | |
time.sleep(0.05) | |
for _ in range(3): | |
st.markdown("") | |
with prompt_placeholder: | |
st.markdown("**Chat**") | |
cols = st.columns((6, 1)) | |
cols[0].text_input( | |
"Chat", | |
value="๊ต์ฌ๊ฐ ๋๋ ค๋ฉด ๋ฌด์์ ํด์ผ ํ๋์?", | |
label_visibility="collapsed", | |
key="human_prompt", | |
) | |
cols[1].form_submit_button( | |
"Submit", | |
type="primary", | |
on_click=on_click_callback, | |
) | |
# credit_card_placeholder.caption(f""" | |
# Used {st.session_state.token_count} tokens \n | |
# Debug Langchain conversation: | |
# {st.session_state.chain.memory.buffer} | |
# """) | |
components.html(""" | |
<script> | |
const streamlitDoc = window.parent.document; | |
const buttons = Array.from( | |
streamlitDoc.querySelectorAll('.stButton > button') | |
); | |
const submitButton = buttons.find( | |
el => el.innerText === 'Submit' | |
); | |
streamlitDoc.addEventListener('keydown', function(e) { | |
switch (e.key) { | |
case 'Enter': | |
submitButton.click(); | |
break; | |
} | |
}); | |
</script> | |
""", | |
height=0, | |
width=0, | |
) | |