Spaces:
Paused
Paused
File size: 3,962 Bytes
9189e38 ee698cf 9189e38 ee698cf be02bbe 22ad617 9189e38 be02bbe 9189e38 22ad617 ee698cf 22ad617 9189e38 be02bbe 6a1a72f 9189e38 6a1a72f 9189e38 be02bbe 6a1a72f 9189e38 be02bbe 22ad617 9189e38 22ad617 ee698cf 22ad617 9189e38 be02bbe fd50ca7 f0a2141 fd50ca7 9189e38 be02bbe ee698cf be02bbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import os
import csv
import json
import psycopg2
import pandas as pd
from tqdm import tqdm
from openai import OpenAI
from dotenv import load_dotenv
from psycopg2.extras import DictCursor
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
from db.db_utils import get_connection
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
use_openai = False
if use_openai:
client = OpenAI(api_key=api_key)
else:
mistral_api_key = os.getenv("MISTRAL_API_KEY")
client = MistralClient(api_key=mistral_api_key)
db_conn = get_connection()
db_cursor = db_conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
# Load your Excel file
file_path = './dictionary/final_corrected_wweia_food_category_complete - final_corrected_wweia_food_category_complete.csv'
spreadsheet = pd.read_csv(file_path)
def find_best_category(food_item, category, dataframe):
# filtered_df = dataframe
filtered_df = dataframe[dataframe['closest_category'] == category]
if filtered_df.empty:
filtered_df = dataframe
descriptions = filtered_df['wweia_food_category_description'].tolist()
prompt = (
f"Given the food item '{food_item}' and the category '{category}', choose the most appropriate category from the following options:\n{descriptions}\n\n"
f"You should respond in json format with an object that has the key `guess`, and the value is the categoy."
)
if use_openai:
completion = client.chat.completions.create(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
],
model="gpt-3.5-turbo-1106",
response_format={"type": "json_object"},
)
else:
completion = client.chat(
# model="mistral-large-latest",
model="mistral-small-latest",
response_format={"type": "json_object"},
messages=[
ChatMessage(role="user", content=prompt)
],
)
response = completion.choices[0].message.content
parsed = parse_response(response)
return parsed
# Define the function to parse the GPT response
def parse_response(response):
try:
result = json.loads(response)
return result['guess']
except (json.JSONDecodeError, KeyError) as e:
print(f"Error parsing response: {response} - {e}")
return None
# open up the current dictionary csv file
if use_openai:
db_cursor.execute('SELECT * FROM dictionary where wweia_category is null')
else:
db_cursor.execute('SELECT * FROM dictionary where wweia_category_mistral is null')
rows = db_cursor.fetchall()
for row in tqdm(rows, desc="Processing"):
print()
fdc_id = row['fdc_id']
food_item = row['description']
category = row['food_category']
print(f"Processing '{food_item}'")
# fix the category for Breakfast Cereals
if category == 'Breakfast Cereals':
category = 'Cereal Grains and Pasta'
# elif category == 'Fast Foods':
# # TODO
# elif category == 'American Indian/Alaska Native Foods':
# # TODO
# elif category == 'Restaurant Foods':
# # TODO
# elif category == 'Spices and Herbs':
# # TODO
# elif category == 'Restaurant Foods':
# # TODO
# Find the best category for the food item
best_category = find_best_category(food_item, category, spreadsheet)
print(f"Q: '{food_item}'")
print(f"A: '{best_category}'")
if best_category:
if use_openai:
db_cursor.execute('UPDATE dictionary SET wweia_category = %s WHERE fdc_id = %s', (best_category, fdc_id))
else:
db_cursor.execute('UPDATE dictionary SET wweia_category_mistral = %s WHERE fdc_id = %s', (best_category, fdc_id))
db_conn.commit()
else:
print(f"Failed to find a category for '{food_item}'")
db_conn.close()
|