diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 5017835565..e1c021a44a 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -55,6 +55,25 @@ class ProviderModelWithStatusEntity(ProviderModel): status: ModelStatus load_balancing_enabled: bool = False + def raise_for_status(self) -> None: + """ + Check model status and raise ValueError if not active. + + :raises ValueError: When model status is not active, with a descriptive message + """ + if self.status == ModelStatus.ACTIVE: + return + + error_messages = { + ModelStatus.NO_CONFIGURE: "Model is not configured", + ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded", + ModelStatus.NO_PERMISSION: "No permission to use this model", + ModelStatus.DISABLED: "Model is disabled", + } + + if self.status in error_messages: + raise ValueError(error_messages[self.status]) + class ModelWithProviderEntity(ProviderModelWithStatusEntity): """ diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 373ef2bbe2..568149cc37 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -160,6 +160,10 @@ class ProviderModel(BaseModel): deprecated: bool = False model_config = ConfigDict(protected_namespaces=()) + @property + def support_structure_output(self) -> bool: + return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features + class ParameterRule(BaseModel): """ diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 486b4b01af..36d0688807 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -66,7 +66,8 @@ class LLMNodeData(BaseNodeData): context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) structured_output: dict | None = None - structured_output_enabled: bool = False + # We used 'structured_output_enabled' in the past, but it's not a good name. + structured_output_switch_on: bool = Field(False, alias="structured_output_enabled") @field_validator("prompt_config", mode="before") @classmethod @@ -74,3 +75,7 @@ class LLMNodeData(BaseNodeData): if v is None: return PromptConfig() return v + + @property + def structured_output_enabled(self) -> bool: + return self.structured_output_switch_on and self.structured_output is not None diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index df8f614db3..2795fe05ad 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -12,9 +12,7 @@ from sqlalchemy.orm import Session from configs import dify_config from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.file import FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage from core.memory.token_buffer_memory import TokenBufferMemory @@ -74,7 +72,6 @@ from core.workflow.nodes.event import ( from core.workflow.utils.structured_output.entities import ( ResponseFormat, SpecialModelType, - SupportStructuredOutputStatus, ) from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT from core.workflow.utils.variable_template_parser import VariableTemplateParser @@ -277,7 +274,7 @@ class LLMNode(BaseNode[LLMNodeData]): llm_usage=usage, ) ) - except LLMNodeError as e: + except ValueError as e: yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -527,65 +524,53 @@ class LLMNode(BaseNode[LLMNodeData]): def _fetch_model_config( self, node_data_model: ModelConfig ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: - model_name = node_data_model.name - provider_name = node_data_model.provider + if not node_data_model.mode: + raise LLMModeRequiredError("LLM mode is required.") - model_manager = ModelManager() - model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name + model = ModelManager().get_model_instance( + tenant_id=self.tenant_id, + model_type=ModelType.LLM, + provider=node_data_model.provider, + model=node_data_model.name, ) - provider_model_bundle = model_instance.provider_model_bundle - model_type_instance = model_instance.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - model_credentials = model_instance.credentials + model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance) # check model - provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_name, model_type=ModelType.LLM + provider_model = model.provider_model_bundle.configuration.get_provider_model( + model=node_data_model.name, model_type=ModelType.LLM ) if provider_model is None: - raise ModelNotExistError(f"Model {model_name} not exist.") - - if provider_model.status == ModelStatus.NO_CONFIGURE: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - elif provider_model.status == ModelStatus.NO_PERMISSION: - raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") - elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + provider_model.raise_for_status() # model config - completion_params = node_data_model.completion_params - stop = [] - if "stop" in completion_params: - stop = completion_params["stop"] - del completion_params["stop"] - - # get model mode - model_mode = node_data_model.mode - if not model_mode: - raise LLMModeRequiredError("LLM mode is required.") - - model_schema = model_type_instance.get_model_schema(model_name, model_credentials) + stop: list[str] = [] + if "stop" in node_data_model.completion_params: + stop = node_data_model.completion_params.pop("stop") + model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials) if not model_schema: - raise ModelNotExistError(f"Model {model_name} not exist.") - support_structured_output = self._check_model_structured_output_support() - if support_structured_output == SupportStructuredOutputStatus.SUPPORTED: - completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules) - elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: - # Set appropriate response format based on model capabilities - self._set_response_format(completion_params, model_schema.parameter_rules) - return model_instance, ModelConfigWithCredentialsEntity( - provider=provider_name, - model=model_name, + raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + + if self.node_data.structured_output_enabled: + if model_schema.support_structure_output: + node_data_model.completion_params = self._handle_native_json_schema( + node_data_model.completion_params, model_schema.parameter_rules + ) + else: + # Set appropriate response format based on model capabilities + self._set_response_format(node_data_model.completion_params, model_schema.parameter_rules) + + return model, ModelConfigWithCredentialsEntity( + provider=node_data_model.provider, + model=node_data_model.name, model_schema=model_schema, - mode=model_mode, - provider_model_bundle=provider_model_bundle, - credentials=model_credentials, - parameters=completion_params, + mode=node_data_model.mode, + provider_model_bundle=model.provider_model_bundle, + credentials=model.credentials, + parameters=node_data_model.completion_params, stop=stop, ) @@ -786,13 +771,25 @@ class LLMNode(BaseNode[LLMNodeData]): "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) - support_structured_output = self._check_model_structured_output_support() - if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: - filtered_prompt_messages = self._handle_prompt_based_schema( - prompt_messages=filtered_prompt_messages, - ) - stop = model_config.stop - return filtered_prompt_messages, stop + + model = ModelManager().get_model_instance( + tenant_id=self.tenant_id, + model_type=ModelType.LLM, + provider=self.node_data.model.provider, + model=self.node_data.model.name, + ) + model_schema = model.model_type_instance.get_model_schema( + model=self.node_data.model.name, + credentials=model.credentials, + ) + if not model_schema: + raise ModelNotExistError(f"Model {self.node_data.model.name} not exist.") + if self.node_data.structured_output_enabled: + if not model_schema.support_structure_output: + filtered_prompt_messages = self._handle_prompt_based_schema( + prompt_messages=filtered_prompt_messages, + ) + return filtered_prompt_messages, model_config.stop def _parse_structured_output(self, result_text: str) -> dict[str, Any]: structured_output: dict[str, Any] = {} @@ -1185,32 +1182,6 @@ class LLMNode(BaseNode[LLMNodeData]): except json.JSONDecodeError: raise LLMNodeError("structured_output_schema is not valid JSON format") - def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus: - """ - Check if the current model supports structured output. - - Returns: - SupportStructuredOutput: The support status of structured output - """ - # Early return if structured output is disabled - if ( - not isinstance(self.node_data, LLMNodeData) - or not self.node_data.structured_output_enabled - or not self.node_data.structured_output - ): - return SupportStructuredOutputStatus.DISABLED - # Get model schema and check if it exists - model_schema = self._fetch_model_schema(self.node_data.model.provider) - if not model_schema: - return SupportStructuredOutputStatus.DISABLED - - # Check if model supports structured output feature - return ( - SupportStructuredOutputStatus.SUPPORTED - if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features) - else SupportStructuredOutputStatus.UNSUPPORTED - ) - def _save_multimodal_output_and_convert_result_to_markdown( self, contents: str | list[PromptMessageContentUnionTypes] | None, diff --git a/api/core/workflow/utils/structured_output/entities.py b/api/core/workflow/utils/structured_output/entities.py index 7954acbaee..6491042bfe 100644 --- a/api/core/workflow/utils/structured_output/entities.py +++ b/api/core/workflow/utils/structured_output/entities.py @@ -14,11 +14,3 @@ class SpecialModelType(StrEnum): GEMINI = "gemini" OLLAMA = "ollama" - - -class SupportStructuredOutputStatus(StrEnum): - """Constants for structured output support status""" - - SUPPORTED = "supported" - UNSUPPORTED = "unsupported" - DISABLED = "disabled"