Merge remote-tracking branch 'origin/deploy/rag-dev' into deploy/rag-dev

feat/datasource
jyong 12 months ago
commit 52c118f5b8

@ -491,3 +491,10 @@ OTEL_METRIC_EXPORT_TIMEOUT=30000
# Prevent Clickjacking # Prevent Clickjacking
ALLOW_EMBED=false ALLOW_EMBED=false
# Dataset queue monitor configuration
QUEUE_MONITOR_THRESHOLD=200
# You can configure multiple ones, separated by commas. eg: test1@dify.ai,test2@dify.ai
QUEUE_MONITOR_ALERT_EMAILS=
# Monitor interval in minutes, default is 30 minutes
QUEUE_MONITOR_INTERVAL=30

@ -2,7 +2,7 @@ import os
from typing import Any, Literal, Optional from typing import Any, Literal, Optional
from urllib.parse import parse_qsl, quote_plus from urllib.parse import parse_qsl, quote_plus
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
from .cache.redis_config import RedisConfig from .cache.redis_config import RedisConfig
@ -256,6 +256,25 @@ class InternalTestConfig(BaseSettings):
) )
class DatasetQueueMonitorConfig(BaseSettings):
"""
Configuration settings for Dataset Queue Monitor
"""
QUEUE_MONITOR_THRESHOLD: Optional[NonNegativeInt] = Field(
description="Threshold for dataset queue monitor",
default=200,
)
QUEUE_MONITOR_ALERT_EMAILS: Optional[str] = Field(
description="Emails for dataset queue monitor alert, separated by commas",
default=None,
)
QUEUE_MONITOR_INTERVAL: Optional[NonNegativeFloat] = Field(
description="Interval for dataset queue monitor in minutes",
default=30,
)
class MiddlewareConfig( class MiddlewareConfig(
# place the configs in alphabet order # place the configs in alphabet order
CeleryConfig, CeleryConfig,
@ -303,5 +322,6 @@ class MiddlewareConfig(
BaiduVectorDBConfig, BaiduVectorDBConfig,
OpenGaussConfig, OpenGaussConfig,
TableStoreConfig, TableStoreConfig,
DatasetQueueMonitorConfig,
): ):
pass pass

@ -175,8 +175,11 @@ class DocumentAddByFileApi(DatasetApiResource):
if not dataset: if not dataset:
raise ValueError("Dataset does not exist.") raise ValueError("Dataset does not exist.")
if not dataset.indexing_technique and not args.get("indexing_technique"):
indexing_technique = args.get("indexing_technique") or dataset.indexing_technique
if not indexing_technique:
raise ValueError("indexing_technique is required.") raise ValueError("indexing_technique is required.")
args["indexing_technique"] = indexing_technique
# save file info # save file info
file = request.files["file"] file = request.files["file"]

@ -55,6 +55,25 @@ class ProviderModelWithStatusEntity(ProviderModel):
status: ModelStatus status: ModelStatus
load_balancing_enabled: bool = False load_balancing_enabled: bool = False
def raise_for_status(self) -> None:
"""
Check model status and raise ValueError if not active.
:raises ValueError: When model status is not active, with a descriptive message
"""
if self.status == ModelStatus.ACTIVE:
return
error_messages = {
ModelStatus.NO_CONFIGURE: "Model is not configured",
ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded",
ModelStatus.NO_PERMISSION: "No permission to use this model",
ModelStatus.DISABLED: "Model is disabled",
}
if self.status in error_messages:
raise ValueError(error_messages[self.status])
class ModelWithProviderEntity(ProviderModelWithStatusEntity): class ModelWithProviderEntity(ProviderModelWithStatusEntity):
""" """

@ -41,45 +41,53 @@ class Extensible:
extensions = [] extensions = []
position_map: dict[str, int] = {} position_map: dict[str, int] = {}
# get the path of the current class # Get the package name from the module path
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py") package_name = ".".join(cls.__module__.split(".")[:-1])
current_dir_path = os.path.dirname(current_path)
try:
# traverse subdirectories # Get package directory path
for subdir_name in os.listdir(current_dir_path): package_spec = importlib.util.find_spec(package_name)
if subdir_name.startswith("__"): if not package_spec or not package_spec.origin:
continue raise ImportError(f"Could not find package {package_name}")
subdir_path = os.path.join(current_dir_path, subdir_name) package_dir = os.path.dirname(package_spec.origin)
extension_name = subdir_name
if os.path.isdir(subdir_path): # Traverse subdirectories
for subdir_name in os.listdir(package_dir):
if subdir_name.startswith("__"):
continue
subdir_path = os.path.join(package_dir, subdir_name)
if not os.path.isdir(subdir_path):
continue
extension_name = subdir_name
file_names = os.listdir(subdir_path) file_names = os.listdir(subdir_path)
# is builtin extension, builtin extension # Check for extension module file
# in the front-end page and business logic, there are special treatments. if (extension_name + ".py") not in file_names:
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
continue
# Check for builtin flag and position
builtin = False builtin = False
# default position is 0 can not be None for sort_to_dict_by_position_map
position = 0 position = 0
if "__builtin__" in file_names: if "__builtin__" in file_names:
builtin = True builtin = True
builtin_file_path = os.path.join(subdir_path, "__builtin__") builtin_file_path = os.path.join(subdir_path, "__builtin__")
if os.path.exists(builtin_file_path): if os.path.exists(builtin_file_path):
position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip()) position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
position_map[extension_name] = position position_map[extension_name] = position
if (extension_name + ".py") not in file_names: # Import the extension module
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") module_name = f"{package_name}.{extension_name}.{extension_name}"
continue spec = importlib.util.find_spec(module_name)
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
py_path = os.path.join(subdir_path, extension_name + ".py")
spec = importlib.util.spec_from_file_location(extension_name, py_path)
if not spec or not spec.loader: if not spec or not spec.loader:
raise Exception(f"Failed to load module {extension_name} from {py_path}") raise ImportError(f"Failed to load module {module_name}")
mod = importlib.util.module_from_spec(spec) mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) spec.loader.exec_module(mod)
# Find extension class
extension_class = None extension_class = None
for name, obj in vars(mod).items(): for name, obj in vars(mod).items():
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls: if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
@ -87,21 +95,21 @@ class Extensible:
break break
if not extension_class: if not extension_class:
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.") logging.warning(f"Missing subclass of {cls.__name__} in {module_name}, Skip.")
continue continue
# Load schema if not builtin
json_data: dict[str, Any] = {} json_data: dict[str, Any] = {}
if not builtin: if not builtin:
if "schema.json" not in file_names: json_path = os.path.join(subdir_path, "schema.json")
if not os.path.exists(json_path):
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
continue continue
json_path = os.path.join(subdir_path, "schema.json") with open(json_path, encoding="utf-8") as f:
json_data = {} json_data = json.load(f)
if os.path.exists(json_path):
with open(json_path, encoding="utf-8") as f:
json_data = json.load(f)
# Create extension
extensions.append( extensions.append(
ModuleExtension( ModuleExtension(
extension_class=extension_class, extension_class=extension_class,
@ -113,6 +121,11 @@ class Extensible:
) )
) )
except Exception as e:
logging.exception("Error scanning extensions")
raise
# Sort extensions by position
sorted_extensions = sort_to_dict_by_position_map( sorted_extensions = sort_to_dict_by_position_map(
position_map=position_map, data=extensions, name_func=lambda x: x.name position_map=position_map, data=extensions, name_func=lambda x: x.name
) )

@ -160,6 +160,10 @@ class ProviderModel(BaseModel):
deprecated: bool = False deprecated: bool = False
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
@property
def support_structure_output(self) -> bool:
return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features
class ParameterRule(BaseModel): class ParameterRule(BaseModel):
""" """

@ -3,7 +3,9 @@ from collections import defaultdict
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any, Optional, cast from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
@ -393,19 +395,13 @@ class ProviderManager:
@staticmethod @staticmethod
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]: def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
"""
Get all provider records of the workspace.
:param tenant_id: workspace id
:return:
"""
providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all()
provider_name_to_provider_records_dict = defaultdict(list) provider_name_to_provider_records_dict = defaultdict(list)
for provider in providers: with Session(db.engine, expire_on_commit=False) as session:
# TODO: Use provider name with prefix after the data migration stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider) providers = session.scalars(stmt)
for provider in providers:
# Use provider name with prefix after the data migration
provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
return provider_name_to_provider_records_dict return provider_name_to_provider_records_dict
@staticmethod @staticmethod
@ -416,17 +412,12 @@ class ProviderManager:
:param tenant_id: workspace id :param tenant_id: workspace id
:return: :return:
""" """
# Get all provider model records of the workspace
provider_models = (
db.session.query(ProviderModel)
.filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
.all()
)
provider_name_to_provider_model_records_dict = defaultdict(list) provider_name_to_provider_model_records_dict = defaultdict(list)
for provider_model in provider_models: with Session(db.engine, expire_on_commit=False) as session:
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model) stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
provider_models = session.scalars(stmt)
for provider_model in provider_models:
provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
return provider_name_to_provider_model_records_dict return provider_name_to_provider_model_records_dict
@staticmethod @staticmethod
@ -437,17 +428,14 @@ class ProviderManager:
:param tenant_id: workspace id :param tenant_id: workspace id
:return: :return:
""" """
preferred_provider_types = ( provider_name_to_preferred_provider_type_records_dict = {}
db.session.query(TenantPreferredModelProvider) with Session(db.engine, expire_on_commit=False) as session:
.filter(TenantPreferredModelProvider.tenant_id == tenant_id) stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id)
.all() preferred_provider_types = session.scalars(stmt)
) provider_name_to_preferred_provider_type_records_dict = {
preferred_provider_type.provider_name: preferred_provider_type
provider_name_to_preferred_provider_type_records_dict = { for preferred_provider_type in preferred_provider_types
preferred_provider_type.provider_name: preferred_provider_type }
for preferred_provider_type in preferred_provider_types
}
return provider_name_to_preferred_provider_type_records_dict return provider_name_to_preferred_provider_type_records_dict
@staticmethod @staticmethod
@ -458,18 +446,14 @@ class ProviderManager:
:param tenant_id: workspace id :param tenant_id: workspace id
:return: :return:
""" """
provider_model_settings = (
db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all()
)
provider_name_to_provider_model_settings_dict = defaultdict(list) provider_name_to_provider_model_settings_dict = defaultdict(list)
for provider_model_setting in provider_model_settings: with Session(db.engine, expire_on_commit=False) as session:
( stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id)
provider_model_settings = session.scalars(stmt)
for provider_model_setting in provider_model_settings:
provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append( provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append(
provider_model_setting provider_model_setting
) )
)
return provider_name_to_provider_model_settings_dict return provider_name_to_provider_model_settings_dict
@staticmethod @staticmethod
@ -492,15 +476,14 @@ class ProviderManager:
if not model_load_balancing_enabled: if not model_load_balancing_enabled:
return {} return {}
provider_load_balancing_configs = (
db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all()
)
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
for provider_load_balancing_config in provider_load_balancing_configs: with Session(db.engine, expire_on_commit=False) as session:
provider_name_to_provider_load_balancing_model_configs_dict[ stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id)
provider_load_balancing_config.provider_name provider_load_balancing_configs = session.scalars(stmt)
].append(provider_load_balancing_config) for provider_load_balancing_config in provider_load_balancing_configs:
provider_name_to_provider_load_balancing_model_configs_dict[
provider_load_balancing_config.provider_name
].append(provider_load_balancing_config)
return provider_name_to_provider_load_balancing_model_configs_dict return provider_name_to_provider_load_balancing_model_configs_dict
@ -626,10 +609,9 @@ class ProviderManager:
if not cached_provider_credentials: if not cached_provider_credentials:
try: try:
# fix origin data # fix origin data
if ( if custom_provider_record.encrypted_config is None:
custom_provider_record.encrypted_config raise ValueError("No credentials found")
and not custom_provider_record.encrypted_config.startswith("{") if not custom_provider_record.encrypted_config.startswith("{"):
):
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
else: else:
provider_credentials = json.loads(custom_provider_record.encrypted_config) provider_credentials = json.loads(custom_provider_record.encrypted_config)
@ -733,7 +715,7 @@ class ProviderManager:
return SystemConfiguration(enabled=False) return SystemConfiguration(enabled=False)
# Convert provider_records to dict # Convert provider_records to dict
quota_type_to_provider_records_dict = {} quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {}
for provider_record in provider_records: for provider_record in provider_records:
if provider_record.provider_type != ProviderType.SYSTEM.value: if provider_record.provider_type != ProviderType.SYSTEM.value:
continue continue
@ -758,6 +740,11 @@ class ProviderManager:
else: else:
provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type] provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type]
if provider_record.quota_used is None:
raise ValueError("quota_used is None")
if provider_record.quota_limit is None:
raise ValueError("quota_limit is None")
quota_configuration = QuotaConfiguration( quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type, quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
@ -791,10 +778,9 @@ class ProviderManager:
cached_provider_credentials = provider_credentials_cache.get() cached_provider_credentials = provider_credentials_cache.get()
if not cached_provider_credentials: if not cached_provider_credentials:
try: provider_credentials: dict[str, Any] = {}
provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config) if provider_records and provider_records[0].encrypted_config:
except JSONDecodeError: provider_credentials = json.loads(provider_records[0].encrypted_config)
provider_credentials = {}
# Get provider credential secret variables # Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables( provider_credential_secret_variables = self._extract_secret_variables(

