File size: 24,903 Bytes
68f18b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
from typing import Optional

from langchain.chains import create_extraction_chain_pydantic
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_extraction_chain
from copy import deepcopy
from langchain_openai import ChatOpenAI
from langchain_community.utilities import SQLDatabase
import os
import difflib
import ast
import json
import re
from thefuzz import process
# Set up logging
import logging

from dotenv import load_dotenv

load_dotenv(".env")

logging.basicConfig(level=logging.INFO)
# Save the log to a file
handler = logging.FileHandler('extractor.log')
logger = logging.getLogger(__name__)

os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
# os.environ["ANTHROPIC_API_KEY"] = os.getenv('ANTHROPIC_API_KEY')

if os.getenv('LANGSMITH'):
    os.environ['LANGCHAIN_TRACING_V2'] = 'true'
    os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
    os.environ[
        'LANGCHAIN_API_KEY'] = os.getenv("LANGSMITH_API_KEY")
    os.environ['LANGCHAIN_PROJECT'] = 'master-theses'
db = SQLDatabase.from_uri("sqlite:///data/games.db")

# from langchain_anthropic import ChatAnthropic
class Extractor():
    # llm = ChatOpenAI(model_name="gpt-4-0125-preview", temperature=0)
    #gpt-3.5-turbo
    def __init__(self, model="gpt-3.5-turbo-0125", schema_config=None, custom_extractor_prompt=None):
        # model = "gpt-4-0125-preview"
        if custom_extractor_prompt:
            cust_promt = ChatPromptTemplate.from_template(custom_extractor_prompt)

        self.llm = ChatOpenAI(model=model, temperature=0)
        # self.llm = ChatAnthropic(model="claude-3-opus-20240229", temperature=0)
        self.schema = schema_config or {}
        self.chain = create_extraction_chain(self.schema, self.llm, prompt=cust_promt)

    def extract(self, query):
        return self.chain.invoke(query)


