File size: 17,930 Bytes
8aec19e
1b82d4c
a0d1776
365213e
9d9ac18
c160ff7
1b82d4c
6fe5041
8e698eb
3a7ead9
8e698eb
295e94f
8ef9348
d18c569
b4c6c2b
 
6fe5041
c304855
 
 
1b82d4c
 
 
 
09305ff
1b82d4c
 
 
 
 
09305ff
1b82d4c
 
a0d1776
 
 
 
1457d21
 
a0d1776
1b82d4c
1457d21
5df90ab
1457d21
 
 
 
 
 
 
9d9ac18
1457d21
 
 
 
 
 
 
 
 
 
 
 
5df90ab
 
 
 
 
 
 
 
 
 
 
1457d21
5df90ab
1457d21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d9ac18
8aec19e
c304855
1b82d4c
8aec19e
1457d21
c160ff7
 
 
1b82d4c
1457d21
 
 
 
 
 
2617428
1457d21
09305ff
1b82d4c
 
 
c304855
 
8ef9348
 
018eb6d
 
 
 
a0d1776
1b82d4c
1457d21
1b82d4c
8ef9348
c304855
ae495a3
1b82d4c
ae495a3
1b82d4c
 
 
 
1457d21
 
 
 
 
 
 
 
 
 
 
 
 
1b82d4c
ab295c7
6eb659a
 
 
 
c160ff7
 
 
a0d1776
1457d21
8e698eb
8aec19e
a0d1776
8ef9348
 
1457d21
 
8e698eb
 
365213e
 
8e698eb
 
1457d21
 
 
 
8e698eb
2dc9347
1457d21
 
 
 
 
 
 
 
 
5df90ab
1457d21
 
 
 
 
 
 
 
 
 
8e698eb
 
2dc9347
 
 
 
1457d21
 
 
 
 
 
 
 
2dc9347
 
 
8e698eb
365213e
8e698eb
 
365213e
8e698eb
2dc9347
 
 
 
 
1457d21
 
 
 
 
 
 
 
 
2dc9347
1457d21
 
 
 
2dc9347
8e698eb
 
 
 
2617428
 
 
 
 
 
 
 
 
 
c160ff7
 
8e698eb
 
 
c160ff7
8e698eb
 
 
 
a0d1776
1457d21
1b82d4c
c160ff7
1b82d4c
1457d21
 
 
 
 
 
 
 
 
8e698eb
c304855
1457d21
 
 
 
8aec19e
2617428
 
 
c160ff7
8aec19e
018eb6d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
import gradio as gr
import os
import openai
from auto_backgrounds import generate_backgrounds, generate_draft
from utils.file_operations import hash_name, list_folders
from references_generator import generate_top_k_references

# todo:
#   6. get logs when the procedure is not completed. *
#   7. 自己的文件库; 更多的prompts
#   2. 实现别的功能
#   3. Check API Key GPT-4 Support.
# future:
#   generation.log sometimes disappears (ignore this)
#   1. Check if there are any duplicated citations
#   2. Remove potential thebibliography and bibitem in .tex file

#######################################################################################################################
# Check if openai and cloud storage available
#######################################################################################################################
openai_key = os.getenv("OPENAI_API_KEY")
access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
if access_key_id is None or secret_access_key is None:
    print("Access keys are not provided. Outputs cannot be saved to AWS Cloud Storage.\n")
    IS_CACHE_AVAILABLE = False
else:
    IS_CACHE_AVAILABLE = True

if openai_key is None:
    print("OPENAI_API_KEY is not found in environment variables. The output may not be generated.\n")
    IS_OPENAI_API_KEY_AVAILABLE = False
else:
    openai.api_key = openai_key
    try:
        openai.Model.list()
        IS_OPENAI_API_KEY_AVAILABLE = True
    # except Exception as e:
    except openai.error.AuthenticationError:
        IS_OPENAI_API_KEY_AVAILABLE = False

DEFAULT_MODEL = "gpt-4" if IS_OPENAI_API_KEY_AVAILABLE else "gpt-3.5-turbo"
GPT4_INTERACTIVE = True if IS_OPENAI_API_KEY_AVAILABLE else False
DEFAULT_SECTIONS = ["introduction", "related works", "backgrounds", "methodology", "experiments",
                    "conclusion", "abstract"] if IS_OPENAI_API_KEY_AVAILABLE \
    else ["introduction", "related works"]

#######################################################################################################################
# Load the list of templates & knowledge databases
#######################################################################################################################
ALL_TEMPLATES = list_folders("latex_templates")
ALL_DATABASES = ["(None)"] + list_folders("knowledge_databases")