@ -66,7 +66,8 @@ class LLMNodeData(BaseNodeData):
context: ContextConfig context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig) vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: dict | None = None structured_output: dict | None = None
structured_output_enabled: bool = False # We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
@field_validator("prompt_config", mode="before") @field_validator("prompt_config", mode="before")
@classmethod @classmethod
@ -74,3 +75,7 @@ class LLMNodeData(BaseNodeData):
if v is None: if v is None:
return PromptConfig() return PromptConfig()
return v return v
@property
def structured_output_enabled(self) -> bool:
return self.structured_output_switch_on and self.structured_output is not None

@ -12,9 +12,7 @@ from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file import FileType, file_manager from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
@ -74,7 +72,6 @@ from core.workflow.nodes.event import (
from core.workflow.utils.structured_output.entities import ( from core.workflow.utils.structured_output.entities import (
ResponseFormat, ResponseFormat,
SpecialModelType, SpecialModelType,
SupportStructuredOutputStatus,
) )
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
@ -277,7 +274,7 @@ class LLMNode(BaseNode[LLMNodeData]):
llm_usage=usage, llm_usage=usage,
) )
) )
except LLMNodeError as e: except ValueError as e:
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
@ -527,65 +524,53 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_model_config( def _fetch_model_config(
self, node_data_model: ModelConfig self, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model_name = node_data_model.name if not node_data_model.mode:
provider_name = node_data_model.provider raise LLMModeRequiredError("LLM mode is required.")
model_manager = ModelManager() model = ModelManager().get_model_instance(
model_instance = model_manager.get_model_instance( tenant_id=self.tenant_id,
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name model_type=ModelType.LLM,
provider=node_data_model.provider,
model=node_data_model.name,
) )
provider_model_bundle = model_instance.provider_model_bundle model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
# check model # check model
provider_model = provider_model_bundle.configuration.get_provider_model( provider_model = model.provider_model_bundle.configuration.get_provider_model(
model=model_name, model_type=ModelType.LLM model=node_data_model.name, model_type=ModelType.LLM
) )
if provider_model is None: if provider_model is None:
raise ModelNotExistError(f"Model {model_name} not exist.") raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config # model config
completion_params = node_data_model.completion_params stop: list[str] = []
stop = [] if "stop" in node_data_model.completion_params:
if "stop" in completion_params: stop = node_data_model.completion_params.pop("stop")
stop = completion_params["stop"]
del completion_params["stop"]
# get model mode
model_mode = node_data_model.mode
if not model_mode:
raise LLMModeRequiredError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
if not model_schema: if not model_schema:
raise ModelNotExistError(f"Model {model_name} not exist.") raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
support_structured_output = self._check_model_structured_output_support()
if support_structured_output == SupportStructuredOutputStatus.SUPPORTED: if self.node_data.structured_output_enabled:
completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules) if model_schema.support_structure_output:
elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: node_data_model.completion_params = self._handle_native_json_schema(
# Set appropriate response format based on model capabilities node_data_model.completion_params, model_schema.parameter_rules
self._set_response_format(completion_params, model_schema.parameter_rules) )
return model_instance, ModelConfigWithCredentialsEntity( else:
provider=provider_name, # Set appropriate response format based on model capabilities
model=model_name, self._set_response_format(node_data_model.completion_params, model_schema.parameter_rules)
return model, ModelConfigWithCredentialsEntity(
provider=node_data_model.provider,
model=node_data_model.name,
model_schema=model_schema, model_schema=model_schema,
mode=model_mode, mode=node_data_model.mode,
provider_model_bundle=provider_model_bundle, provider_model_bundle=model.provider_model_bundle,
credentials=model_credentials, credentials=model.credentials,
parameters=completion_params, parameters=node_data_model.completion_params,
stop=stop, stop=stop,
) )
@ -786,13 +771,25 @@ class LLMNode(BaseNode[LLMNodeData]):
"No prompt found in the LLM configuration. " "No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding." "Please ensure a prompt is properly configured before proceeding."
) )
support_structured_output = self._check_model_structured_output_support()
if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: model = ModelManager().get_model_instance(
filtered_prompt_messages = self._handle_prompt_based_schema( tenant_id=self.tenant_id,
prompt_messages=filtered_prompt_messages, model_type=ModelType.LLM,
) provider=self.node_data.model.provider,
stop = model_config.stop model=self.node_data.model.name,
return filtered_prompt_messages, stop )
model_schema = model.model_type_instance.get_model_schema(
model=self.node_data.model.name,
credentials=model.credentials,
)
if not model_schema:
raise ModelNotExistError(f"Model {self.node_data.model.name} not exist.")
if self.node_data.structured_output_enabled:
if not model_schema.support_structure_output:
filtered_prompt_messages = self._handle_prompt_based_schema(
prompt_messages=filtered_prompt_messages,
)
return filtered_prompt_messages, model_config.stop
def _parse_structured_output(self, result_text: str) -> dict[str, Any]: def _parse_structured_output(self, result_text: str) -> dict[str, Any]:
structured_output: dict[str, Any] = {} structured_output: dict[str, Any] = {}
@ -903,7 +900,7 @@ class LLMNode(BaseNode[LLMNodeData]):
variable_mapping["#context#"] = node_data.context.variable_selector variable_mapping["#context#"] = node_data.context.variable_selector
if node_data.vision.enabled: if node_data.vision.enabled:
variable_mapping["#files#"] = ["sys", SystemVariableKey.FILES.value] variable_mapping["#files#"] = node_data.vision.configs.variable_selector
if node_data.memory: if node_data.memory:
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value] variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value]
@ -1185,32 +1182,6 @@ class LLMNode(BaseNode[LLMNodeData]):
except json.JSONDecodeError: except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format") raise LLMNodeError("structured_output_schema is not valid JSON format")
def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus:
"""
Check if the current model supports structured output.
Returns:
SupportStructuredOutput: The support status of structured output
"""
# Early return if structured output is disabled
if (
not isinstance(self.node_data, LLMNodeData)
or not self.node_data.structured_output_enabled
or not self.node_data.structured_output
):
return SupportStructuredOutputStatus.DISABLED
# Get model schema and check if it exists
model_schema = self._fetch_model_schema(self.node_data.model.provider)
if not model_schema:
return SupportStructuredOutputStatus.DISABLED
# Check if model supports structured output feature
return (
SupportStructuredOutputStatus.SUPPORTED
if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features)
else SupportStructuredOutputStatus.UNSUPPORTED
)
def _save_multimodal_output_and_convert_result_to_markdown( def _save_multimodal_output_and_convert_result_to_markdown(
self, self,
contents: str | list[PromptMessageContentUnionTypes] | None, contents: str | list[PromptMessageContentUnionTypes] | None,

@ -14,11 +14,3 @@ class SpecialModelType(StrEnum):
GEMINI = "gemini" GEMINI = "gemini"
OLLAMA = "ollama" OLLAMA = "ollama"
class SupportStructuredOutputStatus(StrEnum):
"""Constants for structured output support status"""
SUPPORTED = "supported"
UNSUPPORTED = "unsupported"
DISABLED = "disabled"

@ -70,6 +70,7 @@ def init_app(app: DifyApp) -> Celery:
"schedule.update_tidb_serverless_status_task", "schedule.update_tidb_serverless_status_task",
"schedule.clean_messages", "schedule.clean_messages",
"schedule.mail_clean_document_notify_task", "schedule.mail_clean_document_notify_task",
"schedule.queue_monitor_task",
] ]
day = dify_config.CELERY_BEAT_SCHEDULER_TIME day = dify_config.CELERY_BEAT_SCHEDULER_TIME
beat_schedule = { beat_schedule = {
@ -98,6 +99,12 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task", "task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task",
"schedule": crontab(minute="0", hour="10", day_of_week="1"), "schedule": crontab(minute="0", hour="10", day_of_week="1"),
}, },
"datasets-queue-monitor": {
"task": "schedule.queue_monitor_task.queue_monitor_task",
"schedule": timedelta(
minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30
),
},
} }
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)

@ -1,6 +1,9 @@
from datetime import datetime
from enum import Enum from enum import Enum
from typing import Optional
from sqlalchemy import func from sqlalchemy import func, text
from sqlalchemy.orm import Mapped, mapped_column
from .base import Base from .base import Base
from .engine import db from .engine import db
@ -51,20 +54,24 @@ class Provider(Base):
), ),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False) provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) provider_type: Mapped[str] = mapped_column(
encrypted_config = db.Column(db.Text, nullable=True) db.String(40), nullable=False, server_default=text("'custom'::character varying")
is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) )
last_used = db.Column(db.DateTime, nullable=True) encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)
quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying")) quota_type: Mapped[Optional[str]] = mapped_column(
quota_limit = db.Column(db.BigInteger, nullable=True) db.String(40), nullable=True, server_default=text("''::character varying")
quota_used = db.Column(db.BigInteger, default=0) )
quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True)
quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
def __repr__(self): def __repr__(self):
return ( return (
@ -104,15 +111,15 @@ class ProviderModel(Base):
), ),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False) provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_name = db.Column(db.String(255), nullable=False) model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False) model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
encrypted_config = db.Column(db.Text, nullable=True) encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TenantDefaultModel(Base): class TenantDefaultModel(Base):
@ -122,13 +129,13 @@ class TenantDefaultModel(Base):
db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False) provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_name = db.Column(db.String(255), nullable=False) model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False) model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TenantPreferredModelProvider(Base): class TenantPreferredModelProvider(Base):
@ -138,12 +145,12 @@ class TenantPreferredModelProvider(Base):
db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False) provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
preferred_provider_type = db.Column(db.String(40), nullable=False) preferred_provider_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class ProviderOrder(Base): class ProviderOrder(Base):
@ -153,22 +160,24 @@ class ProviderOrder(Base):
db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False) provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
account_id = db.Column(StringUUID, nullable=False) account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
payment_product_id = db.Column(db.String(191), nullable=False) payment_product_id: Mapped[str] = mapped_column(db.String(191), nullable=False)
payment_id = db.Column(db.String(191)) payment_id: Mapped[Optional[str]] = mapped_column(db.String(191))
transaction_id = db.Column(db.String(191)) transaction_id: Mapped[Optional[str]] = mapped_column(db.String(191))
quantity = db.Column(db.Integer, nullable=False, server_default=db.text("1")) quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1"))
currency = db.Column(db.String(40)) currency: Mapped[Optional[str]] = mapped_column(db.String(40))
total_amount = db.Column(db.Integer) total_amount: Mapped[Optional[int]] = mapped_column(db.Integer)
payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying")) payment_status: Mapped[str] = mapped_column(
paid_at = db.Column(db.DateTime) db.String(40), nullable=False, server_default=text("'wait_pay'::character varying")
pay_failed_at = db.Column(db.DateTime) )
refunded_at = db.Column(db.DateTime) paid_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) pay_failed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) refunded_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class ProviderModelSetting(Base): class ProviderModelSetting(Base):
@ -182,15 +191,15 @@ class ProviderModelSetting(Base):
db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False) provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_name = db.Column(db.String(255), nullable=False) model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False) model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class LoadBalancingModelConfig(Base): class LoadBalancingModelConfig(Base):
@ -204,13 +213,13 @@ class LoadBalancingModelConfig(Base):
db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False) provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_name = db.Column(db.String(255), nullable=False) model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False) model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
name = db.Column(db.String(255), nullable=False) name: Mapped[str] = mapped_column(db.String(255), nullable=False)
encrypted_config = db.Column(db.Text, nullable=True) encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())

