import argparse from dataclasses import asdict import json import os import streamlit as st from data_driven_characters.character import get_character_definition from data_driven_characters.corpus import ( get_corpus_summaries, load_docs, ) from data_driven_characters.chatbots import ( SummaryChatBot, RetrievalChatBot, SummaryRetrievalChatBot, ) from data_driven_characters.interfaces import CommandLine, Streamlit OUTPUT_ROOT = "output" def create_chatbot(corpus, character_name, chatbot_type, retrieval_docs, summary_type): # logging corpus_name = os.path.splitext(os.path.basename(corpus))[0] output_dir = f"{OUTPUT_ROOT}/{corpus_name}/summarytype_{summary_type}" os.makedirs(output_dir, exist_ok=True) summaries_dir = f"{output_dir}/summaries" character_definitions_dir = f"{output_dir}/character_definitions" os.makedirs(character_definitions_dir, exist_ok=True) # load docs docs = load_docs(corpus_path=corpus, chunk_size=2048, chunk_overlap=64) # generate summaries corpus_summaries = get_corpus_summaries( docs=docs, summary_type=summary_type, cache_dir=summaries_dir ) # get character definition character_definition = get_character_definition( name=character_name, corpus_summaries=corpus_summaries, cache_dir=character_definitions_dir, ) print(json.dumps(asdict(character_definition), indent=4)) # construct retrieval documents if retrieval_docs == "raw": documents = [ doc.page_content for doc in load_docs(corpus_path=corpus, chunk_size=256, chunk_overlap=16) ] elif retrieval_docs == "summarized": documents = corpus_summaries else: raise ValueError(f"Unknown retrieval docs type: {retrieval_docs}") # initialize chatbot if chatbot_type == "summary": chatbot = SummaryChatBot(character_definition=character_definition) elif chatbot_type == "retrieval": chatbot = RetrievalChatBot( character_definition=character_definition, documents=documents, ) elif chatbot_type == "summary_retrieval": chatbot = SummaryRetrievalChatBot( character_definition=character_definition, documents=documents, ) else: raise ValueError(f"Unknown chatbot type: {chatbot_type}") return chatbot def main(): parser = argparse.ArgumentParser() parser.add_argument( "--corpus", type=str, default="data/everything_everywhere_all_at_once.txt" ) parser.add_argument("--character_name", type=str, default="Evelyn") parser.add_argument( "--chatbot_type", type=str, default="summary_retrieval", choices=["summary", "retrieval", "summary_retrieval"], ) parser.add_argument( "--summary_type", type=str, default="map_reduce", choices=["map_reduce", "refine"], ) parser.add_argument( "--retrieval_docs", type=str, default="summarized", choices=["raw", "summarized"], ) parser.add_argument( "--interface", type=str, default="cli", choices=["cli", "streamlit"] ) args = parser.parse_args() if args.interface == "cli": chatbot = create_chatbot( args.corpus, args.character_name, args.chatbot_type, args.retrieval_docs, args.summary_type, ) app = CommandLine(chatbot=chatbot) elif args.interface == "streamlit": chatbot = st.cache_resource(create_chatbot)( args.corpus, args.character_name, args.chatbot_type, args.retrieval_docs, args.summary_type, ) st.title("Data Driven Characters") st.write("Create your own character chatbots, grounded in existing corpora.") st.divider() st.markdown(f"**chatbot type**: *{args.chatbot_type}*") if "retrieval" in args.chatbot_type: st.markdown(f"**retrieving from**: *{args.retrieval_docs} corpus*") app = Streamlit(chatbot=chatbot) else: raise ValueError(f"Unknown interface: {args.interface}") app.run() if __name__ == "__main__": main()