Merge branch 'main' into feat/tool-plugin-oauth

pull/22036/head
Harry 11 months ago
commit 8a954c0b19

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field( CURRENT_VERSION: str = Field(
description="Dify version", description="Dify version",
default="1.4.3", default="1.5.0",
) )
COMMIT_SHA: str = Field( COMMIT_SHA: str = Field(

@ -85,6 +85,7 @@ class MemberInviteEmailApi(Resource):
return { return {
"result": "success", "result": "success",
"invitation_results": invitation_results, "invitation_results": invitation_results,
"tenant_id": str(current_user.current_tenant.id),
}, 201 }, 201
@ -110,7 +111,7 @@ class MemberCancelInviteApi(Resource):
except Exception as e: except Exception as e:
raise ValueError(str(e)) raise ValueError(str(e))
return {"result": "success"}, 204 return {"result": "success", "tenant_id": str(current_user.current_tenant.id)}, 200
class MemberUpdateRoleApi(Resource): class MemberUpdateRoleApi(Resource):

@ -36,7 +36,6 @@ from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.message import MessageNotExistsError
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -480,8 +479,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message # get conversation and message
conversation = self._get_conversation(conversation_id) conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id) message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError("Message not exists")
# chatbot app # chatbot app
runner = AdvancedChatAppRunner( runner = AdvancedChatAppRunner(

@ -26,7 +26,6 @@ from factories import file_factory
from libs.flask_utils import preserve_flask_contexts from libs.flask_utils import preserve_flask_contexts
from models import Account, App, EndUser from models import Account, App, EndUser
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -238,8 +237,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message # get conversation and message
conversation = self._get_conversation(conversation_id) conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id) message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError("Message not exists")
# chatbot app # chatbot app
runner = AgentChatAppRunner() runner = AgentChatAppRunner()

@ -25,7 +25,6 @@ from factories import file_factory
from models.account import Account from models.account import Account
from models.model import App, EndUser from models.model import App, EndUser
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -224,8 +223,6 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message # get conversation and message
conversation = self._get_conversation(conversation_id) conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id) message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError("Message not exists")
# chatbot app # chatbot app
runner = ChatAppRunner() runner = ChatAppRunner()

@ -201,8 +201,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
try: try:
# get message # get message
message = self._get_message(message_id) message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError()
# chatbot app # chatbot app
runner = CompletionAppRunner() runner = CompletionAppRunner()

@ -29,6 +29,7 @@ from models.enums import CreatorUserRole
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -251,7 +252,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
return introduction or "" return introduction or ""
def _get_conversation(self, conversation_id: str): def _get_conversation(self, conversation_id: str) -> Conversation:
""" """
Get conversation by conversation id Get conversation by conversation id
:param conversation_id: conversation id :param conversation_id: conversation id
@ -260,11 +261,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
if not conversation: if not conversation:
raise ConversationNotExistsError() raise ConversationNotExistsError("Conversation not exists")
return conversation return conversation
def _get_message(self, message_id: str) -> Optional[Message]: def _get_message(self, message_id: str) -> Message:
""" """
Get message by message id Get message by message id
:param message_id: message id :param message_id: message id
@ -272,4 +273,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
""" """
message = db.session.query(Message).filter(Message.id == message_id).first() message = db.session.query(Message).filter(Message.id == message_id).first()
if message is None:
raise MessageNotExistsError("Message not exists")
return message return message

@ -534,7 +534,7 @@ class IndexingRunner:
# chunk nodes by chunk size # chunk nodes by chunk size
indexing_start_at = time.perf_counter() indexing_start_at = time.perf_counter()
tokens = 0 tokens = 0
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX: if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
# create keyword index # create keyword index
create_keyword_thread = threading.Thread( create_keyword_thread = threading.Thread(
target=self._process_keyword_index, target=self._process_keyword_index,
@ -572,7 +572,7 @@ class IndexingRunner:
for future in futures: for future in futures:
tokens += future.result() tokens += future.result()
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX: if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
create_keyword_thread.join() create_keyword_thread.join()
indexing_end_at = time.perf_counter() indexing_end_at = time.perf_counter()

@ -15,6 +15,7 @@ class OAuthHandler(BasePluginClient):
user_id: str, user_id: str,
plugin_id: str, plugin_id: str,
provider: str, provider: str,
redirect_uri: str,
system_credentials: Mapping[str, Any], system_credentials: Mapping[str, Any],
) -> PluginOAuthAuthorizationUrlResponse: ) -> PluginOAuthAuthorizationUrlResponse:
response = self._request_with_plugin_daemon_response_stream( response = self._request_with_plugin_daemon_response_stream(
@ -25,6 +26,7 @@ class OAuthHandler(BasePluginClient):
"user_id": user_id, "user_id": user_id,
"data": { "data": {
"provider": provider, "provider": provider,
"redirect_uri": redirect_uri,
"system_credentials": system_credentials, "system_credentials": system_credentials,
}, },
}, },
@ -43,6 +45,7 @@ class OAuthHandler(BasePluginClient):
user_id: str, user_id: str,
plugin_id: str, plugin_id: str,
provider: str, provider: str,
redirect_uri: str,
system_credentials: Mapping[str, Any], system_credentials: Mapping[str, Any],
request: Request, request: Request,
) -> PluginOAuthCredentialsResponse: ) -> PluginOAuthCredentialsResponse:
@ -61,6 +64,7 @@ class OAuthHandler(BasePluginClient):
"user_id": user_id, "user_id": user_id,
"data": { "data": {
"provider": provider, "provider": provider,
"redirect_uri": redirect_uri,
"system_credentials": system_credentials, "system_credentials": system_credentials,
# for json serialization # for json serialization
"raw_http_request": binascii.hexlify(raw_request_bytes).decode(), "raw_http_request": binascii.hexlify(raw_request_bytes).decode(),

@ -76,6 +76,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
if dataset.indexing_technique == "high_quality": if dataset.indexing_technique == "high_quality":
vector = Vector(dataset) vector = Vector(dataset)
vector.create(documents) vector.create(documents)
with_keywords = False
if with_keywords: if with_keywords:
keywords_list = kwargs.get("keywords_list") keywords_list = kwargs.get("keywords_list")
keyword = Keyword(dataset) keyword = Keyword(dataset)
@ -91,6 +92,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
vector.delete_by_ids(node_ids) vector.delete_by_ids(node_ids)
else: else:
vector.delete() vector.delete()
with_keywords = False
if with_keywords: if with_keywords:
keyword = Keyword(dataset) keyword = Keyword(dataset)
if node_ids: if node_ids:

@ -7,6 +7,7 @@ def append_variables_recursively(
): ):
""" """
Append variables recursively Append variables recursively
:param pool: variable pool to append variables to
:param node_id: node id :param node_id: node id
:param variable_key_list: variable key list :param variable_key_list: variable key list
:param variable_value: variable value :param variable_value: variable value

@ -300,7 +300,7 @@ class WorkflowEntry:
return node_instance, generator return node_instance, generator
except Exception as e: except Exception as e:
logger.exception( logger.exception(
"error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s", "error while running node_instance, node_id=%s, type=%s, version=%s",
node_instance.id, node_instance.id,
node_instance.node_type, node_instance.node_type,
node_instance.version(), node_instance.version(),

@ -3,8 +3,10 @@ from .clean_when_document_deleted import handle
from .create_document_index import handle from .create_document_index import handle
from .create_installed_app_when_app_created import handle from .create_installed_app_when_app_created import handle
from .create_site_record_when_app_created import handle from .create_site_record_when_app_created import handle
from .deduct_quota_when_message_created import handle
from .delete_tool_parameters_cache_when_sync_draft_workflow import handle from .delete_tool_parameters_cache_when_sync_draft_workflow import handle
from .update_app_dataset_join_when_app_model_config_updated import handle from .update_app_dataset_join_when_app_model_config_updated import handle
from .update_app_dataset_join_when_app_published_workflow_updated import handle from .update_app_dataset_join_when_app_published_workflow_updated import handle
from .update_provider_last_used_at_when_message_created import handle
# Consolidated handler replaces both deduct_quota_when_message_created and
# update_provider_last_used_at_when_message_created
from .update_provider_when_message_created import handle

@ -1,65 +0,0 @@
from datetime import UTC, datetime
from configs import dify_config
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
from core.entities.provider_entities import QuotaUnit
from core.plugin.entities.plugin import ModelProviderID
from events.message_event import message_was_created
from extensions.ext_database import db
from models.provider import Provider, ProviderType
@message_was_created.connect
def handle(sender, **kwargs):
message = sender
application_generate_entity = kwargs.get("application_generate_entity")
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
return
model_config = application_generate_entity.model_conf
provider_model_bundle = model_config.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
if not system_configuration.current_quota_type:
return
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = message.message_tokens + message.answer_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_config.model)
else:
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
db.session.query(Provider).filter(
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_config.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
).update(
{
"quota_used": Provider.quota_used + used_quota,
"last_used": datetime.now(tz=UTC).replace(tzinfo=None),
}
)
db.session.commit()

@ -1,20 +0,0 @@
from datetime import UTC, datetime
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
from events.message_event import message_was_created
from extensions.ext_database import db
from models.provider import Provider
@message_was_created.connect
def handle(sender, **kwargs):
application_generate_entity = kwargs.get("application_generate_entity")
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
return
db.session.query(Provider).filter(
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
Provider.provider_name == application_generate_entity.model_conf.provider,
).update({"last_used": datetime.now(UTC).replace(tzinfo=None)})
db.session.commit()

@ -0,0 +1,234 @@
import logging
import time as time_module
from datetime import datetime
from typing import Any, Optional
from pydantic import BaseModel
from sqlalchemy import update
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
from core.entities.provider_entities import QuotaUnit, SystemConfiguration
from core.plugin.entities.plugin import ModelProviderID
from events.message_event import message_was_created
from extensions.ext_database import db
from libs import datetime_utils
from models.model import Message
from models.provider import Provider, ProviderType
logger = logging.getLogger(__name__)
class _ProviderUpdateFilters(BaseModel):
"""Filters for identifying Provider records to update."""
tenant_id: str
provider_name: str
provider_type: Optional[str] = None
quota_type: Optional[str] = None
class _ProviderUpdateAdditionalFilters(BaseModel):
"""Additional filters for Provider updates."""
quota_limit_check: bool = False
class _ProviderUpdateValues(BaseModel):
"""Values to update in Provider records."""
last_used: Optional[datetime] = None
quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression
class _ProviderUpdateOperation(BaseModel):
"""A single Provider update operation."""
filters: _ProviderUpdateFilters
values: _ProviderUpdateValues
additional_filters: _ProviderUpdateAdditionalFilters = _ProviderUpdateAdditionalFilters()
description: str = "unknown"
@message_was_created.connect
def handle(sender: Message, **kwargs):
"""
Consolidated handler for Provider updates when a message is created.
This handler replaces both:
- update_provider_last_used_at_when_message_created
- deduct_quota_when_message_created
By performing all Provider updates in a single transaction, we ensure
consistency and efficiency when updating Provider records.
"""
message = sender
application_generate_entity = kwargs.get("application_generate_entity")
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
return
tenant_id = application_generate_entity.app_config.tenant_id
provider_name = application_generate_entity.model_conf.provider
current_time = datetime_utils.naive_utc_now()
# Prepare updates for both scenarios
updates_to_perform: list[_ProviderUpdateOperation] = []
# 1. Always update last_used for the provider
basic_update = _ProviderUpdateOperation(
filters=_ProviderUpdateFilters(
tenant_id=tenant_id,
provider_name=provider_name,
),
values=_ProviderUpdateValues(last_used=current_time),
description="basic_last_used_update",
)
updates_to_perform.append(basic_update)
# 2. Check if we need to deduct quota (system provider only)
model_config = application_generate_entity.model_conf
provider_model_bundle = model_config.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if (
provider_configuration.using_provider_type == ProviderType.SYSTEM
and provider_configuration.system_configuration
and provider_configuration.system_configuration.current_quota_type is not None
):
system_configuration = provider_configuration.system_configuration
# Calculate quota usage
used_quota = _calculate_quota_usage(
message=message,
system_configuration=system_configuration,
model_name=model_config.model,
)
if used_quota is not None:
quota_update = _ProviderUpdateOperation(
filters=_ProviderUpdateFilters(
tenant_id=tenant_id,
provider_name=ModelProviderID(model_config.provider).provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=provider_configuration.system_configuration.current_quota_type.value,
),
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
additional_filters=_ProviderUpdateAdditionalFilters(
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
),
description="quota_deduction_update",
)
updates_to_perform.append(quota_update)
# Execute all updates
start_time = time_module.perf_counter()
try:
_execute_provider_updates(updates_to_perform)
# Log successful completion with timing
duration = time_module.perf_counter() - start_time
logger.info(
f"Provider updates completed successfully. "
f"Updates: {len(updates_to_perform)}, Duration: {duration:.3f}s, "
f"Tenant: {tenant_id}, Provider: {provider_name}"
)
except Exception as e:
# Log failure with timing and context
duration = time_module.perf_counter() - start_time
logger.exception(
f"Provider updates failed after {duration:.3f}s. "
f"Updates: {len(updates_to_perform)}, Tenant: {tenant_id}, "
f"Provider: {provider_name}"
)
raise
def _calculate_quota_usage(
*, message: Message, system_configuration: SystemConfiguration, model_name: str
) -> Optional[int]:
"""Calculate quota usage based on message tokens and quota type."""
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return None
break
if quota_unit is None:
return None
try:
if quota_unit == QuotaUnit.TOKENS:
tokens = message.message_tokens + message.answer_tokens
return tokens
if quota_unit == QuotaUnit.CREDITS:
tokens = dify_config.get_model_credits(model_name)
return tokens
elif quota_unit == QuotaUnit.TIMES:
return 1
return None
except Exception as e:
logger.exception("Failed to calculate quota usage")
return None
def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]):
"""Execute all Provider updates in a single transaction."""
if not updates_to_perform:
return
# Use SQLAlchemy's context manager for transaction management
# This automatically handles commit/rollback
with Session(db.engine) as session:
# Use a single transaction for all updates
for update_operation in updates_to_perform:
filters = update_operation.filters
values = update_operation.values
additional_filters = update_operation.additional_filters
description = update_operation.description
# Build the where conditions
where_conditions = [
Provider.tenant_id == filters.tenant_id,
Provider.provider_name == filters.provider_name,
]
# Add additional filters if specified
if filters.provider_type is not None:
where_conditions.append(Provider.provider_type == filters.provider_type)
if filters.quota_type is not None:
where_conditions.append(Provider.quota_type == filters.quota_type)
if additional_filters.quota_limit_check:
where_conditions.append(Provider.quota_limit > Provider.quota_used)
# Prepare values dict for SQLAlchemy update
update_values = {}
if values.last_used is not None:
update_values["last_used"] = values.last_used
if values.quota_used is not None:
update_values["quota_used"] = values.quota_used
# Build and execute the update statement
stmt = update(Provider).where(*where_conditions).values(**update_values)
result = session.execute(stmt)
rows_affected = result.rowcount
logger.debug(
f"Provider update ({description}): {rows_affected} rows affected. "
f"Filters: {filters.model_dump()}, Values: {update_values}"
)
# If no rows were affected for quota updates, log a warning
if rows_affected == 0 and description == "quota_deduction_update":
logger.warning(
f"No Provider rows updated for quota deduction. "
f"This may indicate quota limit exceeded or provider not found. "
f"Filters: {filters.model_dump()}"
)
logger.debug(f"Successfully processed {len(updates_to_perform)} Provider updates")

