Merge branch 'langgenius:main' into bug/ui/show-error-info-for-missing-env-and-conv-var-in-prompt-editor

pull/21802/head
Minamiyama 11 months ago committed by GitHub
commit 66714d11bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -87,7 +87,5 @@ class PluginUploadFileApi(Resource):
except services.errors.file.UnsupportedFileTypeError: except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
return tool_file, 201
api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin") api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin")

@ -317,8 +317,9 @@ class IndexingRunner:
image_upload_file_ids = get_image_upload_file_ids(document.page_content) image_upload_file_ids = get_image_upload_file_ids(document.page_content)
for upload_file_id in image_upload_file_ids: for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
if image_file is None:
continue
try: try:
if image_file:
storage.delete(image_file.key) storage.delete(image_file.key)
except Exception: except Exception:
logging.exception( logging.exception(

@ -41,7 +41,7 @@ GEN_AI_PROMPT_TEMPLATE_VARIABLE = "gen_ai.prompt_template.variable"
GEN_AI_PROMPT = "gen_ai.prompt" GEN_AI_PROMPT = "gen_ai.prompt"
GEN_AI_COMPLETION = "gem_ai.completion" GEN_AI_COMPLETION = "gen_ai.completion"
GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason" GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"

@ -103,7 +103,7 @@ class GraphEngine:
call_depth: int, call_depth: int,
graph: Graph, graph: Graph,
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
variable_pool: VariablePool, graph_runtime_state: GraphRuntimeState,
max_execution_steps: int, max_execution_steps: int,
max_execution_time: int, max_execution_time: int,
thread_pool_id: Optional[str] = None, thread_pool_id: Optional[str] = None,
@ -140,7 +140,7 @@ class GraphEngine:
call_depth=call_depth, call_depth=call_depth,
) )
self.graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) self.graph_runtime_state = graph_runtime_state
self.max_execution_steps = max_execution_steps self.max_execution_steps = max_execution_steps
self.max_execution_time = max_execution_time self.max_execution_time = max_execution_time

@ -1,5 +1,6 @@
import contextvars import contextvars
import logging import logging
import time
import uuid import uuid
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, wait from concurrent.futures import Future, wait
@ -133,8 +134,11 @@ class IterationNode(BaseNode[IterationNodeData]):
variable_pool.add([self.node_id, "item"], iterator_list_value[0]) variable_pool.add([self.node_id, "item"], iterator_list_value[0])
# init graph engine # init graph engine
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
app_id=self.app_id, app_id=self.app_id,
@ -146,7 +150,7 @@ class IterationNode(BaseNode[IterationNodeData]):
call_depth=self.workflow_call_depth, call_depth=self.workflow_call_depth,
graph=iteration_graph, graph=iteration_graph,
graph_config=graph_config, graph_config=graph_config,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id, thread_pool_id=self.thread_pool_id,

@ -1,5 +1,6 @@
import json import json
import logging import logging
import time
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Literal, cast from typing import TYPE_CHECKING, Any, Literal, cast
@ -101,8 +102,11 @@ class LoopNode(BaseNode[LoopNodeData]):
loop_variable_selectors[loop_variable.label] = variable_selector loop_variable_selectors[loop_variable.label] = variable_selector
inputs[loop_variable.label] = processed_segment.value inputs[loop_variable.label] = processed_segment.value
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.graph_engine.graph_engine import GraphEngine
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
app_id=self.app_id, app_id=self.app_id,
@ -114,7 +118,7 @@ class LoopNode(BaseNode[LoopNodeData]):
call_depth=self.workflow_call_depth, call_depth=self.workflow_call_depth,
graph=loop_graph, graph=loop_graph,
graph_config=self.graph_config, graph_config=self.graph_config,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id, thread_pool_id=self.thread_pool_id,

