diff --git a/api/.env.example b/api/.env.example
index ae7e82c779..7878308588 100644
--- a/api/.env.example
+++ b/api/.env.example
@@ -491,3 +491,10 @@ OTEL_METRIC_EXPORT_TIMEOUT=30000
# Prevent Clickjacking
ALLOW_EMBED=false
+
+# Dataset queue monitor configuration
+QUEUE_MONITOR_THRESHOLD=200
+# You can configure multiple ones, separated by commas. eg: test1@dify.ai,test2@dify.ai
+QUEUE_MONITOR_ALERT_EMAILS=
+# Monitor interval in minutes, default is 30 minutes
+QUEUE_MONITOR_INTERVAL=30
diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py
index 1b015b3267..2dcf1710b0 100644
--- a/api/configs/middleware/__init__.py
+++ b/api/configs/middleware/__init__.py
@@ -2,7 +2,7 @@ import os
from typing import Any, Literal, Optional
from urllib.parse import parse_qsl, quote_plus
-from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
+from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
from pydantic_settings import BaseSettings
from .cache.redis_config import RedisConfig
@@ -256,6 +256,25 @@ class InternalTestConfig(BaseSettings):
)
+class DatasetQueueMonitorConfig(BaseSettings):
+ """
+ Configuration settings for Dataset Queue Monitor
+ """
+
+ QUEUE_MONITOR_THRESHOLD: Optional[NonNegativeInt] = Field(
+ description="Threshold for dataset queue monitor",
+ default=200,
+ )
+ QUEUE_MONITOR_ALERT_EMAILS: Optional[str] = Field(
+ description="Emails for dataset queue monitor alert, separated by commas",
+ default=None,
+ )
+ QUEUE_MONITOR_INTERVAL: Optional[NonNegativeFloat] = Field(
+ description="Interval for dataset queue monitor in minutes",
+ default=30,
+ )
+
+
class MiddlewareConfig(
# place the configs in alphabet order
CeleryConfig,
@@ -303,5 +322,6 @@ class MiddlewareConfig(
BaiduVectorDBConfig,
OpenGaussConfig,
TableStoreConfig,
+ DatasetQueueMonitorConfig,
):
pass
diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py
index c100f53078..27e8dd3fa6 100644
--- a/api/controllers/service_api/dataset/dataset.py
+++ b/api/controllers/service_api/dataset/dataset.py
@@ -369,6 +369,7 @@ class DatasetTagsApi(DatasetApiResource):
)
parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
args = parser.parse_args()
+ args["type"] = "knowledge"
tag = TagService.update_tags(args, args.get("tag_id"))
binding_count = TagService.get_tag_binding_count(args.get("tag_id"))
diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py
index 418363ffbb..ab7ab4dcf0 100644
--- a/api/controllers/service_api/dataset/document.py
+++ b/api/controllers/service_api/dataset/document.py
@@ -175,8 +175,11 @@ class DocumentAddByFileApi(DatasetApiResource):
if not dataset:
raise ValueError("Dataset does not exist.")
- if not dataset.indexing_technique and not args.get("indexing_technique"):
+
+ indexing_technique = args.get("indexing_technique") or dataset.indexing_technique
+ if not indexing_technique:
raise ValueError("indexing_technique is required.")
+ args["indexing_technique"] = indexing_technique
# save file info
file = request.files["file"]
@@ -206,12 +209,16 @@ class DocumentAddByFileApi(DatasetApiResource):
knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)
+ dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None
+ if not knowledge_config.original_document_id and not dataset_process_rule and not knowledge_config.process_rule:
+ raise ValueError("process_rule is required.")
+
try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
knowledge_config=knowledge_config,
account=dataset.created_by_account,
- dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
+ dataset_process_rule=dataset_process_rule,
created_from="api",
)
except ProviderTokenNotInitError as ex:
diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py
index 5017835565..e1c021a44a 100644
--- a/api/core/entities/model_entities.py
+++ b/api/core/entities/model_entities.py
@@ -55,6 +55,25 @@ class ProviderModelWithStatusEntity(ProviderModel):
status: ModelStatus
load_balancing_enabled: bool = False
+ def raise_for_status(self) -> None:
+ """
+ Check model status and raise ValueError if not active.
+
+ :raises ValueError: When model status is not active, with a descriptive message
+ """
+ if self.status == ModelStatus.ACTIVE:
+ return
+
+ error_messages = {
+ ModelStatus.NO_CONFIGURE: "Model is not configured",
+ ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded",
+ ModelStatus.NO_PERMISSION: "No permission to use this model",
+ ModelStatus.DISABLED: "Model is disabled",
+ }
+
+ if self.status in error_messages:
+ raise ValueError(error_messages[self.status])
+
class ModelWithProviderEntity(ProviderModelWithStatusEntity):
"""
diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py
index 231743bf2a..06fdb089d4 100644
--- a/api/core/extension/extensible.py
+++ b/api/core/extension/extensible.py
@@ -41,45 +41,53 @@ class Extensible:
extensions = []
position_map: dict[str, int] = {}
- # get the path of the current class
- current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py")
- current_dir_path = os.path.dirname(current_path)
-
- # traverse subdirectories
- for subdir_name in os.listdir(current_dir_path):
- if subdir_name.startswith("__"):
- continue
-
- subdir_path = os.path.join(current_dir_path, subdir_name)
- extension_name = subdir_name
- if os.path.isdir(subdir_path):
+ # Get the package name from the module path
+ package_name = ".".join(cls.__module__.split(".")[:-1])
+
+ try:
+ # Get package directory path
+ package_spec = importlib.util.find_spec(package_name)
+ if not package_spec or not package_spec.origin:
+ raise ImportError(f"Could not find package {package_name}")
+
+ package_dir = os.path.dirname(package_spec.origin)
+
+ # Traverse subdirectories
+ for subdir_name in os.listdir(package_dir):
+ if subdir_name.startswith("__"):
+ continue
+
+ subdir_path = os.path.join(package_dir, subdir_name)
+ if not os.path.isdir(subdir_path):
+ continue
+
+ extension_name = subdir_name
file_names = os.listdir(subdir_path)
- # is builtin extension, builtin extension
- # in the front-end page and business logic, there are special treatments.
+ # Check for extension module file
+ if (extension_name + ".py") not in file_names:
+ logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
+ continue
+
+ # Check for builtin flag and position
builtin = False
- # default position is 0 can not be None for sort_to_dict_by_position_map
position = 0
if "__builtin__" in file_names:
builtin = True
-
builtin_file_path = os.path.join(subdir_path, "__builtin__")
if os.path.exists(builtin_file_path):
position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
position_map[extension_name] = position
- if (extension_name + ".py") not in file_names:
- logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
- continue
-
- # Dynamic loading {subdir_name}.py file and find the subclass of Extensible
- py_path = os.path.join(subdir_path, extension_name + ".py")
- spec = importlib.util.spec_from_file_location(extension_name, py_path)
+ # Import the extension module
+ module_name = f"{package_name}.{extension_name}.{extension_name}"
+ spec = importlib.util.find_spec(module_name)
if not spec or not spec.loader:
- raise Exception(f"Failed to load module {extension_name} from {py_path}")
+ raise ImportError(f"Failed to load module {module_name}")
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
+ # Find extension class
extension_class = None
for name, obj in vars(mod).items():
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
@@ -87,21 +95,21 @@ class Extensible:
break
if not extension_class:
- logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
+ logging.warning(f"Missing subclass of {cls.__name__} in {module_name}, Skip.")
continue
+ # Load schema if not builtin
json_data: dict[str, Any] = {}
if not builtin:
- if "schema.json" not in file_names:
+ json_path = os.path.join(subdir_path, "schema.json")
+ if not os.path.exists(json_path):
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
continue
- json_path = os.path.join(subdir_path, "schema.json")
- json_data = {}
- if os.path.exists(json_path):
- with open(json_path, encoding="utf-8") as f:
- json_data = json.load(f)
+ with open(json_path, encoding="utf-8") as f:
+ json_data = json.load(f)
+ # Create extension
extensions.append(
ModuleExtension(
extension_class=extension_class,
@@ -113,6 +121,11 @@ class Extensible:
)
)
+ except Exception as e:
+ logging.exception("Error scanning extensions")
+ raise
+
+ # Sort extensions by position
sorted_extensions = sort_to_dict_by_position_map(
position_map=position_map, data=extensions, name_func=lambda x: x.name
)
diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py
index 373ef2bbe2..568149cc37 100644
--- a/api/core/model_runtime/entities/model_entities.py
+++ b/api/core/model_runtime/entities/model_entities.py
@@ -160,6 +160,10 @@ class ProviderModel(BaseModel):
deprecated: bool = False
model_config = ConfigDict(protected_namespaces=())
+ @property
+ def support_structure_output(self) -> bool:
+ return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features
+
class ParameterRule(BaseModel):
"""
diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py
index 7570200175..488a394679 100644
--- a/api/core/provider_manager.py
+++ b/api/core/provider_manager.py
@@ -3,7 +3,9 @@ from collections import defaultdict
from json import JSONDecodeError
from typing import Any, Optional, cast
+from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
+from sqlalchemy.orm import Session
from configs import dify_config
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
@@ -393,19 +395,13 @@ class ProviderManager:
@staticmethod
def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]:
- """
- Get all provider records of the workspace.
-
- :param tenant_id: workspace id
- :return:
- """
- providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all()
-
provider_name_to_provider_records_dict = defaultdict(list)
- for provider in providers:
- # TODO: Use provider name with prefix after the data migration
- provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
-
+ with Session(db.engine, expire_on_commit=False) as session:
+ stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True)
+ providers = session.scalars(stmt)
+ for provider in providers:
+ # Use provider name with prefix after the data migration
+ provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider)
return provider_name_to_provider_records_dict
@staticmethod
@@ -416,17 +412,12 @@ class ProviderManager:
:param tenant_id: workspace id
:return:
"""
- # Get all provider model records of the workspace
- provider_models = (
- db.session.query(ProviderModel)
- .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
- .all()
- )
-
provider_name_to_provider_model_records_dict = defaultdict(list)
- for provider_model in provider_models:
- provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
-
+ with Session(db.engine, expire_on_commit=False) as session:
+ stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
+ provider_models = session.scalars(stmt)
+ for provider_model in provider_models:
+ provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model)
return provider_name_to_provider_model_records_dict
@staticmethod
@@ -437,17 +428,14 @@ class ProviderManager:
:param tenant_id: workspace id
:return:
"""
- preferred_provider_types = (
- db.session.query(TenantPreferredModelProvider)
- .filter(TenantPreferredModelProvider.tenant_id == tenant_id)
- .all()
- )
-
- provider_name_to_preferred_provider_type_records_dict = {
- preferred_provider_type.provider_name: preferred_provider_type
- for preferred_provider_type in preferred_provider_types
- }
-
+ provider_name_to_preferred_provider_type_records_dict = {}
+ with Session(db.engine, expire_on_commit=False) as session:
+ stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id)
+ preferred_provider_types = session.scalars(stmt)
+ provider_name_to_preferred_provider_type_records_dict = {
+ preferred_provider_type.provider_name: preferred_provider_type
+ for preferred_provider_type in preferred_provider_types
+ }
return provider_name_to_preferred_provider_type_records_dict
@staticmethod
@@ -458,18 +446,14 @@ class ProviderManager:
:param tenant_id: workspace id
:return:
"""
- provider_model_settings = (
- db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all()
- )
-
provider_name_to_provider_model_settings_dict = defaultdict(list)
- for provider_model_setting in provider_model_settings:
- (
+ with Session(db.engine, expire_on_commit=False) as session:
+ stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id)
+ provider_model_settings = session.scalars(stmt)
+ for provider_model_setting in provider_model_settings:
provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append(
provider_model_setting
)
- )
-
return provider_name_to_provider_model_settings_dict
@staticmethod
@@ -492,15 +476,14 @@ class ProviderManager:
if not model_load_balancing_enabled:
return {}
- provider_load_balancing_configs = (
- db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all()
- )
-
provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list)
- for provider_load_balancing_config in provider_load_balancing_configs:
- provider_name_to_provider_load_balancing_model_configs_dict[
- provider_load_balancing_config.provider_name
- ].append(provider_load_balancing_config)
+ with Session(db.engine, expire_on_commit=False) as session:
+ stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id)
+ provider_load_balancing_configs = session.scalars(stmt)
+ for provider_load_balancing_config in provider_load_balancing_configs:
+ provider_name_to_provider_load_balancing_model_configs_dict[
+ provider_load_balancing_config.provider_name
+ ].append(provider_load_balancing_config)
return provider_name_to_provider_load_balancing_model_configs_dict
@@ -626,10 +609,9 @@ class ProviderManager:
if not cached_provider_credentials:
try:
# fix origin data
- if (
- custom_provider_record.encrypted_config
- and not custom_provider_record.encrypted_config.startswith("{")
- ):
+ if custom_provider_record.encrypted_config is None:
+ raise ValueError("No credentials found")
+ if not custom_provider_record.encrypted_config.startswith("{"):
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
else:
provider_credentials = json.loads(custom_provider_record.encrypted_config)
@@ -733,7 +715,7 @@ class ProviderManager:
return SystemConfiguration(enabled=False)
# Convert provider_records to dict
- quota_type_to_provider_records_dict = {}
+ quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {}
for provider_record in provider_records:
if provider_record.provider_type != ProviderType.SYSTEM.value:
continue
@@ -758,6 +740,11 @@ class ProviderManager:
else:
provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type]
+ if provider_record.quota_used is None:
+ raise ValueError("quota_used is None")
+ if provider_record.quota_limit is None:
+ raise ValueError("quota_limit is None")
+
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
@@ -791,10 +778,9 @@ class ProviderManager:
cached_provider_credentials = provider_credentials_cache.get()
if not cached_provider_credentials:
- try:
- provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config)
- except JSONDecodeError:
- provider_credentials = {}
+ provider_credentials: dict[str, Any] = {}
+ if provider_records and provider_records[0].encrypted_config:
+ provider_credentials = json.loads(provider_records[0].encrypted_config)
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
diff --git a/api/core/rag/datasource/keyword/jieba/stopwords.py b/api/core/rag/datasource/keyword/jieba/stopwords.py
index 9abe78d6ef..54b65d9a2d 100644
--- a/api/core/rag/datasource/keyword/jieba/stopwords.py
+++ b/api/core/rag/datasource/keyword/jieba/stopwords.py
@@ -720,7 +720,7 @@ STOPWORDS = {
"〉",
"〈",
"…",
- " ",
+ " ",
"0",
"1",
"2",
@@ -731,16 +731,6 @@ STOPWORDS = {
"7",
"8",
"9",
- "0",
- "1",
- "2",
- "3",
- "4",
- "5",
- "6",
- "7",
- "8",
- "9",
"二",
"三",
"四",
diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py
index 0a3738ac93..6b9dd9c561 100644
--- a/api/core/rag/datasource/vdb/oracle/oraclevector.py
+++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py
@@ -261,7 +261,7 @@ class OracleVector(BaseVector):
words = pseg.cut(query)
current_entity = ""
for word, pos in words:
- if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名
+ if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名,ns: 地名,nt: 机构名
current_entity += word
else:
if current_entity:
diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py
index 486b4b01af..36d0688807 100644
--- a/api/core/workflow/nodes/llm/entities.py
+++ b/api/core/workflow/nodes/llm/entities.py
@@ -66,7 +66,8 @@ class LLMNodeData(BaseNodeData):
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: dict | None = None
- structured_output_enabled: bool = False
+ # We used 'structured_output_enabled' in the past, but it's not a good name.
+ structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
@field_validator("prompt_config", mode="before")
@classmethod
@@ -74,3 +75,7 @@ class LLMNodeData(BaseNodeData):
if v is None:
return PromptConfig()
return v
+
+ @property
+ def structured_output_enabled(self) -> bool:
+ return self.structured_output_switch_on and self.structured_output is not None
diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py
index df8f614db3..ee181cf3bf 100644
--- a/api/core/workflow/nodes/llm/node.py
+++ b/api/core/workflow/nodes/llm/node.py
@@ -12,9 +12,7 @@ from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
-from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit
-from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.memory.token_buffer_memory import TokenBufferMemory
@@ -74,7 +72,6 @@ from core.workflow.nodes.event import (
from core.workflow.utils.structured_output.entities import (
ResponseFormat,
SpecialModelType,
- SupportStructuredOutputStatus,
)
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
from core.workflow.utils.variable_template_parser import VariableTemplateParser
@@ -277,7 +274,7 @@ class LLMNode(BaseNode[LLMNodeData]):
llm_usage=usage,
)
)
- except LLMNodeError as e:
+ except ValueError as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@@ -527,65 +524,53 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_model_config(
self, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
- model_name = node_data_model.name
- provider_name = node_data_model.provider
+ if not node_data_model.mode:
+ raise LLMModeRequiredError("LLM mode is required.")
- model_manager = ModelManager()
- model_instance = model_manager.get_model_instance(
- tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
+ model = ModelManager().get_model_instance(
+ tenant_id=self.tenant_id,
+ model_type=ModelType.LLM,
+ provider=node_data_model.provider,
+ model=node_data_model.name,
)
- provider_model_bundle = model_instance.provider_model_bundle
- model_type_instance = model_instance.model_type_instance
- model_type_instance = cast(LargeLanguageModel, model_type_instance)
-
- model_credentials = model_instance.credentials
+ model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)
# check model
- provider_model = provider_model_bundle.configuration.get_provider_model(
- model=model_name, model_type=ModelType.LLM
+ provider_model = model.provider_model_bundle.configuration.get_provider_model(
+ model=node_data_model.name, model_type=ModelType.LLM
)
if provider_model is None:
- raise ModelNotExistError(f"Model {model_name} not exist.")
-
- if provider_model.status == ModelStatus.NO_CONFIGURE:
- raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
- elif provider_model.status == ModelStatus.NO_PERMISSION:
- raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
- elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
- raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
+ raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
+ provider_model.raise_for_status()
# model config
- completion_params = node_data_model.completion_params
- stop = []
- if "stop" in completion_params:
- stop = completion_params["stop"]
- del completion_params["stop"]
-
- # get model mode
- model_mode = node_data_model.mode
- if not model_mode:
- raise LLMModeRequiredError("LLM mode is required.")
-
- model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
+ stop: list[str] = []
+ if "stop" in node_data_model.completion_params:
+ stop = node_data_model.completion_params.pop("stop")
+ model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
if not model_schema:
- raise ModelNotExistError(f"Model {model_name} not exist.")
- support_structured_output = self._check_model_structured_output_support()
- if support_structured_output == SupportStructuredOutputStatus.SUPPORTED:
- completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
- elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
- # Set appropriate response format based on model capabilities
- self._set_response_format(completion_params, model_schema.parameter_rules)
- return model_instance, ModelConfigWithCredentialsEntity(
- provider=provider_name,
- model=model_name,
+ raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
+
+ if self.node_data.structured_output_enabled:
+ if model_schema.support_structure_output:
+ node_data_model.completion_params = self._handle_native_json_schema(
+ node_data_model.completion_params, model_schema.parameter_rules
+ )
+ else:
+ # Set appropriate response format based on model capabilities
+ self._set_response_format(node_data_model.completion_params, model_schema.parameter_rules)
+
+ return model, ModelConfigWithCredentialsEntity(
+ provider=node_data_model.provider,
+ model=node_data_model.name,
model_schema=model_schema,
- mode=model_mode,
- provider_model_bundle=provider_model_bundle,
- credentials=model_credentials,
- parameters=completion_params,
+ mode=node_data_model.mode,
+ provider_model_bundle=model.provider_model_bundle,
+ credentials=model.credentials,
+ parameters=node_data_model.completion_params,
stop=stop,
)
@@ -786,13 +771,25 @@ class LLMNode(BaseNode[LLMNodeData]):
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
- support_structured_output = self._check_model_structured_output_support()
- if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED:
- filtered_prompt_messages = self._handle_prompt_based_schema(
- prompt_messages=filtered_prompt_messages,
- )
- stop = model_config.stop
- return filtered_prompt_messages, stop
+
+ model = ModelManager().get_model_instance(
+ tenant_id=self.tenant_id,
+ model_type=ModelType.LLM,
+ provider=self.node_data.model.provider,
+ model=self.node_data.model.name,
+ )
+ model_schema = model.model_type_instance.get_model_schema(
+ model=self.node_data.model.name,
+ credentials=model.credentials,
+ )
+ if not model_schema:
+ raise ModelNotExistError(f"Model {self.node_data.model.name} not exist.")
+ if self.node_data.structured_output_enabled:
+ if not model_schema.support_structure_output:
+ filtered_prompt_messages = self._handle_prompt_based_schema(
+ prompt_messages=filtered_prompt_messages,
+ )
+ return filtered_prompt_messages, model_config.stop
def _parse_structured_output(self, result_text: str) -> dict[str, Any]:
structured_output: dict[str, Any] = {}
@@ -903,7 +900,7 @@ class LLMNode(BaseNode[LLMNodeData]):
variable_mapping["#context#"] = node_data.context.variable_selector
if node_data.vision.enabled:
- variable_mapping["#files#"] = ["sys", SystemVariableKey.FILES.value]
+ variable_mapping["#files#"] = node_data.vision.configs.variable_selector
if node_data.memory:
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value]
@@ -1185,32 +1182,6 @@ class LLMNode(BaseNode[LLMNodeData]):
except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format")
- def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus:
- """
- Check if the current model supports structured output.
-
- Returns:
- SupportStructuredOutput: The support status of structured output
- """
- # Early return if structured output is disabled
- if (
- not isinstance(self.node_data, LLMNodeData)
- or not self.node_data.structured_output_enabled
- or not self.node_data.structured_output
- ):
- return SupportStructuredOutputStatus.DISABLED
- # Get model schema and check if it exists
- model_schema = self._fetch_model_schema(self.node_data.model.provider)
- if not model_schema:
- return SupportStructuredOutputStatus.DISABLED
-
- # Check if model supports structured output feature
- return (
- SupportStructuredOutputStatus.SUPPORTED
- if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features)
- else SupportStructuredOutputStatus.UNSUPPORTED
- )
-
def _save_multimodal_output_and_convert_result_to_markdown(
self,
contents: str | list[PromptMessageContentUnionTypes] | None,
diff --git a/api/core/workflow/utils/structured_output/entities.py b/api/core/workflow/utils/structured_output/entities.py
index 7954acbaee..6491042bfe 100644
--- a/api/core/workflow/utils/structured_output/entities.py
+++ b/api/core/workflow/utils/structured_output/entities.py
@@ -14,11 +14,3 @@ class SpecialModelType(StrEnum):
GEMINI = "gemini"
OLLAMA = "ollama"
-
-
-class SupportStructuredOutputStatus(StrEnum):
- """Constants for structured output support status"""
-
- SUPPORTED = "supported"
- UNSUPPORTED = "unsupported"
- DISABLED = "disabled"
diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py
index 26bd6b3577..a837552007 100644
--- a/api/extensions/ext_celery.py
+++ b/api/extensions/ext_celery.py
@@ -70,6 +70,7 @@ def init_app(app: DifyApp) -> Celery:
"schedule.update_tidb_serverless_status_task",
"schedule.clean_messages",
"schedule.mail_clean_document_notify_task",
+ "schedule.queue_monitor_task",
]
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
beat_schedule = {
@@ -98,6 +99,12 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task",
"schedule": crontab(minute="0", hour="10", day_of_week="1"),
},
+ "datasets-queue-monitor": {
+ "task": "schedule.queue_monitor_task.queue_monitor_task",
+ "schedule": timedelta(
+ minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30
+ ),
+ },
}
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
diff --git a/api/libs/helper.py b/api/libs/helper.py
index afc8f31681..463ba3308b 100644
--- a/api/libs/helper.py
+++ b/api/libs/helper.py
@@ -18,6 +18,7 @@ from flask_restful import fields
from configs import dify_config
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.file import helpers as file_helpers
+from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_redis import redis_client
if TYPE_CHECKING:
@@ -196,7 +197,7 @@ def generate_text_hash(text: str) -> str:
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
if isinstance(response, dict):
- return Response(response=json.dumps(response), status=200, mimetype="application/json")
+ return Response(response=json.dumps(jsonable_encoder(response)), status=200, mimetype="application/json")
else:
def generate() -> Generator:
diff --git a/api/models/provider.py b/api/models/provider.py
index 497cbefc61..1e25f0c90f 100644
--- a/api/models/provider.py
+++ b/api/models/provider.py
@@ -1,6 +1,9 @@
+from datetime import datetime
from enum import Enum
+from typing import Optional
-from sqlalchemy import func
+from sqlalchemy import func, text
+from sqlalchemy.orm import Mapped, mapped_column
from .base import Base
from .engine import db
@@ -51,20 +54,24 @@ class Provider(Base):
),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- provider_name = db.Column(db.String(255), nullable=False)
- provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying"))
- encrypted_config = db.Column(db.Text, nullable=True)
- is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
- last_used = db.Column(db.DateTime, nullable=True)
+ id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ provider_type: Mapped[str] = mapped_column(
+ db.String(40), nullable=False, server_default=text("'custom'::character varying")
+ )
+ encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
+ is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
+ last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)
- quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying"))
- quota_limit = db.Column(db.BigInteger, nullable=True)
- quota_used = db.Column(db.BigInteger, default=0)
+ quota_type: Mapped[Optional[str]] = mapped_column(
+ db.String(40), nullable=True, server_default=text("''::character varying")
+ )
+ quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True)
+ quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
def __repr__(self):
return (
@@ -104,15 +111,15 @@ class ProviderModel(Base):
),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- provider_name = db.Column(db.String(255), nullable=False)
- model_name = db.Column(db.String(255), nullable=False)
- model_type = db.Column(db.String(40), nullable=False)
- encrypted_config = db.Column(db.Text, nullable=True)
- is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
+ encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
+ is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
+ created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TenantDefaultModel(Base):
@@ -122,13 +129,13 @@ class TenantDefaultModel(Base):
db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- provider_name = db.Column(db.String(255), nullable=False)
- model_name = db.Column(db.String(255), nullable=False)
- model_type = db.Column(db.String(40), nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
+ created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TenantPreferredModelProvider(Base):
@@ -138,12 +145,12 @@ class TenantPreferredModelProvider(Base):
db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- provider_name = db.Column(db.String(255), nullable=False)
- preferred_provider_type = db.Column(db.String(40), nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ preferred_provider_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
+ created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class ProviderOrder(Base):
@@ -153,22 +160,24 @@ class ProviderOrder(Base):
db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- provider_name = db.Column(db.String(255), nullable=False)
- account_id = db.Column(StringUUID, nullable=False)
- payment_product_id = db.Column(db.String(191), nullable=False)
- payment_id = db.Column(db.String(191))
- transaction_id = db.Column(db.String(191))
- quantity = db.Column(db.Integer, nullable=False, server_default=db.text("1"))
- currency = db.Column(db.String(40))
- total_amount = db.Column(db.Integer)
- payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying"))
- paid_at = db.Column(db.DateTime)
- pay_failed_at = db.Column(db.DateTime)
- refunded_at = db.Column(db.DateTime)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ payment_product_id: Mapped[str] = mapped_column(db.String(191), nullable=False)
+ payment_id: Mapped[Optional[str]] = mapped_column(db.String(191))
+ transaction_id: Mapped[Optional[str]] = mapped_column(db.String(191))
+ quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1"))
+ currency: Mapped[Optional[str]] = mapped_column(db.String(40))
+ total_amount: Mapped[Optional[int]] = mapped_column(db.Integer)
+ payment_status: Mapped[str] = mapped_column(
+ db.String(40), nullable=False, server_default=text("'wait_pay'::character varying")
+ )
+ paid_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
+ pay_failed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
+ refunded_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
+ created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class ProviderModelSetting(Base):
@@ -182,15 +191,15 @@ class ProviderModelSetting(Base):
db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- provider_name = db.Column(db.String(255), nullable=False)
- model_name = db.Column(db.String(255), nullable=False)
- model_type = db.Column(db.String(40), nullable=False)
- enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
- load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
+ enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
+ load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false"))
+ created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class LoadBalancingModelConfig(Base):
@@ -204,13 +213,13 @@ class LoadBalancingModelConfig(Base):
db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- provider_name = db.Column(db.String(255), nullable=False)
- model_name = db.Column(db.String(255), nullable=False)
- model_type = db.Column(db.String(40), nullable=False)
- name = db.Column(db.String(255), nullable=False)
- encrypted_config = db.Column(db.Text, nullable=True)
- enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ model_name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ model_type: Mapped[str] = mapped_column(db.String(40), nullable=False)
+ name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
+ enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true"))
+ created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
diff --git a/api/schedule/queue_monitor_task.py b/api/schedule/queue_monitor_task.py
new file mode 100644
index 0000000000..e3a7021b9d
--- /dev/null
+++ b/api/schedule/queue_monitor_task.py
@@ -0,0 +1,62 @@
+import logging
+from datetime import datetime
+from urllib.parse import urlparse
+
+import click
+from flask import render_template
+from redis import Redis
+
+import app
+from configs import dify_config
+from extensions.ext_database import db
+from extensions.ext_mail import mail
+
+# Create a dedicated Redis connection (using the same configuration as Celery)
+celery_broker_url = dify_config.CELERY_BROKER_URL
+
+parsed = urlparse(celery_broker_url)
+host = parsed.hostname or "localhost"
+port = parsed.port or 6379
+password = parsed.password or None
+redis_db = parsed.path.strip("/") or "1" # type: ignore
+
+celery_redis = Redis(host=host, port=port, password=password, db=redis_db)
+
+
+@app.celery.task(queue="monitor")
+def queue_monitor_task():
+ queue_name = "dataset"
+ threshold = dify_config.QUEUE_MONITOR_THRESHOLD
+
+ try:
+ queue_length = celery_redis.llen(f"{queue_name}")
+ logging.info(click.style(f"Start monitor {queue_name}", fg="green"))
+ logging.info(click.style(f"Queue length: {queue_length}", fg="green"))
+
+ if queue_length >= threshold:
+ warning_msg = f"Queue {queue_name} task count exceeded the limit.: {queue_length}/{threshold}"
+ logging.warning(click.style(warning_msg, fg="red"))
+ alter_emails = dify_config.QUEUE_MONITOR_ALERT_EMAILS
+ if alter_emails:
+ to_list = alter_emails.split(",")
+ for to in to_list:
+ try:
+ current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+ html_content = render_template(
+ "queue_monitor_alert_email_template_en-US.html",
+ queue_name=queue_name,
+ queue_length=queue_length,
+ threshold=threshold,
+ alert_time=current_time,
+ )
+ mail.send(
+ to=to, subject="Alert: Dataset Queue pending tasks exceeded the limit", html=html_content
+ )
+ except Exception as e:
+ logging.exception(click.style("Exception occurred during sending email", fg="red"))
+
+ except Exception as e:
+ logging.exception(click.style("Exception occurred during queue monitoring", fg="red"))
+ finally:
+ if db.session.is_active:
+ db.session.close()
diff --git a/api/services/tag_service.py b/api/services/tag_service.py
index be748e8dd1..74c6150b44 100644
--- a/api/services/tag_service.py
+++ b/api/services/tag_service.py
@@ -46,6 +46,8 @@ class TagService:
@staticmethod
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str) -> list:
+ if not tag_type or not tag_name:
+ return []
tags = (
db.session.query(Tag)
.filter(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
@@ -88,7 +90,7 @@ class TagService:
@staticmethod
def update_tags(args: dict, tag_id: str) -> Tag:
- if TagService.get_tag_by_tag_name(args["type"], current_user.current_tenant_id, args["name"]):
+ if TagService.get_tag_by_tag_name(args.get("type", ""), current_user.current_tenant_id, args.get("name", "")):
raise ValueError("Tag name already exists")
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
if not tag:
diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py
index f32bc4f187..51b6343fdc 100644
--- a/api/tasks/batch_create_segment_to_index_task.py
+++ b/api/tasks/batch_create_segment_to_index_task.py
@@ -5,7 +5,7 @@ import uuid
import click
from celery import shared_task # type: ignore
-from sqlalchemy import func, select
+from sqlalchemy import func
from sqlalchemy.orm import Session
from core.model_manager import ModelManager
@@ -68,11 +68,6 @@ def batch_create_segment_to_index_task(
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
- word_count_change = 0
- segments_to_insert: list[str] = []
- max_position_stmt = select(func.max(DocumentSegment.position)).where(
- DocumentSegment.document_id == dataset_document.id
- )
word_count_change = 0
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens(
diff --git a/api/templates/queue_monitor_alert_email_template_en-US.html b/api/templates/queue_monitor_alert_email_template_en-US.html
new file mode 100644
index 0000000000..2885210864
--- /dev/null
+++ b/api/templates/queue_monitor_alert_email_template_en-US.html
@@ -0,0 +1,129 @@
+
+
+
+
+
+
+
+
+
+
+
Queue Monitoring Alert
+
Our system has detected an abnormal queue status that requires your attention:
+
+
+
Queue Task Alert
+
+ Queue "{{queue_name}}" has {{queue_length}} pending tasks (Threshold: {{threshold}})
+
+
+
+
+
Recommended actions:
+
1. Check the queue processing status in the system dashboard
+
2. Verify if there are any processing bottlenecks
+
3. Consider scaling up workers if needed
+
+
+
Additional Information:
+
+ - Alert triggered at: {{alert_time}}
+
+
+
+
+
diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py
index 5fbee266bd..6aa48b1cbb 100644
--- a/api/tests/integration_tests/workflow/nodes/test_llm.py
+++ b/api/tests/integration_tests/workflow/nodes/test_llm.py
@@ -3,11 +3,16 @@ import os
import time
import uuid
from collections.abc import Generator
-from unittest.mock import MagicMock
+from decimal import Decimal
+from unittest.mock import MagicMock, patch
import pytest
+from app_factory import create_app
+from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
+from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
+from core.model_runtime.entities.message_entities import AssistantPromptMessage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
@@ -19,13 +24,27 @@ from core.workflow.nodes.llm.node import LLMNode
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowType
-from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
+@pytest.fixture(scope="session")
+def app():
+ # Set up storage configuration
+ os.environ["STORAGE_TYPE"] = "opendal"
+ os.environ["OPENDAL_SCHEME"] = "fs"
+ os.environ["OPENDAL_FS_ROOT"] = "storage"
+
+ # Ensure storage directory exists
+ os.makedirs("storage", exist_ok=True)
+
+ app = create_app()
+ dify_config.LOGIN_DISABLED = True
+ return app
+
+
def init_llm_node(config: dict) -> LLMNode:
graph_config = {
"edges": [
@@ -40,13 +59,19 @@ def init_llm_node(config: dict) -> LLMNode:
graph = Graph.init(graph_config=graph_config)
+ # Use proper UUIDs for database compatibility
+ tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
+ app_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056c"
+ workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d"
+ user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e"
+
init_params = GraphInitParams(
- tenant_id="1",
- app_id="1",
+ tenant_id=tenant_id,
+ app_id=app_id,
workflow_type=WorkflowType.WORKFLOW,
- workflow_id="1",
+ workflow_id=workflow_id,
graph_config=graph_config,
- user_id="1",
+ user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
@@ -77,115 +102,197 @@ def init_llm_node(config: dict) -> LLMNode:
return node
-def test_execute_llm(setup_model_mock):
- node = init_llm_node(
- config={
- "id": "llm",
- "data": {
- "title": "123",
- "type": "llm",
- "model": {
- "provider": "langgenius/openai/openai",
- "name": "gpt-3.5-turbo",
- "mode": "chat",
- "completion_params": {},
+def test_execute_llm(app):
+ with app.app_context():
+ node = init_llm_node(
+ config={
+ "id": "llm",
+ "data": {
+ "title": "123",
+ "type": "llm",
+ "model": {
+ "provider": "langgenius/openai/openai",
+ "name": "gpt-3.5-turbo",
+ "mode": "chat",
+ "completion_params": {},
+ },
+ "prompt_template": [
+ {
+ "role": "system",
+ "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}.",
+ },
+ {"role": "user", "text": "{{#sys.query#}}"},
+ ],
+ "memory": None,
+ "context": {"enabled": False},
+ "vision": {"enabled": False},
},
- "prompt_template": [
- {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."},
- {"role": "user", "text": "{{#sys.query#}}"},
- ],
- "memory": None,
- "context": {"enabled": False},
- "vision": {"enabled": False},
},
- },
- )
+ )
- credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
+ credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
- # Mock db.session.close()
- db.session.close = MagicMock()
+ # Create a proper LLM result with real entities
+ mock_usage = LLMUsage(
+ prompt_tokens=30,
+ prompt_unit_price=Decimal("0.001"),
+ prompt_price_unit=Decimal("1000"),
+ prompt_price=Decimal("0.00003"),
+ completion_tokens=20,
+ completion_unit_price=Decimal("0.002"),
+ completion_price_unit=Decimal("1000"),
+ completion_price=Decimal("0.00004"),
+ total_tokens=50,
+ total_price=Decimal("0.00007"),
+ currency="USD",
+ latency=0.5,
+ )
- node._fetch_model_config = get_mocked_fetch_model_config(
- provider="langgenius/openai/openai",
- model="gpt-3.5-turbo",
- mode="chat",
- credentials=credentials,
- )
+ mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")
+
+ mock_llm_result = LLMResult(
+ model="gpt-3.5-turbo",
+ prompt_messages=[],
+ message=mock_message,
+ usage=mock_usage,
+ )
+
+ # Create a simple mock model instance that doesn't call real providers
+ mock_model_instance = MagicMock()
+ mock_model_instance.invoke_llm.return_value = mock_llm_result
+
+ # Create a simple mock model config with required attributes
+ mock_model_config = MagicMock()
+ mock_model_config.mode = "chat"
+ mock_model_config.provider = "langgenius/openai/openai"
+ mock_model_config.model = "gpt-3.5-turbo"
+ mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
+
+ # Mock the _fetch_model_config method
+ def mock_fetch_model_config_func(_node_data_model):
+ return mock_model_instance, mock_model_config
+
+ # Also mock ModelManager.get_model_instance to avoid database calls
+ def mock_get_model_instance(_self, **kwargs):
+ return mock_model_instance
- # execute node
- result = node._run()
- assert isinstance(result, Generator)
+ with (
+ patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
+ patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
+ ):
+ # execute node
+ result = node._run()
+ assert isinstance(result, Generator)
- for item in result:
- if isinstance(item, RunCompletedEvent):
- assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
- assert item.run_result.process_data is not None
- assert item.run_result.outputs is not None
- assert item.run_result.outputs.get("text") is not None
- assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
+ for item in result:
+ if isinstance(item, RunCompletedEvent):
+ assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert item.run_result.process_data is not None
+ assert item.run_result.outputs is not None
+ assert item.run_result.outputs.get("text") is not None
+ assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
-def test_execute_llm_with_jinja2(setup_code_executor_mock, setup_model_mock):
+def test_execute_llm_with_jinja2(app, setup_code_executor_mock):
"""
Test execute LLM node with jinja2
"""
- node = init_llm_node(
- config={
- "id": "llm",
- "data": {
- "title": "123",
- "type": "llm",
- "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
- "prompt_config": {
- "jinja2_variables": [
- {"variable": "sys_query", "value_selector": ["sys", "query"]},
- {"variable": "output", "value_selector": ["abc", "output"]},
- ]
- },
- "prompt_template": [
- {
- "role": "system",
- "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
- "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
- "edition_type": "jinja2",
+ with app.app_context():
+ node = init_llm_node(
+ config={
+ "id": "llm",
+ "data": {
+ "title": "123",
+ "type": "llm",
+ "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
+ "prompt_config": {
+ "jinja2_variables": [
+ {"variable": "sys_query", "value_selector": ["sys", "query"]},
+ {"variable": "output", "value_selector": ["abc", "output"]},
+ ]
},
- {
- "role": "user",
- "text": "{{#sys.query#}}",
- "jinja2_text": "{{sys_query}}",
- "edition_type": "basic",
- },
- ],
- "memory": None,
- "context": {"enabled": False},
- "vision": {"enabled": False},
+ "prompt_template": [
+ {
+ "role": "system",
+ "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}",
+ "jinja2_text": "you are a helpful assistant.\ntoday's weather is {{output}}.",
+ "edition_type": "jinja2",
+ },
+ {
+ "role": "user",
+ "text": "{{#sys.query#}}",
+ "jinja2_text": "{{sys_query}}",
+ "edition_type": "basic",
+ },
+ ],
+ "memory": None,
+ "context": {"enabled": False},
+ "vision": {"enabled": False},
+ },
},
- },
- )
+ )
- credentials = {"openai_api_key": os.environ.get("OPENAI_API_KEY")}
+ # Mock db.session.close()
+ db.session.close = MagicMock()
- # Mock db.session.close()
- db.session.close = MagicMock()
+ # Create a proper LLM result with real entities
+ mock_usage = LLMUsage(
+ prompt_tokens=30,
+ prompt_unit_price=Decimal("0.001"),
+ prompt_price_unit=Decimal("1000"),
+ prompt_price=Decimal("0.00003"),
+ completion_tokens=20,
+ completion_unit_price=Decimal("0.002"),
+ completion_price_unit=Decimal("1000"),
+ completion_price=Decimal("0.00004"),
+ total_tokens=50,
+ total_price=Decimal("0.00007"),
+ currency="USD",
+ latency=0.5,
+ )
- node._fetch_model_config = get_mocked_fetch_model_config(
- provider="langgenius/openai/openai",
- model="gpt-3.5-turbo",
- mode="chat",
- credentials=credentials,
- )
+ mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
+
+ mock_llm_result = LLMResult(
+ model="gpt-3.5-turbo",
+ prompt_messages=[],
+ message=mock_message,
+ usage=mock_usage,
+ )
+
+ # Create a simple mock model instance that doesn't call real providers
+ mock_model_instance = MagicMock()
+ mock_model_instance.invoke_llm.return_value = mock_llm_result
+
+ # Create a simple mock model config with required attributes
+ mock_model_config = MagicMock()
+ mock_model_config.mode = "chat"
+ mock_model_config.provider = "openai"
+ mock_model_config.model = "gpt-3.5-turbo"
+ mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
+
+ # Mock the _fetch_model_config method
+ def mock_fetch_model_config_func(_node_data_model):
+ return mock_model_instance, mock_model_config
+
+ # Also mock ModelManager.get_model_instance to avoid database calls
+ def mock_get_model_instance(_self, **kwargs):
+ return mock_model_instance
- # execute node
- result = node._run()
+ with (
+ patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
+ patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
+ ):
+ # execute node
+ result = node._run()
- for item in result:
- if isinstance(item, RunCompletedEvent):
- assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
- assert item.run_result.process_data is not None
- assert "sunny" in json.dumps(item.run_result.process_data)
- assert "what's the weather today?" in json.dumps(item.run_result.process_data)
+ for item in result:
+ if isinstance(item, RunCompletedEvent):
+ assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert item.run_result.process_data is not None
+ assert "sunny" in json.dumps(item.run_result.process_data)
+ assert "what's the weather today?" in json.dumps(item.run_result.process_data)
def test_extract_json():
diff --git a/docker/.env.example b/docker/.env.example
index ac9536be03..4cf5e202d0 100644
--- a/docker/.env.example
+++ b/docker/.env.example
@@ -1057,7 +1057,7 @@ PLUGIN_MAX_EXECUTION_TIMEOUT=600
PIP_MIRROR_URL=
# https://github.com/langgenius/dify-plugin-daemon/blob/main/.env.example
-# Plugin storage type, local aws_s3 tencent_cos azure_blob aliyun_oss
+# Plugin storage type, local aws_s3 tencent_cos azure_blob aliyun_oss volcengine_tos
PLUGIN_STORAGE_TYPE=local
PLUGIN_STORAGE_LOCAL_ROOT=/app/storage
PLUGIN_WORKING_PATH=/app/storage/cwd
@@ -1087,6 +1087,11 @@ PLUGIN_ALIYUN_OSS_ACCESS_KEY_ID=
PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET=
PLUGIN_ALIYUN_OSS_AUTH_VERSION=v4
PLUGIN_ALIYUN_OSS_PATH=
+# Plugin oss volcengine tos
+PLUGIN_VOLCENGINE_TOS_ENDPOINT=
+PLUGIN_VOLCENGINE_TOS_ACCESS_KEY=
+PLUGIN_VOLCENGINE_TOS_SECRET_KEY=
+PLUGIN_VOLCENGINE_TOS_REGION=
# ------------------------------
# OTLP Collector Configuration
@@ -1106,3 +1111,10 @@ OTEL_METRIC_EXPORT_TIMEOUT=30000
# Prevent Clickjacking
ALLOW_EMBED=false
+
+# Dataset queue monitor configuration
+QUEUE_MONITOR_THRESHOLD=200
+# You can configure multiple ones, separated by commas. eg: test1@dify.ai,test2@dify.ai
+QUEUE_MONITOR_ALERT_EMAILS=
+# Monitor interval in minutes, default is 30 minutes
+QUEUE_MONITOR_INTERVAL=30
diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml
index 74a7b87bf9..75bdab1a06 100644
--- a/docker/docker-compose-template.yaml
+++ b/docker/docker-compose-template.yaml
@@ -184,6 +184,10 @@ services:
ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-}
ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4}
ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-}
+ VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-}
+ VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-}
+ VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-}
+ VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-}
ports:
- "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}"
volumes:
diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml
index d4a0b94619..8276e2977f 100644
--- a/docker/docker-compose.middleware.yaml
+++ b/docker/docker-compose.middleware.yaml
@@ -121,6 +121,10 @@ services:
ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-}
ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4}
ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-}
+ VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-}
+ VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-}
+ VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-}
+ VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-}
ports:
- "${EXPOSE_PLUGIN_DAEMON_PORT:-5002}:${PLUGIN_DAEMON_PORT:-5002}"
- "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}"
diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml
index 41e86d015f..e559021684 100644
--- a/docker/docker-compose.yaml
+++ b/docker/docker-compose.yaml
@@ -484,6 +484,10 @@ x-shared-env: &shared-api-worker-env
PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-}
PLUGIN_ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4}
PLUGIN_ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-}
+ PLUGIN_VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-}
+ PLUGIN_VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-}
+ PLUGIN_VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-}
+ PLUGIN_VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-}
ENABLE_OTEL: ${ENABLE_OTEL:-false}
OTLP_BASE_ENDPOINT: ${OTLP_BASE_ENDPOINT:-http://localhost:4318}
OTLP_API_KEY: ${OTLP_API_KEY:-}
@@ -497,6 +501,9 @@ x-shared-env: &shared-api-worker-env
OTEL_BATCH_EXPORT_TIMEOUT: ${OTEL_BATCH_EXPORT_TIMEOUT:-10000}
OTEL_METRIC_EXPORT_TIMEOUT: ${OTEL_METRIC_EXPORT_TIMEOUT:-30000}
ALLOW_EMBED: ${ALLOW_EMBED:-false}
+ QUEUE_MONITOR_THRESHOLD: ${QUEUE_MONITOR_THRESHOLD:-200}
+ QUEUE_MONITOR_ALERT_EMAILS: ${QUEUE_MONITOR_ALERT_EMAILS:-}
+ QUEUE_MONITOR_INTERVAL: ${QUEUE_MONITOR_INTERVAL:-30}
services:
# API service
@@ -683,6 +690,10 @@ services:
ALIYUN_OSS_ACCESS_KEY_SECRET: ${PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET:-}
ALIYUN_OSS_AUTH_VERSION: ${PLUGIN_ALIYUN_OSS_AUTH_VERSION:-v4}
ALIYUN_OSS_PATH: ${PLUGIN_ALIYUN_OSS_PATH:-}
+ VOLCENGINE_TOS_ENDPOINT: ${PLUGIN_VOLCENGINE_TOS_ENDPOINT:-}
+ VOLCENGINE_TOS_ACCESS_KEY: ${PLUGIN_VOLCENGINE_TOS_ACCESS_KEY:-}
+ VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-}
+ VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-}
ports:
- "${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}:${PLUGIN_DEBUGGING_PORT:-5003}"
volumes:
diff --git a/docker/middleware.env.example b/docker/middleware.env.example
index ba6859885b..66037f281c 100644
--- a/docker/middleware.env.example
+++ b/docker/middleware.env.example
@@ -152,3 +152,8 @@ PLUGIN_ALIYUN_OSS_ACCESS_KEY_ID=
PLUGIN_ALIYUN_OSS_ACCESS_KEY_SECRET=
PLUGIN_ALIYUN_OSS_AUTH_VERSION=v4
PLUGIN_ALIYUN_OSS_PATH=
+# Plugin oss volcengine tos
+PLUGIN_VOLCENGINE_TOS_ENDPOINT=
+PLUGIN_VOLCENGINE_TOS_ACCESS_KEY=
+PLUGIN_VOLCENGINE_TOS_SECRET_KEY=
+PLUGIN_VOLCENGINE_TOS_REGION=
diff --git a/web/app/(commonLayout)/apps/AppCard.tsx b/web/app/(commonLayout)/apps/AppCard.tsx
index 42967b96f4..31b9ed87c2 100644
--- a/web/app/(commonLayout)/apps/AppCard.tsx
+++ b/web/app/(commonLayout)/apps/AppCard.tsx
@@ -4,7 +4,7 @@ import { useContext, useContextSelector } from 'use-context-selector'
import { useRouter } from 'next/navigation'
import { useCallback, useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
-import { RiBuildingLine, RiGlobalLine, RiLockLine, RiMoreFill } from '@remixicon/react'
+import { RiBuildingLine, RiGlobalLine, RiLockLine, RiMoreFill, RiVerifiedBadgeLine } from '@remixicon/react'
import cn from '@/utils/classnames'
import type { App } from '@/types/app'
import Confirm from '@/app/components/base/confirm'
@@ -338,7 +338,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
{app.access_mode === AccessMode.PUBLIC &&
-
+
}
{app.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS &&
@@ -346,6 +346,9 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
{app.access_mode === AccessMode.ORGANIZATION &&
}
+ {app.access_mode === AccessMode.EXTERNAL_MEMBERS &&
+
+ }
diff --git a/web/app/(commonLayout)/datasets/Container.tsx b/web/app/(commonLayout)/datasets/Container.tsx
index 62569ab26b..112b6a752e 100644
--- a/web/app/(commonLayout)/datasets/Container.tsx
+++ b/web/app/(commonLayout)/datasets/Container.tsx
@@ -87,7 +87,7 @@ const Container = () => {
return (
-
+
setActiveTab(newActiveTab)}
diff --git a/web/app/(commonLayout)/datasets/template/template.ja.mdx b/web/app/(commonLayout)/datasets/template/template.ja.mdx
index b9fab19948..a796b65bae 100644
--- a/web/app/(commonLayout)/datasets/template/template.ja.mdx
+++ b/web/app/(commonLayout)/datasets/template/template.ja.mdx
@@ -192,15 +192,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
- original_document_id が渡されない場合、新しい操作が実行され、process_rule が必要です。
- indexing_technique インデックスモード
- - high_quality 高品質: 埋め込みモデルを使用してベクトルデータベースインデックスを構築
- - economy 経済: キーワードテーブルインデックスの反転インデックスを構築
+ - high_quality 高品質:埋め込みモデルを使用してベクトルデータベースインデックスを構築
+ - economy 経済:キーワードテーブルインデックスの反転インデックスを構築
- doc_form インデックス化された内容の形式
- text_model テキストドキュメントは直接埋め込まれます; `economy` モードではこの形式がデフォルト
- hierarchical_model 親子モード
- - qa_model Q&A モード: 分割されたドキュメントの質問と回答ペアを生成し、質問を埋め込みます
+ - qa_model Q&A モード:分割されたドキュメントの質問と回答ペアを生成し、質問を埋め込みます
- - doc_language Q&A モードでは、ドキュメントの言語を指定します。例: English, Chinese
+ - doc_language Q&A モードでは、ドキュメントの言語を指定します。例:English, Chinese
- process_rule 処理ルール
- mode (string) クリーニング、セグメンテーションモード、自動 / カスタム
@@ -214,7 +214,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
- segmentation (object) セグメンテーションルール
- separator カスタムセグメント識別子。現在は 1 つの区切り文字のみ設定可能。デフォルトは \n
- max_tokens 最大長 (トークン) デフォルトは 1000
- - parent_mode 親チャンクの検索モード: full-doc 全文検索 / paragraph 段落検索
+ - parent_mode 親チャンクの検索モード:full-doc 全文検索 / paragraph 段落検索
- subchunk_segmentation (object) 子チャンクルール
- separator セグメンテーション識別子。現在は 1 つの区切り文字のみ許可。デフォルトは ***
- max_tokens 最大長 (トークン) は親チャンクの長さより短いことを検証する必要があります
@@ -324,7 +324,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
- partial_members 一部のメンバー
- プロバイダー (オプション、デフォルト: vendor)
+ プロバイダー (オプション、デフォルト:vendor)
- vendor ベンダー
- external 外部ナレッジ
@@ -415,16 +415,16 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
検索キーワード、オプション
- タグIDリスト、オプション
+ タグ ID リスト、オプション
- ページ番号、オプション、デフォルト1
+ ページ番号、オプション、デフォルト 1
- 返されるアイテム数、オプション、デフォルト20、範囲1-100
+ 返されるアイテム数、オプション、デフォルト 20、範囲 1-100
- すべてのデータセットを含めるかどうか(所有者のみ有効)、オプション、デフォルトはfalse
+ すべてのデータセットを含めるかどうか(所有者のみ有効)、オプション、デフォルトは false
@@ -2013,7 +2013,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Request Body
- (text) 新しいタグ名、必須、最大長50文字
+ (text) 新しいタグ名、必須、最大長 50 文字
@@ -2099,10 +2099,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Request Body
- (text) 変更後のタグ名、必須、最大長50文字
+ (text) 変更後のタグ名、必須、最大長 50 文字
- (text) タグID、必須
+ (text) タグ ID、必須
@@ -2147,7 +2147,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Request Body
- (text) タグID、必須
+ (text) タグ ID、必須
@@ -2188,10 +2188,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Request Body
- (list) タグIDリスト、必須
+ (list) タグ ID リスト、必須
- (text) ナレッジベースID、必須
+ (text) ナレッジベース ID、必須
@@ -2230,10 +2230,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Request Body
- (text) タグID、必須
+ (text) タグ ID、必須
- (text) ナレッジベースID、必須
+ (text) ナレッジベース ID、必須
@@ -2273,7 +2273,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Path
- (text) ナレッジベースID
+ (text) ナレッジベース ID
diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx
index b10f22002a..08ef5d562a 100644
--- a/web/app/(commonLayout)/datasets/template/template.zh.mdx
+++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx
@@ -207,7 +207,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
- doc_language 在 Q&A 模式下,指定文档的语言,例如:English、Chinese
- process_rule 处理规则
- - mode (string) 清洗、分段模式 ,automatic 自动 / custom 自定义 / hierarchical 父子
+ - mode (string) 清洗、分段模式,automatic 自动 / custom 自定义 / hierarchical 父子
- rules (object) 自定义规则(自动模式下,该字段为空)
- pre_processing_rules (array[object]) 预处理规则
- id (string) 预处理规则的唯一标识符
@@ -234,12 +234,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
- hybrid_search 混合检索
- semantic_search 语义检索
- full_text_search 全文检索
- - reranking_enable (bool) 是否开启rerank
+ - reranking_enable (bool) 是否开启 rerank
- reranking_model (object) Rerank 模型配置
- reranking_provider_name (string) Rerank 模型的提供商
- reranking_model_name (string) Rerank 模型的名称
- top_k (int) 召回条数
- - score_threshold_enabled (bool)是否开启召回分数限制
+ - score_threshold_enabled (bool) 是否开启召回分数限制
- score_threshold (float) 召回分数限制
@@ -350,12 +350,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
- hybrid_search 混合检索
- semantic_search 语义检索
- full_text_search 全文检索
- - reranking_enable (bool) 是否开启rerank
+ - reranking_enable (bool) 是否开启 rerank
- reranking_model (object) Rerank 模型配置
- reranking_provider_name (string) Rerank 模型的提供商
- reranking_model_name (string) Rerank 模型的名称
- top_k (int) 召回条数
- - score_threshold_enabled (bool)是否开启召回分数限制
+ - score_threshold_enabled (bool) 是否开启召回分数限制
- score_threshold (float) 召回分数限制
@@ -1322,7 +1322,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
文档 ID
- 文档分段ID
+ 文档分段 ID
@@ -1435,7 +1435,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
文档 ID
- 文档分段ID
+ 文档分段 ID
@@ -2404,7 +2404,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Request Body
- (text) 新标签名称,必填,最大长度为50
+ (text) 新标签名称,必填,最大长度为 50
@@ -2490,10 +2490,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Request Body
- (text) 修改后的标签名称,必填,最大长度为50
+ (text) 修改后的标签名称,必填,最大长度为 50
- (text) 标签ID,必填
+ (text) 标签 ID,必填
@@ -2538,7 +2538,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Request Body
- (text) 标签ID,必填
+ (text) 标签 ID,必填
@@ -2579,10 +2579,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Request Body
- (list) 标签ID列表,必填
+ (list) 标签 ID 列表,必填
- (text) 知识库ID,必填
+ (text) 知识库 ID,必填
@@ -2621,10 +2621,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Request Body
- (text) 标签ID,必填
+ (text) 标签 ID,必填
- (text) 知识库ID,必填
+ (text) 知识库 ID,必填
@@ -2664,7 +2664,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
### Path
- (text) 知识库ID
+ (text) 知识库 ID
diff --git a/web/app/(shareLayout)/layout.tsx b/web/app/(shareLayout)/layout.tsx
index 83adbd3cae..8db336a17d 100644
--- a/web/app/(shareLayout)/layout.tsx
+++ b/web/app/(shareLayout)/layout.tsx
@@ -1,14 +1,42 @@
-import React from 'react'
+'use client'
+import React, { useEffect, useState } from 'react'
import type { FC } from 'react'
-import type { Metadata } from 'next'
-
-export const metadata: Metadata = {
- icons: 'data:,', // prevent browser from using default favicon
-}
+import { usePathname, useSearchParams } from 'next/navigation'
+import Loading from '../components/base/loading'
+import { useGlobalPublicStore } from '@/context/global-public-context'
+import { AccessMode } from '@/models/access-control'
+import { getAppAccessModeByAppCode } from '@/service/share'
const Layout: FC<{
children: React.ReactNode
}> = ({ children }) => {
+ const isGlobalPending = useGlobalPublicStore(s => s.isGlobalPending)
+ const setWebAppAccessMode = useGlobalPublicStore(s => s.setWebAppAccessMode)
+ const pathname = usePathname()
+ const searchParams = useSearchParams()
+ const redirectUrl = searchParams.get('redirect_url')
+ const [isLoading, setIsLoading] = useState(true)
+ useEffect(() => {
+ (async () => {
+ let appCode: string | null = null
+ if (redirectUrl)
+ appCode = redirectUrl?.split('/').pop() || null
+ else
+ appCode = pathname.split('/').pop() || null
+
+ if (!appCode)
+ return
+ setIsLoading(true)
+ const ret = await getAppAccessModeByAppCode(appCode)
+ setWebAppAccessMode(ret?.accessMode || AccessMode.PUBLIC)
+ setIsLoading(false)
+ })()
+ }, [pathname, redirectUrl, setWebAppAccessMode])
+ if (isLoading || isGlobalPending) {
+ return
+
+
+ }
return (
{children}
diff --git a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx
new file mode 100644
index 0000000000..da754794b1
--- /dev/null
+++ b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx
@@ -0,0 +1,96 @@
+'use client'
+import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react'
+import { useTranslation } from 'react-i18next'
+import { useState } from 'react'
+import { useRouter, useSearchParams } from 'next/navigation'
+import { useContext } from 'use-context-selector'
+import Countdown from '@/app/components/signin/countdown'
+import Button from '@/app/components/base/button'
+import Input from '@/app/components/base/input'
+import Toast from '@/app/components/base/toast'
+import { sendWebAppResetPasswordCode, verifyWebAppResetPasswordCode } from '@/service/common'
+import I18NContext from '@/context/i18n'
+
+export default function CheckCode() {
+ const { t } = useTranslation()
+ const router = useRouter()
+ const searchParams = useSearchParams()
+ const email = decodeURIComponent(searchParams.get('email') as string)
+ const token = decodeURIComponent(searchParams.get('token') as string)
+ const [code, setVerifyCode] = useState('')
+ const [loading, setIsLoading] = useState(false)
+ const { locale } = useContext(I18NContext)
+
+ const verify = async () => {
+ try {
+ if (!code.trim()) {
+ Toast.notify({
+ type: 'error',
+ message: t('login.checkCode.emptyCode'),
+ })
+ return
+ }
+ if (!/\d{6}/.test(code)) {
+ Toast.notify({
+ type: 'error',
+ message: t('login.checkCode.invalidCode'),
+ })
+ return
+ }
+ setIsLoading(true)
+ const ret = await verifyWebAppResetPasswordCode({ email, code, token })
+ if (ret.is_valid) {
+ const params = new URLSearchParams(searchParams)
+ params.set('token', encodeURIComponent(ret.token))
+ router.push(`/webapp-reset-password/set-password?${params.toString()}`)
+ }
+ }
+ catch (error) { console.error(error) }
+ finally {
+ setIsLoading(false)
+ }
+ }
+
+ const resendCode = async () => {
+ try {
+ const res = await sendWebAppResetPasswordCode(email, locale)
+ if (res.result === 'success') {
+ const params = new URLSearchParams(searchParams)
+ params.set('token', encodeURIComponent(res.data))
+ router.replace(`/webapp-reset-password/check-code?${params.toString()}`)
+ }
+ }
+ catch (error) { console.error(error) }
+ }
+
+ return
+
+
+
+
+
{t('login.checkCode.checkYourEmail')}
+
+
+
+ {t('login.checkCode.validTime')}
+
+
+
+
+
+
router.back()} className='flex h-9 cursor-pointer items-center justify-center text-text-tertiary'>
+
+
+
+
{t('login.back')}
+
+
+}
diff --git a/web/app/(shareLayout)/webapp-reset-password/layout.tsx b/web/app/(shareLayout)/webapp-reset-password/layout.tsx
new file mode 100644
index 0000000000..e0ac6b9ad6
--- /dev/null
+++ b/web/app/(shareLayout)/webapp-reset-password/layout.tsx
@@ -0,0 +1,30 @@
+'use client'
+import Header from '@/app/signin/_header'
+
+import cn from '@/utils/classnames'
+import { useGlobalPublicStore } from '@/context/global-public-context'
+
+export default function SignInLayout({ children }: any) {
+ const { systemFeatures } = useGlobalPublicStore()
+ return <>
+
+
+
+
+ {!systemFeatures.branding.enabled &&
+ © {new Date().getFullYear()} LangGenius, Inc. All rights reserved.
+
}
+
+
+ >
+}
diff --git a/web/app/(shareLayout)/webapp-reset-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/page.tsx
new file mode 100644
index 0000000000..96cd4c5805
--- /dev/null
+++ b/web/app/(shareLayout)/webapp-reset-password/page.tsx
@@ -0,0 +1,104 @@
+'use client'
+import Link from 'next/link'
+import { RiArrowLeftLine, RiLockPasswordLine } from '@remixicon/react'
+import { useTranslation } from 'react-i18next'
+import { useState } from 'react'
+import { useRouter, useSearchParams } from 'next/navigation'
+import { useContext } from 'use-context-selector'
+import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown'
+import { emailRegex } from '@/config'
+import Button from '@/app/components/base/button'
+import Input from '@/app/components/base/input'
+import Toast from '@/app/components/base/toast'
+import { sendResetPasswordCode } from '@/service/common'
+import I18NContext from '@/context/i18n'
+import { noop } from 'lodash-es'
+import useDocumentTitle from '@/hooks/use-document-title'
+
+export default function CheckCode() {
+ const { t } = useTranslation()
+ useDocumentTitle('')
+ const searchParams = useSearchParams()
+ const router = useRouter()
+ const [email, setEmail] = useState('')
+ const [loading, setIsLoading] = useState(false)
+ const { locale } = useContext(I18NContext)
+
+ const handleGetEMailVerificationCode = async () => {
+ try {
+ if (!email) {
+ Toast.notify({ type: 'error', message: t('login.error.emailEmpty') })
+ return
+ }
+
+ if (!emailRegex.test(email)) {
+ Toast.notify({
+ type: 'error',
+ message: t('login.error.emailInValid'),
+ })
+ return
+ }
+ setIsLoading(true)
+ const res = await sendResetPasswordCode(email, locale)
+ if (res.result === 'success') {
+ localStorage.setItem(COUNT_DOWN_KEY, `${COUNT_DOWN_TIME_MS}`)
+ const params = new URLSearchParams(searchParams)
+ params.set('token', encodeURIComponent(res.data))
+ params.set('email', encodeURIComponent(email))
+ router.push(`/webapp-reset-password/check-code?${params.toString()}`)
+ }
+ else if (res.code === 'account_not_found') {
+ Toast.notify({
+ type: 'error',
+ message: t('login.error.registrationNotAllowed'),
+ })
+ }
+ else {
+ Toast.notify({
+ type: 'error',
+ message: res.data,
+ })
+ }
+ }
+ catch (error) {
+ console.error(error)
+ }
+ finally {
+ setIsLoading(false)
+ }
+ }
+
+ return
+
+
+
+
+
{t('login.resetPassword')}
+
+ {t('login.resetPasswordDesc')}
+
+
+
+
+
+
+
+
+
+
{t('login.backToLogin')}
+
+
+}
diff --git a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx
new file mode 100644
index 0000000000..9f9a8ad4e3
--- /dev/null
+++ b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx
@@ -0,0 +1,188 @@
+'use client'
+import { useCallback, useState } from 'react'
+import { useTranslation } from 'react-i18next'
+import { useRouter, useSearchParams } from 'next/navigation'
+import cn from 'classnames'
+import { RiCheckboxCircleFill } from '@remixicon/react'
+import { useCountDown } from 'ahooks'
+import Button from '@/app/components/base/button'
+import { changeWebAppPasswordWithToken } from '@/service/common'
+import Toast from '@/app/components/base/toast'
+import Input from '@/app/components/base/input'
+
+const validPassword = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/
+
+const ChangePasswordForm = () => {
+ const { t } = useTranslation()
+ const router = useRouter()
+ const searchParams = useSearchParams()
+ const token = decodeURIComponent(searchParams.get('token') || '')
+
+ const [password, setPassword] = useState('')
+ const [confirmPassword, setConfirmPassword] = useState('')
+ const [showSuccess, setShowSuccess] = useState(false)
+ const [showPassword, setShowPassword] = useState(false)
+ const [showConfirmPassword, setShowConfirmPassword] = useState(false)
+
+ const showErrorMessage = useCallback((message: string) => {
+ Toast.notify({
+ type: 'error',
+ message,
+ })
+ }, [])
+
+ const getSignInUrl = () => {
+ return `/webapp-signin?redirect_url=${searchParams.get('redirect_url') || ''}`
+ }
+
+ const AUTO_REDIRECT_TIME = 5000
+ const [leftTime, setLeftTime] = useState
(undefined)
+ const [countdown] = useCountDown({
+ leftTime,
+ onEnd: () => {
+ router.replace(getSignInUrl())
+ },
+ })
+
+ const valid = useCallback(() => {
+ if (!password.trim()) {
+ showErrorMessage(t('login.error.passwordEmpty'))
+ return false
+ }
+ if (!validPassword.test(password)) {
+ showErrorMessage(t('login.error.passwordInvalid'))
+ return false
+ }
+ if (password !== confirmPassword) {
+ showErrorMessage(t('common.account.notEqual'))
+ return false
+ }
+ return true
+ }, [password, confirmPassword, showErrorMessage, t])
+
+ const handleChangePassword = useCallback(async () => {
+ if (!valid())
+ return
+ try {
+ await changeWebAppPasswordWithToken({
+ url: '/forgot-password/resets',
+ body: {
+ token,
+ new_password: password,
+ password_confirm: confirmPassword,
+ },
+ })
+ setShowSuccess(true)
+ setLeftTime(AUTO_REDIRECT_TIME)
+ }
+ catch (error) {
+ console.error(error)
+ }
+ }, [password, token, valid, confirmPassword])
+
+ return (
+
+ {!showSuccess && (
+
+
+
+ {t('login.changePassword')}
+
+
+ {t('login.changePasswordTip')}
+
+
+
+
+
+ {/* Password */}
+
+
+
+
setPassword(e.target.value)}
+ placeholder={t('login.passwordPlaceholder') || ''}
+ />
+
+
+
+
+
+
{t('login.error.passwordInvalid')}
+
+ {/* Confirm Password */}
+
+
+
+
setConfirmPassword(e.target.value)}
+ placeholder={t('login.confirmPasswordPlaceholder') || ''}
+ />
+
+
+
+
+
+
+
+
+
+
+
+ )}
+ {showSuccess && (
+
+
+
+
+
+
+ {t('login.passwordChangedTip')}
+
+
+
+
+
+
+ )}
+
+ )
+}
+
+export default ChangePasswordForm
diff --git a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx
new file mode 100644
index 0000000000..1b8f18c98f
--- /dev/null
+++ b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx
@@ -0,0 +1,115 @@
+'use client'
+import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react'
+import { useTranslation } from 'react-i18next'
+import { useCallback, useState } from 'react'
+import { useRouter, useSearchParams } from 'next/navigation'
+import { useContext } from 'use-context-selector'
+import Countdown from '@/app/components/signin/countdown'
+import Button from '@/app/components/base/button'
+import Input from '@/app/components/base/input'
+import Toast from '@/app/components/base/toast'
+import { sendWebAppEMailLoginCode, webAppEmailLoginWithCode } from '@/service/common'
+import I18NContext from '@/context/i18n'
+import { setAccessToken } from '@/app/components/share/utils'
+import { fetchAccessToken } from '@/service/share'
+
+export default function CheckCode() {
+ const { t } = useTranslation()
+ const router = useRouter()
+ const searchParams = useSearchParams()
+ const email = decodeURIComponent(searchParams.get('email') as string)
+ const token = decodeURIComponent(searchParams.get('token') as string)
+ const [code, setVerifyCode] = useState('')
+ const [loading, setIsLoading] = useState(false)
+ const { locale } = useContext(I18NContext)
+ const redirectUrl = searchParams.get('redirect_url')
+
+ const getAppCodeFromRedirectUrl = useCallback(() => {
+ const appCode = redirectUrl?.split('/').pop()
+ if (!appCode)
+ return null
+
+ return appCode
+ }, [redirectUrl])
+
+ const verify = async () => {
+ try {
+ const appCode = getAppCodeFromRedirectUrl()
+ if (!code.trim()) {
+ Toast.notify({
+ type: 'error',
+ message: t('login.checkCode.emptyCode'),
+ })
+ return
+ }
+ if (!/\d{6}/.test(code)) {
+ Toast.notify({
+ type: 'error',
+ message: t('login.checkCode.invalidCode'),
+ })
+ return
+ }
+ if (!redirectUrl || !appCode) {
+ Toast.notify({
+ type: 'error',
+ message: t('login.error.redirectUrlMissing'),
+ })
+ return
+ }
+ setIsLoading(true)
+ const ret = await webAppEmailLoginWithCode({ email, code, token })
+ if (ret.result === 'success') {
+ localStorage.setItem('webapp_access_token', ret.data.access_token)
+ const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: ret.data.access_token })
+ await setAccessToken(appCode, tokenResp.access_token)
+ router.replace(redirectUrl)
+ }
+ }
+ catch (error) { console.error(error) }
+ finally {
+ setIsLoading(false)
+ }
+ }
+
+ const resendCode = async () => {
+ try {
+ const ret = await sendWebAppEMailLoginCode(email, locale)
+ if (ret.result === 'success') {
+ const params = new URLSearchParams(searchParams)
+ params.set('token', encodeURIComponent(ret.data))
+ router.replace(`/webapp-signin/check-code?${params.toString()}`)
+ }
+ }
+ catch (error) { console.error(error) }
+ }
+
+ return
+
+
+
+
+
{t('login.checkCode.checkYourEmail')}
+
+
+
+ {t('login.checkCode.validTime')}
+
+
+
+
+
+
router.back()} className='flex h-9 cursor-pointer items-center justify-center text-text-tertiary'>
+
+
+
+
{t('login.back')}
+
+
+}
diff --git a/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx
new file mode 100644
index 0000000000..e9b15ae331
--- /dev/null
+++ b/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx
@@ -0,0 +1,80 @@
+'use client'
+import { useRouter, useSearchParams } from 'next/navigation'
+import React, { useCallback, useEffect } from 'react'
+import Toast from '@/app/components/base/toast'
+import { fetchWebOAuth2SSOUrl, fetchWebOIDCSSOUrl, fetchWebSAMLSSOUrl } from '@/service/share'
+import { useGlobalPublicStore } from '@/context/global-public-context'
+import { SSOProtocol } from '@/types/feature'
+import Loading from '@/app/components/base/loading'
+import AppUnavailable from '@/app/components/base/app-unavailable'
+
+const ExternalMemberSSOAuth = () => {
+ const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
+ const searchParams = useSearchParams()
+ const router = useRouter()
+
+ const redirectUrl = searchParams.get('redirect_url')
+
+ const showErrorToast = (message: string) => {
+ Toast.notify({
+ type: 'error',
+ message,
+ })
+ }
+
+ const getAppCodeFromRedirectUrl = useCallback(() => {
+ const appCode = redirectUrl?.split('/').pop()
+ if (!appCode)
+ return null
+
+ return appCode
+ }, [redirectUrl])
+
+ const handleSSOLogin = useCallback(async () => {
+ const appCode = getAppCodeFromRedirectUrl()
+ if (!appCode || !redirectUrl) {
+ showErrorToast('redirect url or app code is invalid.')
+ return
+ }
+
+ switch (systemFeatures.webapp_auth.sso_config.protocol) {
+ case SSOProtocol.SAML: {
+ const samlRes = await fetchWebSAMLSSOUrl(appCode, redirectUrl)
+ router.push(samlRes.url)
+ break
+ }
+ case SSOProtocol.OIDC: {
+ const oidcRes = await fetchWebOIDCSSOUrl(appCode, redirectUrl)
+ router.push(oidcRes.url)
+ break
+ }
+ case SSOProtocol.OAuth2: {
+ const oauth2Res = await fetchWebOAuth2SSOUrl(appCode, redirectUrl)
+ router.push(oauth2Res.url)
+ break
+ }
+ case '':
+ break
+ default:
+ showErrorToast('SSO protocol is not supported.')
+ }
+ }, [getAppCodeFromRedirectUrl, redirectUrl, router, systemFeatures.webapp_auth.sso_config.protocol])
+
+ useEffect(() => {
+ handleSSOLogin()
+ }, [handleSSOLogin])
+
+ if (!systemFeatures.webapp_auth.sso_config.protocol) {
+ return
+ }
+
+ return (
+
+
+
+ )
+}
+
+export default React.memo(ExternalMemberSSOAuth)
diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx
new file mode 100644
index 0000000000..29af3e3a57
--- /dev/null
+++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx
@@ -0,0 +1,68 @@
+import { useState } from 'react'
+import { useTranslation } from 'react-i18next'
+import { useRouter, useSearchParams } from 'next/navigation'
+import { useContext } from 'use-context-selector'
+import Input from '@/app/components/base/input'
+import Button from '@/app/components/base/button'
+import { emailRegex } from '@/config'
+import Toast from '@/app/components/base/toast'
+import { sendWebAppEMailLoginCode } from '@/service/common'
+import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown'
+import I18NContext from '@/context/i18n'
+import { noop } from 'lodash-es'
+
+export default function MailAndCodeAuth() {
+ const { t } = useTranslation()
+ const router = useRouter()
+ const searchParams = useSearchParams()
+ const emailFromLink = decodeURIComponent(searchParams.get('email') || '')
+ const [email, setEmail] = useState(emailFromLink)
+ const [loading, setIsLoading] = useState(false)
+ const { locale } = useContext(I18NContext)
+
+ const handleGetEMailVerificationCode = async () => {
+ try {
+ if (!email) {
+ Toast.notify({ type: 'error', message: t('login.error.emailEmpty') })
+ return
+ }
+
+ if (!emailRegex.test(email)) {
+ Toast.notify({
+ type: 'error',
+ message: t('login.error.emailInValid'),
+ })
+ return
+ }
+ setIsLoading(true)
+ const ret = await sendWebAppEMailLoginCode(email, locale)
+ if (ret.result === 'success') {
+ localStorage.setItem(COUNT_DOWN_KEY, `${COUNT_DOWN_TIME_MS}`)
+ const params = new URLSearchParams(searchParams)
+ params.set('email', encodeURIComponent(email))
+ params.set('token', encodeURIComponent(ret.data))
+ router.push(`/webapp-signin/check-code?${params.toString()}`)
+ }
+ }
+ catch (error) {
+ console.error(error)
+ }
+ finally {
+ setIsLoading(false)
+ }
+ }
+
+ return (
+ )
+}
diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx
new file mode 100644
index 0000000000..d9e56af1b8
--- /dev/null
+++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx
@@ -0,0 +1,171 @@
+import Link from 'next/link'
+import { useCallback, useState } from 'react'
+import { useTranslation } from 'react-i18next'
+import { useRouter, useSearchParams } from 'next/navigation'
+import { useContext } from 'use-context-selector'
+import Button from '@/app/components/base/button'
+import Toast from '@/app/components/base/toast'
+import { emailRegex } from '@/config'
+import { webAppLogin } from '@/service/common'
+import Input from '@/app/components/base/input'
+import I18NContext from '@/context/i18n'
+import { noop } from 'lodash-es'
+import { setAccessToken } from '@/app/components/share/utils'
+import { fetchAccessToken } from '@/service/share'
+
+type MailAndPasswordAuthProps = {
+ isEmailSetup: boolean
+}
+
+const passwordRegex = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/
+
+export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAuthProps) {
+ const { t } = useTranslation()
+ const { locale } = useContext(I18NContext)
+ const router = useRouter()
+ const searchParams = useSearchParams()
+ const [showPassword, setShowPassword] = useState(false)
+ const emailFromLink = decodeURIComponent(searchParams.get('email') || '')
+ const [email, setEmail] = useState(emailFromLink)
+ const [password, setPassword] = useState('')
+
+ const [isLoading, setIsLoading] = useState(false)
+ const redirectUrl = searchParams.get('redirect_url')
+
+ const getAppCodeFromRedirectUrl = useCallback(() => {
+ const appCode = redirectUrl?.split('/').pop()
+ if (!appCode)
+ return null
+
+ return appCode
+ }, [redirectUrl])
+ const handleEmailPasswordLogin = async () => {
+ const appCode = getAppCodeFromRedirectUrl()
+ if (!email) {
+ Toast.notify({ type: 'error', message: t('login.error.emailEmpty') })
+ return
+ }
+ if (!emailRegex.test(email)) {
+ Toast.notify({
+ type: 'error',
+ message: t('login.error.emailInValid'),
+ })
+ return
+ }
+ if (!password?.trim()) {
+ Toast.notify({ type: 'error', message: t('login.error.passwordEmpty') })
+ return
+ }
+ if (!passwordRegex.test(password)) {
+ Toast.notify({
+ type: 'error',
+ message: t('login.error.passwordInvalid'),
+ })
+ return
+ }
+ if (!redirectUrl || !appCode) {
+ Toast.notify({
+ type: 'error',
+ message: t('login.error.redirectUrlMissing'),
+ })
+ return
+ }
+ try {
+ setIsLoading(true)
+ const loginData: Record = {
+ email,
+ password,
+ language: locale,
+ remember_me: true,
+ }
+
+ const res = await webAppLogin({
+ url: '/login',
+ body: loginData,
+ })
+ if (res.result === 'success') {
+ localStorage.setItem('webapp_access_token', res.data.access_token)
+ const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: res.data.access_token })
+ await setAccessToken(appCode, tokenResp.access_token)
+ router.replace(redirectUrl)
+ }
+ else {
+ Toast.notify({
+ type: 'error',
+ message: res.data,
+ })
+ }
+ }
+
+ finally {
+ setIsLoading(false)
+ }
+ }
+
+ return
+}
diff --git a/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx
new file mode 100644
index 0000000000..5d649322ba
--- /dev/null
+++ b/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx
@@ -0,0 +1,88 @@
+'use client'
+import { useRouter, useSearchParams } from 'next/navigation'
+import type { FC } from 'react'
+import { useCallback } from 'react'
+import { useState } from 'react'
+import { useTranslation } from 'react-i18next'
+import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security'
+import Toast from '@/app/components/base/toast'
+import Button from '@/app/components/base/button'
+import { SSOProtocol } from '@/types/feature'
+import { fetchMembersOAuth2SSOUrl, fetchMembersOIDCSSOUrl, fetchMembersSAMLSSOUrl } from '@/service/share'
+
+type SSOAuthProps = {
+ protocol: SSOProtocol | ''
+}
+
+const SSOAuth: FC = ({
+ protocol,
+}) => {
+ const router = useRouter()
+ const { t } = useTranslation()
+ const searchParams = useSearchParams()
+
+ const redirectUrl = searchParams.get('redirect_url')
+ const getAppCodeFromRedirectUrl = useCallback(() => {
+ const appCode = redirectUrl?.split('/').pop()
+ if (!appCode)
+ return null
+
+ return appCode
+ }, [redirectUrl])
+
+ const [isLoading, setIsLoading] = useState(false)
+
+ const handleSSOLogin = () => {
+ const appCode = getAppCodeFromRedirectUrl()
+ if (!redirectUrl || !appCode) {
+ Toast.notify({
+ type: 'error',
+ message: 'invalid redirect URL or app code',
+ })
+ return
+ }
+ setIsLoading(true)
+ if (protocol === SSOProtocol.SAML) {
+ fetchMembersSAMLSSOUrl(appCode, redirectUrl).then((res) => {
+ router.push(res.url)
+ }).finally(() => {
+ setIsLoading(false)
+ })
+ }
+ else if (protocol === SSOProtocol.OIDC) {
+ fetchMembersOIDCSSOUrl(appCode, redirectUrl).then((res) => {
+ router.push(res.url)
+ }).finally(() => {
+ setIsLoading(false)
+ })
+ }
+ else if (protocol === SSOProtocol.OAuth2) {
+ fetchMembersOAuth2SSOUrl(appCode, redirectUrl).then((res) => {
+ router.push(res.url)
+ }).finally(() => {
+ setIsLoading(false)
+ })
+ }
+ else {
+ Toast.notify({
+ type: 'error',
+ message: 'invalid SSO protocol',
+ })
+ setIsLoading(false)
+ }
+ }
+
+ return (
+
+ )
+}
+
+export default SSOAuth
diff --git a/web/app/(shareLayout)/webapp-signin/layout.tsx b/web/app/(shareLayout)/webapp-signin/layout.tsx
new file mode 100644
index 0000000000..a03364d326
--- /dev/null
+++ b/web/app/(shareLayout)/webapp-signin/layout.tsx
@@ -0,0 +1,25 @@
+'use client'
+
+import cn from '@/utils/classnames'
+import { useGlobalPublicStore } from '@/context/global-public-context'
+import useDocumentTitle from '@/hooks/use-document-title'
+
+export default function SignInLayout({ children }: any) {
+ const { systemFeatures } = useGlobalPublicStore()
+ useDocumentTitle('')
+ return <>
+
+
+ {/*
*/}
+
+ {systemFeatures.branding.enabled === false &&
+ © {new Date().getFullYear()} LangGenius, Inc. All rights reserved.
+
}
+
+
+ >
+}
diff --git a/web/app/(shareLayout)/webapp-signin/normalForm.tsx b/web/app/(shareLayout)/webapp-signin/normalForm.tsx
new file mode 100644
index 0000000000..d6bdf607ba
--- /dev/null
+++ b/web/app/(shareLayout)/webapp-signin/normalForm.tsx
@@ -0,0 +1,176 @@
+import React, { useCallback, useEffect, useState } from 'react'
+import { useTranslation } from 'react-i18next'
+import Link from 'next/link'
+import { RiContractLine, RiDoorLockLine, RiErrorWarningFill } from '@remixicon/react'
+import Loading from '@/app/components/base/loading'
+import MailAndCodeAuth from './components/mail-and-code-auth'
+import MailAndPasswordAuth from './components/mail-and-password-auth'
+import SSOAuth from './components/sso-auth'
+import cn from '@/utils/classnames'
+import { LicenseStatus } from '@/types/feature'
+import { IS_CE_EDITION } from '@/config'
+import { useGlobalPublicStore } from '@/context/global-public-context'
+
+const NormalForm = () => {
+ const { t } = useTranslation()
+
+ const [isLoading, setIsLoading] = useState(true)
+ const { systemFeatures } = useGlobalPublicStore()
+ const [authType, updateAuthType] = useState<'code' | 'password'>('password')
+ const [showORLine, setShowORLine] = useState(false)
+ const [allMethodsAreDisabled, setAllMethodsAreDisabled] = useState(false)
+
+ const init = useCallback(async () => {
+ try {
+ setAllMethodsAreDisabled(!systemFeatures.enable_social_oauth_login && !systemFeatures.enable_email_code_login && !systemFeatures.enable_email_password_login && !systemFeatures.sso_enforced_for_signin)
+ setShowORLine((systemFeatures.enable_social_oauth_login || systemFeatures.sso_enforced_for_signin) && (systemFeatures.enable_email_code_login || systemFeatures.enable_email_password_login))
+ updateAuthType(systemFeatures.enable_email_password_login ? 'password' : 'code')
+ }
+ catch (error) {
+ console.error(error)
+ setAllMethodsAreDisabled(true)
+ }
+ finally { setIsLoading(false) }
+ }, [systemFeatures])
+ useEffect(() => {
+ init()
+ }, [init])
+ if (isLoading) {
+ return
+
+
+ }
+ if (systemFeatures.license?.status === LicenseStatus.LOST) {
+ return
+
+
+
+
+
+
+
{t('login.licenseLost')}
+
{t('login.licenseLostTip')}
+
+
+
+ }
+ if (systemFeatures.license?.status === LicenseStatus.EXPIRED) {
+ return
+
+
+
+
+
+
+
{t('login.licenseExpired')}
+
{t('login.licenseExpiredTip')}
+
+
+
+ }
+ if (systemFeatures.license?.status === LicenseStatus.INACTIVE) {
+ return
+
+
+
+
+
+
+
{t('login.licenseInactive')}
+
{t('login.licenseInactiveTip')}
+
+
+
+ }
+
+ return (
+ <>
+
+
+
{t('login.pageTitle')}
+ {!systemFeatures.branding.enabled &&
{t('login.welcome')}
}
+
+
+
+ {systemFeatures.sso_enforced_for_signin &&
+
+
}
+
+
+ {showORLine &&
+
+
+ {t('login.or')}
+
+
}
+ {
+ (systemFeatures.enable_email_code_login || systemFeatures.enable_email_password_login) && <>
+ {systemFeatures.enable_email_code_login && authType === 'code' && <>
+
+ {systemFeatures.enable_email_password_login &&
{ updateAuthType('password') }}>
+ {t('login.usePassword')}
+
}
+ >}
+ {systemFeatures.enable_email_password_login && authType === 'password' && <>
+
+ {systemFeatures.enable_email_code_login &&
{ updateAuthType('code') }}>
+ {t('login.useVerificationCode')}
+
}
+ >}
+ >
+ }
+ {allMethodsAreDisabled && <>
+
+
+
+
+
{t('login.noLoginMethod')}
+
{t('login.noLoginMethodTip')}
+
+
+ >}
+ {!systemFeatures.branding.enabled && <>
+
+ {t('login.tosDesc')}
+
+ {t('login.tos')}
+ &
+ {t('login.pp')}
+
+ {IS_CE_EDITION &&
+ {t('login.goToInit')}
+
+ {t('login.setAdminAccount')}
+
}
+ >}
+
+
+
+ >
+ )
+}
+
+export default NormalForm
diff --git a/web/app/(shareLayout)/webapp-signin/page.tsx b/web/app/(shareLayout)/webapp-signin/page.tsx
index 668c3f312c..c12fde38dd 100644
--- a/web/app/(shareLayout)/webapp-signin/page.tsx
+++ b/web/app/(shareLayout)/webapp-signin/page.tsx
@@ -3,19 +3,20 @@ import { useRouter, useSearchParams } from 'next/navigation'
import type { FC } from 'react'
import React, { useCallback, useEffect } from 'react'
import { useTranslation } from 'react-i18next'
-import { RiDoorLockLine } from '@remixicon/react'
-import cn from '@/utils/classnames'
import Toast from '@/app/components/base/toast'
-import { fetchWebOAuth2SSOUrl, fetchWebOIDCSSOUrl, fetchWebSAMLSSOUrl } from '@/service/share'
-import { setAccessToken } from '@/app/components/share/utils'
+import { removeAccessToken, setAccessToken } from '@/app/components/share/utils'
import { useGlobalPublicStore } from '@/context/global-public-context'
-import { SSOProtocol } from '@/types/feature'
import Loading from '@/app/components/base/loading'
import AppUnavailable from '@/app/components/base/app-unavailable'
+import NormalForm from './normalForm'
+import { AccessMode } from '@/models/access-control'
+import ExternalMemberSsoAuth from './components/external-member-sso-auth'
+import { fetchAccessToken } from '@/service/share'
const WebSSOForm: FC = () => {
const { t } = useTranslation()
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
+ const webAppAccessMode = useGlobalPublicStore(s => s.webAppAccessMode)
const searchParams = useSearchParams()
const router = useRouter()
@@ -23,10 +24,22 @@ const WebSSOForm: FC = () => {
const tokenFromUrl = searchParams.get('web_sso_token')
const message = searchParams.get('message')
- const showErrorToast = (message: string) => {
+ const getSigninUrl = useCallback(() => {
+ const params = new URLSearchParams(searchParams)
+ params.delete('message')
+ return `/webapp-signin?${params.toString()}`
+ }, [searchParams])
+
+ const backToHome = useCallback(() => {
+ removeAccessToken()
+ const url = getSigninUrl()
+ router.replace(url)
+ }, [getSigninUrl, router])
+
+ const showErrorToast = (msg: string) => {
Toast.notify({
type: 'error',
- message,
+ message: msg,
})
}
@@ -38,102 +51,73 @@ const WebSSOForm: FC = () => {
return appCode
}, [redirectUrl])
- const processTokenAndRedirect = useCallback(async () => {
- const appCode = getAppCodeFromRedirectUrl()
- if (!appCode || !tokenFromUrl || !redirectUrl) {
- showErrorToast('redirect url or app code or token is invalid.')
- return
- }
-
- await setAccessToken(appCode, tokenFromUrl)
- router.push(redirectUrl)
- }, [getAppCodeFromRedirectUrl, redirectUrl, router, tokenFromUrl])
-
- const handleSSOLogin = useCallback(async () => {
- const appCode = getAppCodeFromRedirectUrl()
- if (!appCode || !redirectUrl) {
- showErrorToast('redirect url or app code is invalid.')
- return
- }
-
- switch (systemFeatures.webapp_auth.sso_config.protocol) {
- case SSOProtocol.SAML: {
- const samlRes = await fetchWebSAMLSSOUrl(appCode, redirectUrl)
- router.push(samlRes.url)
- break
- }
- case SSOProtocol.OIDC: {
- const oidcRes = await fetchWebOIDCSSOUrl(appCode, redirectUrl)
- router.push(oidcRes.url)
- break
- }
- case SSOProtocol.OAuth2: {
- const oauth2Res = await fetchWebOAuth2SSOUrl(appCode, redirectUrl)
- router.push(oauth2Res.url)
- break
- }
- case '':
- break
- default:
- showErrorToast('SSO protocol is not supported.')
- }
- }, [getAppCodeFromRedirectUrl, redirectUrl, router, systemFeatures.webapp_auth.sso_config.protocol])
-
useEffect(() => {
- const init = async () => {
- if (message) {
- showErrorToast(message)
+ (async () => {
+ if (message)
return
- }
- if (!tokenFromUrl) {
- await handleSSOLogin()
+ const appCode = getAppCodeFromRedirectUrl()
+ if (appCode && tokenFromUrl && redirectUrl) {
+ localStorage.setItem('webapp_access_token', tokenFromUrl)
+ const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: tokenFromUrl })
+ await setAccessToken(appCode, tokenResp.access_token)
+ router.replace(redirectUrl)
return
}
+ if (appCode && redirectUrl && localStorage.getItem('webapp_access_token')) {
+ const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: localStorage.getItem('webapp_access_token') })
+ await setAccessToken(appCode, tokenResp.access_token)
+ router.replace(redirectUrl)
+ }
+ })()
+ }, [getAppCodeFromRedirectUrl, redirectUrl, router, tokenFromUrl, message])
- await processTokenAndRedirect()
- }
+ useEffect(() => {
+ if (webAppAccessMode && webAppAccessMode === AccessMode.PUBLIC && redirectUrl)
+ router.replace(redirectUrl)
+ }, [webAppAccessMode, router, redirectUrl])
- init()
- }, [message, processTokenAndRedirect, tokenFromUrl, handleSSOLogin])
- if (tokenFromUrl)
- return
- if (message) {
+ if (tokenFromUrl) {
return
}
- if (systemFeatures.webapp_auth.enabled) {
- if (systemFeatures.webapp_auth.allow_sso) {
- return (
-
- )
- }
- return
-
-
-
-
-
{t('login.webapp.noLoginMethod')}
-
{t('login.webapp.noLoginMethodTip')}
-
-
+ if (message) {
+ return
+
+
{t('share.login.backToHome')}
+
+ }
+ if (!redirectUrl) {
+ showErrorToast('redirect url is invalid.')
+ return
+ }
+ if (webAppAccessMode && webAppAccessMode === AccessMode.PUBLIC) {
+ return
+
}
- else {
+ if (!systemFeatures.webapp_auth.enabled) {
return
{t('login.webapp.disabled')}
}
+ if (webAppAccessMode && (webAppAccessMode === AccessMode.ORGANIZATION || webAppAccessMode === AccessMode.SPECIFIC_GROUPS_MEMBERS)) {
+ return
+
+
+ }
+
+ if (webAppAccessMode && webAppAccessMode === AccessMode.EXTERNAL_MEMBERS)
+ return
+
+ return
+
+
{t('share.login.backToHome')}
+
}
export default React.memo(WebSSOForm)
diff --git a/web/app/components/app/app-access-control/index.tsx b/web/app/components/app/app-access-control/index.tsx
index 2f15c8ec48..13faaea957 100644
--- a/web/app/components/app/app-access-control/index.tsx
+++ b/web/app/components/app/app-access-control/index.tsx
@@ -1,6 +1,6 @@
'use client'
-import { Dialog } from '@headlessui/react'
-import { RiBuildingLine, RiGlobalLine } from '@remixicon/react'
+import { Description as DialogDescription, DialogTitle } from '@headlessui/react'
+import { RiBuildingLine, RiGlobalLine, RiVerifiedBadgeLine } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import { useCallback, useEffect } from 'react'
import Button from '../../base/button'
@@ -67,8 +67,8 @@ export default function AccessControl(props: AccessControlProps) {
return
- {t('app.accessControlDialog.title')}
- {t('app.accessControlDialog.description')}
+ {t('app.accessControlDialog.title')}
+ {t('app.accessControlDialog.description')}
@@ -80,12 +80,20 @@ export default function AccessControl(props: AccessControlProps) {
{t('app.accessControlDialog.accessItems.organization')}
- {!hideTip &&
}
+
+
+
+
+
{t('app.accessControlDialog.accessItems.external')}
+
+ {!hideTip &&
}
+
+
diff --git a/web/app/components/app/app-access-control/specific-groups-or-members.tsx b/web/app/components/app/app-access-control/specific-groups-or-members.tsx
index f4872f8c99..b30c8f1ba3 100644
--- a/web/app/components/app/app-access-control/specific-groups-or-members.tsx
+++ b/web/app/components/app/app-access-control/specific-groups-or-members.tsx
@@ -3,12 +3,10 @@ import { RiAlertFill, RiCloseCircleFill, RiLockLine, RiOrganizationChart } from
import { useTranslation } from 'react-i18next'
import { useCallback, useEffect } from 'react'
import Avatar from '../../base/avatar'
-import Divider from '../../base/divider'
import Tooltip from '../../base/tooltip'
import Loading from '../../base/loading'
import useAccessControlStore from '../../../../context/access-control-store'
import AddMemberOrGroupDialog from './add-member-or-group-pop'
-import { useGlobalPublicStore } from '@/context/global-public-context'
import type { AccessControlAccount, AccessControlGroup } from '@/models/access-control'
import { AccessMode } from '@/models/access-control'
import { useAppWhiteListSubjects } from '@/service/access-control'
@@ -19,11 +17,6 @@ export default function SpecificGroupsOrMembers() {
const setSpecificGroups = useAccessControlStore(s => s.setSpecificGroups)
const setSpecificMembers = useAccessControlStore(s => s.setSpecificMembers)
const { t } = useTranslation()
- const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
- const hideTip = systemFeatures.webapp_auth.enabled
- && (systemFeatures.webapp_auth.allow_sso
- || systemFeatures.webapp_auth.allow_email_password_login
- || systemFeatures.webapp_auth.allow_email_code_login)
const { isPending, data } = useAppWhiteListSubjects(appId, Boolean(appId) && currentMenu === AccessMode.SPECIFIC_GROUPS_MEMBERS)
useEffect(() => {
@@ -37,7 +30,6 @@ export default function SpecificGroupsOrMembers() {
{t('app.accessControlDialog.accessItems.specific')}
- {!hideTip && }
}
@@ -48,10 +40,6 @@ export default function SpecificGroupsOrMembers() {
{t('app.accessControlDialog.accessItems.specific')}
- {!hideTip && <>
-
-
- >}
diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx
index 8d0028c7d7..5825bb72ee 100644
--- a/web/app/components/app/app-publisher/index.tsx
+++ b/web/app/components/app/app-publisher/index.tsx
@@ -9,11 +9,14 @@ import dayjs from 'dayjs'
import {
RiArrowDownSLine,
RiArrowRightSLine,
+ RiBuildingLine,
+ RiGlobalLine,
RiLockLine,
RiPlanetLine,
RiPlayCircleLine,
RiPlayList2Line,
RiTerminalBoxLine,
+ RiVerifiedBadgeLine,
} from '@remixicon/react'
import { useKeyPress } from 'ahooks'
import { getKeyboardKeyCodeBySystem } from '../../workflow/utils'
@@ -276,10 +279,30 @@ const AppPublisher = ({
setShowAppAccessControl(true)
}}>
-
- {appDetail?.access_mode === AccessMode.ORGANIZATION &&
{t('app.accessControlDialog.accessItems.organization')}
}
- {appDetail?.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS &&
{t('app.accessControlDialog.accessItems.specific')}
}
- {appDetail?.access_mode === AccessMode.PUBLIC &&
{t('app.accessControlDialog.accessItems.anyone')}
}
+ {appDetail?.access_mode === AccessMode.ORGANIZATION
+ && <>
+
+
{t('app.accessControlDialog.accessItems.organization')}
+ >
+ }
+ {appDetail?.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS
+ && <>
+
+
{t('app.accessControlDialog.accessItems.specific')}
+ >
+ }
+ {appDetail?.access_mode === AccessMode.PUBLIC
+ && <>
+
+
{t('app.accessControlDialog.accessItems.anyone')}
+ >
+ }
+ {appDetail?.access_mode === AccessMode.EXTERNAL_MEMBERS
+ && <>
+
+
{t('app.accessControlDialog.accessItems.external')}
+ >
+ }
{!isAppAccessSet && {t('app.publishApp.notSet')}
}
diff --git a/web/app/components/app/overview/appCard.tsx b/web/app/components/app/overview/appCard.tsx
index 9b283cdf5e..9f3b3ac4a6 100644
--- a/web/app/components/app/overview/appCard.tsx
+++ b/web/app/components/app/overview/appCard.tsx
@@ -5,10 +5,13 @@ import { useTranslation } from 'react-i18next'
import {
RiArrowRightSLine,
RiBookOpenLine,
+ RiBuildingLine,
RiEqualizer2Line,
RiExternalLinkLine,
+ RiGlobalLine,
RiLockLine,
RiPaintBrushLine,
+ RiVerifiedBadgeLine,
RiWindowLine,
} from '@remixicon/react'
import SettingsModal from './settings'
@@ -248,11 +251,30 @@ function AppCard({
-
- {appDetail?.access_mode === AccessMode.ORGANIZATION &&
{t('app.accessControlDialog.accessItems.organization')}
}
- {appDetail?.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS &&
{t('app.accessControlDialog.accessItems.specific')}
}
- {appDetail?.access_mode === AccessMode.PUBLIC &&
{t('app.accessControlDialog.accessItems.anyone')}
}
-
+ {appDetail?.access_mode === AccessMode.ORGANIZATION
+ && <>
+
+
{t('app.accessControlDialog.accessItems.organization')}
+ >
+ }
+ {appDetail?.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS
+ && <>
+
+
{t('app.accessControlDialog.accessItems.specific')}
+ >
+ }
+ {appDetail?.access_mode === AccessMode.PUBLIC
+ && <>
+
+
{t('app.accessControlDialog.accessItems.anyone')}
+ >
+ }
+ {appDetail?.access_mode === AccessMode.EXTERNAL_MEMBERS
+ && <>
+
+
{t('app.accessControlDialog.accessItems.external')}
+ >
+ }
{!isAppAccessSet &&
{t('app.publishApp.notSet')}
}
diff --git a/web/app/components/base/app-unavailable.tsx b/web/app/components/base/app-unavailable.tsx
index 4e835cbfcf..928c850262 100644
--- a/web/app/components/base/app-unavailable.tsx
+++ b/web/app/components/base/app-unavailable.tsx
@@ -1,4 +1,5 @@
'use client'
+import classNames from '@/utils/classnames'
import type { FC } from 'react'
import React from 'react'
import { useTranslation } from 'react-i18next'
@@ -7,17 +8,19 @@ type IAppUnavailableProps = {
code?: number | string
isUnknownReason?: boolean
unknownReason?: string
+ className?: string
}
const AppUnavailable: FC
= ({
code = 404,
isUnknownReason,
unknownReason,
+ className,
}) => {
const { t } = useTranslation()
return (
-
+
({
- accessMode: AccessMode.SPECIFIC_GROUPS_MEMBERS,
userCanAccess: false,
currentConversationId: '',
appPrevChatTree: [],
diff --git a/web/app/components/base/chat/chat-with-history/hooks.tsx b/web/app/components/base/chat/chat-with-history/hooks.tsx
index 3694666139..32f74e6457 100644
--- a/web/app/components/base/chat/chat-with-history/hooks.tsx
+++ b/web/app/components/base/chat/chat-with-history/hooks.tsx
@@ -16,7 +16,7 @@ import type {
Feedback,
} from '../types'
import { CONVERSATION_ID_INFO } from '../constants'
-import { buildChatItemTree, getProcessedSystemVariablesFromUrlParams } from '../utils'
+import { buildChatItemTree, getProcessedSystemVariablesFromUrlParams, getRawInputsFromUrlParams } from '../utils'
import { addFileInfos, sortAgentSorts } from '../../../tools/utils'
import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils'
import {
@@ -43,9 +43,8 @@ import { useAppFavicon } from '@/hooks/use-app-favicon'
import { InputVarType } from '@/app/components/workflow/types'
import { TransferMethod } from '@/types/app'
import { noop } from 'lodash-es'
-import { useGetAppAccessMode, useGetUserCanAccessApp } from '@/service/access-control'
+import { useGetUserCanAccessApp } from '@/service/access-control'
import { useGlobalPublicStore } from '@/context/global-public-context'
-import { AccessMode } from '@/models/access-control'
function getFormattedChatList(messages: any[]) {
const newChatList: ChatItem[] = []
@@ -77,11 +76,6 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
const isInstalledApp = useMemo(() => !!installedAppInfo, [installedAppInfo])
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
const { data: appInfo, isLoading: appInfoLoading, error: appInfoError } = useSWR(installedAppInfo ? null : 'appInfo', fetchAppInfo)
- const { isPending: isGettingAccessMode, data: appAccessMode } = useGetAppAccessMode({
- appId: installedAppInfo?.app.id || appInfo?.app_id,
- isInstalledApp,
- enabled: systemFeatures.webapp_auth.enabled,
- })
const { isPending: isCheckingPermission, data: userCanAccessResult } = useGetUserCanAccessApp({
appId: installedAppInfo?.app.id || appInfo?.app_id,
isInstalledApp,
@@ -195,6 +189,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
const { t } = useTranslation()
const newConversationInputsRef = useRef>({})
const [newConversationInputs, setNewConversationInputs] = useState>({})
+ const [initInputs, setInitInputs] = useState>({})
const handleNewConversationInputsChange = useCallback((newInputs: Record) => {
newConversationInputsRef.current = newInputs
setNewConversationInputs(newInputs)
@@ -202,20 +197,29 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
const inputsForms = useMemo(() => {
return (appParams?.user_input_form || []).filter((item: any) => !item.external_data_tool).map((item: any) => {
if (item.paragraph) {
+ let value = initInputs[item.paragraph.variable]
+ if (value && item.paragraph.max_length && value.length > item.paragraph.max_length)
+ value = value.slice(0, item.paragraph.max_length)
+
return {
...item.paragraph,
+ default: value || item.default,
type: 'paragraph',
}
}
if (item.number) {
+ const convertedNumber = Number(initInputs[item.number.variable]) ?? undefined
return {
...item.number,
+ default: convertedNumber || item.default,
type: 'number',
}
}
if (item.select) {
+ const isInputInOptions = item.select.options.includes(initInputs[item.select.variable])
return {
...item.select,
+ default: (isInputInOptions ? initInputs[item.select.variable] : undefined) || item.default,
type: 'select',
}
}
@@ -234,17 +238,30 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
}
}
+ let value = initInputs[item['text-input'].variable]
+ if (value && item['text-input'].max_length && value.length > item['text-input'].max_length)
+ value = value.slice(0, item['text-input'].max_length)
+
return {
...item['text-input'],
+ default: value || item.default,
type: 'text-input',
}
})
- }, [appParams])
+ }, [initInputs, appParams])
const allInputsHidden = useMemo(() => {
return inputsForms.length > 0 && inputsForms.every(item => item.hide === true)
}, [inputsForms])
+ useEffect(() => {
+ // init inputs from url params
+ (async () => {
+ const inputs = await getRawInputsFromUrlParams()
+ setInitInputs(inputs)
+ })()
+ }, [])
+
useEffect(() => {
const conversationInputs: Record = {}
@@ -362,11 +379,11 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
if (conversationId)
setClearChatList(false)
}, [handleConversationIdInfoChange, setClearChatList])
- const handleNewConversation = useCallback(() => {
+ const handleNewConversation = useCallback(async () => {
currentChatInstanceRef.current.handleStop()
setShowNewConversationItemInList(true)
handleChangeConversation('')
- handleNewConversationInputsChange({})
+ handleNewConversationInputsChange(await getRawInputsFromUrlParams())
setClearChatList(true)
}, [handleChangeConversation, setShowNewConversationItemInList, handleNewConversationInputsChange, setClearChatList])
const handleUpdateConversationList = useCallback(() => {
@@ -469,8 +486,7 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
return {
appInfoError,
- appInfoLoading: appInfoLoading || (systemFeatures.webapp_auth.enabled && (isGettingAccessMode || isCheckingPermission)),
- accessMode: systemFeatures.webapp_auth.enabled ? appAccessMode?.accessMode : AccessMode.PUBLIC,
+ appInfoLoading: appInfoLoading || (systemFeatures.webapp_auth.enabled && isCheckingPermission),
userCanAccess: systemFeatures.webapp_auth.enabled ? userCanAccessResult?.result : true,
isInstalledApp,
appId,
diff --git a/web/app/components/base/chat/chat-with-history/index.tsx b/web/app/components/base/chat/chat-with-history/index.tsx
index de023e7f58..1fd1383196 100644
--- a/web/app/components/base/chat/chat-with-history/index.tsx
+++ b/web/app/components/base/chat/chat-with-history/index.tsx
@@ -124,7 +124,6 @@ const ChatWithHistoryWrap: FC = ({
const {
appInfoError,
appInfoLoading,
- accessMode,
userCanAccess,
appData,
appParams,
@@ -169,7 +168,6 @@ const ChatWithHistoryWrap: FC = ({
appInfoError,
appInfoLoading,
appData,
- accessMode,
userCanAccess,
appParams,
appMeta,
diff --git a/web/app/components/base/chat/chat-with-history/sidebar/index.tsx b/web/app/components/base/chat/chat-with-history/sidebar/index.tsx
index fd317ccf91..4e50c1cb79 100644
--- a/web/app/components/base/chat/chat-with-history/sidebar/index.tsx
+++ b/web/app/components/base/chat/chat-with-history/sidebar/index.tsx
@@ -19,7 +19,6 @@ import RenameModal from '@/app/components/base/chat/chat-with-history/sidebar/re
import DifyLogo from '@/app/components/base/logo/dify-logo'
import type { ConversationItem } from '@/models/share'
import cn from '@/utils/classnames'
-import { AccessMode } from '@/models/access-control'
import { useGlobalPublicStore } from '@/context/global-public-context'
type Props = {
@@ -30,7 +29,6 @@ const Sidebar = ({ isPanel }: Props) => {
const { t } = useTranslation()
const {
isInstalledApp,
- accessMode,
appData,
handleNewConversation,
pinnedConversationList,
@@ -140,7 +138,7 @@ const Sidebar = ({ isPanel }: Props) => {
)}
-
+
{/* powered by */}
{!appData?.custom_config?.remove_webapp_brand && (
diff --git a/web/app/components/base/chat/chat/answer/__mocks__/markdownContentSVG.ts b/web/app/components/base/chat/chat/answer/__mocks__/markdownContentSVG.ts
index bcc3ae628d..51995a4af5 100644
--- a/web/app/components/base/chat/chat/answer/__mocks__/markdownContentSVG.ts
+++ b/web/app/components/base/chat/chat/answer/__mocks__/markdownContentSVG.ts
@@ -3,7 +3,7 @@ export const markdownContentSVG = `