File size: 3,756 Bytes
a8b3f00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional

from pydantic import BaseModel

from core.extension.extensible import Extensible, ExtensionModule


class ModerationAction(Enum):
    DIRECT_OUTPUT = "direct_output"
    OVERRIDDEN = "overridden"


class ModerationInputsResult(BaseModel):
    flagged: bool = False
    action: ModerationAction
    preset_response: str = ""
    inputs: dict = {}
    query: str = ""


class ModerationOutputsResult(BaseModel):
    flagged: bool = False
    action: ModerationAction
    preset_response: str = ""
    text: str = ""


class Moderation(Extensible, ABC):
    """
    The base class of moderation.
    """

    module: ExtensionModule = ExtensionModule.MODERATION

    def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None:
        super().__init__(tenant_id, config)
        self.app_id = app_id

    @classmethod
    @abstractmethod
    def validate_config(cls, tenant_id: str, config: dict) -> None:
        """
        Validate the incoming form config data.

        :param tenant_id: the id of workspace
        :param config: the form config data
        :return:
        """
        raise NotImplementedError

    @abstractmethod
    def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
        """
        Moderation for inputs.
        After the user inputs, this method will be called to perform sensitive content review
        on the user inputs and return the processed results.

        :param inputs: user inputs
        :param query: query string (required in chat app)
        :return:
        """
        raise NotImplementedError

    @abstractmethod
    def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
        """
        Moderation for outputs.
        When LLM outputs content, the front end will pass the output content (may be segmented)
        to this method for sensitive content review, and the output content will be shielded if the review fails.

        :param text: LLM output content
        :return:
        """
        raise NotImplementedError

    @classmethod
    def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool) -> None:
        # inputs_config
        inputs_config = config.get("inputs_config")
        if not isinstance(inputs_config, dict):
            raise ValueError("inputs_config must be a dict")

        # outputs_config
        outputs_config = config.get("outputs_config")
        if not isinstance(outputs_config, dict):
            raise ValueError("outputs_config must be a dict")

        inputs_config_enabled = inputs_config.get("enabled")
        outputs_config_enabled = outputs_config.get("enabled")
        if not inputs_config_enabled and not outputs_config_enabled:
            raise ValueError("At least one of inputs_config or outputs_config must be enabled")

        # preset_response
        if not is_preset_response_required:
            return

        if inputs_config_enabled:
            if not inputs_config.get("preset_response"):
                raise ValueError("inputs_config.preset_response is required")

            if len(inputs_config.get("preset_response")) > 100:
                raise ValueError("inputs_config.preset_response must be less than 100 characters")

        if outputs_config_enabled:
            if not outputs_config.get("preset_response"):
                raise ValueError("outputs_config.preset_response is required")

            if len(outputs_config.get("preset_response")) > 100:
                raise ValueError("outputs_config.preset_response must be less than 100 characters")


class ModerationError(Exception):
    pass