@ -69,6 +69,7 @@ class WorkflowEntry:
raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth)) raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth))
# init workflow run state # init workflow run state
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
self.graph_engine = GraphEngine( self.graph_engine = GraphEngine(
tenant_id=tenant_id, tenant_id=tenant_id,
app_id=app_id, app_id=app_id,
@ -80,7 +81,7 @@ class WorkflowEntry:
call_depth=call_depth, call_depth=call_depth,
graph=graph, graph=graph,
graph_config=graph_config, graph_config=graph_config,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=thread_pool_id, thread_pool_id=thread_pool_id,

@ -1,3 +1,4 @@
import time
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -19,6 +20,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent, NodeRunSucceededEvent,
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.code_node import CodeNode
@ -172,6 +174,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"} system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
) )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id="111", tenant_id="111",
app_id="222", app_id="222",
@ -183,7 +186,7 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph, graph=graph,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )
@ -299,6 +302,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
user_inputs={}, user_inputs={},
) )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id="111", tenant_id="111",
app_id="222", app_id="222",
@ -310,7 +314,7 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph, graph=graph,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )
@ -479,6 +483,7 @@ def test_run_branch(mock_close, mock_remove):
user_inputs={"uid": "takato"}, user_inputs={"uid": "takato"},
) )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id="111", tenant_id="111",
app_id="222", app_id="222",
@ -490,7 +495,7 @@ def test_run_branch(mock_close, mock_remove):
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph, graph=graph,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )
@ -813,6 +818,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"} system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
) )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id="111", tenant_id="111",
app_id="222", app_id="222",
@ -824,7 +830,7 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph, graph=graph,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )

@ -1,7 +1,9 @@
import time
from unittest.mock import patch from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.event import (
@ -11,6 +13,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunStreamChunkEvent, NodeRunStreamChunkEvent,
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode from core.workflow.nodes.llm.node import LLMNode
@ -163,15 +166,16 @@ class ContinueOnErrorTestHelper:
def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None): def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None):
"""Helper method to create a graph engine instance for testing""" """Helper method to create a graph engine instance for testing"""
graph = Graph.init(graph_config=graph_config) graph = Graph.init(graph_config=graph_config)
variable_pool = { variable_pool = VariablePool(
"system_variables": { system_variables={
SystemVariableKey.QUERY: "clear", SystemVariableKey.QUERY: "clear",
SystemVariableKey.FILES: [], SystemVariableKey.FILES: [],
SystemVariableKey.CONVERSATION_ID: "abababa", SystemVariableKey.CONVERSATION_ID: "abababa",
SystemVariableKey.USER_ID: "aaa", SystemVariableKey.USER_ID: "aaa",
}, },
"user_inputs": user_inputs or {"uid": "takato"}, user_inputs=user_inputs or {"uid": "takato"},
} )
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
return GraphEngine( return GraphEngine(
tenant_id="111", tenant_id="111",
@ -184,7 +188,7 @@ class ContinueOnErrorTestHelper:
invoke_from=InvokeFrom.WEB_APP, invoke_from=InvokeFrom.WEB_APP,
call_depth=0, call_depth=0,
graph=graph, graph=graph,
variable_pool=variable_pool, graph_runtime_state=graph_runtime_state,
max_execution_steps=500, max_execution_steps=500,
max_execution_time=1200, max_execution_time=1200,
) )

@ -958,7 +958,7 @@ NGINX_SSL_PROTOCOLS=TLSv1.1 TLSv1.2 TLSv1.3
# Nginx performance tuning # Nginx performance tuning
NGINX_WORKER_PROCESSES=auto NGINX_WORKER_PROCESSES=auto
NGINX_CLIENT_MAX_BODY_SIZE=15M NGINX_CLIENT_MAX_BODY_SIZE=100M
NGINX_KEEPALIVE_TIMEOUT=65 NGINX_KEEPALIVE_TIMEOUT=65
# Proxy settings # Proxy settings

@ -265,7 +265,7 @@ services:
NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key} NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key}
NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3} NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3}
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto} NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-15M} NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M}
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65} NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}
NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s} NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s}
NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s} NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s}

@ -420,7 +420,7 @@ x-shared-env: &shared-api-worker-env
NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key} NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key}
NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3} NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3}
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto} NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-15M} NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M}
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65} NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}
NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s} NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s}
NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s} NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s}
@ -780,7 +780,7 @@ services:
NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key} NGINX_SSL_CERT_KEY_FILENAME: ${NGINX_SSL_CERT_KEY_FILENAME:-dify.key}
NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3} NGINX_SSL_PROTOCOLS: ${NGINX_SSL_PROTOCOLS:-TLSv1.1 TLSv1.2 TLSv1.3}
NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto} NGINX_WORKER_PROCESSES: ${NGINX_WORKER_PROCESSES:-auto}
NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-15M} NGINX_CLIENT_MAX_BODY_SIZE: ${NGINX_CLIENT_MAX_BODY_SIZE:-100M}
NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65} NGINX_KEEPALIVE_TIMEOUT: ${NGINX_KEEPALIVE_TIMEOUT:-65}
NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s} NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s}
NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s} NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s}

