auto-draft / auto_generators.py
shaocongma
Add a generator wrapper using configuration file. Edit the logic of searching references. Add Gradio UI for testing Knowledge database.
94dc00e
raw
history blame
14.4 kB
import json
import os.path
from utils.references import References
from utils.knowledge import Knowledge
from utils.file_operations import hash_name, make_archive, copy_templates
from utils.tex_processing import create_copies
from section_generator import section_generation # figures_generation, section_generation_bg, keywords_generation,
from utils.prompts import generate_paper_prompts
import logging
import time
from langchain.vectorstores import FAISS
from utils.gpt_interaction import GPTModel
from utils.prompts import SYSTEM
from models import EMBEDDINGS
TOTAL_TOKENS = 0
TOTAL_PROMPTS_TOKENS = 0
TOTAL_COMPLETION_TOKENS = 0
def log_usage(usage, generating_target, print_out=True):
global TOTAL_TOKENS
global TOTAL_PROMPTS_TOKENS
global TOTAL_COMPLETION_TOKENS
prompts_tokens = usage['prompt_tokens']
completion_tokens = usage['completion_tokens']
total_tokens = usage['total_tokens']
TOTAL_TOKENS += total_tokens
TOTAL_PROMPTS_TOKENS += prompts_tokens
TOTAL_COMPLETION_TOKENS += completion_tokens
message = f">>USAGE>> For generating {generating_target}, {total_tokens} tokens have been used " \
f"({prompts_tokens} for prompts; {completion_tokens} for completion). " \
f"{TOTAL_TOKENS} tokens have been used in total."
if print_out:
print(message)
logging.info(message)
def _generation_setup(title, description="", template="ICLR2022",
tldr=False, max_kw_refs=10, refs=None, max_tokens_ref=2048, # generating references
knowledge_database=None, max_tokens_kd=2048, query_counts=10, # querying from knowledge database
debug=True):
"""
This function handles the setup process for paper generation; it contains three folds
1. Copy the template to the outputs folder. Create the log file `generation.log`
2. Collect references based on the given `title` and `description`
3. Generate the basic `paper` object (a dictionary)
Parameters:
title (str): The title of the paper.
description (str, optional): A short description or abstract for the paper. Defaults to an empty string.
template (str, optional): The template to be used for paper generation. Defaults to "ICLR2022".
tldr (bool, optional): A flag indicating whether a TL;DR (Too Long; Didn't Read) summary should be used
for the collected papers. Defaults to False.
max_kw_refs (int, optional): The maximum number of references that can be associated with each keyword.
Defaults to 10.
bib_refs (path to a bibtex file, optional).
Returns:
tuple: A tuple containing the following elements:
- paper (dict): A dictionary containing the generated paper information.
- destination_folder (str): The path to the destination folder where the generation log is saved.
- all_paper_ids (list): A list of all paper IDs collected for the references.
"""
# print("Generation setup...")
# paper = {}
# paper_body = {}
llm = GPTModel(model="gpt-3.5-turbo")
# Create a copy in the outputs folder.
bibtex_path, destination_folder = copy_templates(template, title)
logging.basicConfig(level=logging.INFO, filename=os.path.join(destination_folder, "generation.log"))
###################################################################################################################
# Generate contributions
###################################################################################################################
if description:
contributions = description
else:
try:
contributions, usage = llm(systems=SYSTEM["contributions"], prompts=title, return_json=True)
contributions = [f"Contribution {idx}: {contributions[contribution]['statement']}\n" \
f"Novelty of Contribution {idx}: {contributions[contribution]['reason']}\n"
for idx, contribution in enumerate(contributions)]
contributions = "".join(contributions)
log_usage(usage, "contributions")
except RuntimeError:
if debug:
raise RuntimeError("Failed to generate contributions.")
else:
print("Failed to generate contributions. Use empty contributions.")
contributions = ""
print("Contributions:\n{}".format(contributions))
###################################################################################################################
# Generate references
###################################################################################################################
# input_dict = {"title": title, "description": description}
# keywords, usage = keywords_generation(input_dict)
# log_usage(usage, "keywords")
try:
keywords, usage = llm(systems=SYSTEM["keywords"], prompts=title, return_json=True)
log_usage(usage, "keywords")
keywords = {keyword: max_kw_refs for keyword in keywords}
except RuntimeError:
if debug:
raise RuntimeError("Failed to generate keywords.")
else:
print("Failed to generate keywords. Use default keywords.")
keywords = {"machine learning": max_kw_refs, "artificial intelligence": max_kw_refs} # DEFAULT KEYWORDS
# generate keywords dictionary
# keywords = {keyword: max_kw_refs for keyword in keywords}
print("Keywords: \n", keywords)
# todo: in some rare situations, collected papers will be an empty list. handle this issue
ref = References(title, load_papers=refs)
ref.collect_papers(keywords, tldr=tldr)
references = ref.to_prompts(max_tokens=max_tokens_ref)
all_paper_ids = ref.to_bibtex(bibtex_path)
###################################################################################################################
# Generate domain knowledge
###################################################################################################################
prompts = f"Title: {title}\n Contributions: {contributions}"
preliminaries_kw, _ = llm(systems=SYSTEM["preliminaries"], prompts=prompts)
# check if the database exists or not
db_path = f"knowledge_databases/{knowledge_database}"
db_config_path = os.path.join(db_path, "db_meta.json")
db_index_path = os.path.join(db_path, "faiss_index")
if os.path.isdir(db_path):
try:
# load configuration file
with open(db_config_path, "r", encoding="utf-8") as f:
db_config = json.load(f)
model_name = db_config["embedding_model"]
embeddings = EMBEDDINGS[model_name]
db = FAISS.load_local(db_index_path, embeddings)
knowledge = Knowledge(db=db)
knowledge.collect_knowledge(preliminaries_kw, max_query=query_counts)
domain_knowledge = knowledge.to_prompts(max_tokens_kd)
except Exception as e:
if debug:
raise RuntimeError(f"Failed to query from FAISS. Error {e}.")
else:
print(f"Failed to query from FAISS. Error {e}. Use empty domain knowledge instead.")
domain_knowledge = ""
else:
print("Selected database doesn't exist or no database is selected.")
domain_knowledge = ""
###################################################################################################################
# Generate necessary media
###################################################################################################################
prompts = f"Title: {title}\n Contributions: {contributions}"
try:
components, usage = llm(systems=SYSTEM["components"], prompts=prompts, return_json=True)
log_usage(usage, "media")
except RuntimeError:
if debug:
raise RuntimeError("Failed to generate media.")
else:
print("Failed to generate media. Use default media.")
components = {}
print(f"The paper information has been initialized. References are saved to {bibtex_path}.")
paper = {}
paper_body = {}
paper["title"] = title
paper["description"] = contributions
paper["references"] = references
paper["body"] = paper_body
paper["bibtex"] = bibtex_path
paper["domain_knowledge"] = domain_knowledge
paper["components"] = components
# print(json.dumps(paper, indent=4))
return paper, destination_folder, all_paper_ids
# todo: use `all_paper_ids` to check if all citations are in this list
def generate_backgrounds(title, description="", template="ICLR2022", model="gpt-4"):
# todo: to match the current generation setup
paper, destination_folder, _ = _generation_setup(title, description, template, model)
for section in ["introduction", "related works", "backgrounds"]:
try:
usage = section_generation_bg(paper, section, destination_folder, model=model)
log_usage(usage, section)
except Exception as e:
message = f"Failed to generate {section}. {type(e).__name__} was raised: {e}"
print(message)
logging.info(message)
print(f"The paper '{title}' has been generated. Saved to {destination_folder}.")
input_dict = {"title": title, "description": description, "generator": "generate_backgrounds"}
filename = hash_name(input_dict) + ".zip"
return make_archive(destination_folder, filename)
def generate_draft(title, description="", # main input
tldr=True, max_kw_refs=10, refs=None, max_tokens_ref=2048, # references
knowledge_database=None, max_tokens_kd=2048, query_counts=10, # domain knowledge
sections=None, model="gpt-4", template="ICLR2022", prompts_mode=False, # outputs parameters
):
"""
This function generates a draft paper using the provided information; it contains three steps: 1. Pre-processing:
Initializes the setup for paper generation and filters the sections to be included in the paper. 2. Processing:
Generates each section of the paper. 3. Post-processing: Creates backup copies of the paper and returns the paper
in a zipped format.
Parameters:
title (str): The title of the paper.
description (str, optional): A short description or abstract for the paper. Defaults to an empty string.
template (str, optional): The template to be used for paper generation. Defaults to "ICLR2022".
tldr (bool, optional): A flag indicating whether a TL;DR (Too Long; Didn't Read) summary should be used
for the collected papers. Defaults to True.
max_kw_refs (int, optional): The maximum number of references that can be associated with each keyword.
Defaults to 10.
sections (list, optional): The sections to be included in the paper. If not provided, all the standard
sections are included.
bib_refs (path to a bibtex file, optional).
model (str, optional): The language model to be used for paper generation. Defaults to "gpt-4".
Returns:
str: The path to the zipped file containing the generated paper and associated files.
Note: The function also handles errors that occur during section generation and retries a maximum of 4 times
before proceeding.
"""
def _filter_sections(sections):
ordered_sections = ["introduction", "related works", "backgrounds", "methodology", "experiments", "conclusion",
"abstract"]
return [section for section in ordered_sections if section in sections]
# pre-processing `sections` parameter;
print("================START================")
print(f"Generating the paper '{title}'.")
print("================PRE-PROCESSING================")
# make `sections` in a correct order
if sections is None:
sections = ["introduction", "related works", "backgrounds", "methodology", "experiments", "conclusion",
"abstract"]
else:
sections = _filter_sections(sections)
paper, destination_folder, _ = _generation_setup(title, description, template, tldr, max_kw_refs, refs,
max_tokens_ref=max_tokens_ref, max_tokens_kd=max_tokens_kd,
query_counts=query_counts,
knowledge_database=knowledge_database)
# main components
prompts_dict = {}
print(f"================PROCESSING================")
for section in sections:
prompts = generate_paper_prompts(paper, section)
prompts_dict[section] = prompts
if prompts_mode:
continue
print(f"Generate {section} part...")
max_attempts = 4
attempts_count = 0
while attempts_count < max_attempts:
try:
usage = section_generation(paper, section, destination_folder, model=model)
print(f"{section} part has been generated. ")
log_usage(usage, section)
break
except Exception as e:
message = f"Failed to generate {section}. {type(e).__name__} was raised: {e}\n"
print(message)
logging.info(message)
attempts_count += 1
time.sleep(15)
# post-processing
print("================POST-PROCESSING================")
create_copies(destination_folder)
filename = "prompts.json"
with open(os.path.join(destination_folder, filename), "w") as f:
json.dump(prompts_dict, f)
print("\nMission completed.\n")
return destination_folder
# return make_archive(destination_folder, filename)
if __name__ == "__main__":
import openai
openai.api_key = os.getenv("OPENAI_API_KEY")
openai.api_base = os.getenv("OPENAI_API_BASE")
target_title = "Playing Atari with Decentralized Reinforcement Learning"
output = generate_draft(target_title, knowledge_database="ml_textbook_test")
print(output)