alfraser commited on
Commit
acb7b9c
·
1 Parent(s): 51d7c53

Tidied up comments

Browse files
Files changed (1) hide show
  1. src/data_synthesis/generate_data.py +53 -1
src/data_synthesis/generate_data.py CHANGED
@@ -15,12 +15,18 @@ from src.common import data_dir
15
 
16
 
17
  class Review:
 
 
 
18
  def __init__(self, stars: int, review_text: str):
19
  self.stars = stars
20
  self.review_text = review_text
21
 
22
 
23
  class Product:
 
 
 
24
  def __init__(self, category: str, name: str, description: str, price: float, features: List[str], reviews: List[Review]):
25
  self.category = category
26
  self.name = name
@@ -32,7 +38,7 @@ class Product:
32
 
33
  class DataPrompt:
34
  """
35
- Holder for static prompt generation functions
36
  """
37
  @staticmethod
38
  def prompt_setup() -> str:
@@ -94,6 +100,9 @@ Please format the response as json in this style:
94
 
95
 
96
  def generate_products(category: str, features: List[str], k: int = 20):
 
 
 
97
  prompt = DataPrompt.products_for_category(category, features, k)
98
  response = openai.ChatCompletion.create(
99
  model="gpt-3.5-turbo-16k",
@@ -108,16 +117,27 @@ def generate_products(category: str, features: List[str], k: int = 20):
108
 
109
 
110
  def category_product_file(category: str) -> str:
 
 
 
111
  output_file_name = f"products_{category.lower().replace(' ', '_')}.json"
112
  return os.path.join(data_dir, 'json', output_file_name)
113
 
114
 
115
  def category_review_file(category: str) -> str:
 
 
 
116
  output_file_name = f"reviews_{category.lower().replace(' ', '_')}.json"
117
  return os.path.join(data_dir, 'json', output_file_name)
118
 
119
 
120
  def products_for_category(category: str) -> List[Product]:
 
 
 
 
 
121
  cat_file = category_product_file(category)
122
  if not os.path.exists(cat_file):
123
  return []
@@ -141,6 +161,10 @@ def products_for_category(category: str) -> List[Product]:
141
 
142
 
143
  def product_names_for_category(category: str) -> List[str]:
 
 
 
 
144
  cat_file = category_product_file(category)
145
  if not os.path.exists(cat_file):
146
  return []
@@ -154,6 +178,10 @@ def product_names_for_category(category: str) -> List[str]:
154
 
155
 
156
  def add_products(category: str, product_json: str, k: int) -> None:
 
 
 
 
157
  cat_file = category_product_file(category)
158
  if not os.path.exists(cat_file):
159
  with open(cat_file, 'w') as f:
@@ -173,6 +201,10 @@ def add_products(category: str, product_json: str, k: int) -> None:
173
 
174
 
175
  def get_categories_and_features() -> Dict[str, List[str]]:
 
 
 
 
176
  product_features_file = os.path.join(data_dir, 'json', 'product_features.json')
177
  cats_and_feats = {}
178
  with open(product_features_file, 'r') as f:
@@ -185,6 +217,10 @@ def get_categories_and_features() -> Dict[str, List[str]]:
185
 
186
 
187
  def generate_all_products(target_count=40):
 
 
 
 
188
  product_features_file = os.path.join(data_dir, 'product_features.json')
189
 
190
  with open(product_features_file, 'r') as f:
@@ -202,6 +238,9 @@ def generate_all_products(target_count=40):
202
 
203
 
204
  def dump_products_to_csv():
 
 
 
205
  cats = get_categories_and_features().keys()
206
  cat_keys = []
207
  for cat in cats:
@@ -213,11 +252,17 @@ def dump_products_to_csv():
213
 
214
 
215
  def generate_reviews(target_count: int):
 
 
 
216
  for cat in get_categories_and_features().keys():
217
  generate_reviews_for_category(cat, target_count)
218
 
219
 
220
  def generate_reviews_for_category(category: str, target_count: int):
 
 
 
221
  batch_size = 25 # Max number of reviews to request in one go from GPT
222
 
223
  # Set up a loop to continue trying to find more work to do until complete
@@ -249,6 +294,9 @@ def generate_reviews_for_category(category: str, target_count: int):
249
 
250
 
251
  def generate_reviews_for_product(product: Product, k: int):
 
 
 
252
  prompt = DataPrompt.reviews_for_product(product, k)
253
  response = openai.ChatCompletion.create(
254
  model="gpt-3.5-turbo-16k",
@@ -263,6 +311,10 @@ def generate_reviews_for_product(product: Product, k: int):
263
 
264
 
265
  def add_reviews_to_product(reviews_json: str, product: Product):
 
 
 
 
266
  reviews_json = json.loads(reviews_json)
267
  reviews_file = category_review_file(product.category)
268
  if not os.path.exists(reviews_file):
 
15
 
16
 
17
  class Review:
18
+ """
19
+ Simple representation of a user Review of a Product
20
+ """
21
  def __init__(self, stars: int, review_text: str):
22
  self.stars = stars
23
  self.review_text = review_text
24
 
25
 
26
  class Product:
27
+ """
28
+ Simple representation of a prduct
29
+ """
30
  def __init__(self, category: str, name: str, description: str, price: float, features: List[str], reviews: List[Review]):
31
  self.category = category
32
  self.name = name
 
38
 
39
  class DataPrompt:
40
  """
41
+ Holder for static prompt generation functions for the data generation process
42
  """
43
  @staticmethod
44
  def prompt_setup() -> str:
 
100
 
101
 
102
  def generate_products(category: str, features: List[str], k: int = 20):
103
+ """
104
+ Invoke GPT3.5 Turbo model and get it to generate some products based on a category
105
+ """
106
  prompt = DataPrompt.products_for_category(category, features, k)
107
  response = openai.ChatCompletion.create(
108
  model="gpt-3.5-turbo-16k",
 
117
 
118
 
119
  def category_product_file(category: str) -> str:
120
+ """
121
+ Utility to get the file containing products in a category
122
+ """
123
  output_file_name = f"products_{category.lower().replace(' ', '_')}.json"
124
  return os.path.join(data_dir, 'json', output_file_name)
125
 
126
 
127
  def category_review_file(category: str) -> str:
128
+ """
129
+ Utility to get the file containing reviews of products in a category
130
+ """
131
  output_file_name = f"reviews_{category.lower().replace(' ', '_')}.json"
132
  return os.path.join(data_dir, 'json', output_file_name)
133
 
134
 
135
  def products_for_category(category: str) -> List[Product]:
136
+ """
137
+ Load all the associated products which have been generated for this
138
+ category, and the reviews, then merge the two and return a list of
139
+ all the products in this category along with their reviews
140
+ """
141
  cat_file = category_product_file(category)
142
  if not os.path.exists(cat_file):
143
  return []
 
161
 
162
 
163
  def product_names_for_category(category: str) -> List[str]:
164
+ """
165
+ Get a list of just the names of the products in this category
166
+ from the generated product json file
167
+ """
168
  cat_file = category_product_file(category)
169
  if not os.path.exists(cat_file):
170
  return []
 
178
 
179
 
180
  def add_products(category: str, product_json: str, k: int) -> None:
181
+ """
182
+ Given a string of json representing newly generated products,
183
+ add those products to the existing product json file for this category
184
+ """
185
  cat_file = category_product_file(category)
186
  if not os.path.exists(cat_file):
187
  with open(cat_file, 'w') as f:
 
201
 
202
 
203
  def get_categories_and_features() -> Dict[str, List[str]]:
204
+ """
205
+ Get dictionary of will each category as a key and the list of available
206
+ features to products in that category as the value
207
+ """
208
  product_features_file = os.path.join(data_dir, 'json', 'product_features.json')
209
  cats_and_feats = {}
210
  with open(product_features_file, 'r') as f:
 
217
 
218
 
219
  def generate_all_products(target_count=40):
220
+ """
221
+ Generate all products for all categories, trying to reach a given target count
222
+ of products.
223
+ """
224
  product_features_file = os.path.join(data_dir, 'product_features.json')
225
 
226
  with open(product_features_file, 'r') as f:
 
238
 
239
 
240
  def dump_products_to_csv():
241
+ """
242
+ Dump a csv file for debug, for every product showing category name and product name
243
+ """
244
  cats = get_categories_and_features().keys()
245
  cat_keys = []
246
  for cat in cats:
 
252
 
253
 
254
  def generate_reviews(target_count: int):
255
+ """
256
+ Generate reviews for each category up to a target count of reviews
257
+ """
258
  for cat in get_categories_and_features().keys():
259
  generate_reviews_for_category(cat, target_count)
260
 
261
 
262
  def generate_reviews_for_category(category: str, target_count: int):
263
+ """
264
+ Generate reviews for a specific category up to a given target number of reviews
265
+ """
266
  batch_size = 25 # Max number of reviews to request in one go from GPT
267
 
268
  # Set up a loop to continue trying to find more work to do until complete
 
294
 
295
 
296
  def generate_reviews_for_product(product: Product, k: int):
297
+ """
298
+ Generate a number of reviews from GPT3.5 for a specific product and add them to the product
299
+ """
300
  prompt = DataPrompt.reviews_for_product(product, k)
301
  response = openai.ChatCompletion.create(
302
  model="gpt-3.5-turbo-16k",
 
311
 
312
 
313
  def add_reviews_to_product(reviews_json: str, product: Product):
314
+ """
315
+ Load the reviews file containing this product category, append this review to the list and
316
+ re-save the file
317
+ """
318
  reviews_json = json.loads(reviews_json)
319
  reviews_file = category_review_file(product.category)
320
  if not os.path.exists(reviews_file):