|
import requests |
|
from langchain_community.utilities import SQLDatabase |
|
from langchain_community.tools.sql_database.tool import ListSQLDatabaseTool, InfoSQLDatabaseTool |
|
from sqlalchemy import ( |
|
create_engine, |
|
MetaData, |
|
inspect, |
|
Table, |
|
select, |
|
distinct |
|
) |
|
from sqlalchemy.schema import CreateTable |
|
from sqlalchemy.exc import ProgrammingError |
|
from sqlalchemy.engine import Engine |
|
import re |
|
|
|
def get_all_groq_model(api_key:str=None) -> list: |
|
"""Uses Groq API to fetch all the available models.""" |
|
if api_key is None: |
|
raise ValueError("API key is required") |
|
url = "https://api.groq.com/openai/v1/models" |
|
|
|
headers = { |
|
"Authorization": f"Bearer {api_key}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
response = requests.get(url, headers=headers) |
|
|
|
data = response.json()['data'] |
|
model_ids = [model['id'] for model in data] |
|
|
|
return model_ids |
|
|
|
def validate_api_key(api_key:str) -> bool: |
|
"""Validates the Groq API key using the get_all_groq_model function.""" |
|
if len(api_key) == 0: |
|
return False |
|
try: |
|
get_all_groq_model(api_key=api_key) |
|
return True |
|
except Exception as e: |
|
return False |
|
|
|
def validate_uri(uri:str) -> bool: |
|
"""Validates the SQL Database URI using the SQLDatabase.from_uri function.""" |
|
try: |
|
SQLDatabase.from_uri(uri) |
|
return True |
|
except Exception as e: |
|
return False |
|
|
|
def get_info(uri:str) -> dict[str, str] | None: |
|
"""Gets the dialect name, accessible tables and table schemas using the SQLDatabase toolkit""" |
|
db = SQLDatabase.from_uri(uri) |
|
dialect = db.dialect |
|
|
|
access_tables = ListSQLDatabaseTool(db=db).invoke("") |
|
|
|
tables_schemas = InfoSQLDatabaseTool(db=db).invoke(access_tables) |
|
return {'sql_dialect': dialect, 'tables': access_tables, 'tables_schema': tables_schemas} |
|
|
|
def get_sample_rows(engine:Engine, table:Table, row_count: int = 3) -> str: |
|
"""Gets the sample rows of a table using the SQLAlchemy engine""" |
|
|
|
command = select(table).limit(row_count) |
|
|
|
|
|
columns_str = "\t".join([col.name for col in table.columns]) |
|
|
|
try: |
|
|
|
with engine.connect() as connection: |
|
sample_rows_result = connection.execute(command) |
|
|
|
sample_rows = list( |
|
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result) |
|
) |
|
|
|
|
|
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows]) |
|
|
|
|
|
|
|
except ProgrammingError: |
|
sample_rows_str = "" |
|
|
|
return ( |
|
f"{row_count} rows from {table.name} table:\n" |
|
f"{columns_str}\n" |
|
f"{sample_rows_str}" |
|
) |
|
|
|
def get_unique_values(engine:Engine, table:Table) -> str: |
|
"""Gets the unique values of each column in a table using the SQLAlchemy engine""" |
|
unique_values = {} |
|
for column in table.c: |
|
command = select(distinct(column)) |
|
|
|
try: |
|
|
|
with engine.connect() as connection: |
|
result = connection.execute(command) |
|
|
|
unique_values[column.name] = [str(u) for u in result] |
|
|
|
|
|
|
|
|
|
|
|
except ProgrammingError: |
|
sample_rows_str = "" |
|
|
|
output_str = f"Unique values of each column in {table.name}: \n" |
|
for column, values in unique_values.items(): |
|
output_str += f"{column} has {len(values)} unique values: {' '.join(values[:20])}" |
|
if len(values) > 20: |
|
output_str += ", ...." |
|
output_str += "\n" |
|
|
|
return output_str |
|
|
|
def get_info_sqlalchemy(uri:str) -> dict[str, str] | None: |
|
"""Gets the dialect name, accessible tables and table schemas using the SQLAlchemy engine""" |
|
engine = create_engine(uri) |
|
|
|
inspector = inspect(engine) |
|
dialect = inspector.dialect.name |
|
|
|
m = MetaData() |
|
m.reflect(engine) |
|
|
|
tables = {} |
|
for table in m.tables.values(): |
|
tables[table.name] = str(CreateTable(table).compile(engine)).rstrip() |
|
tables[table.name] += "\n\n/*" |
|
tables[table.name] += "\n" + get_sample_rows(engine, table)+"\n" |
|
tables[table.name] += "\n" + get_unique_values(engine, table)+"\n" |
|
tables[table.name] += "*/" |
|
|
|
return {'sql_dialect': dialect, 'tables': ", ".join(tables.keys()), 'tables_schema': "\n\n".join(tables.values())} |
|
|
|
def extract_code_blocks(text): |
|
pattern = r"```(?:\w+)?\n(.*?)\n```" |
|
matches = re.findall(pattern, text, re.DOTALL) |
|
return matches |
|
|
|
if __name__ == "__main__": |
|
from dotenv import load_dotenv |
|
import os |
|
load_dotenv() |
|
|
|
uri = os.getenv("POSTGRES_URI") |
|
print(get_info_sqlalchemy(uri)) |
|
|