#######################################################################################################################
# Gradio UI
#######################################################################################################################
theme = gr.themes.Default(font=gr.themes.GoogleFont("Questrial"))
# .set(
#     background_fill_primary='#E5E4E2',
#     background_fill_secondary = '#F6F6F6',
#     button_primary_background_fill="#281A39"
# )
ANNOUNCEMENT = """
# Auto-Draft: 学术写作辅助工具

本Demo提供对[Auto-Draft](https://github.com/CCCBora/auto-draft)的auto_draft功能的测试. 
通过输入想要生成的论文名称(比如Playing atari with deep reinforcement learning),即可由AI辅助生成论文模板.    

***2023-06-13 Update***:  
1. 新增‘高级选项-Prompts模式’. 这个模式仅会输出用于生成论文的Prompts而不会生成论文本身. 可以根据自己的需求修改Prompts, 也可以把Prompts复制给其他语言模型. 
2. 把默认的ICLR 2022模板改成了Default模板. 不再显示ICLR的页眉页尾.  
3. 使用GPT-4模型:
    - 点击Duplicate this Space, 进入Settings-Repository secrets, 点击New Secret添加OPENAI_API_KEY为自己的OpenAI API Key
    - 或者可以访问[Auto-Draft-Private](https://huggingface.co/spaces/auto-academic/auto-draft-private)
    
如果有更多想法和建议欢迎加入QQ群里交流, 如果我在Space里更新了Key我会第一时间通知大家. 群号: ***249738228***."""

ACADEMIC_PAPER = """## 一键生成论文初稿
1. 在Title文本框中输入想要生成的论文名称(比如Playing Atari with Deep Reinforcement Learning). 
2. 点击Submit. 等待大概十五分钟(全文). 
3. 在右侧下载.zip格式的输出,在Overleaf上编译浏览.  
"""

REFERENCES = """## 一键搜索相关论文
(此功能已经被整合进一键生成论文初稿)
1. 在Title文本框中输入想要搜索文献的论文(比如Playing Atari with Deep Reinforcement Learning). 
2. 点击Submit. 等待大概十分钟. 
3. 在右侧JSON处会显示相关文献.  
"""

REFERENCES_INSTRUCTION = """### References
这一栏用于定义AI如何选取参考文献. 目前是两种方式混合:
1. GPT自动根据标题生成关键字,使用Semantic Scholar搜索引擎搜索文献,利用Specter获取Paper Embedding来自动选取最相关的文献作为GPT的参考资料.
2. 用户上传bibtex文件,使用Google Scholar搜索摘要作为GPT的参考资料. 
关于有希望利用本地文件来供GPT参考的功能将在未来实装.
"""

DOMAIN_KNOWLEDGE_INSTRUCTION = """### Domain Knowledge
这一栏用于定义AI的知识库. 将提供两种选择: 
1. 各个领域内由专家预先收集资料并构建的的FAISS向量数据库. 目前实装的数据库
* (None): 不使用任何知识库
* ml_textbook_test: 包含两本机器学习教材The Elements of Statistical Learning和Reinforcement Learning Theory and Algorithms. 仅用于测试知识库Pipeline.
2. 自行构建的使用OpenAI text-embedding-ada-002模型创建的FAISS向量数据库. (暂未实装)
"""

OUTPUTS_INSTRUCTION = """### Outputs
这一栏用于定义输出的内容:
* Template: 用于填装内容的LaTeX模板.
* Models: 使用GPT-4或者GPT-3.5-Turbo生成内容.
* Prompts模式: 不生成内容, 而是生成用于生成内容的Prompts. 可以手动复制到网页版或者其他语言模型中进行使用.
"""

OTHERS_INSTRUCTION = """### Others

"""

style_mapping = {True: "color:white;background-color:green",
                 False: "color:white;background-color:red"}  # todo: to match website's style
availability_mapping = {True: "AVAILABLE", False: "NOT AVAILABLE"}
STATUS = f'''## Huggingface Space Status  
 当`OpenAI API`显示AVAILABLE的时候这个Space可以直接使用.    
 当`OpenAI API`显示NOT AVAILABLE的时候这个Space可以通过在左侧输入OPENAI KEY来使用. 需要有GPT-4的API权限. 
 当`Cache`显示AVAILABLE的时候, 所有的输入和输出会被备份到我的云储存中. 显示NOT AVAILABLE的时候不影响实际使用. 
`OpenAI API`: <span style="{style_mapping[IS_OPENAI_API_KEY_AVAILABLE]}">{availability_mapping[IS_OPENAI_API_KEY_AVAILABLE]}</span>.  `Cache`: <span style="{style_mapping[IS_CACHE_AVAILABLE]}">{availability_mapping[IS_CACHE_AVAILABLE]}</span>.'''


