|
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor |
|
from core.helper.encrypter import decrypt_token, encrypt_token |
|
from extensions.ext_database import db |
|
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint |
|
|
|
|
|
class APIBasedExtensionService: |
|
@staticmethod |
|
def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]: |
|
extension_list = ( |
|
db.session.query(APIBasedExtension) |
|
.filter_by(tenant_id=tenant_id) |
|
.order_by(APIBasedExtension.created_at.desc()) |
|
.all() |
|
) |
|
|
|
for extension in extension_list: |
|
extension.api_key = decrypt_token(extension.tenant_id, extension.api_key) |
|
|
|
return extension_list |
|
|
|
@classmethod |
|
def save(cls, extension_data: APIBasedExtension) -> APIBasedExtension: |
|
cls._validation(extension_data) |
|
|
|
extension_data.api_key = encrypt_token(extension_data.tenant_id, extension_data.api_key) |
|
|
|
db.session.add(extension_data) |
|
db.session.commit() |
|
return extension_data |
|
|
|
@staticmethod |
|
def delete(extension_data: APIBasedExtension) -> None: |
|
db.session.delete(extension_data) |
|
db.session.commit() |
|
|
|
@staticmethod |
|
def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: |
|
extension = ( |
|
db.session.query(APIBasedExtension) |
|
.filter_by(tenant_id=tenant_id) |
|
.filter_by(id=api_based_extension_id) |
|
.first() |
|
) |
|
|
|
if not extension: |
|
raise ValueError("API based extension is not found") |
|
|
|
extension.api_key = decrypt_token(extension.tenant_id, extension.api_key) |
|
|
|
return extension |
|
|
|
@classmethod |
|
def _validation(cls, extension_data: APIBasedExtension) -> None: |
|
|
|
if not extension_data.name: |
|
raise ValueError("name must not be empty") |
|
|
|
if not extension_data.id: |
|
|
|
is_name_existed = ( |
|
db.session.query(APIBasedExtension) |
|
.filter_by(tenant_id=extension_data.tenant_id) |
|
.filter_by(name=extension_data.name) |
|
.first() |
|
) |
|
|
|
if is_name_existed: |
|
raise ValueError("name must be unique, it is already existed") |
|
else: |
|
|
|
is_name_existed = ( |
|
db.session.query(APIBasedExtension) |
|
.filter_by(tenant_id=extension_data.tenant_id) |
|
.filter_by(name=extension_data.name) |
|
.filter(APIBasedExtension.id != extension_data.id) |
|
.first() |
|
) |
|
|
|
if is_name_existed: |
|
raise ValueError("name must be unique, it is already existed") |
|
|
|
|
|
if not extension_data.api_endpoint: |
|
raise ValueError("api_endpoint must not be empty") |
|
|
|
|
|
if not extension_data.api_key: |
|
raise ValueError("api_key must not be empty") |
|
|
|
if len(extension_data.api_key) < 5: |
|
raise ValueError("api_key must be at least 5 characters") |
|
|
|
|
|
cls._ping_connection(extension_data) |
|
|
|
@staticmethod |
|
def _ping_connection(extension_data: APIBasedExtension) -> None: |
|
try: |
|
client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key) |
|
resp = client.request(point=APIBasedExtensionPoint.PING, params={}) |
|
if resp.get("result") != "pong": |
|
raise ValueError(resp) |
|
except Exception as e: |
|
raise ValueError("connection error: {}".format(e)) |
|
|