fix(agent): show agent run steps and fix style issues

pull/21899/head
baonudesifeizhai 11 months ago
parent 317d287458
commit 0057d5f72a

5
.gitignore vendored

@ -215,3 +215,8 @@ mise.toml
# AI Assistant # AI Assistant
.roo/ .roo/
api/.env.backup api/.env.backup
# custom untracked files
venv312/
web/.env.local.save
core/

@ -1,4 +1,5 @@
import json import json
import uuid
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast from typing import Any, Optional, cast
@ -102,14 +103,46 @@ class AgentNode(ToolNode):
try: try:
# convert tool messages # convert tool messages
agent_thoughts = []
from core.tools.entities.tool_entities import ToolInvokeMessage
thought_log_message = ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LOG,
message=ToolInvokeMessage.LogMessage(
id=str(uuid.uuid4()),
label=f"Agent Strategy: {cast(AgentNodeData, self.node_data).agent_strategy_name}",
parent_id=None,
error=None,
status=ToolInvokeMessage.LogMessage.LogStatus.START,
data={
"strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
"parameters": parameters_for_log,
"thought_process": "Agent strategy execution started",
},
metadata={
"icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
},
),
)
from core.tools.entities.tool_entities import ToolInvokeMessage
def enhanced_message_stream():
yield thought_log_message
yield from message_stream
yield from self._transform_message( yield from self._transform_message(
message_stream, enhanced_message_stream(),
{ {
"icon": self.agent_strategy_icon, "icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name, "agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
}, },
parameters_for_log, parameters_for_log,
agent_thoughts,
) )
except PluginDaemonClientSideError as e: except PluginDaemonClientSideError as e:
yield RunCompletedEvent( yield RunCompletedEvent(

@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast from typing import Any, Optional, cast
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -191,6 +191,7 @@ class ToolNode(BaseNode[ToolNodeData]):
messages: Generator[ToolInvokeMessage, None, None], messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any], tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any], parameters_for_log: dict[str, Any],
agent_thoughts: Optional[list] = None,
) -> Generator: ) -> Generator:
""" """
Convert ToolInvokeMessages into tuple[plain_text, files] Convert ToolInvokeMessages into tuple[plain_text, files]
@ -369,10 +370,41 @@ class ToolNode(BaseNode[ToolNodeData]):
yield agent_log yield agent_log
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output = json.copy()
if agent_logs:
if not json_output:
json_output = {}
elif isinstance(json_output, list) and len(json_output) == 1:
# If json is a list with only one element, convert it to a dictionary
json_output = json_output[0] if isinstance(json_output[0], dict) else {"data": json_output[0]}
elif isinstance(json_output, list):
# If json is a list with multiple elements, create a dictionary containing all data
json_output = {"data": json_output}
# Ensure json_output is a dictionary type
if not isinstance(json_output, dict):
json_output = {"data": json_output}
# Add agent_logs to json output
json_output["agent_logs"] = [
{
"id": log.id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
for log in agent_logs
]
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json, **variables}, outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata={ metadata={
**agent_execution_metadata, **agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,

@ -29,7 +29,7 @@ class EnterpriseService:
raise ValueError("No data found.") raise ValueError("No data found.")
try: try:
# parse the UTC timestamp from the response # parse the UTC timestamp from the response
return datetime.fromisoformat(data.replace("Z", "+00:00")) return datetime.fromisoformat(data)
except ValueError as e: except ValueError as e:
raise ValueError(f"Invalid date format: {data}") from e raise ValueError(f"Invalid date format: {data}") from e
@ -40,7 +40,7 @@ class EnterpriseService:
raise ValueError("No data found.") raise ValueError("No data found.")
try: try:
# parse the UTC timestamp from the response # parse the UTC timestamp from the response
return datetime.fromisoformat(data.replace("Z", "+00:00")) return datetime.fromisoformat(data)
except ValueError as e: except ValueError as e:
raise ValueError(f"Invalid date format: {data}") from e raise ValueError(f"Invalid date format: {data}") from e

@ -119,11 +119,11 @@ def test_execute_llm(flask_req_ctx):
mock_usage = LLMUsage( mock_usage = LLMUsage(
prompt_tokens=30, prompt_tokens=30,
prompt_unit_price=Decimal("0.001"), prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1000"), prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"), prompt_price=Decimal("0.00003"),
completion_tokens=20, completion_tokens=20,
completion_unit_price=Decimal("0.002"), completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"), completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"), completion_price=Decimal("0.00004"),
total_tokens=50, total_tokens=50,
total_price=Decimal("0.00007"), total_price=Decimal("0.00007"),
@ -222,11 +222,11 @@ def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
mock_usage = LLMUsage( mock_usage = LLMUsage(
prompt_tokens=30, prompt_tokens=30,
prompt_unit_price=Decimal("0.001"), prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1000"), prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"), prompt_price=Decimal("0.00003"),
completion_tokens=20, completion_tokens=20,
completion_unit_price=Decimal("0.002"), completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"), completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"), completion_price=Decimal("0.00004"),
total_tokens=50, total_tokens=50,
total_price=Decimal("0.00007"), total_price=Decimal("0.00007"),

@ -27,11 +27,11 @@ def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LL
return LLMUsage( return LLMUsage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
prompt_unit_price=Decimal("0.001"), prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1"), prompt_price_unit=Decimal(1),
prompt_price=Decimal(str(prompt_tokens)) * Decimal("0.001"), prompt_price=Decimal(str(prompt_tokens)) * Decimal("0.001"),
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
completion_unit_price=Decimal("0.002"), completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1"), completion_price_unit=Decimal(1),
completion_price=Decimal(str(completion_tokens)) * Decimal("0.002"), completion_price=Decimal(str(completion_tokens)) * Decimal("0.002"),
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
total_price=Decimal(str(prompt_tokens)) * Decimal("0.001") + Decimal(str(completion_tokens)) * Decimal("0.002"), total_price=Decimal(str(prompt_tokens)) * Decimal("0.001") + Decimal(str(completion_tokens)) * Decimal("0.002"),

@ -0,0 +1,19 @@
from typing import Optional, Any
from pydantic import BaseModel, Field, model_validator
from core.prompt.entities.role_prefix import RolePrefix
from core.prompt.entities.window import Window
class MemoryConfig(BaseModel):
role_prefix: RolePrefix = Field(default_factory=RolePrefix)
window: Window = Field(default_factory=Window)
memory_key: Optional[str] = Field(None)
# The `model_validate` method is used to create a `MemoryConfig` object from a dictionary.
@model_validator(mode="before")
@classmethod
def pre_validate(cls, values: Any) -> Any:
if "role_prefix" not in values:
values["role_prefix"] = {}
if "window" not in values:
values["window"] = {}
return values
Loading…
Cancel
Save