fix: align core folder with main

pull/19726/head
GareArc 1 year ago
parent 70c826cae3
commit ffcb561b6f
No known key found for this signature in database

@ -2,23 +2,25 @@ import logging
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any from typing import Any
from flask import request
from flask_login import current_user
from flask_restful import Resource, inputs, marshal_with, reqparse
from sqlalchemy import and_
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.console import api from controllers.console import api
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from controllers.console.wraps import (account_initialization_required,
cloud_edition_billing_resource_check)
from extensions.ext_database import db from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields from fields.installed_app_fields import installed_app_list_fields
from flask import request
from flask_login import current_user
from flask_restful import Resource, inputs, marshal_with, reqparse
from libs.login import login_required from libs.login import login_required
from models import App, InstalledApp, RecommendedApp from models import App, InstalledApp, RecommendedApp
from services.account_service import TenantService from services.account_service import TenantService
from services.app_service import AppService from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService from services.feature_service import FeatureService
from sqlalchemy import and_
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
logger = logging.getLogger(__name__)
class InstalledAppsListApi(Resource): class InstalledAppsListApi(Resource):
@ -65,7 +67,7 @@ class InstalledAppsListApi(Resource):
): ):
res.append(installed_app) res.append(installed_app)
installed_app_list = res installed_app_list = res
logging.info(f"installed_app_list: {installed_app_list}, user_id: {user_id}") logger.debug(f"installed_app_list: {installed_app_list}, user_id: {user_id}")
installed_app_list.sort( installed_app_list.sort(
key=lambda app: ( key=lambda app: (

@ -119,5 +119,3 @@ api.add_resource(LoginApi, "/login")
# api.add_resource(LogoutApi, "/logout") # api.add_resource(LogoutApi, "/logout")
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login") api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity") api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")

@ -108,6 +108,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# recalc llm max tokens # recalc llm max tokens
prompt_messages = self._organize_prompt_messages() prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model # invoke model
chunks = model_instance.invoke_llm( chunks = model_instance.invoke_llm(
prompt_messages=prompt_messages, prompt_messages=prompt_messages,

@ -86,6 +86,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# recalc llm max tokens # recalc llm max tokens
prompt_messages = self._organize_prompt_messages() prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model # invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
prompt_messages=prompt_messages, prompt_messages=prompt_messages,

@ -15,12 +15,14 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
from core.external_data_tool.external_data_fetch import ExternalDataFetch from core.external_data_tool.external_data_fetch import ExternalDataFetch
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage, PromptMessage,
) )
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.moderation.input_moderation import InputModeration from core.moderation.input_moderation import InputModeration
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
@ -33,6 +35,106 @@ if TYPE_CHECKING:
class AppRunner: class AppRunner:
def get_pre_calculate_rest_tokens(
self,
app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: Mapping[str, str],
files: Sequence["File"],
query: Optional[str] = None,
) -> int:
"""
Get pre calculate rest tokens
:param app_record: app record
:param model_config: model config entity
:param prompt_template_entity: prompt template entity
:param inputs: inputs
:param files: files
:param query: query
:return:
"""
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
if model_context_tokens is None:
return -1
if max_tokens is None:
max_tokens = 0
# get prompt messages without memory and context
prompt_messages, stop = self.organize_prompt_messages(
app_record=app_record,
model_config=model_config,
prompt_template_entity=prompt_template_entity,
inputs=inputs,
files=files,
query=query,
)
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise InvokeBadRequestError(
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size."
)
return rest_tokens
def recalc_llm_max_tokens(
self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]
):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template or "")
) or 0
if model_context_tokens is None:
return -1
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
if prompt_tokens + max_tokens > model_context_tokens:
max_tokens = max(model_context_tokens - prompt_tokens, 16)
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
model_config.parameters[parameter_rule.name] = max_tokens
def organize_prompt_messages( def organize_prompt_messages(
self, self,
app_record: App, app_record: App,
@ -338,4 +440,4 @@ class AppRunner:
annotation_reply_feature = AnnotationReplyFeature() annotation_reply_feature = AnnotationReplyFeature()
return annotation_reply_feature.query( return annotation_reply_feature.query(
app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from
) )

@ -194,6 +194,9 @@ class ChatAppRunner(AppRunner):
if hosting_moderation_result: if hosting_moderation_result:
return return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
# Invoke model # Invoke model
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
@ -213,4 +216,4 @@ class ChatAppRunner(AppRunner):
# handle invoke result # handle invoke result
self._handle_invoke_result( self._handle_invoke_result(
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
) )

@ -152,6 +152,9 @@ class CompletionAppRunner(AppRunner):
if hosting_moderation_result: if hosting_moderation_result:
return return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
# Invoke model # Invoke model
model_instance = ModelInstance( model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle, provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
@ -171,4 +174,4 @@ class CompletionAppRunner(AppRunner):
# handle invoke result # handle invoke result
self._handle_invoke_result( self._handle_invoke_result(
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
) )

@ -26,7 +26,7 @@ class TokenBufferMemory:
self.model_instance = model_instance self.model_instance = model_instance
def get_history_prompt_messages( def get_history_prompt_messages(
self, max_token_limit: int = 100000, message_limit: Optional[int] = None self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> Sequence[PromptMessage]: ) -> Sequence[PromptMessage]:
""" """
Get history prompt messages. Get history prompt messages.
@ -169,4 +169,4 @@ class TokenBufferMemory:
message = f"{role}: {m.content}" message = f"{role}: {m.content}"
string_messages.append(message) string_messages.append(message)
return "\n".join(string_messages) return "\n".join(string_messages)

@ -199,7 +199,7 @@ class CodeNode(BaseNode[CodeNodeData]):
if output_config.type == "object": if output_config.type == "object":
# check if output is object # check if output is object
if not isinstance(result.get(output_name), dict): if not isinstance(result.get(output_name), dict):
if result.get(output_name) is None: if result[output_name] is None:
transformed_result[output_name] = None transformed_result[output_name] = None
else: else:
raise OutputValidationError( raise OutputValidationError(
@ -333,4 +333,4 @@ class CodeNode(BaseNode[CodeNodeData]):
return { return {
node_id + "." + variable_selector.variable: variable_selector.value_selector node_id + "." + variable_selector.variable: variable_selector.value_selector
for variable_selector in node_data.variables for variable_selector in node_data.variables
} }

@ -1315,12 +1315,14 @@ def _handle_memory_chat_mode(
*, *,
memory: TokenBufferMemory | None, memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None, memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity, # TODO(-LAN-): Needs to remove model_config: ModelConfigWithCredentialsEntity,
) -> Sequence[PromptMessage]: ) -> Sequence[PromptMessage]:
memory_messages: Sequence[PromptMessage] = [] memory_messages: Sequence[PromptMessage] = []
# Get messages from memory for chat model # Get messages from memory for chat model
if memory and memory_config: if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
memory_messages = memory.get_history_prompt_messages( memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None, message_limit=memory_config.window.size if memory_config.window.enabled else None,
) )
return memory_messages return memory_messages
@ -1428,4 +1430,4 @@ def convert_boolean_to_string(schema: dict) -> None:
elif isinstance(value, list): elif isinstance(value, list):
for item in value: for item in value:
if isinstance(item, dict): if isinstance(item, dict):
convert_boolean_to_string(item) convert_boolean_to_string(item)
Loading…
Cancel
Save