shaocongma commited on
Commit
c42190b
1 Parent(s): acf8a73

Bug fix: error when abstract is None.

Browse files
Files changed (4) hide show
  1. api_wrapper.py +13 -5
  2. auto_backgrounds.py +22 -10
  3. utils/references.py +6 -2
  4. worker.py +172 -0
api_wrapper.py CHANGED
@@ -12,18 +12,26 @@ todo:
12
  If `generator_wrapper` returns nothing or Timeout, or raise any error:
13
  Change Task status from Running to Failed.
14
  '''
 
15
 
16
  from auto_backgrounds import generate_draft
17
- import json
 
18
 
19
 
20
- GENERATOR_MAPPING = {"draft": generate_draft}
 
21
 
22
  def generator_wrapper(path_to_config_json):
23
  # Read configuration file and call corresponding function
24
  with open(path_to_config_json, "r", encoding='utf-8') as f:
25
  config = json.load(f)
26
-
27
- generator = GENERATOR_MAPPING.get(config["generator"])
 
28
  if generator is None:
29
- pass
 
 
 
 
 
12
  If `generator_wrapper` returns nothing or Timeout, or raise any error:
13
  Change Task status from Running to Failed.
14
  '''
15
+ import os.path
16
 
17
  from auto_backgrounds import generate_draft
18
+ import json, time
19
+ from utils.file_operations import make_archive
20
 
21
 
22
+ # GENERATOR_MAPPING = {"draft": generate_draft}
23
+ GENERATOR_MAPPING = {"draft": None}
24
 
25
  def generator_wrapper(path_to_config_json):
26
  # Read configuration file and call corresponding function
27
  with open(path_to_config_json, "r", encoding='utf-8') as f:
28
  config = json.load(f)
29
+ print("Configuration:", config)
30
+ # generator = GENERATOR_MAPPING.get(config["generator"])
31
+ generator = None
32
  if generator is None:
33
+ # generate a fake ZIP file and upload
34
+ time.sleep(150)
35
+ zip_path = os.path.splitext(path_to_config_json)[0]+".zip"
36
+ return make_archive(path_to_config_json, zip_path)
37
+
auto_backgrounds.py CHANGED
@@ -3,7 +3,6 @@ from utils.references import References
3
  from utils.file_operations import hash_name, make_archive, copy_templates
4
  from utils.tex_processing import create_copies
5
  from section_generator import section_generation_bg, keywords_generation, figures_generation, section_generation
6
- from references_generator import generate_top_k_references
7
  import logging
8
  import time
9
 
@@ -26,12 +25,14 @@ def log_usage(usage, generating_target, print_out=True):
26
  TOTAL_PROMPTS_TOKENS += prompts_tokens
27
  TOTAL_COMPLETION_TOKENS += completion_tokens
28
 
29
- message = f"For generating {generating_target}, {total_tokens} tokens have been used ({prompts_tokens} for prompts; {completion_tokens} for completion). " \
 
30
  f"{TOTAL_TOKENS} tokens have been used in total.\n\n"
31
  if print_out:
32
  print(message)
33
  logging.info(message)
34
 
 
35
  def _generation_setup(title, description="", template="ICLR2022", tldr=False,
36
  max_kw_refs=10, max_num_refs=50, bib_refs=None, max_tokens=2048):
