diff --git a/api/.env.example b/api/.env.example index ae7e82c779..7878308588 100644 --- a/api/.env.example +++ b/api/.env.example @@ -491,3 +491,10 @@ OTEL_METRIC_EXPORT_TIMEOUT=30000 # Prevent Clickjacking ALLOW_EMBED=false + +# Dataset queue monitor configuration +QUEUE_MONITOR_THRESHOLD=200 +# You can configure multiple ones, separated by commas. eg: test1@dify.ai,test2@dify.ai +QUEUE_MONITOR_ALERT_EMAILS= +# Monitor interval in minutes, default is 30 minutes +QUEUE_MONITOR_INTERVAL=30 diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 1b015b3267..2dcf1710b0 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -2,7 +2,7 @@ import os from typing import Any, Literal, Optional from urllib.parse import parse_qsl, quote_plus -from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field +from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field from pydantic_settings import BaseSettings from .cache.redis_config import RedisConfig @@ -256,6 +256,25 @@ class InternalTestConfig(BaseSettings): ) +class DatasetQueueMonitorConfig(BaseSettings): + """ + Configuration settings for Dataset Queue Monitor + """ + + QUEUE_MONITOR_THRESHOLD: Optional[NonNegativeInt] = Field( + description="Threshold for dataset queue monitor", + default=200, + ) + QUEUE_MONITOR_ALERT_EMAILS: Optional[str] = Field( + description="Emails for dataset queue monitor alert, separated by commas", + default=None, + ) + QUEUE_MONITOR_INTERVAL: Optional[NonNegativeFloat] = Field( + description="Interval for dataset queue monitor in minutes", + default=30, + ) + + class MiddlewareConfig( # place the configs in alphabet order CeleryConfig, @@ -303,5 +322,6 @@ class MiddlewareConfig( BaiduVectorDBConfig, OpenGaussConfig, TableStoreConfig, + DatasetQueueMonitorConfig, ): pass diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index c100f53078..27e8dd3fa6 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -369,6 +369,7 @@ class DatasetTagsApi(DatasetApiResource): ) parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) args = parser.parse_args() + args["type"] = "knowledge" tag = TagService.update_tags(args, args.get("tag_id")) binding_count = TagService.get_tag_binding_count(args.get("tag_id")) diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 418363ffbb..ab7ab4dcf0 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -175,8 +175,11 @@ class DocumentAddByFileApi(DatasetApiResource): if not dataset: raise ValueError("Dataset does not exist.") - if not dataset.indexing_technique and not args.get("indexing_technique"): + + indexing_technique = args.get("indexing_technique") or dataset.indexing_technique + if not indexing_technique: raise ValueError("indexing_technique is required.") + args["indexing_technique"] = indexing_technique # save file info file = request.files["file"] @@ -206,12 +209,16 @@ class DocumentAddByFileApi(DatasetApiResource): knowledge_config = KnowledgeConfig(**args) DocumentService.document_create_args_validate(knowledge_config) + dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None + if not knowledge_config.original_document_id and not dataset_process_rule and not knowledge_config.process_rule: + raise ValueError("process_rule is required.") + try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, knowledge_config=knowledge_config, account=dataset.created_by_account, - dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + dataset_process_rule=dataset_process_rule, created_from="api", ) except ProviderTokenNotInitError as ex: diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 5017835565..e1c021a44a 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -55,6 +55,25 @@ class ProviderModelWithStatusEntity(ProviderModel): status: ModelStatus load_balancing_enabled: bool = False + def raise_for_status(self) -> None: + """ + Check model status and raise ValueError if not active. + + :raises ValueError: When model status is not active, with a descriptive message + """ + if self.status == ModelStatus.ACTIVE: + return + + error_messages = { + ModelStatus.NO_CONFIGURE: "Model is not configured", + ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded", + ModelStatus.NO_PERMISSION: "No permission to use this model", + ModelStatus.DISABLED: "Model is disabled", + } + + if self.status in error_messages: + raise ValueError(error_messages[self.status]) + class ModelWithProviderEntity(ProviderModelWithStatusEntity): """ diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 231743bf2a..06fdb089d4 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -41,45 +41,53 @@ class Extensible: extensions = [] position_map: dict[str, int] = {} - # get the path of the current class - current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py") - current_dir_path = os.path.dirname(current_path) - - # traverse subdirectories - for subdir_name in os.listdir(current_dir_path): - if subdir_name.startswith("__"): - continue - - subdir_path = os.path.join(current_dir_path, subdir_name) - extension_name = subdir_name - if os.path.isdir(subdir_path): + # Get the package name from the module path + package_name = ".".join(cls.__module__.split(".")[:-1]) + + try: + # Get package directory path + package_spec = importlib.util.find_spec(package_name) + if not package_spec or not package_spec.origin: + raise ImportError(f"Could not find package {package_name}") + + package_dir = os.path.dirname(package_spec.origin) + + # Traverse subdirectories + for subdir_name in os.listdir(package_dir): + if subdir_name.startswith("__"): + continue + + subdir_path = os.path.join(package_dir, subdir_name) + if not os.path.isdir(subdir_path): + continue + + extension_name = subdir_name file_names = os.listdir(subdir_path) - # is builtin extension, builtin extension - # in the front-end page and business logic, there are special treatments. + # Check for extension module file + if (extension_name + ".py") not in file_names: + logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") + continue + + # Check for builtin flag and position builtin = False - # default position is 0 can not be None for sort_to_dict_by_position_map position = 0 if "__builtin__" in file_names: builtin = True - builtin_file_path = os.path.join(subdir_path, "__builtin__") if os.path.exists(builtin_file_path): position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip()) position_map[extension_name] = position - if (extension_name + ".py") not in file_names: - logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") - continue - - # Dynamic loading {subdir_name}.py file and find the subclass of Extensible - py_path = os.path.join(subdir_path, extension_name + ".py") - spec = importlib.util.spec_from_file_location(extension_name, py_path) + # Import the extension module + module_name = f"{package_name}.{extension_name}.{extension_name}" + spec = importlib.util.find_spec(module_name) if not spec or not spec.loader: - raise Exception(f"Failed to load module {extension_name} from {py_path}") + raise ImportError(f"Failed to load module {module_name}") mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) + # Find extension class extension_class = None for name, obj in vars(mod).items(): if isinstance(obj, type) and issubclass(obj, cls) and obj != cls: @@ -87,21 +95,21 @@ class Extensible: break if not extension_class: - logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.") + logging.warning(f"Missing subclass of {cls.__name__} in {module_name}, Skip.") continue + # Load schema if not builtin json_data: dict[str, Any] = {} if not builtin: - if "schema.json" not in file_names: + json_path = os.path.join(subdir_path, "schema.json") + if not os.path.exists(json_path): logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") continue - json_path = os.path.join(subdir_path, "schema.json") - json_data = {} - if os.path.exists(json_path): - with open(json_path, encoding="utf-8") as f: - json_data = json.load(f) + with open(json_path, encoding="utf-8") as f: + json_data = json.load(f) + # Create extension extensions.append( ModuleExtension( extension_class=extension_class, @@ -113,6 +121,11 @@ class Extensible: ) ) + except Exception as e: + logging.exception("Error scanning extensions") + raise + + # Sort extensions by position sorted_extensions = sort_to_dict_by_position_map( position_map=position_map, data=extensions, name_func=lambda x: x.name ) diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 373ef2bbe2..568149cc37 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -160,6 +160,10 @@ class ProviderModel(BaseModel): deprecated: bool = False model_config = ConfigDict(protected_namespaces=()) + @property + def support_structure_output(self) -> bool: + return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features + class ParameterRule(BaseModel): """ diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 7570200175..488a394679 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -3,7 +3,9 @@ from collections import defaultdict from json import JSONDecodeError from typing import Any, Optional, cast +from sqlalchemy import select from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session from configs import dify_config from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity @@ -393,19 +395,13 @@ class ProviderManager: @staticmethod def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]: - """ - Get all provider records of the workspace. - - :param tenant_id: workspace id - :return: - """ - providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all() - provider_name_to_provider_records_dict = defaultdict(list) - for provider in providers: - # TODO: Use provider name with prefix after the data migration - provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider) - + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True) + providers = session.scalars(stmt) + for provider in providers: + # Use provider name with prefix after the data migration + provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider) return provider_name_to_provider_records_dict @staticmethod @@ -416,17 +412,12 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - # Get all provider model records of the workspace - provider_models = ( - db.session.query(ProviderModel) - .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) - .all() - ) - provider_name_to_provider_model_records_dict = defaultdict(list) - for provider_model in provider_models: - provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model) - + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) + provider_models = session.scalars(stmt) + for provider_model in provider_models: + provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model) return provider_name_to_provider_model_records_dict @staticmethod @@ -437,17 +428,14 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - preferred_provider_types = ( - db.session.query(TenantPreferredModelProvider) - .filter(TenantPreferredModelProvider.tenant_id == tenant_id) - .all() - ) - - provider_name_to_preferred_provider_type_records_dict = { - preferred_provider_type.provider_name: preferred_provider_type - for preferred_provider_type in preferred_provider_types - } - + provider_name_to_preferred_provider_type_records_dict = {} + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id) + preferred_provider_types = session.scalars(stmt) + provider_name_to_preferred_provider_type_records_dict = { + preferred_provider_type.provider_name: preferred_provider_type + for preferred_provider_type in preferred_provider_types + } return provider_name_to_preferred_provider_type_records_dict @staticmethod @@ -458,18 +446,14 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - provider_model_settings = ( - db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all() - ) - provider_name_to_provider_model_settings_dict = defaultdict(list) - for provider_model_setting in provider_model_settings: - ( + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id) + provider_model_settings = session.scalars(stmt) + for provider_model_setting in provider_model_settings: provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append( provider_model_setting ) - ) - return provider_name_to_provider_model_settings_dict @staticmethod @@ -492,15 +476,14 @@ class ProviderManager: if not model_load_balancing_enabled: return {} - provider_load_balancing_configs = ( - db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all() - ) - provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) - for provider_load_balancing_config in provider_load_balancing_configs: - provider_name_to_provider_load_balancing_model_configs_dict[ - provider_load_balancing_config.provider_name - ].append(provider_load_balancing_config) + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id) + provider_load_balancing_configs = session.scalars(stmt) + for provider_load_balancing_config in provider_load_balancing_configs: + provider_name_to_provider_load_balancing_model_configs_dict[ + provider_load_balancing_config.provider_name + ].append(provider_load_balancing_config) return provider_name_to_provider_load_balancing_model_configs_dict @@ -626,10 +609,9 @@ class ProviderManager: if not cached_provider_credentials: try: # fix origin data - if ( - custom_provider_record.encrypted_config - and not custom_provider_record.encrypted_config.startswith("{") - ): + if custom_provider_record.encrypted_config is None: + raise ValueError("No credentials found") + if not custom_provider_record.encrypted_config.startswith("{"): provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} else: provider_credentials = json.loads(custom_provider_record.encrypted_config) @@ -733,7 +715,7 @@ class ProviderManager: return SystemConfiguration(enabled=False) # Convert provider_records to dict - quota_type_to_provider_records_dict = {} + quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {} for provider_record in provider_records: if provider_record.provider_type != ProviderType.SYSTEM.value: continue @@ -758,6 +740,11 @@ class ProviderManager: else: provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type] + if provider_record.quota_used is None: + raise ValueError("quota_used is None") + if provider_record.quota_limit is None: + raise ValueError("quota_limit is None") + quota_configuration = QuotaConfiguration( quota_type=provider_quota.quota_type, quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, @@ -791,10 +778,9 @@ class ProviderManager: cached_provider_credentials = provider_credentials_cache.get() if not cached_provider_credentials: - try: - provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config) - except JSONDecodeError: - provider_credentials = {} + provider_credentials: dict[str, Any] = {} + if provider_records and provider_records[0].encrypted_config: + provider_credentials = json.loads(provider_records[0].encrypted_config) # Get provider credential secret variables provider_credential_secret_variables = self._extract_secret_variables( diff --git a/api/core/rag/datasource/keyword/jieba/stopwords.py b/api/core/rag/datasource/keyword/jieba/stopwords.py index 9abe78d6ef..54b65d9a2d 100644 --- a/api/core/rag/datasource/keyword/jieba/stopwords.py +++ b/api/core/rag/datasource/keyword/jieba/stopwords.py @@ -720,7 +720,7 @@ STOPWORDS = { "〉", "〈", "…", - " ", + " ", "0", "1", "2", @@ -731,16 +731,6 @@ STOPWORDS = { "7", "8", "9", - "0", - "1", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9", "二", "三", "四", diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 0a3738ac93..6b9dd9c561 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -261,7 +261,7 @@ class OracleVector(BaseVector): words = pseg.cut(query) current_entity = "" for word, pos in words: - if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名 + if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名,ns: 地名,nt: 机构名 current_entity += word else: if current_entity: diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 486b4b01af..36d0688807 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -66,7 +66,8 @@ class LLMNodeData(BaseNodeData): context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) structured_output: dict | None = None - structured_output_enabled: bool = False + # We used 'structured_output_enabled' in the past, but it's not a good name. + structured_output_switch_on: bool = Field(False, alias="structured_output_enabled") @field_validator("prompt_config", mode="before") @classmethod @@ -74,3 +75,7 @@ class LLMNodeData(BaseNodeData): if v is None: return PromptConfig() return v + + @property + def structured_output_enabled(self) -> bool: + return self.structured_output_switch_on and self.structured_output is not None diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index df8f614db3..ee181cf3bf 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -12,9 +12,7 @@ from sqlalchemy.orm import Session from configs import dify_config from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.file import FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage from core.memory.token_buffer_memory import TokenBufferMemory @@ -74,7 +72,6 @@ from core.workflow.nodes.event import ( from core.workflow.utils.structured_output.entities import ( ResponseFormat, SpecialModelType, - SupportStructuredOutputStatus, ) from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT from core.workflow.utils.variable_template_parser import VariableTemplateParser @@ -277,7 +274,7 @@ class LLMNode(BaseNode[LLMNodeData]): llm_usage=usage, ) ) - except LLMNodeError as e: + except ValueError as e: yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -527,65 +524,53 @@ class LLMNode(BaseNode[LLMNodeData]): def _fetch_model_config( self, node_data_model: ModelConfig ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: - model_name = node_data_model.name - provider_name = node_data_model.provider + if not node_data_model.mode: + raise LLMModeRequiredError("LLM mode is required.") - model_manager = ModelManager() - model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name + model = ModelManager().get_model_instance( + tenant_id=self.tenant_id, + model_type=ModelType.LLM, + provider=node_data_model.provider, + model=node_data_model.name, ) - provider_model_bundle = model_instance.provider_model_bundle - model_type_instance = model_instance.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - model_credentials = model_instance.credentials + model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance) # check model - provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_name, model_type=ModelType.LLM + provider_model = model.provider_model_bundle.configuration.get_provider_model( + model=node_data_model.name, model_type=ModelType.LLM ) if provider_model is None: - raise ModelNotExistError(f"Model {model_name} not exist.") - - if provider_model.status == ModelStatus.NO_CONFIGURE: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - elif provider_model.status == ModelStatus.NO_PERMISSION: - raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") - elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + provider_model.raise_for_status() # model config - completion_params = node_data_model.completion_params - stop = [] - if "stop" in completion_params: - stop = completion_params["stop"] - del completion_params["stop"] - - # get model mode - model_mode = node_data_model.mode - if not model_mode: - raise LLMModeRequiredError("LLM mode is required.") - - model_schema = model_type_instance.get_model_schema(model_name, model_credentials) + stop: list[str] = [] + if "stop" in node_data_model.completion_params: + stop = node_data_model.completion_params.pop("stop") + model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials) if not model_schema: - raise ModelNotExistError(f"Model {model_name} not exist.") - support_structured_output = self._check_model_structured_output_support() - if support_structured_output == SupportStructuredOutputStatus.SUPPORTED: - completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules) - elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: - # Set appropriate response format based on model capabilities - self._set_response_format(completion_params, model_schema.parameter_rules) - return model_instance, ModelConfigWithCredentialsEntity( - provider=provider_name, - model=model_name, + raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + + if self.node_data.structured_output_enabled: + if model_schema.support_structure_output: + node_data_model.completion_params = self._handle_native_json_schema( + node_data_model.completion_params, model_schema.parameter_rules + ) + else: + # Set appropriate response format based on model capabilities + self._set_response_format(node_data_model.completion_params, model_schema.parameter_rules) + + return model, ModelConfigWithCredentialsEntity( + provider=node_data_model.provider, + model=node_data_model.name, model_schema=model_schema, - mode=model_mode, - provider_model_bundle=provider_model_bundle, - credentials=model_credentials, - parameters=completion_params, + mode=node_data_model.mode, + provider_model_bundle=model.provider_model_bundle, + credentials=model.credentials, + parameters=node_data_model.completion_params, stop=stop, ) @@ -786,13 +771,25 @@ class LLMNode(BaseNode[LLMNodeData]): "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) - support_structured_output = self._check_model_structured_output_support() - if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: - filtered_prompt_messages = self._handle_prompt_based_schema( - prompt_messages=filtered_prompt_messages, - ) - stop = model_config.stop - return filtered_prompt_messages, stop + + model = ModelManager().get_model_instance( + tenant_id=self.tenant_id, + model_type=ModelType.LLM, + provider=self.node_data.model.provider, + model=self.node_data.model.name, + ) + model_schema = model.model_type_instance.get_model_schema( + model=self.node_data.model.name, + credentials=model.credentials, + ) + if not model_schema: + raise ModelNotExistError(f"Model {self.node_data.model.name} not exist.") + if self.node_data.structured_output_enabled: + if not model_schema.support_structure_output: + filtered_prompt_messages = self._handle_prompt_based_schema( + prompt_messages=filtered_prompt_messages, + ) + return filtered_prompt_messages, model_config.stop def _parse_structured_output(self, result_text: str) -> dict[str, Any]: structured_output: dict[str, Any] = {} @@ -903,7 +900,7 @@ class LLMNode(BaseNode[LLMNodeData]): variable_mapping["#context#"] = node_data.context.variable_selector if node_data.vision.enabled: - variable_mapping["#files#"] = ["sys", SystemVariableKey.FILES.value] + variable_mapping["#files#"] = node_data.vision.configs.variable_selector if node_data.memory: variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value] @@ -1185,32 +1182,6 @@ class LLMNode(BaseNode[LLMNodeData]): except json.JSONDecodeError: raise LLMNodeError("structured_output_schema is not valid JSON format") - def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus: - """ - Check if the current model supports structured output. - - Returns: - SupportStructuredOutput: The support status of structured output - """ - # Early return if structured output is disabled - if ( - not isinstance(self.node_data, LLMNodeData) - or not self.node_data.structured_output_enabled - or not self.node_data.structured_output - ): - return SupportStructuredOutputStatus.DISABLED - # Get model schema and check if it exists - model_schema = self._fetch_model_schema(self.node_data.model.provider) - if not model_schema: - return SupportStructuredOutputStatus.DISABLED - - # Check if model supports structured output feature - return ( - SupportStructuredOutputStatus.SUPPORTED - if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features) - else SupportStructuredOutputStatus.UNSUPPORTED - ) - def _save_multimodal_output_and_convert_result_to_markdown( self, contents: str | list[PromptMessageContentUnionTypes] | None, diff --git a/api/core/workflow/utils/structured_output/entities.py b/api/core/workflow/utils/structured_output/entities.py index 7954acbaee..6491042bfe 100644 --- a/api/core/workflow/utils/structured_output/entities.py +++ b/api/core/workflow/utils/structured_output/entities.py @@ -14,11 +14,3 @@ class SpecialModelType(StrEnum): GEMINI = "gemini" OLLAMA = "ollama" - - -class SupportStructuredOutputStatus(StrEnum): - """Constants for structured output support status""" - - SUPPORTED = "supported" - UNSUPPORTED = "unsupported" - DISABLED = "disabled" diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 26bd6b3577..a837552007 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -70,6 +70,7 @@ def init_app(app: DifyApp) -> Celery: "schedule.update_tidb_serverless_status_task", "schedule.clean_messages", "schedule.mail_clean_document_notify_task", + "schedule.queue_monitor_task", ] day = dify_config.CELERY_BEAT_SCHEDULER_TIME beat_schedule = { @@ -98,6 +99,12 @@ def init_app(app: DifyApp) -> Celery: "task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task", "schedule": crontab(minute="0", hour="10", day_of_week="1"), }, + "datasets-queue-monitor": { + "task": "schedule.queue_monitor_task.queue_monitor_task", + "schedule": timedelta( + minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30 + ), + }, } celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) diff --git a/api/libs/helper.py b/api/libs/helper.py index afc8f31681..463ba3308b 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -18,6 +18,7 @@ from flask_restful import fields from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.file import helpers as file_helpers +from core.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_redis import redis_client if TYPE_CHECKING: @@ -196,7 +197,7 @@ def generate_text_hash(text: str) -> str: def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response: if isinstance(response, dict): - return Response(response=json.dumps(response), status=200, mimetype="application/json") + return Response(response=json.dumps(jsonable_encoder(response)), status=200, mimetype="application/json") else: def generate() -> Generator: diff --git a/api/models/provider.py b/api/models/provider.py index 497cbefc61..1e25f0c90f 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,6 +1,9 @@ +from datetime import datetime from enum import Enum +from typing import Optional -from sqlalchemy import func +from sqlalchemy import func, text +from sqlalchemy.orm import Mapped, mapped_column from .base import Base from .engine import db @@ -51,20 +54,24 @@ class Provider(Base): ), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) - encrypted_config = db.Column(db.Text, nullable=True) - is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - last_used = db.Column(db.DateTime, nullable=True) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + provider_type: Mapped[str] = mapped_column( + db.String(40), nullable=False, server_default=text("'custom'::character varying") + ) + encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) + is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) + last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying")) - quota_limit = db.Column(db.BigInteger, nullable=True) - quota_used = db.Column(db.BigInteger, default=0) + quota_type: Mapped[Optional[str]] = mapped_column( + db.String(40), nullable=True, server_default=text("''::character varying") + ) + quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True) + quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def __repr__(self): return ( @@ -104,15 +111,15 @@ class ProviderModel(Base): ), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - model_type = db.Column(db.String(40), nullable=False) - encrypted_config = db.Column(db.Text, nullable=True) - is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) + is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TenantDefaultModel(Base): @@ -122,13 +129,13 @@ class TenantDefaultModel(Base): db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - model_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TenantPreferredModelProvider(Base): @@ -138,12 +145,12 @@ class TenantPreferredModelProvider(Base): db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - preferred_provider_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + preferred_provider_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderOrder(Base): @@ -153,22 +160,24 @@ class ProviderOrder(Base): db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - account_id = db.Column(StringUUID, nullable=False) - payment_product_id = db.Column(db.String(191), nullable=False) - payment_id = db.Column(db.String(191)) - transaction_id = db.Column(db.String(191)) - quantity = db.Column(db.Integer, nullable=False, server_default=db.text("1")) - currency = db.Column(db.String(40)) - total_amount = db.Column(db.Integer) - payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying")) - paid_at = db.Column(db.DateTime) - pay_failed_at = db.Column(db.DateTime) - refunded_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + payment_product_id: Mapped[str] = mapped_column(db.String(191), nullable=False) + payment_id: Mapped[Optional[str]] = mapped_column(db.String(191)) + transaction_id: Mapped[Optional[str]] = mapped_column(db.String(191)) + quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1")) + currency: Mapped[Optional[str]] = mapped_column(db.String(40)) + total_amount: Mapped[Optional[int]] = mapped_column(db.Integer) + payment_status: Mapped[str] = mapped_column( + db.String(40), nullable=False, server_default=text("'wait_pay'::character varying") + ) + paid_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + pay_failed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + refunded_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderModelSetting(Base): @@ -182,15 +191,15 @@ class ProviderModelSetting(Base): db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - model_type = db.Column(db.String(40), nullable=False) - enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) + load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class LoadBalancingModelConfig(Base): @@ -204,13 +213,13 @@ class LoadBalancingModelConfig(Base): db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - model_type = db.Column(db.String(40), nullable=False) - name = db.Column(db.String(255), nullable=False) - encrypted_config = db.Column(db.Text, nullable=True) - enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + name: Mapped[str] = mapped_column(db.String(255), nullable=False) + encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) + enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/schedule/queue_monitor_task.py b/api/schedule/queue_monitor_task.py new file mode 100644 index 0000000000..e3a7021b9d --- /dev/null +++ b/api/schedule/queue_monitor_task.py @@ -0,0 +1,62 @@ +import logging +from datetime import datetime +from urllib.parse import urlparse + +import click +from flask import render_template +from redis import Redis + +import app +from configs import dify_config +from extensions.ext_database import db +from extensions.ext_mail import mail + +# Create a dedicated Redis connection (using the same configuration as Celery) +celery_broker_url = dify_config.CELERY_BROKER_URL + +parsed = urlparse(celery_broker_url) +host = parsed.hostname or "localhost" +port = parsed.port or 6379 +password = parsed.password or None +redis_db = parsed.path.strip("/") or "1" # type: ignore + +celery_redis = Redis(host=host, port=port, password=password, db=redis_db) + + +@app.celery.task(queue="monitor") +def queue_monitor_task(): + queue_name = "dataset" + threshold = dify_config.QUEUE_MONITOR_THRESHOLD + + try: + queue_length = celery_redis.llen(f"{queue_name}") + logging.info(click.style(f"Start monitor {queue_name}", fg="green")) + logging.info(click.style(f"Queue length: {queue_length}", fg="green")) + + if queue_length >= threshold: + warning_msg = f"Queue {queue_name} task count exceeded the limit.: {queue_length}/{threshold}" + logging.warning(click.style(warning_msg, fg="red")) + alter_emails = dify_config.QUEUE_MONITOR_ALERT_EMAILS + if alter_emails: + to_list = alter_emails.split(",") + for to in to_list: + try: + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + html_content = render_template( + "queue_monitor_alert_email_template_en-US.html", + queue_name=queue_name, + queue_length=queue_length, + threshold=threshold, + alert_time=current_time, + ) + mail.send( + to=to, subject="Alert: Dataset Queue pending tasks exceeded the limit", html=html_content + ) + except Exception as e: + logging.exception(click.style("Exception occurred during sending email", fg="red")) + + except Exception as e: + logging.exception(click.style("Exception occurred during queue monitoring", fg="red")) + finally: + if db.session.is_active: + db.session.close() diff --git a/api/services/tag_service.py b/api/services/tag_service.py index be748e8dd1..74c6150b44 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -46,6 +46,8 @@ class TagService: @staticmethod def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str) -> list: + if not tag_type or not tag_name: + return [] tags = ( db.session.query(Tag) .filter(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type) @@ -88,7 +90,7 @@ class TagService: @staticmethod def update_tags(args: dict, tag_id: str) -> Tag: - if TagService.get_tag_by_tag_name(args["type"], current_user.current_tenant_id, args["name"]): + if TagService.get_tag_by_tag_name(args.get("type", ""), current_user.current_tenant_id, args.get("name", "")): raise ValueError("Tag name already exists") tag = db.session.query(Tag).filter(Tag.id == tag_id).first() if not tag: diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index f32bc4f187..51b6343fdc 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -5,7 +5,7 @@ import uuid import click from celery import shared_task # type: ignore -from sqlalchemy import func, select +from sqlalchemy import func from sqlalchemy.orm import Session from core.model_manager import ModelManager @@ -68,11 +68,6 @@ def batch_create_segment_to_index_task( model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, ) - word_count_change = 0 - segments_to_insert: list[str] = [] - max_position_stmt = select(func.max(DocumentSegment.position)).where( - DocumentSegment.document_id == dataset_document.id - ) word_count_change = 0 if embedding_model: tokens_list = embedding_model.get_text_embedding_num_tokens( diff --git a/api/templates/queue_monitor_alert_email_template_en-US.html b/api/templates/queue_monitor_alert_email_template_en-US.html new file mode 100644 index 0000000000..2885210864 --- /dev/null +++ b/api/templates/queue_monitor_alert_email_template_en-US.html @@ -0,0 +1,129 @@ + + + + + + + + +
+
+ Dify Logo +
+

