Merge branch 'langgenius:main' into deploy

pull/18407/head
lska367 1 year ago committed by GitHub
commit 029e72395b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -424,6 +424,12 @@ WORKFLOW_CALL_MAX_DEPTH=5
WORKFLOW_PARALLEL_DEPTH_LIMIT=3 WORKFLOW_PARALLEL_DEPTH_LIMIT=3
MAX_VARIABLE_SIZE=204800 MAX_VARIABLE_SIZE=204800
# Workflow storage configuration
# Options: rdbms, hybrid
# rdbms: Use only the relational database (default)
# hybrid: Save new data to object storage, read from both object storage and RDBMS
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
# App configuration # App configuration
APP_MAX_EXECUTION_TIME=1200 APP_MAX_EXECUTION_TIME=1200
APP_MAX_ACTIVE_REQUESTS=0 APP_MAX_ACTIVE_REQUESTS=0

@ -54,6 +54,7 @@ def initialize_extensions(app: DifyApp):
ext_otel, ext_otel,
ext_proxy_fix, ext_proxy_fix,
ext_redis, ext_redis,
ext_repositories,
ext_sentry, ext_sentry,
ext_set_secretkey, ext_set_secretkey,
ext_storage, ext_storage,
@ -74,6 +75,7 @@ def initialize_extensions(app: DifyApp):
ext_migrate, ext_migrate,
ext_redis, ext_redis,
ext_storage, ext_storage,
ext_repositories,
ext_celery, ext_celery,
ext_login, ext_login,
ext_mail, ext_mail,

@ -12,7 +12,7 @@ from pydantic import (
) )
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
from configs.feature.hosted_service import HostedServiceConfig from .hosted_service import HostedServiceConfig
class SecurityConfig(BaseSettings): class SecurityConfig(BaseSettings):
@ -519,6 +519,11 @@ class WorkflowNodeExecutionConfig(BaseSettings):
default=100, default=100,
) )
WORKFLOW_NODE_EXECUTION_STORAGE: str = Field(
default="rdbms",
description="Storage backend for WorkflowNodeExecution. Options: 'rdbms', 'hybrid'",
)
class AuthConfig(BaseSettings): class AuthConfig(BaseSettings):
""" """

@ -85,5 +85,35 @@ class RuleCodeGenerateApi(Resource):
return code_result return code_result
class RuleStructuredOutputGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
account = current_user
try:
structured_output = LLMGenerator.generate_structured_output(
tenant_id=account.current_tenant_id,
instruction=args["instruction"],
model_config=args["model_config"],
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
return structured_output
api.add_resource(RuleGenerateApi, "/rule-generate") api.add_resource(RuleGenerateApi, "/rule-generate")
api.add_resource(RuleCodeGenerateApi, "/rule-code-generate") api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")
api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate")

@ -74,7 +74,9 @@ class OAuthDataSourceBinding(Resource):
if not oauth_provider: if not oauth_provider:
return {"error": "Invalid provider"}, 400 return {"error": "Invalid provider"}, 400
if "code" in request.args: if "code" in request.args:
code = request.args.get("code") code = request.args.get("code", "")
if not code:
return {"error": "Invalid code"}, 400
try: try:
oauth_provider.get_access_token(code) oauth_provider.get_access_token(code)
except requests.exceptions.HTTPError as e: except requests.exceptions.HTTPError as e:

@ -16,7 +16,7 @@ from controllers.console.auth.error import (
PasswordMismatchError, PasswordMismatchError,
) )
from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError
from controllers.console.wraps import setup_required from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email, extract_remote_ip from libs.helper import email, extract_remote_ip
@ -30,6 +30,7 @@ from services.feature_service import FeatureService
class ForgotPasswordSendEmailApi(Resource): class ForgotPasswordSendEmailApi(Resource):
@setup_required @setup_required
@email_password_login_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json") parser.add_argument("email", type=email, required=True, location="json")
@ -62,6 +63,7 @@ class ForgotPasswordSendEmailApi(Resource):
class ForgotPasswordCheckApi(Resource): class ForgotPasswordCheckApi(Resource):
@setup_required @setup_required
@email_password_login_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json") parser.add_argument("email", type=str, required=True, location="json")
@ -86,12 +88,21 @@ class ForgotPasswordCheckApi(Resource):
AccountService.add_forgot_password_error_rate_limit(args["email"]) AccountService.add_forgot_password_error_rate_limit(args["email"])
raise EmailCodeError() raise EmailCodeError()
# Verified, revoke the first token
AccountService.revoke_reset_password_token(args["token"])
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
user_email, code=args["code"], additional_data={"phase": "reset"}
)
AccountService.reset_forgot_password_error_rate_limit(args["email"]) AccountService.reset_forgot_password_error_rate_limit(args["email"])
return {"is_valid": True, "email": token_data.get("email")} return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
class ForgotPasswordResetApi(Resource): class ForgotPasswordResetApi(Resource):
@setup_required @setup_required
@email_password_login_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, nullable=False, location="json") parser.add_argument("token", type=str, required=True, nullable=False, location="json")
@ -107,6 +118,9 @@ class ForgotPasswordResetApi(Resource):
reset_data = AccountService.get_reset_password_data(args["token"]) reset_data = AccountService.get_reset_password_data(args["token"])
if not reset_data: if not reset_data:
raise InvalidTokenError() raise InvalidTokenError()
# Must use token in reset phase
if reset_data.get("phase", "") != "reset":
raise InvalidTokenError()
# Revoke token to prevent reuse # Revoke token to prevent reuse
AccountService.revoke_reset_password_token(args["token"]) AccountService.revoke_reset_password_token(args["token"])

@ -22,7 +22,7 @@ from controllers.console.error import (
EmailSendIpLimitError, EmailSendIpLimitError,
NotAllowedCreateWorkspace, NotAllowedCreateWorkspace,
) )
from controllers.console.wraps import setup_required from controllers.console.wraps import email_password_login_enabled, setup_required
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip from libs.helper import email, extract_remote_ip
from libs.password import valid_password from libs.password import valid_password
@ -38,6 +38,7 @@ class LoginApi(Resource):
"""Resource for user login.""" """Resource for user login."""
@setup_required @setup_required
@email_password_login_enabled
def post(self): def post(self):
"""Authenticate user and login.""" """Authenticate user and login."""
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -110,6 +111,7 @@ class LogoutApi(Resource):
class ResetPasswordSendEmailApi(Resource): class ResetPasswordSendEmailApi(Resource):
@setup_required @setup_required
@email_password_login_enabled
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json") parser.add_argument("email", type=email, required=True, location="json")

@ -210,3 +210,16 @@ def enterprise_license_required(view):
return view(*args, **kwargs) return view(*args, **kwargs)
return decorated return decorated
def email_password_login_enabled(view):
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_system_features()
if features.enable_email_password_login:
return view(*args, **kwargs)
# otherwise, return 403
abort(403)
return decorated

@ -46,6 +46,7 @@ class MessageListApi(WebApiResource):
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"created_at": TimestampField, "created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String, "status": fields.String,
"error": fields.String, "error": fields.String,
} }

@ -320,10 +320,9 @@ class AdvancedChatAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id session=session, workflow_run_id=self._workflow_run_id
) )
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event workflow_run=workflow_run, event=event
) )
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
@ -341,11 +340,10 @@ class AdvancedChatAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id session=session, workflow_run_id=self._workflow_run_id
) )
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event workflow_run=workflow_run, event=event
) )
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response( node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
@ -363,11 +361,10 @@ class AdvancedChatAppGenerateTaskPipeline:
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
session=session, event=event event=event
) )
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
@ -383,18 +380,15 @@ class AdvancedChatAppGenerateTaskPipeline:
| QueueNodeInLoopFailedEvent | QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent, | QueueNodeExceptionEvent,
): ):
with Session(db.engine, expire_on_commit=False) as session: workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( event=event
session=session, event=event )
)
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session, event=event,
event=event, task_id=self._application_generate_entity.task_id,
task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution,
workflow_node_execution=workflow_node_execution, )
)
session.commit()
if node_finish_resp: if node_finish_resp:
yield node_finish_resp yield node_finish_resp

@ -279,10 +279,9 @@ class WorkflowAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id session=session, workflow_run_id=self._workflow_run_id
) )
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event workflow_run=workflow_run, event=event
) )
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
@ -300,10 +299,9 @@ class WorkflowAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id session=session, workflow_run_id=self._workflow_run_id
) )
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event workflow_run=workflow_run, event=event
) )
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response( node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
@ -313,17 +311,14 @@ class WorkflowAppGenerateTaskPipeline:
if node_start_response: if node_start_response:
yield node_start_response yield node_start_response
elif isinstance(event, QueueNodeSucceededEvent): elif isinstance(event, QueueNodeSucceededEvent):
with Session(db.engine, expire_on_commit=False) as session: workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( event=event
session=session, event=event )
) node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( event=event,
session=session, task_id=self._application_generate_entity.task_id,
event=event, workflow_node_execution=workflow_node_execution,
task_id=self._application_generate_entity.task_id, )
workflow_node_execution=workflow_node_execution,
)
session.commit()
if node_success_response: if node_success_response:
yield node_success_response yield node_success_response
@ -334,18 +329,14 @@ class WorkflowAppGenerateTaskPipeline:
| QueueNodeInLoopFailedEvent | QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent, | QueueNodeExceptionEvent,
): ):
with Session(db.engine, expire_on_commit=False) as session: workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( event=event,
session=session, )
event=event, node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
) event=event,
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( task_id=self._application_generate_entity.task_id,
session=session, workflow_node_execution=workflow_node_execution,
event=event, )
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
session.commit()
if node_failed_response: if node_failed_response:
yield node_failed_response yield node_failed_response
@ -627,6 +618,7 @@ class WorkflowAppGenerateTaskPipeline:
workflow_app_log.created_by = self._user_id workflow_app_log.created_by = self._user_id
session.add(workflow_app_log) session.add(workflow_app_log)
session.commit()
def _text_chunk_to_stream_response( def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: Optional[list[str]] = None self, text: str, from_variable_selector: Optional[list[str]] = None

@ -6,7 +6,7 @@ from typing import Any, Optional, Union, cast
from uuid import uuid4 from uuid import uuid4
from sqlalchemy import func, select from sqlalchemy import func, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session, sessionmaker
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
@ -49,12 +49,14 @@ from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.repository import RepositoryFactory
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
from models.model import EndUser from models.model import EndUser
@ -80,6 +82,21 @@ class WorkflowCycleManage:
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables self._workflow_system_variables = workflow_system_variables
# Initialize the session factory and repository
# We use the global db engine instead of the session passed to methods
# Disable expire_on_commit to avoid the need for merging objects
self._session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": self._application_generate_entity.app_config.tenant_id,
"app_id": self._application_generate_entity.app_config.app_id,
"session_factory": self._session_factory,
}
)
# We'll still keep the cache for backward compatibility and performance
# but use the repository for database operations
def _handle_workflow_run_start( def _handle_workflow_run_start(
self, self,
*, *,
@ -254,19 +271,15 @@ class WorkflowCycleManage:
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count workflow_run.exceptions_count = exceptions_count
stmt = select(WorkflowNodeExecution.node_execution_id).where( # Use the instance repository to find running executions for a workflow run
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions(
WorkflowNodeExecution.app_id == workflow_run.app_id, workflow_run_id=workflow_run.id
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
) )
ids = session.scalars(stmt).all()
# Use self._get_workflow_node_execution here to make sure the cache is updated # Update the cache with the retrieved executions
running_workflow_node_executions = [ for execution in running_workflow_node_executions:
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id if execution.node_execution_id:
] self._workflow_node_executions[execution.node_execution_id] = execution
for workflow_node_execution in running_workflow_node_executions: for workflow_node_execution in running_workflow_node_executions:
now = datetime.now(UTC).replace(tzinfo=None) now = datetime.now(UTC).replace(tzinfo=None)
@ -288,7 +301,7 @@ class WorkflowCycleManage:
return workflow_run return workflow_run
def _handle_node_execution_start( def _handle_node_execution_start(
self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeStartedEvent self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
) -> WorkflowNodeExecution: ) -> WorkflowNodeExecution:
workflow_node_execution = WorkflowNodeExecution() workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.id = str(uuid4()) workflow_node_execution.id = str(uuid4())
@ -315,17 +328,14 @@ class WorkflowCycleManage:
) )
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
session.add(workflow_node_execution) # Use the instance repository to save the workflow node execution
self._workflow_node_execution_repository.save(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_success( def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
self, *, session: Session, event: QueueNodeSucceededEvent workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
) -> WorkflowNodeExecution:
workflow_node_execution = self._get_workflow_node_execution(
session=session, node_execution_id=event.node_execution_id
)
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data) process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs) outputs = WorkflowEntry.handle_special_values(event.outputs)
@ -344,13 +354,13 @@ class WorkflowCycleManage:
workflow_node_execution.finished_at = finished_at workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution = session.merge(workflow_node_execution) # Use the instance repository to update the workflow node execution
self._workflow_node_execution_repository.update(workflow_node_execution)
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_failed( def _handle_workflow_node_execution_failed(
self, self,
*, *,
session: Session,
event: QueueNodeFailedEvent event: QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent | QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent | QueueNodeInLoopFailedEvent
@ -361,9 +371,7 @@ class WorkflowCycleManage:
:param event: queue node failed event :param event: queue node failed event
:return: :return:
""" """
workflow_node_execution = self._get_workflow_node_execution( workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
session=session, node_execution_id=event.node_execution_id
)
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data) process_data = WorkflowEntry.handle_special_values(event.process_data)
@ -387,14 +395,14 @@ class WorkflowCycleManage:
workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_retried( def _handle_workflow_node_execution_retried(
self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeRetryEvent self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution: ) -> WorkflowNodeExecution:
""" """
Workflow node execution failed Workflow node execution failed
:param workflow_run: workflow run
:param event: queue node failed event :param event: queue node failed event
:return: :return:
""" """
@ -439,15 +447,12 @@ class WorkflowCycleManage:
workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution.index = event.node_run_index workflow_node_execution.index = event.node_run_index
session.add(workflow_node_execution) # Use the instance repository to save the workflow node execution
self._workflow_node_execution_repository.save(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution return workflow_node_execution
#################################################
# to stream responses #
#################################################
def _workflow_start_to_stream_response( def _workflow_start_to_stream_response(
self, self,
*, *,
@ -455,7 +460,6 @@ class WorkflowCycleManage:
task_id: str, task_id: str,
workflow_run: WorkflowRun, workflow_run: WorkflowRun,
) -> WorkflowStartStreamResponse: ) -> WorkflowStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return WorkflowStartStreamResponse( return WorkflowStartStreamResponse(
task_id=task_id, task_id=task_id,
@ -521,14 +525,10 @@ class WorkflowCycleManage:
def _workflow_node_start_to_stream_response( def _workflow_node_start_to_stream_response(
self, self,
*, *,
session: Session,
event: QueueNodeStartedEvent, event: QueueNodeStartedEvent,
task_id: str, task_id: str,
workflow_node_execution: WorkflowNodeExecution, workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeStartStreamResponse]: ) -> Optional[NodeStartStreamResponse]:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None return None
if not workflow_node_execution.workflow_run_id: if not workflow_node_execution.workflow_run_id:
@ -571,7 +571,6 @@ class WorkflowCycleManage:
def _workflow_node_finish_to_stream_response( def _workflow_node_finish_to_stream_response(
self, self,
*, *,
session: Session,
event: QueueNodeSucceededEvent event: QueueNodeSucceededEvent
| QueueNodeFailedEvent | QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent | QueueNodeInIterationFailedEvent
@ -580,8 +579,6 @@ class WorkflowCycleManage:
task_id: str, task_id: str,
workflow_node_execution: WorkflowNodeExecution, workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]: ) -> Optional[NodeFinishStreamResponse]:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None return None
if not workflow_node_execution.workflow_run_id: if not workflow_node_execution.workflow_run_id:
@ -621,13 +618,10 @@ class WorkflowCycleManage:
def _workflow_node_retry_to_stream_response( def _workflow_node_retry_to_stream_response(
self, self,
*, *,
session: Session,
event: QueueNodeRetryEvent, event: QueueNodeRetryEvent,
task_id: str, task_id: str,
workflow_node_execution: WorkflowNodeExecution, workflow_node_execution: WorkflowNodeExecution,
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None return None
if not workflow_node_execution.workflow_run_id: if not workflow_node_execution.workflow_run_id:
@ -668,7 +662,6 @@ class WorkflowCycleManage:
def _workflow_parallel_branch_start_to_stream_response( def _workflow_parallel_branch_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse: ) -> ParallelBranchStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return ParallelBranchStartStreamResponse( return ParallelBranchStartStreamResponse(
task_id=task_id, task_id=task_id,
@ -692,7 +685,6 @@ class WorkflowCycleManage:
workflow_run: WorkflowRun, workflow_run: WorkflowRun,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
) -> ParallelBranchFinishedStreamResponse: ) -> ParallelBranchFinishedStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return ParallelBranchFinishedStreamResponse( return ParallelBranchFinishedStreamResponse(
task_id=task_id, task_id=task_id,
@ -713,7 +705,6 @@ class WorkflowCycleManage:
def _workflow_iteration_start_to_stream_response( def _workflow_iteration_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
) -> IterationNodeStartStreamResponse: ) -> IterationNodeStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return IterationNodeStartStreamResponse( return IterationNodeStartStreamResponse(
task_id=task_id, task_id=task_id,
@ -735,7 +726,6 @@ class WorkflowCycleManage:
def _workflow_iteration_next_to_stream_response( def _workflow_iteration_next_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
) -> IterationNodeNextStreamResponse: ) -> IterationNodeNextStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return IterationNodeNextStreamResponse( return IterationNodeNextStreamResponse(
task_id=task_id, task_id=task_id,
@ -759,7 +749,6 @@ class WorkflowCycleManage:
def _workflow_iteration_completed_to_stream_response( def _workflow_iteration_completed_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
) -> IterationNodeCompletedStreamResponse: ) -> IterationNodeCompletedStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return IterationNodeCompletedStreamResponse( return IterationNodeCompletedStreamResponse(
task_id=task_id, task_id=task_id,
@ -790,7 +779,6 @@ class WorkflowCycleManage:
def _workflow_loop_start_to_stream_response( def _workflow_loop_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent
) -> LoopNodeStartStreamResponse: ) -> LoopNodeStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return LoopNodeStartStreamResponse( return LoopNodeStartStreamResponse(
task_id=task_id, task_id=task_id,
@ -812,7 +800,6 @@ class WorkflowCycleManage:
def _workflow_loop_next_to_stream_response( def _workflow_loop_next_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent
) -> LoopNodeNextStreamResponse: ) -> LoopNodeNextStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return LoopNodeNextStreamResponse( return LoopNodeNextStreamResponse(
task_id=task_id, task_id=task_id,
@ -836,7 +823,6 @@ class WorkflowCycleManage:
def _workflow_loop_completed_to_stream_response( def _workflow_loop_completed_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent
) -> LoopNodeCompletedStreamResponse: ) -> LoopNodeCompletedStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return LoopNodeCompletedStreamResponse( return LoopNodeCompletedStreamResponse(
task_id=task_id, task_id=task_id,
@ -934,11 +920,22 @@ class WorkflowCycleManage:
return workflow_run return workflow_run
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
if node_execution_id not in self._workflow_node_executions: # First check the cache for performance
if node_execution_id in self._workflow_node_executions:
cached_execution = self._workflow_node_executions[node_execution_id]
# No need to merge with session since expire_on_commit=False
return cached_execution
# If not in cache, use the instance repository to get by node_execution_id
execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id)
if not execution:
raise ValueError(f"Workflow node execution not found: {node_execution_id}") raise ValueError(f"Workflow node execution not found: {node_execution_id}")
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
return session.merge(cached_workflow_node_execution) # Update cache
self._workflow_node_executions[node_execution_id] = execution
return execution
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
""" """

@ -6,7 +6,6 @@ from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from models.model import DatasetRetrieverResource
class DatasetIndexToolCallbackHandler: class DatasetIndexToolCallbackHandler:
@ -71,29 +70,6 @@ class DatasetIndexToolCallbackHandler:
def return_retriever_resource_info(self, resource: list): def return_retriever_resource_info(self, resource: list):
"""Handle return_retriever_resource_info.""" """Handle return_retriever_resource_info."""
if resource and len(resource) > 0:
for item in resource:
dataset_retriever_resource = DatasetRetrieverResource(
message_id=self._message_id,
position=item.get("position") or 0,
dataset_id=item.get("dataset_id"),
dataset_name=item.get("dataset_name"),
document_id=item.get("document_id"),
document_name=item.get("document_name"),
data_source_type=item.get("data_source_type"),
segment_id=item.get("segment_id"),
score=item.get("score") if "score" in item else None,
hit_count=item.get("hit_count") if "hit_count" in item else None,
word_count=item.get("word_count") if "word_count" in item else None,
segment_position=item.get("segment_position") if "segment_position" in item else None,
index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None,
content=item.get("content"),
retriever_from=item.get("retriever_from"),
created_by=self._user_id,
)
db.session.add(dataset_retriever_resource)
db.session.commit()
self._queue_manager.publish( self._queue_manager.publish(
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
) )

@ -10,6 +10,7 @@ from core.llm_generator.prompts import (
GENERATOR_QA_PROMPT, GENERATOR_QA_PROMPT,
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE, JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE, PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
SYSTEM_STRUCTURED_OUTPUT_GENERATE,
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
) )
from core.model_manager import ModelManager from core.model_manager import ModelManager
@ -340,3 +341,37 @@ class LLMGenerator:
answer = cast(str, response.message.content) answer = cast(str, response.message.content)
return answer.strip() return answer.strip()
@classmethod
def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict):
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
prompt_messages = [
SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE),
UserPromptMessage(content=instruction),
]
model_parameters = model_config.get("model_parameters", {})
try:
response = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
),
)
generated_json_schema = cast(str, response.message.content)
return {"output": generated_json_schema, "error": ""}
except InvokeError as e:
error = str(e)
return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
except Exception as e:
logging.exception(f"Failed to invoke LLM model, model: {model_config.get('name')}")
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}

@ -220,3 +220,110 @@ Here is the task description: {{INPUT_TEXT}}
You just need to generate the output You just need to generate the output
""" # noqa: E501 """ # noqa: E501
SYSTEM_STRUCTURED_OUTPUT_GENERATE = """
Your task is to convert simple user descriptions into properly formatted JSON Schema definitions. When a user describes data fields they need, generate a complete, valid JSON Schema that accurately represents those fields with appropriate types and requirements.
## Instructions:
1. Analyze the user's description of their data needs
2. Identify each property that should be included in the schema
3. Determine the appropriate data type for each property
4. Decide which properties should be required
5. Generate a complete JSON Schema with proper syntax
6. Include appropriate constraints when specified (min/max values, patterns, formats)
7. Provide ONLY the JSON Schema without any additional explanations, comments, or markdown formatting.
8. DO NOT use markdown code blocks (``` or ``` json). Return the raw JSON Schema directly.
## Examples:
### Example 1:
**User Input:** I need name and age
**JSON Schema Output:**
{
"type": "object",
"properties": {
"name": { "type": "string" },
"age": { "type": "number" }
},
"required": ["name", "age"]
}
### Example 2:
**User Input:** I want to store information about books including title, author, publication year and optional page count
**JSON Schema Output:**
{
"type": "object",
"properties": {
"title": { "type": "string" },
"author": { "type": "string" },
"publicationYear": { "type": "integer" },
"pageCount": { "type": "integer" }
},
"required": ["title", "author", "publicationYear"]
}
### Example 3:
**User Input:** Create a schema for user profiles with email, password, and age (must be at least 18)
**JSON Schema Output:**
{
"type": "object",
"properties": {
"email": {
"type": "string",
"format": "email"
},
"password": {
"type": "string",
"minLength": 8
},
"age": {
"type": "integer",
"minimum": 18
}
},
"required": ["email", "password", "age"]
}
### Example 4:
**User Input:** I need album schema, the ablum has songs, and each song has name, duration, and artist.
**JSON Schema Output:**
{
"type": "object",
"properties": {
"properties": {
"songs": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"id": {
"type": "string"
},
"duration": {
"type": "string"
},
"aritst": {
"type": "string"
}
},
"required": [
"name",
"id",
"duration",
"aritst"
]
}
}
}
},
"required": [
"songs"
]
}
Now, generate a JSON Schema based on my description
""" # noqa: E501

@ -1,8 +1,8 @@
from collections.abc import Sequence from collections.abc import Sequence
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import Optional from typing import Any, Optional, Union
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_serializer, field_validator
class PromptMessageRole(Enum): class PromptMessageRole(Enum):
@ -135,6 +135,16 @@ class PromptMessage(BaseModel):
""" """
return not self.content return not self.content
@field_serializer("content")
def serialize_content(
self, content: Optional[Union[str, Sequence[PromptMessageContent]]]
) -> Optional[str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent]]:
if content is None or isinstance(content, str):
return content
if isinstance(content, list):
return [item.model_dump() if hasattr(item, "model_dump") else item for item in content]
return content
class UserPromptMessage(PromptMessage): class UserPromptMessage(PromptMessage):
""" """

@ -2,7 +2,7 @@ from decimal import Decimal
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, model_validator
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
@ -85,6 +85,7 @@ class ModelFeature(Enum):
DOCUMENT = "document" DOCUMENT = "document"
VIDEO = "video" VIDEO = "video"
AUDIO = "audio" AUDIO = "audio"
STRUCTURED_OUTPUT = "structured-output"
class DefaultParameterName(StrEnum): class DefaultParameterName(StrEnum):
@ -197,6 +198,19 @@ class AIModelEntity(ProviderModel):
parameter_rules: list[ParameterRule] = [] parameter_rules: list[ParameterRule] = []
pricing: Optional[PriceConfig] = None pricing: Optional[PriceConfig] = None
@model_validator(mode="after")
def validate_model(self):
supported_schema_keys = ["json_schema"]
schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None)
if not schema_key:
return self
if self.features is None:
self.features = [ModelFeature.STRUCTURED_OUTPUT]
else:
if ModelFeature.STRUCTURED_OUTPUT not in self.features:
self.features.append(ModelFeature.STRUCTURED_OUTPUT)
return self
class ModelUsage(BaseModel): class ModelUsage(BaseModel):
pass pass

@ -1,5 +1,6 @@
import logging import logging
import time import time
import uuid
from collections.abc import Generator, Sequence from collections.abc import Generator, Sequence
from typing import Optional, Union from typing import Optional, Union
@ -24,6 +25,58 @@ from core.plugin.manager.model import PluginModelManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _gen_tool_call_id() -> str:
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
def _increase_tool_call(
new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall]
):
"""
Merge incremental tool call updates into existing tool calls.
:param new_tool_calls: List of new tool call deltas to be merged.
:param existing_tools_calls: List of existing tool calls to be modified IN-PLACE.
"""
def get_tool_call(tool_call_id: str):
"""
Get or create a tool call by ID
:param tool_call_id: tool call ID
:return: existing or new tool call
"""
if not tool_call_id:
return existing_tools_calls[-1]
_tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None)
if _tool_call is None:
_tool_call = AssistantPromptMessage.ToolCall(
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
)
existing_tools_calls.append(_tool_call)
return _tool_call
for new_tool_call in new_tool_calls:
# generate ID for tool calls with function name but no ID to track them
if new_tool_call.function.name and not new_tool_call.id:
new_tool_call.id = _gen_tool_call_id()
# get tool call
tool_call = get_tool_call(new_tool_call.id)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
class LargeLanguageModel(AIModel): class LargeLanguageModel(AIModel):
""" """
Model class for large language model. Model class for large language model.
@ -109,44 +162,13 @@ class LargeLanguageModel(AIModel):
system_fingerprint = None system_fingerprint = None
tools_calls: list[AssistantPromptMessage.ToolCall] = [] tools_calls: list[AssistantPromptMessage.ToolCall] = []
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
def get_tool_call(tool_name: str):
if not tool_name:
return tools_calls[-1]
tool_call = next(
(tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None
)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id="",
type="",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""),
)
tools_calls.append(tool_call)
return tool_call
for new_tool_call in new_tool_calls:
# get tool call
tool_call = get_tool_call(new_tool_call.function.name)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
for chunk in result: for chunk in result:
if isinstance(chunk.delta.message.content, str): if isinstance(chunk.delta.message.content, str):
content += chunk.delta.message.content content += chunk.delta.message.content
elif isinstance(chunk.delta.message.content, list): elif isinstance(chunk.delta.message.content, list):
content_list.extend(chunk.delta.message.content) content_list.extend(chunk.delta.message.content)
if chunk.delta.message.tool_calls: if chunk.delta.message.tool_calls:
increase_tool_call(chunk.delta.message.tool_calls) _increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
usage = chunk.delta.usage or LLMUsage.empty_usage() usage = chunk.delta.usage or LLMUsage.empty_usage()
system_fingerprint = chunk.system_fingerprint system_fingerprint = chunk.system_fingerprint

@ -5,6 +5,7 @@ from datetime import datetime, timedelta
from typing import Optional from typing import Optional
from langfuse import Langfuse # type: ignore from langfuse import Langfuse # type: ignore
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.config_entity import LangfuseConfig
@ -28,9 +29,9 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
UnitEnum, UnitEnum,
) )
from core.ops.utils import filter_none_values from core.ops.utils import filter_none_values
from core.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.model import EndUser from models.model import EndUser
from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -110,36 +111,18 @@ class LangFuseDataTrace(BaseTraceInstance):
) )
self.add_trace(langfuse_trace_data=trace_data) self.add_trace(langfuse_trace_data=trace_data)
# through workflow_run_id get all_nodes_execution # through workflow_run_id get all_nodes_execution using repository
workflow_nodes_execution_id_records = ( session_factory = sessionmaker(bind=db.engine)
db.session.query(WorkflowNodeExecution.id) workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) params={"tenant_id": trace_info.tenant_id, "session_factory": session_factory},
.all()
) )
for node_execution_id_record in workflow_nodes_execution_id_records: # Get all executions for this workflow run
node_execution = ( workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
db.session.query( workflow_run_id=trace_info.workflow_run_id
WorkflowNodeExecution.id, )
WorkflowNodeExecution.tenant_id,
WorkflowNodeExecution.app_id,
WorkflowNodeExecution.title,
WorkflowNodeExecution.node_type,
WorkflowNodeExecution.status,
WorkflowNodeExecution.inputs,
WorkflowNodeExecution.outputs,
WorkflowNodeExecution.created_at,
WorkflowNodeExecution.elapsed_time,
WorkflowNodeExecution.process_data,
WorkflowNodeExecution.execution_metadata,
)
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
.first()
)
if not node_execution:
continue
for node_execution in workflow_node_executions:
node_execution_id = node_execution.id node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id tenant_id = node_execution.tenant_id
app_id = node_execution.app_id app_id = node_execution.app_id

