From e2533f1e6b4f1ee1e37ce39beeb53abddcb702d7 Mon Sep 17 00:00:00 2001 From: baonudesifeizhai Date: Fri, 4 Jul 2025 14:36:20 -0400 Subject: [PATCH] fix(agent): show agent run steps, fixes #21718 --- api/core/workflow/nodes/agent/agent_node.py | 35 ++++++++++- api/core/workflow/nodes/tool/tool_node.py | 36 ++++++++++- api/services/enterprise/enterprise_service.py | 4 +- .../workflow/nodes/test_llm.py | 8 +-- .../entities/advanced_prompt_entities.py | 19 ++++++ web/.env.local.save | 60 +++++++++++++++++++ 6 files changed, 153 insertions(+), 9 deletions(-) create mode 100644 core/prompt/entities/advanced_prompt_entities.py create mode 100644 web/.env.local.save diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 2f28363955..066e52278c 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -1,4 +1,5 @@ import json +import uuid from collections.abc import Generator, Mapping, Sequence from typing import Any, Optional, cast @@ -102,14 +103,46 @@ class AgentNode(ToolNode): try: # convert tool messages + agent_thoughts = [] + + from core.tools.entities.tool_entities import ToolInvokeMessage + + thought_log_message = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LOG, + message=ToolInvokeMessage.LogMessage( + id=str(uuid.uuid4()), + label=f"Agent Strategy: {cast(AgentNodeData, self.node_data).agent_strategy_name}", + parent_id=None, + error=None, + status=ToolInvokeMessage.LogMessage.LogStatus.START, + data={ + "strategy": cast(AgentNodeData, self.node_data).agent_strategy_name, + "parameters": parameters_for_log, + "thought_process": "Agent strategy execution started", + }, + metadata={ + "icon": self.agent_strategy_icon, + "agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name, + }, + ), + ) + + from core.tools.entities.tool_entities import ToolInvokeMessage + + def enhanced_message_stream(): + + yield thought_log_message + + yield from message_stream yield from self._transform_message( - message_stream, + enhanced_message_stream(), { "icon": self.agent_strategy_icon, "agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name, }, parameters_for_log, + agent_thoughts, ) except PluginDaemonClientSideError as e: yield RunCompletedEvent( diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index aa15d69931..df10e8c41d 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping, Sequence -from typing import Any, cast +from typing import Any, Optional, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -188,6 +188,7 @@ class ToolNode(BaseNode[ToolNodeData]): messages: Generator[ToolInvokeMessage, None, None], tool_info: Mapping[str, Any], parameters_for_log: dict[str, Any], + agent_thoughts: Optional[list] = None, ) -> Generator: """ Convert ToolInvokeMessages into tuple[plain_text, files] @@ -365,10 +366,41 @@ class ToolNode(BaseNode[ToolNodeData]): yield agent_log + # Add agent_logs to outputs['json'] to ensure frontend can access thinking process + json_output = json.copy() + if agent_logs: + if not json_output: + json_output = {} + elif isinstance(json_output, list) and len(json_output) == 1: + # If json is a list with only one element, convert it to a dictionary + json_output = json_output[0] if isinstance(json_output[0], dict) else {"data": json_output[0]} + elif isinstance(json_output, list): + # If json is a list with multiple elements, create a dictionary containing all data + json_output = {"data": json_output} + + # Ensure json_output is a dictionary type + if not isinstance(json_output, dict): + json_output = {"data": json_output} + + # Add agent_logs to json output + json_output["agent_logs"] = [ + { + "id": log.id, + "parent_id": log.parent_id, + "error": log.error, + "status": log.status, + "data": log.data, + "label": log.label, + "metadata": log.metadata, + "node_id": log.node_id, + } + for log in agent_logs + ] + yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json, **variables}, + outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, metadata={ **agent_execution_metadata, WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 8c06ee9386..54d45f45ea 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -29,7 +29,7 @@ class EnterpriseService: raise ValueError("No data found.") try: # parse the UTC timestamp from the response - return datetime.fromisoformat(data.replace("Z", "+00:00")) + return datetime.fromisoformat(data) except ValueError as e: raise ValueError(f"Invalid date format: {data}") from e @@ -40,7 +40,7 @@ class EnterpriseService: raise ValueError("No data found.") try: # parse the UTC timestamp from the response - return datetime.fromisoformat(data.replace("Z", "+00:00")) + return datetime.fromisoformat(data) except ValueError as e: raise ValueError(f"Invalid date format: {data}") from e diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 389d1071f3..fe1d357d44 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -119,11 +119,11 @@ def test_execute_llm(flask_req_ctx): mock_usage = LLMUsage( prompt_tokens=30, prompt_unit_price=Decimal("0.001"), - prompt_price_unit=Decimal("1000"), + prompt_price_unit=Decimal(1000), prompt_price=Decimal("0.00003"), completion_tokens=20, completion_unit_price=Decimal("0.002"), - completion_price_unit=Decimal("1000"), + completion_price_unit=Decimal(1000), completion_price=Decimal("0.00004"), total_tokens=50, total_price=Decimal("0.00007"), @@ -222,11 +222,11 @@ def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock): mock_usage = LLMUsage( prompt_tokens=30, prompt_unit_price=Decimal("0.001"), - prompt_price_unit=Decimal("1000"), + prompt_price_unit=Decimal(1000), prompt_price=Decimal("0.00003"), completion_tokens=20, completion_unit_price=Decimal("0.002"), - completion_price_unit=Decimal("1000"), + completion_price_unit=Decimal(1000), completion_price=Decimal("0.00004"), total_tokens=50, total_price=Decimal("0.00007"), diff --git a/core/prompt/entities/advanced_prompt_entities.py b/core/prompt/entities/advanced_prompt_entities.py new file mode 100644 index 0000000000..2c56b82465 --- /dev/null +++ b/core/prompt/entities/advanced_prompt_entities.py @@ -0,0 +1,19 @@ +from typing import Optional, Any +from pydantic import BaseModel, Field, model_validator +from core.prompt.entities.role_prefix import RolePrefix +from core.prompt.entities.window import Window + +class MemoryConfig(BaseModel): + role_prefix: RolePrefix = Field(default_factory=RolePrefix) + window: Window = Field(default_factory=Window) + memory_key: Optional[str] = Field(None) + + # The `model_validate` method is used to create a `MemoryConfig` object from a dictionary. + @model_validator(mode="before") + @classmethod + def pre_validate(cls, values: Any) -> Any: + if "role_prefix" not in values: + values["role_prefix"] = {} + if "window" not in values: + values["window"] = {} + return values \ No newline at end of file diff --git a/web/.env.local.save b/web/.env.local.save new file mode 100644 index 0000000000..a02d6cd519 --- /dev/null +++ b/web/.env.local.save @@ -0,0 +1,60 @@ +\# For production release, change this to PRODUCTION +NEXT_PUBLIC_DEPLOY_ENV=DEVELOPMENT +# The deployment edition, SELF_HOSTED +NEXT_PUBLIC_EDITION=SELF_HOSTED +# The base URL of console application, refers to the Console base URL of WEB service if console domain is +# different from api or web app domain. +# example: http://cloud.dify.ai/console/api +NEXT_PUBLIC_API_PREFIX=http://localhost:5001/console/api +# The URL for Web APP, refers to the Web App base URL of WEB service if web app domain is different from +# console or api domain. +# example: http://udify.app/api +NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api +# The API PREFIX for MARKETPLACE +NEXT_PUBLIC_MARKETPLACE_API_PREFIX=https://marketplace.dify.ai/api/v1 +# The URL for MARKETPLACE +NEXT_PUBLIC_MARKETPLACE_URL_PREFIX=https://marketplace.dify.ai + +# SENTRY +NEXT_PUBLIC_SENTRY_DSN= + +# Disable Next.js Telemetry (https://nextjs.org/telemetry) +NEXT_TELEMETRY_DISABLED=1 + +# Disable Upload Image as WebApp icon default is false +NEXT_PUBLIC_UPLOAD_IMAGE_AS_ICON=false + +# The timeout for the text generation in millisecond +NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS=60000 + +# CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP +NEXT_PUBLIC_CSP_WHITELIST= +# Default is not allow to embed into iframe to prevent Clickjacking: https://owasp.org/www-community/attacks/Clickjacking +NEXT_PUBLIC_ALLOW_EMBED= + +# Github Access Token, used for invoking Github API +NEXT_PUBLIC_GITHUB_ACCESS_TOKEN= +# The maximum number of top-k value for RAG. +NEXT_PUBLIC_TOP_K_MAX_VALUE=10 + +# The maximum number of tokens for segmentation +NEXT_PUBLIC_INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000 + +# Maximum loop count in the workflow +NEXT_PUBLIC_LOOP_NODE_MAX_COUNT=100 + +# Maximum number of tools in the agent/workflow +NEXT_PUBLIC_MAX_TOOLS_NUM=10 + +# Maximum number of Parallelism branches in the workflow +NEXT_PUBLIC_MAX_PARALLEL_LIMIT=10 + +# The maximum number of iterations for agent setting +NEXT_PUBLIC_MAX_ITERATIONS_NUM=99 + +NEXT_PUBLIC_ENABLE_WEBSITE_JINAREADER=true +NEXT_PUBLIC_ENABLE_WEBSITE_FIRECRAWL=true +NEXT_PUBLIC_ENABLE_WEBSITE_WATERCRAWL=true + +# The maximum number of tree node depth for workflow +NEXT_PUBLIC_MAX_TREE_DEPTH=50