Allow Variable Aggregator to aggregate all

pull/22200/head
lizb 11 months ago
parent f929bfb94c
commit ef1ef08780

@ -1,6 +1,6 @@
from collections.abc import Mapping
from typing import Any
from core.variables.segments import Segment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
@ -8,8 +8,12 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
_node_data_cls = VariableAssignerNodeData
class VariableAggregatorNodeData(VariableAssignerNodeData):
aggregate_all: bool = False
class VariableAggregatorNode(BaseNode[VariableAggregatorNodeData]):
_node_data_cls = VariableAggregatorNodeData
_node_type = NodeType.VARIABLE_AGGREGATOR
@classmethod
@ -18,25 +22,50 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
def _run(self) -> NodeRunResult:
# Get variables
outputs: dict[str, Segment | Mapping[str, Segment]] = {}
inputs = {}
outputs: dict[str, Any] = {}
inputs: dict[str, Any] = {}
# if aggregate_all is not configured, aggregate only first variables
if not getattr(self.node_data, "aggregate_all", False):
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
for selector in self.node_data.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None:
outputs = {"output": variable}
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
for selector in self.node_data.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None:
outputs = {"output": variable}
inputs = {".".join(selector[1:]): variable.to_object()}
break
else:
for group in self.node_data.advanced_settings.groups:
for selector in group.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)
inputs = {".".join(selector[1:]): variable.to_object()}
break
if variable is not None:
outputs[group.group_name] = {"output": variable}
inputs[".".join(selector[1:])] = variable.to_object()
break
else:
for group in self.node_data.advanced_settings.groups:
for selector in group.variables:
# if aggregate_all is configured, aggregate all variables
if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
aggregated_values = []
for selector in self.node_data.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None:
outputs[group.group_name] = {"output": variable}
aggregated_values.append(variable.to_object())
inputs[".".join(selector[1:])] = variable.to_object()
break
if aggregated_values:
outputs = {"output": aggregated_values}
else:
for group in self.node_data.advanced_settings.groups:
aggregated_values = []
for selector in group.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None:
aggregated_values.append(variable.to_object())
inputs[".".join(selector[1:])] = variable.to_object()
if aggregated_values:
outputs[group.group_name] = {"output": aggregated_values}
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs)

@ -35,6 +35,7 @@ const Panel: FC<NodePanelProps<VariableAssignerNodeType>> = ({
onRemoveVarConfirm,
getAvailableVars,
filterVar,
handleAggregateAllChange,
} = useConfig(id, data)
return (
@ -95,26 +96,41 @@ const Panel: FC<NodePanelProps<VariableAssignerNodeType>> = ({
/>
}
/>
<Field
title={t(`${i18nPrefix}.aggregateAll`)}
tooltip={t(`${i18nPrefix}.aggregateAllTip`)!}
operations={
<Switch
defaultValue={inputs.aggregate_all}
onChange={handleAggregateAllChange}
size='md'
disabled={readOnly}
/>
}
/>
</div>
{isEnableGroup && (
<Split />
<OutputVars>
<>
<Split />
<OutputVars>
<>
{inputs.advanced_settings?.groups.map((item, index) => (
<VarItem
key={index}
name={`${item.group_name}.output`}
type={item.output_type}
description={t(`${i18nPrefix}.outputVars.varDescribe`, {
groupName: item.group_name,
})}
/>
))}
</>
</OutputVars>
{isEnableGroup
? inputs.advanced_settings?.groups.map((item, index) => (
<VarItem
key={index}
name={`${item.group_name}.output`}
type={item.output_type}
description={t(`${i18nPrefix}.outputVars.varDescribe`, {
groupName: item.group_name,
})}
/>
))
: (
<VarItem
name='output'
type={inputs.output_type}
/>
)}
</>
)}
</OutputVars>
<RemoveEffectVarConfirm
isShow={isShowRemoveVarConfirm}
onCancel={hideRemoveVarConfirm}

@ -5,6 +5,7 @@ export type VarGroupItem = {
variables: ValueSelector[]
}
export type VariableAssignerNodeType = CommonNodeType & VarGroupItem & {
aggregate_all?: boolean
advanced_settings: {
group_enabled: boolean
groups: ({

@ -33,6 +33,13 @@ const useConfig = (id: string, payload: VariableAssignerNodeType) => {
})
}, [inputs, setInputs])
const handleAggregateAllChange = useCallback((checked: boolean) => {
const newInputs = produce(inputs, (draft) => {
draft.aggregate_all = checked
})
setInputs(newInputs)
}, [inputs, setInputs])
const handleListOrTypeChangeInGroup = useCallback((groupId: string) => {
return (payload: VarGroupItem) => {
const index = inputs.advanced_settings.groups.findIndex(item => item.groupId === groupId)
@ -208,6 +215,7 @@ const useConfig = (id: string, payload: VariableAssignerNodeType) => {
onRemoveVarConfirm,
getAvailableVars,
filterVar,
handleAggregateAllChange,
}
}

@ -626,6 +626,8 @@ const translation = {
},
aggregationGroup: 'Aggregation Group',
aggregationGroupTip: 'Enabling this feature allows the variable aggregator to aggregate multiple sets of variables.',
aggregateAll: 'Aggregate All',
aggregateAllTip: 'Enabling this feature allows the variable aggregator to aggregate all input variables into an array, otherwise it will aggregate first arrival variables into the array.',
addGroup: 'Add Group',
outputVars: {
varDescribe: '{{groupName}} output',

@ -627,6 +627,8 @@ const translation = {
},
aggregationGroup: '聚合分组',
aggregationGroupTip: '开启该功能后,变量聚合器内可以同时聚合多组变量',
aggregateAll: '聚合所有',
aggregateAllTip: '开启该功能后,变量聚合器将聚合所有输入变量到一个数组中,否则只会聚合第一个变量',
addGroup: '添加分组',
outputVars: {
varDescribe: '{{groupName}}的输出变量',

Loading…
Cancel
Save