@ -384,7 +384,7 @@ def get_file_type_by_mime_type(mime_type: str) -> FileType:
class StorageKeyLoader: class StorageKeyLoader:
"""FileKeyLoader load the storage key from database for a list of files. """FileKeyLoader load the storage key from database for a list of files.
This loader is batched, the This loader is batched, the database query count is constant regardless of the input size.
""" """
def __init__(self, session: Session, tenant_id: str) -> None: def __init__(self, session: Session, tenant_id: str) -> None:
@ -445,10 +445,10 @@ class StorageKeyLoader:
if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
upload_file_row = upload_files.get(model_id) upload_file_row = upload_files.get(model_id)
if upload_file_row is None: if upload_file_row is None:
raise ValueError(...) raise ValueError(f"Upload file not found for id: {model_id}")
file._storage_key = upload_file_row.key file._storage_key = upload_file_row.key
elif file.transfer_method == FileTransferMethod.TOOL_FILE: elif file.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file_row = tool_files.get(model_id) tool_file_row = tool_files.get(model_id)
if tool_file_row is None: if tool_file_row is None:
raise ValueError(...) raise ValueError(f"Tool file not found for id: {model_id}")
file._storage_key = tool_file_row.file_key file._storage_key = tool_file_row.file_key

@ -718,7 +718,6 @@ class Conversation(Base):
if "model" in override_model_configs: if "model" in override_model_configs:
app_model_config = AppModelConfig() app_model_config = AppModelConfig()
app_model_config = app_model_config.from_model_config_dict(override_model_configs) app_model_config = app_model_config.from_model_config_dict(override_model_configs)
assert app_model_config is not None, "app model config not found"
model_config = app_model_config.to_dict() model_config = app_model_config.to_dict()
else: else:
model_config["configs"] = override_model_configs model_config["configs"] = override_model_configs
@ -914,11 +913,11 @@ class Message(Base):
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON) _inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
query: Mapped[str] = db.Column(db.Text, nullable=False) query: Mapped[str] = db.Column(db.Text, nullable=False)
message = db.Column(db.JSON, nullable=False) message = db.Column(db.JSON, nullable=False)
message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) message_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
message_unit_price = db.Column(db.Numeric(10, 4), nullable=False) message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
answer: Mapped[str] = db.Column(db.Text, nullable=False) answer: Mapped[str] = db.Column(db.Text, nullable=False)
answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) answer_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False) answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001")) answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
parent_message_id = db.Column(StringUUID, nullable=True) parent_message_id = db.Column(StringUUID, nullable=True)