37
  """
@@ -44,9 +45,12 @@ def _generation_setup(title, description="", template="ICLR2022", tldr=False,
44
  title (str): The title of the paper.
45
  description (str, optional): A short description or abstract for the paper. Defaults to an empty string.
46
  template (str, optional): The template to be used for paper generation. Defaults to "ICLR2022".
47
- tldr (bool, optional): A flag indicating whether a TL;DR (Too Long; Didn't Read) summary should be generated for the collected papers. Defaults to False.
48
- max_kw_refs (int, optional): The maximum number of references that can be associated with each keyword. Defaults to 10.
49
- max_num_refs (int, optional): The maximum number of references that can be included in the paper. Defaults to 50.
 
 
 
50
  bib_refs (list, optional): A list of pre-existing references in BibTeX format. Defaults to None.
51
 
52
  Returns:
@@ -111,21 +115,29 @@ def generate_backgrounds(title, description="", template="ICLR2022", model="gpt-
111
  def generate_draft(title, description="", template="ICLR2022",
112
  tldr=True, max_kw_refs=10, max_num_refs=30, sections=None, bib_refs=None, model="gpt-4"):
113
  # pre-processing `sections` parameter;
 
 
114
  print("================PRE-PROCESSING================")
115
  if sections is None:
116
  sections = ["introduction", "related works", "backgrounds", "methodology", "experiments", "conclusion", "abstract"]
117
 
118
  # todo: add more parameters; select which section to generate; select maximum refs.
119
- paper, destination_folder, _ = _generation_setup(title, description, template, tldr, max_kw_refs, max_num_refs, bib_refs)
 
 
 
 
120
 
121
  # main components
 
122
  for section in sections:
123
- print(f"================Generate {section}================")
124
  max_attempts = 4
125
  attempts_count = 0
126
  while attempts_count < max_attempts:
127
  try:
128
  usage = section_generation(paper, section, destination_folder, model=model)
 
129
  log_usage(usage, section)
130
  break
131
  except Exception as e:
@@ -153,7 +165,7 @@ if __name__ == "__main__":
153
  import openai
154
  openai.api_key = os.getenv("OPENAI_API_KEY")
155
 
156
- title = "Using interpretable boosting algorithms for modeling environmental and agricultural data"
157
- description = ""
158
- output = generate_draft(title, description, tldr=True, max_kw_refs=10)
159
  print(output)
 
3
  from utils.file_operations import hash_name, make_archive, copy_templates
4
  from utils.tex_processing import create_copies
5
  from section_generator import section_generation_bg, keywords_generation, figures_generation, section_generation
 
6
  import logging
7
  import time
8
 
 
25
  TOTAL_PROMPTS_TOKENS += prompts_tokens
26
  TOTAL_COMPLETION_TOKENS += completion_tokens
27
 
28
+ message = f"For generating {generating_target}, {total_tokens} tokens have been used " \
29
+ f"({prompts_tokens} for prompts; {completion_tokens} for completion). " \
30
  f"{TOTAL_TOKENS} tokens have been used in total.\n\n"
31
  if print_out:
32
  print(message)
33
  logging.info(message)
34
 
35
+
36
  def _generation_setup(title, description="", template="ICLR2022", tldr=False,
37
  max_kw_refs=10, max_num_refs=50, bib_refs=None, max_tokens=2048):
38
  """
 
45
  title (str): The title of the paper.
46
  description (str, optional): A short description or abstract for the paper. Defaults to an empty string.
47
  template (str, optional): The template to be used for paper generation. Defaults to "ICLR2022".
