davanstrien HF staff commited on
Commit
33c1203
1 Parent(s): 55877ba

improve db

Browse files
Files changed (2) hide show
  1. data_loader.py +5 -2
  2. main.py +83 -67
data_loader.py CHANGED
@@ -121,8 +121,11 @@ def refresh_data() -> List[Dict[str, Any]]:
121
  df["tags"] = df["tags"].apply(ensure_list_of_strings)
122
  df["language"] = df["language"].apply(ensure_list_of_strings)
123
 
124
- # Convert 'features' column to string
125
- df["features"] = df["features"].apply(lambda x: str(x) if x is not None else None)
 
 
 
126
  df = df.astype({"hub_id": "string", "config_name": "string"})
127
 
128
  # save to parquet file with current date
 
121
  df["tags"] = df["tags"].apply(ensure_list_of_strings)
122
  df["language"] = df["language"].apply(ensure_list_of_strings)
123
 
124
+ # Ensure 'column_names' is a list
125
+ df["column_names"] = df["column_names"].apply(
126
+ lambda x: x if isinstance(x, list) else []
127
+ )
128
+
129
  df = df.astype({"hub_id": "string", "config_name": "string"})
130
 
131
  # save to parquet file with current date
main.py CHANGED
@@ -20,6 +20,8 @@ logger = logging.getLogger(__name__)
20
  def get_db_connection():
21
  conn = sqlite3.connect("datasets.db")
22
  conn.row_factory = sqlite3.Row
 
 
23
  return conn
24
 
25
 
@@ -31,17 +33,29 @@ def setup_database():
31
  (hub_id TEXT PRIMARY KEY,
32
  likes INTEGER,
33
  downloads INTEGER,
34
- tags TEXT,
35
  created_at INTEGER,
36
  last_modified INTEGER,
37
- license TEXT,
38
- language TEXT,
39
  config_name TEXT,
40
- column_names TEXT,
41
- features TEXT)"""
 
 
 
 
 
 
 
 
 
 
 
 
42
  )
43
- c.execute("CREATE INDEX IF NOT EXISTS idx_column_names ON datasets (column_names)")
44
  conn.commit()
 
45
  conn.close()
46
 
47
 
@@ -58,56 +72,46 @@ def serialize_numpy(obj):
58
  raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
59
 
60
 
61
- def insert_data(conn, data):
62
- c = conn.cursor()
63
-
64
- created_at = data.get("created_at", 0)
65
- if isinstance(created_at, Timestamp):
66
- created_at = int(created_at.timestamp())
67
-
68
- last_modified = data.get("last_modified", 0)
69
- if isinstance(last_modified, Timestamp):
70
- last_modified = int(last_modified.timestamp())
71
-
72
- c.execute(
73
- """
74
- INSERT OR REPLACE INTO datasets
75
- (hub_id, likes, downloads, tags, created_at, last_modified, license, language, config_name, column_names, features)
76
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
77
- """,
78
- (
79
- data["hub_id"],
80
- data.get("likes", 0),
81
- data.get("downloads", 0),
82
- json.dumps(data.get("tags", []), default=serialize_numpy),
83
- created_at,
84
- last_modified,
85
- json.dumps(data.get("license", []), default=serialize_numpy),
86
- json.dumps(data.get("language", []), default=serialize_numpy),
87
- data.get("config_name", ""),
88
- json.dumps(data.get("column_names", []), default=serialize_numpy),
89
- json.dumps(data.get("features", []), default=serialize_numpy),
90
- ),
91
- )
92
- conn.commit()
93
-
94
-
95
  @asynccontextmanager
96
  async def lifespan(app: FastAPI):
97
- # Startup: Load data into the database
98
  setup_database()
99
  logger.info("Creating database connection")
100
  conn = get_db_connection()
101
  logger.info("Refreshing data")
102
  datasets = refresh_data()
103
 
104
- for data in datasets:
105
- insert_data(conn, data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  conn.close()
107
  logger.info("Data refreshed")
108
  yield
109
- # Shutdown: You can add any cleanup operations here if needed
110
- # For example, closing database connections, clearing caches, etc.
111
 
112
 
113
  app = FastAPI(lifespan=lifespan)
@@ -140,43 +144,55 @@ async def search_datasets(
140
  try:
141
  if match_all:
142
  query = """
143
- SELECT COUNT(*) as total FROM datasets
144
- WHERE (SELECT COUNT(*) FROM json_each(column_names)
145
- WHERE value IN ({})) = ?
 
 
 
 
 
 
146
  """.format(",".join("?" * len(columns)))
147
- c.execute(query, (*columns, len(columns)))
148
  else:
149
  query = """
150
- SELECT COUNT(*) as total FROM datasets
151
  WHERE EXISTS (
152
- SELECT 1 FROM json_each(column_names)
153
- WHERE value IN ({})
 
154
  )
 
 
155
  """.format(",".join("?" * len(columns)))
156
- c.execute(query, columns)
157
 
158
- total = c.fetchone()["total"]
159
 
 
160
  if match_all:
161
- query = """
162
- SELECT * FROM datasets
163
- WHERE (SELECT COUNT(*) FROM json_each(column_names)
164
- WHERE value IN ({})) = ?
165
- LIMIT ? OFFSET ?
 
 
166
  """.format(",".join("?" * len(columns)))
167
- c.execute(query, (*columns, len(columns), page_size, offset))
168
  else:
169
- query = """
170
- SELECT * FROM datasets
171
  WHERE EXISTS (
172
- SELECT 1 FROM json_each(column_names)
173
- WHERE value IN ({})
 
174
  )
