from dataclasses import asdict from io import StringIO import json import os import streamlit as st from data_driven_characters.character import generate_character_definition, Character from data_driven_characters.corpus import ( generate_corpus_summaries, generate_docs, ) from data_driven_characters.chatbots import ( SummaryChatBot, RetrievalChatBot, SummaryRetrievalChatBot, ) from data_driven_characters.interfaces import reset_chat, clear_user_input, converse @st.cache_resource() def create_chatbot(character_definition, corpus_summaries, chatbot_type): if chatbot_type == "summary": chatbot = SummaryChatBot(character_definition=character_definition) elif chatbot_type == "retrieval": chatbot = RetrievalChatBot( character_definition=character_definition, documents=corpus_summaries, ) elif chatbot_type == "summary with retrieval": chatbot = SummaryRetrievalChatBot( character_definition=character_definition, documents=corpus_summaries, ) else: raise ValueError(f"Unknown chatbot type: {chatbot_type}") return chatbot @st.cache_data(persist="disk") def process_corpus(corpus): # load docs docs = generate_docs( corpus=corpus, chunk_size=2048, chunk_overlap=64, ) # generate summaries corpus_summaries = generate_corpus_summaries(docs=docs, summary_type="map_reduce") return corpus_summaries @st.cache_data(persist="disk") def get_character_definition(name, corpus_summaries): character_definition = generate_character_definition( name=name, corpus_summaries=corpus_summaries, ) return asdict(character_definition) def main(): st.title("Data-Driven Characters") st.write( "Upload a corpus in the sidebar to generate a character chatbot that is grounded in the corpus content." ) openai_api_key = st.text_input( label="Your OpenAI API KEY", placeholder="Your OpenAI API KEY", type="password", ) os.environ["OPENAI_API_KEY"] = openai_api_key with st.sidebar: uploaded_file = st.file_uploader("Upload corpus") if uploaded_file is not None: corpus_name = os.path.splitext(os.path.basename(uploaded_file.name))[0] # read file stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) corpus = stringio.read() # scrollable text st.markdown( f"""