buzzCraft commited on
Commit
68f18b5
1 Parent(s): 9baf55e
.env_demo ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ OPENAI_API_KEY=API_KEY_HERE
2
+ LANGSMITH = False
3
+ LANGSMITH_API_KEY=API_KEY_HERE -NOT NEEDED IF LANGSMITH IS FALSE
4
+ ```
README.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SoccerRAG: Multimodal Soccer Information Retrieval via Natural Queries
2
+
3
+ ## Setup
4
+ ````bash
5
+ pip install -r requirements.txt
6
+ ````
7
+ Rename .env_demo to .env and fill in the required fields.
8
+
9
+ ## Required data
10
+ The data required to run the code is not included in this repository.
11
+ The data can be downloaded from the [Soccernet](https://www.soccer-net.org/data).
12
+ Files needed are:
13
+ * Labels-v2.json
14
+ * Labels-captions.json
15
+
16
+ ## Running the code
17
+ To run the code, execute the following command:
18
+ ````bash
19
+ python main.py
20
+ ````
21
+ The code will prompt you to enter a natural language query.
22
+
23
+ ## Results
24
+ ..
25
+
26
+ ## Acknowledgements
27
+ ..
28
+
29
+ ## Citation
30
+ ..
31
+
data/Dataset/augmented.csv ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ name,augmented_name
2
+ Manchester United, ManU
3
+ Manchester United, Man U
4
+ Manchester United, ManUnt
5
+ Manchester United, Manchester U
6
+ Manchester United, Manchester Unt
7
+ Manchester United, Man United
data/Dataset/augmented_leauges.csv ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name,augmented_name
2
+ england_epl, epl
3
+ england_epl, premier league
4
+ england_epl, english premier league
5
+ england_epl, english premier
6
+ europe_uefa-champions-league, uefa champions league
7
+ europe_uefa-champions-league, champions league
8
+ europe_uefa-champions-league, cl
9
+ europe_uefa-champions-league, ucl
10
+ france_ligue-1, ligue 1
11
+ france_ligue-1, ligue1
12
+ germany_bundesliga, bundesliga
13
+ germany_bundesliga, bundes liga
14
+ germany_bundesliga, bundes
15
+ italy_serie-a, serie a
16
+ italy_serie-a, seriea
17
+ italy_serie-a, serie-a
18
+ spain_laliga, la liga
19
+ spain_laliga, laliga
20
+ spain_laliga, la-liga
21
+
main.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.extractor import create_extractor
2
+ from src.sql_chain import create_agent
3
+ ex = create_extractor()
4
+ ag = create_agent(llm_model="gpt-3.5-turbo-0125", verbose=False)
5
+ # ag = create_agent(llm_model = "gpt-4-0125-preview")
6
+
7
+ def query(prompt):
8
+ clean = ex.clean(prompt)
9
+ return ag.ask(clean)
10
+
11
+
12
+ if __name__ == "__main__":
13
+ while True:
14
+ inp = input("Enter a query: ")
15
+ if inp == "exit":
16
+ break
17
+ ans, _ = query(inp)
18
+ print(ans["output"])
19
+ exit(0)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai
2
+ langchainhub
3
+ langchain == 0.1.4
4
+ langchain_openai
5
+ langchain_experimental
6
+ sqlalchemy
7
+ python-dotenv
8
+ chromadb
9
+ python-Levenshtein
10
+ rapidfuzz
11
+ thefuzz
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (180 Bytes). View file
 
src/__pycache__/extractor.cpython-311.pyc ADDED
Binary file (26.1 kB). View file
 
src/__pycache__/sql_chain.cpython-311.pyc ADDED
Binary file (9.27 kB). View file
 
src/conf/extractor_prompt.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "extract_prompt": "Extract and save the relevant entities mentioned in the following passage together with their properties.\\n\\n Only extract the properties mentioned in the 'information_extraction' function.\\n\\n The questions are football related. game_event can be things like yellow cards, goals, assists etc.\\n\\n If a property is not present and is not required in the function parameters, do not include it in the output.\\n\\n Passage:\\n {input}\\n ",
3
+
4
+ }
src/conf/schema.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "properties": {
3
+ "person_name": {
4
+ "type": "array",
5
+ "items": {
6
+ "type": "string",
7
+ "db_table": "players",
8
+ "db_column": "name",
9
+ "pk_column": "hash",
10
+ "numeric": false
11
+ }
12
+ },
13
+ "team_name": {
14
+ "type": "array",
15
+ "items": {
16
+ "type": "string",
17
+ "db_table": "teams",
18
+ "db_column": "name",
19
+ "pk_column": "id",
20
+ "numeric": false,
21
+ "augmented_table": "augmented_teams",
22
+ "augmented_column": "augmented_name",
23
+ "augmented_fk": "team_id"
24
+ }
25
+ },
26
+ "year_season": {
27
+ "type": "array",
28
+ "items": {
29
+ "type": "string",
30
+ "db_table": "games",
31
+ "db_column": "season",
32
+ "pk_column": null,
33
+ "numeric": true
34
+ }
35
+ },
36
+ "in_game_event": {
37
+ "type": "array",
38
+ "items": {
39
+ "type": "string",
40
+ "db_table": "events",
41
+ "db_column": "label",
42
+ "pk_column": null,
43
+ "numeric": false
44
+ }
45
+ },
46
+ "league": {
47
+ "type": "array",
48
+ "items": {
49
+ "type": "string",
50
+ "db_table": "leagues",
51
+ "db_column": "name",
52
+ "pk_column": "id",
53
+ "numeric": false,
54
+ "augmented_table": "augmented_leagues",
55
+ "augmented_column": "augmented_name",
56
+ "augmented_fk": "league_id"
57
+ }
58
+ }
59
+ },
60
+ "required": []
61
+ }
src/conf/sqls.json ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "input": "List all teams",
4
+ "query": "SELECT * FROM teams;"
5
+ },
6
+ {
7
+ "input": "Find a player by name",
8
+ "query": "SELECT * FROM players WHERE name = 'name';"
9
+ },
10
+ {
11
+ "input": "Select the names of teams and calculate their total home and away goals in the 2016-2017 season. Count the total matches played and calculate the average goals per match. Order the teams by their total goals scored and limit the results to the top three.",
12
+ "query": "SELECT teams.name, SUM(CASE WHEN games.home_team_id = teams.id THEN games.goal_home ELSE 0 END) AS home_goals, SUM(CASE WHEN games.away_team_id = teams.id THEN games.goal_away ELSE 0 END) AS away_goals, COUNT(*) AS matches_played, (SUM(CASE WHEN games.home_team_id = teams.id THEN games.goal_home ELSE 0 END) + SUM(CASE WHEN games.away_team_id = teams.id THEN games.goal_away ELSE 0 END)) * 1.0 / COUNT(*) AS avg_goals_per_match FROM games INNER JOIN teams ON teams.id = games.home_team_id OR teams.id = games.away_team_id WHERE games.season = '2016-2017' GROUP BY teams.name ORDER BY (home_goals + away_goals) DESC LIMIT 3;');"
13
+ },
14
+ {
15
+ "input": "Retrieve the name and country of a player identified by a specific hash value.",
16
+ "query": "SELECT players.name, players.country FROM players WHERE players.hash = 'hash';"
17
+ },
18
+ {
19
+ "input": "Get information about what happened in a time period",
20
+ "query": "SELECT event_time_start, event_time_end, period, description FROM Commentary WHERE game_id = game_id AND period = period AND ABS(event_time_start - start_time) <= duration;"
21
+ },
22
+ {
23
+ "input": "For the a game with ID X, list the home team's name, players' names (with a captain indicator), and shirt numbers.",
24
+ "query": "SELECT t.name AS team_name, p.name || CASE WHEN l.captain THEN ' (C)' ELSE '' END AS player_name, l.shirt_number FROM games g JOIN teams t ON g.home_team_id = t.id JOIN game_lineup l ON t.id = l.team_id AND l.game_id = g.id JOIN players p ON l.player_id = p.hash WHERE g.id = X;"
25
+ },
26
+ {
27
+ "input": "Who was the home team, and away team in game X?",
28
+ "query": "SELECT home_team.name AS home_team, away_team.name AS away_team FROM games JOIN teams AS home_team ON games.home_team_id = home_team.id JOIN teams AS away_team ON games.away_team_id = away_team.id WHERE games.id = X;"
29
+ },
30
+ {
31
+ "input": "For game X, list all Shots on targets and goals, for each team (using their name not id) for each period",
32
+ "query": "SELECT t.name AS team_name, a.period, SUM(CASE WHEN a.label = 'Shots on target' THEN 1 ELSE 0 END) AS shots_on_target, SUM(CASE WHEN a.label = 'Goal' THEN 1 ELSE 0 END) AS goals FROM annotations a JOIN teams t ON a.team_id = t.id WHERE a.game_id = X AND (a.label = 'Shots on target' OR a.label = 'Goal') GROUP BY t.name, a.period ORDER BY t.name, a.period;"
33
+ },
34
+ {
35
+ "input": "How many offsides were caused by the away team in game X, also get the time of the event",
36
+ "query": "SELECT a.game_id, a.label, a.position, a.period FROM annotations a JOIN games g ON a.game_id = g.id WHERE a.game_id = X AND a.label = 'Offside' AND a.team_id = g.away_team_id;"
37
+ },
38
+ {
39
+ "input": "all goals scored by <team> in <season>",
40
+ "query": "SELECT t.name AS TeamName, g.season, SUM(CASE WHEN g.home_team_id = t.id THEN g.goal_home ELSE 0 END + CASE WHEN g.away_team_id = t.id THEN g.goal_away ELSE 0 END) AS TotalGoals FROM games g JOIN teams t ON g.home_team_id = t.id OR g.away_team_id = t.id WHERE t.name = '<team>' AND g.season = '<season>' GROUP BY t.name, g.season;"
41
+ },
42
+ {
43
+ "input": "All games played by <team> in <season> in <league>",
44
+ "query": "SELECT g.id, g.date, g.season, l.name AS LeagueName, ht.name AS HomeTeam, at.name AS AwayTeam, g.score FROM games g JOIN teams ht ON g.home_team_id = ht.id JOIN teams at ON g.away_team_id = at.id JOIN leagues l ON g.league_id = l.id WHERE (ht.name = '<team>' OR at.name = '<team>') AND l.name = '<league>' AND g.season = '<season>';"
45
+ },
46
+ {
47
+ "input": "List all teams that played against <team> in season <season> and league <league>",
48
+ "query": "SELECT DISTINCT CASE WHEN ht.name = '<team>' THEN at.name ELSE ht.name END AS OpponentTeam FROM games g JOIN teams ht ON g.home_team_id = ht.id JOIN teams at ON g.away_team_id = at.id JOIN leagues l ON g.league_id = l.id WHERE (ht.name = '<team>' OR at.name = '<team>') AND l.name = '<league>' AND g.season = '<season>' ORDER BY OpponentTeam;"
49
+ },
50
+ {
51
+ "input": "Get home and away stats for <team> in <season>",
52
+ "query": "WITH home_games AS (SELECT g.id, g.season, g.home_team_id AS team_id, CASE WHEN g.goal_home > g.goal_away THEN 1 ELSE 0 END AS won, CASE WHEN g.goal_home = g.goal_away THEN 1 ELSE 0 END AS draw, CASE WHEN g.goal_home < g.goal_away THEN 1 ELSE 0 END AS lost FROM games g JOIN teams t ON g.home_team_id = t.id WHERE t.name = '<team>' AND g.season = '<season>'), away_games AS (SELECT g.id, g.season, g.away_team_id AS team_id, CASE WHEN g.goal_away > g.goal_home THEN 1 ELSE 0 END AS won, CASE WHEN g.goal_away = g.goal_home THEN 1 ELSE 0 END AS draw, CASE WHEN g.goal_away < g.goal_home THEN 1 ELSE 0 END AS lost FROM games g JOIN teams t ON g.away_team_id = t.id WHERE t.name = '<team>' AND g.season = '<season>'), home_stats AS (SELECT COUNT(*) AS total_home_games, SUM(won) AS home_wins, SUM(draw) AS home_draws, SUM(lost) AS home_losses FROM home_games), away_stats AS (SELECT COUNT(*) AS total_away_games, SUM(won) AS away_wins, SUM(draw) AS away_draws, SUM(lost) AS away_losses FROM away_games) SELECT hs.total_home_games, hs.home_wins, hs.home_draws, hs.home_losses, as_stats.total_away_games, as_stats.away_wins, as_stats.away_draws, as_stats.away_losses FROM home_stats hs, away_stats as_stats;"
53
+ },
54
+ {
55
+ "input": "How many goals did <player> score in <season> in <league>?",
56
+ "query": "SELECT COUNT(*) AS goal_count FROM player_events pe JOIN players p ON pe.player_id = p.hash JOIN games g ON pe.game_id = g.id JOIN leagues l ON g.league_id = l.id JOIN player_event_labels pel ON pe.type = pel.id WHERE p.name = <player> AND g.season = <season> AND l.name = <league> AND pel.label = 'Goal';"
57
+ },
58
+ {
59
+ "input": "How many goals did <player> score in <season>?",
60
+ "query": "SELECT COUNT(*) AS goal_count FROM player_events pe JOIN players p ON pe.player_id = p.hash JOIN games g ON pe.game_id = g.id JOIN player_event_labels pel ON pe.type = pel.id WHERE p.name = <player> AND g.season = <season> AND pel.label = 'Goal';"
61
+ },
62
+ {
63
+ "input": "List all teams that played against <team> in season <season>",
64
+ "query": "SELECT DISTINCT opponent.name AS opponent_name FROM games JOIN teams AS opponent ON (opponent.id = games.home_team_id OR opponent.id = games.away_team_id) JOIN teams AS specified_team ON (specified_team.id = games.home_team_id OR specified_team.id = games.away_team_id) WHERE (games.home_team_id = (SELECT id FROM teams WHERE name = '<team>') OR games.away_team_id = (SELECT id FROM teams WHERE name = '<team>')) AND games.season = '<season>' AND opponent.name != '<team>'"
65
+ },
66
+ {
67
+ "input": "List all teams in <league> in <season>",
68
+ "query": "SELECT DISTINCT team.name FROM games JOIN teams team ON team.id = games.home_team_id OR team.id = games.away_team_id WHERE games.league_id = (SELECT id FROM leagues WHERE name = '<league_name>') AND games.season = '<season>'"
69
+ },
70
+ {
71
+ "input": "List all games in <league> in <season> with <event> in first half",
72
+ "query": "SELECT ht.name AS home_team, at.name AS away_team, g.score, g.date FROM games g JOIN leagues l ON g.league_id = l.id JOIN events e ON g.id = e.game_id AND g.home_team_id = e.team_id JOIN teams ht ON g.home_team_id = ht.id JOIN teams at ON g.away_team_id = at.id WHERE l.name = '<leauge>' AND g.season = '<season>' AND e.period = 1 AND e.label = '<event>' GROUP BY g.id;"
73
+ },
74
+ {
75
+ "input": "List all games in <league> in <season> with <event>, and include the number of times the event occurred",
76
+ "query": "SELECT ht.name AS home_team, at.name AS away_team, g.score, g.date, COUNT(e.id) AS event_count FROM games g JOIN leagues l ON g.league_id = l.id JOIN events e ON g.id = e.game_id AND g.home_team_id = e.team_id JOIN teams ht ON g.home_team_id = ht.id JOIN teams at ON g.away_team_id = at.id WHERE l.name = '<leauge>' AND g.season = '<season>' AND e.label = '<event>' GROUP BY g.id;"
77
+ },
78
+ {
79
+ "input": "What teams and in what season did <player> play in?",
80
+ "query": "SELECT DISTINCT p.name AS player_name, t.name AS team_name, g.season, l.name AS league_name FROM game_lineup gl JOIN players p ON gl.player_id = p.hash JOIN teams t ON gl.team_id = t.id JOIN games g ON gl.game_id = g.id JOIN leagues l ON g.league_id = l.id WHERE p.name = '<player>' ORDER BY p.name, t.name, g.season, l.name;"
81
+ },
82
+ {
83
+ "input": "List all players in <team> in <season>",
84
+ "query": "SELECT DISTINCT p.name AS player_name FROM game_lineup gl JOIN players p ON gl.player_id = p.hash JOIN teams t ON gl.team_id = t.id JOIN games g ON gl.game_id = g.id WHERE t.name = '<team>' AND g.season = '<season>' ORDER BY p.name;"
85
+ },
86
+ {
87
+ "input": "List all teams a player has played for",
88
+ "query": "SELECT DISTINCT t.name AS team_name FROM game_lineup gl JOIN players p ON gl.player_id = p.hash JOIN teams t ON gl.team_id = t.id WHERE p.name = '<player>' ORDER BY t.name;"
89
+ },
90
+ {
91
+ "input": "List all yellow and red cards for game <game_id>, sorted by time",
92
+ "query": "SELECT p.name AS player_name, pel.label AS card_type, pe.time AS event_time FROM player_events pe JOIN players p ON pe.player_id = p.hash JOIN player_event_labels pel ON pe.type = pel.id WHERE pe.game_id = <game_id> AND (pel.label = 'Yellow card' OR pel.label = 'Red card') ORDER BY CAST(pe.time AS UNSIGNED) ASC;"
93
+ },
94
+ {
95
+ "input": "What player had the first <event> in league <league> in season <season>?",
96
+ "query": "SELECT p.name AS player_name, pe.game_id, pe.time AS event_time FROM player_events pe JOIN players p ON pe.player_id = p.hash JOIN (SELECT g.id FROM games g JOIN leagues l ON g.league_id = l.id WHERE g.season = '<season>' AND l.id = <leauge_id> ORDER BY g.id LIMIT 1) AS first_game ON pe.game_id = first_game.id JOIN player_event_labels pel ON pe.type = pel.id WHERE pel.label = <event> ORDER BY CAST(pe.time AS UNSIGNED) ASC LIMIT 1;"
97
+ },
98
+ {
99
+ "input": "How many times did <player> get substituted in <season>?",
100
+ "query": "SELECT COUNT(*) AS substitution_count FROM player_events pe JOIN players p ON pe.player_id = p.hash JOIN games g ON pe.game_id = g.id WHERE p.hash = <player_hash> AND g.season = <season> AND (pe.type = 6 or pe.type = 7)"
101
+ }
102
+
103
+ ]
104
+
105
+
src/database/database.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import create_engine, Column, Integer, String, ForeignKey, Text, Float, Boolean, UniqueConstraint
2
+ from sqlalchemy.orm import declarative_base, sessionmaker
3
+ import pandas as pd
4
+ import os
5
+ import json
6
+
7
+ engine = create_engine('sqlite:///../../data/games.db', echo=False)
8
+ Base = declarative_base()
9
+
10
+
11
+ class Game(Base):
12
+ __tablename__ = 'games'
13
+ id = Column(Integer, primary_key=True)
14
+ timestamp = Column(String)
15
+ score = Column(String)
16
+ goal_home = Column(Integer)
17
+ goal_away = Column(Integer)
18
+ round = Column(String)
19
+ home_team_id = Column(Integer, ForeignKey('teams.id'))
20
+ away_team_id = Column(Integer, ForeignKey('teams.id'))
21
+ venue = Column(String)
22
+ referee = Column(String)
23
+ attendance = Column(String)
24
+ date = Column(String)
25
+ season = Column(String)
26
+ league_id = Column(Integer, ForeignKey('leagues.id'))
27
+
28
+ class GameLineup(Base):
29
+ __tablename__ = 'game_lineup'
30
+ id = Column(Integer, primary_key=True)
31
+ game_id = Column(Integer, ForeignKey('games.id'))
32
+ team_id = Column(Integer, ForeignKey('teams.id'))
33
+ player_id = Column(Integer, ForeignKey('players.hash'))
34
+ shirt_number = Column(String)
35
+ position = Column(String)
36
+ starting = Column(Boolean)
37
+ captain = Column(Boolean)
38
+ coach = Column(Boolean)
39
+ tactics = Column(String)
40
+ # Add a unique constraint on game_id and player_id
41
+ __table_args__ = (UniqueConstraint('game_id', 'player_id', name='uc_game_id_player_id'),)
42
+
43
+
44
+ class Team(Base):
45
+ __tablename__ = 'teams'
46
+ id = Column(Integer, primary_key=True)
47
+ name = Column(String)
48
+
49
+ class Player(Base):
50
+ __tablename__ = 'players'
51
+ hash = Column(String, primary_key=True)
52
+ name = Column(String)
53
+ country = Column(String)
54
+
55
+
56
+ class Caption(Base):
57
+ __tablename__ = 'captions'
58
+ id = Column(Integer, primary_key=True)
59
+ game_id = Column(Integer, ForeignKey('games.id'))
60
+ game_time = Column(String)
61
+ period = Column(Integer)
62
+ label = Column(String)
63
+ description = Column(Text)
64
+ important = Column(Boolean)
65
+ visibility = Column(Boolean)
66
+ frame_stamp = Column(Integer)
67
+
68
+
69
+ class Commentary(Base):
70
+ __tablename__ = 'commentary'
71
+ id = Column(Integer, primary_key=True)
72
+ game_id = Column(Integer, ForeignKey('games.id'))
73
+ period = Column(Integer)
74
+ event_time_start = Column(Float)
75
+ event_time_end = Column(Float)
76
+ description = Column(Text)
77
+
78
+ class League(Base):
79
+ __tablename__ = 'leagues'
80
+ id = Column(Integer, primary_key=True)
81
+ name = Column(String)
82
+
83
+ class Event(Base):
84
+ __tablename__ = 'events'
85
+ id = Column(Integer, primary_key=True)
86
+ game_id = Column(Integer, ForeignKey('games.id'))
87
+ period = Column(Integer)
88
+ # half = Column(Integer)
89
+ game_time = Column(Integer)
90
+ team_id = Column(Integer, ForeignKey('teams.id'))
91
+ frame_stamp = Column(Integer)
92
+ label = Column(String)
93
+ visibility = Column(Boolean)
94
+
95
+ class Augmented_Team(Base):
96
+ __tablename__ = 'augmented_teams'
97
+ id = Column(Integer, primary_key=True)
98
+ team_id = Column(Integer, ForeignKey('teams.id'))
99
+ augmented_name = Column(String)
100
+
101
+ class Augmented_League(Base):
102
+ __tablename__ = 'augmented_leagues'
103
+ id = Column(Integer, primary_key=True)
104
+ league_id = Column(Integer, ForeignKey('leagues.id'))
105
+ augmented_name = Column(String)
106
+
107
+ class Player_Event_Label(Base):
108
+ __tablename__ = 'player_event_labels'
109
+ id = Column(Integer, primary_key=True)
110
+ label = Column(String)
111
+
112
+ class Player_Event(Base):
113
+ __tablename__ = 'player_events'
114
+ id = Column(Integer, primary_key=True)
115
+ game_id = Column(Integer, ForeignKey('games.id'))
116
+ player_id = Column(Integer, ForeignKey('players.hash'))
117
+ time = Column(String) # Time in minutes of the game
118
+ type = Column(Integer, ForeignKey('player_event_labels.id'))
119
+ linked_player = Column(Integer, ForeignKey('players.hash')) # If the event is linked to another player, for example a substitution
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+ # Create Tables
128
+ Base.metadata.create_all(engine)
129
+
130
+ # Session setup
131
+ Session = sessionmaker(bind=engine)
132
+
133
+ def extract_time_from_player_event(time:str)->str:
134
+ # Extract the time from the string
135
+ time = time.split("'")[0] # Need to keep it str because of overtime eg. (45+2)
136
+ return time
137
+
138
+ def get_or_create(session, model, **kwargs):
139
+ instance = session.query(model).filter_by(**kwargs).first()
140
+ if instance:
141
+ return instance
142
+ else:
143
+ instance = model(**kwargs)
144
+ session.add(instance)
145
+ session.commit()
146
+ return instance
147
+
148
+ def process_game_data(data,data2, league, season):
149
+ session = Session()
150
+ # Caption = d and v2 = d2
151
+ home_team = data["gameHomeTeam"]
152
+ away_team = data["gameAwayTeam"]
153
+ score = data["score"]
154
+ home_score = score[0]
155
+ away_score = score[-1]
156
+ round_ = data["round"]
157
+ venue = data["venue"][0]
158
+ referee = data.get("referee_found", None)
159
+ referee = referee[0] if referee else data.get("referee", None)
160
+ date = data["gameDate"]
161
+ timestamp = data["timestamp"]
162
+ attendance = data.get("attendance", [])
163
+ attendance = attendance[0] if attendance else None
164
+
165
+ home_team = get_or_create(session, Team, name=home_team)
166
+ away_team = get_or_create(session, Team, name=away_team)
167
+ # Check if the game already exists
168
+ game = session.query(Game).filter_by(timestamp=timestamp, home_team_id=home_team.id).first()
169
+ # Check if league exists
170
+ league = get_or_create(session, League, name=league)
171
+ if not game:
172
+ game = Game(timestamp=timestamp, score=score, goal_home=home_score, goal_away=away_score, round=round_, home_team_id=home_team.id, away_team_id=away_team.id,
173
+ venue=venue, date=date, attendance=attendance, season=season, league_id=league.id, referee=referee)
174
+ session.add(game)
175
+ session.commit()
176
+
177
+ teams = ["home", "away"]
178
+ # Lets add lineup data
179
+ for team in teams:
180
+ if team == "home":
181
+ team_id = home_team.id
182
+ else:
183
+ team_id = away_team.id
184
+ team_lineup = data["lineup"][team]
185
+ tactic = team_lineup["tactic"]
186
+
187
+ for player_data in team_lineup["players"]:
188
+ player_hash = player_data["hash"]
189
+ name = player_data["long_name"]
190
+ if " " not in name: # Since some players are missing their first name, do this to help with the search
191
+ name = "NULL " + name
192
+ number = player_data["shirt_number"]
193
+ captain = player_data["captain"] == "(C)"
194
+ starting = player_data["starting"]
195
+ country = player_data["country"]
196
+ position = player_data["lineup"]
197
+ facts = player_data.get("facts", None) # Facts might be empty
198
+
199
+
200
+
201
+
202
+
203
+ player = get_or_create(session, Player, hash=player_hash, name=name, country=country)
204
+ game_lineup = GameLineup(game_id=game.id, team_id=team_id, player_id=player.hash,
205
+ shirt_number=number, position=position, starting=starting, captain=captain, coach=False, tactics=tactic)
206
+ if facts:
207
+ for fact in facts:
208
+ type = fact["type"]
209
+ time = extract_time_from_player_event(fact["time"])
210
+ event = get_or_create(session, Player_Event_Label, id=int(type))
211
+ linked_player = fact.get("linked_player_hash", None)
212
+
213
+ player_event = Player_Event(game_id=game.id, player_id=player.hash, time=time, type=event.id, linked_player=linked_player)
214
+ session.add(player_event)
215
+ session.add(game_lineup)
216
+
217
+ # Get the coach
218
+ coach = team_lineup["coach"][0]
219
+ coach_hash = coach["hash"]
220
+ coach_name = coach["long_name"]
221
+ if " " not in coach_name: # Since some players are missing their first name, do this to help with the search
222
+ name = "NULL " + coach_name
223
+ coach_country = coach["country"]
224
+ coach_player = get_or_create(session, Player, hash=coach_hash, name=coach_name, country=coach_country)
225
+ game_lineup = GameLineup(game_id=game.id, team_id=team_id, player_id=coach_player.hash,
226
+ shirt_number=None, position=None, starting=None, captain=False, coach=True, tactics=tactic)
227
+ session.add(game_lineup)
228
+
229
+ # Commit all changes at once
230
+ session.commit()
231
+
232
+ # Start parsing the events
233
+ events = data["annotations"]
234
+ for event in events:
235
+ period, time = convert_to_seconds(event["gameTime"])
236
+ label = event["label"]
237
+ # Renaming labels
238
+ if label == "soccer-ball":
239
+ label = "goal"
240
+ elif label == "y-card":
241
+ label = "yellow card"
242
+ elif label == "r-card":
243
+ label = "red card"
244
+
245
+ description = event["description"]
246
+ important = event["important"] == "true"
247
+ visible = event["visibility"]
248
+ # Convert to boolean
249
+ # True if shown, False if not
250
+ visible = visible == "shown"
251
+ position = int(event["position"])
252
+
253
+ event = Caption(game_id=game.id, game_time=time, period=period, label=label, description=description,
254
+ important=important, visibility=visible, frame_stamp=position)
255
+ session.add(event)
256
+ session.commit()
257
+
258
+ return game.id, home_team.id, away_team.id
259
+
260
+ def process_player_data(data):
261
+ pass
262
+
263
+ def process_ASR_data(data, game_id, period):
264
+ session = Session()
265
+ seg = data["segments"]
266
+ commentary_events = [] # Store the events in a list
267
+
268
+ for k, v in seg.items():
269
+ start = float(v[0])
270
+ end = float(v[1])
271
+ desc = v[2]
272
+ event = Commentary(game_id=game_id, period=period, event_time_start=start, event_time_end=end, description=desc)
273
+ commentary_events.append(event)
274
+
275
+ # Bulk save objects
276
+ session.bulk_save_objects(commentary_events)
277
+ session.commit()
278
+ session.close()
279
+
280
+ def convert_to_seconds(time_str):
281
+ # Split the string into its components
282
+ period, time = time_str.split(" - ")
283
+ minutes, seconds = time.split(":")
284
+
285
+ # Convert the components to integers
286
+ period = int(period)
287
+ minutes = int(minutes)
288
+ seconds = int(seconds)
289
+ # Calculate the time in seconds
290
+
291
+ total_seconds = (minutes * 60) + seconds
292
+ return period, total_seconds
293
+
294
+
295
+ def parse_labels_v2(data, session, home_team_id, away_team_id, game_id):
296
+ annotations_data = data["annotations"]
297
+ no_team = get_or_create(session, Team, name="not applicable")
298
+
299
+ for annotation in annotations_data:
300
+ period, game_time = convert_to_seconds(annotation["gameTime"])
301
+
302
+ # Determine which team the annotation belongs to
303
+ if annotation["team"] == "home":
304
+ team_id = home_team_id
305
+ elif annotation["team"] == "away":
306
+ team_id = away_team_id
307
+ else:
308
+ team_id = no_team.id
309
+
310
+ position = annotation.get("position", None) # Assuming position can be null
311
+ visibility = annotation["visibility"] == "visible"
312
+ # Convert to boolean
313
+ # True if visible, False if not
314
+ visibility = visibility == "visible"
315
+ label = annotation["label"]
316
+
317
+ # Create and add the Annotations instance
318
+ annotation_entry = Event(
319
+ game_id=game_id,
320
+ period=period, # periode
321
+ game_time=game_time, # Already in seconds
322
+ frame_stamp=position, # Make sure this is an integer or None
323
+ team_id=team_id, # Integer ID of the team
324
+ visibility=visibility, # Boolean
325
+ label=label # String with information
326
+ )
327
+ session.add(annotation_entry)
328
+
329
+ session.commit()
330
+
331
+
332
+
333
+
334
+
335
+ def process_json_files(directory):
336
+ session = Session()
337
+ fill_player_events(session)
338
+ for root, dirs, files in os.walk(directory):
339
+ print(root)
340
+ labels_file = None
341
+ asr_files = []
342
+ path_parts = root.split("\\")
343
+ if len(path_parts) > 2:
344
+ league = path_parts[-3].split("/")[-1]
345
+ season = path_parts[-2]
346
+ # Need the labels-v2 first as it contains the game ID
347
+ for file in files:
348
+ if 'Labels-caption.json' in file:
349
+ labels_file = file
350
+ elif file.endswith('.json'):
351
+ asr_files.append(file)
352
+
353
+ if labels_file:
354
+ with open(os.path.join(root, labels_file), 'r') as f:
355
+ lb_cap = json.load(f)
356
+ with open(os.path.join(root, "Labels-v2.json"), 'r') as f:
357
+ lb_v2 = json.load(f)
358
+ game_id, home_team_id, away_team_id = process_game_data(lb_cap,lb_v2, league, season)
359
+
360
+ for file in asr_files:
361
+ with open(os.path.join(root, file), 'r') as f:
362
+ asr = json.load(f)
363
+
364
+ # Determine the type of file and process accordingly
365
+ if 'Labels-v2' in file:
366
+ parse_labels_v2(asr, session, home_team_id, away_team_id, game_id)
367
+
368
+ elif '1_half-ASR' in file:
369
+ period = 1
370
+ # Parse and commit the data
371
+ process_ASR_data(data=asr, game_id = game_id, period=period)
372
+
373
+ elif '2_half-ASR' in file:
374
+ period = 2
375
+ # Parse and commit the data
376
+ process_ASR_data(data=asr, game_id = game_id, period=period)
377
+
378
+
379
+ session.commit()
380
+ session.close()
381
+
382
+ def fill_player_events(session):
383
+
384
+ fact_id2label = {
385
+ "1": "Yellow card",
386
+ # Example: "time": "71' Ivanovic B. (Unsportsmanlike conduct)", "description": "Yellow Card"
387
+ "2": "Red card", # Example: "time": "70' Matic N. (Unsportsmanlike conduct)", "description": "Red Card"
388
+ "3": "Goal", # Example: "time": "14' Ivanovic B. (Hazard E.)", "description": "Goal"
389
+ "4": "NA",
390
+ "5": "NA 2",
391
+ "6": "Substitution home", # Example: "time": "72'", "description": "Ramires"
392
+ "7": "Substitution away", # Example: "time": "86'", "description": "Filipe Luis"
393
+ "8": "Assistance" # Example: "time": "14' Ivanovic B. (Hazard E.)", "description": "Assistance"
394
+ }
395
+ for key, value in fact_id2label.items():
396
+ label = get_or_create(session, Player_Event_Label, label=value)
397
+ session.commit()
398
+
399
+
400
+
401
+ def fill_Augmented_Team(file_path):
402
+
403
+ df = pd.read_csv(file_path)
404
+ # the df should have two columns, team_name and augmented_name
405
+
406
+ session = Session()
407
+ teams = session.query(Team).all()
408
+ # For each row, find the team_id and add the augmented name
409
+ for index, row in df.iterrows():
410
+ team_name = row["name"]
411
+ augmented_name = row["augmented_name"]
412
+ # Strip leading and trailing whitespace
413
+ augmented_name = augmented_name.strip()
414
+ team = session.query(Team).filter_by(name=team_name).first()
415
+ if team:
416
+ augmented_team = get_or_create(session, Augmented_Team, team_id=team.id, augmented_name=augmented_name)
417
+ session.commit()
418
+ session.close()
419
+
420
+ def fill_Augmented_League(file_path):
421
+ # Read the csv file
422
+ df = pd.read_csv(file_path)
423
+ # the df should have two columns, team_name and augmented_name
424
+
425
+ session = Session()
426
+ leagues = session.query(League).all()
427
+ # For each row, find the team_id and add the augmented name
428
+ for index, row in df.iterrows():
429
+ league_name = row["name"]
430
+ augmented_name = row["augmented_name"]
431
+ # Strip leading and trailing whitespace
432
+ augmented_name = augmented_name.strip()
433
+ league = session.query(League).filter_by(name=league_name).first()
434
+ if league:
435
+ augmented_league = get_or_create(session, Augmented_League, league_id=league.id, augmented_name=augmented_name)
436
+ session.commit()
437
+ session.close()
438
+
439
+ if __name__ == "__main__":
440
+ # Example directory path
441
+ process_json_files('../../data/Dataset/SN-ASR_captions_and_actions/')
442
+ fill_Augmented_Team('../../data/Dataset/augmented.csv')
443
+ fill_Augmented_League('../../data/Dataset/augmented_leauges.csv')
444
+ # Rename the event/annotation table to something more descriptive. Events are fucking everything else over
445
+
src/database/readdata.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/extractor.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from langchain.chains import create_extraction_chain_pydantic
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+ from langchain.chains import create_extraction_chain
6
+ from copy import deepcopy
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain_community.utilities import SQLDatabase
9
+ import os
10
+ import difflib
11
+ import ast
12
+ import json
13
+ import re
14
+ from thefuzz import process
15
+ # Set up logging
16
+ import logging
17
+
18
+ from dotenv import load_dotenv
19
+
20
+ load_dotenv(".env")
21
+
22
+ logging.basicConfig(level=logging.INFO)
23
+ # Save the log to a file
24
+ handler = logging.FileHandler('extractor.log')
25
+ logger = logging.getLogger(__name__)
26
+
27
+ os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
28
+ # os.environ["ANTHROPIC_API_KEY"] = os.getenv('ANTHROPIC_API_KEY')
29
+
30
+ if os.getenv('LANGSMITH'):
31
+ os.environ['LANGCHAIN_TRACING_V2'] = 'true'
32
+ os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
33
+ os.environ[
34
+ 'LANGCHAIN_API_KEY'] = os.getenv("LANGSMITH_API_KEY")
35
+ os.environ['LANGCHAIN_PROJECT'] = 'master-theses'
36
+ db = SQLDatabase.from_uri("sqlite:///data/games.db")
37
+
38
+ # from langchain_anthropic import ChatAnthropic
39
+ class Extractor():
40
+ # llm = ChatOpenAI(model_name="gpt-4-0125-preview", temperature=0)
41
+ #gpt-3.5-turbo
42
+ def __init__(self, model="gpt-3.5-turbo-0125", schema_config=None, custom_extractor_prompt=None):
43
+ # model = "gpt-4-0125-preview"
44
+ if custom_extractor_prompt:
45
+ cust_promt = ChatPromptTemplate.from_template(custom_extractor_prompt)
46
+
47
+ self.llm = ChatOpenAI(model=model, temperature=0)
48
+ # self.llm = ChatAnthropic(model="claude-3-opus-20240229", temperature=0)
49
+ self.schema = schema_config or {}
50
+ self.chain = create_extraction_chain(self.schema, self.llm, prompt=cust_promt)
51
+
52
+ def extract(self, query):
53
+ return self.chain.invoke(query)
54
+
55
+
56
+ class Retriever():
57
+ def __init__(self, db, config):
58
+ self.db = db
59
+ self.config = config
60
+ self.table = config.get('db_table')
61
+ self.column = config.get('db_column')
62
+ self.pk_column = config.get('pk_column')
63
+ self.numeric = config.get('numeric', False)
64
+ self.response = []
65
+ self.query = f"SELECT {self.column} FROM {self.table}"
66
+ self.augmented_table = config.get('augmented_table', None)
67
+ self.augmented_column = config.get('augmented_column', None)
68
+ self.augmented_fk = config.get('augmented_fk', None)
69
+
70
+ def query_as_list(self):
71
+ # Execute the query
72
+ response = self.db.run(self.query)
73
+ response = [el for sub in ast.literal_eval(response) for el in sub if el]
74
+ if not self.numeric:
75
+ response = [re.sub(r"\b\d+\b", "", string).strip() for string in response]
76
+ self.response = list(set(response))
77
+ # print(self.response)
78
+ return self.response
79
+
80
+ def get_augmented_items(self, prompt):
81
+ if self.augmented_table is None:
82
+ return None
83
+ else:
84
+ # Construct the query to search for the prompt in the augmented table
85
+ query = f"SELECT {self.augmented_fk} FROM {self.augmented_table} WHERE LOWER({self.augmented_column}) = LOWER('{prompt}')"
86
+
87
+ # Execute the query
88
+ fk_response = self.db.run(query)
89
+ if fk_response:
90
+ # Extract the FK value
91
+ fk_response = ast.literal_eval(fk_response)
92
+ fk_value = fk_response[0][0]
93
+ query = f"SELECT {self.column} FROM {self.table} WHERE {self.pk_column} = {fk_value}"
94
+ # Execute the query
95
+ matching_response = self.db.run(query)
96
+ # Extract the matching response
97
+ matching_response = ast.literal_eval(matching_response)
98
+ matching_response = matching_response[0][0]
99
+ return matching_response
100
+ else:
101
+ return None
102
+
103
+ def find_close_matches(self, target_string, n=3, method="difflib", threshold=70):
104
+ """
105
+ Find and return the top n close matches to target_string in the database query results.
106
+
107
+ Args:
108
+ - target_string (str): The string to match against the database results.
109
+ - n (int): Number of top matches to return.
110
+
111
+ Returns:
112
+ - list of tuples: Each tuple contains a match and its score.
113
+ """
114
+ # Ensure we have the response list populated
115
+ if not self.response:
116
+ self.query_as_list()
117
+
118
+ # Find top n close matches
119
+ if method == "fuzzy":
120
+ # Use the fuzzy_string method to get matches and their scores
121
+ # If the threshold is met, return the best match; otherwise, return all matches meeting the threshold
122
+ top_matches = self.fuzzy_string(target_string, limit=n, threshold=threshold)
123
+
124
+
125
+ else:
126
+ # Use difflib's get_close_matches to get the top n matches
127
+ top_matches = difflib.get_close_matches(target_string, self.response, n=n, cutoff=0.2)
128
+
129
+ return top_matches
130
+
131
+ def fuzzy_string(self, prompt, limit, threshold=80, low_threshold=30):
132
+
133
+ # Get matches and their scores, limited by the specified 'limit'
134
+ matches = process.extract(prompt, self.response, limit=limit)
135
+
136
+
137
+ filtered_matches = [match for match in matches if match[1] >= threshold]
138
+
139
+ # If no matches meet the threshold, return the list of all matches' strings
140
+ if not filtered_matches:
141
+ # Return matches above the low_threshold
142
+ # Fix for wrong properties being returned
143
+ return [match[0] for match in matches if match[1] >= low_threshold]
144
+
145
+
146
+ # If there's only one match meeting the threshold, return it as a string
147
+ if len(filtered_matches) == 1:
148
+ return filtered_matches[0][0] # Return the matched string directly
149
+
150
+ # If there's more than one match meeting the threshold or ties, return the list of matches' strings
151
+ highest_score = filtered_matches[0][1]
152
+ ties = [match for match in filtered_matches if match[1] == highest_score]
153
+
154
+ # Return the strings of tied matches directly, ignoring the scores
155
+ m = [match[0] for match in ties]
156
+ if len(m) == 1:
157
+ return m[0]
158
+ return [match[0] for match in ties]
159
+
160
+ def fetch_pk(self, property_name, property_value):
161
+ # Some properties do not have a primary key
162
+ # Return the property value if no primary key is specified
163
+ pk_list = []
164
+
165
+ # Check if the property_value is a list; if not, make it a list for uniform processing
166
+ if not isinstance(property_value, list):
167
+ property_value = [property_value]
168
+
169
+ # Some properties do not have a primary key
170
+ # Return None for each property_value if no primary key is specified
171
+ if self.pk_column is None:
172
+ return [None for _ in property_value]
173
+
174
+ for value in property_value:
175
+ query = f"SELECT {self.pk_column} FROM {self.table} WHERE {self.column} = '{value}' LIMIT 1"
176
+ response = self.db.run(query)
177
+
178
+ # Append the response (PK or None) to the pk_list
179
+ pk_list.append(response)
180
+
181
+ return pk_list
182
+
183
+
184
+ def setup_retrievers(db, schema_config):
185
+ # retrievers = {}
186
+ # for prop, config in schema_config["properties"].items():
187
+ # retrievers[prop] = Retriever(db=db, config=config)
188
+ # return retrievers
189
+
190
+ retrievers = {}
191
+ # Iterate over each property in the schema_config's properties
192
+ for prop, config in schema_config["properties"].items():
193
+ # Access the 'items' dictionary for the configuration of the array's elements
194
+ item_config = config['items']
195
+ # Create a Retriever instance using the item_config
196
+ retrievers[prop] = Retriever(db=db, config=item_config)
197
+ return retrievers
198
+
199
+
200
+ def extract_properties(prompt, schema_config, custom_extractor_prompt=None):
201
+ """Extract properties from the prompt."""
202
+ # modify schema_conf to only include the required properties
203
+ schema_stripped = {'properties': {}}
204
+ for key, value in schema_config['properties'].items():
205
+ schema_stripped['properties'][key] = {
206
+ 'type': value['type'],
207
+ 'items': {'type': value['items']['type']}
208
+ }
209
+
210
+ extractor = Extractor(schema_config=schema_stripped, custom_extractor_prompt=custom_extractor_prompt)
211
+ extraction_result = extractor.extract(prompt)
212
+ # print("Extraction Result:", extraction_result)
213
+
214
+ if 'text' in extraction_result and extraction_result['text']:
215
+ properties = extraction_result['text']
216
+ return properties
217
+ else:
218
+ print("No properties extracted.")
219
+ return None
220
+
221
+
222
+ def recheck_property_value(properties, property_name, retrievers, input_func):
223
+ while True:
224
+ new_value = input_func(f"Enter new value for {property_name} or type 'quit' to stop: ")
225
+ if new_value.lower() == 'quit':
226
+ break # Exit the loop and do not update the property
227
+
228
+ new_top_matches = retrievers[property_name].find_close_matches(new_value, n=3)
229
+ if new_top_matches:
230
+ # Display new top matches and ask for confirmation or re-entry
231
+ print("\nNew close matches found:")
232
+ for i, match in enumerate(new_top_matches, start=1):
233
+ print(f"[{i}] {match}")
234
+ print("[4] Re-enter value")
235
+ print("[5] Quit without updating")
236
+
237
+ selection = input_func("Select the best match (1-3), choose 4 to re-enter value, or 5 to quit: ")
238
+ if selection in ['1', '2', '3']:
239
+ selected_match = new_top_matches[int(selection) - 1]
240
+ properties[property_name] = selected_match # Update the dictionary directly
241
+ print(f"Updated {property_name} to {selected_match}")
242
+ break # Successfully updated, exit the loop
243
+ elif selection == '5':
244
+ break # Quit without updating
245
+ # Loop will continue if user selects 4 or inputs invalid selection
246
+ else:
247
+ print("No close matches found. Please try again or type 'quit' to stop.")
248
+
249
+
250
+ def check_and_update_properties(properties_list, retrievers, method="fuzzy", input_func=input):
251
+ """
252
+ Checks and updates the properties in the properties list based on close matches found in the database.
253
+ The function iterates through each property in each property dictionary within the list,
254
+ finds close matches for it in the database using the retrievers, and updates the property
255
+ value based on user selection.
256
+
257
+ Args:
258
+ properties_list (list of dict): A list of dictionaries, where each dictionary contains properties
259
+ to check and potentially update based on database matches.
260
+ retrievers (dict): A dictionary of Retriever objects keyed by property name, used to find close matches in the database.
261
+ input_func (function, optional): A function to capture user input. Defaults to the built-in input function.
262
+
263
+ The function updates the properties_list in place based on user choices for updating property values
264
+ with close matches found by the retrievers.
265
+ """
266
+
267
+ for index, properties in enumerate(properties_list):
268
+ for property_name, retriever in retrievers.items(): # Iterate using items to get both key and value
269
+ property_values = properties.get(property_name, [])
270
+ if not property_values: # Skip if the property is not present or is an empty list
271
+ continue
272
+
273
+ updated_property_values = [] # To store updated list of values
274
+
275
+ for value in property_values:
276
+ if retriever.augmented_table:
277
+ augmented_value = retriever.get_augmented_items(value)
278
+ if augmented_value:
279
+ updated_property_values.append(augmented_value)
280
+ continue
281
+ # Since property_value is now expected to be a list, we handle each value individually
282
+ top_matches = retriever.find_close_matches(value, method=method, n=3)
283
+
284
+ # Check if the closest match is the same as the current value
285
+ if top_matches and top_matches[0] == value:
286
+ updated_property_values.append(value)
287
+ continue
288
+
289
+ if not top_matches:
290
+ updated_property_values.append(value) # Keep the original value if no matches found
291
+ continue
292
+
293
+ if type(top_matches) == str and method == "fuzzy":
294
+ # If the top_matches is a string, it means that the threshold was met and only one item was returned
295
+ # In this case, we can directly update the property with the top match
296
+ updated_property_values.append(top_matches)
297
+ properties[property_name] = updated_property_values
298
+ continue
299
+
300
+ print(f"\nCurrent {property_name}: {value}")
301
+ for i, match in enumerate(top_matches, start=1):
302
+ print(f"[{i}] {match}")
303
+ print("[4] Enter new value")
304
+
305
+ # hmm = input_func(f"Fix for Pycharm, press enter to continue")
306
+
307
+ choice = input_func(f"Select the best match for {property_name} (1-4): ")
308
+ if choice in ['1', '2', '3']:
309
+ selected_match = top_matches[int(choice) - 1]
310
+ updated_property_values.append(selected_match) # Update with the selected match
311
+ print(f"Updated {property_name} to {selected_match}")
312
+ elif choice == '4':
313
+ # Allow re-entry of value for this specific item
314
+ recheck_property_value(properties, property_name, value, retrievers, input_func)
315
+ # Note: Implement recheck_property_value to handle individual value updates within the list
316
+ else:
317
+ print("Invalid selection. Property not updated.")
318
+ updated_property_values.append(value) # Keep the original value
319
+
320
+ # Update the entire list for the property after processing all values
321
+ properties[property_name] = updated_property_values
322
+
323
+
324
+ # Function to remove duplicates
325
+ def remove_duplicates(dicts):
326
+ seen = {} # Dictionary to keep track of seen values for each key
327
+ for d in dicts:
328
+ for key in list(d.keys()): # Use list to avoid RuntimeError for changing dict size during iteration
329
+ value = d[key]
330
+ if key in seen and value == seen[key]:
331
+ del d[key] # Remove key-value pair if duplicate is found
332
+ else:
333
+ seen[key] = value # Update seen values for this key
334
+ return dicts
335
+
336
+
337
+ def fetch_pks(properties_list, retrievers):
338
+ all_pk_attributes = [] # Initialize a list to store dictionaries of _pk attributes for each item in properties_list
339
+
340
+ # Iterate through each properties dictionary in the list
341
+ for properties in properties_list:
342
+ pk_attributes = {} # Initialize a dictionary for the current set of properties
343
+ for property_name, property_value in properties.items():
344
+ if property_name in retrievers:
345
+ # Fetch the primary key using the retriever for the current property
346
+ pk = retrievers[property_name].fetch_pk(property_name, property_value)
347
+ # Store it in the dictionary with a modified key name
348
+ pk_attributes[f"{property_name}_pk"] = pk
349
+
350
+ # Add the dictionary of _pk attributes for the current set of properties to the list
351
+ all_pk_attributes.append(pk_attributes)
352
+
353
+ # Return a list of dictionaries, where each dictionary contains _pk attributes for a set of properties
354
+ return all_pk_attributes
355
+
356
+
357
+ def update_prompt(prompt, properties, pk, properties_original):
358
+ # Replace the original prompt with the updated properties and pk
359
+ prompt = prompt.replace("{{properties}}", str(properties))
360
+ prompt = prompt.replace("{{pk}}", str(pk))
361
+ return prompt
362
+
363
+
364
+ def update_prompt_enhanced(prompt, properties, pk, properties_original):
365
+ updated_info = ""
366
+ for prop, pk_info, prop_orig in zip(properties, pk, properties_original):
367
+ for key in prop.keys():
368
+ # Extract original and updated values
369
+ orig_values = prop_orig.get(key, [])
370
+ updated_values = prop.get(key, [])
371
+
372
+ # Ensure both original and updated values are lists for uniform processing
373
+ if not isinstance(orig_values, list):
374
+ orig_values = [orig_values]
375
+ if not isinstance(updated_values, list):
376
+ updated_values = [updated_values]
377
+
378
+ # Extract primary key detail for this key, handling various pk formats carefully
379
+ pk_key = f"{key}_pk" # Construct pk key name based on the property key
380
+ pk_details = pk_info.get(pk_key, [])
381
+ if not isinstance(pk_details, list):
382
+ pk_details = [pk_details]
383
+
384
+ for orig_value, updated_value, pk_detail in zip(orig_values, updated_values, pk_details):
385
+ pk_value = None
386
+ if isinstance(pk_detail, str):
387
+ pk_value = pk_detail.strip("[]()").split(",")[0].replace("'", "").replace('"', '')
388
+
389
+ update_statement = ""
390
+ # Skip updating if there's no change in value to avoid redundant info
391
+ if orig_value != updated_value and pk_value:
392
+ update_statement = f"\n- {orig_value} (now referred to as {updated_value}) has a primary key: {pk_value}."
393
+ elif orig_value != updated_value:
394
+ update_statement = f"\n- {orig_value} (now referred to as {updated_value})."
395
+ elif pk_value:
396
+ update_statement = f"\n- {orig_value} has a primary key: {pk_value}."
397
+
398
+ updated_info += update_statement
399
+
400
+ if updated_info:
401
+ prompt += "\nUpdated Information:" + updated_info
402
+
403
+ return prompt
404
+
405
+
406
+ def prompt_cleaner(prompt, db, schema_config):
407
+ """Main function to clean the prompt."""
408
+
409
+ retrievers = setup_retrievers(db, schema_config)
410
+
411
+ properties = extract_properties(prompt, schema_config)
412
+ # Keep original properties for later use
413
+ properties_original = deepcopy(properties)
414
+ # Remove duplicates - Happens when there are more than one player or team in the prompt
415
+ properties = remove_duplicates(properties)
416
+ if properties:
417
+ check_and_update_properties(properties, retrievers)
418
+
419
+ pk = fetch_pks(properties, retrievers)
420
+ properties = update_prompt_enhanced(prompt, properties, pk, properties_original)
421
+
422
+ return properties, pk
423
+
424
+
425
+ class PromptCleaner:
426
+ """
427
+ A class designed to clean and process prompts by extracting properties, removing duplicates,
428
+ and updating these properties based on a predefined schema configuration and database interactions.
429
+
430
+ Attributes:
431
+ db: A database connection object used to execute queries and fetch data.
432
+ schema_config: A dictionary defining the schema configuration for the extraction process.
433
+ schema_config = {
434
+ "properties": {
435
+ # Property name
436
+ "person_name": {"type": "string", "db_table": "players", "db_column": "name", "pk_column": "hash",
437
+ # if mostly numeric, such as 2015-2016 set true
438
+ "numeric": False},
439
+ "team_name": {"type": "string", "db_table": "teams", "db_column": "name", "pk_column": "id",
440
+ "numeric": False},
441
+ # Add more as needed
442
+ },
443
+ # Parameter to extractor, if person_name is required, add it here and the extractor will
444
+ # return an error if it is not found
445
+ "required": [],
446
+ }
447
+
448
+ Methods:
449
+ clean(prompt): Cleans the given prompt by extracting and updating properties based on the database.
450
+ Returns a tuple containing the updated properties and their primary keys.
451
+ """
452
+
453
+ def __init__(self, db=db, schema_config=None, custom_extractor_prompt=None):
454
+ """
455
+ Initializes the PromptCleaner with a database connection and a schema configuration.
456
+
457
+ Args:
458
+ db: The database connection object to be used for querying. (if none, it will use the default db)
459
+ schema_config: A dictionary defining properties and their database mappings for extraction and updating.
460
+ """
461
+ self.db = db
462
+ self.schema_config = schema_config
463
+ self.retrievers = setup_retrievers(self.db, self.schema_config)
464
+ self.cust_extractor_prompt = custom_extractor_prompt
465
+
466
+ def clean(self, prompt, return_pk=False, test=False, verbose = False):
467
+ """
468
+ Processes the given prompt to extract properties, remove duplicates, update the properties
469
+ based on close matches within the database, and fetch primary keys for these properties.
470
+
471
+ The method first extracts properties from the prompt using the schema configuration,
472
+ then checks these properties against the database to find and update close matches.
473
+ It also fetches primary keys for the updated properties where applicable.
474
+
475
+ Args:
476
+ prompt (str): The prompt text to be cleaned and processed.
477
+ return_pk (bool): A flag to indicate whether to return primary keys along with the properties.
478
+ test (bool): A flag to indicate whether to return the original properties for testing purposes.
479
+ verbose (bool): A flag to indicate whether to return the original properties for debugging.
480
+
481
+ Returns:
482
+ tuple: A tuple containing two elements:
483
+ - The first element is the original prompt, with updated information that excist in the db.
484
+ - The second element is a list of dictionaries, each containing primary keys for the properties,
485
+ where applicable.
486
+
487
+ """
488
+ if self.cust_extractor_prompt:
489
+
490
+ properties = extract_properties(prompt, self.schema_config, self.cust_extractor_prompt)
491
+
492
+ else:
493
+ properties = extract_properties(prompt, self.schema_config)
494
+ # Keep original properties for later use
495
+ properties_original = deepcopy(properties)
496
+ if test:
497
+ return properties_original
498
+ # Remove duplicates - Happens when there are more than one player or team in the prompt
499
+ # properties = remove_duplicates(properties)
500
+ pk = None
501
+ if properties:
502
+ check_and_update_properties(properties, self.retrievers)
503
+ pk = fetch_pks(properties, self.retrievers)
504
+ properties = update_prompt_enhanced(prompt, properties, pk, properties_original)
505
+
506
+
507
+
508
+ if return_pk:
509
+ return properties, pk
510
+ elif verbose:
511
+ return properties, properties_original
512
+ else:
513
+ return properties
514
+
515
+
516
+ def load_json(file_path: str) -> dict:
517
+ with open(file_path, 'r') as file:
518
+ return json.load(file)
519
+
520
+
521
+ def create_extractor(schema: str = "src/conf/schema.json", db: SQLDatabase = "sqlite:///data/games.db", ):
522
+ schema_config = load_json(schema)
523
+ db = SQLDatabase.from_uri(db)
524
+ pre_prompt = """Extract and save the relevant entities mentioned \
525
+ in the following passage together with their properties.
526
+
527
+ Only extract the properties mentioned in the 'information_extraction' function.
528
+
529
+ The questions are soccer related. game_event are things like yellow cards, goals, assists, freekick ect.
530
+ Generic properties like, "description", "home team", "away team", "game" ect should NOT be extracted.
531
+
532
+ If a property is not present and is not required in the function parameters, do not include it in the output.
533
+ If no properties are found, return an empty list.
534
+
535
+ Here are some exampels:
536
+ 'How many goals did Henry score for Arsnl in the 2015 season?'
537
+ person_name': ['Henry'], 'team_name': [Arsnl],'year_season': ['2015'],
538
+
539
+ Passage:
540
+ {input}
541
+ """
542
+
543
+ return PromptCleaner(db, schema_config, custom_extractor_prompt=pre_prompt)
544
+
545
+
546
+ if __name__ == "__main__":
547
+
548
+
549
+ schema_config = load_json("src/conf/schema.json")
550
+ # Add game and league to the schema_config
551
+
552
+ # prompter = PromptCleaner(db, schema_config, custom_extractor_prompt=extract_prompt)
553
+ prompter = create_extractor("src/conf/schema.json", "sqlite:///data/games.db")
554
+ prompt= prompter.clean("Give me goals, shots on target, shots off target and corners from the game between ManU and Swansa")
555
+
556
+
557
+ print(prompt)
558
+
src/sql_chain.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import json
3
+ import os
4
+ from langchain_community.vectorstores import FAISS
5
+ from langchain_core.example_selectors import SemanticSimilarityExampleSelector
6
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
7
+ from langchain_community.agent_toolkits import create_sql_agent
8
+ from langchain_core.prompts import (
9
+ ChatPromptTemplate,
10
+ FewShotPromptTemplate,
11
+ MessagesPlaceholder,
12
+ PromptTemplate,
13
+ SystemMessagePromptTemplate,
14
+ )
15
+ from langchain_community.utilities import SQLDatabase
16
+ from dotenv import load_dotenv
17
+
18
+ load_dotenv(".env")
19
+
20
+ logging.basicConfig(level=logging.INFO)
21
+ # Save the log to a file
22
+ handler = logging.FileHandler('extractor.log')
23
+ logger = logging.getLogger(__name__)
24
+
25
+ os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
26
+
27
+ if os.getenv('LANGSMITH'):
28
+ os.environ['LANGCHAIN_TRACING_V2'] = 'true'
29
+ os.environ['LANGCHAIN_ENDPOINT'] = 'https://api.smith.langchain.com'
30
+ os.environ[
31
+ 'LANGCHAIN_API_KEY'] = os.getenv("LANGSMITH_API_KEY")
32
+ os.environ['LANGCHAIN_PROJECT'] = 'master-theses'
33
+
34
+
35
+ def load_json(file_path: str) -> dict:
36
+ with open(file_path, 'r') as file:
37
+ return json.load(file)
38
+
39
+
40
+ class SqlChain:
41
+ def __init__(self, few_shot_prompts: str, llm_model="gpt-3.5-turbo", db_uri="sqlite:///data/games.db", few_shot_k=2, verbose=True):
42
+ self.llm = ChatOpenAI(model=llm_model, temperature=0)
43
+ self.db = SQLDatabase.from_uri(db_uri)
44
+ self.few_shot_k = few_shot_k
45
+ self.few_shot = self._set_up_few_shot_prompts(load_json(few_shot_prompts))
46
+ self.full_prompt = None
47
+
48
+ self.agent = create_sql_agent(
49
+ llm=self.llm,
50
+ db=self.db,
51
+ prompt=self.full_prompt,
52
+ max_iterations=10,
53
+ verbose=verbose,
54
+ agent_type="openai-tools",
55
+ # Default to 10 examples - Can be overwritten with the prompt
56
+ top_k=30,
57
+ )
58
+
59
+
60
+ def _set_up_few_shot_prompts(self, few_shot_prompts: dict) -> None:
61
+ few_shots = SemanticSimilarityExampleSelector.from_examples(
62
+ few_shot_prompts,
63
+ OpenAIEmbeddings(),
64
+ FAISS,
65
+ k=self.few_shot_k,
66
+ input_keys=["input"],
67
+ )
68
+ return few_shots
69
+
70
+ def few_prompt_construct(self, query: str, top_k=5, dialect="SQLite") -> str:
71
+ system_prefix = """You are an agent designed to interact with a SQL database.
72
+ Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
73
+ ALWAYS query the database before returning an answer.
74
+ Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
75
+ You can order the results by a relevant column to return the most interesting examples in the database.
76
+ Never query for all the columns from a specific table, only ask for the relevant columns given the question.
77
+ You have access to tools for interacting with the database.
78
+ Only use the given tools. Only use the information returned by the tools to construct your final answer.
79
+ You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
80
+
81
+ DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
82
+
83
+ If the question does not seem related to the database, just return 'I don't know' as the answer.
84
+ DO NOT include information that is not present in the database in your answer.
85
+
86
+ Here are some examples of user inputs and their corresponding SQL queries. They are tested and works.
87
+ Use them as a guide when creating your own queries:"""
88
+
89
+ SUFFIX = """Begin!
90
+
91
+ Question: {input}
92
+ Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.
93
+ I will not stop until I query the database and return the answer.
94
+ {agent_scratchpad}"""
95
+
96
+ few_shot_prompt = FewShotPromptTemplate(
97
+ example_selector=self.few_shot,
98
+ example_prompt=PromptTemplate.from_template(
99
+ "User input: {input}\nSQL query: {query}"
100
+ ),
101
+ input_variables=["input", "dialect", "top_k"],
102
+ prefix=system_prefix,
103
+ suffix=SUFFIX,
104
+ )
105
+ full_prompt = ChatPromptTemplate.from_messages(
106
+ [
107
+ SystemMessagePromptTemplate(prompt=few_shot_prompt),
108
+ ("human", "{input}"),
109
+ MessagesPlaceholder("agent_scratchpad"),
110
+ ]
111
+ )
112
+ self.full_prompt = full_prompt.invoke(
113
+ {
114
+ "input": query,
115
+ "top_k": top_k,
116
+ "dialect": dialect,
117
+ "agent_scratchpad": [],
118
+ }
119
+ )
120
+ def prompt_no_few_shot(self, query: str, dialect="SQLite") -> str:
121
+ system_prefix = """You are an agent designed to interact with a SQL database.
122
+ Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
123
+ Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
124
+ You can order the results by a relevant column to return the most interesting examples in the database.
125
+ Never query for all the columns from a specific table, only ask for the relevant columns given the question.
126
+ You have access to tools for interacting with the database.
127
+ Only use the given tools. Only use the information returned by the tools to construct your final answer.
128
+ You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
129
+
130
+ DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
131
+
132
+ If the question does not seem related to the database, just return 'I don't know' as the answer.
133
+ DO NOT include information that is not present in the database in your answer."""
134
+
135
+ return f"{system_prefix}\n{query}"
136
+
137
+
138
+
139
+
140
+ def ask(self, query: str, few_prompt:bool=True) -> str:
141
+ if few_prompt:
142
+ self.few_prompt_construct(query)
143
+ return self.agent.invoke({"input": self.full_prompt}), self.full_prompt
144
+ else:
145
+
146
+ return self.agent.invoke(self.prompt_no_few_shot(query)), self.prompt_no_few_shot(query)
147
+
148
+
149
+
150
+
151
+ def create_agent(few_shot_prompts: str = "src/conf/sqls.json", llm_model="gpt-3.5-turbo-0125",
152
+ db_uri="sqlite:///data/games.db", few_shot_k=2, verbose=True):
153
+ """ Create an agent with the given few_shot_prompts, llm_model and db_uri
154
+ Call it with agent.ask(prompt)"""
155
+ return SqlChain(few_shot_prompts, llm_model, db_uri, few_shot_k, verbose)
156
+
157
+
158
+ if __name__ == "__main__":
159
+ chain = SqlChain("src/conf/sqls.json")
160
+ chain.ask("Is Manchester United in the database?", False)