Merge remote-tracking branch 'company/feat/17150_2'

pull/18136/head
JF.Hsiong 1 year ago
commit 452fa1b4c6

@ -45,7 +45,17 @@ jobs:
run: uv sync --project api --dev
- name: Run Unit tests
run: uv run --project api bash dev/pytest/pytest_unit_tests.sh
run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report >> $GITHUB_STEP_SUMMARY
echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
- name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py

1
.gitignore vendored

@ -46,6 +46,7 @@ htmlcov/
.cache
nosetests.xml
coverage.xml
coverage.json
*.cover
*.py,cover
.hypothesis/

@ -165,6 +165,7 @@ MILVUS_URI=http://127.0.0.1:19530
MILVUS_TOKEN=
MILVUS_USER=root
MILVUS_PASSWORD=Milvus
MILVUS_ANALYZER_PARAMS=
# MyScale configuration
MYSCALE_HOST=127.0.0.1
@ -423,6 +424,12 @@ WORKFLOW_CALL_MAX_DEPTH=5
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
MAX_VARIABLE_SIZE=204800
# Workflow storage configuration
# Options: rdbms, hybrid
# rdbms: Use only the relational database (default)
# hybrid: Save new data to object storage, read from both object storage and RDBMS
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
# App configuration
APP_MAX_EXECUTION_TIME=1200
APP_MAX_ACTIVE_REQUESTS=0

@ -54,6 +54,7 @@ def initialize_extensions(app: DifyApp):
ext_otel,
ext_proxy_fix,
ext_redis,
ext_repositories,
ext_sentry,
ext_set_secretkey,
ext_storage,
@ -74,6 +75,7 @@ def initialize_extensions(app: DifyApp):
ext_migrate,
ext_redis,
ext_storage,
ext_repositories,
ext_celery,
ext_login,
ext_mail,

