rubensmau's picture
repo chat upload
b092c58
import json
import os
from langchain import PromptTemplate, LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.chains.summarize import load_summarize_chain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from data_driven_characters.constants import VERBOSE
def generate_docs(corpus, chunk_size, chunk_overlap):
"""Generate docs from a corpus."""
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
docs = text_splitter.create_documents([corpus])
return docs
def load_docs(corpus_path, chunk_size, chunk_overlap):
"""Load the corpus and split it into chunks."""
with open(corpus_path) as f:
corpus = f.read()
docs = generate_docs(corpus, chunk_size, chunk_overlap)
return docs
def generate_corpus_summaries(docs, summary_type="map_reduce"):
"""Generate summaries of the story."""
GPT3 = ChatOpenAI(model_name="gpt-3.5-turbo")
chain = load_summarize_chain(
GPT3, chain_type=summary_type, return_intermediate_steps=True, verbose=True
)
summary = chain({"input_documents": docs}, return_only_outputs=True)
intermediate_summaries = summary["intermediate_steps"]
return intermediate_summaries
def get_corpus_summaries(docs, summary_type, cache_dir, force_refresh=False):
"""Load the corpus summaries from cache or generate them."""
if not os.path.exists(cache_dir) or force_refresh:
os.makedirs(cache_dir, exist_ok=True)
if VERBOSE:
print("Summaries do not exist. Generating summaries.")
intermediate_summaries = generate_corpus_summaries(docs, summary_type)
for i, intermediate_summary in enumerate(intermediate_summaries):
with open(os.path.join(cache_dir, f"summary_{i}.txt"), "w") as f:
f.write(intermediate_summary)
else:
if VERBOSE:
print("Summaries already exist. Loading summaries.")
intermediate_summaries = []
for i in range(len(os.listdir(cache_dir))):
with open(os.path.join(cache_dir, f"summary_{i}.txt")) as f:
intermediate_summaries.append(f.read())
return intermediate_summaries
def generate_characters(corpus_summaries, num_characters):
"""Get a list of characters from a list of summaries."""
GPT4 = ChatOpenAI(model_name="gpt-3.5-turbo")
characters_prompt_template = """Consider the following corpus.
---
{corpus_summaries}
---
Give a line-separated list of all the characters, ordered by importance, without punctuation.
"""
characters = LLMChain(
llm=GPT4, prompt=PromptTemplate.from_template(characters_prompt_template)
).run(corpus_summaries="\n\n".join(corpus_summaries))
# remove (, ), and " for each element of list
return characters.split("\n")[:num_characters]
def get_characters(corpus_summaries, num_characters, cache_dir, force_refresh=False):
cache_file = os.path.join(cache_dir, "characters.json")
if not os.path.exists(cache_file) or force_refresh:
characters = generate_characters(corpus_summaries, num_characters)
with open(cache_file, "w") as f:
json.dump(characters, f)
else:
with open(cache_file, "r") as f:
characters = json.load(f)
return characters