Allow Variable Aggregator to aggregate all

pull/22200/head
lizb 11 months ago
parent f929bfb94c
commit ef1ef08780

@ -1,6 +1,6 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any
from core.variables.segments import Segment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
@ -8,8 +8,12 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): class VariableAggregatorNodeData(VariableAssignerNodeData):
_node_data_cls = VariableAssignerNodeData aggregate_all: bool = False
class VariableAggregatorNode(BaseNode[VariableAggregatorNodeData]):
_node_data_cls = VariableAggregatorNodeData
_node_type = NodeType.VARIABLE_AGGREGATOR _node_type = NodeType.VARIABLE_AGGREGATOR
@classmethod @classmethod
@ -18,25 +22,50 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
# Get variables # Get variables
outputs: dict[str, Segment | Mapping[str, Segment]] = {} outputs: dict[str, Any] = {}
inputs = {} inputs: dict[str, Any] = {}
# if aggregate_all is not configured, aggregate only first variables
if not getattr(self.node_data, "aggregate_all", False):
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
for selector in self.node_data.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None:
outputs = {"output": variable}
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled: inputs = {".".join(selector[1:]): variable.to_object()}
for selector in self.node_data.variables: break
variable = self.graph_runtime_state.variable_pool.get(selector) else:
if variable is not None: for group in self.node_data.advanced_settings.groups:
outputs = {"output": variable} for selector in group.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)
inputs = {".".join(selector[1:]): variable.to_object()} if variable is not None:
break outputs[group.group_name] = {"output": variable}
inputs[".".join(selector[1:])] = variable.to_object()
break
else: else:
for group in self.node_data.advanced_settings.groups: # if aggregate_all is configured, aggregate all variables
for selector in group.variables: if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
aggregated_values = []
for selector in self.node_data.variables:
variable = self.graph_runtime_state.variable_pool.get(selector) variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None: if variable is not None:
outputs[group.group_name] = {"output": variable} aggregated_values.append(variable.to_object())
inputs[".".join(selector[1:])] = variable.to_object() inputs[".".join(selector[1:])] = variable.to_object()
break
if aggregated_values:
outputs = {"output": aggregated_values}
else:
for group in self.node_data.advanced_settings.groups:
aggregated_values = []
for selector in group.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None:
aggregated_values.append(variable.to_object())
inputs[".".join(selector[1:])] = variable.to_object()
if aggregated_values:
outputs[group.group_name] = {"output": aggregated_values}
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs) return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs)

@ -35,6 +35,7 @@ const Panel: FC<NodePanelProps<VariableAssignerNodeType>> = ({
onRemoveVarConfirm, onRemoveVarConfirm,
getAvailableVars, getAvailableVars,
filterVar, filterVar,
handleAggregateAllChange,
} = useConfig(id, data) } = useConfig(id, data)
return ( return (
@ -95,26 +96,41 @@ const Panel: FC<NodePanelProps<VariableAssignerNodeType>> = ({
/> />
} }
/> />
<Field
title={t(`${i18nPrefix}.aggregateAll`)}
tooltip={t(`${i18nPrefix}.aggregateAllTip`)!}
operations={
<Switch
defaultValue={inputs.aggregate_all}
onChange={handleAggregateAllChange}
size='md'
disabled={readOnly}
/>
}
/>
</div> </div>
{isEnableGroup && ( <Split />
<OutputVars>
<> <>
<Split /> {isEnableGroup
<OutputVars> ? inputs.advanced_settings?.groups.map((item, index) => (
<> <VarItem
{inputs.advanced_settings?.groups.map((item, index) => ( key={index}
<VarItem name={`${item.group_name}.output`}
key={index} type={item.output_type}
name={`${item.group_name}.output`} description={t(`${i18nPrefix}.outputVars.varDescribe`, {
type={item.output_type} groupName: item.group_name,
description={t(`${i18nPrefix}.outputVars.varDescribe`, { })}
groupName: item.group_name, />
})} ))
/> : (
))} <VarItem
</> name='output'
</OutputVars> type={inputs.output_type}
/>
)}
</> </>
)} </OutputVars>
<RemoveEffectVarConfirm <RemoveEffectVarConfirm
isShow={isShowRemoveVarConfirm} isShow={isShowRemoveVarConfirm}
onCancel={hideRemoveVarConfirm} onCancel={hideRemoveVarConfirm}

@ -5,6 +5,7 @@ export type VarGroupItem = {
variables: ValueSelector[] variables: ValueSelector[]
} }
export type VariableAssignerNodeType = CommonNodeType & VarGroupItem & { export type VariableAssignerNodeType = CommonNodeType & VarGroupItem & {
aggregate_all?: boolean
advanced_settings: { advanced_settings: {
group_enabled: boolean group_enabled: boolean
groups: ({ groups: ({

@ -33,6 +33,13 @@ const useConfig = (id: string, payload: VariableAssignerNodeType) => {
}) })
}, [inputs, setInputs]) }, [inputs, setInputs])
const handleAggregateAllChange = useCallback((checked: boolean) => {
const newInputs = produce(inputs, (draft) => {
draft.aggregate_all = checked
})
setInputs(newInputs)
}, [inputs, setInputs])
const handleListOrTypeChangeInGroup = useCallback((groupId: string) => { const handleListOrTypeChangeInGroup = useCallback((groupId: string) => {
return (payload: VarGroupItem) => { return (payload: VarGroupItem) => {
const index = inputs.advanced_settings.groups.findIndex(item => item.groupId === groupId) const index = inputs.advanced_settings.groups.findIndex(item => item.groupId === groupId)
@ -208,6 +215,7 @@ const useConfig = (id: string, payload: VariableAssignerNodeType) => {
onRemoveVarConfirm, onRemoveVarConfirm,
getAvailableVars, getAvailableVars,
filterVar, filterVar,
handleAggregateAllChange,
} }
} }

@ -626,6 +626,8 @@ const translation = {
}, },
aggregationGroup: 'Aggregation Group', aggregationGroup: 'Aggregation Group',
aggregationGroupTip: 'Enabling this feature allows the variable aggregator to aggregate multiple sets of variables.', aggregationGroupTip: 'Enabling this feature allows the variable aggregator to aggregate multiple sets of variables.',
aggregateAll: 'Aggregate All',
aggregateAllTip: 'Enabling this feature allows the variable aggregator to aggregate all input variables into an array, otherwise it will aggregate first arrival variables into the array.',
addGroup: 'Add Group', addGroup: 'Add Group',
outputVars: { outputVars: {
varDescribe: '{{groupName}} output', varDescribe: '{{groupName}} output',

@ -627,6 +627,8 @@ const translation = {
}, },
aggregationGroup: '聚合分组', aggregationGroup: '聚合分组',
aggregationGroupTip: '开启该功能后,变量聚合器内可以同时聚合多组变量', aggregationGroupTip: '开启该功能后,变量聚合器内可以同时聚合多组变量',
aggregateAll: '聚合所有',
aggregateAllTip: '开启该功能后,变量聚合器将聚合所有输入变量到一个数组中,否则只会聚合第一个变量',
addGroup: '添加分组', addGroup: '添加分组',
outputVars: { outputVars: {
varDescribe: '{{groupName}}的输出变量', varDescribe: '{{groupName}}的输出变量',

Loading…
Cancel
Save