48
+ tldr (bool, optional): A flag indicating whether a TL;DR (Too Long; Didn't Read) summary should be used
49
+ for the collected papers. Defaults to False.
50
+ max_kw_refs (int, optional): The maximum number of references that can be associated with each keyword.
51
+ Defaults to 10.
52
+ max_num_refs (int, optional): The maximum number of references that can be included in the paper.
53
+ Defaults to 50.
54
  bib_refs (list, optional): A list of pre-existing references in BibTeX format. Defaults to None.
55
 
56
  Returns:
 
115
  def generate_draft(title, description="", template="ICLR2022",
116
  tldr=True, max_kw_refs=10, max_num_refs=30, sections=None, bib_refs=None, model="gpt-4"):
117
  # pre-processing `sections` parameter;
118
+ print("================START================")
119
+ print(f"Generating {title}.")
120
  print("================PRE-PROCESSING================")
121
  if sections is None:
122
  sections = ["introduction", "related works", "backgrounds", "methodology", "experiments", "conclusion", "abstract"]
123
 
124
  # todo: add more parameters; select which section to generate; select maximum refs.
125
+ if model == "gpt-4":
126
+ max_tokens = 4096
127
+ else:
128
+ max_tokens = 2048
129
+ paper, destination_folder, _ = _generation_setup(title, description, template, tldr, max_kw_refs, max_num_refs, bib_refs, max_tokens=max_tokens)
130
 
131
  # main components
132
+ print(f"================PROCESSING================")
133
  for section in sections:
134
+ print(f"Generate {section} part...")
135
  max_attempts = 4
136
  attempts_count = 0
137
  while attempts_count < max_attempts:
138
  try:
139
  usage = section_generation(paper, section, destination_folder, model=model)
140
+ print(f"{section} part has been generated. ")
141
  log_usage(usage, section)
142
  break
143
  except Exception as e:
 
165
  import openai
166
  openai.api_key = os.getenv("OPENAI_API_KEY")
167
 
168
+ target_title = "Using interpretable boosting algorithms for modeling environmental and agricultural data"
169
+ target_description = ""
170
+ output = generate_draft(target_title, target_description, tldr=True, max_kw_refs=10)
171
  print(output)
utils/references.py CHANGED
@@ -334,8 +334,12 @@ class References:
334
  prompts = {}
335
  tokens = 0
336
  for paper in result:
337
- prompts[paper["paper_id"]] = paper["abstract"]
338
- tokens += tiktoken_len(paper["abstract"])
 
 
 
 
339
  if tokens >= max_tokens:
340
  break
341
  return prompts
 
334
  prompts = {}
335
  tokens = 0
336
  for paper in result:
337
+ abstract = paper.get("abstract")
338
+ if abstract is not None and isinstance(abstract, str):
339
+ prompts[paper["paper_id"]] = paper["abstract"]
340
+ tokens += tiktoken_len(paper["abstract"])
341
+ else:
342
+ prompts[paper["paper_id"]] = " "
343
  if tokens >= max_tokens:
344
  break
345
  return prompts
worker.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This script is only used for service-side host.
3
+ '''
4
+ import boto3
5
+ import os, time
6
+ from api_wrapper import generator_wrapper
7
+ from sqlalchemy import create_engine, Table, MetaData, update, select
8
+ from sqlalchemy.orm import sessionmaker
9
+ from sqlalchemy import inspect
10
+
11
+ QUEUE_URL = os.getenv('QUEUE_URL')
12
+ AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
13
+ AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_ACCESS_KEY')
14
+ BUCKET_NAME = os.getenv('BUCKET_NAME')
15
+ DB_STRING = os.getenv('DATABASE_STRING')
16
+
17
+ # Create engine
18
+ ENGINE = create_engine(DB_STRING)
19
+ SESSION = sessionmaker(bind=ENGINE)
20
+
21
+
22
+ #######################################################################################################################
23
+ # Amazon SQS Handler
24
+ #######################################################################################################################
25
+ def get_sqs_client():
26
+ sqs = boto3.client('sqs', region_name="us-east-2",
27
+ aws_access_key_id=AWS_ACCESS_KEY_ID,
28
+ aws_secret_access_key=AWS_SECRET_ACCESS_KEY)
29
+ return sqs
30
+
31
+
32
+ def receive_message():
33
+ sqs = get_sqs_client()
34
+ message = sqs.receive_message(QueueUrl=QUEUE_URL)
35
+ if message.get('Messages') is not None:
36
+ receipt_handle = message['Messages'][0]['ReceiptHandle']
37
+ else:
38
+ receipt_handle = None
39
+ return message, receipt_handle
40
+
41
+
42
+ def delete_message(receipt_handle):
43
+ sqs = get_sqs_client()
44
+ response = sqs.delete_message(QueueUrl=QUEUE_URL, ReceiptHandle=receipt_handle)
45
+ return response
46
+
47
+
48
+ #######################################################################################################################
49
+ # AWS S3 Handler
50
+ #######################################################################################################################
51
+ def get_s3_client():
52
+ access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
53
+ secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
54
+ session = boto3.Session(
55
+ aws_access_key_id=access_key_id,
56
+ aws_secret_access_key=secret_access_key,
57
+ )
58
+ s3 = session.resource('s3')
59
+ bucket = s3.Bucket(BUCKET_NAME)
60
+ return s3, bucket
61
+
62
+
63
+ def upload_file(file_name, target_name=None):
64
+ s3, _ = get_s3_client()
65
+
66
+ if target_name is None:
67
+ target_name = file_name
68
+ s3.meta.client.upload_file(Filename=file_name, Bucket=BUCKET_NAME, Key=target_name)
69
+ print(f"The file {file_name} has been uploaded!")
70
+
71
+
72
+ def download_file(file_name):
73
+ """ Download `file_name` from the bucket.
74
+ Bucket (str) – The name of the bucket to download from.
75
+ Key (str) – The name of the key to download from.
76
+ Filename (str) – The path to the file to download to.
77
+ """
78
+ s3, _ = get_s3_client()
79
+ s3.meta.client.download_file(Bucket=BUCKET_NAME, Key=file_name, Filename=os.path.basename(file_name))
80
+ print(f"The file {file_name} has been downloaded!")
81
+
82
+
83
+ #######################################################################################################################
84
+ # AWS SQL Handler
85
+ #######################################################################################################################
86
+ def modify_status(task_id, new_status):
87
+ session = SESSION()
88
+ metadata = MetaData()
89
+ task_to_update = task_id
90
+ task_table = Table('task', metadata, autoload_with=ENGINE)
91
+ stmt = select(task_table).where(task_table.c.task_id == task_to_update)
92
+ # Execute the statement
93
+ with ENGINE.connect() as connection:
94
+ result = connection.execute(stmt)
95
+
96
+ # Fetch the first result (if exists)
97
+ task_data = result.fetchone()
98
+
99
+ # If user_data is not None, the user exists and we can update the password
100
+ if task_data:
101
+ # Update statement
102
+ stmt = (
103
+ update(task_table).
104
+ where(task_table.c.task_id == task_to_update).
105
+ values(status=new_status)
106
+ )
107
+ # Execute the statement and commit
108
+ result = connection.execute(stmt)
109
+ connection.commit()
110
+ # Close the session
111
+ session.close()
112
+
113
+ #######################################################################################################################
114
+ # Pipline
115
+ #######################################################################################################################
116
+ def pipeline(message_count=0, query_interval=10):
117
+ # status: 0 - pending (default), 1 - running, 2 - completed, 3 - failed
118
+
119
+ # Query a message from SQS
120
+ msg, handle = receive_message()
121
+ if handle is None:
122
+ print("No message in SQS. ")
123
+ time.sleep(query_interval)
124
+ else:
125
+ print("===============================================================================================")
126
+ print(f"MESSAGE COUNT: {message_count}")
127
+ print("===============================================================================================")
128
+ config_s3_path = msg['Messages'][0]['Body']
129
+ config_s3_dir = os.path.dirname(config_s3_path)
130
+ config_local_path = os.path.basename(config_s3_path)
131
+ task_id, _ = os.path.splitext(config_local_path)
132
+
133
+ print("Initializing ...")
134
+ print("Configuration file on S3: ", config_s3_path)
135
+ print("Configuration file on S3 (Directory): ", config_s3_dir)
136
+ print("Local file path: ", config_local_path)
137
+ print("Task id: ", task_id)
138
+
139
+ print(f"Success in receiving message: {msg}")
140
+ print(f"Configuration file path: {config_s3_path}")
141
+
142
+ # Process the downloaded configuration file
143
+ download_file(config_s3_path)
144
+ modify_status(task_id, 1) # status: 0 - pending (default), 1 - running, 2 - completed, 3 - failed
145
+ delete_message(handle)
146
+ print(f"Success in the initialization. Message deleted.")
147
+
148
+ print("Running ...")
149
+ # try:
150
+ zip_path = generator_wrapper(config_local_path)
151
+ # Upload the generated file to S3
152
+ upload_to = os.path.join(config_s3_dir, zip_path).replace("\\", "/")
153
+
154
+ print("Local file path (ZIP): ", zip_path)
155
+ print("Upload to S3: ", upload_to)
156
+ upload_file(zip_path, upload_to)
157
+ modify_status(task_id, 2) # status: 0 - pending (default), 1 - running, 2 - completed, 3 - failed, 4 - deleted
158
+ print(f"Success in generating the paper.")
159
+
160
+ # Complete.
161
+ print("Task completed.")
162
+
163
+
164
+ def initialize_everything():
165
+ # Clear S3
166
+
167
+ # Clear SQS
168
+ pass
169
+
170
+
171
+ if __name__ == "__main__":
172
+ pipeline()