@ -0,0 +1,62 @@
import logging
from datetime import datetime
from urllib.parse import urlparse
import click
from flask import render_template
from redis import Redis
import app
from configs import dify_config
from extensions.ext_database import db
from extensions.ext_mail import mail
# Create a dedicated Redis connection (using the same configuration as Celery)
celery_broker_url = dify_config.CELERY_BROKER_URL
parsed = urlparse(celery_broker_url)
host = parsed.hostname or "localhost"
port = parsed.port or 6379
password = parsed.password or None
redis_db = parsed.path.strip("/") or "1" # type: ignore
celery_redis = Redis(host=host, port=port, password=password, db=redis_db)
@app.celery.task(queue="monitor")
def queue_monitor_task():
queue_name = "dataset"
threshold = dify_config.QUEUE_MONITOR_THRESHOLD
try:
queue_length = celery_redis.llen(f"{queue_name}")
logging.info(click.style(f"Start monitor {queue_name}", fg="green"))
logging.info(click.style(f"Queue length: {queue_length}", fg="green"))
if queue_length >= threshold:
warning_msg = f"Queue {queue_name} task count exceeded the limit.: {queue_length}/{threshold}"
logging.warning(click.style(warning_msg, fg="red"))
alter_emails = dify_config.QUEUE_MONITOR_ALERT_EMAILS
if alter_emails:
to_list = alter_emails.split(",")
for to in to_list:
try:
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
html_content = render_template(
"queue_monitor_alert_email_template_en-US.html",
queue_name=queue_name,
queue_length=queue_length,
threshold=threshold,
alert_time=current_time,
)
mail.send(
to=to, subject="Alert: Dataset Queue pending tasks exceeded the limit", html=html_content
)
except Exception as e:
logging.exception(click.style("Exception occurred during sending email", fg="red"))
except Exception as e:
logging.exception(click.style("Exception occurred during queue monitoring", fg="red"))
finally:
if db.session.is_active:
db.session.close()