@ -155,6 +155,7 @@ dev = [
"types_setuptools>=80.9.0", "types_setuptools>=80.9.0",
"pandas-stubs~=2.2.3", "pandas-stubs~=2.2.3",
"scipy-stubs>=1.15.3.0", "scipy-stubs>=1.15.3.0",
"types-python-http-client>=3.3.7.20240910",
] ]
############################################################ ############################################################

@ -586,6 +586,10 @@ class DatasetService:
) )
except ProviderTokenNotInitError: except ProviderTokenNotInitError:
# If we can't get the embedding model, preserve existing settings # If we can't get the embedding model, preserve existing settings
logging.warning(
f"Failed to initialize embedding model {data['embedding_model_provider']}/{data['embedding_model']}, "
f"preserving existing settings"
)
if dataset.embedding_model_provider and dataset.embedding_model: if dataset.embedding_model_provider and dataset.embedding_model:
filtered_data["embedding_model_provider"] = dataset.embedding_model_provider filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
filtered_data["embedding_model"] = dataset.embedding_model filtered_data["embedding_model"] = dataset.embedding_model

@ -1,23 +0,0 @@
from typing import Optional
from core.moderation.factory import ModerationFactory, ModerationOutputsResult
from extensions.ext_database import db
from models.model import App, AppModelConfig
class ModerationService:
def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult:
app_model_config: Optional[AppModelConfig] = None
app_model_config = (
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
)
if not app_model_config:
raise ValueError("app model config not found")
name = app_model_config.sensitive_word_avoidance_dict["type"]
config = app_model_config.sensitive_word_avoidance_dict["config"]
moderation = ModerationFactory(name, app_id, app_model.tenant_id, config)
return moderation.moderation_for_outputs(text)

@ -97,7 +97,7 @@ class VectorService:
vector = Vector(dataset=dataset) vector = Vector(dataset=dataset)
vector.delete_by_ids([segment.index_node_id]) vector.delete_by_ids([segment.index_node_id])
vector.add_texts([document], duplicate_check=True) vector.add_texts([document], duplicate_check=True)
else:
# update keyword index # update keyword index
keyword = Keyword(dataset) keyword = Keyword(dataset)
keyword.delete_by_ids([segment.index_node_id]) keyword.delete_by_ids([segment.index_node_id])

