db-ally / app.py
micpst's picture
fix indexes
935c9e3
raw
history blame
1.98 kB
import asyncio
import datetime
from typing import Annotated
import dbally
import sqlalchemy
from dbally import SqlAlchemyBaseView
from dbally.audit import CLIEventHandler
from dbally.embeddings import LiteLLMEmbeddingClient
from dbally.gradio import create_gradio_interface
from dbally.llms import LiteLLM
from dbally.similarity import SimilarityIndex, SimpleSqlAlchemyFetcher, FaissStore
from dbally.views import decorators
from dotenv import load_dotenv
from sqlalchemy import create_engine
from sqlalchemy.ext.automap import automap_base
dbally.event_handlers = [CLIEventHandler()]
engine = create_engine('sqlite:///clients.db')
load_dotenv()
Base = automap_base()
Base.prepare(autoload_with=engine)
Clients = Base.classes.clients
cities_fetcher = SimpleSqlAlchemyFetcher(
sqlalchemy_engine=engine,
table=Clients,
column=Clients.city,
)
cities_store = FaissStore(
index_dir="indexes",
index_name="cities_index",
embedding_client=LiteLLMEmbeddingClient("text-embedding-3-small"),
)
CityIndex = SimilarityIndex(
fetcher=cities_fetcher,
store=cities_store,
)
class ClientsView(SqlAlchemyBaseView):
def get_select(self) -> sqlalchemy.Select:
return sqlalchemy.select(Clients)
@decorators.view_filter()
def filter_by_city(self, city: Annotated[str, CityIndex]):
return Clients.city == city
@decorators.view_filter()
def eligible_for_loyalty_program(self):
total_orders_check = Clients.total_orders > 3
date_joined_check = Clients.date_joined < (datetime.datetime.now() - datetime.timedelta(days=365))
return total_orders_check & date_joined_check
async def main() -> None:
llm = LiteLLM(model_name="gpt-4-turbo")
collection = dbally.create_collection("clients", llm=llm)
collection.add(ClientsView, lambda: ClientsView(engine))
interface = create_gradio_interface(collection)
interface.launch()
if __name__ == '__main__':
asyncio.run(main())