@ -5,7 +5,7 @@ import uuid
import click import click
from celery import shared_task # type: ignore from celery import shared_task # type: ignore
from sqlalchemy import func, select from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.model_manager import ModelManager from core.model_manager import ModelManager
@ -68,11 +68,6 @@ def batch_create_segment_to_index_task(
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model, model=dataset.embedding_model,
) )
word_count_change = 0
segments_to_insert: list[str] = []
max_position_stmt = select(func.max(DocumentSegment.position)).where(
DocumentSegment.document_id == dataset_document.id
)
word_count_change = 0 word_count_change = 0
if embedding_model: if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens( tokens_list = embedding_model.get_text_embedding_num_tokens(

@ -0,0 +1,129 @@
<!DOCTYPE html>
<html>
<head>
<style>
body {
font-family: 'Arial', sans-serif;
line-height: 16pt;
color: #101828;
background-color: #e9ebf0;
margin: 0;
padding: 0;
}
.container {
width: 600px;
min-height: 605px;
margin: 40px auto;
padding: 36px 48px;
background-color: #fcfcfd;
border-radius: 16px;
border: 1px solid #ffffff;
box-shadow: 0 2px 4px -2px rgba(9, 9, 11, 0.08);
}
.header {
margin-bottom: 24px;
}
.header img {
max-width: 100px;
height: auto;
}
.title {
font-weight: 600;
font-size: 24px;
line-height: 28.8px;
}
.description {
font-size: 13px;
line-height: 16px;
color: #676f83;
margin-top: 12px;
}
.alert-content {
padding: 16px 32px;
text-align: center;
border-radius: 16px;
background-color: #fef0f0;
margin: 16px auto;
border: 1px solid #fda29b;
}
.alert-title {
line-height: 24px;
font-weight: 700;
font-size: 18px;
color: #d92d20;
}
.alert-detail {
line-height: 20px;
font-size: 14px;
margin-top: 8px;
}
.typography {
letter-spacing: -0.07px;
font-weight: 400;
font-style: normal;
font-size: 14px;
line-height: 20px;
color: #354052;
margin-top: 12px;
margin-bottom: 12px;
}
.typography p{
margin: 0 auto;
}
.typography-title {
color: #101828;
font-size: 14px;
font-style: normal;
font-weight: 600;
line-height: 20px;
margin-top: 12px;
margin-bottom: 4px;
}
.tip-list{
margin: 0;
padding-left: 10px;
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<img src="https://assets.dify.ai/images/logo.png" alt="Dify Logo" />
</div>
<p class="title">Queue Monitoring Alert</p>
<p class="typography">Our system has detected an abnormal queue status that requires your attention:</p>
<div class="alert-content">
<div class="alert-title">Queue Task Alert</div>
<div class="alert-detail">
Queue "{{queue_name}}" has {{queue_length}} pending tasks (Threshold: {{threshold}})
</div>
</div>
<div class="typography">
<p style="margin-bottom:4px">Recommended actions:</p>
<p>1. Check the queue processing status in the system dashboard</p>
<p>2. Verify if there are any processing bottlenecks</p>
<p>3. Consider scaling up workers if needed</p>
</div>
<p class="typography-title">Additional Information:</p>
<ul class="typography tip-list">
<li>Alert triggered at: {{alert_time}}</li>
</ul>
</div>
</body>
</html>

@ -3,11 +3,16 @@ import os
import time import time
import uuid import uuid
from collections.abc import Generator from collections.abc import Generator
from unittest.mock import MagicMock from decimal import Decimal
from unittest.mock import MagicMock, patch
import pytest import pytest
from app_factory import create_app
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
@ -19,13 +24,27 @@ from core.workflow.nodes.llm.node import LLMNode
from extensions.ext_database import db from extensions.ext_database import db
from models.enums import UserFrom from models.enums import UserFrom
from models.workflow import WorkflowType from models.workflow import WorkflowType
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
"""FOR MOCK FIXTURES, DO NOT REMOVE""" """FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
@pytest.fixture(scope="session")
def app():
# Set up storage configuration
os.environ["STORAGE_TYPE"] = "opendal"
os.environ["OPENDAL_SCHEME"] = "fs"
os.environ["OPENDAL_FS_ROOT"] = "storage"
# Ensure storage directory exists
os.makedirs("storage", exist_ok=True)
app = create_app()
dify_config.LOGIN_DISABLED = True
return app
def init_llm_node(config: dict) -> LLMNode: def init_llm_node(config: dict) -> LLMNode:
graph_config = { graph_config = {
"edges": [ "edges": [
@ -40,13 +59,19 @@ def init_llm_node(config: dict) -> LLMNode:
graph = Graph.init(graph_config=graph_config) graph = Graph.init(graph_config=graph_config)
# Use proper UUIDs for database compatibility
tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c"
workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d"
user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e"
init_params = GraphInitParams( init_params = GraphInitParams(
tenant_id="1", tenant_id=tenant_id,
app_id="1", app_id=app_id,
workflow_type=WorkflowType.WORKFLOW, workflow_type=WorkflowType.WORKFLOW,
workflow_id="1", workflow_id=workflow_id,
graph_config=graph_config, graph_config=graph_config,
user_id="1", user_id=user_id,
user_from=UserFrom.ACCOUNT, user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
call_depth=0, call_depth=0,
@ -77,115 +102,197 @@ def init_llm_node(config: dict) -> LLMNode:
return node return node
def test_execute_llm(setup_model_mock): def test_execute_llm(app):
node = init_llm_node( with app.app_context():
config={ node = init_llm_node(
"id": "llm", config={
"data": { "id": "llm",
"title": "123", "data": {
"type": "llm", "title": "123",
"model": { "type": "llm",
"provider": "langgenius/openai/openai", "model": {
"name": "gpt-3.5-turbo", "provider": "langgenius/openai/openai",
"mode": "chat", "name": "gpt-3.5-turbo",
"completion_params": {}, "mode": "chat",
"completion_params": {},
},
"prompt_template": [
{
"role": "system",
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.",
},
{"role": "user", "text": "{{#sys.query#}}"},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
}, },
"prompt_template": [
{"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."},
{"role": "user", "text": "{{#sys.query#}}"},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
}, },
}, )
)
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
# Mock db.session.close() # Create a proper LLM result with real entities
db.session.close = MagicMock() mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1000"),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
node._fetch_model_config = get_mocked_fetch_model_config( mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")
provider="langgenius/openai/openai",
model="gpt-3.5-turbo", mock_llm_result = LLMResult(
mode="chat", model="gpt-3.5-turbo",
credentials=credentials, prompt_messages=[],
) message=mock_message,
usage=mock_usage,
)
# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create a simple mock model config with required attributes
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "langgenius/openai/openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config
# Also mock ModelManager.get_model_instance to avoid database calls
def mock_get_model_instance(_self, **kwargs):
return mock_model_instance
# execute node with (
result = node._run() patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
assert isinstance(result, Generator) patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
):
# execute node
result = node._run()
assert isinstance(result, Generator)
for item in result: for item in result:
if isinstance(item, RunCompletedEvent): if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None assert item.run_result.process_data is not None
assert item.run_result.outputs is not None assert item.run_result.outputs is not None
assert item.run_result.outputs.get("text") is not None assert item.run_result.outputs.get("text") is not None
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_model_mock): def test_execute_llm_with_jinja2(app, setup_code_executor_mock):
""" """
Test execute LLM node with jinja2 Test execute LLM node with jinja2
""" """
node = init_llm_node( with app.app_context():
config={ node = init_llm_node(
"id": "llm", config={
"data": { "id": "llm",
"title": "123", "data": {
"type": "llm", "title": "123",
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, "type": "llm",
"prompt_config": { "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
"jinja2_variables": [ "prompt_config": {
{"variable": "sys_query", "value_selector": ["sys", "query"]}, "jinja2_variables": [
{"variable": "output", "value_selector": ["abc", "output"]}, {"variable": "sys_query", "value_selector": ["sys", "query"]},
] {"variable": "output", "value_selector": ["abc", "output"]},
}, ]
"prompt_template": [
{
"role": "system",
"text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
"jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
"edition_type": "jinja2",
}, },
{ "prompt_template": [
"role": "user", {
"text": "{{#sys.query#}}", "role": "system",
"jinja2_text": "{{sys_query}}", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
"edition_type": "basic", "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
}, "edition_type": "jinja2",
], },
"memory": None, {
"context": {"enabled": False}, "role": "user",
"vision": {"enabled": False}, "text": "{{#sys.query#}}",
"jinja2_text": "{{sys_query}}",
"edition_type": "basic",
},
],
"memory": None,
"context": {"enabled": False},
"vision": {"enabled": False},
},
}, },
}, )
)
credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} # Mock db.session.close()
db.session.close = MagicMock()
# Mock db.session.close() # Create a proper LLM result with real entities
db.session.close = MagicMock() mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1000"),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
node._fetch_model_config = get_mocked_fetch_model_config( mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
provider="langgenius/openai/openai",
model="gpt-3.5-turbo", mock_llm_result = LLMResult(
mode="chat", model="gpt-3.5-turbo",
credentials=credentials, prompt_messages=[],
) message=mock_message,
usage=mock_usage,
)
# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create a simple mock model config with required attributes
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config
# Also mock ModelManager.get_model_instance to avoid database calls
def mock_get_model_instance(_self, **kwargs):
return mock_model_instance
# execute node with (
result = node._run() patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
):
# execute node
result = node._run()
for item in result: for item in result:
if isinstance(item, RunCompletedEvent): if isinstance(item, RunCompletedEvent):
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None assert item.run_result.process_data is not None
assert "sunny" in json.dumps(item.run_result.process_data) assert "sunny" in json.dumps(item.run_result.process_data)
assert "what's the weather today?" in json.dumps(item.run_result.process_data) assert "what's the weather today?" in json.dumps(item.run_result.process_data)
def test_extract_json(): def test_extract_json():

@ -1057,7 +1057,7 @@ PLUGIN_MAX_EXECUTION_TIMEOUT=600
PIP_MIRROR_URL= PIP_MIRROR_URL=
# https://github.com/langgenius/dify-plugin-daemon/blob/main/.env.example # https://github.com/langgenius/dify-plugin-daemon/blob/main/.env.example
# Plugin storage type, local aws_s3 tencent_cos azure_blob aliyun_oss # Plugin storage type, local aws_s3 tencent_cos azure_blob aliyun_oss volcengine_tos
PLUGIN_STORAGE_TYPE=local PLUGIN_STORAGE_TYPE=local
PLUGIN_STORAGE_LOCAL_ROOT=/app/storage PLUGIN_STORAGE_LOCAL_ROOT=/app/storage
PLUGIN_WORKING_PATH=/app/storage/cwd PLUGIN_WORKING_PATH=/app/storage/cwd
@ -1087,6 +1087,11 @@ PLUGIN_ALIYUN_OSS_ACCESS_KEY_ID=
PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET= PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET=
PLUGIN_ALIYUN_OSS_AUTH_VERSION=v4 PLUGIN_ALIYUN_OSS_AUTH_VERSION=v4
PLUGIN_ALIYUN_OSS_PATH= PLUGIN_ALIYUN_OSS_PATH=
# Plugin oss volcengine tos
PLUGIN_VOLCENGINE_TOS_ENDPOINT=
PLUGIN_VOLCENGINE_TOS_ACCESS_KEY=
PLUGIN_VOLCENGINE_TOS_SECRET_KEY=
PLUGIN_VOLCENGINE_TOS_REGION=
# ------------------------------ # ------------------------------
# OTLP Collector Configuration # OTLP Collector Configuration
@ -1106,3 +1111,10 @@ OTEL_METRIC_EXPORT_TIMEOUT=30000
# Prevent Clickjacking # Prevent Clickjacking
ALLOW_EMBED=false ALLOW_EMBED=false
# Dataset queue monitor configuration
QUEUE_MONITOR_THRESHOLD=200
# You can configure multiple ones, separated by commas. eg: test1@dify.ai,test2@dify.ai
QUEUE_MONITOR_ALERT_EMAILS=
# Monitor interval in minutes, default is 30 minutes
QUEUE_MONITOR_INTERVAL=30

@ -184,6 +184,10 @@ services:
ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-} ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-}
ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4} ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4}
ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-} ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-}
VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-}
VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-}
VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-}
VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-}
ports: ports:
- "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}" - "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}"
volumes: volumes:

@ -121,6 +121,10 @@ services:
ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-} ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-}
ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4} ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4}
ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-} ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-}
VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-}
VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-}
VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-}
VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-}
ports: ports:
- "${EXPOSE_PLUGIN_DAEMON_PORT:-5002}:${PLUGIN_DAEMON_PORT:-5002}" - "${EXPOSE_PLUGIN_DAEMON_PORT:-5002}:${PLUGIN_DAEMON_PORT:-5002}"
- "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}" - "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}"

@ -484,6 +484,10 @@ x-shared-env: &shared-api-worker-env
PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-} PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-}
PLUGIN_ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4} PLUGIN_ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4}
PLUGIN_ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-} PLUGIN_ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-}
PLUGIN_VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-}
PLUGIN_VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-}
PLUGIN_VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-}
PLUGIN_VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-}
ENABLE_OTEL: ${ENABLE_OTEL:-false} ENABLE_OTEL: ${ENABLE_OTEL:-false}
OTLP_BASE_ENDPOINT: ${OTLP_BASE_ENDPOINT:-http://localhost:4318} OTLP_BASE_ENDPOINT: ${OTLP_BASE_ENDPOINT:-http://localhost:4318}
OTLP_API_KEY: ${OTLP_API_KEY:-} OTLP_API_KEY: ${OTLP_API_KEY:-}
@ -497,6 +501,9 @@ x-shared-env: &shared-api-worker-env
OTEL_BATCH_EXPORT_TIMEOUT: ${OTEL_BATCH_EXPORT_TIMEOUT:-10000} OTEL_BATCH_EXPORT_TIMEOUT: ${OTEL_BATCH_EXPORT_TIMEOUT:-10000}
OTEL_METRIC_EXPORT_TIMEOUT: ${OTEL_METRIC_EXPORT_TIMEOUT:-30000} OTEL_METRIC_EXPORT_TIMEOUT: ${OTEL_METRIC_EXPORT_TIMEOUT:-30000}
ALLOW_EMBED: ${ALLOW_EMBED:-false} ALLOW_EMBED: ${ALLOW_EMBED:-false}
QUEUE_MONITOR_THRESHOLD: ${QUEUE_MONITOR_THRESHOLD:-200}
QUEUE_MONITOR_ALERT_EMAILS: ${QUEUE_MONITOR_ALERT_EMAILS:-}
QUEUE_MONITOR_INTERVAL: ${QUEUE_MONITOR_INTERVAL:-30}
services: services:
# API service # API service
@ -683,6 +690,10 @@ services:
ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-} ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-}
ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4} ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4}
ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-} ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-}
VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-}
VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-}
VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-}
VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-}
ports: ports:
- "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}" - "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}"
volumes: volumes:

@ -152,3 +152,8 @@ PLUGIN_ALIYUN_OSS_ACCESS_KEY_ID=
PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET= PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET=
PLUGIN_ALIYUN_OSS_AUTH_VERSION=v4 PLUGIN_ALIYUN_OSS_AUTH_VERSION=v4
PLUGIN_ALIYUN_OSS_PATH= PLUGIN_ALIYUN_OSS_PATH=
# Plugin oss volcengine tos
PLUGIN_VOLCENGINE_TOS_ENDPOINT=
PLUGIN_VOLCENGINE_TOS_ACCESS_KEY=
PLUGIN_VOLCENGINE_TOS_SECRET_KEY=
PLUGIN_VOLCENGINE_TOS_REGION=

@ -5,7 +5,6 @@ import type { Area } from 'react-easy-crop'
import Modal from '../modal' import Modal from '../modal'
import Divider from '../divider' import Divider from '../divider'
import Button from '../button' import Button from '../button'
import { ImagePlus } from '../icons/src/vender/line/images'
import { useLocalFileUploader } from '../image-uploader/hooks' import { useLocalFileUploader } from '../image-uploader/hooks'
import EmojiPickerInner from '../emoji-picker/Inner' import EmojiPickerInner from '../emoji-picker/Inner'
import type { OnImageInput } from './ImageInput' import type { OnImageInput } from './ImageInput'
@ -16,6 +15,7 @@ import type { AppIconType, ImageFile } from '@/types/app'
import cn from '@/utils/classnames' import cn from '@/utils/classnames'
import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config' import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config'
import { noop } from 'lodash-es' import { noop } from 'lodash-es'
import { RiImageCircleAiLine } from '@remixicon/react'
export type AppIconEmojiSelection = { export type AppIconEmojiSelection = {
type: 'emoji' type: 'emoji'
@ -46,7 +46,7 @@ const AppIconPicker: FC<AppIconPickerProps> = ({
const tabs = [ const tabs = [
{ key: 'emoji', label: t('app.iconPicker.emoji'), icon: <span className="text-lg">🤖</span> }, { key: 'emoji', label: t('app.iconPicker.emoji'), icon: <span className="text-lg">🤖</span> },
{ key: 'image', label: t('app.iconPicker.image'), icon: <ImagePlus /> }, { key: 'image', label: t('app.iconPicker.image'), icon: <RiImageCircleAiLine className='size-4' /> },
] ]
const [activeTab, setActiveTab] = useState<AppIconType>('emoji') const [activeTab, setActiveTab] = useState<AppIconType>('emoji')
@ -119,10 +119,10 @@ const AppIconPicker: FC<AppIconPickerProps> = ({
{tabs.map(tab => ( {tabs.map(tab => (
<button <button
key={tab.key} key={tab.key}
className={` className={cn(
flex h-8 flex-1 shrink-0 items-center justify-center rounded-xl p-2 text-sm font-medium 'system-sm-medium flex h-8 flex-1 shrink-0 items-center justify-center rounded-xl p-2 text-text-tertiary',
${activeTab === tab.key && 'bg-components-main-nav-nav-button-bg-active shadow-md'} activeTab === tab.key && 'bg-components-main-nav-nav-button-bg-active text-text-accent shadow-md',
`} )}
onClick={() => setActiveTab(tab.key as AppIconType)} onClick={() => setActiveTab(tab.key as AppIconType)}
> >
{tab.icon} &nbsp; {tab.label} {tab.icon} &nbsp; {tab.label}

@ -16,7 +16,7 @@ import type {
Feedback, Feedback,
} from '../types' } from '../types'
import { CONVERSATION_ID_INFO } from '../constants' import { CONVERSATION_ID_INFO } from '../constants'
import { buildChatItemTree, getProcessedSystemVariablesFromUrlParams } from '../utils' import { buildChatItemTree, getProcessedSystemVariablesFromUrlParams, getRawInputsFromUrlParams } from '../utils'
import { addFileInfos, sortAgentSorts } from '../../../tools/utils' import { addFileInfos, sortAgentSorts } from '../../../tools/utils'
import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils' import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils'
import { import {
@ -195,6 +195,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
const { t } = useTranslation() const { t } = useTranslation()
const newConversationInputsRef = useRef<Record<string, any>>({}) const newConversationInputsRef = useRef<Record<string, any>>({})
const [newConversationInputs, setNewConversationInputs] = useState<Record<string, any>>({}) const [newConversationInputs, setNewConversationInputs] = useState<Record<string, any>>({})
const [initInputs, setInitInputs] = useState<Record<string, any>>({})
const handleNewConversationInputsChange = useCallback((newInputs: Record<string, any>) => { const handleNewConversationInputsChange = useCallback((newInputs: Record<string, any>) => {
newConversationInputsRef.current = newInputs newConversationInputsRef.current = newInputs
setNewConversationInputs(newInputs) setNewConversationInputs(newInputs)
@ -202,20 +203,29 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
const inputsForms = useMemo(() => { const inputsForms = useMemo(() => {
return (appParams?.user_input_form || []).filter((item: any) => !item.external_data_tool).map((item: any) => { return (appParams?.user_input_form || []).filter((item: any) => !item.external_data_tool).map((item: any) => {
if (item.paragraph) { if (item.paragraph) {
let value = initInputs[item.paragraph.variable]
if (value && item.paragraph.max_length && value.length > item.paragraph.max_length)
value = value.slice(0, item.paragraph.max_length)
return { return {
...item.paragraph, ...item.paragraph,
default: value || item.default,
type: 'paragraph', type: 'paragraph',
} }
} }
if (item.number) { if (item.number) {
const convertedNumber = Number(initInputs[item.number.variable]) ?? undefined
return { return {
...item.number, ...item.number,
default: convertedNumber || item.default,
type: 'number', type: 'number',
} }
} }
if (item.select) { if (item.select) {
const isInputInOptions = item.select.options.includes(initInputs[item.select.variable])
return { return {
...item.select, ...item.select,
default: (isInputInOptions ? initInputs[item.select.variable] : undefined) || item.default,
type: 'select', type: 'select',
} }
} }
@ -234,17 +244,30 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
} }
} }
let value = initInputs[item['text-input'].variable]
if (value && item['text-input'].max_length && value.length > item['text-input'].max_length)
value = value.slice(0, item['text-input'].max_length)
return { return {
...item['text-input'], ...item['text-input'],
default: value || item.default,
type: 'text-input', type: 'text-input',
} }
}) })
}, [appParams]) }, [initInputs, appParams])
const allInputsHidden = useMemo(() => { const allInputsHidden = useMemo(() => {
return inputsForms.length > 0 && inputsForms.every(item => item.hide === true) return inputsForms.length > 0 && inputsForms.every(item => item.hide === true)
}, [inputsForms]) }, [inputsForms])
useEffect(() => {
// init inputs from url params
(async () => {
const inputs = await getRawInputsFromUrlParams()
setInitInputs(inputs)
})()
}, [])
useEffect(() => { useEffect(() => {
const conversationInputs: Record<string, any> = {} const conversationInputs: Record<string, any> = {}
@ -362,11 +385,11 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
if (conversationId) if (conversationId)
setClearChatList(false) setClearChatList(false)
}, [handleConversationIdInfoChange, setClearChatList]) }, [handleConversationIdInfoChange, setClearChatList])
const handleNewConversation = useCallback(() => { const handleNewConversation = useCallback(async () => {
currentChatInstanceRef.current.handleStop() currentChatInstanceRef.current.handleStop()
setShowNewConversationItemInList(true) setShowNewConversationItemInList(true)
handleChangeConversation('') handleChangeConversation('')
handleNewConversationInputsChange({}) handleNewConversationInputsChange(await getRawInputsFromUrlParams())
setClearChatList(true) setClearChatList(true)
}, [handleChangeConversation, setShowNewConversationItemInList, handleNewConversationInputsChange, setClearChatList]) }, [handleChangeConversation, setShowNewConversationItemInList, handleNewConversationInputsChange, setClearChatList])
const handleUpdateConversationList = useCallback(() => { const handleUpdateConversationList = useCallback(() => {

@ -15,6 +15,17 @@ async function decodeBase64AndDecompress(base64String: string) {
} }
} }
async function getRawInputsFromUrlParams(): Promise<Record<string, any>> {
const urlParams = new URLSearchParams(window.location.search)
const inputs: Record<string, any> = {}
const entriesArray = Array.from(urlParams.entries())
entriesArray.forEach(([key, value]) => {
if (!key.startsWith('sys.'))
inputs[key] = decodeURIComponent(value)
})
return inputs
}
async function getProcessedInputsFromUrlParams(): Promise<Record<string, any>> { async function getProcessedInputsFromUrlParams(): Promise<Record<string, any>> {
const urlParams = new URLSearchParams(window.location.search) const urlParams = new URLSearchParams(window.location.search)
const inputs: Record<string, any> = {} const inputs: Record<string, any> = {}
@ -184,6 +195,7 @@ function getThreadMessages(tree: ChatItemInTree[], targetMessageId?: string): Ch
} }
export { export {
getRawInputsFromUrlParams,
getProcessedInputsFromUrlParams, getProcessedInputsFromUrlParams,
getProcessedSystemVariablesFromUrlParams, getProcessedSystemVariablesFromUrlParams,
isValidGeneratedAnswer, isValidGeneratedAnswer,

@ -1,7 +1,7 @@
import { useChatContext } from '@/app/components/base/chat/chat/context' import { useChatContext } from '@/app/components/base/chat/chat/context'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import cn from '@/utils/classnames' import cn from '@/utils/classnames'
import { isValidUrl } from './utils'
const MarkdownButton = ({ node }: any) => { const MarkdownButton = ({ node }: any) => {
const { onSend } = useChatContext() const { onSend } = useChatContext()
const variant = node.properties.dataVariant const variant = node.properties.dataVariant
@ -9,25 +9,17 @@ const MarkdownButton = ({ node }: any) => {
const link = node.properties.dataLink const link = node.properties.dataLink
const size = node.properties.dataSize const size = node.properties.dataSize
function is_valid_url(url: string): boolean {
try {
const parsed_url = new URL(url)
return ['http:', 'https:'].includes(parsed_url.protocol)
}
catch {
return false
}
}
return <Button return <Button
variant={variant} variant={variant}
size={size} size={size}
className={cn('!h-auto min-h-8 select-none whitespace-normal !px-3')} className={cn('!h-auto min-h-8 select-none whitespace-normal !px-3')}
onClick={() => { onClick={() => {
if (is_valid_url(link)) { if (isValidUrl(link)) {
window.open(link, '_blank') window.open(link, '_blank')
return return
} }
if(!message)
return
onSend?.(message) onSend?.(message)
}} }}
> >

@ -5,6 +5,7 @@
*/ */
import React from 'react' import React from 'react'
import { useChatContext } from '@/app/components/base/chat/chat/context' import { useChatContext } from '@/app/components/base/chat/chat/context'
import { isValidUrl } from './utils'
const Link = ({ node, children, ...props }: any) => { const Link = ({ node, children, ...props }: any) => {
const { onSend } = useChatContext() const { onSend } = useChatContext()
@ -14,7 +15,11 @@ const Link = ({ node, children, ...props }: any) => {
return <abbr className="cursor-pointer underline !decoration-primary-700 decoration-dashed" onClick={() => onSend?.(hidden_text)} title={node.children[0]?.value || ''}>{node.children[0]?.value || ''}</abbr> return <abbr className="cursor-pointer underline !decoration-primary-700 decoration-dashed" onClick={() => onSend?.(hidden_text)} title={node.children[0]?.value || ''}>{node.children[0]?.value || ''}</abbr>
} }
else { else {
return <a {...props} target="_blank" className="cursor-pointer underline !decoration-primary-700 decoration-dashed">{children || 'Download'}</a> const href = props.href || node.properties?.href
if(!isValidUrl(href))
return <span>{children}</span>
return <a href={href} target="_blank" className="cursor-pointer underline !decoration-primary-700 decoration-dashed">{children || 'Download'}</a>
} }
} }

@ -0,0 +1,3 @@
export const isValidUrl = (url: string): boolean => {
return ['http:', 'https:', '//', 'mailto:'].some(prefix => url.startsWith(prefix))
}

@ -7,7 +7,7 @@ import RemarkGfm from 'remark-gfm'
import RehypeRaw from 'rehype-raw' import RehypeRaw from 'rehype-raw'
import { flow } from 'lodash-es' import { flow } from 'lodash-es'
import cn from '@/utils/classnames' import cn from '@/utils/classnames'
import { preprocessLaTeX, preprocessThinkTag } from './markdown-utils' import { customUrlTransform, preprocessLaTeX, preprocessThinkTag } from './markdown-utils'
import { import {
AudioBlock, AudioBlock,
CodeBlock, CodeBlock,
@ -65,6 +65,7 @@ export function Markdown(props: { content: string; className?: string; customDis
} }
}, },
]} ]}
urlTransform={customUrlTransform}
disallowedElements={['iframe', 'head', 'html', 'meta', 'link', 'style', 'body', ...(props.customDisallowedElements || [])]} disallowedElements={['iframe', 'head', 'html', 'meta', 'link', 'style', 'body', ...(props.customDisallowedElements || [])]}
components={{ components={{
code: CodeBlock, code: CodeBlock,

@ -36,3 +36,52 @@ export const preprocessThinkTag = (content: string) => {
(str: string) => str.replace(/(<\/details>)(?![^\S\r\n]*[\r\n])(?![^\S\r\n]*$)/g, '$1\n'), (str: string) => str.replace(/(<\/details>)(?![^\S\r\n]*[\r\n])(?![^\S\r\n]*$)/g, '$1\n'),
])(content) ])(content)
} }
/**
* Transforms a URI for use in react-markdown, ensuring security and compatibility.
* This function is designed to work with react-markdown v9+ which has stricter
* default URL handling.
*
* Behavior:
* 1. Always allows the custom 'abbr:' protocol.
* 2. Always allows page-local fragments (e.g., "#some-id").
* 3. Always allows protocol-relative URLs (e.g., "//example.com/path").
* 4. Always allows purely relative paths (e.g., "path/to/file", "/abs/path").
* 5. Allows absolute URLs if their scheme is in a permitted list (case-insensitive):
* 'http:', 'https:', 'mailto:', 'xmpp:', 'irc:', 'ircs:'.
* 6. Intelligently distinguishes colons used for schemes from colons within
* paths, query parameters, or fragments of relative-like URLs.
* 7. Returns the original URI if allowed, otherwise returns `undefined` to
* signal that the URI should be removed/disallowed by react-markdown.
*/
export const customUrlTransform = (uri: string): string | undefined => {
const PERMITTED_SCHEME_REGEX = /^(https?|ircs?|mailto|xmpp|abbr):$/i
if (uri.startsWith('#'))
return uri
if (uri.startsWith('//'))
return uri
const colonIndex = uri.indexOf(':')
if (colonIndex === -1)
return uri
const slashIndex = uri.indexOf('/')
const questionMarkIndex = uri.indexOf('?')
const hashIndex = uri.indexOf('#')
if (
(slashIndex !== -1 && colonIndex > slashIndex)
|| (questionMarkIndex !== -1 && colonIndex > questionMarkIndex)
|| (hashIndex !== -1 && colonIndex > hashIndex)
)
return uri
const scheme = uri.substring(0, colonIndex + 1).toLowerCase()
if (PERMITTED_SCHEME_REGEX.test(scheme))
return uri
return undefined
}

@ -487,15 +487,15 @@ const Flowchart = React.forwardRef((props: {
'bg-white': currentTheme === Theme.light, 'bg-white': currentTheme === Theme.light,
'bg-slate-900': currentTheme === Theme.dark, 'bg-slate-900': currentTheme === Theme.dark,
}), }),
mermaidDiv: cn('mermaid cursor-pointer h-auto w-full relative', { mermaidDiv: cn('mermaid relative h-auto w-full cursor-pointer', {
'bg-white': currentTheme === Theme.light, 'bg-white': currentTheme === Theme.light,
'bg-slate-900': currentTheme === Theme.dark, 'bg-slate-900': currentTheme === Theme.dark,
}), }),
errorMessage: cn('py-4 px-[26px]', { errorMessage: cn('px-[26px] py-4', {
'text-red-500': currentTheme === Theme.light, 'text-red-500': currentTheme === Theme.light,
'text-red-400': currentTheme === Theme.dark, 'text-red-400': currentTheme === Theme.dark,
}), }),
errorIcon: cn('w-6 h-6', { errorIcon: cn('h-6 w-6', {
'text-red-500': currentTheme === Theme.light, 'text-red-500': currentTheme === Theme.light,
'text-red-400': currentTheme === Theme.dark, 'text-red-400': currentTheme === Theme.dark,
}), }),
@ -503,7 +503,7 @@ const Flowchart = React.forwardRef((props: {
'text-gray-700': currentTheme === Theme.light, 'text-gray-700': currentTheme === Theme.light,
'text-gray-300': currentTheme === Theme.dark, 'text-gray-300': currentTheme === Theme.dark,
}), }),
themeToggle: cn('flex items-center justify-center w-10 h-10 rounded-full transition-all duration-300 shadow-md backdrop-blur-sm', { themeToggle: cn('flex h-10 w-10 items-center justify-center rounded-full shadow-md backdrop-blur-sm transition-all duration-300', {
'bg-white/80 hover:bg-white hover:shadow-lg text-gray-700 border border-gray-200': currentTheme === Theme.light, 'bg-white/80 hover:bg-white hover:shadow-lg text-gray-700 border border-gray-200': currentTheme === Theme.light,
'bg-slate-800/80 hover:bg-slate-700 hover:shadow-lg text-yellow-300 border border-slate-600': currentTheme === Theme.dark, 'bg-slate-800/80 hover:bg-slate-700 hover:shadow-lg text-yellow-300 border border-slate-600': currentTheme === Theme.dark,
}), }),
@ -512,7 +512,7 @@ const Flowchart = React.forwardRef((props: {
// Style classes for look options // Style classes for look options
const getLookButtonClass = (lookType: 'classic' | 'handDrawn') => { const getLookButtonClass = (lookType: 'classic' | 'handDrawn') => {
return cn( return cn(
'flex items-center justify-center mb-4 w-[calc((100%-8px)/2)] h-8 rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg cursor-pointer system-sm-medium text-text-secondary', 'system-sm-medium mb-4 flex h-8 w-[calc((100%-8px)/2)] cursor-pointer items-center justify-center rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg text-text-secondary',
look === lookType && 'border-[1.5px] border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary', look === lookType && 'border-[1.5px] border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary',
currentTheme === Theme.dark && 'border-slate-600 bg-slate-800 text-slate-300', currentTheme === Theme.dark && 'border-slate-600 bg-slate-800 text-slate-300',
look === lookType && currentTheme === Theme.dark && 'border-blue-500 bg-slate-700 text-white', look === lookType && currentTheme === Theme.dark && 'border-blue-500 bg-slate-700 text-white',
@ -523,7 +523,7 @@ const Flowchart = React.forwardRef((props: {
<div ref={ref as React.RefObject<HTMLDivElement>} className={themeClasses.container}> <div ref={ref as React.RefObject<HTMLDivElement>} className={themeClasses.container}>
<div className={themeClasses.segmented}> <div className={themeClasses.segmented}>
<div className="msh-segmented-group"> <div className="msh-segmented-group">
<label className="msh-segmented-item flex items-center space-x-1 m-2 w-[200px]"> <label className="msh-segmented-item m-2 flex w-[200px] items-center space-x-1">
<div <div
key='classic' key='classic'
className={getLookButtonClass('classic')} className={getLookButtonClass('classic')}
@ -545,7 +545,7 @@ const Flowchart = React.forwardRef((props: {
<div ref={containerRef} style={{ position: 'absolute', visibility: 'hidden', height: 0, overflow: 'hidden' }} /> <div ref={containerRef} style={{ position: 'absolute', visibility: 'hidden', height: 0, overflow: 'hidden' }} />
{isLoading && !svgCode && ( {isLoading && !svgCode && (
<div className='py-4 px-[26px]'> <div className='px-[26px] py-4'>
<LoadingAnim type='text'/> <LoadingAnim type='text'/>
{!isCodeComplete && ( {!isCodeComplete && (
<div className="mt-2 text-sm text-gray-500"> <div className="mt-2 text-sm text-gray-500">
@ -557,7 +557,7 @@ const Flowchart = React.forwardRef((props: {
{svgCode && ( {svgCode && (
<div className={themeClasses.mermaidDiv} style={{ objectFit: 'cover' }} onClick={() => setImagePreviewUrl(svgCode)}> <div className={themeClasses.mermaidDiv} style={{ objectFit: 'cover' }} onClick={() => setImagePreviewUrl(svgCode)}>
<div className="absolute left-2 bottom-2 z-[100]"> <div className="absolute bottom-2 left-2 z-[100]">
<button <button
onClick={(e) => { onClick={(e) => {
e.stopPropagation() e.stopPropagation()

@ -274,7 +274,6 @@ const CreateFormPipeline = () => {
{datasource?.type === DatasourceType.websiteCrawl && ( {datasource?.type === DatasourceType.websiteCrawl && (
<WebsiteCrawl <WebsiteCrawl
nodeId={datasource?.nodeId || ''} nodeId={datasource?.nodeId || ''}
variables={[]} // todo: replace with actual variables if needed
headerInfo={{ headerInfo={{
title: datasource.description, title: datasource.description,
docTitle: datasource.docTitle || '', docTitle: datasource.docTitle || '',
@ -284,6 +283,7 @@ const CreateFormPipeline = () => {
onCheckedCrawlResultChange={setWebsitePages} onCheckedCrawlResultChange={setWebsitePages}
onJobIdChange={setWebsiteCrawlJobId} onJobIdChange={setWebsiteCrawlJobId}
onPreview={updateCurrentWebsite} onPreview={updateCurrentWebsite}
usingPublished
/> />
)} )}
{isShowVectorSpaceFull && ( {isShowVectorSpaceFull && (

@ -108,7 +108,7 @@ const PluginItem: FC<Props> = ({
}><RiErrorWarningLine color='red' className="ml-0.5 h-4 w-4 shrink-0 text-text-accent" /></Tooltip>} }><RiErrorWarningLine color='red' className="ml-0.5 h-4 w-4 shrink-0 text-text-accent" /></Tooltip>}
<Badge className='ml-1 shrink-0' <Badge className='ml-1 shrink-0'
text={source === PluginSource.github ? plugin.meta!.version : plugin.version} text={source === PluginSource.github ? plugin.meta!.version : plugin.version}
hasRedCornerMark={(source === PluginSource.marketplace) && !!plugin.latest_unique_identifier && plugin.latest_unique_identifier !== plugin_unique_identifier} hasRedCornerMark={(source === PluginSource.marketplace) && !!plugin.latest_version && plugin.latest_version !== plugin.version}
/> />
</div> </div>
<div className='flex items-center justify-between'> <div className='flex items-center justify-between'>

@ -51,11 +51,13 @@ const FieldItem = ({
className={cn( className={cn(
'flex h-8 cursor-pointer items-center justify-between gap-x-1 rounded-lg border border-components-panel-border-subtle bg-components-panel-on-panel-item-bg py-1 pl-2 shadow-xs hover:shadow-sm', 'flex h-8 cursor-pointer items-center justify-between gap-x-1 rounded-lg border border-components-panel-border-subtle bg-components-panel-on-panel-item-bg py-1 pl-2 shadow-xs hover:shadow-sm',
(isHovering && !readonly) ? 'pr-1' : 'pr-2.5', (isHovering && !readonly) ? 'pr-1' : 'pr-2.5',
readonly && 'cursor-default',
)} )}
onClick={handleOnClickEdit}
> >
<div className='flex grow basis-0 items-center gap-x-1'> <div className='flex grow basis-0 items-center gap-x-1'>
{ {
isHovering (isHovering && !readonly)
? <RiDraggable className='handle h-4 w-4 cursor-all-scroll text-text-quaternary' /> ? <RiDraggable className='handle h-4 w-4 cursor-all-scroll text-text-quaternary' />
: <InputField className='size-4 text-text-accent' /> : <InputField className='size-4 text-text-accent' />
} }

@ -1,5 +1,6 @@
import { import {
memo, memo,
useCallback,
useMemo, useMemo,
} from 'react' } from 'react'
import { ReactSortable } from 'react-sortablejs' import { ReactSortable } from 'react-sortablejs'
@ -7,6 +8,7 @@ import cn from '@/utils/classnames'
import type { InputVar } from '@/models/pipeline' import type { InputVar } from '@/models/pipeline'
import FieldItem from './field-item' import FieldItem from './field-item'
import type { SortableItem } from './types' import type { SortableItem } from './types'
import { isEqual } from 'lodash-es'
type FieldListContainerProps = { type FieldListContainerProps = {
className?: string className?: string
@ -33,11 +35,17 @@ const FieldListContainer = ({
}) })
}, [inputFields]) }, [inputFields])
const handleListSortChange = useCallback((newList: SortableItem[]) => {
if (isEqual(newList, list))
return
onListSortChange(newList)
}, [list, onListSortChange])
return ( return (
<ReactSortable<SortableItem> <ReactSortable<SortableItem>
className={cn(className)} className={cn(className)}
list={list} list={list}
setList={onListSortChange} setList={handleListSortChange}
handle='.handle' handle='.handle'
ghostClass='opacity-50' ghostClass='opacity-50'
group='rag-pipeline-input-field' group='rag-pipeline-input-field'

@ -44,9 +44,11 @@ export const useFieldList = (
const [editingField, setEditingField] = useState<InputVar | undefined>() const [editingField, setEditingField] = useState<InputVar | undefined>()
const [showInputFieldEditor, setShowInputFieldEditor] = useState(false) const [showInputFieldEditor, setShowInputFieldEditor] = useState(false)
const editingFieldIndex = useRef<number>(-1)
const handleOpenInputFieldEditor = useCallback((id?: string) => { const handleOpenInputFieldEditor = useCallback((id?: string) => {
const fieldToEdit = inputFieldsRef.current.find(field => field.variable === id) const index = inputFieldsRef.current.findIndex(field => field.variable === id)
setEditingField(fieldToEdit) editingFieldIndex.current = index
setEditingField(inputFieldsRef.current[index])
setShowInputFieldEditor(true) setShowInputFieldEditor(true)
}, []) }, [])
const handleCancelInputFieldEditor = useCallback(() => { const handleCancelInputFieldEditor = useCallback(() => {
@ -76,7 +78,7 @@ export const useFieldList = (
const handleSubmitField = useCallback((data: InputVar, moreInfo?: MoreInfo) => { const handleSubmitField = useCallback((data: InputVar, moreInfo?: MoreInfo) => {
const newInputFields = produce(inputFieldsRef.current, (draft) => { const newInputFields = produce(inputFieldsRef.current, (draft) => {
const currentIndex = draft.findIndex(field => field.variable === data.variable) const currentIndex = editingFieldIndex.current
if (currentIndex === -1) { if (currentIndex === -1) {
draft.push(data) draft.push(data)
return return

@ -17,7 +17,6 @@ import Datasource from './label-right-content/datasource'
import { useNodes } from 'reactflow' import { useNodes } from 'reactflow'
import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types' import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
// import produce from 'immer'
import { useNodesSyncDraft } from '@/app/components/workflow/hooks' import { useNodesSyncDraft } from '@/app/components/workflow/hooks'
import type { InputVar, RAGPipelineVariables } from '@/models/pipeline' import type { InputVar, RAGPipelineVariables } from '@/models/pipeline'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
@ -25,7 +24,6 @@ import Divider from '@/app/components/base/divider'
import Tooltip from '@/app/components/base/tooltip' import Tooltip from '@/app/components/base/tooltip'
import cn from '@/utils/classnames' import cn from '@/utils/classnames'
import PreviewPanel from './preview' import PreviewPanel from './preview'
import { useDebounceFn, useUnmount } from 'ahooks'
type InputFieldDialogProps = { type InputFieldDialogProps = {
readonly?: boolean readonly?: boolean
@ -55,17 +53,7 @@ const InputFieldDialog = ({
} }
const inputFieldsMap = useRef(getInputFieldsMap()) const inputFieldsMap = useRef(getInputFieldsMap())
const { doSyncWorkflowDraft } = useNodesSyncDraft() const { handleSyncWorkflowDraft } = useNodesSyncDraft()
useUnmount(async () => {
await doSyncWorkflowDraft()
})
const { run: syncWorkflowDraft } = useDebounceFn(() => {
doSyncWorkflowDraft()
}, {
wait: 500,
})
const datasourceNodeDataMap = useMemo(() => { const datasourceNodeDataMap = useMemo(() => {
const datasourceNodeDataMap: Record<string, DataSourceNodeType> = {} const datasourceNodeDataMap: Record<string, DataSourceNodeType> = {}
@ -77,7 +65,7 @@ const InputFieldDialog = ({
return datasourceNodeDataMap return datasourceNodeDataMap
}, [nodes]) }, [nodes])
const updateInputFields = useCallback((key: string, value: InputVar[]) => { const updateInputFields = useCallback(async (key: string, value: InputVar[]) => {
inputFieldsMap.current[key] = value inputFieldsMap.current[key] = value
const newRagPipelineVariables: RAGPipelineVariables = [] const newRagPipelineVariables: RAGPipelineVariables = []
Object.keys(inputFieldsMap.current).forEach((key) => { Object.keys(inputFieldsMap.current).forEach((key) => {
@ -90,8 +78,8 @@ const InputFieldDialog = ({
}) })
}) })
setRagPipelineVariables?.(newRagPipelineVariables) setRagPipelineVariables?.(newRagPipelineVariables)
syncWorkflowDraft() handleSyncWorkflowDraft()
}, [setRagPipelineVariables, syncWorkflowDraft]) }, [setRagPipelineVariables, handleSyncWorkflowDraft])
const closePanel = useCallback(() => { const closePanel = useCallback(() => {
setShowInputFieldDialog?.(false) setShowInputFieldDialog?.(false)

@ -3,17 +3,24 @@ import { useTranslation } from 'react-i18next'
import DataSourceOptions from '../../panel/test-run/data-source-options' import DataSourceOptions from '../../panel/test-run/data-source-options'
import Form from './form' import Form from './form'
import type { Datasource } from '../../panel/test-run/types' import type { Datasource } from '../../panel/test-run/types'
import { useStore } from '@/app/components/workflow/store'
import { useDraftPipelinePreProcessingParams } from '@/service/use-pipeline'
type DatasourceProps = { type DatasourceProps = {
onSelect: (dataSource: Datasource) => void onSelect: (dataSource: Datasource) => void
datasourceNodeId: string dataSourceNodeId: string
} }
const DataSource = ({ const DataSource = ({
onSelect: setDatasource, onSelect: setDatasource,
datasourceNodeId, dataSourceNodeId,
}: DatasourceProps) => { }: DatasourceProps) => {
const { t } = useTranslation() const { t } = useTranslation()
const pipelineId = useStore(state => state.pipelineId)
const { data: paramsConfig } = useDraftPipelinePreProcessingParams({
pipeline_id: pipelineId!,
node_id: dataSourceNodeId,
}, !!pipelineId && !!dataSourceNodeId)
return ( return (
<div className='flex flex-col'> <div className='flex flex-col'>
@ -23,10 +30,10 @@ const DataSource = ({
<div className='px-4 py-2'> <div className='px-4 py-2'>
<DataSourceOptions <DataSourceOptions
onSelect={setDatasource} onSelect={setDatasource}
datasourceNodeId={datasourceNodeId} dataSourceNodeId={dataSourceNodeId}
/> />
</div> </div>
<Form variables={[]} /> <Form variables={paramsConfig?.variables || []} />
</div> </div>
) )
} }

@ -40,13 +40,15 @@ const PreviewPanel = ({
<RiCloseLine className='size-4 text-text-tertiary' /> <RiCloseLine className='size-4 text-text-tertiary' />
</button> </button>
</div> </div>
{/* Data source form Preview */}
<DataSource <DataSource
onSelect={setDatasource} onSelect={setDatasource}
datasourceNodeId={datasource?.nodeId || ''} dataSourceNodeId={datasource?.nodeId || ''}
/> />
<div className='px-4 py-2'> <div className='px-4 py-2'>
<Divider type='horizontal' className='bg-divider-subtle' /> <Divider type='horizontal' className='bg-divider-subtle' />
</div> </div>
{/* Process documents form Preview */}
<ProcessDocuments dataSourceNodeId={datasource?.nodeId || ''} /> <ProcessDocuments dataSourceNodeId={datasource?.nodeId || ''} />
</DialogWrapper> </DialogWrapper>
) )

@ -16,7 +16,7 @@ const ProcessDocuments = ({
const { data: paramsConfig } = useDraftPipelineProcessingParams({ const { data: paramsConfig } = useDraftPipelineProcessingParams({
pipeline_id: pipelineId!, pipeline_id: pipelineId!,
node_id: dataSourceNodeId, node_id: dataSourceNodeId,
}) }, !!pipelineId && !!dataSourceNodeId)
return ( return (
<div className='flex flex-col'> <div className='flex flex-col'>

@ -4,12 +4,12 @@ import OptionCard from './option-card'
import type { Datasource } from '../types' import type { Datasource } from '../types'
type DataSourceOptionsProps = { type DataSourceOptionsProps = {
datasourceNodeId: string dataSourceNodeId: string
onSelect: (option: Datasource) => void onSelect: (option: Datasource) => void
} }
const DataSourceOptions = ({ const DataSourceOptions = ({
datasourceNodeId, dataSourceNodeId,
onSelect, onSelect,
}: DataSourceOptionsProps) => { }: DataSourceOptionsProps) => {
const { datasources, options } = useDatasourceOptions() const { datasources, options } = useDatasourceOptions()
@ -34,7 +34,7 @@ const DataSourceOptions = ({
key={option.value} key={option.value}
label={option.label} label={option.label}
nodeData={option.data} nodeData={option.data}
selected={datasourceNodeId === option.value} selected={dataSourceNodeId === option.value}
onClick={handelSelect.bind(null, option.value)} onClick={handelSelect.bind(null, option.value)}
/> />
))} ))}

@ -1,5 +1,5 @@
'use client' 'use client'
import React, { useCallback, useEffect, useState } from 'react' import React, { useCallback, useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import type { CrawlResultItem } from '@/models/datasets' import type { CrawlResultItem } from '@/models/datasets'
import Header from '@/app/components/datasets/create/website/base/header' import Header from '@/app/components/datasets/create/website/base/header'
@ -7,15 +7,17 @@ import Options from './options'
import Crawling from './crawling' import Crawling from './crawling'
import ErrorMessage from './error-message' import ErrorMessage from './error-message'
import CrawledResult from './crawled-result' import CrawledResult from './crawled-result'
import type { RAGPipelineVariables } from '@/models/pipeline' import {
import { useDatasourceNodeRun } from '@/service/use-pipeline' useDatasourceNodeRun,
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' useDraftPipelinePreProcessingParams,
usePublishedPipelineProcessingParams,
} from '@/service/use-pipeline'
import { useStore } from '@/app/components/workflow/store'
const I18N_PREFIX = 'datasetCreation.stepOne.website' const I18N_PREFIX = 'datasetCreation.stepOne.website'
type CrawlerProps = { type CrawlerProps = {
nodeId: string nodeId: string
variables: RAGPipelineVariables
checkedCrawlResult: CrawlResultItem[] checkedCrawlResult: CrawlResultItem[]
onCheckedCrawlResultChange: (payload: CrawlResultItem[]) => void onCheckedCrawlResultChange: (payload: CrawlResultItem[]) => void
onJobIdChange: (jobId: string) => void onJobIdChange: (jobId: string) => void
@ -25,6 +27,7 @@ type CrawlerProps = {
docLink: string docLink: string
} }
onPreview?: (payload: CrawlResultItem) => void onPreview?: (payload: CrawlResultItem) => void
usingPublished?: boolean
} }
enum Step { enum Step {
@ -35,17 +38,23 @@ enum Step {
const Crawler = ({ const Crawler = ({
nodeId, nodeId,
variables,
checkedCrawlResult, checkedCrawlResult,
headerInfo, headerInfo,
onCheckedCrawlResultChange, onCheckedCrawlResultChange,
onJobIdChange, onJobIdChange,
onPreview, onPreview,
usingPublished = false,
}: CrawlerProps) => { }: CrawlerProps) => {
const { t } = useTranslation() const { t } = useTranslation()
const [step, setStep] = useState<Step>(Step.init) const [step, setStep] = useState<Step>(Step.init)
const [controlFoldOptions, setControlFoldOptions] = useState<number>(0) const [controlFoldOptions, setControlFoldOptions] = useState<number>(0)
const pipelineId = useDatasetDetailContextWithSelector(s => s.dataset?.pipeline_id) const pipelineId = useStore(s => s.pipelineId)
const usePreProcessingParams = useRef(usingPublished ? usePublishedPipelineProcessingParams : useDraftPipelinePreProcessingParams)
const { data: paramsConfig } = usePreProcessingParams.current({
pipeline_id: pipelineId!,
node_id: nodeId,
}, !!pipelineId && !!nodeId)
useEffect(() => { useEffect(() => {
if (step !== Step.init) if (step !== Step.init)
@ -95,7 +104,7 @@ const Crawler = ({
/> />
<div className='mt-2 rounded-xl border border-components-panel-border bg-background-default-subtle'> <div className='mt-2 rounded-xl border border-components-panel-border bg-background-default-subtle'>
<Options <Options
variables={variables} variables={paramsConfig?.variables || []}
isRunning={isRunning} isRunning={isRunning}
controlFoldOptions={controlFoldOptions} controlFoldOptions={controlFoldOptions}
onSubmit={(value) => { onSubmit={(value) => {

@ -1,12 +1,10 @@
'use client' 'use client'
import React from 'react' import React from 'react'
import type { CrawlResultItem } from '@/models/datasets' import type { CrawlResultItem } from '@/models/datasets'
import type { RAGPipelineVariables } from '@/models/pipeline'
import Crawler from './base/crawler' import Crawler from './base/crawler'
type WebsiteCrawlProps = { type WebsiteCrawlProps = {
nodeId: string nodeId: string
variables: RAGPipelineVariables
checkedCrawlResult: CrawlResultItem[] checkedCrawlResult: CrawlResultItem[]
onCheckedCrawlResultChange: (payload: CrawlResultItem[]) => void onCheckedCrawlResultChange: (payload: CrawlResultItem[]) => void
onJobIdChange: (jobId: string) => void onJobIdChange: (jobId: string) => void
@ -16,26 +14,27 @@ type WebsiteCrawlProps = {
docLink: string docLink: string
} }
onPreview?: (payload: CrawlResultItem) => void onPreview?: (payload: CrawlResultItem) => void
usingPublished?: boolean
} }
const WebsiteCrawl = ({ const WebsiteCrawl = ({
nodeId, nodeId,
variables,
checkedCrawlResult, checkedCrawlResult,
headerInfo, headerInfo,
onCheckedCrawlResultChange, onCheckedCrawlResultChange,
onJobIdChange, onJobIdChange,
onPreview, onPreview,
usingPublished,
}: WebsiteCrawlProps) => { }: WebsiteCrawlProps) => {
return ( return (
<Crawler <Crawler
nodeId={nodeId} nodeId={nodeId}
variables={variables}
checkedCrawlResult={checkedCrawlResult} checkedCrawlResult={checkedCrawlResult}
headerInfo={headerInfo} headerInfo={headerInfo}
onCheckedCrawlResultChange={onCheckedCrawlResultChange} onCheckedCrawlResultChange={onCheckedCrawlResultChange}
onJobIdChange={onJobIdChange} onJobIdChange={onJobIdChange}
onPreview={onPreview} onPreview={onPreview}
usingPublished={usingPublished}
/> />
) )
} }

@ -117,7 +117,7 @@ const TestRunPanel = () => {
<> <>
<div className='flex flex-col gap-y-4 px-4 py-2'> <div className='flex flex-col gap-y-4 px-4 py-2'>
<DataSourceOptions <DataSourceOptions
datasourceNodeId={datasource?.nodeId || ''} dataSourceNodeId={datasource?.nodeId || ''}
onSelect={setDatasource} onSelect={setDatasource}
/> />
{datasource?.type === DatasourceType.localFile && ( {datasource?.type === DatasourceType.localFile && (
@ -139,7 +139,6 @@ const TestRunPanel = () => {
{datasource?.type === DatasourceType.websiteCrawl && ( {datasource?.type === DatasourceType.websiteCrawl && (
<WebsiteCrawl <WebsiteCrawl
nodeId={datasource?.nodeId || ''} nodeId={datasource?.nodeId || ''}
variables={[]} // todo: replace with actual variables if needed
checkedCrawlResult={websitePages} checkedCrawlResult={websitePages}
headerInfo={{ headerInfo={{
title: datasource.description, title: datasource.description,

@ -280,6 +280,7 @@ const translation = {
'inputPlaceholder': 'Por favor ingresa', 'inputPlaceholder': 'Por favor ingresa',
'content': 'Contenido', 'content': 'Contenido',
'required': 'Requerido', 'required': 'Requerido',
'hide': 'Ocultar',
'errorMsg': { 'errorMsg': {
varNameRequired: 'Nombre de la variable es requerido', varNameRequired: 'Nombre de la variable es requerido',
labelNameRequired: 'Nombre de la etiqueta es requerido', labelNameRequired: 'Nombre de la etiqueta es requerido',

@ -315,6 +315,7 @@ const translation = {
'inputPlaceholder': 'لطفاً وارد کنید', 'inputPlaceholder': 'لطفاً وارد کنید',
'content': 'محتوا', 'content': 'محتوا',
'required': 'مورد نیاز', 'required': 'مورد نیاز',
'hide': 'مخفی کردن',
'errorMsg': { 'errorMsg': {
varNameRequired: 'نام متغیر مورد نیاز است', varNameRequired: 'نام متغیر مورد نیاز است',
labelNameRequired: 'نام برچسب مورد نیاز است', labelNameRequired: 'نام برچسب مورد نیاز است',

@ -268,6 +268,7 @@ const translation = {
'labelName': 'Label Name', 'labelName': 'Label Name',
'inputPlaceholder': 'Please input', 'inputPlaceholder': 'Please input',
'required': 'Required', 'required': 'Required',
'hide': 'Caché',
'errorMsg': { 'errorMsg': {
varNameRequired: 'Variable name is required', varNameRequired: 'Variable name is required',
labelNameRequired: 'Label name is required', labelNameRequired: 'Label name is required',

@ -312,6 +312,7 @@ const translation = {
'inputPlaceholder': 'कृपया इनपुट करें', 'inputPlaceholder': 'कृपया इनपुट करें',
'content': 'सामग्री', 'content': 'सामग्री',
'required': 'आवश्यक', 'required': 'आवश्यक',
'hide': 'छुपाएँ',
'errorMsg': { 'errorMsg': {
varNameRequired: 'वेरिएबल नाम आवश्यक है', varNameRequired: 'वेरिएबल नाम आवश्यक है',
labelNameRequired: 'लेबल नाम आवश्यक है', labelNameRequired: 'लेबल नाम आवश्यक है',

@ -314,6 +314,7 @@ const translation = {
'inputPlaceholder': 'Per favore inserisci', 'inputPlaceholder': 'Per favore inserisci',
'content': 'Contenuto', 'content': 'Contenuto',
'required': 'Richiesto', 'required': 'Richiesto',
'hide': 'Nascondi',
'errorMsg': { 'errorMsg': {
varNameRequired: 'Il nome della variabile è richiesto', varNameRequired: 'Il nome della variabile è richiesto',
labelNameRequired: 'Il nome dell\'etichetta è richiesto', labelNameRequired: 'Il nome dell\'etichetta è richiesto',

@ -359,6 +359,7 @@ const translation = {
'labelName': 'ラベル名', 'labelName': 'ラベル名',
'inputPlaceholder': '入力してください', 'inputPlaceholder': '入力してください',
'required': '必須', 'required': '必須',
'hide': '非表示',
'file': { 'file': {
supportFileTypes: 'サポートされたファイルタイプ', supportFileTypes: 'サポートされたファイルタイプ',
image: { image: {

@ -279,6 +279,7 @@ const translation = {
'labelName': '레이블명', 'labelName': '레이블명',
'inputPlaceholder': '입력하세요', 'inputPlaceholder': '입력하세요',
'required': '필수', 'required': '필수',
'hide': '숨기기',
'errorMsg': { 'errorMsg': {
varNameRequired: '변수명은 필수입니다', varNameRequired: '변수명은 필수입니다',
labelNameRequired: '레이블명은 필수입니다', labelNameRequired: '레이블명은 필수입니다',

@ -309,6 +309,7 @@ const translation = {
'labelName': 'Nazwa etykiety', 'labelName': 'Nazwa etykiety',
'inputPlaceholder': 'Proszę wpisać', 'inputPlaceholder': 'Proszę wpisać',
'required': 'Wymagane', 'required': 'Wymagane',
'hide': 'Ukryj',
'errorMsg': { 'errorMsg': {
varNameRequired: 'Wymagana nazwa zmiennej', varNameRequired: 'Wymagana nazwa zmiennej',
labelNameRequired: 'Wymagana nazwa etykiety', labelNameRequired: 'Wymagana nazwa etykiety',

@ -285,6 +285,7 @@ const translation = {
'labelName': 'Nome do Rótulo', 'labelName': 'Nome do Rótulo',
'inputPlaceholder': 'Por favor, insira', 'inputPlaceholder': 'Por favor, insira',
'required': 'Obrigatório', 'required': 'Obrigatório',
'hide': 'Ocultar',
'errorMsg': { 'errorMsg': {
varNameRequired: 'O nome da variável é obrigatório', varNameRequired: 'O nome da variável é obrigatório',
labelNameRequired: 'O nome do rótulo é obrigatório', labelNameRequired: 'O nome do rótulo é obrigatório',

@ -285,6 +285,7 @@ const translation = {
'labelName': 'Nume etichetă', 'labelName': 'Nume etichetă',
'inputPlaceholder': 'Vă rugăm să introduceți', 'inputPlaceholder': 'Vă rugăm să introduceți',
'required': 'Obligatoriu', 'required': 'Obligatoriu',
'hide': 'Ascundeți',
'errorMsg': { 'errorMsg': {
varNameRequired: 'Numele variabilei este obligatoriu', varNameRequired: 'Numele variabilei este obligatoriu',
labelNameRequired: 'Numele etichetei este obligatoriu', labelNameRequired: 'Numele etichetei este obligatoriu',

@ -322,6 +322,7 @@ const translation = {
'inputPlaceholder': 'Пожалуйста, введите', 'inputPlaceholder': 'Пожалуйста, введите',
'content': 'Содержимое', 'content': 'Содержимое',
'required': 'Обязательно', 'required': 'Обязательно',
'hide': 'Скрыть',
'errorMsg': { 'errorMsg': {
labelNameRequired: 'Имя метки обязательно', labelNameRequired: 'Имя метки обязательно',
varNameCanBeRepeat: 'Имя переменной не может повторяться', varNameCanBeRepeat: 'Имя переменной не может повторяться',

@ -239,4 +239,4 @@ const translation = {
}, },
} }
module.exports = translation export default translation

@ -279,6 +279,7 @@ const translation = {
'labelName': 'Назва мітки', 'labelName': 'Назва мітки',
'inputPlaceholder': 'Будь ласка, введіть', 'inputPlaceholder': 'Будь ласка, введіть',
'required': 'Обов\'язково', 'required': 'Обов\'язково',
'hide': 'Приховати',
'errorMsg': { 'errorMsg': {
varNameRequired: 'Потрібно вказати назву змінної', varNameRequired: 'Потрібно вказати назву змінної',
labelNameRequired: 'Потрібно вказати назву мітки', labelNameRequired: 'Потрібно вказати назву мітки',

@ -279,6 +279,7 @@ const translation = {
'labelName': 'Tên nhãn', 'labelName': 'Tên nhãn',
'inputPlaceholder': 'Vui lòng nhập', 'inputPlaceholder': 'Vui lòng nhập',
'required': 'Bắt buộc', 'required': 'Bắt buộc',
'hide': 'Ẩn',
'errorMsg': { 'errorMsg': {
varNameRequired: 'Tên biến là bắt buộc', varNameRequired: 'Tên biến là bắt buộc',
labelNameRequired: 'Tên nhãn là bắt buộc', labelNameRequired: 'Tên nhãn là bắt buộc',

@ -361,6 +361,7 @@ const translation = {
'inputPlaceholder': '请输入', 'inputPlaceholder': '请输入',
'labelName': '显示名称', 'labelName': '显示名称',
'required': '必填', 'required': '必填',
'hide': '隐藏',
'placeholder': '占位符', 'placeholder': '占位符',
'placeholderPlaceholder': '输入字段为空时显示的文本', 'placeholderPlaceholder': '输入字段为空时显示的文本',
'defaultValue': '默认值', 'defaultValue': '默认值',

@ -264,6 +264,7 @@ const translation = {
'inputPlaceholder': '請輸入', 'inputPlaceholder': '請輸入',
'labelName': '顯示名稱', 'labelName': '顯示名稱',
'required': '必填', 'required': '必填',
'hide': '隱藏',
'errorMsg': { 'errorMsg': {
varNameRequired: '變數名稱必填', varNameRequired: '變數名稱必填',
labelNameRequired: '顯示名稱必填', labelNameRequired: '顯示名稱必填',

@ -142,6 +142,15 @@ export type PipelineProcessingParamsResponse = {
variables: RAGPipelineVariables variables: RAGPipelineVariables
} }
export type PipelinePreProcessingParamsRequest = {
pipeline_id: string
node_id: string
}
export type PipelinePreProcessingParamsResponse = {
variables: RAGPipelineVariables
}
export type PipelineDatasourceNodeRunRequest = { export type PipelineDatasourceNodeRunRequest = {
pipeline_id: string pipeline_id: string
node_id: string node_id: string

@ -10,6 +10,8 @@ import type {
PipelineCheckDependenciesResponse, PipelineCheckDependenciesResponse,
PipelineDatasourceNodeRunRequest, PipelineDatasourceNodeRunRequest,
PipelineDatasourceNodeRunResponse, PipelineDatasourceNodeRunResponse,
PipelinePreProcessingParamsRequest,
PipelinePreProcessingParamsResponse,
PipelineProcessingParamsRequest, PipelineProcessingParamsRequest,
PipelineProcessingParamsResponse, PipelineProcessingParamsResponse,
PipelineTemplateByIdResponse, PipelineTemplateByIdResponse,
@ -136,10 +138,10 @@ export const useDatasourceNodeRun = (
}) })
} }
export const useDraftPipelineProcessingParams = (params: PipelineProcessingParamsRequest) => { export const useDraftPipelineProcessingParams = (params: PipelineProcessingParamsRequest, enabled = true) => {
const { pipeline_id, node_id } = params const { pipeline_id, node_id } = params
return useQuery<PipelineProcessingParamsResponse>({ return useQuery<PipelineProcessingParamsResponse>({
queryKey: [NAME_SPACE, 'pipeline-processing-params', pipeline_id], queryKey: [NAME_SPACE, 'pipeline-processing-params', pipeline_id, node_id],
queryFn: () => { queryFn: () => {
return get<PipelineProcessingParamsResponse>(`/rag/pipelines/${pipeline_id}/workflows/draft/processing/parameters`, { return get<PipelineProcessingParamsResponse>(`/rag/pipelines/${pipeline_id}/workflows/draft/processing/parameters`, {
params: { params: {
@ -148,14 +150,14 @@ export const useDraftPipelineProcessingParams = (params: PipelineProcessingParam
}) })
}, },
staleTime: 0, staleTime: 0,
enabled: !!pipeline_id && !!node_id, enabled,
}) })
} }
export const usePublishedPipelineProcessingParams = (params: PipelineProcessingParamsRequest) => { export const usePublishedPipelineProcessingParams = (params: PipelineProcessingParamsRequest) => {
const { pipeline_id, node_id } = params const { pipeline_id, node_id } = params
return useQuery<PipelineProcessingParamsResponse>({ return useQuery<PipelineProcessingParamsResponse>({
queryKey: [NAME_SPACE, 'pipeline-processing-params', pipeline_id], queryKey: [NAME_SPACE, 'pipeline-processing-params', pipeline_id, node_id],
queryFn: () => { queryFn: () => {
return get<PipelineProcessingParamsResponse>(`/rag/pipelines/${pipeline_id}/workflows/published/processing/parameters`, { return get<PipelineProcessingParamsResponse>(`/rag/pipelines/${pipeline_id}/workflows/published/processing/parameters`, {
params: { params: {
@ -163,6 +165,7 @@ export const usePublishedPipelineProcessingParams = (params: PipelineProcessingP
}, },
}) })
}, },
staleTime: 0,
}) })
} }
@ -248,3 +251,35 @@ export const useUpdateDataSourceCredentials = (
}, },
}) })
} }
export const useDraftPipelinePreProcessingParams = (params: PipelinePreProcessingParamsRequest, enabled = true) => {
const { pipeline_id, node_id } = params
return useQuery<PipelinePreProcessingParamsResponse>({
queryKey: [NAME_SPACE, 'pipeline-pre-processing-params', pipeline_id, node_id],
queryFn: () => {
return get<PipelinePreProcessingParamsResponse>(`/rag/pipelines/${pipeline_id}/workflows/draft/pre-processing/parameters`, {
params: {
node_id,
},
})
},
staleTime: 0,
enabled,
})
}
export const usePublishedPipelinePreProcessingParams = (params: PipelinePreProcessingParamsRequest, enabled = true) => {
const { pipeline_id, node_id } = params
return useQuery<PipelinePreProcessingParamsResponse>({
queryKey: [NAME_SPACE, 'pipeline-pre-processing-params', pipeline_id, node_id],
queryFn: () => {
return get<PipelinePreProcessingParamsResponse>(`/rag/pipelines/${pipeline_id}/workflows/published/processing/parameters`, {
params: {
node_id,
},
})
},
staleTime: 0,
enabled,
})
}

Loading…
Cancel
Save