175
- LIMIT ? OFFSET ?
176
  """.format(",".join("?" * len(columns)))
177
- c.execute(query, (*columns, page_size, offset))
178
 
179
- results = [dict(row) for row in c.fetchall()]
180
 
181
  for result in results:
182
  result["tags"] = json.loads(result["tags"])
 
20
  def get_db_connection():
21
  conn = sqlite3.connect("datasets.db")
22
  conn.row_factory = sqlite3.Row
23
+ conn.execute("PRAGMA journal_mode = WAL")
24
+ conn.execute("PRAGMA synchronous = NORMAL")
25
  return conn
26
 
27
 
 
33
  (hub_id TEXT PRIMARY KEY,
34
  likes INTEGER,
35
  downloads INTEGER,
36
+ tags JSON,
37
  created_at INTEGER,
38
  last_modified INTEGER,
39
+ license JSON,
40
+ language JSON,
41
  config_name TEXT,
42
+ column_names JSON,
43
+ features JSON)"""
44
+ )
45
+ c.execute(
46
+ """
47
+ CREATE INDEX IF NOT EXISTS idx_column_names
48
+ ON datasets((json_each.value))
49
+ """
50
+ )
51
+ c.execute(
52
+ """
53
+ CREATE INDEX IF NOT EXISTS idx_downloads_likes
54
+ ON datasets(downloads DESC, likes DESC)
55
+ """
56
  )
 
57
  conn.commit()
58
+ c.execute("ANALYZE")
59
  conn.close()
60
 
61
 
 
72
  raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
73
 
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  @asynccontextmanager
76
  async def lifespan(app: FastAPI):
 
77
  setup_database()
78
  logger.info("Creating database connection")
79
  conn = get_db_connection()
80
  logger.info("Refreshing data")
81
  datasets = refresh_data()
82
 
83
+ c = conn.cursor()
84
+ c.executemany(
85
+ """
86
+ INSERT OR REPLACE INTO datasets
87
+ (hub_id, likes, downloads, tags, created_at, last_modified, license, language, config_name, column_names, features)
88
+ VALUES (?, ?, ?, json(?), ?, ?, json(?), json(?), ?, json(?), json(?))
89
+ """,
90
+ [
91
+ (
92
+ data["hub_id"],
93
+ data.get("likes", 0),
94
+ data.get("downloads", 0),
95
+ json.dumps(data.get("tags", []), default=serialize_numpy),
96
+ int(data["created_at"].timestamp())
97
+ if isinstance(data["created_at"], Timestamp)
98
+ else data.get("created_at", 0),
99
+ int(data["last_modified"].timestamp())
100
+ if isinstance(data["last_modified"], Timestamp)
101
+ else data.get("last_modified", 0),
102
+ json.dumps(data.get("license", []), default=serialize_numpy),
103
+ json.dumps(data.get("language", []), default=serialize_numpy),
104
+ data.get("config_name", ""),
105
+ json.dumps(data.get("column_names", []), default=serialize_numpy),
106
+ json.dumps(data.get("features", []), default=serialize_numpy),
107
+ )
108
+ for data in datasets
109
+ ],
110
+ )
111
+ conn.commit()
112
  conn.close()
113
  logger.info("Data refreshed")
114
  yield
 
 
115
 
116
 
117
  app = FastAPI(lifespan=lifespan)
 
144
  try:
145
  if match_all:
146
  query = """
147
+ SELECT *, (
148
+ SELECT COUNT(*)
149
+ FROM json_each(column_names)
150
+ WHERE json_each.value IN ({})
151
+ ) as match_count
152
+ FROM datasets
153
+ WHERE match_count = ?
154
+ ORDER BY downloads DESC, likes DESC
155
+ LIMIT ? OFFSET ?
156
  """.format(",".join("?" * len(columns)))
157
+ c.execute(query, (*columns, len(columns), page_size, offset))
158
  else:
159
  query = """
160
+ SELECT * FROM datasets
161
  WHERE EXISTS (
162
+ SELECT 1
163
+ FROM json_each(column_names)
164
+ WHERE json_each.value IN ({})
165
  )
166
+ ORDER BY downloads DESC, likes DESC
167
+ LIMIT ? OFFSET ?
168
  """.format(",".join("?" * len(columns)))
169
+ c.execute(query, (*columns, page_size, offset))
170
 
171
+ results = [dict(row) for row in c.fetchall()]
172
 
173
+ # Get total count
174
  if match_all:
175
+ count_query = """
176
+ SELECT COUNT(*) as total FROM datasets
177
+ WHERE (
178
+ SELECT COUNT(*)
179
+ FROM json_each(column_names)
180
+ WHERE json_each.value IN ({})
181
+ ) = ?
182
  """.format(",".join("?" * len(columns)))
183
+ c.execute(count_query, (*columns, len(columns)))
184
  else:
185
+ count_query = """
186
+ SELECT COUNT(*) as total FROM datasets
187
  WHERE EXISTS (
188
+ SELECT 1
189
+ FROM json_each(column_names)
190
+ WHERE json_each.value IN ({})
191
  )
 
192
  """.format(",".join("?" * len(columns)))
193
+ c.execute(count_query, columns)
194
 
195
+ total = c.fetchone()["total"]
196
 
197
  for result in results:
198
  result["tags"] = json.loads(result["tags"])