class Retriever():
    def __init__(self, db, config):
        self.db = db
        self.config = config
        self.table = config.get('db_table')
        self.column = config.get('db_column')
        self.pk_column = config.get('pk_column')
        self.numeric = config.get('numeric', False)
        self.response = []
        self.query = f"SELECT {self.column} FROM {self.table}"
        self.augmented_table = config.get('augmented_table', None)
        self.augmented_column = config.get('augmented_column', None)
        self.augmented_fk = config.get('augmented_fk', None)

    def query_as_list(self):
        # Execute the query
        response = self.db.run(self.query)
        response = [el for sub in ast.literal_eval(response) for el in sub if el]
        if not self.numeric:
            response = [re.sub(r"\b\d+\b", "", string).strip() for string in response]
        self.response = list(set(response))
        # print(self.response)
        return self.response

    def get_augmented_items(self, prompt):
        if self.augmented_table is None:
            return None
        else:
            # Construct the query to search for the prompt in the augmented table
            query = f"SELECT {self.augmented_fk} FROM {self.augmented_table} WHERE LOWER({self.augmented_column}) = LOWER('{prompt}')"

            # Execute the query
            fk_response = self.db.run(query)
            if fk_response:
                # Extract the FK value
                fk_response = ast.literal_eval(fk_response)
                fk_value = fk_response[0][0]
                query = f"SELECT {self.column} FROM {self.table} WHERE {self.pk_column} = {fk_value}"
                # Execute the query
                matching_response = self.db.run(query)
                # Extract the matching response
                matching_response = ast.literal_eval(matching_response)
                matching_response = matching_response[0][0]
                return matching_response
            else:
                return None

    def find_close_matches(self, target_string, n=3, method="difflib", threshold=70):
        """
        Find and return the top n close matches to target_string in the database query results.

        Args:
        - target_string (str): The string to match against the database results.
        - n (int): Number of top matches to return.

        Returns:
        - list of tuples: Each tuple contains a match and its score.
        """
        # Ensure we have the response list populated
        if not self.response:
            self.query_as_list()

        # Find top n close matches
        if method == "fuzzy":
            # Use the fuzzy_string method to get matches and their scores
            # If the threshold is met, return the best match; otherwise, return all matches meeting the threshold
            top_matches = self.fuzzy_string(target_string, limit=n, threshold=threshold)


        else:
            # Use difflib's get_close_matches to get the top n matches
            top_matches = difflib.get_close_matches(target_string, self.response, n=n, cutoff=0.2)

        return top_matches

    def fuzzy_string(self, prompt, limit, threshold=80, low_threshold=30):

        # Get matches and their scores, limited by the specified 'limit'
        matches = process.extract(prompt, self.response, limit=limit)


        filtered_matches = [match for match in matches if match[1] >= threshold]

        # If no matches meet the threshold, return the list of all matches' strings
        if not filtered_matches:
            # Return matches above the low_threshold
            # Fix for wrong properties being returned
            return [match[0] for match in matches if match[1] >= low_threshold]


        # If there's only one match meeting the threshold, return it as a string
        if len(filtered_matches) == 1:
            return filtered_matches[0][0]  # Return the matched string directly

        # If there's more than one match meeting the threshold or ties, return the list of matches' strings
        highest_score = filtered_matches[0][1]
        ties = [match for match in filtered_matches if match[1] == highest_score]

        # Return the strings of tied matches directly, ignoring the scores
        m = [match[0] for match in ties]
        if len(m) == 1:
            return m[0]
        return [match[0] for match in ties]

    def fetch_pk(self, property_name, property_value):
        # Some properties do not have a primary key
        # Return the property value if no primary key is specified
        pk_list = []

        # Check if the property_value is a list; if not, make it a list for uniform processing
        if not isinstance(property_value, list):
            property_value = [property_value]

        # Some properties do not have a primary key
        # Return None for each property_value if no primary key is specified
        if self.pk_column is None:
            return [None for _ in property_value]

        for value in property_value:
            query = f"SELECT {self.pk_column} FROM {self.table} WHERE {self.column} = '{value}' LIMIT 1"
            response = self.db.run(query)

            # Append the response (PK or None) to the pk_list
            pk_list.append(response)

        return pk_list


def setup_retrievers(db, schema_config):
    # retrievers = {}
    # for prop, config in schema_config["properties"].items():
    #     retrievers[prop] = Retriever(db=db, config=config)
    # return retrievers

    retrievers = {}
    # Iterate over each property in the schema_config's properties
    for prop, config in schema_config["properties"].items():
        # Access the 'items' dictionary for the configuration of the array's elements
        item_config = config['items']
        # Create a Retriever instance using the item_config
        retrievers[prop] = Retriever(db=db, config=item_config)
    return retrievers


def extract_properties(prompt, schema_config, custom_extractor_prompt=None):
    """Extract properties from the prompt."""
    # modify schema_conf to only include the required properties
    schema_stripped = {'properties': {}}
    for key, value in schema_config['properties'].items():
        schema_stripped['properties'][key] = {
            'type': value['type'],
            'items': {'type': value['items']['type']}
        }

    extractor = Extractor(schema_config=schema_stripped, custom_extractor_prompt=custom_extractor_prompt)
    extraction_result = extractor.extract(prompt)
    # print("Extraction Result:", extraction_result)

    if 'text' in extraction_result and extraction_result['text']:
        properties = extraction_result['text']
        return properties
    else:
        print("No properties extracted.")
        return None