@ -8,151 +8,298 @@ from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError from services.errors.account import NoPermissionError
class DatasetPermissionTestDataFactory:
"""Factory class for creating test data and mock objects for dataset permission tests."""
@staticmethod
def create_dataset_mock(
dataset_id: str = "dataset-123",
tenant_id: str = "test-tenant-123",
created_by: str = "creator-456",
permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
**kwargs,
) -> Mock:
"""Create a mock dataset with specified attributes."""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.tenant_id = tenant_id
dataset.created_by = created_by
dataset.permission = permission
for key, value in kwargs.items():
setattr(dataset, key, value)
return dataset
@staticmethod
def create_user_mock(
user_id: str = "user-789",
tenant_id: str = "test-tenant-123",
role: TenantAccountRole = TenantAccountRole.NORMAL,
**kwargs,
) -> Mock:
"""Create a mock user with specified attributes."""
user = Mock(spec=Account)
user.id = user_id
user.current_tenant_id = tenant_id
user.current_role = role
for key, value in kwargs.items():
setattr(user, key, value)
return user
@staticmethod
def create_dataset_permission_mock(
dataset_id: str = "dataset-123",
account_id: str = "user-789",
**kwargs,
) -> Mock:
"""Create a mock dataset permission record."""
permission = Mock(spec=DatasetPermission)
permission.dataset_id = dataset_id
permission.account_id = account_id
for key, value in kwargs.items():
setattr(permission, key, value)
return permission
class TestDatasetPermissionService: class TestDatasetPermissionService:
"""Test cases for dataset permission checking functionality""" """
Comprehensive unit tests for DatasetService.check_dataset_permission method.
def setup_method(self):
"""Set up test fixtures""" This test suite covers all permission scenarios including:
# Mock tenant and user - Cross-tenant access restrictions
self.tenant_id = "test-tenant-123" - Owner privilege checks
self.creator_id = "creator-456" - Different permission levels (ONLY_ME, ALL_TEAM, PARTIAL_TEAM)
self.normal_user_id = "normal-789" - Explicit permission checks for PARTIAL_TEAM
self.owner_user_id = "owner-999" - Error conditions and logging
"""
# Mock dataset
self.dataset = Mock(spec=Dataset) @pytest.fixture
self.dataset.id = "dataset-123" def mock_dataset_service_dependencies(self):
self.dataset.tenant_id = self.tenant_id """Common mock setup for dataset service dependencies."""
self.dataset.created_by = self.creator_id with patch("services.dataset_service.db.session") as mock_session:
yield {
# Mock users "db_session": mock_session,
self.creator_user = Mock(spec=Account) }
self.creator_user.id = self.creator_id
self.creator_user.current_tenant_id = self.tenant_id @pytest.fixture
self.creator_user.current_role = TenantAccountRole.EDITOR def mock_logging_dependencies(self):
"""Mock setup for logging tests."""
self.normal_user = Mock(spec=Account) with patch("services.dataset_service.logging") as mock_logging:
self.normal_user.id = self.normal_user_id yield {
self.normal_user.current_tenant_id = self.tenant_id "logging": mock_logging,
self.normal_user.current_role = TenantAccountRole.NORMAL }
self.owner_user = Mock(spec=Account) def _assert_permission_check_passes(self, dataset: Mock, user: Mock):
self.owner_user.id = self.owner_user_id """Helper method to verify that permission check passes without raising exceptions."""
self.owner_user.current_tenant_id = self.tenant_id # Should not raise any exception
self.owner_user.current_role = TenantAccountRole.OWNER DatasetService.check_dataset_permission(dataset, user)
def _assert_permission_check_fails(
self, dataset: Mock, user: Mock, expected_message: str = "You do not have permission to access this dataset."
):
"""Helper method to verify that permission check fails with expected error."""
with pytest.raises(NoPermissionError, match=expected_message):
DatasetService.check_dataset_permission(dataset, user)
def _assert_database_query_called(self, mock_session: Mock, dataset_id: str, account_id: str):
"""Helper method to verify database query calls for permission checks."""
mock_session.query().filter_by.assert_called_with(dataset_id=dataset_id, account_id=account_id)
def _assert_database_query_not_called(self, mock_session: Mock):
"""Helper method to verify that database query was not called."""
mock_session.query.assert_not_called()
# ==================== Cross-Tenant Access Tests ====================
def test_permission_check_different_tenant_should_fail(self): def test_permission_check_different_tenant_should_fail(self):
"""Test that users from different tenants cannot access dataset""" """Test that users from different tenants cannot access dataset regardless of other permissions."""
self.normal_user.current_tenant_id = "different-tenant" # Create dataset and user from different tenants
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
tenant_id="tenant-123", permission=DatasetPermissionEnum.ALL_TEAM
)
user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="user-789", tenant_id="different-tenant-456", role=TenantAccountRole.EDITOR
)
# Should fail due to different tenant
self._assert_permission_check_fails(dataset, user)
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset."): # ==================== Owner Privilege Tests ====================
DatasetService.check_dataset_permission(self.dataset, self.normal_user)
def test_owner_can_access_any_dataset(self): def test_owner_can_access_any_dataset(self):
"""Test that tenant owners can access any dataset regardless of permission""" """Test that tenant owners can access any dataset regardless of permission level."""
self.dataset.permission = DatasetPermissionEnum.ONLY_ME # Create dataset with restrictive permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ONLY_ME)
# Should not raise any exception # Create owner user
DatasetService.check_dataset_permission(self.dataset, self.owner_user) owner_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="owner-999", role=TenantAccountRole.OWNER
)
# Owner should have access regardless of dataset permission
self._assert_permission_check_passes(dataset, owner_user)
# ==================== ONLY_ME Permission Tests ====================
def test_only_me_permission_creator_can_access(self): def test_only_me_permission_creator_can_access(self):
"""Test ONLY_ME permission allows only creator to access""" """Test ONLY_ME permission allows only the dataset creator to access."""
self.dataset.permission = DatasetPermissionEnum.ONLY_ME # Create dataset with ONLY_ME permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME
)
# Create creator user
creator_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="creator-456", role=TenantAccountRole.EDITOR
)
# Creator should be able to access # Creator should be able to access
DatasetService.check_dataset_permission(self.dataset, self.creator_user) self._assert_permission_check_passes(dataset, creator_user)
def test_only_me_permission_others_cannot_access(self): def test_only_me_permission_others_cannot_access(self):
"""Test ONLY_ME permission denies access to non-creators""" """Test ONLY_ME permission denies access to non-creators."""
self.dataset.permission = DatasetPermissionEnum.ONLY_ME # Create dataset with ONLY_ME permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME
)
# Create normal user (not the creator)
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="normal-789", role=TenantAccountRole.NORMAL
)
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset."): # Non-creator should be denied access
DatasetService.check_dataset_permission(self.dataset, self.normal_user) self._assert_permission_check_fails(dataset, normal_user)
# ==================== ALL_TEAM Permission Tests ====================
def test_all_team_permission_allows_access(self): def test_all_team_permission_allows_access(self):
"""Test ALL_TEAM permission allows any team member to access""" """Test ALL_TEAM permission allows any team member to access the dataset."""
self.dataset.permission = DatasetPermissionEnum.ALL_TEAM # Create dataset with ALL_TEAM permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ALL_TEAM)
# Create different types of team members
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="normal-789", role=TenantAccountRole.NORMAL
)
editor_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="editor-456", role=TenantAccountRole.EDITOR
)
# All team members should have access
self._assert_permission_check_passes(dataset, normal_user)
self._assert_permission_check_passes(dataset, editor_user)
# Should not raise any exception for team members # ==================== PARTIAL_TEAM Permission Tests ====================
DatasetService.check_dataset_permission(self.dataset, self.normal_user)
DatasetService.check_dataset_permission(self.dataset, self.creator_user)
@patch("services.dataset_service.db.session") def test_partial_team_permission_creator_can_access(self, mock_dataset_service_dependencies):
def test_partial_team_permission_creator_can_access(self, mock_session): """Test PARTIAL_TEAM permission allows creator to access without database query."""
"""Test PARTIAL_TEAM permission allows creator to access""" # Create dataset with PARTIAL_TEAM permission
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM
)
# Should not raise any exception for creator # Create creator user
DatasetService.check_dataset_permission(self.dataset, self.creator_user) creator_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="creator-456", role=TenantAccountRole.EDITOR
)
# Should not query database for creator # Creator should have access without database query
mock_session.query.assert_not_called() self._assert_permission_check_passes(dataset, creator_user)
self._assert_database_query_not_called(mock_dataset_service_dependencies["db_session"])
@patch("services.dataset_service.db.session") def test_partial_team_permission_with_explicit_permission(self, mock_dataset_service_dependencies):
def test_partial_team_permission_with_explicit_permission(self, mock_session): """Test PARTIAL_TEAM permission allows users with explicit permission records."""
"""Test PARTIAL_TEAM permission allows users with explicit permission""" # Create dataset with PARTIAL_TEAM permission
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
# Create normal user (not the creator)
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="normal-789", role=TenantAccountRole.NORMAL
)
# Mock database query to return a permission record # Mock database query to return a permission record
mock_permission = Mock(spec=DatasetPermission) mock_permission = DatasetPermissionTestDataFactory.create_dataset_permission_mock(
mock_session.query().filter_by().first.return_value = mock_permission dataset_id=dataset.id, account_id=normal_user.id
)
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = mock_permission
# Should not raise any exception # User with explicit permission should have access
DatasetService.check_dataset_permission(self.dataset, self.normal_user) self._assert_permission_check_passes(dataset, normal_user)
self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id)
# Verify database was queried correctly def test_partial_team_permission_without_explicit_permission(self, mock_dataset_service_dependencies):
mock_session.query().filter_by.assert_called_with(dataset_id=self.dataset.id, account_id=self.normal_user.id) """Test PARTIAL_TEAM permission denies users without explicit permission records."""
# Create dataset with PARTIAL_TEAM permission
dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
@patch("services.dataset_service.db.session") # Create normal user (not the creator)
def test_partial_team_permission_without_explicit_permission(self, mock_session): normal_user = DatasetPermissionTestDataFactory.create_user_mock(
"""Test PARTIAL_TEAM permission denies users without explicit permission""" user_id="normal-789", role=TenantAccountRole.NORMAL
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM )
# Mock database query to return None (no permission record) # Mock database query to return None (no permission record)
mock_session.query().filter_by().first.return_value = None mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset."):
DatasetService.check_dataset_permission(self.dataset, self.normal_user)
# Verify database was queried correctly # User without explicit permission should be denied access
mock_session.query().filter_by.assert_called_with(dataset_id=self.dataset.id, account_id=self.normal_user.id) self._assert_permission_check_fails(dataset, normal_user)
self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id)
@patch("services.dataset_service.db.session") def test_partial_team_permission_non_creator_without_permission_fails(self, mock_dataset_service_dependencies):
def test_partial_team_permission_non_creator_without_permission_fails(self, mock_session): """Test that non-creators without explicit permission are denied access to PARTIAL_TEAM datasets."""
"""Test that non-creators without explicit permission are denied access""" # Create dataset with PARTIAL_TEAM permission
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM
)
# Create a different user (not the creator) # Create a different user (not the creator)
other_user = Mock(spec=Account) other_user = DatasetPermissionTestDataFactory.create_user_mock(
other_user.id = "other-user-123" user_id="other-user-123", role=TenantAccountRole.NORMAL
other_user.current_tenant_id = self.tenant_id )
other_user.current_role = TenantAccountRole.NORMAL
# Mock database query to return None (no permission record) # Mock database query to return None (no permission record)
mock_session.query().filter_by().first.return_value = None mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None
# Non-creator without explicit permission should be denied access
self._assert_permission_check_fails(dataset, other_user)
self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, other_user.id)
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset."): # ==================== Enum Usage Tests ====================
DatasetService.check_dataset_permission(self.dataset, other_user)
def test_partial_team_permission_uses_correct_enum(self): def test_partial_team_permission_uses_correct_enum(self):
"""Test that the method correctly uses DatasetPermissionEnum.PARTIAL_TEAM""" """Test that the method correctly uses DatasetPermissionEnum.PARTIAL_TEAM instead of string literals."""
# This test ensures we're using the enum instead of string literals # Create dataset with PARTIAL_TEAM permission using enum
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM dataset = DatasetPermissionTestDataFactory.create_dataset_mock(
created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM
)
# Create creator user
creator_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="creator-456", role=TenantAccountRole.EDITOR
)
# Creator should always have access # Creator should always have access regardless of permission level
DatasetService.check_dataset_permission(self.dataset, self.creator_user) self._assert_permission_check_passes(dataset, creator_user)
@patch("services.dataset_service.logging") # ==================== Logging Tests ====================
@patch("services.dataset_service.db.session")
def test_permission_denied_logs_debug_message(self, mock_session, mock_logging): def test_permission_denied_logs_debug_message(self, mock_dataset_service_dependencies, mock_logging_dependencies):
"""Test that permission denied events are logged""" """Test that permission denied events are properly logged for debugging purposes."""
self.dataset.permission = DatasetPermissionEnum.PARTIAL_TEAM # Create dataset with PARTIAL_TEAM permission
mock_session.query().filter_by().first.return_value = None dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
# Create normal user (not the creator)
normal_user = DatasetPermissionTestDataFactory.create_user_mock(
user_id="normal-789", role=TenantAccountRole.NORMAL
)
# Mock database query to return None (no permission record)
mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None
# Attempt permission check (should fail)
with pytest.raises(NoPermissionError): with pytest.raises(NoPermissionError):
DatasetService.check_dataset_permission(self.dataset, self.normal_user) DatasetService.check_dataset_permission(dataset, normal_user)
# Verify debug message was logged # Verify debug message was logged with correct user and dataset information
mock_logging.debug.assert_called_with( mock_logging_dependencies["logging"].debug.assert_called_with(
f"User {self.normal_user.id} does not have permission to access dataset {self.dataset.id}" f"User {normal_user.id} does not have permission to access dataset {dataset.id}"
) )