Queue Monitoring Alert

+

Our system has detected an abnormal queue status that requires your attention:

+ +
+
Queue Task Alert
+
+ Queue "{{queue_name}}" has {{queue_length}} pending tasks (Threshold: {{threshold}}) +
+
+ +
+

Recommended actions:

+

1. Check the queue processing status in the system dashboard

+

2. Verify if there are any processing bottlenecks

+

3. Consider scaling up workers if needed

+
+ +

Additional Information:

+ +
+ + + diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 5fbee266bd..6aa48b1cbb 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -3,11 +3,16 @@ import os import time import uuid from collections.abc import Generator -from unittest.mock import MagicMock +from decimal import Decimal +from unittest.mock import MagicMock, patch import pytest +from app_factory import create_app +from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.message_entities import AssistantPromptMessage from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.enums import SystemVariableKey @@ -19,13 +24,27 @@ from core.workflow.nodes.llm.node import LLMNode from extensions.ext_database import db from models.enums import UserFrom from models.workflow import WorkflowType -from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config """FOR MOCK FIXTURES, DO NOT REMOVE""" from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock +@pytest.fixture(scope="session") +def app(): + # Set up storage configuration + os.environ["STORAGE_TYPE"] = "opendal" + os.environ["OPENDAL_SCHEME"] = "fs" + os.environ["OPENDAL_FS_ROOT"] = "storage" + + # Ensure storage directory exists + os.makedirs("storage", exist_ok=True) + + app = create_app() + dify_config.LOGIN_DISABLED = True + return app + + def init_llm_node(config: dict) -> LLMNode: graph_config = { "edges": [ @@ -40,13 +59,19 @@ def init_llm_node(config: dict) -> LLMNode: graph = Graph.init(graph_config=graph_config) + # Use proper UUIDs for database compatibility + tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" + app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c" + workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d" + user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e" + init_params = GraphInitParams( - tenant_id="1", - app_id="1", + tenant_id=tenant_id, + app_id=app_id, workflow_type=WorkflowType.WORKFLOW, - workflow_id="1", + workflow_id=workflow_id, graph_config=graph_config, - user_id="1", + user_id=user_id, user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, call_depth=0, @@ -77,115 +102,197 @@ def init_llm_node(config: dict) -> LLMNode: return node -def test_execute_llm(setup_model_mock): - node = init_llm_node( - config={ - "id": "llm", - "data": { - "title": "123", - "type": "llm", - "model": { - "provider": "langgenius/openai/openai", - "name": "gpt-3.5-turbo", - "mode": "chat", - "completion_params": {}, +def test_execute_llm(app): + with app.app_context(): + node = init_llm_node( + config={ + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": { + "provider": "langgenius/openai/openai", + "name": "gpt-3.5-turbo", + "mode": "chat", + "completion_params": {}, + }, + "prompt_template": [ + { + "role": "system", + "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.", + }, + {"role": "user", "text": "{{#sys.query#}}"}, + ], + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, }, - "prompt_template": [ - {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."}, - {"role": "user", "text": "{{#sys.query#}}"}, - ], - "memory": None, - "context": {"enabled": False}, - "vision": {"enabled": False}, }, - }, - ) + ) - credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} + credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} - # Mock db.session.close() - db.session.close = MagicMock() + # Create a proper LLM result with real entities + mock_usage = LLMUsage( + prompt_tokens=30, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("1000"), + prompt_price=Decimal("0.00003"), + completion_tokens=20, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("1000"), + completion_price=Decimal("0.00004"), + total_tokens=50, + total_price=Decimal("0.00007"), + currency="USD", + latency=0.5, + ) - node._fetch_model_config = get_mocked_fetch_model_config( - provider="langgenius/openai/openai", - model="gpt-3.5-turbo", - mode="chat", - credentials=credentials, - ) + mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.") + + mock_llm_result = LLMResult( + model="gpt-3.5-turbo", + prompt_messages=[], + message=mock_message, + usage=mock_usage, + ) + + # Create a simple mock model instance that doesn't call real providers + mock_model_instance = MagicMock() + mock_model_instance.invoke_llm.return_value = mock_llm_result + + # Create a simple mock model config with required attributes + mock_model_config = MagicMock() + mock_model_config.mode = "chat" + mock_model_config.provider = "langgenius/openai/openai" + mock_model_config.model = "gpt-3.5-turbo" + mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" + + # Mock the _fetch_model_config method + def mock_fetch_model_config_func(_node_data_model): + return mock_model_instance, mock_model_config + + # Also mock ModelManager.get_model_instance to avoid database calls + def mock_get_model_instance(_self, **kwargs): + return mock_model_instance - # execute node - result = node._run() - assert isinstance(result, Generator) + with ( + patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), + patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), + ): + # execute node + result = node._run() + assert isinstance(result, Generator) - for item in result: - if isinstance(item, RunCompletedEvent): - assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.process_data is not None - assert item.run_result.outputs is not None - assert item.run_result.outputs.get("text") is not None - assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 + for item in result: + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.process_data is not None + assert item.run_result.outputs is not None + assert item.run_result.outputs.get("text") is not None + assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0 @pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) -def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_model_mock): +def test_execute_llm_with_jinja2(app, setup_code_executor_mock): """ Test execute LLM node with jinja2 """ - node = init_llm_node( - config={ - "id": "llm", - "data": { - "title": "123", - "type": "llm", - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, - "prompt_config": { - "jinja2_variables": [ - {"variable": "sys_query", "value_selector": ["sys", "query"]}, - {"variable": "output", "value_selector": ["abc", "output"]}, - ] - }, - "prompt_template": [ - { - "role": "system", - "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", - "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", - "edition_type": "jinja2", + with app.app_context(): + node = init_llm_node( + config={ + "id": "llm", + "data": { + "title": "123", + "type": "llm", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "prompt_config": { + "jinja2_variables": [ + {"variable": "sys_query", "value_selector": ["sys", "query"]}, + {"variable": "output", "value_selector": ["abc", "output"]}, + ] }, - { - "role": "user", - "text": "{{#sys.query#}}", - "jinja2_text": "{{sys_query}}", - "edition_type": "basic", - }, - ], - "memory": None, - "context": {"enabled": False}, - "vision": {"enabled": False}, + "prompt_template": [ + { + "role": "system", + "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}", + "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.", + "edition_type": "jinja2", + }, + { + "role": "user", + "text": "{{#sys.query#}}", + "jinja2_text": "{{sys_query}}", + "edition_type": "basic", + }, + ], + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, + }, }, - }, - ) + ) - credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")} + # Mock db.session.close() + db.session.close = MagicMock() - # Mock db.session.close() - db.session.close = MagicMock() + # Create a proper LLM result with real entities + mock_usage = LLMUsage( + prompt_tokens=30, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("1000"), + prompt_price=Decimal("0.00003"), + completion_tokens=20, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("1000"), + completion_price=Decimal("0.00004"), + total_tokens=50, + total_price=Decimal("0.00007"), + currency="USD", + latency=0.5, + ) - node._fetch_model_config = get_mocked_fetch_model_config( - provider="langgenius/openai/openai", - model="gpt-3.5-turbo", - mode="chat", - credentials=credentials, - ) + mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?") + + mock_llm_result = LLMResult( + model="gpt-3.5-turbo", + prompt_messages=[], + message=mock_message, + usage=mock_usage, + ) + + # Create a simple mock model instance that doesn't call real providers + mock_model_instance = MagicMock() + mock_model_instance.invoke_llm.return_value = mock_llm_result + + # Create a simple mock model config with required attributes + mock_model_config = MagicMock() + mock_model_config.mode = "chat" + mock_model_config.provider = "openai" + mock_model_config.model = "gpt-3.5-turbo" + mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b" + + # Mock the _fetch_model_config method + def mock_fetch_model_config_func(_node_data_model): + return mock_model_instance, mock_model_config + + # Also mock ModelManager.get_model_instance to avoid database calls + def mock_get_model_instance(_self, **kwargs): + return mock_model_instance - # execute node - result = node._run() + with ( + patch.object(node, "_fetch_model_config", mock_fetch_model_config_func), + patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance), + ): + # execute node + result = node._run() - for item in result: - if isinstance(item, RunCompletedEvent): - assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.run_result.process_data is not None - assert "sunny" in json.dumps(item.run_result.process_data) - assert "what's the weather today?" in json.dumps(item.run_result.process_data) + for item in result: + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.process_data is not None + assert "sunny" in json.dumps(item.run_result.process_data) + assert "what's the weather today?" in json.dumps(item.run_result.process_data) def test_extract_json(): diff --git a/docker/.env.example b/docker/.env.example index ac9536be03..4cf5e202d0 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1057,7 +1057,7 @@ PLUGIN_MAX_EXECUTION_TIMEOUT=600 PIP_MIRROR_URL= # https://github.com/langgenius/dify-plugin-daemon/blob/main/.env.example -# Plugin storage type, local aws_s3 tencent_cos azure_blob aliyun_oss +# Plugin storage type, local aws_s3 tencent_cos azure_blob aliyun_oss volcengine_tos PLUGIN_STORAGE_TYPE=local PLUGIN_STORAGE_LOCAL_ROOT=/app/storage PLUGIN_WORKING_PATH=/app/storage/cwd @@ -1087,6 +1087,11 @@ PLUGIN_ALIYUN_OSS_ACCESS_KEY_ID= PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET= PLUGIN_ALIYUN_OSS_AUTH_VERSION=v4 PLUGIN_ALIYUN_OSS_PATH= +# Plugin oss volcengine tos +PLUGIN_VOLCENGINE_TOS_ENDPOINT= +PLUGIN_VOLCENGINE_TOS_ACCESS_KEY= +PLUGIN_VOLCENGINE_TOS_SECRET_KEY= +PLUGIN_VOLCENGINE_TOS_REGION= # ------------------------------ # OTLP Collector Configuration @@ -1106,3 +1111,10 @@ OTEL_METRIC_EXPORT_TIMEOUT=30000 # Prevent Clickjacking ALLOW_EMBED=false + +# Dataset queue monitor configuration +QUEUE_MONITOR_THRESHOLD=200 +# You can configure multiple ones, separated by commas. eg: test1@dify.ai,test2@dify.ai +QUEUE_MONITOR_ALERT_EMAILS= +# Monitor interval in minutes, default is 30 minutes +QUEUE_MONITOR_INTERVAL=30 diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 74a7b87bf9..75bdab1a06 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -184,6 +184,10 @@ services: ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-} ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4} ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-} + VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-} + VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-} + VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-} + VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-} ports: - "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}" volumes: diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index d4a0b94619..8276e2977f 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -121,6 +121,10 @@ services: ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-} ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4} ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-} + VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-} + VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-} + VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-} + VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-} ports: - "${EXPOSE_PLUGIN_DAEMON_PORT:-5002}:${PLUGIN_DAEMON_PORT:-5002}" - "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}" diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 41e86d015f..e559021684 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -484,6 +484,10 @@ x-shared-env: &shared-api-worker-env PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-} PLUGIN_ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4} PLUGIN_ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-} + PLUGIN_VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-} + PLUGIN_VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-} + PLUGIN_VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-} + PLUGIN_VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-} ENABLE_OTEL: ${ENABLE_OTEL:-false} OTLP_BASE_ENDPOINT: ${OTLP_BASE_ENDPOINT:-http://localhost:4318} OTLP_API_KEY: ${OTLP_API_KEY:-} @@ -497,6 +501,9 @@ x-shared-env: &shared-api-worker-env OTEL_BATCH_EXPORT_TIMEOUT: ${OTEL_BATCH_EXPORT_TIMEOUT:-10000} OTEL_METRIC_EXPORT_TIMEOUT: ${OTEL_METRIC_EXPORT_TIMEOUT:-30000} ALLOW_EMBED: ${ALLOW_EMBED:-false} + QUEUE_MONITOR_THRESHOLD: ${QUEUE_MONITOR_THRESHOLD:-200} + QUEUE_MONITOR_ALERT_EMAILS: ${QUEUE_MONITOR_ALERT_EMAILS:-} + QUEUE_MONITOR_INTERVAL: ${QUEUE_MONITOR_INTERVAL:-30} services: # API service @@ -683,6 +690,10 @@ services: ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-} ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4} ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-} + VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-} + VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-} + VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-} + VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-} ports: - "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}" volumes: diff --git a/docker/middleware.env.example b/docker/middleware.env.example index ba6859885b..66037f281c 100644 --- a/docker/middleware.env.example +++ b/docker/middleware.env.example @@ -152,3 +152,8 @@ PLUGIN_ALIYUN_OSS_ACCESS_KEY_ID= PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET= PLUGIN_ALIYUN_OSS_AUTH_VERSION=v4 PLUGIN_ALIYUN_OSS_PATH= +# Plugin oss volcengine tos +PLUGIN_VOLCENGINE_TOS_ENDPOINT= +PLUGIN_VOLCENGINE_TOS_ACCESS_KEY= +PLUGIN_VOLCENGINE_TOS_SECRET_KEY= +PLUGIN_VOLCENGINE_TOS_REGION= diff --git a/web/app/(commonLayout)/apps/AppCard.tsx b/web/app/(commonLayout)/apps/AppCard.tsx index 42967b96f4..31b9ed87c2 100644 --- a/web/app/(commonLayout)/apps/AppCard.tsx +++ b/web/app/(commonLayout)/apps/AppCard.tsx @@ -4,7 +4,7 @@ import { useContext, useContextSelector } from 'use-context-selector' import { useRouter } from 'next/navigation' import { useCallback, useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' -import { RiBuildingLine, RiGlobalLine, RiLockLine, RiMoreFill } from '@remixicon/react' +import { RiBuildingLine, RiGlobalLine, RiLockLine, RiMoreFill, RiVerifiedBadgeLine } from '@remixicon/react' import cn from '@/utils/classnames' import type { App } from '@/types/app' import Confirm from '@/app/components/base/confirm' @@ -338,7 +338,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
{app.access_mode === AccessMode.PUBLIC && - + } {app.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS && @@ -346,6 +346,9 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { {app.access_mode === AccessMode.ORGANIZATION && } + {app.access_mode === AccessMode.EXTERNAL_MEMBERS && + + }
diff --git a/web/app/(commonLayout)/datasets/Container.tsx b/web/app/(commonLayout)/datasets/Container.tsx index 62569ab26b..112b6a752e 100644 --- a/web/app/(commonLayout)/datasets/Container.tsx +++ b/web/app/(commonLayout)/datasets/Container.tsx @@ -87,7 +87,7 @@ const Container = () => { return (
-
+
setActiveTab(newActiveTab)} diff --git a/web/app/(commonLayout)/datasets/template/template.ja.mdx b/web/app/(commonLayout)/datasets/template/template.ja.mdx index b9fab19948..a796b65bae 100644 --- a/web/app/(commonLayout)/datasets/template/template.ja.mdx +++ b/web/app/(commonLayout)/datasets/template/template.ja.mdx @@ -192,15 +192,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi - original_document_id が渡されない場合、新しい操作が実行され、process_rule が必要です。 - indexing_technique インデックスモード - - high_quality 高品質: 埋め込みモデルを使用してベクトルデータベースインデックスを構築 - - economy 経済: キーワードテーブルインデックスの反転インデックスを構築 + - high_quality 高品質:埋め込みモデルを使用してベクトルデータベースインデックスを構築 + - economy 経済:キーワードテーブルインデックスの反転インデックスを構築 - doc_form インデックス化された内容の形式 - text_model テキストドキュメントは直接埋め込まれます; `economy` モードではこの形式がデフォルト - hierarchical_model 親子モード - - qa_model Q&A モード: 分割されたドキュメントの質問と回答ペアを生成し、質問を埋め込みます + - qa_model Q&A モード:分割されたドキュメントの質問と回答ペアを生成し、質問を埋め込みます - - doc_language Q&A モードでは、ドキュメントの言語を指定します。例: English, Chinese + - doc_language Q&A モードでは、ドキュメントの言語を指定します。例:English, Chinese - process_rule 処理ルール - mode (string) クリーニング、セグメンテーションモード、自動 / カスタム @@ -214,7 +214,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi - segmentation (object) セグメンテーションルール - separator カスタムセグメント識別子。現在は 1 つの区切り文字のみ設定可能。デフォルトは \n - max_tokens 最大長 (トークン) デフォルトは 1000 - - parent_mode 親チャンクの検索モード: full-doc 全文検索 / paragraph 段落検索 + - parent_mode 親チャンクの検索モード:full-doc 全文検索 / paragraph 段落検索 - subchunk_segmentation (object) 子チャンクルール - separator セグメンテーション識別子。現在は 1 つの区切り文字のみ許可。デフォルトは *** - max_tokens 最大長 (トークン) は親チャンクの長さより短いことを検証する必要があります @@ -324,7 +324,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi - partial_members 一部のメンバー - プロバイダー (オプション、デフォルト: vendor) + プロバイダー (オプション、デフォルト:vendor) - vendor ベンダー - external 外部ナレッジ @@ -415,16 +415,16 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi 検索キーワード、オプション - タグIDリスト、オプション + タグ ID リスト、オプション - ページ番号、オプション、デフォルト1 + ページ番号、オプション、デフォルト 1 - 返されるアイテム数、オプション、デフォルト20、範囲1-100 + 返されるアイテム数、オプション、デフォルト 20、範囲 1-100 - すべてのデータセットを含めるかどうか(所有者のみ有効)、オプション、デフォルトはfalse + すべてのデータセットを含めるかどうか(所有者のみ有効)、オプション、デフォルトは false @@ -2013,7 +2013,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi ### Request Body - (text) 新しいタグ名、必須、最大長50文字 + (text) 新しいタグ名、必須、最大長 50 文字 @@ -2099,10 +2099,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi ### Request Body - (text) 変更後のタグ名、必須、最大長50文字 + (text) 変更後のタグ名、必須、最大長 50 文字 - (text) タグID、必須 + (text) タグ ID、必須 @@ -2147,7 +2147,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi ### Request Body - (text) タグID、必須 + (text) タグ ID、必須 @@ -2188,10 +2188,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi ### Request Body - (list) タグIDリスト、必須 + (list) タグ ID リスト、必須 - (text) ナレッジベースID、必須 + (text) ナレッジベース ID、必須 @@ -2230,10 +2230,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi ### Request Body - (text) タグID、必須 + (text) タグ ID、必須 - (text) ナレッジベースID、必須 + (text) ナレッジベース ID、必須 @@ -2273,7 +2273,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi ### Path - (text) ナレッジベースID + (text) ナレッジベース ID diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index b10f22002a..08ef5d562a 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -207,7 +207,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi - doc_language 在 Q&A 模式下,指定文档的语言,例如:EnglishChinese - process_rule 处理规则 - - mode (string) 清洗、分段模式 ,automatic 自动 / custom 自定义 / hierarchical 父子 + - mode (string) 清洗、分段模式,automatic 自动 / custom 自定义 / hierarchical 父子 - rules (object) 自定义规则(自动模式下,该字段为空) - pre_processing_rules (array[object]) 预处理规则 - id (string) 预处理规则的唯一标识符 @@ -234,12 +234,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi - hybrid_search 混合检索 - semantic_search 语义检索 - full_text_search 全文检索 - - reranking_enable (bool) 是否开启rerank + - reranking_enable (bool) 是否开启 rerank - reranking_model (object) Rerank 模型配置 - reranking_provider_name (string) Rerank 模型的提供商 - reranking_model_name (string) Rerank 模型的名称 - top_k (int) 召回条数 - - score_threshold_enabled (bool)是否开启召回分数限制 + - score_threshold_enabled (bool) 是否开启召回分数限制 - score_threshold (float) 召回分数限制 @@ -350,12 +350,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi - hybrid_search 混合检索 - semantic_search 语义检索 - full_text_search 全文检索 - - reranking_enable (bool) 是否开启rerank + - reranking_enable (bool) 是否开启 rerank - reranking_model (object) Rerank 模型配置 - reranking_provider_name (string) Rerank 模型的提供商 - reranking_model_name (string) Rerank 模型的名称 - top_k (int) 召回条数 - - score_threshold_enabled (bool)是否开启召回分数限制 + - score_threshold_enabled (bool) 是否开启召回分数限制 - score_threshold (float) 召回分数限制 @@ -1322,7 +1322,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi 文档 ID - 文档分段ID + 文档分段 ID @@ -1435,7 +1435,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi 文档 ID - 文档分段ID + 文档分段 ID @@ -2404,7 +2404,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi ### Request Body - (text) 新标签名称,必填,最大长度为50 + (text) 新标签名称,必填,最大长度为 50 @@ -2490,10 +2490,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi ### Request Body - (text) 修改后的标签名称,必填,最大长度为50 + (text) 修改后的标签名称,必填,最大长度为 50 - (text) 标签ID,必填 + (text) 标签 ID,必填 @@ -2538,7 +2538,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi ### Request Body - (text) 标签ID,必填 + (text) 标签 ID,必填 @@ -2579,10 +2579,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi ### Request Body - (list) 标签ID列表,必填 + (list) 标签 ID 列表,必填 - (text) 知识库ID,必填 + (text) 知识库 ID,必填 @@ -2621,10 +2621,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi ### Request Body - (text) 标签ID,必填 + (text) 标签 ID,必填 - (text) 知识库ID,必填 + (text) 知识库 ID,必填 @@ -2664,7 +2664,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi ### Path - (text) 知识库ID + (text) 知识库 ID diff --git a/web/app/(shareLayout)/layout.tsx b/web/app/(shareLayout)/layout.tsx index 83adbd3cae..8db336a17d 100644 --- a/web/app/(shareLayout)/layout.tsx +++ b/web/app/(shareLayout)/layout.tsx @@ -1,14 +1,42 @@ -import React from 'react' +'use client' +import React, { useEffect, useState } from 'react' import type { FC } from 'react' -import type { Metadata } from 'next' - -export const metadata: Metadata = { - icons: 'data:,', // prevent browser from using default favicon -} +import { usePathname, useSearchParams } from 'next/navigation' +import Loading from '../components/base/loading' +import { useGlobalPublicStore } from '@/context/global-public-context' +import { AccessMode } from '@/models/access-control' +import { getAppAccessModeByAppCode } from '@/service/share' const Layout: FC<{ children: React.ReactNode }> = ({ children }) => { + const isGlobalPending = useGlobalPublicStore(s => s.isGlobalPending) + const setWebAppAccessMode = useGlobalPublicStore(s => s.setWebAppAccessMode) + const pathname = usePathname() + const searchParams = useSearchParams() + const redirectUrl = searchParams.get('redirect_url') + const [isLoading, setIsLoading] = useState(true) + useEffect(() => { + (async () => { + let appCode: string | null = null + if (redirectUrl) + appCode = redirectUrl?.split('/').pop() || null + else + appCode = pathname.split('/').pop() || null + + if (!appCode) + return + setIsLoading(true) + const ret = await getAppAccessModeByAppCode(appCode) + setWebAppAccessMode(ret?.accessMode || AccessMode.PUBLIC) + setIsLoading(false) + })() + }, [pathname, redirectUrl, setWebAppAccessMode]) + if (isLoading || isGlobalPending) { + return
+ +
+ } return (
{children} diff --git a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx new file mode 100644 index 0000000000..da754794b1 --- /dev/null +++ b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx @@ -0,0 +1,96 @@ +'use client' +import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { useState } from 'react' +import { useRouter, useSearchParams } from 'next/navigation' +import { useContext } from 'use-context-selector' +import Countdown from '@/app/components/signin/countdown' +import Button from '@/app/components/base/button' +import Input from '@/app/components/base/input' +import Toast from '@/app/components/base/toast' +import { sendWebAppResetPasswordCode, verifyWebAppResetPasswordCode } from '@/service/common' +import I18NContext from '@/context/i18n' + +export default function CheckCode() { + const { t } = useTranslation() + const router = useRouter() + const searchParams = useSearchParams() + const email = decodeURIComponent(searchParams.get('email') as string) + const token = decodeURIComponent(searchParams.get('token') as string) + const [code, setVerifyCode] = useState('') + const [loading, setIsLoading] = useState(false) + const { locale } = useContext(I18NContext) + + const verify = async () => { + try { + if (!code.trim()) { + Toast.notify({ + type: 'error', + message: t('login.checkCode.emptyCode'), + }) + return + } + if (!/\d{6}/.test(code)) { + Toast.notify({ + type: 'error', + message: t('login.checkCode.invalidCode'), + }) + return + } + setIsLoading(true) + const ret = await verifyWebAppResetPasswordCode({ email, code, token }) + if (ret.is_valid) { + const params = new URLSearchParams(searchParams) + params.set('token', encodeURIComponent(ret.token)) + router.push(`/webapp-reset-password/set-password?${params.toString()}`) + } + } + catch (error) { console.error(error) } + finally { + setIsLoading(false) + } + } + + const resendCode = async () => { + try { + const res = await sendWebAppResetPasswordCode(email, locale) + if (res.result === 'success') { + const params = new URLSearchParams(searchParams) + params.set('token', encodeURIComponent(res.data)) + router.replace(`/webapp-reset-password/check-code?${params.toString()}`) + } + } + catch (error) { console.error(error) } + } + + return
+
+ +
+
+

{t('login.checkCode.checkYourEmail')}

+

+ +
+ {t('login.checkCode.validTime')} +

+
+ +
+ + + setVerifyCode(e.target.value)} max-length={6} className='mt-1' placeholder={t('login.checkCode.verificationCodePlaceholder') as string} /> + + + +
+
+
+
router.back()} className='flex h-9 cursor-pointer items-center justify-center text-text-tertiary'> +
+ +
+ {t('login.back')} +
+
+} diff --git a/web/app/(shareLayout)/webapp-reset-password/layout.tsx b/web/app/(shareLayout)/webapp-reset-password/layout.tsx new file mode 100644 index 0000000000..e0ac6b9ad6 --- /dev/null +++ b/web/app/(shareLayout)/webapp-reset-password/layout.tsx @@ -0,0 +1,30 @@ +'use client' +import Header from '@/app/signin/_header' + +import cn from '@/utils/classnames' +import { useGlobalPublicStore } from '@/context/global-public-context' + +export default function SignInLayout({ children }: any) { + const { systemFeatures } = useGlobalPublicStore() + return <> +
+
+
+
+
+ {children} +
+
+ {!systemFeatures.branding.enabled &&
+ © {new Date().getFullYear()} LangGenius, Inc. All rights reserved. +
} +
+
+ +} diff --git a/web/app/(shareLayout)/webapp-reset-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/page.tsx new file mode 100644 index 0000000000..96cd4c5805 --- /dev/null +++ b/web/app/(shareLayout)/webapp-reset-password/page.tsx @@ -0,0 +1,104 @@ +'use client' +import Link from 'next/link' +import { RiArrowLeftLine, RiLockPasswordLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { useState } from 'react' +import { useRouter, useSearchParams } from 'next/navigation' +import { useContext } from 'use-context-selector' +import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown' +import { emailRegex } from '@/config' +import Button from '@/app/components/base/button' +import Input from '@/app/components/base/input' +import Toast from '@/app/components/base/toast' +import { sendResetPasswordCode } from '@/service/common' +import I18NContext from '@/context/i18n' +import { noop } from 'lodash-es' +import useDocumentTitle from '@/hooks/use-document-title' + +export default function CheckCode() { + const { t } = useTranslation() + useDocumentTitle('') + const searchParams = useSearchParams() + const router = useRouter() + const [email, setEmail] = useState('') + const [loading, setIsLoading] = useState(false) + const { locale } = useContext(I18NContext) + + const handleGetEMailVerificationCode = async () => { + try { + if (!email) { + Toast.notify({ type: 'error', message: t('login.error.emailEmpty') }) + return + } + + if (!emailRegex.test(email)) { + Toast.notify({ + type: 'error', + message: t('login.error.emailInValid'), + }) + return + } + setIsLoading(true) + const res = await sendResetPasswordCode(email, locale) + if (res.result === 'success') { + localStorage.setItem(COUNT_DOWN_KEY, `${COUNT_DOWN_TIME_MS}`) + const params = new URLSearchParams(searchParams) + params.set('token', encodeURIComponent(res.data)) + params.set('email', encodeURIComponent(email)) + router.push(`/webapp-reset-password/check-code?${params.toString()}`) + } + else if (res.code === 'account_not_found') { + Toast.notify({ + type: 'error', + message: t('login.error.registrationNotAllowed'), + }) + } + else { + Toast.notify({ + type: 'error', + message: res.data, + }) + } + } + catch (error) { + console.error(error) + } + finally { + setIsLoading(false) + } + } + + return
+
+ +
+
+

{t('login.resetPassword')}

+

+ {t('login.resetPasswordDesc')} +

+
+ +
+ +
+ +
+ setEmail(e.target.value)} /> +
+
+ +
+
+
+
+
+
+ +
+ +
+ {t('login.backToLogin')} + +
+} diff --git a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx new file mode 100644 index 0000000000..9f9a8ad4e3 --- /dev/null +++ b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx @@ -0,0 +1,188 @@ +'use client' +import { useCallback, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { useRouter, useSearchParams } from 'next/navigation' +import cn from 'classnames' +import { RiCheckboxCircleFill } from '@remixicon/react' +import { useCountDown } from 'ahooks' +import Button from '@/app/components/base/button' +import { changeWebAppPasswordWithToken } from '@/service/common' +import Toast from '@/app/components/base/toast' +import Input from '@/app/components/base/input' + +const validPassword = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/ + +const ChangePasswordForm = () => { + const { t } = useTranslation() + const router = useRouter() + const searchParams = useSearchParams() + const token = decodeURIComponent(searchParams.get('token') || '') + + const [password, setPassword] = useState('') + const [confirmPassword, setConfirmPassword] = useState('') + const [showSuccess, setShowSuccess] = useState(false) + const [showPassword, setShowPassword] = useState(false) + const [showConfirmPassword, setShowConfirmPassword] = useState(false) + + const showErrorMessage = useCallback((message: string) => { + Toast.notify({ + type: 'error', + message, + }) + }, []) + + const getSignInUrl = () => { + return `/webapp-signin?redirect_url=${searchParams.get('redirect_url') || ''}` + } + + const AUTO_REDIRECT_TIME = 5000 + const [leftTime, setLeftTime] = useState(undefined) + const [countdown] = useCountDown({ + leftTime, + onEnd: () => { + router.replace(getSignInUrl()) + }, + }) + + const valid = useCallback(() => { + if (!password.trim()) { + showErrorMessage(t('login.error.passwordEmpty')) + return false + } + if (!validPassword.test(password)) { + showErrorMessage(t('login.error.passwordInvalid')) + return false + } + if (password !== confirmPassword) { + showErrorMessage(t('common.account.notEqual')) + return false + } + return true + }, [password, confirmPassword, showErrorMessage, t]) + + const handleChangePassword = useCallback(async () => { + if (!valid()) + return + try { + await changeWebAppPasswordWithToken({ + url: '/forgot-password/resets', + body: { + token, + new_password: password, + password_confirm: confirmPassword, + }, + }) + setShowSuccess(true) + setLeftTime(AUTO_REDIRECT_TIME) + } + catch (error) { + console.error(error) + } + }, [password, token, valid, confirmPassword]) + + return ( +
+ {!showSuccess && ( +
+
+

+ {t('login.changePassword')} +

+

+ {t('login.changePasswordTip')} +

+
+ +
+
+ {/* Password */} +
+ +
+ setPassword(e.target.value)} + placeholder={t('login.passwordPlaceholder') || ''} + /> + +
+ +
+
+
{t('login.error.passwordInvalid')}
+
+ {/* Confirm Password */} +
+ +
+ setConfirmPassword(e.target.value)} + placeholder={t('login.confirmPasswordPlaceholder') || ''} + /> +
+ +
+
+
+
+ +
+
+
+
+ )} + {showSuccess && ( +
+
+
+ +
+

+ {t('login.passwordChangedTip')} +

+
+
+ +
+
+ )} +
+ ) +} + +export default ChangePasswordForm diff --git a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx new file mode 100644 index 0000000000..1b8f18c98f --- /dev/null +++ b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx @@ -0,0 +1,115 @@ +'use client' +import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { useCallback, useState } from 'react' +import { useRouter, useSearchParams } from 'next/navigation' +import { useContext } from 'use-context-selector' +import Countdown from '@/app/components/signin/countdown' +import Button from '@/app/components/base/button' +import Input from '@/app/components/base/input' +import Toast from '@/app/components/base/toast' +import { sendWebAppEMailLoginCode, webAppEmailLoginWithCode } from '@/service/common' +import I18NContext from '@/context/i18n' +import { setAccessToken } from '@/app/components/share/utils' +import { fetchAccessToken } from '@/service/share' + +export default function CheckCode() { + const { t } = useTranslation() + const router = useRouter() + const searchParams = useSearchParams() + const email = decodeURIComponent(searchParams.get('email') as string) + const token = decodeURIComponent(searchParams.get('token') as string) + const [code, setVerifyCode] = useState('') + const [loading, setIsLoading] = useState(false) + const { locale } = useContext(I18NContext) + const redirectUrl = searchParams.get('redirect_url') + + const getAppCodeFromRedirectUrl = useCallback(() => { + const appCode = redirectUrl?.split('/').pop() + if (!appCode) + return null + + return appCode + }, [redirectUrl]) + + const verify = async () => { + try { + const appCode = getAppCodeFromRedirectUrl() + if (!code.trim()) { + Toast.notify({ + type: 'error', + message: t('login.checkCode.emptyCode'), + }) + return + } + if (!/\d{6}/.test(code)) { + Toast.notify({ + type: 'error', + message: t('login.checkCode.invalidCode'), + }) + return + } + if (!redirectUrl || !appCode) { + Toast.notify({ + type: 'error', + message: t('login.error.redirectUrlMissing'), + }) + return + } + setIsLoading(true) + const ret = await webAppEmailLoginWithCode({ email, code, token }) + if (ret.result === 'success') { + localStorage.setItem('webapp_access_token', ret.data.access_token) + const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: ret.data.access_token }) + await setAccessToken(appCode, tokenResp.access_token) + router.replace(redirectUrl) + } + } + catch (error) { console.error(error) } + finally { + setIsLoading(false) + } + } + + const resendCode = async () => { + try { + const ret = await sendWebAppEMailLoginCode(email, locale) + if (ret.result === 'success') { + const params = new URLSearchParams(searchParams) + params.set('token', encodeURIComponent(ret.data)) + router.replace(`/webapp-signin/check-code?${params.toString()}`) + } + } + catch (error) { console.error(error) } + } + + return
+
+ +
+
+

