fix(agent): show agent run steps, fixes #21718

pull/21940/head
baonudesifeizhai 11 months ago
parent 75f232d832
commit e2533f1e6b

@ -1,4 +1,5 @@
import json import json
import uuid
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast from typing import Any, Optional, cast
@ -102,14 +103,46 @@ class AgentNode(ToolNode):
try: try:
# convert tool messages # convert tool messages
agent_thoughts = []
from core.tools.entities.tool_entities import ToolInvokeMessage
thought_log_message = ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LOG,
message=ToolInvokeMessage.LogMessage(
id=str(uuid.uuid4()),
label=f"Agent Strategy: {cast(AgentNodeData, self.node_data).agent_strategy_name}",
parent_id=None,
error=None,
status=ToolInvokeMessage.LogMessage.LogStatus.START,
data={
"strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
"parameters": parameters_for_log,
"thought_process": "Agent strategy execution started",
},
metadata={
"icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
},
),
)
from core.tools.entities.tool_entities import ToolInvokeMessage
def enhanced_message_stream():
yield thought_log_message
yield from message_stream
yield from self._transform_message( yield from self._transform_message(
message_stream, enhanced_message_stream(),
{ {
"icon": self.agent_strategy_icon, "icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name, "agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
}, },
parameters_for_log, parameters_for_log,
agent_thoughts,
) )
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
yield RunCompletedEvent( yield RunCompletedEvent(

@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast from typing import Any, Optional, cast
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -188,6 +188,7 @@ class ToolNode(BaseNode[ToolNodeData]):
messages: Generator[ToolInvokeMessage, None, None], messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any], tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any], parameters_for_log: dict[str, Any],
agent_thoughts: Optional[list] = None,
) -> Generator: ) -> Generator:
""" """
Convert ToolInvokeMessages into tuple[plain_text, files] Convert ToolInvokeMessages into tuple[plain_text, files]
@ -365,10 +366,41 @@ class ToolNode(BaseNode[ToolNodeData]):
yield agent_log yield agent_log
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output = json.copy()
if agent_logs:
if not json_output:
json_output = {}
elif isinstance(json_output, list) and len(json_output) == 1:
# If json is a list with only one element, convert it to a dictionary
json_output = json_output[0] if isinstance(json_output[0], dict) else {"data": json_output[0]}
elif isinstance(json_output, list):
# If json is a list with multiple elements, create a dictionary containing all data
json_output = {"data": json_output}
# Ensure json_output is a dictionary type
if not isinstance(json_output, dict):
json_output = {"data": json_output}
# Add agent_logs to json output
json_output["agent_logs"] = [
{
"id": log.id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
for log in agent_logs
]
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json, **variables}, outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata={ metadata={
**agent_execution_metadata, **agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,

@ -29,7 +29,7 @@ class EnterpriseService:
raise ValueError("No data found.") raise ValueError("No data found.")
try: try:
# parse the UTC timestamp from the response # parse the UTC timestamp from the response
return datetime.fromisoformat(data.replace("Z", "+00:00")) return datetime.fromisoformat(data)
except ValueError as e: except ValueError as e:
raise ValueError(f"Invalid date format: {data}") from e raise ValueError(f"Invalid date format: {data}") from e
@ -40,7 +40,7 @@ class EnterpriseService:
raise ValueError("No data found.") raise ValueError("No data found.")
try: try:
# parse the UTC timestamp from the response # parse the UTC timestamp from the response
return datetime.fromisoformat(data.replace("Z", "+00:00")) return datetime.fromisoformat(data)
except ValueError as e: except ValueError as e:
raise ValueError(f"Invalid date format: {data}") from e raise ValueError(f"Invalid date format: {data}") from e

@ -119,11 +119,11 @@ def test_execute_llm(flask_req_ctx):
mock_usage = LLMUsage( mock_usage = LLMUsage(
prompt_tokens=30, prompt_tokens=30,
prompt_unit_price=Decimal("0.001"), prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1000"), prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"), prompt_price=Decimal("0.00003"),
completion_tokens=20, completion_tokens=20,
completion_unit_price=Decimal("0.002"), completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"), completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"), completion_price=Decimal("0.00004"),
total_tokens=50, total_tokens=50,
total_price=Decimal("0.00007"), total_price=Decimal("0.00007"),
@ -222,11 +222,11 @@ def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
mock_usage = LLMUsage( mock_usage = LLMUsage(
prompt_tokens=30, prompt_tokens=30,
prompt_unit_price=Decimal("0.001"), prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1000"), prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"), prompt_price=Decimal("0.00003"),
completion_tokens=20, completion_tokens=20,
completion_unit_price=Decimal("0.002"), completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"), completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"), completion_price=Decimal("0.00004"),
total_tokens=50, total_tokens=50,
total_price=Decimal("0.00007"), total_price=Decimal("0.00007"),

@ -0,0 +1,19 @@
from typing import Optional, Any
from pydantic import BaseModel, Field, model_validator
from core.prompt.entities.role_prefix import RolePrefix
from core.prompt.entities.window import Window
class MemoryConfig(BaseModel):
role_prefix: RolePrefix = Field(default_factory=RolePrefix)
window: Window = Field(default_factory=Window)
memory_key: Optional[str] = Field(None)
# The `model_validate` method is used to create a `MemoryConfig` object from a dictionary.
@model_validator(mode="before")
@classmethod
def pre_validate(cls, values: Any) -> Any:
if "role_prefix" not in values:
values["role_prefix"] = {}
if "window" not in values:
values["window"] = {}
return values

@ -0,0 +1,60 @@
\# For production release, change this to PRODUCTION
NEXT_PUBLIC_DEPLOY_ENV=DEVELOPMENT
# The deployment edition, SELF_HOSTED
NEXT_PUBLIC_EDITION=SELF_HOSTED
# The base URL of console application, refers to the Console base URL of WEB service if console domain is
# different from api or web app domain.
# example: http://cloud.dify.ai/console/api
NEXT_PUBLIC_API_PREFIX=http://localhost:5001/console/api
# The URL for Web APP, refers to the Web App base URL of WEB service if web app domain is different from
# console or api domain.
# example: http://udify.app/api
NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api
# The API PREFIX for MARKETPLACE
NEXT_PUBLIC_MARKETPLACE_API_PREFIX=https://marketplace.dify.ai/api/v1
# The URL for MARKETPLACE
NEXT_PUBLIC_MARKETPLACE_URL_PREFIX=https://marketplace.dify.ai
# SENTRY
NEXT_PUBLIC_SENTRY_DSN=
# Disable Next.js Telemetry (https://nextjs.org/telemetry)
NEXT_TELEMETRY_DISABLED=1
# Disable Upload Image as WebApp icon default is false
NEXT_PUBLIC_UPLOAD_IMAGE_AS_ICON=false
# The timeout for the text generation in millisecond
NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS=60000
# CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP
NEXT_PUBLIC_CSP_WHITELIST=
# Default is not allow to embed into iframe to prevent Clickjacking: https://owasp.org/www-community/attacks/Clickjacking
NEXT_PUBLIC_ALLOW_EMBED=
# Github Access Token, used for invoking Github API
NEXT_PUBLIC_GITHUB_ACCESS_TOKEN=
# The maximum number of top-k value for RAG.
NEXT_PUBLIC_TOP_K_MAX_VALUE=10
# The maximum number of tokens for segmentation
NEXT_PUBLIC_INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
# Maximum loop count in the workflow
NEXT_PUBLIC_LOOP_NODE_MAX_COUNT=100
# Maximum number of tools in the agent/workflow
NEXT_PUBLIC_MAX_TOOLS_NUM=10
# Maximum number of Parallelism branches in the workflow
NEXT_PUBLIC_MAX_PARALLEL_LIMIT=10
# The maximum number of iterations for agent setting
NEXT_PUBLIC_MAX_ITERATIONS_NUM=99
NEXT_PUBLIC_ENABLE_WEBSITE_JINAREADER=true
NEXT_PUBLIC_ENABLE_WEBSITE_FIRECRAWL=true
NEXT_PUBLIC_ENABLE_WEBSITE_WATERCRAWL=true
# The maximum number of tree node depth for workflow
NEXT_PUBLIC_MAX_TREE_DEPTH=50
Loading…
Cancel
Save