File size: 5,452 Bytes
fb8c051
 
 
 
 
 
 
 
 
8aec19e
fb8c051
 
 
 
 
8aec19e
 
 
 
 
 
 
 
 
 
fb8c051
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8aec19e
fb8c051
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8aec19e
fb8c051
 
238735e
 
fb8c051
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from utils.references import References
from utils.prompts import generate_paper_prompts, generate_keywords_prompts, generate_experiments_prompts
from utils.gpt_interaction import get_responses, extract_responses, extract_keywords, extract_json
from utils.tex_processing import replace_title
from utils.figures import generate_random_figures
import datetime
import shutil
import time
import logging
import os

TOTAL_TOKENS = 0
TOTAL_PROMPTS_TOKENS = 0
TOTAL_COMPLETION_TOKENS = 0

def make_archive(source, destination):
    base = os.path.basename(destination)
    name = base.split('.')[0]
    format = base.split('.')[1]
    archive_from = os.path.dirname(source)
    archive_to = os.path.basename(source.strip(os.sep))
    shutil.make_archive(name, format, archive_from, archive_to)
    shutil.move('%s.%s'%(name,format), destination)
    return destination


def log_usage(usage, generating_target, print_out=True):
    global TOTAL_TOKENS
    global TOTAL_PROMPTS_TOKENS
    global TOTAL_COMPLETION_TOKENS

    prompts_tokens = usage['prompt_tokens']
    completion_tokens = usage['completion_tokens']
    total_tokens = usage['total_tokens']

    TOTAL_TOKENS += total_tokens
    TOTAL_PROMPTS_TOKENS += prompts_tokens
    TOTAL_COMPLETION_TOKENS += completion_tokens

    message = f"For generating {generating_target}, {total_tokens} tokens have been used ({prompts_tokens} for prompts; {completion_tokens} for completion). " \
              f"{TOTAL_TOKENS} tokens have been used in total."
    if print_out:
        print(message)
    logging.info(message)

def pipeline(paper, section, save_to_path, model):
    """
    The main pipeline of generating a section.
        1. Generate prompts.
        2. Get responses from AI assistant.
        3. Extract the section text.
        4. Save the text to .tex file.
    :return usage
    """
    print(f"Generating {section}...")
    prompts = generate_paper_prompts(paper, section)
    gpt_response, usage = get_responses(prompts, model)
    output = extract_responses(gpt_response)
    paper["body"][section] = output
    tex_file = save_to_path + f"{section}.tex"
    if section == "abstract":
        with open(tex_file, "w") as f:
            f.write(r"\begin{abstract}")
        with open(tex_file, "a") as f:
            f.write(output)
        with open(tex_file, "a") as f:
            f.write(r"\end{abstract}")
    else:
        with open(tex_file, "w") as f:
            f.write(f"\section{{{section}}}\n")
        with open(tex_file, "a") as f:
            f.write(output)
    time.sleep(5)
    print(f"{section} has been generated. Saved to {tex_file}.")
    return usage



def generate_draft(title, description="", template="ICLR2022", model="gpt-4"):
    """
    The main pipeline of generating a paper.
        1. Copy everything to the output folder.
        2. Create references.
        3. Generate each section using `pipeline`.
        4. Post-processing: check common errors, fill the title, ...
    """
    paper = {}
    paper_body = {}

    # Create a copy in the outputs folder.
    now = datetime.datetime.now()
    target_name = now.strftime("outputs_%Y%m%d_%H%M%S")
    source_folder = f"latex_templates/{template}"
    destination_folder = f"outputs/{target_name}"
    shutil.copytree(source_folder, destination_folder)

    bibtex_path = destination_folder + "/ref.bib"
    save_to_path = destination_folder +"/"
    replace_title(save_to_path, title)
    logging.basicConfig( level=logging.INFO, filename=save_to_path+"generation.log")

    # Generate keywords and references
    print("Initialize the paper information ...")
    prompts = generate_keywords_prompts(title, description)
    gpt_response, usage = get_responses(prompts, model)
    keywords = extract_keywords(gpt_response)
    log_usage(usage, "keywords")

    ref = References(load_papers = "")
    ref.collect_papers(keywords, method="arxiv")
    all_paper_ids = ref.to_bibtex(bibtex_path) #todo: this will used to check if all citations are in this list

    print(f"The paper information has been initialized. References are saved to {bibtex_path}.")

    paper["title"] = title
    paper["description"] = description
    paper["references"] = ref.to_prompts() # to_prompts(top_papers)
    paper["body"] = paper_body
    paper["bibtex"] = bibtex_path

    print("Generating figures ...")
    prompts = generate_experiments_prompts(paper)
    gpt_response, usage = get_responses(prompts, model)
    list_of_methods = list(extract_json(gpt_response))
    log_usage(usage, "figures")
    generate_random_figures(list_of_methods, save_to_path + "comparison.png")

    for section in ["introduction", "related works", "backgrounds", "methodology", "experiments", "conclusion", "abstract"]:
        try:
            usage = pipeline(paper, section, save_to_path, model=model)
            log_usage(usage, section)
        except Exception as e:
            print(f"Failed to generate {section} due to the error: {e}")
    print(f"The paper {title} has been generated. Saved to {save_to_path}.")
    return make_archive(save_to_path, save_to_path+"output.zip")

if __name__ == "__main__":
    # title = "Training Adversarial Generative Neural Network with Adaptive Dropout Rate"
    title = "Playing Atari Game with Deep Reinforcement Learning"
    description = ""
    template = "ICLR2022"
    model = "gpt-4"
    # model = "gpt-3.5-turbo"

    generate_draft(title, description, template, model)