@ -7,6 +7,7 @@ from typing import Optional, cast
from langsmith import Client from langsmith import Client
from langsmith.schemas import RunBase from langsmith.schemas import RunBase
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangSmithConfig from core.ops.entities.config_entity import LangSmithConfig
@ -27,9 +28,9 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
LangSmithRunUpdateModel, LangSmithRunUpdateModel,
) )
from core.ops.utils import filter_none_values, generate_dotted_order from core.ops.utils import filter_none_values, generate_dotted_order
from core.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.model import EndUser, MessageFile from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -134,36 +135,22 @@ class LangSmithDataTrace(BaseTraceInstance):
self.add_run(langsmith_run) self.add_run(langsmith_run)
# through workflow_run_id get all_nodes_execution # through workflow_run_id get all_nodes_execution using repository
workflow_nodes_execution_id_records = ( session_factory = sessionmaker(bind=db.engine)
db.session.query(WorkflowNodeExecution.id) workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) params={
.all() "tenant_id": trace_info.tenant_id,
"app_id": trace_info.metadata.get("app_id"),
"session_factory": session_factory,
},
) )
for node_execution_id_record in workflow_nodes_execution_id_records: # Get all executions for this workflow run
node_execution = ( workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
db.session.query( workflow_run_id=trace_info.workflow_run_id
WorkflowNodeExecution.id, )
WorkflowNodeExecution.tenant_id,
WorkflowNodeExecution.app_id,
WorkflowNodeExecution.title,
WorkflowNodeExecution.node_type,
WorkflowNodeExecution.status,
WorkflowNodeExecution.inputs,
WorkflowNodeExecution.outputs,
WorkflowNodeExecution.created_at,
WorkflowNodeExecution.elapsed_time,
WorkflowNodeExecution.process_data,
WorkflowNodeExecution.execution_metadata,
)
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
.first()
)
if not node_execution:
continue
for node_execution in workflow_node_executions:
node_execution_id = node_execution.id node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id tenant_id = node_execution.tenant_id
app_id = node_execution.app_id app_id = node_execution.app_id

@ -7,6 +7,7 @@ from typing import Optional, cast
from opik import Opik, Trace from opik import Opik, Trace
from opik.id_helpers import uuid4_to_uuid7 from opik.id_helpers import uuid4_to_uuid7
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import OpikConfig from core.ops.entities.config_entity import OpikConfig
@ -21,9 +22,9 @@ from core.ops.entities.trace_entity import (
TraceTaskName, TraceTaskName,
WorkflowTraceInfo, WorkflowTraceInfo,
) )
from core.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.model import EndUser, MessageFile from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -147,36 +148,22 @@ class OpikDataTrace(BaseTraceInstance):
} }
self.add_trace(trace_data) self.add_trace(trace_data)
# through workflow_run_id get all_nodes_execution # through workflow_run_id get all_nodes_execution using repository
workflow_nodes_execution_id_records = ( session_factory = sessionmaker(bind=db.engine)
db.session.query(WorkflowNodeExecution.id) workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) params={
.all() "tenant_id": trace_info.tenant_id,
"app_id": trace_info.metadata.get("app_id"),
"session_factory": session_factory,
},
) )
for node_execution_id_record in workflow_nodes_execution_id_records: # Get all executions for this workflow run
node_execution = ( workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
db.session.query( workflow_run_id=trace_info.workflow_run_id
WorkflowNodeExecution.id, )
WorkflowNodeExecution.tenant_id,
WorkflowNodeExecution.app_id,
WorkflowNodeExecution.title,
WorkflowNodeExecution.node_type,
WorkflowNodeExecution.status,
WorkflowNodeExecution.inputs,
WorkflowNodeExecution.outputs,
WorkflowNodeExecution.created_at,
WorkflowNodeExecution.elapsed_time,
WorkflowNodeExecution.process_data,
WorkflowNodeExecution.execution_metadata,
)
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
.first()
)
if not node_execution:
continue
for node_execution in workflow_node_executions:
node_execution_id = node_execution.id node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id tenant_id = node_execution.tenant_id
app_id = node_execution.app_id app_id = node_execution.app_id

@ -39,6 +39,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
:param query: str :param query: str
:return: dict :return: dict
""" """
# FIXME(-LAN-): Avoid import service into core
workflow_service = WorkflowService() workflow_service = WorkflowService()
node_id = "1919810" node_id = "1919810"
node_data = ParameterExtractorNodeData( node_data = ParameterExtractorNodeData(
@ -89,6 +90,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
:param query: str :param query: str
:return: dict :return: dict
""" """
# FIXME(-LAN-): Avoid import service into core
workflow_service = WorkflowService() workflow_service = WorkflowService()
node_id = "1919810" node_id = "1919810"
node_data = QuestionClassifierNodeData( node_data = QuestionClassifierNodeData(

@ -124,6 +124,15 @@ class ProviderManager:
# Get All preferred provider types of the workspace # Get All preferred provider types of the workspace
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id) provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
# Ensure that both the original provider name and its ModelProviderID string representation
# are present in the dictionary to handle cases where either form might be used
for provider_name in list(provider_name_to_preferred_model_provider_records_dict.keys()):
provider_id = ModelProviderID(provider_name)
if str(provider_id) not in provider_name_to_preferred_model_provider_records_dict:
# Add the ModelProviderID string representation if it's not already present
provider_name_to_preferred_model_provider_records_dict[str(provider_id)] = (
provider_name_to_preferred_model_provider_records_dict[provider_name]
)
# Get All provider model settings # Get All provider model settings
provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id) provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id)
@ -497,8 +506,8 @@ class ProviderManager:
@staticmethod @staticmethod
def _init_trial_provider_records( def _init_trial_provider_records(
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list] tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
) -> dict[str, list]: ) -> dict[str, list[Provider]]:
""" """
Initialize trial provider records if not exists. Initialize trial provider records if not exists.
@ -532,7 +541,7 @@ class ProviderManager:
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
try: try:
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic # FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
provider_record = Provider( new_provider_record = Provider(
tenant_id=tenant_id, tenant_id=tenant_id,
# TODO: Use provider name with prefix after the data migration. # TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name, provider_name=ModelProviderID(provider_name).provider_name,
@ -542,11 +551,12 @@ class ProviderManager:
quota_used=0, quota_used=0,
is_valid=True, is_valid=True,
) )
db.session.add(provider_record) db.session.add(new_provider_record)
db.session.commit() db.session.commit()
provider_name_to_provider_records_dict[provider_name].append(new_provider_record)
except IntegrityError: except IntegrityError:
db.session.rollback() db.session.rollback()
provider_record = ( existed_provider_record = (
db.session.query(Provider) db.session.query(Provider)
.filter( .filter(
Provider.tenant_id == tenant_id, Provider.tenant_id == tenant_id,
@ -556,11 +566,14 @@ class ProviderManager:
) )
.first() .first()
) )
if provider_record and not provider_record.is_valid: if not existed_provider_record:
provider_record.is_valid = True continue
if not existed_provider_record.is_valid:
existed_provider_record.is_valid = True
db.session.commit() db.session.commit()
provider_name_to_provider_records_dict[provider_name].append(provider_record) provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
return provider_name_to_provider_records_dict return provider_name_to_provider_records_dict

@ -246,7 +246,7 @@ class AnalyticdbVectorBySql:
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
FROM {self.table_name} FROM {self.table_name}
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause} WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
ORDER BY (score,id) DESC ORDER BY score DESC, id DESC
LIMIT {top_k}""", LIMIT {top_k}""",
(f"'{query}'", f"'{query}'"), (f"'{query}'", f"'{query}'"),
) )