def recheck_property_value(properties, property_name, retrievers, input_func):
    while True:
        new_value = input_func(f"Enter new value for {property_name} or type 'quit' to stop: ")
        if new_value.lower() == 'quit':
            break  # Exit the loop and do not update the property

        new_top_matches = retrievers[property_name].find_close_matches(new_value, n=3)
        if new_top_matches:
            # Display new top matches and ask for confirmation or re-entry
            print("\nNew close matches found:")
            for i, match in enumerate(new_top_matches, start=1):
                print(f"[{i}] {match}")
            print("[4] Re-enter value")
            print("[5] Quit without updating")

            selection = input_func("Select the best match (1-3), choose 4 to re-enter value, or 5 to quit: ")
            if selection in ['1', '2', '3']:
                selected_match = new_top_matches[int(selection) - 1]
                properties[property_name] = selected_match  # Update the dictionary directly
                print(f"Updated {property_name} to {selected_match}")
                break  # Successfully updated, exit the loop
            elif selection == '5':
                break  # Quit without updating
            # Loop will continue if user selects 4 or inputs invalid selection
        else:
            print("No close matches found. Please try again or type 'quit' to stop.")


def check_and_update_properties(properties_list, retrievers, method="fuzzy", input_func=input):
    """
    Checks and updates the properties in the properties list based on close matches found in the database.
    The function iterates through each property in each property dictionary within the list,
    finds close matches for it in the database using the retrievers, and updates the property
    value based on user selection.

    Args:
        properties_list (list of dict): A list of dictionaries, where each dictionary contains properties
            to check and potentially update based on database matches.
        retrievers (dict): A dictionary of Retriever objects keyed by property name, used to find close matches in the database.
        input_func (function, optional): A function to capture user input. Defaults to the built-in input function.

    The function updates the properties_list in place based on user choices for updating property values
    with close matches found by the retrievers.
    """

    for index, properties in enumerate(properties_list):
        for property_name, retriever in retrievers.items():  # Iterate using items to get both key and value
            property_values = properties.get(property_name, [])
            if not property_values:  # Skip if the property is not present or is an empty list
                continue

            updated_property_values = []  # To store updated list of values

            for value in property_values:
                if retriever.augmented_table:
                    augmented_value = retriever.get_augmented_items(value)
                    if augmented_value:
                        updated_property_values.append(augmented_value)
                        continue
                # Since property_value is now expected to be a list, we handle each value individually
                top_matches = retriever.find_close_matches(value, method=method, n=3)

                # Check if the closest match is the same as the current value
                if top_matches and top_matches[0] == value:
                    updated_property_values.append(value)
                    continue

                if not top_matches:
                    updated_property_values.append(value)  # Keep the original value if no matches found
                    continue

                if type(top_matches) == str and method == "fuzzy":
                    # If the top_matches is a string, it means that the threshold was met and only one item was returned
                    # In this case, we can directly update the property with the top match
                    updated_property_values.append(top_matches)
                    properties[property_name] = updated_property_values
                    continue

                print(f"\nCurrent {property_name}: {value}")
                for i, match in enumerate(top_matches, start=1):
                    print(f"[{i}] {match}")
                print("[4] Enter new value")

                # hmm = input_func(f"Fix for Pycharm, press enter to continue")

                choice = input_func(f"Select the best match for {property_name} (1-4): ")
                if choice in ['1', '2', '3']:
                    selected_match = top_matches[int(choice) - 1]
                    updated_property_values.append(selected_match)  # Update with the selected match
                    print(f"Updated {property_name} to {selected_match}")
                elif choice == '4':
                    # Allow re-entry of value for this specific item
                    recheck_property_value(properties, property_name, value, retrievers, input_func)
                    # Note: Implement recheck_property_value to handle individual value updates within the list
                else:
                    print("Invalid selection. Property not updated.")
                    updated_property_values.append(value)  # Keep the original value

            # Update the entire list for the property after processing all values
            properties[property_name] = updated_property_values


# Function to remove duplicates
def remove_duplicates(dicts):
    seen = {}  # Dictionary to keep track of seen values for each key
    for d in dicts:
        for key in list(d.keys()):  # Use list to avoid RuntimeError for changing dict size during iteration
            value = d[key]
            if key in seen and value == seen[key]:
                del d[key]  # Remove key-value pair if duplicate is found
            else:
                seen[key] = value  # Update seen values for this key
    return dicts


