Spaces:
Paused
Paused
import os | |
import csv | |
import json | |
import psycopg2 | |
import pandas as pd | |
from tqdm import tqdm | |
from openai import OpenAI | |
from dotenv import load_dotenv | |
from psycopg2.extras import DictCursor | |
from mistralai.client import MistralClient | |
from mistralai.models.chat_completion import ChatMessage | |
from db.db_utils import get_connection | |
load_dotenv() | |
api_key = os.getenv("OPENAI_API_KEY") | |
use_openai = True | |
if use_openai: | |
client = OpenAI(api_key=api_key) | |
else: | |
mistral_api_key = os.getenv("MISTRAL_API_KEY") | |
client = MistralClient(api_key=mistral_api_key) | |
db_conn = get_connection() | |
db_cursor = db_conn.cursor(cursor_factory=psycopg2.extras.DictCursor) | |
# Load your Excel file | |
file_path = './dictionary/final_corrected_wweia_food_category_complete - final_corrected_wweia_food_category_complete.csv' | |
spreadsheet = pd.read_csv(file_path) | |
def find_best_category(food_item, category, dataframe): | |
filtered_df = dataframe | |
# filtered_df = dataframe[dataframe['closest_category'] == category] | |
# if filtered_df.empty: | |
# filtered_df = dataframe | |
descriptions = filtered_df['wweia_food_category_description'].tolist() | |
prompt = ( | |
f"Given the food item '{food_item}' and the classification of '{category}', choose the most appropriate category from the following options:\n{descriptions}\n\n" | |
f"Only respond with a category from the above. Do not come up with a new category. Do not respond with 'Legumes and Legume Products'.\n\n" | |
f"You should respond in json format with an object that has the key `guess`, and the value is the categoy." | |
) | |
if use_openai: | |
completion = client.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": prompt} | |
], | |
model="gpt-3.5-turbo-1106", | |
response_format={"type": "json_object"}, | |
) | |
else: | |
completion = client.chat( | |
# model="mistral-large-latest", | |
model="mistral-small-latest", | |
response_format={"type": "json_object"}, | |
messages=[ | |
ChatMessage(role="user", content=prompt) | |
], | |
) | |
response = completion.choices[0].message.content | |
parsed = parse_response(response) | |
return parsed | |
# Define the function to parse the GPT response | |
def parse_response(response): | |
try: | |
result = json.loads(response) | |
return result['guess'] | |
except (json.JSONDecodeError, KeyError) as e: | |
print(f"Error parsing response: {response} - {e}") | |
return None | |
# open up the current dictionary csv file | |
if use_openai: | |
db_cursor.execute('SELECT * FROM dictionary where wweia_category is null') | |
else: | |
db_cursor.execute('SELECT * FROM dictionary where wweia_category_mistral is null') | |
rows = db_cursor.fetchall() | |
for row in tqdm(rows, desc="Processing"): | |
print() | |
fdc_id = row['fdc_id'] | |
food_item = row['description'] | |
category = row['foundation_category'] | |
print(f"Processing '{food_item}'") | |
# fix the category for Breakfast Cereals | |
# if category == 'Breakfast Cereals': | |
# category = 'Cereal Grains and Pasta' | |
# elif category == 'Fast Foods': | |
# # TODO | |
# elif category == 'American Indian/Alaska Native Foods': | |
# # TODO | |
# elif category == 'Restaurant Foods': | |
# # TODO | |
# elif category == 'Spices and Herbs': | |
# # TODO | |
# elif category == 'Restaurant Foods': | |
# # TODO | |
# Find the best category for the food item | |
best_category = find_best_category(food_item, category, spreadsheet) | |
print(f"Q: '{food_item}'") | |
print(f"A: '{best_category}'") | |
if best_category: | |
# ensure that the best_category is in the spreadsheet | |
if best_category not in spreadsheet['wweia_food_category_description'].values: | |
print(f"Error: '{best_category}' not found in the spreadsheet") | |
continue | |
if use_openai: | |
db_cursor.execute('UPDATE dictionary SET wweia_category = %s WHERE fdc_id = %s', (best_category, fdc_id)) | |
else: | |
db_cursor.execute('UPDATE dictionary SET wweia_category_mistral = %s WHERE fdc_id = %s', (best_category, fdc_id)) | |
db_conn.commit() | |
else: | |
print(f"Failed to find a category for '{food_item}'") | |
db_conn.close() | |