|
import time
|
|
from collections.abc import Generator
|
|
from typing import Optional, Union, cast
|
|
|
|
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
|
from core.app.entities.app_invoke_entities import (
|
|
AppGenerateEntity,
|
|
EasyUIBasedAppGenerateEntity,
|
|
InvokeFrom,
|
|
ModelConfigWithCredentialsEntity,
|
|
)
|
|
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent
|
|
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
|
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
|
|
from core.external_data_tool.external_data_fetch import ExternalDataFetch
|
|
from core.file.file_obj import FileVar
|
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
|
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
|
|
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
|
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
from core.moderation.input_moderation import InputModeration
|
|
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
|
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
|
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
|
from models.model import App, AppMode, Message, MessageAnnotation
|
|
|
|
|
|
class AppRunner:
|
|
def get_pre_calculate_rest_tokens(self, app_record: App,
|
|
model_config: ModelConfigWithCredentialsEntity,
|
|
prompt_template_entity: PromptTemplateEntity,
|
|
inputs: dict[str, str],
|
|
files: list[FileVar],
|
|
query: Optional[str] = None) -> int:
|
|
"""
|
|
Get pre calculate rest tokens
|
|
:param app_record: app record
|
|
:param model_config: model config entity
|
|
:param prompt_template_entity: prompt template entity
|
|
:param inputs: inputs
|
|
:param files: files
|
|
:param query: query
|
|
:return:
|
|
"""
|
|
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
|
|
|
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
|
|
|
max_tokens = 0
|
|
for parameter_rule in model_config.model_schema.parameter_rules:
|
|
if (parameter_rule.name == 'max_tokens'
|
|
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
|
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
|
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
|
|
|
if model_context_tokens is None:
|
|
return -1
|
|
|
|
if max_tokens is None:
|
|
max_tokens = 0
|
|
|
|
|
|
prompt_messages, stop = self.organize_prompt_messages(
|
|
app_record=app_record,
|
|
model_config=model_config,
|
|
prompt_template_entity=prompt_template_entity,
|
|
inputs=inputs,
|
|
files=files,
|
|
query=query
|
|
)
|
|
|
|
prompt_tokens = model_type_instance.get_num_tokens(
|
|
model_config.model,
|
|
model_config.credentials,
|
|
prompt_messages
|
|
)
|
|
|
|
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
|
|
if rest_tokens < 0:
|
|
raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
|
"or shrink the max token, or switch to a llm with a larger token limit size.")
|
|
|
|
return rest_tokens
|
|
|
|
def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity,
|
|
prompt_messages: list[PromptMessage]):
|
|
|
|
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
|
|
|
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
|
|
|
max_tokens = 0
|
|
for parameter_rule in model_config.model_schema.parameter_rules:
|
|
if (parameter_rule.name == 'max_tokens'
|
|
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
|
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
|
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
|
|
|
if model_context_tokens is None:
|
|
return -1
|
|
|
|
if max_tokens is None:
|
|
max_tokens = 0
|
|
|
|
prompt_tokens = model_type_instance.get_num_tokens(
|
|
model_config.model,
|
|
model_config.credentials,
|
|
prompt_messages
|
|
)
|
|
|
|
if prompt_tokens + max_tokens > model_context_tokens:
|
|
max_tokens = max(model_context_tokens - prompt_tokens, 16)
|
|
|
|
for parameter_rule in model_config.model_schema.parameter_rules:
|
|
if (parameter_rule.name == 'max_tokens'
|
|
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
|
model_config.parameters[parameter_rule.name] = max_tokens
|
|
|
|
def organize_prompt_messages(self, app_record: App,
|
|
model_config: ModelConfigWithCredentialsEntity,
|
|
prompt_template_entity: PromptTemplateEntity,
|
|
inputs: dict[str, str],
|
|
files: list[FileVar],
|
|
query: Optional[str] = None,
|
|
context: Optional[str] = None,
|
|
memory: Optional[TokenBufferMemory] = None) \
|
|
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
|
"""
|
|
Organize prompt messages
|
|
:param context:
|
|
:param app_record: app record
|
|
:param model_config: model config entity
|
|
:param prompt_template_entity: prompt template entity
|
|
:param inputs: inputs
|
|
:param files: files
|
|
:param query: query
|
|
:param memory: memory
|
|
:return:
|
|
"""
|
|
|
|
if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
|
prompt_transform = SimplePromptTransform()
|
|
prompt_messages, stop = prompt_transform.get_prompt(
|
|
app_mode=AppMode.value_of(app_record.mode),
|
|
prompt_template_entity=prompt_template_entity,
|
|
inputs=inputs,
|
|
query=query if query else '',
|
|
files=files,
|
|
context=context,
|
|
memory=memory,
|
|
model_config=model_config
|
|
)
|
|
else:
|
|
memory_config = MemoryConfig(
|
|
window=MemoryConfig.WindowConfig(
|
|
enabled=False
|
|
)
|
|
)
|
|
|
|
model_mode = ModelMode.value_of(model_config.mode)
|
|
if model_mode == ModelMode.COMPLETION:
|
|
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
|
|
prompt_template = CompletionModelPromptTemplate(
|
|
text=advanced_completion_prompt_template.prompt
|
|
)
|
|
|
|
if advanced_completion_prompt_template.role_prefix:
|
|
memory_config.role_prefix = MemoryConfig.RolePrefix(
|
|
user=advanced_completion_prompt_template.role_prefix.user,
|
|
assistant=advanced_completion_prompt_template.role_prefix.assistant
|
|
)
|
|
else:
|
|
prompt_template = []
|
|
for message in prompt_template_entity.advanced_chat_prompt_template.messages:
|
|
prompt_template.append(ChatModelMessage(
|
|
text=message.text,
|
|
role=message.role
|
|
))
|
|
|
|
prompt_transform = AdvancedPromptTransform()
|
|
prompt_messages = prompt_transform.get_prompt(
|
|
prompt_template=prompt_template,
|
|
inputs=inputs,
|
|
query=query if query else '',
|
|
files=files,
|
|
context=context,
|
|
memory_config=memory_config,
|
|
memory=memory,
|
|
model_config=model_config
|
|
)
|
|
stop = model_config.stop
|
|
|
|
return prompt_messages, stop
|
|
|
|
def direct_output(self, queue_manager: AppQueueManager,
|
|
app_generate_entity: EasyUIBasedAppGenerateEntity,
|
|
prompt_messages: list,
|
|
text: str,
|
|
stream: bool,
|
|
usage: Optional[LLMUsage] = None) -> None:
|
|
"""
|
|
Direct output
|
|
:param queue_manager: application queue manager
|
|
:param app_generate_entity: app generate entity
|
|
:param prompt_messages: prompt messages
|
|
:param text: text
|
|
:param stream: stream
|
|
:param usage: usage
|
|
:return:
|
|
"""
|
|
if stream:
|
|
index = 0
|
|
for token in text:
|
|
chunk = LLMResultChunk(
|
|
model=app_generate_entity.model_config.model,
|
|
prompt_messages=prompt_messages,
|
|
delta=LLMResultChunkDelta(
|
|
index=index,
|
|
message=AssistantPromptMessage(content=token)
|
|
)
|
|
)
|
|
|
|
queue_manager.publish(
|
|
QueueLLMChunkEvent(
|
|
chunk=chunk
|
|
), PublishFrom.APPLICATION_MANAGER
|
|
)
|
|
index += 1
|
|
time.sleep(0.01)
|
|
|
|
queue_manager.publish(
|
|
QueueMessageEndEvent(
|
|
llm_result=LLMResult(
|
|
model=app_generate_entity.model_config.model,
|
|
prompt_messages=prompt_messages,
|
|
message=AssistantPromptMessage(content=text),
|
|
usage=usage if usage else LLMUsage.empty_usage()
|
|
),
|
|
), PublishFrom.APPLICATION_MANAGER
|
|
)
|
|
|
|
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
|
|
queue_manager: AppQueueManager,
|
|
stream: bool,
|
|
agent: bool = False) -> None:
|
|
"""
|
|
Handle invoke result
|
|
:param invoke_result: invoke result
|
|
:param queue_manager: application queue manager
|
|
:param stream: stream
|
|
:return:
|
|
"""
|
|
if not stream:
|
|
self._handle_invoke_result_direct(
|
|
invoke_result=invoke_result,
|
|
queue_manager=queue_manager,
|
|
agent=agent
|
|
)
|
|
else:
|
|
self._handle_invoke_result_stream(
|
|
invoke_result=invoke_result,
|
|
queue_manager=queue_manager,
|
|
agent=agent
|
|
)
|
|
|
|
def _handle_invoke_result_direct(self, invoke_result: LLMResult,
|
|
queue_manager: AppQueueManager,
|
|
agent: bool) -> None:
|
|
"""
|
|
Handle invoke result direct
|
|
:param invoke_result: invoke result
|
|
:param queue_manager: application queue manager
|
|
:return:
|
|
"""
|
|
queue_manager.publish(
|
|
QueueMessageEndEvent(
|
|
llm_result=invoke_result,
|
|
), PublishFrom.APPLICATION_MANAGER
|
|
)
|
|
|
|
def _handle_invoke_result_stream(self, invoke_result: Generator,
|
|
queue_manager: AppQueueManager,
|
|
agent: bool) -> None:
|
|
"""
|
|
Handle invoke result
|
|
:param invoke_result: invoke result
|
|
:param queue_manager: application queue manager
|
|
:return:
|
|
"""
|
|
model = None
|
|
prompt_messages = []
|
|
text = ''
|
|
usage = None
|
|
for result in invoke_result:
|
|
if not agent:
|
|
queue_manager.publish(
|
|
QueueLLMChunkEvent(
|
|
chunk=result
|
|
), PublishFrom.APPLICATION_MANAGER
|
|
)
|
|
else:
|
|
queue_manager.publish(
|
|
QueueAgentMessageEvent(
|
|
chunk=result
|
|
), PublishFrom.APPLICATION_MANAGER
|
|
)
|
|
|
|
text += result.delta.message.content
|
|
|
|
if not model:
|
|
model = result.model
|
|
|
|
if not prompt_messages:
|
|
prompt_messages = result.prompt_messages
|
|
|
|
if not usage and result.delta.usage:
|
|
usage = result.delta.usage
|
|
|
|
if not usage:
|
|
usage = LLMUsage.empty_usage()
|
|
|
|
llm_result = LLMResult(
|
|
model=model,
|
|
prompt_messages=prompt_messages,
|
|
message=AssistantPromptMessage(content=text),
|
|
usage=usage
|
|
)
|
|
|
|
queue_manager.publish(
|
|
QueueMessageEndEvent(
|
|
llm_result=llm_result,
|
|
), PublishFrom.APPLICATION_MANAGER
|
|
)
|
|
|
|
def moderation_for_inputs(self, app_id: str,
|
|
tenant_id: str,
|
|
app_generate_entity: AppGenerateEntity,
|
|
inputs: dict,
|
|
query: str) -> tuple[bool, dict, str]:
|
|
"""
|
|
Process sensitive_word_avoidance.
|
|
:param app_id: app id
|
|
:param tenant_id: tenant id
|
|
:param app_generate_entity: app generate entity
|
|
:param inputs: inputs
|
|
:param query: query
|
|
:return:
|
|
"""
|
|
moderation_feature = InputModeration()
|
|
return moderation_feature.check(
|
|
app_id=app_id,
|
|
tenant_id=tenant_id,
|
|
app_config=app_generate_entity.app_config,
|
|
inputs=inputs,
|
|
query=query if query else ''
|
|
)
|
|
|
|
def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
|
|
queue_manager: AppQueueManager,
|
|
prompt_messages: list[PromptMessage]) -> bool:
|
|
"""
|
|
Check hosting moderation
|
|
:param application_generate_entity: application generate entity
|
|
:param queue_manager: queue manager
|
|
:param prompt_messages: prompt messages
|
|
:return:
|
|
"""
|
|
hosting_moderation_feature = HostingModerationFeature()
|
|
moderation_result = hosting_moderation_feature.check(
|
|
application_generate_entity=application_generate_entity,
|
|
prompt_messages=prompt_messages
|
|
)
|
|
|
|
if moderation_result:
|
|
self.direct_output(
|
|
queue_manager=queue_manager,
|
|
app_generate_entity=application_generate_entity,
|
|
prompt_messages=prompt_messages,
|
|
text="I apologize for any confusion, " \
|
|
"but I'm an AI assistant to be helpful, harmless, and honest.",
|
|
stream=application_generate_entity.stream
|
|
)
|
|
|
|
return moderation_result
|
|
|
|
def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
|
|
app_id: str,
|
|
external_data_tools: list[ExternalDataVariableEntity],
|
|
inputs: dict,
|
|
query: str) -> dict:
|
|
"""
|
|
Fill in variable inputs from external data tools if exists.
|
|
|
|
:param tenant_id: workspace id
|
|
:param app_id: app id
|
|
:param external_data_tools: external data tools configs
|
|
:param inputs: the inputs
|
|
:param query: the query
|
|
:return: the filled inputs
|
|
"""
|
|
external_data_fetch_feature = ExternalDataFetch()
|
|
return external_data_fetch_feature.fetch(
|
|
tenant_id=tenant_id,
|
|
app_id=app_id,
|
|
external_data_tools=external_data_tools,
|
|
inputs=inputs,
|
|
query=query
|
|
)
|
|
|
|
def query_app_annotations_to_reply(self, app_record: App,
|
|
message: Message,
|
|
query: str,
|
|
user_id: str,
|
|
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
|
|
"""
|
|
Query app annotations to reply
|
|
:param app_record: app record
|
|
:param message: message
|
|
:param query: query
|
|
:param user_id: user id
|
|
:param invoke_from: invoke from
|
|
:return:
|
|
"""
|
|
annotation_reply_feature = AnnotationReplyFeature()
|
|
return annotation_reply_feature.query(
|
|
app_record=app_record,
|
|
message=message,
|
|
query=query,
|
|
user_id=user_id,
|
|
invoke_from=invoke_from
|
|
)
|
|
|