davanstrien's picture
davanstrien HF staff
improve db
33c1203
raw
history blame
6.67 kB
import json
import logging
import sqlite3
from contextlib import asynccontextmanager
from typing import List
import numpy as np
from cashews import NOT_NONE, cache
from fastapi import FastAPI, HTTPException, Query
from pandas import Timestamp
from pydantic import BaseModel
from starlette.responses import RedirectResponse
from data_loader import refresh_data
cache.setup("mem://?check_interval=10&size=10000")
logger = logging.getLogger(__name__)
def get_db_connection():
conn = sqlite3.connect("datasets.db")
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode = WAL")
conn.execute("PRAGMA synchronous = NORMAL")
return conn
def setup_database():
conn = get_db_connection()
c = conn.cursor()
c.execute(
"""CREATE TABLE IF NOT EXISTS datasets
(hub_id TEXT PRIMARY KEY,
likes INTEGER,
downloads INTEGER,
tags JSON,
created_at INTEGER,
last_modified INTEGER,
license JSON,
language JSON,
config_name TEXT,
column_names JSON,
features JSON)"""
)
c.execute(
"""
CREATE INDEX IF NOT EXISTS idx_column_names
ON datasets((json_each.value))
"""
)
c.execute(
"""
CREATE INDEX IF NOT EXISTS idx_downloads_likes
ON datasets(downloads DESC, likes DESC)
"""
)
conn.commit()
c.execute("ANALYZE")
conn.close()
def serialize_numpy(obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, Timestamp):
return int(obj.timestamp())
logger.error(f"Object of type {type(obj)} is not JSON serializable")
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
@asynccontextmanager
async def lifespan(app: FastAPI):
setup_database()
logger.info("Creating database connection")
conn = get_db_connection()
logger.info("Refreshing data")
datasets = refresh_data()
c = conn.cursor()
c.executemany(
"""
INSERT OR REPLACE INTO datasets
(hub_id, likes, downloads, tags, created_at, last_modified, license, language, config_name, column_names, features)
VALUES (?, ?, ?, json(?), ?, ?, json(?), json(?), ?, json(?), json(?))
""",
[
(
data["hub_id"],
data.get("likes", 0),
data.get("downloads", 0),
json.dumps(data.get("tags", []), default=serialize_numpy),
int(data["created_at"].timestamp())
if isinstance(data["created_at"], Timestamp)
else data.get("created_at", 0),
int(data["last_modified"].timestamp())
if isinstance(data["last_modified"], Timestamp)
else data.get("last_modified", 0),
json.dumps(data.get("license", []), default=serialize_numpy),
json.dumps(data.get("language", []), default=serialize_numpy),
data.get("config_name", ""),
json.dumps(data.get("column_names", []), default=serialize_numpy),
json.dumps(data.get("features", []), default=serialize_numpy),
)
for data in datasets
],
)
conn.commit()
conn.close()
logger.info("Data refreshed")
yield
app = FastAPI(lifespan=lifespan)
@app.get("/", include_in_schema=False)
def root():
return RedirectResponse(url="/docs")
class SearchResponse(BaseModel):
total: int
page: int
page_size: int
results: List[dict]
@cache(ttl="1h", condition=NOT_NONE)
@app.get("/search", response_model=SearchResponse)
async def search_datasets(
columns: List[str] = Query(...),
match_all: bool = Query(False),
page: int = Query(1, ge=1),
page_size: int = Query(10, ge=1, le=1000),
):
offset = (page - 1) * page_size
conn = get_db_connection()
c = conn.cursor()
try:
if match_all:
query = """
SELECT *, (
SELECT COUNT(*)
FROM json_each(column_names)
WHERE json_each.value IN ({})
) as match_count
FROM datasets
WHERE match_count = ?
ORDER BY downloads DESC, likes DESC
LIMIT ? OFFSET ?
""".format(",".join("?" * len(columns)))
c.execute(query, (*columns, len(columns), page_size, offset))
else:
query = """
SELECT * FROM datasets
WHERE EXISTS (
SELECT 1
FROM json_each(column_names)
WHERE json_each.value IN ({})
)
ORDER BY downloads DESC, likes DESC
LIMIT ? OFFSET ?
""".format(",".join("?" * len(columns)))
c.execute(query, (*columns, page_size, offset))
results = [dict(row) for row in c.fetchall()]
# Get total count
if match_all:
count_query = """
SELECT COUNT(*) as total FROM datasets
WHERE (
SELECT COUNT(*)
FROM json_each(column_names)
WHERE json_each.value IN ({})
) = ?
""".format(",".join("?" * len(columns)))
c.execute(count_query, (*columns, len(columns)))
else:
count_query = """
SELECT COUNT(*) as total FROM datasets
WHERE EXISTS (
SELECT 1
FROM json_each(column_names)
WHERE json_each.value IN ({})
)
""".format(",".join("?" * len(columns)))
c.execute(count_query, columns)
total = c.fetchone()["total"]
for result in results:
result["tags"] = json.loads(result["tags"])
result["license"] = json.loads(result["license"])
result["language"] = json.loads(result["language"])
result["column_names"] = json.loads(result["column_names"])
result["features"] = json.loads(result["features"])
return SearchResponse(
total=total, page=page, page_size=page_size, results=results
)
except sqlite3.Error as e:
logger.error(f"Database error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") from e
finally:
conn.close()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)