Spaces:
Sleeping
Sleeping
import json | |
import os | |
from langchain import PromptTemplate, LLMChain | |
from langchain.chat_models import ChatOpenAI | |
from langchain.chains.summarize import load_summarize_chain | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from data_driven_characters.constants import VERBOSE | |
def generate_docs(corpus, chunk_size, chunk_overlap): | |
"""Generate docs from a corpus.""" | |
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | |
chunk_size=chunk_size, chunk_overlap=chunk_overlap | |
) | |
docs = text_splitter.create_documents([corpus]) | |
return docs | |
def load_docs(corpus_path, chunk_size, chunk_overlap): | |
"""Load the corpus and split it into chunks.""" | |
with open(corpus_path) as f: | |
corpus = f.read() | |
docs = generate_docs(corpus, chunk_size, chunk_overlap) | |
return docs | |
def generate_corpus_summaries(docs, summary_type="map_reduce"): | |
"""Generate summaries of the story.""" | |
GPT3 = ChatOpenAI(model_name="gpt-3.5-turbo") | |
chain = load_summarize_chain( | |
GPT3, chain_type=summary_type, return_intermediate_steps=True, verbose=True | |
) | |
summary = chain({"input_documents": docs}, return_only_outputs=True) | |
intermediate_summaries = summary["intermediate_steps"] | |
return intermediate_summaries | |
def get_corpus_summaries(docs, summary_type, cache_dir, force_refresh=False): | |
"""Load the corpus summaries from cache or generate them.""" | |
if not os.path.exists(cache_dir) or force_refresh: | |
os.makedirs(cache_dir, exist_ok=True) | |
if VERBOSE: | |
print("Summaries do not exist. Generating summaries.") | |
intermediate_summaries = generate_corpus_summaries(docs, summary_type) | |
for i, intermediate_summary in enumerate(intermediate_summaries): | |
with open(os.path.join(cache_dir, f"summary_{i}.txt"), "w") as f: | |
f.write(intermediate_summary) | |
else: | |
if VERBOSE: | |
print("Summaries already exist. Loading summaries.") | |
intermediate_summaries = [] | |
for i in range(len(os.listdir(cache_dir))): | |
with open(os.path.join(cache_dir, f"summary_{i}.txt")) as f: | |
intermediate_summaries.append(f.read()) | |
return intermediate_summaries | |
def generate_characters(corpus_summaries, num_characters): | |
"""Get a list of characters from a list of summaries.""" | |
GPT4 = ChatOpenAI(model_name="gpt-3.5-turbo") | |
characters_prompt_template = """Consider the following corpus. | |
--- | |
{corpus_summaries} | |
--- | |
Give a line-separated list of all the characters, ordered by importance, without punctuation. | |
""" | |
characters = LLMChain( | |
llm=GPT4, prompt=PromptTemplate.from_template(characters_prompt_template) | |
).run(corpus_summaries="\n\n".join(corpus_summaries)) | |
# remove (, ), and " for each element of list | |
return characters.split("\n")[:num_characters] | |
def get_characters(corpus_summaries, num_characters, cache_dir, force_refresh=False): | |
cache_file = os.path.join(cache_dir, "characters.json") | |
if not os.path.exists(cache_file) or force_refresh: | |
characters = generate_characters(corpus_summaries, num_characters) | |
with open(cache_file, "w") as f: | |
json.dump(characters, f) | |
else: | |
with open(cache_file, "r") as f: | |
characters = json.load(f) | |
return characters | |