beweinreich commited on
Commit
21fcd81
1 Parent(s): bc3ba53

added in batch inserts

Browse files
Files changed (3) hide show
  1. algo.py +4 -11
  2. db/db_utils.py +68 -2
  3. tasks.py +0 -19
algo.py CHANGED
@@ -14,17 +14,10 @@ from db.db_utils import store_mapping_to_db, cached_get_mapping_from_db, get_dic
14
  from ask_gpt import query_gpt
15
  from multi_food_item_detector import extract_items, has_delimiters
16
  from mapping_template import empty_template, heterogeneous_template, multi_item_template, nonfood_template, usda_template
17
- # from tasks import insert_result
18
  from specificity_classifier import classify_text_to_specificity
19
 
20
- logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(levelname)s - %(message)s')
21
  similarity_threshold = 0.78
22
-
23
-
24
- def insert_result(db_conn, run_key, mappings):
25
- db_cursor = db_conn.cursor()
26
- for mapping in mappings:
27
- store_result_to_db(db_cursor, db_conn, run_key, mapping)
28
 
29
 
30
  class Algo:
@@ -347,12 +340,12 @@ class Algo:
347
  # store_result_to_db(self.db_cursor, self.db_conn, self.run_key, mapping)
348
  results.append(mapping)
349
 
350
- if len(result_batch) >= 100:
351
- insert_result(self.db_conn, self.run_key, result_batch)
352
  result_batch = []
353
 
354
  if len(result_batch) > 0:
355
- insert_result(self.db_conn, self.run_key, result_batch)
356
  result_batch = []
357
 
358
 
 
14
  from ask_gpt import query_gpt
15
  from multi_food_item_detector import extract_items, has_delimiters
16
  from mapping_template import empty_template, heterogeneous_template, multi_item_template, nonfood_template, usda_template
 
17
  from specificity_classifier import classify_text_to_specificity
18
 
 
19
  similarity_threshold = 0.78
20
+ logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(levelname)s - %(message)s')
 
 
 
 
 
21
 
22
 
23
  class Algo:
 
340
  # store_result_to_db(self.db_cursor, self.db_conn, self.run_key, mapping)
341
  results.append(mapping)
342
 
343
+ if len(result_batch) >= 500:
344
+ store_batch_results_to_db(self.db_conn, self.run_key, result_batch)
345
  result_batch = []
346
 
347
  if len(result_batch) > 0:
348
+ store_batch_results_to_db(self.db_conn, self.run_key, result_batch)
349
  result_batch = []
350
 
351
 
db/db_utils.py CHANGED
@@ -3,9 +3,10 @@ import psycopg2
3
  import logging
4
  from dotenv import load_dotenv
5
  from functools import lru_cache
 
6
 
7
- load_dotenv()
8
 
 
9
 
10
  def get_connection():
11
  DATABASE_URL = os.environ['DATABASE_URL']
@@ -19,7 +20,6 @@ def get_connection():
19
  print(f"Failed to connect to database: {e}")
20
  raise
21
 
22
-
23
  def initialize_db(conn):
24
  cursor = conn.cursor()
