davanstrien HF staff commited on
Commit
9a6e4f2
·
1 Parent(s): f06e8ae

refactor: Update main.py with scheduled data refresh and background task

Browse files
Files changed (1) hide show
  1. main.py +86 -35
main.py CHANGED
@@ -1,10 +1,15 @@
 
 
1
  import json
2
  import logging
 
3
  import sqlite3
4
  from contextlib import asynccontextmanager
5
  from typing import List
6
 
7
  import numpy as np
 
 
8
  from cashews import NOT_NONE, cache
9
  from fastapi import FastAPI, HTTPException, Query
10
  from pandas import Timestamp
@@ -13,6 +18,8 @@ from starlette.responses import RedirectResponse
13
 
14
  from data_loader import refresh_data
15
 
 
 
16
  cache.setup("mem://?check_interval=10&size=10000")
17
  logger = logging.getLogger(__name__)
18
 
@@ -72,47 +79,91 @@ def serialize_numpy(obj):
72
  raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
73
 
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  @asynccontextmanager
76
  async def lifespan(app: FastAPI):
77
  setup_database()
78
- logger.info("Creating database connection")
79
- conn = get_db_connection()
80
- logger.info("Refreshing data")
81
- datasets = refresh_data()
 
 
 
 
82
 
83
- c = conn.cursor()
84
- c.executemany(
85
- """
86
- INSERT OR REPLACE INTO datasets
87
- (hub_id, likes, downloads, tags, created_at, last_modified, license, language, config_name, column_names, features)
88
- VALUES (?, ?, ?, json(?), ?, ?, json(?), json(?), ?, json(?), json(?))
89
- """,
90
- [
91
- (
92
- data["hub_id"],
93
- data.get("likes", 0),
94
- data.get("downloads", 0),
95
- json.dumps(data.get("tags", []), default=serialize_numpy),
96
- int(data["created_at"].timestamp())
97
- if isinstance(data["created_at"], Timestamp)
98
- else data.get("created_at", 0),
99
- int(data["last_modified"].timestamp())
100
- if isinstance(data["last_modified"], Timestamp)
101
- else data.get("last_modified", 0),
102
- json.dumps(data.get("license", []), default=serialize_numpy),
103
- json.dumps(data.get("language", []), default=serialize_numpy),
104
- data.get("config_name", ""),
105
- json.dumps(data.get("column_names", []), default=serialize_numpy),
106
- json.dumps(data.get("features", []), default=serialize_numpy),
107
- )
108
- for data in datasets
109
- ],
110
- )
111
- conn.commit()
112
- conn.close()
113
- logger.info("Data refreshed")
114
  yield
115
 
 
 
116
 
117
  app = FastAPI(lifespan=lifespan)
118
 
 
1
+ import asyncio
2
+ import concurrent.futures
3
  import json
4
  import logging
5
+ import os
6
  import sqlite3
7
  from contextlib import asynccontextmanager
8
  from typing import List
9
 
10
  import numpy as np
11
+ from apscheduler.schedulers.asyncio import AsyncIOScheduler
12
+ from apscheduler.triggers.cron import CronTrigger
13
  from cashews import NOT_NONE, cache
14
  from fastapi import FastAPI, HTTPException, Query
15
  from pandas import Timestamp
 
18
 
19
  from data_loader import refresh_data
20
 
21
+ UPDATE_SCHEDULE = {"hour": os.getenv("UPDATE_INTERVAL_HOURS", "*/6")}
22
+
23
  cache.setup("mem://?check_interval=10&size=10000")
24
  logger = logging.getLogger(__name__)
25
 
 
79
  raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
80
 
81
 
82
+ def background_refresh_data():
83
+ logger.info("Starting background data refresh")
84
+ try:
85
+ return refresh_data()
86
+ except Exception as e:
87
+ logger.error(f"Error in background data refresh: {str(e)}")
88
+ return None
89
+
90
+
91
+ async def update_database():
92
+ logger.info("Starting scheduled data refresh")
93
+
94
+ # Run refresh_data in a background thread
95
+ with concurrent.futures.ThreadPoolExecutor() as executor:
96
+ future = executor.submit(background_refresh_data)
97
+
98
+ # Wait for the background task to complete, but allow for cancellation
99
+ try:
100
+ datasets = await asyncio.get_event_loop().run_in_executor(
101
+ None, future.result
102
+ )
103
+ except asyncio.CancelledError:
104
+ future.cancel()
105
+ logger.info("Data refresh cancelled")
106
+ return
107
+
108
+ if datasets is None:
109
+ logger.error("Data refresh failed, skipping database update")
110
+ return
111
+
112
+ conn = get_db_connection()
113
+ try:
114
+ c = conn.cursor()
115
+ c.executemany(
116
+ """
117
+ INSERT OR REPLACE INTO datasets
118
+ (hub_id, likes, downloads, tags, created_at, last_modified, license, language, config_name, column_names, features)
119
+ VALUES (?, ?, ?, json(?), ?, ?, json(?), json(?), ?, json(?), json(?))
120
+ """,
121
+ [
122
+ (
123
+ data["hub_id"],
124
+ data.get("likes", 0),
125
+ data.get("downloads", 0),
126
+ json.dumps(data.get("tags", []), default=serialize_numpy),
127
+ int(data["created_at"].timestamp())
128
+ if isinstance(data["created_at"], Timestamp)
129
+ else data.get("created_at", 0),
130
+ int(data["last_modified"].timestamp())
131
+ if isinstance(data["last_modified"], Timestamp)
132
+ else data.get("last_modified", 0),
133
+ json.dumps(data.get("license", []), default=serialize_numpy),
134
+ json.dumps(data.get("language", []), default=serialize_numpy),
135
+ data.get("config_name", ""),
136
+ json.dumps(data.get("column_names", []), default=serialize_numpy),
137
+ json.dumps(data.get("features", []), default=serialize_numpy),
138
+ )
139
+ for data in datasets
140
+ ],
141
+ )
142
+ conn.commit()
143
+ logger.info("Scheduled data refresh completed")
144
+ except Exception as e:
145
+ logger.error(f"Error during database update: {str(e)}")
146
+ conn.rollback()
147
+ finally:
148
+ conn.close()
149
+
150
+
151
  @asynccontextmanager
152
  async def lifespan(app: FastAPI):
153
  setup_database()
154
+ logger.info("Performing initial data refresh")
155
+ await update_database()
156
+
157
+ # Set up the scheduler
158
+ scheduler = AsyncIOScheduler()
159
+ # Schedule the update_database function using the UPDATE_SCHEDULE configuration
160
+ scheduler.add_job(update_database, CronTrigger(**UPDATE_SCHEDULE))
161
+ scheduler.start()
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  yield
164
 
165
+ scheduler.shutdown()
166
+
167
 
168
  app = FastAPI(lifespan=lifespan)
169