CatPtain commited on
Commit
b1224fd
·
verified ·
1 Parent(s): e94c687

Upload 1285 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. api/core/model_runtime/README.md +70 -0
  3. api/core/model_runtime/README_CN.md +89 -0
  4. api/core/model_runtime/__init__.py +0 -0
  5. api/core/model_runtime/model_providers/__base/__init__.py +0 -0
  6. api/core/model_runtime/model_providers/__base/ai_model.py +334 -0
  7. api/core/model_runtime/model_providers/__base/audio.mp3 +3 -0
  8. api/core/model_runtime/model_providers/__base/large_language_model.py +904 -0
  9. api/core/model_runtime/model_providers/__base/model_provider.py +121 -0
  10. api/core/model_runtime/model_providers/__base/moderation_model.py +49 -0
  11. api/core/model_runtime/model_providers/__base/rerank_model.py +69 -0
  12. api/core/model_runtime/model_providers/__base/speech2text_model.py +59 -0
  13. api/core/model_runtime/model_providers/__base/text2img_model.py +54 -0
  14. api/core/model_runtime/model_providers/__base/text_embedding_model.py +111 -0
  15. api/core/model_runtime/model_providers/__base/tokenizers/gpt2/merges.txt +0 -0
  16. api/core/model_runtime/model_providers/__base/tokenizers/gpt2/special_tokens_map.json +23 -0
  17. api/core/model_runtime/model_providers/__base/tokenizers/gpt2/tokenizer_config.json +33 -0
  18. api/core/model_runtime/model_providers/__base/tokenizers/gpt2/vocab.json +0 -0
  19. api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py +51 -0
  20. api/core/model_runtime/model_providers/__base/tts_model.py +179 -0
  21. api/core/model_runtime/model_providers/__init__.py +3 -0
  22. api/core/model_runtime/model_providers/_position.yaml +43 -0
  23. api/core/model_runtime/model_providers/anthropic/__init__.py +0 -0
  24. api/core/model_runtime/model_providers/anthropic/_assets/icon_l_en.svg +78 -0
  25. api/core/model_runtime/model_providers/anthropic/_assets/icon_s_en.svg +4 -0
  26. api/core/model_runtime/model_providers/anthropic/anthropic.py +28 -0
  27. api/core/model_runtime/model_providers/anthropic/anthropic.yaml +39 -0
  28. api/core/model_runtime/model_providers/anthropic/llm/__init__.py +0 -0
  29. api/core/model_runtime/model_providers/anthropic/llm/_position.yaml +10 -0
  30. api/core/model_runtime/model_providers/anthropic/llm/claude-2.1.yaml +36 -0
  31. api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml +37 -0
  32. api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-haiku-20241022.yaml +38 -0
  33. api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml +40 -0
  34. api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml +40 -0
  35. api/core/model_runtime/model_providers/anthropic/llm/claude-3-haiku-20240307.yaml +39 -0
  36. api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml +39 -0
  37. api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml +39 -0
  38. api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.2.yaml +36 -0
  39. api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml +36 -0
  40. api/core/model_runtime/model_providers/anthropic/llm/llm.py +654 -0
  41. api/core/model_runtime/model_providers/azure_ai_studio/__init__.py +0 -0
  42. api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_l_en.png +0 -0
  43. api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_s_en.png +0 -0
  44. api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.py +17 -0
  45. api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.yaml +99 -0
  46. api/core/model_runtime/model_providers/azure_ai_studio/llm/__init__.py +0 -0
  47. api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py +345 -0
  48. api/core/model_runtime/model_providers/azure_ai_studio/rerank/__init__.py +0 -0
  49. api/core/model_runtime/model_providers/azure_ai_studio/rerank/rerank.py +164 -0
  50. api/core/model_runtime/model_providers/azure_openai/__init__.py +0 -0
.gitattributes CHANGED
@@ -26,3 +26,9 @@ api/core/model_runtime/docs/zh_Hans/images/index/image-20231210144229650.png fil
26
  api/core/model_runtime/docs/zh_Hans/images/index/image-20231210144814617.png filter=lfs diff=lfs merge=lfs -text
27
  api/core/model_runtime/docs/zh_Hans/images/index/image-20231210165243632.png filter=lfs diff=lfs merge=lfs -text
28
  api/core/model_runtime/docs/zh_Hans/images/index/image.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
26
  api/core/model_runtime/docs/zh_Hans/images/index/image-20231210144814617.png filter=lfs diff=lfs merge=lfs -text
27
  api/core/model_runtime/docs/zh_Hans/images/index/image-20231210165243632.png filter=lfs diff=lfs merge=lfs -text
28
  api/core/model_runtime/docs/zh_Hans/images/index/image.png filter=lfs diff=lfs merge=lfs -text