File diff suppressed because it is too large Load Diff

@ -2,7 +2,7 @@ x-shared-env: &shared-api-worker-env
services: services:
# API service # API service
api: api:
image: langgenius/dify-api:1.4.3 image: langgenius/dify-api:1.5.0
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -31,7 +31,7 @@ services:
# worker service # worker service
# The Celery worker for processing the queue. # The Celery worker for processing the queue.
worker: worker:
image: langgenius/dify-api:1.4.3 image: langgenius/dify-api:1.5.0
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -57,7 +57,7 @@ services:
# Frontend web application. # Frontend web application.
web: web:
image: langgenius/dify-web:1.4.3 image: langgenius/dify-web:1.5.0
restart: always restart: always
environment: environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-} CONSOLE_API_URL: ${CONSOLE_API_URL:-}

@ -516,7 +516,7 @@ x-shared-env: &shared-api-worker-env
services: services:
# API service # API service
api: api:
image: langgenius/dify-api:1.4.3 image: langgenius/dify-api:1.5.0
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -545,7 +545,7 @@ services:
# worker service # worker service
# The Celery worker for processing the queue. # The Celery worker for processing the queue.
worker: worker:
image: langgenius/dify-api:1.4.3 image: langgenius/dify-api:1.5.0
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -571,7 +571,7 @@ services:
# Frontend web application. # Frontend web application.
web: web:
image: langgenius/dify-web:1.4.3 image: langgenius/dify-web:1.5.0
restart: always restart: always
environment: environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-} CONSOLE_API_URL: ${CONSOLE_API_URL:-}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 60 KiB

After

Width:  |  Height:  |  Size: 187 KiB

