diff --git a/api/.env.example b/api/.env.example
index 502461f658..01ddb4adfd 100644
--- a/api/.env.example
+++ b/api/.env.example
@@ -424,6 +424,12 @@ WORKFLOW_CALL_MAX_DEPTH=5
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
MAX_VARIABLE_SIZE=204800
+# Workflow storage configuration
+# Options: rdbms, hybrid
+# rdbms: Use only the relational database (default)
+# hybrid: Save new data to object storage, read from both object storage and RDBMS
+WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
+
# App configuration
APP_MAX_EXECUTION_TIME=1200
APP_MAX_ACTIVE_REQUESTS=0
diff --git a/api/app_factory.py b/api/app_factory.py
index 1c886ac5c7..586f2ded9e 100644
--- a/api/app_factory.py
+++ b/api/app_factory.py
@@ -54,6 +54,7 @@ def initialize_extensions(app: DifyApp):
ext_otel,
ext_proxy_fix,
ext_redis,
+ ext_repositories,
ext_sentry,
ext_set_secretkey,
ext_storage,
@@ -74,6 +75,7 @@ def initialize_extensions(app: DifyApp):
ext_migrate,
ext_redis,
ext_storage,
+ ext_repositories,
ext_celery,
ext_login,
ext_mail,
diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py
index d35a74e3ee..f498dccbbc 100644
--- a/api/configs/feature/__init__.py
+++ b/api/configs/feature/__init__.py
@@ -12,7 +12,7 @@ from pydantic import (
)
from pydantic_settings import BaseSettings
-from configs.feature.hosted_service import HostedServiceConfig
+from .hosted_service import HostedServiceConfig
class SecurityConfig(BaseSettings):
@@ -519,6 +519,11 @@ class WorkflowNodeExecutionConfig(BaseSettings):
default=100,
)
+ WORKFLOW_NODE_EXECUTION_STORAGE: str = Field(
+ default="rdbms",
+ description="Storage backend for WorkflowNodeExecution. Options: 'rdbms', 'hybrid'",
+ )
+
class AuthConfig(BaseSettings):
"""
diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py
index 8518d34a8e..4046417076 100644
--- a/api/controllers/console/app/generator.py
+++ b/api/controllers/console/app/generator.py
@@ -85,5 +85,35 @@ class RuleCodeGenerateApi(Resource):
return code_result
+class RuleStructuredOutputGenerateApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def post(self):
+ parser = reqparse.RequestParser()
+ parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
+ parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
+ args = parser.parse_args()
+
+ account = current_user
+ try:
+ structured_output = LLMGenerator.generate_structured_output(
+ tenant_id=account.current_tenant_id,
+ instruction=args["instruction"],
+ model_config=args["model_config"],
+ )
+ except ProviderTokenNotInitError as ex:
+ raise ProviderNotInitializeError(ex.description)
+ except QuotaExceededError:
+ raise ProviderQuotaExceededError()
+ except ModelCurrentlyNotSupportError:
+ raise ProviderModelCurrentlyNotSupportError()
+ except InvokeError as e:
+ raise CompletionRequestError(e.description)
+
+ return structured_output
+
+
api.add_resource(RuleGenerateApi, "/rule-generate")
api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")
+api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate")
diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py
index e911c9a5e5..b4bd80fe2f 100644
--- a/api/controllers/console/auth/data_source_oauth.py
+++ b/api/controllers/console/auth/data_source_oauth.py
@@ -74,7 +74,9 @@ class OAuthDataSourceBinding(Resource):
if not oauth_provider:
return {"error": "Invalid provider"}, 400
if "code" in request.args:
- code = request.args.get("code")
+ code = request.args.get("code", "")
+ if not code:
+ return {"error": "Invalid code"}, 400
try:
oauth_provider.get_access_token(code)
except requests.exceptions.HTTPError as e:
diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py
index dc0009f36e..d4a33645ab 100644
--- a/api/controllers/console/auth/forgot_password.py
+++ b/api/controllers/console/auth/forgot_password.py
@@ -16,7 +16,7 @@ from controllers.console.auth.error import (
PasswordMismatchError,
)
from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError
-from controllers.console.wraps import setup_required
+from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import email, extract_remote_ip
@@ -30,6 +30,7 @@ from services.feature_service import FeatureService
class ForgotPasswordSendEmailApi(Resource):
@setup_required
+ @email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
@@ -62,6 +63,7 @@ class ForgotPasswordSendEmailApi(Resource):
class ForgotPasswordCheckApi(Resource):
@setup_required
+ @email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
@@ -86,12 +88,21 @@ class ForgotPasswordCheckApi(Resource):
AccountService.add_forgot_password_error_rate_limit(args["email"])
raise EmailCodeError()
+ # Verified, revoke the first token
+ AccountService.revoke_reset_password_token(args["token"])
+
+ # Refresh token data by generating a new token
+ _, new_token = AccountService.generate_reset_password_token(
+ user_email, code=args["code"], additional_data={"phase": "reset"}
+ )
+
AccountService.reset_forgot_password_error_rate_limit(args["email"])
- return {"is_valid": True, "email": token_data.get("email")}
+ return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
class ForgotPasswordResetApi(Resource):
@setup_required
+ @email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
@@ -107,6 +118,9 @@ class ForgotPasswordResetApi(Resource):
reset_data = AccountService.get_reset_password_data(args["token"])
if not reset_data:
raise InvalidTokenError()
+ # Must use token in reset phase
+ if reset_data.get("phase", "") != "reset":
+ raise InvalidTokenError()
# Revoke token to prevent reuse
AccountService.revoke_reset_password_token(args["token"])
diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py
index 41362e9fa2..16c1dcc441 100644
--- a/api/controllers/console/auth/login.py
+++ b/api/controllers/console/auth/login.py
@@ -22,7 +22,7 @@ from controllers.console.error import (
EmailSendIpLimitError,
NotAllowedCreateWorkspace,
)
-from controllers.console.wraps import setup_required
+from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip
from libs.password import valid_password
@@ -38,6 +38,7 @@ class LoginApi(Resource):
"""Resource for user login."""
@setup_required
+ @email_password_login_enabled
def post(self):
"""Authenticate user and login."""
parser = reqparse.RequestParser()
@@ -110,6 +111,7 @@ class LogoutApi(Resource):
class ResetPasswordSendEmailApi(Resource):
@setup_required
+ @email_password_login_enabled
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json")
diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py
index 6caaae87f4..e5e8038ad7 100644
--- a/api/controllers/console/wraps.py
+++ b/api/controllers/console/wraps.py
@@ -210,3 +210,16 @@ def enterprise_license_required(view):
return view(*args, **kwargs)
return decorated
+
+
+def email_password_login_enabled(view):
+ @wraps(view)
+ def decorated(*args, **kwargs):
+ features = FeatureService.get_system_features()
+ if features.enable_email_password_login:
+ return view(*args, **kwargs)
+
+ # otherwise, return 403
+ abort(403)
+
+ return decorated
diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py
index 494b357d46..17e9a3990f 100644
--- a/api/controllers/web/message.py
+++ b/api/controllers/web/message.py
@@ -46,6 +46,7 @@ class MessageListApi(WebApiResource):
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
+ "metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
}
diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
index 66f2c754bb..3bf6c330db 100644
--- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py
+++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
@@ -320,10 +320,9 @@ class AdvancedChatAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
- session=session, workflow_run=workflow_run, event=event
+ workflow_run=workflow_run, event=event
)
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
- session=session,
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@@ -341,11 +340,10 @@ class AdvancedChatAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
- session=session, workflow_run=workflow_run, event=event
+ workflow_run=workflow_run, event=event
)
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
- session=session,
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@@ -363,11 +361,10 @@ class AdvancedChatAppGenerateTaskPipeline:
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
- session=session, event=event
+ event=event
)
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
- session=session,
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@@ -383,18 +380,15 @@ class AdvancedChatAppGenerateTaskPipeline:
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
):
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
- session=session, event=event
- )
+ workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
+ event=event
+ )
- node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
- session=session,
- event=event,
- task_id=self._application_generate_entity.task_id,
- workflow_node_execution=workflow_node_execution,
- )
- session.commit()
+ node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
+ event=event,
+ task_id=self._application_generate_entity.task_id,
+ workflow_node_execution=workflow_node_execution,
+ )
if node_finish_resp:
yield node_finish_resp
diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py
index 14441ada40..1f998edb6a 100644
--- a/api/core/app/apps/workflow/generate_task_pipeline.py
+++ b/api/core/app/apps/workflow/generate_task_pipeline.py
@@ -279,10 +279,9 @@ class WorkflowAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
- session=session, workflow_run=workflow_run, event=event
+ workflow_run=workflow_run, event=event
)
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
- session=session,
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@@ -300,10 +299,9 @@ class WorkflowAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
- session=session, workflow_run=workflow_run, event=event
+ workflow_run=workflow_run, event=event
)
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
- session=session,
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@@ -313,17 +311,14 @@ class WorkflowAppGenerateTaskPipeline:
if node_start_response:
yield node_start_response
elif isinstance(event, QueueNodeSucceededEvent):
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
- session=session, event=event
- )
- node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
- session=session,
- event=event,
- task_id=self._application_generate_entity.task_id,
- workflow_node_execution=workflow_node_execution,
- )
- session.commit()
+ workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
+ event=event
+ )
+ node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
+ event=event,
+ task_id=self._application_generate_entity.task_id,
+ workflow_node_execution=workflow_node_execution,
+ )
if node_success_response:
yield node_success_response
@@ -334,18 +329,14 @@ class WorkflowAppGenerateTaskPipeline:
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
):
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
- session=session,
- event=event,
- )
- node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
- session=session,
- event=event,
- task_id=self._application_generate_entity.task_id,
- workflow_node_execution=workflow_node_execution,
- )
- session.commit()
+ workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
+ event=event,
+ )
+ node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
+ event=event,
+ task_id=self._application_generate_entity.task_id,
+ workflow_node_execution=workflow_node_execution,
+ )
if node_failed_response:
yield node_failed_response
@@ -627,6 +618,7 @@ class WorkflowAppGenerateTaskPipeline:
workflow_app_log.created_by = self._user_id
session.add(workflow_app_log)
+ session.commit()
def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: Optional[list[str]] = None
diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py
index 4d629ca186..5ce9f737d1 100644
--- a/api/core/app/task_pipeline/workflow_cycle_manage.py
+++ b/api/core/app/task_pipeline/workflow_cycle_manage.py
@@ -6,7 +6,7 @@ from typing import Any, Optional, Union, cast
from uuid import uuid4
from sqlalchemy import func, select
-from sqlalchemy.orm import Session
+from sqlalchemy.orm import Session, sessionmaker
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
@@ -49,12 +49,14 @@ from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
+from core.repository import RepositoryFactory
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_entry import WorkflowEntry
+from extensions.ext_database import db
from models.account import Account
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
from models.model import EndUser
@@ -80,6 +82,21 @@ class WorkflowCycleManage:
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
+ # Initialize the session factory and repository
+ # We use the global db engine instead of the session passed to methods
+ # Disable expire_on_commit to avoid the need for merging objects
+ self._session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
+ self._workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
+ params={
+ "tenant_id": self._application_generate_entity.app_config.tenant_id,
+ "app_id": self._application_generate_entity.app_config.app_id,
+ "session_factory": self._session_factory,
+ }
+ )
+
+ # We'll still keep the cache for backward compatibility and performance
+ # but use the repository for database operations
+
def _handle_workflow_run_start(
self,
*,
@@ -254,19 +271,15 @@ class WorkflowCycleManage:
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
- stmt = select(WorkflowNodeExecution.node_execution_id).where(
- WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
- WorkflowNodeExecution.app_id == workflow_run.app_id,
- WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
- WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
- WorkflowNodeExecution.workflow_run_id == workflow_run.id,
- WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
+ # Use the instance repository to find running executions for a workflow run
+ running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions(
+ workflow_run_id=workflow_run.id
)
- ids = session.scalars(stmt).all()
- # Use self._get_workflow_node_execution here to make sure the cache is updated
- running_workflow_node_executions = [
- self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
- ]
+
+ # Update the cache with the retrieved executions
+ for execution in running_workflow_node_executions:
+ if execution.node_execution_id:
+ self._workflow_node_executions[execution.node_execution_id] = execution
for workflow_node_execution in running_workflow_node_executions:
now = datetime.now(UTC).replace(tzinfo=None)
@@ -288,7 +301,7 @@ class WorkflowCycleManage:
return workflow_run
def _handle_node_execution_start(
- self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
+ self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
) -> WorkflowNodeExecution:
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.id = str(uuid4())
@@ -315,17 +328,14 @@ class WorkflowCycleManage:
)
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
- session.add(workflow_node_execution)
+ # Use the instance repository to save the workflow node execution
+ self._workflow_node_execution_repository.save(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution
- def _handle_workflow_node_execution_success(
- self, *, session: Session, event: QueueNodeSucceededEvent
- ) -> WorkflowNodeExecution:
- workflow_node_execution = self._get_workflow_node_execution(
- session=session, node_execution_id=event.node_execution_id
- )
+ def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
+ workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
@@ -344,13 +354,13 @@ class WorkflowCycleManage:
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
- workflow_node_execution = session.merge(workflow_node_execution)
+ # Use the instance repository to update the workflow node execution
+ self._workflow_node_execution_repository.update(workflow_node_execution)
return workflow_node_execution
def _handle_workflow_node_execution_failed(
self,
*,
- session: Session,
event: QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
@@ -361,9 +371,7 @@ class WorkflowCycleManage:
:param event: queue node failed event
:return:
"""
- workflow_node_execution = self._get_workflow_node_execution(
- session=session, node_execution_id=event.node_execution_id
- )
+ workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
@@ -387,14 +395,14 @@ class WorkflowCycleManage:
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata
- workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution
def _handle_workflow_node_execution_retried(
- self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
+ self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
+ :param workflow_run: workflow run
:param event: queue node failed event
:return:
"""
@@ -439,15 +447,12 @@ class WorkflowCycleManage:
workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution.index = event.node_run_index
- session.add(workflow_node_execution)
+ # Use the instance repository to save the workflow node execution
+ self._workflow_node_execution_repository.save(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution
- #################################################
- # to stream responses #
- #################################################
-
def _workflow_start_to_stream_response(
self,
*,
@@ -455,7 +460,6 @@ class WorkflowCycleManage:
task_id: str,
workflow_run: WorkflowRun,
) -> WorkflowStartStreamResponse:
- # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return WorkflowStartStreamResponse(
task_id=task_id,
@@ -521,14 +525,10 @@ class WorkflowCycleManage:
def _workflow_node_start_to_stream_response(
self,
*,
- session: Session,
event: QueueNodeStartedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeStartStreamResponse]:
- # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
- _ = session
-
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
if not workflow_node_execution.workflow_run_id:
@@ -571,7 +571,6 @@ class WorkflowCycleManage:
def _workflow_node_finish_to_stream_response(
self,
*,
- session: Session,
event: QueueNodeSucceededEvent
| QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
@@ -580,8 +579,6 @@ class WorkflowCycleManage:
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
- # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
- _ = session
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
if not workflow_node_execution.workflow_run_id:
@@ -621,13 +618,10 @@ class WorkflowCycleManage:
def _workflow_node_retry_to_stream_response(
self,
*,
- session: Session,
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
- # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
- _ = session
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
if not workflow_node_execution.workflow_run_id:
@@ -668,7 +662,6 @@ class WorkflowCycleManage:
def _workflow_parallel_branch_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse:
- # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return ParallelBranchStartStreamResponse(
task_id=task_id,
@@ -692,7 +685,6 @@ class WorkflowCycleManage:
workflow_run: WorkflowRun,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
) -> ParallelBranchFinishedStreamResponse:
- # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return ParallelBranchFinishedStreamResponse(
task_id=task_id,
@@ -713,7 +705,6 @@ class WorkflowCycleManage:
def _workflow_iteration_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
) -> IterationNodeStartStreamResponse:
- # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return IterationNodeStartStreamResponse(
task_id=task_id,
@@ -735,7 +726,6 @@ class WorkflowCycleManage:
def _workflow_iteration_next_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
) -> IterationNodeNextStreamResponse:
- # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return IterationNodeNextStreamResponse(
task_id=task_id,
@@ -759,7 +749,6 @@ class WorkflowCycleManage:
def _workflow_iteration_completed_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
) -> IterationNodeCompletedStreamResponse:
- # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return IterationNodeCompletedStreamResponse(
task_id=task_id,
@@ -790,7 +779,6 @@ class WorkflowCycleManage:
def _workflow_loop_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent
) -> LoopNodeStartStreamResponse:
- # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return LoopNodeStartStreamResponse(
task_id=task_id,
@@ -812,7 +800,6 @@ class WorkflowCycleManage:
def _workflow_loop_next_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent
) -> LoopNodeNextStreamResponse:
- # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return LoopNodeNextStreamResponse(
task_id=task_id,
@@ -836,7 +823,6 @@ class WorkflowCycleManage:
def _workflow_loop_completed_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent
) -> LoopNodeCompletedStreamResponse:
- # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return LoopNodeCompletedStreamResponse(
task_id=task_id,
@@ -934,11 +920,22 @@ class WorkflowCycleManage:
return workflow_run
- def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
- if node_execution_id not in self._workflow_node_executions:
+ def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
+ # First check the cache for performance
+ if node_execution_id in self._workflow_node_executions:
+ cached_execution = self._workflow_node_executions[node_execution_id]
+ # No need to merge with session since expire_on_commit=False
+ return cached_execution
+
+ # If not in cache, use the instance repository to get by node_execution_id
+ execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id)
+
+ if not execution:
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
- cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
- return session.merge(cached_workflow_node_execution)
+
+ # Update cache
+ self._workflow_node_executions[node_execution_id] = execution
+ return execution
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
"""
diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py
index 64c734f626..56859df7f4 100644
--- a/api/core/callback_handler/index_tool_callback_handler.py
+++ b/api/core/callback_handler/index_tool_callback_handler.py
@@ -6,7 +6,6 @@ from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument
-from models.model import DatasetRetrieverResource
class DatasetIndexToolCallbackHandler:
@@ -71,29 +70,6 @@ class DatasetIndexToolCallbackHandler:
def return_retriever_resource_info(self, resource: list):
"""Handle return_retriever_resource_info."""
- if resource and len(resource) > 0:
- for item in resource:
- dataset_retriever_resource = DatasetRetrieverResource(
- message_id=self._message_id,
- position=item.get("position") or 0,
- dataset_id=item.get("dataset_id"),
- dataset_name=item.get("dataset_name"),
- document_id=item.get("document_id"),
- document_name=item.get("document_name"),
- data_source_type=item.get("data_source_type"),
- segment_id=item.get("segment_id"),
- score=item.get("score") if "score" in item else None,
- hit_count=item.get("hit_count") if "hit_count" in item else None,
- word_count=item.get("word_count") if "word_count" in item else None,
- segment_position=item.get("segment_position") if "segment_position" in item else None,
- index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None,
- content=item.get("content"),
- retriever_from=item.get("retriever_from"),
- created_by=self._user_id,
- )
- db.session.add(dataset_retriever_resource)
- db.session.commit()
-
self._queue_manager.publish(
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
)
diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py
index 75687f9ae3..d5d2ca60fa 100644
--- a/api/core/llm_generator/llm_generator.py
+++ b/api/core/llm_generator/llm_generator.py
@@ -10,6 +10,7 @@ from core.llm_generator.prompts import (
GENERATOR_QA_PROMPT,
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
+ SYSTEM_STRUCTURED_OUTPUT_GENERATE,
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
)
from core.model_manager import ModelManager
@@ -340,3 +341,37 @@ class LLMGenerator:
answer = cast(str, response.message.content)
return answer.strip()
+
+ @classmethod
+ def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict):
+ model_manager = ModelManager()
+ model_instance = model_manager.get_model_instance(
+ tenant_id=tenant_id,
+ model_type=ModelType.LLM,
+ provider=model_config.get("provider", ""),
+ model=model_config.get("name", ""),
+ )
+
+ prompt_messages = [
+ SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE),
+ UserPromptMessage(content=instruction),
+ ]
+ model_parameters = model_config.get("model_parameters", {})
+
+ try:
+ response = cast(
+ LLMResult,
+ model_instance.invoke_llm(
+ prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
+ ),
+ )
+
+ generated_json_schema = cast(str, response.message.content)
+ return {"output": generated_json_schema, "error": ""}
+
+ except InvokeError as e:
+ error = str(e)
+ return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
+ except Exception as e:
+ logging.exception(f"Failed to invoke LLM model, model: {model_config.get('name')}")
+ return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py
index cf20e60c82..82d22d7f89 100644
--- a/api/core/llm_generator/prompts.py
+++ b/api/core/llm_generator/prompts.py
@@ -220,3 +220,110 @@ Here is the task description: {{INPUT_TEXT}}
You just need to generate the output
""" # noqa: E501
+
+SYSTEM_STRUCTURED_OUTPUT_GENERATE = """
+Your task is to convert simple user descriptions into properly formatted JSON Schema definitions. When a user describes data fields they need, generate a complete, valid JSON Schema that accurately represents those fields with appropriate types and requirements.
+
+## Instructions:
+
+1. Analyze the user's description of their data needs
+2. Identify each property that should be included in the schema
+3. Determine the appropriate data type for each property
+4. Decide which properties should be required
+5. Generate a complete JSON Schema with proper syntax
+6. Include appropriate constraints when specified (min/max values, patterns, formats)
+7. Provide ONLY the JSON Schema without any additional explanations, comments, or markdown formatting.
+8. DO NOT use markdown code blocks (``` or ``` json). Return the raw JSON Schema directly.
+
+## Examples:
+
+### Example 1:
+**User Input:** I need name and age
+**JSON Schema Output:**
+{
+ "type": "object",
+ "properties": {
+ "name": { "type": "string" },
+ "age": { "type": "number" }
+ },
+ "required": ["name", "age"]
+}
+
+### Example 2:
+**User Input:** I want to store information about books including title, author, publication year and optional page count
+**JSON Schema Output:**
+{
+ "type": "object",
+ "properties": {
+ "title": { "type": "string" },
+ "author": { "type": "string" },
+ "publicationYear": { "type": "integer" },
+ "pageCount": { "type": "integer" }
+ },
+ "required": ["title", "author", "publicationYear"]
+}
+
+### Example 3:
+**User Input:** Create a schema for user profiles with email, password, and age (must be at least 18)
+**JSON Schema Output:**
+{
+ "type": "object",
+ "properties": {
+ "email": {
+ "type": "string",
+ "format": "email"
+ },
+ "password": {
+ "type": "string",
+ "minLength": 8
+ },
+ "age": {
+ "type": "integer",
+ "minimum": 18
+ }
+ },
+ "required": ["email", "password", "age"]
+}
+
+### Example 4:
+**User Input:** I need album schema, the ablum has songs, and each song has name, duration, and artist.
+**JSON Schema Output:**
+{
+ "type": "object",
+ "properties": {
+ "properties": {
+ "songs": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "name": {
+ "type": "string"
+ },
+ "id": {
+ "type": "string"
+ },
+ "duration": {
+ "type": "string"
+ },
+ "aritst": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "name",
+ "id",
+ "duration",
+ "aritst"
+ ]
+ }
+ }
+ }
+ },
+ "required": [
+ "songs"
+ ]
+}
+
+Now, generate a JSON Schema based on my description
+""" # noqa: E501
diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py
index 977678b893..3bed2460dd 100644
--- a/api/core/model_runtime/entities/message_entities.py
+++ b/api/core/model_runtime/entities/message_entities.py
@@ -1,8 +1,8 @@
from collections.abc import Sequence
from enum import Enum, StrEnum
-from typing import Optional
+from typing import Any, Optional, Union
-from pydantic import BaseModel, Field, field_validator
+from pydantic import BaseModel, Field, field_serializer, field_validator
class PromptMessageRole(Enum):
@@ -135,6 +135,16 @@ class PromptMessage(BaseModel):
"""
return not self.content
+ @field_serializer("content")
+ def serialize_content(
+ self, content: Optional[Union[str, Sequence[PromptMessageContent]]]
+ ) -> Optional[str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent]]:
+ if content is None or isinstance(content, str):
+ return content
+ if isinstance(content, list):
+ return [item.model_dump() if hasattr(item, "model_dump") else item for item in content]
+ return content
+
class UserPromptMessage(PromptMessage):
"""
diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py
index 3225f03fbd..373ef2bbe2 100644
--- a/api/core/model_runtime/entities/model_entities.py
+++ b/api/core/model_runtime/entities/model_entities.py
@@ -2,7 +2,7 @@ from decimal import Decimal
from enum import Enum, StrEnum
from typing import Any, Optional
-from pydantic import BaseModel, ConfigDict
+from pydantic import BaseModel, ConfigDict, model_validator
from core.model_runtime.entities.common_entities import I18nObject
@@ -85,6 +85,7 @@ class ModelFeature(Enum):
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"
+ STRUCTURED_OUTPUT = "structured-output"
class DefaultParameterName(StrEnum):
@@ -197,6 +198,19 @@ class AIModelEntity(ProviderModel):
parameter_rules: list[ParameterRule] = []
pricing: Optional[PriceConfig] = None
+ @model_validator(mode="after")
+ def validate_model(self):
+ supported_schema_keys = ["json_schema"]
+ schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None)
+ if not schema_key:
+ return self
+ if self.features is None:
+ self.features = [ModelFeature.STRUCTURED_OUTPUT]
+ else:
+ if ModelFeature.STRUCTURED_OUTPUT not in self.features:
+ self.features.append(ModelFeature.STRUCTURED_OUTPUT)
+ return self
+
class ModelUsage(BaseModel):
pass
diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py
index 53de16d621..1b799131e7 100644
--- a/api/core/model_runtime/model_providers/__base/large_language_model.py
+++ b/api/core/model_runtime/model_providers/__base/large_language_model.py
@@ -1,5 +1,6 @@
import logging
import time
+import uuid
from collections.abc import Generator, Sequence
from typing import Optional, Union
@@ -24,6 +25,58 @@ from core.plugin.manager.model import PluginModelManager
logger = logging.getLogger(__name__)
+def _gen_tool_call_id() -> str:
+ return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
+
+
+def _increase_tool_call(
+ new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall]
+):
+ """
+ Merge incremental tool call updates into existing tool calls.
+
+ :param new_tool_calls: List of new tool call deltas to be merged.
+ :param existing_tools_calls: List of existing tool calls to be modified IN-PLACE.
+ """
+
+ def get_tool_call(tool_call_id: str):
+ """
+ Get or create a tool call by ID
+
+ :param tool_call_id: tool call ID
+ :return: existing or new tool call
+ """
+ if not tool_call_id:
+ return existing_tools_calls[-1]
+
+ _tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None)
+ if _tool_call is None:
+ _tool_call = AssistantPromptMessage.ToolCall(
+ id=tool_call_id,
+ type="function",
+ function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
+ )
+ existing_tools_calls.append(_tool_call)
+
+ return _tool_call
+
+ for new_tool_call in new_tool_calls:
+ # generate ID for tool calls with function name but no ID to track them
+ if new_tool_call.function.name and not new_tool_call.id:
+ new_tool_call.id = _gen_tool_call_id()
+ # get tool call
+ tool_call = get_tool_call(new_tool_call.id)
+ # update tool call
+ if new_tool_call.id:
+ tool_call.id = new_tool_call.id
+ if new_tool_call.type:
+ tool_call.type = new_tool_call.type
+ if new_tool_call.function.name:
+ tool_call.function.name = new_tool_call.function.name
+ if new_tool_call.function.arguments:
+ tool_call.function.arguments += new_tool_call.function.arguments
+
+
class LargeLanguageModel(AIModel):
"""
Model class for large language model.
@@ -109,44 +162,13 @@ class LargeLanguageModel(AIModel):
system_fingerprint = None
tools_calls: list[AssistantPromptMessage.ToolCall] = []
- def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
- def get_tool_call(tool_name: str):
- if not tool_name:
- return tools_calls[-1]
-
- tool_call = next(
- (tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None
- )
- if tool_call is None:
- tool_call = AssistantPromptMessage.ToolCall(
- id="",
- type="",
- function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""),
- )
- tools_calls.append(tool_call)
-
- return tool_call
-
- for new_tool_call in new_tool_calls:
- # get tool call
- tool_call = get_tool_call(new_tool_call.function.name)
- # update tool call
- if new_tool_call.id:
- tool_call.id = new_tool_call.id
- if new_tool_call.type:
- tool_call.type = new_tool_call.type
- if new_tool_call.function.name:
- tool_call.function.name = new_tool_call.function.name
- if new_tool_call.function.arguments:
- tool_call.function.arguments += new_tool_call.function.arguments
-
for chunk in result:
if isinstance(chunk.delta.message.content, str):
content += chunk.delta.message.content
elif isinstance(chunk.delta.message.content, list):
content_list.extend(chunk.delta.message.content)
if chunk.delta.message.tool_calls:
- increase_tool_call(chunk.delta.message.tool_calls)
+ _increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
usage = chunk.delta.usage or LLMUsage.empty_usage()
system_fingerprint = chunk.system_fingerprint
diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py
index f67e270ab1..fa78b7b8e9 100644
--- a/api/core/ops/langfuse_trace/langfuse_trace.py
+++ b/api/core/ops/langfuse_trace/langfuse_trace.py
@@ -5,6 +5,7 @@ from datetime import datetime, timedelta
from typing import Optional
from langfuse import Langfuse # type: ignore
+from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangfuseConfig
@@ -28,9 +29,9 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
UnitEnum,
)
from core.ops.utils import filter_none_values
+from core.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db
from models.model import EndUser
-from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__)
@@ -110,36 +111,18 @@ class LangFuseDataTrace(BaseTraceInstance):
)
self.add_trace(langfuse_trace_data=trace_data)
- # through workflow_run_id get all_nodes_execution
- workflow_nodes_execution_id_records = (
- db.session.query(WorkflowNodeExecution.id)
- .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
- .all()
+ # through workflow_run_id get all_nodes_execution using repository
+ session_factory = sessionmaker(bind=db.engine)
+ workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
+ params={"tenant_id": trace_info.tenant_id, "session_factory": session_factory},
)
- for node_execution_id_record in workflow_nodes_execution_id_records:
- node_execution = (
- db.session.query(
- WorkflowNodeExecution.id,
- WorkflowNodeExecution.tenant_id,
- WorkflowNodeExecution.app_id,
- WorkflowNodeExecution.title,
- WorkflowNodeExecution.node_type,
- WorkflowNodeExecution.status,
- WorkflowNodeExecution.inputs,
- WorkflowNodeExecution.outputs,
- WorkflowNodeExecution.created_at,
- WorkflowNodeExecution.elapsed_time,
- WorkflowNodeExecution.process_data,
- WorkflowNodeExecution.execution_metadata,
- )
- .filter(WorkflowNodeExecution.id == node_execution_id_record.id)
- .first()
- )
-
- if not node_execution:
- continue
+ # Get all executions for this workflow run
+ workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
+ workflow_run_id=trace_info.workflow_run_id
+ )
+ for node_execution in workflow_node_executions:
node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id
app_id = node_execution.app_id
diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py
index e3494e2f23..85a0eafdc1 100644
--- a/api/core/ops/langsmith_trace/langsmith_trace.py
+++ b/api/core/ops/langsmith_trace/langsmith_trace.py
@@ -7,6 +7,7 @@ from typing import Optional, cast
from langsmith import Client
from langsmith.schemas import RunBase
+from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangSmithConfig
@@ -27,9 +28,9 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
LangSmithRunUpdateModel,
)
from core.ops.utils import filter_none_values, generate_dotted_order
+from core.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db
from models.model import EndUser, MessageFile
-from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__)
@@ -134,36 +135,22 @@ class LangSmithDataTrace(BaseTraceInstance):
self.add_run(langsmith_run)
- # through workflow_run_id get all_nodes_execution
- workflow_nodes_execution_id_records = (
- db.session.query(WorkflowNodeExecution.id)
- .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
- .all()
+ # through workflow_run_id get all_nodes_execution using repository
+ session_factory = sessionmaker(bind=db.engine)
+ workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
+ params={
+ "tenant_id": trace_info.tenant_id,
+ "app_id": trace_info.metadata.get("app_id"),
+ "session_factory": session_factory,
+ },
)
- for node_execution_id_record in workflow_nodes_execution_id_records:
- node_execution = (
- db.session.query(
- WorkflowNodeExecution.id,
- WorkflowNodeExecution.tenant_id,
- WorkflowNodeExecution.app_id,
- WorkflowNodeExecution.title,
- WorkflowNodeExecution.node_type,
- WorkflowNodeExecution.status,
- WorkflowNodeExecution.inputs,
- WorkflowNodeExecution.outputs,
- WorkflowNodeExecution.created_at,
- WorkflowNodeExecution.elapsed_time,
- WorkflowNodeExecution.process_data,
- WorkflowNodeExecution.execution_metadata,
- )
- .filter(WorkflowNodeExecution.id == node_execution_id_record.id)
- .first()
- )
-
- if not node_execution:
- continue
+ # Get all executions for this workflow run
+ workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
+ workflow_run_id=trace_info.workflow_run_id
+ )
+ for node_execution in workflow_node_executions:
node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id
app_id = node_execution.app_id
diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py
index fabf38fbd6..923b9a24ed 100644
--- a/api/core/ops/opik_trace/opik_trace.py
+++ b/api/core/ops/opik_trace/opik_trace.py
@@ -7,6 +7,7 @@ from typing import Optional, cast
from opik import Opik, Trace
from opik.id_helpers import uuid4_to_uuid7
+from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import OpikConfig
@@ -21,9 +22,9 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
+from core.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db
from models.model import EndUser, MessageFile
-from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__)
@@ -147,36 +148,22 @@ class OpikDataTrace(BaseTraceInstance):
}
self.add_trace(trace_data)
- # through workflow_run_id get all_nodes_execution
- workflow_nodes_execution_id_records = (
- db.session.query(WorkflowNodeExecution.id)
- .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
- .all()
+ # through workflow_run_id get all_nodes_execution using repository
+ session_factory = sessionmaker(bind=db.engine)
+ workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
+ params={
+ "tenant_id": trace_info.tenant_id,
+ "app_id": trace_info.metadata.get("app_id"),
+ "session_factory": session_factory,
+ },
)
- for node_execution_id_record in workflow_nodes_execution_id_records:
- node_execution = (
- db.session.query(
- WorkflowNodeExecution.id,
- WorkflowNodeExecution.tenant_id,
- WorkflowNodeExecution.app_id,
- WorkflowNodeExecution.title,
- WorkflowNodeExecution.node_type,
- WorkflowNodeExecution.status,
- WorkflowNodeExecution.inputs,
- WorkflowNodeExecution.outputs,
- WorkflowNodeExecution.created_at,
- WorkflowNodeExecution.elapsed_time,
- WorkflowNodeExecution.process_data,
- WorkflowNodeExecution.execution_metadata,
- )
- .filter(WorkflowNodeExecution.id == node_execution_id_record.id)
- .first()
- )
-
- if not node_execution:
- continue
+ # Get all executions for this workflow run
+ workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
+ workflow_run_id=trace_info.workflow_run_id
+ )
+ for node_execution in workflow_node_executions:
node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id
app_id = node_execution.app_id
diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py
index f402da030f..db07e52f3f 100644
--- a/api/core/plugin/backwards_invocation/node.py
+++ b/api/core/plugin/backwards_invocation/node.py
@@ -39,6 +39,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
:param query: str
:return: dict
"""
+ # FIXME(-LAN-): Avoid import service into core
workflow_service = WorkflowService()
node_id = "1919810"
node_data = ParameterExtractorNodeData(
@@ -89,6 +90,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
:param query: str
:return: dict
"""
+ # FIXME(-LAN-): Avoid import service into core
workflow_service = WorkflowService()
node_id = "1919810"
node_data = QuestionClassifierNodeData(
diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py
index 099acfd7f4..7570200175 100644
--- a/api/core/provider_manager.py
+++ b/api/core/provider_manager.py
@@ -124,6 +124,15 @@ class ProviderManager:
# Get All preferred provider types of the workspace
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
+ # Ensure that both the original provider name and its ModelProviderID string representation
+ # are present in the dictionary to handle cases where either form might be used
+ for provider_name in list(provider_name_to_preferred_model_provider_records_dict.keys()):
+ provider_id = ModelProviderID(provider_name)
+ if str(provider_id) not in provider_name_to_preferred_model_provider_records_dict:
+ # Add the ModelProviderID string representation if it's not already present
+ provider_name_to_preferred_model_provider_records_dict[str(provider_id)] = (
+ provider_name_to_preferred_model_provider_records_dict[provider_name]
+ )
# Get All provider model settings
provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id)
@@ -497,8 +506,8 @@ class ProviderManager:
@staticmethod
def _init_trial_provider_records(
- tenant_id: str, provider_name_to_provider_records_dict: dict[str, list]
- ) -> dict[str, list]:
+ tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
+ ) -> dict[str, list[Provider]]:
"""
Initialize trial provider records if not exists.
@@ -532,7 +541,7 @@ class ProviderManager:
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
try:
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
- provider_record = Provider(
+ new_provider_record = Provider(
tenant_id=tenant_id,
# TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name,
@@ -542,11 +551,12 @@ class ProviderManager:
quota_used=0,
is_valid=True,
)
- db.session.add(provider_record)
+ db.session.add(new_provider_record)
db.session.commit()
+ provider_name_to_provider_records_dict[provider_name].append(new_provider_record)
except IntegrityError:
db.session.rollback()
- provider_record = (
+ existed_provider_record = (
db.session.query(Provider)
.filter(
Provider.tenant_id == tenant_id,
@@ -556,11 +566,14 @@ class ProviderManager:
)
.first()
)
- if provider_record and not provider_record.is_valid:
- provider_record.is_valid = True
+ if not existed_provider_record:
+ continue
+
+ if not existed_provider_record.is_valid:
+ existed_provider_record.is_valid = True
db.session.commit()
- provider_name_to_provider_records_dict[provider_name].append(provider_record)
+ provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
return provider_name_to_provider_records_dict
diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
index 778e8a07d8..c1792943bb 100644
--- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
+++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
@@ -246,7 +246,7 @@ class AnalyticdbVectorBySql:
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
FROM {self.table_name}
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
- ORDER BY (score,id) DESC
+ ORDER BY score DESC, id DESC
LIMIT {top_k}""",
(f"'{query}'", f"'{query}'"),
)
diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py
index 4af2578197..63695e6f3f 100644
--- a/api/core/rag/datasource/vdb/oracle/oraclevector.py
+++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py
@@ -2,12 +2,12 @@ import array
import json
import re
import uuid
-from contextlib import contextmanager
from typing import Any
import jieba.posseg as pseg # type: ignore
import numpy
import oracledb
+from oracledb.connection import Connection
from pydantic import BaseModel, model_validator
from configs import dify_config
@@ -70,6 +70,7 @@ class OracleVector(BaseVector):
super().__init__(collection_name)
self.pool = self._create_connection_pool(config)
self.table_name = f"embedding_{collection_name}"
+ self.config = config
def get_type(self) -> str:
return VectorType.ORACLE
@@ -107,16 +108,19 @@ class OracleVector(BaseVector):
outconverter=self.numpy_converter_out,
)
+ def _get_connection(self) -> Connection:
+ connection = oracledb.connect(user=self.config.user, password=self.config.password, dsn=self.config.dsn)
+ return connection
+
def _create_connection_pool(self, config: OracleVectorConfig):
pool_params = {
"user": config.user,
"password": config.password,
"dsn": config.dsn,
"min": 1,
- "max": 50,
+ "max": 5,
"increment": 1,
}
-
if config.is_autonomous:
pool_params.update(
{
@@ -125,22 +129,8 @@ class OracleVector(BaseVector):
"wallet_password": config.wallet_password,
}
)
-
return oracledb.create_pool(**pool_params)
- @contextmanager
- def _get_cursor(self):
- conn = self.pool.acquire()
- conn.inputtypehandler = self.input_type_handler
- conn.outputtypehandler = self.output_type_handler
- cur = conn.cursor()
- try:
- yield cur
- finally:
- cur.close()
- conn.commit()
- conn.close()
-
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
@@ -162,41 +152,68 @@ class OracleVector(BaseVector):
numpy.array(embeddings[i]),
)
)
- # print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
- with self._get_cursor() as cur:
- cur.executemany(
- f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values
- )
+ with self._get_connection() as conn:
+ conn.inputtypehandler = self.input_type_handler
+ conn.outputtypehandler = self.output_type_handler
+ # with conn.cursor() as cur:
+ # cur.executemany(
+ # f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values
+ # )
+ # conn.commit()
+ for value in values:
+ with conn.cursor() as cur:
+ try:
+ cur.execute(
+ f"""INSERT INTO {self.table_name} (id, text, meta, embedding)
+ VALUES (:1, :2, :3, :4)""",
+ value,
+ )
+ conn.commit()
+ except Exception as e:
+ print(e)
+ conn.close()
return pks
def text_exists(self, id: str) -> bool:
- with self._get_cursor() as cur:
- cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
- return cur.fetchone() is not None
+ with self._get_connection() as conn:
+ with conn.cursor() as cur:
+ cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
+ return cur.fetchone() is not None
+ conn.close()
def get_by_ids(self, ids: list[str]) -> list[Document]:
- with self._get_cursor() as cur:
- cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
- docs = []
- for record in cur:
- docs.append(Document(page_content=record[1], metadata=record[0]))
+ with self._get_connection() as conn:
+ with conn.cursor() as cur:
+ cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
+ docs = []
+ for record in cur:
+ docs.append(Document(page_content=record[1], metadata=record[0]))
+ self.pool.release(connection=conn)
+ conn.close()
return docs
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
- with self._get_cursor() as cur:
- cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
+ with self._get_connection() as conn:
+ with conn.cursor() as cur:
+ cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
+ conn.commit()
+ conn.close()
def delete_by_metadata_field(self, key: str, value: str) -> None:
- with self._get_cursor() as cur:
- cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
+ with self._get_connection() as conn:
+ with conn.cursor() as cur:
+ cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
+ conn.commit()
+ conn.close()
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""
Search the nearest neighbors to a vector.
:param query_vector: The input vector to search for similar items.
+ :param top_k: The number of nearest neighbors to return, default is 5.
:return: List of Documents that are nearest to the query vector.
"""
top_k = kwargs.get("top_k", 4)
@@ -205,20 +222,25 @@ class OracleVector(BaseVector):
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
- with self._get_cursor() as cur:
- cur.execute(
- f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
- f" {where_clause} ORDER BY distance fetch first {top_k} rows only",
- [numpy.array(query_vector)],
- )
- docs = []
- score_threshold = float(kwargs.get("score_threshold") or 0.0)
- for record in cur:
- metadata, text, distance = record
- score = 1 - distance
- metadata["score"] = score
- if score > score_threshold:
- docs.append(Document(page_content=text, metadata=metadata))
+ with self._get_connection() as conn:
+ conn.inputtypehandler = self.input_type_handler
+ conn.outputtypehandler = self.output_type_handler
+ with conn.cursor() as cur:
+ cur.execute(
+ f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine)
+ AS distance FROM {self.table_name}
+ {where_clause} ORDER BY distance fetch first {top_k} rows only""",
+ [numpy.array(query_vector)],
+ )
+ docs = []
+ score_threshold = float(kwargs.get("score_threshold") or 0.0)
+ for record in cur:
+ metadata, text, distance = record
+ score = 1 - distance
+ metadata["score"] = score
+ if score > score_threshold:
+ docs.append(Document(page_content=text, metadata=metadata))
+ conn.close()
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -228,7 +250,7 @@ class OracleVector(BaseVector):
top_k = kwargs.get("top_k", 5)
# just not implement fetch by score_threshold now, may be later
- # score_threshold = float(kwargs.get("score_threshold") or 0.0)
+ score_threshold = float(kwargs.get("score_threshold") or 0.0)
if len(query) > 0:
# Check which language the query is in
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
@@ -239,7 +261,7 @@ class OracleVector(BaseVector):
words = pseg.cut(query)
current_entity = ""
for word, pos in words:
- if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名,ns: 地名,nt: 机构名
+ if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名
current_entity += word
else:
if current_entity:
@@ -260,30 +282,35 @@ class OracleVector(BaseVector):
for token in all_tokens:
if token not in stop_words:
entities.append(token)
- with self._get_cursor() as cur:
- document_ids_filter = kwargs.get("document_ids_filter")
- where_clause = ""
- if document_ids_filter:
- document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
- where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
- cur.execute(
- f"select meta, text, embedding FROM {self.table_name}"
- f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} "
- f"order by score(1) desc fetch first {top_k} rows only",
- [" ACCUM ".join(entities)],
- )
- docs = []
- for record in cur:
- metadata, text, embedding = record
- docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
+ with self._get_connection() as conn:
+ with conn.cursor() as cur:
+ document_ids_filter = kwargs.get("document_ids_filter")
+ where_clause = ""
+ if document_ids_filter:
+ document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
+ where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
+ cur.execute(
+ f"""select meta, text, embedding FROM {self.table_name}
+ WHERE CONTAINS(text, :kk, 1) > 0 {where_clause}
+ order by score(1) desc fetch first {top_k} rows only""",
+ kk=" ACCUM ".join(entities),
+ )
+ docs = []
+ for record in cur:
+ metadata, text, embedding = record
+ docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
+ conn.close()
return docs
else:
return [Document(page_content="", metadata={})]
return []
def delete(self) -> None:
- with self._get_cursor() as cur:
- cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints")
+ with self._get_connection() as conn:
+ with conn.cursor() as cur:
+ cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints")
+ conn.commit()
+ conn.close()
def _create_collection(self, dimension: int):
cache_key = f"vector_indexing_{self._collection_name}"
@@ -293,11 +320,14 @@ class OracleVector(BaseVector):
if redis_client.get(collection_exist_cache_key):
return
- with self._get_cursor() as cur:
- cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name))
- redis_client.set(collection_exist_cache_key, 1, ex=3600)
- with self._get_cursor() as cur:
- cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
+ with self._get_connection() as conn:
+ with conn.cursor() as cur:
+ cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name))
+ redis_client.set(collection_exist_cache_key, 1, ex=3600)
+ with conn.cursor() as cur:
+ cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
+ conn.commit()
+ conn.close()
class OracleVectorFactory(AbstractVectorFactory):
diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py
index 70c618a631..edaa8c92fa 100644
--- a/api/core/rag/extractor/word_extractor.py
+++ b/api/core/rag/extractor/word_extractor.py
@@ -126,9 +126,7 @@ class WordExtractor(BaseExtractor):
db.session.add(upload_file)
db.session.commit()
- image_map[rel.target_part] = (
- f""
- )
+ image_map[rel.target_part] = f""
return image_map
diff --git a/api/core/repository/__init__.py b/api/core/repository/__init__.py
new file mode 100644
index 0000000000..253df1251d
--- /dev/null
+++ b/api/core/repository/__init__.py
@@ -0,0 +1,15 @@
+"""
+Repository interfaces for data access.
+
+This package contains repository interfaces that define the contract
+for accessing and manipulating data, regardless of the underlying
+storage mechanism.
+"""
+
+from core.repository.repository_factory import RepositoryFactory
+from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+
+__all__ = [
+ "RepositoryFactory",
+ "WorkflowNodeExecutionRepository",
+]
diff --git a/api/core/repository/repository_factory.py b/api/core/repository/repository_factory.py
new file mode 100644
index 0000000000..7da7e49055
--- /dev/null
+++ b/api/core/repository/repository_factory.py
@@ -0,0 +1,97 @@
+"""
+Repository factory for creating repository instances.
+
+This module provides a simple factory interface for creating repository instances.
+It does not contain any implementation details or dependencies on specific repositories.
+"""
+
+from collections.abc import Callable, Mapping
+from typing import Any, Literal, Optional, cast
+
+from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+
+# Type for factory functions - takes a dict of parameters and returns any repository type
+RepositoryFactoryFunc = Callable[[Mapping[str, Any]], Any]
+
+# Type for workflow node execution factory function
+WorkflowNodeExecutionFactoryFunc = Callable[[Mapping[str, Any]], WorkflowNodeExecutionRepository]
+
+# Repository type literals
+_RepositoryType = Literal["workflow_node_execution"]
+
+
+class RepositoryFactory:
+ """
+ Factory class for creating repository instances.
+
+ This factory delegates the actual repository creation to implementation-specific
+ factory functions that are registered with the factory at runtime.
+ """
+
+ # Dictionary to store factory functions
+ _factory_functions: dict[str, RepositoryFactoryFunc] = {}
+
+ @classmethod
+ def _register_factory(cls, repository_type: _RepositoryType, factory_func: RepositoryFactoryFunc) -> None:
+ """
+ Register a factory function for a specific repository type.
+ This is a private method and should not be called directly.
+
+ Args:
+ repository_type: The type of repository (e.g., 'workflow_node_execution')
+ factory_func: A function that takes parameters and returns a repository instance
+ """
+ cls._factory_functions[repository_type] = factory_func
+
+ @classmethod
+ def _create_repository(cls, repository_type: _RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any:
+ """
+ Create a new repository instance with the provided parameters.
+ This is a private method and should not be called directly.
+
+ Args:
+ repository_type: The type of repository to create
+ params: A dictionary of parameters to pass to the factory function
+
+ Returns:
+ A new instance of the requested repository
+
+ Raises:
+ ValueError: If no factory function is registered for the repository type
+ """
+ if repository_type not in cls._factory_functions:
+ raise ValueError(f"No factory function registered for repository type '{repository_type}'")
+
+ # Use empty dict if params is None
+ params = params or {}
+
+ return cls._factory_functions[repository_type](params)
+
+ @classmethod
+ def register_workflow_node_execution_factory(cls, factory_func: WorkflowNodeExecutionFactoryFunc) -> None:
+ """
+ Register a factory function for the workflow node execution repository.
+
+ Args:
+ factory_func: A function that takes parameters and returns a WorkflowNodeExecutionRepository instance
+ """
+ cls._register_factory("workflow_node_execution", factory_func)
+
+ @classmethod
+ def create_workflow_node_execution_repository(
+ cls, params: Optional[Mapping[str, Any]] = None
+ ) -> WorkflowNodeExecutionRepository:
+ """
+ Create a new WorkflowNodeExecutionRepository instance with the provided parameters.
+
+ Args:
+ params: A dictionary of parameters to pass to the factory function
+
+ Returns:
+ A new instance of the WorkflowNodeExecutionRepository
+
+ Raises:
+ ValueError: If no factory function is registered for the workflow_node_execution repository type
+ """
+ # We can safely cast here because we've registered a WorkflowNodeExecutionFactoryFunc
+ return cast(WorkflowNodeExecutionRepository, cls._create_repository("workflow_node_execution", params))
diff --git a/api/core/repository/workflow_node_execution_repository.py b/api/core/repository/workflow_node_execution_repository.py
new file mode 100644
index 0000000000..9bb790cb0f
--- /dev/null
+++ b/api/core/repository/workflow_node_execution_repository.py
@@ -0,0 +1,97 @@
+from collections.abc import Sequence
+from dataclasses import dataclass
+from typing import Literal, Optional, Protocol
+
+from models.workflow import WorkflowNodeExecution
+
+
+@dataclass
+class OrderConfig:
+ """Configuration for ordering WorkflowNodeExecution instances."""
+
+ order_by: list[str]
+ order_direction: Optional[Literal["asc", "desc"]] = None
+
+
+class WorkflowNodeExecutionRepository(Protocol):
+ """
+ Repository interface for WorkflowNodeExecution.
+
+ This interface defines the contract for accessing and manipulating
+ WorkflowNodeExecution data, regardless of the underlying storage mechanism.
+
+ Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
+ and trigger sources (triggered_from) should be handled at the implementation level, not in
+ the core interface. This keeps the core domain model clean and independent of specific
+ application domains or deployment scenarios.
+ """
+
+ def save(self, execution: WorkflowNodeExecution) -> None:
+ """
+ Save a WorkflowNodeExecution instance.
+
+ Args:
+ execution: The WorkflowNodeExecution instance to save
+ """
+ ...
+
+ def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
+ """
+ Retrieve a WorkflowNodeExecution by its node_execution_id.
+
+ Args:
+ node_execution_id: The node execution ID
+
+ Returns:
+ The WorkflowNodeExecution instance if found, None otherwise
+ """
+ ...
+
+ def get_by_workflow_run(
+ self,
+ workflow_run_id: str,
+ order_config: Optional[OrderConfig] = None,
+ ) -> Sequence[WorkflowNodeExecution]:
+ """
+ Retrieve all WorkflowNodeExecution instances for a specific workflow run.
+
+ Args:
+ workflow_run_id: The workflow run ID
+ order_config: Optional configuration for ordering results
+ order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
+ order_config.order_direction: Direction to order ("asc" or "desc")
+
+ Returns:
+ A list of WorkflowNodeExecution instances
+ """
+ ...
+
+ def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
+ """
+ Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
+
+ Args:
+ workflow_run_id: The workflow run ID
+
+ Returns:
+ A list of running WorkflowNodeExecution instances
+ """
+ ...
+
+ def update(self, execution: WorkflowNodeExecution) -> None:
+ """
+ Update an existing WorkflowNodeExecution instance.
+
+ Args:
+ execution: The WorkflowNodeExecution instance to update
+ """
+ ...
+
+ def clear(self) -> None:
+ """
+ Clear all WorkflowNodeExecution records based on implementation-specific criteria.
+
+ This method is intended to be used for bulk deletion operations, such as removing
+ all records associated with a specific app_id and tenant_id in multi-tenant implementations.
+ """
+ ...
diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py
index f661294ec4..f5838c3b76 100644
--- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py
+++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py
@@ -94,7 +94,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
"title": item.metadata.get("title"),
"content": item.page_content,
}
- context_list.append(source)
+ context_list.append(source)
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py
index 7c8960fe49..da40cbcdea 100644
--- a/api/core/workflow/nodes/agent/agent_node.py
+++ b/api/core/workflow/nodes/agent/agent_node.py
@@ -16,7 +16,7 @@ from core.variables.segments import StringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
-from core.workflow.nodes.agent.entities import AgentNodeData, ParamsAutoGenerated
+from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import RunCompletedEvent
@@ -251,7 +251,12 @@ class AgentNode(ToolNode):
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
]
value["history_prompt_messages"] = history_prompt_messages
- value["entity"] = model_schema.model_dump(mode="json") if model_schema else None
+ if model_schema:
+ # remove structured output feature to support old version agent plugin
+ model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
+ value["entity"] = model_schema.model_dump(mode="json")
+ else:
+ value["entity"] = None
result[parameter_name] = value
return result
@@ -348,3 +353,10 @@ class AgentNode(ToolNode):
)
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_instance, model_schema
+
+ def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity:
+ if model_schema.features:
+ for feature in model_schema.features:
+ if feature.value not in AgentOldVersionModelFeatures:
+ model_schema.features.remove(feature)
+ return model_schema
diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py
index 87cc7e9824..77e94375bf 100644
--- a/api/core/workflow/nodes/agent/entities.py
+++ b/api/core/workflow/nodes/agent/entities.py
@@ -24,3 +24,18 @@ class AgentNodeData(BaseNodeData):
class ParamsAutoGenerated(Enum):
CLOSE = 0
OPEN = 1
+
+
+class AgentOldVersionModelFeatures(Enum):
+ """
+ Enum class for old SDK version llm feature.
+ """
+
+ TOOL_CALL = "tool-call"
+ MULTI_TOOL_CALL = "multi-tool-call"
+ AGENT_THOUGHT = "agent-thought"
+ VISION = "vision"
+ STREAM_TOOL_CALL = "stream-tool-call"
+ DOCUMENT = "document"
+ VIDEO = "video"
+ AUDIO = "audio"
diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py
index bf54fdb80c..486b4b01af 100644
--- a/api/core/workflow/nodes/llm/entities.py
+++ b/api/core/workflow/nodes/llm/entities.py
@@ -65,6 +65,8 @@ class LLMNodeData(BaseNodeData):
memory: Optional[MemoryConfig] = None
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
+ structured_output: dict | None = None
+ structured_output_enabled: bool = False
@field_validator("prompt_config", mode="before")
@classmethod
diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py
index fe0ed3e564..8db7394e54 100644
--- a/api/core/workflow/nodes/llm/node.py
+++ b/api/core/workflow/nodes/llm/node.py
@@ -4,6 +4,8 @@ from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Optional, cast
+import json_repair
+
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus
@@ -27,7 +29,13 @@ from core.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
-from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
+from core.model_runtime.entities.model_entities import (
+ AIModelEntity,
+ ModelFeature,
+ ModelPropertyKey,
+ ModelType,
+ ParameterRule,
+)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ModelProviderID
@@ -57,6 +65,12 @@ from core.workflow.nodes.event import (
RunRetrieverResourceEvent,
RunStreamChunkEvent,
)
+from core.workflow.utils.structured_output.entities import (
+ ResponseFormat,
+ SpecialModelType,
+ SupportStructuredOutputStatus,
+)
+from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.model import Conversation
@@ -92,6 +106,12 @@ class LLMNode(BaseNode[LLMNodeData]):
_node_type = NodeType.LLM
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
+ def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]:
+ """Process structured output if enabled"""
+ if not self.node_data.structured_output_enabled or not self.node_data.structured_output:
+ return None
+ return self._parse_structured_output(text)
+
node_inputs: Optional[dict[str, Any]] = None
process_data = None
result_text = ""
@@ -130,7 +150,6 @@ class LLMNode(BaseNode[LLMNodeData]):
if isinstance(event, RunRetrieverResourceEvent):
context = event.context
yield event
-
if context:
node_inputs["#context#"] = context
@@ -192,7 +211,9 @@ class LLMNode(BaseNode[LLMNodeData]):
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
break
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
-
+ structured_output = process_structured_output(result_text)
+ if structured_output:
+ outputs["structured_output"] = structured_output
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -513,7 +534,12 @@ class LLMNode(BaseNode[LLMNodeData]):
if not model_schema:
raise ModelNotExistError(f"Model {model_name} not exist.")
-
+ support_structured_output = self._check_model_structured_output_support()
+ if support_structured_output == SupportStructuredOutputStatus.SUPPORTED:
+ completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
+ elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
+ # Set appropriate response format based on model capabilities
+ self._set_response_format(completion_params, model_schema.parameter_rules)
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
model=model_name,
@@ -724,10 +750,29 @@ class LLMNode(BaseNode[LLMNodeData]):
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
-
+ support_structured_output = self._check_model_structured_output_support()
+ if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
+ filtered_prompt_messages = self._handle_prompt_based_schema(
+ prompt_messages=filtered_prompt_messages,
+ )
stop = model_config.stop
return filtered_prompt_messages, stop
+ def _parse_structured_output(self, result_text: str) -> dict[str, Any] | list[Any]:
+ structured_output: dict[str, Any] | list[Any] = {}
+ try:
+ parsed = json.loads(result_text)
+ if not isinstance(parsed, (dict | list)):
+ raise LLMNodeError(f"Failed to parse structured output: {result_text}")
+ structured_output = parsed
+ except json.JSONDecodeError as e:
+ # if the result_text is not a valid json, try to repair it
+ parsed = json_repair.loads(result_text)
+ if not isinstance(parsed, (dict | list)):
+ raise LLMNodeError(f"Failed to parse structured output: {result_text}")
+ structured_output = parsed
+ return structured_output
+
@classmethod
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
provider_model_bundle = model_instance.provider_model_bundle
@@ -926,6 +971,166 @@ class LLMNode(BaseNode[LLMNodeData]):
return prompt_messages
+ def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict:
+ """
+ Handle structured output for models with native JSON schema support.
+
+ :param model_parameters: Model parameters to update
+ :param rules: Model parameter rules
+ :return: Updated model parameters with JSON schema configuration
+ """
+ # Process schema according to model requirements
+ schema = self._fetch_structured_output_schema()
+ schema_json = self._prepare_schema_for_model(schema)
+
+ # Set JSON schema in parameters
+ model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
+
+ # Set appropriate response format if required by the model
+ for rule in rules:
+ if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
+ model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
+
+ return model_parameters
+
+ def _handle_prompt_based_schema(self, prompt_messages: Sequence[PromptMessage]) -> list[PromptMessage]:
+ """
+ Handle structured output for models without native JSON schema support.
+ This function modifies the prompt messages to include schema-based output requirements.
+
+ Args:
+ prompt_messages: Original sequence of prompt messages
+
+ Returns:
+ list[PromptMessage]: Updated prompt messages with structured output requirements
+ """
+ # Convert schema to string format
+ schema_str = json.dumps(self._fetch_structured_output_schema(), ensure_ascii=False)
+
+ # Find existing system prompt with schema placeholder
+ system_prompt = next(
+ (prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)),
+ None,
+ )
+ structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str)
+ # Prepare system prompt content
+ system_prompt_content = (
+ structured_output_prompt + "\n\n" + system_prompt.content
+ if system_prompt and isinstance(system_prompt.content, str)
+ else structured_output_prompt
+ )
+ system_prompt = SystemPromptMessage(content=system_prompt_content)
+
+ # Extract content from the last user message
+
+ filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)]
+ updated_prompt = [system_prompt] + filtered_prompts
+
+ return updated_prompt
+
+ def _set_response_format(self, model_parameters: dict, rules: list) -> None:
+ """
+ Set the appropriate response format parameter based on model rules.
+
+ :param model_parameters: Model parameters to update
+ :param rules: Model parameter rules
+ """
+ for rule in rules:
+ if rule.name == "response_format":
+ if ResponseFormat.JSON.value in rule.options:
+ model_parameters["response_format"] = ResponseFormat.JSON.value
+ elif ResponseFormat.JSON_OBJECT.value in rule.options:
+ model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
+
+ def _prepare_schema_for_model(self, schema: dict) -> dict:
+ """
+ Prepare JSON schema based on model requirements.
+
+ Different models have different requirements for JSON schema formatting.
+ This function handles these differences.
+
+ :param schema: The original JSON schema
+ :return: Processed schema compatible with the current model
+ """
+
+ # Deep copy to avoid modifying the original schema
+ processed_schema = schema.copy()
+
+ # Convert boolean types to string types (common requirement)
+ convert_boolean_to_string(processed_schema)
+
+ # Apply model-specific transformations
+ if SpecialModelType.GEMINI in self.node_data.model.name:
+ remove_additional_properties(processed_schema)
+ return processed_schema
+ elif SpecialModelType.OLLAMA in self.node_data.model.provider:
+ return processed_schema
+ else:
+ # Default format with name field
+ return {"schema": processed_schema, "name": "llm_response"}
+
+ def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
+ """
+ Fetch model schema
+ """
+ model_name = self.node_data.model.name
+ model_manager = ModelManager()
+ model_instance = model_manager.get_model_instance(
+ tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name
+ )
+ model_type_instance = model_instance.model_type_instance
+ model_type_instance = cast(LargeLanguageModel, model_type_instance)
+ model_credentials = model_instance.credentials
+ model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
+ return model_schema
+
+ def _fetch_structured_output_schema(self) -> dict[str, Any]:
+ """
+ Fetch the structured output schema from the node data.
+
+ Returns:
+ dict[str, Any]: The structured output schema
+ """
+ if not self.node_data.structured_output:
+ raise LLMNodeError("Please provide a valid structured output schema")
+ structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False)
+ if not structured_output_schema:
+ raise LLMNodeError("Please provide a valid structured output schema")
+
+ try:
+ schema = json.loads(structured_output_schema)
+ if not isinstance(schema, dict):
+ raise LLMNodeError("structured_output_schema must be a JSON object")
+ return schema
+ except json.JSONDecodeError:
+ raise LLMNodeError("structured_output_schema is not valid JSON format")
+
+ def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus:
+ """
+ Check if the current model supports structured output.
+
+ Returns:
+ SupportStructuredOutput: The support status of structured output
+ """
+ # Early return if structured output is disabled
+ if (
+ not isinstance(self.node_data, LLMNodeData)
+ or not self.node_data.structured_output_enabled
+ or not self.node_data.structured_output
+ ):
+ return SupportStructuredOutputStatus.DISABLED
+ # Get model schema and check if it exists
+ model_schema = self._fetch_model_schema(self.node_data.model.provider)
+ if not model_schema:
+ return SupportStructuredOutputStatus.DISABLED
+
+ # Check if model supports structured output feature
+ return (
+ SupportStructuredOutputStatus.SUPPORTED
+ if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features)
+ else SupportStructuredOutputStatus.UNSUPPORTED
+ )
+
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
match role:
@@ -1064,3 +1269,49 @@ def _handle_completion_template(
)
prompt_messages.append(prompt_message)
return prompt_messages
+
+
+def remove_additional_properties(schema: dict) -> None:
+ """
+ Remove additionalProperties fields from JSON schema.
+ Used for models like Gemini that don't support this property.
+
+ :param schema: JSON schema to modify in-place
+ """
+ if not isinstance(schema, dict):
+ return
+
+ # Remove additionalProperties at current level
+ schema.pop("additionalProperties", None)
+
+ # Process nested structures recursively
+ for value in schema.values():
+ if isinstance(value, dict):
+ remove_additional_properties(value)
+ elif isinstance(value, list):
+ for item in value:
+ if isinstance(item, dict):
+ remove_additional_properties(item)
+
+
+def convert_boolean_to_string(schema: dict) -> None:
+ """
+ Convert boolean type specifications to string in JSON schema.
+
+ :param schema: JSON schema to modify in-place
+ """
+ if not isinstance(schema, dict):
+ return
+
+ # Check for boolean type at current level
+ if schema.get("type") == "boolean":
+ schema["type"] = "string"
+
+ # Process nested dictionaries and lists recursively
+ for value in schema.values():
+ if isinstance(value, dict):
+ convert_boolean_to_string(value)
+ elif isinstance(value, list):
+ for item in value:
+ if isinstance(item, dict):
+ convert_boolean_to_string(item)
diff --git a/api/core/workflow/utils/structured_output/entities.py b/api/core/workflow/utils/structured_output/entities.py
new file mode 100644
index 0000000000..7954acbaee
--- /dev/null
+++ b/api/core/workflow/utils/structured_output/entities.py
@@ -0,0 +1,24 @@
+from enum import StrEnum
+
+
+class ResponseFormat(StrEnum):
+ """Constants for model response formats"""
+
+ JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode.
+ JSON = "JSON" # model's json mode. some model like claude support this mode.
+ JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias.
+
+
+class SpecialModelType(StrEnum):
+ """Constants for identifying model types"""
+
+ GEMINI = "gemini"
+ OLLAMA = "ollama"
+
+
+class SupportStructuredOutputStatus(StrEnum):
+ """Constants for structured output support status"""
+
+ SUPPORTED = "supported"
+ UNSUPPORTED = "unsupported"
+ DISABLED = "disabled"
diff --git a/api/core/workflow/utils/structured_output/prompt.py b/api/core/workflow/utils/structured_output/prompt.py
new file mode 100644
index 0000000000..06d9b2056e
--- /dev/null
+++ b/api/core/workflow/utils/structured_output/prompt.py
@@ -0,0 +1,17 @@
+STRUCTURED_OUTPUT_PROMPT = """You’re a helpful AI assistant. You could answer questions and output in JSON format.
+constraints:
+ - You must output in JSON format.
+ - Do not output boolean value, use string type instead.
+ - Do not output integer or float value, use number type instead.
+eg:
+ Here is the JSON schema:
+ {"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"}
+
+ Here is the user's question:
+ My name is John Doe and I am 30 years old.
+
+ output:
+ {"name": "John Doe", "age": 30}
+Here is the JSON schema:
+{{schema}}
+""" # noqa: E501
diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py
index 422ec87765..aa55862b7c 100644
--- a/api/extensions/ext_logging.py
+++ b/api/extensions/ext_logging.py
@@ -26,9 +26,12 @@ def init_app(app: DifyApp):
# Always add StreamHandler to log to console
sh = logging.StreamHandler(sys.stdout)
- sh.addFilter(RequestIdFilter())
log_handlers.append(sh)
+ # Apply RequestIdFilter to all handlers
+ for handler in log_handlers:
+ handler.addFilter(RequestIdFilter())
+
logging.basicConfig(
level=dify_config.LOG_LEVEL,
format=dify_config.LOG_FORMAT,
diff --git a/api/extensions/ext_repositories.py b/api/extensions/ext_repositories.py
new file mode 100644
index 0000000000..27d8408ec1
--- /dev/null
+++ b/api/extensions/ext_repositories.py
@@ -0,0 +1,18 @@
+"""
+Extension for initializing repositories.
+
+This extension registers repository implementations with the RepositoryFactory.
+"""
+
+from dify_app import DifyApp
+from repositories.repository_registry import register_repositories
+
+
+def init_app(_app: DifyApp) -> None:
+ """
+ Initialize repository implementations.
+
+ Args:
+ _app: The Flask application instance (unused)
+ """
+ register_repositories()
diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py
index 588bdb2d27..4c811c66ba 100644
--- a/api/extensions/ext_storage.py
+++ b/api/extensions/ext_storage.py
@@ -73,11 +73,7 @@ class Storage:
raise ValueError(f"unsupported storage type {storage_type}")
def save(self, filename, data):
- try:
- self.storage_runner.save(filename, data)
- except Exception as e:
- logger.exception(f"Failed to save file {filename}")
- raise e
+ self.storage_runner.save(filename, data)
@overload
def load(self, filename: str, /, *, stream: Literal[False] = False) -> bytes: ...
@@ -86,49 +82,25 @@ class Storage:
def load(self, filename: str, /, *, stream: Literal[True]) -> Generator: ...
def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]:
- try:
- if stream:
- return self.load_stream(filename)
- else:
- return self.load_once(filename)
- except Exception as e:
- logger.exception(f"Failed to load file {filename}")
- raise e
+ if stream:
+ return self.load_stream(filename)
+ else:
+ return self.load_once(filename)
def load_once(self, filename: str) -> bytes:
- try:
- return self.storage_runner.load_once(filename)
- except Exception as e:
- logger.exception(f"Failed to load_once file {filename}")
- raise e
+ return self.storage_runner.load_once(filename)
def load_stream(self, filename: str) -> Generator:
- try:
- return self.storage_runner.load_stream(filename)
- except Exception as e:
- logger.exception(f"Failed to load_stream file {filename}")
- raise e
+ return self.storage_runner.load_stream(filename)
def download(self, filename, target_filepath):
- try:
- self.storage_runner.download(filename, target_filepath)
- except Exception as e:
- logger.exception(f"Failed to download file {filename}")
- raise e
+ self.storage_runner.download(filename, target_filepath)
def exists(self, filename):
- try:
- return self.storage_runner.exists(filename)
- except Exception as e:
- logger.exception(f"Failed to check file exists {filename}")
- raise e
+ return self.storage_runner.exists(filename)
def delete(self, filename):
- try:
- return self.storage_runner.delete(filename)
- except Exception as e:
- logger.exception(f"Failed to delete file {filename}")
- raise e
+ return self.storage_runner.delete(filename)
storage = Storage()
diff --git a/api/models/model.py b/api/models/model.py
index a826d13e7d..6577492d1b 100644
--- a/api/models/model.py
+++ b/api/models/model.py
@@ -1091,12 +1091,7 @@ class Message(db.Model): # type: ignore[name-defined]
@property
def retriever_resources(self):
- return (
- db.session.query(DatasetRetrieverResource)
- .filter(DatasetRetrieverResource.message_id == self.id)
- .order_by(DatasetRetrieverResource.position.asc())
- .all()
- )
+ return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []
@property
def message_files(self):
diff --git a/api/models/workflow.py b/api/models/workflow.py
index 8b7c376e4b..5a67fa47a8 100644
--- a/api/models/workflow.py
+++ b/api/models/workflow.py
@@ -245,6 +245,13 @@ class Workflow(Base):
@property
def tool_published(self) -> bool:
+ """
+ DEPRECATED: This property is not accurate for determining if a workflow is published as a tool.
+ It only checks if there's a WorkflowToolProvider for the app, not if this specific workflow version
+ is the one being used by the tool.
+
+ For accurate checking, use a direct query with tenant_id, app_id, and version.
+ """
from models.tools import WorkflowToolProvider
return (
@@ -510,7 +517,7 @@ class WorkflowRun(Base):
)
-class WorkflowNodeExecutionTriggeredFrom(Enum):
+class WorkflowNodeExecutionTriggeredFrom(StrEnum):
"""
Workflow Node Execution Triggered From Enum
"""
@@ -518,21 +525,8 @@ class WorkflowNodeExecutionTriggeredFrom(Enum):
SINGLE_STEP = "single-step"
WORKFLOW_RUN = "workflow-run"
- @classmethod
- def value_of(cls, value: str) -> "WorkflowNodeExecutionTriggeredFrom":
- """
- Get value of given mode.
-
- :param value: mode value
- :return: mode
- """
- for mode in cls:
- if mode.value == value:
- return mode
- raise ValueError(f"invalid workflow node execution triggered from value {value}")
-
-class WorkflowNodeExecutionStatus(Enum):
+class WorkflowNodeExecutionStatus(StrEnum):
"""
Workflow Node Execution Status Enum
"""
@@ -543,19 +537,6 @@ class WorkflowNodeExecutionStatus(Enum):
EXCEPTION = "exception"
RETRY = "retry"
- @classmethod
- def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus":
- """
- Get value of given mode.
-
- :param value: mode value
- :return: mode
- """
- for mode in cls:
- if mode.value == value:
- return mode
- raise ValueError(f"invalid workflow node execution status value {value}")
-
class WorkflowNodeExecution(Base):
"""
@@ -656,6 +637,7 @@ class WorkflowNodeExecution(Base):
@property
def created_by_account(self):
created_by_role = CreatedByRole(self.created_by_role)
+ # TODO(-LAN-): Avoid using db.session.get() here.
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
@property
@@ -663,6 +645,7 @@ class WorkflowNodeExecution(Base):
from models.model import EndUser
created_by_role = CreatedByRole(self.created_by_role)
+ # TODO(-LAN-): Avoid using db.session.get() here.
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
@property
diff --git a/api/pyproject.toml b/api/pyproject.toml
index 85679a6359..4992178423 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -30,6 +30,7 @@ dependencies = [
"gunicorn~=23.0.0",
"httpx[socks]~=0.27.0",
"jieba==0.42.1",
+ "json-repair>=0.41.1",
"langfuse~=2.51.3",
"langsmith~=0.1.77",
"mailchimp-transactional~=1.0.50",
@@ -163,10 +164,7 @@ storage = [
############################################################
# [ Tools ] dependency group
############################################################
-tools = [
- "cloudscraper~=1.2.71",
- "nltk~=3.9.1",
-]
+tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"]
############################################################
# [ VDB ] dependency group
@@ -180,7 +178,7 @@ vdb = [
"couchbase~=4.3.0",
"elasticsearch==8.14.0",
"opensearch-py==2.4.0",
- "oracledb~=2.2.1",
+ "oracledb==3.0.0",
"pgvecto-rs[sqlalchemy]~=0.2.1",
"pgvector==0.2.5",
"pymilvus~=2.5.0",
diff --git a/api/repositories/__init__.py b/api/repositories/__init__.py
new file mode 100644
index 0000000000..4cc339688b
--- /dev/null
+++ b/api/repositories/__init__.py
@@ -0,0 +1,6 @@
+"""
+Repository implementations for data access.
+
+This package contains concrete implementations of the repository interfaces
+defined in the core.repository package.
+"""
diff --git a/api/repositories/repository_registry.py b/api/repositories/repository_registry.py
new file mode 100644
index 0000000000..aa0a208d8e
--- /dev/null
+++ b/api/repositories/repository_registry.py
@@ -0,0 +1,87 @@
+"""
+Registry for repository implementations.
+
+This module is responsible for registering factory functions with the repository factory.
+"""
+
+import logging
+from collections.abc import Mapping
+from typing import Any
+
+from sqlalchemy.orm import sessionmaker
+
+from configs import dify_config
+from core.repository.repository_factory import RepositoryFactory
+from extensions.ext_database import db
+from repositories.workflow_node_execution import SQLAlchemyWorkflowNodeExecutionRepository
+
+logger = logging.getLogger(__name__)
+
+# Storage type constants
+STORAGE_TYPE_RDBMS = "rdbms"
+STORAGE_TYPE_HYBRID = "hybrid"
+
+
+def register_repositories() -> None:
+ """
+ Register repository factory functions with the RepositoryFactory.
+
+ This function reads configuration settings to determine which repository
+ implementations to register.
+ """
+ # Configure WorkflowNodeExecutionRepository factory based on configuration
+ workflow_node_execution_storage = dify_config.WORKFLOW_NODE_EXECUTION_STORAGE
+
+ # Check storage type and register appropriate implementation
+ if workflow_node_execution_storage == STORAGE_TYPE_RDBMS:
+ # Register SQLAlchemy implementation for RDBMS storage
+ logger.info("Registering WorkflowNodeExecution repository with RDBMS storage")
+ RepositoryFactory.register_workflow_node_execution_factory(create_workflow_node_execution_repository)
+ elif workflow_node_execution_storage == STORAGE_TYPE_HYBRID:
+ # Hybrid storage is not yet implemented
+ raise NotImplementedError("Hybrid storage for WorkflowNodeExecution repository is not yet implemented")
+ else:
+ # Unknown storage type
+ raise ValueError(
+ f"Unknown storage type '{workflow_node_execution_storage}' for WorkflowNodeExecution repository. "
+ f"Supported types: {STORAGE_TYPE_RDBMS}"
+ )
+
+
+def create_workflow_node_execution_repository(params: Mapping[str, Any]) -> SQLAlchemyWorkflowNodeExecutionRepository:
+ """
+ Create a WorkflowNodeExecutionRepository instance using SQLAlchemy implementation.
+
+ This factory function creates a repository for the RDBMS storage type.
+
+ Args:
+ params: Parameters for creating the repository, including:
+ - tenant_id: Required. The tenant ID for multi-tenancy.
+ - app_id: Optional. The application ID for filtering.
+ - session_factory: Optional. A SQLAlchemy sessionmaker instance. If not provided,
+ a new sessionmaker will be created using the global database engine.
+
+ Returns:
+ A WorkflowNodeExecutionRepository instance
+
+ Raises:
+ ValueError: If required parameters are missing
+ """
+ # Extract required parameters
+ tenant_id = params.get("tenant_id")
+ if tenant_id is None:
+ raise ValueError("tenant_id is required for WorkflowNodeExecution repository with RDBMS storage")
+
+ # Extract optional parameters
+ app_id = params.get("app_id")
+
+ # Use the session_factory from params if provided, otherwise create one using the global db engine
+ session_factory = params.get("session_factory")
+ if session_factory is None:
+ # Create a sessionmaker using the same engine as the global db session
+ session_factory = sessionmaker(bind=db.engine)
+
+ # Create and return the repository
+ return SQLAlchemyWorkflowNodeExecutionRepository(
+ session_factory=session_factory, tenant_id=tenant_id, app_id=app_id
+ )
diff --git a/api/repositories/workflow_node_execution/__init__.py b/api/repositories/workflow_node_execution/__init__.py
new file mode 100644
index 0000000000..eed827bd05
--- /dev/null
+++ b/api/repositories/workflow_node_execution/__init__.py
@@ -0,0 +1,9 @@
+"""
+WorkflowNodeExecution repository implementations.
+"""
+
+from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
+
+__all__ = [
+ "SQLAlchemyWorkflowNodeExecutionRepository",
+]
diff --git a/api/repositories/workflow_node_execution/sqlalchemy_repository.py b/api/repositories/workflow_node_execution/sqlalchemy_repository.py
new file mode 100644
index 0000000000..0594d816a2
--- /dev/null
+++ b/api/repositories/workflow_node_execution/sqlalchemy_repository.py
@@ -0,0 +1,192 @@
+"""
+SQLAlchemy implementation of the WorkflowNodeExecutionRepository.
+"""
+
+import logging
+from collections.abc import Sequence
+from typing import Optional
+
+from sqlalchemy import UnaryExpression, asc, delete, desc, select
+from sqlalchemy.engine import Engine
+from sqlalchemy.orm import sessionmaker
+
+from core.repository.workflow_node_execution_repository import OrderConfig
+from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom
+
+logger = logging.getLogger(__name__)
+
+
+class SQLAlchemyWorkflowNodeExecutionRepository:
+ """
+ SQLAlchemy implementation of the WorkflowNodeExecutionRepository interface.
+
+ This implementation supports multi-tenancy by filtering operations based on tenant_id.
+ Each method creates its own session, handles the transaction, and commits changes
+ to the database. This prevents long-running connections in the workflow core.
+ """
+
+ def __init__(self, session_factory: sessionmaker | Engine, tenant_id: str, app_id: Optional[str] = None):
+ """
+ Initialize the repository with a SQLAlchemy sessionmaker or engine and tenant context.
+
+ Args:
+ session_factory: SQLAlchemy sessionmaker or engine for creating sessions
+ tenant_id: Tenant ID for multi-tenancy
+ app_id: Optional app ID for filtering by application
+ """
+ # If an engine is provided, create a sessionmaker from it
+ if isinstance(session_factory, Engine):
+ self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
+ else:
+ self._session_factory = session_factory
+
+ self._tenant_id = tenant_id
+ self._app_id = app_id
+
+ def save(self, execution: WorkflowNodeExecution) -> None:
+ """
+ Save a WorkflowNodeExecution instance and commit changes to the database.
+
+ Args:
+ execution: The WorkflowNodeExecution instance to save
+ """
+ with self._session_factory() as session:
+ # Ensure tenant_id is set
+ if not execution.tenant_id:
+ execution.tenant_id = self._tenant_id
+
+ # Set app_id if provided and not already set
+ if self._app_id and not execution.app_id:
+ execution.app_id = self._app_id
+
+ session.add(execution)
+ session.commit()
+
+ def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
+ """
+ Retrieve a WorkflowNodeExecution by its node_execution_id.
+
+ Args:
+ node_execution_id: The node execution ID
+
+ Returns:
+ The WorkflowNodeExecution instance if found, None otherwise
+ """
+ with self._session_factory() as session:
+ stmt = select(WorkflowNodeExecution).where(
+ WorkflowNodeExecution.node_execution_id == node_execution_id,
+ WorkflowNodeExecution.tenant_id == self._tenant_id,
+ )
+
+ if self._app_id:
+ stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
+
+ return session.scalar(stmt)
+
+ def get_by_workflow_run(
+ self,
+ workflow_run_id: str,
+ order_config: Optional[OrderConfig] = None,
+ ) -> Sequence[WorkflowNodeExecution]:
+ """
+ Retrieve all WorkflowNodeExecution instances for a specific workflow run.
+
+ Args:
+ workflow_run_id: The workflow run ID
+ order_config: Optional configuration for ordering results
+ order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
+ order_config.order_direction: Direction to order ("asc" or "desc")
+
+ Returns:
+ A list of WorkflowNodeExecution instances
+ """
+ with self._session_factory() as session:
+ stmt = select(WorkflowNodeExecution).where(
+ WorkflowNodeExecution.workflow_run_id == workflow_run_id,
+ WorkflowNodeExecution.tenant_id == self._tenant_id,
+ WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
+ )
+
+ if self._app_id:
+ stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
+
+ # Apply ordering if provided
+ if order_config and order_config.order_by:
+ order_columns: list[UnaryExpression] = []
+ for field in order_config.order_by:
+ column = getattr(WorkflowNodeExecution, field, None)
+ if not column:
+ continue
+ if order_config.order_direction == "desc":
+ order_columns.append(desc(column))
+ else:
+ order_columns.append(asc(column))
+
+ if order_columns:
+ stmt = stmt.order_by(*order_columns)
+
+ return session.scalars(stmt).all()
+
+ def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
+ """
+ Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
+
+ Args:
+ workflow_run_id: The workflow run ID
+
+ Returns:
+ A list of running WorkflowNodeExecution instances
+ """
+ with self._session_factory() as session:
+ stmt = select(WorkflowNodeExecution).where(
+ WorkflowNodeExecution.workflow_run_id == workflow_run_id,
+ WorkflowNodeExecution.tenant_id == self._tenant_id,
+ WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING,
+ WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
+ )
+
+ if self._app_id:
+ stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
+
+ return session.scalars(stmt).all()
+
+ def update(self, execution: WorkflowNodeExecution) -> None:
+ """
+ Update an existing WorkflowNodeExecution instance and commit changes to the database.
+
+ Args:
+ execution: The WorkflowNodeExecution instance to update
+ """
+ with self._session_factory() as session:
+ # Ensure tenant_id is set
+ if not execution.tenant_id:
+ execution.tenant_id = self._tenant_id
+
+ # Set app_id if provided and not already set
+ if self._app_id and not execution.app_id:
+ execution.app_id = self._app_id
+
+ session.merge(execution)
+ session.commit()
+
+ def clear(self) -> None:
+ """
+ Clear all WorkflowNodeExecution records for the current tenant_id and app_id.
+
+ This method deletes all WorkflowNodeExecution records that match the tenant_id
+ and app_id (if provided) associated with this repository instance.
+ """
+ with self._session_factory() as session:
+ stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id)
+
+ if self._app_id:
+ stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
+
+ result = session.execute(stmt)
+ session.commit()
+
+ deleted_count = result.rowcount
+ logger.info(
+ f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}"
+ + (f" and app {self._app_id}" if self._app_id else "")
+ )
diff --git a/api/services/account_service.py b/api/services/account_service.py
index ada8109067..f930ef910b 100644
--- a/api/services/account_service.py
+++ b/api/services/account_service.py
@@ -407,10 +407,8 @@ class AccountService:
raise PasswordResetRateLimitExceededError()
- code = "".join([str(random.randint(0, 9)) for _ in range(6)])
- token = TokenManager.generate_token(
- account=account, email=email, token_type="reset_password", additional_data={"code": code}
- )
+ code, token = cls.generate_reset_password_token(account_email, account)
+
send_reset_password_mail_task.delay(
language=language,
to=account_email,
@@ -419,6 +417,22 @@ class AccountService:
cls.reset_password_rate_limiter.increment_rate_limit(account_email)
return token
+ @classmethod
+ def generate_reset_password_token(
+ cls,
+ email: str,
+ account: Optional[Account] = None,
+ code: Optional[str] = None,
+ additional_data: dict[str, Any] = {},
+ ):
+ if not code:
+ code = "".join([str(random.randint(0, 9)) for _ in range(6)])
+ additional_data["code"] = code
+ token = TokenManager.generate_token(
+ account=account, email=email, token_type="reset_password", additional_data=additional_data
+ )
+ return code, token
+
@classmethod
def revoke_reset_password_token(cls, token: str):
TokenManager.revoke_token(token, "reset_password")
diff --git a/api/services/plugin/dependencies_analysis.py b/api/services/plugin/dependencies_analysis.py
index 778f05a0cd..07e624b4e8 100644
--- a/api/services/plugin/dependencies_analysis.py
+++ b/api/services/plugin/dependencies_analysis.py
@@ -1,3 +1,4 @@
+from configs import dify_config
from core.helper import marketplace
from core.plugin.entities.plugin import ModelProviderID, PluginDependency, PluginInstallationSource, ToolProviderID
from core.plugin.manager.plugin import PluginInstallationManager
@@ -111,6 +112,8 @@ class DependenciesAnalysisService:
Generate the latest version of dependencies
"""
dependencies = list(set(dependencies))
+ if not dify_config.MARKETPLACE_ENABLED:
+ return []
deps = marketplace.batch_fetch_plugin_manifests(dependencies)
return [
PluginDependency(
diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py
index 0ddd18ea27..ff3b33eecd 100644
--- a/api/services/workflow_run_service.py
+++ b/api/services/workflow_run_service.py
@@ -2,13 +2,14 @@ import threading
from typing import Optional
import contexts
+from core.repository import RepositoryFactory
+from core.repository.workflow_node_execution_repository import OrderConfig
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
from models.model import App
from models.workflow import (
WorkflowNodeExecution,
- WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
)
@@ -127,17 +128,17 @@ class WorkflowRunService:
if not workflow_run:
return []
- node_executions = (
- db.session.query(WorkflowNodeExecution)
- .filter(
- WorkflowNodeExecution.tenant_id == app_model.tenant_id,
- WorkflowNodeExecution.app_id == app_model.id,
- WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
- WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
- WorkflowNodeExecution.workflow_run_id == run_id,
- )
- .order_by(WorkflowNodeExecution.index.desc())
- .all()
+ # Use the repository to get the node executions
+ repository = RepositoryFactory.create_workflow_node_execution_repository(
+ params={
+ "tenant_id": app_model.tenant_id,
+ "app_id": app_model.id,
+ "session_factory": db.session.get_bind,
+ }
)
- return node_executions
+ # Use the repository to get the node executions with ordering
+ order_config = OrderConfig(order_by=["index"], order_direction="desc")
+ node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
+
+ return list(node_executions)
diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py
index 992942fc70..5cd5c55746 100644
--- a/api/services/workflow_service.py
+++ b/api/services/workflow_service.py
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.model_runtime.utils.encoders import jsonable_encoder
+from core.repository import RepositoryFactory
from core.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.errors import WorkflowNodeRunFailedError
@@ -27,6 +28,7 @@ from extensions.ext_database import db
from models.account import Account
from models.enums import CreatedByRole
from models.model import App, AppMode
+from models.tools import WorkflowToolProvider
from models.workflow import (
Workflow,
WorkflowNodeExecution,
@@ -282,8 +284,15 @@ class WorkflowService:
workflow_node_execution.created_by = account.id
workflow_node_execution.workflow_id = draft_workflow.id
- db.session.add(workflow_node_execution)
- db.session.commit()
+ # Use the repository to save the workflow node execution
+ repository = RepositoryFactory.create_workflow_node_execution_repository(
+ params={
+ "tenant_id": app_model.tenant_id,
+ "app_id": app_model.id,
+ "session_factory": db.session.get_bind,
+ }
+ )
+ repository.save(workflow_node_execution)
return workflow_node_execution
@@ -515,8 +524,19 @@ class WorkflowService:
# Cannot delete a workflow that's currently in use by an app
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'")
- # Check if this workflow is published as a tool
- if workflow.tool_published:
+ # Don't use workflow.tool_published as it's not accurate for specific workflow versions
+ # Check if there's a tool provider using this specific workflow version
+ tool_provider = (
+ session.query(WorkflowToolProvider)
+ .filter(
+ WorkflowToolProvider.tenant_id == workflow.tenant_id,
+ WorkflowToolProvider.app_id == workflow.app_id,
+ WorkflowToolProvider.version == workflow.version,
+ )
+ .first()
+ )
+
+ if tool_provider:
# Cannot delete a workflow that's published as a tool
raise WorkflowInUseError("Cannot delete workflow that is published as a tool")
diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py
index c3910e2be3..4542b1b923 100644
--- a/api/tasks/remove_app_and_related_data_task.py
+++ b/api/tasks/remove_app_and_related_data_task.py
@@ -7,6 +7,7 @@ from celery import shared_task # type: ignore
from sqlalchemy import delete
from sqlalchemy.exc import SQLAlchemyError
+from core.repository import RepositoryFactory
from extensions.ext_database import db
from models.dataset import AppDatasetJoin
from models.model import (
@@ -30,7 +31,7 @@ from models.model import (
)
from models.tools import WorkflowToolProvider
from models.web import PinnedConversation, SavedMessage
-from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun
+from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowRun
@shared_task(queue="app_deletion", bind=True, max_retries=3)
@@ -187,18 +188,20 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str):
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
- def del_workflow_node_execution(workflow_node_execution_id: str):
- db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).delete(
- synchronize_session=False
- )
-
- _delete_records(
- """select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
- {"tenant_id": tenant_id, "app_id": app_id},
- del_workflow_node_execution,
- "workflow node execution",
+ # Create a repository instance for WorkflowNodeExecution
+ repository = RepositoryFactory.create_workflow_node_execution_repository(
+ params={
+ "tenant_id": tenant_id,
+ "app_id": app_id,
+ "session_factory": db.session.get_bind,
+ }
)
+ # Use the clear method to delete all records for this tenant_id and app_id
+ repository.clear()
+
+ logging.info(click.style(f"Deleted workflow node executions for tenant {tenant_id} and app {app_id}", fg="green"))
+
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def del_workflow_app_log(workflow_app_log_id: str):
diff --git a/api/tests/unit_tests/core/model_runtime/__base/__init__.py b/api/tests/unit_tests/core/model_runtime/__base/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py b/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py
new file mode 100644
index 0000000000..93d8a20cac
--- /dev/null
+++ b/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py
@@ -0,0 +1,99 @@
+from unittest.mock import MagicMock, patch
+
+from core.model_runtime.entities.message_entities import AssistantPromptMessage
+from core.model_runtime.model_providers.__base.large_language_model import _increase_tool_call
+
+ToolCall = AssistantPromptMessage.ToolCall
+
+# CASE 1: Single tool call
+INPUTS_CASE_1 = [
+ ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
+ ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
+ ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
+]
+EXPECTED_CASE_1 = [
+ ToolCall(
+ id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
+ ),
+]
+
+# CASE 2: Tool call sequences where IDs are anchored to the first chunk (vLLM/SiliconFlow ...)
+INPUTS_CASE_2 = [
+ ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
+ ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
+ ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
+ ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
+ ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
+ ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
+]
+EXPECTED_CASE_2 = [
+ ToolCall(
+ id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
+ ),
+ ToolCall(
+ id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
+ ),
+]
+
+# CASE 3: Tool call sequences where IDs are anchored to every chunk (SGLang ...)
+INPUTS_CASE_3 = [
+ ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
+ ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
+ ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
+ ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
+ ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
+ ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
+]
+EXPECTED_CASE_3 = [
+ ToolCall(
+ id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
+ ),
+ ToolCall(
+ id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
+ ),
+]
+
+# CASE 4: Tool call sequences with no IDs
+INPUTS_CASE_4 = [
+ ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
+ ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
+ ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
+ ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
+ ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
+ ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
+]
+EXPECTED_CASE_4 = [
+ ToolCall(
+ id="RANDOM_ID_1",
+ type="function",
+ function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'),
+ ),
+ ToolCall(
+ id="RANDOM_ID_2",
+ type="function",
+ function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}'),
+ ),
+]
+
+
+def _run_case(inputs: list[ToolCall], expected: list[ToolCall]):
+ actual = []
+ _increase_tool_call(inputs, actual)
+ assert actual == expected
+
+
+def test__increase_tool_call():
+ # case 1:
+ _run_case(INPUTS_CASE_1, EXPECTED_CASE_1)
+
+ # case 2:
+ _run_case(INPUTS_CASE_2, EXPECTED_CASE_2)
+
+ # case 3:
+ _run_case(INPUTS_CASE_3, EXPECTED_CASE_3)
+
+ # case 4:
+ mock_id_generator = MagicMock()
+ mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4]
+ with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator):
+ _run_case(INPUTS_CASE_4, EXPECTED_CASE_4)
diff --git a/api/tests/unit_tests/repositories/__init__.py b/api/tests/unit_tests/repositories/__init__.py
new file mode 100644
index 0000000000..bc0d6e78c9
--- /dev/null
+++ b/api/tests/unit_tests/repositories/__init__.py
@@ -0,0 +1,3 @@
+"""
+Unit tests for repositories.
+"""
diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/__init__.py b/api/tests/unit_tests/repositories/workflow_node_execution/__init__.py
new file mode 100644
index 0000000000..78815a8d1a
--- /dev/null
+++ b/api/tests/unit_tests/repositories/workflow_node_execution/__init__.py
@@ -0,0 +1,3 @@
+"""
+Unit tests for workflow_node_execution repositories.
+"""
diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py
new file mode 100644
index 0000000000..36847f8a13
--- /dev/null
+++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py
@@ -0,0 +1,178 @@
+"""
+Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository.
+"""
+
+from unittest.mock import MagicMock
+
+import pytest
+from pytest_mock import MockerFixture
+from sqlalchemy.orm import Session, sessionmaker
+
+from core.repository.workflow_node_execution_repository import OrderConfig
+from models.workflow import WorkflowNodeExecution
+from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
+
+
+@pytest.fixture
+def session():
+ """Create a mock SQLAlchemy session."""
+ session = MagicMock(spec=Session)
+ # Configure the session to be used as a context manager
+ session.__enter__ = MagicMock(return_value=session)
+ session.__exit__ = MagicMock(return_value=None)
+
+ # Configure the session factory to return the session
+ session_factory = MagicMock(spec=sessionmaker)
+ session_factory.return_value = session
+ return session, session_factory
+
+
+@pytest.fixture
+def repository(session):
+ """Create a repository instance with test data."""
+ _, session_factory = session
+ tenant_id = "test-tenant"
+ app_id = "test-app"
+ return SQLAlchemyWorkflowNodeExecutionRepository(
+ session_factory=session_factory, tenant_id=tenant_id, app_id=app_id
+ )
+
+
+def test_save(repository, session):
+ """Test save method."""
+ session_obj, _ = session
+ # Create a mock execution
+ execution = MagicMock(spec=WorkflowNodeExecution)
+ execution.tenant_id = None
+ execution.app_id = None
+
+ # Call save method
+ repository.save(execution)
+
+ # Assert tenant_id and app_id are set
+ assert execution.tenant_id == repository._tenant_id
+ assert execution.app_id == repository._app_id
+
+ # Assert session.add was called
+ session_obj.add.assert_called_once_with(execution)
+
+
+def test_save_with_existing_tenant_id(repository, session):
+ """Test save method with existing tenant_id."""
+ session_obj, _ = session
+ # Create a mock execution with existing tenant_id
+ execution = MagicMock(spec=WorkflowNodeExecution)
+ execution.tenant_id = "existing-tenant"
+ execution.app_id = None
+
+ # Call save method
+ repository.save(execution)
+
+ # Assert tenant_id is not changed and app_id is set
+ assert execution.tenant_id == "existing-tenant"
+ assert execution.app_id == repository._app_id
+
+ # Assert session.add was called
+ session_obj.add.assert_called_once_with(execution)
+
+
+def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
+ """Test get_by_node_execution_id method."""
+ session_obj, _ = session
+ # Set up mock
+ mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
+ mock_stmt = mocker.MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+ session_obj.scalar.return_value = mocker.MagicMock(spec=WorkflowNodeExecution)
+
+ # Call method
+ result = repository.get_by_node_execution_id("test-node-execution-id")
+
+ # Assert select was called with correct parameters
+ mock_select.assert_called_once()
+ session_obj.scalar.assert_called_once_with(mock_stmt)
+ assert result is not None
+
+
+def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
+ """Test get_by_workflow_run method."""
+ session_obj, _ = session
+ # Set up mock
+ mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
+ mock_stmt = mocker.MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+ mock_stmt.order_by.return_value = mock_stmt
+ session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)]
+
+ # Call method
+ order_config = OrderConfig(order_by=["index"], order_direction="desc")
+ result = repository.get_by_workflow_run(workflow_run_id="test-workflow-run-id", order_config=order_config)
+
+ # Assert select was called with correct parameters
+ mock_select.assert_called_once()
+ session_obj.scalars.assert_called_once_with(mock_stmt)
+ assert len(result) == 1
+
+
+def test_get_running_executions(repository, session, mocker: MockerFixture):
+ """Test get_running_executions method."""
+ session_obj, _ = session
+ # Set up mock
+ mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
+ mock_stmt = mocker.MagicMock()
+ mock_select.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+ session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)]
+
+ # Call method
+ result = repository.get_running_executions("test-workflow-run-id")
+
+ # Assert select was called with correct parameters
+ mock_select.assert_called_once()
+ session_obj.scalars.assert_called_once_with(mock_stmt)
+ assert len(result) == 1
+
+
+def test_update(repository, session):
+ """Test update method."""
+ session_obj, _ = session
+ # Create a mock execution
+ execution = MagicMock(spec=WorkflowNodeExecution)
+ execution.tenant_id = None
+ execution.app_id = None
+
+ # Call update method
+ repository.update(execution)
+
+ # Assert tenant_id and app_id are set
+ assert execution.tenant_id == repository._tenant_id
+ assert execution.app_id == repository._app_id
+
+ # Assert session.merge was called
+ session_obj.merge.assert_called_once_with(execution)
+
+
+def test_clear(repository, session, mocker: MockerFixture):
+ """Test clear method."""
+ session_obj, _ = session
+ # Set up mock
+ mock_delete = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.delete")
+ mock_stmt = mocker.MagicMock()
+ mock_delete.return_value = mock_stmt
+ mock_stmt.where.return_value = mock_stmt
+
+ # Mock the execute result with rowcount
+ mock_result = mocker.MagicMock()
+ mock_result.rowcount = 5 # Simulate 5 records deleted
+ session_obj.execute.return_value = mock_result
+
+ # Call method
+ repository.clear()
+
+ # Assert delete was called with correct parameters
+ mock_delete.assert_called_once_with(WorkflowNodeExecution)
+ mock_stmt.where.assert_called()
+ session_obj.execute.assert_called_once_with(mock_stmt)
+ session_obj.commit.assert_called_once()
diff --git a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py b/api/tests/unit_tests/services/workflow/test_workflow_deletion.py
index 56efcccc78..223020c2c5 100644
--- a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py
+++ b/api/tests/unit_tests/services/workflow/test_workflow_deletion.py
@@ -40,6 +40,10 @@ def workflow_setup():
def test_delete_workflow_success(workflow_setup):
# Setup mocks
+
+ # Mock the tool provider query to return None (not published as a tool)
+ workflow_setup["session"].query.return_value.filter.return_value.first.return_value = None
+
workflow_setup["session"].scalar = MagicMock(
side_effect=[workflow_setup["workflow"], None]
) # Return workflow first, then None for app
@@ -97,7 +101,12 @@ def test_delete_workflow_in_use_by_app_error(workflow_setup):
def test_delete_workflow_published_as_tool_error(workflow_setup):
# Setup mocks
- workflow_setup["workflow"].tool_published = True
+ from models.tools import WorkflowToolProvider
+
+ # Mock the tool provider query
+ mock_tool_provider = MagicMock(spec=WorkflowToolProvider)
+ workflow_setup["session"].query.return_value.filter.return_value.first.return_value = mock_tool_provider
+
workflow_setup["session"].scalar = MagicMock(
side_effect=[workflow_setup["workflow"], None]
) # Return workflow first, then None for app
diff --git a/api/uv.lock b/api/uv.lock
index 4ff9c34446..6c8699dd7c 100644
--- a/api/uv.lock
+++ b/api/uv.lock
@@ -1,5 +1,4 @@
version = 1
-revision = 1
requires-python = ">=3.11, <3.13"
resolution-markers = [
"python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy'",
@@ -1178,6 +1177,7 @@ dependencies = [
{ name = "gunicorn" },
{ name = "httpx", extra = ["socks"] },
{ name = "jieba" },
+ { name = "json-repair" },
{ name = "langfuse" },
{ name = "langsmith" },
{ name = "mailchimp-transactional" },
@@ -1346,6 +1346,7 @@ requires-dist = [
{ name = "gunicorn", specifier = "~=23.0.0" },
{ name = "httpx", extras = ["socks"], specifier = "~=0.27.0" },
{ name = "jieba", specifier = "==0.42.1" },
+ { name = "json-repair", specifier = ">=0.41.1" },
{ name = "langfuse", specifier = "~=2.51.3" },
{ name = "langsmith", specifier = "~=0.1.77" },
{ name = "mailchimp-transactional", specifier = "~=1.0.50" },
@@ -1470,7 +1471,7 @@ vdb = [
{ name = "couchbase", specifier = "~=4.3.0" },
{ name = "elasticsearch", specifier = "==8.14.0" },
{ name = "opensearch-py", specifier = "==2.4.0" },
- { name = "oracledb", specifier = "~=2.2.1" },
+ { name = "oracledb", specifier = "==3.0.0" },
{ name = "pgvecto-rs", extras = ["sqlalchemy"], specifier = "~=0.2.1" },
{ name = "pgvector", specifier = "==0.2.5" },
{ name = "pymilvus", specifier = "~=2.5.0" },
@@ -2524,6 +2525,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/91/29/df4b9b42f2be0b623cbd5e2140cafcaa2bef0759a00b7b70104dcfe2fb51/joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6", size = 301817 },
]
+[[package]]
+name = "json-repair"
+version = "0.41.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/6d/6a/6c7a75a10da6dc807b582f2449034da1ed74415e8899746bdfff97109012/json_repair-0.41.1.tar.gz", hash = "sha256:bba404b0888c84a6b86ecc02ec43b71b673cfee463baf6da94e079c55b136565", size = 31208 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/10/5c/abd7495c934d9af5c263c2245ae30cfaa716c3c0cf027b2b8fa686ee7bd4/json_repair-0.41.1-py3-none-any.whl", hash = "sha256:0e181fd43a696887881fe19fed23422a54b3e4c558b6ff27a86a8c3ddde9ae79", size = 21578 },
+]
+
[[package]]
name = "jsonpath-python"
version = "1.0.6"
@@ -3590,23 +3600,23 @@ wheels = [
[[package]]
name = "oracledb"
-version = "2.2.1"
+version = "3.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "cryptography" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/36/fb/3fbacb351833dd794abb184303a5761c4bb33df9d770fd15d01ead2ff738/oracledb-2.2.1.tar.gz", hash = "sha256:8464c6f0295f3318daf6c2c72c83c2dcbc37e13f8fd44e3e39ff8665f442d6b6", size = 580818 }
+sdist = { url = "https://files.pythonhosted.org/packages/bf/39/712f797b75705c21148fa1d98651f63c2e5cc6876e509a0a9e2f5b406572/oracledb-3.0.0.tar.gz", hash = "sha256:64dc86ee5c032febc556798b06e7b000ef6828bb0252084f6addacad3363db85", size = 840431 }
wheels = [
- { url = "https://files.pythonhosted.org/packages/74/b7/a4238295944670fb8cc50a8cc082e0af5a0440bfb1c2bac2b18429c0a579/oracledb-2.2.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:fb6d9a4d7400398b22edb9431334f9add884dec9877fd9c4ae531e1ccc6ee1fd", size = 3551303 },
- { url = "https://files.pythonhosted.org/packages/4f/5f/98481d44976cd2b3086361f2d50026066b24090b0e6cd1f2a12c824e9717/oracledb-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07757c240afbb4f28112a6affc2c5e4e34b8a92e5bb9af81a40fba398da2b028", size = 12258455 },
- { url = "https://files.pythonhosted.org/packages/e9/54/06b2540286e2b63f60877d6f3c6c40747e216b6eeda0756260e194897076/oracledb-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63daec72f853c47179e98493e9b732909d96d495bdceb521c5973a3940d28142", size = 12317476 },
- { url = "https://files.pythonhosted.org/packages/4d/1a/67814439a4e24df83281a72cb0ba433d6b74e1bff52a9975b87a725bcba5/oracledb-2.2.1-cp311-cp311-win32.whl", hash = "sha256:fec5318d1e0ada7e4674574cb6c8d1665398e8b9c02982279107212f05df1660", size = 1369368 },
- { url = "https://files.pythonhosted.org/packages/e3/b8/b2a8f0607be17f58ec6689ad5fd15c2956f4996c64547325e96439570edf/oracledb-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:5134dccb5a11bc755abf02fd49be6dc8141dfcae4b650b55d40509323d00b5c2", size = 1655035 },
- { url = "https://files.pythonhosted.org/packages/24/5b/2fff762243030f31a6b1561fc8eeb142e69ba6ebd3e7fbe4a2c82f0eb6f0/oracledb-2.2.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ac5716bc9a48247fdf563f5f4ec097f5c9f074a60fd130cdfe16699208ca29b5", size = 3583960 },
- { url = "https://files.pythonhosted.org/packages/e6/88/34117ae830e7338af7c0481f1c0fc6eda44d558e12f9203b45b491e53071/oracledb-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c150bddb882b7c73fb462aa2d698744da76c363e404570ed11d05b65811d96c3", size = 11749006 },
- { url = "https://files.pythonhosted.org/packages/9d/58/bac788f18c21f727955652fe238de2d24a12c2b455ed4db18a6d23ff781e/oracledb-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:193e1888411bc21187ade4b16b76820bd1e8f216e25602f6cd0a97d45723c1dc", size = 11950663 },
- { url = "https://files.pythonhosted.org/packages/3b/e2/005f66ae919c6f7c73e06863256cf43aa844330e2dc61a5f9779ae44a801/oracledb-2.2.1-cp312-cp312-win32.whl", hash = "sha256:44a960f8bbb0711af222e0a9690e037b6a2a382e0559ae8eeb9cfafe26c7a3bc", size = 1324255 },
- { url = "https://files.pythonhosted.org/packages/e6/25/759eb2143134513382e66d874c4aacfd691dec3fef7141170cfa6c1b154f/oracledb-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:470136add32f0d0084225c793f12a52b61b52c3dc00c9cd388ec6a3db3a7643e", size = 1613047 },
+ { url = "https://files.pythonhosted.org/packages/fa/bf/d872c4b3fc15cd3261fe0ea72b21d181700c92dbc050160e161654987062/oracledb-3.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:52daa9141c63dfa75c07d445e9bb7f69f43bfb3c5a173ecc48c798fe50288d26", size = 4312963 },
+ { url = "https://files.pythonhosted.org/packages/b1/ea/01ee29e76a610a53bb34fdc1030f04b7669c3f80b25f661e07850fc6160e/oracledb-3.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:af98941789df4c6aaaf4338f5b5f6b7f2c8c3fe6f8d6a9382f177f350868747a", size = 2661536 },
+ { url = "https://files.pythonhosted.org/packages/3d/8e/ad380e34a46819224423b4773e58c350bc6269643c8969604097ced8c3bc/oracledb-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9812bb48865aaec35d73af54cd1746679f2a8a13cbd1412ab371aba2e39b3943", size = 2867461 },
+ { url = "https://files.pythonhosted.org/packages/96/09/ecc4384a27fd6e1e4de824ae9c160e4ad3aaebdaade5b4bdcf56a4d1ff63/oracledb-3.0.0-cp311-cp311-win32.whl", hash = "sha256:6c27fe0de64f2652e949eb05b3baa94df9b981a4a45fa7f8a991e1afb450c8e2", size = 1752046 },
+ { url = "https://files.pythonhosted.org/packages/62/e8/f34bde24050c6e55eeba46b23b2291f2dd7fd272fa8b322dcbe71be55778/oracledb-3.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:f922709672002f0b40997456f03a95f03e5712a86c61159951c5ce09334325e0", size = 2101210 },
+ { url = "https://files.pythonhosted.org/packages/6f/fc/24590c3a3d41e58494bd3c3b447a62835138e5f9b243d9f8da0cfb5da8dc/oracledb-3.0.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:acd0e747227dea01bebe627b07e958bf36588a337539f24db629dc3431d3f7eb", size = 4351993 },
+ { url = "https://files.pythonhosted.org/packages/b7/b6/1f3b0b7bb94d53e8857d77b2e8dbdf6da091dd7e377523e24b79dac4fd71/oracledb-3.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f8b402f77c22af031cd0051aea2472ecd0635c1b452998f511aa08b7350c90a4", size = 2532640 },
+ { url = "https://files.pythonhosted.org/packages/72/1a/1815f6c086ab49c00921cf155ff5eede5267fb29fcec37cb246339a5ce4d/oracledb-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:378a27782e9a37918bd07a5a1427a77cb6f777d0a5a8eac9c070d786f50120ef", size = 2765949 },
+ { url = "https://files.pythonhosted.org/packages/33/8d/208900f8d372909792ee70b2daad3f7361181e55f2217c45ed9dff658b54/oracledb-3.0.0-cp312-cp312-win32.whl", hash = "sha256:54a28c2cb08316a527cd1467740a63771cc1c1164697c932aa834c0967dc4efc", size = 1709373 },
+ { url = "https://files.pythonhosted.org/packages/0c/5e/c21754f19c896102793c3afec2277e2180aa7d505e4d7fcca24b52d14e4f/oracledb-3.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:8289bad6d103ce42b140e40576cf0c81633e344d56e2d738b539341eacf65624", size = 2056452 },
]
[[package]]
@@ -4074,6 +4084,8 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/af/cd/ed6e429fb0792ce368f66e83246264dd3a7a045b0b1e63043ed22a063ce5/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7c9e222d0976f68d0cf6409cfea896676ddc1d98485d601e9508f90f60e2b0a2", size = 2144914 },
{ url = "https://files.pythonhosted.org/packages/f6/23/b064bd4cfbf2cc5f25afcde0e7c880df5b20798172793137ba4b62d82e72/pycryptodome-3.19.1-cp35-abi3-win32.whl", hash = "sha256:4805e053571140cb37cf153b5c72cd324bb1e3e837cbe590a19f69b6cf85fd03", size = 1713105 },
{ url = "https://files.pythonhosted.org/packages/7d/e0/ded1968a5257ab34216a0f8db7433897a2337d59e6d03be113713b346ea2/pycryptodome-3.19.1-cp35-abi3-win_amd64.whl", hash = "sha256:a470237ee71a1efd63f9becebc0ad84b88ec28e6784a2047684b693f458f41b7", size = 1749222 },
+ { url = "https://files.pythonhosted.org/packages/1d/e3/0c9679cd66cf5604b1f070bdf4525a0c01a15187be287d8348b2eafb718e/pycryptodome-3.19.1-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:ed932eb6c2b1c4391e166e1a562c9d2f020bfff44a0e1b108f67af38b390ea89", size = 1629005 },
+ { url = "https://files.pythonhosted.org/packages/13/75/0d63bf0daafd0580b17202d8a9dd57f28c8487f26146b3e2799b0c5a059c/pycryptodome-3.19.1-pp27-pypy_73-win32.whl", hash = "sha256:81e9d23c0316fc1b45d984a44881b220062336bbdc340aa9218e8d0656587934", size = 1697997 },
]
[[package]]
diff --git a/docker/.env.example b/docker/.env.example
index 9b372dcec9..82ef4174c2 100644
--- a/docker/.env.example
+++ b/docker/.env.example
@@ -744,6 +744,12 @@ MAX_VARIABLE_SIZE=204800
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
WORKFLOW_FILE_UPLOAD_LIMIT=10
+# Workflow storage configuration
+# Options: rdbms, hybrid
+# rdbms: Use only the relational database (default)
+# hybrid: Save new data to object storage, read from both object storage and RDBMS
+WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
+
# HTTP request node in workflow configuration
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576
diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml
index a8f7b755fb..c6d41849ef 100644
--- a/docker/docker-compose-template.yaml
+++ b/docker/docker-compose-template.yaml
@@ -130,6 +130,7 @@ services:
HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128}
HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128}
SANDBOX_PORT: ${SANDBOX_PORT:-8194}
+ PIP_MIRROR_URL: ${PIP_MIRROR_URL:-}
volumes:
- ./volumes/sandbox/dependencies:/dependencies
- ./volumes/sandbox/conf:/conf
diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml
index 27d6d660d0..1702a5395f 100644
--- a/docker/docker-compose.middleware.yaml
+++ b/docker/docker-compose.middleware.yaml
@@ -60,6 +60,7 @@ services:
HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128}
HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128}
SANDBOX_PORT: ${SANDBOX_PORT:-8194}
+ PIP_MIRROR_URL: ${PIP_MIRROR_URL:-}
volumes:
- ./volumes/sandbox/dependencies:/dependencies
- ./volumes/sandbox/conf:/conf
diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml
index 486264cebd..15fab4a4bf 100644
--- a/docker/docker-compose.yaml
+++ b/docker/docker-compose.yaml
@@ -327,6 +327,7 @@ x-shared-env: &shared-api-worker-env
MAX_VARIABLE_SIZE: ${MAX_VARIABLE_SIZE:-204800}
WORKFLOW_PARALLEL_DEPTH_LIMIT: ${WORKFLOW_PARALLEL_DEPTH_LIMIT:-3}
WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10}
+ WORKFLOW_NODE_EXECUTION_STORAGE: ${WORKFLOW_NODE_EXECUTION_STORAGE:-rdbms}
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760}
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576}
HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True}
@@ -605,6 +606,7 @@ services:
HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128}
HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128}
SANDBOX_PORT: ${SANDBOX_PORT:-8194}
+ PIP_MIRROR_URL: ${PIP_MIRROR_URL:-}
volumes:
- ./volumes/sandbox/dependencies:/dependencies
- ./volumes/sandbox/conf:/conf
diff --git a/web/README.md b/web/README.md
index 3236347e80..3d9fd2de87 100644
--- a/web/README.md
+++ b/web/README.md
@@ -7,7 +7,7 @@ This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next
### Run by source code
Before starting the web frontend service, please make sure the following environment is ready.
-- [Node.js](https://nodejs.org) >= v18.x
+- [Node.js](https://nodejs.org) >= v22.11.x
- [pnpm](https://pnpm.io) v10.x
First, install the dependencies:
diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/workflow/page.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/workflow/page.tsx
index f4d49425ae..d5df70f004 100644
--- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/workflow/page.tsx
+++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/workflow/page.tsx
@@ -1,11 +1,11 @@
'use client'
-import Workflow from '@/app/components/workflow'
+import WorkflowApp from '@/app/components/workflow-app'
const Page = () => {
return (
-
+
)
}
diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx
index b435a9bb67..a8bb7046e6 100644
--- a/web/app/(commonLayout)/datasets/template/template.zh.mdx
+++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx
@@ -94,6 +94,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
- semantic_search 语义检索
- full_text_search 全文检索
- reranking_enable (bool) 是否开启rerank
+ - reranking_mode (String) 混合检索
+ - weighted_score 权重设置
+ - reranking_model Rerank 模型
- reranking_model (object) Rerank 模型配置
- reranking_provider_name (string) Rerank 模型的提供商
- reranking_model_name (string) Rerank 模型的名称
@@ -591,7 +594,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
检索参数(选填,如不填,按照默认方式召回)
- - search_method (text) 检索方法:以下三个关键字之一,必填
+ - search_method (text) 检索方法:以下四个关键字之一,必填
- keyword_search 关键字检索
- semantic_search 语义检索
- full_text_search 全文检索
@@ -1817,7 +1820,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
检索参数(选填,如不填,按照默认方式召回)
- - search_method (text) 检索方法:以下三个关键字之一,必填
+ - search_method (text) 检索方法:以下四个关键字之一,必填
- keyword_search 关键字检索
- semantic_search 语义检索
- full_text_search 全文检索
diff --git a/web/app/components/app/configuration/config-var/config-select/index.spec.tsx b/web/app/components/app/configuration/config-var/config-select/index.spec.tsx
new file mode 100644
index 0000000000..18df318de3
--- /dev/null
+++ b/web/app/components/app/configuration/config-var/config-select/index.spec.tsx
@@ -0,0 +1,82 @@
+import { fireEvent, render, screen } from '@testing-library/react'
+import ConfigSelect from './index'
+
+jest.mock('react-sortablejs', () => ({
+ ReactSortable: ({ children }: { children: React.ReactNode }) => {children}
,
+}))
+
+jest.mock('react-i18next', () => ({
+ useTranslation: () => ({
+ t: (key: string) => key,
+ }),
+}))
+
+describe('ConfigSelect Component', () => {
+ const defaultProps = {
+ options: ['Option 1', 'Option 2'],
+ onChange: jest.fn(),
+ }
+
+ afterEach(() => {
+ jest.clearAllMocks()
+ })
+
+ it('renders all options', () => {
+ render()
+
+ defaultProps.options.forEach((option) => {
+ expect(screen.getByDisplayValue(option)).toBeInTheDocument()
+ })
+ })
+
+ it('renders add button', () => {
+ render()
+
+ expect(screen.getByText('appDebug.variableConfig.addOption')).toBeInTheDocument()
+ })
+
+ it('handles option deletion', () => {
+ render()
+ const optionContainer = screen.getByDisplayValue('Option 1').closest('div')
+ const deleteButton = optionContainer?.querySelector('div[role="button"]')
+
+ if (!deleteButton) return
+ fireEvent.click(deleteButton)
+ expect(defaultProps.onChange).toHaveBeenCalledWith(['Option 2'])
+ })
+
+ it('handles adding new option', () => {
+ render()
+ const addButton = screen.getByText('appDebug.variableConfig.addOption')
+
+ fireEvent.click(addButton)
+
+ expect(defaultProps.onChange).toHaveBeenCalledWith([...defaultProps.options, ''])
+ })
+
+ it('applies focus styles on input focus', () => {
+ render()
+ const firstInput = screen.getByDisplayValue('Option 1')
+
+ fireEvent.focus(firstInput)
+
+ expect(firstInput.closest('div')).toHaveClass('border-components-input-border-active')
+ })
+
+ it('applies delete hover styles', () => {
+ render()
+ const optionContainer = screen.getByDisplayValue('Option 1').closest('div')
+ const deleteButton = optionContainer?.querySelector('div[role="button"]')
+
+ if (!deleteButton) return
+ fireEvent.mouseEnter(deleteButton)
+ expect(optionContainer).toHaveClass('border-components-input-border-destructive')
+ })
+
+ it('renders empty state correctly', () => {
+ render()
+
+ expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
+ expect(screen.getByText('appDebug.variableConfig.addOption')).toBeInTheDocument()
+ })
+})
diff --git a/web/app/components/app/configuration/config-var/config-select/index.tsx b/web/app/components/app/configuration/config-var/config-select/index.tsx
index d2dc1662c1..40ddaef78f 100644
--- a/web/app/components/app/configuration/config-var/config-select/index.tsx
+++ b/web/app/components/app/configuration/config-var/config-select/index.tsx
@@ -51,7 +51,7 @@ const ConfigSelect: FC = ({
{
const value = e.target.value
@@ -67,6 +67,7 @@ const ConfigSelect: FC = ({
onBlur={() => setFocusID(null)}
/>
{
onChange(options.filter((_, i) => index !== i))
diff --git a/web/app/components/app/overview/settings/index.tsx b/web/app/components/app/overview/settings/index.tsx
index 896229d433..679d616e54 100644
--- a/web/app/components/app/overview/settings/index.tsx
+++ b/web/app/components/app/overview/settings/index.tsx
@@ -162,11 +162,22 @@ const SettingsModal: FC
= ({
return check
}
+ const validatePrivacyPolicy = (privacyPolicy: string | null) => {
+ if (privacyPolicy === null || privacyPolicy?.length === 0)
+ return true
+
+ return privacyPolicy.startsWith('http://') || privacyPolicy.startsWith('https://')
+ }
+
if (inputInfo !== null) {
if (!validateColorHex(inputInfo.chatColorTheme)) {
notify({ type: 'error', message: t(`${prefixSettings}.invalidHexMessage`) })
return
}
+ if (!validatePrivacyPolicy(inputInfo.privacyPolicy)) {
+ notify({ type: 'error', message: t(`${prefixSettings}.invalidPrivacyPolicy`) })
+ return
+ }
}
setSaveLoading(true)
@@ -410,7 +421,7 @@ const SettingsModal: FC = ({
}}
+ components={{ privacyPolicyLink: }}
/>
= ({
)
}
-export default memo(Answer)
+export default memo(Answer, (prevProps, nextProps) =>
+ prevProps.responding === false && nextProps.responding === false,
+)
diff --git a/web/app/components/base/checkbox/assets/indeterminate-icon.tsx b/web/app/components/base/checkbox/assets/indeterminate-icon.tsx
new file mode 100644
index 0000000000..56df8db6a4
--- /dev/null
+++ b/web/app/components/base/checkbox/assets/indeterminate-icon.tsx
@@ -0,0 +1,11 @@
+const IndeterminateIcon = () => {
+ return (
+
+ )
+}
+
+export default IndeterminateIcon
diff --git a/web/app/components/base/checkbox/assets/mixed.svg b/web/app/components/base/checkbox/assets/mixed.svg
deleted file mode 100644
index e16b8fc975..0000000000
--- a/web/app/components/base/checkbox/assets/mixed.svg
+++ /dev/null
@@ -1,5 +0,0 @@
-
diff --git a/web/app/components/base/checkbox/index.module.css b/web/app/components/base/checkbox/index.module.css
deleted file mode 100644
index d675607b46..0000000000
--- a/web/app/components/base/checkbox/index.module.css
+++ /dev/null
@@ -1,10 +0,0 @@
-.mixed {
- background: var(--color-components-checkbox-bg) url(./assets/mixed.svg) center center no-repeat;
- background-size: 12px 12px;
- border: none;
-}
-
-.checked.disabled {
- background-color: #d0d5dd;
- border-color: #d0d5dd;
-}
\ No newline at end of file
diff --git a/web/app/components/base/checkbox/index.spec.tsx b/web/app/components/base/checkbox/index.spec.tsx
new file mode 100644
index 0000000000..7ef901aef5
--- /dev/null
+++ b/web/app/components/base/checkbox/index.spec.tsx
@@ -0,0 +1,67 @@
+import { fireEvent, render, screen } from '@testing-library/react'
+import Checkbox from './index'
+
+describe('Checkbox Component', () => {
+ const mockProps = {
+ id: 'test',
+ }
+
+ it('renders unchecked checkbox by default', () => {
+ render()
+ const checkbox = screen.getByTestId('checkbox-test')
+ expect(checkbox).toBeInTheDocument()
+ expect(checkbox).not.toHaveClass('bg-components-checkbox-bg')
+ })
+
+ it('renders checked checkbox when checked prop is true', () => {
+ render()
+ const checkbox = screen.getByTestId('checkbox-test')
+ expect(checkbox).toHaveClass('bg-components-checkbox-bg')
+ expect(screen.getByTestId('check-icon-test')).toBeInTheDocument()
+ })
+
+ it('renders indeterminate state correctly', () => {
+ render()
+ expect(screen.getByTestId('indeterminate-icon')).toBeInTheDocument()
+ })
+
+ it('handles click events when not disabled', () => {
+ const onCheck = jest.fn()
+ render()
+ const checkbox = screen.getByTestId('checkbox-test')
+
+ fireEvent.click(checkbox)
+ expect(onCheck).toHaveBeenCalledTimes(1)
+ })
+
+ it('does not handle click events when disabled', () => {
+ const onCheck = jest.fn()
+ render()
+ const checkbox = screen.getByTestId('checkbox-test')
+
+ fireEvent.click(checkbox)
+ expect(onCheck).not.toHaveBeenCalled()
+ expect(checkbox).toHaveClass('cursor-not-allowed')
+ })
+
+ it('applies custom className when provided', () => {
+ const customClass = 'custom-class'
+ render()
+ const checkbox = screen.getByTestId('checkbox-test')
+ expect(checkbox).toHaveClass(customClass)
+ })
+
+ it('applies correct styles for disabled checked state', () => {
+ render()
+ const checkbox = screen.getByTestId('checkbox-test')
+ expect(checkbox).toHaveClass('bg-components-checkbox-bg-disabled-checked')
+ expect(checkbox).toHaveClass('cursor-not-allowed')
+ })
+
+ it('applies correct styles for disabled unchecked state', () => {
+ render()
+ const checkbox = screen.getByTestId('checkbox-test')
+ expect(checkbox).toHaveClass('bg-components-checkbox-bg-disabled')
+ expect(checkbox).toHaveClass('cursor-not-allowed')
+ })
+})
diff --git a/web/app/components/base/checkbox/index.tsx b/web/app/components/base/checkbox/index.tsx
index b0b0ebca7c..3e47967c62 100644
--- a/web/app/components/base/checkbox/index.tsx
+++ b/web/app/components/base/checkbox/index.tsx
@@ -1,48 +1,49 @@
import { RiCheckLine } from '@remixicon/react'
-import s from './index.module.css'
import cn from '@/utils/classnames'
+import IndeterminateIcon from './assets/indeterminate-icon'
type CheckboxProps = {
+ id?: string
checked?: boolean
onCheck?: () => void
className?: string
disabled?: boolean
- mixed?: boolean
+ indeterminate?: boolean
}
-const Checkbox = ({ checked, onCheck, className, disabled, mixed }: CheckboxProps) => {
- if (!checked) {
- return (
- {
- if (disabled)
- return
- onCheck?.()
- }}
- >
- )
- }
+const Checkbox = ({
+ id,
+ checked,
+ onCheck,
+ className,
+ disabled,
+ indeterminate,
+}: CheckboxProps) => {
+ const checkClassName = (checked || indeterminate)
+ ? 'bg-components-checkbox-bg text-components-checkbox-icon hover:bg-components-checkbox-bg-hover'
+ : 'border border-components-checkbox-border bg-components-checkbox-bg-unchecked hover:bg-components-checkbox-bg-unchecked-hover hover:border-components-checkbox-border-hover'
+ const disabledClassName = (checked || indeterminate)
+ ? 'cursor-not-allowed bg-components-checkbox-bg-disabled-checked text-components-checkbox-icon-disabled hover:bg-components-checkbox-bg-disabled-checked'
+ : 'cursor-not-allowed border-components-checkbox-border-disabled bg-components-checkbox-bg-disabled hover:border-components-checkbox-border-disabled hover:bg-components-checkbox-bg-disabled'
+
return (
{
if (disabled)
return
-
onCheck?.()
}}
+ data-testid={`checkbox-${id}`}
>
-
+ {!checked && indeterminate && }
+ {checked && }
)
}
diff --git a/web/app/components/base/form/components/field/checkbox.tsx b/web/app/components/base/form/components/field/checkbox.tsx
new file mode 100644
index 0000000000..855dbd80fe
--- /dev/null
+++ b/web/app/components/base/form/components/field/checkbox.tsx
@@ -0,0 +1,43 @@
+import cn from '@/utils/classnames'
+import { useFieldContext } from '../..'
+import Checkbox from '../../../checkbox'
+
+type CheckboxFieldProps = {
+ label: string;
+ labelClassName?: string;
+}
+
+const CheckboxField = ({
+ label,
+ labelClassName,
+}: CheckboxFieldProps) => {
+ const field = useFieldContext()
+
+ return (
+
+
+ {
+ field.handleChange(!field.state.value)
+ }}
+ />
+
+
+
+ )
+}
+
+export default CheckboxField
diff --git a/web/app/components/base/form/components/field/number-input.tsx b/web/app/components/base/form/components/field/number-input.tsx
new file mode 100644
index 0000000000..fce3143fe1
--- /dev/null
+++ b/web/app/components/base/form/components/field/number-input.tsx
@@ -0,0 +1,49 @@
+import React from 'react'
+import { useFieldContext } from '../..'
+import Label from '../label'
+import cn from '@/utils/classnames'
+import type { InputNumberProps } from '../../../input-number'
+import { InputNumber } from '../../../input-number'
+
+type TextFieldProps = {
+ label: string
+ isRequired?: boolean
+ showOptional?: boolean
+ tooltip?: string
+ className?: string
+ labelClassName?: string
+} & Omit
+
+const NumberInputField = ({
+ label,
+ isRequired,
+ showOptional,
+ tooltip,
+ className,
+ labelClassName,
+ ...inputProps
+}: TextFieldProps) => {
+ const field = useFieldContext()
+
+ return (
+
+
+ field.handleChange(value)}
+ onBlur={field.handleBlur}
+ {...inputProps}
+ />
+
+ )
+}
+
+export default NumberInputField
diff --git a/web/app/components/base/form/components/field/options.tsx b/web/app/components/base/form/components/field/options.tsx
new file mode 100644
index 0000000000..9ff71e50af
--- /dev/null
+++ b/web/app/components/base/form/components/field/options.tsx
@@ -0,0 +1,34 @@
+import cn from '@/utils/classnames'
+import { useFieldContext } from '../..'
+import Label from '../label'
+import ConfigSelect from '@/app/components/app/configuration/config-var/config-select'
+
+type OptionsFieldProps = {
+ label: string;
+ className?: string;
+ labelClassName?: string;
+}
+
+const OptionsField = ({
+ label,
+ className,
+ labelClassName,
+}: OptionsFieldProps) => {
+ const field = useFieldContext()
+
+ return (
+
+
+ field.handleChange(value)}
+ />
+
+ )
+}
+
+export default OptionsField
diff --git a/web/app/components/base/form/components/field/select.tsx b/web/app/components/base/form/components/field/select.tsx
new file mode 100644
index 0000000000..95af3c0116
--- /dev/null
+++ b/web/app/components/base/form/components/field/select.tsx
@@ -0,0 +1,51 @@
+import cn from '@/utils/classnames'
+import { useFieldContext } from '../..'
+import PureSelect from '../../../select/pure'
+import Label from '../label'
+
+type SelectOption = {
+ value: string
+ label: string
+}
+
+type SelectFieldProps = {
+ label: string
+ options: SelectOption[]
+ isRequired?: boolean
+ showOptional?: boolean
+ tooltip?: string
+ className?: string
+ labelClassName?: string
+}
+
+const SelectField = ({
+ label,
+ options,
+ isRequired,
+ showOptional,
+ tooltip,
+ className,
+ labelClassName,
+}: SelectFieldProps) => {
+ const field = useFieldContext()
+
+ return (
+
+
+
field.handleChange(value)}
+ />
+
+ )
+}
+
+export default SelectField
diff --git a/web/app/components/base/form/components/field/text.tsx b/web/app/components/base/form/components/field/text.tsx
new file mode 100644
index 0000000000..b2090291a0
--- /dev/null
+++ b/web/app/components/base/form/components/field/text.tsx
@@ -0,0 +1,48 @@
+import React from 'react'
+import { useFieldContext } from '../..'
+import Input, { type InputProps } from '../../../input'
+import Label from '../label'
+import cn from '@/utils/classnames'
+
+type TextFieldProps = {
+ label: string
+ isRequired?: boolean
+ showOptional?: boolean
+ tooltip?: string
+ className?: string
+ labelClassName?: string
+} & Omit
+
+const TextField = ({
+ label,
+ isRequired,
+ showOptional,
+ tooltip,
+ className,
+ labelClassName,
+ ...inputProps
+}: TextFieldProps) => {
+ const field = useFieldContext()
+
+ return (
+
+
+ field.handleChange(e.target.value)}
+ onBlur={field.handleBlur}
+ {...inputProps}
+ />
+
+ )
+}
+
+export default TextField
diff --git a/web/app/components/base/form/components/form/submit-button.tsx b/web/app/components/base/form/components/form/submit-button.tsx
new file mode 100644
index 0000000000..494d19b843
--- /dev/null
+++ b/web/app/components/base/form/components/form/submit-button.tsx
@@ -0,0 +1,25 @@
+import { useStore } from '@tanstack/react-form'
+import { useFormContext } from '../..'
+import Button, { type ButtonProps } from '../../../button'
+
+type SubmitButtonProps = Omit
+
+const SubmitButton = ({ ...buttonProps }: SubmitButtonProps) => {
+ const form = useFormContext()
+
+ const [isSubmitting, canSubmit] = useStore(form.store, state => [
+ state.isSubmitting,
+ state.canSubmit,
+ ])
+
+ return (
+
+ )
+}
+
+export default Label
diff --git a/web/app/components/base/form/form-scenarios/demo/contact-fields.tsx b/web/app/components/base/form/form-scenarios/demo/contact-fields.tsx
new file mode 100644
index 0000000000..9ba664fc10
--- /dev/null
+++ b/web/app/components/base/form/form-scenarios/demo/contact-fields.tsx
@@ -0,0 +1,35 @@
+import { withForm } from '../..'
+import { demoFormOpts } from './shared-options'
+import { ContactMethods } from './types'
+
+const ContactFields = withForm({
+ ...demoFormOpts,
+ render: ({ form }) => {
+ return (
+
+
Contacts
+
+
}
+ />
+ }
+ />
+ (
+
+ )}
+ />
+
+
+ )
+ },
+})
+
+export default ContactFields
diff --git a/web/app/components/base/form/form-scenarios/demo/index.tsx b/web/app/components/base/form/form-scenarios/demo/index.tsx
new file mode 100644
index 0000000000..f08edee41e
--- /dev/null
+++ b/web/app/components/base/form/form-scenarios/demo/index.tsx
@@ -0,0 +1,68 @@
+import { useStore } from '@tanstack/react-form'
+import { useAppForm } from '../..'
+import ContactFields from './contact-fields'
+import { demoFormOpts } from './shared-options'
+import { UserSchema } from './types'
+
+const DemoForm = () => {
+ const form = useAppForm({
+ ...demoFormOpts,
+ validators: {
+ onSubmit: ({ value }) => {
+ // Validate the entire form
+ const result = UserSchema.safeParse(value)
+ if (!result.success) {
+ const issues = result.error.issues
+ console.log('Validation errors:', issues)
+ return issues[0].message
+ }
+ return undefined
+ },
+ },
+ onSubmit: ({ value }) => {
+ console.log('Form submitted:', value)
+ },
+ })
+
+const name = useStore(form.store, state => state.values.name)
+
+ return (
+ (
+
+ )}
+ />
+ (
+
+ )}
+ />
+ (
+
+ )}
+ />
+ {
+ !!name && (
+
+ )
+ }
+
+ Submit
+
+
+ )
+}
+
+export default DemoForm
diff --git a/web/app/components/base/form/form-scenarios/demo/shared-options.tsx b/web/app/components/base/form/form-scenarios/demo/shared-options.tsx
new file mode 100644
index 0000000000..8b216c8b90
--- /dev/null
+++ b/web/app/components/base/form/form-scenarios/demo/shared-options.tsx
@@ -0,0 +1,14 @@
+import { formOptions } from '@tanstack/react-form'
+
+export const demoFormOpts = formOptions({
+ defaultValues: {
+ name: '',
+ surname: '',
+ isAcceptingTerms: false,
+ contact: {
+ email: '',
+ phone: '',
+ preferredContactMethod: 'email',
+ },
+ },
+})
diff --git a/web/app/components/base/form/form-scenarios/demo/types.ts b/web/app/components/base/form/form-scenarios/demo/types.ts
new file mode 100644
index 0000000000..c4e626ef63
--- /dev/null
+++ b/web/app/components/base/form/form-scenarios/demo/types.ts
@@ -0,0 +1,34 @@
+import { z } from 'zod'
+
+const ContactMethod = z.union([
+ z.literal('email'),
+ z.literal('phone'),
+ z.literal('whatsapp'),
+ z.literal('sms'),
+])
+
+export const ContactMethods = ContactMethod.options.map(({ value }) => ({
+ value,
+ label: value.charAt(0).toUpperCase() + value.slice(1),
+}))
+
+export const UserSchema = z.object({
+ name: z
+ .string()
+ .regex(/^[A-Z]/, 'Name must start with a capital letter')
+ .min(3, 'Name must be at least 3 characters long'),
+ surname: z
+ .string()
+ .min(3, 'Surname must be at least 3 characters long')
+ .regex(/^[A-Z]/, 'Surname must start with a capital letter'),
+ isAcceptingTerms: z.boolean().refine(val => val, {
+ message: 'You must accept the terms and conditions',
+ }),
+ contact: z.object({
+ email: z.string().email('Invalid email address'),
+ phone: z.string().optional(),
+ preferredContactMethod: ContactMethod,
+ }),
+})
+
+export type User = z.infer
diff --git a/web/app/components/base/form/index.tsx b/web/app/components/base/form/index.tsx
new file mode 100644
index 0000000000..aeb482ad02
--- /dev/null
+++ b/web/app/components/base/form/index.tsx
@@ -0,0 +1,25 @@
+import { createFormHook, createFormHookContexts } from '@tanstack/react-form'
+import TextField from './components/field/text'
+import NumberInputField from './components/field/number-input'
+import CheckboxField from './components/field/checkbox'
+import SelectField from './components/field/select'
+import OptionsField from './components/field/options'
+import SubmitButton from './components/form/submit-button'
+
+export const { fieldContext, useFieldContext, formContext, useFormContext }
+ = createFormHookContexts()
+
+export const { useAppForm, withForm } = createFormHook({
+ fieldComponents: {
+ TextField,
+ NumberInputField,
+ CheckboxField,
+ SelectField,
+ OptionsField,
+ },
+ formComponents: {
+ SubmitButton,
+ },
+ fieldContext,
+ formContext,
+})
diff --git a/web/app/components/base/icons/assets/vender/solid/general/arrow-down-round-fill.svg b/web/app/components/base/icons/assets/vender/solid/general/arrow-down-round-fill.svg
new file mode 100644
index 0000000000..9566fcc0c3
--- /dev/null
+++ b/web/app/components/base/icons/assets/vender/solid/general/arrow-down-round-fill.svg
@@ -0,0 +1,5 @@
+
diff --git a/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.json b/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.json
new file mode 100644
index 0000000000..4e7da3c801
--- /dev/null
+++ b/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.json
@@ -0,0 +1,36 @@
+{
+ "icon": {
+ "type": "element",
+ "isRootNode": true,
+ "name": "svg",
+ "attributes": {
+ "width": "16",
+ "height": "16",
+ "viewBox": "0 0 16 16",
+ "fill": "none",
+ "xmlns": "http://www.w3.org/2000/svg"
+ },
+ "children": [
+ {
+ "type": "element",
+ "name": "g",
+ "attributes": {
+ "id": "arrow-down-round-fill"
+ },
+ "children": [
+ {
+ "type": "element",
+ "name": "path",
+ "attributes": {
+ "id": "Vector",
+ "d": "M6.02913 6.23572C5.08582 6.23572 4.56482 7.33027 5.15967 8.06239L7.13093 10.4885C7.57922 11.0403 8.42149 11.0403 8.86986 10.4885L10.8411 8.06239C11.4359 7.33027 10.9149 6.23572 9.97158 6.23572H6.02913Z",
+ "fill": "currentColor"
+ },
+ "children": []
+ }
+ ]
+ }
+ ]
+ },
+ "name": "ArrowDownRoundFill"
+}
\ No newline at end of file
diff --git a/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.tsx b/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.tsx
new file mode 100644
index 0000000000..c766a72b94
--- /dev/null
+++ b/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.tsx
@@ -0,0 +1,20 @@
+// GENERATE BY script
+// DON NOT EDIT IT MANUALLY
+
+import * as React from 'react'
+import data from './ArrowDownRoundFill.json'
+import IconBase from '@/app/components/base/icons/IconBase'
+import type { IconData } from '@/app/components/base/icons/IconBase'
+
+const Icon = (
+ {
+ ref,
+ ...props
+ }: React.SVGProps & {
+ ref?: React.RefObject>;
+ },
+) =>
+
+Icon.displayName = 'ArrowDownRoundFill'
+
+export default Icon
diff --git a/web/app/components/base/icons/src/vender/solid/general/index.ts b/web/app/components/base/icons/src/vender/solid/general/index.ts
index 52647905ab..4c4dd9a437 100644
--- a/web/app/components/base/icons/src/vender/solid/general/index.ts
+++ b/web/app/components/base/icons/src/vender/solid/general/index.ts
@@ -1,4 +1,5 @@
export { default as AnswerTriangle } from './AnswerTriangle'
+export { default as ArrowDownRoundFill } from './ArrowDownRoundFill'
export { default as CheckCircle } from './CheckCircle'
export { default as CheckDone01 } from './CheckDone01'
export { default as Download02 } from './Download02'
diff --git a/web/app/components/base/input-number/index.spec.tsx b/web/app/components/base/input-number/index.spec.tsx
new file mode 100644
index 0000000000..8dfd1184b0
--- /dev/null
+++ b/web/app/components/base/input-number/index.spec.tsx
@@ -0,0 +1,97 @@
+import { fireEvent, render, screen } from '@testing-library/react'
+import { InputNumber } from './index'
+
+jest.mock('react-i18next', () => ({
+ useTranslation: () => ({
+ t: (key: string) => key,
+ }),
+}))
+
+describe('InputNumber Component', () => {
+ const defaultProps = {
+ onChange: jest.fn(),
+ }
+
+ afterEach(() => {
+ jest.clearAllMocks()
+ })
+
+ it('renders input with default values', () => {
+ render()
+ const input = screen.getByRole('textbox')
+ expect(input).toBeInTheDocument()
+ })
+
+ it('handles increment button click', () => {
+ render()
+ const incrementBtn = screen.getByRole('button', { name: /increment/i })
+
+ fireEvent.click(incrementBtn)
+ expect(defaultProps.onChange).toHaveBeenCalledWith(6)
+ })
+
+ it('handles decrement button click', () => {
+ render()
+ const decrementBtn = screen.getByRole('button', { name: /decrement/i })
+
+ fireEvent.click(decrementBtn)
+ expect(defaultProps.onChange).toHaveBeenCalledWith(4)
+ })
+
+ it('respects max value constraint', () => {
+ render()
+ const incrementBtn = screen.getByRole('button', { name: /increment/i })
+
+ fireEvent.click(incrementBtn)
+ expect(defaultProps.onChange).not.toHaveBeenCalled()
+ })
+
+ it('respects min value constraint', () => {
+ render()
+ const decrementBtn = screen.getByRole('button', { name: /decrement/i })
+
+ fireEvent.click(decrementBtn)
+ expect(defaultProps.onChange).not.toHaveBeenCalled()
+ })
+
+ it('handles direct input changes', () => {
+ render()
+ const input = screen.getByRole('textbox')
+
+ fireEvent.change(input, { target: { value: '42' } })
+ expect(defaultProps.onChange).toHaveBeenCalledWith(42)
+ })
+
+ it('handles empty input', () => {
+ render()
+ const input = screen.getByRole('textbox')
+
+ fireEvent.change(input, { target: { value: '' } })
+ expect(defaultProps.onChange).toHaveBeenCalledWith(undefined)
+ })
+
+ it('handles invalid input', () => {
+ render()
+ const input = screen.getByRole('textbox')
+
+ fireEvent.change(input, { target: { value: 'abc' } })
+ expect(defaultProps.onChange).not.toHaveBeenCalled()
+ })
+
+ it('displays unit when provided', () => {
+ const unit = 'px'
+ render()
+ expect(screen.getByText(unit)).toBeInTheDocument()
+ })
+
+ it('disables controls when disabled prop is true', () => {
+ render()
+ const input = screen.getByRole('textbox')
+ const incrementBtn = screen.getByRole('button', { name: /increment/i })
+ const decrementBtn = screen.getByRole('button', { name: /decrement/i })
+
+ expect(input).toBeDisabled()
+ expect(incrementBtn).toBeDisabled()
+ expect(decrementBtn).toBeDisabled()
+ })
+})
diff --git a/web/app/components/base/input-number/index.tsx b/web/app/components/base/input-number/index.tsx
index 5b88fc67f8..98efc94462 100644
--- a/web/app/components/base/input-number/index.tsx
+++ b/web/app/components/base/input-number/index.tsx
@@ -8,7 +8,7 @@ export type InputNumberProps = {
value?: number
onChange: (value?: number) => void
amount?: number
- size?: 'sm' | 'md'
+ size?: 'regular' | 'large'
max?: number
min?: number
defaultValue?: number
@@ -19,14 +19,12 @@ export type InputNumberProps = {
} & Omit
export const InputNumber: FC = (props) => {
- const { unit, className, onChange, amount = 1, value, size = 'md', max, min, defaultValue, wrapClassName, controlWrapClassName, controlClassName, disabled, ...rest } = props
+ const { unit, className, onChange, amount = 1, value, size = 'regular', max, min, defaultValue, wrapClassName, controlWrapClassName, controlClassName, disabled, ...rest } = props
const isValidValue = (v: number) => {
- if (max && v > max)
+ if (typeof max === 'number' && v > max)
return false
- if (min && v < min)
- return false
- return true
+ return !(typeof min === 'number' && v < min)
}
const inc = () => {
@@ -76,29 +74,39 @@ export const InputNumber: FC = (props) => {
onChange(parsed)
}}
unit={unit}
+ size={size}
/>
-
diff --git a/web/app/components/base/input/index.tsx b/web/app/components/base/input/index.tsx
index 5f059c3b7f..30fd90aff8 100644
--- a/web/app/components/base/input/index.tsx
+++ b/web/app/components/base/input/index.tsx
@@ -30,7 +30,7 @@ export type InputProps = {
wrapperClassName?: string
styleCss?: CSSProperties
unit?: string
-} & React.InputHTMLAttributes & VariantProps
+} & Omit, 'size'> & VariantProps
const Input = ({
size,
diff --git a/web/app/components/base/markdown-blocks/music.tsx b/web/app/components/base/markdown-blocks/music.tsx
new file mode 100644
index 0000000000..7edd1713c9
--- /dev/null
+++ b/web/app/components/base/markdown-blocks/music.tsx
@@ -0,0 +1,37 @@
+import abcjs from 'abcjs'
+import { useEffect, useRef } from 'react'
+import 'abcjs/abcjs-audio.css'
+
+const MarkdownMusic = ({ children }: { children: React.ReactNode }) => {
+ const containerRef = useRef(null)
+ const controlsRef = useRef(null)
+
+ useEffect(() => {
+ if (containerRef.current && controlsRef.current) {
+ if (typeof children === 'string') {
+ const visualObjs = abcjs.renderAbc(containerRef.current, children, {
+ add_classes: true, // Add classes to SVG elements for cursor tracking
+ responsive: 'resize', // Make notation responsive
+ })
+ const synthControl = new abcjs.synth.SynthController()
+ synthControl.load(controlsRef.current, {}, { displayPlay: true })
+ const synth = new abcjs.synth.CreateSynth()
+ const visualObj = visualObjs[0]
+ synth.init({ visualObj }).then(() => {
+ synthControl.setTune(visualObj, false)
+ })
+ containerRef.current.style.overflow = 'auto'
+ }
+ }
+ }, [children])
+
+ return (
+
+ )
+}
+MarkdownMusic.displayName = 'MarkdownMusic'
+
+export default MarkdownMusic
diff --git a/web/app/components/base/markdown.tsx b/web/app/components/base/markdown.tsx
index 24ae59af73..52b880affa 100644
--- a/web/app/components/base/markdown.tsx
+++ b/web/app/components/base/markdown.tsx
@@ -23,6 +23,7 @@ import VideoGallery from '@/app/components/base/video-gallery'
import AudioGallery from '@/app/components/base/audio-gallery'
import MarkdownButton from '@/app/components/base/markdown-blocks/button'
import MarkdownForm from '@/app/components/base/markdown-blocks/form'
+import MarkdownMusic from '@/app/components/base/markdown-blocks/music'
import ThinkBlock from '@/app/components/base/markdown-blocks/think-block'
import { Theme } from '@/types/app'
import useTheme from '@/hooks/use-theme'
@@ -51,6 +52,7 @@ const capitalizationLanguageNameMap: Record = {
json: 'JSON',
latex: 'Latex',
svg: 'SVG',
+ abc: 'ABC',
}
const getCorrectCapitalizationLanguageName = (language: string) => {
if (!language)
@@ -85,9 +87,11 @@ const preprocessLaTeX = (content: string) => {
}
const preprocessThinkTag = (content: string) => {
+ const thinkOpenTagRegex = /\n/g
+ const thinkCloseTagRegex = /\n<\/think>/g
return flow([
- (str: string) => str.replace('\n', '\n'),
- (str: string) => str.replace('\n ', '\n[ENDTHINKFLAG]'),
+ (str: string) => str.replace(thinkOpenTagRegex, '\n'),
+ (str: string) => str.replace(thinkCloseTagRegex, '\n[ENDTHINKFLAG] '),
])(content)
}
@@ -135,45 +139,54 @@ const CodeBlock: any = memo(({ inline, className, children, ...props }: any) =>
const renderCodeContent = useMemo(() => {
const content = String(children).replace(/\n$/, '')
- if (language === 'mermaid' && isSVG) {
- return
- }
- else if (language === 'echarts') {
- return (
-
+ switch (language) {
+ case 'mermaid':
+ if (isSVG)
+ return
+ break
+ case 'echarts':
+ return (
+
+
+
+
+
+ )
+ case 'svg':
+ if (isSVG) {
+ return (
+
+
+
+ )
+ }
+ break
+ case 'abc':
+ return (
-
+
-
- )
- }
- else if (language === 'svg' && isSVG) {
- return (
-
-
-
- )
- }
- else {
- return (
-
- {content}
-
- )
+ )
+ default:
+ return (
+
+ {content}
+
+ )
}
- }, [language, match, props, children, chartData, isSVG])
+ }, [children, language, isSVG, chartData, props, theme, match])
if (inline || !match)
return {children}
diff --git a/web/app/components/base/param-item/index.tsx b/web/app/components/base/param-item/index.tsx
index 4cae402e3b..03eb5a7c42 100644
--- a/web/app/components/base/param-item/index.tsx
+++ b/web/app/components/base/param-item/index.tsx
@@ -54,7 +54,7 @@ const ParamItem: FC = ({ className, id, name, noTooltip, tip, step = 0.1,
max={max}
step={step}
amount={step}
- size='sm'
+ size='regular'
value={value}
onChange={(value) => {
onChange(id, value)
diff --git a/web/app/components/base/prompt-editor/plugins/history-block/node.tsx b/web/app/components/base/prompt-editor/plugins/history-block/node.tsx
index 1a2600d568..1cb33fcc49 100644
--- a/web/app/components/base/prompt-editor/plugins/history-block/node.tsx
+++ b/web/app/components/base/prompt-editor/plugins/history-block/node.tsx
@@ -14,7 +14,7 @@ export class HistoryBlockNode extends DecoratorNode {
}
static clone(node: HistoryBlockNode): HistoryBlockNode {
- return new HistoryBlockNode(node.__roleName, node.__onEditRole)
+ return new HistoryBlockNode(node.__roleName, node.__onEditRole, node.__key)
}
constructor(roleName: RoleName, onEditRole: () => void, key?: NodeKey) {
diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx
index 2cf4c95b87..2f6c3374a7 100644
--- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx
+++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx
@@ -11,6 +11,7 @@ import { mergeRegister } from '@lexical/utils'
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import {
RiErrorWarningFill,
+ RiMoreLine,
} from '@remixicon/react'
import { useSelectOrDelete } from '../../hooks'
import type { WorkflowNodesMap } from './node'
@@ -27,26 +28,35 @@ import { Line3 } from '@/app/components/base/icons/src/public/common'
import { isConversationVar, isENV, isSystemVar } from '@/app/components/workflow/nodes/_base/components/variable/utils'
import Tooltip from '@/app/components/base/tooltip'
import { isExceptionVariable } from '@/app/components/workflow/utils'
+import VarFullPathPanel from '@/app/components/workflow/nodes/_base/components/variable/var-full-path-panel'
+import { Type } from '@/app/components/workflow/nodes/llm/types'
+import type { ValueSelector } from '@/app/components/workflow/types'
type WorkflowVariableBlockComponentProps = {
nodeKey: string
variables: string[]
workflowNodesMap: WorkflowNodesMap
+ getVarType?: (payload: {
+ nodeId: string,
+ valueSelector: ValueSelector,
+ }) => Type
}
const WorkflowVariableBlockComponent = ({
nodeKey,
variables,
workflowNodesMap = {},
+ getVarType,
}: WorkflowVariableBlockComponentProps) => {
const { t } = useTranslation()
const [editor] = useLexicalComposerContext()
const [ref, isSelected] = useSelectOrDelete(nodeKey, DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND)
const variablesLength = variables.length
+ const isShowAPart = variablesLength > 2
const varName = (
() => {
const isSystem = isSystemVar(variables)
- const varName = variablesLength >= 3 ? (variables).slice(-2).join('.') : variables[variablesLength - 1]
+ const varName = variables[variablesLength - 1]
return `${isSystem ? 'sys.' : ''}${varName}`
}
)()
@@ -76,7 +86,7 @@ const WorkflowVariableBlockComponent = ({
const Item = (
)}
+ {isShowAPart && (
+
+
+
+
+ )}
+
{!isEnv && !isChatVar &&
}
{isEnv &&
}
@@ -126,7 +143,27 @@ const WorkflowVariableBlockComponent = ({
)
}
- return Item
+ if (!node)
+ return null
+
+ return (
+
}
+ disabled={!isShowAPart}
+ >
+
{Item}
+
+ )
}
export default memo(WorkflowVariableBlockComponent)
diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx
index 05d4505e20..479dce9615 100644
--- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx
+++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx
@@ -9,7 +9,7 @@ import {
} from 'lexical'
import { mergeRegister } from '@lexical/utils'
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
-import type { WorkflowVariableBlockType } from '../../types'
+import type { GetVarType, WorkflowVariableBlockType } from '../../types'
import {
$createWorkflowVariableBlockNode,
WorkflowVariableBlockNode,
@@ -25,11 +25,13 @@ export type WorkflowVariableBlockProps = {
getWorkflowNode: (nodeId: string) => Node
onInsert?: () => void
onDelete?: () => void
+ getVarType: GetVarType
}
const WorkflowVariableBlock = memo(({
workflowNodesMap,
onInsert,
onDelete,
+ getVarType,
}: WorkflowVariableBlockType) => {
const [editor] = useLexicalComposerContext()
@@ -48,7 +50,7 @@ const WorkflowVariableBlock = memo(({
INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND,
(variables: string[]) => {
editor.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined)
- const workflowVariableBlockNode = $createWorkflowVariableBlockNode(variables, workflowNodesMap)
+ const workflowVariableBlockNode = $createWorkflowVariableBlockNode(variables, workflowNodesMap, getVarType)
$insertNodes([workflowVariableBlockNode])
if (onInsert)
@@ -69,7 +71,7 @@ const WorkflowVariableBlock = memo(({
COMMAND_PRIORITY_EDITOR,
),
)
- }, [editor, onInsert, onDelete, workflowNodesMap])
+ }, [editor, onInsert, onDelete, workflowNodesMap, getVarType])
return null
})
diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx
index 0564e6f16d..dce636d92d 100644
--- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx
+++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx
@@ -2,34 +2,39 @@ import type { LexicalNode, NodeKey, SerializedLexicalNode } from 'lexical'
import { DecoratorNode } from 'lexical'
import type { WorkflowVariableBlockType } from '../../types'
import WorkflowVariableBlockComponent from './component'
+import type { GetVarType } from '../../types'
export type WorkflowNodesMap = WorkflowVariableBlockType['workflowNodesMap']
+
export type SerializedNode = SerializedLexicalNode & {
variables: string[]
workflowNodesMap: WorkflowNodesMap
+ getVarType?: GetVarType
}
export class WorkflowVariableBlockNode extends DecoratorNode
{
__variables: string[]
__workflowNodesMap: WorkflowNodesMap
+ __getVarType?: GetVarType
static getType(): string {
return 'workflow-variable-block'
}
static clone(node: WorkflowVariableBlockNode): WorkflowVariableBlockNode {
- return new WorkflowVariableBlockNode(node.__variables, node.__workflowNodesMap, node.__key)
+ return new WorkflowVariableBlockNode(node.__variables, node.__workflowNodesMap, node.__getVarType, node.__key)
}
isInline(): boolean {
return true
}
- constructor(variables: string[], workflowNodesMap: WorkflowNodesMap, key?: NodeKey) {
+ constructor(variables: string[], workflowNodesMap: WorkflowNodesMap, getVarType: any, key?: NodeKey) {
super(key)
this.__variables = variables
this.__workflowNodesMap = workflowNodesMap
+ this.__getVarType = getVarType
}
createDOM(): HTMLElement {
@@ -48,12 +53,13 @@ export class WorkflowVariableBlockNode extends DecoratorNode
nodeKey={this.getKey()}
variables={this.__variables}
workflowNodesMap={this.__workflowNodesMap}
+ getVarType={this.__getVarType!}
/>
)
}
static importJSON(serializedNode: SerializedNode): WorkflowVariableBlockNode {
- const node = $createWorkflowVariableBlockNode(serializedNode.variables, serializedNode.workflowNodesMap)
+ const node = $createWorkflowVariableBlockNode(serializedNode.variables, serializedNode.workflowNodesMap, serializedNode.getVarType)
return node
}
@@ -64,6 +70,7 @@ export class WorkflowVariableBlockNode extends DecoratorNode
version: 1,
variables: this.getVariables(),
workflowNodesMap: this.getWorkflowNodesMap(),
+ getVarType: this.getVarType(),
}
}
@@ -77,12 +84,17 @@ export class WorkflowVariableBlockNode extends DecoratorNode
return self.__workflowNodesMap
}
+ getVarType(): any {
+ const self = this.getLatest()
+ return self.__getVarType
+ }
+
getTextContent(): string {
return `{{#${this.getVariables().join('.')}#}}`
}
}
-export function $createWorkflowVariableBlockNode(variables: string[], workflowNodesMap: WorkflowNodesMap): WorkflowVariableBlockNode {
- return new WorkflowVariableBlockNode(variables, workflowNodesMap)
+export function $createWorkflowVariableBlockNode(variables: string[], workflowNodesMap: WorkflowNodesMap, getVarType?: GetVarType): WorkflowVariableBlockNode {
+ return new WorkflowVariableBlockNode(variables, workflowNodesMap, getVarType)
}
export function $isWorkflowVariableBlockNode(
diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/workflow-variable-block-replacement-block.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/workflow-variable-block-replacement-block.tsx
index 22ebc5d248..288008bbcc 100644
--- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/workflow-variable-block-replacement-block.tsx
+++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/workflow-variable-block-replacement-block.tsx
@@ -16,6 +16,7 @@ import { VAR_REGEX as REGEX, resetReg } from '@/config'
const WorkflowVariableBlockReplacementBlock = ({
workflowNodesMap,
+ getVarType,
onInsert,
}: WorkflowVariableBlockType) => {
const [editor] = useLexicalComposerContext()
@@ -30,8 +31,8 @@ const WorkflowVariableBlockReplacementBlock = ({
onInsert()
const nodePathString = textNode.getTextContent().slice(3, -3)
- return $applyNodeReplacement($createWorkflowVariableBlockNode(nodePathString.split('.'), workflowNodesMap))
- }, [onInsert, workflowNodesMap])
+ return $applyNodeReplacement($createWorkflowVariableBlockNode(nodePathString.split('.'), workflowNodesMap, getVarType))
+ }, [onInsert, workflowNodesMap, getVarType])
const getMatch = useCallback((text: string) => {
const matchArr = REGEX.exec(text)
diff --git a/web/app/components/base/prompt-editor/types.ts b/web/app/components/base/prompt-editor/types.ts
index 6d0f307c17..0f09fb2473 100644
--- a/web/app/components/base/prompt-editor/types.ts
+++ b/web/app/components/base/prompt-editor/types.ts
@@ -1,8 +1,10 @@
+import type { Type } from '../../workflow/nodes/llm/types'
import type { Dataset } from './plugins/context-block'
import type { RoleName } from './plugins/history-block'
import type {
Node,
NodeOutPutVar,
+ ValueSelector,
} from '@/app/components/workflow/types'
export type Option = {
@@ -54,12 +56,18 @@ export type ExternalToolBlockType = {
onAddExternalTool?: () => void
}
+export type GetVarType = (payload: {
+ nodeId: string,
+ valueSelector: ValueSelector,
+}) => Type
+
export type WorkflowVariableBlockType = {
show?: boolean
variables?: NodeOutPutVar[]
workflowNodesMap?: Record>
onInsert?: () => void
onDelete?: () => void
+ getVarType?: GetVarType
}
export type MenuTextMatch = {
diff --git a/web/app/components/base/segmented-control/index.tsx b/web/app/components/base/segmented-control/index.tsx
new file mode 100644
index 0000000000..bd921e4243
--- /dev/null
+++ b/web/app/components/base/segmented-control/index.tsx
@@ -0,0 +1,68 @@
+import React from 'react'
+import classNames from '@/utils/classnames'
+import type { RemixiconComponentType } from '@remixicon/react'
+import Divider from '../divider'
+
+// Updated generic type to allow enum values
+type SegmentedControlProps = {
+ options: { Icon: RemixiconComponentType, text: string, value: T }[]
+ value: T
+ onChange: (value: T) => void
+ className?: string
+}
+
+export const SegmentedControl = ({
+ options,
+ value,
+ onChange,
+ className,
+}: SegmentedControlProps): JSX.Element => {
+ const selectedOptionIndex = options.findIndex(option => option.value === value)
+
+ return (
+
+ {options.map((option, index) => {
+ const { Icon } = option
+ const isSelected = index === selectedOptionIndex
+ const isNextSelected = index === selectedOptionIndex - 1
+ const isLast = index === options.length - 1
+ return (
+
onChange(option.value)}
+ >
+
+
+
+
+ {option.text}
+
+ {!isLast && !isSelected && !isNextSelected && (
+
+ )}
+
+ )
+ })}
+
+ )
+}
+
+export default React.memo(SegmentedControl) as typeof SegmentedControl
diff --git a/web/app/components/base/textarea/index.tsx b/web/app/components/base/textarea/index.tsx
index 0f18bebedf..1e274515f8 100644
--- a/web/app/components/base/textarea/index.tsx
+++ b/web/app/components/base/textarea/index.tsx
@@ -8,8 +8,9 @@ const textareaVariants = cva(
{
variants: {
size: {
- regular: 'px-3 radius-md system-sm-regular',
- large: 'px-4 radius-lg system-md-regular',
+ small: 'py-1 rounded-md system-xs-regular',
+ regular: 'px-3 rounded-md system-sm-regular',
+ large: 'px-4 rounded-lg system-md-regular',
},
},
defaultVariants: {
diff --git a/web/app/components/base/tooltip/index.tsx b/web/app/components/base/tooltip/index.tsx
index e9b7ab047a..e6c4de31f1 100644
--- a/web/app/components/base/tooltip/index.tsx
+++ b/web/app/components/base/tooltip/index.tsx
@@ -10,6 +10,7 @@ export type TooltipProps = {
position?: Placement
triggerMethod?: 'hover' | 'click'
triggerClassName?: string
+ triggerTestId?: string
disabled?: boolean
popupContent?: React.ReactNode
children?: React.ReactNode
@@ -24,6 +25,7 @@ const Tooltip: FC = ({
position = 'top',
triggerMethod = 'hover',
triggerClassName,
+ triggerTestId,
disabled = false,
popupContent,
children,
@@ -91,7 +93,7 @@ const Tooltip: FC = ({
onMouseLeave={() => triggerMethod === 'hover' && handleLeave(true)}
asChild={asChild}
>
- {children ||
}
+ {children ||
}
= (props) => {
}>
= (props) => {
}>
= ({
const resetList = useCallback(() => {
setSelectedSegmentIds([])
invalidSegmentList()
- // eslint-disable-next-line react-hooks/exhaustive-deps
- }, [])
+ }, [invalidSegmentList])
const resetChildList = useCallback(() => {
invalidChildSegmentList()
- // eslint-disable-next-line react-hooks/exhaustive-deps
- }, [])
+ }, [invalidChildSegmentList])
const onClickCard = (detail: SegmentDetailModel, isEditMode = false) => {
setCurrSegment({ segInfo: detail, showModal: true, isEditMode })
@@ -253,7 +251,7 @@ const Completed: FC = ({
const invalidChunkListEnabled = useInvalid(useChunkListEnabledKey)
const invalidChunkListDisabled = useInvalid(useChunkListDisabledKey)
- const refreshChunkListWithStatusChanged = () => {
+ const refreshChunkListWithStatusChanged = useCallback(() => {
switch (selectedStatus) {
case 'all':
invalidChunkListDisabled()
@@ -262,7 +260,7 @@ const Completed: FC = ({
default:
invalidSegmentList()
}
- }
+ }, [selectedStatus, invalidChunkListDisabled, invalidChunkListEnabled, invalidSegmentList])
const onChangeSwitch = useCallback(async (enable: boolean, segId?: string) => {
const operationApi = enable ? enableSegment : disableSegment
@@ -280,8 +278,7 @@ const Completed: FC = ({
notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') })
},
})
- // eslint-disable-next-line react-hooks/exhaustive-deps
- }, [datasetId, documentId, selectedSegmentIds, segments])
+ }, [datasetId, documentId, selectedSegmentIds, segments, disableSegment, enableSegment, t, notify, refreshChunkListWithStatusChanged])
const { mutateAsync: deleteSegment } = useDeleteSegment()
@@ -296,12 +293,11 @@ const Completed: FC = ({
notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') })
},
})
- // eslint-disable-next-line react-hooks/exhaustive-deps
- }, [datasetId, documentId, selectedSegmentIds])
+ }, [datasetId, documentId, selectedSegmentIds, deleteSegment, resetList, t, notify])
const { mutateAsync: updateSegment } = useUpdateSegment()
- const refreshChunkListDataWithDetailChanged = () => {
+ const refreshChunkListDataWithDetailChanged = useCallback(() => {
switch (selectedStatus) {
case 'all':
invalidChunkListDisabled()
@@ -316,7 +312,7 @@ const Completed: FC = ({
invalidChunkListEnabled()
break
}
- }
+ }, [selectedStatus, invalidChunkListDisabled, invalidChunkListEnabled, invalidChunkListAll])
const handleUpdateSegment = useCallback(async (
segmentId: string,
@@ -375,17 +371,18 @@ const Completed: FC = ({
eventEmitter?.emit('update-segment-done')
},
})
- // eslint-disable-next-line react-hooks/exhaustive-deps
- }, [segments, datasetId, documentId])
+ }, [segments, datasetId, documentId, updateSegment, docForm, notify, eventEmitter, onCloseSegmentDetail, refreshChunkListDataWithDetailChanged, t])
useEffect(() => {
resetList()
+ // eslint-disable-next-line react-hooks/exhaustive-deps
}, [pathname])
useEffect(() => {
if (importStatus === ProcessStatus.COMPLETED)
resetList()
- }, [importStatus, resetList])
+ // eslint-disable-next-line react-hooks/exhaustive-deps
+ }, [importStatus])
const onCancelBatchOperation = useCallback(() => {
setSelectedSegmentIds([])
@@ -430,8 +427,7 @@ const Completed: FC = ({
const count = segmentListData?.total || 0
return `${total} ${t('datasetDocuments.segment.searchResults', { count })}`
}
- // eslint-disable-next-line react-hooks/exhaustive-deps
- }, [segmentListData?.total, mode, parentMode, searchValue, selectedStatus])
+ }, [segmentListData, mode, parentMode, searchValue, selectedStatus, t])
const toggleFullScreen = useCallback(() => {
setFullScreen(!fullScreen)
@@ -449,8 +445,7 @@ const Completed: FC = ({
resetList()
currentPage !== totalPages && setCurrentPage(totalPages)
}
- // eslint-disable-next-line react-hooks/exhaustive-deps
- }, [segmentListData, limit, currentPage])
+ }, [segmentListData, limit, currentPage, resetList])
const { mutateAsync: deleteChildSegment } = useDeleteChildSegment()
@@ -470,8 +465,7 @@ const Completed: FC = ({
},
},
)
- // eslint-disable-next-line react-hooks/exhaustive-deps
- }, [datasetId, documentId, parentMode])
+ }, [datasetId, documentId, parentMode, deleteChildSegment, resetList, resetChildList, t, notify])
const handleAddNewChildChunk = useCallback((parentChunkId: string) => {
setShowNewChildSegmentModal(true)
@@ -490,8 +484,7 @@ const Completed: FC = ({
else {
resetChildList()
}
- // eslint-disable-next-line react-hooks/exhaustive-deps
- }, [parentMode, currChunkId, segments])
+ }, [parentMode, currChunkId, segments, refreshChunkListDataWithDetailChanged, resetChildList])
const viewNewlyAddedChildChunk = useCallback(() => {
const totalPages = childChunkListData?.total_pages || 0
@@ -505,8 +498,7 @@ const Completed: FC = ({
resetChildList()
currentPage !== totalPages && setCurrentPage(totalPages)
}
- // eslint-disable-next-line react-hooks/exhaustive-deps
- }, [childChunkListData, limit, currentPage])
+ }, [childChunkListData, limit, currentPage, resetChildList])
const onClickSlice = useCallback((detail: ChildChunkDetail) => {
setCurrChildChunk({ childChunkInfo: detail, showModal: true })
@@ -560,8 +552,7 @@ const Completed: FC = ({
eventEmitter?.emit('update-child-segment-done')
},
})
- // eslint-disable-next-line react-hooks/exhaustive-deps
- }, [segments, childSegments, datasetId, documentId, parentMode])
+ }, [segments, datasetId, documentId, parentMode, updateChildSegment, notify, eventEmitter, onCloseChildSegmentDetail, refreshChunkListDataWithDetailChanged, resetChildList, t])
const onClearFilter = useCallback(() => {
setInputValue('')
@@ -570,6 +561,12 @@ const Completed: FC = ({
setCurrentPage(1)
}, [])
+ const selectDefaultValue = useMemo(() => {
+ if (selectedStatus === 'all')
+ return 'all'
+ return selectedStatus ? 1 : 0
+ }, [selectedStatus])
+
return (
= ({
@@ -591,7 +588,7 @@ const Completed: FC = ({
= ({
const wordCountText = useMemo(() => {
const total = formatNumber(word_count)
return `${total} ${t('datasetDocuments.segment.characters', { count: word_count })}`
- // eslint-disable-next-line react-hooks/exhaustive-deps
- }, [word_count])
+ }, [word_count, t])
const labelPrefix = useMemo(() => {
return isParentChildMode ? t('datasetDocuments.segment.parentChunk') : t('datasetDocuments.segment.chunk')
- // eslint-disable-next-line react-hooks/exhaustive-deps
- }, [isParentChildMode])
+ }, [isParentChildMode, t])
if (loading)
return
diff --git a/web/app/components/datasets/documents/detail/completed/segment-detail.tsx b/web/app/components/datasets/documents/detail/completed/segment-detail.tsx
index cea3402499..d3575c18ed 100644
--- a/web/app/components/datasets/documents/detail/completed/segment-detail.tsx
+++ b/web/app/components/datasets/documents/detail/completed/segment-detail.tsx
@@ -86,8 +86,7 @@ const SegmentDetail: FC = ({
const titleText = useMemo(() => {
return isEditMode ? t('datasetDocuments.segment.editChunk') : t('datasetDocuments.segment.chunkDetail')
- // eslint-disable-next-line react-hooks/exhaustive-deps
- }, [isEditMode])
+ }, [isEditMode, t])
const isQAModel = useMemo(() => {
return docForm === ChunkingMode.qa
@@ -98,13 +97,11 @@ const SegmentDetail: FC = ({
const total = formatNumber(isEditMode ? contentLength : segInfo!.word_count as number)
const count = isEditMode ? contentLength : segInfo!.word_count as number
return `${total} ${t('datasetDocuments.segment.characters', { count })}`
- // eslint-disable-next-line react-hooks/exhaustive-deps
- }, [isEditMode, question.length, answer.length, segInfo?.word_count, isQAModel])
+ }, [isEditMode, question.length, answer.length, isQAModel, segInfo, t])
const labelPrefix = useMemo(() => {
return isParentChildMode ? t('datasetDocuments.segment.parentChunk') : t('datasetDocuments.segment.chunk')
- // eslint-disable-next-line react-hooks/exhaustive-deps
- }, [isParentChildMode])
+ }, [isParentChildMode, t])
return (
diff --git a/web/app/components/datasets/documents/detail/completed/segment-list.tsx b/web/app/components/datasets/documents/detail/completed/segment-list.tsx
index b2351c1b97..f6076e5813 100644
--- a/web/app/components/datasets/documents/detail/completed/segment-list.tsx
+++ b/web/app/components/datasets/documents/detail/completed/segment-list.tsx
@@ -42,7 +42,7 @@ const SegmentList = (
embeddingAvailable,
onClearFilter,
}: ISegmentListProps & {
- ref: React.RefObject
;
+ ref: React.LegacyRef
},
) => {
const mode = useDocumentContext(s => s.mode)
diff --git a/web/app/components/datasets/documents/index.tsx b/web/app/components/datasets/documents/index.tsx
index 20e14a994b..854c984559 100644
--- a/web/app/components/datasets/documents/index.tsx
+++ b/web/app/components/datasets/documents/index.tsx
@@ -29,6 +29,8 @@ import { useChildSegmentListKey, useSegmentListKey } from '@/service/knowledge/u
import useEditDocumentMetadata from '../metadata/hooks/use-edit-dataset-metadata'
import DatasetMetadataDrawer from '../metadata/metadata-dataset/dataset-metadata-drawer'
import StatusWithAction from '../common/document-status-with-action/status-with-action'
+import { LanguagesSupported } from '@/i18n/language'
+import { getLocaleOnClient } from '@/i18n'
const FolderPlusIcon = ({ className }: React.SVGProps) => {
return