29
+ api/core/model_runtime/model_providers/__base/audio.mp3 filter=lfs diff=lfs merge=lfs -text
30
+ api/core/model_runtime/model_providers/gpustack/_assets/icon_l_en.png filter=lfs diff=lfs merge=lfs -text
31
+ api/core/model_runtime/model_providers/leptonai/_assets/icon_l_en.png filter=lfs diff=lfs merge=lfs -text
32
+ api/core/model_runtime/model_providers/mixedbread/_assets/icon_l_en.png filter=lfs diff=lfs merge=lfs -text
33
+ api/core/model_runtime/model_providers/nvidia_nim/_assets/icon_l_en.png filter=lfs diff=lfs merge=lfs -text
34
+ api/core/model_runtime/model_providers/nvidia/_assets/icon_l_en.png filter=lfs diff=lfs merge=lfs -text
api/core/model_runtime/README.md ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Runtime
2
+
3
+ This module provides the interface for invoking and authenticating various models, and offers Dify a unified information and credentials form rule for model providers.
4
+
5
+ - On one hand, it decouples models from upstream and downstream processes, facilitating horizontal expansion for developers,
6
+ - On the other hand, it allows for direct display of providers and models in the frontend interface by simply defining them in the backend, eliminating the need to modify frontend logic.
7
+
8
+ ## Features
9
+
10
+ - Supports capability invocation for 5 types of models
11
+
12
+ - `LLM` - LLM text completion, dialogue, pre-computed tokens capability
13
+ - `Text Embedding Model` - Text Embedding, pre-computed tokens capability
14
+ - `Rerank Model` - Segment Rerank capability
15
+ - `Speech-to-text Model` - Speech to text capability
16
+ - `Text-to-speech Model` - Text to speech capability
17
+ - `Moderation` - Moderation capability
18
+
19
+ - Model provider display
20
+
21
+ ![image-20231210143654461](./docs/en_US/images/index/image-20231210143654461.png)
22
+
23
+ Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. For detailed rule design, see: [Schema](./docs/en_US/schema.md).
24
+
25
+ - Selectable model list display
26
+
27
+ ![image-20231210144229650](./docs/en_US/images/index/image-20231210144229650.png)
28
+
29
+ After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models.
30
+
31
+ In addition, this list also returns configurable parameter information and rules for LLM, as shown below:
32
+
33
+ ![image-20231210144814617](./docs/en_US/images/index/image-20231210144814617.png)
34
+
35
+ These parameters are all defined in the backend, allowing different settings for various parameters supported by different models, as detailed in: [Schema](./docs/en_US/schema.md#ParameterRule).
36
+
37
+ - Provider/model credential authentication
38
+
39
+ ![image-20231210151548521](./docs/en_US/images/index/image-20231210151548521.png)
40
+
41
+ ![image-20231210151628992](./docs/en_US/images/index/image-20231210151628992.png)
42
+
43
+ The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. The first image above is a provider credential DEMO, and the second is a model credential DEMO.
44
+
45
+ ## Structure
46
+
47
+ ![](./docs/en_US/images/index/image-20231210165243632.png)
48
+
49
+ Model Runtime is divided into three layers:
50
+
51
+ - The outermost layer is the factory method
52
+
53
+ It provides methods for obtaining all providers, all model lists, getting provider instances, and authenticating provider/model credentials.
54
+
55
+ - The second layer is the provider layer
56
+
57
+ It provides the current provider's model list, model instance obtaining, provider credential authentication, and provider configuration rule information, **allowing horizontal expansion** to support different providers.
58
+
59
+ - The bottom layer is the model layer
60
+
61
+ It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types).
62
+
63
+
64
+
65
+ ## Next Steps
66
+
67
+ - Add new provider configuration: [Link](./docs/en_US/provider_scale_out.md)
68
+ - Add new models for existing providers: [Link](./docs/en_US/provider_scale_out.md#AddModel)
69
+ - View YAML configuration rules: [Link](./docs/en_US/schema.md)
70
+ - Implement interface methods: [Link](./docs/en_US/interfaces.md)
api/core/model_runtime/README_CN.md ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Runtime
2
+
3
+ 该模块提供了各模型的调用、鉴权接口,并为 Dify 提供了统一的模型供应商的信息和凭据表单规则。
4
+
5
+ - 一方面将模型和上下游解耦,方便开发者对模型横向扩展,
6
+ - 另一方面提供了只需在后端定义供应商和模型,即可在前端页面直接展示,无需修改前端逻辑。
7
+
8
+ ## 功能介绍
9
+
10
+ - 支持 5 种模型类型的能力调用
11
+
12
+ - `LLM` - LLM 文本补全、对话,预计算 tokens 能力
13
+ - `Text Embedidng Model` - 文本 Embedding ,预计算 tokens 能力
14
+ - `Rerank Model` - 分段 Rerank 能力
15
+ - `Speech-to-text Model` - 语音转文本能力
16
+ - `Text-to-speech Model` - 文本转语音能力
17
+ - `Moderation` - Moderation 能力
18
+
19
+ - 模型供应商展示
20
+
21
+ ![image-20231210143654461](./docs/zh_Hans/images/index/image-20231210143654461.png)
22
+
23
+ ​ 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。
24
+
25
+ - 可选择的模型列表展示
26
+
27
+ ![image-20231210144229650](./docs/zh_Hans/images/index/image-20231210144229650.png)
28
+
29
+ ​ 配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。
30
+
31
+ ​ 除此之外,该列表还返回了 LLM 可配置的参数信息和规则,如下图:
32
+
33
+ ​ ![image-20231210144814617](./docs/zh_Hans/images/index/image-20231210144814617.png)
34
+
35
+ ​ 这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数,详见:[Schema](./docs/zh_Hans/schema.md#ParameterRule)。
36
+
37
+ - 供应商/模型凭据鉴权
38
+
39
+ ![image-20231210151548521](./docs/zh_Hans/images/index/image-20231210151548521.png)
40
+
41
+ ![image-20231210151628992](./docs/zh_Hans/images/index/image-20231210151628992.png)
42
+
43
+ ​ 供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权,上图 1 为供应商凭据 DEMO,上图 2 为模型凭据 DEMO。
44
+
45
+ ## 结构
46
+
47
+ ![](./docs/zh_Hans/images/index/image-20231210165243632.png)
48
+
49
+ Model Runtime 分三层:
50
+
51
+ - 最外层为工厂方法
52
+
53
+ 提供获取所有供应商、所有模型列表、获取供应商实例、供应商/模型凭据鉴权方法。
54
+
55
+ - 第二层为供应商层
56
+
57
+ 提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。
58
+
59
+ 对于供应商/模型凭据,有两种情况
60
+ - 如OpenAI这类中心化供应商,需要定义如**api_key**这类的鉴权凭据
61
+ - 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
62
+ ![Alt text](docs/zh_Hans/images/index/image.png)
63
+
64
+ 当配置好凭据后,就可以通过DifyRuntime的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
65
+
66
+ - 最底层为模型层
67
+
68
+ 提供各种模型类型的直接调用、预定义模型配置信息、获取预定义/远程模型列表、模型凭据鉴权方法,不同模型额外提供了特殊方法,如 LLM 提供预计算 tokens 方法、获取费用信息方法等,**可横向扩展**同供应商下不同的模型(支持的模型类型下)。
69
+
70
+ 在这里我们需要先区分模型参数与模型凭据。
71
+
72
+ - 模型参数(**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等,这些参数是由用户在前端页面上进行调整的,因此需要在后端定义参数的规则,以便前端页面进行展示和调整。在DifyRuntime中,他们的参数名一般为**model_parameters: dict[str, any]**。
73
+
74
+ - 模型凭据(**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在DifyRuntime中,他们的参数名一般为**credentials: dict[str, any]**,Provider层的credentials会直接被传递到这一层,不需要再单独定义。
75
+
76
+ ## 下一步
77
+
78
+ ### [增加新的供应商配置 👈🏻](./docs/zh_Hans/provider_scale_out.md)
79
+ 当添加后,这里将会出现一个新的供应商
80
+
81
+ ![Alt text](docs/zh_Hans/images/index/image-1.png)
82
+
83
+ ### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#增加模型)
84
+ 当添加后,对应供应商的模型列表中将会出现一个新的预定义模型供用户选择,如GPT-3.5 GPT-4 ChatGLM3-6b等,而对于支持自定义模型的供应商,则不需要新增模型。
85
+
86
+ ![Alt text](docs/zh_Hans/images/index/image-2.png)
87
+
88
+ ### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md)
89
+ 你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。
api/core/model_runtime/__init__.py ADDED
File without changes
api/core/model_runtime/model_providers/__base/__init__.py ADDED
File without changes
api/core/model_runtime/model_providers/__base/ai_model.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import decimal
2
+ import os
3
+ from abc import ABC, abstractmethod
4
+ from typing import Optional
5
+
6
+ from pydantic import ConfigDict
7
+
8
+ from core.helper.position_helper import get_position_map, sort_by_position_map
9
+ from core.model_runtime.entities.common_entities import I18nObject
10
+ from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
11
+ from core.model_runtime.entities.model_entities import (
12
+ AIModelEntity,
13
+ DefaultParameterName,
14
+ FetchFrom,
15
+ ModelType,
16
+ PriceConfig,
17
+ PriceInfo,
18
+ PriceType,
19
+ )
20
+ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
21
+ from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
22
+ from core.tools.utils.yaml_utils import load_yaml_file
23
+
24
+
25
+ class AIModel(ABC):
26
+ """
27
+ Base class for all models.
28
+ """
29
+
30
+ model_type: ModelType
31
+ model_schemas: Optional[list[AIModelEntity]] = None
32
+ started_at: float = 0
33
+
34
+ # pydantic configs
35
+ model_config = ConfigDict(protected_namespaces=())
36
+
37
+ @abstractmethod
38
+ def validate_credentials(self, model: str, credentials: dict) -> None:
39
+ """
40
+ Validate model credentials
41
+
42
+ :param model: model name
43
+ :param credentials: model credentials
44
+ :return:
45
+ """
46
+ raise NotImplementedError
47
+
48
+ @property
49
+ @abstractmethod
50
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
51
+ """
52
+ Map model invoke error to unified error
53
+ The key is the error type thrown to the caller
54
+ The value is the error type thrown by the model,
55
+ which needs to be converted into a unified error type for the caller.
56
+
57
+ :return: Invoke error mapping
58
+ """
59
+ raise NotImplementedError
60
+
61
+ def _transform_invoke_error(self, error: Exception) -> InvokeError:
62
+ """
63
+ Transform invoke error to unified error
64
+
65
+ :param error: model invoke error
66
+ :return: unified error
67
+ """
68
+ provider_name = self.__class__.__module__.split(".")[-3]
69
+
70
+ for invoke_error, model_errors in self._invoke_error_mapping.items():
71
+ if isinstance(error, tuple(model_errors)):
72
+ if invoke_error == InvokeAuthorizationError:
73
+ return invoke_error(
74
+ description=(
75
+ f"[{provider_name}] Incorrect model credentials provided, please check and try again."
76
+ )
77
+ )
78
+
79
+ return invoke_error(description=f"[{provider_name}] {invoke_error.description}, {str(error)}")
80
+
81
+ return InvokeError(description=f"[{provider_name}] Error: {str(error)}")
82
+
83
+ def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo:
84
+ """
85
+ Get price for given model and tokens
86
+
87
+ :param model: model name
88
+ :param credentials: model credentials
89
+ :param price_type: price type
90
+ :param tokens: number of tokens
91
+ :return: price info
92
+ """
93
+ # get model schema
94
+ model_schema = self.get_model_schema(model, credentials)
95
+
96
+ # get price info from predefined model schema
97
+ price_config: Optional[PriceConfig] = None
98
+ if model_schema and model_schema.pricing:
99
+ price_config = model_schema.pricing
100
+
101
+ # get unit price
102
+ unit_price = None
103
+ if price_config:
104
+ if price_type == PriceType.INPUT:
105
+ unit_price = price_config.input
106
+ elif price_type == PriceType.OUTPUT and price_config.output is not None:
107
+ unit_price = price_config.output
108
+
109
+ if unit_price is None:
110
+ return PriceInfo(
111
+ unit_price=decimal.Decimal("0.0"),
112
+ unit=decimal.Decimal("0.0"),
113
+ total_amount=decimal.Decimal("0.0"),
114
+ currency="USD",
115
+ )
116
+
117
+ # calculate total amount
118
+ if not price_config:
119
+ raise ValueError(f"Price config not found for model {model}")
120
+ total_amount = tokens * unit_price * price_config.unit
121
+ total_amount = total_amount.quantize(decimal.Decimal("0.0000001"), rounding=decimal.ROUND_HALF_UP)
122
+
123
+ return PriceInfo(
124
+ unit_price=unit_price,
125
+ unit=price_config.unit,
126
+ total_amount=total_amount,
127
+ currency=price_config.currency,
128
+ )
129
+
130
+ def predefined_models(self) -> list[AIModelEntity]:
131
+ """
132
+ Get all predefined models for given provider.
133
+
134
+ :return:
135
+ """
136
+ if self.model_schemas:
137
+ return self.model_schemas
138
+
139
+ model_schemas = []
140
+
141
+ # get module name
142
+ model_type = self.__class__.__module__.split(".")[-1]
143
+
144
+ # get provider name
145
+ provider_name = self.__class__.__module__.split(".")[-3]
146
+
147
+ # get the path of current classes
148
+ current_path = os.path.abspath(__file__)
149
+ # get parent path of the current path
150
+ provider_model_type_path = os.path.join(
151
+ os.path.dirname(os.path.dirname(current_path)), provider_name, model_type
152
+ )
153
+
154
+ # get all yaml files path under provider_model_type_path that do not start with __
155
+ model_schema_yaml_paths = [
156
+ os.path.join(provider_model_type_path, model_schema_yaml)
157
+ for model_schema_yaml in os.listdir(provider_model_type_path)
158
+ if not model_schema_yaml.startswith("__")
159
+ and not model_schema_yaml.startswith("_")
160
+ and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
161
+ and model_schema_yaml.endswith(".yaml")
162
+ ]
163
+
164
+ # get _position.yaml file path
165
+ position_map = get_position_map(provider_model_type_path)
166
+
167
+ # traverse all model_schema_yaml_paths
168
+ for model_schema_yaml_path in model_schema_yaml_paths:
169
+ # read yaml data from yaml file
170
+ yaml_data = load_yaml_file(model_schema_yaml_path)
171
+
172
+ new_parameter_rules = []
173
+ for parameter_rule in yaml_data.get("parameter_rules", []):
174
+ if "use_template" in parameter_rule:
175
+ try:
176
+ default_parameter_name = DefaultParameterName.value_of(parameter_rule["use_template"])
177
+ default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
178
+ copy_default_parameter_rule = default_parameter_rule.copy()
179
+ copy_default_parameter_rule.update(parameter_rule)
180
+ parameter_rule = copy_default_parameter_rule
181
+ except ValueError:
182
+ pass
183
+
184
+ if "label" not in parameter_rule:
185
+ parameter_rule["label"] = {"zh_Hans": parameter_rule["name"], "en_US": parameter_rule["name"]}
186
+
187
+ new_parameter_rules.append(parameter_rule)
188
+
189
+ yaml_data["parameter_rules"] = new_parameter_rules
190
+
191
+ if "label" not in yaml_data:
192
+ yaml_data["label"] = {"zh_Hans": yaml_data["model"], "en_US": yaml_data["model"]}
193
+
194
+ yaml_data["fetch_from"] = FetchFrom.PREDEFINED_MODEL.value
195
+
196
+ try:
197
+ # yaml_data to entity
198
+ model_schema = AIModelEntity(**yaml_data)
199
+ except Exception as e:
200
+ model_schema_yaml_file_name = os.path.basename(model_schema_yaml_path).rstrip(".yaml")
201
+ raise Exception(
202
+ f"Invalid model schema for {provider_name}.{model_type}.{model_schema_yaml_file_name}: {str(e)}"
203
+ )
204
+
205
+ # cache model schema
206
+ model_schemas.append(model_schema)
207
+
208
+ # resort model schemas by position
209
+ model_schemas = sort_by_position_map(position_map, model_schemas, lambda x: x.model)
210
+
211
+ # cache model schemas
212
+ self.model_schemas = model_schemas
213
+
214
+ return model_schemas
215
+
216
+ def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]:
217
+ """
218
+ Get model schema by model name and credentials
219
+
220
+ :param model: model name
221
+ :param credentials: model credentials
222
+ :return: model schema
223
+ """
224
+ # Try to get model schema from predefined models
225
+ for predefined_model in self.predefined_models():
226
+ if model == predefined_model.model:
227
+ return predefined_model
228
+
229
+ # Try to get model schema from credentials
230
+ if credentials:
231
+ model_schema = self.get_customizable_model_schema_from_credentials(model, credentials)
232
+ if model_schema:
233
+ return model_schema
234
+
235
+ return None
236
+
237
+ def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
238
+ """
239
+ Get customizable model schema from credentials
240
+
241
+ :param model: model name
242
+ :param credentials: model credentials
243
+ :return: model schema
244
+ """
245
+ return self._get_customizable_model_schema(model, credentials)
246
+
247
+ def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
248
+ """
249
+ Get customizable model schema and fill in the template
250
+ """
251
+ schema = self.get_customizable_model_schema(model, credentials)
252
+
253
+ if not schema:
254
+ return None
255
+
256
+ # fill in the template
257
+ new_parameter_rules = []
258
+ for parameter_rule in schema.parameter_rules:
259
+ if parameter_rule.use_template:
260
+ try:
261
+ default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
262
+ default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
263
+ if not parameter_rule.max and "max" in default_parameter_rule:
264
+ parameter_rule.max = default_parameter_rule["max"]
265
+ if not parameter_rule.min and "min" in default_parameter_rule:
266
+ parameter_rule.min = default_parameter_rule["min"]
267
+ if not parameter_rule.default and "default" in default_parameter_rule:
268
+ parameter_rule.default = default_parameter_rule["default"]
269
+ if not parameter_rule.precision and "precision" in default_parameter_rule:
270
+ parameter_rule.precision = default_parameter_rule["precision"]
271
+ if not parameter_rule.required and "required" in default_parameter_rule:
272
+ parameter_rule.required = default_parameter_rule["required"]
273
+ if not parameter_rule.help and "help" in default_parameter_rule:
274
+ parameter_rule.help = I18nObject(
275
+ en_US=default_parameter_rule["help"]["en_US"],
276
+ )
277
+ if (
278
+ parameter_rule.help
279
+ and not parameter_rule.help.en_US
280
+ and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"])
281
+ ):
282
+ parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"]
283
+ if (
284
+ parameter_rule.help
285
+ and not parameter_rule.help.zh_Hans
286
+ and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"])
287
+ ):
288
+ parameter_rule.help.zh_Hans = default_parameter_rule["help"].get(
289
+ "zh_Hans", default_parameter_rule["help"]["en_US"]
290
+ )
291
+ except ValueError:
292
+ pass
293
+
294
+ new_parameter_rules.append(parameter_rule)
295
+
296
+ schema.parameter_rules = new_parameter_rules
297
+
298
+ return schema
299
+
300
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
301
+ """
302
+ Get customizable model schema
303
+
304
+ :param model: model name
305
+ :param credentials: model credentials
306
+ :return: model schema
307
+ """
308
+ return None
309
+
310
+ def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName) -> dict:
311
+ """
312
+ Get default parameter rule for given name
313
+
314
+ :param name: parameter name
315
+ :return: parameter rule
316
+ """
317
+ default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name)
318
+
319
+ if not default_parameter_rule:
320
+ raise Exception(f"Invalid model parameter rule name {name}")
321
+
322
+ return default_parameter_rule
323
+
324
+ def _get_num_tokens_by_gpt2(self, text: str) -> int:
325
+ """
326
+ Get number of tokens for given prompt messages by gpt2
327
+ Some provider models do not provide an interface for obtaining the number of tokens.
328
+ Here, the gpt2 tokenizer is used to calculate the number of tokens.
329
+ This method can be executed offline, and the gpt2 tokenizer has been cached in the project.
330
+
331
+ :param text: plain text of prompt. You need to convert the original message to plain text
332
+ :return: number of tokens
333
+ """
334
+ return GPT2Tokenizer.get_num_tokens(text)
api/core/model_runtime/model_providers/__base/audio.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29b714073410fefc10ecb80526b5c7c33df73b0830ff0e7778d5065a6cfcae3e
3
+ size 218880
api/core/model_runtime/model_providers/__base/large_language_model.py ADDED
@@ -0,0 +1,904 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+ import time
4
+ from abc import abstractmethod
5
+ from collections.abc import Generator, Sequence
6
+ from typing import Optional, Union
7
+
8
+ from pydantic import ConfigDict
9
+
10
+ from configs import dify_config
11
+ from core.model_runtime.callbacks.base_callback import Callback
12
+ from core.model_runtime.callbacks.logging_callback import LoggingCallback
13
+ from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
14
+ from core.model_runtime.entities.message_entities import (
15
+ AssistantPromptMessage,
16
+ PromptMessage,
17
+ PromptMessageContentType,
18
+ PromptMessageTool,
19
+ SystemPromptMessage,
20
+ UserPromptMessage,
21
+ )
22
+ from core.model_runtime.entities.model_entities import (
23
+ ModelPropertyKey,
24
+ ModelType,
25
+ ParameterRule,
26
+ ParameterType,
27
+ PriceType,
28
+ )
29
+ from core.model_runtime.model_providers.__base.ai_model import AIModel
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class LargeLanguageModel(AIModel):
35
+ """
36
+ Model class for large language model.
37
+ """
38
+
39
+ model_type: ModelType = ModelType.LLM
40
+
41
+ # pydantic configs
42
+ model_config = ConfigDict(protected_namespaces=())
43
+
44
+ def invoke(
45
+ self,
46
+ model: str,
47
+ credentials: dict,
48
+ prompt_messages: list[PromptMessage],
49
+ model_parameters: Optional[dict] = None,
50
+ tools: Optional[list[PromptMessageTool]] = None,
51
+ stop: Optional[list[str]] = None,
52
+ stream: bool = True,
53
+ user: Optional[str] = None,
54
+ callbacks: Optional[list[Callback]] = None,
55
+ ) -> Union[LLMResult, Generator]:
56
+ """
57
+ Invoke large language model
58
+
59
+ :param model: model name
60
+ :param credentials: model credentials
61
+ :param prompt_messages: prompt messages
62
+ :param model_parameters: model parameters
63
+ :param tools: tools for tool calling
64
+ :param stop: stop words
65
+ :param stream: is stream response
66
+ :param user: unique user id
67
+ :param callbacks: callbacks
68
+ :return: full response or stream response chunk generator result
69
+ """
70
+ # validate and filter model parameters
71
+ if model_parameters is None:
72
+ model_parameters = {}
73
+
74
+ model_parameters = self._validate_and_filter_model_parameters(model, model_parameters, credentials)
75
+
76
+ self.started_at = time.perf_counter()
77
+
78
+ callbacks = callbacks or []
79
+
80
+ if dify_config.DEBUG:
81
+ callbacks.append(LoggingCallback())
82
+
83
+ # trigger before invoke callbacks
84
+ self._trigger_before_invoke_callbacks(
85
+ model=model,
86
+ credentials=credentials,
87
+ prompt_messages=prompt_messages,
88
+ model_parameters=model_parameters,
89
+ tools=tools,
90
+ stop=stop,
91
+ stream=stream,
92
+ user=user,
93
+ callbacks=callbacks,
94
+ )
95
+
96
+ try:
97
+ if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}:
98
+ result = self._code_block_mode_wrapper(
99
+ model=model,
100
+ credentials=credentials,
101
+ prompt_messages=prompt_messages,
102
+ model_parameters=model_parameters,
103
+ tools=tools,
104
+ stop=stop,
105
+ stream=stream,
106
+ user=user,
107
+ callbacks=callbacks,
108
+ )
109
+ else:
110
+ result = self._invoke(
111
+ model=model,
112
+ credentials=credentials,
113
+ prompt_messages=prompt_messages,
114
+ model_parameters=model_parameters,
115
+ tools=tools,
116
+ stop=stop,
117
+ stream=stream,
118
+ user=user,
119
+ )
120
+ except Exception as e:
121
+ self._trigger_invoke_error_callbacks(
122
+ model=model,
123
+ ex=e,
124
+ credentials=credentials,
125
+ prompt_messages=prompt_messages,
126
+ model_parameters=model_parameters,
127
+ tools=tools,
128
+ stop=stop,
129
+ stream=stream,
130
+ user=user,
131
+ callbacks=callbacks,
132
+ )
133
+
134
+ raise self._transform_invoke_error(e)
135
+
136
+ if stream and isinstance(result, Generator):
137
+ return self._invoke_result_generator(
138
+ model=model,
139
+ result=result,
140
+ credentials=credentials,
141
+ prompt_messages=prompt_messages,
142
+ model_parameters=model_parameters,
143
+ tools=tools,
144
+ stop=stop,
145
+ stream=stream,
146
+ user=user,
147
+ callbacks=callbacks,
148
+ )
149
+ elif isinstance(result, LLMResult):
150
+ self._trigger_after_invoke_callbacks(
151
+ model=model,
152
+ result=result,
153
+ credentials=credentials,
154
+ prompt_messages=prompt_messages,
155
+ model_parameters=model_parameters,
156
+ tools=tools,
157
+ stop=stop,
158
+ stream=stream,
159
+ user=user,
160
+ callbacks=callbacks,
161
+ )
162
+
163
+ return result
164
+
165
+ def _code_block_mode_wrapper(
166
+ self,
167
+ model: str,
168
+ credentials: dict,
169
+ prompt_messages: list[PromptMessage],
170
+ model_parameters: dict,
171
+ tools: Optional[list[PromptMessageTool]] = None,
172
+ stop: Optional[Sequence[str]] = None,
173
+ stream: bool = True,
174
+ user: Optional[str] = None,
175
+ callbacks: Optional[list[Callback]] = None,
176
+ ) -> Union[LLMResult, Generator]:
177
+ """
178
+ Code block mode wrapper, ensure the response is a code block with output markdown quote
179
+
180
+ :param model: model name
181
+ :param credentials: model credentials
182
+ :param prompt_messages: prompt messages
183
+ :param model_parameters: model parameters
184
+ :param tools: tools for tool calling
185
+ :param stop: stop words
186
+ :param stream: is stream response
187
+ :param user: unique user id
188
+ :param callbacks: callbacks
189
+ :return: full response or stream response chunk generator result
190
+ """
191
+
192
+ block_prompts = """You should always follow the instructions and output a valid {{block}} object.
193
+ The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
194
+ if you are not sure about the structure.
195
+
196
+ <instructions>
197
+ {{instructions}}
198
+ </instructions>
199
+ """ # noqa: E501
200
+
201
+ code_block = model_parameters.get("response_format", "")
202
+ if not code_block:
203
+ return self._invoke(
204
+ model=model,
205
+ credentials=credentials,
206
+ prompt_messages=prompt_messages,
207
+ model_parameters=model_parameters,
208
+ tools=tools,
209
+ stop=stop,
210
+ stream=stream,
211
+ user=user,
212
+ )
213
+
214
+ model_parameters.pop("response_format")
215
+ stop = list(stop) if stop is not None else []
216
+ stop.extend(["\n```", "```\n"])
217
+ block_prompts = block_prompts.replace("{{block}}", code_block)
218
+
219
+ # check if there is a system message
220
+ if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
221
+ # override the system message
222
+ prompt_messages[0] = SystemPromptMessage(
223
+ content=block_prompts.replace("{{instructions}}", str(prompt_messages[0].content))
224
+ )
225
+ else:
226
+ # insert the system message
227
+ prompt_messages.insert(
228
+ 0,
229
+ SystemPromptMessage(
230
+ content=block_prompts.replace("{{instructions}}", f"Please output a valid {code_block} object.")
231
+ ),
232
+ )
233
+
234
+ if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
235
+ # add ```JSON\n to the last text message
236
+ if isinstance(prompt_messages[-1].content, str):
237
+ prompt_messages[-1].content += f"\n```{code_block}\n"
238
+ elif isinstance(prompt_messages[-1].content, list):
239
+ for i in range(len(prompt_messages[-1].content) - 1, -1, -1):
240
+ if prompt_messages[-1].content[i].type == PromptMessageContentType.TEXT:
241
+ prompt_messages[-1].content[i].data += f"\n```{code_block}\n"
242
+ break
243
+ else:
244
+ # append a user message
245
+ prompt_messages.append(UserPromptMessage(content=f"```{code_block}\n"))
246
+
247
+ response = self._invoke(
248
+ model=model,
249
+ credentials=credentials,
250
+ prompt_messages=prompt_messages,
251
+ model_parameters=model_parameters,
252
+ tools=tools,
253
+ stop=stop,
254
+ stream=stream,
255
+ user=user,
256
+ )
257
+
258
+ if isinstance(response, Generator):
259
+ first_chunk = next(response)
260
+
261
+ def new_generator():
262
+ yield first_chunk
263
+ yield from response
264
+
265
+ if first_chunk.delta.message.content and first_chunk.delta.message.content.startswith("`"):
266
+ return self._code_block_mode_stream_processor_with_backtick(
267
+ model=model, prompt_messages=prompt_messages, input_generator=new_generator()
268
+ )
269
+ else:
270
+ return self._code_block_mode_stream_processor(
271
+ model=model, prompt_messages=prompt_messages, input_generator=new_generator()
272
+ )
273
+
274
+ return response
275
+
276
+ def _code_block_mode_stream_processor(
277
+ self, model: str, prompt_messages: list[PromptMessage], input_generator: Generator[LLMResultChunk, None, None]
278
+ ) -> Generator[LLMResultChunk, None, None]:
279
+ """
280
+ Code block mode stream processor, ensure the response is a code block with output markdown quote
281
+
282
+ :param model: model name
283
+ :param prompt_messages: prompt messages
284
+ :param input_generator: input generator
285
+ :return: output generator
286
+ """
287
+ state = "normal"
288
+ backtick_count = 0
289
+ for piece in input_generator:
290
+ if piece.delta.message.content:
291
+ content = piece.delta.message.content
292
+ piece.delta.message.content = ""
293
+ yield piece
294
+ content_piece = content
295
+ else:
296
+ yield piece
297
+ continue
298
+ new_piece: str = ""
299
+ for char in content_piece:
300
+ char = str(char)
301
+ if state == "normal":
302
+ if char == "`":
303
+ state = "in_backticks"
304
+ backtick_count = 1
305
+ else:
306
+ new_piece += char
307
+ elif state == "in_backticks":
308
+ if char == "`":
309
+ backtick_count += 1
310
+ if backtick_count == 3:
311
+ state = "skip_content"
312
+ backtick_count = 0
313
+ else:
314
+ new_piece += "`" * backtick_count + char
315
+ state = "normal"
316
+ backtick_count = 0
317
+ elif state == "skip_content":
318
+ if char.isspace():
319
+ state = "normal"
320
+
321
+ if new_piece:
322
+ yield LLMResultChunk(
323
+ model=model,
324
+ prompt_messages=prompt_messages,
325
+ delta=LLMResultChunkDelta(
326
+ index=0,
327
+ message=AssistantPromptMessage(content=new_piece, tool_calls=[]),
328
+ ),
329
+ )
330
+
331
+ def _code_block_mode_stream_processor_with_backtick(
332
+ self, model: str, prompt_messages: list, input_generator: Generator[LLMResultChunk, None, None]
333
+ ) -> Generator[LLMResultChunk, None, None]:
334
+ """
335
+ Code block mode stream processor, ensure the response is a code block with output markdown quote.
336
+ This version skips the language identifier that follows the opening triple backticks.
337
+
338
+ :param model: model name
339
+ :param prompt_messages: prompt messages
340
+ :param input_generator: input generator
341
+ :return: output generator
342
+ """
343
+ state = "search_start"
344
+ backtick_count = 0
345
+
346
+ for piece in input_generator:
347
+ if piece.delta.message.content:
348
+ content = piece.delta.message.content
349
+ # Reset content to ensure we're only processing and yielding the relevant parts
350
+ piece.delta.message.content = ""
351
+ # Yield a piece with cleared content before processing it to maintain the generator structure
352
+ yield piece
353
+ content_piece = content
354
+ else:
355
+ # Yield pieces without content directly
356
+ yield piece
357
+ continue
358
+
359
+ if state == "done":
360
+ continue
361
+
362
+ new_piece: str = ""
363
+ for char in content_piece:
364
+ if state == "search_start":
365
+ if char == "`":
366
+ backtick_count += 1
367
+ if backtick_count == 3:
368
+ state = "skip_language"
369
+ backtick_count = 0
370
+ else:
371
+ backtick_count = 0
372
+ elif state == "skip_language":
373
+ # Skip everything until the first newline, marking the end of the language identifier
374
+ if char == "\n":
375
+ state = "in_code_block"
376
+ elif state == "in_code_block":
377
+ if char == "`":
378
+ backtick_count += 1
379
+ if backtick_count == 3:
380
+ state = "done"
381
+ break
382
+ else:
383
+ if backtick_count > 0:
384
+ # If backticks were counted but we're still collecting content, it was a false start
385
+ new_piece += "`" * backtick_count
386
+ backtick_count = 0
387
+ new_piece += str(char)
388
+
389
+ elif state == "done":
390
+ break
391
+
392
+ if new_piece:
393
+ # Only yield content collected within the code block
394
+ yield LLMResultChunk(
395
+ model=model,
396
+ prompt_messages=prompt_messages,
397
+ delta=LLMResultChunkDelta(
398
+ index=0,
399
+ message=AssistantPromptMessage(content=new_piece, tool_calls=[]),
400
+ ),
401
+ )
402
+
403
+ def _wrap_thinking_by_reasoning_content(self, delta: dict, is_reasoning: bool) -> tuple[str, bool]:
404
+ """
405
+ If the reasoning response is from delta.get("reasoning_content"), we wrap
406
+ it with HTML think tag.
407
+
408
+ :param delta: delta dictionary from LLM streaming response
409
+ :param is_reasoning: is reasoning
410
+ :return: tuple of (processed_content, is_reasoning)
411
+ """
412
+
413
+ content = delta.get("content") or ""
414
+ reasoning_content = delta.get("reasoning_content")
415
+
416
+ if reasoning_content:
417
+ if not is_reasoning:
418
+ content = "<think>\n" + reasoning_content
419
+ is_reasoning = True
420
+ else:
421
+ content = reasoning_content
422
+ elif is_reasoning and content:
423
+ # do not end reasoning when content is empty
424
+ # there may be more reasoning_content later that follows previous reasoning closely
425
+ content = "\n</think>" + content
426
+ is_reasoning = False
427
+ return content, is_reasoning
428
+
429
+ def _invoke_result_generator(
430
+ self,
431
+ model: str,
432
+ result: Generator,
433
+ credentials: dict,
434
+ prompt_messages: list[PromptMessage],
435
+ model_parameters: dict,
436
+ tools: Optional[list[PromptMessageTool]] = None,
437
+ stop: Optional[Sequence[str]] = None,
438
+ stream: bool = True,
439
+ user: Optional[str] = None,
440
+ callbacks: Optional[list[Callback]] = None,
441
+ ) -> Generator:
442
+ """
443
+ Invoke result generator
444
+
445
+ :param result: result generator
446
+ :return: result generator
447
+ """
448
+ callbacks = callbacks or []
449
+ prompt_message = AssistantPromptMessage(content="")
450
+ usage = None
451
+ system_fingerprint = None
452
+ real_model = model
453
+
454
+ try:
455
+ for chunk in result:
456
+ yield chunk
457
+
458
+ self._trigger_new_chunk_callbacks(
459
+ chunk=chunk,
460
+ model=model,
461
+ credentials=credentials,
462
+ prompt_messages=prompt_messages,
463
+ model_parameters=model_parameters,
464
+ tools=tools,
465
+ stop=stop,
466
+ stream=stream,
467
+ user=user,
468
+ callbacks=callbacks,
469
+ )
470
+
471
+ prompt_message.content += chunk.delta.message.content
472
+ real_model = chunk.model
473
+ if chunk.delta.usage:
474
+ usage = chunk.delta.usage
475
+
476
+ if chunk.system_fingerprint:
477
+ system_fingerprint = chunk.system_fingerprint
478
+ except Exception as e:
479
+ raise self._transform_invoke_error(e)
480
+
481
+ self._trigger_after_invoke_callbacks(
482
+ model=model,
483
+ result=LLMResult(
484
+ model=real_model,
485
+ prompt_messages=prompt_messages,
486
+ message=prompt_message,
487
+ usage=usage or LLMUsage.empty_usage(),
488
+ system_fingerprint=system_fingerprint,
489
+ ),
490
+ credentials=credentials,
491
+ prompt_messages=prompt_messages,
492
+ model_parameters=model_parameters,
493
+ tools=tools,
494
+ stop=stop,
495
+ stream=stream,
496
+ user=user,
497
+ callbacks=callbacks,
498
+ )
499
+
500
+ @abstractmethod
501
+ def _invoke(
502
+ self,
503
+ model: str,
504
+ credentials: dict,
505
+ prompt_messages: list[PromptMessage],
506
+ model_parameters: dict,
507
+ tools: Optional[list[PromptMessageTool]] = None,
508
+ stop: Optional[Sequence[str]] = None,
509
+ stream: bool = True,
510
+ user: Optional[str] = None,
511
+ ) -> Union[LLMResult, Generator]:
512
+ """
513
+ Invoke large language model
514
+
515
+ :param model: model name
516
+ :param credentials: model credentials
517
+ :param prompt_messages: prompt messages
518
+ :param model_parameters: model parameters
519
+ :param tools: tools for tool calling
520
+ :param stop: stop words
521
+ :param stream: is stream response
522
+ :param user: unique user id
523
+ :return: full response or stream response chunk generator result
524
+ """
525
+ raise NotImplementedError
526
+
527
+ @abstractmethod
528
+ def get_num_tokens(
529
+ self,
530
+ model: str,
531
+ credentials: dict,
532
+ prompt_messages: list[PromptMessage],
533
+ tools: Optional[list[PromptMessageTool]] = None,
534
+ ) -> int:
535
+ """
536
+ Get number of tokens for given prompt messages
537
+
538
+ :param model: model name
539
+ :param credentials: model credentials
540
+ :param prompt_messages: prompt messages
541
+ :param tools: tools for tool calling
542
+ :return:
543
+ """
544
+ raise NotImplementedError
545
+
546
+ def enforce_stop_tokens(self, text: str, stop: list[str]) -> str:
547
+ """Cut off the text as soon as any stop words occur."""
548
+ return re.split("|".join(stop), text, maxsplit=1)[0]
549
+
550
+ def get_parameter_rules(self, model: str, credentials: dict) -> list[ParameterRule]:
551
+ """
552
+ Get parameter rules
553
+
554
+ :param model: model name
555
+ :param credentials: model credentials
556
+ :return: parameter rules
557
+ """
558
+ model_schema = self.get_model_schema(model, credentials)
559
+ if model_schema:
560
+ return model_schema.parameter_rules
561
+
562
+ return []
563
+
564
+ def get_model_mode(self, model: str, credentials: Optional[dict] = None) -> LLMMode:
565
+ """
566
+ Get model mode
567
+
568
+ :param model: model name
569
+ :param credentials: model credentials
570
+ :return: model mode
571
+ """
572
+ model_schema = self.get_model_schema(model, credentials)
573
+
574
+ mode = LLMMode.CHAT
575
+ if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
576
+ mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE])
577
+
578
+ return mode
579
+
580
+ def _calc_response_usage(
581
+ self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int
582
+ ) -> LLMUsage:
583
+ """
584
+ Calculate response usage
585
+
586
+ :param model: model name
587
+ :param credentials: model credentials
588
+ :param prompt_tokens: prompt tokens
589
+ :param completion_tokens: completion tokens
590
+ :return: usage
591
+ """
592
+ # get prompt price info
593
+ prompt_price_info = self.get_price(
594
+ model=model,
595
+ credentials=credentials,
596
+ price_type=PriceType.INPUT,
597
+ tokens=prompt_tokens,
598
+ )
599
+
600
+ # get completion price info
601
+ completion_price_info = self.get_price(
602
+ model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens
603
+ )
604
+
605
+ # transform usage
606
+ usage = LLMUsage(
607
+ prompt_tokens=prompt_tokens,
608
+ prompt_unit_price=prompt_price_info.unit_price,
609
+ prompt_price_unit=prompt_price_info.unit,
610
+ prompt_price=prompt_price_info.total_amount,
611
+ completion_tokens=completion_tokens,
612
+ completion_unit_price=completion_price_info.unit_price,
613
+ completion_price_unit=completion_price_info.unit,
614
+ completion_price=completion_price_info.total_amount,
615
+ total_tokens=prompt_tokens + completion_tokens,
616
+ total_price=prompt_price_info.total_amount + completion_price_info.total_amount,
617
+ currency=prompt_price_info.currency,
618
+ latency=time.perf_counter() - self.started_at,
619
+ )
620
+
621
+ return usage
622
+
623
+ def _trigger_before_invoke_callbacks(
624
+ self,
625
+ model: str,
626
+ credentials: dict,
627
+ prompt_messages: list[PromptMessage],
628
+ model_parameters: dict,
629
+ tools: Optional[list[PromptMessageTool]] = None,
630
+ stop: Optional[Sequence[str]] = None,
631
+ stream: bool = True,
632
+ user: Optional[str] = None,
633
+ callbacks: Optional[list[Callback]] = None,
634
+ ) -> None:
635
+ """
636
+ Trigger before invoke callbacks
637
+
638
+ :param model: model name
639
+ :param credentials: model credentials
640
+ :param prompt_messages: prompt messages
641
+ :param model_parameters: model parameters
642
+ :param tools: tools for tool calling
643
+ :param stop: stop words
644
+ :param stream: is stream response
645
+ :param user: unique user id
646
+ :param callbacks: callbacks
647
+ """
648
+ if callbacks:
649
+ for callback in callbacks:
650
+ try:
651
+ callback.on_before_invoke(
652
+ llm_instance=self,
653
+ model=model,
654
+ credentials=credentials,
655
+ prompt_messages=prompt_messages,
656
+ model_parameters=model_parameters,
657
+ tools=tools,
658
+ stop=stop,
659
+ stream=stream,
660
+ user=user,
661
+ )
662
+ except Exception as e:
663
+ if callback.raise_error:
664
+ raise e
665
+ else:
666
+ logger.warning(f"Callback {callback.__class__.__name__} on_before_invoke failed with error {e}")
667
+
668
+ def _trigger_new_chunk_callbacks(
669
+ self,
670
+ chunk: LLMResultChunk,
671
+ model: str,
672
+ credentials: dict,
673
+ prompt_messages: list[PromptMessage],
674
+ model_parameters: dict,
675
+ tools: Optional[list[PromptMessageTool]] = None,
676
+ stop: Optional[Sequence[str]] = None,
677
+ stream: bool = True,
678
+ user: Optional[str] = None,
679
+ callbacks: Optional[list[Callback]] = None,
680
+ ) -> None:
681
+ """
682
+ Trigger new chunk callbacks
683
+
684
+ :param chunk: chunk
685
+ :param model: model name
686
+ :param credentials: model credentials
687
+ :param prompt_messages: prompt messages
688
+ :param model_parameters: model parameters
689
+ :param tools: tools for tool calling
690
+ :param stop: stop words
691
+ :param stream: is stream response
692
+ :param user: unique user id
693
+ """
694
+ if callbacks:
695
+ for callback in callbacks:
696
+ try:
697
+ callback.on_new_chunk(
698
+ llm_instance=self,
699
+ chunk=chunk,
700
+ model=model,
701
+ credentials=credentials,
702
+ prompt_messages=prompt_messages,
703
+ model_parameters=model_parameters,
704
+ tools=tools,
705
+ stop=stop,
706
+ stream=stream,
707
+ user=user,
708
+ )
709
+ except Exception as e:
710
+ if callback.raise_error:
711
+ raise e
712
+ else:
713
+ logger.warning(f"Callback {callback.__class__.__name__} on_new_chunk failed with error {e}")
714
+
715
+ def _trigger_after_invoke_callbacks(
716
+ self,
717
+ model: str,
718
+ result: LLMResult,
719
+ credentials: dict,
720
+ prompt_messages: list[PromptMessage],
721
+ model_parameters: dict,
722
+ tools: Optional[list[PromptMessageTool]] = None,
723
+ stop: Optional[Sequence[str]] = None,
724
+ stream: bool = True,
725
+ user: Optional[str] = None,
726
+ callbacks: Optional[list[Callback]] = None,
727
+ ) -> None:
728
+ """
729
+ Trigger after invoke callbacks
730
+
731
+ :param model: model name
732
+ :param result: result
733
+ :param credentials: model credentials
734
+ :param prompt_messages: prompt messages
735
+ :param model_parameters: model parameters
736
+ :param tools: tools for tool calling
737
+ :param stop: stop words
738
+ :param stream: is stream response
739
+ :param user: unique user id
740
+ :param callbacks: callbacks
741
+ """
742
+ if callbacks:
743
+ for callback in callbacks:
744
+ try:
745
+ callback.on_after_invoke(
746
+ llm_instance=self,
747
+ result=result,
748
+ model=model,
749
+ credentials=credentials,
750
+ prompt_messages=prompt_messages,
751
+ model_parameters=model_parameters,
752
+ tools=tools,
753
+ stop=stop,
754
+ stream=stream,
755
+ user=user,
756
+ )
757
+ except Exception as e:
758
+ if callback.raise_error:
759
+ raise e
760
+ else:
761
+ logger.warning(f"Callback {callback.__class__.__name__} on_after_invoke failed with error {e}")
762
+
763
+ def _trigger_invoke_error_callbacks(
764
+ self,
765
+ model: str,
766
+ ex: Exception,
767
+ credentials: dict,
768
+ prompt_messages: list[PromptMessage],
769
+ model_parameters: dict,
770
+ tools: Optional[list[PromptMessageTool]] = None,
771
+ stop: Optional[Sequence[str]] = None,
772
+ stream: bool = True,
773
+ user: Optional[str] = None,
774
+ callbacks: Optional[list[Callback]] = None,
775
+ ) -> None:
776
+ """
777
+ Trigger invoke error callbacks
778
+
779
+ :param model: model name
780
+ :param ex: exception
781
+ :param credentials: model credentials
782
+ :param prompt_messages: prompt messages
783
+ :param model_parameters: model parameters
784
+ :param tools: tools for tool calling
785
+ :param stop: stop words
786
+ :param stream: is stream response
787
+ :param user: unique user id
788
+ :param callbacks: callbacks
789
+ """
790
+ if callbacks:
791
+ for callback in callbacks:
792
+ try:
793
+ callback.on_invoke_error(
794
+ llm_instance=self,
795
+ ex=ex,
796
+ model=model,
797
+ credentials=credentials,
798
+ prompt_messages=prompt_messages,
799
+ model_parameters=model_parameters,
800
+ tools=tools,
801
+ stop=stop,
802
+ stream=stream,
803
+ user=user,
804
+ )
805
+ except Exception as e:
806
+ if callback.raise_error:
807
+ raise e
808
+ else:
809
+ logger.warning(f"Callback {callback.__class__.__name__} on_invoke_error failed with error {e}")
810
+
811
+ def _validate_and_filter_model_parameters(self, model: str, model_parameters: dict, credentials: dict) -> dict:
812
+ """
813
+ Validate model parameters
814
+
815
+ :param model: model name
816
+ :param model_parameters: model parameters
817
+ :param credentials: model credentials
818
+ :return:
819
+ """
820
+ parameter_rules = self.get_parameter_rules(model, credentials)
821
+
822
+ # validate model parameters
823
+ filtered_model_parameters = {}
824
+ for parameter_rule in parameter_rules:
825
+ parameter_name = parameter_rule.name
826
+ parameter_value = model_parameters.get(parameter_name)
827
+ if parameter_value is None:
828
+ if parameter_rule.use_template and parameter_rule.use_template in model_parameters:
829
+ # if parameter value is None, use template value variable name instead
830
+ parameter_value = model_parameters[parameter_rule.use_template]
831
+ else:
832
+ if parameter_rule.required:
833
+ if parameter_rule.default is not None:
834
+ filtered_model_parameters[parameter_name] = parameter_rule.default
835
+ continue
836
+ else:
837
+ raise ValueError(f"Model Parameter {parameter_name} is required.")
838
+ else:
839
+ continue
840
+
841
+ # validate parameter value type
842
+ if parameter_rule.type == ParameterType.INT:
843
+ if not isinstance(parameter_value, int):
844
+ raise ValueError(f"Model Parameter {parameter_name} should be int.")
845
+
846
+ # validate parameter value range
847
+ if parameter_rule.min is not None and parameter_value < parameter_rule.min:
848
+ raise ValueError(
849
+ f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}."
850
+ )
851
+
852
+ if parameter_rule.max is not None and parameter_value > parameter_rule.max:
853
+ raise ValueError(
854
+ f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}."
855
+ )
856
+ elif parameter_rule.type == ParameterType.FLOAT:
857
+ if not isinstance(parameter_value, float | int):
858
+ raise ValueError(f"Model Parameter {parameter_name} should be float.")
859
+
860
+ # validate parameter value precision
861
+ if parameter_rule.precision is not None:
862
+ if parameter_rule.precision == 0:
863
+ if parameter_value != int(parameter_value):
864
+ raise ValueError(f"Model Parameter {parameter_name} should be int.")
865
+ else:
866
+ if parameter_value != round(parameter_value, parameter_rule.precision):
867
+ raise ValueError(
868
+ f"Model Parameter {parameter_name} should be round to {parameter_rule.precision}"
869
+ f" decimal places."
870
+ )
871
+
872
+ # validate parameter value range
873
+ if parameter_rule.min is not None and parameter_value < parameter_rule.min:
874
+ raise ValueError(
875
+ f"Model Parameter {parameter_name} should be greater than or equal to {parameter_rule.min}."
876
+ )
877
+
878
+ if parameter_rule.max is not None and parameter_value > parameter_rule.max:
879
+ raise ValueError(
880
+ f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}."
881
+ )
882
+ elif parameter_rule.type == ParameterType.BOOLEAN:
883
+ if not isinstance(parameter_value, bool):
884
+ raise ValueError(f"Model Parameter {parameter_name} should be bool.")
885
+ elif parameter_rule.type == ParameterType.STRING:
886
+ if not isinstance(parameter_value, str):
887
+ raise ValueError(f"Model Parameter {parameter_name} should be string.")
888
+
889
+ # validate options
890
+ if parameter_rule.options and parameter_value not in parameter_rule.options:
891
+ raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")
892
+ elif parameter_rule.type == ParameterType.TEXT:
893
+ if not isinstance(parameter_value, str):
894
+ raise ValueError(f"Model Parameter {parameter_name} should be text.")
895
+
896
+ # validate options
897
+ if parameter_rule.options and parameter_value not in parameter_rule.options:
898
+ raise ValueError(f"Model Parameter {parameter_name} should be one of {parameter_rule.options}.")
899
+ else:
900
+ raise ValueError(f"Model Parameter {parameter_name} type {parameter_rule.type} is not supported.")
901
+
902
+ filtered_model_parameters[parameter_name] = parameter_value
903
+
904
+ return filtered_model_parameters
api/core/model_runtime/model_providers/__base/model_provider.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+ from typing import Optional
4
+
5
+ from core.helper.module_import_helper import get_subclasses_from_module, import_module_from_source
6
+ from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
7
+ from core.model_runtime.entities.provider_entities import ProviderEntity
8
+ from core.model_runtime.model_providers.__base.ai_model import AIModel
9
+ from core.tools.utils.yaml_utils import load_yaml_file
10
+
11
+
12
+ class ModelProvider(ABC):
13
+ provider_schema: Optional[ProviderEntity] = None
14
+ model_instance_map: dict[str, AIModel] = {}
15
+
16
+ @abstractmethod
17
+ def validate_provider_credentials(self, credentials: dict) -> None:
18
+ """
19
+ Validate provider credentials
20
+ You can choose any validate_credentials method of model type or implement validate method by yourself,
21
+ such as: get model list api
22
+
23
+ if validate failed, raise exception
24
+
25
+ :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
26
+ """
27
+ raise NotImplementedError
28
+
29
+ def get_provider_schema(self) -> ProviderEntity:
30
+ """
31
+ Get provider schema
32
+
33
+ :return: provider schema
34
+ """
35
+ if self.provider_schema:
36
+ return self.provider_schema
37
+
38
+ # get dirname of the current path
39
+ provider_name = self.__class__.__module__.split(".")[-1]
40
+
41
+ # get the path of the model_provider classes
42
+ base_path = os.path.abspath(__file__)
43
+ current_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name)
44
+
45
+ # read provider schema from yaml file
46
+ yaml_path = os.path.join(current_path, f"{provider_name}.yaml")
47
+ yaml_data = load_yaml_file(yaml_path)
48
+
49
+ try:
50
+ # yaml_data to entity
51
+ provider_schema = ProviderEntity(**yaml_data)
52
+ except Exception as e:
53
+ raise Exception(f"Invalid provider schema for {provider_name}: {str(e)}")
54
+
55
+ # cache schema
56
+ self.provider_schema = provider_schema
57
+
58
+ return provider_schema
59
+
60
+ def models(self, model_type: ModelType) -> list[AIModelEntity]:
61
+ """
62
+ Get all models for given model type
63
+
64
+ :param model_type: model type defined in `ModelType`
65
+ :return: list of models
66
+ """
67
+ provider_schema = self.get_provider_schema()
68
+ if model_type not in provider_schema.supported_model_types:
69
+ return []
70
+
71
+ # get model instance of the model type
72
+ model_instance = self.get_model_instance(model_type)
73
+
74
+ # get predefined models (predefined_models)
75
+ models = model_instance.predefined_models()
76
+
77
+ # return models
78
+ return models
79
+
80
+ def get_model_instance(self, model_type: ModelType) -> AIModel:
81
+ """
82
+ Get model instance
83
+
84
+ :param model_type: model type defined in `ModelType`
85
+ :return:
86
+ """
87
+ # get dirname of the current path
88
+ provider_name = self.__class__.__module__.split(".")[-1]
89
+
90
+ if f"{provider_name}.{model_type.value}" in self.model_instance_map:
91
+ return self.model_instance_map[f"{provider_name}.{model_type.value}"]
92
+
93
+ # get the path of the model type classes
94
+ base_path = os.path.abspath(__file__)
95
+ model_type_name = model_type.value.replace("-", "_")
96
+ model_type_path = os.path.join(os.path.dirname(os.path.dirname(base_path)), provider_name, model_type_name)
97
+ model_type_py_path = os.path.join(model_type_path, f"{model_type_name}.py")
98
+
99
+ if not os.path.isdir(model_type_path) or not os.path.exists(model_type_py_path):
100
+ raise Exception(f"Invalid model type {model_type} for provider {provider_name}")
101
+
102
+ # Dynamic loading {model_type_name}.py file and find the subclass of AIModel
103
+ parent_module = ".".join(self.__class__.__module__.split(".")[:-1])
104
+ mod = import_module_from_source(
105
+ module_name=f"{parent_module}.{model_type_name}.{model_type_name}", py_file_path=model_type_py_path
106
+ )
107
+ # FIXME "type" has no attribute "__abstractmethods__" ignore it for now fix it later
108
+ model_class = next(
109
+ filter(
110
+ lambda x: x.__module__ == mod.__name__ and not x.__abstractmethods__, # type: ignore
111
+ get_subclasses_from_module(mod, AIModel),
112
+ ),
113
+ None,
114
+ )
115
+ if not model_class:
116
+ raise Exception(f"Missing AIModel Class for model type {model_type} in {model_type_py_path}")
117
+
118
+ model_instance_map = model_class()
119
+ self.model_instance_map[f"{provider_name}.{model_type.value}"] = model_instance_map
120
+
121
+ return model_instance_map
api/core/model_runtime/model_providers/__base/moderation_model.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from abc import abstractmethod
3
+ from typing import Optional
4
+
5
+ from pydantic import ConfigDict
6
+
7
+ from core.model_runtime.entities.model_entities import ModelType
8
+ from core.model_runtime.model_providers.__base.ai_model import AIModel
9
+
10
+
11
+ class ModerationModel(AIModel):
12
+ """
13
+ Model class for moderation model.
14
+ """
15
+
16
+ model_type: ModelType = ModelType.MODERATION
17
+
18
+ # pydantic configs
19
+ model_config = ConfigDict(protected_namespaces=())
20
+
21
+ def invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool:
22
+ """
23
+ Invoke moderation model
24
+
25
+ :param model: model name
26
+ :param credentials: model credentials
27
+ :param text: text to moderate
28
+ :param user: unique user id
29
+ :return: false if text is safe, true otherwise
30
+ """
31
+ self.started_at = time.perf_counter()
32
+
33
+ try:
34
+ return self._invoke(model, credentials, text, user)
35
+ except Exception as e:
36
+ raise self._transform_invoke_error(e)
37
+
38
+ @abstractmethod
39
+ def _invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool:
40
+ """
41
+ Invoke large language model
42
+
43
+ :param model: model name
44
+ :param credentials: model credentials
45
+ :param text: text to moderate
46
+ :param user: unique user id
47
+ :return: false if text is safe, true otherwise
48
+ """
49
+ raise NotImplementedError
api/core/model_runtime/model_providers/__base/rerank_model.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from abc import abstractmethod
3
+ from typing import Optional
4
+
5
+ from core.model_runtime.entities.model_entities import ModelType
6
+ from core.model_runtime.entities.rerank_entities import RerankResult
7
+ from core.model_runtime.model_providers.__base.ai_model import AIModel
8
+
9
+
10
+ class RerankModel(AIModel):
11
+ """
12
+ Base Model class for rerank model.
13
+ """
14
+
15
+ model_type: ModelType = ModelType.RERANK
16
+
17
+ def invoke(
18
+ self,
19
+ model: str,
20
+ credentials: dict,
21
+ query: str,
22
+ docs: list[str],
23
+ score_threshold: Optional[float] = None,
24
+ top_n: Optional[int] = None,
25
+ user: Optional[str] = None,
26
+ ) -> RerankResult:
27
+ """
28
+ Invoke rerank model
29
+
30
+ :param model: model name
31
+ :param credentials: model credentials
32
+ :param query: search query
33
+ :param docs: docs for reranking
34
+ :param score_threshold: score threshold
35
+ :param top_n: top n
36
+ :param user: unique user id
37
+ :return: rerank result
38
+ """
39
+ self.started_at = time.perf_counter()
40
+
41
+ try:
42
+ return self._invoke(model, credentials, query, docs, score_threshold, top_n, user)
43
+ except Exception as e:
44
+ raise self._transform_invoke_error(e)
45
+
46
+ @abstractmethod
47
+ def _invoke(
48
+ self,
49
+ model: str,
50
+ credentials: dict,
51
+ query: str,
52
+ docs: list[str],
53
+ score_threshold: Optional[float] = None,
54
+ top_n: Optional[int] = None,
55
+ user: Optional[str] = None,
56
+ ) -> RerankResult:
57
+ """
58
+ Invoke rerank model
59
+
60
+ :param model: model name
61
+ :param credentials: model credentials
62
+ :param query: search query
63
+ :param docs: docs for reranking
64
+ :param score_threshold: score threshold
65
+ :param top_n: top n
66
+ :param user: unique user id
67
+ :return: rerank result
68
+ """
69
+ raise NotImplementedError
api/core/model_runtime/model_providers/__base/speech2text_model.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import abstractmethod
3
+ from typing import IO, Optional
4
+
5
+ from pydantic import ConfigDict
6
+
7
+ from core.model_runtime.entities.model_entities import ModelType
8
+ from core.model_runtime.model_providers.__base.ai_model import AIModel
9
+
10
+
11
+ class Speech2TextModel(AIModel):
12
+ """
13
+ Model class for speech2text model.
14
+ """
15
+
16
+ model_type: ModelType = ModelType.SPEECH2TEXT
17
+
18
+ # pydantic configs
19
+ model_config = ConfigDict(protected_namespaces=())
20
+
21
+ def invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
22
+ """
23
+ Invoke large language model
24
+
25
+ :param model: model name
26
+ :param credentials: model credentials
27
+ :param file: audio file
28
+ :param user: unique user id
29
+ :return: text for given audio file
30
+ """
31
+ try:
32
+ return self._invoke(model, credentials, file, user)
33
+ except Exception as e:
34
+ raise self._transform_invoke_error(e)
35
+
36
+ @abstractmethod
37
+ def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
38
+ """
39
+ Invoke large language model
40
+
41
+ :param model: model name
42
+ :param credentials: model credentials
43
+ :param file: audio file
44
+ :param user: unique user id
45
+ :return: text for given audio file
46
+ """
47
+ raise NotImplementedError
48
+
49
+ def _get_demo_file_path(self) -> str:
50
+ """
51
+ Get demo file for given model
52
+
53
+ :return: demo file
54
+ """
55
+ # Get the directory of the current file
56
+ current_dir = os.path.dirname(os.path.abspath(__file__))
57
+
58
+ # Construct the path to the audio file
59
+ return os.path.join(current_dir, "audio.mp3")
api/core/model_runtime/model_providers/__base/text2img_model.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import IO, Optional
3
+
4
+ from pydantic import ConfigDict
5
+
6
+ from core.model_runtime.entities.model_entities import ModelType
7
+ from core.model_runtime.model_providers.__base.ai_model import AIModel
8
+
9
+
10
+ class Text2ImageModel(AIModel):
11
+ """
12
+ Model class for text2img model.
13
+ """
14
+
15
+ model_type: ModelType = ModelType.TEXT2IMG
16
+
17
+ # pydantic configs
18
+ model_config = ConfigDict(protected_namespaces=())
19
+
20
+ def invoke(
21
+ self, model: str, credentials: dict, prompt: str, model_parameters: dict, user: Optional[str] = None
22
+ ) -> list[IO[bytes]]:
23
+ """
24
+ Invoke Text2Image model
25
+
26
+ :param model: model name
27
+ :param credentials: model credentials
28
+ :param prompt: prompt for image generation
29
+ :param model_parameters: model parameters
30
+ :param user: unique user id
31
+
32
+ :return: image bytes
33
+ """
34
+ try:
35
+ return self._invoke(model, credentials, prompt, model_parameters, user)
36
+ except Exception as e:
37
+ raise self._transform_invoke_error(e)
38
+
39
+ @abstractmethod
40
+ def _invoke(
41
+ self, model: str, credentials: dict, prompt: str, model_parameters: dict, user: Optional[str] = None
42
+ ) -> list[IO[bytes]]:
43
+ """
44
+ Invoke Text2Image model
45
+
46
+ :param model: model name
47
+ :param credentials: model credentials
48
+ :param prompt: prompt for image generation
49
+ :param model_parameters: model parameters
50
+ :param user: unique user id
51
+
52
+ :return: image bytes
53
+ """
54
+ raise NotImplementedError
api/core/model_runtime/model_providers/__base/text_embedding_model.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from abc import abstractmethod
3
+ from typing import Optional
4
+
5
+ from pydantic import ConfigDict
6
+
7
+ from core.entities.embedding_type import EmbeddingInputType
8
+ from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
9
+ from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
10
+ from core.model_runtime.model_providers.__base.ai_model import AIModel
11
+
12
+
13
+ class TextEmbeddingModel(AIModel):
14
+ """
15
+ Model class for text embedding model.
16
+ """
17
+
18
+ model_type: ModelType = ModelType.TEXT_EMBEDDING
19
+
20
+ # pydantic configs
21
+ model_config = ConfigDict(protected_namespaces=())
22
+
23
+ def invoke(
24
+ self,
25
+ model: str,
26
+ credentials: dict,
27
+ texts: list[str],
28
+ user: Optional[str] = None,
29
+ input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
30
+ ) -> TextEmbeddingResult:
31
+ """
32
+ Invoke text embedding model
33
+
34
+ :param model: model name
35
+ :param credentials: model credentials
36
+ :param texts: texts to embed
37
+ :param user: unique user id
38
+ :param input_type: input type
39
+ :return: embeddings result
40
+ """
41
+ self.started_at = time.perf_counter()
42
+
43
+ try:
44
+ return self._invoke(model, credentials, texts, user, input_type)
45
+ except Exception as e:
46
+ raise self._transform_invoke_error(e)
47
+
48
+ @abstractmethod
49
+ def _invoke(
50
+ self,
51
+ model: str,
52
+ credentials: dict,
53
+ texts: list[str],
54
+ user: Optional[str] = None,
55
+ input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
56
+ ) -> TextEmbeddingResult:
57
+ """
58
+ Invoke text embedding model
59
+
60
+ :param model: model name
61
+ :param credentials: model credentials
62
+ :param texts: texts to embed
63
+ :param user: unique user id
64
+ :param input_type: input type
65
+ :return: embeddings result
66
+ """
67
+ raise NotImplementedError
68
+
69
+ @abstractmethod
70
+ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
71
+ """
72
+ Get number of tokens for given prompt messages
73
+
74
+ :param model: model name
75
+ :param credentials: model credentials
76
+ :param texts: texts to embed
77
+ :return:
78
+ """
79
+ raise NotImplementedError
80
+
81
+ def _get_context_size(self, model: str, credentials: dict) -> int:
82
+ """
83
+ Get context size for given embedding model
84
+
85
+ :param model: model name
86
+ :param credentials: model credentials
87
+ :return: context size
88
+ """
89
+ model_schema = self.get_model_schema(model, credentials)
90
+
91
+ if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties:
92
+ content_size: int = model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE]
93
+ return content_size
94
+
95
+ return 1000
96
+
97
+ def _get_max_chunks(self, model: str, credentials: dict) -> int:
98
+ """
99
+ Get max chunks for given embedding model
100
+
101
+ :param model: model name
102
+ :param credentials: model credentials
103
+ :return: max chunks
104
+ """
105
+ model_schema = self.get_model_schema(model, credentials)
106
+
107
+ if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties:
108
+ max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
109
+ return max_chunks
110
+
111
+ return 1
api/core/model_runtime/model_providers/__base/tokenizers/gpt2/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
api/core/model_runtime/model_providers/__base/tokenizers/gpt2/special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": true,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
api/core/model_runtime/model_providers/__base/tokenizers/gpt2/tokenizer_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "clean_up_tokenization_spaces": true,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "<|endoftext|>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "errors": "replace",
22
+ "model_max_length": 1024,
23
+ "pad_token": null,
24
+ "tokenizer_class": "GPT2Tokenizer",
25
+ "unk_token": {
26
+ "__type": "AddedToken",
27
+ "content": "<|endoftext|>",
28
+ "lstrip": false,
29
+ "normalized": true,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
api/core/model_runtime/model_providers/__base/tokenizers/gpt2/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from threading import Lock
3
+ from typing import Any
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ _tokenizer: Any = None
8
+ _lock = Lock()
9
+
10
+
11
+ class GPT2Tokenizer:
12
+ @staticmethod
13
+ def _get_num_tokens_by_gpt2(text: str) -> int:
14
+ """
15
+ use gpt2 tokenizer to get num tokens
16
+ """
17
+ _tokenizer = GPT2Tokenizer.get_encoder()
18
+ tokens = _tokenizer.encode(text)
19
+ return len(tokens)
20
+
21
+ @staticmethod
22
+ def get_num_tokens(text: str) -> int:
23
+ # Because this process needs more cpu resource, we turn this back before we find a better way to handle it.
24
+ #
25
+ # future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text)
26
+ # result = future.result()
27
+ # return cast(int, result)
28
+ return GPT2Tokenizer._get_num_tokens_by_gpt2(text)
29
+
30
+ @staticmethod
31
+ def get_encoder() -> Any:
32
+ global _tokenizer, _lock
33
+ with _lock:
34
+ if _tokenizer is None:
35
+ # Try to use tiktoken to get the tokenizer because it is faster
36
+ #
37
+ try:
38
+ import tiktoken
39
+
40
+ _tokenizer = tiktoken.get_encoding("gpt2")
41
+ except Exception:
42
+ from os.path import abspath, dirname, join
43
+
44
+ from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore
45
+
46
+ base_path = abspath(__file__)
47
+ gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
48
+ _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
49
+ logger.info("Fallback to Transformers' GPT-2 tokenizer from tiktoken")
50
+
51
+ return _tokenizer
api/core/model_runtime/model_providers/__base/tts_model.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+ from abc import abstractmethod
4
+ from collections.abc import Iterable
5
+ from typing import Any, Optional
6
+
7
+ from pydantic import ConfigDict
8
+
9
+ from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
10
+ from core.model_runtime.model_providers.__base.ai_model import AIModel
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class TTSModel(AIModel):
16
+ """
17
+ Model class for TTS model.
18
+ """
19
+
20
+ model_type: ModelType = ModelType.TTS
21
+
22
+ # pydantic configs
23
+ model_config = ConfigDict(protected_namespaces=())
24
+
25
+ def invoke(
26
+ self,
27
+ model: str,
28
+ tenant_id: str,
29
+ credentials: dict,
30
+ content_text: str,
31
+ voice: str,
32
+ user: Optional[str] = None,
33
+ ) -> Iterable[bytes]:
34
+ """
35
+ Invoke large language model
36
+
37
+ :param model: model name
38
+ :param tenant_id: user tenant id
39
+ :param credentials: model credentials
40
+ :param voice: model timbre
41
+ :param content_text: text content to be translated
42
+ :param streaming: output is streaming
43
+ :param user: unique user id
44
+ :return: translated audio file
45
+ """
46
+ try:
47
+ return self._invoke(
48
+ model=model,
49
+ credentials=credentials,
50
+ user=user,
51
+ content_text=content_text,
52
+ voice=voice,
53
+ tenant_id=tenant_id,
54
+ )
55
+ except Exception as e:
56
+ raise self._transform_invoke_error(e)
57
+
58
+ @abstractmethod
59
+ def _invoke(
60
+ self,
61
+ model: str,
62
+ tenant_id: str,
63
+ credentials: dict,
64
+ content_text: str,
65
+ voice: str,
66
+ user: Optional[str] = None,
67
+ ) -> Iterable[bytes]:
68
+ """
69
+ Invoke large language model
70
+
71
+ :param model: model name
72
+ :param tenant_id: user tenant id
73
+ :param credentials: model credentials
74
+ :param voice: model timbre
75
+ :param content_text: text content to be translated
76
+ :param streaming: output is streaming
77
+ :param user: unique user id
78
+ :return: translated audio file
79
+ """
80
+ raise NotImplementedError
81
+
82
+ def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
83
+ """
84
+ Retrieves the list of voices supported by a given text-to-speech (TTS) model.
85
+
86
+ :param language: The language for which the voices are requested.
87
+ :param model: The name of the TTS model.
88
+ :param credentials: The credentials required to access the TTS model.
89
+ :return: A list of voices supported by the TTS model.
90
+ """
91
+ model_schema = self.get_model_schema(model, credentials)
92
+
93
+ if not model_schema or ModelPropertyKey.VOICES not in model_schema.model_properties:
94
+ raise ValueError("this model does not support voice")
95
+
96
+ voices = model_schema.model_properties[ModelPropertyKey.VOICES]
97
+ if language:
98
+ return [
99
+ {"name": d["name"], "value": d["mode"]} for d in voices if language and language in d.get("language")
100
+ ]
101
+ else:
102
+ return [{"name": d["name"], "value": d["mode"]} for d in voices]
103
+
104
+ def _get_model_default_voice(self, model: str, credentials: dict) -> Any:
105
+ """
106
+ Get voice for given tts model
107
+
108
+ :param model: model name
109
+ :param credentials: model credentials
110
+ :return: voice
111
+ """
112
+ model_schema = self.get_model_schema(model, credentials)
113
+
114
+ if model_schema and ModelPropertyKey.DEFAULT_VOICE in model_schema.model_properties:
115
+ return model_schema.model_properties[ModelPropertyKey.DEFAULT_VOICE]
116
+
117
+ def _get_model_audio_type(self, model: str, credentials: dict) -> str:
118
+ """
119
+ Get audio type for given tts model
120
+
121
+ :param model: model name
122
+ :param credentials: model credentials
123
+ :return: voice
124
+ """
125
+ model_schema = self.get_model_schema(model, credentials)
126
+
127
+ if not model_schema or ModelPropertyKey.AUDIO_TYPE not in model_schema.model_properties:
128
+ raise ValueError("this model does not support audio type")
129
+
130
+ audio_type: str = model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
131
+ return audio_type
132
+
133
+ def _get_model_word_limit(self, model: str, credentials: dict) -> int:
134
+ """
135
+ Get audio type for given tts model
136
+ :return: audio type
137
+ """
138
+ model_schema = self.get_model_schema(model, credentials)
139
+
140
+ if not model_schema or ModelPropertyKey.WORD_LIMIT not in model_schema.model_properties:
141
+ raise ValueError("this model does not support word limit")
142
+ world_limit: int = model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
143
+
144
+ return world_limit
145
+
146
+ def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
147
+ """
148
+ Get audio max workers for given tts model
149
+ :return: audio type
150
+ """
151
+ model_schema = self.get_model_schema(model, credentials)
152
+
153
+ if not model_schema or ModelPropertyKey.MAX_WORKERS not in model_schema.model_properties:
154
+ raise ValueError("this model does not support max workers")
155
+ workers_limit: int = model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
156
+
157
+ return workers_limit
158
+
159
+ @staticmethod
160
+ def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"):
161
+ match = re.compile(pattern)
162
+ tx = match.finditer(org_text)
163
+ start = 0
164
+ result = []
165
+ one_sentence = ""
166
+ for i in tx:
167
+ end = i.regs[0][1]
168
+ tmp = org_text[start:end]
169
+ if len(one_sentence + tmp) > max_length:
170
+ result.append(one_sentence)
171
+ one_sentence = ""
172
+ one_sentence += tmp
173
+ start = end
174
+ last_sens = org_text[start:]
175
+ if last_sens:
176
+ one_sentence += last_sens
177
+ if one_sentence != "":
178
+ result.append(one_sentence)
179
+ return result
api/core/model_runtime/model_providers/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
2
+
3
+ model_provider_factory = ModelProviderFactory()
api/core/model_runtime/model_providers/_position.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - openai
2
+ - deepseek
3
+ - anthropic
4
+ - azure_openai
5
+ - google
6
+ - vertex_ai
7
+ - nvidia
8
+ - nvidia_nim
9
+ - cohere
10
+ - upstage
11
+ - bedrock
12
+ - togetherai
13
+ - openrouter
14
+ - ollama
15
+ - mistralai
16
+ - groq
17
+ - replicate
18
+ - huggingface_hub
19
+ - xinference
20
+ - triton_inference_server
21
+ - zhipuai
22
+ - baichuan
23
+ - spark
24
+ - minimax
25
+ - tongyi
26
+ - wenxin
27
+ - moonshot
28
+ - tencent
29
+ - jina
30
+ - chatglm
31
+ - yi
32
+ - openllm
33
+ - localai
34
+ - volcengine_maas
35
+ - openai_api_compatible
36
+ - hunyuan
37
+ - siliconflow
38
+ - perfxcloud
39
+ - zhinao
40
+ - fireworks
41
+ - mixedbread
42
+ - nomic
43
+ - voyage
api/core/model_runtime/model_providers/anthropic/__init__.py ADDED
File without changes
api/core/model_runtime/model_providers/anthropic/_assets/icon_l_en.svg ADDED
api/core/model_runtime/model_providers/anthropic/_assets/icon_s_en.svg ADDED
api/core/model_runtime/model_providers/anthropic/anthropic.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from core.model_runtime.entities.model_entities import ModelType
4
+ from core.model_runtime.errors.validate import CredentialsValidateFailedError
5
+ from core.model_runtime.model_providers.__base.model_provider import ModelProvider
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class AnthropicProvider(ModelProvider):
11
+ def validate_provider_credentials(self, credentials: dict) -> None:
12
+ """
13
+ Validate provider credentials
14
+
15
+ if validate failed, raise exception
16
+
17
+ :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
18
+ """
19
+ try:
20
+ model_instance = self.get_model_instance(ModelType.LLM)
21
+
22
+ # Use `claude-3-opus-20240229` model for validate,
23
+ model_instance.validate_credentials(model="claude-3-opus-20240229", credentials=credentials)
24
+ except CredentialsValidateFailedError as ex:
25
+ raise ex
26
+ except Exception as ex:
27
+ logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
28
+ raise ex
api/core/model_runtime/model_providers/anthropic/anthropic.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ provider: anthropic
2
+ label:
3
+ en_US: Anthropic
4
+ description:
5
+ en_US: Anthropic’s powerful models, such as Claude 3.
6
+ zh_Hans: Anthropic 的强大模型,例如 Claude 3。
7
+ icon_small:
8
+ en_US: icon_s_en.svg
9
+ icon_large:
10
+ en_US: icon_l_en.svg
11
+ background: "#F0F0EB"
12
+ help:
13
+ title:
14
+ en_US: Get your API Key from Anthropic
15
+ zh_Hans: 从 Anthropic 获取 API Key
16
+ url:
17
+ en_US: https://console.anthropic.com/account/keys
18
+ supported_model_types:
19
+ - llm
20
+ configurate_methods:
21
+ - predefined-model
22
+ provider_credential_schema:
23
+ credential_form_schemas:
24
+ - variable: anthropic_api_key
25
+ label:
26
+ en_US: API Key
27
+ type: secret-input
28
+ required: true
29
+ placeholder:
30
+ zh_Hans: 在此输入您的 API Key
31
+ en_US: Enter your API Key
32
+ - variable: anthropic_api_url
33
+ label:
34
+ en_US: API URL
35
+ type: text-input
36
+ required: false
37
+ placeholder:
38
+ zh_Hans: 在此输入您的 API URL
39
+ en_US: Enter your API URL
api/core/model_runtime/model_providers/anthropic/llm/__init__.py ADDED
File without changes
api/core/model_runtime/model_providers/anthropic/llm/_position.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ - claude-3-5-haiku-20241022
2
+ - claude-3-5-sonnet-20241022
3
+ - claude-3-5-sonnet-20240620
4
+ - claude-3-haiku-20240307
5
+ - claude-3-opus-20240229
6
+ - claude-3-sonnet-20240229
7
+ - claude-2.1
8
+ - claude-instant-1.2
9
+ - claude-2
10
+ - claude-instant-1
api/core/model_runtime/model_providers/anthropic/llm/claude-2.1.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: claude-2.1
2
+ label:
3
+ en_US: claude-2.1
4
+ model_type: llm
5
+ features:
6
+ - agent-thought
7
+ model_properties:
8
+ mode: chat
9
+ context_size: 200000
10
+ parameter_rules:
11
+ - name: temperature
12
+ use_template: temperature
13
+ - name: top_p
14
+ use_template: top_p
15
+ - name: top_k
16
+ label:
17
+ zh_Hans: 取样数量
18
+ en_US: Top k
19
+ type: int
20
+ help:
21
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
22
+ en_US: Only sample from the top K options for each subsequent token.
23
+ required: false
24
+ - name: max_tokens_to_sample
25
+ use_template: max_tokens
26
+ required: true
27
+ default: 4096
28
+ min: 1
29
+ max: 4096
30
+ - name: response_format
31
+ use_template: response_format
32
+ pricing:
33
+ input: '8.00'
34
+ output: '24.00'
35
+ unit: '0.000001'
36
+ currency: USD
api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: claude-2
2
+ label:
3
+ en_US: claude-2
4
+ model_type: llm
5
+ features:
6
+ - agent-thought
7
+ model_properties:
8
+ mode: chat
9
+ context_size: 100000
10
+ parameter_rules:
11
+ - name: temperature
12
+ use_template: temperature
13
+ - name: top_p
14
+ use_template: top_p
15
+ - name: top_k
16
+ label:
17
+ zh_Hans: 取样数量
18
+ en_US: Top k
19
+ type: int
20
+ help:
21
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
22
+ en_US: Only sample from the top K options for each subsequent token.
23
+ required: false
24
+ - name: max_tokens_to_sample
25
+ use_template: max_tokens
26
+ required: true
27
+ default: 4096
28
+ min: 1
29
+ max: 4096
30
+ - name: response_format
31
+ use_template: response_format
32
+ pricing:
33
+ input: '8.00'
34
+ output: '24.00'
35
+ unit: '0.000001'
36
+ currency: USD
37
+ deprecated: true
api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-haiku-20241022.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: claude-3-5-haiku-20241022
2
+ label:
3
+ en_US: claude-3-5-haiku-20241022
4
+ model_type: llm
5
+ features:
6
+ - agent-thought
7
+ - tool-call
8
+ - stream-tool-call
9
+ model_properties:
10
+ mode: chat
11
+ context_size: 200000
12
+ parameter_rules:
13
+ - name: temperature
14
+ use_template: temperature
15
+ - name: top_p
16
+ use_template: top_p
17
+ - name: top_k
18
+ label:
19
+ zh_Hans: 取样数量
20
+ en_US: Top k
21
+ type: int
22
+ help:
23
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
24
+ en_US: Only sample from the top K options for each subsequent token.
25
+ required: false
26
+ - name: max_tokens
27
+ use_template: max_tokens
28
+ required: true
29
+ default: 8192
30
+ min: 1
31
+ max: 8192
32
+ - name: response_format
33
+ use_template: response_format
34
+ pricing:
35
+ input: '1.00'
36
+ output: '5.00'
37
+ unit: '0.000001'
38
+ currency: USD
api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: claude-3-5-sonnet-20240620
2
+ label:
3
+ en_US: claude-3-5-sonnet-20240620
4
+ model_type: llm
5
+ features:
6
+ - agent-thought
7
+ - vision
8
+ - tool-call
9
+ - stream-tool-call
10
+ - document
11
+ model_properties:
12
+ mode: chat
13
+ context_size: 200000
14
+ parameter_rules:
15
+ - name: temperature
16
+ use_template: temperature
17
+ - name: top_p
18
+ use_template: top_p
19
+ - name: top_k
20
+ label:
21
+ zh_Hans: 取样数量
22
+ en_US: Top k
23
+ type: int
24
+ help:
25
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
26
+ en_US: Only sample from the top K options for each subsequent token.
27
+ required: false
28
+ - name: max_tokens
29
+ use_template: max_tokens
30
+ required: true
31
+ default: 8192
32
+ min: 1
33
+ max: 8192
34
+ - name: response_format
35
+ use_template: response_format
36
+ pricing:
37
+ input: '3.00'
38
+ output: '15.00'
39
+ unit: '0.000001'
40
+ currency: USD
api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: claude-3-5-sonnet-20241022
2
+ label:
3
+ en_US: claude-3-5-sonnet-20241022
4
+ model_type: llm
5
+ features:
6
+ - agent-thought
7
+ - vision
8
+ - tool-call
9
+ - stream-tool-call
10
+ - document
11
+ model_properties:
12
+ mode: chat
13
+ context_size: 200000
14
+ parameter_rules:
15
+ - name: temperature
16
+ use_template: temperature
17
+ - name: top_p
18
+ use_template: top_p
19
+ - name: top_k
20
+ label:
21
+ zh_Hans: 取样数量
22
+ en_US: Top k
23
+ type: int
24
+ help:
25
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
26
+ en_US: Only sample from the top K options for each subsequent token.
27
+ required: false
28
+ - name: max_tokens
29
+ use_template: max_tokens
30
+ required: true
31
+ default: 8192
32
+ min: 1
33
+ max: 8192
34
+ - name: response_format
35
+ use_template: response_format
36
+ pricing:
37
+ input: '3.00'
38
+ output: '15.00'
39
+ unit: '0.000001'
40
+ currency: USD
api/core/model_runtime/model_providers/anthropic/llm/claude-3-haiku-20240307.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: claude-3-haiku-20240307
2
+ label:
3
+ en_US: claude-3-haiku-20240307
4
+ model_type: llm
5
+ features:
6
+ - agent-thought
7
+ - vision
8
+ - tool-call
9
+ - stream-tool-call
10
+ model_properties:
11
+ mode: chat
12
+ context_size: 200000
13
+ parameter_rules:
14
+ - name: temperature
15
+ use_template: temperature
16
+ - name: top_p
17
+ use_template: top_p
18
+ - name: top_k
19
+ label:
20
+ zh_Hans: 取样数量
21
+ en_US: Top k
22
+ type: int
23
+ help:
24
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
25
+ en_US: Only sample from the top K options for each subsequent token.
26
+ required: false
27
+ - name: max_tokens
28
+ use_template: max_tokens
29
+ required: true
30
+ default: 4096
31
+ min: 1
32
+ max: 4096
33
+ - name: response_format
34
+ use_template: response_format
35
+ pricing:
36
+ input: '0.25'
37
+ output: '1.25'
38
+ unit: '0.000001'
39
+ currency: USD
api/core/model_runtime/model_providers/anthropic/llm/claude-3-opus-20240229.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: claude-3-opus-20240229
2
+ label:
3
+ en_US: claude-3-opus-20240229
4
+ model_type: llm
5
+ features:
6
+ - agent-thought
7
+ - vision
8
+ - tool-call
9
+ - stream-tool-call
10
+ model_properties:
11
+ mode: chat
12
+ context_size: 200000
13
+ parameter_rules:
14
+ - name: temperature
15
+ use_template: temperature
16
+ - name: top_p
17
+ use_template: top_p
18
+ - name: top_k
19
+ label:
20
+ zh_Hans: 取样数量
21
+ en_US: Top k
22
+ type: int
23
+ help:
24
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
25
+ en_US: Only sample from the top K options for each subsequent token.
26
+ required: false
27
+ - name: max_tokens
28
+ use_template: max_tokens
29
+ required: true
30
+ default: 4096
31
+ min: 1
32
+ max: 4096
33
+ - name: response_format
34
+ use_template: response_format
35
+ pricing:
36
+ input: '15.00'
37
+ output: '75.00'
38
+ unit: '0.000001'
39
+ currency: USD
api/core/model_runtime/model_providers/anthropic/llm/claude-3-sonnet-20240229.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: claude-3-sonnet-20240229
2
+ label:
3
+ en_US: claude-3-sonnet-20240229
4
+ model_type: llm
5
+ features:
6
+ - agent-thought
7
+ - vision
8
+ - tool-call
9
+ - stream-tool-call
10
+ model_properties:
11
+ mode: chat
12
+ context_size: 200000
13
+ parameter_rules:
14
+ - name: temperature
15
+ use_template: temperature
16
+ - name: top_p
17
+ use_template: top_p
18
+ - name: top_k
19
+ label:
20
+ zh_Hans: 取样数量
21
+ en_US: Top k
22
+ type: int
23
+ help:
24
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
25
+ en_US: Only sample from the top K options for each subsequent token.
26
+ required: false
27
+ - name: max_tokens
28
+ use_template: max_tokens
29
+ required: true
30
+ default: 4096
31
+ min: 1
32
+ max: 4096
33
+ - name: response_format
34
+ use_template: response_format
35
+ pricing:
36
+ input: '3.00'
37
+ output: '15.00'
38
+ unit: '0.000001'
39
+ currency: USD
api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.2.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: claude-instant-1.2
2
+ label:
3
+ en_US: claude-instant-1.2
4
+ model_type: llm
5
+ features: [ ]
6
+ model_properties:
7
+ mode: chat
8
+ context_size: 100000
9
+ parameter_rules:
10
+ - name: temperature
11
+ use_template: temperature
12
+ - name: top_p
13
+ use_template: top_p
14
+ - name: top_k
15
+ label:
16
+ zh_Hans: 取样数量
17
+ en_US: Top k
18
+ type: int
19
+ help:
20
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
21
+ en_US: Only sample from the top K options for each subsequent token.
22
+ required: false
23
+ - name: max_tokens
24
+ use_template: max_tokens
25
+ required: true
26
+ default: 4096
27
+ min: 1
28
+ max: 4096
29
+ - name: response_format
30
+ use_template: response_format
31
+ pricing:
32
+ input: '1.63'
33
+ output: '5.51'
34
+ unit: '0.000001'
35
+ currency: USD
36
+ deprecated: true
api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: claude-instant-1
2
+ label:
3
+ en_US: claude-instant-1
4
+ model_type: llm
5
+ features: [ ]
6
+ model_properties:
7
+ mode: chat
8
+ context_size: 100000
9
+ parameter_rules:
10
+ - name: temperature
11
+ use_template: temperature
12
+ - name: top_p
13
+ use_template: top_p
14
+ - name: top_k
15
+ label:
16
+ zh_Hans: 取样数量
17
+ en_US: Top k
18
+ type: int
19
+ help:
20
+ zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
21
+ en_US: Only sample from the top K options for each subsequent token.
22
+ required: false
23
+ - name: max_tokens_to_sample
24
+ use_template: max_tokens
25
+ required: true
26
+ default: 4096
27
+ min: 1
28
+ max: 4096
29
+ - name: response_format
30
+ use_template: response_format
31
+ pricing:
32
+ input: '1.63'
33
+ output: '5.51'
34
+ unit: '0.000001'
35
+ currency: USD
36
+ deprecated: true
api/core/model_runtime/model_providers/anthropic/llm/llm.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ from collections.abc import Generator, Sequence
4
+ from typing import Optional, Union, cast
5
+
6
+ import anthropic
7
+ import requests
8
+ from anthropic import Anthropic, Stream
9
+ from anthropic.types import (
10
+ ContentBlockDeltaEvent,
11
+ Message,
12
+ MessageDeltaEvent,
13
+ MessageStartEvent,
14
+ MessageStopEvent,
15
+ MessageStreamEvent,
16
+ completion_create_params,
17
+ )
18
+ from anthropic.types.beta.tools import ToolsBetaMessage
19
+ from httpx import Timeout
20
+
21
+ from core.model_runtime.callbacks.base_callback import Callback
22
+ from core.model_runtime.entities import (
23
+ AssistantPromptMessage,
24
+ DocumentPromptMessageContent,
25
+ ImagePromptMessageContent,
26
+ PromptMessage,
27
+ PromptMessageContentType,
28
+ PromptMessageTool,
29
+ SystemPromptMessage,
30
+ TextPromptMessageContent,
31
+ ToolPromptMessage,
32
+ UserPromptMessage,
33
+ )
34
+ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
35
+ from core.model_runtime.errors.invoke import (
36
+ InvokeAuthorizationError,
37
+ InvokeBadRequestError,
38
+ InvokeConnectionError,
39
+ InvokeError,
40
+ InvokeRateLimitError,
41
+ InvokeServerUnavailableError,
42
+ )
43
+ from core.model_runtime.errors.validate import CredentialsValidateFailedError
44
+ from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
45
+
46
+ ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
47
+ The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
48
+ if you are not sure about the structure.
49
+
50
+ <instructions>
51
+ {{instructions}}
52
+ </instructions>
53
+ """ # noqa: E501
54
+
55
+
56
+ class AnthropicLargeLanguageModel(LargeLanguageModel):
57
+ def _invoke(
58
+ self,
59
+ model: str,
60
+ credentials: dict,
61
+ prompt_messages: list[PromptMessage],
62
+ model_parameters: dict,
63
+ tools: Optional[list[PromptMessageTool]] = None,
64
+ stop: Optional[list[str]] = None,
65
+ stream: bool = True,
66
+ user: Optional[str] = None,
67
+ ) -> Union[LLMResult, Generator]:
68
+ """
69
+ Invoke large language model
70
+
71
+ :param model: model name
72
+ :param credentials: model credentials
73
+ :param prompt_messages: prompt messages
74
+ :param model_parameters: model parameters
75
+ :param tools: tools for tool calling
76
+ :param stop: stop words
77
+ :param stream: is stream response
78
+ :param user: unique user id
79
+ :return: full response or stream response chunk generator result
80
+ """
81
+ # invoke model
82
+ return self._chat_generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
83
+
84
+ def _chat_generate(
85
+ self,
86
+ model: str,
87
+ credentials: dict,
88
+ prompt_messages: Sequence[PromptMessage],
89
+ model_parameters: dict,
90
+ tools: Optional[list[PromptMessageTool]] = None,
91
+ stop: Optional[Sequence[str]] = None,
92
+ stream: bool = True,
93
+ user: Optional[str] = None,
94
+ ) -> Union[LLMResult, Generator]:
95
+ """
96
+ Invoke llm chat model
97
+
98
+ :param model: model name
99
+ :param credentials: credentials
100
+ :param prompt_messages: prompt messages
101
+ :param model_parameters: model parameters
102
+ :param stop: stop words
103
+ :param stream: is stream response
104
+ :param user: unique user id
105
+ :return: full response or stream response chunk generator result
106
+ """
107
+ # transform credentials to kwargs for model instance
108
+ credentials_kwargs = self._to_credential_kwargs(credentials)
109
+
110
+ # transform model parameters from completion api of anthropic to chat api
111
+ if "max_tokens_to_sample" in model_parameters:
112
+ model_parameters["max_tokens"] = model_parameters.pop("max_tokens_to_sample")
113
+
114
+ # init model client
115
+ client = Anthropic(**credentials_kwargs)
116
+
117
+ extra_model_kwargs = {}
118
+ if stop:
119
+ extra_model_kwargs["stop_sequences"] = stop
120
+
121
+ if user:
122
+ extra_model_kwargs["metadata"] = completion_create_params.Metadata(user_id=user)
123
+
124
+ system, prompt_message_dicts = self._convert_prompt_messages(prompt_messages)
125
+
126
+ if system:
127
+ extra_model_kwargs["system"] = system
128
+
129
+ # Add the new header for claude-3-5-sonnet-20240620 model
130
+ extra_headers = {}
131
+ if model == "claude-3-5-sonnet-20240620":
132
+ if model_parameters.get("max_tokens", 0) > 4096:
133
+ extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"
134
+
135
+ if any(
136
+ isinstance(content, DocumentPromptMessageContent)
137
+ for prompt_message in prompt_messages
138
+ if isinstance(prompt_message.content, list)
139
+ for content in prompt_message.content
140
+ ):
141
+ extra_headers["anthropic-beta"] = "pdfs-2024-09-25"
142
+
143
+ if tools:
144
+ extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools]
145
+ response = client.beta.tools.messages.create(
146
+ model=model,
147
+ messages=prompt_message_dicts,
148
+ stream=stream,
149
+ extra_headers=extra_headers,
150
+ **model_parameters,
151
+ **extra_model_kwargs,
152
+ )
153
+ else:
154
+ # chat model
155
+ response = client.messages.create(
156
+ model=model,
157
+ messages=prompt_message_dicts,
158
+ stream=stream,
159
+ extra_headers=extra_headers,
160
+ **model_parameters,
161
+ **extra_model_kwargs,
162
+ )
163
+
164
+ if stream:
165
+ return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
166
+
167
+ return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
168
+
169
+ def _code_block_mode_wrapper(
170
+ self,
171
+ model: str,
172
+ credentials: dict,
173
+ prompt_messages: list[PromptMessage],
174
+ model_parameters: dict,
175
+ tools: Optional[list[PromptMessageTool]] = None,
176
+ stop: Optional[list[str]] = None,
177
+ stream: bool = True,
178
+ user: Optional[str] = None,
179
+ callbacks: Optional[list[Callback]] = None,
180
+ ) -> Union[LLMResult, Generator]:
181
+ """
182
+ Code block mode wrapper for invoking large language model
183
+ """
184
+ if model_parameters.get("response_format"):
185
+ stop = stop or []
186
+ # chat model
187
+ self._transform_chat_json_prompts(
188
+ model=model,
189
+ credentials=credentials,
190
+ prompt_messages=prompt_messages,
191
+ model_parameters=model_parameters,
192
+ tools=tools,
193
+ stop=stop,
194
+ stream=stream,
195
+ user=user,
196
+ response_format=model_parameters["response_format"],
197
+ )
198
+ model_parameters.pop("response_format")
199
+
200
+ return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
201
+
202
+ def _transform_tool_prompt(self, tool: PromptMessageTool) -> dict:
203
+ return {"name": tool.name, "description": tool.description, "input_schema": tool.parameters}
204
+
205
+ def _transform_chat_json_prompts(
206
+ self,
207
+ model: str,
208
+ credentials: dict,
209
+ prompt_messages: list[PromptMessage],
210
+ model_parameters: dict,
211
+ tools: list[PromptMessageTool] | None = None,
212
+ stop: list[str] | None = None,
213
+ stream: bool = True,
214
+ user: str | None = None,
215
+ response_format: str = "JSON",
216
+ ) -> None:
217
+ """
218
+ Transform json prompts
219
+ """
220
+ if "```\n" not in stop:
221
+ stop.append("```\n")
222
+ if "\n```" not in stop:
223
+ stop.append("\n```")
224
+
225
+ # check if there is a system message
226
+ if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
227
+ # override the system message
228
+ prompt_messages[0] = SystemPromptMessage(
229
+ content=ANTHROPIC_BLOCK_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content).replace(
230
+ "{{block}}", response_format
231
+ )
232
+ )
233
+ prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
234
+ else:
235
+ # insert the system message
236
+ prompt_messages.insert(
237
+ 0,
238
+ SystemPromptMessage(
239
+ content=ANTHROPIC_BLOCK_MODE_PROMPT.replace(
240
+ "{{instructions}}", f"Please output a valid {response_format} object."
241
+ ).replace("{{block}}", response_format)
242
+ ),
243
+ )
244
+ prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
245
+
246
+ def get_num_tokens(
247
+ self,
248
+ model: str,
249
+ credentials: dict,
250
+ prompt_messages: list[PromptMessage],
251
+ tools: Optional[list[PromptMessageTool]] = None,
252
+ ) -> int:
253
+ """
254
+ Get number of tokens for given prompt messages
255
+
256
+ :param model: model name
257
+ :param credentials: model credentials
258
+ :param prompt_messages: prompt messages
259
+ :param tools: tools for tool calling
260
+ :return:
261
+ """
262
+ prompt = self._convert_messages_to_prompt_anthropic(prompt_messages)
263
+
264
+ client = Anthropic(api_key="")
265
+ tokens = client.count_tokens(prompt)
266
+
267
+ tool_call_inner_prompts_tokens_map = {
268
+ "claude-3-opus-20240229": 395,
269
+ "claude-3-haiku-20240307": 264,
270
+ "claude-3-sonnet-20240229": 159,
271
+ }
272
+
273
+ if model in tool_call_inner_prompts_tokens_map and tools:
274
+ tokens += tool_call_inner_prompts_tokens_map[model]
275
+
276
+ return tokens
277
+
278
+ def validate_credentials(self, model: str, credentials: dict) -> None:
279
+ """
280
+ Validate model credentials
281
+
282
+ :param model: model name
283
+ :param credentials: model credentials
284
+ :return:
285
+ """
286
+ try:
287
+ self._chat_generate(
288
+ model=model,
289
+ credentials=credentials,
290
+ prompt_messages=[
291
+ UserPromptMessage(content="ping"),
292
+ ],
293
+ model_parameters={
294
+ "temperature": 0,
295
+ "max_tokens": 20,
296
+ },
297
+ stream=False,
298
+ )
299
+ except Exception as ex:
300
+ raise CredentialsValidateFailedError(str(ex))
301
+
302
+ def _handle_chat_generate_response(
303
+ self,
304
+ model: str,
305
+ credentials: dict,
306
+ response: Union[Message, ToolsBetaMessage],
307
+ prompt_messages: list[PromptMessage],
308
+ ) -> LLMResult:
309
+ """
310
+ Handle llm chat response
311
+
312
+ :param model: model name
313
+ :param credentials: credentials
314
+ :param response: response
315
+ :param prompt_messages: prompt messages
316
+ :return: llm response
317
+ """
318
+ # transform assistant message to prompt message
319
+ assistant_prompt_message = AssistantPromptMessage(content="", tool_calls=[])
320
+
321
+ for content in response.content:
322
+ if content.type == "text":
323
+ assistant_prompt_message.content += content.text
324
+ elif content.type == "tool_use":
325
+ tool_call = AssistantPromptMessage.ToolCall(
326
+ id=content.id,
327
+ type="function",
328
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(
329
+ name=content.name, arguments=json.dumps(content.input)
330
+ ),
331
+ )
332
+ assistant_prompt_message.tool_calls.append(tool_call)
333
+
334
+ # calculate num tokens
335
+ prompt_tokens = (response.usage and response.usage.input_tokens) or self.get_num_tokens(
336
+ model, credentials, prompt_messages
337
+ )
338
+
339
+ completion_tokens = (response.usage and response.usage.output_tokens) or self.get_num_tokens(
340
+ model, credentials, [assistant_prompt_message]
341
+ )
342
+
343
+ # transform usage
344
+ usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
345
+
346
+ # transform response
347
+ response = LLMResult(
348
+ model=response.model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage
349
+ )
350
+
351
+ return response
352
+
353
+ def _handle_chat_generate_stream_response(
354
+ self, model: str, credentials: dict, response: Stream[MessageStreamEvent], prompt_messages: list[PromptMessage]
355
+ ) -> Generator:
356
+ """
357
+ Handle llm chat stream response
358
+
359
+ :param model: model name
360
+ :param response: response
361
+ :param prompt_messages: prompt messages
362
+ :return: llm response chunk generator
363
+ """
364
+ full_assistant_content = ""
365
+ return_model = None
366
+ input_tokens = 0
367
+ output_tokens = 0
368
+ finish_reason = None
369
+ index = 0
370
+
371
+ tool_calls: list[AssistantPromptMessage.ToolCall] = []
372
+
373
+ for chunk in response:
374
+ if isinstance(chunk, MessageStartEvent):
375
+ if hasattr(chunk, "content_block"):
376
+ content_block = chunk.content_block
377
+ if isinstance(content_block, dict):
378
+ if content_block.get("type") == "tool_use":
379
+ tool_call = AssistantPromptMessage.ToolCall(
380
+ id=content_block.get("id"),
381
+ type="function",
382
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(
383
+ name=content_block.get("name"), arguments=""
384
+ ),
385
+ )
386
+ tool_calls.append(tool_call)
387
+ elif hasattr(chunk, "delta"):
388
+ delta = chunk.delta
389
+ if isinstance(delta, dict) and len(tool_calls) > 0:
390
+ if delta.get("type") == "input_json_delta":
391
+ tool_calls[-1].function.arguments += delta.get("partial_json", "")
392
+ elif chunk.message:
393
+ return_model = chunk.message.model
394
+ input_tokens = chunk.message.usage.input_tokens
395
+ elif isinstance(chunk, MessageDeltaEvent):
396
+ output_tokens = chunk.usage.output_tokens
397
+ finish_reason = chunk.delta.stop_reason
398
+ elif isinstance(chunk, MessageStopEvent):
399
+ # transform usage
400
+ usage = self._calc_response_usage(model, credentials, input_tokens, output_tokens)
401
+
402
+ # transform empty tool call arguments to {}
403
+ for tool_call in tool_calls:
404
+ if not tool_call.function.arguments:
405
+ tool_call.function.arguments = "{}"
406
+
407
+ yield LLMResultChunk(
408
+ model=return_model,
409
+ prompt_messages=prompt_messages,
410
+ delta=LLMResultChunkDelta(
411
+ index=index + 1,
412
+ message=AssistantPromptMessage(content="", tool_calls=tool_calls),
413
+ finish_reason=finish_reason,
414
+ usage=usage,
415
+ ),
416
+ )
417
+ elif isinstance(chunk, ContentBlockDeltaEvent):
418
+ chunk_text = chunk.delta.text or ""
419
+ full_assistant_content += chunk_text
420
+
421
+ # transform assistant message to prompt message
422
+ assistant_prompt_message = AssistantPromptMessage(content=chunk_text)
423
+
424
+ index = chunk.index
425
+
426
+ yield LLMResultChunk(
427
+ model=return_model,
428
+ prompt_messages=prompt_messages,
429
+ delta=LLMResultChunkDelta(
430
+ index=chunk.index,
431
+ message=assistant_prompt_message,
432
+ ),
433
+ )
434
+
435
+ def _to_credential_kwargs(self, credentials: dict) -> dict:
436
+ """
437
+ Transform credentials to kwargs for model instance
438
+
439
+ :param credentials:
440
+ :return:
441
+ """
442
+ credentials_kwargs = {
443
+ "api_key": credentials["anthropic_api_key"],
444
+ "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
445
+ "max_retries": 1,
446
+ }
447
+
448
+ if credentials.get("anthropic_api_url"):
449
+ credentials["anthropic_api_url"] = credentials["anthropic_api_url"].rstrip("/")
450
+ credentials_kwargs["base_url"] = credentials["anthropic_api_url"]
451
+
452
+ return credentials_kwargs
453
+
454
+ def _convert_prompt_messages(self, prompt_messages: Sequence[PromptMessage]) -> tuple[str, list[dict]]:
455
+ """
456
+ Convert prompt messages to dict list and system
457
+ """
458
+ system = ""
459
+ first_loop = True
460
+ for message in prompt_messages:
461
+ if isinstance(message, SystemPromptMessage):
462
+ if isinstance(message.content, str):
463
+ message.content = message.content.strip()
464
+ elif isinstance(message.content, list):
465
+ # System prompt only support text
466
+ message.content = "".join(
467
+ c.data.strip() for c in message.content if isinstance(c, TextPromptMessageContent)
468
+ )
469
+ else:
470
+ raise ValueError(f"Unknown system prompt message content type {type(message.content)}")
471
+ if first_loop:
472
+ system = message.content
473
+ first_loop = False
474
+ else:
475
+ system += "\n"
476
+ system += message.content
477
+
478
+ prompt_message_dicts = []
479
+ for message in prompt_messages:
480
+ if not isinstance(message, SystemPromptMessage):
481
+ if isinstance(message, UserPromptMessage):
482
+ message = cast(UserPromptMessage, message)
483
+ if isinstance(message.content, str):
484
+ # handle empty user prompt see #10013 #10520
485
+ # responses, ignore user prompts containing only whitespace, the Claude API can't handle it.
486
+ if not message.content.strip():
487
+ continue
488
+ message_dict = {"role": "user", "content": message.content}
489
+ prompt_message_dicts.append(message_dict)
490
+ else:
491
+ sub_messages = []
492
+ for message_content in message.content:
493
+ if message_content.type == PromptMessageContentType.TEXT:
494
+ message_content = cast(TextPromptMessageContent, message_content)
495
+ sub_message_dict = {"type": "text", "text": message_content.data}
496
+ sub_messages.append(sub_message_dict)
497
+ elif message_content.type == PromptMessageContentType.IMAGE:
498
+ message_content = cast(ImagePromptMessageContent, message_content)
499
+ if not message_content.base64_data:
500
+ # fetch image data from url
501
+ try:
502
+ image_content = requests.get(message_content.url).content
503
+ base64_data = base64.b64encode(image_content).decode("utf-8")
504
+ except Exception as ex:
505
+ raise ValueError(
506
+ f"Failed to fetch image data from url {message_content.data}, {ex}"
507
+ )
508
+ else:
509
+ base64_data = message_content.base64_data
510
+
511
+ mime_type = message_content.mime_type
512
+ if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
513
+ raise ValueError(
514
+ f"Unsupported image type {mime_type}, "
515
+ f"only support image/jpeg, image/png, image/gif, and image/webp"
516
+ )
517
+
518
+ sub_message_dict = {
519
+ "type": "image",
520
+ "source": {"type": "base64", "media_type": mime_type, "data": base64_data},
521
+ }
522
+ sub_messages.append(sub_message_dict)
523
+ elif isinstance(message_content, DocumentPromptMessageContent):
524
+ if message_content.mime_type != "application/pdf":
525
+ raise ValueError(
526
+ f"Unsupported document type {message_content.mime_type}, "
527
+ "only support application/pdf"
528
+ )
529
+ sub_message_dict = {
530
+ "type": "document",
531
+ "source": {
532
+ "type": "base64",
533
+ "media_type": message_content.mime_type,
534
+ "data": message_content.base64_data,
535
+ },
536
+ }
537
+ sub_messages.append(sub_message_dict)
538
+ prompt_message_dicts.append({"role": "user", "content": sub_messages})
539
+ elif isinstance(message, AssistantPromptMessage):
540
+ message = cast(AssistantPromptMessage, message)
541
+ content = []
542
+ if message.tool_calls:
543
+ for tool_call in message.tool_calls:
544
+ content.append(
545
+ {
546
+ "type": "tool_use",
547
+ "id": tool_call.id,
548
+ "name": tool_call.function.name,
549
+ "input": json.loads(tool_call.function.arguments),
550
+ }
551
+ )
552
+ if message.content:
553
+ content.append({"type": "text", "text": message.content})
554
+
555
+ if prompt_message_dicts[-1]["role"] == "assistant":
556
+ prompt_message_dicts[-1]["content"].extend(content)
557
+ else:
558
+ prompt_message_dicts.append({"role": "assistant", "content": content})
559
+ elif isinstance(message, ToolPromptMessage):
560
+ message = cast(ToolPromptMessage, message)
561
+ message_dict = {
562
+ "role": "user",
563
+ "content": [
564
+ {"type": "tool_result", "tool_use_id": message.tool_call_id, "content": message.content}
565
+ ],
566
+ }
567
+ prompt_message_dicts.append(message_dict)
568
+ else:
569
+ raise ValueError(f"Got unknown type {message}")
570
+
571
+ return system, prompt_message_dicts
572
+
573
+ def _convert_one_message_to_text(self, message: PromptMessage) -> str:
574
+ """
575
+ Convert a single message to a string.
576
+
577
+ :param message: PromptMessage to convert.
578
+ :return: String representation of the message.
579
+ """
580
+ human_prompt = "\n\nHuman:"
581
+ ai_prompt = "\n\nAssistant:"
582
+ content = message.content
583
+
584
+ if isinstance(message, UserPromptMessage):
585
+ message_text = f"{human_prompt} {content}"
586
+ if not isinstance(message.content, list):
587
+ message_text = f"{ai_prompt} {content}"
588
+ else:
589
+ message_text = ""
590
+ for sub_message in message.content:
591
+ if sub_message.type == PromptMessageContentType.TEXT:
592
+ message_text += f"{human_prompt} {sub_message.data}"
593
+ elif sub_message.type == PromptMessageContentType.IMAGE:
594
+ message_text += f"{human_prompt} [IMAGE]"
595
+ elif isinstance(message, AssistantPromptMessage):
596
+ if not isinstance(message.content, list):
597
+ message_text = f"{ai_prompt} {content}"
598
+ else:
599
+ message_text = ""
600
+ for sub_message in message.content:
601
+ if sub_message.type == PromptMessageContentType.TEXT:
602
+ message_text += f"{ai_prompt} {sub_message.data}"
603
+ elif sub_message.type == PromptMessageContentType.IMAGE:
604
+ message_text += f"{ai_prompt} [IMAGE]"
605
+ elif isinstance(message, SystemPromptMessage):
606
+ message_text = content
607
+ elif isinstance(message, ToolPromptMessage):
608
+ message_text = f"{human_prompt} {message.content}"
609
+ else:
610
+ raise ValueError(f"Got unknown type {message}")
611
+
612
+ return message_text
613
+
614
+ def _convert_messages_to_prompt_anthropic(self, messages: list[PromptMessage]) -> str:
615
+ """
616
+ Format a list of messages into a full prompt for the Anthropic model
617
+
618
+ :param messages: List of PromptMessage to combine.
619
+ :return: Combined string with necessary human_prompt and ai_prompt tags.
620
+ """
621
+ if not messages:
622
+ return ""
623
+
624
+ messages = messages.copy() # don't mutate the original list
625
+ if not isinstance(messages[-1], AssistantPromptMessage):
626
+ messages.append(AssistantPromptMessage(content=""))
627
+
628
+ text = "".join(self._convert_one_message_to_text(message) for message in messages)
629
+
630
+ # trim off the trailing ' ' that might come from the "Assistant: "
631
+ return text.rstrip()
632
+
633
+ @property
634
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
635
+ """
636
+ Map model invoke error to unified error
637
+ The key is the error type thrown to the caller
638
+ The value is the error type thrown by the model,
639
+ which needs to be converted into a unified error type for the caller.
640
+
641
+ :return: Invoke error mapping
642
+ """
643
+ return {
644
+ InvokeConnectionError: [anthropic.APIConnectionError, anthropic.APITimeoutError],
645
+ InvokeServerUnavailableError: [anthropic.InternalServerError],
646
+ InvokeRateLimitError: [anthropic.RateLimitError],
647
+ InvokeAuthorizationError: [anthropic.AuthenticationError, anthropic.PermissionDeniedError],
648
+ InvokeBadRequestError: [
649
+ anthropic.BadRequestError,
650
+ anthropic.NotFoundError,
651
+ anthropic.UnprocessableEntityError,
652
+ anthropic.APIError,
653
+ ],
654
+ }
api/core/model_runtime/model_providers/azure_ai_studio/__init__.py ADDED
File without changes
api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_l_en.png ADDED
api/core/model_runtime/model_providers/azure_ai_studio/_assets/icon_s_en.png ADDED
api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from core.model_runtime.model_providers.__base.model_provider import ModelProvider
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ class AzureAIStudioProvider(ModelProvider):
9
+ def validate_provider_credentials(self, credentials: dict) -> None:
10
+ """
11
+ Validate provider credentials
12
+
13
+ if validate failed, raise exception
14
+
15
+ :param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
16
+ """
17
+ pass
api/core/model_runtime/model_providers/azure_ai_studio/azure_ai_studio.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ provider: azure_ai_studio
2
+ label:
3
+ zh_Hans: Azure AI Studio
4
+ en_US: Azure AI Studio
5
+ icon_small:
6
+ en_US: icon_s_en.png
7
+ icon_large:
8
+ en_US: icon_l_en.png
9
+ description:
10
+ en_US: Azure AI Studio
11
+ zh_Hans: Azure AI Studio
12
+ background: "#93c5fd"
13
+ help:
14
+ title:
15
+ en_US: How to deploy customized model on Azure AI Studio
16
+ zh_Hans: 如何在Azure AI Studio上的私有化部署的模型
17
+ url:
18
+ en_US: https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models
19
+ zh_Hans: https://learn.microsoft.com/zh-cn/azure/ai-studio/how-to/deploy-models
20
+ supported_model_types:
21
+ - llm
22
+ - rerank
23
+ configurate_methods:
24
+ - customizable-model
25
+ model_credential_schema:
26
+ model:
27
+ label:
28
+ en_US: Model Name
29
+ zh_Hans: 模型名称
30
+ placeholder:
31
+ en_US: Enter your model name
32
+ zh_Hans: 输入模型名称
33
+ credential_form_schemas:
34
+ - variable: endpoint
35
+ label:
36
+ en_US: Azure AI Studio Endpoint
37
+ type: text-input
38
+ required: true
39
+ placeholder:
40
+ zh_Hans: 请输入你的Azure AI Studio推理端点
41
+ en_US: 'Enter your API Endpoint, eg: https://example.com'
42
+ - variable: api_key
43
+ required: true
44
+ label:
45
+ en_US: API Key
46
+ zh_Hans: API Key
47
+ type: secret-input
48
+ placeholder:
49
+ en_US: Enter your Azure AI Studio API Key
50
+ zh_Hans: 在此输入您的 Azure AI Studio API Key
51
+ show_on:
52
+ - variable: __model_type
53
+ value: llm
54
+ - variable: mode
55
+ show_on:
56
+ - variable: __model_type
57
+ value: llm
58
+ label:
59
+ en_US: Completion mode
60
+ type: select
61
+ required: false
62
+ default: chat
63
+ placeholder:
64
+ zh_Hans: 选择对话类型
65
+ en_US: Select completion mode
66
+ options:
67
+ - value: completion
68
+ label:
69
+ en_US: Completion
70
+ zh_Hans: 补全
71
+ - value: chat
72
+ label:
73
+ en_US: Chat
74
+ zh_Hans: 对话
75
+ - variable: context_size
76
+ label:
77
+ zh_Hans: 模型上下文长度
78
+ en_US: Model context size
79
+ required: true
80
+ show_on:
81
+ - variable: __model_type
82
+ value: llm
83
+ type: text-input
84
+ default: "4096"
85
+ placeholder:
86
+ zh_Hans: 在此输入您的模型上下文长度
87
+ en_US: Enter your Model context size
88
+ - variable: jwt_token
89
+ required: true
90
+ label:
91
+ en_US: JWT Token
92
+ zh_Hans: JWT令牌
93
+ type: secret-input
94
+ placeholder:
95
+ en_US: Enter your Azure AI Studio JWT Token
96
+ zh_Hans: 在此输入您的 Azure AI Studio 推理 API Key
97
+ show_on:
98
+ - variable: __model_type
99
+ value: rerank
api/core/model_runtime/model_providers/azure_ai_studio/llm/__init__.py ADDED
File without changes
api/core/model_runtime/model_providers/azure_ai_studio/llm/llm.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections.abc import Generator, Sequence
3
+ from typing import Any, Optional, Union
4
+
5
+ from azure.ai.inference import ChatCompletionsClient
6
+ from azure.ai.inference.models import StreamingChatCompletionsUpdate, SystemMessage, UserMessage
7
+ from azure.core.credentials import AzureKeyCredential
8
+ from azure.core.exceptions import (
9
+ ClientAuthenticationError,
10
+ DecodeError,
11
+ DeserializationError,
12
+ HttpResponseError,
13
+ ResourceExistsError,
14
+ ResourceModifiedError,
15
+ ResourceNotFoundError,
16
+ ResourceNotModifiedError,
17
+ SerializationError,
18
+ ServiceRequestError,
19
+ ServiceResponseError,
20
+ )
21
+
22
+ from core.model_runtime.callbacks.base_callback import Callback
23
+ from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
24
+ from core.model_runtime.entities.message_entities import (
25
+ AssistantPromptMessage,
26
+ PromptMessage,
27
+ PromptMessageTool,
28
+ )
29
+ from core.model_runtime.entities.model_entities import (
30
+ AIModelEntity,
31
+ FetchFrom,
32
+ I18nObject,
33
+ ModelPropertyKey,
34
+ ModelType,
35
+ ParameterRule,
36
+ ParameterType,
37
+ )
38
+ from core.model_runtime.errors.invoke import (
39
+ InvokeAuthorizationError,
40
+ InvokeBadRequestError,
41
+ InvokeConnectionError,
42
+ InvokeError,
43
+ InvokeServerUnavailableError,
44
+ )
45
+ from core.model_runtime.errors.validate import CredentialsValidateFailedError
46
+ from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ class AzureAIStudioLargeLanguageModel(LargeLanguageModel):
52
+ """
53
+ Model class for Azure AI Studio large language model.
54
+ """
55
+
56
+ client: Any = None
57
+
58
+ from azure.ai.inference.models import StreamingChatCompletionsUpdate
59
+
60
+ def _invoke(
61
+ self,
62
+ model: str,
63
+ credentials: dict,
64
+ prompt_messages: Sequence[PromptMessage],
65
+ model_parameters: dict,
66
+ tools: Optional[Sequence[PromptMessageTool]] = None,
67
+ stop: Optional[Sequence[str]] = None,
68
+ stream: bool = True,
69
+ user: Optional[str] = None,
70
+ ) -> Union[LLMResult, Generator]:
71
+ """
72
+ Invoke large language model
73
+
74
+ :param model: model name
75
+ :param credentials: model credentials
76
+ :param prompt_messages: prompt messages
77
+ :param model_parameters: model parameters
78
+ :param tools: tools for tool calling
79
+ :param stop: stop words
80
+ :param stream: is stream response
81
+ :param user: unique user id
82
+ :return: full response or stream response chunk generator result
83
+ """
84
+
85
+ if not self.client:
86
+ endpoint = str(credentials.get("endpoint"))
87
+ api_key = str(credentials.get("api_key"))
88
+ self.client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key))
89
+
90
+ messages = [{"role": msg.role.value, "content": msg.content} for msg in prompt_messages]
91
+
92
+ payload = {
93
+ "messages": messages,
94
+ "max_tokens": model_parameters.get("max_tokens", 4096),
95
+ "temperature": model_parameters.get("temperature", 0),
96
+ "top_p": model_parameters.get("top_p", 1),
97
+ "stream": stream,
98
+ "model": model,
99
+ }
100
+
101
+ if stop:
102
+ payload["stop"] = stop
103
+
104
+ if tools:
105
+ payload["tools"] = [tool.model_dump() for tool in tools]
106
+
107
+ try:
108
+ response = self.client.complete(**payload)
109
+
110
+ if stream:
111
+ return self._handle_stream_response(response, model, prompt_messages)
112
+ else:
113
+ return self._handle_non_stream_response(response, model, prompt_messages, credentials)
114
+ except Exception as e:
115
+ raise self._transform_invoke_error(e)
116
+
117
+ def _handle_stream_response(self, response, model: str, prompt_messages: list[PromptMessage]) -> Generator:
118
+ for chunk in response:
119
+ if isinstance(chunk, StreamingChatCompletionsUpdate):
120
+ if chunk.choices:
121
+ delta = chunk.choices[0].delta
122
+ if delta.content:
123
+ yield LLMResultChunk(
124
+ model=model,
125
+ prompt_messages=prompt_messages,
126
+ delta=LLMResultChunkDelta(
127
+ index=0,
128
+ message=AssistantPromptMessage(content=delta.content, tool_calls=[]),
129
+ ),
130
+ )
131
+
132
+ def _handle_non_stream_response(
133
+ self, response, model: str, prompt_messages: list[PromptMessage], credentials: dict
134
+ ) -> LLMResult:
135
+ assistant_text = response.choices[0].message.content
136
+ assistant_prompt_message = AssistantPromptMessage(content=assistant_text)
137
+ usage = self._calc_response_usage(
138
+ model, credentials, response.usage.prompt_tokens, response.usage.completion_tokens
139
+ )
140
+ result = LLMResult(model=model, prompt_messages=prompt_messages, message=assistant_prompt_message, usage=usage)
141
+
142
+ if hasattr(response, "system_fingerprint"):
143
+ result.system_fingerprint = response.system_fingerprint
144
+
145
+ return result
146
+
147
+ def _invoke_result_generator(
148
+ self,
149
+ model: str,
150
+ result: Generator,
151
+ credentials: dict,
152
+ prompt_messages: list[PromptMessage],
153
+ model_parameters: dict,
154
+ tools: Optional[list[PromptMessageTool]] = None,
155
+ stop: Optional[list[str]] = None,
156
+ stream: bool = True,
157
+ user: Optional[str] = None,
158
+ callbacks: Optional[list[Callback]] = None,
159
+ ) -> Generator:
160
+ """
161
+ Invoke result generator
162
+
163
+ :param result: result generator
164
+ :return: result generator
165
+ """
166
+ callbacks = callbacks or []
167
+ prompt_message = AssistantPromptMessage(content="")
168
+ usage = None
169
+ system_fingerprint = None
170
+ real_model = model
171
+
172
+ try:
173
+ for chunk in result:
174
+ if isinstance(chunk, dict):
175
+ content = chunk["choices"][0]["message"]["content"]
176
+ usage = chunk["usage"]
177
+ chunk = LLMResultChunk(
178
+ model=model,
179
+ prompt_messages=prompt_messages,
180
+ delta=LLMResultChunkDelta(
181
+ index=0,
182
+ message=AssistantPromptMessage(content=content, tool_calls=[]),
183
+ ),
184
+ system_fingerprint=chunk.get("system_fingerprint"),
185
+ )
186
+
187
+ yield chunk
188
+
189
+ self._trigger_new_chunk_callbacks(
190
+ chunk=chunk,
191
+ model=model,
192
+ credentials=credentials,
193
+ prompt_messages=prompt_messages,
194
+ model_parameters=model_parameters,
195
+ tools=tools,
196
+ stop=stop,
197
+ stream=stream,
198
+ user=user,
199
+ callbacks=callbacks,
200
+ )
201
+
202
+ prompt_message.content += chunk.delta.message.content
203
+ real_model = chunk.model
204
+ if hasattr(chunk.delta, "usage"):
205
+ usage = chunk.delta.usage
206
+
207
+ if chunk.system_fingerprint:
208
+ system_fingerprint = chunk.system_fingerprint
209
+ except Exception as e:
210
+ raise self._transform_invoke_error(e)
211
+
212
+ self._trigger_after_invoke_callbacks(
213
+ model=model,
214
+ result=LLMResult(
215
+ model=real_model,
216
+ prompt_messages=prompt_messages,
217
+ message=prompt_message,
218
+ usage=usage or LLMUsage.empty_usage(),
219
+ system_fingerprint=system_fingerprint,
220
+ ),
221
+ credentials=credentials,
222
+ prompt_messages=prompt_messages,
223
+ model_parameters=model_parameters,
224
+ tools=tools,
225
+ stop=stop,
226
+ stream=stream,
227
+ user=user,
228
+ callbacks=callbacks,
229
+ )
230
+
231
+ def get_num_tokens(
232
+ self,
233
+ model: str,
234
+ credentials: dict,
235
+ prompt_messages: list[PromptMessage],
236
+ tools: Optional[list[PromptMessageTool]] = None,
237
+ ) -> int:
238
+ """
239
+ Get number of tokens for given prompt messages
240
+
241
+ :param model: model name
242
+ :param credentials: model credentials
243
+ :param prompt_messages: prompt messages
244
+ :param tools: tools for tool calling
245
+ :return:
246
+ """
247
+ # Implement token counting logic here
248
+ # Might need to use a tokenizer specific to the Azure AI Studio model
249
+ return 0
250
+
251
+ def validate_credentials(self, model: str, credentials: dict) -> None:
252
+ """
253
+ Validate model credentials
254
+
255
+ :param model: model name
256
+ :param credentials: model credentials
257
+ :return:
258
+ """
259
+ try:
260
+ endpoint = str(credentials.get("endpoint"))
261
+ api_key = str(credentials.get("api_key"))
262
+ client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(api_key))
263
+ client.complete(
264
+ messages=[
265
+ SystemMessage(content="I say 'ping', you say 'pong'"),
266
+ UserMessage(content="ping"),
267
+ ],
268
+ model=model,
269
+ )
270
+ except Exception as ex:
271
+ raise CredentialsValidateFailedError(str(ex))
272
+
273
+ @property
274
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
275
+ """
276
+ Map model invoke error to unified error
277
+ The key is the error type thrown to the caller
278
+ The value is the error type thrown by the model,
279
+ which needs to be converted into a unified error type for the caller.
280
+
281
+ :return: Invoke error mapping
282
+ """
283
+ return {
284
+ InvokeConnectionError: [
285
+ ServiceRequestError,
286
+ ],
287
+ InvokeServerUnavailableError: [
288
+ ServiceResponseError,
289
+ ],
290
+ InvokeAuthorizationError: [
291
+ ClientAuthenticationError,
292
+ ],
293
+ InvokeBadRequestError: [
294
+ HttpResponseError,
295
+ DecodeError,
296
+ ResourceExistsError,
297
+ ResourceNotFoundError,
298
+ ResourceModifiedError,
299
+ ResourceNotModifiedError,
300
+ SerializationError,
301
+ DeserializationError,
302
+ ],
303
+ }
304
+
305
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
306
+ """
307
+ Used to define customizable model schema
308
+ """
309
+ rules = [
310
+ ParameterRule(
311
+ name="temperature",
312
+ type=ParameterType.FLOAT,
313
+ use_template="temperature",
314
+ label=I18nObject(zh_Hans="温度", en_US="Temperature"),
315
+ ),
316
+ ParameterRule(
317
+ name="top_p",
318
+ type=ParameterType.FLOAT,
319
+ use_template="top_p",
320
+ label=I18nObject(zh_Hans="Top P", en_US="Top P"),
321
+ ),
322
+ ParameterRule(
323
+ name="max_tokens",
324
+ type=ParameterType.INT,
325
+ use_template="max_tokens",
326
+ min=1,
327
+ default=512,
328
+ label=I18nObject(zh_Hans="最大生成长度", en_US="Max Tokens"),
329
+ ),
330
+ ]
331
+
332
+ entity = AIModelEntity(
333
+ model=model,
334
+ label=I18nObject(en_US=model),
335
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
336
+ model_type=ModelType.LLM,
337
+ features=[],
338
+ model_properties={
339
+ ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "4096")),
340
+ ModelPropertyKey.MODE: credentials.get("mode", LLMMode.CHAT),
341
+ },
342
+ parameter_rules=rules,
343
+ )
344
+
345
+ return entity
api/core/model_runtime/model_providers/azure_ai_studio/rerank/__init__.py ADDED
File without changes
api/core/model_runtime/model_providers/azure_ai_studio/rerank/rerank.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import ssl
5
+ import urllib.request
6
+ from typing import Optional
7
+
8
+ from core.model_runtime.entities.common_entities import I18nObject
9
+ from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
10
+ from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
11
+ from core.model_runtime.errors.invoke import (
12
+ InvokeAuthorizationError,
13
+ InvokeBadRequestError,
14
+ InvokeConnectionError,
15
+ InvokeError,
16
+ InvokeRateLimitError,
17
+ InvokeServerUnavailableError,
18
+ )
19
+ from core.model_runtime.errors.validate import CredentialsValidateFailedError
20
+ from core.model_runtime.model_providers.__base.rerank_model import RerankModel
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class AzureRerankModel(RerankModel):
26
+ """
27
+ Model class for Azure AI Studio rerank model.
28
+ """
29
+
30
+ def _allow_self_signed_https(self, allowed):
31
+ # bypass the server certificate verification on client side
32
+ if allowed and not os.environ.get("PYTHONHTTPSVERIFY", "") and getattr(ssl, "_create_unverified_context", None):
33
+ ssl._create_default_https_context = ssl._create_unverified_context
34
+
35
+ def _azure_rerank(self, query_input: str, docs: list[str], endpoint: str, api_key: str):
36
+ # self._allow_self_signed_https(True) # Enable if using self-signed certificate
37
+
38
+ data = {"inputs": query_input, "docs": docs}
39
+
40
+ body = json.dumps(data).encode("utf-8")
41
+ headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
42
+
43
+ req = urllib.request.Request(endpoint, body, headers)
44
+
45
+ try:
46
+ with urllib.request.urlopen(req) as response:
47
+ result = response.read()
48
+ return json.loads(result)
49
+ except urllib.error.HTTPError as error:
50
+ logger.exception(f"The request failed with status code: {error.code}")
51
+ logger.exception(error.info())
52
+ logger.exception(error.read().decode("utf8", "ignore"))
53
+ raise
54
+
55
+ def _invoke(
56
+ self,
57
+ model: str,
58
+ credentials: dict,
59
+ query: str,
60
+ docs: list[str],
61
+ score_threshold: Optional[float] = None,
62
+ top_n: Optional[int] = None,
63
+ user: Optional[str] = None,
64
+ ) -> RerankResult:
65
+ """
66
+ Invoke rerank model
67
+
68
+ :param model: model name
69
+ :param credentials: model credentials
70
+ :param query: search query
71
+ :param docs: docs for reranking
72
+ :param score_threshold: score threshold
73
+ :param top_n: top n
74
+ :param user: unique user id
75
+ :return: rerank result
76
+ """
77
+ try:
78
+ if len(docs) == 0:
79
+ return RerankResult(model=model, docs=[])
80
+
81
+ endpoint = credentials.get("endpoint")
82
+ api_key = credentials.get("jwt_token")
83
+
84
+ if not endpoint or not api_key:
85
+ raise ValueError("Azure endpoint and API key must be provided in credentials")
86
+
87
+ result = self._azure_rerank(query, docs, endpoint, api_key)
88
+ logger.info(f"Azure rerank result: {result}")
89
+
90
+ rerank_documents = []
91
+ for idx, (doc, score_dict) in enumerate(zip(docs, result)):
92
+ score = score_dict["score"]
93
+ rerank_document = RerankDocument(index=idx, text=doc, score=score)
94
+
95
+ if score_threshold is None or score >= score_threshold:
96
+ rerank_documents.append(rerank_document)
97
+
98
+ rerank_documents.sort(key=lambda x: x.score, reverse=True)
99
+
100
+ if top_n:
101
+ rerank_documents = rerank_documents[:top_n]
102
+
103
+ return RerankResult(model=model, docs=rerank_documents)
104
+
105
+ except Exception as e:
106
+ logger.exception(f"Failed to invoke rerank model, model: {model}")
107
+ raise
108
+
109
+ def validate_credentials(self, model: str, credentials: dict) -> None:
110
+ """
111
+ Validate model credentials
112
+
113
+ :param model: model name
114
+ :param credentials: model credentials
115
+ :return:
116
+ """
117
+ try:
118
+ self._invoke(
119
+ model=model,
120
+ credentials=credentials,
121
+ query="What is the capital of the United States?",
122
+ docs=[
123
+ "Carson City is the capital city of the American state of Nevada. At the 2010 United States "
124
+ "Census, Carson City had a population of 55,274.",
125
+ "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
126
+ "are a political division controlled by the United States. Its capital is Saipan.",
127
+ ],
128
+ score_threshold=0.8,
129
+ )
130
+ except Exception as ex:
131
+ raise CredentialsValidateFailedError(str(ex))
132
+
133
+ @property
134
+ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
135
+ """
136
+ Map model invoke error to unified error
137
+ The key is the error type thrown to the caller
138
+ The value is the error type thrown by the model,
139
+ which needs to be converted into a unified error type for the caller.
140
+
141
+ :return: Invoke error mapping
142
+ """
143
+ return {
144
+ InvokeConnectionError: [urllib.error.URLError],
145
+ InvokeServerUnavailableError: [urllib.error.HTTPError],
146
+ InvokeRateLimitError: [InvokeRateLimitError],
147
+ InvokeAuthorizationError: [InvokeAuthorizationError],
148
+ InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError, json.JSONDecodeError],
149
+ }
150
+
151
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
152
+ """
153
+ used to define customizable model schema
154
+ """
155
+ entity = AIModelEntity(
156
+ model=model,
157
+ label=I18nObject(en_US=model),
158
+ fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
159
+ model_type=ModelType.RERANK,
160
+ model_properties={},
161
+ parameter_rules=[],
162
+ )
163
+
164
+ return entity
api/core/model_runtime/model_providers/azure_openai/__init__.py ADDED
File without changes