Spaces:
Running
Running
shaocongma
commited on
Commit
•
c42190b
1
Parent(s):
acf8a73
Bug fix: error when abstract is None.
Browse files- api_wrapper.py +13 -5
- auto_backgrounds.py +22 -10
- utils/references.py +6 -2
- worker.py +172 -0
api_wrapper.py
CHANGED
@@ -12,18 +12,26 @@ todo:
|
|
12 |
If `generator_wrapper` returns nothing or Timeout, or raise any error:
|
13 |
Change Task status from Running to Failed.
|
14 |
'''
|
|
|
15 |
|
16 |
from auto_backgrounds import generate_draft
|
17 |
-
import json
|
|
|
18 |
|
19 |
|
20 |
-
GENERATOR_MAPPING = {"draft": generate_draft}
|
|
|
21 |
|
22 |
def generator_wrapper(path_to_config_json):
|
23 |
# Read configuration file and call corresponding function
|
24 |
with open(path_to_config_json, "r", encoding='utf-8') as f:
|
25 |
config = json.load(f)
|
26 |
-
|
27 |
-
generator = GENERATOR_MAPPING.get(config["generator"])
|
|
|
28 |
if generator is None:
|
29 |
-
|
|
|
|
|
|
|
|
|
|
12 |
If `generator_wrapper` returns nothing or Timeout, or raise any error:
|
13 |
Change Task status from Running to Failed.
|
14 |
'''
|
15 |
+
import os.path
|
16 |
|
17 |
from auto_backgrounds import generate_draft
|
18 |
+
import json, time
|
19 |
+
from utils.file_operations import make_archive
|
20 |
|
21 |
|
22 |
+
# GENERATOR_MAPPING = {"draft": generate_draft}
|
23 |
+
GENERATOR_MAPPING = {"draft": None}
|
24 |
|
25 |
def generator_wrapper(path_to_config_json):
|
26 |
# Read configuration file and call corresponding function
|
27 |
with open(path_to_config_json, "r", encoding='utf-8') as f:
|
28 |
config = json.load(f)
|
29 |
+
print("Configuration:", config)
|
30 |
+
# generator = GENERATOR_MAPPING.get(config["generator"])
|
31 |
+
generator = None
|
32 |
if generator is None:
|
33 |
+
# generate a fake ZIP file and upload
|
34 |
+
time.sleep(150)
|
35 |
+
zip_path = os.path.splitext(path_to_config_json)[0]+".zip"
|
36 |
+
return make_archive(path_to_config_json, zip_path)
|
37 |
+
|
auto_backgrounds.py
CHANGED
@@ -3,7 +3,6 @@ from utils.references import References
|
|
3 |
from utils.file_operations import hash_name, make_archive, copy_templates
|
4 |
from utils.tex_processing import create_copies
|
5 |
from section_generator import section_generation_bg, keywords_generation, figures_generation, section_generation
|
6 |
-
from references_generator import generate_top_k_references
|
7 |
import logging
|
8 |
import time
|
9 |
|
@@ -26,12 +25,14 @@ def log_usage(usage, generating_target, print_out=True):
|
|
26 |
TOTAL_PROMPTS_TOKENS += prompts_tokens
|
27 |
TOTAL_COMPLETION_TOKENS += completion_tokens
|
28 |
|
29 |
-
message = f"For generating {generating_target}, {total_tokens} tokens have been used
|
|
|
30 |
f"{TOTAL_TOKENS} tokens have been used in total.\n\n"
|
31 |
if print_out:
|
32 |
print(message)
|
33 |
logging.info(message)
|
34 |
|
|
|
35 |
def _generation_setup(title, description="", template="ICLR2022", tldr=False,
|
36 |
max_kw_refs=10, max_num_refs=50, bib_refs=None, max_tokens=2048):
|
37 |
"""
|
@@ -44,9 +45,12 @@ def _generation_setup(title, description="", template="ICLR2022", tldr=False,
|
|
44 |
title (str): The title of the paper.
|
45 |
description (str, optional): A short description or abstract for the paper. Defaults to an empty string.
|
46 |
template (str, optional): The template to be used for paper generation. Defaults to "ICLR2022".
|
47 |
-
tldr (bool, optional): A flag indicating whether a TL;DR (Too Long; Didn't Read) summary should be
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
50 |
bib_refs (list, optional): A list of pre-existing references in BibTeX format. Defaults to None.
|
51 |
|
52 |
Returns:
|
@@ -111,21 +115,29 @@ def generate_backgrounds(title, description="", template="ICLR2022", model="gpt-
|
|
111 |
def generate_draft(title, description="", template="ICLR2022",
|
112 |
tldr=True, max_kw_refs=10, max_num_refs=30, sections=None, bib_refs=None, model="gpt-4"):
|
113 |
# pre-processing `sections` parameter;
|
|
|
|
|
114 |
print("================PRE-PROCESSING================")
|
115 |
if sections is None:
|
116 |
sections = ["introduction", "related works", "backgrounds", "methodology", "experiments", "conclusion", "abstract"]
|
117 |
|
118 |
# todo: add more parameters; select which section to generate; select maximum refs.
|
119 |
-
|
|
|
|
|
|
|
|
|
120 |
|
121 |
# main components
|
|
|
122 |
for section in sections:
|
123 |
-
print(f"
|
124 |
max_attempts = 4
|
125 |
attempts_count = 0
|
126 |
while attempts_count < max_attempts:
|
127 |
try:
|
128 |
usage = section_generation(paper, section, destination_folder, model=model)
|
|
|
129 |
log_usage(usage, section)
|
130 |
break
|
131 |
except Exception as e:
|
@@ -153,7 +165,7 @@ if __name__ == "__main__":
|
|
153 |
import openai
|
154 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
155 |
|
156 |
-
|
157 |
-
|
158 |
-
output = generate_draft(
|
159 |
print(output)
|
|
|
3 |
from utils.file_operations import hash_name, make_archive, copy_templates
|
4 |
from utils.tex_processing import create_copies
|
5 |
from section_generator import section_generation_bg, keywords_generation, figures_generation, section_generation
|
|
|
6 |
import logging
|
7 |
import time
|
8 |
|
|
|
25 |
TOTAL_PROMPTS_TOKENS += prompts_tokens
|
26 |
TOTAL_COMPLETION_TOKENS += completion_tokens
|
27 |
|
28 |
+
message = f"For generating {generating_target}, {total_tokens} tokens have been used " \
|
29 |
+
f"({prompts_tokens} for prompts; {completion_tokens} for completion). " \
|
30 |
f"{TOTAL_TOKENS} tokens have been used in total.\n\n"
|
31 |
if print_out:
|
32 |
print(message)
|
33 |
logging.info(message)
|
34 |
|
35 |
+
|
36 |
def _generation_setup(title, description="", template="ICLR2022", tldr=False,
|
37 |
max_kw_refs=10, max_num_refs=50, bib_refs=None, max_tokens=2048):
|
38 |
"""
|
|
|
45 |
title (str): The title of the paper.
|
46 |
description (str, optional): A short description or abstract for the paper. Defaults to an empty string.
|
47 |
template (str, optional): The template to be used for paper generation. Defaults to "ICLR2022".
|
48 |
+
tldr (bool, optional): A flag indicating whether a TL;DR (Too Long; Didn't Read) summary should be used
|
49 |
+
for the collected papers. Defaults to False.
|
50 |
+
max_kw_refs (int, optional): The maximum number of references that can be associated with each keyword.
|
51 |
+
Defaults to 10.
|
52 |
+
max_num_refs (int, optional): The maximum number of references that can be included in the paper.
|
53 |
+
Defaults to 50.
|
54 |
bib_refs (list, optional): A list of pre-existing references in BibTeX format. Defaults to None.
|
55 |
|
56 |
Returns:
|
|
|
115 |
def generate_draft(title, description="", template="ICLR2022",
|
116 |
tldr=True, max_kw_refs=10, max_num_refs=30, sections=None, bib_refs=None, model="gpt-4"):
|
117 |
# pre-processing `sections` parameter;
|
118 |
+
print("================START================")
|
119 |
+
print(f"Generating {title}.")
|
120 |
print("================PRE-PROCESSING================")
|
121 |
if sections is None:
|
122 |
sections = ["introduction", "related works", "backgrounds", "methodology", "experiments", "conclusion", "abstract"]
|
123 |
|
124 |
# todo: add more parameters; select which section to generate; select maximum refs.
|
125 |
+
if model == "gpt-4":
|
126 |
+
max_tokens = 4096
|
127 |
+
else:
|
128 |
+
max_tokens = 2048
|
129 |
+
paper, destination_folder, _ = _generation_setup(title, description, template, tldr, max_kw_refs, max_num_refs, bib_refs, max_tokens=max_tokens)
|
130 |
|
131 |
# main components
|
132 |
+
print(f"================PROCESSING================")
|
133 |
for section in sections:
|
134 |
+
print(f"Generate {section} part...")
|
135 |
max_attempts = 4
|
136 |
attempts_count = 0
|
137 |
while attempts_count < max_attempts:
|
138 |
try:
|
139 |
usage = section_generation(paper, section, destination_folder, model=model)
|
140 |
+
print(f"{section} part has been generated. ")
|
141 |
log_usage(usage, section)
|
142 |
break
|
143 |
except Exception as e:
|
|
|
165 |
import openai
|
166 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
167 |
|
168 |
+
target_title = "Using interpretable boosting algorithms for modeling environmental and agricultural data"
|
169 |
+
target_description = ""
|
170 |
+
output = generate_draft(target_title, target_description, tldr=True, max_kw_refs=10)
|
171 |
print(output)
|
utils/references.py
CHANGED
@@ -334,8 +334,12 @@ class References:
|
|
334 |
prompts = {}
|
335 |
tokens = 0
|
336 |
for paper in result:
|
337 |
-
|
338 |
-
|
|
|
|
|
|
|
|
|
339 |
if tokens >= max_tokens:
|
340 |
break
|
341 |
return prompts
|
|
|
334 |
prompts = {}
|
335 |
tokens = 0
|
336 |
for paper in result:
|
337 |
+
abstract = paper.get("abstract")
|
338 |
+
if abstract is not None and isinstance(abstract, str):
|
339 |
+
prompts[paper["paper_id"]] = paper["abstract"]
|
340 |
+
tokens += tiktoken_len(paper["abstract"])
|
341 |
+
else:
|
342 |
+
prompts[paper["paper_id"]] = " "
|
343 |
if tokens >= max_tokens:
|
344 |
break
|
345 |
return prompts
|
worker.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This script is only used for service-side host.
|
3 |
+
'''
|
4 |
+
import boto3
|
5 |
+
import os, time
|
6 |
+
from api_wrapper import generator_wrapper
|
7 |
+
from sqlalchemy import create_engine, Table, MetaData, update, select
|
8 |
+
from sqlalchemy.orm import sessionmaker
|
9 |
+
from sqlalchemy import inspect
|
10 |
+
|
11 |
+
QUEUE_URL = os.getenv('QUEUE_URL')
|
12 |
+
AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
|
13 |
+
AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_ACCESS_KEY')
|
14 |
+
BUCKET_NAME = os.getenv('BUCKET_NAME')
|
15 |
+
DB_STRING = os.getenv('DATABASE_STRING')
|
16 |
+
|
17 |
+
# Create engine
|
18 |
+
ENGINE = create_engine(DB_STRING)
|
19 |
+
SESSION = sessionmaker(bind=ENGINE)
|
20 |
+
|
21 |
+
|
22 |
+
#######################################################################################################################
|
23 |
+
# Amazon SQS Handler
|
24 |
+
#######################################################################################################################
|
25 |
+
def get_sqs_client():
|
26 |
+
sqs = boto3.client('sqs', region_name="us-east-2",
|
27 |
+
aws_access_key_id=AWS_ACCESS_KEY_ID,
|
28 |
+
aws_secret_access_key=AWS_SECRET_ACCESS_KEY)
|
29 |
+
return sqs
|
30 |
+
|
31 |
+
|
32 |
+
def receive_message():
|
33 |
+
sqs = get_sqs_client()
|
34 |
+
message = sqs.receive_message(QueueUrl=QUEUE_URL)
|
35 |
+
if message.get('Messages') is not None:
|
36 |
+
receipt_handle = message['Messages'][0]['ReceiptHandle']
|
37 |
+
else:
|
38 |
+
receipt_handle = None
|
39 |
+
return message, receipt_handle
|
40 |
+
|
41 |
+
|
42 |
+
def delete_message(receipt_handle):
|
43 |
+
sqs = get_sqs_client()
|
44 |
+
response = sqs.delete_message(QueueUrl=QUEUE_URL, ReceiptHandle=receipt_handle)
|
45 |
+
return response
|
46 |
+
|
47 |
+
|
48 |
+
#######################################################################################################################
|
49 |
+
# AWS S3 Handler
|
50 |
+
#######################################################################################################################
|
51 |
+
def get_s3_client():
|
52 |
+
access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
|
53 |
+
secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
|
54 |
+
session = boto3.Session(
|
55 |
+
aws_access_key_id=access_key_id,
|
56 |
+
aws_secret_access_key=secret_access_key,
|
57 |
+
)
|
58 |
+
s3 = session.resource('s3')
|
59 |
+
bucket = s3.Bucket(BUCKET_NAME)
|
60 |
+
return s3, bucket
|
61 |
+
|
62 |
+
|
63 |
+
def upload_file(file_name, target_name=None):
|
64 |
+
s3, _ = get_s3_client()
|
65 |
+
|
66 |
+
if target_name is None:
|
67 |
+
target_name = file_name
|
68 |
+
s3.meta.client.upload_file(Filename=file_name, Bucket=BUCKET_NAME, Key=target_name)
|
69 |
+
print(f"The file {file_name} has been uploaded!")
|
70 |
+
|
71 |
+
|
72 |
+
def download_file(file_name):
|
73 |
+
""" Download `file_name` from the bucket.
|
74 |
+
Bucket (str) – The name of the bucket to download from.
|
75 |
+
Key (str) – The name of the key to download from.
|
76 |
+
Filename (str) – The path to the file to download to.
|
77 |
+
"""
|
78 |
+
s3, _ = get_s3_client()
|
79 |
+
s3.meta.client.download_file(Bucket=BUCKET_NAME, Key=file_name, Filename=os.path.basename(file_name))
|
80 |
+
print(f"The file {file_name} has been downloaded!")
|
81 |
+
|
82 |
+
|
83 |
+
#######################################################################################################################
|
84 |
+
# AWS SQL Handler
|
85 |
+
#######################################################################################################################
|
86 |
+
def modify_status(task_id, new_status):
|
87 |
+
session = SESSION()
|
88 |
+
metadata = MetaData()
|
89 |
+
task_to_update = task_id
|
90 |
+
task_table = Table('task', metadata, autoload_with=ENGINE)
|
91 |
+
stmt = select(task_table).where(task_table.c.task_id == task_to_update)
|
92 |
+
# Execute the statement
|
93 |
+
with ENGINE.connect() as connection:
|
94 |
+
result = connection.execute(stmt)
|
95 |
+
|
96 |
+
# Fetch the first result (if exists)
|
97 |
+
task_data = result.fetchone()
|
98 |
+
|
99 |
+
# If user_data is not None, the user exists and we can update the password
|
100 |
+
if task_data:
|
101 |
+
# Update statement
|
102 |
+
stmt = (
|
103 |
+
update(task_table).
|
104 |
+
where(task_table.c.task_id == task_to_update).
|
105 |
+
values(status=new_status)
|
106 |
+
)
|
107 |
+
# Execute the statement and commit
|
108 |
+
result = connection.execute(stmt)
|
109 |
+
connection.commit()
|
110 |
+
# Close the session
|
111 |
+
session.close()
|
112 |
+
|
113 |
+
#######################################################################################################################
|
114 |
+
# Pipline
|
115 |
+
#######################################################################################################################
|
116 |
+
def pipeline(message_count=0, query_interval=10):
|
117 |
+
# status: 0 - pending (default), 1 - running, 2 - completed, 3 - failed
|
118 |
+
|
119 |
+
# Query a message from SQS
|
120 |
+
msg, handle = receive_message()
|
121 |
+
if handle is None:
|
122 |
+
print("No message in SQS. ")
|
123 |
+
time.sleep(query_interval)
|
124 |
+
else:
|
125 |
+
print("===============================================================================================")
|
126 |
+
print(f"MESSAGE COUNT: {message_count}")
|
127 |
+
print("===============================================================================================")
|
128 |
+
config_s3_path = msg['Messages'][0]['Body']
|
129 |
+
config_s3_dir = os.path.dirname(config_s3_path)
|
130 |
+
config_local_path = os.path.basename(config_s3_path)
|
131 |
+
task_id, _ = os.path.splitext(config_local_path)
|
132 |
+
|
133 |
+
print("Initializing ...")
|
134 |
+
print("Configuration file on S3: ", config_s3_path)
|
135 |
+
print("Configuration file on S3 (Directory): ", config_s3_dir)
|
136 |
+
print("Local file path: ", config_local_path)
|
137 |
+
print("Task id: ", task_id)
|
138 |
+
|
139 |
+
print(f"Success in receiving message: {msg}")
|
140 |
+
print(f"Configuration file path: {config_s3_path}")
|
141 |
+
|
142 |
+
# Process the downloaded configuration file
|
143 |
+
download_file(config_s3_path)
|
144 |
+
modify_status(task_id, 1) # status: 0 - pending (default), 1 - running, 2 - completed, 3 - failed
|
145 |
+
delete_message(handle)
|
146 |
+
print(f"Success in the initialization. Message deleted.")
|
147 |
+
|
148 |
+
print("Running ...")
|
149 |
+
# try:
|
150 |
+
zip_path = generator_wrapper(config_local_path)
|
151 |
+
# Upload the generated file to S3
|
152 |
+
upload_to = os.path.join(config_s3_dir, zip_path).replace("\\", "/")
|
153 |
+
|
154 |
+
print("Local file path (ZIP): ", zip_path)
|
155 |
+
print("Upload to S3: ", upload_to)
|
156 |
+
upload_file(zip_path, upload_to)
|
157 |
+
modify_status(task_id, 2) # status: 0 - pending (default), 1 - running, 2 - completed, 3 - failed, 4 - deleted
|
158 |
+
print(f"Success in generating the paper.")
|
159 |
+
|
160 |
+
# Complete.
|
161 |
+
print("Task completed.")
|
162 |
+
|
163 |
+
|
164 |
+
def initialize_everything():
|
165 |
+
# Clear S3
|
166 |
+
|
167 |
+
# Clear SQS
|
168 |
+
pass
|
169 |
+
|
170 |
+
|
171 |
+
if __name__ == "__main__":
|
172 |
+
pipeline()
|