@ -0,0 +1,248 @@
import threading
from unittest.mock import Mock, patch
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
from core.entities.provider_entities import QuotaUnit
from events.event_handlers.update_provider_when_message_created import (
handle,
get_update_stats,
)
from models.provider import ProviderType
from sqlalchemy.exc import OperationalError
class TestProviderUpdateDeadlockPrevention:
"""Test suite for deadlock prevention in Provider updates."""
def setup_method(self):
"""Setup test fixtures."""
self.mock_message = Mock()
self.mock_message.answer_tokens = 100
self.mock_app_config = Mock()
self.mock_app_config.tenant_id = "test-tenant-123"
self.mock_model_conf = Mock()
self.mock_model_conf.provider = "openai"
self.mock_system_config = Mock()
self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
self.mock_provider_config = Mock()
self.mock_provider_config.using_provider_type = ProviderType.SYSTEM
self.mock_provider_config.system_configuration = self.mock_system_config
self.mock_provider_bundle = Mock()
self.mock_provider_bundle.configuration = self.mock_provider_config
self.mock_model_conf.provider_model_bundle = self.mock_provider_bundle
self.mock_generate_entity = Mock(spec=ChatAppGenerateEntity)
self.mock_generate_entity.app_config = self.mock_app_config
self.mock_generate_entity.model_conf = self.mock_model_conf
@patch("events.event_handlers.update_provider_when_message_created.db")
def test_consolidated_handler_basic_functionality(self, mock_db):
"""Test that the consolidated handler performs both updates correctly."""
# Setup mock query chain
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1 # 1 row affected
# Call the handler
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
# Verify db.session.query was called
assert mock_db.session.query.called
# Verify commit was called
mock_db.session.commit.assert_called_once()
# Verify no rollback was called
assert not mock_db.session.rollback.called
@patch("events.event_handlers.update_provider_when_message_created.db")
def test_deadlock_retry_mechanism(self, mock_db):
"""Test that deadlock errors trigger retry logic."""
# Setup mock to raise deadlock error on first attempt, succeed on second
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
# First call raises deadlock, second succeeds
mock_db.session.commit.side_effect = [
OperationalError("deadlock detected", None, None),
None, # Success on retry
]
# Call the handler
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
# Verify commit was called twice (original + retry)
assert mock_db.session.commit.call_count == 2
# Verify rollback was called once (after first failure)
mock_db.session.rollback.assert_called_once()
@patch("events.event_handlers.update_provider_when_message_created.db")
@patch("events.event_handlers.update_provider_when_message_created.time.sleep")
def test_exponential_backoff_timing(self, mock_sleep, mock_db):
"""Test that retry delays follow exponential backoff pattern."""
# Setup mock to fail twice, succeed on third attempt
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
mock_db.session.commit.side_effect = [
OperationalError("deadlock detected", None, None),
OperationalError("deadlock detected", None, None),
None, # Success on third attempt
]
# Call the handler
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
# Verify sleep was called twice with increasing delays
assert mock_sleep.call_count == 2
# First delay should be around 0.1s + jitter
first_delay = mock_sleep.call_args_list[0][0][0]
assert 0.1 <= first_delay <= 0.3
# Second delay should be around 0.2s + jitter
second_delay = mock_sleep.call_args_list[1][0][0]
assert 0.2 <= second_delay <= 0.4
def test_concurrent_handler_execution(self):
"""Test that multiple handlers can run concurrently without deadlock."""
results = []
errors = []
def run_handler():
try:
with patch(
"events.event_handlers.update_provider_when_message_created.db"
) as mock_db:
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
handle(
self.mock_message,
application_generate_entity=self.mock_generate_entity,
)
results.append("success")
except Exception as e:
errors.append(str(e))
# Run multiple handlers concurrently
threads = []
for _ in range(5):
thread = threading.Thread(target=run_handler)
threads.append(thread)
thread.start()
# Wait for all threads to complete
for thread in threads:
thread.join(timeout=5)
# Verify all handlers completed successfully
assert len(results) == 5
assert len(errors) == 0
def test_performance_stats_tracking(self):
"""Test that performance statistics are tracked correctly."""
# Reset stats
stats = get_update_stats()
initial_total = stats["total_updates"]
with patch(
"events.event_handlers.update_provider_when_message_created.db"
) as mock_db:
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
# Call handler
handle(
self.mock_message, application_generate_entity=self.mock_generate_entity
)
# Check that stats were updated
updated_stats = get_update_stats()
assert updated_stats["total_updates"] == initial_total + 1
assert updated_stats["successful_updates"] >= initial_total + 1
def test_non_chat_entity_ignored(self):
"""Test that non-chat entities are ignored by the handler."""
# Create a non-chat entity
mock_non_chat_entity = Mock()
mock_non_chat_entity.__class__.__name__ = "NonChatEntity"
with patch(
"events.event_handlers.update_provider_when_message_created.db"
) as mock_db:
# Call handler with non-chat entity
handle(self.mock_message, application_generate_entity=mock_non_chat_entity)
# Verify no database operations were performed
assert not mock_db.session.query.called
assert not mock_db.session.commit.called
@patch("events.event_handlers.update_provider_when_message_created.db")
def test_quota_calculation_tokens(self, mock_db):
"""Test quota calculation for token-based quotas."""
# Setup token-based quota
self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
self.mock_message.answer_tokens = 150
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
# Call handler
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
# Verify update was called with token count
update_calls = mock_query.update.call_args_list
# Should have at least one call with quota_used update
quota_update_found = False
for call in update_calls:
values = call[0][0] # First argument to update()
if "quota_used" in values:
quota_update_found = True
break
assert quota_update_found
@patch("events.event_handlers.update_provider_when_message_created.db")
def test_quota_calculation_times(self, mock_db):
"""Test quota calculation for times-based quotas."""
# Setup times-based quota
self.mock_system_config.current_quota_type = QuotaUnit.TIMES
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
# Call handler
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
# Verify update was called
assert mock_query.update.called
assert mock_db.session.commit.called

@ -256,7 +256,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx
</div> </div>
{/* description */} {/* description */}
{appDetail.description && ( {appDetail.description && (
<div className='system-xs-regular overflow-wrap-anywhere w-full max-w-full whitespace-normal break-words text-text-tertiary'>{appDetail.description}</div> <div className='system-xs-regular overflow-wrap-anywhere max-h-[105px] w-full max-w-full overflow-y-auto whitespace-normal break-words text-text-tertiary'>{appDetail.description}</div>
)} )}
{/* operations */} {/* operations */}
<div className='flex flex-wrap items-center gap-1 self-stretch'> <div className='flex flex-wrap items-center gap-1 self-stretch'>

