File size: 4,383 Bytes
9189e38
 
 
ee698cf
9189e38
 
 
 
ee698cf
be02bbe
 
22ad617
9189e38
 
 
 
394e2d6
be02bbe
 
 
 
 
 
9189e38
22ad617
ee698cf
22ad617
9189e38
 
 
 
 
394e2d6
 
 
 
9189e38
6a1a72f
 
 
394e2d6
 
6a1a72f
 
394e2d6
6a1a72f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9189e38
be02bbe
6a1a72f
 
 
9189e38
 
 
 
 
 
 
 
 
 
 
be02bbe
 
 
 
22ad617
9189e38
22ad617
ee698cf
22ad617
9189e38
1c28270
9189e38
be02bbe
 
fd50ca7
394e2d6
 
f0a2141
 
 
 
 
 
 
 
 
 
fd50ca7
9189e38
 
 
 
 
 
394e2d6
 
 
 
 
be02bbe
 
 
 
ee698cf
 
 
 
be02bbe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import csv
import json
import psycopg2
import pandas as pd
from tqdm import tqdm
from openai import OpenAI
from dotenv import load_dotenv
from psycopg2.extras import DictCursor
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
from db.db_utils import get_connection

load_dotenv()

api_key = os.getenv("OPENAI_API_KEY")
use_openai = True

if use_openai:
    client = OpenAI(api_key=api_key)
else:
    mistral_api_key = os.getenv("MISTRAL_API_KEY")
    client = MistralClient(api_key=mistral_api_key)

db_conn = get_connection()
db_cursor = db_conn.cursor(cursor_factory=psycopg2.extras.DictCursor)

# Load your Excel file
file_path = './dictionary/final_corrected_wweia_food_category_complete - final_corrected_wweia_food_category_complete.csv'
spreadsheet = pd.read_csv(file_path)

def find_best_category(food_item, category, dataframe):
    filtered_df = dataframe
    # filtered_df = dataframe[dataframe['closest_category'] == category]
    # if filtered_df.empty:
    #     filtered_df = dataframe

    descriptions = filtered_df['wweia_food_category_description'].tolist()

    prompt = (
        f"Given the food item '{food_item}' and the classification of '{category}', choose the most appropriate category from the following options:\n{descriptions}\n\n"
        f"Only respond with a category from the above. Do not come up with a new category. Do not respond with 'Legumes and Legume Products'.\n\n"
        f"You should respond in json format with an object that has the key `guess`, and the value is the categoy."
    )

    if use_openai:
        completion = client.chat.completions.create(
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt}
            ],
            model="gpt-3.5-turbo-1106",
            response_format={"type": "json_object"},
        )
    else:
        completion = client.chat(
            # model="mistral-large-latest",
            model="mistral-small-latest",
            response_format={"type": "json_object"},
            messages=[
                ChatMessage(role="user", content=prompt)
            ],
        )

    response = completion.choices[0].message.content
    parsed = parse_response(response)
    return parsed

# Define the function to parse the GPT response
def parse_response(response):
    try:
        result = json.loads(response)
        return result['guess']
    except (json.JSONDecodeError, KeyError) as e:
        print(f"Error parsing response: {response} - {e}")
        return None

# open up the current dictionary csv file
if use_openai:
    db_cursor.execute('SELECT * FROM dictionary where wweia_category is null')
else:
    db_cursor.execute('SELECT * FROM dictionary where wweia_category_mistral is null')
rows = db_cursor.fetchall()

for row in tqdm(rows, desc="Processing"):
    print()
    fdc_id = row['fdc_id']
    food_item = row['description']
    category = row['sr_legacy_food_category']

    print(f"Processing '{food_item}'")

    # fix the category for Breakfast Cereals
    # if category == 'Breakfast Cereals':
    #     category = 'Cereal Grains and Pasta'
    # elif category == 'Fast Foods':
    #     # TODO
    # elif category == 'American Indian/Alaska Native Foods':
    #     # TODO
    # elif category == 'Restaurant Foods':
    #     # TODO
    # elif category == 'Spices and Herbs':
    #     # TODO
    # elif category == 'Restaurant Foods':
    #     # TODO

    # Find the best category for the food item
    best_category = find_best_category(food_item, category, spreadsheet)
    print(f"Q: '{food_item}'")
    print(f"A: '{best_category}'")

    if best_category:
        # ensure that the best_category is in the spreadsheet
        if best_category not in spreadsheet['wweia_food_category_description'].values:
            print(f"Error: '{best_category}' not found in the spreadsheet")
            continue

        if use_openai:
            db_cursor.execute('UPDATE dictionary SET wweia_category = %s WHERE fdc_id = %s', (best_category, fdc_id))
        else:
            db_cursor.execute('UPDATE dictionary SET wweia_category_mistral = %s WHERE fdc_id = %s', (best_category, fdc_id))
        db_conn.commit()
    else:
        print(f"Failed to find a category for '{food_item}'")

db_conn.close()