SoccerRAG / src /extractor.py
buzzCraft
Created setup.py and updated readme
be5af2d
raw
history blame
No virus
27.6 kB
from typing import Optional
from langchain.chains import create_extraction_chain_pydantic
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_extraction_chain
from copy import deepcopy
from langchain_openai import ChatOpenAI
from langchain_community.utilities import SQLDatabase
import os
import difflib
import ast
import json
import re
from thefuzz import process
# Set up logging
import logging
from dotenv import load_dotenv
load_dotenv(".env")
logging.basicConfig(level=logging.INFO)
# Save the log to a file
handler = logging.FileHandler('extractor.log')
logger = logging.getLogger(__name__)
os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
# os.environ["ANTHROPIC_API_KEY"] = os.getenv('ANTHROPIC_API_KEY')
if os.getenv('LANGSMITH'):
os.environ['LANGCHAIN_TRACING_V2'] = 'true'
os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
os.environ[
'LANGCHAIN_API_KEY'] = os.getenv("LANGSMITH_API_KEY")
os.environ['LANGCHAIN_PROJECT'] = os.getenv('LANGSMITH_PROJECT')
db_uri = os.getenv('DATABASE_PATH')
db_uri = f"sqlite:///{db_uri}"
db = SQLDatabase.from_uri(db_uri)
few_shot_n = os.getenv('FEW_SHOT')
few_shot_n = int(few_shot_n)
# from langchain_anthropic import ChatAnthropic
class Extractor():
# llm = ChatOpenAI(model_name="gpt-4-0125-preview", temperature=0)
# gpt-3.5-turbo
def __init__(self, model="gpt-3.5-turbo-0125", schema_config=None, custom_extractor_prompt=None):
# model = "gpt-4-0125-preview"
if custom_extractor_prompt:
cust_promt = ChatPromptTemplate.from_template(custom_extractor_prompt)
self.llm = ChatOpenAI(model=model, temperature=0)
# self.llm = ChatAnthropic(model="claude-3-opus-20240229", temperature=0)
self.schema = schema_config or {}
self.chain = create_extraction_chain(self.schema, self.llm, prompt=cust_promt)
def extract(self, query):
return self.chain.invoke(query)
class Retriever():
def __init__(self, db, config):
self.db = db
self.config = config
self.table = config.get('db_table')
self.column = config.get('db_column')
self.pk_column = config.get('pk_column')
self.numeric = config.get('numeric', False)
self.response = []
self.query = f"SELECT {self.column} FROM {self.table}"
self.augmented_table = config.get('augmented_table', None)
self.augmented_column = config.get('augmented_column', None)
self.augmented_fk = config.get('augmented_fk', None)
def query_as_list(self):
# Execute the query
response = self.db.run(self.query)
response = [el for sub in ast.literal_eval(response) for el in sub if el]
if not self.numeric:
response = [re.sub(r"\b\d+\b", "", string).strip() for string in response]
self.response = list(set(response))
# print(self.response)
return self.response
def get_augmented_items(self, prompt):
if self.augmented_table is None:
return None
else:
# Construct the query to search for the prompt in the augmented table
query = f"SELECT {self.augmented_fk} FROM {self.augmented_table} WHERE LOWER({self.augmented_column}) = LOWER('{prompt}')"
# Execute the query
fk_response = self.db.run(query)
if fk_response:
# Extract the FK value
fk_response = ast.literal_eval(fk_response)
fk_value = fk_response[0][0]
query = f"SELECT {self.column} FROM {self.table} WHERE {self.pk_column} = {fk_value}"
# Execute the query
matching_response = self.db.run(query)
# Extract the matching response
matching_response = ast.literal_eval(matching_response)
matching_response = matching_response[0][0]
return matching_response
else:
return None
def find_close_matches(self, target_string, n=3, method="difflib", threshold=70):
"""
Find and return the top n close matches to target_string in the database query results.
Args:
- target_string (str): The string to match against the database results.
- n (int): Number of top matches to return.
Returns:
- list of tuples: Each tuple contains a match and its score.
"""
# Ensure we have the response list populated
if not self.response:
self.query_as_list()
# Find top n close matches
if method == "fuzzy":
# Use the fuzzy_string method to get matches and their scores
# If the threshold is met, return the best match; otherwise, return all matches meeting the threshold
top_matches = self.fuzzy_string(target_string, limit=n, threshold=threshold)
else:
# Use difflib's get_close_matches to get the top n matches
top_matches = difflib.get_close_matches(target_string, self.response, n=n, cutoff=0.2)
return top_matches
def fuzzy_string(self, prompt, limit, threshold=80, low_threshold=30):
# Get matches and their scores, limited by the specified 'limit'
matches = process.extract(prompt, self.response, limit=limit)
filtered_matches = [match for match in matches if match[1] >= threshold]
# If no matches meet the threshold, return the list of all matches' strings
if not filtered_matches:
# Return matches above the low_threshold
# Fix for wrong properties being returned
return [match[0] for match in matches if match[1] >= low_threshold]
# If there's only one match meeting the threshold, return it as a string
if len(filtered_matches) == 1:
return filtered_matches[0][0] # Return the matched string directly
# If there's more than one match meeting the threshold or ties, return the list of matches' strings
highest_score = filtered_matches[0][1]
ties = [match for match in filtered_matches if match[1] == highest_score]
# Return the strings of tied matches directly, ignoring the scores
m = [match[0] for match in ties]
if len(m) == 1:
return m[0]
return [match[0] for match in ties]
def fetch_pk(self, property_name, property_value):
# Some properties do not have a primary key
# Return the property value if no primary key is specified
pk_list = []
# Check if the property_value is a list; if not, make it a list for uniform processing
if not isinstance(property_value, list):
property_value = [property_value]
# Some properties do not have a primary key
# Return None for each property_value if no primary key is specified
if self.pk_column is None:
return [None for _ in property_value]
for value in property_value:
query = f"SELECT {self.pk_column} FROM {self.table} WHERE {self.column} = '{value}' LIMIT 1"
response = self.db.run(query)
# Append the response (PK or None) to the pk_list
pk_list.append(response)
return pk_list
def setup_retrievers(db, schema_config):
# retrievers = {}
# for prop, config in schema_config["properties"].items():
# retrievers[prop] = Retriever(db=db, config=config)
# return retrievers
retrievers = {}
# Iterate over each property in the schema_config's properties
for prop, config in schema_config["properties"].items():
# Access the 'items' dictionary for the configuration of the array's elements
item_config = config['items']
# Create a Retriever instance using the item_config
retrievers[prop] = Retriever(db=db, config=item_config)
return retrievers
def extract_properties(prompt, schema_config, custom_extractor_prompt=None):
"""Extract properties from the prompt."""
# modify schema_conf to only include the required properties
schema_stripped = {'properties': {}}
for key, value in schema_config['properties'].items():
schema_stripped['properties'][key] = {
'type': value['type'],
'items': {'type': value['items']['type']}
}
extractor = Extractor(schema_config=schema_stripped, custom_extractor_prompt=custom_extractor_prompt)
extraction_result = extractor.extract(prompt)
# print("Extraction Result:", extraction_result)
if 'text' in extraction_result and extraction_result['text']:
properties = extraction_result['text']
return properties
else:
print("No properties extracted.")
return None
def recheck_property_value(properties, property_name, value, retrievers):
while True:
print(property_name)
new_value = input(f"Enter new value for {property_name} - {value} or type 'quit' to stop: ")
if new_value.lower() == 'quit':
break # Exit the loop and do not update the property
new_top_matches = retrievers.find_close_matches(new_value, n=few_shot_n)
if new_top_matches:
# Display new top matches and ask for confirmation or re-entry
print("\nNew close matches found:")
for i, match in enumerate(new_top_matches, start=1):
print(f"[{i}] {match}")
print(f"[{i+1}] Re-enter value")
print(f"[{i+2}] Quit without updating")
selection = input(f"Select the best match (1-{i}), choose {i+1} to re-enter value, or {i+2} to quit: ")
if selection in [str(i) for i in range(1, i + 1)]:
selected_match = new_top_matches[int(selection) - 1]
properties[property_name] = selected_match # Update the dictionary directly
print(f"Updated {property_name} to {selected_match}")
break # Successfully updated, exit the loop
elif selection == f'{i+2}':
break # Quit without updating
# Loop will continue if user selects 4 or inputs invalid selection
else:
print("No close matches found. Please try again or type 'quit' to stop.")
def check_and_update_properties(properties_list, retrievers, method="fuzzy", input_func="input"):
"""
Checks and updates the properties in the properties list based on close matches found in the database.
The function iterates through each property in each property dictionary within the list,
finds close matches for it in the database using the retrievers, and updates the property
value based on user selection.
Args:
properties_list (list of dict): A list of dictionaries, where each dictionary contains properties
to check and potentially update based on database matches.
retrievers (dict): A dictionary of Retriever objects keyed by property name, used to find close matches in the database.
input_func (function, optional): A function to capture user input. Defaults to the built-in input function.
The function updates the properties_list in place based on user choices for updating property values
with close matches found by the retrievers.
"""
return_list = []
for index, properties in enumerate(properties_list):
for property_name, retriever in retrievers.items(): # Iterate using items to get both key and value
property_values = properties.get(property_name, [])
if not property_values: # Skip if the property is not present or is an empty list
continue
updated_property_values = [] # To store updated list of values
for value in property_values:
if retriever.augmented_table:
augmented_value = retriever.get_augmented_items(value)
if augmented_value:
updated_property_values.append(augmented_value)
continue
# Since property_value is now expected to be a list, we handle each value individually
n = few_shot_n
# if input_func == "chainlit":
# n = 5
# else:
# n = 3
top_matches = retriever.find_close_matches(value, method=method, n=n)
# Check if the closest match is the same as the current value
if top_matches and top_matches[0] == value:
updated_property_values.append(value)
continue
if not top_matches:
updated_property_values.append(value) # Keep the original value if no matches found
continue
if type(top_matches) == str and method == "fuzzy":
# If the top_matches is a string, it means that the threshold was met and only one item was returned
# In this case, we can directly update the property with the top match
updated_property_values.append(top_matches)
properties[property_name] = updated_property_values
continue
if input_func == "input":
print(f"\nCurrent {property_name}: {value}")
for i, match in enumerate(top_matches, start=1):
print(f"[{i}] {match}")
print(f"[{i+1}] Enter new value")
# hmm = input(f"Fix for Pycharm, press enter to continue")
choice = input(f"Select the best match for {property_name} (1-{i+1}): ")
# if choice == in range(1, i)
if choice in [str(i) for i in range(1, i+1)]:
selected_match = top_matches[int(choice) - 1]
updated_property_values.append(selected_match) # Update with the selected match
print(f"Updated {property_name} to {selected_match}")
elif choice == f'{i+1}':
# Allow re-entry of value for this specific item
recheck_property_value(properties, property_name, value, retriever)
# Note: Implement recheck_property_value to handle individual value updates within the list
else:
print("Invalid selection. Property not updated.")
updated_property_values.append(value) # Keep the original value
elif input_func == "chainlit": # If we use UI, just return the list of top matches, and then let the user select
options = {property_name: value, "top_matches": top_matches}
return_list.append(options)
# Update the entire list for the property after processing all values
properties[property_name] = updated_property_values
if input_func == "chainlit":
return properties, return_list
else:
return properties
# Function to remove duplicates
def remove_duplicates(dicts):
seen = {} # Dictionary to keep track of seen values for each key
for d in dicts:
for key in list(d.keys()): # Use list to avoid RuntimeError for changing dict size during iteration
value = d[key]
if key in seen and value == seen[key]:
del d[key] # Remove key-value pair if duplicate is found
else:
seen[key] = value # Update seen values for this key
return dicts
def fetch_pks(properties_list, retrievers):
all_pk_attributes = [] # Initialize a list to store dictionaries of _pk attributes for each item in properties_list
# Iterate through each properties dictionary in the list
for properties in properties_list:
pk_attributes = {} # Initialize a dictionary for the current set of properties
for property_name, property_value in properties.items():
if property_name in retrievers:
# Fetch the primary key using the retriever for the current property
pk = retrievers[property_name].fetch_pk(property_name, property_value)
# Store it in the dictionary with a modified key name
pk_attributes[f"{property_name}_pk"] = pk
# Add the dictionary of _pk attributes for the current set of properties to the list
all_pk_attributes.append(pk_attributes)
# Return a list of dictionaries, where each dictionary contains _pk attributes for a set of properties
return all_pk_attributes
# def update_prompt(prompt, properties, pk, properties_original):
# # Replace the original prompt with the updated properties and pk
# prompt = prompt.replace("{{properties}}", str(properties))
# prompt = prompt.replace("{{pk}}", str(pk))
# return prompt
def update_prompt(prompt, properties, pk, properties_original, retrievers):
updated_info = ""
for prop, pk_info, prop_orig in zip(properties, pk, properties_original):
for key in prop.keys():
# Extract original and updated values
if key in retrievers:
# Fetch the primary key using the retriever for the current property
table = retrievers[key].table
orig_values = prop_orig.get(key, [])
updated_values = prop.get(key, [])
# Ensure both original and updated values are lists for uniform processing
if not isinstance(orig_values, list):
orig_values = [orig_values]
if not isinstance(updated_values, list):
updated_values = [updated_values]
# Extract primary key detail for this key, handling various pk formats carefully
pk_key = f"{key}_pk" # Construct pk key name based on the property key
pk_details = pk_info.get(pk_key, [])
if not isinstance(pk_details, list):
pk_details = [pk_details]
for orig_value, updated_value, pk_detail in zip(orig_values, updated_values, pk_details):
pk_value = None
if isinstance(pk_detail, str):
pk_value = pk_detail.strip("[]()").split(",")[0].replace("'", "").replace('"', '')
update_statement = ""
# Skip updating if there's no change in value to avoid redundant info
if orig_value != updated_value and pk_value:
update_statement = f"\n- {orig_value} (now referred to as {updated_value}) has a primary key: {pk_value}."
elif orig_value != updated_value:
update_statement = f"\n- {orig_value} (now referred to as {updated_value}.)"
elif pk_value:
update_statement = f"\n- {orig_value} has a primary key: {pk_value}."
elif orig_value == updated_value and pk_value:
update_statement = f"\n- {orig_value} has a primary key: {pk_value}."
elif orig_value == updated_value:
update_statement = f"\n- {orig_value}."
updated_info += update_statement
if updated_info:
prompt += "\nUpdated Information:" + updated_info
return prompt
def prompt_cleaner(prompt, db, schema_config):
"""Main function to clean the prompt."""
retrievers = setup_retrievers(db, schema_config)
properties = extract_properties(prompt, schema_config)
# Keep original properties for later use
properties_original = deepcopy(properties)
# Remove duplicates - Happens when there are more than one player or team in the prompt
properties = remove_duplicates(properties)
if properties:
check_and_update_properties(properties, retrievers)
pk = fetch_pks(properties, retrievers)
properties = update_prompt(prompt, properties, pk, properties_original)
return properties, pk
class PromptCleaner:
"""
A class designed to clean and process prompts by extracting properties, removing duplicates,
and updating these properties based on a predefined schema configuration and database interactions.
Attributes:
db: A database connection object used to execute queries and fetch data.
schema_config: A dictionary defining the schema configuration for the extraction process.
schema_config = {
"properties": {
# Property name
"person_name": {"type": "string", "db_table": "players", "db_column": "name", "pk_column": "hash",
# if mostly numeric, such as 2015-2016 set true
"numeric": False},
"team_name": {"type": "string", "db_table": "teams", "db_column": "name", "pk_column": "id",
"numeric": False},
# Add more as needed
},
# Parameter to extractor, if person_name is required, add it here and the extractor will
# return an error if it is not found
"required": [],
}
Methods:
clean(prompt): Cleans the given prompt by extracting and updating properties based on the database.
Returns a tuple containing the updated properties and their primary keys.
"""
def __init__(self, db=db, schema_config=None, custom_extractor_prompt=None):
"""
Initializes the PromptCleaner with a database connection and a schema configuration.
Args:
db: The database connection object to be used for querying. (if none, it will use the default db)
schema_config: A dictionary defining properties and their database mappings for extraction and updating.
"""
self.db = db
self.schema_config = schema_config
self.retrievers = setup_retrievers(self.db, self.schema_config)
self.cust_extractor_prompt = custom_extractor_prompt
self.properties_original = None
def clean(self, prompt, return_pk=False, test=False, verbose=False):
"""
Processes the given prompt to extract properties, remove duplicates, update the properties
based on close matches within the database, and fetch primary keys for these properties.
The method first extracts properties from the prompt using the schema configuration,
then checks these properties against the database to find and update close matches.
It also fetches primary keys for the updated properties where applicable.
Args:
prompt (str): The prompt text to be cleaned and processed.
return_pk (bool): A flag to indicate whether to return primary keys along with the properties.
test (bool): A flag to indicate whether to return the original properties for testing purposes.
verbose (bool): A flag to indicate whether to return the original properties for debugging.
Returns:
tuple: A tuple containing two elements:
- The first element is the original prompt, with updated information that excist in the db.
- The second element is a list of dictionaries, each containing primary keys for the properties,
where applicable.
"""
if self.cust_extractor_prompt:
properties = extract_properties(prompt, self.schema_config, self.cust_extractor_prompt)
else:
properties = extract_properties(prompt, self.schema_config)
# Keep original properties for later use
properties_original = deepcopy(properties)
if test:
return properties_original
# Remove duplicates - Happens when there are more than one player or team in the prompt
# properties = remove_duplicates(properties)
pk = None
# VALIDATE PROPERTIES
if properties:
check_and_update_properties(properties, self.retrievers)
pk = fetch_pks(properties, self.retrievers)
properties = update_prompt(prompt=prompt, properties=properties, pk=pk, properties_original=properties_original,
retrievers=self.retrievers)
# Prepare additional data if requested
if return_pk and verbose:
return (properties, pk), (properties, properties_original)
elif return_pk:
return properties, pk
elif verbose:
return properties, properties_original
return properties
def extract_chainlit(self, prompt):
if self.cust_extractor_prompt:
properties = extract_properties(prompt, self.schema_config, self.cust_extractor_prompt)
else:
properties = extract_properties(prompt, self.schema_config)
self.properties_original = deepcopy(properties)
return properties
def validate_chainlit(self, properties):
properties, need_val = check_and_update_properties(properties, self.retrievers, input_func="chainlit")
return properties, need_val
def build_prompt_chainlit(self, properties, prompt):
pk = None
# self.properties_original= deepcopy(properties)
if properties:
pk = fetch_pks(properties, self.retrievers)
prompt_new = update_prompt(prompt, properties, pk, self.properties_original, self.retrievers)
return prompt_new
def load_json(file_path: str) -> dict:
with open(file_path, 'r') as file:
return json.load(file)
def create_extractor(schema: str = "src/conf/schema.json", db: SQLDatabase = db_uri):
schema_config = load_json(schema)
db = SQLDatabase.from_uri(db)
pre_prompt = """Extract and save the relevant entities mentioned \
in the following passage together with their properties.
Only extract the properties mentioned in the 'information_extraction' function.
The questions are soccer related. game_event are things like yellow cards, goals, assists, freekick ect.
Generic properties like, "description", "home team", "away team", "game" ect should NOT be extracted.
If a property is not present and is not required in the function parameters, do not include it in the output.
If no properties are found, return an empty list.
Here are some exampels:
'How many goals did Henry score for Arsnl in the 2015 season?'
person_name': ['Henry'], 'team_name': [Arsnl],'year_season': ['2015'],
Passage:
{input}
"""
return PromptCleaner(db, schema_config, custom_extractor_prompt=pre_prompt)
if __name__ == "__main__":
schema_config = load_json("src/conf/schema.json")
# Add game and league to the schema_config
# prompter = PromptCleaner(db, schema_config, custom_extractor_prompt=extract_prompt)
prompter = create_extractor("src/conf/schema.json", "sqlite:///data/games.db")
prompt = prompter.clean(
"Give me goals, shots on target, shots off target and corners from the game between ManU and Swansa and Manchester City")
print(prompt)
# ex = create_extractor()
#
# val_list = [{'person_name': ['Cristiano Ronaldo'], 'team_name': ['Manchester City']}]
# user_prompt = "Did ronaldo play for city?"
# p = ex.build_prompt_chainlit(val_list, user_prompt)
# print(p)