From e7bea600c6fa83cd183ad3b9119c08588c5cbde5 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 18 Jul 2025 03:08:42 +0800 Subject: [PATCH] refactor: Simplifies ModelMode instantiation Signed-off-by: -LAN- --- api/core/app/apps/base_app_runner.py | 2 +- api/core/prompt/simple_prompt_transform.py | 15 +-------------- api/core/rag/retrieval/dataset_retrieval.py | 2 +- .../knowledge_retrieval_node.py | 11 +++-------- .../parameter_extractor_node.py | 6 +++--- .../question_classifier_node.py | 2 +- 6 files changed, 10 insertions(+), 28 deletions(-) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 428db607fa..6e8c261a6a 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -118,7 +118,7 @@ class AppRunner: else: memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) - model_mode = ModelMode.value_of(model_config.mode) + model_mode = ModelMode(model_config.mode) prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]] if model_mode == ModelMode.COMPLETION: advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 47808928f7..e19c6419ca 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -29,19 +29,6 @@ class ModelMode(enum.StrEnum): COMPLETION = "completion" CHAT = "chat" - @classmethod - def value_of(cls, value: str) -> "ModelMode": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid mode value {value}") - prompt_file_contents: dict[str, Any] = {} @@ -65,7 +52,7 @@ class SimplePromptTransform(PromptTransform): ) -> tuple[list[PromptMessage], Optional[list[str]]]: inputs = {key: str(value) for key, value in inputs.items()} - model_mode = ModelMode.value_of(model_config.mode) + model_mode = ModelMode(model_config.mode) if model_mode == ModelMode.CHAT: prompt_messages, stops = self._get_chat_model_prompt_messages( app_mode=app_mode, diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 5c0360b064..3d0f0f97bc 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -1137,7 +1137,7 @@ class DatasetRetrieval: def _get_prompt_template( self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str ): - model_mode = ModelMode.value_of(mode) + model_mode = ModelMode(mode) input_text = query prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]] diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 0324be8b2a..3b96cecf81 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -510,13 +510,8 @@ class KnowledgeRetrievalNode(BaseNode): # get all metadata field metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all() all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] - # get metadata model config - metadata_model_config = node_data.metadata_model_config - if metadata_model_config is None: - raise ValueError("metadata_model_config is required") - # get metadata model instance - # fetch model config - model_instance, model_config = self.get_model_config(metadata_model_config) + # get metadata model instance and fetch model config + model_instance, model_config = self.get_model_config(node_data.metadata_model_config) # fetch prompt messages prompt_template = self._get_prompt_template( node_data=node_data, @@ -707,7 +702,7 @@ class KnowledgeRetrievalNode(BaseNode): ) def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str): - model_mode = ModelMode.value_of(node_data.metadata_model_config.mode) # type: ignore + model_mode = ModelMode(node_data.metadata_model_config.mode) input_text = query prompt_messages: list[LLMNodeChatModelMessage] = [] diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 7352a0c136..a23d284626 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -420,7 +420,7 @@ class ParameterExtractorNode(BaseNode): """ Generate prompt engineering prompt. """ - model_mode = ModelMode.value_of(data.model.mode) + model_mode = ModelMode(data.model.mode) if model_mode == ModelMode.COMPLETION: return self._generate_prompt_engineering_completion_prompt( @@ -716,7 +716,7 @@ class ParameterExtractorNode(BaseNode): memory: Optional[TokenBufferMemory], max_token_limit: int = 2000, ) -> list[ChatModelMessage]: - model_mode = ModelMode.value_of(node_data.model.mode) + model_mode = ModelMode(node_data.model.mode) input_text = query memory_str = "" instruction = variable_pool.convert_template(node_data.instruction or "").text @@ -743,7 +743,7 @@ class ParameterExtractorNode(BaseNode): memory: Optional[TokenBufferMemory], max_token_limit: int = 2000, ): - model_mode = ModelMode.value_of(node_data.model.mode) + model_mode = ModelMode(node_data.model.mode) input_text = query memory_str = "" instruction = variable_pool.convert_template(node_data.instruction or "").text diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 84983f8ad8..15012fa48d 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -334,7 +334,7 @@ class QuestionClassifierNode(BaseNode): memory: Optional[TokenBufferMemory], max_token_limit: int = 2000, ): - model_mode = ModelMode.value_of(node_data.model.mode) + model_mode = ModelMode(node_data.model.mode) classes = node_data.classes categories = [] for class_ in classes: