diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 196fa9512f..6254d81913 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -2,23 +2,25 @@ import logging from datetime import UTC, datetime from typing import Any -from flask import request -from flask_login import current_user -from flask_restful import Resource, inputs, marshal_with, reqparse -from sqlalchemy import and_ -from werkzeug.exceptions import BadRequest, Forbidden, NotFound - from controllers.console import api from controllers.console.explore.wraps import InstalledAppResource -from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from controllers.console.wraps import (account_initialization_required, + cloud_edition_billing_resource_check) from extensions.ext_database import db from fields.installed_app_fields import installed_app_list_fields +from flask import request +from flask_login import current_user +from flask_restful import Resource, inputs, marshal_with, reqparse from libs.login import login_required from models import App, InstalledApp, RecommendedApp from services.account_service import TenantService from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService +from sqlalchemy import and_ +from werkzeug.exceptions import BadRequest, Forbidden, NotFound + +logger = logging.getLogger(__name__) class InstalledAppsListApi(Resource): @@ -65,7 +67,7 @@ class InstalledAppsListApi(Resource): ): res.append(installed_app) installed_app_list = res - logging.info(f"installed_app_list: {installed_app_list}, user_id: {user_id}") + logger.debug(f"installed_app_list: {installed_app_list}, user_id: {user_id}") installed_app_list.sort( key=lambda app: ( diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py index 6af8d578c5..f743b31338 100644 --- a/api/controllers/web/login.py +++ b/api/controllers/web/login.py @@ -119,5 +119,3 @@ api.add_resource(LoginApi, "/login") # api.add_resource(LogoutApi, "/logout") api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login") api.add_resource(EmailCodeLoginApi, "/email-code-login/validity") -api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login") -api.add_resource(EmailCodeLoginApi, "/email-code-login/validity") diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 0674a82c56..5212d797d8 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -108,6 +108,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): # recalc llm max tokens prompt_messages = self._organize_prompt_messages() + self.recalc_llm_max_tokens(self.model_config, prompt_messages) # invoke model chunks = model_instance.invoke_llm( prompt_messages=prompt_messages, diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index faf410cfff..611a55b30a 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -86,6 +86,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): # recalc llm max tokens prompt_messages = self._organize_prompt_messages() + self.recalc_llm_max_tokens(self.model_config, prompt_messages) # invoke model chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( prompt_messages=prompt_messages, diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index eb3c84208f..18ed115874 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -15,12 +15,14 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.external_data_tool.external_data_fetch import ExternalDataFetch from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, ) +from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.errors.invoke import InvokeBadRequestError from core.moderation.input_moderation import InputModeration from core.prompt.advanced_prompt_transform import AdvancedPromptTransform @@ -33,6 +35,106 @@ if TYPE_CHECKING: class AppRunner: + def get_pre_calculate_rest_tokens( + self, + app_record: App, + model_config: ModelConfigWithCredentialsEntity, + prompt_template_entity: PromptTemplateEntity, + inputs: Mapping[str, str], + files: Sequence["File"], + query: Optional[str] = None, + ) -> int: + """ + Get pre calculate rest tokens + :param app_record: app record + :param model_config: model config entity + :param prompt_template_entity: prompt template entity + :param inputs: inputs + :param files: files + :param query: query + :return: + """ + # Invoke model + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model + ) + + model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + + max_tokens = 0 + for parameter_rule in model_config.model_schema.parameter_rules: + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template or "") + ) or 0 + + if model_context_tokens is None: + return -1 + + if max_tokens is None: + max_tokens = 0 + + # get prompt messages without memory and context + prompt_messages, stop = self.organize_prompt_messages( + app_record=app_record, + model_config=model_config, + prompt_template_entity=prompt_template_entity, + inputs=inputs, + files=files, + query=query, + ) + + prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) + + rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens + if rest_tokens < 0: + raise InvokeBadRequestError( + "Query or prefix prompt is too long, you can reduce the prefix prompt, " + "or shrink the max token, or switch to a llm with a larger token limit size." + ) + + return rest_tokens + + def recalc_llm_max_tokens( + self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage] + ): + # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model + ) + + model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + + max_tokens = 0 + for parameter_rule in model_config.model_schema.parameter_rules: + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(parameter_rule.use_template or "") + ) or 0 + + if model_context_tokens is None: + return -1 + + if max_tokens is None: + max_tokens = 0 + + prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) + + if prompt_tokens + max_tokens > model_context_tokens: + max_tokens = max(model_context_tokens - prompt_tokens, 16) + + for parameter_rule in model_config.model_schema.parameter_rules: + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + model_config.parameters[parameter_rule.name] = max_tokens + def organize_prompt_messages( self, app_record: App, @@ -338,4 +440,4 @@ class AppRunner: annotation_reply_feature = AnnotationReplyFeature() return annotation_reply_feature.query( app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from - ) + ) \ No newline at end of file diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 64724e9d71..69a66375db 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -194,6 +194,9 @@ class ChatAppRunner(AppRunner): if hosting_moderation_result: return + # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit + self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages) + # Invoke model model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, @@ -213,4 +216,4 @@ class ChatAppRunner(AppRunner): # handle invoke result self._handle_invoke_result( invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream - ) + ) \ No newline at end of file diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 8440686c9d..a5c393a27a 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -152,6 +152,9 @@ class CompletionAppRunner(AppRunner): if hosting_moderation_result: return + # Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit + self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages) + # Invoke model model_instance = ModelInstance( provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, @@ -171,4 +174,4 @@ class CompletionAppRunner(AppRunner): # handle invoke result self._handle_invoke_result( invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream - ) + ) \ No newline at end of file diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 7cccc7120e..6230100c64 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -26,7 +26,7 @@ class TokenBufferMemory: self.model_instance = model_instance def get_history_prompt_messages( - self, max_token_limit: int = 100000, message_limit: Optional[int] = None + self, max_token_limit: int = 2000, message_limit: Optional[int] = None ) -> Sequence[PromptMessage]: """ Get history prompt messages. @@ -169,4 +169,4 @@ class TokenBufferMemory: message = f"{role}: {m.content}" string_messages.append(message) - return "\n".join(string_messages) + return "\n".join(string_messages) \ No newline at end of file diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index f34869cdef..335ba3f3fc 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -199,7 +199,7 @@ class CodeNode(BaseNode[CodeNodeData]): if output_config.type == "object": # check if output is object if not isinstance(result.get(output_name), dict): - if result.get(output_name) is None: + if result[output_name] is None: transformed_result[output_name] = None else: raise OutputValidationError( @@ -333,4 +333,4 @@ class CodeNode(BaseNode[CodeNodeData]): return { node_id + "." + variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables - } + } \ No newline at end of file diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index c25aed72e8..57771f8f64 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1315,12 +1315,14 @@ def _handle_memory_chat_mode( *, memory: TokenBufferMemory | None, memory_config: MemoryConfig | None, - model_config: ModelConfigWithCredentialsEntity, # TODO(-LAN-): Needs to remove + model_config: ModelConfigWithCredentialsEntity, ) -> Sequence[PromptMessage]: memory_messages: Sequence[PromptMessage] = [] # Get messages from memory for chat model if memory and memory_config: + rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) memory_messages = memory.get_history_prompt_messages( + max_token_limit=rest_tokens, message_limit=memory_config.window.size if memory_config.window.enabled else None, ) return memory_messages @@ -1428,4 +1430,4 @@ def convert_boolean_to_string(schema: dict) -> None: elif isinstance(value, list): for item in value: if isinstance(item, dict): - convert_boolean_to_string(item) + convert_boolean_to_string(item) \ No newline at end of file