25
  cursor.execute('''
@@ -215,3 +215,69 @@ def store_result_to_db(cursor, conn, run_key, result):
215
 
216
  conn.commit()
217
  return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import logging
4
  from dotenv import load_dotenv
5
  from functools import lru_cache
6
+ from psycopg2.extras import execute_values
7
 
 
8
 
9
+ load_dotenv()
10
 
11
  def get_connection():
12
  DATABASE_URL = os.environ['DATABASE_URL']
 
20
  print(f"Failed to connect to database: {e}")
21
  raise
22
 
 
23
  def initialize_db(conn):
24
  cursor = conn.cursor()
25
  cursor.execute('''
 
215
 
216
  conn.commit()
217
  return True
218
+
219
+
220
+ def store_batch_results_to_db(cursor, conn, run_key, results):
221
+ values = [
222
+ (
223
+ run_key,
224
+ result['run_row'],
225
+ result['date'],
226
+ result['input_word'],
227
+ result['dictionary_word'],
228
+ result['is_food'],
229
+ result['sr_legacy_food_category'],
230
+ result['wweia_category'],
231
+ result['dry_matter_content'],
232
+ result['leakage'],
233
+ result['weight'],
234
+ result['weight_metric_tonnes'],
235
+ result['donor'],
236
+ result['similarity_score'],
237
+ result['food_nonfood_score'],
238
+ result['distance'],
239
+ result['ef'],
240
+ result['mt_lb_mile'],
241
+ result['baseline_emissions'],
242
+ result['leakage_emissions'],
243
+ result['project_emissions'],
244
+ result['total_emissions_reduction']
245
+ )
246
+ for result in results
247
+ ]
248
+
249
+ insert_query = '''
250
+ INSERT INTO results (
251
+ run_key, run_row, date, input_word, dictionary_word, is_food,
252
+ sr_legacy_food_category, wweia_category, dry_matter_content, leakage,
253
+ weight, weight_metric_tonnes, donor, similarity_score, food_nonfood_score,
254
+ distance, ef, mt_lb_mile, baseline_emissions, leakage_emissions,
255
+ project_emissions, total_emissions_reduction
256
+ ) VALUES %s
257
+ ON CONFLICT (run_key, run_row)
258
+ DO UPDATE SET
259
+ date = EXCLUDED.date,
260
+ input_word = EXCLUDED.input_word,
261
+ dictionary_word = EXCLUDED.dictionary_word,
262
+ is_food = EXCLUDED.is_food,
263
+ sr_legacy_food_category = EXCLUDED.sr_legacy_food_category,
264
+ wweia_category = EXCLUDED.wweia_category,
265
+ dry_matter_content = EXCLUDED.dry_matter_content,
266
+ leakage = EXCLUDED.leakage,
267
+ weight = EXCLUDED.weight,
268
+ weight_metric_tonnes = EXCLUDED.weight_metric_tonnes,
269
+ donor = EXCLUDED.donor,
270
+ similarity_score = EXCLUDED.similarity_score,
271
+ food_nonfood_score = EXCLUDED.food_nonfood_score,
272
+ distance = EXCLUDED.distance,
273
+ ef = EXCLUDED.ef,
274
+ mt_lb_mile = EXCLUDED.mt_lb_mile,
275
+ baseline_emissions = EXCLUDED.baseline_emissions,
276
+ leakage_emissions = EXCLUDED.leakage_emissions,
277
+ project_emissions = EXCLUDED.project_emissions,
278
+ total_emissions_reduction = EXCLUDED.total_emissions_reduction;
279
+ '''
280
+
281
+ execute_values(cursor, insert_query, values)
282
+ conn.commit()
283
+ return True
tasks.py CHANGED
@@ -7,30 +7,11 @@ from algo import Algo
7
  from dotenv import load_dotenv
8
  from redis import Redis
9
  from rq import Queue
10
- # from celery import Celery
11
  from db.db_utils import get_connection, store_result_to_db
12
 
13
  load_dotenv()
14
 
15
- # app = Celery('tasks', broker=REDIS_URL, backend=REDIS_URL)
16
 
17
- # app.conf.update(
18
- # result_expires=3600,
19
- # task_serializer='json',
20
- # result_serializer='json',
21
- # accept_content=['json'],
22
- # timezone='UTC',
23
- # enable_utc=True,
24
- # broker_connection_retry_on_startup=True
25
- # )
26
-
27
- # @app.task
28
- # def insert_result(db_conn, run_key, mappings):
29
- # db_cursor = db_conn.cursor()
30
- # for mapping in mappings:
31
- # store_result_to_db(db_cursor, db_conn, run_key, mapping)
32
-
33
- # @app.task
34
  def process_file(raw_file_name):
35
  print(f"Processing {raw_file_name}")
36
  if not raw_file_name.endswith('.csv'):
 
7
  from dotenv import load_dotenv
8
  from redis import Redis
9
  from rq import Queue
 
10
  from db.db_utils import get_connection, store_result_to_db
11
 
12
  load_dotenv()
13
 
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def process_file(raw_file_name):
16
  print(f"Processing {raw_file_name}")
17
  if not raw_file_name.endswith('.csv'):