@ -12,7 +12,7 @@ from pydantic import (
)
from pydantic_settings import BaseSettings
from configs.feature.hosted_service import HostedServiceConfig
from .hosted_service import HostedServiceConfig
class SecurityConfig(BaseSettings):
@ -519,6 +519,11 @@ class WorkflowNodeExecutionConfig(BaseSettings):
default=100,
)
WORKFLOW_NODE_EXECUTION_STORAGE: str = Field(
default="rdbms",
description="Storage backend for WorkflowNodeExecution. Options: 'rdbms', 'hybrid'",
)
class AuthConfig(BaseSettings):
"""

@ -39,3 +39,8 @@ class MilvusConfig(BaseSettings):
"older versions",
default=True,
)
MILVUS_ANALYZER_PARAMS: Optional[str] = Field(
description='Milvus text analyzer parameters, e.g., {"type": "chinese"} for Chinese segmentation support.',
default=None,
)

@ -4,14 +4,10 @@ import platform
import re
import urllib.parse
import warnings
from collections.abc import Mapping
from typing import Any
from uuid import uuid4
import httpx
from constants import DEFAULT_FILE_NUMBER_LIMITS
try:
import magic
except ImportError:
@ -31,8 +27,6 @@ except ImportError:
from pydantic import BaseModel
from configs import dify_config
class FileInfo(BaseModel):
filename: str
@ -89,38 +83,3 @@ def guess_file_info_from_response(response: httpx.Response):
mimetype=mimetype,
size=int(response.headers.get("Content-Length", -1)),
)
def get_parameters_from_feature_dict(*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]):
return {
"opening_statement": features_dict.get("opening_statement"),
"suggested_questions": features_dict.get("suggested_questions", []),
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
"user_input_form": user_input_form,
"sensitive_word_avoidance": features_dict.get(
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
),
"file_upload": features_dict.get(
"file_upload",
{
"image": {
"enabled": False,
"number_limits": DEFAULT_FILE_NUMBER_LIMITS,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"],
}
},
),
"system_parameters": {
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
},
}

@ -1,5 +1,4 @@
from datetime import datetime
from dateutil.parser import isoparse
from flask_restful import Resource, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session
@ -41,10 +40,10 @@ class WorkflowAppLogApi(Resource):
args.status = WorkflowRunStatus(args.status) if args.status else None
if args.created_at__before:
args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00"))
args.created_at__before = isoparse(args.created_at__before)
if args.created_at__after:
args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00"))
args.created_at__after = isoparse(args.created_at__after)
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()

@ -74,7 +74,9 @@ class OAuthDataSourceBinding(Resource):
if not oauth_provider:
return {"error": "Invalid provider"}, 400
if "code" in request.args:
code = request.args.get("code")
code = request.args.get("code", "")
if not code:
return {"error": "Invalid code"}, 400
try:
oauth_provider.get_access_token(code)
except requests.exceptions.HTTPError as e:

@ -1,10 +1,10 @@
from flask_restful import marshal_with # type: ignore
from controllers.common import fields
from controllers.common import helpers as controller_helpers
from controllers.console import api
from controllers.console.app.error import AppUnavailableError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import AppMode, InstalledApp
from services.app_service import AppService
@ -36,9 +36,7 @@ class AppParameterApi(InstalledAppResource):
user_input_form = features_dict.get("user_input_form", [])
return controller_helpers.get_parameters_from_feature_dict(
features_dict=features_dict, user_input_form=user_input_form
)
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
class ExploreAppMetaApi(InstalledAppResource):

@ -13,6 +13,7 @@ from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocatio
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
from core.plugin.entities.request import (
RequestFetchAppInfo,
RequestInvokeApp,
RequestInvokeEncrypt,
RequestInvokeLLM,
@ -278,6 +279,17 @@ class PluginUploadFileRequestApi(Resource):
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
class PluginFetchAppInfoApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestFetchAppInfo)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestFetchAppInfo):
return BaseBackwardsInvocationResponse(
data=PluginAppBackwardsInvocation.fetch_app_info(payload.app_id, tenant_model.id)
).model_dump()
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
@ -291,3 +303,4 @@ api.add_resource(PluginInvokeAppApi, "/invoke/app")
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")
api.add_resource(PluginInvokeSummaryApi, "/invoke/summary")
api.add_resource(PluginUploadFileRequestApi, "/upload/file/request")
api.add_resource(PluginFetchAppInfoApi, "/fetch/app/info")

@ -1,10 +1,10 @@
from flask_restful import Resource, marshal_with # type: ignore
from controllers.common import fields
from controllers.common import helpers as controller_helpers
from controllers.service_api import api
from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import App, AppMode
from services.app_service import AppService
@ -32,9 +32,7 @@ class AppParameterApi(Resource):
user_input_form = features_dict.get("user_input_form", [])
return controller_helpers.get_parameters_from_feature_dict(
features_dict=features_dict, user_input_form=user_input_form
)
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
class AppMetaApi(Resource):

@ -1,6 +1,6 @@
import logging
from datetime import datetime
from dateutil.parser import isoparse
from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session
@ -140,10 +140,10 @@ class WorkflowAppLogApi(Resource):
args.status = WorkflowRunStatus(args.status) if args.status else None
if args.created_at__before:
args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00"))
args.created_at__before = isoparse(args.created_at__before)
if args.created_at__after:
args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00"))
args.created_at__after = isoparse(args.created_at__after)
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()

@ -139,7 +139,9 @@ class DatasetListApi(DatasetApiResource):
external_knowledge_id=args["external_knowledge_id"],
embedding_model_provider=args["embedding_model_provider"],
embedding_model_name=args["embedding_model"],
retrieval_model=RetrievalModel(**args["retrieval_model"]),
retrieval_model=RetrievalModel(**args["retrieval_model"])
if args["retrieval_model"] is not None
else None,
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()

@ -122,6 +122,8 @@ class SegmentApi(DatasetApiResource):
tenant_id=current_user.current_tenant_id,
status_list=args["status"],
keyword=args["keyword"],
page=page,
limit=limit,
)
response = {

@ -1,10 +1,10 @@
from flask_restful import marshal_with # type: ignore
from controllers.common import fields
from controllers.common import helpers as controller_helpers
from controllers.web import api
from controllers.web.error import AppUnavailableError
from controllers.web.wraps import WebApiResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import App, AppMode
from services.app_service import AppService
@ -31,9 +31,7 @@ class AppParameterApi(WebApiResource):
user_input_form = features_dict.get("user_input_form", [])
return controller_helpers.get_parameters_from_feature_dict(
features_dict=features_dict, user_input_form=user_input_form
)
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
class AppMeta(WebApiResource):

@ -46,6 +46,7 @@ class MessageListApi(WebApiResource):
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
}

@ -52,6 +52,7 @@ class AgentStrategyParameter(PluginParameter):
return cast_parameter_value(self, value)
type: AgentStrategyParameterType = Field(..., description="The type of the parameter")
help: Optional[I18nObject] = None
def init_frontend_parameter(self, value: Any):
return init_frontend_parameter(self, self.type, value)

@ -0,0 +1,45 @@
from collections.abc import Mapping
from typing import Any
from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
def get_parameters_from_feature_dict(
*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]
) -> Mapping[str, Any]:
"""
Mapping from feature dict to webapp parameters
"""
return {
"opening_statement": features_dict.get("opening_statement"),
"suggested_questions": features_dict.get("suggested_questions", []),
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
"user_input_form": user_input_form,
"sensitive_word_avoidance": features_dict.get(
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
),
"file_upload": features_dict.get(
"file_upload",
{
"image": {
"enabled": False,
"number_limits": DEFAULT_FILE_NUMBER_LIMITS,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"],
}
},
),
"system_parameters": {
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
},
}

@ -320,10 +320,9 @@ class AdvancedChatAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event
workflow_run=workflow_run, event=event
)
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@ -341,11 +340,10 @@ class AdvancedChatAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event
workflow_run=workflow_run, event=event
)
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@ -363,11 +361,10 @@ class AdvancedChatAppGenerateTaskPipeline:
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
session=session, event=event
event=event
)
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@ -383,18 +380,15 @@ class AdvancedChatAppGenerateTaskPipeline:
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
):
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
session=session, event=event
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
event=event
)
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
session.commit()
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_finish_resp:
yield node_finish_resp

@ -17,6 +17,7 @@ class BaseAppGenerator:
user_inputs: Optional[Mapping[str, Any]],
variables: Sequence["VariableEntity"],
tenant_id: str,
strict_type_validation: bool = False,
) -> Mapping[str, Any]:
user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values
@ -37,6 +38,7 @@ class BaseAppGenerator:
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
),
strict_type_validation=strict_type_validation,
)
for k, v in user_inputs.items()
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE

@ -92,6 +92,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
)
# convert to app config
@ -114,7 +115,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_config=app_config,
file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
user_inputs=inputs,
variables=app_config.variables,
tenant_id=app_model.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
),
files=list(system_files),
user_id=user.id,

@ -279,10 +279,9 @@ class WorkflowAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event
workflow_run=workflow_run, event=event
)
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@ -300,10 +299,9 @@ class WorkflowAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id
)
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event
workflow_run=workflow_run, event=event
)
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
@ -313,17 +311,14 @@ class WorkflowAppGenerateTaskPipeline:
if node_start_response:
yield node_start_response
elif isinstance(event, QueueNodeSucceededEvent):
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
session=session, event=event
)
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
session.commit()
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
event=event
)
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_success_response:
yield node_success_response
@ -334,18 +329,14 @@ class WorkflowAppGenerateTaskPipeline:
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
):
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
session=session,
event=event,
)
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
session.commit()
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
event=event,
)
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_failed_response:
yield node_failed_response
@ -627,6 +618,7 @@ class WorkflowAppGenerateTaskPipeline:
workflow_app_log.created_by = self._user_id
session.add(workflow_app_log)
session.commit()
def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: Optional[list[str]] = None

@ -6,7 +6,7 @@ from typing import Any, Optional, Union, cast
from uuid import uuid4
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
@ -49,12 +49,14 @@ from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.repository import RepositoryFactory
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
from models.model import EndUser
@ -80,6 +82,21 @@ class WorkflowCycleManage:
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
# Initialize the session factory and repository
# We use the global db engine instead of the session passed to methods
# Disable expire_on_commit to avoid the need for merging objects
self._session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": self._application_generate_entity.app_config.tenant_id,
"app_id": self._application_generate_entity.app_config.app_id,
"session_factory": self._session_factory,
}
)
# We'll still keep the cache for backward compatibility and performance
# but use the repository for database operations
def _handle_workflow_run_start(
self,
*,
@ -254,19 +271,15 @@ class WorkflowCycleManage:
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
stmt = select(WorkflowNodeExecution.node_execution_id).where(
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
# Use the instance repository to find running executions for a workflow run
running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions(
workflow_run_id=workflow_run.id
)
ids = session.scalars(stmt).all()
# Use self._get_workflow_node_execution here to make sure the cache is updated
running_workflow_node_executions = [
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
]
# Update the cache with the retrieved executions
for execution in running_workflow_node_executions:
if execution.node_execution_id:
self._workflow_node_executions[execution.node_execution_id] = execution
for workflow_node_execution in running_workflow_node_executions:
now = datetime.now(UTC).replace(tzinfo=None)
@ -288,7 +301,7 @@ class WorkflowCycleManage:
return workflow_run
def _handle_node_execution_start(
self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
) -> WorkflowNodeExecution:
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.id = str(uuid4())
@ -315,17 +328,14 @@ class WorkflowCycleManage:
)
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
session.add(workflow_node_execution)
# Use the instance repository to save the workflow node execution
self._workflow_node_execution_repository.save(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution
def _handle_workflow_node_execution_success(
self, *, session: Session, event: QueueNodeSucceededEvent
) -> WorkflowNodeExecution:
workflow_node_execution = self._get_workflow_node_execution(
session=session, node_execution_id=event.node_execution_id
)
def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
@ -344,13 +354,13 @@ class WorkflowCycleManage:
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution = session.merge(workflow_node_execution)
# Use the instance repository to update the workflow node execution
self._workflow_node_execution_repository.update(workflow_node_execution)
return workflow_node_execution
def _handle_workflow_node_execution_failed(
self,
*,
session: Session,
event: QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
@ -361,9 +371,7 @@ class WorkflowCycleManage:
:param event: queue node failed event
:return:
"""
workflow_node_execution = self._get_workflow_node_execution(
session=session, node_execution_id=event.node_execution_id
)
workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
@ -387,14 +395,14 @@ class WorkflowCycleManage:
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution
def _handle_workflow_node_execution_retried(
self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param workflow_run: workflow run
:param event: queue node failed event
:return:
"""
@ -439,15 +447,12 @@ class WorkflowCycleManage:
workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution.index = event.node_run_index
session.add(workflow_node_execution)
# Use the instance repository to save the workflow node execution
self._workflow_node_execution_repository.save(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution
#################################################
# to stream responses #
#################################################
def _workflow_start_to_stream_response(
self,
*,
@ -455,7 +460,6 @@ class WorkflowCycleManage:
task_id: str,
workflow_run: WorkflowRun,
) -> WorkflowStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return WorkflowStartStreamResponse(
task_id=task_id,
@ -521,14 +525,10 @@ class WorkflowCycleManage:
def _workflow_node_start_to_stream_response(
self,
*,
session: Session,
event: QueueNodeStartedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeStartStreamResponse]:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
if not workflow_node_execution.workflow_run_id:
@ -571,7 +571,6 @@ class WorkflowCycleManage:
def _workflow_node_finish_to_stream_response(
self,
*,
session: Session,
event: QueueNodeSucceededEvent
| QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
@ -580,8 +579,6 @@ class WorkflowCycleManage:
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
if not workflow_node_execution.workflow_run_id:
@ -621,13 +618,10 @@ class WorkflowCycleManage:
def _workflow_node_retry_to_stream_response(
self,
*,
session: Session,
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
if not workflow_node_execution.workflow_run_id:
@ -668,7 +662,6 @@ class WorkflowCycleManage:
def _workflow_parallel_branch_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return ParallelBranchStartStreamResponse(
task_id=task_id,
@ -692,7 +685,6 @@ class WorkflowCycleManage:
workflow_run: WorkflowRun,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
) -> ParallelBranchFinishedStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return ParallelBranchFinishedStreamResponse(
task_id=task_id,
@ -713,7 +705,6 @@ class WorkflowCycleManage:
def _workflow_iteration_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
) -> IterationNodeStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return IterationNodeStartStreamResponse(
task_id=task_id,
@ -735,7 +726,6 @@ class WorkflowCycleManage:
def _workflow_iteration_next_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
) -> IterationNodeNextStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return IterationNodeNextStreamResponse(
task_id=task_id,
@ -759,7 +749,6 @@ class WorkflowCycleManage:
def _workflow_iteration_completed_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
) -> IterationNodeCompletedStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return IterationNodeCompletedStreamResponse(
task_id=task_id,
@ -790,7 +779,6 @@ class WorkflowCycleManage:
def _workflow_loop_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent
) -> LoopNodeStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return LoopNodeStartStreamResponse(
task_id=task_id,
@ -812,7 +800,6 @@ class WorkflowCycleManage:
def _workflow_loop_next_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent
) -> LoopNodeNextStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return LoopNodeNextStreamResponse(
task_id=task_id,
@ -836,7 +823,6 @@ class WorkflowCycleManage:
def _workflow_loop_completed_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent
) -> LoopNodeCompletedStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return LoopNodeCompletedStreamResponse(
task_id=task_id,
@ -934,11 +920,22 @@ class WorkflowCycleManage:
return workflow_run
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
if node_execution_id not in self._workflow_node_executions:
def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
# First check the cache for performance
if node_execution_id in self._workflow_node_executions:
cached_execution = self._workflow_node_executions[node_execution_id]
# No need to merge with session since expire_on_commit=False
return cached_execution
# If not in cache, use the instance repository to get by node_execution_id
execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id)
if not execution:
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
return session.merge(cached_workflow_node_execution)
# Update cache
self._workflow_node_executions[node_execution_id] = execution
return execution
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
"""

@ -6,7 +6,6 @@ from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import DatasetRetrieverResource
class DatasetIndexToolCallbackHandler:
@ -71,29 +70,6 @@ class DatasetIndexToolCallbackHandler:
def return_retriever_resource_info(self, resource: list):
"""Handle return_retriever_resource_info."""
if resource and len(resource) > 0:
for item in resource:
dataset_retriever_resource = DatasetRetrieverResource(
message_id=self._message_id,
position=item.get("position") or 0,
dataset_id=item.get("dataset_id"),
dataset_name=item.get("dataset_name"),
document_id=item.get("document_id"),
document_name=item.get("document_name"),
data_source_type=item.get("data_source_type"),
segment_id=item.get("segment_id"),
score=item.get("score") if "score" in item else None,
hit_count=item.get("hit_count") if "hit_count" in item else None,
word_count=item.get("word_count") if "word_count" in item else None,
segment_position=item.get("segment_position") if "segment_position" in item else None,
index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None,
content=item.get("content"),
retriever_from=item.get("retriever_from"),
created_by=self._user_id,
)
db.session.add(dataset_retriever_resource)
db.session.commit()
self._queue_manager.publish(
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
)

@ -48,25 +48,26 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
)
if "ssl_verify" not in kwargs:
kwargs["ssl_verify"] = HTTP_REQUEST_NODE_SSL_VERIFY
ssl_verify = kwargs.pop("ssl_verify")
retries = 0
while retries <= max_retries:
try:
if dify_config.SSRF_PROXY_ALL_URL:
with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client:
with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=ssl_verify) as client:
response = client.request(method=method, url=url, **kwargs)
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxy_mounts = {
"http://": httpx.HTTPTransport(
proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
),
"https://": httpx.HTTPTransport(
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
),
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=ssl_verify),
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=ssl_verify),
}
with httpx.Client(mounts=proxy_mounts, verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client:
with httpx.Client(mounts=proxy_mounts, verify=ssl_verify) as client:
response = client.request(method=method, url=url, **kwargs)
else:
with httpx.Client(verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client:
with httpx.Client(verify=ssl_verify) as client:
response = client.request(method=method, url=url, **kwargs)
if response.status_code not in STATUS_FORCELIST:

@ -44,6 +44,7 @@ class TokenBufferMemory:
Message.created_at,
Message.workflow_run_id,
Message.parent_message_id,
Message.answer_tokens,
)
.filter(
Message.conversation_id == self.conversation.id,
@ -63,7 +64,7 @@ class TokenBufferMemory:
thread_messages = extract_thread_messages(messages)
# for newly created message, its answer is temporarily empty, we don't need to add it to memory
if thread_messages and not thread_messages[0].answer:
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
thread_messages.pop(0)
messages = list(reversed(thread_messages))

@ -1,5 +1,6 @@
import logging
import time
import uuid
from collections.abc import Generator, Sequence
from typing import Optional, Union
@ -24,6 +25,58 @@ from core.plugin.manager.model import PluginModelManager
logger = logging.getLogger(__name__)
def _gen_tool_call_id() -> str:
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
def _increase_tool_call(
new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall]
):
"""
Merge incremental tool call updates into existing tool calls.
:param new_tool_calls: List of new tool call deltas to be merged.
:param existing_tools_calls: List of existing tool calls to be modified IN-PLACE.
"""
def get_tool_call(tool_call_id: str):
"""
Get or create a tool call by ID
:param tool_call_id: tool call ID
:return: existing or new tool call
"""
if not tool_call_id:
return existing_tools_calls[-1]
_tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None)
if _tool_call is None:
_tool_call = AssistantPromptMessage.ToolCall(
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
)
existing_tools_calls.append(_tool_call)
return _tool_call
for new_tool_call in new_tool_calls:
# generate ID for tool calls with function name but no ID to track them
if new_tool_call.function.name and not new_tool_call.id:
new_tool_call.id = _gen_tool_call_id()
# get tool call
tool_call = get_tool_call(new_tool_call.id)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
class LargeLanguageModel(AIModel):
"""
Model class for large language model.
@ -109,44 +162,13 @@ class LargeLanguageModel(AIModel):
system_fingerprint = None
tools_calls: list[AssistantPromptMessage.ToolCall] = []
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
def get_tool_call(tool_name: str):
if not tool_name:
return tools_calls[-1]
tool_call = next(
(tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None
)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id="",
type="",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""),
)
tools_calls.append(tool_call)
return tool_call
for new_tool_call in new_tool_calls:
# get tool call
tool_call = get_tool_call(new_tool_call.function.name)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
for chunk in result:
if isinstance(chunk.delta.message.content, str):
content += chunk.delta.message.content
elif isinstance(chunk.delta.message.content, list):
content_list.extend(chunk.delta.message.content)
if chunk.delta.message.tool_calls:
increase_tool_call(chunk.delta.message.tool_calls)
_increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
usage = chunk.delta.usage or LLMUsage.empty_usage()
system_fingerprint = chunk.system_fingerprint

@ -5,6 +5,7 @@ from datetime import datetime, timedelta
from typing import Optional
from langfuse import Langfuse # type: ignore
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangfuseConfig
@ -28,9 +29,9 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
UnitEnum,
)
from core.ops.utils import filter_none_values
from core.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db
from models.model import EndUser
from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__)
@ -110,36 +111,18 @@ class LangFuseDataTrace(BaseTraceInstance):
)
self.add_trace(langfuse_trace_data=trace_data)
# through workflow_run_id get all_nodes_execution
workflow_nodes_execution_id_records = (
db.session.query(WorkflowNodeExecution.id)
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
.all()
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={"tenant_id": trace_info.tenant_id, "session_factory": session_factory},
)
for node_execution_id_record in workflow_nodes_execution_id_records:
node_execution = (
db.session.query(
WorkflowNodeExecution.id,
WorkflowNodeExecution.tenant_id,
WorkflowNodeExecution.app_id,
WorkflowNodeExecution.title,
WorkflowNodeExecution.node_type,
WorkflowNodeExecution.status,
WorkflowNodeExecution.inputs,
WorkflowNodeExecution.outputs,
WorkflowNodeExecution.created_at,
WorkflowNodeExecution.elapsed_time,
WorkflowNodeExecution.process_data,
WorkflowNodeExecution.execution_metadata,
)
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
.first()
)
if not node_execution:
continue
# Get all executions for this workflow run
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
workflow_run_id=trace_info.workflow_run_id
)
for node_execution in workflow_node_executions:
node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id
app_id = node_execution.app_id

@ -7,6 +7,7 @@ from typing import Optional, cast
from langsmith import Client
from langsmith.schemas import RunBase
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangSmithConfig
@ -27,9 +28,9 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
LangSmithRunUpdateModel,
)
from core.ops.utils import filter_none_values, generate_dotted_order
from core.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db
from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__)
@ -134,36 +135,22 @@ class LangSmithDataTrace(BaseTraceInstance):
self.add_run(langsmith_run)
# through workflow_run_id get all_nodes_execution
workflow_nodes_execution_id_records = (
db.session.query(WorkflowNodeExecution.id)
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
.all()
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": trace_info.tenant_id,
"app_id": trace_info.metadata.get("app_id"),
"session_factory": session_factory,
},
)
for node_execution_id_record in workflow_nodes_execution_id_records:
node_execution = (
db.session.query(
WorkflowNodeExecution.id,
WorkflowNodeExecution.tenant_id,
WorkflowNodeExecution.app_id,
WorkflowNodeExecution.title,
WorkflowNodeExecution.node_type,
WorkflowNodeExecution.status,
WorkflowNodeExecution.inputs,
WorkflowNodeExecution.outputs,
WorkflowNodeExecution.created_at,
WorkflowNodeExecution.elapsed_time,
WorkflowNodeExecution.process_data,
WorkflowNodeExecution.execution_metadata,
)
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
.first()
)
if not node_execution:
continue
# Get all executions for this workflow run
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
workflow_run_id=trace_info.workflow_run_id
)
for node_execution in workflow_node_executions:
node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id
app_id = node_execution.app_id

@ -7,6 +7,7 @@ from typing import Optional, cast
from opik import Opik, Trace
from opik.id_helpers import uuid4_to_uuid7
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import OpikConfig
@ -21,9 +22,9 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db
from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__)
@ -147,36 +148,22 @@ class OpikDataTrace(BaseTraceInstance):
}
self.add_trace(trace_data)
# through workflow_run_id get all_nodes_execution
workflow_nodes_execution_id_records = (
db.session.query(WorkflowNodeExecution.id)
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
.all()
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": trace_info.tenant_id,
"app_id": trace_info.metadata.get("app_id"),
"session_factory": session_factory,
},
)
for node_execution_id_record in workflow_nodes_execution_id_records:
node_execution = (
db.session.query(
WorkflowNodeExecution.id,
WorkflowNodeExecution.tenant_id,
WorkflowNodeExecution.app_id,
WorkflowNodeExecution.title,
WorkflowNodeExecution.node_type,
WorkflowNodeExecution.status,
WorkflowNodeExecution.inputs,
WorkflowNodeExecution.outputs,
WorkflowNodeExecution.created_at,
WorkflowNodeExecution.elapsed_time,
WorkflowNodeExecution.process_data,
WorkflowNodeExecution.execution_metadata,
)
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
.first()
)
if not node_execution:
continue
# Get all executions for this workflow run
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
workflow_run_id=trace_info.workflow_run_id
)
for node_execution in workflow_node_executions:
node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id
app_id = node_execution.app_id

@ -2,6 +2,7 @@ from collections.abc import Generator, Mapping
from typing import Optional, Union
from controllers.service_api.wraps import create_or_update_end_user_for_user_id
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
from core.app.apps.chat.app_generator import ChatAppGenerator
@ -15,6 +16,34 @@ from models.model import App, AppMode, EndUser
class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
@classmethod
def fetch_app_info(cls, app_id: str, tenant_id: str) -> Mapping:
"""
Fetch app info
"""
app = cls._get_app(app_id, tenant_id)
"""Retrieve app parameters."""
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app.workflow
if workflow is None:
raise ValueError("unexpected app type")
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app.app_model_config
if app_model_config is None:
raise ValueError("unexpected app type")
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
return {
"data": get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form),
}
@classmethod
def invoke_app(
cls,

@ -204,3 +204,11 @@ class RequestRequestUploadFile(BaseModel):
filename: str
mimetype: str
class RequestFetchAppInfo(BaseModel):
"""
Request to fetch app info
"""
app_id: str

@ -124,6 +124,15 @@ class ProviderManager:
# Get All preferred provider types of the workspace
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
# Ensure that both the original provider name and its ModelProviderID string representation
# are present in the dictionary to handle cases where either form might be used
for provider_name in list(provider_name_to_preferred_model_provider_records_dict.keys()):
provider_id = ModelProviderID(provider_name)
if str(provider_id) not in provider_name_to_preferred_model_provider_records_dict:
# Add the ModelProviderID string representation if it's not already present
provider_name_to_preferred_model_provider_records_dict[str(provider_id)] = (
provider_name_to_preferred_model_provider_records_dict[provider_name]
)
# Get All provider model settings
provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id)
@ -497,8 +506,8 @@ class ProviderManager:
@staticmethod
def _init_trial_provider_records(
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list]
) -> dict[str, list]:
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
) -> dict[str, list[Provider]]:
"""
Initialize trial provider records if not exists.
@ -532,7 +541,7 @@ class ProviderManager:
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
try:
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
provider_record = Provider(
new_provider_record = Provider(
tenant_id=tenant_id,
# TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name,
@ -542,11 +551,12 @@ class ProviderManager:
quota_used=0,
is_valid=True,
)
db.session.add(provider_record)
db.session.add(new_provider_record)
db.session.commit()
provider_name_to_provider_records_dict[provider_name].append(new_provider_record)
except IntegrityError:
db.session.rollback()
provider_record = (
existed_provider_record = (
db.session.query(Provider)
.filter(
Provider.tenant_id == tenant_id,
@ -556,11 +566,14 @@ class ProviderManager:
)
.first()
)
if provider_record and not provider_record.is_valid:
provider_record.is_valid = True
if not existed_provider_record:
continue
if not existed_provider_record.is_valid:
existed_provider_record.is_valid = True
db.session.commit()
provider_name_to_provider_records_dict[provider_name].append(provider_record)
provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
return provider_name_to_provider_records_dict

@ -246,7 +246,7 @@ class AnalyticdbVectorBySql:
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
FROM {self.table_name}
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
ORDER BY (score,id) DESC
ORDER BY score DESC, id DESC
LIMIT {top_k}""",
(f"'{query}'", f"'{query}'"),
)

@ -32,6 +32,7 @@ class MilvusConfig(BaseModel):
batch_size: int = 100 # Batch size for operations
database: str = "default" # Database name
enable_hybrid_search: bool = False # Flag to enable hybrid search
analyzer_params: Optional[str] = None # Analyzer params
@model_validator(mode="before")
@classmethod
@ -58,6 +59,7 @@ class MilvusConfig(BaseModel):
"user": self.user,
"password": self.password,
"db_name": self.database,
"analyzer_params": self.analyzer_params,
}
@ -300,14 +302,19 @@ class MilvusVector(BaseVector):
# Create the text field, enable_analyzer will be set True to support milvus automatically
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
fields.append(
FieldSchema(
Field.CONTENT_KEY.value,
DataType.VARCHAR,
max_length=65_535,
enable_analyzer=self._hybrid_search_enabled,
)
)
content_field_kwargs: dict[str, Any] = {
"max_length": 65_535,
"enable_analyzer": self._hybrid_search_enabled,
}
if (
self._hybrid_search_enabled
and self._client_config.analyzer_params is not None
and self._client_config.analyzer_params.strip()
):
content_field_kwargs["analyzer_params"] = self._client_config.analyzer_params
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, **content_field_kwargs))
# Create the primary key field
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
# Create the vector field, supports binary or float vectors
@ -383,5 +390,6 @@ class MilvusVectorFactory(AbstractVectorFactory):
password=dify_config.MILVUS_PASSWORD or "",
database=dify_config.MILVUS_DATABASE or "",
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
analyzer_params=dify_config.MILVUS_ANALYZER_PARAMS or "",
),
)

@ -39,6 +39,12 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
else:
return [GPT2Tokenizer.get_num_tokens(text) for text in texts]
def _character_encoder(texts: list[str]) -> list[int]:
if not texts:
return []
return [len(text) for text in texts]
if issubclass(cls, TokenTextSplitter):
extra_kwargs = {
"model_name": embedding_model_instance.model if embedding_model_instance else "gpt2",
@ -47,7 +53,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
}
kwargs = {**kwargs, **extra_kwargs}
return cls(length_function=_token_encoder, **kwargs)
return cls(length_function=_character_encoder, **kwargs)
class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
@ -103,7 +109,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
_good_splits_lengths = [] # cache the lengths of the splits
_separator = "" if self._keep_separator else separator
s_lens = self._length_function(splits)
if _separator != "":
if separator != "":
for s, s_len in zip(splits, s_lens):
if s_len < self._chunk_size:
_good_splits.append(s)

@ -0,0 +1,15 @@
"""
Repository interfaces for data access.
This package contains repository interfaces that define the contract
for accessing and manipulating data, regardless of the underlying
storage mechanism.
"""
from core.repository.repository_factory import RepositoryFactory
from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
__all__ = [
"RepositoryFactory",
"WorkflowNodeExecutionRepository",
]

@ -0,0 +1,97 @@
"""
Repository factory for creating repository instances.
This module provides a simple factory interface for creating repository instances.
It does not contain any implementation details or dependencies on specific repositories.
"""
from collections.abc import Callable, Mapping
from typing import Any, Literal, Optional, cast
from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
# Type for factory functions - takes a dict of parameters and returns any repository type
RepositoryFactoryFunc = Callable[[Mapping[str, Any]], Any]
# Type for workflow node execution factory function
WorkflowNodeExecutionFactoryFunc = Callable[[Mapping[str, Any]], WorkflowNodeExecutionRepository]
# Repository type literals
_RepositoryType = Literal["workflow_node_execution"]
class RepositoryFactory:
"""
Factory class for creating repository instances.
This factory delegates the actual repository creation to implementation-specific
factory functions that are registered with the factory at runtime.
"""
# Dictionary to store factory functions
_factory_functions: dict[str, RepositoryFactoryFunc] = {}
@classmethod
def _register_factory(cls, repository_type: _RepositoryType, factory_func: RepositoryFactoryFunc) -> None:
"""
Register a factory function for a specific repository type.
This is a private method and should not be called directly.
Args:
repository_type: The type of repository (e.g., 'workflow_node_execution')
factory_func: A function that takes parameters and returns a repository instance
"""
cls._factory_functions[repository_type] = factory_func
@classmethod
def _create_repository(cls, repository_type: _RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any:
"""
Create a new repository instance with the provided parameters.
This is a private method and should not be called directly.
Args:
repository_type: The type of repository to create
params: A dictionary of parameters to pass to the factory function
Returns:
A new instance of the requested repository
Raises:
ValueError: If no factory function is registered for the repository type
"""
if repository_type not in cls._factory_functions:
raise ValueError(f"No factory function registered for repository type '{repository_type}'")
# Use empty dict if params is None
params = params or {}
return cls._factory_functions[repository_type](params)
@classmethod
def register_workflow_node_execution_factory(cls, factory_func: WorkflowNodeExecutionFactoryFunc) -> None:
"""
Register a factory function for the workflow node execution repository.
Args:
factory_func: A function that takes parameters and returns a WorkflowNodeExecutionRepository instance
"""
cls._register_factory("workflow_node_execution", factory_func)
@classmethod
def create_workflow_node_execution_repository(
cls, params: Optional[Mapping[str, Any]] = None
) -> WorkflowNodeExecutionRepository:
"""
Create a new WorkflowNodeExecutionRepository instance with the provided parameters.
Args:
params: A dictionary of parameters to pass to the factory function
Returns:
A new instance of the WorkflowNodeExecutionRepository
Raises:
ValueError: If no factory function is registered for the workflow_node_execution repository type
"""
# We can safely cast here because we've registered a WorkflowNodeExecutionFactoryFunc
return cast(WorkflowNodeExecutionRepository, cls._create_repository("workflow_node_execution", params))

@ -0,0 +1,88 @@
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Literal, Optional, Protocol
from models.workflow import WorkflowNodeExecution
@dataclass
class OrderConfig:
"""Configuration for ordering WorkflowNodeExecution instances."""
order_by: list[str]
order_direction: Optional[Literal["asc", "desc"]] = None
class WorkflowNodeExecutionRepository(Protocol):
"""
Repository interface for WorkflowNodeExecution.
This interface defines the contract for accessing and manipulating
WorkflowNodeExecution data, regardless of the underlying storage mechanism.
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
and trigger sources (triggered_from) should be handled at the implementation level, not in
the core interface. This keeps the core domain model clean and independent of specific
application domains or deployment scenarios.
"""
def save(self, execution: WorkflowNodeExecution) -> None:
"""
Save a WorkflowNodeExecution instance.
Args:
execution: The WorkflowNodeExecution instance to save
"""
...
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
"""
Retrieve a WorkflowNodeExecution by its node_execution_id.
Args:
node_execution_id: The node execution ID
Returns:
The WorkflowNodeExecution instance if found, None otherwise
"""
...
def get_by_workflow_run(
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
Args:
workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
order_config.order_direction: Direction to order ("asc" or "desc")
Returns:
A list of WorkflowNodeExecution instances
"""
...
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
Args:
workflow_run_id: The workflow run ID
Returns:
A list of running WorkflowNodeExecution instances
"""
...
def update(self, execution: WorkflowNodeExecution) -> None:
"""
Update an existing WorkflowNodeExecution instance.
Args:
execution: The WorkflowNodeExecution instance to update
"""
...

@ -155,9 +155,28 @@ class AnswerStreamProcessor(StreamProcessor):
for answer_node_id, route_position in self.route_position.items():
if answer_node_id not in self.rest_node_ids:
continue
# exclude current node id
# Remove current node id from answer dependencies to support stream output if it is a success branch
answer_dependencies = self.generate_routes.answer_dependencies
if event.node_id in answer_dependencies[answer_node_id]:
edge_mapping = self.graph.edge_mapping.get(event.node_id)
success_edge = (
next(
(
edge
for edge in edge_mapping
if edge.run_condition
and edge.run_condition.type == "branch_identify"
and edge.run_condition.branch_identify == "success-branch"
),
None,
)
if edge_mapping
else None
)
if (
event.node_id in answer_dependencies[answer_node_id]
and success_edge
and success_edge.target_node_id == answer_node_id
):
answer_dependencies[answer_node_id].remove(event.node_id)
answer_dependencies_ids = answer_dependencies.get(answer_node_id, [])
# all depends on answer node id not in rest node ids

@ -90,6 +90,7 @@ class HttpRequestNodeData(BaseNodeData):
params: str
body: Optional[HttpRequestNodeBody] = None
timeout: Optional[HttpRequestNodeTimeout] = None
ssl_verify: Optional[bool] = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
class Response:

@ -88,6 +88,7 @@ class Executor:
self.method = node_data.method
self.auth = node_data.authorization
self.timeout = timeout
self.ssl_verify = node_data.ssl_verify
self.params = []
self.headers = {}
self.content = None
@ -316,6 +317,7 @@ class Executor:
"headers": headers,
"params": self.params,
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
"ssl_verify": self.ssl_verify,
"follow_redirects": True,
"max_retries": self.max_retries,
}

@ -51,6 +51,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
"max_read_timeout": dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT,
"max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
},
"ssl_verify": dify_config.HTTP_REQUEST_NODE_SSL_VERIFY,
},
"retry_config": {
"max_retries": dify_config.SSRF_DEFAULT_MAX_RETRIES,

@ -116,20 +116,10 @@ class MetadataFilteringCondition(BaseModel):
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
class MetadataFilteringComplexSubCondition(BaseModel):
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
sub_conditions: Optional[list["MetadataFilteringComplexSubCondition"]] = None
class MetadataFilteringComplexCondition(BaseModel):
"""
Complex Metadata Filtering Condition.
"""
logical_operator: Optional[Literal["and", "or"]] = "and"
conditions: Optional[list[MetadataFilteringComplexSubCondition]] = Field(default=None, deprecated=True)
conditions: Optional[list[Condition]] = Field(default=None, deprecated=True)
sub_conditions: Optional[list["MetadataFilteringComplexCondition"]] = None
class KnowledgeRetrievalNodeData(BaseNodeData):

@ -47,7 +47,6 @@ from services.feature_service import FeatureService
from .entities import (
KnowledgeRetrievalNodeData,
MetadataFilteringComplexCondition,
MetadataFilteringComplexSubCondition,
ModelConfig,
)
from .exc import (
@ -322,7 +321,7 @@ class KnowledgeRetrievalNode(LLMNode):
return retrieval_resource_list
def _recursive_metadata_filter(
self, metadata_filtering_complex_conditions: MetadataFilteringComplexSubCondition, filters
self, metadata_filtering_complex_conditions: MetadataFilteringComplexCondition, filters
):
logical_operator = metadata_filtering_complex_conditions.logical_operator
conditions = metadata_filtering_complex_conditions.conditions
@ -331,8 +330,8 @@ class KnowledgeRetrievalNode(LLMNode):
sub_filters = []
if sub_conditions:
for sub_condition in sub_conditions:
sub_filter = self._recursive_metadata_filter(sub_condition, filters)
sub_filters.append(sub_filter)
sub_filter = self._recursive_metadata_filter(sub_condition, [])
sub_filters.extend(sub_filter)
temp_filters: list = []
if conditions:
@ -357,16 +356,27 @@ class KnowledgeRetrievalNode(LLMNode):
expected_value,
temp_filters,
)
temp_filters_result: ColumnElement[bool]
if temp_filters:
sub_filters_result: ColumnElement
temp_filters_result: ColumnElement
if temp_filters and sub_filters:
temp_all_filters = sub_filters +temp_filters
if logical_operator == "and": # type: ignore
sub_filters_result = and_(*temp_all_filters)
else:
sub_filters_result = or_(*temp_all_filters)
filters.append(sub_filters_result)
return filters
if temp_filters: # text
if logical_operator == "and": # type: ignore
temp_filters_result = and_(*temp_filters)
else:
temp_filters_result = or_(*temp_filters)
filters.append(temp_filters_result)
return filters
sub_filters_result: ColumnElement[bool]
if sub_filters:
if sub_filters: # Boolean
if logical_operator == "and": # type: ignore
sub_filters_result = and_(*sub_filters)
else:
@ -375,6 +385,7 @@ class KnowledgeRetrievalNode(LLMNode):
return filters
def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
@ -392,17 +403,13 @@ class KnowledgeRetrievalNode(LLMNode):
# todo: do not support external_knowledge_retrieval
if node_data.metadata_filtering_complex_conditions:
# Enable forward references
MetadataFilteringComplexSubCondition.model_rebuild()
MetadataFilteringComplexCondition.model_rebuild()
metadata_filtering_complex_conditions = MetadataFilteringComplexCondition(
**node_data.metadata_filtering_complex_conditions.model_dump()
)
for condition in metadata_filtering_complex_conditions.conditions: # type: ignore
filters = self._recursive_metadata_filter(condition, filters)
filters = self._recursive_metadata_filter(metadata_filtering_complex_conditions, filters)
if filters:
if metadata_filtering_complex_conditions.logical_operator == "and": # type: ignore
document_query = document_query.filter(and_(*filters))
else:
document_query = document_query.filter(or_(*filters))
document_query = document_query.filter(*filters)
documents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore

@ -149,7 +149,10 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
def _extract_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text) - 1
value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text)
if value < 1:
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
value -= 1
if len(variable.value) > int(value):
result = variable.value[value]
else:

@ -26,9 +26,12 @@ def init_app(app: DifyApp):
# Always add StreamHandler to log to console
sh = logging.StreamHandler(sys.stdout)
sh.addFilter(RequestIdFilter())
log_handlers.append(sh)
# Apply RequestIdFilter to all handlers
for handler in log_handlers:
handler.addFilter(RequestIdFilter())
logging.basicConfig(
level=dify_config.LOG_LEVEL,
format=dify_config.LOG_FORMAT,

@ -0,0 +1,18 @@
"""
Extension for initializing repositories.
This extension registers repository implementations with the RepositoryFactory.
"""
from dify_app import DifyApp
from repositories.repository_registry import register_repositories
def init_app(_app: DifyApp) -> None:
"""
Initialize repository implementations.
Args:
_app: The Flask application instance (unused)
"""
register_repositories()

@ -73,11 +73,7 @@ class Storage:
raise ValueError(f"unsupported storage type {storage_type}")
def save(self, filename, data):
try:
self.storage_runner.save(filename, data)
except Exception as e:
logger.exception(f"Failed to save file {filename}")
raise e
self.storage_runner.save(filename, data)
@overload
def load(self, filename: str, /, *, stream: Literal[False] = False) -> bytes: ...
@ -86,49 +82,25 @@ class Storage:
def load(self, filename: str, /, *, stream: Literal[True]) -> Generator: ...
def load(self, filename: str, /, *, stream: bool = False) -> Union[bytes, Generator]:
try:
if stream:
return self.load_stream(filename)
else:
return self.load_once(filename)
except Exception as e:
logger.exception(f"Failed to load file {filename}")
raise e
if stream:
return self.load_stream(filename)
else:
return self.load_once(filename)
def load_once(self, filename: str) -> bytes:
try:
return self.storage_runner.load_once(filename)
except Exception as e:
logger.exception(f"Failed to load_once file {filename}")
raise e
return self.storage_runner.load_once(filename)
def load_stream(self, filename: str) -> Generator:
try:
return self.storage_runner.load_stream(filename)
except Exception as e:
logger.exception(f"Failed to load_stream file {filename}")
raise e
return self.storage_runner.load_stream(filename)
def download(self, filename, target_filepath):
try:
self.storage_runner.download(filename, target_filepath)
except Exception as e:
logger.exception(f"Failed to download file {filename}")
raise e
self.storage_runner.download(filename, target_filepath)
def exists(self, filename):
try:
return self.storage_runner.exists(filename)
except Exception as e:
logger.exception(f"Failed to check file exists {filename}")
raise e
return self.storage_runner.exists(filename)
def delete(self, filename):
try:
return self.storage_runner.delete(filename)
except Exception as e:
logger.exception(f"Failed to delete file {filename}")
raise e
return self.storage_runner.delete(filename)
storage = Storage()

@ -52,6 +52,7 @@ def build_from_mapping(
mapping: Mapping[str, Any],
tenant_id: str,
config: FileUploadConfig | None = None,
strict_type_validation: bool = False,
) -> File:
transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
@ -69,6 +70,7 @@ def build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
transfer_method=transfer_method,
strict_type_validation=strict_type_validation,
)
if config and not _is_file_valid_with_config(
@ -87,12 +89,14 @@ def build_from_mappings(
mappings: Sequence[Mapping[str, Any]],
config: FileUploadConfig | None = None,
tenant_id: str,
strict_type_validation: bool = False,
) -> Sequence[File]:
files = [
build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
config=config,
strict_type_validation=strict_type_validation,
)
for mapping in mappings
]
@ -116,6 +120,7 @@ def _build_from_local_file(
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
upload_file_id = mapping.get("upload_file_id")
if not upload_file_id:
@ -134,10 +139,16 @@ def _build_from_local_file(
if row is None:
raise ValueError("Invalid upload file")
file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
if file_type.value != mapping.get("type", "custom"):
detected_file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
specified_type = mapping.get("type", "custom")
if strict_type_validation and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
file_type = (
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
)
return File(
id=mapping.get("id"),
filename=row.name,
@ -158,6 +169,7 @@ def _build_from_remote_url(
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
upload_file_id = mapping.get("upload_file_id")
if upload_file_id:
@ -174,10 +186,21 @@ def _build_from_remote_url(
if upload_file is None:
raise ValueError("Invalid upload file")
file_type = _standardize_file_type(extension="." + upload_file.extension, mime_type=upload_file.mime_type)
if file_type.value != mapping.get("type", "custom"):
detected_file_type = _standardize_file_type(
extension="." + upload_file.extension, mime_type=upload_file.mime_type
)
specified_type = mapping.get("type")
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
file_type = (
FileType(specified_type)
if specified_type and specified_type != FileType.CUSTOM.value
else detected_file_type
)
return File(
id=mapping.get("id"),
filename=upload_file.name,
@ -237,6 +260,7 @@ def _build_from_tool_file(
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
tool_file = (
db.session.query(ToolFile)
@ -252,7 +276,16 @@ def _build_from_tool_file(
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
detected_file_type = _standardize_file_type(extension="." + extension, mime_type=tool_file.mimetype)
specified_type = mapping.get("type")
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
file_type = (
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
)
return File(
id=mapping.get("id"),

@ -1091,12 +1091,7 @@ class Message(db.Model): # type: ignore[name-defined]
@property
def retriever_resources(self):
return (
db.session.query(DatasetRetrieverResource)
.filter(DatasetRetrieverResource.message_id == self.id)
.order_by(DatasetRetrieverResource.position.asc())
.all()
)
return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []
@property
def message_files(self):

@ -510,7 +510,7 @@ class WorkflowRun(Base):
)
class WorkflowNodeExecutionTriggeredFrom(Enum):
class WorkflowNodeExecutionTriggeredFrom(StrEnum):
"""
Workflow Node Execution Triggered From Enum
"""
@ -518,21 +518,8 @@ class WorkflowNodeExecutionTriggeredFrom(Enum):
SINGLE_STEP = "single-step"
WORKFLOW_RUN = "workflow-run"
@classmethod
def value_of(cls, value: str) -> "WorkflowNodeExecutionTriggeredFrom":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid workflow node execution triggered from value {value}")
class WorkflowNodeExecutionStatus(Enum):
class WorkflowNodeExecutionStatus(StrEnum):
"""
Workflow Node Execution Status Enum
"""
@ -543,19 +530,6 @@ class WorkflowNodeExecutionStatus(Enum):
EXCEPTION = "exception"
RETRY = "retry"
@classmethod
def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid workflow node execution status value {value}")
class WorkflowNodeExecution(Base):
"""

@ -1,4 +0,0 @@
[virtualenvs]
in-project = true
create = true
prefer-active-python = true

@ -104,6 +104,7 @@ dev = [
"ruff~=0.11.5",
"pytest~=8.3.2",
"pytest-benchmark~=4.0.0",
"pytest-cov~=4.1.0",
"pytest-env~=1.1.3",
"pytest-mock~=3.14.0",
"types-aiofiles~=24.1.0",

@ -1,5 +1,6 @@
[pytest]
continue-on-collection-errors = true
addopts = --cov=./api --cov-report=json --cov-report=xml
env =
ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com

@ -0,0 +1,6 @@
"""
Repository implementations for data access.
This package contains concrete implementations of the repository interfaces
defined in the core.repository package.
"""

@ -0,0 +1,87 @@
"""
Registry for repository implementations.
This module is responsible for registering factory functions with the repository factory.
"""
import logging
from collections.abc import Mapping
from typing import Any
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db
from repositories.workflow_node_execution import SQLAlchemyWorkflowNodeExecutionRepository
logger = logging.getLogger(__name__)
# Storage type constants
STORAGE_TYPE_RDBMS = "rdbms"
STORAGE_TYPE_HYBRID = "hybrid"
def register_repositories() -> None:
"""
Register repository factory functions with the RepositoryFactory.
This function reads configuration settings to determine which repository
implementations to register.
"""
# Configure WorkflowNodeExecutionRepository factory based on configuration
workflow_node_execution_storage = dify_config.WORKFLOW_NODE_EXECUTION_STORAGE
# Check storage type and register appropriate implementation
if workflow_node_execution_storage == STORAGE_TYPE_RDBMS:
# Register SQLAlchemy implementation for RDBMS storage
logger.info("Registering WorkflowNodeExecution repository with RDBMS storage")
RepositoryFactory.register_workflow_node_execution_factory(create_workflow_node_execution_repository)
elif workflow_node_execution_storage == STORAGE_TYPE_HYBRID:
# Hybrid storage is not yet implemented
raise NotImplementedError("Hybrid storage for WorkflowNodeExecution repository is not yet implemented")
else:
# Unknown storage type
raise ValueError(
f"Unknown storage type '{workflow_node_execution_storage}' for WorkflowNodeExecution repository. "
f"Supported types: {STORAGE_TYPE_RDBMS}"
)
def create_workflow_node_execution_repository(params: Mapping[str, Any]) -> SQLAlchemyWorkflowNodeExecutionRepository:
"""
Create a WorkflowNodeExecutionRepository instance using SQLAlchemy implementation.
This factory function creates a repository for the RDBMS storage type.
Args:
params: Parameters for creating the repository, including:
- tenant_id: Required. The tenant ID for multi-tenancy.
- app_id: Optional. The application ID for filtering.
- session_factory: Optional. A SQLAlchemy sessionmaker instance. If not provided,
a new sessionmaker will be created using the global database engine.
Returns:
A WorkflowNodeExecutionRepository instance
Raises:
ValueError: If required parameters are missing
"""
# Extract required parameters
tenant_id = params.get("tenant_id")
if tenant_id is None:
raise ValueError("tenant_id is required for WorkflowNodeExecution repository with RDBMS storage")
# Extract optional parameters
app_id = params.get("app_id")
# Use the session_factory from params if provided, otherwise create one using the global db engine
session_factory = params.get("session_factory")
if session_factory is None:
# Create a sessionmaker using the same engine as the global db session
session_factory = sessionmaker(bind=db.engine)
# Create and return the repository
return SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory, tenant_id=tenant_id, app_id=app_id
)

@ -0,0 +1,9 @@
"""
WorkflowNodeExecution repository implementations.
"""
from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
__all__ = [
"SQLAlchemyWorkflowNodeExecutionRepository",
]

@ -0,0 +1,170 @@
"""
SQLAlchemy implementation of the WorkflowNodeExecutionRepository.
"""
import logging
from collections.abc import Sequence
from typing import Optional
from sqlalchemy import UnaryExpression, asc, desc, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.repository.workflow_node_execution_repository import OrderConfig
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
class SQLAlchemyWorkflowNodeExecutionRepository:
"""
SQLAlchemy implementation of the WorkflowNodeExecutionRepository interface.
This implementation supports multi-tenancy by filtering operations based on tenant_id.
Each method creates its own session, handles the transaction, and commits changes
to the database. This prevents long-running connections in the workflow core.
"""
def __init__(self, session_factory: sessionmaker | Engine, tenant_id: str, app_id: Optional[str] = None):
"""
Initialize the repository with a SQLAlchemy sessionmaker or engine and tenant context.
Args:
session_factory: SQLAlchemy sessionmaker or engine for creating sessions
tenant_id: Tenant ID for multi-tenancy
app_id: Optional app ID for filtering by application
"""
# If an engine is provided, create a sessionmaker from it
if isinstance(session_factory, Engine):
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
else:
self._session_factory = session_factory
self._tenant_id = tenant_id
self._app_id = app_id
def save(self, execution: WorkflowNodeExecution) -> None:
"""
Save a WorkflowNodeExecution instance and commit changes to the database.
Args:
execution: The WorkflowNodeExecution instance to save
"""
with self._session_factory() as session:
# Ensure tenant_id is set
if not execution.tenant_id:
execution.tenant_id = self._tenant_id
# Set app_id if provided and not already set
if self._app_id and not execution.app_id:
execution.app_id = self._app_id
session.add(execution)
session.commit()
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
"""
Retrieve a WorkflowNodeExecution by its node_execution_id.
Args:
node_execution_id: The node execution ID
Returns:
The WorkflowNodeExecution instance if found, None otherwise
"""
with self._session_factory() as session:
stmt = select(WorkflowNodeExecution).where(
WorkflowNodeExecution.node_execution_id == node_execution_id,
WorkflowNodeExecution.tenant_id == self._tenant_id,
)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
return session.scalar(stmt)
def get_by_workflow_run(
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
Args:
workflow_run_id: The workflow run ID
order_config: Optional configuration for ordering results
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
order_config.order_direction: Direction to order ("asc" or "desc")
Returns:
A list of WorkflowNodeExecution instances
"""
with self._session_factory() as session:
stmt = select(WorkflowNodeExecution).where(
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
WorkflowNodeExecution.tenant_id == self._tenant_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
# Apply ordering if provided
if order_config and order_config.order_by:
order_columns: list[UnaryExpression] = []
for field in order_config.order_by:
column = getattr(WorkflowNodeExecution, field, None)
if not column:
continue
if order_config.order_direction == "desc":
order_columns.append(desc(column))
else:
order_columns.append(asc(column))
if order_columns:
stmt = stmt.order_by(*order_columns)
return session.scalars(stmt).all()
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
"""
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
Args:
workflow_run_id: The workflow run ID
Returns:
A list of running WorkflowNodeExecution instances
"""
with self._session_factory() as session:
stmt = select(WorkflowNodeExecution).where(
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
WorkflowNodeExecution.tenant_id == self._tenant_id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
if self._app_id:
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
return session.scalars(stmt).all()
def update(self, execution: WorkflowNodeExecution) -> None:
"""
Update an existing WorkflowNodeExecution instance and commit changes to the database.
Args:
execution: The WorkflowNodeExecution instance to update
"""
with self._session_factory() as session:
# Ensure tenant_id is set
if not execution.tenant_id:
execution.tenant_id = self._tenant_id
# Set app_id if provided and not already set
if self._app_id and not execution.app_id:
execution.app_id = self._app_id
session.merge(execution)
session.commit()

@ -553,7 +553,7 @@ class DocumentService:
{"id": "remove_extra_spaces", "enabled": True},
{"id": "remove_urls_emails", "enabled": False},
],
"segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50},
"segmentation": {"delimiter": "\n", "max_tokens": 1024, "chunk_overlap": 50},
},
"limits": {
"indexing_max_segmentation_tokens_length": dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH,
@ -2025,7 +2025,7 @@ class SegmentService:
dataset_id=dataset.id,
document_id=document.id,
segment_id=segment.id,
position=max_position + 1,
position=max_position + 1 if max_position else 1,
index_node_id=index_node_id,
index_node_hash=index_node_hash,
content=content,
@ -2175,7 +2175,13 @@ class SegmentService:
@classmethod
def get_segments(
cls, document_id: str, tenant_id: str, status_list: list[str] | None = None, keyword: str | None = None
cls,
document_id: str,
tenant_id: str,
status_list: list[str] | None = None,
keyword: str | None = None,
page: int = 1,
limit: int = 20,
):
"""Get segments for a document with optional filtering."""
query = DocumentSegment.query.filter(
@ -2188,10 +2194,11 @@ class SegmentService:
if keyword:
query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%"))
segments = query.order_by(DocumentSegment.position.asc()).all()
total = len(segments)
paginated_segments = query.order_by(DocumentSegment.position.asc()).paginate(
page=page, per_page=limit, max_per_page=100, error_out=False
)
return segments, total
return paginated_segments.items, paginated_segments.total
@classmethod
def update_segment_by_id(

@ -1,3 +1,4 @@
from configs import dify_config
from core.helper import marketplace
from core.plugin.entities.plugin import ModelProviderID, PluginDependency, PluginInstallationSource, ToolProviderID
from core.plugin.manager.plugin import PluginInstallationManager
@ -111,6 +112,8 @@ class DependenciesAnalysisService:
Generate the latest version of dependencies
"""
dependencies = list(set(dependencies))
if not dify_config.MARKETPLACE_ENABLED:
return []
deps = marketplace.batch_fetch_plugin_manifests(dependencies)
return [
PluginDependency(

@ -0,0 +1,99 @@
from unittest.mock import MagicMock, patch
from core.model_runtime.entities.message_entities import AssistantPromptMessage
from core.model_runtime.model_providers.__base.large_language_model import _increase_tool_call
ToolCall = AssistantPromptMessage.ToolCall
# CASE 1: Single tool call
INPUTS_CASE_1 = [
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
]
EXPECTED_CASE_1 = [
ToolCall(
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
),
]
# CASE 2: Tool call sequences where IDs are anchored to the first chunk (vLLM/SiliconFlow ...)
INPUTS_CASE_2 = [
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
]
EXPECTED_CASE_2 = [
ToolCall(
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
),
ToolCall(
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
),
]
# CASE 3: Tool call sequences where IDs are anchored to every chunk (SGLang ...)
INPUTS_CASE_3 = [
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
]
EXPECTED_CASE_3 = [
ToolCall(
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
),
ToolCall(
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
),
]
# CASE 4: Tool call sequences with no IDs
INPUTS_CASE_4 = [
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
]
EXPECTED_CASE_4 = [
ToolCall(
id="RANDOM_ID_1",
type="function",
function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'),
),
ToolCall(
id="RANDOM_ID_2",
type="function",
function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}'),
),
]
def _run_case(inputs: list[ToolCall], expected: list[ToolCall]):
actual = []
_increase_tool_call(inputs, actual)
assert actual == expected
def test__increase_tool_call():
# case 1:
_run_case(INPUTS_CASE_1, EXPECTED_CASE_1)
# case 2:
_run_case(INPUTS_CASE_2, EXPECTED_CASE_2)
# case 3:
_run_case(INPUTS_CASE_3, EXPECTED_CASE_3)
# case 4:
mock_id_generator = MagicMock()
mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4]
with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator):
_run_case(INPUTS_CASE_4, EXPECTED_CASE_4)

@ -1,14 +1,20 @@
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunStreamChunkEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode
from models.enums import UserFrom
from models.workflow import WorkflowType
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
class ContinueOnErrorTestHelper:
@ -492,10 +498,7 @@ def test_no_node_in_fail_branch_continue_on_error():
"edges": FAIL_BRANCH_EDGES[:-1],
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
"id": "success",
},
{"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, "id": "success"},
ContinueOnErrorTestHelper.get_http_node(),
],
}
@ -506,3 +509,47 @@ def test_no_node_in_fail_branch_continue_on_error():
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0
def test_stream_output_with_fail_branch_continue_on_error():
"""Test stream output with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "LLM request successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "{{#node.text#}}"},
"id": "error",
},
ContinueOnErrorTestHelper.get_llm_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
def llm_generator(self):
contents = ["hi", "bye", "good morning"]
yield RunStreamChunkEvent(chunk_content=contents[0], from_variable_selector=[self.node_id, "text"])
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={},
process_data={},
outputs={},
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: 1,
NodeRunMetadataKey.TOTAL_PRICE: 1,
NodeRunMetadataKey.CURRENCY: "USD",
},
)
)
with patch.object(LLMNode, "_run", new=llm_generator):
events = list(graph_engine.run())
assert sum(isinstance(e, NodeRunStreamChunkEvent) for e in events) == 1
assert all(not isinstance(e, NodeRunFailedEvent | NodeRunExceptionEvent) for e in events)

@ -0,0 +1,198 @@
import uuid
from unittest.mock import MagicMock, patch
import pytest
from httpx import Response
from factories.file_factory import (
File,
FileTransferMethod,
FileType,
FileUploadConfig,
build_from_mapping,
)
from models import ToolFile, UploadFile
# Test Data
TEST_TENANT_ID = "test_tenant_id"
TEST_UPLOAD_FILE_ID = str(uuid.uuid4())
TEST_TOOL_FILE_ID = str(uuid.uuid4())
TEST_REMOTE_URL = "http://example.com/test.jpg"
# Test Config
TEST_CONFIG = FileUploadConfig(
allowed_file_types=["image", "document"],
allowed_file_extensions=[".jpg", ".pdf"],
allowed_file_upload_methods=[FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE],
number_limits=10,
)
# Fixtures
@pytest.fixture
def mock_upload_file():
mock = MagicMock(spec=UploadFile)
mock.id = TEST_UPLOAD_FILE_ID
mock.tenant_id = TEST_TENANT_ID
mock.name = "test.jpg"
mock.extension = "jpg"
mock.mime_type = "image/jpeg"
mock.source_url = TEST_REMOTE_URL
mock.size = 1024
mock.key = "test_key"
with patch("factories.file_factory.db.session.scalar", return_value=mock) as m:
yield m
@pytest.fixture
def mock_tool_file():
mock = MagicMock(spec=ToolFile)
mock.id = TEST_TOOL_FILE_ID
mock.tenant_id = TEST_TENANT_ID
mock.name = "tool_file.pdf"
mock.file_key = "tool_file.pdf"
mock.mimetype = "application/pdf"
mock.original_url = "http://example.com/tool.pdf"
mock.size = 2048
with patch("factories.file_factory.db.session.query") as mock_query:
mock_query.return_value.filter.return_value.first.return_value = mock
yield mock
@pytest.fixture
def mock_http_head():
def _mock_response(filename, size, content_type):
return Response(
status_code=200,
headers={
"Content-Disposition": f'attachment; filename="{filename}"',
"Content-Length": str(size),
"Content-Type": content_type,
},
)
with patch("factories.file_factory.ssrf_proxy.head") as mock_head:
mock_head.return_value = _mock_response("remote_test.jpg", 2048, "image/jpeg")
yield mock_head
# Helper functions
def local_file_mapping(file_type="image"):
return {
"transfer_method": "local_file",
"upload_file_id": TEST_UPLOAD_FILE_ID,
"type": file_type,
}
def tool_file_mapping(file_type="document"):
return {
"transfer_method": "tool_file",
"tool_file_id": TEST_TOOL_FILE_ID,
"type": file_type,
}
# Tests
def test_build_from_mapping_backward_compatibility(mock_upload_file):
mapping = local_file_mapping(file_type="image")
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
assert isinstance(file, File)
assert file.transfer_method == FileTransferMethod.LOCAL_FILE
assert file.type == FileType.IMAGE
assert file.related_id == TEST_UPLOAD_FILE_ID
@pytest.mark.parametrize(
("file_type", "should_pass", "expected_error"),
[
("image", True, None),
("document", False, "Detected file type does not match"),
],
)
def test_build_from_local_file_strict_validation(mock_upload_file, file_type, should_pass, expected_error):
mapping = local_file_mapping(file_type=file_type)
if should_pass:
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
assert file.type == FileType(file_type)
else:
with pytest.raises(ValueError, match=expected_error):
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
@pytest.mark.parametrize(
("file_type", "should_pass", "expected_error"),
[
("document", True, None),
("image", False, "Detected file type does not match"),
],
)
def test_build_from_tool_file_strict_validation(mock_tool_file, file_type, should_pass, expected_error):
"""Strict type validation for tool_file."""
mapping = tool_file_mapping(file_type=file_type)
if should_pass:
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
assert file.type == FileType(file_type)
else:
with pytest.raises(ValueError, match=expected_error):
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
def test_build_from_remote_url(mock_http_head):
mapping = {
"transfer_method": "remote_url",
"url": TEST_REMOTE_URL,
"type": "image",
}
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
assert file.transfer_method == FileTransferMethod.REMOTE_URL
assert file.type == FileType.IMAGE
assert file.filename == "remote_test.jpg"
assert file.size == 2048
def test_tool_file_not_found():
"""Test ToolFile not found in database."""
with patch("factories.file_factory.db.session.query") as mock_query:
mock_query.return_value.filter.return_value.first.return_value = None
mapping = tool_file_mapping()
with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"):
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
def test_local_file_not_found():
"""Test UploadFile not found in database."""
with patch("factories.file_factory.db.session.scalar", return_value=None):
mapping = local_file_mapping()
with pytest.raises(ValueError, match="Invalid upload file"):
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
def test_build_without_type_specification(mock_upload_file):
"""Test the situation where no file type is specified"""
mapping = {
"transfer_method": "local_file",
"upload_file_id": TEST_UPLOAD_FILE_ID,
# leave out the type
}
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
# It should automatically infer the type as "image" based on the file extension
assert file.type == FileType.IMAGE
@pytest.mark.parametrize(
("file_type", "should_pass", "expected_error"),
[
("image", True, None),
("video", False, "File validation failed"),
],
)
def test_file_validation_with_config(mock_upload_file, file_type, should_pass, expected_error):
"""Test the validation of files and configurations"""
mapping = local_file_mapping(file_type=file_type)
if should_pass:
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=TEST_CONFIG)
assert file is not None
else:
with pytest.raises(ValueError, match=expected_error):
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=TEST_CONFIG)

@ -0,0 +1,3 @@
"""
Unit tests for repositories.
"""

@ -0,0 +1,3 @@
"""
Unit tests for workflow_node_execution repositories.
"""

@ -0,0 +1,154 @@
"""
Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository.
"""
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from sqlalchemy.orm import Session, sessionmaker
from core.repository.workflow_node_execution_repository import OrderConfig
from models.workflow import WorkflowNodeExecution
from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
@pytest.fixture
def session():
"""Create a mock SQLAlchemy session."""
session = MagicMock(spec=Session)
# Configure the session to be used as a context manager
session.__enter__ = MagicMock(return_value=session)
session.__exit__ = MagicMock(return_value=None)
# Configure the session factory to return the session
session_factory = MagicMock(spec=sessionmaker)
session_factory.return_value = session
return session, session_factory
@pytest.fixture
def repository(session):
"""Create a repository instance with test data."""
_, session_factory = session
tenant_id = "test-tenant"
app_id = "test-app"
return SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory, tenant_id=tenant_id, app_id=app_id
)
def test_save(repository, session):
"""Test save method."""
session_obj, _ = session
# Create a mock execution
execution = MagicMock(spec=WorkflowNodeExecution)
execution.tenant_id = None
execution.app_id = None
# Call save method
repository.save(execution)
# Assert tenant_id and app_id are set
assert execution.tenant_id == repository._tenant_id
assert execution.app_id == repository._app_id
# Assert session.add was called
session_obj.add.assert_called_once_with(execution)
def test_save_with_existing_tenant_id(repository, session):
"""Test save method with existing tenant_id."""
session_obj, _ = session
# Create a mock execution with existing tenant_id
execution = MagicMock(spec=WorkflowNodeExecution)
execution.tenant_id = "existing-tenant"
execution.app_id = None
# Call save method
repository.save(execution)
# Assert tenant_id is not changed and app_id is set
assert execution.tenant_id == "existing-tenant"
assert execution.app_id == repository._app_id
# Assert session.add was called
session_obj.add.assert_called_once_with(execution)
def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
"""Test get_by_node_execution_id method."""
session_obj, _ = session
# Set up mock
mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
mock_stmt = mocker.MagicMock()
mock_select.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt
session_obj.scalar.return_value = mocker.MagicMock(spec=WorkflowNodeExecution)
# Call method
result = repository.get_by_node_execution_id("test-node-execution-id")
# Assert select was called with correct parameters
mock_select.assert_called_once()
session_obj.scalar.assert_called_once_with(mock_stmt)
assert result is not None
def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
"""Test get_by_workflow_run method."""
session_obj, _ = session
# Set up mock
mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
mock_stmt = mocker.MagicMock()
mock_select.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt
mock_stmt.order_by.return_value = mock_stmt
session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)]
# Call method
order_config = OrderConfig(order_by=["index"], order_direction="desc")
result = repository.get_by_workflow_run(workflow_run_id="test-workflow-run-id", order_config=order_config)
# Assert select was called with correct parameters
mock_select.assert_called_once()
session_obj.scalars.assert_called_once_with(mock_stmt)
assert len(result) == 1
def test_get_running_executions(repository, session, mocker: MockerFixture):
"""Test get_running_executions method."""
session_obj, _ = session
# Set up mock
mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
mock_stmt = mocker.MagicMock()
mock_select.return_value = mock_stmt
mock_stmt.where.return_value = mock_stmt
session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)]
# Call method
result = repository.get_running_executions("test-workflow-run-id")
# Assert select was called with correct parameters
mock_select.assert_called_once()
session_obj.scalars.assert_called_once_with(mock_stmt)
assert len(result) == 1
def test_update(repository, session):
"""Test update method."""
session_obj, _ = session
# Create a mock execution
execution = MagicMock(spec=WorkflowNodeExecution)
execution.tenant_id = None
execution.app_id = None
# Call update method
repository.update(execution)
# Assert tenant_id and app_id are set
assert execution.tenant_id == repository._tenant_id
assert execution.app_id == repository._app_id
# Assert session.merge was called
session_obj.merge.assert_called_once_with(execution)

@ -1012,6 +1012,11 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b7/00/14b00a0748e9eda26e97be07a63cc911108844004687321ddcc213be956c/coverage-7.2.7-cp312-cp312-win_amd64.whl", hash = "sha256:9e31cb64d7de6b6f09702bb27c02d1904b3aebfca610c12772452c4e6c21a0d3", size = 204347 },
]
[package.optional-dependencies]
toml = [
{ name = "tomli", marker = "python_full_version <= '3.11'" },
]
[[package]]
name = "crc32c"
version = "2.7.1"
@ -1234,6 +1239,7 @@ dev = [
{ name = "mypy" },
{ name = "pytest" },
{ name = "pytest-benchmark" },
{ name = "pytest-cov" },
{ name = "pytest-env" },
{ name = "pytest-mock" },
{ name = "ruff" },
@ -1401,6 +1407,7 @@ dev = [
{ name = "mypy", specifier = "~=1.15.0" },
{ name = "pytest", specifier = "~=8.3.2" },
{ name = "pytest-benchmark", specifier = "~=4.0.0" },
{ name = "pytest-cov", specifier = "~=4.1.0" },
{ name = "pytest-env", specifier = "~=1.1.3" },
{ name = "pytest-mock", specifier = "~=3.14.0" },
{ name = "ruff", specifier = "~=0.11.5" },
@ -4333,6 +4340,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/4d/a1/3b70862b5b3f830f0422844f25a823d0470739d994466be9dbbbb414d85a/pytest_benchmark-4.0.0-py3-none-any.whl", hash = "sha256:fdb7db64e31c8b277dff9850d2a2556d8b60bcb0ea6524e36e28ffd7c87f71d6", size = 43951 },
]
[[package]]
name = "pytest-cov"
version = "4.1.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "coverage", extra = ["toml"] },
{ name = "pytest" },
]
sdist = { url = "https://files.pythonhosted.org/packages/7a/15/da3df99fd551507694a9b01f512a2f6cf1254f33601605843c3775f39460/pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6", size = 63245 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a7/4b/8b78d126e275efa2379b1c2e09dc52cf70df16fc3b90613ef82531499d73/pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a", size = 21949 },
]
[[package]]
name = "pytest-env"
version = "1.1.5"
@ -5235,6 +5255,35 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", size = 16588 },
]
[[package]]
name = "tomli"
version = "2.2.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/18/87/302344fed471e44a87289cf4967697d07e532f2421fdaf868a303cbae4ff/tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff", size = 17175 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/43/ca/75707e6efa2b37c77dadb324ae7d9571cb424e61ea73fad7c56c2d14527f/tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249", size = 131077 },
{ url = "https://files.pythonhosted.org/packages/c7/16/51ae563a8615d472fdbffc43a3f3d46588c264ac4f024f63f01283becfbb/tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6", size = 123429 },
{ url = "https://files.pythonhosted.org/packages/f1/dd/4f6cd1e7b160041db83c694abc78e100473c15d54620083dbd5aae7b990e/tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a", size = 226067 },
{ url = "https://files.pythonhosted.org/packages/a9/6b/c54ede5dc70d648cc6361eaf429304b02f2871a345bbdd51e993d6cdf550/tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee", size = 236030 },
{ url = "https://files.pythonhosted.org/packages/1f/47/999514fa49cfaf7a92c805a86c3c43f4215621855d151b61c602abb38091/tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e", size = 240898 },
{ url = "https://files.pythonhosted.org/packages/73/41/0a01279a7ae09ee1573b423318e7934674ce06eb33f50936655071d81a24/tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4", size = 229894 },
{ url = "https://files.pythonhosted.org/packages/55/18/5d8bc5b0a0362311ce4d18830a5d28943667599a60d20118074ea1b01bb7/tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106", size = 245319 },
{ url = "https://files.pythonhosted.org/packages/92/a3/7ade0576d17f3cdf5ff44d61390d4b3febb8a9fc2b480c75c47ea048c646/tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8", size = 238273 },
{ url = "https://files.pythonhosted.org/packages/72/6f/fa64ef058ac1446a1e51110c375339b3ec6be245af9d14c87c4a6412dd32/tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff", size = 98310 },
{ url = "https://files.pythonhosted.org/packages/6a/1c/4a2dcde4a51b81be3530565e92eda625d94dafb46dbeb15069df4caffc34/tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b", size = 108309 },
{ url = "https://files.pythonhosted.org/packages/52/e1/f8af4c2fcde17500422858155aeb0d7e93477a0d59a98e56cbfe75070fd0/tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea", size = 132762 },
{ url = "https://files.pythonhosted.org/packages/03/b8/152c68bb84fc00396b83e7bbddd5ec0bd3dd409db4195e2a9b3e398ad2e3/tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8", size = 123453 },
{ url = "https://files.pythonhosted.org/packages/c8/d6/fc9267af9166f79ac528ff7e8c55c8181ded34eb4b0e93daa767b8841573/tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192", size = 233486 },
{ url = "https://files.pythonhosted.org/packages/5c/51/51c3f2884d7bab89af25f678447ea7d297b53b5a3b5730a7cb2ef6069f07/tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222", size = 242349 },
{ url = "https://files.pythonhosted.org/packages/ab/df/bfa89627d13a5cc22402e441e8a931ef2108403db390ff3345c05253935e/tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77", size = 252159 },
{ url = "https://files.pythonhosted.org/packages/9e/6e/fa2b916dced65763a5168c6ccb91066f7639bdc88b48adda990db10c8c0b/tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6", size = 237243 },
{ url = "https://files.pythonhosted.org/packages/b4/04/885d3b1f650e1153cbb93a6a9782c58a972b94ea4483ae4ac5cedd5e4a09/tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd", size = 259645 },
{ url = "https://files.pythonhosted.org/packages/9c/de/6b432d66e986e501586da298e28ebeefd3edc2c780f3ad73d22566034239/tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e", size = 244584 },
{ url = "https://files.pythonhosted.org/packages/1c/9a/47c0449b98e6e7d1be6cbac02f93dd79003234ddc4aaab6ba07a9a7482e2/tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98", size = 98875 },
{ url = "https://files.pythonhosted.org/packages/ef/60/9b9638f081c6f1261e2688bd487625cd1e660d0a85bd469e91d8db969734/tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4", size = 109418 },
{ url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257 },
]
[[package]]
name = "tos"
version = "2.7.2"

@ -174,6 +174,12 @@ CELERY_MIN_WORKERS=
API_TOOL_DEFAULT_CONNECT_TIMEOUT=10
API_TOOL_DEFAULT_READ_TIMEOUT=60
# -------------------------------
# Datasource Configuration
# --------------------------------
ENABLE_WEBSITE_JINAREADER=true
ENABLE_WEBSITE_FIRECRAWL=true
ENABLE_WEBSITE_WATERCRAWL=true
# ------------------------------
# Database Configuration
@ -404,6 +410,7 @@ MILVUS_TOKEN=
MILVUS_USER=
MILVUS_PASSWORD=
MILVUS_ENABLE_HYBRID_SEARCH=False
MILVUS_ANALYZER_PARAMS=
# MyScale configuration, only available when VECTOR_STORE is `myscale`
# For multi-language support, please set MYSCALE_FTS_PARAMS with referring to:
@ -737,6 +744,12 @@ MAX_VARIABLE_SIZE=204800
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
WORKFLOW_FILE_UPLOAD_LIMIT=10
# Workflow storage configuration
# Options: rdbms, hybrid
# rdbms: Use only the relational database (default)
# hybrid: Save new data to object storage, read from both object storage and RDBMS
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
# HTTP request node in workflow configuration
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576

@ -75,7 +75,9 @@ services:
MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10}
MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10}
MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-5}
ENABLE_WEBSITE_JINAREADER: ${ENABLE_WEBSITE_JINAREADER:-true}
ENABLE_WEBSITE_FIRECRAWL: ${ENABLE_WEBSITE_FIRECRAWL:-true}
ENABLE_WEBSITE_WATERCRAWL: ${ENABLE_WEBSITE_WATERCRAWL:-true}
# The postgres database.
db:
image: postgres:15-alpine

@ -43,6 +43,9 @@ x-shared-env: &shared-api-worker-env
CELERY_MIN_WORKERS: ${CELERY_MIN_WORKERS:-}
API_TOOL_DEFAULT_CONNECT_TIMEOUT: ${API_TOOL_DEFAULT_CONNECT_TIMEOUT:-10}
API_TOOL_DEFAULT_READ_TIMEOUT: ${API_TOOL_DEFAULT_READ_TIMEOUT:-60}
ENABLE_WEBSITE_JINAREADER: ${ENABLE_WEBSITE_JINAREADER:-true}
ENABLE_WEBSITE_FIRECRAWL: ${ENABLE_WEBSITE_FIRECRAWL:-true}
ENABLE_WEBSITE_WATERCRAWL: ${ENABLE_WEBSITE_WATERCRAWL:-true}
DB_USERNAME: ${DB_USERNAME:-postgres}
DB_PASSWORD: ${DB_PASSWORD:-difyai123456}
DB_HOST: ${DB_HOST:-db}
@ -139,6 +142,7 @@ x-shared-env: &shared-api-worker-env
MILVUS_USER: ${MILVUS_USER:-}
MILVUS_PASSWORD: ${MILVUS_PASSWORD:-}
MILVUS_ENABLE_HYBRID_SEARCH: ${MILVUS_ENABLE_HYBRID_SEARCH:-False}
MILVUS_ANALYZER_PARAMS: ${MILVUS_ANALYZER_PARAMS:-}
MYSCALE_HOST: ${MYSCALE_HOST:-myscale}
MYSCALE_PORT: ${MYSCALE_PORT:-8123}
MYSCALE_USER: ${MYSCALE_USER:-default}
@ -323,6 +327,7 @@ x-shared-env: &shared-api-worker-env
MAX_VARIABLE_SIZE: ${MAX_VARIABLE_SIZE:-204800}
WORKFLOW_PARALLEL_DEPTH_LIMIT: ${WORKFLOW_PARALLEL_DEPTH_LIMIT:-3}
WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10}
WORKFLOW_NODE_EXECUTION_STORAGE: ${WORKFLOW_NODE_EXECUTION_STORAGE:-rdbms}
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760}
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576}
HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True}
@ -543,7 +548,9 @@ services:
MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10}
MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10}
MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-5}
ENABLE_WEBSITE_JINAREADER: ${ENABLE_WEBSITE_JINAREADER:-true}
ENABLE_WEBSITE_FIRECRAWL: ${ENABLE_WEBSITE_FIRECRAWL:-true}
ENABLE_WEBSITE_WATERCRAWL: ${ENABLE_WEBSITE_WATERCRAWL:-true}
# The postgres database.
db:
image: postgres:15-alpine

@ -49,3 +49,8 @@ NEXT_PUBLIC_MAX_PARALLEL_LIMIT=10
# The maximum number of iterations for agent setting
NEXT_PUBLIC_MAX_ITERATIONS_NUM=5
NEXT_PUBLIC_ENABLE_WEBSITE_JINAREADER=true
NEXT_PUBLIC_ENABLE_WEBSITE_FIRECRAWL=true
NEXT_PUBLIC_ENABLE_WEBSITE_WATERCRAWL=true

@ -1,11 +1,11 @@
'use client'
import Workflow from '@/app/components/workflow'
import WorkflowApp from '@/app/components/workflow-app'
const Page = () => {
return (
<div className='h-full w-full overflow-x-auto'>
<Workflow />
<WorkflowApp />
</div>
)
}

@ -557,7 +557,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
<Heading
url='/datasets/{dataset_id}'
method='POST'
method='PATCH'
title='Update knowledge base'
name='#update_dataset'
/>
@ -585,8 +585,21 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
<Property name='embedding_model' type='string' key='embedding_model'>
Specified embedding model, corresponding to the model field(Optional)
</Property>
<Property name='retrieval_model' type='string' key='retrieval_model'>
Specified retrieval model, corresponding to the model field(Optional)
<Property name='retrieval_model' type='object' key='retrieval_model'>
Retrieval model (optional, if not filled, it will be recalled according to the default method)
- <code>search_method</code> (text) Search method: One of the following four keywords is required
- <code>keyword_search</code> Keyword search
- <code>semantic_search</code> Semantic search
- <code>full_text_search</code> Full-text search
- <code>hybrid_search</code> Hybrid search
- <code>reranking_enable</code> (bool) Whether to enable reranking, required if the search mode is semantic_search or hybrid_search (optional)
- <code>reranking_mode</code> (object) Rerank model configuration, required if reranking is enabled
- <code>reranking_provider_name</code> (string) Rerank model provider
- <code>reranking_model_name</code> (string) Rerank model name
- <code>weights</code> (float) Semantic search weight setting in hybrid search mode
- <code>top_k</code> (integer) Number of results to return (optional)
- <code>score_threshold_enabled</code> (bool) Whether to enable score threshold
- <code>score_threshold</code> (float) Score threshold
</Property>
<Property name='partial_member_list' type='array' key='partial_member_list'>
Partial member list(Optional)
@ -596,16 +609,56 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
<Col sticky>
<CodeGroup
title="Request"
tag="POST"
tag="PATCH"
label="/datasets/{dataset_id}"
targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"name": "Test Knowledge Base", "indexing_technique": "high_quality", "permission": "only_me", "embedding_model_provider": "zhipuai", "embedding_model": "embedding-3", "retrieval_model": "", "partial_member_list": []}' `}
targetCode={`curl --location --request PATCH '${props.apiBaseUrl}/datasets/{dataset_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{
"name": "Test Knowledge Base",
"indexing_technique": "high_quality",
"permission": "only_me",
"embedding_model_provider": "zhipuai",
"embedding_model": "embedding-3",
"retrieval_model": {
"search_method": "keyword_search",
"reranking_enable": false,
"reranking_mode": null,
"reranking_model": {
"reranking_provider_name": "",
"reranking_model_name": ""
},
"weights": null,
"top_k": 1,
"score_threshold_enabled": false,
"score_threshold": null
},
"partial_member_list": []
}'
`}
>
```bash {{ title: 'cURL' }}
curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}' \
curl --location --request PATCH '${props.apiBaseUrl}/datasets/{dataset_id}' \
--header 'Authorization: Bearer {api_key}' \
--header 'Content-Type: application/json' \
--data-raw '{"name": "Test Knowledge Base", "indexing_technique": "high_quality", "permission": "only_me",\
"embedding_model_provider": "zhipuai", "embedding_model": "embedding-3", "retrieval_model": "", "partial_member_list": []}'
--data-raw '{
"name": "Test Knowledge Base",
"indexing_technique": "high_quality",
"permission": "only_me",
"embedding_model_provider": "zhipuai",
"embedding_model": "embedding-3",
"retrieval_model": {
"search_method": "keyword_search",
"reranking_enable": false,
"reranking_mode": null,
"reranking_model": {
"reranking_provider_name": "",
"reranking_model_name": ""
},
"weights": null,
"top_k": 1,
"score_threshold_enabled": false,
"score_threshold": null
},
"partial_member_list": []
}'
```
</CodeGroup>
<CodeGroup title="Response">

@ -94,6 +94,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
- <code>semantic_search</code> 语义检索
- <code>full_text_search</code> 全文检索
- <code>reranking_enable</code> (bool) 是否开启rerank
- <code>reranking_mode</code> (String) 混合检索
- <code>weighted_score</code> 权重设置
- <code>reranking_model</code> Rerank 模型
- <code>reranking_model</code> (object) Rerank 模型配置
- <code>reranking_provider_name</code> (string) Rerank 模型的提供商
- <code>reranking_model_name</code> (string) Rerank 模型的名称
@ -557,7 +560,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
<Heading
url='/datasets/{dataset_id}'
method='POST'
method='PATCH'
title='修改知识库详情'
name='#update_dataset'
/>
@ -589,8 +592,21 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
<Property name='embedding_model' type='string' key='embedding_model'>
嵌入模型(选填)
</Property>
<Property name='retrieval_model' type='string' key='retrieval_model'>
检索模型(选填)
<Property name='retrieval_model' type='object' key='retrieval_model'>
检索参数(选填,如不填,按照默认方式召回)
- <code>search_method</code> (text) 检索方法:以下四个关键字之一,必填
- <code>keyword_search</code> 关键字检索
- <code>semantic_search</code> 语义检索
- <code>full_text_search</code> 全文检索
- <code>hybrid_search</code> 混合检索
- <code>reranking_enable</code> (bool) 是否启用 Reranking非必填如果检索模式为 semantic_search 模式或者 hybrid_search 则传值
- <code>reranking_mode</code> (object) Rerank 模型配置,非必填,如果启用了 reranking 则传值
- <code>reranking_provider_name</code> (string) Rerank 模型提供商
- <code>reranking_model_name</code> (string) Rerank 模型名称
- <code>weights</code> (float) 混合检索模式下语意检索的权重设置
- <code>top_k</code> (integer) 返回结果数量,非必填
- <code>score_threshold_enabled</code> (bool) 是否开启 score 阈值
- <code>score_threshold</code> (float) Score 阈值
</Property>
<Property name='partial_member_list' type='array' key='partial_member_list'>
部分团队成员 ID 列表(选填)
@ -600,16 +616,56 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
<Col sticky>
<CodeGroup
title="Request"
tag="POST"
tag="PATCH"
label="/datasets/{dataset_id}"
targetCode={`curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"name": "Test Knowledge Base", "indexing_technique": "high_quality", "permission": "only_me", "embedding_model_provider": "zhipuai", "embedding_model": "embedding-3", "retrieval_model": "", "partial_member_list": []}' `}
targetCode={`curl --location --request PATCH '${props.apiBaseUrl}/datasets/{dataset_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{
"name": "Test Knowledge Base",
"indexing_technique": "high_quality",
"permission": "only_me",
"embedding_model_provider": "zhipuai",
"embedding_model": "embedding-3",
"retrieval_model": {
"search_method": "keyword_search",
"reranking_enable": false,
"reranking_mode": null,
"reranking_model": {
"reranking_provider_name": "",
"reranking_model_name": ""
},
"weights": null,
"top_k": 1,
"score_threshold_enabled": false,
"score_threshold": null
},
"partial_member_list": []
}'
`}
>
```bash {{ title: 'cURL' }}
curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}' \
curl --location --request PATCH '${props.apiBaseUrl}/datasets/{dataset_id}' \
--header 'Authorization: Bearer {api_key}' \
--header 'Content-Type: application/json' \
--data-raw '{"name": "Test Knowledge Base", "indexing_technique": "high_quality", "permission": "only_me",\
"embedding_model_provider": "zhipuai", "embedding_model": "embedding-3", "retrieval_model": "", "partial_member_list": []}'
--data-raw '{
"name": "Test Knowledge Base",
"indexing_technique": "high_quality",
"permission": "only_me",
"embedding_model_provider": "zhipuai",
"embedding_model": "embedding-3",
"retrieval_model": {
"search_method": "keyword_search",
"reranking_enable": false,
"reranking_mode": null,
"reranking_model": {
"reranking_provider_name": "",
"reranking_model_name": ""
},
"weights": null,
"top_k": 1,
"score_threshold_enabled": false,
"score_threshold": null
},
"partial_member_list": []
}'
```
</CodeGroup>
<CodeGroup title="Response">
@ -1764,7 +1820,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Property>
<Property name='retrieval_model' type='object' key='retrieval_model'>
检索参数(选填,如不填,按照默认方式召回)
- <code>search_method</code> (text) 检索方法:以下个关键字之一,必填
- <code>search_method</code> (text) 检索方法:以下个关键字之一,必填
- <code>keyword_search</code> 关键字检索
- <code>semantic_search</code> 语义检索
- <code>full_text_search</code> 全文检索

@ -16,7 +16,7 @@ import type {
Feedback,
} from '../types'
import { CONVERSATION_ID_INFO } from '../constants'
import { buildChatItemTree } from '../utils'
import { buildChatItemTree, getProcessedSystemVariablesFromUrlParams } from '../utils'
import { addFileInfos, sortAgentSorts } from '../../../tools/utils'
import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils'
import {
@ -106,6 +106,13 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
}, [isInstalledApp, installedAppInfo, appInfo])
const appId = useMemo(() => appData?.app_id, [appData])
const [userId, setUserId] = useState<string>()
useEffect(() => {
getProcessedSystemVariablesFromUrlParams().then(({ user_id }) => {
setUserId(user_id)
})
}, [])
useEffect(() => {
if (appData?.site.default_language)
changeLanguage(appData.site.default_language)
@ -124,18 +131,24 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
setSidebarCollapseState(localState === 'collapsed')
}
}, [appId])
const [conversationIdInfo, setConversationIdInfo] = useLocalStorageState<Record<string, string>>(CONVERSATION_ID_INFO, {
const [conversationIdInfo, setConversationIdInfo] = useLocalStorageState<Record<string, Record<string, string>>>(CONVERSATION_ID_INFO, {
defaultValue: {},
})
const currentConversationId = useMemo(() => conversationIdInfo?.[appId || ''] || '', [appId, conversationIdInfo])
const currentConversationId = useMemo(() => conversationIdInfo?.[appId || '']?.[userId || 'DEFAULT'] || '', [appId, conversationIdInfo, userId])
const handleConversationIdInfoChange = useCallback((changeConversationId: string) => {
if (appId) {
let prevValue = conversationIdInfo?.[appId || '']
if (typeof prevValue === 'string')
prevValue = {}
setConversationIdInfo({
...conversationIdInfo,
[appId || '']: changeConversationId,
[appId || '']: {
...prevValue,
[userId || 'DEFAULT']: changeConversationId,
},
})
}
}, [appId, conversationIdInfo, setConversationIdInfo])
}, [appId, conversationIdInfo, setConversationIdInfo, userId])
const [newConversationId, setNewConversationId] = useState('')
const chatShouldReloadKey = useMemo(() => {

@ -2,8 +2,6 @@ import type { FC } from 'react'
import { memo } from 'react'
import type { ChatItem } from '../../types'
import { useChatContext } from '../context'
import Button from '@/app/components/base/button'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
type SuggestedQuestionsProps = {
item: ChatItem
@ -12,9 +10,6 @@ const SuggestedQuestions: FC<SuggestedQuestionsProps> = ({
item,
}) => {
const { onSend } = useChatContext()
const media = useBreakpoints()
const isMobile = media === MediaType.mobile
const klassName = `mr-1 mt-1 ${isMobile ? 'block overflow-hidden text-ellipsis' : ''} max-w-full shrink-0 last:mr-0`
const {
isOpeningStatement,
@ -27,14 +22,13 @@ const SuggestedQuestions: FC<SuggestedQuestionsProps> = ({
return (
<div className='flex flex-wrap'>
{suggestedQuestions.filter(q => !!q && q.trim()).map((question, index) => (
<Button
<div
key={index}
variant='secondary-accent'
className={klassName}
className='system-sm-medium mr-1 mt-1 inline-flex max-w-full shrink-0 cursor-pointer flex-wrap rounded-lg border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg px-3.5 py-2 text-components-button-secondary-accent-text shadow-xs last:mr-0 hover:border-components-button-secondary-border-hover hover:bg-components-button-secondary-bg-hover'
onClick={() => onSend?.(question)}
>
{question}
</Button>),
</div>),
)}
</div>
)

@ -184,7 +184,7 @@ const ChatWrapper = () => {
return null
if (welcomeMessage.suggestedQuestions && welcomeMessage.suggestedQuestions?.length > 0) {
return (
<div className='flex h-[50vh] items-center justify-center px-4 py-12'>
<div className={cn('flex items-center justify-center px-4 py-12', isMobile ? 'min-h-[30vh] py-0' : 'h-[50vh]')}>
<div className='flex max-w-[720px] grow gap-4'>
<AppIcon
size='xl'
@ -202,7 +202,7 @@ const ChatWrapper = () => {
)
}
return (
<div className={cn('flex h-[50vh] flex-col items-center justify-center gap-3 py-12')}>
<div className={cn('flex h-[50vh] flex-col items-center justify-center gap-3 py-12', isMobile ? 'min-h-[30vh] py-0' : 'h-[50vh]')}>
<AppIcon
size='xl'
iconType={appData?.site.icon_type}

@ -15,7 +15,7 @@ import type {
Feedback,
} from '../types'
import { CONVERSATION_ID_INFO } from '../constants'
import { buildChatItemTree, getProcessedInputsFromUrlParams } from '../utils'
import { buildChatItemTree, getProcessedInputsFromUrlParams, getProcessedSystemVariablesFromUrlParams } from '../utils'
import { getProcessedFilesFromResponse } from '../../file-uploader/utils'
import {
fetchAppInfo,
@ -72,23 +72,36 @@ export const useEmbeddedChatbot = () => {
}, [appInfo])
const appId = useMemo(() => appData?.app_id, [appData])
const [userId, setUserId] = useState<string>()
useEffect(() => {
getProcessedSystemVariablesFromUrlParams().then(({ user_id }) => {
setUserId(user_id)
})
}, [])
useEffect(() => {
if (appInfo?.site.default_language)
changeLanguage(appInfo.site.default_language)
}, [appInfo])
const [conversationIdInfo, setConversationIdInfo] = useLocalStorageState<Record<string, string>>(CONVERSATION_ID_INFO, {
const [conversationIdInfo, setConversationIdInfo] = useLocalStorageState<Record<string, Record<string, string>>>(CONVERSATION_ID_INFO, {
defaultValue: {},
})
const currentConversationId = useMemo(() => conversationIdInfo?.[appId || ''] || '', [appId, conversationIdInfo])
const currentConversationId = useMemo(() => conversationIdInfo?.[appId || '']?.[userId || 'DEFAULT'] || '', [appId, conversationIdInfo, userId])
const handleConversationIdInfoChange = useCallback((changeConversationId: string) => {
if (appId) {
let prevValue = conversationIdInfo?.[appId || '']
if (typeof prevValue === 'string')
prevValue = {}
setConversationIdInfo({
...conversationIdInfo,
[appId || '']: changeConversationId,
[appId || '']: {
...prevValue,
[userId || 'DEFAULT']: changeConversationId,
},
})
}
}, [appId, conversationIdInfo, setConversationIdInfo])
}, [appId, conversationIdInfo, setConversationIdInfo, userId])
const [newConversationId, setNewConversationId] = useState('')
const chatShouldReloadKey = useMemo(() => {

@ -85,9 +85,11 @@ const preprocessLaTeX = (content: string) => {
}
const preprocessThinkTag = (content: string) => {
const thinkOpenTagRegex = /<think>\n/g
const thinkCloseTagRegex = /\n<\/think>/g
return flow([
(str: string) => str.replace('<think>\n', '<details data-think=true>\n'),
(str: string) => str.replace('\n</think>', '\n[ENDTHINKFLAG]</details>'),
(str: string) => str.replace(thinkOpenTagRegex, '<details data-think=true>\n'),
(str: string) => str.replace(thinkCloseTagRegex, '\n[ENDTHINKFLAG]</details>'),
])(content)
}

@ -20,7 +20,7 @@ import { useProviderContext } from '@/context/provider-context'
import VectorSpaceFull from '@/app/components/billing/vector-space-full'
import classNames from '@/utils/classnames'
import { Icon3Dots } from '@/app/components/base/icons/src/vender/line/others'
import { ENABLE_WEBSITE_FIRECRAWL, ENABLE_WEBSITE_JINAREADER, ENABLE_WEBSITE_WATERCRAWL } from '@/config'
type IStepOneProps = {
datasetId?: string
dataSourceType?: DataSourceType
@ -126,9 +126,7 @@ const StepOne = ({
return true
if (files.some(file => !file.file.id))
return true
if (isShowVectorSpaceFull)
return true
return false
return isShowVectorSpaceFull
}, [files, isShowVectorSpaceFull])
return (
@ -193,7 +191,8 @@ const StepOne = ({
{t('datasetCreation.stepOne.dataSourceType.notion')}
</span>
</div>
<div
{(ENABLE_WEBSITE_FIRECRAWL || ENABLE_WEBSITE_JINAREADER || ENABLE_WEBSITE_WATERCRAWL) && (
<div
className={cn(
s.dataSourceItem,
'system-sm-medium',
@ -201,7 +200,7 @@ const StepOne = ({
dataSourceTypeDisable && dataSourceType !== DataSourceType.WEB && s.disabled,
)}
onClick={() => changeType(DataSourceType.WEB)}
>
>
<span className={cn(s.datasetIcon, s.web)} />
<span
title={t('datasetCreation.stepOne.dataSourceType.web')}
@ -209,7 +208,8 @@ const StepOne = ({
>
{t('datasetCreation.stepOne.dataSourceType.web')}
</span>
</div>
</div>
)}
</div>
)
}

@ -97,7 +97,7 @@ export enum IndexingType {
}
const DEFAULT_SEGMENT_IDENTIFIER = '\\n\\n'
const DEFAULT_MAXIMUM_CHUNK_LENGTH = 500
const DEFAULT_MAXIMUM_CHUNK_LENGTH = 1024
const DEFAULT_OVERLAP = 50
const MAXIMUM_CHUNK_TOKEN_LENGTH = Number.parseInt(globalThis.document?.body?.getAttribute('data-public-indexing-max-segmentation-tokens-length') || '4000', 10)
@ -117,11 +117,11 @@ const defaultParentChildConfig: ParentChildConfig = {
chunkForContext: 'paragraph',
parent: {
delimiter: '\\n\\n',
maxLength: 500,
maxLength: 1024,
},
child: {
delimiter: '\\n',
maxLength: 200,
maxLength: 512,
},
}
@ -623,12 +623,12 @@ const StepTwo = ({
onChange={e => setSegmentIdentifier(e.target.value, true)}
/>
<MaxLengthInput
unit='tokens'
unit='characters'
value={maxChunkLength}
onChange={setMaxChunkLength}
/>
<OverlapInput
unit='tokens'
unit='characters'
value={overlap}
min={1}
onChange={setOverlap}
@ -756,7 +756,7 @@ const StepTwo = ({
})}
/>
<MaxLengthInput
unit='tokens'
unit='characters'
value={parentChildConfig.parent.maxLength}
onChange={value => setParentChildConfig({
...parentChildConfig,
@ -803,7 +803,7 @@ const StepTwo = ({
})}
/>
<MaxLengthInput
unit='tokens'
unit='characters'
value={parentChildConfig.child.maxLength}
onChange={value => setParentChildConfig({
...parentChildConfig,

@ -12,6 +12,7 @@ import { useModalContext } from '@/context/modal-context'
import type { CrawlOptions, CrawlResultItem } from '@/models/datasets'
import { fetchDataSources } from '@/service/datasets'
import { type DataSourceItem, DataSourceProvider } from '@/models/common'
import { ENABLE_WEBSITE_FIRECRAWL, ENABLE_WEBSITE_JINAREADER, ENABLE_WEBSITE_WATERCRAWL } from '@/config'
type Props = {
onPreview: (payload: CrawlResultItem) => void
@ -84,7 +85,7 @@ const Website: FC<Props> = ({
{t('datasetCreation.stepOne.website.chooseProvider')}
</div>
<div className="flex space-x-2">
<button
{ENABLE_WEBSITE_JINAREADER && <button
className={cn('flex items-center justify-center rounded-lg px-4 py-2',
selectedProvider === DataSourceProvider.jinaReader
? 'system-sm-medium border-[1.5px] border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary'
@ -95,8 +96,8 @@ const Website: FC<Props> = ({
>
<span className={cn(s.jinaLogo, 'mr-2')}/>
<span>Jina Reader</span>
</button>
<button
</button>}
{ENABLE_WEBSITE_FIRECRAWL && <button
className={cn('rounded-lg px-4 py-2',
selectedProvider === DataSourceProvider.fireCrawl
? 'system-sm-medium border-[1.5px] border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary'
@ -106,8 +107,8 @@ const Website: FC<Props> = ({
onClick={() => setSelectedProvider(DataSourceProvider.fireCrawl)}
>
🔥 Firecrawl
</button>
<button
</button>}
{ENABLE_WEBSITE_WATERCRAWL && <button
className={cn('flex items-center justify-center rounded-lg px-4 py-2',
selectedProvider === DataSourceProvider.waterCrawl
? 'system-sm-medium border-[1.5px] border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary'
@ -118,7 +119,7 @@ const Website: FC<Props> = ({
>
<span className={cn(s.watercrawlLogo, 'mr-2')}/>
<span>WaterCrawl</span>
</button>
</button>}
</div>
</div>
{source && selectedProvider === DataSourceProvider.fireCrawl && (

@ -6,6 +6,7 @@ import s from './index.module.css'
import { Icon3Dots } from '@/app/components/base/icons/src/vender/line/others'
import Button from '@/app/components/base/button'
import { DataSourceProvider } from '@/models/common'
import { ENABLE_WEBSITE_FIRECRAWL, ENABLE_WEBSITE_JINAREADER, ENABLE_WEBSITE_WATERCRAWL } from '@/config'
const I18N_PREFIX = 'datasetCreation.stepOne.website'
@ -16,29 +17,30 @@ type Props = {
const NoData: FC<Props> = ({
onConfig,
provider,
}) => {
const { t } = useTranslation()
const providerConfig = {
[DataSourceProvider.jinaReader]: {
[DataSourceProvider.jinaReader]: ENABLE_WEBSITE_JINAREADER ? {
emoji: <span className={s.jinaLogo} />,
title: t(`${I18N_PREFIX}.jinaReaderNotConfigured`),
description: t(`${I18N_PREFIX}.jinaReaderNotConfiguredDescription`),
},
[DataSourceProvider.fireCrawl]: {
} : null,
[DataSourceProvider.fireCrawl]: ENABLE_WEBSITE_FIRECRAWL ? {
emoji: '🔥',
title: t(`${I18N_PREFIX}.fireCrawlNotConfigured`),
description: t(`${I18N_PREFIX}.fireCrawlNotConfiguredDescription`),
},
[DataSourceProvider.waterCrawl]: {
emoji: <span className={s.watercrawlLogo} />,
} : null,
[DataSourceProvider.waterCrawl]: ENABLE_WEBSITE_WATERCRAWL ? {
emoji: '💧',
title: t(`${I18N_PREFIX}.waterCrawlNotConfigured`),
description: t(`${I18N_PREFIX}.waterCrawlNotConfiguredDescription`),
},
} : null,
}
const currentProvider = providerConfig[provider]
const currentProvider = Object.values(providerConfig).find(provider => provider !== null) || providerConfig[DataSourceProvider.jinaReader]
if (!currentProvider) return null
return (
<>

@ -7,9 +7,11 @@ import SecretKeyButton from '@/app/components/develop/secret-key/secret-key-butt
type ApiServerProps = {
apiBaseUrl: string
appId?: string
}
const ApiServer: FC<ApiServerProps> = ({
apiBaseUrl,
appId,
}) => {
const { t } = useTranslation()
@ -25,7 +27,7 @@ const ApiServer: FC<ApiServerProps> = ({
{t('appApi.ok')}
</div>
<SecretKeyButton
className='!h-8 shrink-0'
className='!h-8 shrink-0' appId={appId}
/>
</div>
)

@ -23,7 +23,7 @@ const DevelopMain = ({ appId }: IDevelopMainProps) => {
<div className='relative flex h-full flex-col overflow-hidden'>
<div className='flex shrink-0 items-center justify-between border-b border-solid border-b-divider-regular px-6 py-2'>
<div className='text-lg font-medium text-text-primary'></div>
<ApiServer apiBaseUrl={appDetail.api_base_url} />
<ApiServer apiBaseUrl={appDetail.api_base_url} appId={appId} />
</div>
<div className='grow overflow-auto px-4 py-4 sm:px-10'>
<Doc appDetail={appDetail} />

@ -12,7 +12,7 @@ type IChildrenProps = {
type IHeaderingProps = {
url: string
method: 'PUT' | 'DELETE' | 'GET' | 'POST'
method: 'PUT' | 'DELETE' | 'GET' | 'POST' | 'PATCH'
title: string
name: string
}
@ -34,6 +34,9 @@ export const Heading = function H2({
case 'POST':
style = 'ring-sky-300 bg-sky-400/10 text-sky-500 dark:ring-sky-400/30 dark:bg-sky-400/10 dark:text-sky-400'
break
case 'PATCH':
style = 'ring-violet-300 bg-violet-400/10 text-violet-500 dark:ring-violet-400/30 dark:bg-violet-400/10 dark:text-violet-400'
break
default:
style = 'ring-emerald-300 dark:ring-emerald-400/30 bg-emerald-400/10 text-emerald-500 dark:text-emerald-400'
break

@ -776,6 +776,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
</Col>
<Col sticky>
嵌入模型的提供商和模型名称可以通过以下接口获取v1/workspaces/current/models/model-types/text-embedding 具体见:通过 API 维护知识库。 使用的Authorization是Dataset的API Token。
该接口是异步执行所以会返回一个job_id通过查询job状态接口可以获取到最终的执行结果。
<CodeGroup
title="Request"
tag="POST"
@ -801,7 +802,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
"job_status": "waiting"
}
```
该接口是异步执行所以会返回一个job_id通过查询job状态接口可以获取到最终的执行结果。
</CodeGroup>
</Col>
</Row>

@ -523,7 +523,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
<CodeGroup title="Request" tag="GET" label="/messages/{message_id}/suggested" targetCode={`curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested?user=abc-123 \\\n--header 'Authorization: Bearer ENTER-YOUR-SECRET-KEY' \\\n--header 'Content-Type: application/json'`}>
```bash {{ title: 'cURL' }}
curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested'?user=abc-123 \
curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested?user=abc-123' \
--header 'Authorization: Bearer ENTER-YOUR-SECRET-KEY' \
--header 'Content-Type: application/json' \
```
@ -967,7 +967,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
"user": "abc-123"
}'
```
</CodeGroup>
<CodeGroup title="headers">
@ -1191,10 +1191,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
title="Request"
tag="GET"
label="/apps/annotations"
targetCode={`curl --location --request GET '${props.apiBaseUrl}/apps/annotations?page=1&limit=20' \\\n--header 'Authorization: Bearer {api_key}'`}
targetCode={`curl --location --request GET '${props.appDetail.api_base_url}/apps/annotations?page=1&limit=20' \\\n--header 'Authorization: Bearer {api_key}'`}
>
```bash {{ title: 'cURL' }}
curl --location --request GET '${props.apiBaseUrl}/apps/annotations?page=1&limit=20' \
curl --location --request GET '${props.appDetail.api_base_url}/apps/annotations?page=1&limit=20' \
--header 'Authorization: Bearer {api_key}'
```
</CodeGroup>
@ -1245,10 +1245,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
title="Request"
tag="POST"
label="/apps/annotations"
targetCode={`curl --location --request POST '${props.apiBaseUrl}/apps/annotations' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"question": "What is your name?","answer": "I am Dify."}'`}
targetCode={`curl --location --request POST '${props.appDetail.api_base_url}/apps/annotations' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"question": "What is your name?","answer": "I am Dify."}'`}
>
```bash {{ title: 'cURL' }}
curl --location --request POST '${props.apiBaseUrl}/apps/annotations' \
curl --location --request POST '${props.appDetail.api_base_url}/apps/annotations' \
--header 'Authorization: Bearer {api_key}' \
--header 'Content-Type: application/json' \
--data-raw '{
@ -1301,10 +1301,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
title="Request"
tag="PUT"
label="/apps/annotations/{annotation_id}"
targetCode={`curl --location --request POST '${props.apiBaseUrl}/apps/annotations/{annotation_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"question": "What is your name?","answer": "I am Dify."}'`}
targetCode={`curl --location --request POST '${props.appDetail.api_base_url}/apps/annotations/{annotation_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"question": "What is your name?","answer": "I am Dify."}'`}
>
```bash {{ title: 'cURL' }}
curl --location --request POST '${props.apiBaseUrl}/apps/annotations/{annotation_id}' \
curl --location --request POST '${props.appDetail.api_base_url}/apps/annotations/{annotation_id}' \
--header 'Authorization: Bearer {api_key}' \
--header 'Content-Type: application/json' \
--data-raw '{
@ -1351,10 +1351,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
title="Request"
tag="PUT"
label="/apps/annotations/{annotation_id}"
targetCode={`curl --location --request DELETE '${props.apiBaseUrl}/apps/annotations/{annotation_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json'`}
targetCode={`curl --location --request DELETE '${props.appDetail.api_base_url}/apps/annotations/{annotation_id}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json'`}
>
```bash {{ title: 'cURL' }}
curl --location --request DELETE '${props.apiBaseUrl}/apps/annotations/{annotation_id}' \
curl --location --request DELETE '${props.appDetail.api_base_url}/apps/annotations/{annotation_id}' \
--header 'Authorization: Bearer {api_key}'
```
</CodeGroup>
@ -1398,7 +1398,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
title="Request"
tag="POST"
label="/apps/annotation-reply/{action}"
targetCode={`curl --location --request POST '${props.apiBaseUrl}/apps/annotation-reply/{action}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"score_threshold": 0.9, "embedding_provider_name": "zhipu", "embedding_model_name": "embedding_3"}'`}
targetCode={`curl --location --request POST '${props.appDetail.api_base_url}/apps/annotation-reply/{action}' \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{"score_threshold": 0.9, "embedding_provider_name": "zhipu", "embedding_model_name": "embedding_3"}'`}
>
```bash {{ title: 'cURL' }}
curl --location --request POST 'https://api.dify.ai/v1/apps/annotation-reply/{action}' \
@ -1448,10 +1448,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
title="Request"
tag="GET"
label="/apps/annotations"
targetCode={`curl --location --request GET '${props.apiBaseUrl}/apps/annotation-reply/{action}/status/{job_id}' \\\n--header 'Authorization: Bearer {api_key}'`}
targetCode={`curl --location --request GET '${props.appDetail.api_base_url}/apps/annotation-reply/{action}/status/{job_id}' \\\n--header 'Authorization: Bearer {api_key}'`}
>
```bash {{ title: 'cURL' }}
curl --location --request GET '${props.apiBaseUrl}/apps/annotation-reply/{action}/status/{job_id}' \
curl --location --request GET '${props.appDetail.api_base_url}/apps/annotation-reply/{action}/status/{job_id}' \
--header 'Authorization: Bearer {api_key}'
```
</CodeGroup>

@ -3,6 +3,7 @@ import DataSourceNotion from './data-source-notion'
import DataSourceWebsite from './data-source-website'
import { fetchDataSource } from '@/service/common'
import { DataSourceProvider } from '@/models/common'
import { ENABLE_WEBSITE_FIRECRAWL, ENABLE_WEBSITE_JINAREADER, ENABLE_WEBSITE_WATERCRAWL } from '@/config'
export default function DataSourcePage() {
const { data } = useSWR({ url: 'data-source/integrates' }, fetchDataSource)
@ -11,9 +12,9 @@ export default function DataSourcePage() {
return (
<div className='mb-8'>
<DataSourceNotion workspaces={notionWorkspaces} />
<DataSourceWebsite provider={DataSourceProvider.jinaReader} />
<DataSourceWebsite provider={DataSourceProvider.fireCrawl} />
<DataSourceWebsite provider={DataSourceProvider.waterCrawl} />
{ENABLE_WEBSITE_JINAREADER && <DataSourceWebsite provider={DataSourceProvider.jinaReader} />}
{ENABLE_WEBSITE_FIRECRAWL && <DataSourceWebsite provider={DataSourceProvider.fireCrawl} />}
{ENABLE_WEBSITE_WATERCRAWL && <DataSourceWebsite provider={DataSourceProvider.waterCrawl} />}
</div>
)
}

@ -104,7 +104,10 @@ const PluginItem: FC<Props> = ({
{!isDifyVersionCompatible && <Tooltip popupContent={
t('plugin.difyVersionNotCompatible', { minimalDifyVersion: declarationMeta.minimum_dify_version })
}><RiErrorWarningLine color='red' className="ml-0.5 h-4 w-4 shrink-0 text-text-accent" /></Tooltip>}
<Badge className='ml-1 shrink-0' text={source === PluginSource.github ? plugin.meta!.version : plugin.version} />
<Badge className='ml-1 shrink-0'
text={source === PluginSource.github ? plugin.meta!.version : plugin.version}
hasRedCornerMark={(source === PluginSource.marketplace) && !!plugin.latest_unique_identifier && plugin.latest_unique_identifier !== plugin_unique_identifier}
/>
</div>
<div className='flex items-center justify-between'>
<Description text={descriptionText} descriptionLineRows={1}></Description>

@ -406,8 +406,7 @@ export type VersionProps = {
export type StrategyParamItem = {
name: string
label: Record<Locale, string>
human_description: Record<Locale, string>
llm_description: string
help: Record<Locale, string>
placeholder: Record<Locale, string>
type: string
scope: string

@ -2,29 +2,44 @@ import { CONVERSATION_ID_INFO } from '../base/chat/constants'
import { fetchAccessToken } from '@/service/share'
import { getProcessedSystemVariablesFromUrlParams } from '../base/chat/utils'
export const isTokenV1 = (token: Record<string, any>) => {
return !token.version
}
export const getInitialTokenV2 = (): Record<string, any> => ({
version: 2,
})
export const checkOrSetAccessToken = async () => {
const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
let accessTokenJson = { [sharedToken]: '' }
const userId = (await getProcessedSystemVariablesFromUrlParams()).user_id
const accessToken = localStorage.getItem('token') || JSON.stringify(getInitialTokenV2())
let accessTokenJson = getInitialTokenV2()
try {
accessTokenJson = JSON.parse(accessToken)
if (isTokenV1(accessTokenJson))
accessTokenJson = getInitialTokenV2()
}
catch {
}
if (!accessTokenJson[sharedToken]) {
const sysUserId = (await getProcessedSystemVariablesFromUrlParams()).user_id
const res = await fetchAccessToken(sharedToken, sysUserId)
accessTokenJson[sharedToken] = res.access_token
if (!accessTokenJson[sharedToken]?.[userId || 'DEFAULT']) {
const res = await fetchAccessToken(sharedToken, userId)
accessTokenJson[sharedToken] = {
...accessTokenJson[sharedToken],
[userId || 'DEFAULT']: res.access_token,
}
localStorage.setItem('token', JSON.stringify(accessTokenJson))
}
}
export const setAccessToken = async (sharedToken: string, token: string) => {
const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
let accessTokenJson = { [sharedToken]: '' }
export const setAccessToken = async (sharedToken: string, token: string, user_id?: string) => {
const accessToken = localStorage.getItem('token') || JSON.stringify(getInitialTokenV2())
let accessTokenJson = getInitialTokenV2()
try {
accessTokenJson = JSON.parse(accessToken)
if (isTokenV1(accessTokenJson))
accessTokenJson = getInitialTokenV2()
}
catch {
@ -32,17 +47,22 @@ export const setAccessToken = async (sharedToken: string, token: string) => {
localStorage.removeItem(CONVERSATION_ID_INFO)
accessTokenJson[sharedToken] = token
accessTokenJson[sharedToken] = {
...accessTokenJson[sharedToken],
[user_id || 'DEFAULT']: token,
}
localStorage.setItem('token', JSON.stringify(accessTokenJson))
}
export const removeAccessToken = () => {
const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
let accessTokenJson = { [sharedToken]: '' }
const accessToken = localStorage.getItem('token') || JSON.stringify(getInitialTokenV2())
let accessTokenJson = getInitialTokenV2()
try {
accessTokenJson = JSON.parse(accessToken)
if (isTokenV1(accessTokenJson))
accessTokenJson = getInitialTokenV2()
}
catch {

@ -0,0 +1,69 @@
import {
memo,
useState,
} from 'react'
import type { EnvironmentVariable } from '@/app/components/workflow/types'
import { DSL_EXPORT_CHECK } from '@/app/components/workflow/constants'
import { useStore } from '@/app/components/workflow/store'
import Features from '@/app/components/workflow/features'
import PluginDependency from '@/app/components/workflow/plugin-dependency'
import UpdateDSLModal from '@/app/components/workflow/update-dsl-modal'
import DSLExportConfirmModal from '@/app/components/workflow/dsl-export-confirm-modal'
import {
useDSL,
usePanelInteractions,
} from '@/app/components/workflow/hooks'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import WorkflowHeader from './workflow-header'
import WorkflowPanel from './workflow-panel'
const WorkflowChildren = () => {
const { eventEmitter } = useEventEmitterContextContext()
const [secretEnvList, setSecretEnvList] = useState<EnvironmentVariable[]>([])
const showFeaturesPanel = useStore(s => s.showFeaturesPanel)
const showImportDSLModal = useStore(s => s.showImportDSLModal)
const setShowImportDSLModal = useStore(s => s.setShowImportDSLModal)
const {
handlePaneContextmenuCancel,
} = usePanelInteractions()
const {
exportCheck,
handleExportDSL,
} = useDSL()
eventEmitter?.useSubscription((v: any) => {
if (v.type === DSL_EXPORT_CHECK)
setSecretEnvList(v.payload.data as EnvironmentVariable[])
})
return (
<>
<PluginDependency />
{
showFeaturesPanel && <Features />
}
{
showImportDSLModal && (
<UpdateDSLModal
onCancel={() => setShowImportDSLModal(false)}
onBackup={exportCheck}
onImport={handlePaneContextmenuCancel}
/>
)
}
{
secretEnvList.length > 0 && (
<DSLExportConfirmModal
envList={secretEnvList}
onConfirm={handleExportDSL}
onClose={() => setSecretEnvList([])}
/>
)
}
<WorkflowHeader />
<WorkflowPanel />
</>
)
}
export default memo(WorkflowChildren)

@ -0,0 +1,11 @@
import { memo } from 'react'
import ChatVariableButton from '@/app/components/workflow/header/chat-variable-button'
import {
useNodesReadOnly,
} from '@/app/components/workflow/hooks'
const ChatVariableTrigger = () => {
const { nodesReadOnly } = useNodesReadOnly()
return <ChatVariableButton disabled={nodesReadOnly} />
}
export default memo(ChatVariableTrigger)

@ -0,0 +1,152 @@
import {
memo,
useCallback,
useMemo,
} from 'react'
import { useNodes } from 'reactflow'
import { RiApps2AddLine } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import {
useStore,
useWorkflowStore,
} from '@/app/components/workflow/store'
import {
useChecklistBeforePublish,
useNodesReadOnly,
useNodesSyncDraft,
} from '@/app/components/workflow/hooks'
import Button from '@/app/components/base/button'
import AppPublisher from '@/app/components/app/app-publisher'
import { useFeatures } from '@/app/components/base/features/hooks'
import {
BlockEnum,
InputVarType,
} from '@/app/components/workflow/types'
import type { StartNodeType } from '@/app/components/workflow/nodes/start/types'
import { useToastContext } from '@/app/components/base/toast'
import { usePublishWorkflow, useResetWorkflowVersionHistory } from '@/service/use-workflow'
import type { PublishWorkflowParams } from '@/types/workflow'
import { fetchAppDetail, fetchAppSSO } from '@/service/apps'
import { useStore as useAppStore } from '@/app/components/app/store'
import { useSelector as useAppSelector } from '@/context/app-context'
const FeaturesTrigger = () => {
const { t } = useTranslation()
const workflowStore = useWorkflowStore()
const appDetail = useAppStore(s => s.appDetail)
const appID = appDetail?.id
const setAppDetail = useAppStore(s => s.setAppDetail)
const systemFeatures = useAppSelector(state => state.systemFeatures)
const {
nodesReadOnly,
getNodesReadOnly,
} = useNodesReadOnly()
const publishedAt = useStore(s => s.publishedAt)
const draftUpdatedAt = useStore(s => s.draftUpdatedAt)
const toolPublished = useStore(s => s.toolPublished)
const nodes = useNodes<StartNodeType>()
const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
const startVariables = startNode?.data.variables
const fileSettings = useFeatures(s => s.features.file)
const variables = useMemo(() => {
const data = startVariables || []
if (fileSettings?.image?.enabled) {
return [
...data,
{
type: InputVarType.files,
variable: '__image',
required: false,
label: 'files',
},
]
}
return data
}, [fileSettings?.image?.enabled, startVariables])
const { handleCheckBeforePublish } = useChecklistBeforePublish()
const { handleSyncWorkflowDraft } = useNodesSyncDraft()
const { notify } = useToastContext()
const handleShowFeatures = useCallback(() => {
const {
showFeaturesPanel,
isRestoring,
setShowFeaturesPanel,
} = workflowStore.getState()
if (getNodesReadOnly() && !isRestoring)
return
setShowFeaturesPanel(!showFeaturesPanel)
}, [workflowStore, getNodesReadOnly])
const resetWorkflowVersionHistory = useResetWorkflowVersionHistory(appDetail!.id)
const updateAppDetail = useCallback(async () => {
try {
const res = await fetchAppDetail({ url: '/apps', id: appID! })
if (systemFeatures.enable_web_sso_switch_component) {
const ssoRes = await fetchAppSSO({ appId: appID! })
setAppDetail({ ...res, enable_sso: ssoRes.enabled })
}
else {
setAppDetail({ ...res })
}
}
catch (error) {
console.error(error)
}
}, [appID, setAppDetail, systemFeatures.enable_web_sso_switch_component])
const { mutateAsync: publishWorkflow } = usePublishWorkflow(appID!)
const onPublish = useCallback(async (params?: PublishWorkflowParams) => {
if (await handleCheckBeforePublish()) {
const res = await publishWorkflow({
title: params?.title || '',
releaseNotes: params?.releaseNotes || '',
})
if (res) {
notify({ type: 'success', message: t('common.api.actionSuccess') })
updateAppDetail()
workflowStore.getState().setPublishedAt(res.created_at)
resetWorkflowVersionHistory()
}
}
else {
throw new Error('Checklist failed')
}
}, [handleCheckBeforePublish, notify, t, workflowStore, publishWorkflow, resetWorkflowVersionHistory, updateAppDetail])
const onPublisherToggle = useCallback((state: boolean) => {
if (state)
handleSyncWorkflowDraft(true)
}, [handleSyncWorkflowDraft])
const handleToolConfigureUpdate = useCallback(() => {
workflowStore.setState({ toolPublished: true })
}, [workflowStore])
return (
<>
<Button className='text-components-button-secondary-text' onClick={handleShowFeatures}>
<RiApps2AddLine className='mr-1 h-4 w-4 text-components-button-secondary-text' />
{t('workflow.common.features')}
</Button>
<AppPublisher
{...{
publishedAt,
draftUpdatedAt,
disabled: nodesReadOnly,
toolPublished,
inputs: variables,
onRefreshData: handleToolConfigureUpdate,
onPublish,
onToggle: onPublisherToggle,
crossAxisOffset: 4,
}}
/>
</>
)
}
export default memo(FeaturesTrigger)

@ -0,0 +1,31 @@
import { useMemo } from 'react'
import type { HeaderProps } from '@/app/components/workflow/header'
import Header from '@/app/components/workflow/header'
import { useStore as useAppStore } from '@/app/components/app/store'
import ChatVariableTrigger from './chat-variable-trigger'
import FeaturesTrigger from './features-trigger'
import { useResetWorkflowVersionHistory } from '@/service/use-workflow'
const WorkflowHeader = () => {
const appDetail = useAppStore(s => s.appDetail)
const resetWorkflowVersionHistory = useResetWorkflowVersionHistory(appDetail!.id)
const headerProps: HeaderProps = useMemo(() => {
return {
normal: {
components: {
left: <ChatVariableTrigger />,
middle: <FeaturesTrigger />,
},
},
restoring: {
onRestoreSettled: resetWorkflowVersionHistory,
},
}
}, [resetWorkflowVersionHistory])
return (
<Header {...headerProps} />
)
}
export default WorkflowHeader

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save