Spaces:
Running
Running
File size: 14,469 Bytes
1457d21 d1feb02 238735e 1457d21 a0d1776 3b4e6ce 1457d21 2617428 238735e 2c0ffed 1457d21 238735e ae495a3 238735e 1457d21 c42190b 1457d21 238735e c42190b 1457d21 365213e c42190b 2c5adce 365213e 2dc9347 1457d21 238735e a0d1776 2c5adce 238735e 1457d21 365213e c9efba3 1457d21 2dc9347 1457d21 2617428 1457d21 238735e 1457d21 238735e 1457d21 238735e 1457d21 2c5adce ae495a3 365213e ae495a3 238735e a0d1776 238735e 05783f8 a0d1776 238735e 1b82d4c 1457d21 2c5adce 036df68 2c5adce 365213e c42190b ae239a7 2dc9347 1457d21 365213e 2c5adce 036df68 2c5adce 1457d21 2dc9347 2617428 c42190b 365213e 2617428 c42190b c9efba3 c42190b c9efba3 2dc9347 c9efba3 2dc9347 3b4e6ce 2dc9347 3b4e6ce ae495a3 365213e 2617428 3a7ead9 c9efba3 1457d21 c9efba3 2c5adce 1457d21 2c5adce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 |
import json
import os.path
from utils.references import References
from utils.knowledge import Knowledge
from utils.file_operations import hash_name, make_archive, copy_templates
from utils.tex_processing import create_copies
from section_generator import section_generation # figures_generation, section_generation_bg, keywords_generation,
from utils.prompts import generate_paper_prompts
import logging
import time
from langchain.vectorstores import FAISS
from utils.gpt_interaction import GPTModel
from utils.prompts import SYSTEM
from models import EMBEDDINGS
TOTAL_TOKENS = 0
TOTAL_PROMPTS_TOKENS = 0
TOTAL_COMPLETION_TOKENS = 0
def log_usage(usage, generating_target, print_out=True):
global TOTAL_TOKENS
global TOTAL_PROMPTS_TOKENS
global TOTAL_COMPLETION_TOKENS
prompts_tokens = usage['prompt_tokens']
completion_tokens = usage['completion_tokens']
total_tokens = usage['total_tokens']
TOTAL_TOKENS += total_tokens
TOTAL_PROMPTS_TOKENS += prompts_tokens
TOTAL_COMPLETION_TOKENS += completion_tokens
message = f">>USAGE>> For generating {generating_target}, {total_tokens} tokens have been used " \
f"({prompts_tokens} for prompts; {completion_tokens} for completion). " \
f"{TOTAL_TOKENS} tokens have been used in total."
if print_out:
print(message)
logging.info(message)
def _generation_setup(title, description="", template="ICLR2022",
tldr=False, max_kw_refs=10, bib_refs=None, max_tokens_ref=2048, # generating references
knowledge_database=None, max_tokens_kd=2048, query_counts=10, # querying from knowledge database
debug=True):
"""
This function handles the setup process for paper generation; it contains three folds
1. Copy the template to the outputs folder. Create the log file `generation.log`
2. Collect references based on the given `title` and `description`
3. Generate the basic `paper` object (a dictionary)
Parameters:
title (str): The title of the paper.
description (str, optional): A short description or abstract for the paper. Defaults to an empty string.
template (str, optional): The template to be used for paper generation. Defaults to "ICLR2022".
tldr (bool, optional): A flag indicating whether a TL;DR (Too Long; Didn't Read) summary should be used
for the collected papers. Defaults to False.
max_kw_refs (int, optional): The maximum number of references that can be associated with each keyword.
Defaults to 10.
bib_refs (path to a bibtex file, optional).
Returns:
tuple: A tuple containing the following elements:
- paper (dict): A dictionary containing the generated paper information.
- destination_folder (str): The path to the destination folder where the generation log is saved.
- all_paper_ids (list): A list of all paper IDs collected for the references.
"""
# print("Generation setup...")
# paper = {}
# paper_body = {}
llm = GPTModel()
# Create a copy in the outputs folder.
bibtex_path, destination_folder = copy_templates(template, title)
logging.basicConfig(level=logging.INFO, filename=os.path.join(destination_folder, "generation.log"))
###################################################################################################################
# Generate contributions
###################################################################################################################
if description:
contributions = description
else:
try:
contributions, usage = llm(systems=SYSTEM["contributions"], prompts=title, return_json=True)
contributions = [f"Contribution {idx}: {contributions[contribution]['statement']}\n" \
f"Novelty of Contribution {idx}: {contributions[contribution]['reason']}\n"
for idx, contribution in enumerate(contributions)]
contributions = "".join(contributions)
log_usage(usage, "contributions")
except RuntimeError:
if debug:
raise RuntimeError("Failed to generate contributions.")
else:
print("Failed to generate contributions. Use empty contributions.")
contributions = ""
print("Contributions:\n{}".format(contributions))
###################################################################################################################
# Generate references
###################################################################################################################
# input_dict = {"title": title, "description": description}
# keywords, usage = keywords_generation(input_dict)
# log_usage(usage, "keywords")
try:
keywords, usage = llm(systems=SYSTEM["keywords"], prompts=title, return_json=True)
log_usage(usage, "keywords")
keywords = {keyword: max_kw_refs for keyword in keywords}
except RuntimeError:
if debug:
raise RuntimeError("Failed to generate keywords.")
else:
print("Failed to generate keywords. Use default keywords.")
keywords = {"machine learning": max_kw_refs, "artificial intelligence": max_kw_refs} # DEFAULT KEYWORDS
# generate keywords dictionary
# keywords = {keyword: max_kw_refs for keyword in keywords}
print("Keywords: \n", keywords)
# todo: in some rare situations, collected papers will be an empty list. handle this issue
ref = References(title, bib_refs)
ref.collect_papers(keywords, tldr=tldr)
references = ref.to_prompts(max_tokens=max_tokens_ref)
all_paper_ids = ref.to_bibtex(bibtex_path)
###################################################################################################################
# Generate domain knowledge
###################################################################################################################
prompts = f"Title: {title}\n Contributions: {contributions}"
preliminaries_kw, _ = llm(systems=SYSTEM["preliminaries"], prompts=prompts)
# check if the database exists or not
db_path = f"knowledge_databases/{knowledge_database}"
db_config_path = os.path.join(db_path, "db_meta.json")
db_index_path = os.path.join(db_path, "faiss_index")
if os.path.isdir(db_path):
try:
# load configuration file
with open(db_config_path, "r", encoding="utf-8") as f:
db_config = json.load(f)
model_name = db_config["embedding_model"]
embeddings = EMBEDDINGS[model_name]
db = FAISS.load_local(db_index_path, embeddings)
knowledge = Knowledge(db=db)
knowledge.collect_knowledge(preliminaries_kw, max_query=query_counts)
domain_knowledge = knowledge.to_prompts(max_tokens_kd)
except Exception as e:
if debug:
raise RuntimeError(f"Failed to query from FAISS. Error {e}.")
else:
print(f"Failed to query from FAISS. Error {e}. Use empty domain knowledge instead.")
domain_knowledge = ""
else:
print("Selected database doesn't exist or no database is selected.")
domain_knowledge = ""
###################################################################################################################
# Generate necessary media
###################################################################################################################
prompts = f"Title: {title}\n Contributions: {contributions}"
try:
components, usage = llm(systems=SYSTEM["components"], prompts=prompts, return_json=True)
log_usage(usage, "media")
except RuntimeError:
if debug:
raise RuntimeError("Failed to generate media.")
else:
print("Failed to generate media. Use default media.")
components = {}
print(f"The paper information has been initialized. References are saved to {bibtex_path}.")
paper = {}
paper_body = {}
paper["title"] = title
paper["description"] = contributions
paper["references"] = references
paper["body"] = paper_body
paper["bibtex"] = bibtex_path
paper["domain_knowledge"] = domain_knowledge
paper["components"] = components
# print(json.dumps(paper, indent=4))
return paper, destination_folder, all_paper_ids
# todo: use `all_paper_ids` to check if all citations are in this list
def generate_backgrounds(title, description="", template="ICLR2022", model="gpt-4"):
# todo: to match the current generation setup
paper, destination_folder, _ = _generation_setup(title, description, template, model)
for section in ["introduction", "related works", "backgrounds"]:
try:
usage = section_generation_bg(paper, section, destination_folder, model=model)
log_usage(usage, section)
except Exception as e:
message = f"Failed to generate {section}. {type(e).__name__} was raised: {e}"
print(message)
logging.info(message)
print(f"The paper '{title}' has been generated. Saved to {destination_folder}.")
input_dict = {"title": title, "description": description, "generator": "generate_backgrounds"}
filename = hash_name(input_dict) + ".zip"
return make_archive(destination_folder, filename)
def generate_draft(title, description="", # main input
tldr=True, max_kw_refs=10, bib_refs=None, max_tokens_ref=2048, # references
knowledge_database=None, max_tokens_kd=2048, query_counts=10, # domain knowledge
sections=None, model="gpt-4", template="ICLR2022", prompts_mode=False, # outputs parameters
):
"""
This function generates a draft paper using the provided information; it contains three steps: 1. Pre-processing:
Initializes the setup for paper generation and filters the sections to be included in the paper. 2. Processing:
Generates each section of the paper. 3. Post-processing: Creates backup copies of the paper and returns the paper
in a zipped format.
Parameters:
title (str): The title of the paper.
description (str, optional): A short description or abstract for the paper. Defaults to an empty string.
template (str, optional): The template to be used for paper generation. Defaults to "ICLR2022".
tldr (bool, optional): A flag indicating whether a TL;DR (Too Long; Didn't Read) summary should be used
for the collected papers. Defaults to True.
max_kw_refs (int, optional): The maximum number of references that can be associated with each keyword.
Defaults to 10.
sections (list, optional): The sections to be included in the paper. If not provided, all the standard
sections are included.
bib_refs (path to a bibtex file, optional).
model (str, optional): The language model to be used for paper generation. Defaults to "gpt-4".
Returns:
str: The path to the zipped file containing the generated paper and associated files.
Note: The function also handles errors that occur during section generation and retries a maximum of 4 times
before proceeding.
"""
def _filter_sections(sections):
ordered_sections = ["introduction", "related works", "backgrounds", "methodology", "experiments", "conclusion",
"abstract"]
return [section for section in ordered_sections if section in sections]
# pre-processing `sections` parameter;
print("================START================")
print(f"Generating the paper '{title}'.")
print("================PRE-PROCESSING================")
# make `sections` in a correct order
if sections is None:
sections = ["introduction", "related works", "backgrounds", "methodology", "experiments", "conclusion",
"abstract"]
else:
sections = _filter_sections(sections)
paper, destination_folder, _ = _generation_setup(title, description, template, tldr, max_kw_refs, bib_refs,
max_tokens_ref=max_tokens_ref, max_tokens_kd=max_tokens_kd,
query_counts=query_counts,
knowledge_database=knowledge_database)
# main components
prompts_dict = {}
print(f"================PROCESSING================")
for section in sections:
if prompts_mode:
prompts = generate_paper_prompts(paper, section)
prompts_dict[section] = prompts
continue
print(f"Generate {section} part...")
max_attempts = 4
attempts_count = 0
while attempts_count < max_attempts:
try:
usage = section_generation(paper, section, destination_folder, model=model)
print(f"{section} part has been generated. ")
log_usage(usage, section)
break
except Exception as e:
message = f"Failed to generate {section}. {type(e).__name__} was raised: {e}\n"
print(message)
logging.info(message)
attempts_count += 1
time.sleep(15)
# post-processing
print("================POST-PROCESSING================")
create_copies(destination_folder)
input_dict = {"title": title, "description": description, "generator": "generate_draft"}
filename = hash_name(input_dict) + ".zip"
print("\nMission completed.\n")
if prompts_mode:
filename = hash_name(input_dict) + ".json"
with open(filename, "w") as f:
json.dump(prompts_dict, f)
return filename
else:
return make_archive(destination_folder, filename)
if __name__ == "__main__":
import openai
openai.api_key = os.getenv("OPENAI_API_KEY")
target_title = "Playing Atari with Decentralized Reinforcement Learning"
output = generate_draft(target_title, knowledge_database="ml_textbook_test")
print(output)
|