Rifqi Hafizuddin commited on
Commit
d310770
·
1 Parent(s): 52415b6

[NOTICKET] add duplicate check for storing database

Browse files
src/database_client/database_client_service.py CHANGED
@@ -16,9 +16,46 @@ from src.utils.db_credential_encryption import (
16
  logger = get_logger("database_client_service")
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
19
  class DatabaseClientService:
20
  """Service for managing user-registered external database connections."""
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  async def create(
23
  self,
24
  db: AsyncSession,
@@ -29,8 +66,17 @@ class DatabaseClientService:
29
  ) -> DatabaseClient:
30
  """Register a new database client connection.
31
 
 
 
32
  Credentials are encrypted before being stored.
33
  """
 
 
 
 
 
 
 
34
  client = DatabaseClient(
35
  id=str(uuid.uuid4()),
36
  user_id=user_id,
 
16
  logger = get_logger("database_client_service")
17
 
18
 
19
+ # Fields that identify the same physical database per db_type.
20
+ _CONNECTION_IDENTITY_KEYS: dict[str, tuple[str, ...]] = {
21
+ "postgres": ("host", "port", "database"),
22
+ "supabase": ("host", "port", "database"),
23
+ "mysql": ("host", "port", "database"),
24
+ "sqlserver": ("host", "port", "database"),
25
+ "bigquery": ("project_id", "dataset_id"),
26
+ "snowflake": ("account", "warehouse", "database"),
27
+ }
28
+
29
+
30
  class DatabaseClientService:
31
  """Service for managing user-registered external database connections."""
32
 
33
+ async def _find_duplicate(
34
+ self,
35
+ db: AsyncSession,
36
+ user_id: str,
37
+ db_type: str,
38
+ credentials: dict,
39
+ ) -> Optional[DatabaseClient]:
40
+ """Return an existing client if it points to the same physical database."""
41
+ identity_keys = _CONNECTION_IDENTITY_KEYS.get(db_type, ())
42
+ if not identity_keys:
43
+ return None
44
+
45
+ result = await db.execute(
46
+ select(DatabaseClient).where(
47
+ DatabaseClient.user_id == user_id,
48
+ DatabaseClient.db_type == db_type,
49
+ )
50
+ )
51
+ for existing in result.scalars().all():
52
+ decrypted = decrypt_credentials_dict(existing.credentials)
53
+ if all(
54
+ decrypted.get(k) == credentials.get(k) for k in identity_keys
55
+ ):
56
+ return existing
57
+ return None
58
+
59
  async def create(
60
  self,
61
  db: AsyncSession,
 
66
  ) -> DatabaseClient:
67
  """Register a new database client connection.
68
 
69
+ If a connection to the same physical database already exists for this
70
+ user, the existing record is returned instead of creating a duplicate.
71
  Credentials are encrypted before being stored.
72
  """
73
+ existing = await self._find_duplicate(db, user_id, db_type, credentials)
74
+ if existing:
75
+ logger.info(
76
+ f"Duplicate connection detected, returning existing client {existing.id}"
77
+ )
78
+ return existing
79
+
80
  client = DatabaseClient(
81
  id=str(uuid.uuid4()),
82
  user_id=user_id,