db-ally / app.py
micpst's picture
fix: similarity indexes setup (#2)
0885182 verified
import asyncio
import datetime
from typing import Annotated
import dbally
import sqlalchemy
from dbally import SqlAlchemyBaseView
from dbally.audit import CLIEventHandler
from dbally.embeddings import LiteLLMEmbeddingClient
from dbally.gradio import create_gradio_interface
from dbally.llms import LiteLLM
from dbally.similarity import SimilarityIndex, SimpleSqlAlchemyFetcher, FaissStore
from dbally.views import decorators
from dotenv import load_dotenv
from sqlalchemy import create_engine
from sqlalchemy.ext.automap import automap_base
dbally.event_handlers = [CLIEventHandler()]
engine = create_engine('sqlite:///clients.db')
load_dotenv()
Base = automap_base()
Base.prepare(autoload_with=engine)
Clients = Base.classes.clients
cities_fetcher = SimpleSqlAlchemyFetcher(
sqlalchemy_engine=engine,
table=Clients,
column=Clients.city,
)
cities_store = FaissStore(
index_dir="indexes",
index_name="cities_index",
embedding_client=LiteLLMEmbeddingClient("text-embedding-3-small"),
)
CityIndex = SimilarityIndex(
fetcher=cities_fetcher,
store=cities_store,
)
class ClientsView(SqlAlchemyBaseView):
def get_select(self) -> sqlalchemy.Select:
return sqlalchemy.select(Clients)
@decorators.view_filter()
def filter_by_city(self, city: Annotated[str, CityIndex]):
return Clients.city == city
@decorators.view_filter()
def eligible_for_loyalty_program(self):
total_orders_check = Clients.total_orders > 3
date_joined_check = Clients.date_joined < (datetime.datetime.now() - datetime.timedelta(days=365))
return total_orders_check & date_joined_check
async def main() -> None:
llm = LiteLLM(model_name="gpt-4-turbo")
collection = dbally.create_collection("clients", llm=llm)
collection.add(ClientsView, lambda: ClientsView(engine))
interface = create_gradio_interface(collection)
interface.launch()
if __name__ == '__main__':
asyncio.run(main())