fix(agent): show agent run steps, fixes #21718

pull/21940/head
baonudesifeizhai 11 months ago
parent 75f232d832
commit e2533f1e6b

@ -1,4 +1,5 @@
import json
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
@ -102,14 +103,46 @@ class AgentNode(ToolNode):
try:
# convert tool messages
agent_thoughts = []
from core.tools.entities.tool_entities import ToolInvokeMessage
thought_log_message = ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LOG,
message=ToolInvokeMessage.LogMessage(
id=str(uuid.uuid4()),
label=f"Agent Strategy: {cast(AgentNodeData, self.node_data).agent_strategy_name}",
parent_id=None,
error=None,
status=ToolInvokeMessage.LogMessage.LogStatus.START,
data={
"strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
"parameters": parameters_for_log,
"thought_process": "Agent strategy execution started",
},
metadata={
"icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
},
),
)
from core.tools.entities.tool_entities import ToolInvokeMessage
def enhanced_message_stream():
yield thought_log_message
yield from message_stream
yield from self._transform_message(
message_stream,
enhanced_message_stream(),
{
"icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
},
parameters_for_log,
agent_thoughts,
)
except PluginDaemonClientSideError as e:
yield RunCompletedEvent(

@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -188,6 +188,7 @@ class ToolNode(BaseNode[ToolNodeData]):
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
agent_thoughts: Optional[list] = None,
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
@ -365,10 +366,41 @@ class ToolNode(BaseNode[ToolNodeData]):
yield agent_log
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output = json.copy()
if agent_logs:
if not json_output:
json_output = {}
elif isinstance(json_output, list) and len(json_output) == 1:
# If json is a list with only one element, convert it to a dictionary
json_output = json_output[0] if isinstance(json_output[0], dict) else {"data": json_output[0]}
elif isinstance(json_output, list):
# If json is a list with multiple elements, create a dictionary containing all data
json_output = {"data": json_output}
# Ensure json_output is a dictionary type
if not isinstance(json_output, dict):
json_output = {"data": json_output}
# Add agent_logs to json output
json_output["agent_logs"] = [
{
"id": log.id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
for log in agent_logs
]
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json, **variables},
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,

@ -29,7 +29,7 @@ class EnterpriseService:
raise ValueError("No data found.")
try:
# parse the UTC timestamp from the response
return datetime.fromisoformat(data.replace("Z", "+00:00"))
return datetime.fromisoformat(data)
except ValueError as e:
raise ValueError(f"Invalid date format: {data}") from e
@ -40,7 +40,7 @@ class EnterpriseService:
raise ValueError("No data found.")
try:
# parse the UTC timestamp from the response
return datetime.fromisoformat(data.replace("Z", "+00:00"))
return datetime.fromisoformat(data)
except ValueError as e:
raise ValueError(f"Invalid date format: {data}") from e

@ -119,11 +119,11 @@ def test_execute_llm(flask_req_ctx):
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1000"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
@ -222,11 +222,11 @@ def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1000"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),

@ -0,0 +1,19 @@
from typing import Optional, Any
from pydantic import BaseModel, Field, model_validator
from core.prompt.entities.role_prefix import RolePrefix
from core.prompt.entities.window import Window
class MemoryConfig(BaseModel):
role_prefix: RolePrefix = Field(default_factory=RolePrefix)
window: Window = Field(default_factory=Window)
memory_key: Optional[str] = Field(None)
# The `model_validate` method is used to create a `MemoryConfig` object from a dictionary.
@model_validator(mode="before")
@classmethod
def pre_validate(cls, values: Any) -> Any:
if "role_prefix" not in values:
values["role_prefix"] = {}
if "window" not in values:
values["window"] = {}
return values

@ -0,0 +1,60 @@
\# For production release, change this to PRODUCTION
NEXT_PUBLIC_DEPLOY_ENV=DEVELOPMENT
# The deployment edition, SELF_HOSTED
NEXT_PUBLIC_EDITION=SELF_HOSTED
# The base URL of console application, refers to the Console base URL of WEB service if console domain is
# different from api or web app domain.
# example: http://cloud.dify.ai/console/api
NEXT_PUBLIC_API_PREFIX=http://localhost:5001/console/api
# The URL for Web APP, refers to the Web App base URL of WEB service if web app domain is different from
# console or api domain.
# example: http://udify.app/api
NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api
# The API PREFIX for MARKETPLACE
NEXT_PUBLIC_MARKETPLACE_API_PREFIX=https://marketplace.dify.ai/api/v1
# The URL for MARKETPLACE
NEXT_PUBLIC_MARKETPLACE_URL_PREFIX=https://marketplace.dify.ai
# SENTRY
NEXT_PUBLIC_SENTRY_DSN=
# Disable Next.js Telemetry (https://nextjs.org/telemetry)
NEXT_TELEMETRY_DISABLED=1
# Disable Upload Image as WebApp icon default is false
NEXT_PUBLIC_UPLOAD_IMAGE_AS_ICON=false
# The timeout for the text generation in millisecond
NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS=60000
# CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP
NEXT_PUBLIC_CSP_WHITELIST=
# Default is not allow to embed into iframe to prevent Clickjacking: https://owasp.org/www-community/attacks/Clickjacking
NEXT_PUBLIC_ALLOW_EMBED=
# Github Access Token, used for invoking Github API
NEXT_PUBLIC_GITHUB_ACCESS_TOKEN=
# The maximum number of top-k value for RAG.
NEXT_PUBLIC_TOP_K_MAX_VALUE=10
# The maximum number of tokens for segmentation
NEXT_PUBLIC_INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
# Maximum loop count in the workflow
NEXT_PUBLIC_LOOP_NODE_MAX_COUNT=100
# Maximum number of tools in the agent/workflow
NEXT_PUBLIC_MAX_TOOLS_NUM=10
# Maximum number of Parallelism branches in the workflow
NEXT_PUBLIC_MAX_PARALLEL_LIMIT=10
# The maximum number of iterations for agent setting
NEXT_PUBLIC_MAX_ITERATIONS_NUM=99
NEXT_PUBLIC_ENABLE_WEBSITE_JINAREADER=true
NEXT_PUBLIC_ENABLE_WEBSITE_FIRECRAWL=true
NEXT_PUBLIC_ENABLE_WEBSITE_WATERCRAWL=true
# The maximum number of tree node depth for workflow
NEXT_PUBLIC_MAX_TREE_DEPTH=50
Loading…
Cancel
Save