Rifqi Hafizuddin commited on
Commit
e49db60
·
1 Parent(s): 8218650

[NOTICKET] add db_client for querying

Browse files
src/api/v1/db_client.py CHANGED
@@ -458,7 +458,7 @@ async def ingest_database_client(
458
  db_type=client.db_type,
459
  credentials=creds,
460
  ) as engine:
461
- total = await db_pipeline_service.run(user_id=user_id, engine=engine)
462
  except NotImplementedError as e:
463
  raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=str(e))
464
  except Exception as e:
 
458
  db_type=client.db_type,
459
  credentials=creds,
460
  ) as engine:
461
+ total = await db_pipeline_service.run(user_id=user_id, client_id=client_id, engine=engine)
462
  except NotImplementedError as e:
463
  raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail=str(e))
464
  except Exception as e:
src/pipeline/db_pipeline/db_pipeline_service.py CHANGED
@@ -148,7 +148,7 @@ class DbPipelineService:
148
  engine.dispose()
149
 
150
  def _to_document(
151
- self, user_id: str, table_name: str, entry: dict, updated_at: str
152
  ) -> LangChainDocument:
153
  col = entry["col"]
154
  return LangChainDocument(
@@ -156,6 +156,7 @@ class DbPipelineService:
156
  metadata={
157
  "user_id": user_id,
158
  "source_type": "database",
 
159
  "updated_at": updated_at,
160
  "data": {
161
  "table_name": table_name,
@@ -170,6 +171,7 @@ class DbPipelineService:
170
  async def run(
171
  self,
172
  user_id: str,
 
173
  engine: Engine,
174
  exclude_tables: Optional[frozenset[str]] = None,
175
  ) -> int:
@@ -202,7 +204,7 @@ class DbPipelineService:
202
  for table_name, columns in schema.items():
203
  logger.info("profiling table", table=table_name, columns=len(columns))
204
  entries = await asyncio.to_thread(profile_table, engine, table_name, columns)
205
- docs = [self._to_document(user_id, table_name, e, updated_at) for e in entries]
206
  if docs:
207
  await vector_store.aadd_documents(docs)
208
  total += len(docs)
 
148
  engine.dispose()
149
 
150
  def _to_document(
151
+ self, user_id: str, client_id: str, table_name: str, entry: dict, updated_at: str
152
  ) -> LangChainDocument:
153
  col = entry["col"]
154
  return LangChainDocument(
 
156
  metadata={
157
  "user_id": user_id,
158
  "source_type": "database",
159
+ "database_client_id": client_id,
160
  "updated_at": updated_at,
161
  "data": {
162
  "table_name": table_name,
 
171
  async def run(
172
  self,
173
  user_id: str,
174
+ client_id: str,
175
  engine: Engine,
176
  exclude_tables: Optional[frozenset[str]] = None,
177
  ) -> int:
 
204
  for table_name, columns in schema.items():
205
  logger.info("profiling table", table=table_name, columns=len(columns))
206
  entries = await asyncio.to_thread(profile_table, engine, table_name, columns)
207
+ docs = [self._to_document(user_id, client_id, table_name, e, updated_at) for e in entries]
208
  if docs:
209
  await vector_store.aadd_documents(docs)
210
  total += len(docs)