Spaces:
Paused
Paused
Commit
•
21fcd81
1
Parent(s):
bc3ba53
added in batch inserts
Browse files- algo.py +4 -11
- db/db_utils.py +68 -2
- tasks.py +0 -19
algo.py
CHANGED
@@ -14,17 +14,10 @@ from db.db_utils import store_mapping_to_db, cached_get_mapping_from_db, get_dic
|
|
14 |
from ask_gpt import query_gpt
|
15 |
from multi_food_item_detector import extract_items, has_delimiters
|
16 |
from mapping_template import empty_template, heterogeneous_template, multi_item_template, nonfood_template, usda_template
|
17 |
-
# from tasks import insert_result
|
18 |
from specificity_classifier import classify_text_to_specificity
|
19 |
|
20 |
-
logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(levelname)s - %(message)s')
|
21 |
similarity_threshold = 0.78
|
22 |
-
|
23 |
-
|
24 |
-
def insert_result(db_conn, run_key, mappings):
|
25 |
-
db_cursor = db_conn.cursor()
|
26 |
-
for mapping in mappings:
|
27 |
-
store_result_to_db(db_cursor, db_conn, run_key, mapping)
|
28 |
|
29 |
|
30 |
class Algo:
|
@@ -347,12 +340,12 @@ class Algo:
|
|
347 |
# store_result_to_db(self.db_cursor, self.db_conn, self.run_key, mapping)
|
348 |
results.append(mapping)
|
349 |
|
350 |
-
if len(result_batch) >=
|
351 |
-
|
352 |
result_batch = []
|
353 |
|
354 |
if len(result_batch) > 0:
|
355 |
-
|
356 |
result_batch = []
|
357 |
|
358 |
|
|
|
14 |
from ask_gpt import query_gpt
|
15 |
from multi_food_item_detector import extract_items, has_delimiters
|
16 |
from mapping_template import empty_template, heterogeneous_template, multi_item_template, nonfood_template, usda_template
|
|
|
17 |
from specificity_classifier import classify_text_to_specificity
|
18 |
|
|
|
19 |
similarity_threshold = 0.78
|
20 |
+
logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
|
23 |
class Algo:
|
|
|
340 |
# store_result_to_db(self.db_cursor, self.db_conn, self.run_key, mapping)
|
341 |
results.append(mapping)
|
342 |
|
343 |
+
if len(result_batch) >= 500:
|
344 |
+
store_batch_results_to_db(self.db_conn, self.run_key, result_batch)
|
345 |
result_batch = []
|
346 |
|
347 |
if len(result_batch) > 0:
|
348 |
+
store_batch_results_to_db(self.db_conn, self.run_key, result_batch)
|
349 |
result_batch = []
|
350 |
|
351 |
|
db/db_utils.py
CHANGED
@@ -3,9 +3,10 @@ import psycopg2
|
|
3 |
import logging
|
4 |
from dotenv import load_dotenv
|
5 |
from functools import lru_cache
|
|
|
6 |
|
7 |
-
load_dotenv()
|
8 |
|
|
|
9 |
|
10 |
def get_connection():
|
11 |
DATABASE_URL = os.environ['DATABASE_URL']
|
@@ -19,7 +20,6 @@ def get_connection():
|
|
19 |
print(f"Failed to connect to database: {e}")
|
20 |
raise
|
21 |
|
22 |
-
|
23 |
def initialize_db(conn):
|
24 |
cursor = conn.cursor()
|
25 |
cursor.execute('''
|
@@ -215,3 +215,69 @@ def store_result_to_db(cursor, conn, run_key, result):
|
|
215 |
|
216 |
conn.commit()
|
217 |
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import logging
|
4 |
from dotenv import load_dotenv
|
5 |
from functools import lru_cache
|
6 |
+
from psycopg2.extras import execute_values
|
7 |
|
|
|
8 |
|
9 |
+
load_dotenv()
|
10 |
|
11 |
def get_connection():
|
12 |
DATABASE_URL = os.environ['DATABASE_URL']
|
|
|
20 |
print(f"Failed to connect to database: {e}")
|
21 |
raise
|
22 |
|
|
|
23 |
def initialize_db(conn):
|
24 |
cursor = conn.cursor()
|
25 |
cursor.execute('''
|
|
|
215 |
|
216 |
conn.commit()
|
217 |
return True
|
218 |
+
|
219 |
+
|
220 |
+
def store_batch_results_to_db(cursor, conn, run_key, results):
|
221 |
+
values = [
|
222 |
+
(
|
223 |
+
run_key,
|
224 |
+
result['run_row'],
|
225 |
+
result['date'],
|
226 |
+
result['input_word'],
|
227 |
+
result['dictionary_word'],
|
228 |
+
result['is_food'],
|
229 |
+
result['sr_legacy_food_category'],
|
230 |
+
result['wweia_category'],
|
231 |
+
result['dry_matter_content'],
|
232 |
+
result['leakage'],
|
233 |
+
result['weight'],
|
234 |
+
result['weight_metric_tonnes'],
|
235 |
+
result['donor'],
|
236 |
+
result['similarity_score'],
|
237 |
+
result['food_nonfood_score'],
|
238 |
+
result['distance'],
|
239 |
+
result['ef'],
|
240 |
+
result['mt_lb_mile'],
|
241 |
+
result['baseline_emissions'],
|
242 |
+
result['leakage_emissions'],
|
243 |
+
result['project_emissions'],
|
244 |
+
result['total_emissions_reduction']
|
245 |
+
)
|
246 |
+
for result in results
|
247 |
+
]
|
248 |
+
|
249 |
+
insert_query = '''
|
250 |
+
INSERT INTO results (
|
251 |
+
run_key, run_row, date, input_word, dictionary_word, is_food,
|
252 |
+
sr_legacy_food_category, wweia_category, dry_matter_content, leakage,
|
253 |
+
weight, weight_metric_tonnes, donor, similarity_score, food_nonfood_score,
|
254 |
+
distance, ef, mt_lb_mile, baseline_emissions, leakage_emissions,
|
255 |
+
project_emissions, total_emissions_reduction
|
256 |
+
) VALUES %s
|
257 |
+
ON CONFLICT (run_key, run_row)
|
258 |
+
DO UPDATE SET
|
259 |
+
date = EXCLUDED.date,
|
260 |
+
input_word = EXCLUDED.input_word,
|
261 |
+
dictionary_word = EXCLUDED.dictionary_word,
|
262 |
+
is_food = EXCLUDED.is_food,
|
263 |
+
sr_legacy_food_category = EXCLUDED.sr_legacy_food_category,
|
264 |
+
wweia_category = EXCLUDED.wweia_category,
|
265 |
+
dry_matter_content = EXCLUDED.dry_matter_content,
|
266 |
+
leakage = EXCLUDED.leakage,
|
267 |
+
weight = EXCLUDED.weight,
|
268 |
+
weight_metric_tonnes = EXCLUDED.weight_metric_tonnes,
|
269 |
+
donor = EXCLUDED.donor,
|
270 |
+
similarity_score = EXCLUDED.similarity_score,
|
271 |
+
food_nonfood_score = EXCLUDED.food_nonfood_score,
|
272 |
+
distance = EXCLUDED.distance,
|
273 |
+
ef = EXCLUDED.ef,
|
274 |
+
mt_lb_mile = EXCLUDED.mt_lb_mile,
|
275 |
+
baseline_emissions = EXCLUDED.baseline_emissions,
|
276 |
+
leakage_emissions = EXCLUDED.leakage_emissions,
|
277 |
+
project_emissions = EXCLUDED.project_emissions,
|
278 |
+
total_emissions_reduction = EXCLUDED.total_emissions_reduction;
|
279 |
+
'''
|
280 |
+
|
281 |
+
execute_values(cursor, insert_query, values)
|
282 |
+
conn.commit()
|
283 |
+
return True
|
tasks.py
CHANGED
@@ -7,30 +7,11 @@ from algo import Algo
|
|
7 |
from dotenv import load_dotenv
|
8 |
from redis import Redis
|
9 |
from rq import Queue
|
10 |
-
# from celery import Celery
|
11 |
from db.db_utils import get_connection, store_result_to_db
|
12 |
|
13 |
load_dotenv()
|
14 |
|
15 |
-
# app = Celery('tasks', broker=REDIS_URL, backend=REDIS_URL)
|
16 |
|
17 |
-
# app.conf.update(
|
18 |
-
# result_expires=3600,
|
19 |
-
# task_serializer='json',
|
20 |
-
# result_serializer='json',
|
21 |
-
# accept_content=['json'],
|
22 |
-
# timezone='UTC',
|
23 |
-
# enable_utc=True,
|
24 |
-
# broker_connection_retry_on_startup=True
|
25 |
-
# )
|
26 |
-
|
27 |
-
# @app.task
|
28 |
-
# def insert_result(db_conn, run_key, mappings):
|
29 |
-
# db_cursor = db_conn.cursor()
|
30 |
-
# for mapping in mappings:
|
31 |
-
# store_result_to_db(db_cursor, db_conn, run_key, mapping)
|
32 |
-
|
33 |
-
# @app.task
|
34 |
def process_file(raw_file_name):
|
35 |
print(f"Processing {raw_file_name}")
|
36 |
if not raw_file_name.endswith('.csv'):
|
|
|
7 |
from dotenv import load_dotenv
|
8 |
from redis import Redis
|
9 |
from rq import Queue
|
|
|
10 |
from db.db_utils import get_connection, store_result_to_db
|
11 |
|
12 |
load_dotenv()
|
13 |
|
|
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def process_file(raw_file_name):
|
16 |
print(f"Processing {raw_file_name}")
|
17 |
if not raw_file_name.endswith('.csv'):
|