{t('login.checkCode.checkYourEmail')}

+

+ +
+ {t('login.checkCode.validTime')} +

+
+ +
+ + setVerifyCode(e.target.value)} max-length={6} className='mt-1' placeholder={t('login.checkCode.verificationCodePlaceholder') as string} /> + + + +
+
+
+
router.back()} className='flex h-9 cursor-pointer items-center justify-center text-text-tertiary'> +
+ +
+ {t('login.back')} +
+
+} diff --git a/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx new file mode 100644 index 0000000000..e9b15ae331 --- /dev/null +++ b/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx @@ -0,0 +1,80 @@ +'use client' +import { useRouter, useSearchParams } from 'next/navigation' +import React, { useCallback, useEffect } from 'react' +import Toast from '@/app/components/base/toast' +import { fetchWebOAuth2SSOUrl, fetchWebOIDCSSOUrl, fetchWebSAMLSSOUrl } from '@/service/share' +import { useGlobalPublicStore } from '@/context/global-public-context' +import { SSOProtocol } from '@/types/feature' +import Loading from '@/app/components/base/loading' +import AppUnavailable from '@/app/components/base/app-unavailable' + +const ExternalMemberSSOAuth = () => { + const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) + const searchParams = useSearchParams() + const router = useRouter() + + const redirectUrl = searchParams.get('redirect_url') + + const showErrorToast = (message: string) => { + Toast.notify({ + type: 'error', + message, + }) + } + + const getAppCodeFromRedirectUrl = useCallback(() => { + const appCode = redirectUrl?.split('/').pop() + if (!appCode) + return null + + return appCode + }, [redirectUrl]) + + const handleSSOLogin = useCallback(async () => { + const appCode = getAppCodeFromRedirectUrl() + if (!appCode || !redirectUrl) { + showErrorToast('redirect url or app code is invalid.') + return + } + + switch (systemFeatures.webapp_auth.sso_config.protocol) { + case SSOProtocol.SAML: { + const samlRes = await fetchWebSAMLSSOUrl(appCode, redirectUrl) + router.push(samlRes.url) + break + } + case SSOProtocol.OIDC: { + const oidcRes = await fetchWebOIDCSSOUrl(appCode, redirectUrl) + router.push(oidcRes.url) + break + } + case SSOProtocol.OAuth2: { + const oauth2Res = await fetchWebOAuth2SSOUrl(appCode, redirectUrl) + router.push(oauth2Res.url) + break + } + case '': + break + default: + showErrorToast('SSO protocol is not supported.') + } + }, [getAppCodeFromRedirectUrl, redirectUrl, router, systemFeatures.webapp_auth.sso_config.protocol]) + + useEffect(() => { + handleSSOLogin() + }, [handleSSOLogin]) + + if (!systemFeatures.webapp_auth.sso_config.protocol) { + return
+ +
+ } + + return ( +
+ +
+ ) +} + +export default React.memo(ExternalMemberSSOAuth) diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx new file mode 100644 index 0000000000..29af3e3a57 --- /dev/null +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx @@ -0,0 +1,68 @@ +import { useState } from 'react' +import { useTranslation } from 'react-i18next' +import { useRouter, useSearchParams } from 'next/navigation' +import { useContext } from 'use-context-selector' +import Input from '@/app/components/base/input' +import Button from '@/app/components/base/button' +import { emailRegex } from '@/config' +import Toast from '@/app/components/base/toast' +import { sendWebAppEMailLoginCode } from '@/service/common' +import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown' +import I18NContext from '@/context/i18n' +import { noop } from 'lodash-es' + +export default function MailAndCodeAuth() { + const { t } = useTranslation() + const router = useRouter() + const searchParams = useSearchParams() + const emailFromLink = decodeURIComponent(searchParams.get('email') || '') + const [email, setEmail] = useState(emailFromLink) + const [loading, setIsLoading] = useState(false) + const { locale } = useContext(I18NContext) + + const handleGetEMailVerificationCode = async () => { + try { + if (!email) { + Toast.notify({ type: 'error', message: t('login.error.emailEmpty') }) + return + } + + if (!emailRegex.test(email)) { + Toast.notify({ + type: 'error', + message: t('login.error.emailInValid'), + }) + return + } + setIsLoading(true) + const ret = await sendWebAppEMailLoginCode(email, locale) + if (ret.result === 'success') { + localStorage.setItem(COUNT_DOWN_KEY, `${COUNT_DOWN_TIME_MS}`) + const params = new URLSearchParams(searchParams) + params.set('email', encodeURIComponent(email)) + params.set('token', encodeURIComponent(ret.data)) + router.push(`/webapp-signin/check-code?${params.toString()}`) + } + } + catch (error) { + console.error(error) + } + finally { + setIsLoading(false) + } + } + + return (
+ +
+ +
+ setEmail(e.target.value)} /> +
+
+ +
+
+
+ ) +} diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx new file mode 100644 index 0000000000..d9e56af1b8 --- /dev/null +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx @@ -0,0 +1,171 @@ +import Link from 'next/link' +import { useCallback, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { useRouter, useSearchParams } from 'next/navigation' +import { useContext } from 'use-context-selector' +import Button from '@/app/components/base/button' +import Toast from '@/app/components/base/toast' +import { emailRegex } from '@/config' +import { webAppLogin } from '@/service/common' +import Input from '@/app/components/base/input' +import I18NContext from '@/context/i18n' +import { noop } from 'lodash-es' +import { setAccessToken } from '@/app/components/share/utils' +import { fetchAccessToken } from '@/service/share' + +type MailAndPasswordAuthProps = { + isEmailSetup: boolean +} + +const passwordRegex = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/ + +export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAuthProps) { + const { t } = useTranslation() + const { locale } = useContext(I18NContext) + const router = useRouter() + const searchParams = useSearchParams() + const [showPassword, setShowPassword] = useState(false) + const emailFromLink = decodeURIComponent(searchParams.get('email') || '') + const [email, setEmail] = useState(emailFromLink) + const [password, setPassword] = useState('') + + const [isLoading, setIsLoading] = useState(false) + const redirectUrl = searchParams.get('redirect_url') + + const getAppCodeFromRedirectUrl = useCallback(() => { + const appCode = redirectUrl?.split('/').pop() + if (!appCode) + return null + + return appCode + }, [redirectUrl]) + const handleEmailPasswordLogin = async () => { + const appCode = getAppCodeFromRedirectUrl() + if (!email) { + Toast.notify({ type: 'error', message: t('login.error.emailEmpty') }) + return + } + if (!emailRegex.test(email)) { + Toast.notify({ + type: 'error', + message: t('login.error.emailInValid'), + }) + return + } + if (!password?.trim()) { + Toast.notify({ type: 'error', message: t('login.error.passwordEmpty') }) + return + } + if (!passwordRegex.test(password)) { + Toast.notify({ + type: 'error', + message: t('login.error.passwordInvalid'), + }) + return + } + if (!redirectUrl || !appCode) { + Toast.notify({ + type: 'error', + message: t('login.error.redirectUrlMissing'), + }) + return + } + try { + setIsLoading(true) + const loginData: Record = { + email, + password, + language: locale, + remember_me: true, + } + + const res = await webAppLogin({ + url: '/login', + body: loginData, + }) + if (res.result === 'success') { + localStorage.setItem('webapp_access_token', res.data.access_token) + const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: res.data.access_token }) + await setAccessToken(appCode, tokenResp.access_token) + router.replace(redirectUrl) + } + else { + Toast.notify({ + type: 'error', + message: res.data, + }) + } + } + + finally { + setIsLoading(false) + } + } + + return
+
+ +
+ setEmail(e.target.value)} + id="email" + type="email" + autoComplete="email" + placeholder={t('login.emailPlaceholder') || ''} + tabIndex={1} + /> +
+
+ +
+ +
+ setPassword(e.target.value)} + onKeyDown={(e) => { + if (e.key === 'Enter') + handleEmailPasswordLogin() + }} + type={showPassword ? 'text' : 'password'} + autoComplete="current-password" + placeholder={t('login.passwordPlaceholder') || ''} + tabIndex={2} + /> +
+ +
+
+
+ +
+ +
+
+} diff --git a/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx new file mode 100644 index 0000000000..5d649322ba --- /dev/null +++ b/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx @@ -0,0 +1,88 @@ +'use client' +import { useRouter, useSearchParams } from 'next/navigation' +import type { FC } from 'react' +import { useCallback } from 'react' +import { useState } from 'react' +import { useTranslation } from 'react-i18next' +import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' +import Toast from '@/app/components/base/toast' +import Button from '@/app/components/base/button' +import { SSOProtocol } from '@/types/feature' +import { fetchMembersOAuth2SSOUrl, fetchMembersOIDCSSOUrl, fetchMembersSAMLSSOUrl } from '@/service/share' + +type SSOAuthProps = { + protocol: SSOProtocol | '' +} + +const SSOAuth: FC = ({ + protocol, +}) => { + const router = useRouter() + const { t } = useTranslation() + const searchParams = useSearchParams() + + const redirectUrl = searchParams.get('redirect_url') + const getAppCodeFromRedirectUrl = useCallback(() => { + const appCode = redirectUrl?.split('/').pop() + if (!appCode) + return null + + return appCode + }, [redirectUrl]) + + const [isLoading, setIsLoading] = useState(false) + + const handleSSOLogin = () => { + const appCode = getAppCodeFromRedirectUrl() + if (!redirectUrl || !appCode) { + Toast.notify({ + type: 'error', + message: 'invalid redirect URL or app code', + }) + return + } + setIsLoading(true) + if (protocol === SSOProtocol.SAML) { + fetchMembersSAMLSSOUrl(appCode, redirectUrl).then((res) => { + router.push(res.url) + }).finally(() => { + setIsLoading(false) + }) + } + else if (protocol === SSOProtocol.OIDC) { + fetchMembersOIDCSSOUrl(appCode, redirectUrl).then((res) => { + router.push(res.url) + }).finally(() => { + setIsLoading(false) + }) + } + else if (protocol === SSOProtocol.OAuth2) { + fetchMembersOAuth2SSOUrl(appCode, redirectUrl).then((res) => { + router.push(res.url) + }).finally(() => { + setIsLoading(false) + }) + } + else { + Toast.notify({ + type: 'error', + message: 'invalid SSO protocol', + }) + setIsLoading(false) + } + } + + return ( + + ) +} + +export default SSOAuth diff --git a/web/app/(shareLayout)/webapp-signin/layout.tsx b/web/app/(shareLayout)/webapp-signin/layout.tsx new file mode 100644 index 0000000000..a03364d326 --- /dev/null +++ b/web/app/(shareLayout)/webapp-signin/layout.tsx @@ -0,0 +1,25 @@ +'use client' + +import cn from '@/utils/classnames' +import { useGlobalPublicStore } from '@/context/global-public-context' +import useDocumentTitle from '@/hooks/use-document-title' + +export default function SignInLayout({ children }: any) { + const { systemFeatures } = useGlobalPublicStore() + useDocumentTitle('') + return <> +
+
+ {/*
*/} +
+
+ {children} +
+
+ {systemFeatures.branding.enabled === false &&
+ © {new Date().getFullYear()} LangGenius, Inc. All rights reserved. +
} +
+
+ +} diff --git a/web/app/(shareLayout)/webapp-signin/normalForm.tsx b/web/app/(shareLayout)/webapp-signin/normalForm.tsx new file mode 100644 index 0000000000..d6bdf607ba --- /dev/null +++ b/web/app/(shareLayout)/webapp-signin/normalForm.tsx @@ -0,0 +1,176 @@ +import React, { useCallback, useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' +import Link from 'next/link' +import { RiContractLine, RiDoorLockLine, RiErrorWarningFill } from '@remixicon/react' +import Loading from '@/app/components/base/loading' +import MailAndCodeAuth from './components/mail-and-code-auth' +import MailAndPasswordAuth from './components/mail-and-password-auth' +import SSOAuth from './components/sso-auth' +import cn from '@/utils/classnames' +import { LicenseStatus } from '@/types/feature' +import { IS_CE_EDITION } from '@/config' +import { useGlobalPublicStore } from '@/context/global-public-context' + +const NormalForm = () => { + const { t } = useTranslation() + + const [isLoading, setIsLoading] = useState(true) + const { systemFeatures } = useGlobalPublicStore() + const [authType, updateAuthType] = useState<'code' | 'password'>('password') + const [showORLine, setShowORLine] = useState(false) + const [allMethodsAreDisabled, setAllMethodsAreDisabled] = useState(false) + + const init = useCallback(async () => { + try { + setAllMethodsAreDisabled(!systemFeatures.enable_social_oauth_login && !systemFeatures.enable_email_code_login && !systemFeatures.enable_email_password_login && !systemFeatures.sso_enforced_for_signin) + setShowORLine((systemFeatures.enable_social_oauth_login || systemFeatures.sso_enforced_for_signin) && (systemFeatures.enable_email_code_login || systemFeatures.enable_email_password_login)) + updateAuthType(systemFeatures.enable_email_password_login ? 'password' : 'code') + } + catch (error) { + console.error(error) + setAllMethodsAreDisabled(true) + } + finally { setIsLoading(false) } + }, [systemFeatures]) + useEffect(() => { + init() + }, [init]) + if (isLoading) { + return
+ +
+ } + if (systemFeatures.license?.status === LicenseStatus.LOST) { + return
+
+
+
+ + +
+

{t('login.licenseLost')}

+

{t('login.licenseLostTip')}

+
+
+
+ } + if (systemFeatures.license?.status === LicenseStatus.EXPIRED) { + return
+
+
+
+ + +
+

{t('login.licenseExpired')}

+

{t('login.licenseExpiredTip')}

+
+
+
+ } + if (systemFeatures.license?.status === LicenseStatus.INACTIVE) { + return
+
+
+
+ + +
+

{t('login.licenseInactive')}

+

{t('login.licenseInactiveTip')}

+
+
+
+ } + + return ( + <> +
+
+

{t('login.pageTitle')}

+ {!systemFeatures.branding.enabled &&

{t('login.welcome')}

} +
+
+
+ {systemFeatures.sso_enforced_for_signin &&
+ +
} +
+ + {showORLine &&
+ +
+ {t('login.or')} +
+
} + { + (systemFeatures.enable_email_code_login || systemFeatures.enable_email_password_login) && <> + {systemFeatures.enable_email_code_login && authType === 'code' && <> + + {systemFeatures.enable_email_password_login &&
{ updateAuthType('password') }}> + {t('login.usePassword')} +
} + } + {systemFeatures.enable_email_password_login && authType === 'password' && <> + + {systemFeatures.enable_email_code_login &&
{ updateAuthType('code') }}> + {t('login.useVerificationCode')} +
} + } + + } + {allMethodsAreDisabled && <> +
+
+ +
+

{t('login.noLoginMethod')}

+

{t('login.noLoginMethodTip')}

+
+
+ +
+ } + {!systemFeatures.branding.enabled && <> +
+ {t('login.tosDesc')} +   + {t('login.tos')} +  &  + {t('login.pp')} +
+ {IS_CE_EDITION &&
+ {t('login.goToInit')} +   + {t('login.setAdminAccount')} +
} + } + +
+
+ + ) +} + +export default NormalForm diff --git a/web/app/(shareLayout)/webapp-signin/page.tsx b/web/app/(shareLayout)/webapp-signin/page.tsx index 668c3f312c..c12fde38dd 100644 --- a/web/app/(shareLayout)/webapp-signin/page.tsx +++ b/web/app/(shareLayout)/webapp-signin/page.tsx @@ -3,19 +3,20 @@ import { useRouter, useSearchParams } from 'next/navigation' import type { FC } from 'react' import React, { useCallback, useEffect } from 'react' import { useTranslation } from 'react-i18next' -import { RiDoorLockLine } from '@remixicon/react' -import cn from '@/utils/classnames' import Toast from '@/app/components/base/toast' -import { fetchWebOAuth2SSOUrl, fetchWebOIDCSSOUrl, fetchWebSAMLSSOUrl } from '@/service/share' -import { setAccessToken } from '@/app/components/share/utils' +import { removeAccessToken, setAccessToken } from '@/app/components/share/utils' import { useGlobalPublicStore } from '@/context/global-public-context' -import { SSOProtocol } from '@/types/feature' import Loading from '@/app/components/base/loading' import AppUnavailable from '@/app/components/base/app-unavailable' +import NormalForm from './normalForm' +import { AccessMode } from '@/models/access-control' +import ExternalMemberSsoAuth from './components/external-member-sso-auth' +import { fetchAccessToken } from '@/service/share' const WebSSOForm: FC = () => { const { t } = useTranslation() const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) + const webAppAccessMode = useGlobalPublicStore(s => s.webAppAccessMode) const searchParams = useSearchParams() const router = useRouter() @@ -23,10 +24,22 @@ const WebSSOForm: FC = () => { const tokenFromUrl = searchParams.get('web_sso_token') const message = searchParams.get('message') - const showErrorToast = (message: string) => { + const getSigninUrl = useCallback(() => { + const params = new URLSearchParams(searchParams) + params.delete('message') + return `/webapp-signin?${params.toString()}` + }, [searchParams]) + + const backToHome = useCallback(() => { + removeAccessToken() + const url = getSigninUrl() + router.replace(url) + }, [getSigninUrl, router]) + + const showErrorToast = (msg: string) => { Toast.notify({ type: 'error', - message, + message: msg, }) } @@ -38,102 +51,73 @@ const WebSSOForm: FC = () => { return appCode }, [redirectUrl]) - const processTokenAndRedirect = useCallback(async () => { - const appCode = getAppCodeFromRedirectUrl() - if (!appCode || !tokenFromUrl || !redirectUrl) { - showErrorToast('redirect url or app code or token is invalid.') - return - } - - await setAccessToken(appCode, tokenFromUrl) - router.push(redirectUrl) - }, [getAppCodeFromRedirectUrl, redirectUrl, router, tokenFromUrl]) - - const handleSSOLogin = useCallback(async () => { - const appCode = getAppCodeFromRedirectUrl() - if (!appCode || !redirectUrl) { - showErrorToast('redirect url or app code is invalid.') - return - } - - switch (systemFeatures.webapp_auth.sso_config.protocol) { - case SSOProtocol.SAML: { - const samlRes = await fetchWebSAMLSSOUrl(appCode, redirectUrl) - router.push(samlRes.url) - break - } - case SSOProtocol.OIDC: { - const oidcRes = await fetchWebOIDCSSOUrl(appCode, redirectUrl) - router.push(oidcRes.url) - break - } - case SSOProtocol.OAuth2: { - const oauth2Res = await fetchWebOAuth2SSOUrl(appCode, redirectUrl) - router.push(oauth2Res.url) - break - } - case '': - break - default: - showErrorToast('SSO protocol is not supported.') - } - }, [getAppCodeFromRedirectUrl, redirectUrl, router, systemFeatures.webapp_auth.sso_config.protocol]) - useEffect(() => { - const init = async () => { - if (message) { - showErrorToast(message) + (async () => { + if (message) return - } - if (!tokenFromUrl) { - await handleSSOLogin() + const appCode = getAppCodeFromRedirectUrl() + if (appCode && tokenFromUrl && redirectUrl) { + localStorage.setItem('webapp_access_token', tokenFromUrl) + const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: tokenFromUrl }) + await setAccessToken(appCode, tokenResp.access_token) + router.replace(redirectUrl) return } + if (appCode && redirectUrl && localStorage.getItem('webapp_access_token')) { + const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: localStorage.getItem('webapp_access_token') }) + await setAccessToken(appCode, tokenResp.access_token) + router.replace(redirectUrl) + } + })() + }, [getAppCodeFromRedirectUrl, redirectUrl, router, tokenFromUrl, message]) - await processTokenAndRedirect() - } + useEffect(() => { + if (webAppAccessMode && webAppAccessMode === AccessMode.PUBLIC && redirectUrl) + router.replace(redirectUrl) + }, [webAppAccessMode, router, redirectUrl]) - init() - }, [message, processTokenAndRedirect, tokenFromUrl, handleSSOLogin]) - if (tokenFromUrl) - return
- if (message) { + if (tokenFromUrl) { return
- +
} - if (systemFeatures.webapp_auth.enabled) { - if (systemFeatures.webapp_auth.allow_sso) { - return ( -
-
- -
-
- ) - } - return
-
-
- -
-

{t('login.webapp.noLoginMethod')}

-

{t('login.webapp.noLoginMethodTip')}

-
-
- -
+ if (message) { + return
+ + {t('share.login.backToHome')} +
+ } + if (!redirectUrl) { + showErrorToast('redirect url is invalid.') + return
+ +
+ } + if (webAppAccessMode && webAppAccessMode === AccessMode.PUBLIC) { + return
+
} - else { + if (!systemFeatures.webapp_auth.enabled) { return

{t('login.webapp.disabled')}

} + if (webAppAccessMode && (webAppAccessMode === AccessMode.ORGANIZATION || webAppAccessMode === AccessMode.SPECIFIC_GROUPS_MEMBERS)) { + return
+ +
+ } + + if (webAppAccessMode && webAppAccessMode === AccessMode.EXTERNAL_MEMBERS) + return + + return
+ + {t('share.login.backToHome')} +
} export default React.memo(WebSSOForm) diff --git a/web/app/components/app/app-access-control/index.tsx b/web/app/components/app/app-access-control/index.tsx index 2f15c8ec48..13faaea957 100644 --- a/web/app/components/app/app-access-control/index.tsx +++ b/web/app/components/app/app-access-control/index.tsx @@ -1,6 +1,6 @@ 'use client' -import { Dialog } from '@headlessui/react' -import { RiBuildingLine, RiGlobalLine } from '@remixicon/react' +import { Description as DialogDescription, DialogTitle } from '@headlessui/react' +import { RiBuildingLine, RiGlobalLine, RiVerifiedBadgeLine } from '@remixicon/react' import { useTranslation } from 'react-i18next' import { useCallback, useEffect } from 'react' import Button from '../../base/button' @@ -67,8 +67,8 @@ export default function AccessControl(props: AccessControlProps) { return
- {t('app.accessControlDialog.title')} - {t('app.accessControlDialog.description')} + {t('app.accessControlDialog.title')} + {t('app.accessControlDialog.description')}
@@ -80,12 +80,20 @@ export default function AccessControl(props: AccessControlProps) {

{t('app.accessControlDialog.accessItems.organization')}

- {!hideTip && }
+ +
+
+ +

{t('app.accessControlDialog.accessItems.external')}

+
+ {!hideTip && } +
+
diff --git a/web/app/components/app/app-access-control/specific-groups-or-members.tsx b/web/app/components/app/app-access-control/specific-groups-or-members.tsx index f4872f8c99..b30c8f1ba3 100644 --- a/web/app/components/app/app-access-control/specific-groups-or-members.tsx +++ b/web/app/components/app/app-access-control/specific-groups-or-members.tsx @@ -3,12 +3,10 @@ import { RiAlertFill, RiCloseCircleFill, RiLockLine, RiOrganizationChart } from import { useTranslation } from 'react-i18next' import { useCallback, useEffect } from 'react' import Avatar from '../../base/avatar' -import Divider from '../../base/divider' import Tooltip from '../../base/tooltip' import Loading from '../../base/loading' import useAccessControlStore from '../../../../context/access-control-store' import AddMemberOrGroupDialog from './add-member-or-group-pop' -import { useGlobalPublicStore } from '@/context/global-public-context' import type { AccessControlAccount, AccessControlGroup } from '@/models/access-control' import { AccessMode } from '@/models/access-control' import { useAppWhiteListSubjects } from '@/service/access-control' @@ -19,11 +17,6 @@ export default function SpecificGroupsOrMembers() { const setSpecificGroups = useAccessControlStore(s => s.setSpecificGroups) const setSpecificMembers = useAccessControlStore(s => s.setSpecificMembers) const { t } = useTranslation() - const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) - const hideTip = systemFeatures.webapp_auth.enabled - && (systemFeatures.webapp_auth.allow_sso - || systemFeatures.webapp_auth.allow_email_password_login - || systemFeatures.webapp_auth.allow_email_code_login) const { isPending, data } = useAppWhiteListSubjects(appId, Boolean(appId) && currentMenu === AccessMode.SPECIFIC_GROUPS_MEMBERS) useEffect(() => { @@ -37,7 +30,6 @@ export default function SpecificGroupsOrMembers() {

{t('app.accessControlDialog.accessItems.specific')}

- {!hideTip && }
} @@ -48,10 +40,6 @@ export default function SpecificGroupsOrMembers() {

{t('app.accessControlDialog.accessItems.specific')}

- {!hideTip && <> - - - }
diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index 8d0028c7d7..5825bb72ee 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -9,11 +9,14 @@ import dayjs from 'dayjs' import { RiArrowDownSLine, RiArrowRightSLine, + RiBuildingLine, + RiGlobalLine, RiLockLine, RiPlanetLine, RiPlayCircleLine, RiPlayList2Line, RiTerminalBoxLine, + RiVerifiedBadgeLine, } from '@remixicon/react' import { useKeyPress } from 'ahooks' import { getKeyboardKeyCodeBySystem } from '../../workflow/utils' @@ -276,10 +279,30 @@ const AppPublisher = ({ setShowAppAccessControl(true) }}>
- - {appDetail?.access_mode === AccessMode.ORGANIZATION &&

{t('app.accessControlDialog.accessItems.organization')}

} - {appDetail?.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS &&

{t('app.accessControlDialog.accessItems.specific')}

} - {appDetail?.access_mode === AccessMode.PUBLIC &&

{t('app.accessControlDialog.accessItems.anyone')}

} + {appDetail?.access_mode === AccessMode.ORGANIZATION + && <> + +

{t('app.accessControlDialog.accessItems.organization')}

+ + } + {appDetail?.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS + && <> + +

{t('app.accessControlDialog.accessItems.specific')}

+ + } + {appDetail?.access_mode === AccessMode.PUBLIC + && <> + +

{t('app.accessControlDialog.accessItems.anyone')}

+ + } + {appDetail?.access_mode === AccessMode.EXTERNAL_MEMBERS + && <> + +

{t('app.accessControlDialog.accessItems.external')}

+ + }
{!isAppAccessSet &&

{t('app.publishApp.notSet')}

}
diff --git a/web/app/components/app/overview/appCard.tsx b/web/app/components/app/overview/appCard.tsx index 9b283cdf5e..9f3b3ac4a6 100644 --- a/web/app/components/app/overview/appCard.tsx +++ b/web/app/components/app/overview/appCard.tsx @@ -5,10 +5,13 @@ import { useTranslation } from 'react-i18next' import { RiArrowRightSLine, RiBookOpenLine, + RiBuildingLine, RiEqualizer2Line, RiExternalLinkLine, + RiGlobalLine, RiLockLine, RiPaintBrushLine, + RiVerifiedBadgeLine, RiWindowLine, } from '@remixicon/react' import SettingsModal from './settings' @@ -248,11 +251,30 @@ function AppCard({
- - {appDetail?.access_mode === AccessMode.ORGANIZATION &&

{t('app.accessControlDialog.accessItems.organization')}

} - {appDetail?.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS &&

{t('app.accessControlDialog.accessItems.specific')}

} - {appDetail?.access_mode === AccessMode.PUBLIC &&

{t('app.accessControlDialog.accessItems.anyone')}

} -
+ {appDetail?.access_mode === AccessMode.ORGANIZATION + && <> + +

{t('app.accessControlDialog.accessItems.organization')}

+ + } + {appDetail?.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS + && <> + +

{t('app.accessControlDialog.accessItems.specific')}

+ + } + {appDetail?.access_mode === AccessMode.PUBLIC + && <> + +

{t('app.accessControlDialog.accessItems.anyone')}

+ + } + {appDetail?.access_mode === AccessMode.EXTERNAL_MEMBERS + && <> + +

{t('app.accessControlDialog.accessItems.external')}

+ + }
{!isAppAccessSet &&

{t('app.publishApp.notSet')}

}
diff --git a/web/app/components/base/app-unavailable.tsx b/web/app/components/base/app-unavailable.tsx index 4e835cbfcf..928c850262 100644 --- a/web/app/components/base/app-unavailable.tsx +++ b/web/app/components/base/app-unavailable.tsx @@ -1,4 +1,5 @@ 'use client' +import classNames from '@/utils/classnames' import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' @@ -7,17 +8,19 @@ type IAppUnavailableProps = { code?: number | string isUnknownReason?: boolean unknownReason?: string + className?: string } const AppUnavailable: FC = ({ code = 404, isUnknownReason, unknownReason, + className, }) => { const { t } = useTranslation() return ( -
+

({ - accessMode: AccessMode.SPECIFIC_GROUPS_MEMBERS, userCanAccess: false, currentConversationId: '', appPrevChatTree: [], diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx index 3694666139..32f74e6457 100644 --- a/web/app/components/base/chat/chat-with-history/hooks.tsx +++ b/web/app/components/base/chat/chat-with-history/hooks.tsx @@ -16,7 +16,7 @@ import type { Feedback, } from '../types' import { CONVERSATION_ID_INFO } from '../constants' -import { buildChatItemTree, getProcessedSystemVariablesFromUrlParams } from '../utils' +import { buildChatItemTree, getProcessedSystemVariablesFromUrlParams, getRawInputsFromUrlParams } from '../utils' import { addFileInfos, sortAgentSorts } from '../../../tools/utils' import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils' import { @@ -43,9 +43,8 @@ import { useAppFavicon } from '@/hooks/use-app-favicon' import { InputVarType } from '@/app/components/workflow/types' import { TransferMethod } from '@/types/app' import { noop } from 'lodash-es' -import { useGetAppAccessMode, useGetUserCanAccessApp } from '@/service/access-control' +import { useGetUserCanAccessApp } from '@/service/access-control' import { useGlobalPublicStore } from '@/context/global-public-context' -import { AccessMode } from '@/models/access-control' function getFormattedChatList(messages: any[]) { const newChatList: ChatItem[] = [] @@ -77,11 +76,6 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { const isInstalledApp = useMemo(() => !!installedAppInfo, [installedAppInfo]) const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) const { data: appInfo, isLoading: appInfoLoading, error: appInfoError } = useSWR(installedAppInfo ? null : 'appInfo', fetchAppInfo) - const { isPending: isGettingAccessMode, data: appAccessMode } = useGetAppAccessMode({ - appId: installedAppInfo?.app.id || appInfo?.app_id, - isInstalledApp, - enabled: systemFeatures.webapp_auth.enabled, - }) const { isPending: isCheckingPermission, data: userCanAccessResult } = useGetUserCanAccessApp({ appId: installedAppInfo?.app.id || appInfo?.app_id, isInstalledApp, @@ -195,6 +189,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { const { t } = useTranslation() const newConversationInputsRef = useRef>({}) const [newConversationInputs, setNewConversationInputs] = useState>({}) + const [initInputs, setInitInputs] = useState>({}) const handleNewConversationInputsChange = useCallback((newInputs: Record) => { newConversationInputsRef.current = newInputs setNewConversationInputs(newInputs) @@ -202,20 +197,29 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { const inputsForms = useMemo(() => { return (appParams?.user_input_form || []).filter((item: any) => !item.external_data_tool).map((item: any) => { if (item.paragraph) { + let value = initInputs[item.paragraph.variable] + if (value && item.paragraph.max_length && value.length > item.paragraph.max_length) + value = value.slice(0, item.paragraph.max_length) + return { ...item.paragraph, + default: value || item.default, type: 'paragraph', } } if (item.number) { + const convertedNumber = Number(initInputs[item.number.variable]) ?? undefined return { ...item.number, + default: convertedNumber || item.default, type: 'number', } } if (item.select) { + const isInputInOptions = item.select.options.includes(initInputs[item.select.variable]) return { ...item.select, + default: (isInputInOptions ? initInputs[item.select.variable] : undefined) || item.default, type: 'select', } } @@ -234,17 +238,30 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { } } + let value = initInputs[item['text-input'].variable] + if (value && item['text-input'].max_length && value.length > item['text-input'].max_length) + value = value.slice(0, item['text-input'].max_length) + return { ...item['text-input'], + default: value || item.default, type: 'text-input', } }) - }, [appParams]) + }, [initInputs, appParams]) const allInputsHidden = useMemo(() => { return inputsForms.length > 0 && inputsForms.every(item => item.hide === true) }, [inputsForms]) + useEffect(() => { + // init inputs from url params + (async () => { + const inputs = await getRawInputsFromUrlParams() + setInitInputs(inputs) + })() + }, []) + useEffect(() => { const conversationInputs: Record = {} @@ -362,11 +379,11 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { if (conversationId) setClearChatList(false) }, [handleConversationIdInfoChange, setClearChatList]) - const handleNewConversation = useCallback(() => { + const handleNewConversation = useCallback(async () => { currentChatInstanceRef.current.handleStop() setShowNewConversationItemInList(true) handleChangeConversation('') - handleNewConversationInputsChange({}) + handleNewConversationInputsChange(await getRawInputsFromUrlParams()) setClearChatList(true) }, [handleChangeConversation, setShowNewConversationItemInList, handleNewConversationInputsChange, setClearChatList]) const handleUpdateConversationList = useCallback(() => { @@ -469,8 +486,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => { return { appInfoError, - appInfoLoading: appInfoLoading || (systemFeatures.webapp_auth.enabled && (isGettingAccessMode || isCheckingPermission)), - accessMode: systemFeatures.webapp_auth.enabled ? appAccessMode?.accessMode : AccessMode.PUBLIC, + appInfoLoading: appInfoLoading || (systemFeatures.webapp_auth.enabled && isCheckingPermission), userCanAccess: systemFeatures.webapp_auth.enabled ? userCanAccessResult?.result : true, isInstalledApp, appId, diff --git a/web/app/components/base/chat/chat-with-history/index.tsx b/web/app/components/base/chat/chat-with-history/index.tsx index de023e7f58..1fd1383196 100644 --- a/web/app/components/base/chat/chat-with-history/index.tsx +++ b/web/app/components/base/chat/chat-with-history/index.tsx @@ -124,7 +124,6 @@ const ChatWithHistoryWrap: FC = ({ const { appInfoError, appInfoLoading, - accessMode, userCanAccess, appData, appParams, @@ -169,7 +168,6 @@ const ChatWithHistoryWrap: FC = ({ appInfoError, appInfoLoading, appData, - accessMode, userCanAccess, appParams, appMeta, diff --git a/web/app/components/base/chat/chat-with-history/sidebar/index.tsx b/web/app/components/base/chat/chat-with-history/sidebar/index.tsx index fd317ccf91..4e50c1cb79 100644 --- a/web/app/components/base/chat/chat-with-history/sidebar/index.tsx +++ b/web/app/components/base/chat/chat-with-history/sidebar/index.tsx @@ -19,7 +19,6 @@ import RenameModal from '@/app/components/base/chat/chat-with-history/sidebar/re import DifyLogo from '@/app/components/base/logo/dify-logo' import type { ConversationItem } from '@/models/share' import cn from '@/utils/classnames' -import { AccessMode } from '@/models/access-control' import { useGlobalPublicStore } from '@/context/global-public-context' type Props = { @@ -30,7 +29,6 @@ const Sidebar = ({ isPanel }: Props) => { const { t } = useTranslation() const { isInstalledApp, - accessMode, appData, handleNewConversation, pinnedConversationList, @@ -140,7 +138,7 @@ const Sidebar = ({ isPanel }: Props) => { )}

- + {/* powered by */}
{!appData?.custom_config?.remove_webapp_brand && ( diff --git a/web/app/components/base/chat/chat/answer/__mocks__/markdownContentSVG.ts b/web/app/components/base/chat/chat/answer/__mocks__/markdownContentSVG.ts index bcc3ae628d..51995a4af5 100644 --- a/web/app/components/base/chat/chat/answer/__mocks__/markdownContentSVG.ts +++ b/web/app/components/base/chat/chat/answer/__mocks__/markdownContentSVG.ts @@ -3,7 +3,7 @@ export const markdownContentSVG = ` - 创意Logo设计 + 创意 Logo 设计 diff --git a/web/app/components/base/chat/embedded-chatbot/context.tsx b/web/app/components/base/chat/embedded-chatbot/context.tsx index 5964efd806..d24265ed9e 100644 --- a/web/app/components/base/chat/embedded-chatbot/context.tsx +++ b/web/app/components/base/chat/embedded-chatbot/context.tsx @@ -15,10 +15,8 @@ import type { ConversationItem, } from '@/models/share' import { noop } from 'lodash-es' -import { AccessMode } from '@/models/access-control' export type EmbeddedChatbotContextValue = { - accessMode?: AccessMode userCanAccess?: boolean appInfoError?: any appInfoLoading?: boolean @@ -58,7 +56,6 @@ export type EmbeddedChatbotContextValue = { export const EmbeddedChatbotContext = createContext({ userCanAccess: false, - accessMode: AccessMode.SPECIFIC_GROUPS_MEMBERS, currentConversationId: '', appPrevChatList: [], pinnedConversationList: [], diff --git a/web/app/components/base/chat/embedded-chatbot/hooks.tsx b/web/app/components/base/chat/embedded-chatbot/hooks.tsx index 40c56eca7b..0158e8d041 100644 --- a/web/app/components/base/chat/embedded-chatbot/hooks.tsx +++ b/web/app/components/base/chat/embedded-chatbot/hooks.tsx @@ -36,9 +36,8 @@ import { InputVarType } from '@/app/components/workflow/types' import { TransferMethod } from '@/types/app' import { addFileInfos, sortAgentSorts } from '@/app/components/tools/utils' import { noop } from 'lodash-es' -import { useGetAppAccessMode, useGetUserCanAccessApp } from '@/service/access-control' +import { useGetUserCanAccessApp } from '@/service/access-control' import { useGlobalPublicStore } from '@/context/global-public-context' -import { AccessMode } from '@/models/access-control' function getFormattedChatList(messages: any[]) { const newChatList: ChatItem[] = [] @@ -70,11 +69,6 @@ export const useEmbeddedChatbot = () => { const isInstalledApp = false const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) const { data: appInfo, isLoading: appInfoLoading, error: appInfoError } = useSWR('appInfo', fetchAppInfo) - const { isPending: isGettingAccessMode, data: appAccessMode } = useGetAppAccessMode({ - appId: appInfo?.app_id, - isInstalledApp, - enabled: systemFeatures.webapp_auth.enabled, - }) const { isPending: isCheckingPermission, data: userCanAccessResult } = useGetUserCanAccessApp({ appId: appInfo?.app_id, isInstalledApp, @@ -385,8 +379,7 @@ export const useEmbeddedChatbot = () => { return { appInfoError, - appInfoLoading: appInfoLoading || (systemFeatures.webapp_auth.enabled && (isGettingAccessMode || isCheckingPermission)), - accessMode: systemFeatures.webapp_auth.enabled ? appAccessMode?.accessMode : AccessMode.PUBLIC, + appInfoLoading: appInfoLoading || (systemFeatures.webapp_auth.enabled && isCheckingPermission), userCanAccess: systemFeatures.webapp_auth.enabled ? userCanAccessResult?.result : true, isInstalledApp, allowResetChat, diff --git a/web/app/components/base/chat/utils.ts b/web/app/components/base/chat/utils.ts index bdac5998fe..8e044375d3 100644 --- a/web/app/components/base/chat/utils.ts +++ b/web/app/components/base/chat/utils.ts @@ -15,6 +15,17 @@ async function decodeBase64AndDecompress(base64String: string) { } } +async function getRawInputsFromUrlParams(): Promise> { + const urlParams = new URLSearchParams(window.location.search) + const inputs: Record = {} + const entriesArray = Array.from(urlParams.entries()) + entriesArray.forEach(([key, value]) => { + if (!key.startsWith('sys.')) + inputs[key] = decodeURIComponent(value) + }) + return inputs +} + async function getProcessedInputsFromUrlParams(): Promise> { const urlParams = new URLSearchParams(window.location.search) const inputs: Record = {} @@ -184,6 +195,7 @@ function getThreadMessages(tree: ChatItemInTree[], targetMessageId?: string): Ch } export { + getRawInputsFromUrlParams, getProcessedInputsFromUrlParams, getProcessedSystemVariablesFromUrlParams, isValidGeneratedAnswer, diff --git a/web/app/components/base/file-uploader/hooks.ts b/web/app/components/base/file-uploader/hooks.ts index 66d5b46ba7..8e1b2148c5 100644 --- a/web/app/components/base/file-uploader/hooks.ts +++ b/web/app/components/base/file-uploader/hooks.ts @@ -231,7 +231,7 @@ export const useFile = (fileConfig: FileUpload) => { url: res.url, } if (!isAllowedFileExtension(res.name, res.mime_type, fileConfig.allowed_file_types || [], fileConfig.allowed_file_extensions || [])) { - notify({ type: 'error', message: t('common.fileUploader.fileExtensionNotSupport') }) + notify({ type: 'error', message: `${t('common.fileUploader.fileExtensionNotSupport')} ${file.type}` }) handleRemoveFile(uploadingFile.id) } if (!checkSizeLimit(newFile.supportFileType, newFile.size)) @@ -257,7 +257,7 @@ export const useFile = (fileConfig: FileUpload) => { const handleLocalFileUpload = useCallback((file: File) => { if (!isAllowedFileExtension(file.name, file.type, fileConfig.allowed_file_types || [], fileConfig.allowed_file_extensions || [])) { - notify({ type: 'error', message: t('common.fileUploader.fileExtensionNotSupport') }) + notify({ type: 'error', message: `${t('common.fileUploader.fileExtensionNotSupport')} ${file.type}` }) return } const allowedFileTypes = fileConfig.allowed_file_types diff --git a/web/app/components/base/file-uploader/utils.spec.ts b/web/app/components/base/file-uploader/utils.spec.ts index c8cf9fbe74..4a3408ef00 100644 --- a/web/app/components/base/file-uploader/utils.spec.ts +++ b/web/app/components/base/file-uploader/utils.spec.ts @@ -22,7 +22,7 @@ import { FILE_EXTS } from '../prompt-editor/constants' jest.mock('mime', () => ({ __esModule: true, default: { - getExtension: jest.fn(), + getAllExtensions: jest.fn(), }, })) @@ -58,12 +58,27 @@ describe('file-uploader utils', () => { describe('getFileExtension', () => { it('should get extension from mimetype', () => { - jest.mocked(mime.getExtension).mockReturnValue('pdf') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['pdf'])) expect(getFileExtension('file', 'application/pdf')).toBe('pdf') }) + it('should get extension from mimetype and file name 1', () => { + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['pdf'])) + expect(getFileExtension('file.pdf', 'application/pdf')).toBe('pdf') + }) + + it('should get extension from mimetype with multiple ext candidates with filename hint', () => { + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['der', 'crt', 'pem'])) + expect(getFileExtension('file.pem', 'application/x-x509-ca-cert')).toBe('pem') + }) + + it('should get extension from mimetype with multiple ext candidates without filename hint', () => { + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['der', 'crt', 'pem'])) + expect(getFileExtension('file', 'application/x-x509-ca-cert')).toBe('der') + }) + it('should get extension from filename if mimetype fails', () => { - jest.mocked(mime.getExtension).mockReturnValue(null) + jest.mocked(mime.getAllExtensions).mockReturnValue(null) expect(getFileExtension('file.txt', '')).toBe('txt') expect(getFileExtension('file.txt.docx', '')).toBe('docx') expect(getFileExtension('file', '')).toBe('') @@ -76,157 +91,157 @@ describe('file-uploader utils', () => { describe('getFileAppearanceType', () => { it('should identify gif files', () => { - jest.mocked(mime.getExtension).mockReturnValue('gif') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['gif'])) expect(getFileAppearanceType('image.gif', 'image/gif')) .toBe(FileAppearanceTypeEnum.gif) }) it('should identify image files', () => { - jest.mocked(mime.getExtension).mockReturnValue('jpg') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['jpg'])) expect(getFileAppearanceType('image.jpg', 'image/jpeg')) .toBe(FileAppearanceTypeEnum.image) - jest.mocked(mime.getExtension).mockReturnValue('jpeg') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['jpeg'])) expect(getFileAppearanceType('image.jpeg', 'image/jpeg')) .toBe(FileAppearanceTypeEnum.image) - jest.mocked(mime.getExtension).mockReturnValue('png') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['png'])) expect(getFileAppearanceType('image.png', 'image/png')) .toBe(FileAppearanceTypeEnum.image) - jest.mocked(mime.getExtension).mockReturnValue('webp') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['webp'])) expect(getFileAppearanceType('image.webp', 'image/webp')) .toBe(FileAppearanceTypeEnum.image) - jest.mocked(mime.getExtension).mockReturnValue('svg') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['svg'])) expect(getFileAppearanceType('image.svg', 'image/svgxml')) .toBe(FileAppearanceTypeEnum.image) }) it('should identify video files', () => { - jest.mocked(mime.getExtension).mockReturnValue('mp4') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['mp4'])) expect(getFileAppearanceType('video.mp4', 'video/mp4')) .toBe(FileAppearanceTypeEnum.video) - jest.mocked(mime.getExtension).mockReturnValue('mov') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['mov'])) expect(getFileAppearanceType('video.mov', 'video/quicktime')) .toBe(FileAppearanceTypeEnum.video) - jest.mocked(mime.getExtension).mockReturnValue('mpeg') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['mpeg'])) expect(getFileAppearanceType('video.mpeg', 'video/mpeg')) .toBe(FileAppearanceTypeEnum.video) - jest.mocked(mime.getExtension).mockReturnValue('webm') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['webm'])) expect(getFileAppearanceType('video.web', 'video/webm')) .toBe(FileAppearanceTypeEnum.video) }) it('should identify audio files', () => { - jest.mocked(mime.getExtension).mockReturnValue('mp3') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['mp3'])) expect(getFileAppearanceType('audio.mp3', 'audio/mpeg')) .toBe(FileAppearanceTypeEnum.audio) - jest.mocked(mime.getExtension).mockReturnValue('m4a') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['m4a'])) expect(getFileAppearanceType('audio.m4a', 'audio/mp4')) .toBe(FileAppearanceTypeEnum.audio) - jest.mocked(mime.getExtension).mockReturnValue('wav') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['wav'])) expect(getFileAppearanceType('audio.wav', 'audio/vnd.wav')) .toBe(FileAppearanceTypeEnum.audio) - jest.mocked(mime.getExtension).mockReturnValue('amr') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['amr'])) expect(getFileAppearanceType('audio.amr', 'audio/AMR')) .toBe(FileAppearanceTypeEnum.audio) - jest.mocked(mime.getExtension).mockReturnValue('mpga') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['mpga'])) expect(getFileAppearanceType('audio.mpga', 'audio/mpeg')) .toBe(FileAppearanceTypeEnum.audio) }) it('should identify code files', () => { - jest.mocked(mime.getExtension).mockReturnValue('html') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['html'])) expect(getFileAppearanceType('index.html', 'text/html')) .toBe(FileAppearanceTypeEnum.code) }) it('should identify PDF files', () => { - jest.mocked(mime.getExtension).mockReturnValue('pdf') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['pdf'])) expect(getFileAppearanceType('doc.pdf', 'application/pdf')) .toBe(FileAppearanceTypeEnum.pdf) }) it('should identify markdown files', () => { - jest.mocked(mime.getExtension).mockReturnValue('md') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['md'])) expect(getFileAppearanceType('file.md', 'text/markdown')) .toBe(FileAppearanceTypeEnum.markdown) - jest.mocked(mime.getExtension).mockReturnValue('markdown') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['markdown'])) expect(getFileAppearanceType('file.markdown', 'text/markdown')) .toBe(FileAppearanceTypeEnum.markdown) - jest.mocked(mime.getExtension).mockReturnValue('mdx') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['mdx'])) expect(getFileAppearanceType('file.mdx', 'text/mdx')) .toBe(FileAppearanceTypeEnum.markdown) }) it('should identify excel files', () => { - jest.mocked(mime.getExtension).mockReturnValue('xlsx') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['xlsx'])) expect(getFileAppearanceType('doc.xlsx', 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet')) .toBe(FileAppearanceTypeEnum.excel) - jest.mocked(mime.getExtension).mockReturnValue('xls') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['xls'])) expect(getFileAppearanceType('doc.xls', 'application/vnd.ms-excel')) .toBe(FileAppearanceTypeEnum.excel) }) it('should identify word files', () => { - jest.mocked(mime.getExtension).mockReturnValue('doc') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['doc'])) expect(getFileAppearanceType('doc.doc', 'application/msword')) .toBe(FileAppearanceTypeEnum.word) - jest.mocked(mime.getExtension).mockReturnValue('docx') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['docx'])) expect(getFileAppearanceType('doc.docx', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document')) .toBe(FileAppearanceTypeEnum.word) }) it('should identify word files', () => { - jest.mocked(mime.getExtension).mockReturnValue('ppt') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['ppt'])) expect(getFileAppearanceType('doc.ppt', 'application/vnd.ms-powerpoint')) .toBe(FileAppearanceTypeEnum.ppt) - jest.mocked(mime.getExtension).mockReturnValue('pptx') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['pptx'])) expect(getFileAppearanceType('doc.pptx', 'application/vnd.openxmlformats-officedocument.presentationml.presentation')) .toBe(FileAppearanceTypeEnum.ppt) }) it('should identify document files', () => { - jest.mocked(mime.getExtension).mockReturnValue('txt') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['txt'])) expect(getFileAppearanceType('file.txt', 'text/plain')) .toBe(FileAppearanceTypeEnum.document) - jest.mocked(mime.getExtension).mockReturnValue('csv') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['csv'])) expect(getFileAppearanceType('file.csv', 'text/csv')) .toBe(FileAppearanceTypeEnum.document) - jest.mocked(mime.getExtension).mockReturnValue('msg') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['msg'])) expect(getFileAppearanceType('file.msg', 'application/vnd.ms-outlook')) .toBe(FileAppearanceTypeEnum.document) - jest.mocked(mime.getExtension).mockReturnValue('eml') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['eml'])) expect(getFileAppearanceType('file.eml', 'message/rfc822')) .toBe(FileAppearanceTypeEnum.document) - jest.mocked(mime.getExtension).mockReturnValue('xml') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['xml'])) expect(getFileAppearanceType('file.xml', 'application/rssxml')) .toBe(FileAppearanceTypeEnum.document) - jest.mocked(mime.getExtension).mockReturnValue('epub') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['epub'])) expect(getFileAppearanceType('file.epub', 'application/epubzip')) .toBe(FileAppearanceTypeEnum.document) }) it('should handle null mime extension', () => { - jest.mocked(mime.getExtension).mockReturnValue(null) + jest.mocked(mime.getAllExtensions).mockReturnValue(null) expect(getFileAppearanceType('file.txt', 'text/plain')) .toBe(FileAppearanceTypeEnum.document) }) @@ -360,7 +375,7 @@ describe('file-uploader utils', () => { describe('isAllowedFileExtension', () => { it('should validate allowed file extensions', () => { - jest.mocked(mime.getExtension).mockReturnValue('pdf') + jest.mocked(mime.getAllExtensions).mockReturnValue(new Set(['pdf'])) expect(isAllowedFileExtension( 'test.pdf', 'application/pdf', diff --git a/web/app/components/base/file-uploader/utils.ts b/web/app/components/base/file-uploader/utils.ts index e05c0b2087..9b5a449481 100644 --- a/web/app/components/base/file-uploader/utils.ts +++ b/web/app/components/base/file-uploader/utils.ts @@ -42,19 +42,38 @@ export const fileUpload: FileUpload = ({ }) } +const additionalExtensionMap = new Map([ + ['text/x-markdown', ['md']], +]) + export const getFileExtension = (fileName: string, fileMimetype: string, isRemote?: boolean) => { let extension = '' - if (fileMimetype) - extension = mime.getExtension(fileMimetype) || '' + let extensions = new Set() + if (fileMimetype) { + const extensionsFromMimeType = mime.getAllExtensions(fileMimetype) || new Set() + const additionalExtensions = additionalExtensionMap.get(fileMimetype) || [] + extensions = new Set([ + ...extensionsFromMimeType, + ...additionalExtensions, + ]) + } - if (fileName && !extension) { + let extensionInFileName = '' + if (fileName) { const fileNamePair = fileName.split('.') const fileNamePairLength = fileNamePair.length - if (fileNamePairLength > 1) - extension = fileNamePair[fileNamePairLength - 1] + if (fileNamePairLength > 1) { + extensionInFileName = fileNamePair[fileNamePairLength - 1].toLowerCase() + if (extensions.has(extensionInFileName)) + extension = extensionInFileName + } + } + if (!extension) { + if (extensions.size > 0) + extension = extensions.values().next().value.toLowerCase() else - extension = '' + extension = extensionInFileName } if (isRemote) diff --git a/web/app/components/base/markdown-blocks/button.tsx b/web/app/components/base/markdown-blocks/button.tsx index 81a3f30660..4646b12921 100644 --- a/web/app/components/base/markdown-blocks/button.tsx +++ b/web/app/components/base/markdown-blocks/button.tsx @@ -1,7 +1,7 @@ import { useChatContext } from '@/app/components/base/chat/chat/context' import Button from '@/app/components/base/button' import cn from '@/utils/classnames' - +import { isValidUrl } from './utils' const MarkdownButton = ({ node }: any) => { const { onSend } = useChatContext() const variant = node.properties.dataVariant @@ -9,25 +9,17 @@ const MarkdownButton = ({ node }: any) => { const link = node.properties.dataLink const size = node.properties.dataSize - function is_valid_url(url: string): boolean { - try { - const parsed_url = new URL(url) - return ['http:', 'https:'].includes(parsed_url.protocol) - } - catch { - return false - } - } - return {children} + + return {children || 'Download'} } } diff --git a/web/app/components/base/markdown-blocks/utils.ts b/web/app/components/base/markdown-blocks/utils.ts new file mode 100644 index 0000000000..4e9e98dbed --- /dev/null +++ b/web/app/components/base/markdown-blocks/utils.ts @@ -0,0 +1,3 @@ +export const isValidUrl = (url: string): boolean => { + return ['http:', 'https:', '//', 'mailto:'].some(prefix => url.startsWith(prefix)) +} diff --git a/web/app/components/base/markdown/index.tsx b/web/app/components/base/markdown/index.tsx index 0e0dc41cf2..1e50e6745b 100644 --- a/web/app/components/base/markdown/index.tsx +++ b/web/app/components/base/markdown/index.tsx @@ -7,7 +7,7 @@ import RemarkGfm from 'remark-gfm' import RehypeRaw from 'rehype-raw' import { flow } from 'lodash-es' import cn from '@/utils/classnames' -import { preprocessLaTeX, preprocessThinkTag } from './markdown-utils' +import { customUrlTransform, preprocessLaTeX, preprocessThinkTag } from './markdown-utils' import { AudioBlock, CodeBlock, @@ -65,6 +65,7 @@ export function Markdown(props: { content: string; className?: string; customDis } }, ]} + urlTransform={customUrlTransform} disallowedElements={['iframe', 'head', 'html', 'meta', 'link', 'style', 'body', ...(props.customDisallowedElements || [])]} components={{ code: CodeBlock, diff --git a/web/app/components/base/markdown/markdown-utils.ts b/web/app/components/base/markdown/markdown-utils.ts index d77b2ddccf..0aa385a1d1 100644 --- a/web/app/components/base/markdown/markdown-utils.ts +++ b/web/app/components/base/markdown/markdown-utils.ts @@ -36,3 +36,52 @@ export const preprocessThinkTag = (content: string) => { (str: string) => str.replace(/(<\/details>)(?![^\S\r\n]*[\r\n])(?![^\S\r\n]*$)/g, '$1\n'), ])(content) } + +/** + * Transforms a URI for use in react-markdown, ensuring security and compatibility. + * This function is designed to work with react-markdown v9+ which has stricter + * default URL handling. + * + * Behavior: + * 1. Always allows the custom 'abbr:' protocol. + * 2. Always allows page-local fragments (e.g., "#some-id"). + * 3. Always allows protocol-relative URLs (e.g., "//example.com/path"). + * 4. Always allows purely relative paths (e.g., "path/to/file", "/abs/path"). + * 5. Allows absolute URLs if their scheme is in a permitted list (case-insensitive): + * 'http:', 'https:', 'mailto:', 'xmpp:', 'irc:', 'ircs:'. + * 6. Intelligently distinguishes colons used for schemes from colons within + * paths, query parameters, or fragments of relative-like URLs. + * 7. Returns the original URI if allowed, otherwise returns `undefined` to + * signal that the URI should be removed/disallowed by react-markdown. + */ +export const customUrlTransform = (uri: string): string | undefined => { + const PERMITTED_SCHEME_REGEX = /^(https?|ircs?|mailto|xmpp|abbr):$/i + + if (uri.startsWith('#')) + return uri + + if (uri.startsWith('//')) + return uri + + const colonIndex = uri.indexOf(':') + + if (colonIndex === -1) + return uri + + const slashIndex = uri.indexOf('/') + const questionMarkIndex = uri.indexOf('?') + const hashIndex = uri.indexOf('#') + + if ( + (slashIndex !== -1 && colonIndex > slashIndex) + || (questionMarkIndex !== -1 && colonIndex > questionMarkIndex) + || (hashIndex !== -1 && colonIndex > hashIndex) + ) + return uri + + const scheme = uri.substring(0, colonIndex + 1).toLowerCase() + if (PERMITTED_SCHEME_REGEX.test(scheme)) + return uri + + return undefined +} diff --git a/web/app/components/base/mermaid/index.tsx b/web/app/components/base/mermaid/index.tsx index a0332ce819..31eaffb813 100644 --- a/web/app/components/base/mermaid/index.tsx +++ b/web/app/components/base/mermaid/index.tsx @@ -487,15 +487,15 @@ const Flowchart = React.forwardRef((props: { 'bg-white': currentTheme === Theme.light, 'bg-slate-900': currentTheme === Theme.dark, }), - mermaidDiv: cn('mermaid cursor-pointer h-auto w-full relative', { + mermaidDiv: cn('mermaid relative h-auto w-full cursor-pointer', { 'bg-white': currentTheme === Theme.light, 'bg-slate-900': currentTheme === Theme.dark, }), - errorMessage: cn('py-4 px-[26px]', { + errorMessage: cn('px-[26px] py-4', { 'text-red-500': currentTheme === Theme.light, 'text-red-400': currentTheme === Theme.dark, }), - errorIcon: cn('w-6 h-6', { + errorIcon: cn('h-6 w-6', { 'text-red-500': currentTheme === Theme.light, 'text-red-400': currentTheme === Theme.dark, }), @@ -503,7 +503,7 @@ const Flowchart = React.forwardRef((props: { 'text-gray-700': currentTheme === Theme.light, 'text-gray-300': currentTheme === Theme.dark, }), - themeToggle: cn('flex items-center justify-center w-10 h-10 rounded-full transition-all duration-300 shadow-md backdrop-blur-sm', { + themeToggle: cn('flex h-10 w-10 items-center justify-center rounded-full shadow-md backdrop-blur-sm transition-all duration-300', { 'bg-white/80 hover:bg-white hover:shadow-lg text-gray-700 border border-gray-200': currentTheme === Theme.light, 'bg-slate-800/80 hover:bg-slate-700 hover:shadow-lg text-yellow-300 border border-slate-600': currentTheme === Theme.dark, }), @@ -512,7 +512,7 @@ const Flowchart = React.forwardRef((props: { // Style classes for look options const getLookButtonClass = (lookType: 'classic' | 'handDrawn') => { return cn( - 'flex items-center justify-center mb-4 w-[calc((100%-8px)/2)] h-8 rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg cursor-pointer system-sm-medium text-text-secondary', + 'system-sm-medium mb-4 flex h-8 w-[calc((100%-8px)/2)] cursor-pointer items-center justify-center rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg text-text-secondary', look === lookType && 'border-[1.5px] border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary', currentTheme === Theme.dark && 'border-slate-600 bg-slate-800 text-slate-300', look === lookType && currentTheme === Theme.dark && 'border-blue-500 bg-slate-700 text-white', @@ -523,7 +523,7 @@ const Flowchart = React.forwardRef((props: {
} className={themeClasses.container}>
-