File size: 2,709 Bytes
868b252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import secrets
from typing import Dict, Optional

from fastapi import Depends
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession

from reworkd_platform.db.crud.base import BaseCrud
from reworkd_platform.db.dependencies import get_db_session
from reworkd_platform.db.models.auth import OauthCredentials
from reworkd_platform.schemas import UserBase


class OAuthCrud(BaseCrud):
    @classmethod
    async def inject(
        cls,
        session: AsyncSession = Depends(get_db_session),
    ) -> "OAuthCrud":
        return cls(session)

    async def create_installation(
        self, user: UserBase, provider: str, redirect_uri: Optional[str]
    ) -> OauthCredentials:
        return await OauthCredentials(
            user_id=user.id,
            organization_id=user.organization_id,
            provider=provider,
            state=secrets.token_hex(16),
            redirect_uri=redirect_uri,
        ).save(self.session)

    async def get_installation_by_state(self, state: str) -> Optional[OauthCredentials]:
        query = select(OauthCredentials).filter(OauthCredentials.state == state)

        return (await self.session.execute(query)).scalar_one_or_none()

    async def get_installation_by_user_id(
        self, user_id: str, provider: str
    ) -> Optional[OauthCredentials]:
        query = select(OauthCredentials).filter(
            OauthCredentials.user_id == user_id,
            OauthCredentials.provider == provider,
            OauthCredentials.access_token_enc.isnot(None),
        )

        return (await self.session.execute(query)).scalars().first()

    async def get_installation_by_organization_id(
        self, organization_id: str, provider: str
    ) -> Optional[OauthCredentials]:
        query = select(OauthCredentials).filter(
            OauthCredentials.organization_id == organization_id,
            OauthCredentials.provider == provider,
            OauthCredentials.access_token_enc.isnot(None),
            OauthCredentials.organization_id.isnot(None),
        )

        return (await self.session.execute(query)).scalars().first()

    async def get_all(self, user: UserBase) -> Dict[str, str]:
        query = (
            select(
                OauthCredentials.provider,
                func.any_value(OauthCredentials.access_token_enc),
            )
            .filter(
                OauthCredentials.access_token_enc.isnot(None),
                OauthCredentials.organization_id == user.organization_id,
            )
            .group_by(OauthCredentials.provider)
        )

        return {
            provider: token
            for provider, token in (await self.session.execute(query)).all()
        }