@ -2,12 +2,12 @@ import array
import json import json
import re import re
import uuid import uuid
from contextlib import contextmanager
from typing import Any from typing import Any
import jieba.posseg as pseg # type: ignore import jieba.posseg as pseg # type: ignore
import numpy import numpy
import oracledb import oracledb
from oracledb.connection import Connection
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from configs import dify_config from configs import dify_config
@ -70,6 +70,7 @@ class OracleVector(BaseVector):
super().__init__(collection_name) super().__init__(collection_name)
self.pool = self._create_connection_pool(config) self.pool = self._create_connection_pool(config)
self.table_name = f"embedding_{collection_name}" self.table_name = f"embedding_{collection_name}"
self.config = config
def get_type(self) -> str: def get_type(self) -> str:
return VectorType.ORACLE return VectorType.ORACLE
@ -107,16 +108,19 @@ class OracleVector(BaseVector):
outconverter=self.numpy_converter_out, outconverter=self.numpy_converter_out,
) )
def _get_connection(self) -> Connection:
connection = oracledb.connect(user=self.config.user, password=self.config.password, dsn=self.config.dsn)
return connection
def _create_connection_pool(self, config: OracleVectorConfig): def _create_connection_pool(self, config: OracleVectorConfig):
pool_params = { pool_params = {
"user": config.user, "user": config.user,
"password": config.password, "password": config.password,
"dsn": config.dsn, "dsn": config.dsn,
"min": 1, "min": 1,
"max": 50, "max": 5,
"increment": 1, "increment": 1,
} }
if config.is_autonomous: if config.is_autonomous:
pool_params.update( pool_params.update(
{ {
@ -125,22 +129,8 @@ class OracleVector(BaseVector):
"wallet_password": config.wallet_password, "wallet_password": config.wallet_password,
} }
) )
return oracledb.create_pool(**pool_params) return oracledb.create_pool(**pool_params)
@contextmanager
def _get_cursor(self):
conn = self.pool.acquire()
conn.inputtypehandler = self.input_type_handler
conn.outputtypehandler = self.output_type_handler
cur = conn.cursor()
try:
yield cur
finally:
cur.close()
conn.commit()
conn.close()
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0]) dimension = len(embeddings[0])
self._create_collection(dimension) self._create_collection(dimension)
@ -162,41 +152,68 @@ class OracleVector(BaseVector):
numpy.array(embeddings[i]), numpy.array(embeddings[i]),
) )
) )
# print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)") with self._get_connection() as conn:
with self._get_cursor() as cur: conn.inputtypehandler = self.input_type_handler
cur.executemany( conn.outputtypehandler = self.output_type_handler
f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values # with conn.cursor() as cur:
) # cur.executemany(
# f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values
# )
# conn.commit()
for value in values:
with conn.cursor() as cur:
try:
cur.execute(
f"""INSERT INTO {self.table_name} (id, text, meta, embedding)
VALUES (:1, :2, :3, :4)""",
value,
)
conn.commit()
except Exception as e:
print(e)
conn.close()
return pks return pks
def text_exists(self, id: str) -> bool: def text_exists(self, id: str) -> bool:
with self._get_cursor() as cur: with self._get_connection() as conn:
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,)) with conn.cursor() as cur:
return cur.fetchone() is not None cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
return cur.fetchone() is not None
conn.close()
def get_by_ids(self, ids: list[str]) -> list[Document]: def get_by_ids(self, ids: list[str]) -> list[Document]:
with self._get_cursor() as cur: with self._get_connection() as conn:
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) with conn.cursor() as cur:
docs = [] cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
for record in cur: docs = []
docs.append(Document(page_content=record[1], metadata=record[0])) for record in cur:
docs.append(Document(page_content=record[1], metadata=record[0]))
self.pool.release(connection=conn)
conn.close()
return docs return docs
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids: if not ids:
return return
with self._get_cursor() as cur: with self._get_connection() as conn:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) with conn.cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
conn.commit()
conn.close()
def delete_by_metadata_field(self, key: str, value: str) -> None: def delete_by_metadata_field(self, key: str, value: str) -> None:
with self._get_cursor() as cur: with self._get_connection() as conn:
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value)) with conn.cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
conn.commit()
conn.close()
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
""" """
Search the nearest neighbors to a vector. Search the nearest neighbors to a vector.
:param query_vector: The input vector to search for similar items. :param query_vector: The input vector to search for similar items.
:param top_k: The number of nearest neighbors to return, default is 5.
:return: List of Documents that are nearest to the query vector. :return: List of Documents that are nearest to the query vector.
""" """
top_k = kwargs.get("top_k", 4) top_k = kwargs.get("top_k", 4)
@ -205,20 +222,25 @@ class OracleVector(BaseVector):
if document_ids_filter: if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})" where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
with self._get_cursor() as cur: with self._get_connection() as conn:
cur.execute( conn.inputtypehandler = self.input_type_handler
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}" conn.outputtypehandler = self.output_type_handler
f" {where_clause} ORDER BY distance fetch first {top_k} rows only", with conn.cursor() as cur:
[numpy.array(query_vector)], cur.execute(
) f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine)
docs = [] AS distance FROM {self.table_name}
score_threshold = float(kwargs.get("score_threshold") or 0.0) {where_clause} ORDER BY distance fetch first {top_k} rows only""",
for record in cur: [numpy.array(query_vector)],
metadata, text, distance = record )
score = 1 - distance docs = []
metadata["score"] = score score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold: for record in cur:
docs.append(Document(page_content=text, metadata=metadata)) metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score > score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
conn.close()
return docs return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@ -228,7 +250,7 @@ class OracleVector(BaseVector):
top_k = kwargs.get("top_k", 5) top_k = kwargs.get("top_k", 5)
# just not implement fetch by score_threshold now, may be later # just not implement fetch by score_threshold now, may be later
# score_threshold = float(kwargs.get("score_threshold") or 0.0) score_threshold = float(kwargs.get("score_threshold") or 0.0)
if len(query) > 0: if len(query) > 0:
# Check which language the query is in # Check which language the query is in
zh_pattern = re.compile("[\u4e00-\u9fa5]+") zh_pattern = re.compile("[\u4e00-\u9fa5]+")
@ -239,7 +261,7 @@ class OracleVector(BaseVector):
words = pseg.cut(query) words = pseg.cut(query)
current_entity = "" current_entity = ""
for word, pos in words: for word, pos in words:
if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名ns: 地名,nt: 机构名 if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名
current_entity += word current_entity += word
else: else:
if current_entity: if current_entity:
@ -260,30 +282,35 @@ class OracleVector(BaseVector):
for token in all_tokens: for token in all_tokens:
if token not in stop_words: if token not in stop_words:
entities.append(token) entities.append(token)
with self._get_cursor() as cur: with self._get_connection() as conn:
document_ids_filter = kwargs.get("document_ids_filter") with conn.cursor() as cur:
where_clause = "" document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter: where_clause = ""
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) if document_ids_filter:
where_clause = f" AND metadata->>'document_id' in ({document_ids}) " document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
cur.execute( where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
f"select meta, text, embedding FROM {self.table_name}" cur.execute(
f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} " f"""select meta, text, embedding FROM {self.table_name}
f"order by score(1) desc fetch first {top_k} rows only", WHERE CONTAINS(text, :kk, 1) > 0 {where_clause}
[" ACCUM ".join(entities)], order by score(1) desc fetch first {top_k} rows only""",
) kk=" ACCUM ".join(entities),
docs = [] )
for record in cur: docs = []
metadata, text, embedding = record for record in cur:
docs.append(Document(page_content=text, vector=embedding, metadata=metadata)) metadata, text, embedding = record
docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
conn.close()
return docs return docs
else: else:
return [Document(page_content="", metadata={})] return [Document(page_content="", metadata={})]
return [] return []
def delete(self) -> None: def delete(self) -> None:
with self._get_cursor() as cur: with self._get_connection() as conn:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints") with conn.cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints")
conn.commit()
conn.close()
def _create_collection(self, dimension: int): def _create_collection(self, dimension: int):
cache_key = f"vector_indexing_{self._collection_name}" cache_key = f"vector_indexing_{self._collection_name}"
@ -293,11 +320,14 @@ class OracleVector(BaseVector):
if redis_client.get(collection_exist_cache_key): if redis_client.get(collection_exist_cache_key):
return return
with self._get_cursor() as cur: with self._get_connection() as conn:
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name)) with conn.cursor() as cur:
redis_client.set(collection_exist_cache_key, 1, ex=3600) cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name))
with self._get_cursor() as cur: redis_client.set(collection_exist_cache_key, 1, ex=3600)
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name)) with conn.cursor() as cur:
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
conn.commit()
conn.close()
class OracleVectorFactory(AbstractVectorFactory): class OracleVectorFactory(AbstractVectorFactory):

@ -126,9 +126,7 @@ class WordExtractor(BaseExtractor):
db.session.add(upload_file) db.session.add(upload_file)
db.session.commit() db.session.commit()
image_map[rel.target_part] = ( image_map[rel.target_part] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)"
f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/file-preview)"
)
return image_map return image_map

@ -0,0 +1,15 @@
"""
Repository interfaces for data access.
This package contains repository interfaces that define the contract
for accessing and manipulating data, regardless of the underlying
storage mechanism.
"""
from core.repository.repository_factory import RepositoryFactory
from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
__all__ = [
"RepositoryFactory",
"WorkflowNodeExecutionRepository",
]

@ -0,0 +1,97 @@
"""
Repository factory for creating repository instances.
This module provides a simple factory interface for creating repository instances.
It does not contain any implementation details or dependencies on specific repositories.
"""
from collections.abc import Callable, Mapping
from typing import Any, Literal, Optional, cast
from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
# Type for factory functions - takes a dict of parameters and returns any repository type
RepositoryFactoryFunc = Callable[[Mapping[str, Any]], Any]
# Type for workflow node execution factory function
WorkflowNodeExecutionFactoryFunc = Callable[[Mapping[str, Any]], WorkflowNodeExecutionRepository]
# Repository type literals
_RepositoryType = Literal["workflow_node_execution"]
class RepositoryFactory:
"""
Factory class for creating repository instances.
This factory delegates the actual repository creation to implementation-specific
factory functions that are registered with the factory at runtime.
"""
# Dictionary to store factory functions
_factory_functions: dict[str, RepositoryFactoryFunc] = {}
@classmethod
def _register_factory(cls, repository_type: _RepositoryType, factory_func: RepositoryFactoryFunc) -> None:
"""
Register a factory function for a specific repository type.
This is a private method and should not be called directly.
Args:
repository_type: The type of repository (e.g., 'workflow_node_execution')
factory_func: A function that takes parameters and returns a repository instance
"""
cls._factory_functions[repository_type] = factory_func
@classmethod
def _create_repository(cls, repository_type: _RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any:
"""
Create a new repository instance with the provided parameters.
This is a private method and should not be called directly.
Args:
repository_type: The type of repository to create
params: A dictionary of parameters to pass to the factory function
Returns:
A new instance of the requested repository
Raises:
ValueError: If no factory function is registered for the repository type
"""
if repository_type not in cls._factory_functions:
raise ValueError(f"No factory function registered for repository type '{repository_type}'")
# Use empty dict if params is None
params = params or {}
return cls._factory_functions[repository_type](params)
@classmethod
def register_workflow_node_execution_factory(cls, factory_func: WorkflowNodeExecutionFactoryFunc) -> None:
"""
Register a factory function for the workflow node execution repository.
Args:
factory_func: A function that takes parameters and returns a WorkflowNodeExecutionRepository instance
"""
cls._register_factory("workflow_node_execution", factory_func)
@classmethod
def create_workflow_node_execution_repository(
cls, params: Optional[Mapping[str, Any]] = None
) -> WorkflowNodeExecutionRepository:
"""
Create a new WorkflowNodeExecutionRepository instance with the provided parameters.
Args:
params: A dictionary of parameters to pass to the factory function
Returns:
A new instance of the WorkflowNodeExecutionRepository
Raises:
ValueError: If no factory function is registered for the workflow_node_execution repository type
"""
# We can safely cast here because we've registered a WorkflowNodeExecutionFactoryFunc
return cast(WorkflowNodeExecutionRepository, cls._create_repository("workflow_node_execution", params))

@ -0,0 +1,97 @@
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Literal, Optional, Protocol
from models.workflow import WorkflowNodeExecution
@dataclass
class OrderConfig:
"""Configuration for ordering WorkflowNodeExecution instances."""
order_by: list[str]
order_direction: Optional[Literal["asc", "desc"]] = None
class WorkflowNodeExecutionRepository(Protocol):
"""
Repository interface for WorkflowNodeExecution.
This interface defines the contract for accessing and manipulating
WorkflowNodeExecution data, regardless of the underlying storage mechanism.
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
and trigger sources (triggered_from) should be handled at the implementation level, not in
the core interface. This keeps the core domain model clean and independent of specific
application domains or deployment scenarios.
"""
def save(self, execution: WorkflowNodeExecution) -> None:
"""
Save a WorkflowNodeExecution instance.
Args:
execution: The WorkflowNodeExecution instance to save
"""
...
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
"""
Retrieve a WorkflowNodeExecution by its node_execution_id.
Args:
node_execution_id: The node execution ID
Returns:
The WorkflowNodeExecution instance if found, None otherwise
"""
...
def get_by_workflow_run(
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
Args:
workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
order_config.order_direction: Direction to order ("asc" or "desc")
Returns:
A list of WorkflowNodeExecution instances
"""
...
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
Args:
workflow_run_id: The workflow run ID
Returns:
A list of running WorkflowNodeExecution instances
"""
...
def update(self, execution: WorkflowNodeExecution) -> None:
"""
Update an existing WorkflowNodeExecution instance.
Args:
execution: The WorkflowNodeExecution instance to update
"""
...
def clear(self) -> None:
"""
Clear all WorkflowNodeExecution records based on implementation-specific criteria.
This method is intended to be used for bulk deletion operations, such as removing
all records associated with a specific app_id and tenant_id in multi-tenant implementations.
"""
...

@ -94,7 +94,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
"title": item.metadata.get("title"), "title": item.metadata.get("title"),
"content": item.page_content, "content": item.page_content,
} }
context_list.append(source) context_list.append(source)
for hit_callback in self.hit_callbacks: for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list) hit_callback.return_retriever_resource_info(context_list)

@ -16,7 +16,7 @@ from core.variables.segments import StringSegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.agent.entities import AgentNodeData, ParamsAutoGenerated from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event.event import RunCompletedEvent from core.workflow.nodes.event.event import RunCompletedEvent
@ -251,7 +251,12 @@ class AgentNode(ToolNode):
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
] ]
value["history_prompt_messages"] = history_prompt_messages value["history_prompt_messages"] = history_prompt_messages
value["entity"] = model_schema.model_dump(mode="json") if model_schema else None if model_schema:
# remove structured output feature to support old version agent plugin
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
value["entity"] = model_schema.model_dump(mode="json")
else:
value["entity"] = None
result[parameter_name] = value result[parameter_name] = value
return result return result
@ -348,3 +353,10 @@ class AgentNode(ToolNode):
) )
model_schema = model_type_instance.get_model_schema(model_name, model_credentials) model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_instance, model_schema return model_instance, model_schema
def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity:
if model_schema.features:
for feature in model_schema.features:
if feature.value not in AgentOldVersionModelFeatures:
model_schema.features.remove(feature)
return model_schema

@ -24,3 +24,18 @@ class AgentNodeData(BaseNodeData):
class ParamsAutoGenerated(Enum): class ParamsAutoGenerated(Enum):
CLOSE = 0 CLOSE = 0
OPEN = 1 OPEN = 1
class AgentOldVersionModelFeatures(Enum):
"""
Enum class for old SDK version llm feature.
"""
TOOL_CALL = "tool-call"
MULTI_TOOL_CALL = "multi-tool-call"
AGENT_THOUGHT = "agent-thought"
VISION = "vision"
STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"

@ -65,6 +65,8 @@ class LLMNodeData(BaseNodeData):
memory: Optional[MemoryConfig] = None memory: Optional[MemoryConfig] = None
context: ContextConfig context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig) vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: dict | None = None
structured_output_enabled: bool = False
@field_validator("prompt_config", mode="before") @field_validator("prompt_config", mode="before")
@classmethod @classmethod

@ -4,6 +4,8 @@ from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast
import json_repair
from configs import dify_config from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
@ -27,7 +29,13 @@ from core.model_runtime.entities.message_entities import (
SystemPromptMessage, SystemPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType from core.model_runtime.entities.model_entities import (
AIModelEntity,
ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ModelProviderID from core.plugin.entities.plugin import ModelProviderID
@ -57,6 +65,12 @@ from core.workflow.nodes.event import (
RunRetrieverResourceEvent, RunRetrieverResourceEvent,
RunStreamChunkEvent, RunStreamChunkEvent,
) )
from core.workflow.utils.structured_output.entities import (
ResponseFormat,
SpecialModelType,
SupportStructuredOutputStatus,
)
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db from extensions.ext_database import db
from models.model import Conversation from models.model import Conversation
@ -92,6 +106,12 @@ class LLMNode(BaseNode[LLMNodeData]):
_node_type = NodeType.LLM _node_type = NodeType.LLM
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]:
"""Process structured output if enabled"""
if not self.node_data.structured_output_enabled or not self.node_data.structured_output:
return None
return self._parse_structured_output(text)
node_inputs: Optional[dict[str, Any]] = None node_inputs: Optional[dict[str, Any]] = None
process_data = None process_data = None
result_text = "" result_text = ""
@ -130,7 +150,6 @@ class LLMNode(BaseNode[LLMNodeData]):
if isinstance(event, RunRetrieverResourceEvent): if isinstance(event, RunRetrieverResourceEvent):
context = event.context context = event.context
yield event yield event
if context: if context:
node_inputs["#context#"] = context node_inputs["#context#"] = context
@ -192,7 +211,9 @@ class LLMNode(BaseNode[LLMNodeData]):
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
break break
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
structured_output = process_structured_output(result_text)
if structured_output:
outputs["structured_output"] = structured_output
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -513,7 +534,12 @@ class LLMNode(BaseNode[LLMNodeData]):
if not model_schema: if not model_schema:
raise ModelNotExistError(f"Model {model_name} not exist.") raise ModelNotExistError(f"Model {model_name} not exist.")
support_structured_output = self._check_model_structured_output_support()
if support_structured_output == SupportStructuredOutputStatus.SUPPORTED:
completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
# Set appropriate response format based on model capabilities
self._set_response_format(completion_params, model_schema.parameter_rules)
return model_instance, ModelConfigWithCredentialsEntity( return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name, provider=provider_name,
model=model_name, model=model_name,
@ -724,10 +750,29 @@ class LLMNode(BaseNode[LLMNodeData]):
"No prompt found in the LLM configuration. " "No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding." "Please ensure a prompt is properly configured before proceeding."
) )
support_structured_output = self._check_model_structured_output_support()
if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
filtered_prompt_messages = self._handle_prompt_based_schema(
prompt_messages=filtered_prompt_messages,
)
stop = model_config.stop stop = model_config.stop
return filtered_prompt_messages, stop return filtered_prompt_messages, stop
def _parse_structured_output(self, result_text: str) -> dict[str, Any] | list[Any]:
structured_output: dict[str, Any] | list[Any] = {}
try:
parsed = json.loads(result_text)
if not isinstance(parsed, (dict | list)):
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
structured_output = parsed
except json.JSONDecodeError as e:
# if the result_text is not a valid json, try to repair it
parsed = json_repair.loads(result_text)
if not isinstance(parsed, (dict | list)):
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
structured_output = parsed
return structured_output
@classmethod @classmethod
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
provider_model_bundle = model_instance.provider_model_bundle provider_model_bundle = model_instance.provider_model_bundle
@ -926,6 +971,166 @@ class LLMNode(BaseNode[LLMNodeData]):
return prompt_messages return prompt_messages
def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict:
"""
Handle structured output for models with native JSON schema support.
:param model_parameters: Model parameters to update
:param rules: Model parameter rules
:return: Updated model parameters with JSON schema configuration
"""
# Process schema according to model requirements
schema = self._fetch_structured_output_schema()
schema_json = self._prepare_schema_for_model(schema)
# Set JSON schema in parameters
model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
# Set appropriate response format if required by the model
for rule in rules:
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
return model_parameters
def _handle_prompt_based_schema(self, prompt_messages: Sequence[PromptMessage]) -> list[PromptMessage]:
"""
Handle structured output for models without native JSON schema support.
This function modifies the prompt messages to include schema-based output requirements.
Args:
prompt_messages: Original sequence of prompt messages
Returns:
list[PromptMessage]: Updated prompt messages with structured output requirements
"""
# Convert schema to string format
schema_str = json.dumps(self._fetch_structured_output_schema(), ensure_ascii=False)
# Find existing system prompt with schema placeholder
system_prompt = next(
(prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)),
None,
)
structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str)
# Prepare system prompt content
system_prompt_content = (
structured_output_prompt + "\n\n" + system_prompt.content
if system_prompt and isinstance(system_prompt.content, str)
else structured_output_prompt
)
system_prompt = SystemPromptMessage(content=system_prompt_content)
# Extract content from the last user message
filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)]
updated_prompt = [system_prompt] + filtered_prompts
return updated_prompt
def _set_response_format(self, model_parameters: dict, rules: list) -> None:
"""
Set the appropriate response format parameter based on model rules.
:param model_parameters: Model parameters to update
:param rules: Model parameter rules
"""
for rule in rules:
if rule.name == "response_format":
if ResponseFormat.JSON.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON.value
elif ResponseFormat.JSON_OBJECT.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
def _prepare_schema_for_model(self, schema: dict) -> dict:
"""
Prepare JSON schema based on model requirements.
Different models have different requirements for JSON schema formatting.
This function handles these differences.
:param schema: The original JSON schema
:return: Processed schema compatible with the current model
"""
# Deep copy to avoid modifying the original schema
processed_schema = schema.copy()
# Convert boolean types to string types (common requirement)
convert_boolean_to_string(processed_schema)
# Apply model-specific transformations
if SpecialModelType.GEMINI in self.node_data.model.name:
remove_additional_properties(processed_schema)
return processed_schema
elif SpecialModelType.OLLAMA in self.node_data.model.provider:
return processed_schema
else:
# Default format with name field
return {"schema": processed_schema, "name": "llm_response"}
def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
"""
Fetch model schema
"""
model_name = self.node_data.model.name
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name
)
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_schema
def _fetch_structured_output_schema(self) -> dict[str, Any]:
"""
Fetch the structured output schema from the node data.
Returns:
dict[str, Any]: The structured output schema
"""
if not self.node_data.structured_output:
raise LLMNodeError("Please provide a valid structured output schema")
structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False)
if not structured_output_schema:
raise LLMNodeError("Please provide a valid structured output schema")
try:
schema = json.loads(structured_output_schema)
if not isinstance(schema, dict):
raise LLMNodeError("structured_output_schema must be a JSON object")
return schema
except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format")
def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus:
"""
Check if the current model supports structured output.
Returns:
SupportStructuredOutput: The support status of structured output
"""
# Early return if structured output is disabled
if (
not isinstance(self.node_data, LLMNodeData)
or not self.node_data.structured_output_enabled
or not self.node_data.structured_output
):
return SupportStructuredOutputStatus.DISABLED
# Get model schema and check if it exists
model_schema = self._fetch_model_schema(self.node_data.model.provider)
if not model_schema:
return SupportStructuredOutputStatus.DISABLED
# Check if model supports structured output feature
return (
SupportStructuredOutputStatus.SUPPORTED
if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features)
else SupportStructuredOutputStatus.UNSUPPORTED
)
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole): def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
match role: match role:
@ -1064,3 +1269,49 @@ def _handle_completion_template(
) )
prompt_messages.append(prompt_message) prompt_messages.append(prompt_message)
return prompt_messages return prompt_messages
def remove_additional_properties(schema: dict) -> None:
"""
Remove additionalProperties fields from JSON schema.
Used for models like Gemini that don't support this property.
:param schema: JSON schema to modify in-place
"""
if not isinstance(schema, dict):
return
# Remove additionalProperties at current level
schema.pop("additionalProperties", None)
# Process nested structures recursively
for value in schema.values():
if isinstance(value, dict):
remove_additional_properties(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
remove_additional_properties(item)
def convert_boolean_to_string(schema: dict) -> None:
"""
Convert boolean type specifications to string in JSON schema.
:param schema: JSON schema to modify in-place
"""
if not isinstance(schema, dict):
return
# Check for boolean type at current level
if schema.get("type") == "boolean":
schema["type"] = "string"
# Process nested dictionaries and lists recursively
for value in schema.values():
if isinstance(value, dict):
convert_boolean_to_string(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
convert_boolean_to_string(item)

@ -0,0 +1,24 @@
from enum import StrEnum
class ResponseFormat(StrEnum):
"""Constants for model response formats"""
JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode.
JSON = "JSON" # model's json mode. some model like claude support this mode.
JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias.
class SpecialModelType(StrEnum):
"""Constants for identifying model types"""
GEMINI = "gemini"
OLLAMA = "ollama"
class SupportStructuredOutputStatus(StrEnum):
"""Constants for structured output support status"""
SUPPORTED = "supported"
UNSUPPORTED = "unsupported"
DISABLED = "disabled"

@ -0,0 +1,17 @@
STRUCTURED_OUTPUT_PROMPT = """Youre a helpful AI assistant. You could answer questions and output in JSON format.
constraints:
- You must output in JSON format.
- Do not output boolean value, use string type instead.
- Do not output integer or float value, use number type instead.
eg:
Here is the JSON schema:
{"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"}
Here is the user's question:
My name is John Doe and I am 30 years old.
output:
{"name": "John Doe", "age": 30}
Here is the JSON schema:
{{schema}}
""" # noqa: E501

@ -26,9 +26,12 @@ def init_app(app: DifyApp):
# Always add StreamHandler to log to console # Always add StreamHandler to log to console
sh = logging.StreamHandler(sys.stdout) sh = logging.StreamHandler(sys.stdout)
sh.addFilter(RequestIdFilter())
log_handlers.append(sh) log_handlers.append(sh)
# Apply RequestIdFilter to all handlers
for handler in log_handlers:
handler.addFilter(RequestIdFilter())
logging.basicConfig( logging.basicConfig(
level=dify_config.LOG_LEVEL, level=dify_config.LOG_LEVEL,
format=dify_config.LOG_FORMAT, format=dify_config.LOG_FORMAT,

@ -0,0 +1,18 @@
"""
Extension for initializing repositories.
This extension registers repository implementations with the RepositoryFactory.
"""
from dify_app import DifyApp
from repositories.repository_registry import register_repositories
def init_app(_app: DifyApp) -> None:
"""
Initialize repository implementations.
Args:
_app: The Flask application instance (unused)
"""
register_repositories()

@ -73,11 +73,7 @@ class Storage:
raise ValueError(f"unsupported storage type {storage_type}") raise ValueError(f"unsupported storage type {storage_type}")
def save(self, filename, data): def save(self, filename, data):
try: self.storage_runner.save(filename, data)
self.storage_runner.save(filename, data)
except Exception as e:
logger.exception(f"Failed to save file {filename}")
raise e
@overload @overload
def load(self, filename: str, /, *, stream: Literal[False] = False) -> bytes: ... def load(self, filename: str, /, *, stream: Literal[False] = False) -> bytes: ...
@ -86,49 +82,25 @@ class Storage:
def load(self, filename: str, /, *, stream: Literal[True]) -> Generator: ... def load(self, filename: str, /, *, stream: Literal[True]) -> Generator: ...
def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]: def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]:
try: if stream:
if stream: return self.load_stream(filename)
return self.load_stream(filename) else:
else: return self.load_once(filename)
return self.load_once(filename)
except Exception as e:
logger.exception(f"Failed to load file {filename}")
raise e
def load_once(self, filename: str) -> bytes: def load_once(self, filename: str) -> bytes:
try: return self.storage_runner.load_once(filename)
return self.storage_runner.load_once(filename)
except Exception as e:
logger.exception(f"Failed to load_once file {filename}")
raise e
def load_stream(self, filename: str) -> Generator: def load_stream(self, filename: str) -> Generator:
try: return self.storage_runner.load_stream(filename)
return self.storage_runner.load_stream(filename)
except Exception as e:
logger.exception(f"Failed to load_stream file {filename}")
raise e
def download(self, filename, target_filepath): def download(self, filename, target_filepath):
try: self.storage_runner.download(filename, target_filepath)
self.storage_runner.download(filename, target_filepath)
except Exception as e:
logger.exception(f"Failed to download file {filename}")
raise e
def exists(self, filename): def exists(self, filename):
try: return self.storage_runner.exists(filename)
return self.storage_runner.exists(filename)
except Exception as e:
logger.exception(f"Failed to check file exists {filename}")
raise e
def delete(self, filename): def delete(self, filename):
try: return self.storage_runner.delete(filename)
return self.storage_runner.delete(filename)
except Exception as e:
logger.exception(f"Failed to delete file {filename}")
raise e
storage = Storage() storage = Storage()

@ -1091,12 +1091,7 @@ class Message(db.Model): # type: ignore[name-defined]
@property @property
def retriever_resources(self): def retriever_resources(self):
return ( return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []
db.session.query(DatasetRetrieverResource)
.filter(DatasetRetrieverResource.message_id == self.id)
.order_by(DatasetRetrieverResource.position.asc())
.all()
)
@property @property
def message_files(self): def message_files(self):

@ -245,6 +245,13 @@ class Workflow(Base):
@property @property
def tool_published(self) -> bool: def tool_published(self) -> bool:
"""
DEPRECATED: This property is not accurate for determining if a workflow is published as a tool.
It only checks if there's a WorkflowToolProvider for the app, not if this specific workflow version
is the one being used by the tool.
For accurate checking, use a direct query with tenant_id, app_id, and version.
"""
from models.tools import WorkflowToolProvider from models.tools import WorkflowToolProvider
return ( return (
@ -510,7 +517,7 @@ class WorkflowRun(Base):
) )
class WorkflowNodeExecutionTriggeredFrom(Enum): class WorkflowNodeExecutionTriggeredFrom(StrEnum):
""" """
Workflow Node Execution Triggered From Enum Workflow Node Execution Triggered From Enum
""" """
@ -518,21 +525,8 @@ class WorkflowNodeExecutionTriggeredFrom(Enum):
SINGLE_STEP = "single-step" SINGLE_STEP = "single-step"
WORKFLOW_RUN = "workflow-run" WORKFLOW_RUN = "workflow-run"
@classmethod
def value_of(cls, value: str) -> "WorkflowNodeExecutionTriggeredFrom":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid workflow node execution triggered from value {value}")
class WorkflowNodeExecutionStatus(StrEnum):
class WorkflowNodeExecutionStatus(Enum):
""" """
Workflow Node Execution Status Enum Workflow Node Execution Status Enum
""" """
@ -543,19 +537,6 @@ class WorkflowNodeExecutionStatus(Enum):
EXCEPTION = "exception" EXCEPTION = "exception"
RETRY = "retry" RETRY = "retry"
@classmethod
def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid workflow node execution status value {value}")
class WorkflowNodeExecution(Base): class WorkflowNodeExecution(Base):
""" """
@ -656,6 +637,7 @@ class WorkflowNodeExecution(Base):
@property @property
def created_by_account(self): def created_by_account(self):
created_by_role = CreatedByRole(self.created_by_role) created_by_role = CreatedByRole(self.created_by_role)
# TODO(-LAN-): Avoid using db.session.get() here.
return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None
@property @property
@ -663,6 +645,7 @@ class WorkflowNodeExecution(Base):
from models.model import EndUser from models.model import EndUser
created_by_role = CreatedByRole(self.created_by_role) created_by_role = CreatedByRole(self.created_by_role)
# TODO(-LAN-): Avoid using db.session.get() here.
return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
@property @property

@ -30,6 +30,7 @@ dependencies = [
"gunicorn~=23.0.0", "gunicorn~=23.0.0",
"httpx[socks]~=0.27.0", "httpx[socks]~=0.27.0",
"jieba==0.42.1", "jieba==0.42.1",
"json-repair>=0.41.1",
"langfuse~=2.51.3", "langfuse~=2.51.3",
"langsmith~=0.1.77", "langsmith~=0.1.77",
"mailchimp-transactional~=1.0.50", "mailchimp-transactional~=1.0.50",
@ -163,10 +164,7 @@ storage = [
############################################################ ############################################################
# [ Tools ] dependency group # [ Tools ] dependency group
############################################################ ############################################################
tools = [ tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"]
"cloudscraper~=1.2.71",
"nltk~=3.9.1",
]
############################################################ ############################################################
# [ VDB ] dependency group # [ VDB ] dependency group
@ -180,7 +178,7 @@ vdb = [
"couchbase~=4.3.0", "couchbase~=4.3.0",
"elasticsearch==8.14.0", "elasticsearch==8.14.0",
"opensearch-py==2.4.0", "opensearch-py==2.4.0",
"oracledb~=2.2.1", "oracledb==3.0.0",
"pgvecto-rs[sqlalchemy]~=0.2.1", "pgvecto-rs[sqlalchemy]~=0.2.1",
"pgvector==0.2.5", "pgvector==0.2.5",
"pymilvus~=2.5.0", "pymilvus~=2.5.0",

@ -0,0 +1,6 @@
"""
Repository implementations for data access.
This package contains concrete implementations of the repository interfaces
defined in the core.repository package.
"""

@ -0,0 +1,87 @@
"""
Registry for repository implementations.
This module is responsible for registering factory functions with the repository factory.
"""
import logging
from collections.abc import Mapping
from typing import Any
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db
from repositories.workflow_node_execution import SQLAlchemyWorkflowNodeExecutionRepository
logger = logging.getLogger(__name__)
# Storage type constants
STORAGE_TYPE_RDBMS = "rdbms"
STORAGE_TYPE_HYBRID = "hybrid"
def register_repositories() -> None:
"""
Register repository factory functions with the RepositoryFactory.
This function reads configuration settings to determine which repository
implementations to register.
"""
# Configure WorkflowNodeExecutionRepository factory based on configuration
workflow_node_execution_storage = dify_config.WORKFLOW_NODE_EXECUTION_STORAGE
# Check storage type and register appropriate implementation
if workflow_node_execution_storage == STORAGE_TYPE_RDBMS:
# Register SQLAlchemy implementation for RDBMS storage
logger.info("Registering WorkflowNodeExecution repository with RDBMS storage")
RepositoryFactory.register_workflow_node_execution_factory(create_workflow_node_execution_repository)
elif workflow_node_execution_storage == STORAGE_TYPE_HYBRID:
# Hybrid storage is not yet implemented
raise NotImplementedError("Hybrid storage for WorkflowNodeExecution repository is not yet implemented")
else:
# Unknown storage type
raise ValueError(
f"Unknown storage type '{workflow_node_execution_storage}' for WorkflowNodeExecution repository. "
f"Supported types: {STORAGE_TYPE_RDBMS}"
)
def create_workflow_node_execution_repository(params: Mapping[str, Any]) -> SQLAlchemyWorkflowNodeExecutionRepository:
"""
Create a WorkflowNodeExecutionRepository instance using SQLAlchemy implementation.
This factory function creates a repository for the RDBMS storage type.
Args:
params: Parameters for creating the repository, including:
- tenant_id: Required. The tenant ID for multi-tenancy.
- app_id: Optional. The application ID for filtering.
- session_factory: Optional. A SQLAlchemy sessionmaker instance. If not provided,
a new sessionmaker will be created using the global database engine.
Returns:
A WorkflowNodeExecutionRepository instance
Raises:
ValueError: If required parameters are missing
"""
# Extract required parameters
tenant_id = params.get("tenant_id")
if tenant_id is None:
raise ValueError("tenant_id is required for WorkflowNodeExecution repository with RDBMS storage")
# Extract optional parameters
app_id = params.get("app_id")
# Use the session_factory from params if provided, otherwise create one using the global db engine
session_factory = params.get("session_factory")
if session_factory is None:
# Create a sessionmaker using the same engine as the global db session
session_factory = sessionmaker(bind=db.engine)
# Create and return the repository
return SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory, tenant_id=tenant_id, app_id=app_id
)

@ -0,0 +1,9 @@
"""
WorkflowNodeExecution repository implementations.
"""
from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
__all__ = [
"SQLAlchemyWorkflowNodeExecutionRepository",
]

@ -0,0 +1,192 @@
"""
SQLAlchemy implementation of the WorkflowNodeExecutionRepository.
"""
import logging
from collections.abc import Sequence
from typing import Optional
from sqlalchemy import UnaryExpression, asc, delete, desc, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.repository.workflow_node_execution_repository import OrderConfig
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
class SQLAlchemyWorkflowNodeExecutionRepository:
"""
SQLAlchemy implementation of the WorkflowNodeExecutionRepository interface.
This implementation supports multi-tenancy by filtering operations based on tenant_id.
Each method creates its own session, handles the transaction, and commits changes
to the database. This prevents long-running connections in the workflow core.
"""
def __init__(self, session_factory: sessionmaker | Engine, tenant_id: str, app_id: Optional[str] = None):
"""
Initialize the repository with a SQLAlchemy sessionmaker or engine and tenant context.
Args:
session_factory: SQLAlchemy sessionmaker or engine for creating sessions
tenant_id: Tenant ID for multi-tenancy
app_id: Optional app ID for filtering by application
"""
# If an engine is provided, create a sessionmaker from it
if isinstance(session_factory, Engine):
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
else:
self._session_factory = session_factory
self._tenant_id = tenant_id
self._app_id = app_id
def save(self, execution: WorkflowNodeExecution) -> None:
"""
Save a WorkflowNodeExecution instance and commit changes to the database.
Args:
execution: The WorkflowNodeExecution instance to save
"""
with self._session_factory() as session:
# Ensure tenant_id is set
if not execution.tenant_id:
execution.tenant_id = self._tenant_id
# Set app_id if provided and not already set
if self._app_id and not execution.app_id:
execution.app_id = self._app_id
session.add(execution)
session.commit()
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
"""
Retrieve a WorkflowNodeExecution by its node_execution_id.
Args:
node_execution_id: The node execution ID
Returns:
The WorkflowNodeExecution instance if found, None otherwise
"""
with self._session_factory() as session:
stmt = select(WorkflowNodeExecution).where(
WorkflowNodeExecution.node_execution_id == node_execution_id,
WorkflowNodeExecution.tenant_id == self._tenant_id,
)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
return session.scalar(stmt)
def get_by_workflow_run(
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
Args:
workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
order_config.order_direction: Direction to order ("asc" or "desc")
Returns:
A list of WorkflowNodeExecution instances
"""
with self._session_factory() as session:
stmt = select(WorkflowNodeExecution).where(
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
WorkflowNodeExecution.tenant_id == self._tenant_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
# Apply ordering if provided
if order_config and order_config.order_by:
order_columns: list[UnaryExpression] = []
for field in order_config.order_by:
column = getattr(WorkflowNodeExecution, field, None)
if not column:
continue
if order_config.order_direction == "desc":
order_columns.append(desc(column))
else:
order_columns.append(asc(column))
if order_columns:
stmt = stmt.order_by(*order_columns)
return session.scalars(stmt).all()
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
Args:
workflow_run_id: The workflow run ID
Returns:
A list of running WorkflowNodeExecution instances
"""
with self._session_factory() as session:
stmt = select(WorkflowNodeExecution).where(
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
WorkflowNodeExecution.tenant_id == self._tenant_id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
return session.scalars(stmt).all()
def update(self, execution: WorkflowNodeExecution) -> None:
"""
Update an existing WorkflowNodeExecution instance and commit changes to the database.
Args:
execution: The WorkflowNodeExecution instance to update
"""
with self._session_factory() as session:
# Ensure tenant_id is set
if not execution.tenant_id:
execution.tenant_id = self._tenant_id
# Set app_id if provided and not already set
if self._app_id and not execution.app_id:
execution.app_id = self._app_id
session.merge(execution)
session.commit()
def clear(self) -> None:
"""
Clear all WorkflowNodeExecution records for the current tenant_id and app_id.
This method deletes all WorkflowNodeExecution records that match the tenant_id
and app_id (if provided) associated with this repository instance.
"""
with self._session_factory() as session:
stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
result = session.execute(stmt)
session.commit()
deleted_count = result.rowcount
logger.info(
f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}"
+ (f" and app {self._app_id}" if self._app_id else "")
)

@ -407,10 +407,8 @@ class AccountService:
raise PasswordResetRateLimitExceededError() raise PasswordResetRateLimitExceededError()
code = "".join([str(random.randint(0, 9)) for _ in range(6)]) code, token = cls.generate_reset_password_token(account_email, account)
token = TokenManager.generate_token(
account=account, email=email, token_type="reset_password", additional_data={"code": code}
)
send_reset_password_mail_task.delay( send_reset_password_mail_task.delay(
language=language, language=language,
to=account_email, to=account_email,
@ -419,6 +417,22 @@ class AccountService:
cls.reset_password_rate_limiter.increment_rate_limit(account_email) cls.reset_password_rate_limiter.increment_rate_limit(account_email)
return token return token
@classmethod
def generate_reset_password_token(
cls,
email: str,
account: Optional[Account] = None,
code: Optional[str] = None,
additional_data: dict[str, Any] = {},
):
if not code:
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
additional_data["code"] = code
token = TokenManager.generate_token(
account=account, email=email, token_type="reset_password", additional_data=additional_data
)
return code, token
@classmethod @classmethod
def revoke_reset_password_token(cls, token: str): def revoke_reset_password_token(cls, token: str):
TokenManager.revoke_token(token, "reset_password") TokenManager.revoke_token(token, "reset_password")

@ -1,3 +1,4 @@
from configs import dify_config
from core.helper import marketplace from core.helper import marketplace
from core.plugin.entities.plugin import ModelProviderID, PluginDependency, PluginInstallationSource, ToolProviderID from core.plugin.entities.plugin import ModelProviderID, PluginDependency, PluginInstallationSource, ToolProviderID
from core.plugin.manager.plugin import PluginInstallationManager from core.plugin.manager.plugin import PluginInstallationManager
@ -111,6 +112,8 @@ class DependenciesAnalysisService:
Generate the latest version of dependencies Generate the latest version of dependencies
""" """
dependencies = list(set(dependencies)) dependencies = list(set(dependencies))
if not dify_config.MARKETPLACE_ENABLED:
return []
deps = marketplace.batch_fetch_plugin_manifests(dependencies) deps = marketplace.batch_fetch_plugin_manifests(dependencies)
return [ return [
PluginDependency( PluginDependency(

@ -2,13 +2,14 @@ import threading
from typing import Optional from typing import Optional
import contexts import contexts
from core.repository import RepositoryFactory
from core.repository.workflow_node_execution_repository import OrderConfig
from extensions.ext_database import db from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from models.model import App from models.model import App
from models.workflow import ( from models.workflow import (
WorkflowNodeExecution, WorkflowNodeExecution,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun, WorkflowRun,
) )
@ -127,17 +128,17 @@ class WorkflowRunService:
if not workflow_run: if not workflow_run:
return [] return []
node_executions = ( # Use the repository to get the node executions
db.session.query(WorkflowNodeExecution) repository = RepositoryFactory.create_workflow_node_execution_repository(
.filter( params={
WorkflowNodeExecution.tenant_id == app_model.tenant_id, "tenant_id": app_model.tenant_id,
WorkflowNodeExecution.app_id == app_model.id, "app_id": app_model.id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, "session_factory": db.session.get_bind,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, }
WorkflowNodeExecution.workflow_run_id == run_id,
)
.order_by(WorkflowNodeExecution.index.desc())
.all()
) )
return node_executions # Use the repository to get the node executions with ordering
order_config = OrderConfig(order_by=["index"], order_direction="desc")
node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config)
return list(node_executions)

@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.repository import RepositoryFactory
from core.variables import Variable from core.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.errors import WorkflowNodeRunFailedError
@ -27,6 +28,7 @@ from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.enums import CreatedByRole from models.enums import CreatedByRole
from models.model import App, AppMode from models.model import App, AppMode
from models.tools import WorkflowToolProvider
from models.workflow import ( from models.workflow import (
Workflow, Workflow,
WorkflowNodeExecution, WorkflowNodeExecution,
@ -282,8 +284,15 @@ class WorkflowService:
workflow_node_execution.created_by = account.id workflow_node_execution.created_by = account.id
workflow_node_execution.workflow_id = draft_workflow.id workflow_node_execution.workflow_id = draft_workflow.id
db.session.add(workflow_node_execution) # Use the repository to save the workflow node execution
db.session.commit() repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": app_model.tenant_id,
"app_id": app_model.id,
"session_factory": db.session.get_bind,
}
)
repository.save(workflow_node_execution)
return workflow_node_execution return workflow_node_execution
@ -515,8 +524,19 @@ class WorkflowService:
# Cannot delete a workflow that's currently in use by an app # Cannot delete a workflow that's currently in use by an app
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'") raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'")
# Check if this workflow is published as a tool # Don't use workflow.tool_published as it's not accurate for specific workflow versions
if workflow.tool_published: # Check if there's a tool provider using this specific workflow version
tool_provider = (
session.query(WorkflowToolProvider)
.filter(
WorkflowToolProvider.tenant_id == workflow.tenant_id,
WorkflowToolProvider.app_id == workflow.app_id,
WorkflowToolProvider.version == workflow.version,
)
.first()
)
if tool_provider:
# Cannot delete a workflow that's published as a tool # Cannot delete a workflow that's published as a tool
raise WorkflowInUseError("Cannot delete workflow that is published as a tool") raise WorkflowInUseError("Cannot delete workflow that is published as a tool")

@ -7,6 +7,7 @@ from celery import shared_task # type: ignore
from sqlalchemy import delete from sqlalchemy import delete
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from core.repository import RepositoryFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import AppDatasetJoin from models.dataset import AppDatasetJoin
from models.model import ( from models.model import (
@ -30,7 +31,7 @@ from models.model import (
) )
from models.tools import WorkflowToolProvider from models.tools import WorkflowToolProvider
from models.web import PinnedConversation, SavedMessage from models.web import PinnedConversation, SavedMessage
from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowRun
@shared_task(queue="app_deletion", bind=True, max_retries=3) @shared_task(queue="app_deletion", bind=True, max_retries=3)
@ -187,18 +188,20 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str):
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
def del_workflow_node_execution(workflow_node_execution_id: str): # Create a repository instance for WorkflowNodeExecution
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).delete( repository = RepositoryFactory.create_workflow_node_execution_repository(
synchronize_session=False params={
) "tenant_id": tenant_id,
"app_id": app_id,
_delete_records( "session_factory": db.session.get_bind,
"""select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""", }
{"tenant_id": tenant_id, "app_id": app_id},
del_workflow_node_execution,
"workflow node execution",
) )
# Use the clear method to delete all records for this tenant_id and app_id
repository.clear()
logging.info(click.style(f"Deleted workflow node executions for tenant {tenant_id} and app {app_id}", fg="green"))
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def del_workflow_app_log(workflow_app_log_id: str): def del_workflow_app_log(workflow_app_log_id: str):

@ -0,0 +1,99 @@
from unittest.mock import MagicMock, patch
from core.model_runtime.entities.message_entities import AssistantPromptMessage
from core.model_runtime.model_providers.__base.large_language_model import _increase_tool_call
ToolCall = AssistantPromptMessage.ToolCall
# CASE 1: Single tool call
INPUTS_CASE_1 = [
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
]
EXPECTED_CASE_1 = [
ToolCall(
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
),
]
# CASE 2: Tool call sequences where IDs are anchored to the first chunk (vLLM/SiliconFlow ...)
INPUTS_CASE_2 = [
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
]
EXPECTED_CASE_2 = [
ToolCall(
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
),
ToolCall(
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
),
]
# CASE 3: Tool call sequences where IDs are anchored to every chunk (SGLang ...)
INPUTS_CASE_3 = [
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
]
EXPECTED_CASE_3 = [
ToolCall(
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
),
ToolCall(
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
),
]
# CASE 4: Tool call sequences with no IDs
INPUTS_CASE_4 = [
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
]
EXPECTED_CASE_4 = [
ToolCall(
id="RANDOM_ID_1",
type="function",
function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'),
),
ToolCall(
id="RANDOM_ID_2",
type="function",
function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}'),
),
]
def _run_case(inputs: list[ToolCall], expected: list[ToolCall]):
actual = []
_increase_tool_call(inputs, actual)
assert actual == expected
def test__increase_tool_call():
# case 1:
_run_case(INPUTS_CASE_1, EXPECTED_CASE_1)
# case 2:
_run_case(INPUTS_CASE_2, EXPECTED_CASE_2)
# case 3:
_run_case(INPUTS_CASE_3, EXPECTED_CASE_3)
# case 4:
mock_id_generator = MagicMock()
mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4]
with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator):
_run_case(INPUTS_CASE_4, EXPECTED_CASE_4)

@ -0,0 +1,3 @@
"""
Unit tests for repositories.
"""

@ -0,0 +1,3 @@
"""
Unit tests for workflow_node_execution repositories.
"""

@ -0,0 +1,178 @@
"""
Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository.
"""
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from sqlalchemy.orm import Session, sessionmaker
from core.repository.workflow_node_execution_repository import OrderConfig
from models.workflow import WorkflowNodeExecution
from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
@pytest.fixture
def session():
"""Create a mock SQLAlchemy session."""
session = MagicMock(spec=Session)
# Configure the session to be used as a context manager
session.__enter__ = MagicMock(return_value=session)
session.__exit__ = MagicMock(return_value=None)
# Configure the session factory to return the session
session_factory = MagicMock(spec=sessionmaker)
session_factory.return_value = session
return session, session_factory
@pytest.fixture
def repository(session):
"""Create a repository instance with test data."""
_, session_factory = session
tenant_id = "test-tenant"
app_id = "test-app"
return SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory, tenant_id=tenant_id, app_id=app_id
)
def test_save(repository, session):
"""Test save method."""
session_obj, _ = session
# Create a mock execution
execution = MagicMock(spec=WorkflowNodeExecution)
execution.tenant_id = None
execution.app_id = None
# Call save method
repository.save(execution)
# Assert tenant_id and app_id are set
assert execution.tenant_id == repository._tenant_id
assert execution.app_id == repository._app_id
# Assert session.add was called
session_obj.add.assert_called_once_with(execution)
def test_save_with_existing_tenant_id(repository, session):
"""Test save method with existing tenant_id."""
session_obj, _ = session
# Create a mock execution with existing tenant_id
execution = MagicMock(spec=WorkflowNodeExecution)
execution.tenant_id = "existing-tenant"
execution.app_id = None
# Call save method
repository.save(execution)
# Assert tenant_id is not changed and app_id is set
assert execution.tenant_id == "existing-tenant"
assert execution.app_id == repository._app_id
# Assert session.add was called
session_obj.add.assert_called_once_with(execution)
def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
"""Test get_by_node_execution_id method."""
session_obj, _ = session
# Set up mock
mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
mock_stmt = mocker.MagicMock()
mock_select.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt
session_obj.scalar.return_value = mocker.MagicMock(spec=WorkflowNodeExecution)
# Call method
result = repository.get_by_node_execution_id("test-node-execution-id")
# Assert select was called with correct parameters
mock_select.assert_called_once()
session_obj.scalar.assert_called_once_with(mock_stmt)
assert result is not None
def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
"""Test get_by_workflow_run method."""
session_obj, _ = session
# Set up mock
mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
mock_stmt = mocker.MagicMock()
mock_select.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt
mock_stmt.order_by.return_value = mock_stmt
session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)]
# Call method
order_config = OrderConfig(order_by=["index"], order_direction="desc")
result = repository.get_by_workflow_run(workflow_run_id="test-workflow-run-id", order_config=order_config)
# Assert select was called with correct parameters
mock_select.assert_called_once()
session_obj.scalars.assert_called_once_with(mock_stmt)
assert len(result) == 1
def test_get_running_executions(repository, session, mocker: MockerFixture):
"""Test get_running_executions method."""
session_obj, _ = session
# Set up mock
mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
mock_stmt = mocker.MagicMock()
mock_select.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt
session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)]
# Call method
result = repository.get_running_executions("test-workflow-run-id")
# Assert select was called with correct parameters
mock_select.assert_called_once()
session_obj.scalars.assert_called_once_with(mock_stmt)
assert len(result) == 1
def test_update(repository, session):
"""Test update method."""
session_obj, _ = session
# Create a mock execution
execution = MagicMock(spec=WorkflowNodeExecution)
execution.tenant_id = None
execution.app_id = None
# Call update method
repository.update(execution)
# Assert tenant_id and app_id are set
assert execution.tenant_id == repository._tenant_id
assert execution.app_id == repository._app_id
# Assert session.merge was called
session_obj.merge.assert_called_once_with(execution)
def test_clear(repository, session, mocker: MockerFixture):
"""Test clear method."""
session_obj, _ = session
# Set up mock
mock_delete = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.delete")
mock_stmt = mocker.MagicMock()
mock_delete.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt
# Mock the execute result with rowcount
mock_result = mocker.MagicMock()
mock_result.rowcount = 5 # Simulate 5 records deleted
session_obj.execute.return_value = mock_result
# Call method
repository.clear()
# Assert delete was called with correct parameters
mock_delete.assert_called_once_with(WorkflowNodeExecution)
mock_stmt.where.assert_called()
session_obj.execute.assert_called_once_with(mock_stmt)
session_obj.commit.assert_called_once()

@ -40,6 +40,10 @@ def workflow_setup():
def test_delete_workflow_success(workflow_setup): def test_delete_workflow_success(workflow_setup):
# Setup mocks # Setup mocks
# Mock the tool provider query to return None (not published as a tool)
workflow_setup["session"].query.return_value.filter.return_value.first.return_value = None
workflow_setup["session"].scalar = MagicMock( workflow_setup["session"].scalar = MagicMock(
side_effect=[workflow_setup["workflow"], None] side_effect=[workflow_setup["workflow"], None]
) # Return workflow first, then None for app ) # Return workflow first, then None for app
@ -97,7 +101,12 @@ def test_delete_workflow_in_use_by_app_error(workflow_setup):
def test_delete_workflow_published_as_tool_error(workflow_setup): def test_delete_workflow_published_as_tool_error(workflow_setup):
# Setup mocks # Setup mocks
workflow_setup["workflow"].tool_published = True from models.tools import WorkflowToolProvider
# Mock the tool provider query
mock_tool_provider = MagicMock(spec=WorkflowToolProvider)
workflow_setup["session"].query.return_value.filter.return_value.first.return_value = mock_tool_provider
workflow_setup["session"].scalar = MagicMock( workflow_setup["session"].scalar = MagicMock(
side_effect=[workflow_setup["workflow"], None] side_effect=[workflow_setup["workflow"], None]
) # Return workflow first, then None for app ) # Return workflow first, then None for app

@ -1,5 +1,4 @@
version = 1 version = 1
revision = 1
requires-python = ">=3.11, <3.13" requires-python = ">=3.11, <3.13"
resolution-markers = [ resolution-markers = [
"python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy'", "python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy'",
@ -1178,6 +1177,7 @@ dependencies = [
{ name = "gunicorn" }, { name = "gunicorn" },
{ name = "httpx", extra = ["socks"] }, { name = "httpx", extra = ["socks"] },
{ name = "jieba" }, { name = "jieba" },
{ name = "json-repair" },
{ name = "langfuse" }, { name = "langfuse" },
{ name = "langsmith" }, { name = "langsmith" },
{ name = "mailchimp-transactional" }, { name = "mailchimp-transactional" },
@ -1346,6 +1346,7 @@ requires-dist = [
{ name = "gunicorn", specifier = "~=23.0.0" }, { name = "gunicorn", specifier = "~=23.0.0" },
{ name = "httpx", extras = ["socks"], specifier = "~=0.27.0" }, { name = "httpx", extras = ["socks"], specifier = "~=0.27.0" },
{ name = "jieba", specifier = "==0.42.1" }, { name = "jieba", specifier = "==0.42.1" },
{ name = "json-repair", specifier = ">=0.41.1" },
{ name = "langfuse", specifier = "~=2.51.3" }, { name = "langfuse", specifier = "~=2.51.3" },
{ name = "langsmith", specifier = "~=0.1.77" }, { name = "langsmith", specifier = "~=0.1.77" },
{ name = "mailchimp-transactional", specifier = "~=1.0.50" }, { name = "mailchimp-transactional", specifier = "~=1.0.50" },
@ -1470,7 +1471,7 @@ vdb = [
{ name = "couchbase", specifier = "~=4.3.0" }, { name = "couchbase", specifier = "~=4.3.0" },
{ name = "elasticsearch", specifier = "==8.14.0" }, { name = "elasticsearch", specifier = "==8.14.0" },
{ name = "opensearch-py", specifier = "==2.4.0" }, { name = "opensearch-py", specifier = "==2.4.0" },
{ name = "oracledb", specifier = "~=2.2.1" }, { name = "oracledb", specifier = "==3.0.0" },
{ name = "pgvecto-rs", extras = ["sqlalchemy"], specifier = "~=0.2.1" }, { name = "pgvecto-rs", extras = ["sqlalchemy"], specifier = "~=0.2.1" },
{ name = "pgvector", specifier = "==0.2.5" }, { name = "pgvector", specifier = "==0.2.5" },
{ name = "pymilvus", specifier = "~=2.5.0" }, { name = "pymilvus", specifier = "~=2.5.0" },
@ -2524,6 +2525,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/91/29/df4b9b42f2be0b623cbd5e2140cafcaa2bef0759a00b7b70104dcfe2fb51/joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6", size = 301817 }, { url = "https://files.pythonhosted.org/packages/91/29/df4b9b42f2be0b623cbd5e2140cafcaa2bef0759a00b7b70104dcfe2fb51/joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6", size = 301817 },
] ]
[[package]]
name = "json-repair"
version = "0.41.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/6d/6a/6c7a75a10da6dc807b582f2449034da1ed74415e8899746bdfff97109012/json_repair-0.41.1.tar.gz", hash = "sha256:bba404b0888c84a6b86ecc02ec43b71b673cfee463baf6da94e079c55b136565", size = 31208 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/10/5c/abd7495c934d9af5c263c2245ae30cfaa716c3c0cf027b2b8fa686ee7bd4/json_repair-0.41.1-py3-none-any.whl", hash = "sha256:0e181fd43a696887881fe19fed23422a54b3e4c558b6ff27a86a8c3ddde9ae79", size = 21578 },
]
[[package]] [[package]]
name = "jsonpath-python" name = "jsonpath-python"
version = "1.0.6" version = "1.0.6"
@ -3590,23 +3600,23 @@ wheels = [
[[package]] [[package]]
name = "oracledb" name = "oracledb"
version = "2.2.1" version = "3.0.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "cryptography" }, { name = "cryptography" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/36/fb/3fbacb351833dd794abb184303a5761c4bb33df9d770fd15d01ead2ff738/oracledb-2.2.1.tar.gz", hash = "sha256:8464c6f0295f3318daf6c2c72c83c2dcbc37e13f8fd44e3e39ff8665f442d6b6", size = 580818 } sdist = { url = "https://files.pythonhosted.org/packages/bf/39/712f797b75705c21148fa1d98651f63c2e5cc6876e509a0a9e2f5b406572/oracledb-3.0.0.tar.gz", hash = "sha256:64dc86ee5c032febc556798b06e7b000ef6828bb0252084f6addacad3363db85", size = 840431 }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/74/b7/a4238295944670fb8cc50a8cc082e0af5a0440bfb1c2bac2b18429c0a579/oracledb-2.2.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:fb6d9a4d7400398b22edb9431334f9add884dec9877fd9c4ae531e1ccc6ee1fd", size = 3551303 }, { url = "https://files.pythonhosted.org/packages/fa/bf/d872c4b3fc15cd3261fe0ea72b21d181700c92dbc050160e161654987062/oracledb-3.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:52daa9141c63dfa75c07d445e9bb7f69f43bfb3c5a173ecc48c798fe50288d26", size = 4312963 },
{ url = "https://files.pythonhosted.org/packages/4f/5f/98481d44976cd2b3086361f2d50026066b24090b0e6cd1f2a12c824e9717/oracledb-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07757c240afbb4f28112a6affc2c5e4e34b8a92e5bb9af81a40fba398da2b028", size = 12258455 }, { url = "https://files.pythonhosted.org/packages/b1/ea/01ee29e76a610a53bb34fdc1030f04b7669c3f80b25f661e07850fc6160e/oracledb-3.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:af98941789df4c6aaaf4338f5b5f6b7f2c8c3fe6f8d6a9382f177f350868747a", size = 2661536 },
{ url = "https://files.pythonhosted.org/packages/e9/54/06b2540286e2b63f60877d6f3c6c40747e216b6eeda0756260e194897076/oracledb-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63daec72f853c47179e98493e9b732909d96d495bdceb521c5973a3940d28142", size = 12317476 }, { url = "https://files.pythonhosted.org/packages/3d/8e/ad380e34a46819224423b4773e58c350bc6269643c8969604097ced8c3bc/oracledb-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9812bb48865aaec35d73af54cd1746679f2a8a13cbd1412ab371aba2e39b3943", size = 2867461 },
{ url = "https://files.pythonhosted.org/packages/4d/1a/67814439a4e24df83281a72cb0ba433d6b74e1bff52a9975b87a725bcba5/oracledb-2.2.1-cp311-cp311-win32.whl", hash = "sha256:fec5318d1e0ada7e4674574cb6c8d1665398e8b9c02982279107212f05df1660", size = 1369368 }, { url = "https://files.pythonhosted.org/packages/96/09/ecc4384a27fd6e1e4de824ae9c160e4ad3aaebdaade5b4bdcf56a4d1ff63/oracledb-3.0.0-cp311-cp311-win32.whl", hash = "sha256:6c27fe0de64f2652e949eb05b3baa94df9b981a4a45fa7f8a991e1afb450c8e2", size = 1752046 },
{ url = "https://files.pythonhosted.org/packages/e3/b8/b2a8f0607be17f58ec6689ad5fd15c2956f4996c64547325e96439570edf/oracledb-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:5134dccb5a11bc755abf02fd49be6dc8141dfcae4b650b55d40509323d00b5c2", size = 1655035 }, { url = "https://files.pythonhosted.org/packages/62/e8/f34bde24050c6e55eeba46b23b2291f2dd7fd272fa8b322dcbe71be55778/oracledb-3.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:f922709672002f0b40997456f03a95f03e5712a86c61159951c5ce09334325e0", size = 2101210 },
{ url = "https://files.pythonhosted.org/packages/24/5b/2fff762243030f31a6b1561fc8eeb142e69ba6ebd3e7fbe4a2c82f0eb6f0/oracledb-2.2.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ac5716bc9a48247fdf563f5f4ec097f5c9f074a60fd130cdfe16699208ca29b5", size = 3583960 }, { url = "https://files.pythonhosted.org/packages/6f/fc/24590c3a3d41e58494bd3c3b447a62835138e5f9b243d9f8da0cfb5da8dc/oracledb-3.0.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:acd0e747227dea01bebe627b07e958bf36588a337539f24db629dc3431d3f7eb", size = 4351993 },
{ url = "https://files.pythonhosted.org/packages/e6/88/34117ae830e7338af7c0481f1c0fc6eda44d558e12f9203b45b491e53071/oracledb-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c150bddb882b7c73fb462aa2d698744da76c363e404570ed11d05b65811d96c3", size = 11749006 }, { url = "https://files.pythonhosted.org/packages/b7/b6/1f3b0b7bb94d53e8857d77b2e8dbdf6da091dd7e377523e24b79dac4fd71/oracledb-3.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f8b402f77c22af031cd0051aea2472ecd0635c1b452998f511aa08b7350c90a4", size = 2532640 },
{ url = "https://files.pythonhosted.org/packages/9d/58/bac788f18c21f727955652fe238de2d24a12c2b455ed4db18a6d23ff781e/oracledb-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:193e1888411bc21187ade4b16b76820bd1e8f216e25602f6cd0a97d45723c1dc", size = 11950663 }, { url = "https://files.pythonhosted.org/packages/72/1a/1815f6c086ab49c00921cf155ff5eede5267fb29fcec37cb246339a5ce4d/oracledb-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:378a27782e9a37918bd07a5a1427a77cb6f777d0a5a8eac9c070d786f50120ef", size = 2765949 },
{ url = "https://files.pythonhosted.org/packages/3b/e2/005f66ae919c6f7c73e06863256cf43aa844330e2dc61a5f9779ae44a801/oracledb-2.2.1-cp312-cp312-win32.whl", hash = "sha256:44a960f8bbb0711af222e0a9690e037b6a2a382e0559ae8eeb9cfafe26c7a3bc", size = 1324255 }, { url = "https://files.pythonhosted.org/packages/33/8d/208900f8d372909792ee70b2daad3f7361181e55f2217c45ed9dff658b54/oracledb-3.0.0-cp312-cp312-win32.whl", hash = "sha256:54a28c2cb08316a527cd1467740a63771cc1c1164697c932aa834c0967dc4efc", size = 1709373 },
{ url = "https://files.pythonhosted.org/packages/e6/25/759eb2143134513382e66d874c4aacfd691dec3fef7141170cfa6c1b154f/oracledb-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:470136add32f0d0084225c793f12a52b61b52c3dc00c9cd388ec6a3db3a7643e", size = 1613047 }, { url = "https://files.pythonhosted.org/packages/0c/5e/c21754f19c896102793c3afec2277e2180aa7d505e4d7fcca24b52d14e4f/oracledb-3.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:8289bad6d103ce42b140e40576cf0c81633e344d56e2d738b539341eacf65624", size = 2056452 },
] ]
[[package]] [[package]]
@ -4074,6 +4084,8 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/af/cd/ed6e429fb0792ce368f66e83246264dd3a7a045b0b1e63043ed22a063ce5/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7c9e222d0976f68d0cf6409cfea896676ddc1d98485d601e9508f90f60e2b0a2", size = 2144914 }, { url = "https://files.pythonhosted.org/packages/af/cd/ed6e429fb0792ce368f66e83246264dd3a7a045b0b1e63043ed22a063ce5/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7c9e222d0976f68d0cf6409cfea896676ddc1d98485d601e9508f90f60e2b0a2", size = 2144914 },
{ url = "https://files.pythonhosted.org/packages/f6/23/b064bd4cfbf2cc5f25afcde0e7c880df5b20798172793137ba4b62d82e72/pycryptodome-3.19.1-cp35-abi3-win32.whl", hash = "sha256:4805e053571140cb37cf153b5c72cd324bb1e3e837cbe590a19f69b6cf85fd03", size = 1713105 }, { url = "https://files.pythonhosted.org/packages/f6/23/b064bd4cfbf2cc5f25afcde0e7c880df5b20798172793137ba4b62d82e72/pycryptodome-3.19.1-cp35-abi3-win32.whl", hash = "sha256:4805e053571140cb37cf153b5c72cd324bb1e3e837cbe590a19f69b6cf85fd03", size = 1713105 },
{ url = "https://files.pythonhosted.org/packages/7d/e0/ded1968a5257ab34216a0f8db7433897a2337d59e6d03be113713b346ea2/pycryptodome-3.19.1-cp35-abi3-win_amd64.whl", hash = "sha256:a470237ee71a1efd63f9becebc0ad84b88ec28e6784a2047684b693f458f41b7", size = 1749222 }, { url = "https://files.pythonhosted.org/packages/7d/e0/ded1968a5257ab34216a0f8db7433897a2337d59e6d03be113713b346ea2/pycryptodome-3.19.1-cp35-abi3-win_amd64.whl", hash = "sha256:a470237ee71a1efd63f9becebc0ad84b88ec28e6784a2047684b693f458f41b7", size = 1749222 },
{ url = "https://files.pythonhosted.org/packages/1d/e3/0c9679cd66cf5604b1f070bdf4525a0c01a15187be287d8348b2eafb718e/pycryptodome-3.19.1-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:ed932eb6c2b1c4391e166e1a562c9d2f020bfff44a0e1b108f67af38b390ea89", size = 1629005 },
{ url = "https://files.pythonhosted.org/packages/13/75/0d63bf0daafd0580b17202d8a9dd57f28c8487f26146b3e2799b0c5a059c/pycryptodome-3.19.1-pp27-pypy_73-win32.whl", hash = "sha256:81e9d23c0316fc1b45d984a44881b220062336bbdc340aa9218e8d0656587934", size = 1697997 },
] ]
[[package]] [[package]]

@ -744,6 +744,12 @@ MAX_VARIABLE_SIZE=204800
WORKFLOW_PARALLEL_DEPTH_LIMIT=3 WORKFLOW_PARALLEL_DEPTH_LIMIT=3
WORKFLOW_FILE_UPLOAD_LIMIT=10 WORKFLOW_FILE_UPLOAD_LIMIT=10
# Workflow storage configuration
# Options: rdbms, hybrid
# rdbms: Use only the relational database (default)
# hybrid: Save new data to object storage, read from both object storage and RDBMS
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
# HTTP request node in workflow configuration # HTTP request node in workflow configuration
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576

@ -130,6 +130,7 @@ services:
HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128} HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128}
HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128} HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128}
SANDBOX_PORT: ${SANDBOX_PORT:-8194} SANDBOX_PORT: ${SANDBOX_PORT:-8194}
PIP_MIRROR_URL: ${PIP_MIRROR_URL:-}
volumes: volumes:
- ./volumes/sandbox/dependencies:/dependencies - ./volumes/sandbox/dependencies:/dependencies
- ./volumes/sandbox/conf:/conf - ./volumes/sandbox/conf:/conf

@ -60,6 +60,7 @@ services:
HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128} HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128}
HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128} HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128}
SANDBOX_PORT: ${SANDBOX_PORT:-8194} SANDBOX_PORT: ${SANDBOX_PORT:-8194}
PIP_MIRROR_URL: ${PIP_MIRROR_URL:-}
volumes: volumes:
- ./volumes/sandbox/dependencies:/dependencies - ./volumes/sandbox/dependencies:/dependencies
- ./volumes/sandbox/conf:/conf - ./volumes/sandbox/conf:/conf

@ -327,6 +327,7 @@ x-shared-env: &shared-api-worker-env
MAX_VARIABLE_SIZE: ${MAX_VARIABLE_SIZE:-204800} MAX_VARIABLE_SIZE: ${MAX_VARIABLE_SIZE:-204800}
WORKFLOW_PARALLEL_DEPTH_LIMIT: ${WORKFLOW_PARALLEL_DEPTH_LIMIT:-3} WORKFLOW_PARALLEL_DEPTH_LIMIT: ${WORKFLOW_PARALLEL_DEPTH_LIMIT:-3}
WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10} WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10}
WORKFLOW_NODE_EXECUTION_STORAGE: ${WORKFLOW_NODE_EXECUTION_STORAGE:-rdbms}
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760}
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576}
HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True} HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True}
@ -605,6 +606,7 @@ services:
HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128} HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128}
HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128} HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128}
SANDBOX_PORT: ${SANDBOX_PORT:-8194} SANDBOX_PORT: ${SANDBOX_PORT:-8194}
PIP_MIRROR_URL: ${PIP_MIRROR_URL:-}
volumes: volumes:
- ./volumes/sandbox/dependencies:/dependencies - ./volumes/sandbox/dependencies:/dependencies
- ./volumes/sandbox/conf:/conf - ./volumes/sandbox/conf:/conf

@ -7,7 +7,7 @@ This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next
### Run by source code ### Run by source code
Before starting the web frontend service, please make sure the following environment is ready. Before starting the web frontend service, please make sure the following environment is ready.
- [Node.js](https://nodejs.org) >= v18.x - [Node.js](https://nodejs.org) >= v22.11.x
- [pnpm](https://pnpm.io) v10.x - [pnpm](https://pnpm.io) v10.x
First, install the dependencies: First, install the dependencies:

@ -1,11 +1,11 @@
'use client' 'use client'
import Workflow from '@/app/components/workflow' import WorkflowApp from '@/app/components/workflow-app'
const Page = () => { const Page = () => {
return ( return (
<div className='h-full w-full overflow-x-auto'> <div className='h-full w-full overflow-x-auto'>
<Workflow /> <WorkflowApp />
</div> </div>
) )
} }

@ -94,6 +94,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
- <code>semantic_search</code> 语义检索 - <code>semantic_search</code> 语义检索
- <code>full_text_search</code> 全文检索 - <code>full_text_search</code> 全文检索
- <code>reranking_enable</code> (bool) 是否开启rerank - <code>reranking_enable</code> (bool) 是否开启rerank
- <code>reranking_mode</code> (String) 混合检索
- <code>weighted_score</code> 权重设置
- <code>reranking_model</code> Rerank 模型
- <code>reranking_model</code> (object) Rerank 模型配置 - <code>reranking_model</code> (object) Rerank 模型配置
- <code>reranking_provider_name</code> (string) Rerank 模型的提供商 - <code>reranking_provider_name</code> (string) Rerank 模型的提供商
- <code>reranking_model_name</code> (string) Rerank 模型的名称 - <code>reranking_model_name</code> (string) Rerank 模型的名称
@ -591,7 +594,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Property> </Property>
<Property name='retrieval_model' type='object' key='retrieval_model'> <Property name='retrieval_model' type='object' key='retrieval_model'>
检索参数(选填,如不填,按照默认方式召回) 检索参数(选填,如不填,按照默认方式召回)
- <code>search_method</code> (text) 检索方法:以下个关键字之一,必填 - <code>search_method</code> (text) 检索方法:以下个关键字之一,必填
- <code>keyword_search</code> 关键字检索 - <code>keyword_search</code> 关键字检索
- <code>semantic_search</code> 语义检索 - <code>semantic_search</code> 语义检索
- <code>full_text_search</code> 全文检索 - <code>full_text_search</code> 全文检索
@ -1817,7 +1820,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Property> </Property>
<Property name='retrieval_model' type='object' key='retrieval_model'> <Property name='retrieval_model' type='object' key='retrieval_model'>
检索参数(选填,如不填,按照默认方式召回) 检索参数(选填,如不填,按照默认方式召回)
- <code>search_method</code> (text) 检索方法:以下个关键字之一,必填 - <code>search_method</code> (text) 检索方法:以下个关键字之一,必填
- <code>keyword_search</code> 关键字检索 - <code>keyword_search</code> 关键字检索
- <code>semantic_search</code> 语义检索 - <code>semantic_search</code> 语义检索
- <code>full_text_search</code> 全文检索 - <code>full_text_search</code> 全文检索

@ -0,0 +1,82 @@
import { fireEvent, render, screen } from '@testing-library/react'
import ConfigSelect from './index'
jest.mock('react-sortablejs', () => ({
ReactSortable: ({ children }: { children: React.ReactNode }) => <div>{children}</div>,
}))
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
describe('ConfigSelect Component', () => {
const defaultProps = {
options: ['Option 1', 'Option 2'],
onChange: jest.fn(),
}
afterEach(() => {
jest.clearAllMocks()
})
it('renders all options', () => {
render(<ConfigSelect {...defaultProps} />)
defaultProps.options.forEach((option) => {
expect(screen.getByDisplayValue(option)).toBeInTheDocument()
})
})
it('renders add button', () => {
render(<ConfigSelect {...defaultProps} />)
expect(screen.getByText('appDebug.variableConfig.addOption')).toBeInTheDocument()
})
it('handles option deletion', () => {
render(<ConfigSelect {...defaultProps} />)
const optionContainer = screen.getByDisplayValue('Option 1').closest('div')
const deleteButton = optionContainer?.querySelector('div[role="button"]')
if (!deleteButton) return
fireEvent.click(deleteButton)
expect(defaultProps.onChange).toHaveBeenCalledWith(['Option 2'])
})
it('handles adding new option', () => {
render(<ConfigSelect {...defaultProps} />)
const addButton = screen.getByText('appDebug.variableConfig.addOption')
fireEvent.click(addButton)
expect(defaultProps.onChange).toHaveBeenCalledWith([...defaultProps.options, ''])
})
it('applies focus styles on input focus', () => {
render(<ConfigSelect {...defaultProps} />)
const firstInput = screen.getByDisplayValue('Option 1')
fireEvent.focus(firstInput)
expect(firstInput.closest('div')).toHaveClass('border-components-input-border-active')
})
it('applies delete hover styles', () => {
render(<ConfigSelect {...defaultProps} />)
const optionContainer = screen.getByDisplayValue('Option 1').closest('div')
const deleteButton = optionContainer?.querySelector('div[role="button"]')
if (!deleteButton) return
fireEvent.mouseEnter(deleteButton)
expect(optionContainer).toHaveClass('border-components-input-border-destructive')
})
it('renders empty state correctly', () => {
render(<ConfigSelect options={[]} onChange={defaultProps.onChange} />)
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
expect(screen.getByText('appDebug.variableConfig.addOption')).toBeInTheDocument()
})
})

@ -51,7 +51,7 @@ const ConfigSelect: FC<IConfigSelectProps> = ({
<RiDraggable className='handle h-4 w-4 cursor-grab text-text-quaternary' /> <RiDraggable className='handle h-4 w-4 cursor-grab text-text-quaternary' />
<input <input
key={index} key={index}
type="input" type='input'
value={o || ''} value={o || ''}
onChange={(e) => { onChange={(e) => {
const value = e.target.value const value = e.target.value
@ -67,6 +67,7 @@ const ConfigSelect: FC<IConfigSelectProps> = ({
onBlur={() => setFocusID(null)} onBlur={() => setFocusID(null)}
/> />
<div <div
role='button'
className='absolute right-1.5 top-1/2 block translate-y-[-50%] cursor-pointer rounded-md p-1 text-text-tertiary hover:bg-state-destructive-hover hover:text-text-destructive' className='absolute right-1.5 top-1/2 block translate-y-[-50%] cursor-pointer rounded-md p-1 text-text-tertiary hover:bg-state-destructive-hover hover:text-text-destructive'
onClick={() => { onClick={() => {
onChange(options.filter((_, i) => index !== i)) onChange(options.filter((_, i) => index !== i))

@ -162,11 +162,22 @@ const SettingsModal: FC<ISettingsModalProps> = ({
return check return check
} }
const validatePrivacyPolicy = (privacyPolicy: string | null) => {
if (privacyPolicy === null || privacyPolicy?.length === 0)
return true
return privacyPolicy.startsWith('http://') || privacyPolicy.startsWith('https://')
}
if (inputInfo !== null) { if (inputInfo !== null) {
if (!validateColorHex(inputInfo.chatColorTheme)) { if (!validateColorHex(inputInfo.chatColorTheme)) {
notify({ type: 'error', message: t(`${prefixSettings}.invalidHexMessage`) }) notify({ type: 'error', message: t(`${prefixSettings}.invalidHexMessage`) })
return return
} }
if (!validatePrivacyPolicy(inputInfo.privacyPolicy)) {
notify({ type: 'error', message: t(`${prefixSettings}.invalidPrivacyPolicy`) })
return
}
} }
setSaveLoading(true) setSaveLoading(true)
@ -410,7 +421,7 @@ const SettingsModal: FC<ISettingsModalProps> = ({
<p className={cn('body-xs-regular pb-0.5 text-text-tertiary')}> <p className={cn('body-xs-regular pb-0.5 text-text-tertiary')}>
<Trans <Trans
i18nKey={`${prefixSettings}.more.privacyPolicyTip`} i18nKey={`${prefixSettings}.more.privacyPolicyTip`}
components={{ privacyPolicyLink: <Link href={'https://docs.dify.ai/user-agreement/privacy-policy'} target='_blank' rel='noopener noreferrer' className='text-text-accent' /> }} components={{ privacyPolicyLink: <Link href={'https://dify.ai/privacy'} target='_blank' rel='noopener noreferrer' className='text-text-accent' /> }}
/> />
</p> </p>
<Input <Input

@ -234,4 +234,6 @@ const Answer: FC<AnswerProps> = ({
) )
} }
export default memo(Answer) export default memo(Answer, (prevProps, nextProps) =>
prevProps.responding === false && nextProps.responding === false,
)

@ -0,0 +1,11 @@
const IndeterminateIcon = () => {
return (
<div data-testid='indeterminate-icon'>
<svg xmlns="http://www.w3.org/2000/svg" width="12" height="12" viewBox="0 0 12 12" fill="none">
<path d="M2.5 6H9.5" stroke="currentColor" strokeWidth="1.5" strokeLinecap="round"/>
</svg>
</div>
)
}
export default IndeterminateIcon

@ -1,5 +0,0 @@
<svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg">
<g id="check">
<path id="Vector 1" d="M2.5 6H9.5" stroke="white" stroke-width="1.5" stroke-linecap="round"/>
</g>
</svg>

Before

Width:  |  Height:  |  Size: 217 B

@ -1,10 +0,0 @@
.mixed {
background: var(--color-components-checkbox-bg) url(./assets/mixed.svg) center center no-repeat;
background-size: 12px 12px;
border: none;
}
.checked.disabled {
background-color: #d0d5dd;
border-color: #d0d5dd;
}

@ -0,0 +1,67 @@
import { fireEvent, render, screen } from '@testing-library/react'
import Checkbox from './index'
describe('Checkbox Component', () => {
const mockProps = {
id: 'test',
}
it('renders unchecked checkbox by default', () => {
render(<Checkbox {...mockProps} />)
const checkbox = screen.getByTestId('checkbox-test')
expect(checkbox).toBeInTheDocument()
expect(checkbox).not.toHaveClass('bg-components-checkbox-bg')
})
it('renders checked checkbox when checked prop is true', () => {
render(<Checkbox {...mockProps} checked />)
const checkbox = screen.getByTestId('checkbox-test')
expect(checkbox).toHaveClass('bg-components-checkbox-bg')
expect(screen.getByTestId('check-icon-test')).toBeInTheDocument()
})
it('renders indeterminate state correctly', () => {
render(<Checkbox {...mockProps} indeterminate />)
expect(screen.getByTestId('indeterminate-icon')).toBeInTheDocument()
})
it('handles click events when not disabled', () => {
const onCheck = jest.fn()
render(<Checkbox {...mockProps} onCheck={onCheck} />)
const checkbox = screen.getByTestId('checkbox-test')
fireEvent.click(checkbox)
expect(onCheck).toHaveBeenCalledTimes(1)
})
it('does not handle click events when disabled', () => {
const onCheck = jest.fn()
render(<Checkbox {...mockProps} disabled onCheck={onCheck} />)
const checkbox = screen.getByTestId('checkbox-test')
fireEvent.click(checkbox)
expect(onCheck).not.toHaveBeenCalled()
expect(checkbox).toHaveClass('cursor-not-allowed')
})
it('applies custom className when provided', () => {
const customClass = 'custom-class'
render(<Checkbox {...mockProps} className={customClass} />)
const checkbox = screen.getByTestId('checkbox-test')
expect(checkbox).toHaveClass(customClass)
})
it('applies correct styles for disabled checked state', () => {
render(<Checkbox {...mockProps} checked disabled />)
const checkbox = screen.getByTestId('checkbox-test')
expect(checkbox).toHaveClass('bg-components-checkbox-bg-disabled-checked')
expect(checkbox).toHaveClass('cursor-not-allowed')
})
it('applies correct styles for disabled unchecked state', () => {
render(<Checkbox {...mockProps} disabled />)
const checkbox = screen.getByTestId('checkbox-test')
expect(checkbox).toHaveClass('bg-components-checkbox-bg-disabled')
expect(checkbox).toHaveClass('cursor-not-allowed')
})
})

@ -1,48 +1,49 @@
import { RiCheckLine } from '@remixicon/react' import { RiCheckLine } from '@remixicon/react'
import s from './index.module.css'
import cn from '@/utils/classnames' import cn from '@/utils/classnames'
import IndeterminateIcon from './assets/indeterminate-icon'
type CheckboxProps = { type CheckboxProps = {
id?: string
checked?: boolean checked?: boolean
onCheck?: () => void onCheck?: () => void
className?: string className?: string
disabled?: boolean disabled?: boolean
mixed?: boolean indeterminate?: boolean
} }
const Checkbox = ({ checked, onCheck, className, disabled, mixed }: CheckboxProps) => { const Checkbox = ({
if (!checked) { id,
return ( checked,
<div onCheck,
className={cn( className,
'h-4 w-4 cursor-pointer rounded-[4px] border border-components-checkbox-border bg-components-checkbox-bg-unchecked shadow-xs hover:border-components-checkbox-border-hover', disabled,
mixed ? s.mixed : 'hover:bg-components-checkbox-bg-unchecked-hover', indeterminate,
disabled && 'cursor-not-allowed border-components-checkbox-border-disabled bg-components-checkbox-bg-disabled hover:border-components-checkbox-border-disabled hover:bg-components-checkbox-bg-disabled', }: CheckboxProps) => {
className, const checkClassName = (checked || indeterminate)
)} ? 'bg-components-checkbox-bg text-components-checkbox-icon hover:bg-components-checkbox-bg-hover'
onClick={() => { : 'border border-components-checkbox-border bg-components-checkbox-bg-unchecked hover:bg-components-checkbox-bg-unchecked-hover hover:border-components-checkbox-border-hover'
if (disabled) const disabledClassName = (checked || indeterminate)
return ? 'cursor-not-allowed bg-components-checkbox-bg-disabled-checked text-components-checkbox-icon-disabled hover:bg-components-checkbox-bg-disabled-checked'
onCheck?.() : 'cursor-not-allowed border-components-checkbox-border-disabled bg-components-checkbox-bg-disabled hover:border-components-checkbox-border-disabled hover:bg-components-checkbox-bg-disabled'
}}
></div>
)
}
return ( return (
<div <div
id={id}
className={cn( className={cn(
'flex h-4 w-4 cursor-pointer items-center justify-center rounded-[4px] bg-components-checkbox-bg text-components-checkbox-icon shadow-xs hover:bg-components-checkbox-bg-hover', 'flex h-4 w-4 cursor-pointer items-center justify-center rounded-[4px] shadow-xs shadow-shadow-shadow-3',
disabled && 'cursor-not-allowed bg-components-checkbox-bg-disabled-checked text-components-checkbox-icon-disabled hover:bg-components-checkbox-bg-disabled-checked', checkClassName,
disabled && disabledClassName,
className, className,
)} )}
onClick={() => { onClick={() => {
if (disabled) if (disabled)
return return
onCheck?.() onCheck?.()
}} }}
data-testid={`checkbox-${id}`}
> >
<RiCheckLine className={cn('h-3 w-3')} /> {!checked && indeterminate && <IndeterminateIcon />}
{checked && <RiCheckLine className='h-3 w-3' data-testid={`check-icon-${id}`} />}
</div> </div>
) )
} }

@ -0,0 +1,43 @@
import cn from '@/utils/classnames'
import { useFieldContext } from '../..'
import Checkbox from '../../../checkbox'
type CheckboxFieldProps = {
label: string;
labelClassName?: string;
}
const CheckboxField = ({
label,
labelClassName,
}: CheckboxFieldProps) => {
const field = useFieldContext<boolean>()
return (
<div className='flex gap-2'>
<div className='flex h-6 shrink-0 items-center'>
<Checkbox
id={field.name}
checked={field.state.value}
onCheck={() => {
field.handleChange(!field.state.value)
}}
/>
</div>
<label
htmlFor={field.name}
className={cn(
'system-sm-medium grow cursor-pointer pt-1 text-text-secondary',
labelClassName,
)}
onClick={() => {
field.handleChange(!field.state.value)
}}
>
{label}
</label>
</div>
)
}
export default CheckboxField

@ -0,0 +1,49 @@
import React from 'react'
import { useFieldContext } from '../..'
import Label from '../label'
import cn from '@/utils/classnames'
import type { InputNumberProps } from '../../../input-number'
import { InputNumber } from '../../../input-number'
type TextFieldProps = {
label: string
isRequired?: boolean
showOptional?: boolean
tooltip?: string
className?: string
labelClassName?: string
} & Omit<InputNumberProps, 'id' | 'value' | 'onChange' | 'onBlur'>
const NumberInputField = ({
label,
isRequired,
showOptional,
tooltip,
className,
labelClassName,
...inputProps
}: TextFieldProps) => {
const field = useFieldContext<number | undefined>()
return (
<div className={cn('flex flex-col gap-y-0.5', className)}>
<Label
htmlFor={field.name}
label={label}
isRequired={isRequired}
showOptional={showOptional}
tooltip={tooltip}
className={labelClassName}
/>
<InputNumber
id={field.name}
value={field.state.value}
onChange={value => field.handleChange(value)}
onBlur={field.handleBlur}
{...inputProps}
/>
</div>
)
}
export default NumberInputField

@ -0,0 +1,34 @@
import cn from '@/utils/classnames'
import { useFieldContext } from '../..'
import Label from '../label'
import ConfigSelect from '@/app/components/app/configuration/config-var/config-select'
type OptionsFieldProps = {
label: string;
className?: string;
labelClassName?: string;
}
const OptionsField = ({
label,
className,
labelClassName,
}: OptionsFieldProps) => {
const field = useFieldContext<string[]>()
return (
<div className={cn('flex flex-col gap-y-0.5', className)}>
<Label
htmlFor={field.name}
label={label}
className={labelClassName}
/>
<ConfigSelect
options={field.state.value}
onChange={value => field.handleChange(value)}
/>
</div>
)
}
export default OptionsField

@ -0,0 +1,51 @@
import cn from '@/utils/classnames'
import { useFieldContext } from '../..'
import PureSelect from '../../../select/pure'
import Label from '../label'
type SelectOption = {
value: string
label: string
}
type SelectFieldProps = {
label: string
options: SelectOption[]
isRequired?: boolean
showOptional?: boolean
tooltip?: string
className?: string
labelClassName?: string
}
const SelectField = ({
label,
options,
isRequired,
showOptional,
tooltip,
className,
labelClassName,
}: SelectFieldProps) => {
const field = useFieldContext<string>()
return (
<div className={cn('flex flex-col gap-y-0.5', className)}>
<Label
htmlFor={field.name}
label={label}
isRequired={isRequired}
showOptional={showOptional}
tooltip={tooltip}
className={labelClassName}
/>
<PureSelect
value={field.state.value}
options={options}
onChange={value => field.handleChange(value)}
/>
</div>
)
}
export default SelectField

@ -0,0 +1,48 @@
import React from 'react'
import { useFieldContext } from '../..'
import Input, { type InputProps } from '../../../input'
import Label from '../label'
import cn from '@/utils/classnames'
type TextFieldProps = {
label: string
isRequired?: boolean
showOptional?: boolean
tooltip?: string
className?: string
labelClassName?: string
} & Omit<InputProps, 'className' | 'onChange' | 'onBlur' | 'value' | 'id'>
const TextField = ({
label,
isRequired,
showOptional,
tooltip,
className,
labelClassName,
...inputProps
}: TextFieldProps) => {
const field = useFieldContext<string>()
return (
<div className={cn('flex flex-col gap-y-0.5', className)}>
<Label
htmlFor={field.name}
label={label}
isRequired={isRequired}
showOptional={showOptional}
tooltip={tooltip}
className={labelClassName}
/>
<Input
id={field.name}
value={field.state.value}
onChange={e => field.handleChange(e.target.value)}
onBlur={field.handleBlur}
{...inputProps}
/>
</div>
)
}
export default TextField

@ -0,0 +1,25 @@
import { useStore } from '@tanstack/react-form'
import { useFormContext } from '../..'
import Button, { type ButtonProps } from '../../../button'
type SubmitButtonProps = Omit<ButtonProps, 'disabled' | 'loading' | 'onClick'>
const SubmitButton = ({ ...buttonProps }: SubmitButtonProps) => {
const form = useFormContext()
const [isSubmitting, canSubmit] = useStore(form.store, state => [
state.isSubmitting,
state.canSubmit,
])
return (
<Button
disabled={isSubmitting || !canSubmit}
loading={isSubmitting}
onClick={() => form.handleSubmit()}
{...buttonProps}
/>
)
}
export default SubmitButton

@ -0,0 +1,53 @@
import { fireEvent, render, screen } from '@testing-library/react'
import Label from './label'
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
describe('Label Component', () => {
const defaultProps = {
htmlFor: 'test-input',
label: 'Test Label',
}
it('renders basic label correctly', () => {
render(<Label {...defaultProps} />)
const label = screen.getByTestId('label')
expect(label).toBeInTheDocument()
expect(label).toHaveAttribute('for', 'test-input')
})
it('shows optional text when showOptional is true', () => {
render(<Label {...defaultProps} showOptional />)
expect(screen.getByText('common.label.optional')).toBeInTheDocument()
})
it('shows required asterisk when isRequired is true', () => {
render(<Label {...defaultProps} isRequired />)
expect(screen.getByText('*')).toBeInTheDocument()
})
it('renders tooltip when tooltip prop is provided', () => {
const tooltipText = 'Test Tooltip'
render(<Label {...defaultProps} tooltip={tooltipText} />)
const trigger = screen.getByTestId('test-input-tooltip')
fireEvent.mouseEnter(trigger)
expect(screen.getByText(tooltipText)).toBeInTheDocument()
})
it('applies custom className when provided', () => {
const customClass = 'custom-label'
render(<Label {...defaultProps} className={customClass} />)
const label = screen.getByTestId('label')
expect(label).toHaveClass(customClass)
})
it('does not show optional text and required asterisk simultaneously', () => {
render(<Label {...defaultProps} isRequired showOptional />)
expect(screen.queryByText('common.label.optional')).not.toBeInTheDocument()
expect(screen.getByText('*')).toBeInTheDocument()
})
})

@ -0,0 +1,48 @@
import cn from '@/utils/classnames'
import Tooltip from '../../tooltip'
import { useTranslation } from 'react-i18next'
export type LabelProps = {
htmlFor: string
label: string
isRequired?: boolean
showOptional?: boolean
tooltip?: string
className?: string
}
const Label = ({
htmlFor,
label,
isRequired,
showOptional,
tooltip,
className,
}: LabelProps) => {
const { t } = useTranslation()
return (
<div className='flex h-6 items-center'>
<label
data-testid='label'
htmlFor={htmlFor}
className={cn('system-sm-medium text-text-secondary', className)}
>
{label}
</label>
{!isRequired && showOptional && <div className='system-xs-regular ml-1 text-text-tertiary'>{t('common.label.optional')}</div>}
{isRequired && <div className='system-xs-regular ml-1 text-text-destructive-secondary'>*</div>}
{tooltip && (
<Tooltip
popupContent={
<div className='w-[200px]'>{tooltip}</div>
}
triggerClassName='ml-0.5 w-4 h-4'
triggerTestId={`${htmlFor}-tooltip`}
/>
)}
</div>
)
}
export default Label

@ -0,0 +1,35 @@
import { withForm } from '../..'
import { demoFormOpts } from './shared-options'
import { ContactMethods } from './types'
const ContactFields = withForm({
...demoFormOpts,
render: ({ form }) => {
return (
<div className='my-2'>
<h3 className='title-lg-bold text-text-primary'>Contacts</h3>
<div className='flex flex-col gap-4'>
<form.AppField
name='contact.email'
children={field => <field.TextField label='Email' />}
/>
<form.AppField
name='contact.phone'
children={field => <field.TextField label='Phone' />}
/>
<form.AppField
name='contact.preferredContactMethod'
children={field => (
<field.SelectField
label='Preferred Contact Method'
options={ContactMethods}
/>
)}
/>
</div>
</div>
)
},
})
export default ContactFields

@ -0,0 +1,68 @@
import { useStore } from '@tanstack/react-form'
import { useAppForm } from '../..'
import ContactFields from './contact-fields'
import { demoFormOpts } from './shared-options'
import { UserSchema } from './types'
const DemoForm = () => {
const form = useAppForm({
...demoFormOpts,
validators: {
onSubmit: ({ value }) => {
// Validate the entire form
const result = UserSchema.safeParse(value)
if (!result.success) {
const issues = result.error.issues
console.log('Validation errors:', issues)
return issues[0].message
}
return undefined
},
},
onSubmit: ({ value }) => {
console.log('Form submitted:', value)
},
})
const name = useStore(form.store, state => state.values.name)
return (
<form
className='flex w-[400px] flex-col gap-4'
onSubmit={(e) => {
e.preventDefault()
e.stopPropagation()
form.handleSubmit()
}}
>
<form.AppField
name='name'
children={field => (
<field.TextField label='Name' />
)}
/>
<form.AppField
name='surname'
children={field => (
<field.TextField label='Surname' />
)}
/>
<form.AppField
name='isAcceptingTerms'
children={field => (
<field.CheckboxField label='I accept the terms and conditions.' />
)}
/>
{
!!name && (
<ContactFields form={form} />
)
}
<form.AppForm>
<form.SubmitButton>Submit</form.SubmitButton>
</form.AppForm>
</form>
)
}
export default DemoForm

@ -0,0 +1,14 @@
import { formOptions } from '@tanstack/react-form'
export const demoFormOpts = formOptions({
defaultValues: {
name: '',
surname: '',
isAcceptingTerms: false,
contact: {
email: '',
phone: '',
preferredContactMethod: 'email',
},
},
})

@ -0,0 +1,34 @@
import { z } from 'zod'
const ContactMethod = z.union([
z.literal('email'),
z.literal('phone'),
z.literal('whatsapp'),
z.literal('sms'),
])
export const ContactMethods = ContactMethod.options.map(({ value }) => ({
value,
label: value.charAt(0).toUpperCase() + value.slice(1),
}))
export const UserSchema = z.object({
name: z
.string()
.regex(/^[A-Z]/, 'Name must start with a capital letter')
.min(3, 'Name must be at least 3 characters long'),
surname: z
.string()
.min(3, 'Surname must be at least 3 characters long')
.regex(/^[A-Z]/, 'Surname must start with a capital letter'),
isAcceptingTerms: z.boolean().refine(val => val, {
message: 'You must accept the terms and conditions',
}),
contact: z.object({
email: z.string().email('Invalid email address'),
phone: z.string().optional(),
preferredContactMethod: ContactMethod,
}),
})
export type User = z.infer<typeof UserSchema>

@ -0,0 +1,25 @@
import { createFormHook, createFormHookContexts } from '@tanstack/react-form'
import TextField from './components/field/text'
import NumberInputField from './components/field/number-input'
import CheckboxField from './components/field/checkbox'
import SelectField from './components/field/select'
import OptionsField from './components/field/options'
import SubmitButton from './components/form/submit-button'
export const { fieldContext, useFieldContext, formContext, useFormContext }
= createFormHookContexts()
export const { useAppForm, withForm } = createFormHook({
fieldComponents: {
TextField,
NumberInputField,
CheckboxField,
SelectField,
OptionsField,
},
formComponents: {
SubmitButton,
},
fieldContext,
formContext,
})

@ -0,0 +1,5 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<g id="arrow-down-round-fill">
<path id="Vector" d="M6.02913 6.23572C5.08582 6.23572 4.56482 7.33027 5.15967 8.06239L7.13093 10.4885C7.57922 11.0403 8.42149 11.0403 8.86986 10.4885L10.8411 8.06239C11.4359 7.33027 10.9149 6.23572 9.97158 6.23572H6.02913Z" fill="#101828"/>
</g>
</svg>

After

Width:  |  Height:  |  Size: 380 B

@ -0,0 +1,36 @@
{
"icon": {
"type": "element",
"isRootNode": true,
"name": "svg",
"attributes": {
"width": "16",
"height": "16",
"viewBox": "0 0 16 16",
"fill": "none",
"xmlns": "http://www.w3.org/2000/svg"
},
"children": [
{
"type": "element",
"name": "g",
"attributes": {
"id": "arrow-down-round-fill"
},
"children": [
{
"type": "element",
"name": "path",
"attributes": {
"id": "Vector",
"d": "M6.02913 6.23572C5.08582 6.23572 4.56482 7.33027 5.15967 8.06239L7.13093 10.4885C7.57922 11.0403 8.42149 11.0403 8.86986 10.4885L10.8411 8.06239C11.4359 7.33027 10.9149 6.23572 9.97158 6.23572H6.02913Z",
"fill": "currentColor"
},
"children": []
}
]
}
]
},
"name": "ArrowDownRoundFill"
}

@ -0,0 +1,20 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import data from './ArrowDownRoundFill.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconData } from '@/app/components/base/icons/IconBase'
const Icon = (
{
ref,
...props
}: React.SVGProps<SVGSVGElement> & {
ref?: React.RefObject<React.MutableRefObject<HTMLOrSVGElement>>;
},
) => <IconBase {...props} ref={ref} data={data as IconData} />
Icon.displayName = 'ArrowDownRoundFill'
export default Icon

@ -1,4 +1,5 @@
export { default as AnswerTriangle } from './AnswerTriangle' export { default as AnswerTriangle } from './AnswerTriangle'
export { default as ArrowDownRoundFill } from './ArrowDownRoundFill'
export { default as CheckCircle } from './CheckCircle' export { default as CheckCircle } from './CheckCircle'
export { default as CheckDone01 } from './CheckDone01' export { default as CheckDone01 } from './CheckDone01'
export { default as Download02 } from './Download02' export { default as Download02 } from './Download02'

@ -0,0 +1,97 @@
import { fireEvent, render, screen } from '@testing-library/react'
import { InputNumber } from './index'
jest.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
describe('InputNumber Component', () => {
const defaultProps = {
onChange: jest.fn(),
}
afterEach(() => {
jest.clearAllMocks()
})
it('renders input with default values', () => {
render(<InputNumber {...defaultProps} />)
const input = screen.getByRole('textbox')
expect(input).toBeInTheDocument()
})
it('handles increment button click', () => {
render(<InputNumber {...defaultProps} value={5} />)
const incrementBtn = screen.getByRole('button', { name: /increment/i })
fireEvent.click(incrementBtn)
expect(defaultProps.onChange).toHaveBeenCalledWith(6)
})
it('handles decrement button click', () => {
render(<InputNumber {...defaultProps} value={5} />)
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
fireEvent.click(decrementBtn)
expect(defaultProps.onChange).toHaveBeenCalledWith(4)
})
it('respects max value constraint', () => {
render(<InputNumber {...defaultProps} value={10} max={10} />)
const incrementBtn = screen.getByRole('button', { name: /increment/i })
fireEvent.click(incrementBtn)
expect(defaultProps.onChange).not.toHaveBeenCalled()
})
it('respects min value constraint', () => {
render(<InputNumber {...defaultProps} value={0} min={0} />)
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
fireEvent.click(decrementBtn)
expect(defaultProps.onChange).not.toHaveBeenCalled()
})
it('handles direct input changes', () => {
render(<InputNumber {...defaultProps} />)
const input = screen.getByRole('textbox')
fireEvent.change(input, { target: { value: '42' } })
expect(defaultProps.onChange).toHaveBeenCalledWith(42)
})
it('handles empty input', () => {
render(<InputNumber {...defaultProps} value={0} />)
const input = screen.getByRole('textbox')
fireEvent.change(input, { target: { value: '' } })
expect(defaultProps.onChange).toHaveBeenCalledWith(undefined)
})
it('handles invalid input', () => {
render(<InputNumber {...defaultProps} />)
const input = screen.getByRole('textbox')
fireEvent.change(input, { target: { value: 'abc' } })
expect(defaultProps.onChange).not.toHaveBeenCalled()
})
it('displays unit when provided', () => {
const unit = 'px'
render(<InputNumber {...defaultProps} unit={unit} />)
expect(screen.getByText(unit)).toBeInTheDocument()
})
it('disables controls when disabled prop is true', () => {
render(<InputNumber {...defaultProps} disabled />)
const input = screen.getByRole('textbox')
const incrementBtn = screen.getByRole('button', { name: /increment/i })
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
expect(input).toBeDisabled()
expect(incrementBtn).toBeDisabled()
expect(decrementBtn).toBeDisabled()
})
})

@ -8,7 +8,7 @@ export type InputNumberProps = {
value?: number value?: number
onChange: (value?: number) => void onChange: (value?: number) => void
amount?: number amount?: number
size?: 'sm' | 'md' size?: 'regular' | 'large'
max?: number max?: number
min?: number min?: number
defaultValue?: number defaultValue?: number
@ -19,14 +19,12 @@ export type InputNumberProps = {
} & Omit<InputProps, 'value' | 'onChange' | 'size' | 'min' | 'max' | 'defaultValue'> } & Omit<InputProps, 'value' | 'onChange' | 'size' | 'min' | 'max' | 'defaultValue'>
export const InputNumber: FC<InputNumberProps> = (props) => { export const InputNumber: FC<InputNumberProps> = (props) => {
const { unit, className, onChange, amount = 1, value, size = 'md', max, min, defaultValue, wrapClassName, controlWrapClassName, controlClassName, disabled, ...rest } = props const { unit, className, onChange, amount = 1, value, size = 'regular', max, min, defaultValue, wrapClassName, controlWrapClassName, controlClassName, disabled, ...rest } = props
const isValidValue = (v: number) => { const isValidValue = (v: number) => {
if (max && v > max) if (typeof max === 'number' && v > max)
return false return false
if (min && v < min) return !(typeof min === 'number' && v < min)
return false
return true
} }
const inc = () => { const inc = () => {
@ -76,29 +74,39 @@ export const InputNumber: FC<InputNumberProps> = (props) => {
onChange(parsed) onChange(parsed)
}} }}
unit={unit} unit={unit}
size={size}
/> />
<div className={classNames( <div className={classNames(
'flex flex-col bg-components-input-bg-normal rounded-r-md border-l border-divider-subtle text-text-tertiary focus:shadow-xs', 'flex flex-col bg-components-input-bg-normal rounded-r-md border-l border-divider-subtle text-text-tertiary focus:shadow-xs',
disabled && 'opacity-50 cursor-not-allowed', disabled && 'opacity-50 cursor-not-allowed',
controlWrapClassName)} controlWrapClassName)}
> >
<button onClick={inc} disabled={disabled} className={classNames( <button
size === 'sm' ? 'pt-1' : 'pt-1.5', type='button'
'px-1.5 hover:bg-components-input-bg-hover', onClick={inc}
disabled && 'cursor-not-allowed hover:bg-transparent', disabled={disabled}
controlClassName, aria-label='increment'
)}> className={classNames(
size === 'regular' ? 'pt-1' : 'pt-1.5',
'px-1.5 hover:bg-components-input-bg-hover',
disabled && 'cursor-not-allowed hover:bg-transparent',
controlClassName,
)}
>
<RiArrowUpSLine className='size-3' /> <RiArrowUpSLine className='size-3' />
</button> </button>
<button <button
type='button'
onClick={dec} onClick={dec}
disabled={disabled} disabled={disabled}
aria-label='decrement'
className={classNames( className={classNames(
size === 'sm' ? 'pb-1' : 'pb-1.5', size === 'regular' ? 'pb-1' : 'pb-1.5',
'px-1.5 hover:bg-components-input-bg-hover', 'px-1.5 hover:bg-components-input-bg-hover',
disabled && 'cursor-not-allowed hover:bg-transparent', disabled && 'cursor-not-allowed hover:bg-transparent',
controlClassName, controlClassName,
)}> )}
>
<RiArrowDownSLine className='size-3' /> <RiArrowDownSLine className='size-3' />
</button> </button>
</div> </div>

@ -30,7 +30,7 @@ export type InputProps = {
wrapperClassName?: string wrapperClassName?: string
styleCss?: CSSProperties styleCss?: CSSProperties
unit?: string unit?: string
} & React.InputHTMLAttributes<HTMLInputElement> & VariantProps<typeof inputVariants> } & Omit<React.InputHTMLAttributes<HTMLInputElement>, 'size'> & VariantProps<typeof inputVariants>
const Input = ({ const Input = ({
size, size,

@ -0,0 +1,37 @@
import abcjs from 'abcjs'
import { useEffect, useRef } from 'react'
import 'abcjs/abcjs-audio.css'
const MarkdownMusic = ({ children }: { children: React.ReactNode }) => {
const containerRef = useRef<HTMLDivElement>(null)
const controlsRef = useRef<HTMLDivElement>(null)
useEffect(() => {
if (containerRef.current && controlsRef.current) {
if (typeof children === 'string') {
const visualObjs = abcjs.renderAbc(containerRef.current, children, {
add_classes: true, // Add classes to SVG elements for cursor tracking
responsive: 'resize', // Make notation responsive
})
const synthControl = new abcjs.synth.SynthController()
synthControl.load(controlsRef.current, {}, { displayPlay: true })
const synth = new abcjs.synth.CreateSynth()
const visualObj = visualObjs[0]
synth.init({ visualObj }).then(() => {
synthControl.setTune(visualObj, false)
})
containerRef.current.style.overflow = 'auto'
}
}
}, [children])
return (
<div style={{ minWidth: '100%', overflow: 'auto' }}>
<div ref={containerRef} />
<div ref={controlsRef} />
</div>
)
}
MarkdownMusic.displayName = 'MarkdownMusic'
export default MarkdownMusic

@ -23,6 +23,7 @@ import VideoGallery from '@/app/components/base/video-gallery'
import AudioGallery from '@/app/components/base/audio-gallery' import AudioGallery from '@/app/components/base/audio-gallery'
import MarkdownButton from '@/app/components/base/markdown-blocks/button' import MarkdownButton from '@/app/components/base/markdown-blocks/button'
import MarkdownForm from '@/app/components/base/markdown-blocks/form' import MarkdownForm from '@/app/components/base/markdown-blocks/form'
import MarkdownMusic from '@/app/components/base/markdown-blocks/music'
import ThinkBlock from '@/app/components/base/markdown-blocks/think-block' import ThinkBlock from '@/app/components/base/markdown-blocks/think-block'
import { Theme } from '@/types/app' import { Theme } from '@/types/app'
import useTheme from '@/hooks/use-theme' import useTheme from '@/hooks/use-theme'
@ -51,6 +52,7 @@ const capitalizationLanguageNameMap: Record<string, string> = {
json: 'JSON', json: 'JSON',
latex: 'Latex', latex: 'Latex',
svg: 'SVG', svg: 'SVG',
abc: 'ABC',
} }
const getCorrectCapitalizationLanguageName = (language: string) => { const getCorrectCapitalizationLanguageName = (language: string) => {
if (!language) if (!language)
@ -85,9 +87,11 @@ const preprocessLaTeX = (content: string) => {
} }
const preprocessThinkTag = (content: string) => { const preprocessThinkTag = (content: string) => {
const thinkOpenTagRegex = /<think>\n/g
const thinkCloseTagRegex = /\n<\/think>/g
return flow([ return flow([
(str: string) => str.replace('<think>\n', '<details data-think=true>\n'), (str: string) => str.replace(thinkOpenTagRegex, '<details data-think=true>\n'),
(str: string) => str.replace('\n</think>', '\n[ENDTHINKFLAG]</details>'), (str: string) => str.replace(thinkCloseTagRegex, '\n[ENDTHINKFLAG]</details>'),
])(content) ])(content)
} }
@ -135,45 +139,54 @@ const CodeBlock: any = memo(({ inline, className, children, ...props }: any) =>
const renderCodeContent = useMemo(() => { const renderCodeContent = useMemo(() => {
const content = String(children).replace(/\n$/, '') const content = String(children).replace(/\n$/, '')
if (language === 'mermaid' && isSVG) { switch (language) {
return <Flowchart PrimitiveCode={content} /> case 'mermaid':
} if (isSVG)
else if (language === 'echarts') { return <Flowchart PrimitiveCode={content} />
return ( break
<div style={{ minHeight: '350px', minWidth: '100%', overflowX: 'scroll' }}> case 'echarts':
return (
<div style={{ minHeight: '350px', minWidth: '100%', overflowX: 'scroll' }}>
<ErrorBoundary>
<ReactEcharts option={chartData} style={{ minWidth: '700px' }} />
</ErrorBoundary>
</div>
)
case 'svg':
if (isSVG) {
return (
<ErrorBoundary>
<SVGRenderer content={content} />
</ErrorBoundary>
)
}
break
case 'abc':
return (
<ErrorBoundary> <ErrorBoundary>
<ReactEcharts option={chartData} style={{ minWidth: '700px' }} /> <MarkdownMusic children={content} />
</ErrorBoundary> </ErrorBoundary>
</div> )
) default:
} return (
else if (language === 'svg' && isSVG) { <SyntaxHighlighter
return ( {...props}
<ErrorBoundary> style={theme === Theme.light ? atelierHeathLight : atelierHeathDark}
<SVGRenderer content={content} /> customStyle={{
</ErrorBoundary> paddingLeft: 12,
) borderBottomLeftRadius: '10px',
} borderBottomRightRadius: '10px',
else { backgroundColor: 'var(--color-components-input-bg-normal)',
return ( }}
<SyntaxHighlighter language={match?.[1]}
{...props} showLineNumbers
style={theme === Theme.light ? atelierHeathLight : atelierHeathDark} PreTag="div"
customStyle={{ >
paddingLeft: 12, {content}
borderBottomLeftRadius: '10px', </SyntaxHighlighter>
borderBottomRightRadius: '10px', )
backgroundColor: 'var(--color-components-input-bg-normal)',
}}
language={match?.[1]}
showLineNumbers
PreTag="div"
>
{content}
</SyntaxHighlighter>
)
} }
}, [language, match, props, children, chartData, isSVG]) }, [children, language, isSVG, chartData, props, theme, match])
if (inline || !match) if (inline || !match)
return <code {...props} className={className}>{children}</code> return <code {...props} className={className}>{children}</code>

@ -54,7 +54,7 @@ const ParamItem: FC<Props> = ({ className, id, name, noTooltip, tip, step = 0.1,
max={max} max={max}
step={step} step={step}
amount={step} amount={step}
size='sm' size='regular'
value={value} value={value}
onChange={(value) => { onChange={(value) => {
onChange(id, value) onChange(id, value)

@ -14,7 +14,7 @@ export class HistoryBlockNode extends DecoratorNode<React.JSX.Element> {
} }
static clone(node: HistoryBlockNode): HistoryBlockNode { static clone(node: HistoryBlockNode): HistoryBlockNode {
return new HistoryBlockNode(node.__roleName, node.__onEditRole) return new HistoryBlockNode(node.__roleName, node.__onEditRole, node.__key)
} }
constructor(roleName: RoleName, onEditRole: () => void, key?: NodeKey) { constructor(roleName: RoleName, onEditRole: () => void, key?: NodeKey) {

@ -11,6 +11,7 @@ import { mergeRegister } from '@lexical/utils'
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import { import {
RiErrorWarningFill, RiErrorWarningFill,
RiMoreLine,
} from '@remixicon/react' } from '@remixicon/react'
import { useSelectOrDelete } from '../../hooks' import { useSelectOrDelete } from '../../hooks'
import type { WorkflowNodesMap } from './node' import type { WorkflowNodesMap } from './node'
@ -27,26 +28,35 @@ import { Line3 } from '@/app/components/base/icons/src/public/common'
import { isConversationVar, isENV, isSystemVar } from '@/app/components/workflow/nodes/_base/components/variable/utils' import { isConversationVar, isENV, isSystemVar } from '@/app/components/workflow/nodes/_base/components/variable/utils'
import Tooltip from '@/app/components/base/tooltip' import Tooltip from '@/app/components/base/tooltip'
import { isExceptionVariable } from '@/app/components/workflow/utils' import { isExceptionVariable } from '@/app/components/workflow/utils'
import VarFullPathPanel from '@/app/components/workflow/nodes/_base/components/variable/var-full-path-panel'
import { Type } from '@/app/components/workflow/nodes/llm/types'
import type { ValueSelector } from '@/app/components/workflow/types'
type WorkflowVariableBlockComponentProps = { type WorkflowVariableBlockComponentProps = {
nodeKey: string nodeKey: string
variables: string[] variables: string[]
workflowNodesMap: WorkflowNodesMap workflowNodesMap: WorkflowNodesMap
getVarType?: (payload: {
nodeId: string,
valueSelector: ValueSelector,
}) => Type
} }
const WorkflowVariableBlockComponent = ({ const WorkflowVariableBlockComponent = ({
nodeKey, nodeKey,
variables, variables,
workflowNodesMap = {}, workflowNodesMap = {},
getVarType,
}: WorkflowVariableBlockComponentProps) => { }: WorkflowVariableBlockComponentProps) => {
const { t } = useTranslation() const { t } = useTranslation()
const [editor] = useLexicalComposerContext() const [editor] = useLexicalComposerContext()
const [ref, isSelected] = useSelectOrDelete(nodeKey, DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND) const [ref, isSelected] = useSelectOrDelete(nodeKey, DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND)
const variablesLength = variables.length const variablesLength = variables.length
const isShowAPart = variablesLength > 2
const varName = ( const varName = (
() => { () => {
const isSystem = isSystemVar(variables) const isSystem = isSystemVar(variables)
const varName = variablesLength >= 3 ? (variables).slice(-2).join('.') : variables[variablesLength - 1] const varName = variables[variablesLength - 1]
return `${isSystem ? 'sys.' : ''}${varName}` return `${isSystem ? 'sys.' : ''}${varName}`
} }
)() )()
@ -76,7 +86,7 @@ const WorkflowVariableBlockComponent = ({
const Item = ( const Item = (
<div <div
className={cn( className={cn(
'group/wrap relative mx-0.5 flex h-[18px] select-none items-center rounded-[5px] border pl-0.5 pr-[3px]', 'group/wrap relative mx-0.5 flex h-[18px] select-none items-center rounded-[5px] border pl-0.5 pr-[3px] hover:border-state-accent-solid hover:bg-state-accent-hover',
isSelected ? ' border-state-accent-solid bg-state-accent-hover' : ' border-components-panel-border-subtle bg-components-badge-white-to-dark', isSelected ? ' border-state-accent-solid bg-state-accent-hover' : ' border-components-panel-border-subtle bg-components-badge-white-to-dark',
!node && !isEnv && !isChatVar && '!border-state-destructive-solid !bg-state-destructive-hover', !node && !isEnv && !isChatVar && '!border-state-destructive-solid !bg-state-destructive-hover',
)} )}
@ -99,6 +109,13 @@ const WorkflowVariableBlockComponent = ({
<Line3 className='mr-0.5 text-divider-deep'></Line3> <Line3 className='mr-0.5 text-divider-deep'></Line3>
</div> </div>
)} )}
{isShowAPart && (
<div className='flex items-center'>
<RiMoreLine className='h-3 w-3 text-text-secondary' />
<Line3 className='mr-0.5 text-divider-deep'></Line3>
</div>
)}
<div className='flex items-center text-text-accent'> <div className='flex items-center text-text-accent'>
{!isEnv && !isChatVar && <Variable02 className={cn('h-3.5 w-3.5 shrink-0', isException && 'text-text-warning')} />} {!isEnv && !isChatVar && <Variable02 className={cn('h-3.5 w-3.5 shrink-0', isException && 'text-text-warning')} />}
{isEnv && <Env className='h-3.5 w-3.5 shrink-0 text-util-colors-violet-violet-600' />} {isEnv && <Env className='h-3.5 w-3.5 shrink-0 text-util-colors-violet-violet-600' />}
@ -126,7 +143,27 @@ const WorkflowVariableBlockComponent = ({
) )
} }
return Item if (!node)
return null
return (
<Tooltip
noDecoration
popupContent={
<VarFullPathPanel
nodeName={node.title}
path={variables.slice(1)}
varType={getVarType ? getVarType({
nodeId: variables[0],
valueSelector: variables,
}) : Type.string}
nodeType={node?.type}
/>}
disabled={!isShowAPart}
>
<div>{Item}</div>
</Tooltip>
)
} }
export default memo(WorkflowVariableBlockComponent) export default memo(WorkflowVariableBlockComponent)

@ -9,7 +9,7 @@ import {
} from 'lexical' } from 'lexical'
import { mergeRegister } from '@lexical/utils' import { mergeRegister } from '@lexical/utils'
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import type { WorkflowVariableBlockType } from '../../types' import type { GetVarType, WorkflowVariableBlockType } from '../../types'
import { import {
$createWorkflowVariableBlockNode, $createWorkflowVariableBlockNode,
WorkflowVariableBlockNode, WorkflowVariableBlockNode,
@ -25,11 +25,13 @@ export type WorkflowVariableBlockProps = {
getWorkflowNode: (nodeId: string) => Node getWorkflowNode: (nodeId: string) => Node
onInsert?: () => void onInsert?: () => void
onDelete?: () => void onDelete?: () => void
getVarType: GetVarType
} }
const WorkflowVariableBlock = memo(({ const WorkflowVariableBlock = memo(({
workflowNodesMap, workflowNodesMap,
onInsert, onInsert,
onDelete, onDelete,
getVarType,
}: WorkflowVariableBlockType) => { }: WorkflowVariableBlockType) => {
const [editor] = useLexicalComposerContext() const [editor] = useLexicalComposerContext()
@ -48,7 +50,7 @@ const WorkflowVariableBlock = memo(({
INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND, INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND,
(variables: string[]) => { (variables: string[]) => {
editor.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) editor.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined)
const workflowVariableBlockNode = $createWorkflowVariableBlockNode(variables, workflowNodesMap) const workflowVariableBlockNode = $createWorkflowVariableBlockNode(variables, workflowNodesMap, getVarType)
$insertNodes([workflowVariableBlockNode]) $insertNodes([workflowVariableBlockNode])
if (onInsert) if (onInsert)
@ -69,7 +71,7 @@ const WorkflowVariableBlock = memo(({
COMMAND_PRIORITY_EDITOR, COMMAND_PRIORITY_EDITOR,
), ),
) )
}, [editor, onInsert, onDelete, workflowNodesMap]) }, [editor, onInsert, onDelete, workflowNodesMap, getVarType])
return null return null
}) })

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save