@ -179,7 +179,7 @@ const ConfigPopup: FC<PopupProps> = ({
onConfig={handleOnConfig(TracingProvider.aliyun)} onConfig={handleOnConfig(TracingProvider.aliyun)}
isChosen={chosenProvider === TracingProvider.aliyun} isChosen={chosenProvider === TracingProvider.aliyun}
onChoose={handleOnChoose(TracingProvider.aliyun)} onChoose={handleOnChoose(TracingProvider.aliyun)}
key="alyun-provider-panel" key="aliyun-provider-panel"
/> />
) )
const configuredProviderPanel = () => { const configuredProviderPanel = () => {

@ -72,7 +72,7 @@ const ProviderList = () => {
className='relative flex grow flex-col overflow-y-auto bg-background-body' className='relative flex grow flex-col overflow-y-auto bg-background-body'
> >
<div className={cn( <div className={cn(
'sticky top-0 z-20 flex flex-wrap items-center justify-between gap-y-2 bg-background-body px-12 pb-2 pt-4 leading-[56px]', 'sticky top-0 z-10 flex flex-wrap items-center justify-between gap-y-2 bg-background-body px-12 pb-2 pt-4 leading-[56px]',
currentProviderId && 'pr-6', currentProviderId && 'pr-6',
)}> )}>
<TabSliderNew <TabSliderNew

@ -9,6 +9,7 @@ import type {
CommonNodeType, CommonNodeType,
Edge, Edge,
Node, Node,
ValueSelector,
} from '../types' } from '../types'
import { BlockEnum } from '../types' import { BlockEnum } from '../types'
import { useStore } from '../store' import { useStore } from '../store'
@ -33,6 +34,8 @@ import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/ty
import type { DataSet } from '@/models/datasets' import type { DataSet } from '@/models/datasets'
import { fetchDatasets } from '@/service/datasets' import { fetchDatasets } from '@/service/datasets'
import { MAX_TREE_DEPTH } from '@/config' import { MAX_TREE_DEPTH } from '@/config'
import useNodesAvailableVarList from './use-nodes-available-var-list'
import { getNodeUsedVars, isConversationVar, isENV, isSystemVar } from '../nodes/_base/components/variable/utils'
export const useChecklist = (nodes: Node[], edges: Edge[]) => { export const useChecklist = (nodes: Node[], edges: Edge[]) => {
const { t } = useTranslation() const { t } = useTranslation()
@ -45,6 +48,8 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
const { data: strategyProviders } = useStrategyProviders() const { data: strategyProviders } = useStrategyProviders()
const datasetsDetail = useDatasetsDetailStore(s => s.datasetsDetail) const datasetsDetail = useDatasetsDetailStore(s => s.datasetsDetail)
const map = useNodesAvailableVarList(nodes)
const getCheckData = useCallback((data: CommonNodeType<{}>) => { const getCheckData = useCallback((data: CommonNodeType<{}>) => {
let checkData = data let checkData = data
if (data.type === BlockEnum.KnowledgeRetrieval) { if (data.type === BlockEnum.KnowledgeRetrieval) {
@ -70,6 +75,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
const node = nodes[i] const node = nodes[i]
let toolIcon let toolIcon
let moreDataForCheckValid let moreDataForCheckValid
let usedVars: ValueSelector[] = []
if (node.data.type === BlockEnum.Tool) { if (node.data.type === BlockEnum.Tool) {
const { provider_type } = node.data const { provider_type } = node.data
@ -84,8 +90,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
if (provider_type === CollectionType.workflow) if (provider_type === CollectionType.workflow)
toolIcon = workflowTools.find(tool => tool.id === node.data.provider_id)?.icon toolIcon = workflowTools.find(tool => tool.id === node.data.provider_id)?.icon
} }
else if (node.data.type === BlockEnum.Agent) {
if (node.data.type === BlockEnum.Agent) {
const data = node.data as AgentNodeType const data = node.data as AgentNodeType
const isReadyForCheckValid = !!strategyProviders const isReadyForCheckValid = !!strategyProviders
const provider = strategyProviders?.find(provider => provider.declaration.identity.name === data.agent_strategy_provider_name) const provider = strategyProviders?.find(provider => provider.declaration.identity.name === data.agent_strategy_provider_name)
@ -97,10 +102,34 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
isReadyForCheckValid, isReadyForCheckValid,
} }
} }
else {
usedVars = getNodeUsedVars(node).filter(v => v.length > 0)
}
if (node.type === CUSTOM_NODE) { if (node.type === CUSTOM_NODE) {
const checkData = getCheckData(node.data) const checkData = getCheckData(node.data)
const { errorMessage } = nodesExtraData[node.data.type].checkValid(checkData, t, moreDataForCheckValid) let { errorMessage } = nodesExtraData[node.data.type].checkValid(checkData, t, moreDataForCheckValid)
if (!errorMessage) {
const availableVars = map[node.id].availableVars
for (const variable of usedVars) {
const isEnv = isENV(variable)
const isConvVar = isConversationVar(variable)
const isSysVar = isSystemVar(variable)
if (!isEnv && !isConvVar && !isSysVar) {
const usedNode = availableVars.find(v => v.nodeId === variable?.[0])
if (usedNode) {
const usedVar = usedNode.vars.find(v => v.variable === variable?.[1])
if (!usedVar)
errorMessage = t('workflow.errorMsg.invalidVariable')
}
else {
errorMessage = t('workflow.errorMsg.invalidVariable')
}
}
}
}
if (errorMessage || !validNodes.find(n => n.id === node.id)) { if (errorMessage || !validNodes.find(n => n.id === node.id)) {
list.push({ list.push({

@ -0,0 +1,75 @@
import {
useIsChatMode,
useWorkflow,
useWorkflowVariables,
} from '@/app/components/workflow/hooks'
import { BlockEnum, type Node, type NodeOutPutVar, type ValueSelector, type Var } from '@/app/components/workflow/types'
type Params = {
onlyLeafNodeVar?: boolean
hideEnv?: boolean
hideChatVar?: boolean
filterVar: (payload: Var, selector: ValueSelector) => boolean
passedInAvailableNodes?: Node[]
}
const getNodeInfo = (nodeId: string, nodes: Node[]) => {
const allNodes = nodes
const node = allNodes.find(n => n.id === nodeId)
const isInIteration = !!node?.data.isInIteration
const isInLoop = !!node?.data.isInLoop
const parentNodeId = node?.parentId
const parentNode = allNodes.find(n => n.id === parentNodeId)
return {
node,
isInIteration,
isInLoop,
parentNode,
}
}
// TODO: loop type?
const useNodesAvailableVarList = (nodes: Node[], {
onlyLeafNodeVar,
filterVar,
hideEnv = false,
hideChatVar = false,
passedInAvailableNodes,
}: Params = {
onlyLeafNodeVar: false,
filterVar: () => true,
}) => {
const { getTreeLeafNodes, getBeforeNodesInSameBranchIncludeParent } = useWorkflow()
const { getNodeAvailableVars } = useWorkflowVariables()
const isChatMode = useIsChatMode()
const nodeAvailabilityMap: { [key: string ]: { availableVars: NodeOutPutVar[], availableNodes: Node[] } } = {}
nodes.forEach((node) => {
const nodeId = node.id
const availableNodes = passedInAvailableNodes || (onlyLeafNodeVar ? getTreeLeafNodes(nodeId) : getBeforeNodesInSameBranchIncludeParent(nodeId))
if (node.data.type === BlockEnum.Loop)
availableNodes.push(node)
const {
parentNode: iterationNode,
} = getNodeInfo(nodeId, nodes)
const availableVars = getNodeAvailableVars({
parentNode: iterationNode,
beforeNodes: availableNodes,
isChatMode,
filterVar,
hideEnv,
hideChatVar,
})
const result = {
node,
availableVars,
availableNodes,
}
nodeAvailabilityMap[nodeId] = result
})
return nodeAvailabilityMap
}
export default useNodesAvailableVarList

@ -948,9 +948,7 @@ export const getNodeUsedVars = (node: Node): ValueSelector[] => {
break break
} }
case BlockEnum.Answer: { case BlockEnum.Answer: {
res = (data as AnswerNodeType).variables?.map((v) => { res = matchNotSystemVars([(data as AnswerNodeType).answer])
return v.value_selector
})
break break
} }
case BlockEnum.LLM: { case BlockEnum.LLM: {

Loading…
Cancel
Save