diff --git a/api/.env.example b/api/.env.example index 502461f658..01ddb4adfd 100644 --- a/api/.env.example +++ b/api/.env.example @@ -424,6 +424,12 @@ WORKFLOW_CALL_MAX_DEPTH=5 WORKFLOW_PARALLEL_DEPTH_LIMIT=3 MAX_VARIABLE_SIZE=204800 +# Workflow storage configuration +# Options: rdbms, hybrid +# rdbms: Use only the relational database (default) +# hybrid: Save new data to object storage, read from both object storage and RDBMS +WORKFLOW_NODE_EXECUTION_STORAGE=rdbms + # App configuration APP_MAX_EXECUTION_TIME=1200 APP_MAX_ACTIVE_REQUESTS=0 diff --git a/api/Dockerfile b/api/Dockerfile index df7f67785b..18ee598b24 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -6,15 +6,16 @@ WORKDIR /app/api # Install uv ENV UV_VERSION=0.6.14 # if you are located in China,you can add the sentence to acceding uv pip tool -# ENV UV_DEFAULT_INDEX="https://mirrors.aliyun.com/pypi/simple/" -# RUN python -m pip install -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple --upgrade pip -# RUN pip install --no-cache-dir uv==${UV_VERSION} -i https://mirrors.aliyun.com/pypi/simple/ +ENV UV_DEFAULT_INDEX="https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" -RUN pip install --no-cache-dir uv==${UV_VERSION} + +RUN python -m pip install -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple --upgrade pip +RUN pip install --no-cache-dir uv==${UV_VERSION} -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple # if you meet some error about apt,you can try to use sources.list config apt images -# RUN rm -rf /etc/apt/source* -# ADD sources.list /etc/apt/ +RUN rm -rf /etc/apt/source* +ADD sources.list /etc/apt/ + FROM base AS packages diff --git a/api/app_factory.py b/api/app_factory.py index 1c886ac5c7..586f2ded9e 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -54,6 +54,7 @@ def initialize_extensions(app: DifyApp): ext_otel, ext_proxy_fix, ext_redis, + ext_repositories, ext_sentry, ext_set_secretkey, ext_storage, @@ -74,6 +75,7 @@ def initialize_extensions(app: DifyApp): ext_migrate, ext_redis, ext_storage, + ext_repositories, ext_celery, ext_login, ext_mail, diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index d35a74e3ee..f498dccbbc 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -12,7 +12,7 @@ from pydantic import ( ) from pydantic_settings import BaseSettings -from configs.feature.hosted_service import HostedServiceConfig +from .hosted_service import HostedServiceConfig class SecurityConfig(BaseSettings): @@ -519,6 +519,11 @@ class WorkflowNodeExecutionConfig(BaseSettings): default=100, ) + WORKFLOW_NODE_EXECUTION_STORAGE: str = Field( + default="rdbms", + description="Storage backend for WorkflowNodeExecution. Options: 'rdbms', 'hybrid'", + ) + class AuthConfig(BaseSettings): """ diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 8518d34a8e..4046417076 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -85,5 +85,35 @@ class RuleCodeGenerateApi(Resource): return code_result +class RuleStructuredOutputGenerateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") + parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + + account = current_user + try: + structured_output = LLMGenerator.generate_structured_output( + tenant_id=account.current_tenant_id, + instruction=args["instruction"], + model_config=args["model_config"], + ) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + + return structured_output + + api.add_resource(RuleGenerateApi, "/rule-generate") api.add_resource(RuleCodeGenerateApi, "/rule-code-generate") +api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate") diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index e911c9a5e5..b4bd80fe2f 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -74,7 +74,9 @@ class OAuthDataSourceBinding(Resource): if not oauth_provider: return {"error": "Invalid provider"}, 400 if "code" in request.args: - code = request.args.get("code") + code = request.args.get("code", "") + if not code: + return {"error": "Invalid code"}, 400 try: oauth_provider.get_access_token(code) except requests.exceptions.HTTPError as e: diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index dc0009f36e..d4a33645ab 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -16,7 +16,7 @@ from controllers.console.auth.error import ( PasswordMismatchError, ) from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError -from controllers.console.wraps import setup_required +from controllers.console.wraps import email_password_login_enabled, setup_required from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.helper import email, extract_remote_ip @@ -30,6 +30,7 @@ from services.feature_service import FeatureService class ForgotPasswordSendEmailApi(Resource): @setup_required + @email_password_login_enabled def post(self): parser = reqparse.RequestParser() parser.add_argument("email", type=email, required=True, location="json") @@ -62,6 +63,7 @@ class ForgotPasswordSendEmailApi(Resource): class ForgotPasswordCheckApi(Resource): @setup_required + @email_password_login_enabled def post(self): parser = reqparse.RequestParser() parser.add_argument("email", type=str, required=True, location="json") @@ -86,12 +88,21 @@ class ForgotPasswordCheckApi(Resource): AccountService.add_forgot_password_error_rate_limit(args["email"]) raise EmailCodeError() + # Verified, revoke the first token + AccountService.revoke_reset_password_token(args["token"]) + + # Refresh token data by generating a new token + _, new_token = AccountService.generate_reset_password_token( + user_email, code=args["code"], additional_data={"phase": "reset"} + ) + AccountService.reset_forgot_password_error_rate_limit(args["email"]) - return {"is_valid": True, "email": token_data.get("email")} + return {"is_valid": True, "email": token_data.get("email"), "token": new_token} class ForgotPasswordResetApi(Resource): @setup_required + @email_password_login_enabled def post(self): parser = reqparse.RequestParser() parser.add_argument("token", type=str, required=True, nullable=False, location="json") @@ -107,6 +118,9 @@ class ForgotPasswordResetApi(Resource): reset_data = AccountService.get_reset_password_data(args["token"]) if not reset_data: raise InvalidTokenError() + # Must use token in reset phase + if reset_data.get("phase", "") != "reset": + raise InvalidTokenError() # Revoke token to prevent reuse AccountService.revoke_reset_password_token(args["token"]) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 41362e9fa2..16c1dcc441 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -22,7 +22,7 @@ from controllers.console.error import ( EmailSendIpLimitError, NotAllowedCreateWorkspace, ) -from controllers.console.wraps import setup_required +from controllers.console.wraps import email_password_login_enabled, setup_required from events.tenant_event import tenant_was_created from libs.helper import email, extract_remote_ip from libs.password import valid_password @@ -38,6 +38,7 @@ class LoginApi(Resource): """Resource for user login.""" @setup_required + @email_password_login_enabled def post(self): """Authenticate user and login.""" parser = reqparse.RequestParser() @@ -110,6 +111,7 @@ class LogoutApi(Resource): class ResetPasswordSendEmailApi(Resource): @setup_required + @email_password_login_enabled def post(self): parser = reqparse.RequestParser() parser.add_argument("email", type=email, required=True, location="json") diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 6caaae87f4..e5e8038ad7 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -210,3 +210,16 @@ def enterprise_license_required(view): return view(*args, **kwargs) return decorated + + +def email_password_login_enabled(view): + @wraps(view) + def decorated(*args, **kwargs): + features = FeatureService.get_system_features() + if features.enable_email_password_login: + return view(*args, **kwargs) + + # otherwise, return 403 + abort(403) + + return decorated diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 494b357d46..17e9a3990f 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -46,6 +46,7 @@ class MessageListApi(WebApiResource): "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), "created_at": TimestampField, "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "metadata": fields.Raw(attribute="message_metadata_dict"), "status": fields.String, "error": fields.String, } diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 66f2c754bb..3bf6c330db 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -320,10 +320,9 @@ class AdvancedChatAppGenerateTaskPipeline: session=session, workflow_run_id=self._workflow_run_id ) workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( - session=session, workflow_run=workflow_run, event=event + workflow_run=workflow_run, event=event ) node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( - session=session, event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -341,11 +340,10 @@ class AdvancedChatAppGenerateTaskPipeline: session=session, workflow_run_id=self._workflow_run_id ) workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( - session=session, workflow_run=workflow_run, event=event + workflow_run=workflow_run, event=event ) node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response( - session=session, event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -363,11 +361,10 @@ class AdvancedChatAppGenerateTaskPipeline: with Session(db.engine, expire_on_commit=False) as session: workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( - session=session, event=event + event=event ) node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( - session=session, event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -383,18 +380,15 @@ class AdvancedChatAppGenerateTaskPipeline: | QueueNodeInLoopFailedEvent | QueueNodeExceptionEvent, ): - with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( - session=session, event=event - ) + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( + event=event + ) - node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( - session=session, - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - session.commit() + node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) if node_finish_resp: yield node_finish_resp diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 14441ada40..1f998edb6a 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -279,10 +279,9 @@ class WorkflowAppGenerateTaskPipeline: session=session, workflow_run_id=self._workflow_run_id ) workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( - session=session, workflow_run=workflow_run, event=event + workflow_run=workflow_run, event=event ) response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( - session=session, event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -300,10 +299,9 @@ class WorkflowAppGenerateTaskPipeline: session=session, workflow_run_id=self._workflow_run_id ) workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( - session=session, workflow_run=workflow_run, event=event + workflow_run=workflow_run, event=event ) node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response( - session=session, event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -313,17 +311,14 @@ class WorkflowAppGenerateTaskPipeline: if node_start_response: yield node_start_response elif isinstance(event, QueueNodeSucceededEvent): - with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( - session=session, event=event - ) - node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( - session=session, - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - session.commit() + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( + event=event + ) + node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) if node_success_response: yield node_success_response @@ -334,18 +329,14 @@ class WorkflowAppGenerateTaskPipeline: | QueueNodeInLoopFailedEvent | QueueNodeExceptionEvent, ): - with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( - session=session, - event=event, - ) - node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( - session=session, - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - session.commit() + workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( + event=event, + ) + node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) if node_failed_response: yield node_failed_response @@ -627,6 +618,7 @@ class WorkflowAppGenerateTaskPipeline: workflow_app_log.created_by = self._user_id session.add(workflow_app_log) + session.commit() def _text_chunk_to_stream_response( self, text: str, from_variable_selector: Optional[list[str]] = None diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 4d629ca186..5ce9f737d1 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -6,7 +6,7 @@ from typing import Any, Optional, Union, cast from uuid import uuid4 from sqlalchemy import func, select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( @@ -49,12 +49,14 @@ from core.file import FILE_MODEL_IDENTITY, File from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.repository import RepositoryFactory from core.tools.tool_manager import ToolManager from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.workflow_entry import WorkflowEntry +from extensions.ext_database import db from models.account import Account from models.enums import CreatedByRole, WorkflowRunTriggeredFrom from models.model import EndUser @@ -80,6 +82,21 @@ class WorkflowCycleManage: self._application_generate_entity = application_generate_entity self._workflow_system_variables = workflow_system_variables + # Initialize the session factory and repository + # We use the global db engine instead of the session passed to methods + # Disable expire_on_commit to avoid the need for merging objects + self._session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + self._workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": self._application_generate_entity.app_config.tenant_id, + "app_id": self._application_generate_entity.app_config.app_id, + "session_factory": self._session_factory, + } + ) + + # We'll still keep the cache for backward compatibility and performance + # but use the repository for database operations + def _handle_workflow_run_start( self, *, @@ -254,19 +271,15 @@ class WorkflowCycleManage: workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.exceptions_count = exceptions_count - stmt = select(WorkflowNodeExecution.node_execution_id).where( - WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, - WorkflowNodeExecution.app_id == workflow_run.app_id, - WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - WorkflowNodeExecution.workflow_run_id == workflow_run.id, - WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value, + # Use the instance repository to find running executions for a workflow run + running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions( + workflow_run_id=workflow_run.id ) - ids = session.scalars(stmt).all() - # Use self._get_workflow_node_execution here to make sure the cache is updated - running_workflow_node_executions = [ - self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id - ] + + # Update the cache with the retrieved executions + for execution in running_workflow_node_executions: + if execution.node_execution_id: + self._workflow_node_executions[execution.node_execution_id] = execution for workflow_node_execution in running_workflow_node_executions: now = datetime.now(UTC).replace(tzinfo=None) @@ -288,7 +301,7 @@ class WorkflowCycleManage: return workflow_run def _handle_node_execution_start( - self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeStartedEvent + self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent ) -> WorkflowNodeExecution: workflow_node_execution = WorkflowNodeExecution() workflow_node_execution.id = str(uuid4()) @@ -315,17 +328,14 @@ class WorkflowCycleManage: ) workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) - session.add(workflow_node_execution) + # Use the instance repository to save the workflow node execution + self._workflow_node_execution_repository.save(workflow_node_execution) self._workflow_node_executions[event.node_execution_id] = workflow_node_execution return workflow_node_execution - def _handle_workflow_node_execution_success( - self, *, session: Session, event: QueueNodeSucceededEvent - ) -> WorkflowNodeExecution: - workflow_node_execution = self._get_workflow_node_execution( - session=session, node_execution_id=event.node_execution_id - ) + def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: + workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id) inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) @@ -344,13 +354,13 @@ class WorkflowCycleManage: workflow_node_execution.finished_at = finished_at workflow_node_execution.elapsed_time = elapsed_time - workflow_node_execution = session.merge(workflow_node_execution) + # Use the instance repository to update the workflow node execution + self._workflow_node_execution_repository.update(workflow_node_execution) return workflow_node_execution def _handle_workflow_node_execution_failed( self, *, - session: Session, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeInLoopFailedEvent @@ -361,9 +371,7 @@ class WorkflowCycleManage: :param event: queue node failed event :return: """ - workflow_node_execution = self._get_workflow_node_execution( - session=session, node_execution_id=event.node_execution_id - ) + workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id) inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) @@ -387,14 +395,14 @@ class WorkflowCycleManage: workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.execution_metadata = execution_metadata - workflow_node_execution = session.merge(workflow_node_execution) return workflow_node_execution def _handle_workflow_node_execution_retried( - self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeRetryEvent + self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent ) -> WorkflowNodeExecution: """ Workflow node execution failed + :param workflow_run: workflow run :param event: queue node failed event :return: """ @@ -439,15 +447,12 @@ class WorkflowCycleManage: workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.index = event.node_run_index - session.add(workflow_node_execution) + # Use the instance repository to save the workflow node execution + self._workflow_node_execution_repository.save(workflow_node_execution) self._workflow_node_executions[event.node_execution_id] = workflow_node_execution return workflow_node_execution - ################################################# - # to stream responses # - ################################################# - def _workflow_start_to_stream_response( self, *, @@ -455,7 +460,6 @@ class WorkflowCycleManage: task_id: str, workflow_run: WorkflowRun, ) -> WorkflowStartStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return WorkflowStartStreamResponse( task_id=task_id, @@ -521,14 +525,10 @@ class WorkflowCycleManage: def _workflow_node_start_to_stream_response( self, *, - session: Session, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeStartStreamResponse]: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this - _ = session - if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None if not workflow_node_execution.workflow_run_id: @@ -571,7 +571,6 @@ class WorkflowCycleManage: def _workflow_node_finish_to_stream_response( self, *, - session: Session, event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent @@ -580,8 +579,6 @@ class WorkflowCycleManage: task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeFinishStreamResponse]: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this - _ = session if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None if not workflow_node_execution.workflow_run_id: @@ -621,13 +618,10 @@ class WorkflowCycleManage: def _workflow_node_retry_to_stream_response( self, *, - session: Session, event: QueueNodeRetryEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this - _ = session if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None if not workflow_node_execution.workflow_run_id: @@ -668,7 +662,6 @@ class WorkflowCycleManage: def _workflow_parallel_branch_start_to_stream_response( self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent ) -> ParallelBranchStartStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return ParallelBranchStartStreamResponse( task_id=task_id, @@ -692,7 +685,6 @@ class WorkflowCycleManage: workflow_run: WorkflowRun, event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, ) -> ParallelBranchFinishedStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return ParallelBranchFinishedStreamResponse( task_id=task_id, @@ -713,7 +705,6 @@ class WorkflowCycleManage: def _workflow_iteration_start_to_stream_response( self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent ) -> IterationNodeStartStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return IterationNodeStartStreamResponse( task_id=task_id, @@ -735,7 +726,6 @@ class WorkflowCycleManage: def _workflow_iteration_next_to_stream_response( self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent ) -> IterationNodeNextStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return IterationNodeNextStreamResponse( task_id=task_id, @@ -759,7 +749,6 @@ class WorkflowCycleManage: def _workflow_iteration_completed_to_stream_response( self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent ) -> IterationNodeCompletedStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return IterationNodeCompletedStreamResponse( task_id=task_id, @@ -790,7 +779,6 @@ class WorkflowCycleManage: def _workflow_loop_start_to_stream_response( self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent ) -> LoopNodeStartStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return LoopNodeStartStreamResponse( task_id=task_id, @@ -812,7 +800,6 @@ class WorkflowCycleManage: def _workflow_loop_next_to_stream_response( self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent ) -> LoopNodeNextStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return LoopNodeNextStreamResponse( task_id=task_id, @@ -836,7 +823,6 @@ class WorkflowCycleManage: def _workflow_loop_completed_to_stream_response( self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent ) -> LoopNodeCompletedStreamResponse: - # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this _ = session return LoopNodeCompletedStreamResponse( task_id=task_id, @@ -934,11 +920,22 @@ class WorkflowCycleManage: return workflow_run - def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: - if node_execution_id not in self._workflow_node_executions: + def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution: + # First check the cache for performance + if node_execution_id in self._workflow_node_executions: + cached_execution = self._workflow_node_executions[node_execution_id] + # No need to merge with session since expire_on_commit=False + return cached_execution + + # If not in cache, use the instance repository to get by node_execution_id + execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id) + + if not execution: raise ValueError(f"Workflow node execution not found: {node_execution_id}") - cached_workflow_node_execution = self._workflow_node_executions[node_execution_id] - return session.merge(cached_workflow_node_execution) + + # Update cache + self._workflow_node_executions[node_execution_id] = execution + return execution def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: """ diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 64c734f626..56859df7f4 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -6,7 +6,6 @@ from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import ChildChunk, DatasetQuery, DocumentSegment from models.dataset import Document as DatasetDocument -from models.model import DatasetRetrieverResource class DatasetIndexToolCallbackHandler: @@ -71,29 +70,6 @@ class DatasetIndexToolCallbackHandler: def return_retriever_resource_info(self, resource: list): """Handle return_retriever_resource_info.""" - if resource and len(resource) > 0: - for item in resource: - dataset_retriever_resource = DatasetRetrieverResource( - message_id=self._message_id, - position=item.get("position") or 0, - dataset_id=item.get("dataset_id"), - dataset_name=item.get("dataset_name"), - document_id=item.get("document_id"), - document_name=item.get("document_name"), - data_source_type=item.get("data_source_type"), - segment_id=item.get("segment_id"), - score=item.get("score") if "score" in item else None, - hit_count=item.get("hit_count") if "hit_count" in item else None, - word_count=item.get("word_count") if "word_count" in item else None, - segment_position=item.get("segment_position") if "segment_position" in item else None, - index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None, - content=item.get("content"), - retriever_from=item.get("retriever_from"), - created_by=self._user_id, - ) - db.session.add(dataset_retriever_resource) - db.session.commit() - self._queue_manager.publish( QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER ) diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 75687f9ae3..d5d2ca60fa 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -10,6 +10,7 @@ from core.llm_generator.prompts import ( GENERATOR_QA_PROMPT, JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE, PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE, + SYSTEM_STRUCTURED_OUTPUT_GENERATE, WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, ) from core.model_manager import ModelManager @@ -340,3 +341,37 @@ class LLMGenerator: answer = cast(str, response.message.content) return answer.strip() + + @classmethod + def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict): + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), + ) + + prompt_messages = [ + SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE), + UserPromptMessage(content=instruction), + ] + model_parameters = model_config.get("model_parameters", {}) + + try: + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + ), + ) + + generated_json_schema = cast(str, response.message.content) + return {"output": generated_json_schema, "error": ""} + + except InvokeError as e: + error = str(e) + return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"} + except Exception as e: + logging.exception(f"Failed to invoke LLM model, model: {model_config.get('name')}") + return {"output": "", "error": f"An unexpected error occurred: {str(e)}"} diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index cf20e60c82..82d22d7f89 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -220,3 +220,110 @@ Here is the task description: {{INPUT_TEXT}} You just need to generate the output """ # noqa: E501 + +SYSTEM_STRUCTURED_OUTPUT_GENERATE = """ +Your task is to convert simple user descriptions into properly formatted JSON Schema definitions. When a user describes data fields they need, generate a complete, valid JSON Schema that accurately represents those fields with appropriate types and requirements. + +## Instructions: + +1. Analyze the user's description of their data needs +2. Identify each property that should be included in the schema +3. Determine the appropriate data type for each property +4. Decide which properties should be required +5. Generate a complete JSON Schema with proper syntax +6. Include appropriate constraints when specified (min/max values, patterns, formats) +7. Provide ONLY the JSON Schema without any additional explanations, comments, or markdown formatting. +8. DO NOT use markdown code blocks (``` or ``` json). Return the raw JSON Schema directly. + +## Examples: + +### Example 1: +**User Input:** I need name and age +**JSON Schema Output:** +{ + "type": "object", + "properties": { + "name": { "type": "string" }, + "age": { "type": "number" } + }, + "required": ["name", "age"] +} + +### Example 2: +**User Input:** I want to store information about books including title, author, publication year and optional page count +**JSON Schema Output:** +{ + "type": "object", + "properties": { + "title": { "type": "string" }, + "author": { "type": "string" }, + "publicationYear": { "type": "integer" }, + "pageCount": { "type": "integer" } + }, + "required": ["title", "author", "publicationYear"] +} + +### Example 3: +**User Input:** Create a schema for user profiles with email, password, and age (must be at least 18) +**JSON Schema Output:** +{ + "type": "object", + "properties": { + "email": { + "type": "string", + "format": "email" + }, + "password": { + "type": "string", + "minLength": 8 + }, + "age": { + "type": "integer", + "minimum": 18 + } + }, + "required": ["email", "password", "age"] +} + +### Example 4: +**User Input:** I need album schema, the ablum has songs, and each song has name, duration, and artist. +**JSON Schema Output:** +{ + "type": "object", + "properties": { + "properties": { + "songs": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "id": { + "type": "string" + }, + "duration": { + "type": "string" + }, + "aritst": { + "type": "string" + } + }, + "required": [ + "name", + "id", + "duration", + "aritst" + ] + } + } + } + }, + "required": [ + "songs" + ] +} + +Now, generate a JSON Schema based on my description +""" # noqa: E501 diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 977678b893..3bed2460dd 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,8 +1,8 @@ from collections.abc import Sequence from enum import Enum, StrEnum -from typing import Optional +from typing import Any, Optional, Union -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, field_serializer, field_validator class PromptMessageRole(Enum): @@ -135,6 +135,16 @@ class PromptMessage(BaseModel): """ return not self.content + @field_serializer("content") + def serialize_content( + self, content: Optional[Union[str, Sequence[PromptMessageContent]]] + ) -> Optional[str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent]]: + if content is None or isinstance(content, str): + return content + if isinstance(content, list): + return [item.model_dump() if hasattr(item, "model_dump") else item for item in content] + return content + class UserPromptMessage(PromptMessage): """ diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 3225f03fbd..373ef2bbe2 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -2,7 +2,7 @@ from decimal import Decimal from enum import Enum, StrEnum from typing import Any, Optional -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, model_validator from core.model_runtime.entities.common_entities import I18nObject @@ -85,6 +85,7 @@ class ModelFeature(Enum): DOCUMENT = "document" VIDEO = "video" AUDIO = "audio" + STRUCTURED_OUTPUT = "structured-output" class DefaultParameterName(StrEnum): @@ -197,6 +198,19 @@ class AIModelEntity(ProviderModel): parameter_rules: list[ParameterRule] = [] pricing: Optional[PriceConfig] = None + @model_validator(mode="after") + def validate_model(self): + supported_schema_keys = ["json_schema"] + schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None) + if not schema_key: + return self + if self.features is None: + self.features = [ModelFeature.STRUCTURED_OUTPUT] + else: + if ModelFeature.STRUCTURED_OUTPUT not in self.features: + self.features.append(ModelFeature.STRUCTURED_OUTPUT) + return self + class ModelUsage(BaseModel): pass diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 53de16d621..1b799131e7 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -1,5 +1,6 @@ import logging import time +import uuid from collections.abc import Generator, Sequence from typing import Optional, Union @@ -24,6 +25,58 @@ from core.plugin.manager.model import PluginModelManager logger = logging.getLogger(__name__) +def _gen_tool_call_id() -> str: + return f"chatcmpl-tool-{str(uuid.uuid4().hex)}" + + +def _increase_tool_call( + new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall] +): + """ + Merge incremental tool call updates into existing tool calls. + + :param new_tool_calls: List of new tool call deltas to be merged. + :param existing_tools_calls: List of existing tool calls to be modified IN-PLACE. + """ + + def get_tool_call(tool_call_id: str): + """ + Get or create a tool call by ID + + :param tool_call_id: tool call ID + :return: existing or new tool call + """ + if not tool_call_id: + return existing_tools_calls[-1] + + _tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None) + if _tool_call is None: + _tool_call = AssistantPromptMessage.ToolCall( + id=tool_call_id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), + ) + existing_tools_calls.append(_tool_call) + + return _tool_call + + for new_tool_call in new_tool_calls: + # generate ID for tool calls with function name but no ID to track them + if new_tool_call.function.name and not new_tool_call.id: + new_tool_call.id = _gen_tool_call_id() + # get tool call + tool_call = get_tool_call(new_tool_call.id) + # update tool call + if new_tool_call.id: + tool_call.id = new_tool_call.id + if new_tool_call.type: + tool_call.type = new_tool_call.type + if new_tool_call.function.name: + tool_call.function.name = new_tool_call.function.name + if new_tool_call.function.arguments: + tool_call.function.arguments += new_tool_call.function.arguments + + class LargeLanguageModel(AIModel): """ Model class for large language model. @@ -109,44 +162,13 @@ class LargeLanguageModel(AIModel): system_fingerprint = None tools_calls: list[AssistantPromptMessage.ToolCall] = [] - def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]): - def get_tool_call(tool_name: str): - if not tool_name: - return tools_calls[-1] - - tool_call = next( - (tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None - ) - if tool_call is None: - tool_call = AssistantPromptMessage.ToolCall( - id="", - type="", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""), - ) - tools_calls.append(tool_call) - - return tool_call - - for new_tool_call in new_tool_calls: - # get tool call - tool_call = get_tool_call(new_tool_call.function.name) - # update tool call - if new_tool_call.id: - tool_call.id = new_tool_call.id - if new_tool_call.type: - tool_call.type = new_tool_call.type - if new_tool_call.function.name: - tool_call.function.name = new_tool_call.function.name - if new_tool_call.function.arguments: - tool_call.function.arguments += new_tool_call.function.arguments - for chunk in result: if isinstance(chunk.delta.message.content, str): content += chunk.delta.message.content elif isinstance(chunk.delta.message.content, list): content_list.extend(chunk.delta.message.content) if chunk.delta.message.tool_calls: - increase_tool_call(chunk.delta.message.tool_calls) + _increase_tool_call(chunk.delta.message.tool_calls, tools_calls) usage = chunk.delta.usage or LLMUsage.empty_usage() system_fingerprint = chunk.system_fingerprint diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index f67e270ab1..fa78b7b8e9 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -5,6 +5,7 @@ from datetime import datetime, timedelta from typing import Optional from langfuse import Langfuse # type: ignore +from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import LangfuseConfig @@ -28,9 +29,9 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( UnitEnum, ) from core.ops.utils import filter_none_values +from core.repository.repository_factory import RepositoryFactory from extensions.ext_database import db from models.model import EndUser -from models.workflow import WorkflowNodeExecution logger = logging.getLogger(__name__) @@ -110,36 +111,18 @@ class LangFuseDataTrace(BaseTraceInstance): ) self.add_trace(langfuse_trace_data=trace_data) - # through workflow_run_id get all_nodes_execution - workflow_nodes_execution_id_records = ( - db.session.query(WorkflowNodeExecution.id) - .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) - .all() + # through workflow_run_id get all_nodes_execution using repository + session_factory = sessionmaker(bind=db.engine) + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + params={"tenant_id": trace_info.tenant_id, "session_factory": session_factory}, ) - for node_execution_id_record in workflow_nodes_execution_id_records: - node_execution = ( - db.session.query( - WorkflowNodeExecution.id, - WorkflowNodeExecution.tenant_id, - WorkflowNodeExecution.app_id, - WorkflowNodeExecution.title, - WorkflowNodeExecution.node_type, - WorkflowNodeExecution.status, - WorkflowNodeExecution.inputs, - WorkflowNodeExecution.outputs, - WorkflowNodeExecution.created_at, - WorkflowNodeExecution.elapsed_time, - WorkflowNodeExecution.process_data, - WorkflowNodeExecution.execution_metadata, - ) - .filter(WorkflowNodeExecution.id == node_execution_id_record.id) - .first() - ) - - if not node_execution: - continue + # Get all executions for this workflow run + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( + workflow_run_id=trace_info.workflow_run_id + ) + for node_execution in workflow_node_executions: node_execution_id = node_execution.id tenant_id = node_execution.tenant_id app_id = node_execution.app_id diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index e3494e2f23..85a0eafdc1 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -7,6 +7,7 @@ from typing import Optional, cast from langsmith import Client from langsmith.schemas import RunBase +from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import LangSmithConfig @@ -27,9 +28,9 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( LangSmithRunUpdateModel, ) from core.ops.utils import filter_none_values, generate_dotted_order +from core.repository.repository_factory import RepositoryFactory from extensions.ext_database import db from models.model import EndUser, MessageFile -from models.workflow import WorkflowNodeExecution logger = logging.getLogger(__name__) @@ -134,36 +135,22 @@ class LangSmithDataTrace(BaseTraceInstance): self.add_run(langsmith_run) - # through workflow_run_id get all_nodes_execution - workflow_nodes_execution_id_records = ( - db.session.query(WorkflowNodeExecution.id) - .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) - .all() + # through workflow_run_id get all_nodes_execution using repository + session_factory = sessionmaker(bind=db.engine) + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": trace_info.tenant_id, + "app_id": trace_info.metadata.get("app_id"), + "session_factory": session_factory, + }, ) - for node_execution_id_record in workflow_nodes_execution_id_records: - node_execution = ( - db.session.query( - WorkflowNodeExecution.id, - WorkflowNodeExecution.tenant_id, - WorkflowNodeExecution.app_id, - WorkflowNodeExecution.title, - WorkflowNodeExecution.node_type, - WorkflowNodeExecution.status, - WorkflowNodeExecution.inputs, - WorkflowNodeExecution.outputs, - WorkflowNodeExecution.created_at, - WorkflowNodeExecution.elapsed_time, - WorkflowNodeExecution.process_data, - WorkflowNodeExecution.execution_metadata, - ) - .filter(WorkflowNodeExecution.id == node_execution_id_record.id) - .first() - ) - - if not node_execution: - continue + # Get all executions for this workflow run + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( + workflow_run_id=trace_info.workflow_run_id + ) + for node_execution in workflow_node_executions: node_execution_id = node_execution.id tenant_id = node_execution.tenant_id app_id = node_execution.app_id diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index fabf38fbd6..923b9a24ed 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -7,6 +7,7 @@ from typing import Optional, cast from opik import Opik, Trace from opik.id_helpers import uuid4_to_uuid7 +from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import OpikConfig @@ -21,9 +22,9 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) +from core.repository.repository_factory import RepositoryFactory from extensions.ext_database import db from models.model import EndUser, MessageFile -from models.workflow import WorkflowNodeExecution logger = logging.getLogger(__name__) @@ -147,36 +148,22 @@ class OpikDataTrace(BaseTraceInstance): } self.add_trace(trace_data) - # through workflow_run_id get all_nodes_execution - workflow_nodes_execution_id_records = ( - db.session.query(WorkflowNodeExecution.id) - .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) - .all() + # through workflow_run_id get all_nodes_execution using repository + session_factory = sessionmaker(bind=db.engine) + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": trace_info.tenant_id, + "app_id": trace_info.metadata.get("app_id"), + "session_factory": session_factory, + }, ) - for node_execution_id_record in workflow_nodes_execution_id_records: - node_execution = ( - db.session.query( - WorkflowNodeExecution.id, - WorkflowNodeExecution.tenant_id, - WorkflowNodeExecution.app_id, - WorkflowNodeExecution.title, - WorkflowNodeExecution.node_type, - WorkflowNodeExecution.status, - WorkflowNodeExecution.inputs, - WorkflowNodeExecution.outputs, - WorkflowNodeExecution.created_at, - WorkflowNodeExecution.elapsed_time, - WorkflowNodeExecution.process_data, - WorkflowNodeExecution.execution_metadata, - ) - .filter(WorkflowNodeExecution.id == node_execution_id_record.id) - .first() - ) - - if not node_execution: - continue + # Get all executions for this workflow run + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( + workflow_run_id=trace_info.workflow_run_id + ) + for node_execution in workflow_node_executions: node_execution_id = node_execution.id tenant_id = node_execution.tenant_id app_id = node_execution.app_id diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index f402da030f..db07e52f3f 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -39,6 +39,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): :param query: str :return: dict """ + # FIXME(-LAN-): Avoid import service into core workflow_service = WorkflowService() node_id = "1919810" node_data = ParameterExtractorNodeData( @@ -89,6 +90,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): :param query: str :return: dict """ + # FIXME(-LAN-): Avoid import service into core workflow_service = WorkflowService() node_id = "1919810" node_data = QuestionClassifierNodeData( diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 099acfd7f4..7570200175 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -124,6 +124,15 @@ class ProviderManager: # Get All preferred provider types of the workspace provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id) + # Ensure that both the original provider name and its ModelProviderID string representation + # are present in the dictionary to handle cases where either form might be used + for provider_name in list(provider_name_to_preferred_model_provider_records_dict.keys()): + provider_id = ModelProviderID(provider_name) + if str(provider_id) not in provider_name_to_preferred_model_provider_records_dict: + # Add the ModelProviderID string representation if it's not already present + provider_name_to_preferred_model_provider_records_dict[str(provider_id)] = ( + provider_name_to_preferred_model_provider_records_dict[provider_name] + ) # Get All provider model settings provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id) @@ -497,8 +506,8 @@ class ProviderManager: @staticmethod def _init_trial_provider_records( - tenant_id: str, provider_name_to_provider_records_dict: dict[str, list] - ) -> dict[str, list]: + tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]] + ) -> dict[str, list[Provider]]: """ Initialize trial provider records if not exists. @@ -532,7 +541,7 @@ class ProviderManager: if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: try: # FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic - provider_record = Provider( + new_provider_record = Provider( tenant_id=tenant_id, # TODO: Use provider name with prefix after the data migration. provider_name=ModelProviderID(provider_name).provider_name, @@ -542,11 +551,12 @@ class ProviderManager: quota_used=0, is_valid=True, ) - db.session.add(provider_record) + db.session.add(new_provider_record) db.session.commit() + provider_name_to_provider_records_dict[provider_name].append(new_provider_record) except IntegrityError: db.session.rollback() - provider_record = ( + existed_provider_record = ( db.session.query(Provider) .filter( Provider.tenant_id == tenant_id, @@ -556,11 +566,14 @@ class ProviderManager: ) .first() ) - if provider_record and not provider_record.is_valid: - provider_record.is_valid = True + if not existed_provider_record: + continue + + if not existed_provider_record.is_valid: + existed_provider_record.is_valid = True db.session.commit() - provider_name_to_provider_records_dict[provider_name].append(provider_record) + provider_name_to_provider_records_dict[provider_name].append(existed_provider_record) return provider_name_to_provider_records_dict diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py index 778e8a07d8..c1792943bb 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py @@ -246,7 +246,7 @@ class AnalyticdbVectorBySql: ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score FROM {self.table_name} WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause} - ORDER BY (score,id) DESC + ORDER BY score DESC, id DESC LIMIT {top_k}""", (f"'{query}'", f"'{query}'"), ) diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 4af2578197..63695e6f3f 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -2,12 +2,12 @@ import array import json import re import uuid -from contextlib import contextmanager from typing import Any import jieba.posseg as pseg # type: ignore import numpy import oracledb +from oracledb.connection import Connection from pydantic import BaseModel, model_validator from configs import dify_config @@ -70,6 +70,7 @@ class OracleVector(BaseVector): super().__init__(collection_name) self.pool = self._create_connection_pool(config) self.table_name = f"embedding_{collection_name}" + self.config = config def get_type(self) -> str: return VectorType.ORACLE @@ -107,16 +108,19 @@ class OracleVector(BaseVector): outconverter=self.numpy_converter_out, ) + def _get_connection(self) -> Connection: + connection = oracledb.connect(user=self.config.user, password=self.config.password, dsn=self.config.dsn) + return connection + def _create_connection_pool(self, config: OracleVectorConfig): pool_params = { "user": config.user, "password": config.password, "dsn": config.dsn, "min": 1, - "max": 50, + "max": 5, "increment": 1, } - if config.is_autonomous: pool_params.update( { @@ -125,22 +129,8 @@ class OracleVector(BaseVector): "wallet_password": config.wallet_password, } ) - return oracledb.create_pool(**pool_params) - @contextmanager - def _get_cursor(self): - conn = self.pool.acquire() - conn.inputtypehandler = self.input_type_handler - conn.outputtypehandler = self.output_type_handler - cur = conn.cursor() - try: - yield cur - finally: - cur.close() - conn.commit() - conn.close() - def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): dimension = len(embeddings[0]) self._create_collection(dimension) @@ -162,41 +152,68 @@ class OracleVector(BaseVector): numpy.array(embeddings[i]), ) ) - # print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)") - with self._get_cursor() as cur: - cur.executemany( - f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values - ) + with self._get_connection() as conn: + conn.inputtypehandler = self.input_type_handler + conn.outputtypehandler = self.output_type_handler + # with conn.cursor() as cur: + # cur.executemany( + # f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values + # ) + # conn.commit() + for value in values: + with conn.cursor() as cur: + try: + cur.execute( + f"""INSERT INTO {self.table_name} (id, text, meta, embedding) + VALUES (:1, :2, :3, :4)""", + value, + ) + conn.commit() + except Exception as e: + print(e) + conn.close() return pks def text_exists(self, id: str) -> bool: - with self._get_cursor() as cur: - cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,)) - return cur.fetchone() is not None + with self._get_connection() as conn: + with conn.cursor() as cur: + cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,)) + return cur.fetchone() is not None + conn.close() def get_by_ids(self, ids: list[str]) -> list[Document]: - with self._get_cursor() as cur: - cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) - docs = [] - for record in cur: - docs.append(Document(page_content=record[1], metadata=record[0])) + with self._get_connection() as conn: + with conn.cursor() as cur: + cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),)) + docs = [] + for record in cur: + docs.append(Document(page_content=record[1], metadata=record[0])) + self.pool.release(connection=conn) + conn.close() return docs def delete_by_ids(self, ids: list[str]) -> None: if not ids: return - with self._get_cursor() as cur: - cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) + with self._get_connection() as conn: + with conn.cursor() as cur: + cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),)) + conn.commit() + conn.close() def delete_by_metadata_field(self, key: str, value: str) -> None: - with self._get_cursor() as cur: - cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value)) + with self._get_connection() as conn: + with conn.cursor() as cur: + cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value)) + conn.commit() + conn.close() def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: """ Search the nearest neighbors to a vector. :param query_vector: The input vector to search for similar items. + :param top_k: The number of nearest neighbors to return, default is 5. :return: List of Documents that are nearest to the query vector. """ top_k = kwargs.get("top_k", 4) @@ -205,20 +222,25 @@ class OracleVector(BaseVector): if document_ids_filter: document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) where_clause = f"WHERE metadata->>'document_id' in ({document_ids})" - with self._get_cursor() as cur: - cur.execute( - f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}" - f" {where_clause} ORDER BY distance fetch first {top_k} rows only", - [numpy.array(query_vector)], - ) - docs = [] - score_threshold = float(kwargs.get("score_threshold") or 0.0) - for record in cur: - metadata, text, distance = record - score = 1 - distance - metadata["score"] = score - if score > score_threshold: - docs.append(Document(page_content=text, metadata=metadata)) + with self._get_connection() as conn: + conn.inputtypehandler = self.input_type_handler + conn.outputtypehandler = self.output_type_handler + with conn.cursor() as cur: + cur.execute( + f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine) + AS distance FROM {self.table_name} + {where_clause} ORDER BY distance fetch first {top_k} rows only""", + [numpy.array(query_vector)], + ) + docs = [] + score_threshold = float(kwargs.get("score_threshold") or 0.0) + for record in cur: + metadata, text, distance = record + score = 1 - distance + metadata["score"] = score + if score > score_threshold: + docs.append(Document(page_content=text, metadata=metadata)) + conn.close() return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -228,7 +250,7 @@ class OracleVector(BaseVector): top_k = kwargs.get("top_k", 5) # just not implement fetch by score_threshold now, may be later - # score_threshold = float(kwargs.get("score_threshold") or 0.0) + score_threshold = float(kwargs.get("score_threshold") or 0.0) if len(query) > 0: # Check which language the query is in zh_pattern = re.compile("[\u4e00-\u9fa5]+") @@ -239,7 +261,7 @@ class OracleVector(BaseVector): words = pseg.cut(query) current_entity = "" for word, pos in words: - if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名,ns: 地名,nt: 机构名 + if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名 current_entity += word else: if current_entity: @@ -260,30 +282,35 @@ class OracleVector(BaseVector): for token in all_tokens: if token not in stop_words: entities.append(token) - with self._get_cursor() as cur: - document_ids_filter = kwargs.get("document_ids_filter") - where_clause = "" - if document_ids_filter: - document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) - where_clause = f" AND metadata->>'document_id' in ({document_ids}) " - cur.execute( - f"select meta, text, embedding FROM {self.table_name}" - f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} " - f"order by score(1) desc fetch first {top_k} rows only", - [" ACCUM ".join(entities)], - ) - docs = [] - for record in cur: - metadata, text, embedding = record - docs.append(Document(page_content=text, vector=embedding, metadata=metadata)) + with self._get_connection() as conn: + with conn.cursor() as cur: + document_ids_filter = kwargs.get("document_ids_filter") + where_clause = "" + if document_ids_filter: + document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + where_clause = f" AND metadata->>'document_id' in ({document_ids}) " + cur.execute( + f"""select meta, text, embedding FROM {self.table_name} + WHERE CONTAINS(text, :kk, 1) > 0 {where_clause} + order by score(1) desc fetch first {top_k} rows only""", + kk=" ACCUM ".join(entities), + ) + docs = [] + for record in cur: + metadata, text, embedding = record + docs.append(Document(page_content=text, vector=embedding, metadata=metadata)) + conn.close() return docs else: return [Document(page_content="", metadata={})] return [] def delete(self) -> None: - with self._get_cursor() as cur: - cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints") + with self._get_connection() as conn: + with conn.cursor() as cur: + cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints") + conn.commit() + conn.close() def _create_collection(self, dimension: int): cache_key = f"vector_indexing_{self._collection_name}" @@ -293,11 +320,14 @@ class OracleVector(BaseVector): if redis_client.get(collection_exist_cache_key): return - with self._get_cursor() as cur: - cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name)) - redis_client.set(collection_exist_cache_key, 1, ex=3600) - with self._get_cursor() as cur: - cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name)) + with self._get_connection() as conn: + with conn.cursor() as cur: + cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name)) + redis_client.set(collection_exist_cache_key, 1, ex=3600) + with conn.cursor() as cur: + cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name)) + conn.commit() + conn.close() class OracleVectorFactory(AbstractVectorFactory): diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 70c618a631..edaa8c92fa 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -126,9 +126,7 @@ class WordExtractor(BaseExtractor): db.session.add(upload_file) db.session.commit() - image_map[rel.target_part] = ( - f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/file-preview)" - ) + image_map[rel.target_part] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)" return image_map diff --git a/api/core/repository/__init__.py b/api/core/repository/__init__.py new file mode 100644 index 0000000000..253df1251d --- /dev/null +++ b/api/core/repository/__init__.py @@ -0,0 +1,15 @@ +""" +Repository interfaces for data access. + +This package contains repository interfaces that define the contract +for accessing and manipulating data, regardless of the underlying +storage mechanism. +""" + +from core.repository.repository_factory import RepositoryFactory +from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository + +__all__ = [ + "RepositoryFactory", + "WorkflowNodeExecutionRepository", +] diff --git a/api/core/repository/repository_factory.py b/api/core/repository/repository_factory.py new file mode 100644 index 0000000000..7da7e49055 --- /dev/null +++ b/api/core/repository/repository_factory.py @@ -0,0 +1,97 @@ +""" +Repository factory for creating repository instances. + +This module provides a simple factory interface for creating repository instances. +It does not contain any implementation details or dependencies on specific repositories. +""" + +from collections.abc import Callable, Mapping +from typing import Any, Literal, Optional, cast + +from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository + +# Type for factory functions - takes a dict of parameters and returns any repository type +RepositoryFactoryFunc = Callable[[Mapping[str, Any]], Any] + +# Type for workflow node execution factory function +WorkflowNodeExecutionFactoryFunc = Callable[[Mapping[str, Any]], WorkflowNodeExecutionRepository] + +# Repository type literals +_RepositoryType = Literal["workflow_node_execution"] + + +class RepositoryFactory: + """ + Factory class for creating repository instances. + + This factory delegates the actual repository creation to implementation-specific + factory functions that are registered with the factory at runtime. + """ + + # Dictionary to store factory functions + _factory_functions: dict[str, RepositoryFactoryFunc] = {} + + @classmethod + def _register_factory(cls, repository_type: _RepositoryType, factory_func: RepositoryFactoryFunc) -> None: + """ + Register a factory function for a specific repository type. + This is a private method and should not be called directly. + + Args: + repository_type: The type of repository (e.g., 'workflow_node_execution') + factory_func: A function that takes parameters and returns a repository instance + """ + cls._factory_functions[repository_type] = factory_func + + @classmethod + def _create_repository(cls, repository_type: _RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any: + """ + Create a new repository instance with the provided parameters. + This is a private method and should not be called directly. + + Args: + repository_type: The type of repository to create + params: A dictionary of parameters to pass to the factory function + + Returns: + A new instance of the requested repository + + Raises: + ValueError: If no factory function is registered for the repository type + """ + if repository_type not in cls._factory_functions: + raise ValueError(f"No factory function registered for repository type '{repository_type}'") + + # Use empty dict if params is None + params = params or {} + + return cls._factory_functions[repository_type](params) + + @classmethod + def register_workflow_node_execution_factory(cls, factory_func: WorkflowNodeExecutionFactoryFunc) -> None: + """ + Register a factory function for the workflow node execution repository. + + Args: + factory_func: A function that takes parameters and returns a WorkflowNodeExecutionRepository instance + """ + cls._register_factory("workflow_node_execution", factory_func) + + @classmethod + def create_workflow_node_execution_repository( + cls, params: Optional[Mapping[str, Any]] = None + ) -> WorkflowNodeExecutionRepository: + """ + Create a new WorkflowNodeExecutionRepository instance with the provided parameters. + + Args: + params: A dictionary of parameters to pass to the factory function + + Returns: + A new instance of the WorkflowNodeExecutionRepository + + Raises: + ValueError: If no factory function is registered for the workflow_node_execution repository type + """ + # We can safely cast here because we've registered a WorkflowNodeExecutionFactoryFunc + return cast(WorkflowNodeExecutionRepository, cls._create_repository("workflow_node_execution", params)) diff --git a/api/core/repository/workflow_node_execution_repository.py b/api/core/repository/workflow_node_execution_repository.py new file mode 100644 index 0000000000..9bb790cb0f --- /dev/null +++ b/api/core/repository/workflow_node_execution_repository.py @@ -0,0 +1,97 @@ +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Literal, Optional, Protocol + +from models.workflow import WorkflowNodeExecution + + +@dataclass +class OrderConfig: + """Configuration for ordering WorkflowNodeExecution instances.""" + + order_by: list[str] + order_direction: Optional[Literal["asc", "desc"]] = None + + +class WorkflowNodeExecutionRepository(Protocol): + """ + Repository interface for WorkflowNodeExecution. + + This interface defines the contract for accessing and manipulating + WorkflowNodeExecution data, regardless of the underlying storage mechanism. + + Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), + and trigger sources (triggered_from) should be handled at the implementation level, not in + the core interface. This keeps the core domain model clean and independent of specific + application domains or deployment scenarios. + """ + + def save(self, execution: WorkflowNodeExecution) -> None: + """ + Save a WorkflowNodeExecution instance. + + Args: + execution: The WorkflowNodeExecution instance to save + """ + ... + + def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: + """ + Retrieve a WorkflowNodeExecution by its node_execution_id. + + Args: + node_execution_id: The node execution ID + + Returns: + The WorkflowNodeExecution instance if found, None otherwise + """ + ... + + def get_by_workflow_run( + self, + workflow_run_id: str, + order_config: Optional[OrderConfig] = None, + ) -> Sequence[WorkflowNodeExecution]: + """ + Retrieve all WorkflowNodeExecution instances for a specific workflow run. + + Args: + workflow_run_id: The workflow run ID + order_config: Optional configuration for ordering results + order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) + order_config.order_direction: Direction to order ("asc" or "desc") + + Returns: + A list of WorkflowNodeExecution instances + """ + ... + + def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: + """ + Retrieve all running WorkflowNodeExecution instances for a specific workflow run. + + Args: + workflow_run_id: The workflow run ID + + Returns: + A list of running WorkflowNodeExecution instances + """ + ... + + def update(self, execution: WorkflowNodeExecution) -> None: + """ + Update an existing WorkflowNodeExecution instance. + + Args: + execution: The WorkflowNodeExecution instance to update + """ + ... + + def clear(self) -> None: + """ + Clear all WorkflowNodeExecution records based on implementation-specific criteria. + + This method is intended to be used for bulk deletion operations, such as removing + all records associated with a specific app_id and tenant_id in multi-tenant implementations. + """ + ... diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index f661294ec4..f5838c3b76 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -94,7 +94,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): "title": item.metadata.get("title"), "content": item.page_content, } - context_list.append(source) + context_list.append(source) for hit_callback in self.hit_callbacks: hit_callback.return_retriever_resource_info(context_list) diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 7c8960fe49..da40cbcdea 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -16,7 +16,7 @@ from core.variables.segments import StringSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.agent.entities import AgentNodeData, ParamsAutoGenerated +from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event.event import RunCompletedEvent @@ -251,7 +251,12 @@ class AgentNode(ToolNode): prompt_message.model_dump(mode="json") for prompt_message in prompt_messages ] value["history_prompt_messages"] = history_prompt_messages - value["entity"] = model_schema.model_dump(mode="json") if model_schema else None + if model_schema: + # remove structured output feature to support old version agent plugin + model_schema = self._remove_unsupported_model_features_for_old_version(model_schema) + value["entity"] = model_schema.model_dump(mode="json") + else: + value["entity"] = None result[parameter_name] = value return result @@ -348,3 +353,10 @@ class AgentNode(ToolNode): ) model_schema = model_type_instance.get_model_schema(model_name, model_credentials) return model_instance, model_schema + + def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity: + if model_schema.features: + for feature in model_schema.features: + if feature.value not in AgentOldVersionModelFeatures: + model_schema.features.remove(feature) + return model_schema diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 87cc7e9824..77e94375bf 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -24,3 +24,18 @@ class AgentNodeData(BaseNodeData): class ParamsAutoGenerated(Enum): CLOSE = 0 OPEN = 1 + + +class AgentOldVersionModelFeatures(Enum): + """ + Enum class for old SDK version llm feature. + """ + + TOOL_CALL = "tool-call" + MULTI_TOOL_CALL = "multi-tool-call" + AGENT_THOUGHT = "agent-thought" + VISION = "vision" + STREAM_TOOL_CALL = "stream-tool-call" + DOCUMENT = "document" + VIDEO = "video" + AUDIO = "audio" diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index bf54fdb80c..486b4b01af 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -65,6 +65,8 @@ class LLMNodeData(BaseNodeData): memory: Optional[MemoryConfig] = None context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) + structured_output: dict | None = None + structured_output_enabled: bool = False @field_validator("prompt_config", mode="before") @classmethod diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index fe0ed3e564..8db7394e54 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -4,6 +4,8 @@ from collections.abc import Generator, Mapping, Sequence from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, Optional, cast +import json_repair + from configs import dify_config from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus @@ -27,7 +29,13 @@ from core.model_runtime.entities.message_entities import ( SystemPromptMessage, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, +) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import ModelProviderID @@ -57,6 +65,12 @@ from core.workflow.nodes.event import ( RunRetrieverResourceEvent, RunStreamChunkEvent, ) +from core.workflow.utils.structured_output.entities import ( + ResponseFormat, + SpecialModelType, + SupportStructuredOutputStatus, +) +from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from models.model import Conversation @@ -92,6 +106,12 @@ class LLMNode(BaseNode[LLMNodeData]): _node_type = NodeType.LLM def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: + def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]: + """Process structured output if enabled""" + if not self.node_data.structured_output_enabled or not self.node_data.structured_output: + return None + return self._parse_structured_output(text) + node_inputs: Optional[dict[str, Any]] = None process_data = None result_text = "" @@ -130,7 +150,6 @@ class LLMNode(BaseNode[LLMNodeData]): if isinstance(event, RunRetrieverResourceEvent): context = event.context yield event - if context: node_inputs["#context#"] = context @@ -192,7 +211,9 @@ class LLMNode(BaseNode[LLMNodeData]): self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) break outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} - + structured_output = process_structured_output(result_text) + if structured_output: + outputs["structured_output"] = structured_output yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -513,7 +534,12 @@ class LLMNode(BaseNode[LLMNodeData]): if not model_schema: raise ModelNotExistError(f"Model {model_name} not exist.") - + support_structured_output = self._check_model_structured_output_support() + if support_structured_output == SupportStructuredOutputStatus.SUPPORTED: + completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules) + elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: + # Set appropriate response format based on model capabilities + self._set_response_format(completion_params, model_schema.parameter_rules) return model_instance, ModelConfigWithCredentialsEntity( provider=provider_name, model=model_name, @@ -724,10 +750,29 @@ class LLMNode(BaseNode[LLMNodeData]): "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) - + support_structured_output = self._check_model_structured_output_support() + if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: + filtered_prompt_messages = self._handle_prompt_based_schema( + prompt_messages=filtered_prompt_messages, + ) stop = model_config.stop return filtered_prompt_messages, stop + def _parse_structured_output(self, result_text: str) -> dict[str, Any] | list[Any]: + structured_output: dict[str, Any] | list[Any] = {} + try: + parsed = json.loads(result_text) + if not isinstance(parsed, (dict | list)): + raise LLMNodeError(f"Failed to parse structured output: {result_text}") + structured_output = parsed + except json.JSONDecodeError as e: + # if the result_text is not a valid json, try to repair it + parsed = json_repair.loads(result_text) + if not isinstance(parsed, (dict | list)): + raise LLMNodeError(f"Failed to parse structured output: {result_text}") + structured_output = parsed + return structured_output + @classmethod def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: provider_model_bundle = model_instance.provider_model_bundle @@ -926,6 +971,166 @@ class LLMNode(BaseNode[LLMNodeData]): return prompt_messages + def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict: + """ + Handle structured output for models with native JSON schema support. + + :param model_parameters: Model parameters to update + :param rules: Model parameter rules + :return: Updated model parameters with JSON schema configuration + """ + # Process schema according to model requirements + schema = self._fetch_structured_output_schema() + schema_json = self._prepare_schema_for_model(schema) + + # Set JSON schema in parameters + model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False) + + # Set appropriate response format if required by the model + for rule in rules: + if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value + + return model_parameters + + def _handle_prompt_based_schema(self, prompt_messages: Sequence[PromptMessage]) -> list[PromptMessage]: + """ + Handle structured output for models without native JSON schema support. + This function modifies the prompt messages to include schema-based output requirements. + + Args: + prompt_messages: Original sequence of prompt messages + + Returns: + list[PromptMessage]: Updated prompt messages with structured output requirements + """ + # Convert schema to string format + schema_str = json.dumps(self._fetch_structured_output_schema(), ensure_ascii=False) + + # Find existing system prompt with schema placeholder + system_prompt = next( + (prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)), + None, + ) + structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str) + # Prepare system prompt content + system_prompt_content = ( + structured_output_prompt + "\n\n" + system_prompt.content + if system_prompt and isinstance(system_prompt.content, str) + else structured_output_prompt + ) + system_prompt = SystemPromptMessage(content=system_prompt_content) + + # Extract content from the last user message + + filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)] + updated_prompt = [system_prompt] + filtered_prompts + + return updated_prompt + + def _set_response_format(self, model_parameters: dict, rules: list) -> None: + """ + Set the appropriate response format parameter based on model rules. + + :param model_parameters: Model parameters to update + :param rules: Model parameter rules + """ + for rule in rules: + if rule.name == "response_format": + if ResponseFormat.JSON.value in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON.value + elif ResponseFormat.JSON_OBJECT.value in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value + + def _prepare_schema_for_model(self, schema: dict) -> dict: + """ + Prepare JSON schema based on model requirements. + + Different models have different requirements for JSON schema formatting. + This function handles these differences. + + :param schema: The original JSON schema + :return: Processed schema compatible with the current model + """ + + # Deep copy to avoid modifying the original schema + processed_schema = schema.copy() + + # Convert boolean types to string types (common requirement) + convert_boolean_to_string(processed_schema) + + # Apply model-specific transformations + if SpecialModelType.GEMINI in self.node_data.model.name: + remove_additional_properties(processed_schema) + return processed_schema + elif SpecialModelType.OLLAMA in self.node_data.model.provider: + return processed_schema + else: + # Default format with name field + return {"schema": processed_schema, "name": "llm_response"} + + def _fetch_model_schema(self, provider: str) -> AIModelEntity | None: + """ + Fetch model schema + """ + model_name = self.node_data.model.name + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name + ) + model_type_instance = model_instance.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + model_credentials = model_instance.credentials + model_schema = model_type_instance.get_model_schema(model_name, model_credentials) + return model_schema + + def _fetch_structured_output_schema(self) -> dict[str, Any]: + """ + Fetch the structured output schema from the node data. + + Returns: + dict[str, Any]: The structured output schema + """ + if not self.node_data.structured_output: + raise LLMNodeError("Please provide a valid structured output schema") + structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False) + if not structured_output_schema: + raise LLMNodeError("Please provide a valid structured output schema") + + try: + schema = json.loads(structured_output_schema) + if not isinstance(schema, dict): + raise LLMNodeError("structured_output_schema must be a JSON object") + return schema + except json.JSONDecodeError: + raise LLMNodeError("structured_output_schema is not valid JSON format") + + def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus: + """ + Check if the current model supports structured output. + + Returns: + SupportStructuredOutput: The support status of structured output + """ + # Early return if structured output is disabled + if ( + not isinstance(self.node_data, LLMNodeData) + or not self.node_data.structured_output_enabled + or not self.node_data.structured_output + ): + return SupportStructuredOutputStatus.DISABLED + # Get model schema and check if it exists + model_schema = self._fetch_model_schema(self.node_data.model.provider) + if not model_schema: + return SupportStructuredOutputStatus.DISABLED + + # Check if model supports structured output feature + return ( + SupportStructuredOutputStatus.SUPPORTED + if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features) + else SupportStructuredOutputStatus.UNSUPPORTED + ) + def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole): match role: @@ -1064,3 +1269,49 @@ def _handle_completion_template( ) prompt_messages.append(prompt_message) return prompt_messages + + +def remove_additional_properties(schema: dict) -> None: + """ + Remove additionalProperties fields from JSON schema. + Used for models like Gemini that don't support this property. + + :param schema: JSON schema to modify in-place + """ + if not isinstance(schema, dict): + return + + # Remove additionalProperties at current level + schema.pop("additionalProperties", None) + + # Process nested structures recursively + for value in schema.values(): + if isinstance(value, dict): + remove_additional_properties(value) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + remove_additional_properties(item) + + +def convert_boolean_to_string(schema: dict) -> None: + """ + Convert boolean type specifications to string in JSON schema. + + :param schema: JSON schema to modify in-place + """ + if not isinstance(schema, dict): + return + + # Check for boolean type at current level + if schema.get("type") == "boolean": + schema["type"] = "string" + + # Process nested dictionaries and lists recursively + for value in schema.values(): + if isinstance(value, dict): + convert_boolean_to_string(value) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + convert_boolean_to_string(item) diff --git a/api/core/workflow/utils/structured_output/entities.py b/api/core/workflow/utils/structured_output/entities.py new file mode 100644 index 0000000000..7954acbaee --- /dev/null +++ b/api/core/workflow/utils/structured_output/entities.py @@ -0,0 +1,24 @@ +from enum import StrEnum + + +class ResponseFormat(StrEnum): + """Constants for model response formats""" + + JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode. + JSON = "JSON" # model's json mode. some model like claude support this mode. + JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias. + + +class SpecialModelType(StrEnum): + """Constants for identifying model types""" + + GEMINI = "gemini" + OLLAMA = "ollama" + + +class SupportStructuredOutputStatus(StrEnum): + """Constants for structured output support status""" + + SUPPORTED = "supported" + UNSUPPORTED = "unsupported" + DISABLED = "disabled" diff --git a/api/core/workflow/utils/structured_output/prompt.py b/api/core/workflow/utils/structured_output/prompt.py new file mode 100644 index 0000000000..06d9b2056e --- /dev/null +++ b/api/core/workflow/utils/structured_output/prompt.py @@ -0,0 +1,17 @@ +STRUCTURED_OUTPUT_PROMPT = """You’re a helpful AI assistant. You could answer questions and output in JSON format. +constraints: + - You must output in JSON format. + - Do not output boolean value, use string type instead. + - Do not output integer or float value, use number type instead. +eg: + Here is the JSON schema: + {"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"} + + Here is the user's question: + My name is John Doe and I am 30 years old. + + output: + {"name": "John Doe", "age": 30} +Here is the JSON schema: +{{schema}} +""" # noqa: E501 diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py index 422ec87765..aa55862b7c 100644 --- a/api/extensions/ext_logging.py +++ b/api/extensions/ext_logging.py @@ -26,9 +26,12 @@ def init_app(app: DifyApp): # Always add StreamHandler to log to console sh = logging.StreamHandler(sys.stdout) - sh.addFilter(RequestIdFilter()) log_handlers.append(sh) + # Apply RequestIdFilter to all handlers + for handler in log_handlers: + handler.addFilter(RequestIdFilter()) + logging.basicConfig( level=dify_config.LOG_LEVEL, format=dify_config.LOG_FORMAT, diff --git a/api/extensions/ext_repositories.py b/api/extensions/ext_repositories.py new file mode 100644 index 0000000000..27d8408ec1 --- /dev/null +++ b/api/extensions/ext_repositories.py @@ -0,0 +1,18 @@ +""" +Extension for initializing repositories. + +This extension registers repository implementations with the RepositoryFactory. +""" + +from dify_app import DifyApp +from repositories.repository_registry import register_repositories + + +def init_app(_app: DifyApp) -> None: + """ + Initialize repository implementations. + + Args: + _app: The Flask application instance (unused) + """ + register_repositories() diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index 588bdb2d27..4c811c66ba 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -73,11 +73,7 @@ class Storage: raise ValueError(f"unsupported storage type {storage_type}") def save(self, filename, data): - try: - self.storage_runner.save(filename, data) - except Exception as e: - logger.exception(f"Failed to save file {filename}") - raise e + self.storage_runner.save(filename, data) @overload def load(self, filename: str, /, *, stream: Literal[False] = False) -> bytes: ... @@ -86,49 +82,25 @@ class Storage: def load(self, filename: str, /, *, stream: Literal[True]) -> Generator: ... def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]: - try: - if stream: - return self.load_stream(filename) - else: - return self.load_once(filename) - except Exception as e: - logger.exception(f"Failed to load file {filename}") - raise e + if stream: + return self.load_stream(filename) + else: + return self.load_once(filename) def load_once(self, filename: str) -> bytes: - try: - return self.storage_runner.load_once(filename) - except Exception as e: - logger.exception(f"Failed to load_once file {filename}") - raise e + return self.storage_runner.load_once(filename) def load_stream(self, filename: str) -> Generator: - try: - return self.storage_runner.load_stream(filename) - except Exception as e: - logger.exception(f"Failed to load_stream file {filename}") - raise e + return self.storage_runner.load_stream(filename) def download(self, filename, target_filepath): - try: - self.storage_runner.download(filename, target_filepath) - except Exception as e: - logger.exception(f"Failed to download file {filename}") - raise e + self.storage_runner.download(filename, target_filepath) def exists(self, filename): - try: - return self.storage_runner.exists(filename) - except Exception as e: - logger.exception(f"Failed to check file exists {filename}") - raise e + return self.storage_runner.exists(filename) def delete(self, filename): - try: - return self.storage_runner.delete(filename) - except Exception as e: - logger.exception(f"Failed to delete file {filename}") - raise e + return self.storage_runner.delete(filename) storage = Storage() diff --git a/api/models/model.py b/api/models/model.py index a826d13e7d..6577492d1b 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1091,12 +1091,7 @@ class Message(db.Model): # type: ignore[name-defined] @property def retriever_resources(self): - return ( - db.session.query(DatasetRetrieverResource) - .filter(DatasetRetrieverResource.message_id == self.id) - .order_by(DatasetRetrieverResource.position.asc()) - .all() - ) + return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else [] @property def message_files(self): diff --git a/api/models/workflow.py b/api/models/workflow.py index 8b7c376e4b..5a67fa47a8 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -245,6 +245,13 @@ class Workflow(Base): @property def tool_published(self) -> bool: + """ + DEPRECATED: This property is not accurate for determining if a workflow is published as a tool. + It only checks if there's a WorkflowToolProvider for the app, not if this specific workflow version + is the one being used by the tool. + + For accurate checking, use a direct query with tenant_id, app_id, and version. + """ from models.tools import WorkflowToolProvider return ( @@ -510,7 +517,7 @@ class WorkflowRun(Base): ) -class WorkflowNodeExecutionTriggeredFrom(Enum): +class WorkflowNodeExecutionTriggeredFrom(StrEnum): """ Workflow Node Execution Triggered From Enum """ @@ -518,21 +525,8 @@ class WorkflowNodeExecutionTriggeredFrom(Enum): SINGLE_STEP = "single-step" WORKFLOW_RUN = "workflow-run" - @classmethod - def value_of(cls, value: str) -> "WorkflowNodeExecutionTriggeredFrom": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid workflow node execution triggered from value {value}") - -class WorkflowNodeExecutionStatus(Enum): +class WorkflowNodeExecutionStatus(StrEnum): """ Workflow Node Execution Status Enum """ @@ -543,19 +537,6 @@ class WorkflowNodeExecutionStatus(Enum): EXCEPTION = "exception" RETRY = "retry" - @classmethod - def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid workflow node execution status value {value}") - class WorkflowNodeExecution(Base): """ @@ -656,6 +637,7 @@ class WorkflowNodeExecution(Base): @property def created_by_account(self): created_by_role = CreatedByRole(self.created_by_role) + # TODO(-LAN-): Avoid using db.session.get() here. return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property @@ -663,6 +645,7 @@ class WorkflowNodeExecution(Base): from models.model import EndUser created_by_role = CreatedByRole(self.created_by_role) + # TODO(-LAN-): Avoid using db.session.get() here. return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None @property diff --git a/api/pyproject.toml b/api/pyproject.toml index 85679a6359..4992178423 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "gunicorn~=23.0.0", "httpx[socks]~=0.27.0", "jieba==0.42.1", + "json-repair>=0.41.1", "langfuse~=2.51.3", "langsmith~=0.1.77", "mailchimp-transactional~=1.0.50", @@ -163,10 +164,7 @@ storage = [ ############################################################ # [ Tools ] dependency group ############################################################ -tools = [ - "cloudscraper~=1.2.71", - "nltk~=3.9.1", -] +tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"] ############################################################ # [ VDB ] dependency group @@ -180,7 +178,7 @@ vdb = [ "couchbase~=4.3.0", "elasticsearch==8.14.0", "opensearch-py==2.4.0", - "oracledb~=2.2.1", + "oracledb==3.0.0", "pgvecto-rs[sqlalchemy]~=0.2.1", "pgvector==0.2.5", "pymilvus~=2.5.0", diff --git a/api/repositories/__init__.py b/api/repositories/__init__.py new file mode 100644 index 0000000000..4cc339688b --- /dev/null +++ b/api/repositories/__init__.py @@ -0,0 +1,6 @@ +""" +Repository implementations for data access. + +This package contains concrete implementations of the repository interfaces +defined in the core.repository package. +""" diff --git a/api/repositories/repository_registry.py b/api/repositories/repository_registry.py new file mode 100644 index 0000000000..aa0a208d8e --- /dev/null +++ b/api/repositories/repository_registry.py @@ -0,0 +1,87 @@ +""" +Registry for repository implementations. + +This module is responsible for registering factory functions with the repository factory. +""" + +import logging +from collections.abc import Mapping +from typing import Any + +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from core.repository.repository_factory import RepositoryFactory +from extensions.ext_database import db +from repositories.workflow_node_execution import SQLAlchemyWorkflowNodeExecutionRepository + +logger = logging.getLogger(__name__) + +# Storage type constants +STORAGE_TYPE_RDBMS = "rdbms" +STORAGE_TYPE_HYBRID = "hybrid" + + +def register_repositories() -> None: + """ + Register repository factory functions with the RepositoryFactory. + + This function reads configuration settings to determine which repository + implementations to register. + """ + # Configure WorkflowNodeExecutionRepository factory based on configuration + workflow_node_execution_storage = dify_config.WORKFLOW_NODE_EXECUTION_STORAGE + + # Check storage type and register appropriate implementation + if workflow_node_execution_storage == STORAGE_TYPE_RDBMS: + # Register SQLAlchemy implementation for RDBMS storage + logger.info("Registering WorkflowNodeExecution repository with RDBMS storage") + RepositoryFactory.register_workflow_node_execution_factory(create_workflow_node_execution_repository) + elif workflow_node_execution_storage == STORAGE_TYPE_HYBRID: + # Hybrid storage is not yet implemented + raise NotImplementedError("Hybrid storage for WorkflowNodeExecution repository is not yet implemented") + else: + # Unknown storage type + raise ValueError( + f"Unknown storage type '{workflow_node_execution_storage}' for WorkflowNodeExecution repository. " + f"Supported types: {STORAGE_TYPE_RDBMS}" + ) + + +def create_workflow_node_execution_repository(params: Mapping[str, Any]) -> SQLAlchemyWorkflowNodeExecutionRepository: + """ + Create a WorkflowNodeExecutionRepository instance using SQLAlchemy implementation. + + This factory function creates a repository for the RDBMS storage type. + + Args: + params: Parameters for creating the repository, including: + - tenant_id: Required. The tenant ID for multi-tenancy. + - app_id: Optional. The application ID for filtering. + - session_factory: Optional. A SQLAlchemy sessionmaker instance. If not provided, + a new sessionmaker will be created using the global database engine. + + Returns: + A WorkflowNodeExecutionRepository instance + + Raises: + ValueError: If required parameters are missing + """ + # Extract required parameters + tenant_id = params.get("tenant_id") + if tenant_id is None: + raise ValueError("tenant_id is required for WorkflowNodeExecution repository with RDBMS storage") + + # Extract optional parameters + app_id = params.get("app_id") + + # Use the session_factory from params if provided, otherwise create one using the global db engine + session_factory = params.get("session_factory") + if session_factory is None: + # Create a sessionmaker using the same engine as the global db session + session_factory = sessionmaker(bind=db.engine) + + # Create and return the repository + return SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, tenant_id=tenant_id, app_id=app_id + ) diff --git a/api/repositories/workflow_node_execution/__init__.py b/api/repositories/workflow_node_execution/__init__.py new file mode 100644 index 0000000000..eed827bd05 --- /dev/null +++ b/api/repositories/workflow_node_execution/__init__.py @@ -0,0 +1,9 @@ +""" +WorkflowNodeExecution repository implementations. +""" + +from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository + +__all__ = [ + "SQLAlchemyWorkflowNodeExecutionRepository", +] diff --git a/api/repositories/workflow_node_execution/sqlalchemy_repository.py b/api/repositories/workflow_node_execution/sqlalchemy_repository.py new file mode 100644 index 0000000000..0594d816a2 --- /dev/null +++ b/api/repositories/workflow_node_execution/sqlalchemy_repository.py @@ -0,0 +1,192 @@ +""" +SQLAlchemy implementation of the WorkflowNodeExecutionRepository. +""" + +import logging +from collections.abc import Sequence +from typing import Optional + +from sqlalchemy import UnaryExpression, asc, delete, desc, select +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from core.repository.workflow_node_execution_repository import OrderConfig +from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom + +logger = logging.getLogger(__name__) + + +class SQLAlchemyWorkflowNodeExecutionRepository: + """ + SQLAlchemy implementation of the WorkflowNodeExecutionRepository interface. + + This implementation supports multi-tenancy by filtering operations based on tenant_id. + Each method creates its own session, handles the transaction, and commits changes + to the database. This prevents long-running connections in the workflow core. + """ + + def __init__(self, session_factory: sessionmaker | Engine, tenant_id: str, app_id: Optional[str] = None): + """ + Initialize the repository with a SQLAlchemy sessionmaker or engine and tenant context. + + Args: + session_factory: SQLAlchemy sessionmaker or engine for creating sessions + tenant_id: Tenant ID for multi-tenancy + app_id: Optional app ID for filtering by application + """ + # If an engine is provided, create a sessionmaker from it + if isinstance(session_factory, Engine): + self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) + else: + self._session_factory = session_factory + + self._tenant_id = tenant_id + self._app_id = app_id + + def save(self, execution: WorkflowNodeExecution) -> None: + """ + Save a WorkflowNodeExecution instance and commit changes to the database. + + Args: + execution: The WorkflowNodeExecution instance to save + """ + with self._session_factory() as session: + # Ensure tenant_id is set + if not execution.tenant_id: + execution.tenant_id = self._tenant_id + + # Set app_id if provided and not already set + if self._app_id and not execution.app_id: + execution.app_id = self._app_id + + session.add(execution) + session.commit() + + def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: + """ + Retrieve a WorkflowNodeExecution by its node_execution_id. + + Args: + node_execution_id: The node execution ID + + Returns: + The WorkflowNodeExecution instance if found, None otherwise + """ + with self._session_factory() as session: + stmt = select(WorkflowNodeExecution).where( + WorkflowNodeExecution.node_execution_id == node_execution_id, + WorkflowNodeExecution.tenant_id == self._tenant_id, + ) + + if self._app_id: + stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) + + return session.scalar(stmt) + + def get_by_workflow_run( + self, + workflow_run_id: str, + order_config: Optional[OrderConfig] = None, + ) -> Sequence[WorkflowNodeExecution]: + """ + Retrieve all WorkflowNodeExecution instances for a specific workflow run. + + Args: + workflow_run_id: The workflow run ID + order_config: Optional configuration for ordering results + order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) + order_config.order_direction: Direction to order ("asc" or "desc") + + Returns: + A list of WorkflowNodeExecution instances + """ + with self._session_factory() as session: + stmt = select(WorkflowNodeExecution).where( + WorkflowNodeExecution.workflow_run_id == workflow_run_id, + WorkflowNodeExecution.tenant_id == self._tenant_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + if self._app_id: + stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) + + # Apply ordering if provided + if order_config and order_config.order_by: + order_columns: list[UnaryExpression] = [] + for field in order_config.order_by: + column = getattr(WorkflowNodeExecution, field, None) + if not column: + continue + if order_config.order_direction == "desc": + order_columns.append(desc(column)) + else: + order_columns.append(asc(column)) + + if order_columns: + stmt = stmt.order_by(*order_columns) + + return session.scalars(stmt).all() + + def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: + """ + Retrieve all running WorkflowNodeExecution instances for a specific workflow run. + + Args: + workflow_run_id: The workflow run ID + + Returns: + A list of running WorkflowNodeExecution instances + """ + with self._session_factory() as session: + stmt = select(WorkflowNodeExecution).where( + WorkflowNodeExecution.workflow_run_id == workflow_run_id, + WorkflowNodeExecution.tenant_id == self._tenant_id, + WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + if self._app_id: + stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) + + return session.scalars(stmt).all() + + def update(self, execution: WorkflowNodeExecution) -> None: + """ + Update an existing WorkflowNodeExecution instance and commit changes to the database. + + Args: + execution: The WorkflowNodeExecution instance to update + """ + with self._session_factory() as session: + # Ensure tenant_id is set + if not execution.tenant_id: + execution.tenant_id = self._tenant_id + + # Set app_id if provided and not already set + if self._app_id and not execution.app_id: + execution.app_id = self._app_id + + session.merge(execution) + session.commit() + + def clear(self) -> None: + """ + Clear all WorkflowNodeExecution records for the current tenant_id and app_id. + + This method deletes all WorkflowNodeExecution records that match the tenant_id + and app_id (if provided) associated with this repository instance. + """ + with self._session_factory() as session: + stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id) + + if self._app_id: + stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) + + result = session.execute(stmt) + session.commit() + + deleted_count = result.rowcount + logger.info( + f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}" + + (f" and app {self._app_id}" if self._app_id else "") + ) diff --git a/api/services/account_service.py b/api/services/account_service.py index ada8109067..f930ef910b 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -407,10 +407,8 @@ class AccountService: raise PasswordResetRateLimitExceededError() - code = "".join([str(random.randint(0, 9)) for _ in range(6)]) - token = TokenManager.generate_token( - account=account, email=email, token_type="reset_password", additional_data={"code": code} - ) + code, token = cls.generate_reset_password_token(account_email, account) + send_reset_password_mail_task.delay( language=language, to=account_email, @@ -419,6 +417,22 @@ class AccountService: cls.reset_password_rate_limiter.increment_rate_limit(account_email) return token + @classmethod + def generate_reset_password_token( + cls, + email: str, + account: Optional[Account] = None, + code: Optional[str] = None, + additional_data: dict[str, Any] = {}, + ): + if not code: + code = "".join([str(random.randint(0, 9)) for _ in range(6)]) + additional_data["code"] = code + token = TokenManager.generate_token( + account=account, email=email, token_type="reset_password", additional_data=additional_data + ) + return code, token + @classmethod def revoke_reset_password_token(cls, token: str): TokenManager.revoke_token(token, "reset_password") diff --git a/api/services/plugin/dependencies_analysis.py b/api/services/plugin/dependencies_analysis.py index 778f05a0cd..07e624b4e8 100644 --- a/api/services/plugin/dependencies_analysis.py +++ b/api/services/plugin/dependencies_analysis.py @@ -1,3 +1,4 @@ +from configs import dify_config from core.helper import marketplace from core.plugin.entities.plugin import ModelProviderID, PluginDependency, PluginInstallationSource, ToolProviderID from core.plugin.manager.plugin import PluginInstallationManager @@ -111,6 +112,8 @@ class DependenciesAnalysisService: Generate the latest version of dependencies """ dependencies = list(set(dependencies)) + if not dify_config.MARKETPLACE_ENABLED: + return [] deps = marketplace.batch_fetch_plugin_manifests(dependencies) return [ PluginDependency( diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 0ddd18ea27..ff3b33eecd 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -2,13 +2,14 @@ import threading from typing import Optional import contexts +from core.repository import RepositoryFactory +from core.repository.workflow_node_execution_repository import OrderConfig from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.model import App from models.workflow import ( WorkflowNodeExecution, - WorkflowNodeExecutionTriggeredFrom, WorkflowRun, ) @@ -127,17 +128,17 @@ class WorkflowRunService: if not workflow_run: return [] - node_executions = ( - db.session.query(WorkflowNodeExecution) - .filter( - WorkflowNodeExecution.tenant_id == app_model.tenant_id, - WorkflowNodeExecution.app_id == app_model.id, - WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - WorkflowNodeExecution.workflow_run_id == run_id, - ) - .order_by(WorkflowNodeExecution.index.desc()) - .all() + # Use the repository to get the node executions + repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": app_model.tenant_id, + "app_id": app_model.id, + "session_factory": db.session.get_bind, + } ) - return node_executions + # Use the repository to get the node executions with ordering + order_config = OrderConfig(order_by=["index"], order_direction="desc") + node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config) + + return list(node_executions) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 992942fc70..5cd5c55746 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -11,6 +11,7 @@ from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.model_runtime.utils.encoders import jsonable_encoder +from core.repository import RepositoryFactory from core.variables import Variable from core.workflow.entities.node_entities import NodeRunResult from core.workflow.errors import WorkflowNodeRunFailedError @@ -27,6 +28,7 @@ from extensions.ext_database import db from models.account import Account from models.enums import CreatedByRole from models.model import App, AppMode +from models.tools import WorkflowToolProvider from models.workflow import ( Workflow, WorkflowNodeExecution, @@ -282,8 +284,15 @@ class WorkflowService: workflow_node_execution.created_by = account.id workflow_node_execution.workflow_id = draft_workflow.id - db.session.add(workflow_node_execution) - db.session.commit() + # Use the repository to save the workflow node execution + repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": app_model.tenant_id, + "app_id": app_model.id, + "session_factory": db.session.get_bind, + } + ) + repository.save(workflow_node_execution) return workflow_node_execution @@ -515,8 +524,19 @@ class WorkflowService: # Cannot delete a workflow that's currently in use by an app raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'") - # Check if this workflow is published as a tool - if workflow.tool_published: + # Don't use workflow.tool_published as it's not accurate for specific workflow versions + # Check if there's a tool provider using this specific workflow version + tool_provider = ( + session.query(WorkflowToolProvider) + .filter( + WorkflowToolProvider.tenant_id == workflow.tenant_id, + WorkflowToolProvider.app_id == workflow.app_id, + WorkflowToolProvider.version == workflow.version, + ) + .first() + ) + + if tool_provider: # Cannot delete a workflow that's published as a tool raise WorkflowInUseError("Cannot delete workflow that is published as a tool") diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index c3910e2be3..4542b1b923 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -7,6 +7,7 @@ from celery import shared_task # type: ignore from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError +from core.repository import RepositoryFactory from extensions.ext_database import db from models.dataset import AppDatasetJoin from models.model import ( @@ -30,7 +31,7 @@ from models.model import ( ) from models.tools import WorkflowToolProvider from models.web import PinnedConversation, SavedMessage -from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun +from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowRun @shared_task(queue="app_deletion", bind=True, max_retries=3) @@ -187,18 +188,20 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str): def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): - def del_workflow_node_execution(workflow_node_execution_id: str): - db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).delete( - synchronize_session=False - ) - - _delete_records( - """select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""", - {"tenant_id": tenant_id, "app_id": app_id}, - del_workflow_node_execution, - "workflow node execution", + # Create a repository instance for WorkflowNodeExecution + repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": tenant_id, + "app_id": app_id, + "session_factory": db.session.get_bind, + } ) + # Use the clear method to delete all records for this tenant_id and app_id + repository.clear() + + logging.info(click.style(f"Deleted workflow node executions for tenant {tenant_id} and app {app_id}", fg="green")) + def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def del_workflow_app_log(workflow_app_log_id: str): diff --git a/api/tests/unit_tests/core/model_runtime/__base/__init__.py b/api/tests/unit_tests/core/model_runtime/__base/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py b/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py new file mode 100644 index 0000000000..93d8a20cac --- /dev/null +++ b/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py @@ -0,0 +1,99 @@ +from unittest.mock import MagicMock, patch + +from core.model_runtime.entities.message_entities import AssistantPromptMessage +from core.model_runtime.model_providers.__base.large_language_model import _increase_tool_call + +ToolCall = AssistantPromptMessage.ToolCall + +# CASE 1: Single tool call +INPUTS_CASE_1 = [ + ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), +] +EXPECTED_CASE_1 = [ + ToolCall( + id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}') + ), +] + +# CASE 2: Tool call sequences where IDs are anchored to the first chunk (vLLM/SiliconFlow ...) +INPUTS_CASE_2 = [ + ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), + ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), +] +EXPECTED_CASE_2 = [ + ToolCall( + id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}') + ), + ToolCall( + id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}') + ), +] + +# CASE 3: Tool call sequences where IDs are anchored to every chunk (SGLang ...) +INPUTS_CASE_3 = [ + ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), + ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), + ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), + ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")), + ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')), + ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), +] +EXPECTED_CASE_3 = [ + ToolCall( + id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}') + ), + ToolCall( + id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}') + ), +] + +# CASE 4: Tool call sequences with no IDs +INPUTS_CASE_4 = [ + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')), + ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')), +] +EXPECTED_CASE_4 = [ + ToolCall( + id="RANDOM_ID_1", + type="function", + function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'), + ), + ToolCall( + id="RANDOM_ID_2", + type="function", + function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}'), + ), +] + + +def _run_case(inputs: list[ToolCall], expected: list[ToolCall]): + actual = [] + _increase_tool_call(inputs, actual) + assert actual == expected + + +def test__increase_tool_call(): + # case 1: + _run_case(INPUTS_CASE_1, EXPECTED_CASE_1) + + # case 2: + _run_case(INPUTS_CASE_2, EXPECTED_CASE_2) + + # case 3: + _run_case(INPUTS_CASE_3, EXPECTED_CASE_3) + + # case 4: + mock_id_generator = MagicMock() + mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4] + with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator): + _run_case(INPUTS_CASE_4, EXPECTED_CASE_4) diff --git a/api/tests/unit_tests/repositories/__init__.py b/api/tests/unit_tests/repositories/__init__.py new file mode 100644 index 0000000000..bc0d6e78c9 --- /dev/null +++ b/api/tests/unit_tests/repositories/__init__.py @@ -0,0 +1,3 @@ +""" +Unit tests for repositories. +""" diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/__init__.py b/api/tests/unit_tests/repositories/workflow_node_execution/__init__.py new file mode 100644 index 0000000000..78815a8d1a --- /dev/null +++ b/api/tests/unit_tests/repositories/workflow_node_execution/__init__.py @@ -0,0 +1,3 @@ +""" +Unit tests for workflow_node_execution repositories. +""" diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py new file mode 100644 index 0000000000..36847f8a13 --- /dev/null +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py @@ -0,0 +1,178 @@ +""" +Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository. +""" + +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture +from sqlalchemy.orm import Session, sessionmaker + +from core.repository.workflow_node_execution_repository import OrderConfig +from models.workflow import WorkflowNodeExecution +from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository + + +@pytest.fixture +def session(): + """Create a mock SQLAlchemy session.""" + session = MagicMock(spec=Session) + # Configure the session to be used as a context manager + session.__enter__ = MagicMock(return_value=session) + session.__exit__ = MagicMock(return_value=None) + + # Configure the session factory to return the session + session_factory = MagicMock(spec=sessionmaker) + session_factory.return_value = session + return session, session_factory + + +@pytest.fixture +def repository(session): + """Create a repository instance with test data.""" + _, session_factory = session + tenant_id = "test-tenant" + app_id = "test-app" + return SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, tenant_id=tenant_id, app_id=app_id + ) + + +def test_save(repository, session): + """Test save method.""" + session_obj, _ = session + # Create a mock execution + execution = MagicMock(spec=WorkflowNodeExecution) + execution.tenant_id = None + execution.app_id = None + + # Call save method + repository.save(execution) + + # Assert tenant_id and app_id are set + assert execution.tenant_id == repository._tenant_id + assert execution.app_id == repository._app_id + + # Assert session.add was called + session_obj.add.assert_called_once_with(execution) + + +def test_save_with_existing_tenant_id(repository, session): + """Test save method with existing tenant_id.""" + session_obj, _ = session + # Create a mock execution with existing tenant_id + execution = MagicMock(spec=WorkflowNodeExecution) + execution.tenant_id = "existing-tenant" + execution.app_id = None + + # Call save method + repository.save(execution) + + # Assert tenant_id is not changed and app_id is set + assert execution.tenant_id == "existing-tenant" + assert execution.app_id == repository._app_id + + # Assert session.add was called + session_obj.add.assert_called_once_with(execution) + + +def test_get_by_node_execution_id(repository, session, mocker: MockerFixture): + """Test get_by_node_execution_id method.""" + session_obj, _ = session + # Set up mock + mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select") + mock_stmt = mocker.MagicMock() + mock_select.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + session_obj.scalar.return_value = mocker.MagicMock(spec=WorkflowNodeExecution) + + # Call method + result = repository.get_by_node_execution_id("test-node-execution-id") + + # Assert select was called with correct parameters + mock_select.assert_called_once() + session_obj.scalar.assert_called_once_with(mock_stmt) + assert result is not None + + +def test_get_by_workflow_run(repository, session, mocker: MockerFixture): + """Test get_by_workflow_run method.""" + session_obj, _ = session + # Set up mock + mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select") + mock_stmt = mocker.MagicMock() + mock_select.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + mock_stmt.order_by.return_value = mock_stmt + session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)] + + # Call method + order_config = OrderConfig(order_by=["index"], order_direction="desc") + result = repository.get_by_workflow_run(workflow_run_id="test-workflow-run-id", order_config=order_config) + + # Assert select was called with correct parameters + mock_select.assert_called_once() + session_obj.scalars.assert_called_once_with(mock_stmt) + assert len(result) == 1 + + +def test_get_running_executions(repository, session, mocker: MockerFixture): + """Test get_running_executions method.""" + session_obj, _ = session + # Set up mock + mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select") + mock_stmt = mocker.MagicMock() + mock_select.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)] + + # Call method + result = repository.get_running_executions("test-workflow-run-id") + + # Assert select was called with correct parameters + mock_select.assert_called_once() + session_obj.scalars.assert_called_once_with(mock_stmt) + assert len(result) == 1 + + +def test_update(repository, session): + """Test update method.""" + session_obj, _ = session + # Create a mock execution + execution = MagicMock(spec=WorkflowNodeExecution) + execution.tenant_id = None + execution.app_id = None + + # Call update method + repository.update(execution) + + # Assert tenant_id and app_id are set + assert execution.tenant_id == repository._tenant_id + assert execution.app_id == repository._app_id + + # Assert session.merge was called + session_obj.merge.assert_called_once_with(execution) + + +def test_clear(repository, session, mocker: MockerFixture): + """Test clear method.""" + session_obj, _ = session + # Set up mock + mock_delete = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.delete") + mock_stmt = mocker.MagicMock() + mock_delete.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + + # Mock the execute result with rowcount + mock_result = mocker.MagicMock() + mock_result.rowcount = 5 # Simulate 5 records deleted + session_obj.execute.return_value = mock_result + + # Call method + repository.clear() + + # Assert delete was called with correct parameters + mock_delete.assert_called_once_with(WorkflowNodeExecution) + mock_stmt.where.assert_called() + session_obj.execute.assert_called_once_with(mock_stmt) + session_obj.commit.assert_called_once() diff --git a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py b/api/tests/unit_tests/services/workflow/test_workflow_deletion.py index 56efcccc78..223020c2c5 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_deletion.py @@ -40,6 +40,10 @@ def workflow_setup(): def test_delete_workflow_success(workflow_setup): # Setup mocks + + # Mock the tool provider query to return None (not published as a tool) + workflow_setup["session"].query.return_value.filter.return_value.first.return_value = None + workflow_setup["session"].scalar = MagicMock( side_effect=[workflow_setup["workflow"], None] ) # Return workflow first, then None for app @@ -97,7 +101,12 @@ def test_delete_workflow_in_use_by_app_error(workflow_setup): def test_delete_workflow_published_as_tool_error(workflow_setup): # Setup mocks - workflow_setup["workflow"].tool_published = True + from models.tools import WorkflowToolProvider + + # Mock the tool provider query + mock_tool_provider = MagicMock(spec=WorkflowToolProvider) + workflow_setup["session"].query.return_value.filter.return_value.first.return_value = mock_tool_provider + workflow_setup["session"].scalar = MagicMock( side_effect=[workflow_setup["workflow"], None] ) # Return workflow first, then None for app diff --git a/api/uv.lock b/api/uv.lock index 4ff9c34446..6c8699dd7c 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.11, <3.13" resolution-markers = [ "python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy'", @@ -1178,6 +1177,7 @@ dependencies = [ { name = "gunicorn" }, { name = "httpx", extra = ["socks"] }, { name = "jieba" }, + { name = "json-repair" }, { name = "langfuse" }, { name = "langsmith" }, { name = "mailchimp-transactional" }, @@ -1346,6 +1346,7 @@ requires-dist = [ { name = "gunicorn", specifier = "~=23.0.0" }, { name = "httpx", extras = ["socks"], specifier = "~=0.27.0" }, { name = "jieba", specifier = "==0.42.1" }, + { name = "json-repair", specifier = ">=0.41.1" }, { name = "langfuse", specifier = "~=2.51.3" }, { name = "langsmith", specifier = "~=0.1.77" }, { name = "mailchimp-transactional", specifier = "~=1.0.50" }, @@ -1470,7 +1471,7 @@ vdb = [ { name = "couchbase", specifier = "~=4.3.0" }, { name = "elasticsearch", specifier = "==8.14.0" }, { name = "opensearch-py", specifier = "==2.4.0" }, - { name = "oracledb", specifier = "~=2.2.1" }, + { name = "oracledb", specifier = "==3.0.0" }, { name = "pgvecto-rs", extras = ["sqlalchemy"], specifier = "~=0.2.1" }, { name = "pgvector", specifier = "==0.2.5" }, { name = "pymilvus", specifier = "~=2.5.0" }, @@ -2524,6 +2525,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/91/29/df4b9b42f2be0b623cbd5e2140cafcaa2bef0759a00b7b70104dcfe2fb51/joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6", size = 301817 }, ] +[[package]] +name = "json-repair" +version = "0.41.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6d/6a/6c7a75a10da6dc807b582f2449034da1ed74415e8899746bdfff97109012/json_repair-0.41.1.tar.gz", hash = "sha256:bba404b0888c84a6b86ecc02ec43b71b673cfee463baf6da94e079c55b136565", size = 31208 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/5c/abd7495c934d9af5c263c2245ae30cfaa716c3c0cf027b2b8fa686ee7bd4/json_repair-0.41.1-py3-none-any.whl", hash = "sha256:0e181fd43a696887881fe19fed23422a54b3e4c558b6ff27a86a8c3ddde9ae79", size = 21578 }, +] + [[package]] name = "jsonpath-python" version = "1.0.6" @@ -3590,23 +3600,23 @@ wheels = [ [[package]] name = "oracledb" -version = "2.2.1" +version = "3.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cryptography" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/36/fb/3fbacb351833dd794abb184303a5761c4bb33df9d770fd15d01ead2ff738/oracledb-2.2.1.tar.gz", hash = "sha256:8464c6f0295f3318daf6c2c72c83c2dcbc37e13f8fd44e3e39ff8665f442d6b6", size = 580818 } +sdist = { url = "https://files.pythonhosted.org/packages/bf/39/712f797b75705c21148fa1d98651f63c2e5cc6876e509a0a9e2f5b406572/oracledb-3.0.0.tar.gz", hash = "sha256:64dc86ee5c032febc556798b06e7b000ef6828bb0252084f6addacad3363db85", size = 840431 } wheels = [ - { url = "https://files.pythonhosted.org/packages/74/b7/a4238295944670fb8cc50a8cc082e0af5a0440bfb1c2bac2b18429c0a579/oracledb-2.2.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:fb6d9a4d7400398b22edb9431334f9add884dec9877fd9c4ae531e1ccc6ee1fd", size = 3551303 }, - { url = "https://files.pythonhosted.org/packages/4f/5f/98481d44976cd2b3086361f2d50026066b24090b0e6cd1f2a12c824e9717/oracledb-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07757c240afbb4f28112a6affc2c5e4e34b8a92e5bb9af81a40fba398da2b028", size = 12258455 }, - { url = "https://files.pythonhosted.org/packages/e9/54/06b2540286e2b63f60877d6f3c6c40747e216b6eeda0756260e194897076/oracledb-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63daec72f853c47179e98493e9b732909d96d495bdceb521c5973a3940d28142", size = 12317476 }, - { url = "https://files.pythonhosted.org/packages/4d/1a/67814439a4e24df83281a72cb0ba433d6b74e1bff52a9975b87a725bcba5/oracledb-2.2.1-cp311-cp311-win32.whl", hash = "sha256:fec5318d1e0ada7e4674574cb6c8d1665398e8b9c02982279107212f05df1660", size = 1369368 }, - { url = "https://files.pythonhosted.org/packages/e3/b8/b2a8f0607be17f58ec6689ad5fd15c2956f4996c64547325e96439570edf/oracledb-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:5134dccb5a11bc755abf02fd49be6dc8141dfcae4b650b55d40509323d00b5c2", size = 1655035 }, - { url = "https://files.pythonhosted.org/packages/24/5b/2fff762243030f31a6b1561fc8eeb142e69ba6ebd3e7fbe4a2c82f0eb6f0/oracledb-2.2.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ac5716bc9a48247fdf563f5f4ec097f5c9f074a60fd130cdfe16699208ca29b5", size = 3583960 }, - { url = "https://files.pythonhosted.org/packages/e6/88/34117ae830e7338af7c0481f1c0fc6eda44d558e12f9203b45b491e53071/oracledb-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c150bddb882b7c73fb462aa2d698744da76c363e404570ed11d05b65811d96c3", size = 11749006 }, - { url = "https://files.pythonhosted.org/packages/9d/58/bac788f18c21f727955652fe238de2d24a12c2b455ed4db18a6d23ff781e/oracledb-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:193e1888411bc21187ade4b16b76820bd1e8f216e25602f6cd0a97d45723c1dc", size = 11950663 }, - { url = "https://files.pythonhosted.org/packages/3b/e2/005f66ae919c6f7c73e06863256cf43aa844330e2dc61a5f9779ae44a801/oracledb-2.2.1-cp312-cp312-win32.whl", hash = "sha256:44a960f8bbb0711af222e0a9690e037b6a2a382e0559ae8eeb9cfafe26c7a3bc", size = 1324255 }, - { url = "https://files.pythonhosted.org/packages/e6/25/759eb2143134513382e66d874c4aacfd691dec3fef7141170cfa6c1b154f/oracledb-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:470136add32f0d0084225c793f12a52b61b52c3dc00c9cd388ec6a3db3a7643e", size = 1613047 }, + { url = "https://files.pythonhosted.org/packages/fa/bf/d872c4b3fc15cd3261fe0ea72b21d181700c92dbc050160e161654987062/oracledb-3.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:52daa9141c63dfa75c07d445e9bb7f69f43bfb3c5a173ecc48c798fe50288d26", size = 4312963 }, + { url = "https://files.pythonhosted.org/packages/b1/ea/01ee29e76a610a53bb34fdc1030f04b7669c3f80b25f661e07850fc6160e/oracledb-3.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:af98941789df4c6aaaf4338f5b5f6b7f2c8c3fe6f8d6a9382f177f350868747a", size = 2661536 }, + { url = "https://files.pythonhosted.org/packages/3d/8e/ad380e34a46819224423b4773e58c350bc6269643c8969604097ced8c3bc/oracledb-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9812bb48865aaec35d73af54cd1746679f2a8a13cbd1412ab371aba2e39b3943", size = 2867461 }, + { url = "https://files.pythonhosted.org/packages/96/09/ecc4384a27fd6e1e4de824ae9c160e4ad3aaebdaade5b4bdcf56a4d1ff63/oracledb-3.0.0-cp311-cp311-win32.whl", hash = "sha256:6c27fe0de64f2652e949eb05b3baa94df9b981a4a45fa7f8a991e1afb450c8e2", size = 1752046 }, + { url = "https://files.pythonhosted.org/packages/62/e8/f34bde24050c6e55eeba46b23b2291f2dd7fd272fa8b322dcbe71be55778/oracledb-3.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:f922709672002f0b40997456f03a95f03e5712a86c61159951c5ce09334325e0", size = 2101210 }, + { url = "https://files.pythonhosted.org/packages/6f/fc/24590c3a3d41e58494bd3c3b447a62835138e5f9b243d9f8da0cfb5da8dc/oracledb-3.0.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:acd0e747227dea01bebe627b07e958bf36588a337539f24db629dc3431d3f7eb", size = 4351993 }, + { url = "https://files.pythonhosted.org/packages/b7/b6/1f3b0b7bb94d53e8857d77b2e8dbdf6da091dd7e377523e24b79dac4fd71/oracledb-3.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f8b402f77c22af031cd0051aea2472ecd0635c1b452998f511aa08b7350c90a4", size = 2532640 }, + { url = "https://files.pythonhosted.org/packages/72/1a/1815f6c086ab49c00921cf155ff5eede5267fb29fcec37cb246339a5ce4d/oracledb-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:378a27782e9a37918bd07a5a1427a77cb6f777d0a5a8eac9c070d786f50120ef", size = 2765949 }, + { url = "https://files.pythonhosted.org/packages/33/8d/208900f8d372909792ee70b2daad3f7361181e55f2217c45ed9dff658b54/oracledb-3.0.0-cp312-cp312-win32.whl", hash = "sha256:54a28c2cb08316a527cd1467740a63771cc1c1164697c932aa834c0967dc4efc", size = 1709373 }, + { url = "https://files.pythonhosted.org/packages/0c/5e/c21754f19c896102793c3afec2277e2180aa7d505e4d7fcca24b52d14e4f/oracledb-3.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:8289bad6d103ce42b140e40576cf0c81633e344d56e2d738b539341eacf65624", size = 2056452 }, ] [[package]] @@ -4074,6 +4084,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/cd/ed6e429fb0792ce368f66e83246264dd3a7a045b0b1e63043ed22a063ce5/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7c9e222d0976f68d0cf6409cfea896676ddc1d98485d601e9508f90f60e2b0a2", size = 2144914 }, { url = "https://files.pythonhosted.org/packages/f6/23/b064bd4cfbf2cc5f25afcde0e7c880df5b20798172793137ba4b62d82e72/pycryptodome-3.19.1-cp35-abi3-win32.whl", hash = "sha256:4805e053571140cb37cf153b5c72cd324bb1e3e837cbe590a19f69b6cf85fd03", size = 1713105 }, { url = "https://files.pythonhosted.org/packages/7d/e0/ded1968a5257ab34216a0f8db7433897a2337d59e6d03be113713b346ea2/pycryptodome-3.19.1-cp35-abi3-win_amd64.whl", hash = "sha256:a470237ee71a1efd63f9becebc0ad84b88ec28e6784a2047684b693f458f41b7", size = 1749222 }, + { url = "https://files.pythonhosted.org/packages/1d/e3/0c9679cd66cf5604b1f070bdf4525a0c01a15187be287d8348b2eafb718e/pycryptodome-3.19.1-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:ed932eb6c2b1c4391e166e1a562c9d2f020bfff44a0e1b108f67af38b390ea89", size = 1629005 }, + { url = "https://files.pythonhosted.org/packages/13/75/0d63bf0daafd0580b17202d8a9dd57f28c8487f26146b3e2799b0c5a059c/pycryptodome-3.19.1-pp27-pypy_73-win32.whl", hash = "sha256:81e9d23c0316fc1b45d984a44881b220062336bbdc340aa9218e8d0656587934", size = 1697997 }, ] [[package]] diff --git a/docker/.env.example b/docker/.env.example index 9b372dcec9..82ef4174c2 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -744,6 +744,12 @@ MAX_VARIABLE_SIZE=204800 WORKFLOW_PARALLEL_DEPTH_LIMIT=3 WORKFLOW_FILE_UPLOAD_LIMIT=10 +# Workflow storage configuration +# Options: rdbms, hybrid +# rdbms: Use only the relational database (default) +# hybrid: Save new data to object storage, read from both object storage and RDBMS +WORKFLOW_NODE_EXECUTION_STORAGE=rdbms + # HTTP request node in workflow configuration HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index a8f7b755fb..c6d41849ef 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -130,6 +130,7 @@ services: HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128} HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128} SANDBOX_PORT: ${SANDBOX_PORT:-8194} + PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} volumes: - ./volumes/sandbox/dependencies:/dependencies - ./volumes/sandbox/conf:/conf diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 27d6d660d0..1702a5395f 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -60,6 +60,7 @@ services: HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128} HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128} SANDBOX_PORT: ${SANDBOX_PORT:-8194} + PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} volumes: - ./volumes/sandbox/dependencies:/dependencies - ./volumes/sandbox/conf:/conf diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 172cbe2d2f..15fab4a4bf 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -327,6 +327,7 @@ x-shared-env: &shared-api-worker-env MAX_VARIABLE_SIZE: ${MAX_VARIABLE_SIZE:-204800} WORKFLOW_PARALLEL_DEPTH_LIMIT: ${WORKFLOW_PARALLEL_DEPTH_LIMIT:-3} WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10} + WORKFLOW_NODE_EXECUTION_STORAGE: ${WORKFLOW_NODE_EXECUTION_STORAGE:-rdbms} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True} @@ -474,7 +475,8 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:1.2.0 + # image: langgenius/dify-api:1.2.0 + build: ../api restart: always environment: # Use the shared environment variables. @@ -503,7 +505,8 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:1.2.0 + # image: langgenius/dify-api:1.2.0 + build: ../api restart: always environment: # Use the shared environment variables. @@ -529,7 +532,8 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.2.0 + # image: langgenius/dify-web:1.2.0 + build: ../web restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -602,6 +606,7 @@ services: HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128} HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128} SANDBOX_PORT: ${SANDBOX_PORT:-8194} + PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} volumes: - ./volumes/sandbox/dependencies:/dependencies - ./volumes/sandbox/conf:/conf diff --git a/web/Dockerfile b/web/Dockerfile index dfc5ba8b46..ed7946924e 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -10,6 +10,8 @@ RUN npm install -g pnpm@10.8.0 ENV PNPM_HOME="/pnpm" ENV PATH="$PNPM_HOME:$PATH" +# if you located in China, you can use taobao mirror to speed up +RUN pnpm config set registry https://registry.npmmirror.com/ # install packages FROM base AS packages @@ -73,4 +75,4 @@ ENV COMMIT_SHA=${COMMIT_SHA} USER 1001 EXPOSE 3000 -ENTRYPOINT ["/bin/sh", "./entrypoint.sh"] +ENTRYPOINT ["/bin/sh", "./entrypoint.sh"] \ No newline at end of file diff --git a/web/README.md b/web/README.md index 3236347e80..3d9fd2de87 100644 --- a/web/README.md +++ b/web/README.md @@ -7,7 +7,7 @@ This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next ### Run by source code Before starting the web frontend service, please make sure the following environment is ready. -- [Node.js](https://nodejs.org) >= v18.x +- [Node.js](https://nodejs.org) >= v22.11.x - [pnpm](https://pnpm.io) v10.x First, install the dependencies: diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/workflow/page.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/workflow/page.tsx index f4d49425ae..d5df70f004 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/workflow/page.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/workflow/page.tsx @@ -1,11 +1,11 @@ 'use client' -import Workflow from '@/app/components/workflow' +import WorkflowApp from '@/app/components/workflow-app' const Page = () => { return (
- +
) } diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index b435a9bb67..a8bb7046e6 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -94,6 +94,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi - semantic_search 语义检索 - full_text_search 全文检索 - reranking_enable (bool) 是否开启rerank + - reranking_mode (String) 混合检索 + - weighted_score 权重设置 + - reranking_model Rerank 模型 - reranking_model (object) Rerank 模型配置 - reranking_provider_name (string) Rerank 模型的提供商 - reranking_model_name (string) Rerank 模型的名称 @@ -591,7 +594,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi 检索参数(选填,如不填,按照默认方式召回) - - search_method (text) 检索方法:以下三个关键字之一,必填 + - search_method (text) 检索方法:以下四个关键字之一,必填 - keyword_search 关键字检索 - semantic_search 语义检索 - full_text_search 全文检索 @@ -1817,7 +1820,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi 检索参数(选填,如不填,按照默认方式召回) - - search_method (text) 检索方法:以下三个关键字之一,必填 + - search_method (text) 检索方法:以下四个关键字之一,必填 - keyword_search 关键字检索 - semantic_search 语义检索 - full_text_search 全文检索 diff --git a/web/app/components/app/configuration/config-var/config-select/index.spec.tsx b/web/app/components/app/configuration/config-var/config-select/index.spec.tsx new file mode 100644 index 0000000000..18df318de3 --- /dev/null +++ b/web/app/components/app/configuration/config-var/config-select/index.spec.tsx @@ -0,0 +1,82 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import ConfigSelect from './index' + +jest.mock('react-sortablejs', () => ({ + ReactSortable: ({ children }: { children: React.ReactNode }) =>
{children}
, +})) + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +describe('ConfigSelect Component', () => { + const defaultProps = { + options: ['Option 1', 'Option 2'], + onChange: jest.fn(), + } + + afterEach(() => { + jest.clearAllMocks() + }) + + it('renders all options', () => { + render() + + defaultProps.options.forEach((option) => { + expect(screen.getByDisplayValue(option)).toBeInTheDocument() + }) + }) + + it('renders add button', () => { + render() + + expect(screen.getByText('appDebug.variableConfig.addOption')).toBeInTheDocument() + }) + + it('handles option deletion', () => { + render() + const optionContainer = screen.getByDisplayValue('Option 1').closest('div') + const deleteButton = optionContainer?.querySelector('div[role="button"]') + + if (!deleteButton) return + fireEvent.click(deleteButton) + expect(defaultProps.onChange).toHaveBeenCalledWith(['Option 2']) + }) + + it('handles adding new option', () => { + render() + const addButton = screen.getByText('appDebug.variableConfig.addOption') + + fireEvent.click(addButton) + + expect(defaultProps.onChange).toHaveBeenCalledWith([...defaultProps.options, '']) + }) + + it('applies focus styles on input focus', () => { + render() + const firstInput = screen.getByDisplayValue('Option 1') + + fireEvent.focus(firstInput) + + expect(firstInput.closest('div')).toHaveClass('border-components-input-border-active') + }) + + it('applies delete hover styles', () => { + render() + const optionContainer = screen.getByDisplayValue('Option 1').closest('div') + const deleteButton = optionContainer?.querySelector('div[role="button"]') + + if (!deleteButton) return + fireEvent.mouseEnter(deleteButton) + expect(optionContainer).toHaveClass('border-components-input-border-destructive') + }) + + it('renders empty state correctly', () => { + render() + + expect(screen.queryByRole('textbox')).not.toBeInTheDocument() + expect(screen.getByText('appDebug.variableConfig.addOption')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/configuration/config-var/config-select/index.tsx b/web/app/components/app/configuration/config-var/config-select/index.tsx index d2dc1662c1..40ddaef78f 100644 --- a/web/app/components/app/configuration/config-var/config-select/index.tsx +++ b/web/app/components/app/configuration/config-var/config-select/index.tsx @@ -51,7 +51,7 @@ const ConfigSelect: FC = ({ { const value = e.target.value @@ -67,6 +67,7 @@ const ConfigSelect: FC = ({ onBlur={() => setFocusID(null)} />
{ onChange(options.filter((_, i) => index !== i)) diff --git a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx index 75183ab5a7..952ad66fc4 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx @@ -163,7 +163,7 @@ const SettingBuiltInTool: FC = ({ footer={null} mask={false} positionCenter={false} - panelClassname={cn('mb-2 mr-2 mt-[64px] !w-[420px] !max-w-[420px] justify-start rounded-2xl border-[0.5px] border-components-panel-border !bg-components-panel-bg !p-0 shadow-xl')} + panelClassName={cn('mb-2 mr-2 mt-[64px] !w-[420px] !max-w-[420px] justify-start rounded-2xl border-[0.5px] border-components-panel-border !bg-components-panel-bg !p-0 shadow-xl')} > <> {isLoading && } diff --git a/web/app/components/app/configuration/dataset-config/card-item/item.tsx b/web/app/components/app/configuration/dataset-config/card-item/item.tsx index d44fb145bb..65ad2ca941 100644 --- a/web/app/components/app/configuration/dataset-config/card-item/item.tsx +++ b/web/app/components/app/configuration/dataset-config/card-item/item.tsx @@ -97,7 +97,7 @@ const Item: FC = ({
- setShowSettingsModal(false)} footer={null} mask={isMobile} panelClassname='mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[640px] rounded-xl'> + setShowSettingsModal(false)} footer={null} mask={isMobile} panelClassName='mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[640px] rounded-xl'> setShowSettingsModal(false)} diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx index 90885dacc8..645f6045f0 100644 --- a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx @@ -62,13 +62,13 @@ const SettingsModal: FC = ({ const { notify } = useToastContext() const ref = useRef(null) const isExternal = currentDataset.provider === 'external' - const [topK, setTopK] = useState(currentDataset?.external_retrieval_model.top_k ?? 2) - const [scoreThreshold, setScoreThreshold] = useState(currentDataset?.external_retrieval_model.score_threshold ?? 0.5) - const [scoreThresholdEnabled, setScoreThresholdEnabled] = useState(currentDataset?.external_retrieval_model.score_threshold_enabled ?? false) const { setShowAccountSettingModal } = useModalContext() const [loading, setLoading] = useState(false) const { isCurrentWorkspaceDatasetOperator } = useAppContext() const [localeCurrentDataset, setLocaleCurrentDataset] = useState({ ...currentDataset }) + const [topK, setTopK] = useState(localeCurrentDataset?.external_retrieval_model.top_k ?? 2) + const [scoreThreshold, setScoreThreshold] = useState(localeCurrentDataset?.external_retrieval_model.score_threshold ?? 0.5) + const [scoreThresholdEnabled, setScoreThresholdEnabled] = useState(localeCurrentDataset?.external_retrieval_model.score_threshold_enabled ?? false) const [selectedMemberIDs, setSelectedMemberIDs] = useState(currentDataset.partial_member_list || []) const [memberList, setMemberList] = useState([]) @@ -88,6 +88,14 @@ const SettingsModal: FC = ({ setScoreThreshold(data.score_threshold) if (data.score_threshold_enabled !== undefined) setScoreThresholdEnabled(data.score_threshold_enabled) + + setLocaleCurrentDataset({ + ...localeCurrentDataset, + external_retrieval_model: { + ...localeCurrentDataset?.external_retrieval_model, + ...data, + }, + }) } const handleSave = async () => { diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index b78af5cdba..056ce84f1e 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -743,7 +743,7 @@ const ConversationList: FC = ({ logs, appDetail, onRefresh }) onClose={onCloseDrawer} mask={isMobile} footer={null} - panelClassname='mt-16 mx-2 sm:mr-2 mb-4 !p-0 !max-w-[640px] rounded-xl bg-components-panel-bg' + panelClassName='mt-16 mx-2 sm:mr-2 mb-4 !p-0 !max-w-[640px] rounded-xl bg-components-panel-bg' > = ({ return check } + const validatePrivacyPolicy = (privacyPolicy: string | null) => { + if (privacyPolicy === null || privacyPolicy?.length === 0) + return true + + return privacyPolicy.startsWith('http://') || privacyPolicy.startsWith('https://') + } + if (inputInfo !== null) { if (!validateColorHex(inputInfo.chatColorTheme)) { notify({ type: 'error', message: t(`${prefixSettings}.invalidHexMessage`) }) return } + if (!validatePrivacyPolicy(inputInfo.privacyPolicy)) { + notify({ type: 'error', message: t(`${prefixSettings}.invalidPrivacyPolicy`) }) + return + } } setSaveLoading(true) @@ -410,7 +421,7 @@ const SettingsModal: FC = ({

}} + components={{ privacyPolicyLink: }} />

= ({ logs, appDetail, onRefresh }) => { onClose={onCloseDrawer} mask={isMobile} footer={null} - panelClassname='mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[600px] rounded-xl border border-components-panel-border' + panelClassName='mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[600px] rounded-xl border border-components-panel-border' >
diff --git a/web/app/components/base/chat/chat/answer/index.tsx b/web/app/components/base/chat/chat/answer/index.tsx index 3722556931..a0a9323729 100644 --- a/web/app/components/base/chat/chat/answer/index.tsx +++ b/web/app/components/base/chat/chat/answer/index.tsx @@ -234,4 +234,6 @@ const Answer: FC = ({ ) } -export default memo(Answer) +export default memo(Answer, (prevProps, nextProps) => + prevProps.responding === false && nextProps.responding === false, +) diff --git a/web/app/components/base/chat/embedded-chatbot/hooks.tsx b/web/app/components/base/chat/embedded-chatbot/hooks.tsx index d6a7b230e4..0f2529152c 100644 --- a/web/app/components/base/chat/embedded-chatbot/hooks.tsx +++ b/web/app/components/base/chat/embedded-chatbot/hooks.tsx @@ -80,8 +80,30 @@ export const useEmbeddedChatbot = () => { }, []) useEffect(() => { - if (appInfo?.site.default_language) - changeLanguage(appInfo.site.default_language) + const setLanguageFromParams = async () => { + // Check URL parameters for language override + const urlParams = new URLSearchParams(window.location.search) + const localeParam = urlParams.get('locale') + + // Check for encoded system variables + const systemVariables = await getProcessedSystemVariablesFromUrlParams() + const localeFromSysVar = systemVariables.locale + + if (localeParam) { + // If locale parameter exists in URL, use it instead of default + changeLanguage(localeParam) + } + else if (localeFromSysVar) { + // If locale is set as a system variable, use that + changeLanguage(localeFromSysVar) + } + else if (appInfo?.site.default_language) { + // Otherwise use the default from app config + changeLanguage(appInfo.site.default_language) + } + } + + setLanguageFromParams() }, [appInfo]) const [conversationIdInfo, setConversationIdInfo] = useLocalStorageState>>(CONVERSATION_ID_INFO, { diff --git a/web/app/components/base/checkbox/assets/indeterminate-icon.tsx b/web/app/components/base/checkbox/assets/indeterminate-icon.tsx new file mode 100644 index 0000000000..56df8db6a4 --- /dev/null +++ b/web/app/components/base/checkbox/assets/indeterminate-icon.tsx @@ -0,0 +1,11 @@ +const IndeterminateIcon = () => { + return ( +
+ + + +
+ ) +} + +export default IndeterminateIcon diff --git a/web/app/components/base/checkbox/assets/mixed.svg b/web/app/components/base/checkbox/assets/mixed.svg deleted file mode 100644 index e16b8fc975..0000000000 --- a/web/app/components/base/checkbox/assets/mixed.svg +++ /dev/null @@ -1,5 +0,0 @@ - - - - - diff --git a/web/app/components/base/checkbox/index.module.css b/web/app/components/base/checkbox/index.module.css deleted file mode 100644 index d675607b46..0000000000 --- a/web/app/components/base/checkbox/index.module.css +++ /dev/null @@ -1,10 +0,0 @@ -.mixed { - background: var(--color-components-checkbox-bg) url(./assets/mixed.svg) center center no-repeat; - background-size: 12px 12px; - border: none; -} - -.checked.disabled { - background-color: #d0d5dd; - border-color: #d0d5dd; -} \ No newline at end of file diff --git a/web/app/components/base/checkbox/index.spec.tsx b/web/app/components/base/checkbox/index.spec.tsx new file mode 100644 index 0000000000..7ef901aef5 --- /dev/null +++ b/web/app/components/base/checkbox/index.spec.tsx @@ -0,0 +1,67 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import Checkbox from './index' + +describe('Checkbox Component', () => { + const mockProps = { + id: 'test', + } + + it('renders unchecked checkbox by default', () => { + render() + const checkbox = screen.getByTestId('checkbox-test') + expect(checkbox).toBeInTheDocument() + expect(checkbox).not.toHaveClass('bg-components-checkbox-bg') + }) + + it('renders checked checkbox when checked prop is true', () => { + render() + const checkbox = screen.getByTestId('checkbox-test') + expect(checkbox).toHaveClass('bg-components-checkbox-bg') + expect(screen.getByTestId('check-icon-test')).toBeInTheDocument() + }) + + it('renders indeterminate state correctly', () => { + render() + expect(screen.getByTestId('indeterminate-icon')).toBeInTheDocument() + }) + + it('handles click events when not disabled', () => { + const onCheck = jest.fn() + render() + const checkbox = screen.getByTestId('checkbox-test') + + fireEvent.click(checkbox) + expect(onCheck).toHaveBeenCalledTimes(1) + }) + + it('does not handle click events when disabled', () => { + const onCheck = jest.fn() + render() + const checkbox = screen.getByTestId('checkbox-test') + + fireEvent.click(checkbox) + expect(onCheck).not.toHaveBeenCalled() + expect(checkbox).toHaveClass('cursor-not-allowed') + }) + + it('applies custom className when provided', () => { + const customClass = 'custom-class' + render() + const checkbox = screen.getByTestId('checkbox-test') + expect(checkbox).toHaveClass(customClass) + }) + + it('applies correct styles for disabled checked state', () => { + render() + const checkbox = screen.getByTestId('checkbox-test') + expect(checkbox).toHaveClass('bg-components-checkbox-bg-disabled-checked') + expect(checkbox).toHaveClass('cursor-not-allowed') + }) + + it('applies correct styles for disabled unchecked state', () => { + render() + const checkbox = screen.getByTestId('checkbox-test') + expect(checkbox).toHaveClass('bg-components-checkbox-bg-disabled') + expect(checkbox).toHaveClass('cursor-not-allowed') + }) +}) diff --git a/web/app/components/base/checkbox/index.tsx b/web/app/components/base/checkbox/index.tsx index b0b0ebca7c..3e47967c62 100644 --- a/web/app/components/base/checkbox/index.tsx +++ b/web/app/components/base/checkbox/index.tsx @@ -1,48 +1,49 @@ import { RiCheckLine } from '@remixicon/react' -import s from './index.module.css' import cn from '@/utils/classnames' +import IndeterminateIcon from './assets/indeterminate-icon' type CheckboxProps = { + id?: string checked?: boolean onCheck?: () => void className?: string disabled?: boolean - mixed?: boolean + indeterminate?: boolean } -const Checkbox = ({ checked, onCheck, className, disabled, mixed }: CheckboxProps) => { - if (!checked) { - return ( -
{ - if (disabled) - return - onCheck?.() - }} - >
- ) - } +const Checkbox = ({ + id, + checked, + onCheck, + className, + disabled, + indeterminate, +}: CheckboxProps) => { + const checkClassName = (checked || indeterminate) + ? 'bg-components-checkbox-bg text-components-checkbox-icon hover:bg-components-checkbox-bg-hover' + : 'border border-components-checkbox-border bg-components-checkbox-bg-unchecked hover:bg-components-checkbox-bg-unchecked-hover hover:border-components-checkbox-border-hover' + const disabledClassName = (checked || indeterminate) + ? 'cursor-not-allowed bg-components-checkbox-bg-disabled-checked text-components-checkbox-icon-disabled hover:bg-components-checkbox-bg-disabled-checked' + : 'cursor-not-allowed border-components-checkbox-border-disabled bg-components-checkbox-bg-disabled hover:border-components-checkbox-border-disabled hover:bg-components-checkbox-bg-disabled' + return (
{ if (disabled) return - onCheck?.() }} + data-testid={`checkbox-${id}`} > - + {!checked && indeterminate && } + {checked && }
) } diff --git a/web/app/components/base/drawer-plus/index.tsx b/web/app/components/base/drawer-plus/index.tsx index bb022acdcb..33a1948181 100644 --- a/web/app/components/base/drawer-plus/index.tsx +++ b/web/app/components/base/drawer-plus/index.tsx @@ -9,6 +9,8 @@ import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' type Props = { isShow: boolean onHide: () => void + dialogClassName?: string + dialogBackdropClassName?: string panelClassName?: string maxWidthClassName?: string contentClassName?: string @@ -26,6 +28,8 @@ type Props = { const DrawerPlus: FC = ({ isShow, onHide, + dialogClassName = '', + dialogBackdropClassName = '', panelClassName = '', maxWidthClassName = '!max-w-[640px]', height = 'calc(100vh - 72px)', @@ -55,7 +59,9 @@ const DrawerPlus: FC = ({ footer={null} mask={isMobile || isShowMask} positionCenter={positionCenter} - panelClassname={cn('mx-2 mb-3 mt-16 rounded-xl !p-0 sm:mr-2', panelClassName, maxWidthClassName)} + dialogClassName={dialogClassName} + dialogBackdropClassName={dialogBackdropClassName} + panelClassName={cn('mx-2 mb-3 mt-16 rounded-xl !p-0 sm:mr-2', panelClassName, maxWidthClassName)} >
!clickOutsideNotOpen && onClose()} - className="fixed inset-0 z-[80] overflow-y-auto" + className={cn('fixed inset-0 z-[30] overflow-y-auto', dialogClassName)} >
{/* mask */} { !clickOutsideNotOpen && onClose() }} /> -
+
<>
{title && { + const field = useFieldContext() + + return ( +
+
+ { + field.handleChange(!field.state.value) + }} + /> +
+ +
+ ) +} + +export default CheckboxField diff --git a/web/app/components/base/form/components/field/number-input.tsx b/web/app/components/base/form/components/field/number-input.tsx new file mode 100644 index 0000000000..fce3143fe1 --- /dev/null +++ b/web/app/components/base/form/components/field/number-input.tsx @@ -0,0 +1,49 @@ +import React from 'react' +import { useFieldContext } from '../..' +import Label from '../label' +import cn from '@/utils/classnames' +import type { InputNumberProps } from '../../../input-number' +import { InputNumber } from '../../../input-number' + +type TextFieldProps = { + label: string + isRequired?: boolean + showOptional?: boolean + tooltip?: string + className?: string + labelClassName?: string +} & Omit + +const NumberInputField = ({ + label, + isRequired, + showOptional, + tooltip, + className, + labelClassName, + ...inputProps +}: TextFieldProps) => { + const field = useFieldContext() + + return ( +
+
+ ) +} + +export default NumberInputField diff --git a/web/app/components/base/form/components/field/options.tsx b/web/app/components/base/form/components/field/options.tsx new file mode 100644 index 0000000000..9ff71e50af --- /dev/null +++ b/web/app/components/base/form/components/field/options.tsx @@ -0,0 +1,34 @@ +import cn from '@/utils/classnames' +import { useFieldContext } from '../..' +import Label from '../label' +import ConfigSelect from '@/app/components/app/configuration/config-var/config-select' + +type OptionsFieldProps = { + label: string; + className?: string; + labelClassName?: string; +} + +const OptionsField = ({ + label, + className, + labelClassName, +}: OptionsFieldProps) => { + const field = useFieldContext() + + return ( +
+
+ ) +} + +export default OptionsField diff --git a/web/app/components/base/form/components/field/select.tsx b/web/app/components/base/form/components/field/select.tsx new file mode 100644 index 0000000000..95af3c0116 --- /dev/null +++ b/web/app/components/base/form/components/field/select.tsx @@ -0,0 +1,51 @@ +import cn from '@/utils/classnames' +import { useFieldContext } from '../..' +import PureSelect from '../../../select/pure' +import Label from '../label' + +type SelectOption = { + value: string + label: string +} + +type SelectFieldProps = { + label: string + options: SelectOption[] + isRequired?: boolean + showOptional?: boolean + tooltip?: string + className?: string + labelClassName?: string +} + +const SelectField = ({ + label, + options, + isRequired, + showOptional, + tooltip, + className, + labelClassName, +}: SelectFieldProps) => { + const field = useFieldContext() + + return ( +
+
+ ) +} + +export default SelectField diff --git a/web/app/components/base/form/components/field/text.tsx b/web/app/components/base/form/components/field/text.tsx new file mode 100644 index 0000000000..b2090291a0 --- /dev/null +++ b/web/app/components/base/form/components/field/text.tsx @@ -0,0 +1,48 @@ +import React from 'react' +import { useFieldContext } from '../..' +import Input, { type InputProps } from '../../../input' +import Label from '../label' +import cn from '@/utils/classnames' + +type TextFieldProps = { + label: string + isRequired?: boolean + showOptional?: boolean + tooltip?: string + className?: string + labelClassName?: string +} & Omit + +const TextField = ({ + label, + isRequired, + showOptional, + tooltip, + className, + labelClassName, + ...inputProps +}: TextFieldProps) => { + const field = useFieldContext() + + return ( +
+
+ ) +} + +export default TextField diff --git a/web/app/components/base/form/components/form/submit-button.tsx b/web/app/components/base/form/components/form/submit-button.tsx new file mode 100644 index 0000000000..494d19b843 --- /dev/null +++ b/web/app/components/base/form/components/form/submit-button.tsx @@ -0,0 +1,25 @@ +import { useStore } from '@tanstack/react-form' +import { useFormContext } from '../..' +import Button, { type ButtonProps } from '../../../button' + +type SubmitButtonProps = Omit + +const SubmitButton = ({ ...buttonProps }: SubmitButtonProps) => { + const form = useFormContext() + + const [isSubmitting, canSubmit] = useStore(form.store, state => [ + state.isSubmitting, + state.canSubmit, + ]) + + return ( +
+ ) +} + +export default Label diff --git a/web/app/components/base/form/form-scenarios/demo/contact-fields.tsx b/web/app/components/base/form/form-scenarios/demo/contact-fields.tsx new file mode 100644 index 0000000000..9ba664fc10 --- /dev/null +++ b/web/app/components/base/form/form-scenarios/demo/contact-fields.tsx @@ -0,0 +1,35 @@ +import { withForm } from '../..' +import { demoFormOpts } from './shared-options' +import { ContactMethods } from './types' + +const ContactFields = withForm({ + ...demoFormOpts, + render: ({ form }) => { + return ( +
+

Contacts

+
+ } + /> + } + /> + ( + + )} + /> +
+
+ ) + }, +}) + +export default ContactFields diff --git a/web/app/components/base/form/form-scenarios/demo/index.tsx b/web/app/components/base/form/form-scenarios/demo/index.tsx new file mode 100644 index 0000000000..f08edee41e --- /dev/null +++ b/web/app/components/base/form/form-scenarios/demo/index.tsx @@ -0,0 +1,68 @@ +import { useStore } from '@tanstack/react-form' +import { useAppForm } from '../..' +import ContactFields from './contact-fields' +import { demoFormOpts } from './shared-options' +import { UserSchema } from './types' + +const DemoForm = () => { + const form = useAppForm({ + ...demoFormOpts, + validators: { + onSubmit: ({ value }) => { + // Validate the entire form + const result = UserSchema.safeParse(value) + if (!result.success) { + const issues = result.error.issues + console.log('Validation errors:', issues) + return issues[0].message + } + return undefined + }, + }, + onSubmit: ({ value }) => { + console.log('Form submitted:', value) + }, + }) + +const name = useStore(form.store, state => state.values.name) + + return ( +
{ + e.preventDefault() + e.stopPropagation() + form.handleSubmit() + }} + > + ( + + )} + /> + ( + + )} + /> + ( + + )} + /> + { + !!name && ( + + ) + } + + Submit + + + ) +} + +export default DemoForm diff --git a/web/app/components/base/form/form-scenarios/demo/shared-options.tsx b/web/app/components/base/form/form-scenarios/demo/shared-options.tsx new file mode 100644 index 0000000000..8b216c8b90 --- /dev/null +++ b/web/app/components/base/form/form-scenarios/demo/shared-options.tsx @@ -0,0 +1,14 @@ +import { formOptions } from '@tanstack/react-form' + +export const demoFormOpts = formOptions({ + defaultValues: { + name: '', + surname: '', + isAcceptingTerms: false, + contact: { + email: '', + phone: '', + preferredContactMethod: 'email', + }, + }, +}) diff --git a/web/app/components/base/form/form-scenarios/demo/types.ts b/web/app/components/base/form/form-scenarios/demo/types.ts new file mode 100644 index 0000000000..c4e626ef63 --- /dev/null +++ b/web/app/components/base/form/form-scenarios/demo/types.ts @@ -0,0 +1,34 @@ +import { z } from 'zod' + +const ContactMethod = z.union([ + z.literal('email'), + z.literal('phone'), + z.literal('whatsapp'), + z.literal('sms'), +]) + +export const ContactMethods = ContactMethod.options.map(({ value }) => ({ + value, + label: value.charAt(0).toUpperCase() + value.slice(1), +})) + +export const UserSchema = z.object({ + name: z + .string() + .regex(/^[A-Z]/, 'Name must start with a capital letter') + .min(3, 'Name must be at least 3 characters long'), + surname: z + .string() + .min(3, 'Surname must be at least 3 characters long') + .regex(/^[A-Z]/, 'Surname must start with a capital letter'), + isAcceptingTerms: z.boolean().refine(val => val, { + message: 'You must accept the terms and conditions', + }), + contact: z.object({ + email: z.string().email('Invalid email address'), + phone: z.string().optional(), + preferredContactMethod: ContactMethod, + }), +}) + +export type User = z.infer diff --git a/web/app/components/base/form/index.tsx b/web/app/components/base/form/index.tsx new file mode 100644 index 0000000000..aeb482ad02 --- /dev/null +++ b/web/app/components/base/form/index.tsx @@ -0,0 +1,25 @@ +import { createFormHook, createFormHookContexts } from '@tanstack/react-form' +import TextField from './components/field/text' +import NumberInputField from './components/field/number-input' +import CheckboxField from './components/field/checkbox' +import SelectField from './components/field/select' +import OptionsField from './components/field/options' +import SubmitButton from './components/form/submit-button' + +export const { fieldContext, useFieldContext, formContext, useFormContext } + = createFormHookContexts() + +export const { useAppForm, withForm } = createFormHook({ + fieldComponents: { + TextField, + NumberInputField, + CheckboxField, + SelectField, + OptionsField, + }, + formComponents: { + SubmitButton, + }, + fieldContext, + formContext, +}) diff --git a/web/app/components/base/icons/assets/vender/solid/general/arrow-down-round-fill.svg b/web/app/components/base/icons/assets/vender/solid/general/arrow-down-round-fill.svg new file mode 100644 index 0000000000..9566fcc0c3 --- /dev/null +++ b/web/app/components/base/icons/assets/vender/solid/general/arrow-down-round-fill.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.json b/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.json new file mode 100644 index 0000000000..4e7da3c801 --- /dev/null +++ b/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.json @@ -0,0 +1,36 @@ +{ + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "width": "16", + "height": "16", + "viewBox": "0 0 16 16", + "fill": "none", + "xmlns": "http://www.w3.org/2000/svg" + }, + "children": [ + { + "type": "element", + "name": "g", + "attributes": { + "id": "arrow-down-round-fill" + }, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "id": "Vector", + "d": "M6.02913 6.23572C5.08582 6.23572 4.56482 7.33027 5.15967 8.06239L7.13093 10.4885C7.57922 11.0403 8.42149 11.0403 8.86986 10.4885L10.8411 8.06239C11.4359 7.33027 10.9149 6.23572 9.97158 6.23572H6.02913Z", + "fill": "currentColor" + }, + "children": [] + } + ] + } + ] + }, + "name": "ArrowDownRoundFill" +} \ No newline at end of file diff --git a/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.tsx b/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.tsx new file mode 100644 index 0000000000..c766a72b94 --- /dev/null +++ b/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.tsx @@ -0,0 +1,20 @@ +// GENERATE BY script +// DON NOT EDIT IT MANUALLY + +import * as React from 'react' +import data from './ArrowDownRoundFill.json' +import IconBase from '@/app/components/base/icons/IconBase' +import type { IconData } from '@/app/components/base/icons/IconBase' + +const Icon = ( + { + ref, + ...props + }: React.SVGProps & { + ref?: React.RefObject>; + }, +) => + +Icon.displayName = 'ArrowDownRoundFill' + +export default Icon diff --git a/web/app/components/base/icons/src/vender/solid/general/index.ts b/web/app/components/base/icons/src/vender/solid/general/index.ts index 52647905ab..4c4dd9a437 100644 --- a/web/app/components/base/icons/src/vender/solid/general/index.ts +++ b/web/app/components/base/icons/src/vender/solid/general/index.ts @@ -1,4 +1,5 @@ export { default as AnswerTriangle } from './AnswerTriangle' +export { default as ArrowDownRoundFill } from './ArrowDownRoundFill' export { default as CheckCircle } from './CheckCircle' export { default as CheckDone01 } from './CheckDone01' export { default as Download02 } from './Download02' diff --git a/web/app/components/base/input-number/index.spec.tsx b/web/app/components/base/input-number/index.spec.tsx new file mode 100644 index 0000000000..8dfd1184b0 --- /dev/null +++ b/web/app/components/base/input-number/index.spec.tsx @@ -0,0 +1,97 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { InputNumber } from './index' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +describe('InputNumber Component', () => { + const defaultProps = { + onChange: jest.fn(), + } + + afterEach(() => { + jest.clearAllMocks() + }) + + it('renders input with default values', () => { + render() + const input = screen.getByRole('textbox') + expect(input).toBeInTheDocument() + }) + + it('handles increment button click', () => { + render() + const incrementBtn = screen.getByRole('button', { name: /increment/i }) + + fireEvent.click(incrementBtn) + expect(defaultProps.onChange).toHaveBeenCalledWith(6) + }) + + it('handles decrement button click', () => { + render() + const decrementBtn = screen.getByRole('button', { name: /decrement/i }) + + fireEvent.click(decrementBtn) + expect(defaultProps.onChange).toHaveBeenCalledWith(4) + }) + + it('respects max value constraint', () => { + render() + const incrementBtn = screen.getByRole('button', { name: /increment/i }) + + fireEvent.click(incrementBtn) + expect(defaultProps.onChange).not.toHaveBeenCalled() + }) + + it('respects min value constraint', () => { + render() + const decrementBtn = screen.getByRole('button', { name: /decrement/i }) + + fireEvent.click(decrementBtn) + expect(defaultProps.onChange).not.toHaveBeenCalled() + }) + + it('handles direct input changes', () => { + render() + const input = screen.getByRole('textbox') + + fireEvent.change(input, { target: { value: '42' } }) + expect(defaultProps.onChange).toHaveBeenCalledWith(42) + }) + + it('handles empty input', () => { + render() + const input = screen.getByRole('textbox') + + fireEvent.change(input, { target: { value: '' } }) + expect(defaultProps.onChange).toHaveBeenCalledWith(undefined) + }) + + it('handles invalid input', () => { + render() + const input = screen.getByRole('textbox') + + fireEvent.change(input, { target: { value: 'abc' } }) + expect(defaultProps.onChange).not.toHaveBeenCalled() + }) + + it('displays unit when provided', () => { + const unit = 'px' + render() + expect(screen.getByText(unit)).toBeInTheDocument() + }) + + it('disables controls when disabled prop is true', () => { + render() + const input = screen.getByRole('textbox') + const incrementBtn = screen.getByRole('button', { name: /increment/i }) + const decrementBtn = screen.getByRole('button', { name: /decrement/i }) + + expect(input).toBeDisabled() + expect(incrementBtn).toBeDisabled() + expect(decrementBtn).toBeDisabled() + }) +}) diff --git a/web/app/components/base/input-number/index.tsx b/web/app/components/base/input-number/index.tsx index 5b88fc67f8..98efc94462 100644 --- a/web/app/components/base/input-number/index.tsx +++ b/web/app/components/base/input-number/index.tsx @@ -8,7 +8,7 @@ export type InputNumberProps = { value?: number onChange: (value?: number) => void amount?: number - size?: 'sm' | 'md' + size?: 'regular' | 'large' max?: number min?: number defaultValue?: number @@ -19,14 +19,12 @@ export type InputNumberProps = { } & Omit export const InputNumber: FC = (props) => { - const { unit, className, onChange, amount = 1, value, size = 'md', max, min, defaultValue, wrapClassName, controlWrapClassName, controlClassName, disabled, ...rest } = props + const { unit, className, onChange, amount = 1, value, size = 'regular', max, min, defaultValue, wrapClassName, controlWrapClassName, controlClassName, disabled, ...rest } = props const isValidValue = (v: number) => { - if (max && v > max) + if (typeof max === 'number' && v > max) return false - if (min && v < min) - return false - return true + return !(typeof min === 'number' && v < min) } const inc = () => { @@ -76,29 +74,39 @@ export const InputNumber: FC = (props) => { onChange(parsed) }} unit={unit} + size={size} />
-
diff --git a/web/app/components/base/input/index.tsx b/web/app/components/base/input/index.tsx index 5f059c3b7f..30fd90aff8 100644 --- a/web/app/components/base/input/index.tsx +++ b/web/app/components/base/input/index.tsx @@ -30,7 +30,7 @@ export type InputProps = { wrapperClassName?: string styleCss?: CSSProperties unit?: string -} & React.InputHTMLAttributes & VariantProps +} & Omit, 'size'> & VariantProps const Input = ({ size, diff --git a/web/app/components/base/markdown-blocks/music.tsx b/web/app/components/base/markdown-blocks/music.tsx new file mode 100644 index 0000000000..7edd1713c9 --- /dev/null +++ b/web/app/components/base/markdown-blocks/music.tsx @@ -0,0 +1,37 @@ +import abcjs from 'abcjs' +import { useEffect, useRef } from 'react' +import 'abcjs/abcjs-audio.css' + +const MarkdownMusic = ({ children }: { children: React.ReactNode }) => { + const containerRef = useRef(null) + const controlsRef = useRef(null) + + useEffect(() => { + if (containerRef.current && controlsRef.current) { + if (typeof children === 'string') { + const visualObjs = abcjs.renderAbc(containerRef.current, children, { + add_classes: true, // Add classes to SVG elements for cursor tracking + responsive: 'resize', // Make notation responsive + }) + const synthControl = new abcjs.synth.SynthController() + synthControl.load(controlsRef.current, {}, { displayPlay: true }) + const synth = new abcjs.synth.CreateSynth() + const visualObj = visualObjs[0] + synth.init({ visualObj }).then(() => { + synthControl.setTune(visualObj, false) + }) + containerRef.current.style.overflow = 'auto' + } + } + }, [children]) + + return ( +
+
+
+
+ ) +} +MarkdownMusic.displayName = 'MarkdownMusic' + +export default MarkdownMusic diff --git a/web/app/components/base/markdown.tsx b/web/app/components/base/markdown.tsx index 24ae59af73..6ea84a2842 100644 --- a/web/app/components/base/markdown.tsx +++ b/web/app/components/base/markdown.tsx @@ -23,6 +23,7 @@ import VideoGallery from '@/app/components/base/video-gallery' import AudioGallery from '@/app/components/base/audio-gallery' import MarkdownButton from '@/app/components/base/markdown-blocks/button' import MarkdownForm from '@/app/components/base/markdown-blocks/form' +import MarkdownMusic from '@/app/components/base/markdown-blocks/music' import ThinkBlock from '@/app/components/base/markdown-blocks/think-block' import { Theme } from '@/types/app' import useTheme from '@/hooks/use-theme' @@ -51,6 +52,7 @@ const capitalizationLanguageNameMap: Record = { json: 'JSON', latex: 'Latex', svg: 'SVG', + abc: 'ABC', } const getCorrectCapitalizationLanguageName = (language: string) => { if (!language) @@ -85,9 +87,11 @@ const preprocessLaTeX = (content: string) => { } const preprocessThinkTag = (content: string) => { + const thinkOpenTagRegex = /\n/g + const thinkCloseTagRegex = /\n<\/think>/g return flow([ - (str: string) => str.replace('\n', '
\n'), - (str: string) => str.replace('\n', '\n[ENDTHINKFLAG]
'), + (str: string) => str.replace(thinkOpenTagRegex, '
\n'), + (str: string) => str.replace(thinkCloseTagRegex, '\n[ENDTHINKFLAG]
'), ])(content) } @@ -135,45 +139,54 @@ const CodeBlock: any = memo(({ inline, className, children, ...props }: any) => const renderCodeContent = useMemo(() => { const content = String(children).replace(/\n$/, '') - if (language === 'mermaid' && isSVG) { - return - } - else if (language === 'echarts') { - return ( -
+ switch (language) { + case 'mermaid': + if (isSVG) + return + break + case 'echarts': + return ( +
+ + + +
+ ) + case 'svg': + if (isSVG) { + return ( + + + + ) + } + break + case 'abc': + return ( - + -
- ) - } - else if (language === 'svg' && isSVG) { - return ( - - - - ) - } - else { - return ( - - {content} - - ) + ) + default: + return ( + + {content} + + ) } - }, [language, match, props, children, chartData, isSVG]) + }, [children, language, isSVG, chartData, props, theme, match]) if (inline || !match) return {children} @@ -239,7 +252,7 @@ const Img = ({ src }: any) => { return
} -const Link = ({ node, ...props }: any) => { +const Link = ({ node, children, ...props }: any) => { if (node.properties?.href && node.properties.href?.toString().startsWith('abbr')) { // eslint-disable-next-line react-hooks/rules-of-hooks const { onSend } = useChatContext() @@ -248,7 +261,7 @@ const Link = ({ node, ...props }: any) => { return onSend?.(hidden_text)} title={node.children[0]?.value}>{node.children[0]?.value} } else { - return {node.children[0] ? node.children[0]?.value : 'Download'} + return {children || 'Download'} } } diff --git a/web/app/components/base/param-item/index.tsx b/web/app/components/base/param-item/index.tsx index 4cae402e3b..03eb5a7c42 100644 --- a/web/app/components/base/param-item/index.tsx +++ b/web/app/components/base/param-item/index.tsx @@ -54,7 +54,7 @@ const ParamItem: FC = ({ className, id, name, noTooltip, tip, step = 0.1, max={max} step={step} amount={step} - size='sm' + size='regular' value={value} onChange={(value) => { onChange(id, value) diff --git a/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx b/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx index d7a3a81417..562bb8c0d9 100644 --- a/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx +++ b/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx @@ -31,6 +31,7 @@ import { useOptions } from './hooks' import type { PickerBlockMenuOption } from './menu' import VarReferenceVars from '@/app/components/workflow/nodes/_base/components/variable/var-reference-vars' import { useEventEmitterContextContext } from '@/context/event-emitter' +import { KEY_ESCAPE_COMMAND } from 'lexical' type ComponentPickerProps = { triggerString: string @@ -118,6 +119,13 @@ const ComponentPicker = ({ editor.dispatchCommand(INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND, variables) }, [editor, checkForTriggerMatch, triggerString]) + const handleClose = useCallback(() => { + ReactDOM.flushSync(() => { + const escapeEvent = new KeyboardEvent('keydown', { key: 'Escape' }) + editor.dispatchCommand(KEY_ESCAPE_COMMAND, escapeEvent) + }) + }, [editor]) + const renderMenu = useCallback>(( anchorElementRef, { options, selectedIndex, selectOptionAndCleanUp, setHighlightedIndex }, @@ -141,51 +149,54 @@ const ComponentPicker = ({ visibility: isPositioned ? 'visible' : 'hidden', }} ref={refs.setFloating} + data-testid="component-picker-container" > { - options.map((option, index) => ( - - { - // Divider - index !== 0 && options.at(index - 1)?.group !== option.group && ( -
- ) - } - {option.renderMenuOption({ - queryString, - isSelected: selectedIndex === index, - onSelect: () => { - selectOptionAndCleanUp(option) - }, - onSetHighlight: () => { - setHighlightedIndex(index) - }, - })} -
- )) + workflowVariableBlock?.show && ( +
+ { + handleSelectWorkflowVariable(variables) + }} + maxHeightClass='max-h-[34vh]' + isSupportFileVar={isSupportFileVar} + onClose={handleClose} + onBlur={handleClose} + /> +
+ ) } { - workflowVariableBlock?.show && ( - <> - { - (!!options.length) && ( -
- ) - } -
- { - handleSelectWorkflowVariable(variables) - }} - maxHeightClass='max-h-[34vh]' - isSupportFileVar={isSupportFileVar} - /> -
- + workflowVariableBlock?.show && !!options.length && ( +
) } +
+ { + options.map((option, index) => ( + + { + // Divider + index !== 0 && options.at(index - 1)?.group !== option.group && ( +
+ ) + } + {option.renderMenuOption({ + queryString, + isSelected: selectedIndex === index, + onSelect: () => { + selectOptionAndCleanUp(option) + }, + onSetHighlight: () => { + setHighlightedIndex(index) + }, + })} +
+ )) + } +
, anchorElementRef.current, @@ -193,7 +204,7 @@ const ComponentPicker = ({ } ) - }, [allFlattenOptions.length, workflowVariableBlock?.show, refs, isPositioned, floatingStyles, queryString, workflowVariableOptions, handleSelectWorkflowVariable]) + }, [allFlattenOptions.length, workflowVariableBlock?.show, refs, isPositioned, floatingStyles, queryString, workflowVariableOptions, handleSelectWorkflowVariable, handleClose, isSupportFileVar]) return ( { } static clone(node: HistoryBlockNode): HistoryBlockNode { - return new HistoryBlockNode(node.__roleName, node.__onEditRole) + return new HistoryBlockNode(node.__roleName, node.__onEditRole, node.__key) } constructor(roleName: RoleName, onEditRole: () => void, key?: NodeKey) { diff --git a/web/app/components/base/prompt-editor/plugins/on-blur-or-focus-block.tsx b/web/app/components/base/prompt-editor/plugins/on-blur-or-focus-block.tsx index 2e3adc15cf..246fd96769 100644 --- a/web/app/components/base/prompt-editor/plugins/on-blur-or-focus-block.tsx +++ b/web/app/components/base/prompt-editor/plugins/on-blur-or-focus-block.tsx @@ -37,14 +37,16 @@ const OnBlurBlock: FC = ({ ), editor.registerCommand( BLUR_COMMAND, - () => { - ref.current = setTimeout(() => { - editor.dispatchCommand(KEY_ESCAPE_COMMAND, new KeyboardEvent('keydown', { key: 'Escape' })) - }, 200) - - if (onBlur) - onBlur() - + (event) => { + // Check if the clicked target element is var-search-input + const target = event?.relatedTarget as HTMLElement + if (!target?.classList?.contains('var-search-input')) { + ref.current = setTimeout(() => { + editor.dispatchCommand(KEY_ESCAPE_COMMAND, new KeyboardEvent('keydown', { key: 'Escape' })) + }, 200) + if (onBlur) + onBlur() + } return true }, COMMAND_PRIORITY_EDITOR, diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx index 2cf4c95b87..2f6c3374a7 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx @@ -11,6 +11,7 @@ import { mergeRegister } from '@lexical/utils' import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' import { RiErrorWarningFill, + RiMoreLine, } from '@remixicon/react' import { useSelectOrDelete } from '../../hooks' import type { WorkflowNodesMap } from './node' @@ -27,26 +28,35 @@ import { Line3 } from '@/app/components/base/icons/src/public/common' import { isConversationVar, isENV, isSystemVar } from '@/app/components/workflow/nodes/_base/components/variable/utils' import Tooltip from '@/app/components/base/tooltip' import { isExceptionVariable } from '@/app/components/workflow/utils' +import VarFullPathPanel from '@/app/components/workflow/nodes/_base/components/variable/var-full-path-panel' +import { Type } from '@/app/components/workflow/nodes/llm/types' +import type { ValueSelector } from '@/app/components/workflow/types' type WorkflowVariableBlockComponentProps = { nodeKey: string variables: string[] workflowNodesMap: WorkflowNodesMap + getVarType?: (payload: { + nodeId: string, + valueSelector: ValueSelector, + }) => Type } const WorkflowVariableBlockComponent = ({ nodeKey, variables, workflowNodesMap = {}, + getVarType, }: WorkflowVariableBlockComponentProps) => { const { t } = useTranslation() const [editor] = useLexicalComposerContext() const [ref, isSelected] = useSelectOrDelete(nodeKey, DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND) const variablesLength = variables.length + const isShowAPart = variablesLength > 2 const varName = ( () => { const isSystem = isSystemVar(variables) - const varName = variablesLength >= 3 ? (variables).slice(-2).join('.') : variables[variablesLength - 1] + const varName = variables[variablesLength - 1] return `${isSystem ? 'sys.' : ''}${varName}` } )() @@ -76,7 +86,7 @@ const WorkflowVariableBlockComponent = ({ const Item = (
)} + {isShowAPart && ( +
+ + +
+ )} +
{!isEnv && !isChatVar && } {isEnv && } @@ -126,7 +143,27 @@ const WorkflowVariableBlockComponent = ({ ) } - return Item + if (!node) + return null + + return ( + } + disabled={!isShowAPart} + > +
{Item}
+
+ ) } export default memo(WorkflowVariableBlockComponent) diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx index 05d4505e20..479dce9615 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx @@ -9,7 +9,7 @@ import { } from 'lexical' import { mergeRegister } from '@lexical/utils' import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' -import type { WorkflowVariableBlockType } from '../../types' +import type { GetVarType, WorkflowVariableBlockType } from '../../types' import { $createWorkflowVariableBlockNode, WorkflowVariableBlockNode, @@ -25,11 +25,13 @@ export type WorkflowVariableBlockProps = { getWorkflowNode: (nodeId: string) => Node onInsert?: () => void onDelete?: () => void + getVarType: GetVarType } const WorkflowVariableBlock = memo(({ workflowNodesMap, onInsert, onDelete, + getVarType, }: WorkflowVariableBlockType) => { const [editor] = useLexicalComposerContext() @@ -48,7 +50,7 @@ const WorkflowVariableBlock = memo(({ INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND, (variables: string[]) => { editor.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) - const workflowVariableBlockNode = $createWorkflowVariableBlockNode(variables, workflowNodesMap) + const workflowVariableBlockNode = $createWorkflowVariableBlockNode(variables, workflowNodesMap, getVarType) $insertNodes([workflowVariableBlockNode]) if (onInsert) @@ -69,7 +71,7 @@ const WorkflowVariableBlock = memo(({ COMMAND_PRIORITY_EDITOR, ), ) - }, [editor, onInsert, onDelete, workflowNodesMap]) + }, [editor, onInsert, onDelete, workflowNodesMap, getVarType]) return null }) diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx index 0564e6f16d..dce636d92d 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx @@ -2,34 +2,39 @@ import type { LexicalNode, NodeKey, SerializedLexicalNode } from 'lexical' import { DecoratorNode } from 'lexical' import type { WorkflowVariableBlockType } from '../../types' import WorkflowVariableBlockComponent from './component' +import type { GetVarType } from '../../types' export type WorkflowNodesMap = WorkflowVariableBlockType['workflowNodesMap'] + export type SerializedNode = SerializedLexicalNode & { variables: string[] workflowNodesMap: WorkflowNodesMap + getVarType?: GetVarType } export class WorkflowVariableBlockNode extends DecoratorNode { __variables: string[] __workflowNodesMap: WorkflowNodesMap + __getVarType?: GetVarType static getType(): string { return 'workflow-variable-block' } static clone(node: WorkflowVariableBlockNode): WorkflowVariableBlockNode { - return new WorkflowVariableBlockNode(node.__variables, node.__workflowNodesMap, node.__key) + return new WorkflowVariableBlockNode(node.__variables, node.__workflowNodesMap, node.__getVarType, node.__key) } isInline(): boolean { return true } - constructor(variables: string[], workflowNodesMap: WorkflowNodesMap, key?: NodeKey) { + constructor(variables: string[], workflowNodesMap: WorkflowNodesMap, getVarType: any, key?: NodeKey) { super(key) this.__variables = variables this.__workflowNodesMap = workflowNodesMap + this.__getVarType = getVarType } createDOM(): HTMLElement { @@ -48,12 +53,13 @@ export class WorkflowVariableBlockNode extends DecoratorNode nodeKey={this.getKey()} variables={this.__variables} workflowNodesMap={this.__workflowNodesMap} + getVarType={this.__getVarType!} /> ) } static importJSON(serializedNode: SerializedNode): WorkflowVariableBlockNode { - const node = $createWorkflowVariableBlockNode(serializedNode.variables, serializedNode.workflowNodesMap) + const node = $createWorkflowVariableBlockNode(serializedNode.variables, serializedNode.workflowNodesMap, serializedNode.getVarType) return node } @@ -64,6 +70,7 @@ export class WorkflowVariableBlockNode extends DecoratorNode version: 1, variables: this.getVariables(), workflowNodesMap: this.getWorkflowNodesMap(), + getVarType: this.getVarType(), } } @@ -77,12 +84,17 @@ export class WorkflowVariableBlockNode extends DecoratorNode return self.__workflowNodesMap } + getVarType(): any { + const self = this.getLatest() + return self.__getVarType + } + getTextContent(): string { return `{{#${this.getVariables().join('.')}#}}` } } -export function $createWorkflowVariableBlockNode(variables: string[], workflowNodesMap: WorkflowNodesMap): WorkflowVariableBlockNode { - return new WorkflowVariableBlockNode(variables, workflowNodesMap) +export function $createWorkflowVariableBlockNode(variables: string[], workflowNodesMap: WorkflowNodesMap, getVarType?: GetVarType): WorkflowVariableBlockNode { + return new WorkflowVariableBlockNode(variables, workflowNodesMap, getVarType) } export function $isWorkflowVariableBlockNode( diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/workflow-variable-block-replacement-block.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/workflow-variable-block-replacement-block.tsx index 22ebc5d248..288008bbcc 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/workflow-variable-block-replacement-block.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/workflow-variable-block-replacement-block.tsx @@ -16,6 +16,7 @@ import { VAR_REGEX as REGEX, resetReg } from '@/config' const WorkflowVariableBlockReplacementBlock = ({ workflowNodesMap, + getVarType, onInsert, }: WorkflowVariableBlockType) => { const [editor] = useLexicalComposerContext() @@ -30,8 +31,8 @@ const WorkflowVariableBlockReplacementBlock = ({ onInsert() const nodePathString = textNode.getTextContent().slice(3, -3) - return $applyNodeReplacement($createWorkflowVariableBlockNode(nodePathString.split('.'), workflowNodesMap)) - }, [onInsert, workflowNodesMap]) + return $applyNodeReplacement($createWorkflowVariableBlockNode(nodePathString.split('.'), workflowNodesMap, getVarType)) + }, [onInsert, workflowNodesMap, getVarType]) const getMatch = useCallback((text: string) => { const matchArr = REGEX.exec(text) diff --git a/web/app/components/base/prompt-editor/types.ts b/web/app/components/base/prompt-editor/types.ts index 6d0f307c17..0f09fb2473 100644 --- a/web/app/components/base/prompt-editor/types.ts +++ b/web/app/components/base/prompt-editor/types.ts @@ -1,8 +1,10 @@ +import type { Type } from '../../workflow/nodes/llm/types' import type { Dataset } from './plugins/context-block' import type { RoleName } from './plugins/history-block' import type { Node, NodeOutPutVar, + ValueSelector, } from '@/app/components/workflow/types' export type Option = { @@ -54,12 +56,18 @@ export type ExternalToolBlockType = { onAddExternalTool?: () => void } +export type GetVarType = (payload: { + nodeId: string, + valueSelector: ValueSelector, +}) => Type + export type WorkflowVariableBlockType = { show?: boolean variables?: NodeOutPutVar[] workflowNodesMap?: Record> onInsert?: () => void onDelete?: () => void + getVarType?: GetVarType } export type MenuTextMatch = { diff --git a/web/app/components/base/segmented-control/index.tsx b/web/app/components/base/segmented-control/index.tsx new file mode 100644 index 0000000000..bd921e4243 --- /dev/null +++ b/web/app/components/base/segmented-control/index.tsx @@ -0,0 +1,68 @@ +import React from 'react' +import classNames from '@/utils/classnames' +import type { RemixiconComponentType } from '@remixicon/react' +import Divider from '../divider' + +// Updated generic type to allow enum values +type SegmentedControlProps = { + options: { Icon: RemixiconComponentType, text: string, value: T }[] + value: T + onChange: (value: T) => void + className?: string +} + +export const SegmentedControl = ({ + options, + value, + onChange, + className, +}: SegmentedControlProps): JSX.Element => { + const selectedOptionIndex = options.findIndex(option => option.value === value) + + return ( +
+ {options.map((option, index) => { + const { Icon } = option + const isSelected = index === selectedOptionIndex + const isNextSelected = index === selectedOptionIndex - 1 + const isLast = index === options.length - 1 + return ( + + ) + })} +
+ ) +} + +export default React.memo(SegmentedControl) as typeof SegmentedControl diff --git a/web/app/components/base/textarea/index.tsx b/web/app/components/base/textarea/index.tsx index 0f18bebedf..1e274515f8 100644 --- a/web/app/components/base/textarea/index.tsx +++ b/web/app/components/base/textarea/index.tsx @@ -8,8 +8,9 @@ const textareaVariants = cva( { variants: { size: { - regular: 'px-3 radius-md system-sm-regular', - large: 'px-4 radius-lg system-md-regular', + small: 'py-1 rounded-md system-xs-regular', + regular: 'px-3 rounded-md system-sm-regular', + large: 'px-4 rounded-lg system-md-regular', }, }, defaultVariants: { diff --git a/web/app/components/base/tooltip/index.tsx b/web/app/components/base/tooltip/index.tsx index e9b7ab047a..e6c4de31f1 100644 --- a/web/app/components/base/tooltip/index.tsx +++ b/web/app/components/base/tooltip/index.tsx @@ -10,6 +10,7 @@ export type TooltipProps = { position?: Placement triggerMethod?: 'hover' | 'click' triggerClassName?: string + triggerTestId?: string disabled?: boolean popupContent?: React.ReactNode children?: React.ReactNode @@ -24,6 +25,7 @@ const Tooltip: FC = ({ position = 'top', triggerMethod = 'hover', triggerClassName, + triggerTestId, disabled = false, popupContent, children, @@ -91,7 +93,7 @@ const Tooltip: FC = ({ onMouseLeave={() => triggerMethod === 'hover' && handleLeave(true)} asChild={asChild} > - {children ||
} + {children ||
} = (props) => {
}> = (props) => {
}> = ({ = ({ const resetList = useCallback(() => { setSelectedSegmentIds([]) invalidSegmentList() - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []) + }, [invalidSegmentList]) const resetChildList = useCallback(() => { invalidChildSegmentList() - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []) + }, [invalidChildSegmentList]) const onClickCard = (detail: SegmentDetailModel, isEditMode = false) => { setCurrSegment({ segInfo: detail, showModal: true, isEditMode }) @@ -253,7 +251,7 @@ const Completed: FC = ({ const invalidChunkListEnabled = useInvalid(useChunkListEnabledKey) const invalidChunkListDisabled = useInvalid(useChunkListDisabledKey) - const refreshChunkListWithStatusChanged = () => { + const refreshChunkListWithStatusChanged = useCallback(() => { switch (selectedStatus) { case 'all': invalidChunkListDisabled() @@ -262,7 +260,7 @@ const Completed: FC = ({ default: invalidSegmentList() } - } + }, [selectedStatus, invalidChunkListDisabled, invalidChunkListEnabled, invalidSegmentList]) const onChangeSwitch = useCallback(async (enable: boolean, segId?: string) => { const operationApi = enable ? enableSegment : disableSegment @@ -280,8 +278,7 @@ const Completed: FC = ({ notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') }) }, }) - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [datasetId, documentId, selectedSegmentIds, segments]) + }, [datasetId, documentId, selectedSegmentIds, segments, disableSegment, enableSegment, t, notify, refreshChunkListWithStatusChanged]) const { mutateAsync: deleteSegment } = useDeleteSegment() @@ -296,12 +293,11 @@ const Completed: FC = ({ notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') }) }, }) - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [datasetId, documentId, selectedSegmentIds]) + }, [datasetId, documentId, selectedSegmentIds, deleteSegment, resetList, t, notify]) const { mutateAsync: updateSegment } = useUpdateSegment() - const refreshChunkListDataWithDetailChanged = () => { + const refreshChunkListDataWithDetailChanged = useCallback(() => { switch (selectedStatus) { case 'all': invalidChunkListDisabled() @@ -316,7 +312,7 @@ const Completed: FC = ({ invalidChunkListEnabled() break } - } + }, [selectedStatus, invalidChunkListDisabled, invalidChunkListEnabled, invalidChunkListAll]) const handleUpdateSegment = useCallback(async ( segmentId: string, @@ -375,17 +371,18 @@ const Completed: FC = ({ eventEmitter?.emit('update-segment-done') }, }) - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [segments, datasetId, documentId]) + }, [segments, datasetId, documentId, updateSegment, docForm, notify, eventEmitter, onCloseSegmentDetail, refreshChunkListDataWithDetailChanged, t]) useEffect(() => { resetList() + // eslint-disable-next-line react-hooks/exhaustive-deps }, [pathname]) useEffect(() => { if (importStatus === ProcessStatus.COMPLETED) resetList() - }, [importStatus, resetList]) + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [importStatus]) const onCancelBatchOperation = useCallback(() => { setSelectedSegmentIds([]) @@ -430,8 +427,7 @@ const Completed: FC = ({ const count = segmentListData?.total || 0 return `${total} ${t('datasetDocuments.segment.searchResults', { count })}` } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [segmentListData?.total, mode, parentMode, searchValue, selectedStatus]) + }, [segmentListData, mode, parentMode, searchValue, selectedStatus, t]) const toggleFullScreen = useCallback(() => { setFullScreen(!fullScreen) @@ -449,8 +445,7 @@ const Completed: FC = ({ resetList() currentPage !== totalPages && setCurrentPage(totalPages) } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [segmentListData, limit, currentPage]) + }, [segmentListData, limit, currentPage, resetList]) const { mutateAsync: deleteChildSegment } = useDeleteChildSegment() @@ -470,8 +465,7 @@ const Completed: FC = ({ }, }, ) - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [datasetId, documentId, parentMode]) + }, [datasetId, documentId, parentMode, deleteChildSegment, resetList, resetChildList, t, notify]) const handleAddNewChildChunk = useCallback((parentChunkId: string) => { setShowNewChildSegmentModal(true) @@ -490,8 +484,7 @@ const Completed: FC = ({ else { resetChildList() } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [parentMode, currChunkId, segments]) + }, [parentMode, currChunkId, segments, refreshChunkListDataWithDetailChanged, resetChildList]) const viewNewlyAddedChildChunk = useCallback(() => { const totalPages = childChunkListData?.total_pages || 0 @@ -505,8 +498,7 @@ const Completed: FC = ({ resetChildList() currentPage !== totalPages && setCurrentPage(totalPages) } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [childChunkListData, limit, currentPage]) + }, [childChunkListData, limit, currentPage, resetChildList]) const onClickSlice = useCallback((detail: ChildChunkDetail) => { setCurrChildChunk({ childChunkInfo: detail, showModal: true }) @@ -560,8 +552,7 @@ const Completed: FC = ({ eventEmitter?.emit('update-child-segment-done') }, }) - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [segments, childSegments, datasetId, documentId, parentMode]) + }, [segments, datasetId, documentId, parentMode, updateChildSegment, notify, eventEmitter, onCloseChildSegmentDetail, refreshChunkListDataWithDetailChanged, resetChildList, t]) const onClearFilter = useCallback(() => { setInputValue('') @@ -570,6 +561,12 @@ const Completed: FC = ({ setCurrentPage(1) }, []) + const selectDefaultValue = useMemo(() => { + if (selectedStatus === 'all') + return 'all' + return selectedStatus ? 1 : 0 + }, [selectedStatus]) + return ( = ({ @@ -591,7 +588,7 @@ const Completed: FC = ({ = ({ const wordCountText = useMemo(() => { const total = formatNumber(word_count) return `${total} ${t('datasetDocuments.segment.characters', { count: word_count })}` - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [word_count]) + }, [word_count, t]) const labelPrefix = useMemo(() => { return isParentChildMode ? t('datasetDocuments.segment.parentChunk') : t('datasetDocuments.segment.chunk') - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [isParentChildMode]) + }, [isParentChildMode, t]) if (loading) return diff --git a/web/app/components/datasets/documents/detail/completed/segment-detail.tsx b/web/app/components/datasets/documents/detail/completed/segment-detail.tsx index cea3402499..d3575c18ed 100644 --- a/web/app/components/datasets/documents/detail/completed/segment-detail.tsx +++ b/web/app/components/datasets/documents/detail/completed/segment-detail.tsx @@ -86,8 +86,7 @@ const SegmentDetail: FC = ({ const titleText = useMemo(() => { return isEditMode ? t('datasetDocuments.segment.editChunk') : t('datasetDocuments.segment.chunkDetail') - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [isEditMode]) + }, [isEditMode, t]) const isQAModel = useMemo(() => { return docForm === ChunkingMode.qa @@ -98,13 +97,11 @@ const SegmentDetail: FC = ({ const total = formatNumber(isEditMode ? contentLength : segInfo!.word_count as number) const count = isEditMode ? contentLength : segInfo!.word_count as number return `${total} ${t('datasetDocuments.segment.characters', { count })}` - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [isEditMode, question.length, answer.length, segInfo?.word_count, isQAModel]) + }, [isEditMode, question.length, answer.length, isQAModel, segInfo, t]) const labelPrefix = useMemo(() => { return isParentChildMode ? t('datasetDocuments.segment.parentChunk') : t('datasetDocuments.segment.chunk') - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [isParentChildMode]) + }, [isParentChildMode, t]) return (
diff --git a/web/app/components/datasets/documents/detail/completed/segment-list.tsx b/web/app/components/datasets/documents/detail/completed/segment-list.tsx index b2351c1b97..f6076e5813 100644 --- a/web/app/components/datasets/documents/detail/completed/segment-list.tsx +++ b/web/app/components/datasets/documents/detail/completed/segment-list.tsx @@ -42,7 +42,7 @@ const SegmentList = ( embeddingAvailable, onClearFilter, }: ISegmentListProps & { - ref: React.RefObject; + ref: React.LegacyRef }, ) => { const mode = useDocumentContext(s => s.mode) diff --git a/web/app/components/datasets/documents/detail/index.tsx b/web/app/components/datasets/documents/detail/index.tsx index 2ee37bfe6a..aff74038e3 100644 --- a/web/app/components/datasets/documents/detail/index.tsx +++ b/web/app/components/datasets/documents/detail/index.tsx @@ -277,7 +277,7 @@ const DocumentDetail: FC = ({ datasetId, documentId }) => { }
} - setShowMetadata(false)} isMobile={isMobile} panelClassname='!justify-start' footer={null}> + setShowMetadata(false)} isMobile={isMobile} panelClassName='!justify-start' footer={null}> ) => { return @@ -98,7 +100,7 @@ const Documents: FC = ({ datasetId }) => { const isDataSourceWeb = dataset?.data_source_type === DataSourceType.WEB const isDataSourceFile = dataset?.data_source_type === DataSourceType.FILE const embeddingAvailable = !!dataset?.embedding_available - + const locale = getLocaleOnClient() const debouncedSearchValue = useDebounce(searchValue, { wait: 500 }) const { data: documentsRes, isFetching: isListLoading } = useDocumentList({ @@ -260,7 +262,12 @@ const Documents: FC = ({ datasetId }) => { + href={ + locale === LanguagesSupported[1] + ? 'https://docs.dify.ai/v/zh-hans/guides/knowledge-base/integrate-knowledge-within-application' + : 'https://docs.dify.ai/guides/knowledge-base/integrate-knowledge-within-application' + } + > {t('datasetDocuments.list.learnMore')} diff --git a/web/app/components/datasets/documents/list.tsx b/web/app/components/datasets/documents/list.tsx index 8ed878fe56..cb349ee01c 100644 --- a/web/app/components/datasets/documents/list.tsx +++ b/web/app/components/datasets/documents/list.tsx @@ -202,7 +202,7 @@ export const OperationAction: FC<{ const isListScene = scene === 'list' const onOperate = async (operationName: OperationName) => { - let opApi = deleteDocument + let opApi switch (operationName) { case 'archive': opApi = archiveDocument @@ -490,7 +490,7 @@ const DocumentList: FC = ({ const handleAction = (actionName: DocumentActionType) => { return async () => { - let opApi = deleteDocument + let opApi switch (actionName) { case DocumentActionType.archive: opApi = archiveDocument @@ -527,7 +527,7 @@ const DocumentList: FC = ({ )} diff --git a/web/app/components/datasets/hit-testing/index.tsx b/web/app/components/datasets/hit-testing/index.tsx index b74cb8bc83..fef69a5e61 100644 --- a/web/app/components/datasets/hit-testing/index.tsx +++ b/web/app/components/datasets/hit-testing/index.tsx @@ -176,7 +176,7 @@ const HitTestingPage: FC = ({ datasetId }: Props) => { )}
- +
{/* {renderHitResults(generalResultData)} */} {submitLoading @@ -197,7 +197,7 @@ const HitTestingPage: FC = ({ datasetId }: Props) => { }
- setIsShowModifyRetrievalModal(false)} footer={null} mask={isMobile} panelClassname='mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[640px] rounded-xl'> + setIsShowModifyRetrievalModal(false)} footer={null} mask={isMobile} panelClassName='mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[640px] rounded-xl'> = ({ className={cn(className, 'rounded-l-md')} value={value} onChange={onChange} - size='sm' + size='regular' controlWrapClassName='overflow-hidden' controlClassName='pt-0 pb-0' readOnly={readOnly} diff --git a/web/app/components/datasets/metadata/metadata-dataset/dataset-metadata-drawer.tsx b/web/app/components/datasets/metadata/metadata-dataset/dataset-metadata-drawer.tsx index 81c5b641c5..b5e4d1765b 100644 --- a/web/app/components/datasets/metadata/metadata-dataset/dataset-metadata-drawer.tsx +++ b/web/app/components/datasets/metadata/metadata-dataset/dataset-metadata-drawer.tsx @@ -173,7 +173,7 @@ const DatasetMetadataDrawer: FC = ({ showClose title={t('dataset.metadata.metadata')} footer={null} - panelClassname='px-4 block !max-w-[420px] my-2 rounded-l-2xl' + panelClassName='px-4 block !max-w-[420px] my-2 rounded-l-2xl' >
{t(`${i18nPrefix}.description`)}
diff --git a/web/app/components/datasets/settings/permission-selector/index.tsx b/web/app/components/datasets/settings/permission-selector/index.tsx index 71a46087af..9bb6f812d4 100644 --- a/web/app/components/datasets/settings/permission-selector/index.tsx +++ b/web/app/components/datasets/settings/permission-selector/index.tsx @@ -150,8 +150,8 @@ const PermissionSelector = ({ disabled, permission, value, memberList, onChange,
{isPartialMembers && ( -
-
+
+
嵌入模型的提供商和模型名称可以通过以下接口获取:v1/workspaces/current/models/model-types/text-embedding, 具体见:通过 API 维护知识库。 使用的Authorization是Dataset的API Token。 + 该接口是异步执行,所以会返回一个job_id,通过查询job状态接口可以获取到最终的执行结果。 diff --git a/web/app/components/develop/template/template_advanced_chat.zh.mdx b/web/app/components/develop/template/template_advanced_chat.zh.mdx index 42eaf4f7b2..7135cf6188 100755 --- a/web/app/components/develop/template/template_advanced_chat.zh.mdx +++ b/web/app/components/develop/template/template_advanced_chat.zh.mdx @@ -523,7 +523,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' ```bash {{ title: 'cURL' }} - curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested'?user=abc-123 \ + curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested?user=abc-123' \ --header 'Authorization: Bearer ENTER-YOUR-SECRET-KEY' \ --header 'Content-Type: application/json' \ ``` @@ -967,7 +967,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' "user": "abc-123" }' ``` - + @@ -1191,10 +1191,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' title="Request" tag="GET" label="/apps/annotations" - targetCode={`curl --location --request GET '${props.apiBaseUrl}/apps/annotations?page=1&limit=20' \\\n--header 'Authorization: Bearer {api_key}'`} + targetCode={`curl --location --request GET '${props.appDetail.api_base_url}/apps/annotations?page=1&limit=20' \\\n--header 'Authorization: Bearer {api_key}'`} > ```bash {{ title: 'cURL' }} - curl --location --request GET '${props.apiBaseUrl}/apps/annotations?page=1&limit=20' \ + curl --location --request GET '${props.appDetail.api_base_url}/apps/annotations?page=1&limit=20' \ --header 'Authorization: Bearer {api_key}' ``` @@ -1245,10 +1245,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' title="Request" tag="POST" label="/apps/annotations" - targetCode={`curl --location --request POST '${props.apiBaseUrl}/apps/annotations' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"question": "What is your name?","answer": "I am Dify."}'`} + targetCode={`curl --location --request POST '${props.appDetail.api_base_url}/apps/annotations' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"question": "What is your name?","answer": "I am Dify."}'`} > ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/apps/annotations' \ + curl --location --request POST '${props.appDetail.api_base_url}/apps/annotations' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ @@ -1301,10 +1301,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' title="Request" tag="PUT" label="/apps/annotations/{annotation_id}" - targetCode={`curl --location --request POST '${props.apiBaseUrl}/apps/annotations/{annotation_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"question": "What is your name?","answer": "I am Dify."}'`} + targetCode={`curl --location --request POST '${props.appDetail.api_base_url}/apps/annotations/{annotation_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"question": "What is your name?","answer": "I am Dify."}'`} > ```bash {{ title: 'cURL' }} - curl --location --request POST '${props.apiBaseUrl}/apps/annotations/{annotation_id}' \ + curl --location --request POST '${props.appDetail.api_base_url}/apps/annotations/{annotation_id}' \ --header 'Authorization: Bearer {api_key}' \ --header 'Content-Type: application/json' \ --data-raw '{ @@ -1351,10 +1351,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' title="Request" tag="PUT" label="/apps/annotations/{annotation_id}" - targetCode={`curl --location --request DELETE '${props.apiBaseUrl}/apps/annotations/{annotation_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json'`} + targetCode={`curl --location --request DELETE '${props.appDetail.api_base_url}/apps/annotations/{annotation_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json'`} > ```bash {{ title: 'cURL' }} - curl --location --request DELETE '${props.apiBaseUrl}/apps/annotations/{annotation_id}' \ + curl --location --request DELETE '${props.appDetail.api_base_url}/apps/annotations/{annotation_id}' \ --header 'Authorization: Bearer {api_key}' ``` @@ -1398,7 +1398,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' title="Request" tag="POST" label="/apps/annotation-reply/{action}" - targetCode={`curl --location --request POST '${props.apiBaseUrl}/apps/annotation-reply/{action}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"score_threshold": 0.9, "embedding_provider_name": "zhipu", "embedding_model_name": "embedding_3"}'`} + targetCode={`curl --location --request POST '${props.appDetail.api_base_url}/apps/annotation-reply/{action}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"score_threshold": 0.9, "embedding_provider_name": "zhipu", "embedding_model_name": "embedding_3"}'`} > ```bash {{ title: 'cURL' }} curl --location --request POST 'https://api.dify.ai/v1/apps/annotation-reply/{action}' \ @@ -1448,10 +1448,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' title="Request" tag="GET" label="/apps/annotations" - targetCode={`curl --location --request GET '${props.apiBaseUrl}/apps/annotation-reply/{action}/status/{job_id}' \\\n--header 'Authorization: Bearer {api_key}'`} + targetCode={`curl --location --request GET '${props.appDetail.api_base_url}/apps/annotation-reply/{action}/status/{job_id}' \\\n--header 'Authorization: Bearer {api_key}'`} > ```bash {{ title: 'cURL' }} - curl --location --request GET '${props.apiBaseUrl}/apps/annotation-reply/{action}/status/{job_id}' \ + curl --location --request GET '${props.appDetail.api_base_url}/apps/annotation-reply/{action}/status/{job_id}' \ --header 'Authorization: Bearer {api_key}' ``` diff --git a/web/app/components/header/account-dropdown/workplace-selector/index.tsx b/web/app/components/header/account-dropdown/workplace-selector/index.tsx index a9a886376a..da3f8bae6d 100644 --- a/web/app/components/header/account-dropdown/workplace-selector/index.tsx +++ b/web/app/components/header/account-dropdown/workplace-selector/index.tsx @@ -42,7 +42,7 @@ const WorkplaceSelector = () => { `, )}>
- {currentWorkspace?.name[0]?.toLocaleUpperCase()} + {currentWorkspace?.name[0]?.toLocaleUpperCase()}
{currentWorkspace?.name}
@@ -73,7 +73,7 @@ const WorkplaceSelector = () => { workspaces.map(workspace => (
handleSwitchWorkspace(workspace.id)}>
- {workspace?.name[0]?.toLocaleUpperCase()} + {workspace?.name[0]?.toLocaleUpperCase()}
{workspace.name}
diff --git a/web/app/components/header/account-setting/model-provider-page/declarations.ts b/web/app/components/header/account-setting/model-provider-page/declarations.ts index 39e229cd54..12dd9b3b5b 100644 --- a/web/app/components/header/account-setting/model-provider-page/declarations.ts +++ b/web/app/components/header/account-setting/model-provider-page/declarations.ts @@ -60,6 +60,7 @@ export enum ModelFeatureEnum { video = 'video', document = 'document', audio = 'audio', + StructuredOutput = 'structured-output', } export enum ModelFeatureTextEnum { diff --git a/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx index 025cb87dc1..9019051989 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx @@ -23,9 +23,9 @@ const ModelIcon: FC = ({ isDeprecated = false, }) => { const language = useLanguage() - if (provider?.provider.includes('openai') && modelName?.includes('gpt-4o')) + if (provider?.provider && ['openai', 'langgenius/openai/openai'].includes(provider.provider) && modelName?.includes('gpt-4o')) return
- if (provider?.provider.includes('openai') && modelName?.startsWith('gpt-4')) + if (provider?.provider && ['openai', 'langgenius/openai/openai'].includes(provider.provider) && modelName?.startsWith('gpt-4')) return
if (provider?.icon_small) { diff --git a/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx b/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx index 28001bef5e..c5af4ed8a1 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx @@ -376,6 +376,7 @@ function Form< tooltip={tooltip?.[language] || tooltip?.en_US} value={value[variable] || []} onChange={item => handleFormChange(variable, item as any)} + supportCollapse /> {fieldMoreInfo?.(formSchema)} {validating && changeKey === variable && } diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx index 4bb3cbf7d5..3e969d708b 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx @@ -10,6 +10,7 @@ import Slider from '@/app/components/base/slider' import Radio from '@/app/components/base/radio' import { SimpleSelect } from '@/app/components/base/select' import TagInput from '@/app/components/base/tag-input' +import { useTranslation } from 'react-i18next' export type ParameterValue = number | string | string[] | boolean | undefined @@ -27,6 +28,7 @@ const ParameterItem: FC = ({ onSwitch, isInWorkflow, }) => { + const { t } = useTranslation() const language = useLanguage() const [localValue, setLocalValue] = useState(value) const numberInputRef = useRef(null) diff --git a/web/app/components/header/account-setting/model-provider-page/model-selector/popup.tsx b/web/app/components/header/account-setting/model-provider-page/model-selector/popup.tsx index 6a336fb6f7..63849bddda 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-selector/popup.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-selector/popup.tsx @@ -74,7 +74,7 @@ const Popup: FC = ({ /> setSearchText(e.target.value)} /> diff --git a/web/app/components/plugins/marketplace/list/list-with-collection.tsx b/web/app/components/plugins/marketplace/list/list-with-collection.tsx index e18356cd85..4c396c565f 100644 --- a/web/app/components/plugins/marketplace/list/list-with-collection.tsx +++ b/web/app/components/plugins/marketplace/list/list-with-collection.tsx @@ -32,7 +32,9 @@ const ListWithCollection = ({ return ( <> { - marketplaceCollections.map(collection => ( + marketplaceCollections.filter((collection) => { + return marketplaceCollectionPluginsMap[collection.name]?.length + }).map(collection => (
= ({ footer={null} mask positionCenter={false} - panelClassname={cn('mb-2 mr-2 mt-[64px] !w-[420px] !max-w-[420px] justify-start rounded-2xl border-[0.5px] border-components-panel-border !bg-components-panel-bg !p-0 shadow-xl')} + panelClassName={cn('mb-2 mr-2 mt-[64px] !w-[420px] !max-w-[420px] justify-start rounded-2xl border-[0.5px] border-components-panel-border !bg-components-panel-bg !p-0 shadow-xl')} > <>
diff --git a/web/app/components/plugins/plugin-detail-panel/index.tsx b/web/app/components/plugins/plugin-detail-panel/index.tsx index 70bd9edabc..3ec867faae 100644 --- a/web/app/components/plugins/plugin-detail-panel/index.tsx +++ b/web/app/components/plugins/plugin-detail-panel/index.tsx @@ -38,7 +38,7 @@ const PluginDetailPanel: FC = ({ footer={null} mask={false} positionCenter={false} - panelClassname={cn('mb-2 mr-2 mt-[64px] !w-[420px] !max-w-[420px] justify-start rounded-2xl border-[0.5px] border-components-panel-border !bg-components-panel-bg !p-0 shadow-xl')} + panelClassName={cn('mb-2 mr-2 mt-[64px] !w-[420px] !max-w-[420px] justify-start rounded-2xl border-[0.5px] border-components-panel-border !bg-components-panel-bg !p-0 shadow-xl')} > {detail && ( <> diff --git a/web/app/components/plugins/plugin-detail-panel/multiple-tool-selector/index.tsx b/web/app/components/plugins/plugin-detail-panel/multiple-tool-selector/index.tsx index fc29feaefc..f243d30aff 100644 --- a/web/app/components/plugins/plugin-detail-panel/multiple-tool-selector/index.tsx +++ b/web/app/components/plugins/plugin-detail-panel/multiple-tool-selector/index.tsx @@ -2,7 +2,6 @@ import React from 'react' import { useTranslation } from 'react-i18next' import { RiAddLine, - RiArrowDropDownLine, RiQuestionLine, } from '@remixicon/react' import ToolSelector from '@/app/components/plugins/plugin-detail-panel/tool-selector' @@ -13,6 +12,7 @@ import type { ToolValue } from '@/app/components/workflow/block-selector/types' import type { Node } from 'reactflow' import type { NodeOutPutVar } from '@/app/components/workflow/types' import cn from '@/utils/classnames' +import { ArrowDownRoundFill } from '@/app/components/base/icons/src/vender/solid/general' type Props = { disabled?: boolean @@ -98,14 +98,12 @@ const MultipleToolSelector = ({ )} {supportCollapse && ( -
- -
+ )}
{value.length > 0 && ( diff --git a/web/app/components/plugins/plugin-detail-panel/strategy-detail.tsx b/web/app/components/plugins/plugin-detail-panel/strategy-detail.tsx index 89ee850e03..00794d83ed 100644 --- a/web/app/components/plugins/plugin-detail-panel/strategy-detail.tsx +++ b/web/app/components/plugins/plugin-detail-panel/strategy-detail.tsx @@ -78,7 +78,7 @@ const StrategyDetail: FC = ({ footer={null} mask={false} positionCenter={false} - panelClassname={cn('mb-2 mr-2 mt-[64px] !w-[420px] !max-w-[420px] justify-start rounded-2xl border-[0.5px] border-components-panel-border !bg-components-panel-bg !p-0 shadow-xl')} + panelClassName={cn('mb-2 mr-2 mt-[64px] !w-[420px] !max-w-[420px] justify-start rounded-2xl border-[0.5px] border-components-panel-border !bg-components-panel-bg !p-0 shadow-xl')} > <> {/* header */} diff --git a/web/app/components/tools/add-tool-modal/index.tsx b/web/app/components/tools/add-tool-modal/index.tsx index 1129fe55ce..c45313fc09 100644 --- a/web/app/components/tools/add-tool-modal/index.tsx +++ b/web/app/components/tools/add-tool-modal/index.tsx @@ -178,7 +178,7 @@ const AddToolModal: FC = ({ clickOutsideNotOpen onClose={onHide} footer={null} - panelClassname={cn('mx-2 mb-3 mt-16 rounded-xl !p-0 sm:mr-2', 'mt-2 !w-[640px]', '!max-w-[640px]')} + panelClassName={cn('mx-2 mb-3 mt-16 rounded-xl !p-0 sm:mr-2', 'mt-2 !w-[640px]', '!max-w-[640px]')} >
= ({ positionCenter={positionCenter} onHide={onHide} title={t('tools.createTool.authMethod.title')!} - panelClassName='mt-2 !w-[520px] h-fit' + dialogClassName='z-[60]' + dialogBackdropClassName='z-[70]' + panelClassName='mt-2 !w-[520px] h-fit z-[80]' maxWidthClassName='!max-w-[520px]' height={'fit-content'} headerClassName='!border-b-divider-regular' diff --git a/web/app/components/tools/provider/detail.tsx b/web/app/components/tools/provider/detail.tsx index 5d3a1794d8..21ea8bc464 100644 --- a/web/app/components/tools/provider/detail.tsx +++ b/web/app/components/tools/provider/detail.tsx @@ -234,7 +234,7 @@ const ProviderDetail = ({ footer={null} mask={false} positionCenter={false} - panelClassname={cn('mb-2 mr-2 mt-[64px] !w-[420px] !max-w-[420px] justify-start rounded-2xl border-[0.5px] border-components-panel-border !bg-components-panel-bg !p-0 shadow-xl')} + panelClassName={cn('mb-2 mr-2 mt-[64px] !w-[420px] !max-w-[420px] justify-start rounded-2xl border-[0.5px] border-components-panel-border !bg-components-panel-bg !p-0 shadow-xl')} >
diff --git a/web/app/components/workflow-app/components/workflow-children.tsx b/web/app/components/workflow-app/components/workflow-children.tsx new file mode 100644 index 0000000000..6a6bbcd61a --- /dev/null +++ b/web/app/components/workflow-app/components/workflow-children.tsx @@ -0,0 +1,69 @@ +import { + memo, + useState, +} from 'react' +import type { EnvironmentVariable } from '@/app/components/workflow/types' +import { DSL_EXPORT_CHECK } from '@/app/components/workflow/constants' +import { useStore } from '@/app/components/workflow/store' +import Features from '@/app/components/workflow/features' +import PluginDependency from '@/app/components/workflow/plugin-dependency' +import UpdateDSLModal from '@/app/components/workflow/update-dsl-modal' +import DSLExportConfirmModal from '@/app/components/workflow/dsl-export-confirm-modal' +import { + useDSL, + usePanelInteractions, +} from '@/app/components/workflow/hooks' +import { useEventEmitterContextContext } from '@/context/event-emitter' +import WorkflowHeader from './workflow-header' +import WorkflowPanel from './workflow-panel' + +const WorkflowChildren = () => { + const { eventEmitter } = useEventEmitterContextContext() + const [secretEnvList, setSecretEnvList] = useState([]) + const showFeaturesPanel = useStore(s => s.showFeaturesPanel) + const showImportDSLModal = useStore(s => s.showImportDSLModal) + const setShowImportDSLModal = useStore(s => s.setShowImportDSLModal) + const { + handlePaneContextmenuCancel, + } = usePanelInteractions() + const { + exportCheck, + handleExportDSL, + } = useDSL() + + eventEmitter?.useSubscription((v: any) => { + if (v.type === DSL_EXPORT_CHECK) + setSecretEnvList(v.payload.data as EnvironmentVariable[]) + }) + + return ( + <> + + { + showFeaturesPanel && + } + { + showImportDSLModal && ( + setShowImportDSLModal(false)} + onBackup={exportCheck} + onImport={handlePaneContextmenuCancel} + /> + ) + } + { + secretEnvList.length > 0 && ( + setSecretEnvList([])} + /> + ) + } + + + + ) +} + +export default memo(WorkflowChildren) diff --git a/web/app/components/workflow-app/components/workflow-header/chat-variable-trigger.tsx b/web/app/components/workflow-app/components/workflow-header/chat-variable-trigger.tsx new file mode 100644 index 0000000000..df93914285 --- /dev/null +++ b/web/app/components/workflow-app/components/workflow-header/chat-variable-trigger.tsx @@ -0,0 +1,11 @@ +import { memo } from 'react' +import ChatVariableButton from '@/app/components/workflow/header/chat-variable-button' +import { + useNodesReadOnly, +} from '@/app/components/workflow/hooks' + +const ChatVariableTrigger = () => { + const { nodesReadOnly } = useNodesReadOnly() + return +} +export default memo(ChatVariableTrigger) diff --git a/web/app/components/workflow-app/components/workflow-header/features-trigger.tsx b/web/app/components/workflow-app/components/workflow-header/features-trigger.tsx new file mode 100644 index 0000000000..da64409090 --- /dev/null +++ b/web/app/components/workflow-app/components/workflow-header/features-trigger.tsx @@ -0,0 +1,152 @@ +import { + memo, + useCallback, + useMemo, +} from 'react' +import { useNodes } from 'reactflow' +import { RiApps2AddLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { + useStore, + useWorkflowStore, +} from '@/app/components/workflow/store' +import { + useChecklistBeforePublish, + useNodesReadOnly, + useNodesSyncDraft, +} from '@/app/components/workflow/hooks' +import Button from '@/app/components/base/button' +import AppPublisher from '@/app/components/app/app-publisher' +import { useFeatures } from '@/app/components/base/features/hooks' +import { + BlockEnum, + InputVarType, +} from '@/app/components/workflow/types' +import type { StartNodeType } from '@/app/components/workflow/nodes/start/types' +import { useToastContext } from '@/app/components/base/toast' +import { usePublishWorkflow, useResetWorkflowVersionHistory } from '@/service/use-workflow' +import type { PublishWorkflowParams } from '@/types/workflow' +import { fetchAppDetail, fetchAppSSO } from '@/service/apps' +import { useStore as useAppStore } from '@/app/components/app/store' +import { useSelector as useAppSelector } from '@/context/app-context' + +const FeaturesTrigger = () => { + const { t } = useTranslation() + const workflowStore = useWorkflowStore() + const appDetail = useAppStore(s => s.appDetail) + const appID = appDetail?.id + const setAppDetail = useAppStore(s => s.setAppDetail) + const systemFeatures = useAppSelector(state => state.systemFeatures) + const { + nodesReadOnly, + getNodesReadOnly, + } = useNodesReadOnly() + const publishedAt = useStore(s => s.publishedAt) + const draftUpdatedAt = useStore(s => s.draftUpdatedAt) + const toolPublished = useStore(s => s.toolPublished) + const nodes = useNodes() + const startNode = nodes.find(node => node.data.type === BlockEnum.Start) + const startVariables = startNode?.data.variables + const fileSettings = useFeatures(s => s.features.file) + const variables = useMemo(() => { + const data = startVariables || [] + if (fileSettings?.image?.enabled) { + return [ + ...data, + { + type: InputVarType.files, + variable: '__image', + required: false, + label: 'files', + }, + ] + } + + return data + }, [fileSettings?.image?.enabled, startVariables]) + + const { handleCheckBeforePublish } = useChecklistBeforePublish() + const { handleSyncWorkflowDraft } = useNodesSyncDraft() + const { notify } = useToastContext() + + const handleShowFeatures = useCallback(() => { + const { + showFeaturesPanel, + isRestoring, + setShowFeaturesPanel, + } = workflowStore.getState() + if (getNodesReadOnly() && !isRestoring) + return + setShowFeaturesPanel(!showFeaturesPanel) + }, [workflowStore, getNodesReadOnly]) + + const resetWorkflowVersionHistory = useResetWorkflowVersionHistory(appDetail!.id) + + const updateAppDetail = useCallback(async () => { + try { + const res = await fetchAppDetail({ url: '/apps', id: appID! }) + if (systemFeatures.enable_web_sso_switch_component) { + const ssoRes = await fetchAppSSO({ appId: appID! }) + setAppDetail({ ...res, enable_sso: ssoRes.enabled }) + } + else { + setAppDetail({ ...res }) + } + } + catch (error) { + console.error(error) + } + }, [appID, setAppDetail, systemFeatures.enable_web_sso_switch_component]) + const { mutateAsync: publishWorkflow } = usePublishWorkflow(appID!) + const onPublish = useCallback(async (params?: PublishWorkflowParams) => { + if (await handleCheckBeforePublish()) { + const res = await publishWorkflow({ + title: params?.title || '', + releaseNotes: params?.releaseNotes || '', + }) + + if (res) { + notify({ type: 'success', message: t('common.api.actionSuccess') }) + updateAppDetail() + workflowStore.getState().setPublishedAt(res.created_at) + resetWorkflowVersionHistory() + } + } + else { + throw new Error('Checklist failed') + } + }, [handleCheckBeforePublish, notify, t, workflowStore, publishWorkflow, resetWorkflowVersionHistory, updateAppDetail]) + + const onPublisherToggle = useCallback((state: boolean) => { + if (state) + handleSyncWorkflowDraft(true) + }, [handleSyncWorkflowDraft]) + + const handleToolConfigureUpdate = useCallback(() => { + workflowStore.setState({ toolPublished: true }) + }, [workflowStore]) + + return ( + <> + + + + ) +} + +export default memo(FeaturesTrigger) diff --git a/web/app/components/workflow-app/components/workflow-header/index.tsx b/web/app/components/workflow-app/components/workflow-header/index.tsx new file mode 100644 index 0000000000..4eb8df7162 --- /dev/null +++ b/web/app/components/workflow-app/components/workflow-header/index.tsx @@ -0,0 +1,31 @@ +import { useMemo } from 'react' +import type { HeaderProps } from '@/app/components/workflow/header' +import Header from '@/app/components/workflow/header' +import { useStore as useAppStore } from '@/app/components/app/store' +import ChatVariableTrigger from './chat-variable-trigger' +import FeaturesTrigger from './features-trigger' +import { useResetWorkflowVersionHistory } from '@/service/use-workflow' + +const WorkflowHeader = () => { + const appDetail = useAppStore(s => s.appDetail) + const resetWorkflowVersionHistory = useResetWorkflowVersionHistory(appDetail!.id) + + const headerProps: HeaderProps = useMemo(() => { + return { + normal: { + components: { + left: , + middle: , + }, + }, + restoring: { + onRestoreSettled: resetWorkflowVersionHistory, + }, + } + }, [resetWorkflowVersionHistory]) + return ( +
+ ) +} + +export default WorkflowHeader diff --git a/web/app/components/workflow-app/components/workflow-main.tsx b/web/app/components/workflow-app/components/workflow-main.tsx new file mode 100644 index 0000000000..4ff1f4c624 --- /dev/null +++ b/web/app/components/workflow-app/components/workflow-main.tsx @@ -0,0 +1,87 @@ +import { + useCallback, + useMemo, +} from 'react' +import { useFeaturesStore } from '@/app/components/base/features/hooks' +import { WorkflowWithInnerContext } from '@/app/components/workflow' +import type { WorkflowProps } from '@/app/components/workflow' +import WorkflowChildren from './workflow-children' +import { + useNodesSyncDraft, + useWorkflowRun, + useWorkflowStartRun, +} from '../hooks' + +type WorkflowMainProps = Pick +const WorkflowMain = ({ + nodes, + edges, + viewport, +}: WorkflowMainProps) => { + const featuresStore = useFeaturesStore() + + const handleWorkflowDataUpdate = useCallback((payload: any) => { + if (payload.features && featuresStore) { + const { setFeatures } = featuresStore.getState() + + setFeatures(payload.features) + } + }, [featuresStore]) + + const { + doSyncWorkflowDraft, + syncWorkflowDraftWhenPageClose, + } = useNodesSyncDraft() + const { + handleBackupDraft, + handleLoadBackupDraft, + handleRestoreFromPublishedWorkflow, + handleRun, + handleStopRun, + } = useWorkflowRun() + const { + handleStartWorkflowRun, + handleWorkflowStartRunInChatflow, + handleWorkflowStartRunInWorkflow, + } = useWorkflowStartRun() + + const hooksStore = useMemo(() => { + return { + syncWorkflowDraftWhenPageClose, + doSyncWorkflowDraft, + handleBackupDraft, + handleLoadBackupDraft, + handleRestoreFromPublishedWorkflow, + handleRun, + handleStopRun, + handleStartWorkflowRun, + handleWorkflowStartRunInChatflow, + handleWorkflowStartRunInWorkflow, + } + }, [ + syncWorkflowDraftWhenPageClose, + doSyncWorkflowDraft, + handleBackupDraft, + handleLoadBackupDraft, + handleRestoreFromPublishedWorkflow, + handleRun, + handleStopRun, + handleStartWorkflowRun, + handleWorkflowStartRunInChatflow, + handleWorkflowStartRunInWorkflow, + ]) + + return ( + + + + ) +} + +export default WorkflowMain diff --git a/web/app/components/workflow-app/components/workflow-panel.tsx b/web/app/components/workflow-app/components/workflow-panel.tsx new file mode 100644 index 0000000000..3c1b5c8aac --- /dev/null +++ b/web/app/components/workflow-app/components/workflow-panel.tsx @@ -0,0 +1,109 @@ +import { useMemo } from 'react' +import { useShallow } from 'zustand/react/shallow' +import { useStore } from '@/app/components/workflow/store' +import { + useIsChatMode, +} from '../hooks' +import DebugAndPreview from '@/app/components/workflow/panel/debug-and-preview' +import Record from '@/app/components/workflow/panel/record' +import WorkflowPreview from '@/app/components/workflow/panel/workflow-preview' +import ChatRecord from '@/app/components/workflow/panel/chat-record' +import ChatVariablePanel from '@/app/components/workflow/panel/chat-variable-panel' +import GlobalVariablePanel from '@/app/components/workflow/panel/global-variable-panel' +import VersionHistoryPanel from '@/app/components/workflow/panel/version-history-panel' +import { useStore as useAppStore } from '@/app/components/app/store' +import MessageLogModal from '@/app/components/base/message-log-modal' +import type { PanelProps } from '@/app/components/workflow/panel' +import Panel from '@/app/components/workflow/panel' + +const WorkflowPanelOnLeft = () => { + const { currentLogItem, setCurrentLogItem, showMessageLogModal, setShowMessageLogModal, currentLogModalActiveTab } = useAppStore(useShallow(state => ({ + currentLogItem: state.currentLogItem, + setCurrentLogItem: state.setCurrentLogItem, + showMessageLogModal: state.showMessageLogModal, + setShowMessageLogModal: state.setShowMessageLogModal, + currentLogModalActiveTab: state.currentLogModalActiveTab, + }))) + return ( + <> + { + showMessageLogModal && ( + { + setCurrentLogItem() + setShowMessageLogModal(false) + }} + defaultTab={currentLogModalActiveTab} + /> + ) + } + + ) +} +const WorkflowPanelOnRight = () => { + const isChatMode = useIsChatMode() + const historyWorkflowData = useStore(s => s.historyWorkflowData) + const showDebugAndPreviewPanel = useStore(s => s.showDebugAndPreviewPanel) + const showChatVariablePanel = useStore(s => s.showChatVariablePanel) + const showGlobalVariablePanel = useStore(s => s.showGlobalVariablePanel) + const showWorkflowVersionHistoryPanel = useStore(s => s.showWorkflowVersionHistoryPanel) + + return ( + <> + { + historyWorkflowData && !isChatMode && ( + + ) + } + { + historyWorkflowData && isChatMode && ( + + ) + } + { + showDebugAndPreviewPanel && isChatMode && ( + + ) + } + { + showDebugAndPreviewPanel && !isChatMode && ( + + ) + } + { + showChatVariablePanel && ( + + ) + } + { + showGlobalVariablePanel && ( + + ) + } + { + showWorkflowVersionHistoryPanel && ( + + ) + } + + ) +} +const WorkflowPanel = () => { + const panelProps: PanelProps = useMemo(() => { + return { + components: { + left: , + right: , + }, + } + }, []) + + return ( + + ) +} + +export default WorkflowPanel diff --git a/web/app/components/workflow-app/hooks/index.ts b/web/app/components/workflow-app/hooks/index.ts new file mode 100644 index 0000000000..1517eb9a16 --- /dev/null +++ b/web/app/components/workflow-app/hooks/index.ts @@ -0,0 +1,6 @@ +export * from './use-workflow-init' +export * from './use-workflow-template' +export * from './use-nodes-sync-draft' +export * from './use-workflow-run' +export * from './use-workflow-start-run' +export * from './use-is-chat-mode' diff --git a/web/app/components/workflow-app/hooks/use-is-chat-mode.ts b/web/app/components/workflow-app/hooks/use-is-chat-mode.ts new file mode 100644 index 0000000000..3cdfc77b2a --- /dev/null +++ b/web/app/components/workflow-app/hooks/use-is-chat-mode.ts @@ -0,0 +1,7 @@ +import { useStore as useAppStore } from '@/app/components/app/store' + +export const useIsChatMode = () => { + const appDetail = useAppStore(s => s.appDetail) + + return appDetail?.mode === 'advanced-chat' +} diff --git a/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts b/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts new file mode 100644 index 0000000000..7c6eb6a5be --- /dev/null +++ b/web/app/components/workflow-app/hooks/use-nodes-sync-draft.ts @@ -0,0 +1,148 @@ +import { useCallback } from 'react' +import produce from 'immer' +import { useStoreApi } from 'reactflow' +import { useParams } from 'next/navigation' +import { + useWorkflowStore, +} from '@/app/components/workflow/store' +import { BlockEnum } from '@/app/components/workflow/types' +import { useWorkflowUpdate } from '@/app/components/workflow/hooks' +import { + useNodesReadOnly, +} from '@/app/components/workflow/hooks/use-workflow' +import { syncWorkflowDraft } from '@/service/workflow' +import { useFeaturesStore } from '@/app/components/base/features/hooks' +import { API_PREFIX } from '@/config' + +export const useNodesSyncDraft = () => { + const store = useStoreApi() + const workflowStore = useWorkflowStore() + const featuresStore = useFeaturesStore() + const { getNodesReadOnly } = useNodesReadOnly() + const { handleRefreshWorkflowDraft } = useWorkflowUpdate() + const params = useParams() + + const getPostParams = useCallback(() => { + const { + getNodes, + edges, + transform, + } = store.getState() + const [x, y, zoom] = transform + const { + appId, + conversationVariables, + environmentVariables, + syncWorkflowDraftHash, + } = workflowStore.getState() + + if (appId) { + const nodes = getNodes() + const hasStartNode = nodes.find(node => node.data.type === BlockEnum.Start) + + if (!hasStartNode) + return + + const features = featuresStore!.getState().features + const producedNodes = produce(nodes, (draft) => { + draft.forEach((node) => { + Object.keys(node.data).forEach((key) => { + if (key.startsWith('_')) + delete node.data[key] + }) + }) + }) + const producedEdges = produce(edges, (draft) => { + draft.forEach((edge) => { + Object.keys(edge.data).forEach((key) => { + if (key.startsWith('_')) + delete edge.data[key] + }) + }) + }) + return { + url: `/apps/${appId}/workflows/draft`, + params: { + graph: { + nodes: producedNodes, + edges: producedEdges, + viewport: { + x, + y, + zoom, + }, + }, + features: { + opening_statement: features.opening?.enabled ? (features.opening?.opening_statement || '') : '', + suggested_questions: features.opening?.enabled ? (features.opening?.suggested_questions || []) : [], + suggested_questions_after_answer: features.suggested, + text_to_speech: features.text2speech, + speech_to_text: features.speech2text, + retriever_resource: features.citation, + sensitive_word_avoidance: features.moderation, + file_upload: features.file, + }, + environment_variables: environmentVariables, + conversation_variables: conversationVariables, + hash: syncWorkflowDraftHash, + }, + } + } + }, [store, featuresStore, workflowStore]) + + const syncWorkflowDraftWhenPageClose = useCallback(() => { + if (getNodesReadOnly()) + return + const postParams = getPostParams() + + if (postParams) { + navigator.sendBeacon( + `${API_PREFIX}/apps/${params.appId}/workflows/draft?_token=${localStorage.getItem('console_token')}`, + JSON.stringify(postParams.params), + ) + } + }, [getPostParams, params.appId, getNodesReadOnly]) + + const doSyncWorkflowDraft = useCallback(async ( + notRefreshWhenSyncError?: boolean, + callback?: { + onSuccess?: () => void + onError?: () => void + onSettled?: () => void + }, + ) => { + if (getNodesReadOnly()) + return + const postParams = getPostParams() + + if (postParams) { + const { + setSyncWorkflowDraftHash, + setDraftUpdatedAt, + } = workflowStore.getState() + try { + const res = await syncWorkflowDraft(postParams) + setSyncWorkflowDraftHash(res.hash) + setDraftUpdatedAt(res.updated_at) + callback?.onSuccess && callback.onSuccess() + } + catch (error: any) { + if (error && error.json && !error.bodyUsed) { + error.json().then((err: any) => { + if (err.code === 'draft_workflow_not_sync' && !notRefreshWhenSyncError) + handleRefreshWorkflowDraft() + }) + } + callback?.onError && callback.onError() + } + finally { + callback?.onSettled && callback.onSettled() + } + } + }, [workflowStore, getPostParams, getNodesReadOnly, handleRefreshWorkflowDraft]) + + return { + doSyncWorkflowDraft, + syncWorkflowDraftWhenPageClose, + } +} diff --git a/web/app/components/workflow-app/hooks/use-workflow-init.ts b/web/app/components/workflow-app/hooks/use-workflow-init.ts new file mode 100644 index 0000000000..e1c4c25a4e --- /dev/null +++ b/web/app/components/workflow-app/hooks/use-workflow-init.ts @@ -0,0 +1,123 @@ +import { + useCallback, + useEffect, + useState, +} from 'react' +import { + useStore, + useWorkflowStore, +} from '@/app/components/workflow/store' +import { useWorkflowTemplate } from './use-workflow-template' +import { useStore as useAppStore } from '@/app/components/app/store' +import { + fetchNodesDefaultConfigs, + fetchPublishedWorkflow, + fetchWorkflowDraft, + syncWorkflowDraft, +} from '@/service/workflow' +import type { FetchWorkflowDraftResponse } from '@/types/workflow' +import { useWorkflowConfig } from '@/service/use-workflow' + +export const useWorkflowInit = () => { + const workflowStore = useWorkflowStore() + const { + nodes: nodesTemplate, + edges: edgesTemplate, + } = useWorkflowTemplate() + const appDetail = useAppStore(state => state.appDetail)! + const setSyncWorkflowDraftHash = useStore(s => s.setSyncWorkflowDraftHash) + const [data, setData] = useState() + const [isLoading, setIsLoading] = useState(true) + useEffect(() => { + workflowStore.setState({ appId: appDetail.id }) + }, [appDetail.id, workflowStore]) + + const handleUpdateWorkflowConfig = useCallback((config: Record) => { + const { setWorkflowConfig } = workflowStore.getState() + + setWorkflowConfig(config) + }, [workflowStore]) + useWorkflowConfig(appDetail.id, handleUpdateWorkflowConfig) + + const handleGetInitialWorkflowData = useCallback(async () => { + try { + const res = await fetchWorkflowDraft(`/apps/${appDetail.id}/workflows/draft`) + setData(res) + workflowStore.setState({ + envSecrets: (res.environment_variables || []).filter(env => env.value_type === 'secret').reduce((acc, env) => { + acc[env.id] = env.value + return acc + }, {} as Record), + environmentVariables: res.environment_variables?.map(env => env.value_type === 'secret' ? { ...env, value: '[__HIDDEN__]' } : env) || [], + conversationVariables: res.conversation_variables || [], + }) + setSyncWorkflowDraftHash(res.hash) + setIsLoading(false) + } + catch (error: any) { + if (error && error.json && !error.bodyUsed && appDetail) { + error.json().then((err: any) => { + if (err.code === 'draft_workflow_not_exist') { + workflowStore.setState({ notInitialWorkflow: true }) + syncWorkflowDraft({ + url: `/apps/${appDetail.id}/workflows/draft`, + params: { + graph: { + nodes: nodesTemplate, + edges: edgesTemplate, + }, + features: { + retriever_resource: { enabled: true }, + }, + environment_variables: [], + conversation_variables: [], + }, + }).then((res) => { + workflowStore.getState().setDraftUpdatedAt(res.updated_at) + handleGetInitialWorkflowData() + }) + } + }) + } + } + }, [appDetail, nodesTemplate, edgesTemplate, workflowStore, setSyncWorkflowDraftHash]) + + useEffect(() => { + handleGetInitialWorkflowData() + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []) + + const handleFetchPreloadData = useCallback(async () => { + try { + const nodesDefaultConfigsData = await fetchNodesDefaultConfigs(`/apps/${appDetail?.id}/workflows/default-workflow-block-configs`) + const publishedWorkflow = await fetchPublishedWorkflow(`/apps/${appDetail?.id}/workflows/publish`) + workflowStore.setState({ + nodesDefaultConfigs: nodesDefaultConfigsData.reduce((acc, block) => { + if (!acc[block.type]) + acc[block.type] = { ...block.config } + return acc + }, {} as Record), + }) + workflowStore.getState().setPublishedAt(publishedWorkflow?.created_at) + } + catch (e) { + console.error(e) + } + }, [workflowStore, appDetail]) + + useEffect(() => { + handleFetchPreloadData() + }, [handleFetchPreloadData]) + + useEffect(() => { + if (data) { + workflowStore.getState().setDraftUpdatedAt(data.updated_at) + workflowStore.getState().setToolPublished(data.tool_published) + } + }, [data, workflowStore]) + + return { + data, + isLoading, + } +} diff --git a/web/app/components/workflow-app/hooks/use-workflow-run.ts b/web/app/components/workflow-app/hooks/use-workflow-run.ts new file mode 100644 index 0000000000..1e484d0760 --- /dev/null +++ b/web/app/components/workflow-app/hooks/use-workflow-run.ts @@ -0,0 +1,357 @@ +import { useCallback } from 'react' +import { + useReactFlow, + useStoreApi, +} from 'reactflow' +import produce from 'immer' +import { v4 as uuidV4 } from 'uuid' +import { usePathname } from 'next/navigation' +import { useWorkflowStore } from '@/app/components/workflow/store' +import { WorkflowRunningStatus } from '@/app/components/workflow/types' +import { useWorkflowUpdate } from '@/app/components/workflow/hooks/use-workflow-interactions' +import { useWorkflowRunEvent } from '@/app/components/workflow/hooks/use-workflow-run-event/use-workflow-run-event' +import { useStore as useAppStore } from '@/app/components/app/store' +import type { IOtherOptions } from '@/service/base' +import { ssePost } from '@/service/base' +import { stopWorkflowRun } from '@/service/workflow' +import { useFeaturesStore } from '@/app/components/base/features/hooks' +import { AudioPlayerManager } from '@/app/components/base/audio-btn/audio.player.manager' +import type { VersionHistory } from '@/types/workflow' +import { noop } from 'lodash-es' +import { useNodesSyncDraft } from './use-nodes-sync-draft' + +export const useWorkflowRun = () => { + const store = useStoreApi() + const workflowStore = useWorkflowStore() + const reactflow = useReactFlow() + const featuresStore = useFeaturesStore() + const { doSyncWorkflowDraft } = useNodesSyncDraft() + const { handleUpdateWorkflowCanvas } = useWorkflowUpdate() + const pathname = usePathname() + + const { + handleWorkflowStarted, + handleWorkflowFinished, + handleWorkflowFailed, + handleWorkflowNodeStarted, + handleWorkflowNodeFinished, + handleWorkflowNodeIterationStarted, + handleWorkflowNodeIterationNext, + handleWorkflowNodeIterationFinished, + handleWorkflowNodeLoopStarted, + handleWorkflowNodeLoopNext, + handleWorkflowNodeLoopFinished, + handleWorkflowNodeRetry, + handleWorkflowAgentLog, + handleWorkflowTextChunk, + handleWorkflowTextReplace, + } = useWorkflowRunEvent() + + const handleBackupDraft = useCallback(() => { + const { + getNodes, + edges, + } = store.getState() + const { getViewport } = reactflow + const { + backupDraft, + setBackupDraft, + environmentVariables, + } = workflowStore.getState() + const { features } = featuresStore!.getState() + + if (!backupDraft) { + setBackupDraft({ + nodes: getNodes(), + edges, + viewport: getViewport(), + features, + environmentVariables, + }) + doSyncWorkflowDraft() + } + }, [reactflow, workflowStore, store, featuresStore, doSyncWorkflowDraft]) + + const handleLoadBackupDraft = useCallback(() => { + const { + backupDraft, + setBackupDraft, + setEnvironmentVariables, + } = workflowStore.getState() + + if (backupDraft) { + const { + nodes, + edges, + viewport, + features, + environmentVariables, + } = backupDraft + handleUpdateWorkflowCanvas({ + nodes, + edges, + viewport, + }) + setEnvironmentVariables(environmentVariables) + featuresStore!.setState({ features }) + setBackupDraft(undefined) + } + }, [handleUpdateWorkflowCanvas, workflowStore, featuresStore]) + + const handleRun = useCallback(async ( + params: any, + callback?: IOtherOptions, + ) => { + const { + getNodes, + setNodes, + } = store.getState() + const newNodes = produce(getNodes(), (draft) => { + draft.forEach((node) => { + node.data.selected = false + node.data._runningStatus = undefined + }) + }) + setNodes(newNodes) + await doSyncWorkflowDraft() + + const { + onWorkflowStarted, + onWorkflowFinished, + onNodeStarted, + onNodeFinished, + onIterationStart, + onIterationNext, + onIterationFinish, + onLoopStart, + onLoopNext, + onLoopFinish, + onNodeRetry, + onAgentLog, + onError, + ...restCallback + } = callback || {} + workflowStore.setState({ historyWorkflowData: undefined }) + const appDetail = useAppStore.getState().appDetail + const workflowContainer = document.getElementById('workflow-container') + + const { + clientWidth, + clientHeight, + } = workflowContainer! + + let url = '' + if (appDetail?.mode === 'advanced-chat') + url = `/apps/${appDetail.id}/advanced-chat/workflows/draft/run` + + if (appDetail?.mode === 'workflow') + url = `/apps/${appDetail.id}/workflows/draft/run` + + const { + setWorkflowRunningData, + } = workflowStore.getState() + setWorkflowRunningData({ + result: { + status: WorkflowRunningStatus.Running, + }, + tracing: [], + resultText: '', + }) + + let ttsUrl = '' + let ttsIsPublic = false + if (params.token) { + ttsUrl = '/text-to-audio' + ttsIsPublic = true + } + else if (params.appId) { + if (pathname.search('explore/installed') > -1) + ttsUrl = `/installed-apps/${params.appId}/text-to-audio` + else + ttsUrl = `/apps/${params.appId}/text-to-audio` + } + const player = AudioPlayerManager.getInstance().getAudioPlayer(ttsUrl, ttsIsPublic, uuidV4(), 'none', 'none', noop) + + ssePost( + url, + { + body: params, + }, + { + onWorkflowStarted: (params) => { + handleWorkflowStarted(params) + + if (onWorkflowStarted) + onWorkflowStarted(params) + }, + onWorkflowFinished: (params) => { + handleWorkflowFinished(params) + + if (onWorkflowFinished) + onWorkflowFinished(params) + }, + onError: (params) => { + handleWorkflowFailed() + + if (onError) + onError(params) + }, + onNodeStarted: (params) => { + handleWorkflowNodeStarted( + params, + { + clientWidth, + clientHeight, + }, + ) + + if (onNodeStarted) + onNodeStarted(params) + }, + onNodeFinished: (params) => { + handleWorkflowNodeFinished(params) + + if (onNodeFinished) + onNodeFinished(params) + }, + onIterationStart: (params) => { + handleWorkflowNodeIterationStarted( + params, + { + clientWidth, + clientHeight, + }, + ) + + if (onIterationStart) + onIterationStart(params) + }, + onIterationNext: (params) => { + handleWorkflowNodeIterationNext(params) + + if (onIterationNext) + onIterationNext(params) + }, + onIterationFinish: (params) => { + handleWorkflowNodeIterationFinished(params) + + if (onIterationFinish) + onIterationFinish(params) + }, + onLoopStart: (params) => { + handleWorkflowNodeLoopStarted( + params, + { + clientWidth, + clientHeight, + }, + ) + + if (onLoopStart) + onLoopStart(params) + }, + onLoopNext: (params) => { + handleWorkflowNodeLoopNext(params) + + if (onLoopNext) + onLoopNext(params) + }, + onLoopFinish: (params) => { + handleWorkflowNodeLoopFinished(params) + + if (onLoopFinish) + onLoopFinish(params) + }, + onNodeRetry: (params) => { + handleWorkflowNodeRetry(params) + + if (onNodeRetry) + onNodeRetry(params) + }, + onAgentLog: (params) => { + handleWorkflowAgentLog(params) + + if (onAgentLog) + onAgentLog(params) + }, + onTextChunk: (params) => { + handleWorkflowTextChunk(params) + }, + onTextReplace: (params) => { + handleWorkflowTextReplace(params) + }, + onTTSChunk: (messageId: string, audio: string) => { + if (!audio || audio === '') + return + player.playAudioWithAudio(audio, true) + AudioPlayerManager.getInstance().resetMsgId(messageId) + }, + onTTSEnd: (messageId: string, audio: string) => { + player.playAudioWithAudio(audio, false) + }, + ...restCallback, + }, + ) + }, [ + store, + workflowStore, + doSyncWorkflowDraft, + handleWorkflowStarted, + handleWorkflowFinished, + handleWorkflowFailed, + handleWorkflowNodeStarted, + handleWorkflowNodeFinished, + handleWorkflowNodeIterationStarted, + handleWorkflowNodeIterationNext, + handleWorkflowNodeIterationFinished, + handleWorkflowNodeLoopStarted, + handleWorkflowNodeLoopNext, + handleWorkflowNodeLoopFinished, + handleWorkflowNodeRetry, + handleWorkflowTextChunk, + handleWorkflowTextReplace, + handleWorkflowAgentLog, + pathname], + ) + + const handleStopRun = useCallback((taskId: string) => { + const appId = useAppStore.getState().appDetail?.id + + stopWorkflowRun(`/apps/${appId}/workflow-runs/tasks/${taskId}/stop`) + }, []) + + const handleRestoreFromPublishedWorkflow = useCallback((publishedWorkflow: VersionHistory) => { + const nodes = publishedWorkflow.graph.nodes.map(node => ({ ...node, selected: false, data: { ...node.data, selected: false } })) + const edges = publishedWorkflow.graph.edges + const viewport = publishedWorkflow.graph.viewport! + handleUpdateWorkflowCanvas({ + nodes, + edges, + viewport, + }) + const mappedFeatures = { + opening: { + enabled: !!publishedWorkflow.features.opening_statement || !!publishedWorkflow.features.suggested_questions.length, + opening_statement: publishedWorkflow.features.opening_statement, + suggested_questions: publishedWorkflow.features.suggested_questions, + }, + suggested: publishedWorkflow.features.suggested_questions_after_answer, + text2speech: publishedWorkflow.features.text_to_speech, + speech2text: publishedWorkflow.features.speech_to_text, + citation: publishedWorkflow.features.retriever_resource, + moderation: publishedWorkflow.features.sensitive_word_avoidance, + file: publishedWorkflow.features.file_upload, + } + + featuresStore?.setState({ features: mappedFeatures }) + workflowStore.getState().setEnvironmentVariables(publishedWorkflow.environment_variables || []) + }, [featuresStore, handleUpdateWorkflowCanvas, workflowStore]) + + return { + handleBackupDraft, + handleLoadBackupDraft, + handleRun, + handleStopRun, + handleRestoreFromPublishedWorkflow, + } +} diff --git a/web/app/components/workflow-app/hooks/use-workflow-start-run.tsx b/web/app/components/workflow-app/hooks/use-workflow-start-run.tsx new file mode 100644 index 0000000000..3f5ea1c1df --- /dev/null +++ b/web/app/components/workflow-app/hooks/use-workflow-start-run.tsx @@ -0,0 +1,96 @@ +import { useCallback } from 'react' +import { useStoreApi } from 'reactflow' +import { useWorkflowStore } from '@/app/components/workflow/store' +import { + BlockEnum, + WorkflowRunningStatus, +} from '@/app/components/workflow/types' +import { useWorkflowInteractions } from '@/app/components/workflow/hooks' +import { useFeaturesStore } from '@/app/components/base/features/hooks' +import { + useIsChatMode, + useNodesSyncDraft, + useWorkflowRun, +} from '.' + +export const useWorkflowStartRun = () => { + const store = useStoreApi() + const workflowStore = useWorkflowStore() + const featuresStore = useFeaturesStore() + const isChatMode = useIsChatMode() + const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions() + const { handleRun } = useWorkflowRun() + const { doSyncWorkflowDraft } = useNodesSyncDraft() + + const handleWorkflowStartRunInWorkflow = useCallback(async () => { + const { + workflowRunningData, + } = workflowStore.getState() + + if (workflowRunningData?.result.status === WorkflowRunningStatus.Running) + return + + const { getNodes } = store.getState() + const nodes = getNodes() + const startNode = nodes.find(node => node.data.type === BlockEnum.Start) + const startVariables = startNode?.data.variables || [] + const fileSettings = featuresStore!.getState().features.file + const { + showDebugAndPreviewPanel, + setShowDebugAndPreviewPanel, + setShowInputsPanel, + setShowEnvPanel, + } = workflowStore.getState() + + setShowEnvPanel(false) + + if (showDebugAndPreviewPanel) { + handleCancelDebugAndPreviewPanel() + return + } + + if (!startVariables.length && !fileSettings?.image?.enabled) { + await doSyncWorkflowDraft() + handleRun({ inputs: {}, files: [] }) + setShowDebugAndPreviewPanel(true) + setShowInputsPanel(false) + } + else { + setShowDebugAndPreviewPanel(true) + setShowInputsPanel(true) + } + }, [store, workflowStore, featuresStore, handleCancelDebugAndPreviewPanel, handleRun, doSyncWorkflowDraft]) + + const handleWorkflowStartRunInChatflow = useCallback(async () => { + const { + showDebugAndPreviewPanel, + setShowDebugAndPreviewPanel, + setHistoryWorkflowData, + setShowEnvPanel, + setShowChatVariablePanel, + } = workflowStore.getState() + + setShowEnvPanel(false) + setShowChatVariablePanel(false) + + if (showDebugAndPreviewPanel) + handleCancelDebugAndPreviewPanel() + else + setShowDebugAndPreviewPanel(true) + + setHistoryWorkflowData(undefined) + }, [workflowStore, handleCancelDebugAndPreviewPanel]) + + const handleStartWorkflowRun = useCallback(() => { + if (!isChatMode) + handleWorkflowStartRunInWorkflow() + else + handleWorkflowStartRunInChatflow() + }, [isChatMode, handleWorkflowStartRunInWorkflow, handleWorkflowStartRunInChatflow]) + + return { + handleStartWorkflowRun, + handleWorkflowStartRunInWorkflow, + handleWorkflowStartRunInChatflow, + } +} diff --git a/web/app/components/workflow/hooks/use-workflow-template.ts b/web/app/components/workflow-app/hooks/use-workflow-template.ts similarity index 87% rename from web/app/components/workflow/hooks/use-workflow-template.ts rename to web/app/components/workflow-app/hooks/use-workflow-template.ts index c2dc956b63..9f47b981dc 100644 --- a/web/app/components/workflow/hooks/use-workflow-template.ts +++ b/web/app/components/workflow-app/hooks/use-workflow-template.ts @@ -1,10 +1,10 @@ -import { generateNewNode } from '../utils' +import { generateNewNode } from '@/app/components/workflow/utils' import { NODE_WIDTH_X_OFFSET, START_INITIAL_POSITION, -} from '../constants' -import { useIsChatMode } from './use-workflow' -import { useNodesInitialData } from './use-nodes-data' +} from '@/app/components/workflow/constants' +import { useNodesInitialData } from '@/app/components/workflow/hooks' +import { useIsChatMode } from './use-is-chat-mode' export const useWorkflowTemplate = () => { const isChatMode = useIsChatMode() diff --git a/web/app/components/workflow-app/index.tsx b/web/app/components/workflow-app/index.tsx new file mode 100644 index 0000000000..761a7f29c4 --- /dev/null +++ b/web/app/components/workflow-app/index.tsx @@ -0,0 +1,108 @@ +import { + useMemo, +} from 'react' +import useSWR from 'swr' +import { + SupportUploadFileTypes, +} from '@/app/components/workflow/types' +import { + useWorkflowInit, +} from './hooks' +import { + initialEdges, + initialNodes, +} from '@/app/components/workflow/utils' +import Loading from '@/app/components/base/loading' +import { FeaturesProvider } from '@/app/components/base/features' +import type { Features as FeaturesData } from '@/app/components/base/features/types' +import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants' +import { fetchFileUploadConfig } from '@/service/common' +import WorkflowWithDefaultContext from '@/app/components/workflow' +import { + WorkflowContextProvider, +} from '@/app/components/workflow/context' +import { createWorkflowSlice } from './store/workflow/workflow-slice' +import WorkflowAppMain from './components/workflow-main' + +const WorkflowAppWithAdditionalContext = () => { + const { + data, + isLoading, + } = useWorkflowInit() + const { data: fileUploadConfigResponse } = useSWR({ url: '/files/upload' }, fetchFileUploadConfig) + + const nodesData = useMemo(() => { + if (data) + return initialNodes(data.graph.nodes, data.graph.edges) + + return [] + }, [data]) + const edgesData = useMemo(() => { + if (data) + return initialEdges(data.graph.edges, data.graph.nodes) + + return [] + }, [data]) + + if (!data || isLoading) { + return ( +
+ +
+ ) + } + + const features = data.features || {} + const initialFeatures: FeaturesData = { + file: { + image: { + enabled: !!features.file_upload?.image?.enabled, + number_limits: features.file_upload?.image?.number_limits || 3, + transfer_methods: features.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], + }, + enabled: !!(features.file_upload?.enabled || features.file_upload?.image?.enabled), + allowed_file_types: features.file_upload?.allowed_file_types || [SupportUploadFileTypes.image], + allowed_file_extensions: features.file_upload?.allowed_file_extensions || FILE_EXTS[SupportUploadFileTypes.image].map(ext => `.${ext}`), + allowed_file_upload_methods: features.file_upload?.allowed_file_upload_methods || features.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], + number_limits: features.file_upload?.number_limits || features.file_upload?.image?.number_limits || 3, + fileUploadConfig: fileUploadConfigResponse, + }, + opening: { + enabled: !!features.opening_statement, + opening_statement: features.opening_statement, + suggested_questions: features.suggested_questions, + }, + suggested: features.suggested_questions_after_answer || { enabled: false }, + speech2text: features.speech_to_text || { enabled: false }, + text2speech: features.text_to_speech || { enabled: false }, + citation: features.retriever_resource || { enabled: false }, + moderation: features.sensitive_word_avoidance || { enabled: false }, + } + + return ( + + + + + + ) +} + +const WorkflowAppWrapper = () => { + return ( + + + + ) +} + +export default WorkflowAppWrapper diff --git a/web/app/components/workflow-app/store/workflow/workflow-slice.ts b/web/app/components/workflow-app/store/workflow/workflow-slice.ts new file mode 100644 index 0000000000..77626e52b1 --- /dev/null +++ b/web/app/components/workflow-app/store/workflow/workflow-slice.ts @@ -0,0 +1,18 @@ +import type { StateCreator } from 'zustand' + +export type WorkflowSliceShape = { + appId: string + notInitialWorkflow: boolean + setNotInitialWorkflow: (notInitialWorkflow: boolean) => void + nodesDefaultConfigs: Record + setNodesDefaultConfigs: (nodesDefaultConfigs: Record) => void +} + +export type CreateWorkflowSlice = StateCreator +export const createWorkflowSlice: StateCreator = set => ({ + appId: '', + notInitialWorkflow: false, + setNotInitialWorkflow: notInitialWorkflow => set(() => ({ notInitialWorkflow })), + nodesDefaultConfigs: {}, + setNodesDefaultConfigs: nodesDefaultConfigs => set(() => ({ nodesDefaultConfigs })), +}) diff --git a/web/app/components/workflow/context.tsx b/web/app/components/workflow/context.tsx index bb34ce6319..cae14fc2b2 100644 --- a/web/app/components/workflow/context.tsx +++ b/web/app/components/workflow/context.tsx @@ -2,19 +2,24 @@ import { createContext, useRef, } from 'react' -import { createWorkflowStore } from './store' +import { + createWorkflowStore, +} from './store' +import type { StateCreator } from 'zustand' +import type { WorkflowSliceShape } from '@/app/components/workflow-app/store/workflow/workflow-slice' type WorkflowStore = ReturnType export const WorkflowContext = createContext(null) -type WorkflowProviderProps = { +export type WorkflowProviderProps = { children: React.ReactNode + injectWorkflowStoreSliceFn?: StateCreator } -export const WorkflowContextProvider = ({ children }: WorkflowProviderProps) => { +export const WorkflowContextProvider = ({ children, injectWorkflowStoreSliceFn }: WorkflowProviderProps) => { const storeRef = useRef(undefined) if (!storeRef.current) - storeRef.current = createWorkflowStore() + storeRef.current = createWorkflowStore({ injectWorkflowStoreSliceFn }) return ( diff --git a/web/app/components/workflow/header/editing-title.tsx b/web/app/components/workflow/header/editing-title.tsx index b99564a5f9..2444cf8c29 100644 --- a/web/app/components/workflow/header/editing-title.tsx +++ b/web/app/components/workflow/header/editing-title.tsx @@ -1,13 +1,13 @@ import { memo } from 'react' import { useTranslation } from 'react-i18next' -import { useWorkflow } from '../hooks' +import { useFormatTimeFromNow } from '../hooks' import { useStore } from '@/app/components/workflow/store' import useTimestamp from '@/hooks/use-timestamp' const EditingTitle = () => { const { t } = useTranslation() const { formatTime } = useTimestamp() - const { formatTimeFromNow } = useWorkflow() + const { formatTimeFromNow } = useFormatTimeFromNow() const draftUpdatedAt = useStore(state => state.draftUpdatedAt) const publishedAt = useStore(state => state.publishedAt) const isSyncingWorkflowDraft = useStore(s => s.isSyncingWorkflowDraft) diff --git a/web/app/components/workflow/header/header-in-normal.tsx b/web/app/components/workflow/header/header-in-normal.tsx new file mode 100644 index 0000000000..ec016b1b65 --- /dev/null +++ b/web/app/components/workflow/header/header-in-normal.tsx @@ -0,0 +1,69 @@ +import { + useCallback, +} from 'react' +import { useNodes } from 'reactflow' +import { + useStore, + useWorkflowStore, +} from '../store' +import type { StartNodeType } from '../nodes/start/types' +import { + useNodesInteractions, + useNodesReadOnly, + useWorkflowRun, +} from '../hooks' +import Divider from '../../base/divider' +import RunAndHistory from './run-and-history' +import EditingTitle from './editing-title' +import EnvButton from './env-button' +import VersionHistoryButton from './version-history-button' + +export type HeaderInNormalProps = { + components?: { + left?: React.ReactNode + middle?: React.ReactNode + } +} +const HeaderInNormal = ({ + components, +}: HeaderInNormalProps) => { + const workflowStore = useWorkflowStore() + const { nodesReadOnly } = useNodesReadOnly() + const { handleNodeSelect } = useNodesInteractions() + const setShowWorkflowVersionHistoryPanel = useStore(s => s.setShowWorkflowVersionHistoryPanel) + const setShowEnvPanel = useStore(s => s.setShowEnvPanel) + const setShowDebugAndPreviewPanel = useStore(s => s.setShowDebugAndPreviewPanel) + const nodes = useNodes() + const selectedNode = nodes.find(node => node.data.selected) + const { handleBackupDraft } = useWorkflowRun() + + const onStartRestoring = useCallback(() => { + workflowStore.setState({ isRestoring: true }) + handleBackupDraft() + // clear right panel + if (selectedNode) + handleNodeSelect(selectedNode.id, true) + setShowWorkflowVersionHistoryPanel(true) + setShowEnvPanel(false) + setShowDebugAndPreviewPanel(false) + }, [handleBackupDraft, workflowStore, handleNodeSelect, selectedNode, + setShowWorkflowVersionHistoryPanel, setShowEnvPanel, setShowDebugAndPreviewPanel]) + + return ( + <> +
+ +
+
+ {components?.left} + + + + {components?.middle} + +
+ + ) +} + +export default HeaderInNormal diff --git a/web/app/components/workflow/header/header-in-restoring.tsx b/web/app/components/workflow/header/header-in-restoring.tsx new file mode 100644 index 0000000000..4d1954587d --- /dev/null +++ b/web/app/components/workflow/header/header-in-restoring.tsx @@ -0,0 +1,93 @@ +import { + useCallback, +} from 'react' +import { RiHistoryLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { + useStore, + useWorkflowStore, +} from '../store' +import { + WorkflowVersion, +} from '../types' +import { + useNodesSyncDraft, + useWorkflowRun, +} from '../hooks' +import Toast from '../../base/toast' +import RestoringTitle from './restoring-title' +import Button from '@/app/components/base/button' + +export type HeaderInRestoringProps = { + onRestoreSettled?: () => void +} +const HeaderInRestoring = ({ + onRestoreSettled, +}: HeaderInRestoringProps) => { + const { t } = useTranslation() + const workflowStore = useWorkflowStore() + const currentVersion = useStore(s => s.currentVersion) + const setShowWorkflowVersionHistoryPanel = useStore(s => s.setShowWorkflowVersionHistoryPanel) + + const { + handleLoadBackupDraft, + } = useWorkflowRun() + const { handleSyncWorkflowDraft } = useNodesSyncDraft() + + const handleCancelRestore = useCallback(() => { + handleLoadBackupDraft() + workflowStore.setState({ isRestoring: false }) + setShowWorkflowVersionHistoryPanel(false) + }, [workflowStore, handleLoadBackupDraft, setShowWorkflowVersionHistoryPanel]) + + const handleRestore = useCallback(() => { + setShowWorkflowVersionHistoryPanel(false) + workflowStore.setState({ isRestoring: false }) + workflowStore.setState({ backupDraft: undefined }) + handleSyncWorkflowDraft(true, false, { + onSuccess: () => { + Toast.notify({ + type: 'success', + message: t('workflow.versionHistory.action.restoreSuccess'), + }) + }, + onError: () => { + Toast.notify({ + type: 'error', + message: t('workflow.versionHistory.action.restoreFailure'), + }) + }, + onSettled: () => { + onRestoreSettled?.() + }, + }) + }, [handleSyncWorkflowDraft, workflowStore, setShowWorkflowVersionHistoryPanel, onRestoreSettled, t]) + + return ( + <> +
+ +
+
+ + +
+ + ) +} + +export default HeaderInRestoring diff --git a/web/app/components/workflow/header/header-in-view-history.tsx b/web/app/components/workflow/header/header-in-view-history.tsx new file mode 100644 index 0000000000..81858ccc89 --- /dev/null +++ b/web/app/components/workflow/header/header-in-view-history.tsx @@ -0,0 +1,50 @@ +import { + useCallback, +} from 'react' +import { useTranslation } from 'react-i18next' +import { + useWorkflowStore, +} from '../store' +import { + useWorkflowRun, +} from '../hooks' +import Divider from '../../base/divider' +import RunningTitle from './running-title' +import ViewHistory from './view-history' +import Button from '@/app/components/base/button' +import { ArrowNarrowLeft } from '@/app/components/base/icons/src/vender/line/arrows' + +const HeaderInHistory = () => { + const { t } = useTranslation() + const workflowStore = useWorkflowStore() + + const { + handleLoadBackupDraft, + } = useWorkflowRun() + + const handleGoBackToEdit = useCallback(() => { + handleLoadBackupDraft() + workflowStore.setState({ historyWorkflowData: undefined }) + }, [workflowStore, handleLoadBackupDraft]) + + return ( + <> +
+ +
+
+ + + +
+ + ) +} + +export default HeaderInHistory diff --git a/web/app/components/workflow/header/index.tsx b/web/app/components/workflow/header/index.tsx index 7e99f5dd6b..e5391afb09 100644 --- a/web/app/components/workflow/header/index.tsx +++ b/web/app/components/workflow/header/index.tsx @@ -1,292 +1,51 @@ -import type { FC } from 'react' import { - memo, - useCallback, - useMemo, -} from 'react' -import { RiApps2AddLine, RiHistoryLine } from '@remixicon/react' -import { useNodes } from 'reactflow' -import { useTranslation } from 'react-i18next' -import { useContext, useContextSelector } from 'use-context-selector' -import { - useStore, - useWorkflowStore, -} from '../store' -import { - BlockEnum, - InputVarType, - WorkflowVersion, -} from '../types' -import type { StartNodeType } from '../nodes/start/types' -import { - useChecklistBeforePublish, - useIsChatMode, - useNodesInteractions, - useNodesReadOnly, - useNodesSyncDraft, useWorkflowMode, - useWorkflowRun, } from '../hooks' -import AppPublisher from '../../app/app-publisher' -import Toast, { ToastContext } from '../../base/toast' -import Divider from '../../base/divider' -import RunAndHistory from './run-and-history' -import EditingTitle from './editing-title' -import RunningTitle from './running-title' -import RestoringTitle from './restoring-title' -import ViewHistory from './view-history' -import ChatVariableButton from './chat-variable-button' -import EnvButton from './env-button' -import VersionHistoryButton from './version-history-button' -import Button from '@/app/components/base/button' -import { useStore as useAppStore } from '@/app/components/app/store' -import { ArrowNarrowLeft } from '@/app/components/base/icons/src/vender/line/arrows' -import { useFeatures } from '@/app/components/base/features/hooks' -import { usePublishWorkflow, useResetWorkflowVersionHistory } from '@/service/use-workflow' -import type { PublishWorkflowParams } from '@/types/workflow' -import { fetchAppDetail, fetchAppSSO } from '@/service/apps' -import AppContext from '@/context/app-context' - -const Header: FC = () => { - const { t } = useTranslation() - const workflowStore = useWorkflowStore() - const appDetail = useAppStore(s => s.appDetail) - const setAppDetail = useAppStore(s => s.setAppDetail) - const systemFeatures = useContextSelector(AppContext, state => state.systemFeatures) - const appID = appDetail?.id - const isChatMode = useIsChatMode() - const { nodesReadOnly, getNodesReadOnly } = useNodesReadOnly() - const { handleNodeSelect } = useNodesInteractions() - const publishedAt = useStore(s => s.publishedAt) - const draftUpdatedAt = useStore(s => s.draftUpdatedAt) - const toolPublished = useStore(s => s.toolPublished) - const currentVersion = useStore(s => s.currentVersion) - const setShowWorkflowVersionHistoryPanel = useStore(s => s.setShowWorkflowVersionHistoryPanel) - const setShowEnvPanel = useStore(s => s.setShowEnvPanel) - const setShowDebugAndPreviewPanel = useStore(s => s.setShowDebugAndPreviewPanel) - const nodes = useNodes() - const startNode = nodes.find(node => node.data.type === BlockEnum.Start) - const selectedNode = nodes.find(node => node.data.selected) - const startVariables = startNode?.data.variables - const fileSettings = useFeatures(s => s.features.file) - const variables = useMemo(() => { - const data = startVariables || [] - if (fileSettings?.image?.enabled) { - return [ - ...data, - { - type: InputVarType.files, - variable: '__image', - required: false, - label: 'files', - }, - ] - } - - return data - }, [fileSettings?.image?.enabled, startVariables]) - - const { - handleLoadBackupDraft, - handleBackupDraft, - } = useWorkflowRun() - const { handleCheckBeforePublish } = useChecklistBeforePublish() - const { handleSyncWorkflowDraft } = useNodesSyncDraft() - const { notify } = useContext(ToastContext) +import type { HeaderInNormalProps } from './header-in-normal' +import HeaderInNormal from './header-in-normal' +import HeaderInHistory from './header-in-view-history' +import type { HeaderInRestoringProps } from './header-in-restoring' +import HeaderInRestoring from './header-in-restoring' + +export type HeaderProps = { + normal?: HeaderInNormalProps + restoring?: HeaderInRestoringProps +} +const Header = ({ + normal: normalProps, + restoring: restoringProps, +}: HeaderProps) => { const { normal, restoring, viewHistory, } = useWorkflowMode() - const handleShowFeatures = useCallback(() => { - const { - showFeaturesPanel, - isRestoring, - setShowFeaturesPanel, - } = workflowStore.getState() - if (getNodesReadOnly() && !isRestoring) - return - setShowFeaturesPanel(!showFeaturesPanel) - }, [workflowStore, getNodesReadOnly]) - - const handleCancelRestore = useCallback(() => { - handleLoadBackupDraft() - workflowStore.setState({ isRestoring: false }) - setShowWorkflowVersionHistoryPanel(false) - }, [workflowStore, handleLoadBackupDraft, setShowWorkflowVersionHistoryPanel]) - - const resetWorkflowVersionHistory = useResetWorkflowVersionHistory(appDetail!.id) - - const handleRestore = useCallback(() => { - setShowWorkflowVersionHistoryPanel(false) - workflowStore.setState({ isRestoring: false }) - workflowStore.setState({ backupDraft: undefined }) - handleSyncWorkflowDraft(true, false, { - onSuccess: () => { - Toast.notify({ - type: 'success', - message: t('workflow.versionHistory.action.restoreSuccess'), - }) - }, - onError: () => { - Toast.notify({ - type: 'error', - message: t('workflow.versionHistory.action.restoreFailure'), - }) - }, - onSettled: () => { - resetWorkflowVersionHistory() - }, - }) - }, [handleSyncWorkflowDraft, workflowStore, setShowWorkflowVersionHistoryPanel, resetWorkflowVersionHistory, t]) - - const updateAppDetail = useCallback(async () => { - try { - const res = await fetchAppDetail({ url: '/apps', id: appID! }) - if (systemFeatures.enable_web_sso_switch_component) { - const ssoRes = await fetchAppSSO({ appId: appID! }) - setAppDetail({ ...res, enable_sso: ssoRes.enabled }) - } - else { - setAppDetail({ ...res }) - } - } - catch (error) { - console.error(error) - } - }, [appID, setAppDetail, systemFeatures.enable_web_sso_switch_component]) - - const { mutateAsync: publishWorkflow } = usePublishWorkflow(appID!) - - const onPublish = useCallback(async (params?: PublishWorkflowParams) => { - if (await handleCheckBeforePublish()) { - const res = await publishWorkflow({ - title: params?.title || '', - releaseNotes: params?.releaseNotes || '', - }) - - if (res) { - notify({ type: 'success', message: t('common.api.actionSuccess') }) - updateAppDetail() - workflowStore.getState().setPublishedAt(res.created_at) - resetWorkflowVersionHistory() - } - } - else { - throw new Error('Checklist failed') - } - }, [handleCheckBeforePublish, notify, t, workflowStore, publishWorkflow, resetWorkflowVersionHistory, updateAppDetail]) - - const onStartRestoring = useCallback(() => { - workflowStore.setState({ isRestoring: true }) - handleBackupDraft() - // clear right panel - if (selectedNode) - handleNodeSelect(selectedNode.id, true) - setShowWorkflowVersionHistoryPanel(true) - setShowEnvPanel(false) - setShowDebugAndPreviewPanel(false) - }, [handleBackupDraft, workflowStore, handleNodeSelect, selectedNode, - setShowWorkflowVersionHistoryPanel, setShowEnvPanel, setShowDebugAndPreviewPanel]) - - const onPublisherToggle = useCallback((state: boolean) => { - if (state) - handleSyncWorkflowDraft(true) - }, [handleSyncWorkflowDraft]) - - const handleGoBackToEdit = useCallback(() => { - handleLoadBackupDraft() - workflowStore.setState({ historyWorkflowData: undefined }) - }, [workflowStore, handleLoadBackupDraft]) - - const handleToolConfigureUpdate = useCallback(() => { - workflowStore.setState({ toolPublished: true }) - }, [workflowStore]) - return (
-
- { - normal && - } - { - viewHistory && - } - { - restoring && - } -
{ normal && ( -
- {/* */} - {isChatMode && } - - - - - - -
+ ) } { viewHistory && ( -
- - - -
+ ) } { restoring && ( -
- - -
+ ) }
) } -export default memo(Header) +export default Header diff --git a/web/app/components/workflow/header/restoring-title.tsx b/web/app/components/workflow/header/restoring-title.tsx index 310ab5c35a..26cdd79d13 100644 --- a/web/app/components/workflow/header/restoring-title.tsx +++ b/web/app/components/workflow/header/restoring-title.tsx @@ -1,13 +1,13 @@ import { memo, useMemo } from 'react' import { useTranslation } from 'react-i18next' -import { useWorkflow } from '../hooks' +import { useFormatTimeFromNow } from '../hooks' import { useStore } from '../store' import { WorkflowVersion } from '../types' import useTimestamp from '@/hooks/use-timestamp' const RestoringTitle = () => { const { t } = useTranslation() - const { formatTimeFromNow } = useWorkflow() + const { formatTimeFromNow } = useFormatTimeFromNow() const { formatTime } = useTimestamp() const currentVersion = useStore(state => state.currentVersion) const isDraft = currentVersion?.version === WorkflowVersion.Draft diff --git a/web/app/components/workflow/header/view-history.tsx b/web/app/components/workflow/header/view-history.tsx index 1298c0e42d..21b4462867 100644 --- a/web/app/components/workflow/header/view-history.tsx +++ b/web/app/components/workflow/header/view-history.tsx @@ -11,9 +11,9 @@ import { RiErrorWarningLine, } from '@remixicon/react' import { + useFormatTimeFromNow, useIsChatMode, useNodesInteractions, - useWorkflow, useWorkflowInteractions, useWorkflowRun, } from '../hooks' @@ -50,7 +50,7 @@ const ViewHistory = ({ const { t } = useTranslation() const isChatMode = useIsChatMode() const [open, setOpen] = useState(false) - const { formatTimeFromNow } = useWorkflow() + const { formatTimeFromNow } = useFormatTimeFromNow() const { handleNodesCancelSelected, } = useNodesInteractions() diff --git a/web/app/components/workflow/hooks-store/index.ts b/web/app/components/workflow/hooks-store/index.ts new file mode 100644 index 0000000000..40b4132dfd --- /dev/null +++ b/web/app/components/workflow/hooks-store/index.ts @@ -0,0 +1,2 @@ +export * from './provider' +export * from './store' diff --git a/web/app/components/workflow/hooks-store/provider.tsx b/web/app/components/workflow/hooks-store/provider.tsx new file mode 100644 index 0000000000..c1090ae3f8 --- /dev/null +++ b/web/app/components/workflow/hooks-store/provider.tsx @@ -0,0 +1,36 @@ +import { + createContext, + useEffect, + useRef, +} from 'react' +import { useStore } from 'reactflow' +import { + createHooksStore, +} from './store' +import type { Shape } from './store' + +type HooksStore = ReturnType +export const HooksStoreContext = createContext(null) +type HooksStoreContextProviderProps = Partial & { + children: React.ReactNode +} +export const HooksStoreContextProvider = ({ children, ...restProps }: HooksStoreContextProviderProps) => { + const storeRef = useRef(undefined) + const d3Selection = useStore(s => s.d3Selection) + const d3Zoom = useStore(s => s.d3Zoom) + + useEffect(() => { + if (storeRef.current && d3Selection && d3Zoom) + storeRef.current.getState().refreshAll(restProps) + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [d3Selection, d3Zoom]) + + if (!storeRef.current) + storeRef.current = createHooksStore(restProps) + + return ( + + {children} + + ) +} diff --git a/web/app/components/workflow/hooks-store/store.ts b/web/app/components/workflow/hooks-store/store.ts new file mode 100644 index 0000000000..2e40cbfbc9 --- /dev/null +++ b/web/app/components/workflow/hooks-store/store.ts @@ -0,0 +1,72 @@ +import { useContext } from 'react' +import { + noop, +} from 'lodash-es' +import { + useStore as useZustandStore, +} from 'zustand' +import { createStore } from 'zustand/vanilla' +import { HooksStoreContext } from './provider' + +type CommonHooksFnMap = { + doSyncWorkflowDraft: ( + notRefreshWhenSyncError?: boolean, + callback?: { + onSuccess?: () => void + onError?: () => void + onSettled?: () => void + } + ) => Promise + syncWorkflowDraftWhenPageClose: () => void + handleBackupDraft: () => void + handleLoadBackupDraft: () => void + handleRestoreFromPublishedWorkflow: (...args: any[]) => void + handleRun: (...args: any[]) => void + handleStopRun: (...args: any[]) => void + handleStartWorkflowRun: () => void + handleWorkflowStartRunInWorkflow: () => void + handleWorkflowStartRunInChatflow: () => void +} + +export type Shape = { + refreshAll: (props: Partial) => void +} & CommonHooksFnMap + +export const createHooksStore = ({ + doSyncWorkflowDraft = async () => noop(), + syncWorkflowDraftWhenPageClose = noop, + handleBackupDraft = noop, + handleLoadBackupDraft = noop, + handleRestoreFromPublishedWorkflow = noop, + handleRun = noop, + handleStopRun = noop, + handleStartWorkflowRun = noop, + handleWorkflowStartRunInWorkflow = noop, + handleWorkflowStartRunInChatflow = noop, +}: Partial) => { + return createStore(set => ({ + refreshAll: props => set(state => ({ ...state, ...props })), + doSyncWorkflowDraft, + syncWorkflowDraftWhenPageClose, + handleBackupDraft, + handleLoadBackupDraft, + handleRestoreFromPublishedWorkflow, + handleRun, + handleStopRun, + handleStartWorkflowRun, + handleWorkflowStartRunInWorkflow, + handleWorkflowStartRunInChatflow, + })) +} + +export function useHooksStore(selector: (state: Shape) => T): T { + const store = useContext(HooksStoreContext) + if (!store) + throw new Error('Missing HooksStoreContext.Provider in the tree') + + return useZustandStore(store, selector) +} + +export const useHooksStoreApi = () => { + return useContext(HooksStoreContext)! +} diff --git a/web/app/components/workflow/hooks/index.ts b/web/app/components/workflow/hooks/index.ts index 463e9b3271..20a34c69e3 100644 --- a/web/app/components/workflow/hooks/index.ts +++ b/web/app/components/workflow/hooks/index.ts @@ -5,7 +5,6 @@ export * from './use-nodes-data' export * from './use-nodes-sync-draft' export * from './use-workflow' export * from './use-workflow-run' -export * from './use-workflow-template' export * from './use-checklist' export * from './use-selection-interactions' export * from './use-panel-interactions' @@ -16,3 +15,4 @@ export * from './use-workflow-variables' export * from './use-shortcuts' export * from './use-workflow-interactions' export * from './use-workflow-mode' +export * from './use-format-time-from-now' diff --git a/web/app/components/workflow/hooks/use-edges-interactions-without-sync.ts b/web/app/components/workflow/hooks/use-edges-interactions-without-sync.ts new file mode 100644 index 0000000000..c4c709cd25 --- /dev/null +++ b/web/app/components/workflow/hooks/use-edges-interactions-without-sync.ts @@ -0,0 +1,27 @@ +import { useCallback } from 'react' +import produce from 'immer' +import { useStoreApi } from 'reactflow' + +export const useEdgesInteractionsWithoutSync = () => { + const store = useStoreApi() + + const handleEdgeCancelRunningStatus = useCallback(() => { + const { + edges, + setEdges, + } = store.getState() + + const newEdges = produce(edges, (draft) => { + draft.forEach((edge) => { + edge.data._sourceRunningStatus = undefined + edge.data._targetRunningStatus = undefined + edge.data._waitingRun = false + }) + }) + setEdges(newEdges) + }, [store]) + + return { + handleEdgeCancelRunningStatus, + } +} diff --git a/web/app/components/workflow/hooks/use-edges-interactions.ts b/web/app/components/workflow/hooks/use-edges-interactions.ts index 688f0b26ce..306af1e96c 100644 --- a/web/app/components/workflow/hooks/use-edges-interactions.ts +++ b/web/app/components/workflow/hooks/use-edges-interactions.ts @@ -151,28 +151,11 @@ export const useEdgesInteractions = () => { setEdges(newEdges) }, [store, getNodesReadOnly]) - const handleEdgeCancelRunningStatus = useCallback(() => { - const { - edges, - setEdges, - } = store.getState() - - const newEdges = produce(edges, (draft) => { - draft.forEach((edge) => { - edge.data._sourceRunningStatus = undefined - edge.data._targetRunningStatus = undefined - edge.data._waitingRun = false - }) - }) - setEdges(newEdges) - }, [store]) - return { handleEdgeEnter, handleEdgeLeave, handleEdgeDeleteByDeleteBranch, handleEdgeDelete, handleEdgesChange, - handleEdgeCancelRunningStatus, } } diff --git a/web/app/components/workflow/hooks/use-format-time-from-now.ts b/web/app/components/workflow/hooks/use-format-time-from-now.ts new file mode 100644 index 0000000000..b2b521557f --- /dev/null +++ b/web/app/components/workflow/hooks/use-format-time-from-now.ts @@ -0,0 +1,12 @@ +import dayjs from 'dayjs' +import { useCallback } from 'react' +import { useI18N } from '@/context/i18n' + +export const useFormatTimeFromNow = () => { + const { locale } = useI18N() + const formatTimeFromNow = useCallback((time: number) => { + return dayjs(time).locale(locale === 'zh-Hans' ? 'zh-cn' : locale).fromNow() + }, [locale]) + + return { formatTimeFromNow } +} diff --git a/web/app/components/workflow/hooks/use-nodes-interactions-without-sync.ts b/web/app/components/workflow/hooks/use-nodes-interactions-without-sync.ts new file mode 100644 index 0000000000..7fbf0ce868 --- /dev/null +++ b/web/app/components/workflow/hooks/use-nodes-interactions-without-sync.ts @@ -0,0 +1,27 @@ +import { useCallback } from 'react' +import produce from 'immer' +import { useStoreApi } from 'reactflow' + +export const useNodesInteractionsWithoutSync = () => { + const store = useStoreApi() + + const handleNodeCancelRunningStatus = useCallback(() => { + const { + getNodes, + setNodes, + } = store.getState() + + const nodes = getNodes() + const newNodes = produce(nodes, (draft) => { + draft.forEach((node) => { + node.data._runningStatus = undefined + node.data._waitingRun = false + }) + }) + setNodes(newNodes) + }, [store]) + + return { + handleNodeCancelRunningStatus, + } +} diff --git a/web/app/components/workflow/hooks/use-nodes-interactions.ts b/web/app/components/workflow/hooks/use-nodes-interactions.ts index 90231cfcc8..94b10c9929 100644 --- a/web/app/components/workflow/hooks/use-nodes-interactions.ts +++ b/web/app/components/workflow/hooks/use-nodes-interactions.ts @@ -1177,22 +1177,6 @@ export const useNodesInteractions = () => { saveStateToHistory(WorkflowHistoryEvent.NodeChange) }, [getNodesReadOnly, store, t, handleSyncWorkflowDraft, saveStateToHistory]) - const handleNodeCancelRunningStatus = useCallback(() => { - const { - getNodes, - setNodes, - } = store.getState() - - const nodes = getNodes() - const newNodes = produce(nodes, (draft) => { - draft.forEach((node) => { - node.data._runningStatus = undefined - node.data._waitingRun = false - }) - }) - setNodes(newNodes) - }, [store]) - const handleNodesCancelSelected = useCallback(() => { const { getNodes, @@ -1554,7 +1538,6 @@ export const useNodesInteractions = () => { handleNodeDelete, handleNodeChange, handleNodeAdd, - handleNodeCancelRunningStatus, handleNodesCancelSelected, handleNodeContextMenu, handleNodesCopy, diff --git a/web/app/components/workflow/hooks/use-nodes-sync-draft.ts b/web/app/components/workflow/hooks/use-nodes-sync-draft.ts index 5cd8f36bff..e6cc3a97e3 100644 --- a/web/app/components/workflow/hooks/use-nodes-sync-draft.ts +++ b/web/app/components/workflow/hooks/use-nodes-sync-draft.ts @@ -1,147 +1,17 @@ import { useCallback } from 'react' -import produce from 'immer' -import { useStoreApi } from 'reactflow' -import { useParams } from 'next/navigation' import { useStore, - useWorkflowStore, } from '../store' -import { BlockEnum } from '../types' -import { useWorkflowUpdate } from '../hooks' import { useNodesReadOnly, } from './use-workflow' -import { syncWorkflowDraft } from '@/service/workflow' -import { useFeaturesStore } from '@/app/components/base/features/hooks' -import { API_PREFIX } from '@/config' +import { useHooksStore } from '@/app/components/workflow/hooks-store' export const useNodesSyncDraft = () => { - const store = useStoreApi() - const workflowStore = useWorkflowStore() - const featuresStore = useFeaturesStore() const { getNodesReadOnly } = useNodesReadOnly() - const { handleRefreshWorkflowDraft } = useWorkflowUpdate() const debouncedSyncWorkflowDraft = useStore(s => s.debouncedSyncWorkflowDraft) - const params = useParams() - - const getPostParams = useCallback(() => { - const { - getNodes, - edges, - transform, - } = store.getState() - const [x, y, zoom] = transform - const { - appId, - conversationVariables, - environmentVariables, - syncWorkflowDraftHash, - } = workflowStore.getState() - - if (appId) { - const nodes = getNodes() - const hasStartNode = nodes.find(node => node.data.type === BlockEnum.Start) - - if (!hasStartNode) - return - - const features = featuresStore!.getState().features - const producedNodes = produce(nodes, (draft) => { - draft.forEach((node) => { - Object.keys(node.data).forEach((key) => { - if (key.startsWith('_')) - delete node.data[key] - }) - }) - }) - const producedEdges = produce(edges, (draft) => { - draft.forEach((edge) => { - Object.keys(edge.data).forEach((key) => { - if (key.startsWith('_')) - delete edge.data[key] - }) - }) - }) - return { - url: `/apps/${appId}/workflows/draft`, - params: { - graph: { - nodes: producedNodes, - edges: producedEdges, - viewport: { - x, - y, - zoom, - }, - }, - features: { - opening_statement: features.opening?.enabled ? (features.opening?.opening_statement || '') : '', - suggested_questions: features.opening?.enabled ? (features.opening?.suggested_questions || []) : [], - suggested_questions_after_answer: features.suggested, - text_to_speech: features.text2speech, - speech_to_text: features.speech2text, - retriever_resource: features.citation, - sensitive_word_avoidance: features.moderation, - file_upload: features.file, - }, - environment_variables: environmentVariables, - conversation_variables: conversationVariables, - hash: syncWorkflowDraftHash, - }, - } - } - }, [store, featuresStore, workflowStore]) - - const syncWorkflowDraftWhenPageClose = useCallback(() => { - if (getNodesReadOnly()) - return - const postParams = getPostParams() - - if (postParams) { - navigator.sendBeacon( - `${API_PREFIX}/apps/${params.appId}/workflows/draft?_token=${localStorage.getItem('console_token')}`, - JSON.stringify(postParams.params), - ) - } - }, [getPostParams, params.appId, getNodesReadOnly]) - - const doSyncWorkflowDraft = useCallback(async ( - notRefreshWhenSyncError?: boolean, - callback?: { - onSuccess?: () => void - onError?: () => void - onSettled?: () => void - }, - ) => { - if (getNodesReadOnly()) - return - const postParams = getPostParams() - - if (postParams) { - const { - setSyncWorkflowDraftHash, - setDraftUpdatedAt, - } = workflowStore.getState() - try { - const res = await syncWorkflowDraft(postParams) - setSyncWorkflowDraftHash(res.hash) - setDraftUpdatedAt(res.updated_at) - callback?.onSuccess && callback.onSuccess() - } - catch (error: any) { - if (error && error.json && !error.bodyUsed) { - error.json().then((err: any) => { - if (err.code === 'draft_workflow_not_sync' && !notRefreshWhenSyncError) - handleRefreshWorkflowDraft() - }) - } - callback?.onError && callback.onError() - } - finally { - callback?.onSettled && callback.onSettled() - } - } - }, [workflowStore, getPostParams, getNodesReadOnly, handleRefreshWorkflowDraft]) + const doSyncWorkflowDraft = useHooksStore(s => s.doSyncWorkflowDraft) + const syncWorkflowDraftWhenPageClose = useHooksStore(s => s.syncWorkflowDraftWhenPageClose) const handleSyncWorkflowDraft = useCallback(( sync?: boolean, diff --git a/web/app/components/workflow/hooks/use-workflow-interactions.ts b/web/app/components/workflow/hooks/use-workflow-interactions.ts index 202867e22f..740868c594 100644 --- a/web/app/components/workflow/hooks/use-workflow-interactions.ts +++ b/web/app/components/workflow/hooks/use-workflow-interactions.ts @@ -25,8 +25,8 @@ import { useSelectionInteractions, useWorkflowReadOnly, } from '../hooks' -import { useEdgesInteractions } from './use-edges-interactions' -import { useNodesInteractions } from './use-nodes-interactions' +import { useEdgesInteractionsWithoutSync } from './use-edges-interactions-without-sync' +import { useNodesInteractionsWithoutSync } from './use-nodes-interactions-without-sync' import { useNodesSyncDraft } from './use-nodes-sync-draft' import { WorkflowHistoryEvent, useWorkflowHistory } from './use-workflow-history' import { useEventEmitterContextContext } from '@/context/event-emitter' @@ -37,8 +37,8 @@ import { useStore as useAppStore } from '@/app/components/app/store' export const useWorkflowInteractions = () => { const workflowStore = useWorkflowStore() - const { handleNodeCancelRunningStatus } = useNodesInteractions() - const { handleEdgeCancelRunningStatus } = useEdgesInteractions() + const { handleNodeCancelRunningStatus } = useNodesInteractionsWithoutSync() + const { handleEdgeCancelRunningStatus } = useEdgesInteractionsWithoutSync() const handleCancelDebugAndPreviewPanel = useCallback(() => { workflowStore.setState({ diff --git a/web/app/components/workflow/hooks/use-workflow-run.ts b/web/app/components/workflow/hooks/use-workflow-run.ts index 99d9a45702..05a60ebb4b 100644 --- a/web/app/components/workflow/hooks/use-workflow-run.ts +++ b/web/app/components/workflow/hooks/use-workflow-run.ts @@ -1,350 +1,11 @@ -import { useCallback } from 'react' -import { - useReactFlow, - useStoreApi, -} from 'reactflow' -import produce from 'immer' -import { v4 as uuidV4 } from 'uuid' -import { usePathname } from 'next/navigation' -import { useWorkflowStore } from '../store' -import { useNodesSyncDraft } from '../hooks' -import { WorkflowRunningStatus } from '../types' -import { useWorkflowUpdate } from './use-workflow-interactions' -import { useWorkflowRunEvent } from './use-workflow-run-event/use-workflow-run-event' -import { useStore as useAppStore } from '@/app/components/app/store' -import type { IOtherOptions } from '@/service/base' -import { ssePost } from '@/service/base' -import { stopWorkflowRun } from '@/service/workflow' -import { useFeaturesStore } from '@/app/components/base/features/hooks' -import { AudioPlayerManager } from '@/app/components/base/audio-btn/audio.player.manager' -import type { VersionHistory } from '@/types/workflow' -import { noop } from 'lodash-es' +import { useHooksStore } from '@/app/components/workflow/hooks-store' export const useWorkflowRun = () => { - const store = useStoreApi() - const workflowStore = useWorkflowStore() - const reactflow = useReactFlow() - const featuresStore = useFeaturesStore() - const { doSyncWorkflowDraft } = useNodesSyncDraft() - const { handleUpdateWorkflowCanvas } = useWorkflowUpdate() - const pathname = usePathname() - const { - handleWorkflowStarted, - handleWorkflowFinished, - handleWorkflowFailed, - handleWorkflowNodeStarted, - handleWorkflowNodeFinished, - handleWorkflowNodeIterationStarted, - handleWorkflowNodeIterationNext, - handleWorkflowNodeIterationFinished, - handleWorkflowNodeLoopStarted, - handleWorkflowNodeLoopNext, - handleWorkflowNodeLoopFinished, - handleWorkflowNodeRetry, - handleWorkflowAgentLog, - handleWorkflowTextChunk, - handleWorkflowTextReplace, - } = useWorkflowRunEvent() - - const handleBackupDraft = useCallback(() => { - const { - getNodes, - edges, - } = store.getState() - const { getViewport } = reactflow - const { - backupDraft, - setBackupDraft, - environmentVariables, - } = workflowStore.getState() - const { features } = featuresStore!.getState() - - if (!backupDraft) { - setBackupDraft({ - nodes: getNodes(), - edges, - viewport: getViewport(), - features, - environmentVariables, - }) - doSyncWorkflowDraft() - } - }, [reactflow, workflowStore, store, featuresStore, doSyncWorkflowDraft]) - - const handleLoadBackupDraft = useCallback(() => { - const { - backupDraft, - setBackupDraft, - setEnvironmentVariables, - } = workflowStore.getState() - - if (backupDraft) { - const { - nodes, - edges, - viewport, - features, - environmentVariables, - } = backupDraft - handleUpdateWorkflowCanvas({ - nodes, - edges, - viewport, - }) - setEnvironmentVariables(environmentVariables) - featuresStore!.setState({ features }) - setBackupDraft(undefined) - } - }, [handleUpdateWorkflowCanvas, workflowStore, featuresStore]) - - const handleRun = useCallback(async ( - params: any, - callback?: IOtherOptions, - ) => { - const { - getNodes, - setNodes, - } = store.getState() - const newNodes = produce(getNodes(), (draft) => { - draft.forEach((node) => { - node.data.selected = false - node.data._runningStatus = undefined - }) - }) - setNodes(newNodes) - await doSyncWorkflowDraft() - - const { - onWorkflowStarted, - onWorkflowFinished, - onNodeStarted, - onNodeFinished, - onIterationStart, - onIterationNext, - onIterationFinish, - onLoopStart, - onLoopNext, - onLoopFinish, - onNodeRetry, - onAgentLog, - onError, - ...restCallback - } = callback || {} - workflowStore.setState({ historyWorkflowData: undefined }) - const appDetail = useAppStore.getState().appDetail - const workflowContainer = document.getElementById('workflow-container') - - const { - clientWidth, - clientHeight, - } = workflowContainer! - - let url = '' - if (appDetail?.mode === 'advanced-chat') - url = `/apps/${appDetail.id}/advanced-chat/workflows/draft/run` - - if (appDetail?.mode === 'workflow') - url = `/apps/${appDetail.id}/workflows/draft/run` - - const { - setWorkflowRunningData, - } = workflowStore.getState() - setWorkflowRunningData({ - result: { - status: WorkflowRunningStatus.Running, - }, - tracing: [], - resultText: '', - }) - - let ttsUrl = '' - let ttsIsPublic = false - if (params.token) { - ttsUrl = '/text-to-audio' - ttsIsPublic = true - } - else if (params.appId) { - if (pathname.search('explore/installed') > -1) - ttsUrl = `/installed-apps/${params.appId}/text-to-audio` - else - ttsUrl = `/apps/${params.appId}/text-to-audio` - } - const player = AudioPlayerManager.getInstance().getAudioPlayer(ttsUrl, ttsIsPublic, uuidV4(), 'none', 'none', noop) - - ssePost( - url, - { - body: params, - }, - { - onWorkflowStarted: (params) => { - handleWorkflowStarted(params) - - if (onWorkflowStarted) - onWorkflowStarted(params) - }, - onWorkflowFinished: (params) => { - handleWorkflowFinished(params) - - if (onWorkflowFinished) - onWorkflowFinished(params) - }, - onError: (params) => { - handleWorkflowFailed() - - if (onError) - onError(params) - }, - onNodeStarted: (params) => { - handleWorkflowNodeStarted( - params, - { - clientWidth, - clientHeight, - }, - ) - - if (onNodeStarted) - onNodeStarted(params) - }, - onNodeFinished: (params) => { - handleWorkflowNodeFinished(params) - - if (onNodeFinished) - onNodeFinished(params) - }, - onIterationStart: (params) => { - handleWorkflowNodeIterationStarted( - params, - { - clientWidth, - clientHeight, - }, - ) - - if (onIterationStart) - onIterationStart(params) - }, - onIterationNext: (params) => { - handleWorkflowNodeIterationNext(params) - - if (onIterationNext) - onIterationNext(params) - }, - onIterationFinish: (params) => { - handleWorkflowNodeIterationFinished(params) - - if (onIterationFinish) - onIterationFinish(params) - }, - onLoopStart: (params) => { - handleWorkflowNodeLoopStarted( - params, - { - clientWidth, - clientHeight, - }, - ) - - if (onLoopStart) - onLoopStart(params) - }, - onLoopNext: (params) => { - handleWorkflowNodeLoopNext(params) - - if (onLoopNext) - onLoopNext(params) - }, - onLoopFinish: (params) => { - handleWorkflowNodeLoopFinished(params) - - if (onLoopFinish) - onLoopFinish(params) - }, - onNodeRetry: (params) => { - handleWorkflowNodeRetry(params) - - if (onNodeRetry) - onNodeRetry(params) - }, - onAgentLog: (params) => { - handleWorkflowAgentLog(params) - - if (onAgentLog) - onAgentLog(params) - }, - onTextChunk: (params) => { - handleWorkflowTextChunk(params) - }, - onTextReplace: (params) => { - handleWorkflowTextReplace(params) - }, - onTTSChunk: (messageId: string, audio: string) => { - if (!audio || audio === '') - return - player.playAudioWithAudio(audio, true) - AudioPlayerManager.getInstance().resetMsgId(messageId) - }, - onTTSEnd: (messageId: string, audio: string) => { - player.playAudioWithAudio(audio, false) - }, - ...restCallback, - }, - ) - }, [ - store, - workflowStore, - doSyncWorkflowDraft, - handleWorkflowStarted, - handleWorkflowFinished, - handleWorkflowFailed, - handleWorkflowNodeStarted, - handleWorkflowNodeFinished, - handleWorkflowNodeIterationStarted, - handleWorkflowNodeIterationNext, - handleWorkflowNodeIterationFinished, - handleWorkflowNodeLoopStarted, - handleWorkflowNodeLoopNext, - handleWorkflowNodeLoopFinished, - handleWorkflowNodeRetry, - handleWorkflowTextChunk, - handleWorkflowTextReplace, - handleWorkflowAgentLog, - pathname], - ) - - const handleStopRun = useCallback((taskId: string) => { - const appId = useAppStore.getState().appDetail?.id - - stopWorkflowRun(`/apps/${appId}/workflow-runs/tasks/${taskId}/stop`) - }, []) - - const handleRestoreFromPublishedWorkflow = useCallback((publishedWorkflow: VersionHistory) => { - const nodes = publishedWorkflow.graph.nodes.map(node => ({ ...node, selected: false, data: { ...node.data, selected: false } })) - const edges = publishedWorkflow.graph.edges - const viewport = publishedWorkflow.graph.viewport! - handleUpdateWorkflowCanvas({ - nodes, - edges, - viewport, - }) - const mappedFeatures = { - opening: { - enabled: !!publishedWorkflow.features.opening_statement || !!publishedWorkflow.features.suggested_questions.length, - opening_statement: publishedWorkflow.features.opening_statement, - suggested_questions: publishedWorkflow.features.suggested_questions, - }, - suggested: publishedWorkflow.features.suggested_questions_after_answer, - text2speech: publishedWorkflow.features.text_to_speech, - speech2text: publishedWorkflow.features.speech_to_text, - citation: publishedWorkflow.features.retriever_resource, - moderation: publishedWorkflow.features.sensitive_word_avoidance, - file: publishedWorkflow.features.file_upload, - } - - featuresStore?.setState({ features: mappedFeatures }) - workflowStore.getState().setEnvironmentVariables(publishedWorkflow.environment_variables || []) - }, [featuresStore, handleUpdateWorkflowCanvas, workflowStore]) + const handleBackupDraft = useHooksStore(s => s.handleBackupDraft) + const handleLoadBackupDraft = useHooksStore(s => s.handleLoadBackupDraft) + const handleRestoreFromPublishedWorkflow = useHooksStore(s => s.handleRestoreFromPublishedWorkflow) + const handleRun = useHooksStore(s => s.handleRun) + const handleStopRun = useHooksStore(s => s.handleStopRun) return { handleBackupDraft, diff --git a/web/app/components/workflow/hooks/use-workflow-start-run.tsx b/web/app/components/workflow/hooks/use-workflow-start-run.tsx index b2b1c69975..0f4e68fe95 100644 --- a/web/app/components/workflow/hooks/use-workflow-start-run.tsx +++ b/web/app/components/workflow/hooks/use-workflow-start-run.tsx @@ -1,92 +1,9 @@ -import { useCallback } from 'react' -import { useStoreApi } from 'reactflow' -import { useWorkflowStore } from '../store' -import { - BlockEnum, - WorkflowRunningStatus, -} from '../types' -import { - useIsChatMode, - useNodesSyncDraft, - useWorkflowInteractions, - useWorkflowRun, -} from './index' -import { useFeaturesStore } from '@/app/components/base/features/hooks' +import { useHooksStore } from '@/app/components/workflow/hooks-store' export const useWorkflowStartRun = () => { - const store = useStoreApi() - const workflowStore = useWorkflowStore() - const featuresStore = useFeaturesStore() - const isChatMode = useIsChatMode() - const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions() - const { handleRun } = useWorkflowRun() - const { doSyncWorkflowDraft } = useNodesSyncDraft() - - const handleWorkflowStartRunInWorkflow = useCallback(async () => { - const { - workflowRunningData, - } = workflowStore.getState() - - if (workflowRunningData?.result.status === WorkflowRunningStatus.Running) - return - - const { getNodes } = store.getState() - const nodes = getNodes() - const startNode = nodes.find(node => node.data.type === BlockEnum.Start) - const startVariables = startNode?.data.variables || [] - const fileSettings = featuresStore!.getState().features.file - const { - showDebugAndPreviewPanel, - setShowDebugAndPreviewPanel, - setShowInputsPanel, - setShowEnvPanel, - } = workflowStore.getState() - - setShowEnvPanel(false) - - if (showDebugAndPreviewPanel) { - handleCancelDebugAndPreviewPanel() - return - } - - if (!startVariables.length && !fileSettings?.image?.enabled) { - await doSyncWorkflowDraft() - handleRun({ inputs: {}, files: [] }) - setShowDebugAndPreviewPanel(true) - setShowInputsPanel(false) - } - else { - setShowDebugAndPreviewPanel(true) - setShowInputsPanel(true) - } - }, [store, workflowStore, featuresStore, handleCancelDebugAndPreviewPanel, handleRun, doSyncWorkflowDraft]) - - const handleWorkflowStartRunInChatflow = useCallback(async () => { - const { - showDebugAndPreviewPanel, - setShowDebugAndPreviewPanel, - setHistoryWorkflowData, - setShowEnvPanel, - setShowChatVariablePanel, - } = workflowStore.getState() - - setShowEnvPanel(false) - setShowChatVariablePanel(false) - - if (showDebugAndPreviewPanel) - handleCancelDebugAndPreviewPanel() - else - setShowDebugAndPreviewPanel(true) - - setHistoryWorkflowData(undefined) - }, [workflowStore, handleCancelDebugAndPreviewPanel]) - - const handleStartWorkflowRun = useCallback(() => { - if (!isChatMode) - handleWorkflowStartRunInWorkflow() - else - handleWorkflowStartRunInChatflow() - }, [isChatMode, handleWorkflowStartRunInWorkflow, handleWorkflowStartRunInChatflow]) + const handleStartWorkflowRun = useHooksStore(s => s.handleStartWorkflowRun) + const handleWorkflowStartRunInWorkflow = useHooksStore(s => s.handleWorkflowStartRunInWorkflow) + const handleWorkflowStartRunInChatflow = useHooksStore(s => s.handleWorkflowStartRunInChatflow) return { handleStartWorkflowRun, diff --git a/web/app/components/workflow/hooks/use-workflow-variables.ts b/web/app/components/workflow/hooks/use-workflow-variables.ts index a2863671ed..35637bc775 100644 --- a/web/app/components/workflow/hooks/use-workflow-variables.ts +++ b/web/app/components/workflow/hooks/use-workflow-variables.ts @@ -8,6 +8,8 @@ import type { ValueSelector, Var, } from '@/app/components/workflow/types' +import { useIsChatMode } from './use-workflow' +import { useStoreApi } from 'reactflow' export const useWorkflowVariables = () => { const { t } = useTranslation() @@ -75,3 +77,37 @@ export const useWorkflowVariables = () => { getCurrentVariableType, } } + +export const useWorkflowVariableType = () => { + const store = useStoreApi() + const { + getNodes, + } = store.getState() + const { getCurrentVariableType } = useWorkflowVariables() + + const isChatMode = useIsChatMode() + + const getVarType = ({ + nodeId, + valueSelector, + }: { + nodeId: string, + valueSelector: ValueSelector, + }) => { + const node = getNodes().find(n => n.id === nodeId) + const isInIteration = !!node?.data.isInIteration + const iterationNode = isInIteration ? getNodes().find(n => n.id === node.parentId) : null + const availableNodes = [node] + + const type = getCurrentVariableType({ + parentNode: iterationNode, + valueSelector, + availableNodes, + isChatMode, + isConstant: false, + }) + return type + } + + return getVarType +} diff --git a/web/app/components/workflow/hooks/use-workflow.ts b/web/app/components/workflow/hooks/use-workflow.ts index 7a15afa2e4..99dce4dc15 100644 --- a/web/app/components/workflow/hooks/use-workflow.ts +++ b/web/app/components/workflow/hooks/use-workflow.ts @@ -1,13 +1,9 @@ import { useCallback, - useEffect, useMemo, - useState, } from 'react' -import dayjs from 'dayjs' import { uniqBy } from 'lodash-es' import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' import { getIncomers, getOutgoers, @@ -40,25 +36,15 @@ import { import { CUSTOM_NOTE_NODE } from '../note-node/constants' import { findUsedVarNodes, getNodeOutputVars, updateNodeVars } from '../nodes/_base/components/variable/utils' import { useNodesExtraData } from './use-nodes-data' -import { useWorkflowTemplate } from './use-workflow-template' import { useStore as useAppStore } from '@/app/components/app/store' -import { - fetchNodesDefaultConfigs, - fetchPublishedWorkflow, - fetchWorkflowDraft, - syncWorkflowDraft, -} from '@/service/workflow' -import type { FetchWorkflowDraftResponse } from '@/types/workflow' import { fetchAllBuiltInTools, fetchAllCustomTools, fetchAllWorkflowTools, } from '@/service/tools' -import I18n from '@/context/i18n' import { CollectionType } from '@/app/components/tools/types' import { CUSTOM_ITERATION_START_NODE } from '@/app/components/workflow/nodes/iteration-start/constants' import { CUSTOM_LOOP_START_NODE } from '@/app/components/workflow/nodes/loop-start/constants' -import { useWorkflowConfig } from '@/service/use-workflow' import { basePath } from '@/utils/var' import { canFindTool } from '@/utils' @@ -70,12 +56,9 @@ export const useIsChatMode = () => { export const useWorkflow = () => { const { t } = useTranslation() - const { locale } = useContext(I18n) const store = useStoreApi() const workflowStore = useWorkflowStore() - const appId = useStore(s => s.appId) const nodesExtraData = useNodesExtraData() - const { data: workflowConfig } = useWorkflowConfig(appId) const setPanelWidth = useCallback((width: number) => { localStorage.setItem('workflow-node-panel-width', `${width}`) workflowStore.setState({ panelWidth: width }) @@ -120,7 +103,7 @@ export const useWorkflow = () => { list.push(...incomers) - return uniqBy(list, 'id').filter((item) => { + return uniqBy(list, 'id').filter((item: Node) => { return SUPPORT_OUTPUT_VARS_NODE.includes(item.data.type) }) }, [store]) @@ -167,7 +150,7 @@ export const useWorkflow = () => { const length = list.length if (length) { - return uniqBy(list, 'id').reverse().filter((item) => { + return uniqBy(list, 'id').reverse().filter((item: Node) => { return SUPPORT_OUTPUT_VARS_NODE.includes(item.data.type) }) } @@ -344,6 +327,7 @@ export const useWorkflow = () => { parallelList, hasAbnormalEdges, } = getParallelInfo(nodes, edges, parentNodeId) + const { workflowConfig } = workflowStore.getState() if (hasAbnormalEdges) return false @@ -359,7 +343,7 @@ export const useWorkflow = () => { } return true - }, [t, workflowStore, workflowConfig?.parallel_depth_limit]) + }, [t, workflowStore]) const isValidConnection = useCallback(({ source, sourceHandle, target }: Connection) => { const { @@ -407,10 +391,6 @@ export const useWorkflow = () => { return !hasCycle(targetNode) }, [store, nodesExtraData, checkParallelLimit]) - const formatTimeFromNow = useCallback((time: number) => { - return dayjs(time).locale(locale === 'zh-Hans' ? 'zh-cn' : locale).fromNow() - }, [locale]) - const getNode = useCallback((nodeId?: string) => { const { getNodes } = store.getState() const nodes = getNodes() @@ -432,7 +412,6 @@ export const useWorkflow = () => { checkNestedParallelLimit, isValidConnection, isFromStartNode, - formatTimeFromNow, getNode, getBeforeNodeById, getIterationNodeChildren, @@ -478,107 +457,6 @@ export const useFetchToolsData = () => { } } -export const useWorkflowInit = () => { - const workflowStore = useWorkflowStore() - const { - nodes: nodesTemplate, - edges: edgesTemplate, - } = useWorkflowTemplate() - const { handleFetchAllTools } = useFetchToolsData() - const appDetail = useAppStore(state => state.appDetail)! - const setSyncWorkflowDraftHash = useStore(s => s.setSyncWorkflowDraftHash) - const [data, setData] = useState() - const [isLoading, setIsLoading] = useState(true) - useEffect(() => { - workflowStore.setState({ appId: appDetail.id }) - }, [appDetail.id, workflowStore]) - - const handleGetInitialWorkflowData = useCallback(async () => { - try { - const res = await fetchWorkflowDraft(`/apps/${appDetail.id}/workflows/draft`) - setData(res) - workflowStore.setState({ - envSecrets: (res.environment_variables || []).filter(env => env.value_type === 'secret').reduce((acc, env) => { - acc[env.id] = env.value - return acc - }, {} as Record), - environmentVariables: res.environment_variables?.map(env => env.value_type === 'secret' ? { ...env, value: '[__HIDDEN__]' } : env) || [], - conversationVariables: res.conversation_variables || [], - }) - setSyncWorkflowDraftHash(res.hash) - setIsLoading(false) - } - catch (error: any) { - if (error && error.json && !error.bodyUsed && appDetail) { - error.json().then((err: any) => { - if (err.code === 'draft_workflow_not_exist') { - workflowStore.setState({ notInitialWorkflow: true }) - syncWorkflowDraft({ - url: `/apps/${appDetail.id}/workflows/draft`, - params: { - graph: { - nodes: nodesTemplate, - edges: edgesTemplate, - }, - features: { - retriever_resource: { enabled: true }, - }, - environment_variables: [], - conversation_variables: [], - }, - }).then((res) => { - workflowStore.getState().setDraftUpdatedAt(res.updated_at) - handleGetInitialWorkflowData() - }) - } - }) - } - } - }, [appDetail, nodesTemplate, edgesTemplate, workflowStore, setSyncWorkflowDraftHash]) - - useEffect(() => { - handleGetInitialWorkflowData() - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []) - - const handleFetchPreloadData = useCallback(async () => { - try { - const nodesDefaultConfigsData = await fetchNodesDefaultConfigs(`/apps/${appDetail?.id}/workflows/default-workflow-block-configs`) - const publishedWorkflow = await fetchPublishedWorkflow(`/apps/${appDetail?.id}/workflows/publish`) - workflowStore.setState({ - nodesDefaultConfigs: nodesDefaultConfigsData.reduce((acc, block) => { - if (!acc[block.type]) - acc[block.type] = { ...block.config } - return acc - }, {} as Record), - }) - workflowStore.getState().setPublishedAt(publishedWorkflow?.created_at) - } - catch (e) { - console.error(e) - } - }, [workflowStore, appDetail]) - - useEffect(() => { - handleFetchPreloadData() - handleFetchAllTools('builtin') - handleFetchAllTools('custom') - handleFetchAllTools('workflow') - }, [handleFetchPreloadData, handleFetchAllTools]) - - useEffect(() => { - if (data) { - workflowStore.getState().setDraftUpdatedAt(data.updated_at) - workflowStore.getState().setToolPublished(data.tool_published) - } - }, [data, workflowStore]) - - return { - data, - isLoading, - } -} - export const useWorkflowReadOnly = () => { const workflowStore = useWorkflowStore() const workflowRunningData = useStore(s => s.workflowRunningData) diff --git a/web/app/components/workflow/index.tsx b/web/app/components/workflow/index.tsx index 4c48afb56c..9a3e13822a 100644 --- a/web/app/components/workflow/index.tsx +++ b/web/app/components/workflow/index.tsx @@ -5,11 +5,8 @@ import { memo, useCallback, useEffect, - useMemo, useRef, - useState, } from 'react' -import useSWR from 'swr' import { setAutoFreeze } from 'immer' import { useEventListener, @@ -31,17 +28,14 @@ import 'reactflow/dist/style.css' import './style.css' import type { Edge, - EnvironmentVariable, Node, } from './types' import { ControlMode, - SupportUploadFileTypes, } from './types' -import { WorkflowContextProvider } from './context' import { - useDSL, useEdgesInteractions, + useFetchToolsData, useNodesInteractions, useNodesReadOnly, useNodesSyncDraft, @@ -49,11 +43,9 @@ import { useSelectionInteractions, useShortcuts, useWorkflow, - useWorkflowInit, useWorkflowReadOnly, useWorkflowUpdate, } from './hooks' -import Header from './header' import CustomNode from './nodes' import CustomNoteNode from './note-node' import { CUSTOM_NOTE_NODE } from './note-node/constants' @@ -66,42 +58,28 @@ import { CUSTOM_SIMPLE_NODE } from './simple-node/constants' import Operator from './operator' import CustomEdge from './custom-edge' import CustomConnectionLine from './custom-connection-line' -import Panel from './panel' -import Features from './features' import HelpLine from './help-line' import CandidateNode from './candidate-node' import PanelContextmenu from './panel-contextmenu' import NodeContextmenu from './node-contextmenu' import SyncingDataModal from './syncing-data-modal' -import UpdateDSLModal from './update-dsl-modal' -import DSLExportConfirmModal from './dsl-export-confirm-modal' import LimitTips from './limit-tips' -import PluginDependency from './plugin-dependency' import { useStore, useWorkflowStore, } from './store' -import { - initialEdges, - initialNodes, -} from './utils' import { CUSTOM_EDGE, CUSTOM_NODE, - DSL_EXPORT_CHECK, ITERATION_CHILDREN_Z_INDEX, WORKFLOW_DATA_UPDATE, } from './constants' import { WorkflowHistoryProvider } from './workflow-history-store' -import Loading from '@/app/components/base/loading' -import { FeaturesProvider } from '@/app/components/base/features' -import type { Features as FeaturesData } from '@/app/components/base/features/types' -import { useFeaturesStore } from '@/app/components/base/features/hooks' import { useEventEmitterContextContext } from '@/context/event-emitter' import Confirm from '@/app/components/base/confirm' -import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants' -import { fetchFileUploadConfig } from '@/service/common' import DatasetsDetailProvider from './datasets-detail-store/provider' +import { HooksStoreContextProvider } from './hooks-store' +import type { Shape as HooksStoreShape } from './hooks-store' const nodeTypes = { [CUSTOM_NODE]: CustomNode, @@ -114,32 +92,32 @@ const edgeTypes = { [CUSTOM_EDGE]: CustomEdge, } -type WorkflowProps = { +export type WorkflowProps = { nodes: Node[] edges: Edge[] viewport?: Viewport + children?: React.ReactNode + onWorkflowDataUpdate?: (v: any) => void } -const Workflow: FC = memo(({ +export const Workflow: FC = memo(({ nodes: originalNodes, edges: originalEdges, viewport, + children, + onWorkflowDataUpdate, }) => { const workflowContainerRef = useRef(null) const workflowStore = useWorkflowStore() const reactflow = useReactFlow() - const featuresStore = useFeaturesStore() const [nodes, setNodes] = useNodesState(originalNodes) const [edges, setEdges] = useEdgesState(originalEdges) - const showFeaturesPanel = useStore(state => state.showFeaturesPanel) const controlMode = useStore(s => s.controlMode) const nodeAnimation = useStore(s => s.nodeAnimation) const showConfirm = useStore(s => s.showConfirm) - const showImportDSLModal = useStore(s => s.showImportDSLModal) const { setShowConfirm, setControlPromptEditorRerenderKey, - setShowImportDSLModal, setSyncWorkflowDraftHash, } = workflowStore.getState() const { @@ -148,9 +126,6 @@ const Workflow: FC = memo(({ } = useNodesSyncDraft() const { workflowReadOnly } = useWorkflowReadOnly() const { nodesReadOnly } = useNodesReadOnly() - - const [secretEnvList, setSecretEnvList] = useState([]) - const { eventEmitter } = useEventEmitterContextContext() eventEmitter?.useSubscription((v: any) => { @@ -161,19 +136,13 @@ const Workflow: FC = memo(({ if (v.payload.viewport) reactflow.setViewport(v.payload.viewport) - if (v.payload.features && featuresStore) { - const { setFeatures } = featuresStore.getState() - - setFeatures(v.payload.features) - } - if (v.payload.hash) setSyncWorkflowDraftHash(v.payload.hash) + onWorkflowDataUpdate?.(v.payload) + setTimeout(() => setControlPromptEditorRerenderKey(Date.now())) } - if (v.type === DSL_EXPORT_CHECK) - setSecretEnvList(v.payload.data as EnvironmentVariable[]) }) useEffect(() => { @@ -231,6 +200,12 @@ const Workflow: FC = memo(({ }) } }) + const { handleFetchAllTools } = useFetchToolsData() + useEffect(() => { + handleFetchAllTools('builtin') + handleFetchAllTools('custom') + handleFetchAllTools('workflow') + }, [handleFetchAllTools]) const { handleNodeDragStart, @@ -258,15 +233,10 @@ const Workflow: FC = memo(({ } = useSelectionInteractions() const { handlePaneContextMenu, - handlePaneContextmenuCancel, } = usePanelInteractions() const { isValidConnection, } = useWorkflow() - const { - exportCheck, - handleExportDSL, - } = useDSL() useOnViewportChange({ onEnd: () => { @@ -297,12 +267,7 @@ const Workflow: FC = memo(({ > -
- - { - showFeaturesPanel && - } @@ -317,26 +282,8 @@ const Workflow: FC = memo(({ /> ) } - { - showImportDSLModal && ( - setShowImportDSLModal(false)} - onBackup={exportCheck} - onImport={handlePaneContextmenuCancel} - /> - ) - } - { - secretEnvList.length > 0 && ( - setSecretEnvList([])} - /> - ) - } - + {children} = memo(({
) }) -Workflow.displayName = 'Workflow' - -const WorkflowWrap = memo(() => { - const { - data, - isLoading, - } = useWorkflowInit() - const { data: fileUploadConfigResponse } = useSWR({ url: '/files/upload' }, fetchFileUploadConfig) - - const nodesData = useMemo(() => { - if (data) - return initialNodes(data.graph.nodes, data.graph.edges) - - return [] - }, [data]) - const edgesData = useMemo(() => { - if (data) - return initialEdges(data.graph.edges, data.graph.nodes) - return [] - }, [data]) - - if (!data || isLoading) { - return ( -
- -
- ) - } +type WorkflowWithInnerContextProps = WorkflowProps & { + hooksStore?: Partial +} +export const WorkflowWithInnerContext = memo(({ + hooksStore, + ...restProps +}: WorkflowWithInnerContextProps) => { + return ( + + + + ) +}) - const features = data.features || {} - const initialFeatures: FeaturesData = { - file: { - image: { - enabled: !!features.file_upload?.image?.enabled, - number_limits: features.file_upload?.image?.number_limits || 3, - transfer_methods: features.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], - }, - enabled: !!(features.file_upload?.enabled || features.file_upload?.image?.enabled), - allowed_file_types: features.file_upload?.allowed_file_types || [SupportUploadFileTypes.image], - allowed_file_extensions: features.file_upload?.allowed_file_extensions || FILE_EXTS[SupportUploadFileTypes.image].map(ext => `.${ext}`), - allowed_file_upload_methods: features.file_upload?.allowed_file_upload_methods || features.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], - number_limits: features.file_upload?.number_limits || features.file_upload?.image?.number_limits || 3, - fileUploadConfig: fileUploadConfigResponse, - }, - opening: { - enabled: !!features.opening_statement, - opening_statement: features.opening_statement, - suggested_questions: features.suggested_questions, - }, - suggested: features.suggested_questions_after_answer || { enabled: false }, - speech2text: features.speech_to_text || { enabled: false }, - text2speech: features.text_to_speech || { enabled: false }, - citation: features.retriever_resource || { enabled: false }, - moderation: features.sensitive_word_avoidance || { enabled: false }, +type WorkflowWithDefaultContextProps = + Pick + & { + children: React.ReactNode } +const WorkflowWithDefaultContext = ({ + nodes, + edges, + children, +}: WorkflowWithDefaultContextProps) => { return ( - - - - - + nodes={nodes} + edges={edges} > + + {children} + ) -}) -WorkflowWrap.displayName = 'WorkflowWrap' - -const WorkflowContainer = () => { - return ( - - - - ) } -export default memo(WorkflowContainer) +export default memo(WorkflowWithDefaultContext) diff --git a/web/app/components/workflow/nodes/_base/components/agent-strategy.tsx b/web/app/components/workflow/nodes/_base/components/agent-strategy.tsx index be57cbca0f..d67b7af1a4 100644 --- a/web/app/components/workflow/nodes/_base/components/agent-strategy.tsx +++ b/web/app/components/workflow/nodes/_base/components/agent-strategy.tsx @@ -133,7 +133,7 @@ export const AgentStrategy = memo((props: AgentStrategyProps) => { // TODO: maybe empty, handle this onChange={onChange as any} defaultValue={defaultValue} - size='sm' + size='regular' min={def.min} max={def.max} className='w-12' diff --git a/web/app/components/workflow/nodes/_base/components/collapse/field-collapse.tsx b/web/app/components/workflow/nodes/_base/components/collapse/field-collapse.tsx index 4b36125575..2390dfd74e 100644 --- a/web/app/components/workflow/nodes/_base/components/collapse/field-collapse.tsx +++ b/web/app/components/workflow/nodes/_base/components/collapse/field-collapse.tsx @@ -4,10 +4,16 @@ import Collapse from '.' type FieldCollapseProps = { title: string children: ReactNode + collapsed?: boolean + onCollapse?: (collapsed: boolean) => void + operations?: ReactNode } const FieldCollapse = ({ title, children, + collapsed, + onCollapse, + operations, }: FieldCollapseProps) => { return (
@@ -15,6 +21,9 @@ const FieldCollapse = ({ trigger={
{title}
} + operations={operations} + collapsed={collapsed} + onCollapse={onCollapse} >
{children} diff --git a/web/app/components/workflow/nodes/_base/components/collapse/index.tsx b/web/app/components/workflow/nodes/_base/components/collapse/index.tsx index 1f39c1c1c5..16fba88a25 100644 --- a/web/app/components/workflow/nodes/_base/components/collapse/index.tsx +++ b/web/app/components/workflow/nodes/_base/components/collapse/index.tsx @@ -1,15 +1,18 @@ -import { useState } from 'react' -import { RiArrowDropRightLine } from '@remixicon/react' +import type { ReactNode } from 'react' +import { useMemo, useState } from 'react' +import { ArrowDownRoundFill } from '@/app/components/base/icons/src/vender/solid/general' import cn from '@/utils/classnames' export { default as FieldCollapse } from './field-collapse' type CollapseProps = { disabled?: boolean - trigger: React.JSX.Element + trigger: React.JSX.Element | ((collapseIcon: React.JSX.Element | null) => React.JSX.Element) children: React.JSX.Element collapsed?: boolean onCollapse?: (collapsed: boolean) => void + operations?: ReactNode + hideCollapseIcon?: boolean } const Collapse = ({ disabled, @@ -17,34 +20,44 @@ const Collapse = ({ children, collapsed, onCollapse, + operations, + hideCollapseIcon, }: CollapseProps) => { const [collapsedLocal, setCollapsedLocal] = useState(true) const collapsedMerged = collapsed !== undefined ? collapsed : collapsedLocal + const collapseIcon = useMemo(() => { + if (disabled) + return null + return ( + + ) + }, [collapsedMerged, disabled]) return ( <> -
{ - if (!disabled) { - setCollapsedLocal(!collapsedMerged) - onCollapse?.(!collapsedMerged) - } - }} - > -
- { - !disabled && ( - - ) - } +
+
{ + if (!disabled) { + setCollapsedLocal(!collapsedMerged) + onCollapse?.(!collapsedMerged) + } + }} + > + {typeof trigger === 'function' ? trigger(collapseIcon) : trigger} + {!hideCollapseIcon && ( +
+ {collapseIcon} +
+ )}
- {trigger} + {operations}
{ !collapsedMerged && children diff --git a/web/app/components/workflow/nodes/_base/components/editor/base.tsx b/web/app/components/workflow/nodes/_base/components/editor/base.tsx index 3b31f44619..38968b2e0d 100644 --- a/web/app/components/workflow/nodes/_base/components/editor/base.tsx +++ b/web/app/components/workflow/nodes/_base/components/editor/base.tsx @@ -109,7 +109,7 @@ const Base: FC = ({ onHeightChange={setEditorContentHeight} hideResize={isExpand} > -
+
{children}
diff --git a/web/app/components/workflow/nodes/_base/components/editor/code-editor/style.css b/web/app/components/workflow/nodes/_base/components/editor/code-editor/style.css index 296ea0ab14..72e0087a3c 100644 --- a/web/app/components/workflow/nodes/_base/components/editor/code-editor/style.css +++ b/web/app/components/workflow/nodes/_base/components/editor/code-editor/style.css @@ -1,10 +1,3 @@ -.margin-view-overlays { - padding-left: 10px; -} - -.no-wrapper .margin-view-overlays { - padding-left: 0; -} .monaco-editor { background-color: transparent !important; diff --git a/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx b/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx index b36abbfb00..cfcbae80f3 100644 --- a/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx +++ b/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx @@ -49,20 +49,23 @@ const ErrorHandle = ({ disabled={!error_strategy} collapsed={collapsed} onCollapse={setCollapsed} + hideCollapseIcon trigger={ -
-
-
- {t('workflow.nodes.common.errorHandle.title')} + collapseIcon => ( +
+
+
+ {t('workflow.nodes.common.errorHandle.title')} +
+ + {collapseIcon}
- +
- -
- } + )} > <> { diff --git a/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-type-selector.tsx b/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-type-selector.tsx index 190c748831..d9516dfcf5 100644 --- a/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-type-selector.tsx +++ b/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-type-selector.tsx @@ -50,6 +50,7 @@ const ErrorHandleTypeSelector = ({ > { e.stopPropagation() + e.nativeEvent.stopImmediatePropagation() setOpen(v => !v) }}> + + )} + + + +
+
+
+ +
+
+ ) +} + +export default React.memo(CodeEditor) diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/error-message.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/error-message.tsx new file mode 100644 index 0000000000..2685182f9f --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/error-message.tsx @@ -0,0 +1,27 @@ +import React from 'react' +import type { FC } from 'react' +import { RiErrorWarningFill } from '@remixicon/react' +import classNames from '@/utils/classnames' + +type ErrorMessageProps = { + message: string +} & React.HTMLAttributes + +const ErrorMessage: FC = ({ + message, + className, +}) => { + return ( +
+ +
+ {message} +
+
+ ) +} + +export default React.memo(ErrorMessage) diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/index.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/index.tsx new file mode 100644 index 0000000000..d34836d5b2 --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/index.tsx @@ -0,0 +1,34 @@ +import React, { type FC } from 'react' +import Modal from '../../../../../base/modal' +import type { SchemaRoot } from '../../types' +import JsonSchemaConfig from './json-schema-config' + +type JsonSchemaConfigModalProps = { + isShow: boolean + defaultSchema?: SchemaRoot + onSave: (schema: SchemaRoot) => void + onClose: () => void +} + +const JsonSchemaConfigModal: FC = ({ + isShow, + defaultSchema, + onSave, + onClose, +}) => { + return ( + + + + ) +} + +export default JsonSchemaConfigModal diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-importer.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-importer.tsx new file mode 100644 index 0000000000..643059adbd --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-importer.tsx @@ -0,0 +1,136 @@ +import React, { type FC, useCallback, useEffect, useRef, useState } from 'react' +import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem' +import cn from '@/utils/classnames' +import { useTranslation } from 'react-i18next' +import { RiCloseLine } from '@remixicon/react' +import Button from '@/app/components/base/button' +import { checkJsonDepth } from '../../utils' +import { JSON_SCHEMA_MAX_DEPTH } from '@/config' +import CodeEditor from './code-editor' +import ErrorMessage from './error-message' +import { useVisualEditorStore } from './visual-editor/store' +import { useMittContext } from './visual-editor/context' + +type JsonImporterProps = { + onSubmit: (schema: any) => void + updateBtnWidth: (width: number) => void +} + +const JsonImporter: FC = ({ + onSubmit, + updateBtnWidth, +}) => { + const { t } = useTranslation() + const [open, setOpen] = useState(false) + const [json, setJson] = useState('') + const [parseError, setParseError] = useState(null) + const importBtnRef = useRef(null) + const advancedEditing = useVisualEditorStore(state => state.advancedEditing) + const isAddingNewField = useVisualEditorStore(state => state.isAddingNewField) + const { emit } = useMittContext() + + useEffect(() => { + if (importBtnRef.current) { + const rect = importBtnRef.current.getBoundingClientRect() + updateBtnWidth(rect.width) + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []) + + const handleTrigger = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + if (advancedEditing || isAddingNewField) + emit('quitEditing', {}) + setOpen(!open) + }, [open, advancedEditing, isAddingNewField, emit]) + + const onClose = useCallback(() => { + setOpen(false) + }, []) + + const handleSubmit = useCallback(() => { + try { + const parsedJSON = JSON.parse(json) + if (typeof parsedJSON !== 'object' || Array.isArray(parsedJSON)) { + setParseError(new Error('Root must be an object, not an array or primitive value.')) + return + } + const maxDepth = checkJsonDepth(parsedJSON) + if (maxDepth > JSON_SCHEMA_MAX_DEPTH) { + setParseError({ + type: 'error', + message: `Schema exceeds maximum depth of ${JSON_SCHEMA_MAX_DEPTH}.`, + }) + return + } + onSubmit(parsedJSON) + setParseError(null) + setOpen(false) + } + catch (e: any) { + if (e instanceof Error) + setParseError(e) + else + setParseError(new Error('Invalid JSON')) + } + }, [onSubmit, json]) + + return ( + + + + + +
+ {/* Title */} +
+
+ +
+
+ {t('workflow.nodes.llm.jsonSchema.import')} +
+
+ {/* Content */} +
+ + {parseError && } +
+ {/* Footer */} +
+ + +
+
+
+
+ ) +} + +export default JsonImporter diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-config.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-config.tsx new file mode 100644 index 0000000000..d125e31dae --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-config.tsx @@ -0,0 +1,301 @@ +import React, { type FC, useCallback, useState } from 'react' +import { type SchemaRoot, Type } from '../../types' +import { RiBracesLine, RiCloseLine, RiExternalLinkLine, RiTimelineView } from '@remixicon/react' +import { SegmentedControl } from '../../../../../base/segmented-control' +import JsonSchemaGenerator from './json-schema-generator' +import Divider from '@/app/components/base/divider' +import JsonImporter from './json-importer' +import { useTranslation } from 'react-i18next' +import Button from '@/app/components/base/button' +import VisualEditor from './visual-editor' +import SchemaEditor from './schema-editor' +import { + checkJsonSchemaDepth, + convertBooleanToString, + getValidationErrorMessage, + jsonToSchema, + preValidateSchema, + validateSchemaAgainstDraft7, +} from '../../utils' +import { MittProvider, VisualEditorContextProvider, useMittContext } from './visual-editor/context' +import ErrorMessage from './error-message' +import { useVisualEditorStore } from './visual-editor/store' +import Toast from '@/app/components/base/toast' +import { useGetLanguage } from '@/context/i18n' +import { JSON_SCHEMA_MAX_DEPTH } from '@/config' + +type JsonSchemaConfigProps = { + defaultSchema?: SchemaRoot + onSave: (schema: SchemaRoot) => void + onClose: () => void +} + +enum SchemaView { + VisualEditor = 'visualEditor', + JsonSchema = 'jsonSchema', +} + +const VIEW_TABS = [ + { Icon: RiTimelineView, text: 'Visual Editor', value: SchemaView.VisualEditor }, + { Icon: RiBracesLine, text: 'JSON Schema', value: SchemaView.JsonSchema }, +] + +const DEFAULT_SCHEMA: SchemaRoot = { + type: Type.object, + properties: {}, + required: [], + additionalProperties: false, +} + +const HELP_DOC_URL = { + zh_Hans: 'https://docs.dify.ai/zh-hans/guides/workflow/structured-outputs', + en_US: 'https://docs.dify.ai/guides/workflow/structured-outputs', + ja_JP: 'https://docs.dify.ai/ja-jp/guides/workflow/structured-outputs', +} + +type LocaleKey = keyof typeof HELP_DOC_URL + +const JsonSchemaConfig: FC = ({ + defaultSchema, + onSave, + onClose, +}) => { + const { t } = useTranslation() + const locale = useGetLanguage() as LocaleKey + const [currentTab, setCurrentTab] = useState(SchemaView.VisualEditor) + const [jsonSchema, setJsonSchema] = useState(defaultSchema || DEFAULT_SCHEMA) + const [json, setJson] = useState(JSON.stringify(jsonSchema, null, 2)) + const [btnWidth, setBtnWidth] = useState(0) + const [parseError, setParseError] = useState(null) + const [validationError, setValidationError] = useState('') + const advancedEditing = useVisualEditorStore(state => state.advancedEditing) + const setAdvancedEditing = useVisualEditorStore(state => state.setAdvancedEditing) + const isAddingNewField = useVisualEditorStore(state => state.isAddingNewField) + const setIsAddingNewField = useVisualEditorStore(state => state.setIsAddingNewField) + const setHoveringProperty = useVisualEditorStore(state => state.setHoveringProperty) + const { emit } = useMittContext() + + const updateBtnWidth = useCallback((width: number) => { + setBtnWidth(width + 32) + }, []) + + const handleTabChange = useCallback((value: SchemaView) => { + if (currentTab === value) return + if (currentTab === SchemaView.JsonSchema) { + try { + const schema = JSON.parse(json) + setParseError(null) + const result = preValidateSchema(schema) + if (!result.success) { + setValidationError(result.error.message) + return + } + const schemaDepth = checkJsonSchemaDepth(schema) + if (schemaDepth > JSON_SCHEMA_MAX_DEPTH) { + setValidationError(`Schema exceeds maximum depth of ${JSON_SCHEMA_MAX_DEPTH}.`) + return + } + convertBooleanToString(schema) + const validationErrors = validateSchemaAgainstDraft7(schema) + if (validationErrors.length > 0) { + setValidationError(getValidationErrorMessage(validationErrors)) + return + } + setJsonSchema(schema) + setValidationError('') + } + catch (error) { + setValidationError('') + if (error instanceof Error) + setParseError(error) + else + setParseError(new Error('Invalid JSON')) + return + } + } + else if (currentTab === SchemaView.VisualEditor) { + if (advancedEditing || isAddingNewField) + emit('quitEditing', { callback: (backup: SchemaRoot) => setJson(JSON.stringify(backup || jsonSchema, null, 2)) }) + else + setJson(JSON.stringify(jsonSchema, null, 2)) + } + + setCurrentTab(value) + }, [currentTab, jsonSchema, json, advancedEditing, isAddingNewField, emit]) + + const handleApplySchema = useCallback((schema: SchemaRoot) => { + if (currentTab === SchemaView.VisualEditor) + setJsonSchema(schema) + else if (currentTab === SchemaView.JsonSchema) + setJson(JSON.stringify(schema, null, 2)) + }, [currentTab]) + + const handleSubmit = useCallback((schema: any) => { + const jsonSchema = jsonToSchema(schema) as SchemaRoot + if (currentTab === SchemaView.VisualEditor) + setJsonSchema(jsonSchema) + else if (currentTab === SchemaView.JsonSchema) + setJson(JSON.stringify(jsonSchema, null, 2)) + }, [currentTab]) + + const handleVisualEditorUpdate = useCallback((schema: SchemaRoot) => { + setJsonSchema(schema) + }, []) + + const handleSchemaEditorUpdate = useCallback((schema: string) => { + setJson(schema) + }, []) + + const handleResetDefaults = useCallback(() => { + if (currentTab === SchemaView.VisualEditor) { + setHoveringProperty(null) + advancedEditing && setAdvancedEditing(false) + isAddingNewField && setIsAddingNewField(false) + } + setJsonSchema(DEFAULT_SCHEMA) + setJson(JSON.stringify(DEFAULT_SCHEMA, null, 2)) + }, [currentTab, advancedEditing, isAddingNewField, setAdvancedEditing, setIsAddingNewField, setHoveringProperty]) + + const handleCancel = useCallback(() => { + onClose() + }, [onClose]) + + const handleSave = useCallback(() => { + let schema = jsonSchema + if (currentTab === SchemaView.JsonSchema) { + try { + schema = JSON.parse(json) + setParseError(null) + const result = preValidateSchema(schema) + if (!result.success) { + setValidationError(result.error.message) + return + } + const schemaDepth = checkJsonSchemaDepth(schema) + if (schemaDepth > JSON_SCHEMA_MAX_DEPTH) { + setValidationError(`Schema exceeds maximum depth of ${JSON_SCHEMA_MAX_DEPTH}.`) + return + } + convertBooleanToString(schema) + const validationErrors = validateSchemaAgainstDraft7(schema) + if (validationErrors.length > 0) { + setValidationError(getValidationErrorMessage(validationErrors)) + return + } + setJsonSchema(schema) + setValidationError('') + } + catch (error) { + setValidationError('') + if (error instanceof Error) + setParseError(error) + else + setParseError(new Error('Invalid JSON')) + return + } + } + else if (currentTab === SchemaView.VisualEditor) { + if (advancedEditing || isAddingNewField) { + Toast.notify({ + type: 'warning', + message: t('workflow.nodes.llm.jsonSchema.warningTips.saveSchema'), + }) + return + } + } + onSave(schema) + onClose() + }, [currentTab, jsonSchema, json, onSave, onClose, advancedEditing, isAddingNewField, t]) + + return ( +
+ {/* Header */} +
+
+ {t('workflow.nodes.llm.jsonSchema.title')} +
+
+ +
+
+ {/* Content */} +
+ {/* Tab */} + + options={VIEW_TABS} + value={currentTab} + onChange={handleTabChange} + /> +
+ {/* JSON Schema Generator */} + + + {/* JSON Schema Importer */} + +
+
+
+ {currentTab === SchemaView.VisualEditor && ( + + )} + {currentTab === SchemaView.JsonSchema && ( + + )} + {parseError && } + {validationError && } +
+ {/* Footer */} +
+ + {t('workflow.nodes.llm.jsonSchema.doc')} + + +
+
+ + +
+
+ + +
+
+
+
+ ) +} + +const JsonSchemaConfigWrapper: FC = (props) => { + return ( + + + + + + ) +} + +export default JsonSchemaConfigWrapper diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/index.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/index.tsx new file mode 100644 index 0000000000..5f1f117086 --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/index.tsx @@ -0,0 +1,7 @@ +import SchemaGeneratorLight from './schema-generator-light' +import SchemaGeneratorDark from './schema-generator-dark' + +export { + SchemaGeneratorLight, + SchemaGeneratorDark, +} diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/schema-generator-dark.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/schema-generator-dark.tsx new file mode 100644 index 0000000000..ac4793b1e3 --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/schema-generator-dark.tsx @@ -0,0 +1,15 @@ +const SchemaGeneratorDark = () => { + return ( + + + + + + + + + + ) +} + +export default SchemaGeneratorDark diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/schema-generator-light.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/schema-generator-light.tsx new file mode 100644 index 0000000000..8b898bde68 --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/schema-generator-light.tsx @@ -0,0 +1,15 @@ +const SchemaGeneratorLight = () => { + return ( + + + + + + + + + + ) +} + +export default SchemaGeneratorLight diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/generated-result.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/generated-result.tsx new file mode 100644 index 0000000000..00f57237e5 --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/generated-result.tsx @@ -0,0 +1,121 @@ +import React, { type FC, useCallback, useMemo, useState } from 'react' +import type { SchemaRoot } from '../../../types' +import { RiArrowLeftLine, RiCloseLine, RiSparklingLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import Button from '@/app/components/base/button' +import CodeEditor from '../code-editor' +import ErrorMessage from '../error-message' +import { getValidationErrorMessage, validateSchemaAgainstDraft7 } from '../../../utils' +import Loading from '@/app/components/base/loading' + +type GeneratedResultProps = { + schema: SchemaRoot + isGenerating: boolean + onBack: () => void + onRegenerate: () => void + onClose: () => void + onApply: () => void +} + +const GeneratedResult: FC = ({ + schema, + isGenerating, + onBack, + onRegenerate, + onClose, + onApply, +}) => { + const { t } = useTranslation() + const [parseError, setParseError] = useState(null) + const [validationError, setValidationError] = useState('') + + const formatJSON = (json: SchemaRoot) => { + try { + const schema = JSON.stringify(json, null, 2) + setParseError(null) + return schema + } + catch (e) { + if (e instanceof Error) + setParseError(e) + else + setParseError(new Error('Invalid JSON')) + return '' + } + } + + const jsonSchema = useMemo(() => formatJSON(schema), [schema]) + + const handleApply = useCallback(() => { + const validationErrors = validateSchemaAgainstDraft7(schema) + if (validationErrors.length > 0) { + setValidationError(getValidationErrorMessage(validationErrors)) + return + } + onApply() + setValidationError('') + }, [schema, onApply]) + + return ( +
+ { + isGenerating ? ( +
+ +
{t('workflow.nodes.llm.jsonSchema.generating')}
+
+ ) : ( + <> +
+ +
+ {/* Title */} +
+
+ {t('workflow.nodes.llm.jsonSchema.generatedResult')} +
+
+ {t('workflow.nodes.llm.jsonSchema.resultTip')} +
+
+ {/* Content */} +
+ + {parseError && } + {validationError && } +
+ {/* Footer */} +
+ +
+ + +
+
+ + + ) + } +
+ ) +} + +export default React.memo(GeneratedResult) diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/index.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/index.tsx new file mode 100644 index 0000000000..4732499f3a --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/index.tsx @@ -0,0 +1,183 @@ +import React, { type FC, useCallback, useEffect, useState } from 'react' +import type { SchemaRoot } from '../../../types' +import { + PortalToFollowElem, + PortalToFollowElemContent, + PortalToFollowElemTrigger, +} from '@/app/components/base/portal-to-follow-elem' +import useTheme from '@/hooks/use-theme' +import type { CompletionParams, Model } from '@/types/app' +import { ModelModeType } from '@/types/app' +import { Theme } from '@/types/app' +import { SchemaGeneratorDark, SchemaGeneratorLight } from './assets' +import cn from '@/utils/classnames' +import type { ModelInfo } from './prompt-editor' +import PromptEditor from './prompt-editor' +import GeneratedResult from './generated-result' +import { useGenerateStructuredOutputRules } from '@/service/use-common' +import Toast from '@/app/components/base/toast' +import { type FormValue, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { useVisualEditorStore } from '../visual-editor/store' +import { useTranslation } from 'react-i18next' +import { useMittContext } from '../visual-editor/context' + +type JsonSchemaGeneratorProps = { + onApply: (schema: SchemaRoot) => void + crossAxisOffset?: number +} + +enum GeneratorView { + promptEditor = 'promptEditor', + result = 'result', +} + +export const JsonSchemaGenerator: FC = ({ + onApply, + crossAxisOffset, +}) => { + const { t } = useTranslation() + const [open, setOpen] = useState(false) + const [view, setView] = useState(GeneratorView.promptEditor) + const [model, setModel] = useState({ + name: '', + provider: '', + mode: ModelModeType.completion, + completion_params: {} as CompletionParams, + }) + const [instruction, setInstruction] = useState('') + const [schema, setSchema] = useState(null) + const { theme } = useTheme() + const { + defaultModel, + } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) + const advancedEditing = useVisualEditorStore(state => state.advancedEditing) + const isAddingNewField = useVisualEditorStore(state => state.isAddingNewField) + const { emit } = useMittContext() + const SchemaGenerator = theme === Theme.light ? SchemaGeneratorLight : SchemaGeneratorDark + + useEffect(() => { + if (defaultModel) { + setModel(prev => ({ + ...prev, + name: defaultModel.model, + provider: defaultModel.provider.provider, + })) + } + }, [defaultModel]) + + const handleTrigger = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + if (advancedEditing || isAddingNewField) + emit('quitEditing', {}) + setOpen(!open) + }, [open, advancedEditing, isAddingNewField, emit]) + + const onClose = useCallback(() => { + setOpen(false) + }, []) + + const handleModelChange = useCallback((model: ModelInfo) => { + setModel(prev => ({ + ...prev, + provider: model.provider, + name: model.modelId, + mode: model.mode as ModelModeType, + })) + }, []) + + const handleCompletionParamsChange = useCallback((newParams: FormValue) => { + setModel(prev => ({ + ...prev, + completion_params: newParams as CompletionParams, + }), + ) + }, []) + + const { mutateAsync: generateStructuredOutputRules, isPending: isGenerating } = useGenerateStructuredOutputRules() + + const generateSchema = useCallback(async () => { + const { output, error } = await generateStructuredOutputRules({ instruction, model_config: model! }) + if (error) { + Toast.notify({ + type: 'error', + message: error, + }) + setSchema(null) + setView(GeneratorView.promptEditor) + return + } + return output + }, [instruction, model, generateStructuredOutputRules]) + + const handleGenerate = useCallback(async () => { + setView(GeneratorView.result) + const output = await generateSchema() + if (output === undefined) return + setSchema(JSON.parse(output)) + }, [generateSchema]) + + const goBackToPromptEditor = () => { + setView(GeneratorView.promptEditor) + } + + const handleRegenerate = useCallback(async () => { + const output = await generateSchema() + if (output === undefined) return + setSchema(JSON.parse(output)) + }, [generateSchema]) + + const handleApply = () => { + onApply(schema!) + setOpen(false) + } + + return ( + + + + + + {view === GeneratorView.promptEditor && ( + + )} + {view === GeneratorView.result && ( + + )} + + + ) +} + +export default JsonSchemaGenerator diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/prompt-editor.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/prompt-editor.tsx new file mode 100644 index 0000000000..9387813ee5 --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/prompt-editor.tsx @@ -0,0 +1,108 @@ +import React, { useCallback } from 'react' +import type { FC } from 'react' +import { RiCloseLine, RiSparklingFill } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import Textarea from '@/app/components/base/textarea' +import Tooltip from '@/app/components/base/tooltip' +import Button from '@/app/components/base/button' +import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations' +import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' +import type { Model } from '@/types/app' + +export type ModelInfo = { + modelId: string + provider: string + mode?: string + features?: string[] +} + +type PromptEditorProps = { + instruction: string + model: Model + onInstructionChange: (instruction: string) => void + onCompletionParamsChange: (newParams: FormValue) => void + onModelChange: (model: ModelInfo) => void + onClose: () => void + onGenerate: () => void +} + +const PromptEditor: FC = ({ + instruction, + model, + onInstructionChange, + onCompletionParamsChange, + onClose, + onGenerate, + onModelChange, +}) => { + const { t } = useTranslation() + + const handleInstructionChange = useCallback((e: React.ChangeEvent) => { + onInstructionChange(e.target.value) + }, [onInstructionChange]) + + return ( +
+
+ +
+ {/* Title */} +
+
+ {t('workflow.nodes.llm.jsonSchema.generateJsonSchema')} +
+
+ {t('workflow.nodes.llm.jsonSchema.generationTip')} +
+
+ {/* Content */} +
+
+ {t('common.modelProvider.model')} +
+ +
+
+
+ {t('workflow.nodes.llm.jsonSchema.instruction')} + +
+
+