def fetch_pks(properties_list, retrievers):
    all_pk_attributes = []  # Initialize a list to store dictionaries of _pk attributes for each item in properties_list

    # Iterate through each properties dictionary in the list
    for properties in properties_list:
        pk_attributes = {}  # Initialize a dictionary for the current set of properties
        for property_name, property_value in properties.items():
            if property_name in retrievers:
                # Fetch the primary key using the retriever for the current property
                pk = retrievers[property_name].fetch_pk(property_name, property_value)
                # Store it in the dictionary with a modified key name
                pk_attributes[f"{property_name}_pk"] = pk

        # Add the dictionary of _pk attributes for the current set of properties to the list
        all_pk_attributes.append(pk_attributes)

    # Return a list of dictionaries, where each dictionary contains _pk attributes for a set of properties
    return all_pk_attributes


def update_prompt(prompt, properties, pk, properties_original):
    # Replace the original prompt with the updated properties and pk
    prompt = prompt.replace("{{properties}}", str(properties))
    prompt = prompt.replace("{{pk}}", str(pk))
    return prompt


def update_prompt_enhanced(prompt, properties, pk, properties_original):
    updated_info = ""
    for prop, pk_info, prop_orig in zip(properties, pk, properties_original):
        for key in prop.keys():
            # Extract original and updated values
            orig_values = prop_orig.get(key, [])
            updated_values = prop.get(key, [])

            # Ensure both original and updated values are lists for uniform processing
            if not isinstance(orig_values, list):
                orig_values = [orig_values]
            if not isinstance(updated_values, list):
                updated_values = [updated_values]

            # Extract primary key detail for this key, handling various pk formats carefully
            pk_key = f"{key}_pk"  # Construct pk key name based on the property key
            pk_details = pk_info.get(pk_key, [])
            if not isinstance(pk_details, list):
                pk_details = [pk_details]

            for orig_value, updated_value, pk_detail in zip(orig_values, updated_values, pk_details):
                pk_value = None
                if isinstance(pk_detail, str):
                    pk_value = pk_detail.strip("[]()").split(",")[0].replace("'", "").replace('"', '')

                update_statement = ""
                # Skip updating if there's no change in value to avoid redundant info
                if orig_value != updated_value and pk_value:
                    update_statement = f"\n- {orig_value} (now referred to as {updated_value}) has a primary key: {pk_value}."
                elif orig_value != updated_value:
                    update_statement = f"\n- {orig_value} (now referred to as {updated_value})."
                elif pk_value:
                    update_statement = f"\n- {orig_value} has a primary key: {pk_value}."

                updated_info += update_statement

    if updated_info:
        prompt += "\nUpdated Information:" + updated_info

    return prompt


def prompt_cleaner(prompt, db, schema_config):
    """Main function to clean the prompt."""

    retrievers = setup_retrievers(db, schema_config)

    properties = extract_properties(prompt, schema_config)
    # Keep original properties for later use
    properties_original = deepcopy(properties)
    # Remove duplicates - Happens when there are more than one player or team in the prompt
    properties = remove_duplicates(properties)
    if properties:
        check_and_update_properties(properties, retrievers)

        pk = fetch_pks(properties, retrievers)
    properties = update_prompt_enhanced(prompt, properties, pk, properties_original)

    return properties, pk