def clear_inputs(*args):
    return "", ""


def clear_inputs_refs(*args):
    return "", 5


def wrapped_generator(
        paper_title, paper_description,  # main input
        openai_api_key=None, openai_url=None,  # key
        tldr=True, max_kw_refs=10, bib_refs=None, max_tokens_ref=2048,  # references
        knowledge_database=None, max_tokens_kd=2048, query_counts=10,  # domain knowledge
        paper_template="ICLR2022", selected_sections=None, model="gpt-4", prompts_mode=False,  # outputs parameters
        cache_mode=IS_CACHE_AVAILABLE  # handle cache mode
):
    # if `cache_mode` is True, then follow the following steps:
    #        check if "title"+"description" have been generated before
    #        if so, download from the cloud storage, return it
    #        if not, generate the result.
    if bib_refs is not None:
        bib_refs = bib_refs.name
    if openai_api_key is not None:
        openai.api_key = openai_api_key
        try:
            openai.Model.list()
        except Exception as e:
            raise gr.Error(f"Key错误. Error: {e}")

    if cache_mode:
        from utils.storage import list_all_files, download_file
        # check if "title"+"description" have been generated before
        input_dict = {"title": paper_title, "description": paper_description,
                      "generator": "generate_draft"}
        file_name = hash_name(input_dict) + ".zip"
        file_list = list_all_files()
        # print(f"{file_name} will be generated. Check the file list {file_list}")
        if file_name in file_list:
            # download from the cloud storage, return it
            download_file(file_name)
            return file_name
    try:
        output = generate_draft(
            paper_title, description=paper_description, # main input
           tldr=tldr, max_kw_refs=max_kw_refs, bib_refs=bib_refs, max_tokens_ref=max_tokens_ref,  # references
           knowledge_database=knowledge_database, max_tokens_kd=max_tokens_kd, query_counts=query_counts, # domain knowledge
           sections=selected_sections, model=model, template=paper_template, prompts_mode=prompts_mode, # outputs parameters
           )
        if cache_mode:
            from utils.storage import upload_file
            upload_file(output)
    except Exception as e:
        raise gr.Error(f"生成失败. Error: {e}")
    return output


def wrapped_references_generator(paper_title, num_refs, openai_api_key=None):
    if openai_api_key is not None:
        openai.api_key = openai_api_key
        openai.Model.list()
    return generate_top_k_references(paper_title, top_k=num_refs)