@ -32,6 +32,10 @@ export const PromptMenuItem = memo(({
return return
onMouseEnter() onMouseEnter()
}} }}
onMouseDown={(e) => {
e.preventDefault()
e.stopPropagation()
}}
onClick={() => { onClick={() => {
if (disabled) if (disabled)
return return

@ -52,8 +52,8 @@ const StepThree = ({ datasetId, datasetName, indexingType, creationCache, retrie
datasetId={datasetId || creationCache?.dataset?.id || ''} datasetId={datasetId || creationCache?.dataset?.id || ''}
batchId={creationCache?.batch || ''} batchId={creationCache?.batch || ''}
documents={creationCache?.documents as FullDocumentDetail[]} documents={creationCache?.documents as FullDocumentDetail[]}
indexingType={indexingType || creationCache?.dataset?.indexing_technique} indexingType={creationCache?.dataset?.indexing_technique || indexingType}
retrievalMethod={retrievalMethod || creationCache?.dataset?.retrieval_model?.search_method} retrievalMethod={creationCache?.dataset?.retrieval_model_dict?.search_method || retrievalMethod}
/> />
</div> </div>
</div> </div>

@ -575,6 +575,7 @@ const StepTwo = ({
onSuccess(data) { onSuccess(data) {
updateIndexingTypeCache && updateIndexingTypeCache(indexType as string) updateIndexingTypeCache && updateIndexingTypeCache(indexType as string)
updateResultCache && updateResultCache(data) updateResultCache && updateResultCache(data)
updateRetrievalMethodCache && updateRetrievalMethodCache(retrievalConfig.search_method as string)
}, },
}) })
} }

@ -1,4 +1,4 @@
import React, { type FC, useMemo, useState } from 'react' import React, { type FC, useCallback, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { import {
RiCloseLine, RiCloseLine,
@ -16,8 +16,10 @@ import { useSegmentListContext } from './index'
import { ChunkingMode, type SegmentDetailModel } from '@/models/datasets' import { ChunkingMode, type SegmentDetailModel } from '@/models/datasets'
import { useEventEmitterContextContext } from '@/context/event-emitter' import { useEventEmitterContextContext } from '@/context/event-emitter'
import { formatNumber } from '@/utils/format' import { formatNumber } from '@/utils/format'
import classNames from '@/utils/classnames' import cn from '@/utils/classnames'
import Divider from '@/app/components/base/divider' import Divider from '@/app/components/base/divider'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import { IndexingType } from '../../../create/step-two'
type ISegmentDetailProps = { type ISegmentDetailProps = {
segInfo?: Partial<SegmentDetailModel> & { id: string } segInfo?: Partial<SegmentDetailModel> & { id: string }
@ -48,6 +50,7 @@ const SegmentDetail: FC<ISegmentDetailProps> = ({
const toggleFullScreen = useSegmentListContext(s => s.toggleFullScreen) const toggleFullScreen = useSegmentListContext(s => s.toggleFullScreen)
const mode = useDocumentContext(s => s.mode) const mode = useDocumentContext(s => s.mode)
const parentMode = useDocumentContext(s => s.parentMode) const parentMode = useDocumentContext(s => s.parentMode)
const indexingTechnique = useDatasetDetailContextWithSelector(s => s.dataset?.indexing_technique)
eventEmitter?.useSubscription((v) => { eventEmitter?.useSubscription((v) => {
if (v === 'update-segment') if (v === 'update-segment')
@ -56,56 +59,41 @@ const SegmentDetail: FC<ISegmentDetailProps> = ({
setLoading(false) setLoading(false)
}) })
const handleCancel = () => { const handleCancel = useCallback(() => {
onCancel() onCancel()
} }, [onCancel])
const handleSave = () => { const handleSave = useCallback(() => {
onUpdate(segInfo?.id || '', question, answer, keywords) onUpdate(segInfo?.id || '', question, answer, keywords)
} }, [onUpdate, segInfo?.id, question, answer, keywords])
const handleRegeneration = () => { const handleRegeneration = useCallback(() => {
setShowRegenerationModal(true) setShowRegenerationModal(true)
} }, [])
const onCancelRegeneration = () => { const onCancelRegeneration = useCallback(() => {
setShowRegenerationModal(false) setShowRegenerationModal(false)
} }, [])
const onConfirmRegeneration = () => { const onConfirmRegeneration = useCallback(() => {
onUpdate(segInfo?.id || '', question, answer, keywords, true) onUpdate(segInfo?.id || '', question, answer, keywords, true)
} }, [onUpdate, segInfo?.id, question, answer, keywords])
const isParentChildMode = useMemo(() => {
return mode === 'hierarchical'
}, [mode])
const isFullDocMode = useMemo(() => {
return mode === 'hierarchical' && parentMode === 'full-doc'
}, [mode, parentMode])
const titleText = useMemo(() => {
return isEditMode ? t('datasetDocuments.segment.editChunk') : t('datasetDocuments.segment.chunkDetail')
}, [isEditMode, t])
const isQAModel = useMemo(() => {
return docForm === ChunkingMode.qa
}, [docForm])
const wordCountText = useMemo(() => { const wordCountText = useMemo(() => {
const contentLength = isQAModel ? (question.length + answer.length) : question.length const contentLength = docForm === ChunkingMode.qa ? (question.length + answer.length) : question.length
const total = formatNumber(isEditMode ? contentLength : segInfo!.word_count as number) const total = formatNumber(isEditMode ? contentLength : segInfo!.word_count as number)
const count = isEditMode ? contentLength : segInfo!.word_count as number const count = isEditMode ? contentLength : segInfo!.word_count as number
return `${total} ${t('datasetDocuments.segment.characters', { count })}` return `${total} ${t('datasetDocuments.segment.characters', { count })}`
}, [isEditMode, question.length, answer.length, isQAModel, segInfo, t]) }, [isEditMode, question.length, answer.length, docForm, segInfo, t])
const labelPrefix = useMemo(() => { const isFullDocMode = mode === 'hierarchical' && parentMode === 'full-doc'
return isParentChildMode ? t('datasetDocuments.segment.parentChunk') : t('datasetDocuments.segment.chunk') const titleText = isEditMode ? t('datasetDocuments.segment.editChunk') : t('datasetDocuments.segment.chunkDetail')
}, [isParentChildMode, t]) const labelPrefix = mode === 'hierarchical' ? t('datasetDocuments.segment.parentChunk') : t('datasetDocuments.segment.chunk')
const isECOIndexing = indexingTechnique === IndexingType.ECONOMICAL
return ( return (
<div className={'flex h-full flex-col'}> <div className={'flex h-full flex-col'}>
<div className={classNames('flex items-center justify-between', fullScreen ? 'py-3 pr-4 pl-6 border border-divider-subtle' : 'pt-3 pr-3 pl-4')}> <div className={cn('flex items-center justify-between', fullScreen ? 'border border-divider-subtle py-3 pl-6 pr-4' : 'pl-4 pr-3 pt-3')}>
<div className='flex flex-col'> <div className='flex flex-col'>
<div className='system-xl-semibold text-text-primary'>{titleText}</div> <div className='system-xl-semibold text-text-primary'>{titleText}</div>
<div className='flex items-center gap-x-2'> <div className='flex items-center gap-x-2'>
@ -134,12 +122,12 @@ const SegmentDetail: FC<ISegmentDetailProps> = ({
</div> </div>
</div> </div>
</div> </div>
<div className={classNames( <div className={cn(
'flex grow', 'flex grow',
fullScreen ? 'w-full flex-row justify-center px-6 pt-6 gap-x-8' : 'flex-col gap-y-1 py-3 px-4', fullScreen ? 'w-full flex-row justify-center gap-x-8 px-6 pt-6' : 'flex-col gap-y-1 px-4 py-3',
!isEditMode && 'pb-0 overflow-hidden', !isEditMode && 'overflow-hidden pb-0',
)}> )}>
<div className={classNames(isEditMode ? 'break-all whitespace-pre-line overflow-hidden' : 'overflow-y-auto', fullScreen ? 'w-1/2' : 'grow')}> <div className={cn(isEditMode ? 'overflow-hidden whitespace-pre-line break-all' : 'overflow-y-auto', fullScreen ? 'w-1/2' : 'grow')}>
<ChunkContent <ChunkContent
docForm={docForm} docForm={docForm}
question={question} question={question}
@ -149,7 +137,7 @@ const SegmentDetail: FC<ISegmentDetailProps> = ({
isEditMode={isEditMode} isEditMode={isEditMode}
/> />
</div> </div>
{mode === 'custom' && <Keywords {isECOIndexing && <Keywords
className={fullScreen ? 'w-1/5' : ''} className={fullScreen ? 'w-1/5' : ''}
actionType={isEditMode ? 'edit' : 'view'} actionType={isEditMode ? 'edit' : 'view'}
segInfo={segInfo} segInfo={segInfo}

@ -1,4 +1,4 @@
import { memo, useMemo, useRef, useState } from 'react' import { memo, useCallback, useMemo, useRef, useState } from 'react'
import type { FC } from 'react' import type { FC } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector' import { useContext } from 'use-context-selector'
@ -12,7 +12,6 @@ import Keywords from './completed/common/keywords'
import ChunkContent from './completed/common/chunk-content' import ChunkContent from './completed/common/chunk-content'
import AddAnother from './completed/common/add-another' import AddAnother from './completed/common/add-another'
import Dot from './completed/common/dot' import Dot from './completed/common/dot'
import { useDocumentContext } from './index'
import { useStore as useAppStore } from '@/app/components/app/store' import { useStore as useAppStore } from '@/app/components/app/store'
import { ToastContext } from '@/app/components/base/toast' import { ToastContext } from '@/app/components/base/toast'
import { ChunkingMode, type SegmentUpdater } from '@/models/datasets' import { ChunkingMode, type SegmentUpdater } from '@/models/datasets'
@ -20,6 +19,8 @@ import classNames from '@/utils/classnames'
import { formatNumber } from '@/utils/format' import { formatNumber } from '@/utils/format'
import Divider from '@/app/components/base/divider' import Divider from '@/app/components/base/divider'
import { useAddSegment } from '@/service/knowledge/use-segment' import { useAddSegment } from '@/service/knowledge/use-segment'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import { IndexingType } from '../../create/step-two'
type NewSegmentModalProps = { type NewSegmentModalProps = {
onCancel: () => void onCancel: () => void
@ -44,13 +45,14 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
const [addAnother, setAddAnother] = useState(true) const [addAnother, setAddAnother] = useState(true)
const fullScreen = useSegmentListContext(s => s.fullScreen) const fullScreen = useSegmentListContext(s => s.fullScreen)
const toggleFullScreen = useSegmentListContext(s => s.toggleFullScreen) const toggleFullScreen = useSegmentListContext(s => s.toggleFullScreen)
const mode = useDocumentContext(s => s.mode) const indexingTechnique = useDatasetDetailContextWithSelector(s => s.dataset?.indexing_technique)
const { appSidebarExpand } = useAppStore(useShallow(state => ({ const { appSidebarExpand } = useAppStore(useShallow(state => ({
appSidebarExpand: state.appSidebarExpand, appSidebarExpand: state.appSidebarExpand,
}))) })))
const refreshTimer = useRef<any>(null) const refreshTimer = useRef<any>(null)
const CustomButton = <> const CustomButton = useMemo(() => (
<>
<Divider type='vertical' className='mx-1 h-3 bg-divider-regular' /> <Divider type='vertical' className='mx-1 h-3 bg-divider-regular' />
<button <button
type='button' type='button'
@ -62,21 +64,18 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
{t('common.operation.view')} {t('common.operation.view')}
</button> </button>
</> </>
), [viewNewlyAddedChunk, t])
const isQAModel = useMemo(() => { const handleCancel = useCallback((actionType: 'esc' | 'add' = 'esc') => {
return docForm === ChunkingMode.qa
}, [docForm])
const handleCancel = (actionType: 'esc' | 'add' = 'esc') => {
if (actionType === 'esc' || !addAnother) if (actionType === 'esc' || !addAnother)
onCancel() onCancel()
} }, [onCancel, addAnother])
const { mutateAsync: addSegment } = useAddSegment() const { mutateAsync: addSegment } = useAddSegment()
const handleSave = async () => { const handleSave = useCallback(async () => {
const params: SegmentUpdater = { content: '' } const params: SegmentUpdater = { content: '' }
if (isQAModel) { if (docForm === ChunkingMode.qa) {
if (!question.trim()) { if (!question.trim()) {
return notify({ return notify({
type: 'error', type: 'error',
@ -129,21 +128,27 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
setLoading(false) setLoading(false)
}, },
}) })
} }, [docForm, keywords, addSegment, datasetId, documentId, question, answer, notify, t, appSidebarExpand, CustomButton, handleCancel, onSave])
const wordCountText = useMemo(() => { const wordCountText = useMemo(() => {
const count = isQAModel ? (question.length + answer.length) : question.length const count = docForm === ChunkingMode.qa ? (question.length + answer.length) : question.length
return `${formatNumber(count)} ${t('datasetDocuments.segment.characters', { count })}` return `${formatNumber(count)} ${t('datasetDocuments.segment.characters', { count })}`
// eslint-disable-next-line react-hooks/exhaustive-deps }, [question.length, answer.length, docForm, t])
}, [question.length, answer.length, isQAModel])
const isECOIndexing = indexingTechnique === IndexingType.ECONOMICAL
return ( return (
<div className={'flex h-full flex-col'}> <div className={'flex h-full flex-col'}>
<div className={classNames('flex items-center justify-between', fullScreen ? 'py-3 pr-4 pl-6 border border-divider-subtle' : 'pt-3 pr-3 pl-4')}> <div
className={classNames(
'flex items-center justify-between',
fullScreen ? 'border border-divider-subtle py-3 pl-6 pr-4' : 'pl-4 pr-3 pt-3',
)}
>
<div className='flex flex-col'> <div className='flex flex-col'>
<div className='system-xl-semibold text-text-primary'>{ <div className='system-xl-semibold text-text-primary'>
t('datasetDocuments.segment.addChunk') {t('datasetDocuments.segment.addChunk')}
}</div> </div>
<div className='flex items-center gap-x-2'> <div className='flex items-center gap-x-2'>
<SegmentIndexTag label={t('datasetDocuments.segment.newChunk')!} /> <SegmentIndexTag label={t('datasetDocuments.segment.newChunk')!} />
<Dot /> <Dot />
@ -171,8 +176,8 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
</div> </div>
</div> </div>
</div> </div>
<div className={classNames('flex grow', fullScreen ? 'w-full flex-row justify-center px-6 pt-6 gap-x-8' : 'flex-col gap-y-1 py-3 px-4')}> <div className={classNames('flex grow', fullScreen ? 'w-full flex-row justify-center gap-x-8 px-6 pt-6' : 'flex-col gap-y-1 px-4 py-3')}>
<div className={classNames('break-all overflow-hidden whitespace-pre-line', fullScreen ? 'w-1/2' : 'grow')}> <div className={classNames('overflow-hidden whitespace-pre-line break-all', fullScreen ? 'w-1/2' : 'grow')}>
<ChunkContent <ChunkContent
docForm={docForm} docForm={docForm}
question={question} question={question}
@ -182,7 +187,7 @@ const NewSegmentModal: FC<NewSegmentModalProps> = ({
isEditMode={true} isEditMode={true}
/> />
</div> </div>
{mode === 'custom' && <Keywords {isECOIndexing && <Keywords
className={fullScreen ? 'w-1/5' : ''} className={fullScreen ? 'w-1/5' : ''}
actionType='add' actionType='add'
keywords={keywords} keywords={keywords}

@ -15,7 +15,7 @@ const Empty: FC = () => {
<div className='system-xs-regular text-text-tertiary'>{t('workflow.debug.variableInspect.emptyTip')}</div> <div className='system-xs-regular text-text-tertiary'>{t('workflow.debug.variableInspect.emptyTip')}</div>
<a <a
className='system-xs-regular cursor-pointer text-text-accent' className='system-xs-regular cursor-pointer text-text-accent'
href='https://docs.dify.ai/guides/workflow/debug-and-preview/variable-inspect' href='https://docs.dify.ai/en/guides/workflow/debug-and-preview/variable-inspect'
target='_blank' target='_blank'
rel='noopener noreferrer'> rel='noopener noreferrer'>
{t('workflow.debug.variableInspect.emptyLink')} {t('workflow.debug.variableInspect.emptyLink')}

@ -213,7 +213,7 @@ export default combine(
settings: { settings: {
tailwindcss: { tailwindcss: {
// These are the default values but feel free to customize // These are the default values but feel free to customize
callees: ['classnames', 'clsx', 'ctl', 'cn'], callees: ['classnames', 'clsx', 'ctl', 'cn', 'classNames'],
config: 'tailwind.config.js', // returned from `loadConfig()` utility if not provided config: 'tailwind.config.js', // returned from `loadConfig()` utility if not provided
cssFiles: [ cssFiles: [
'**/*.css', '**/*.css',

@ -1,6 +1,6 @@
{ {
"name": "dify-web", "name": "dify-web",
"version": "1.4.3", "version": "1.5.0",
"private": true, "private": true,
"engines": { "engines": {
"node": ">=v22.11.0" "node": ">=v22.11.0"

Loading…
Cancel
Save