class PromptCleaner:
    """
    A class designed to clean and process prompts by extracting properties, removing duplicates,
    and updating these properties based on a predefined schema configuration and database interactions.

    Attributes:
        db: A database connection object used to execute queries and fetch data.
        schema_config: A dictionary defining the schema configuration for the extraction process.
        schema_config = {
            "properties": {
                # Property name
                "person_name": {"type": "string", "db_table": "players", "db_column": "name", "pk_column": "hash",
                                    # if mostly numeric, such as 2015-2016 set true
                                "numeric": False},
                "team_name": {"type": "string", "db_table": "teams", "db_column": "name", "pk_column": "id",
                              "numeric": False},
                              # Add more as needed
            },
            # Parameter to extractor, if person_name is required, add it here and the extractor will
            # return an error if it is not found
            "required": [],
        }

    Methods:
        clean(prompt): Cleans the given prompt by extracting and updating properties based on the database.
            Returns a tuple containing the updated properties and their primary keys.
    """

    def __init__(self, db=db, schema_config=None, custom_extractor_prompt=None):
        """
        Initializes the PromptCleaner with a database connection and a schema configuration.

        Args:
            db: The database connection object to be used for querying. (if none, it will use the default db)
            schema_config: A dictionary defining properties and their database mappings for extraction and updating.
        """
        self.db = db
        self.schema_config = schema_config
        self.retrievers = setup_retrievers(self.db, self.schema_config)
        self.cust_extractor_prompt = custom_extractor_prompt

    def clean(self, prompt, return_pk=False, test=False, verbose = False):
        """
        Processes the given prompt to extract properties, remove duplicates, update the properties
        based on close matches within the database, and fetch primary keys for these properties.

        The method first extracts properties from the prompt using the schema configuration,
        then checks these properties against the database to find and update close matches.
        It also fetches primary keys for the updated properties where applicable.

        Args:
            prompt (str): The prompt text to be cleaned and processed.
            return_pk (bool): A flag to indicate whether to return primary keys along with the properties.
            test (bool): A flag to indicate whether to return the original properties for testing purposes.
            verbose (bool): A flag to indicate whether to return the original properties for debugging.

        Returns:
            tuple: A tuple containing two elements:
                - The first element is the original prompt, with updated information that excist in the db.
                - The second element is a list of dictionaries, each containing primary keys for the properties,
                  where applicable.

        """
        if self.cust_extractor_prompt:

            properties = extract_properties(prompt, self.schema_config, self.cust_extractor_prompt)

        else:
            properties = extract_properties(prompt, self.schema_config)
        # Keep original properties for later use
        properties_original = deepcopy(properties)
        if test:
            return properties_original
        # Remove duplicates - Happens when there are more than one player or team in the prompt
        # properties = remove_duplicates(properties)
        pk = None
        if properties:
            check_and_update_properties(properties, self.retrievers)
            pk = fetch_pks(properties, self.retrievers)
        properties = update_prompt_enhanced(prompt, properties, pk, properties_original)



        if return_pk:
            return properties, pk
        elif verbose:
            return properties, properties_original
        else:
            return properties


def load_json(file_path: str) -> dict:
    with open(file_path, 'r') as file:
        return json.load(file)


def create_extractor(schema: str = "src/conf/schema.json", db: SQLDatabase = "sqlite:///data/games.db", ):
    schema_config = load_json(schema)
    db = SQLDatabase.from_uri(db)
    pre_prompt = """Extract and save the relevant entities mentioned \
                    in the following passage together with their properties.
                    
                    Only extract the properties mentioned in the 'information_extraction' function. 
                    
                    The questions are soccer related. game_event are things like yellow cards, goals, assists, freekick ect.
                    Generic properties like, "description", "home team", "away team", "game" ect should NOT be extracted.
                    
                    If a property is not present and is not required in the function parameters, do not include it in the output.
                    If no properties are found, return an empty list.
                    
                    Here are some exampels:
                    'How many goals did Henry score for Arsnl in the 2015 season?'
                    person_name': ['Henry'], 'team_name': [Arsnl],'year_season': ['2015'],
                    
                    Passage:
                    {input}
    """

    return PromptCleaner(db, schema_config, custom_extractor_prompt=pre_prompt)


if __name__ == "__main__":


    schema_config = load_json("src/conf/schema.json")
    # Add game and league to the schema_config

    # prompter = PromptCleaner(db, schema_config, custom_extractor_prompt=extract_prompt)
    prompter = create_extractor("src/conf/schema.json", "sqlite:///data/games.db")
    prompt= prompter.clean("Give me goals, shots on target, shots off target and corners from the game between ManU and Swansa")


    print(prompt)