with gr.Blocks(theme=theme) as demo:
    gr.Markdown(ANNOUNCEMENT)

    with gr.Row():
        with gr.Column(scale=2):
            key = gr.Textbox(value=openai_key, lines=1, max_lines=1, label="OpenAI Key",
                             visible=not IS_OPENAI_API_KEY_AVAILABLE)
            url = gr.Textbox(value=None, lines=1, max_lines=1, label="URL",
                             visible=False)
            # 每个功能做一个tab
            with gr.Tab("学术论文"):
                gr.Markdown(ACADEMIC_PAPER)

                title = gr.Textbox(value="Playing Atari with Deep Reinforcement Learning", lines=1, max_lines=1,
                                   label="Title", info="论文标题")

                description_pp = gr.Textbox(lines=5, label="Description (Optional)", visible=True,
                                            info="这篇论文的主要贡献和创新点. (生成所有章节时共享这个信息, 保持生成的一致性.)")

                with gr.Accordion("高级设置", open=False):
                    with gr.Row():
                        with gr.Column(scale=1):
                            gr.Markdown(OUTPUTS_INSTRUCTION)
                        with gr.Column(scale=2):
                            with gr.Row():
                                template = gr.Dropdown(label="Template", choices=ALL_TEMPLATES, value="Default",
                                                       interactive=True,
                                                       info="生成论文的模板.")
                                model_selection = gr.Dropdown(label="Model", choices=["gpt-4", "gpt-3.5-turbo"],
                                                              value=DEFAULT_MODEL,
                                                              interactive=GPT4_INTERACTIVE,
                                                              info="生成论文用到的语言模型.")
                                prompts_mode = gr.Checkbox(value=False, visible=True, interactive=True,
                                                             label="Prompts模式",
                                                             info="只输出用于生成论文的Prompts, 可以复制到别的地方生成论文.")

                            sections = gr.CheckboxGroup(
                                choices=["introduction", "related works", "backgrounds", "methodology", "experiments",
                                         "conclusion", "abstract"],
                                type="value", label="生成章节", interactive=True, info="选择生成论文的哪些章节.",
                                value=DEFAULT_SECTIONS)

                    with gr.Row():
                        with gr.Column(scale=1):
                            gr.Markdown(REFERENCES_INSTRUCTION)

                        with gr.Column(scale=2):
                            max_kw_ref_slider = gr.Slider(minimum=1, maximum=20, value=10, step=1,
                                                       interactive=True, label="MAX_KW_REFS",
                                                       info="每个Keyword搜索几篇参考文献", visible=False)

                            max_tokens_ref_slider = gr.Slider(minimum=256, maximum=4096, value=2048, step=2,
                                                       interactive=True, label="MAX_TOKENS",
                                                       info="参考文献内容占用Prompts中的Token数")

                            tldr_checkbox = gr.Checkbox(value=True, label="TLDR;",
                                                        info="选择此筐表示将使用Semantic Scholar的TLDR作为文献的总结.",
                                                        interactive=True)
                            gr.Markdown('''
                            上传.bib文件提供AI需要参考的文献. 
                            ''')
                            bibtex_file = gr.File(label="Upload .bib file", file_types=["text"],
                                                  interactive=True)

                    with gr.Row():
                        with gr.Column(scale=1):
                            gr.Markdown(DOMAIN_KNOWLEDGE_INSTRUCTION)

                        with gr.Column(scale=2):
                            query_counts_slider = gr.Slider(minimum=1, maximum=20, value=10, step=1,
                                                       interactive=True, label="QUERY_COUNTS",
                                                       info="从知识库内检索多少条内容", visible=False)
                            max_tokens_kd_slider = gr.Slider(minimum=256, maximum=4096, value=2048, step=2,
                                                      interactive=True, label="MAX_TOKENS",
                                                      info="知识库内容占用Prompts中的Token数")
                            # template = gr.Dropdown(label="Template", choices=ALL_TEMPLATES, value="Default",
                            #                        interactive=True,
                            #                        info="生成论文的参考模板.")
                            domain_knowledge = gr.Dropdown(label="预载知识库",
                                                           choices=ALL_DATABASES,
                                                           value="(None)",
                                                           interactive=True,
                                                           info="使用预先构建的知识库.")
                            local_domain_knowledge = gr.File(label="本地知识库 (暂未实装)", interactive=False)
                with gr.Row():
                    clear_button_pp = gr.Button("Clear")
                    submit_button_pp = gr.Button("Submit", variant="primary")

            # with gr.Tab("文献搜索"):
            #     gr.Markdown(REFERENCES)
            #
            #     title_refs = gr.Textbox(value="Playing Atari with Deep Reinforcement Learning", lines=1, max_lines=1,
            #                             label="Title", info="论文标题")
            #     slider_refs = gr.Slider(minimum=1, maximum=100, value=5, step=1,
            #                             interactive=True, label="最相关的参考文献数目")
            #     with gr.Row():
            #         clear_button_refs = gr.Button("Clear")
            #         submit_button_refs = gr.Button("Submit", variant="primary")

            with gr.Tab("文献综述 (Coming soon!)"):
                gr.Markdown('''
                <h1  style="text-align: center;">Coming soon!</h1>
                ''')
            with gr.Tab("Github文档 (Coming soon!)"):
                gr.Markdown('''
                <h1  style="text-align: center;">Coming soon!</h1>
                ''')

        with gr.Column(scale=1):
            gr.Markdown(STATUS)
            file_output = gr.File(label="Output")
            json_output = gr.JSON(label="References")


    # def wrapped_generator(
    #         paper_title, paper_description,  # main input
    #         openai_api_key=None, openai_url=None,  # key
    #         tldr=True, max_kw_refs=10, bib_refs=None, max_tokens_ref=2048,  # references
    #         knowledge_database=None, max_tokens_kd=2048, query_counts=10,  # domain knowledge
    #         paper_template="ICLR2022", selected_sections=None, model="gpt-4", prompts_mode=False,  # outputs parameters
    #         cache_mode=IS_CACHE_AVAILABLE  # handle cache mode
    # ):
    clear_button_pp.click(fn=clear_inputs, inputs=[title, description_pp], outputs=[title, description_pp])
    submit_button_pp.click(fn=wrapped_generator,
                           inputs=[title, description_pp, key, url,
                                   tldr_checkbox, max_kw_ref_slider,  bibtex_file, max_tokens_ref_slider,
                                   domain_knowledge, max_tokens_kd_slider, query_counts_slider,
                                   template, sections, model_selection, prompts_mode], outputs=file_output)

    # clear_button_refs.click(fn=clear_inputs_refs, inputs=[title_refs, slider_refs], outputs=[title_refs, slider_refs])
    # submit_button_refs.click(fn=wrapped_references_generator,
    #                          inputs=[title_refs, slider_refs, key], outputs=json_output)

demo.queue(concurrency_count=1, max_size=5, api_open=False)
demo.launch(show_error=True)