diff --git a/api/core/file/models.py b/api/core/file/models.py index aa3b5f629c..f61334e7bc 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -51,7 +51,7 @@ class File(BaseModel): # It should be set to `ToolFile.id` when `transfer_method` is `tool_file`. related_id: Optional[str] = None filename: Optional[str] = None - extension: Optional[str] = Field(default=None, description="File extension, should contains dot") + extension: Optional[str] = Field(default=None, description="File extension, should contain dot") mime_type: Optional[str] = None size: int = -1 diff --git a/api/core/file/upload_file_parser.py b/api/core/file/upload_file_parser.py deleted file mode 100644 index 96b2884811..0000000000 --- a/api/core/file/upload_file_parser.py +++ /dev/null @@ -1,67 +0,0 @@ -import base64 -import logging -import time -from typing import Optional - -from configs import dify_config -from constants import IMAGE_EXTENSIONS -from core.helper.url_signer import UrlSigner -from extensions.ext_storage import storage - - -class UploadFileParser: - @classmethod - def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]: - if not upload_file: - return None - - if upload_file.extension not in IMAGE_EXTENSIONS: - return None - - if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url: - return cls.get_signed_temp_image_url(upload_file.id) - else: - # get image file base64 - try: - data = storage.load(upload_file.key) - except FileNotFoundError: - logging.exception(f"File not found: {upload_file.key}") - return None - - encoded_string = base64.b64encode(data).decode("utf-8") - return f"data:{upload_file.mime_type};base64,{encoded_string}" - - @classmethod - def get_signed_temp_image_url(cls, upload_file_id) -> str: - """ - get signed url from upload file - - :param upload_file_id: the id of UploadFile object - :return: - """ - base_url = dify_config.FILES_URL - image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview" - - return UrlSigner.get_signed_url(url=image_preview_url, sign_key=upload_file_id, prefix="image-preview") - - @classmethod - def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - """ - verify signature - - :param upload_file_id: file id - :param timestamp: timestamp - :param nonce: nonce - :param sign: signature - :return: - """ - result = UrlSigner.verify( - sign_key=upload_file_id, timestamp=timestamp, nonce=nonce, sign=sign, prefix="image-preview" - ) - - # verify signature - if not result: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT diff --git a/api/core/helper/lru_cache.py b/api/core/helper/lru_cache.py deleted file mode 100644 index 81501d2e4e..0000000000 --- a/api/core/helper/lru_cache.py +++ /dev/null @@ -1,22 +0,0 @@ -from collections import OrderedDict -from typing import Any - - -class LRUCache: - def __init__(self, capacity: int): - self.cache: OrderedDict[Any, Any] = OrderedDict() - self.capacity = capacity - - def get(self, key: Any) -> Any: - if key not in self.cache: - return None - else: - self.cache.move_to_end(key) # move the key to the end of the OrderedDict - return self.cache[key] - - def put(self, key: Any, value: Any) -> None: - if key in self.cache: - self.cache.move_to_end(key) - self.cache[key] = value - if len(self.cache) > self.capacity: - self.cache.popitem(last=False) # pop the first item diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index 163b5d0307..b18a6905fe 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -372,6 +372,7 @@ class AliyunDataTrace(BaseTraceInstance): ) -> SpanData: process_data = node_execution.process_data or {} outputs = node_execution.outputs or {} + usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) return SpanData( trace_id=trace_id, parent_span_id=workflow_span_id, @@ -385,9 +386,9 @@ class AliyunDataTrace(BaseTraceInstance): GEN_AI_FRAMEWORK: "dify", GEN_AI_MODEL_NAME: process_data.get("model_name", ""), GEN_AI_SYSTEM: process_data.get("model_provider", ""), - GEN_AI_USAGE_INPUT_TOKENS: str(outputs.get("usage", {}).get("prompt_tokens", 0)), - GEN_AI_USAGE_OUTPUT_TOKENS: str(outputs.get("usage", {}).get("completion_tokens", 0)), - GEN_AI_USAGE_TOTAL_TOKENS: str(outputs.get("usage", {}).get("total_tokens", 0)), + GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)), + GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)), + GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)), GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False), GEN_AI_COMPLETION: str(outputs.get("text", "")), GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""), diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 0b6834acf3..ffda0885d4 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -213,11 +213,12 @@ class ArizePhoenixDataTrace(BaseTraceInstance): if model: node_metadata["ls_model_name"] = model - usage = json.loads(node_execution.outputs).get("usage", {}) if node_execution.outputs else {} - if usage: - node_metadata["total_tokens"] = usage.get("total_tokens", 0) - node_metadata["prompt_tokens"] = usage.get("prompt_tokens", 0) - node_metadata["completion_tokens"] = usage.get("completion_tokens", 0) + outputs = json.loads(node_execution.outputs).get("usage", {}) + usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) + if usage_data: + node_metadata["total_tokens"] = usage_data.get("total_tokens", 0) + node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0) + node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0) elif node_execution.node_type == "dataset_retrieval": span_kind = OpenInferenceSpanKindValues.RETRIEVER.value elif node_execution.node_type == "tool": @@ -246,14 +247,19 @@ class ArizePhoenixDataTrace(BaseTraceInstance): if model: node_span.set_attribute(SpanAttributes.LLM_MODEL_NAME, model) - usage = json.loads(node_execution.outputs).get("usage", {}) if node_execution.outputs else {} - if usage: - node_span.set_attribute(SpanAttributes.LLM_TOKEN_COUNT_TOTAL, usage.get("total_tokens", 0)) + outputs = json.loads(node_execution.outputs).get("usage", {}) + usage_data = ( + process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) + ) + if usage_data: + node_span.set_attribute( + SpanAttributes.LLM_TOKEN_COUNT_TOTAL, usage_data.get("total_tokens", 0) + ) node_span.set_attribute( - SpanAttributes.LLM_TOKEN_COUNT_PROMPT, usage.get("prompt_tokens", 0) + SpanAttributes.LLM_TOKEN_COUNT_PROMPT, usage_data.get("prompt_tokens", 0) ) node_span.set_attribute( - SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, usage.get("completion_tokens", 0) + SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, usage_data.get("completion_tokens", 0) ) finally: node_span.end(end_time=datetime_to_nanos(finished_at)) diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 1d4ae49fc7..a3dbce0e59 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -181,12 +181,9 @@ class LangFuseDataTrace(BaseTraceInstance): prompt_tokens = 0 completion_tokens = 0 try: - if outputs.get("usage"): - prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 0) - completion_tokens = outputs.get("usage", {}).get("completion_tokens", 0) - else: - prompt_tokens = process_data.get("usage", {}).get("prompt_tokens", 0) - completion_tokens = process_data.get("usage", {}).get("completion_tokens", 0) + usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) + prompt_tokens = usage_data.get("prompt_tokens", 0) + completion_tokens = usage_data.get("completion_tokens", 0) except Exception: logger.error("Failed to extract usage", exc_info=True) diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 8a392940db..f94e5e49d7 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -206,12 +206,9 @@ class LangSmithDataTrace(BaseTraceInstance): prompt_tokens = 0 completion_tokens = 0 try: - if outputs.get("usage"): - prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 0) - completion_tokens = outputs.get("usage", {}).get("completion_tokens", 0) - else: - prompt_tokens = process_data.get("usage", {}).get("prompt_tokens", 0) - completion_tokens = process_data.get("usage", {}).get("completion_tokens", 0) + usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) + prompt_tokens = usage_data.get("prompt_tokens", 0) + completion_tokens = usage_data.get("completion_tokens", 0) except Exception: logger.error("Failed to extract usage", exc_info=True) diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index f4d2760ba5..8bedea20fb 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -222,10 +222,10 @@ class OpikDataTrace(BaseTraceInstance): ) try: - if outputs.get("usage"): - total_tokens = outputs["usage"].get("total_tokens", 0) - prompt_tokens = outputs["usage"].get("prompt_tokens", 0) - completion_tokens = outputs["usage"].get("completion_tokens", 0) + usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) + total_tokens = usage_data.get("total_tokens", 0) + prompt_tokens = usage_data.get("prompt_tokens", 0) + completion_tokens = usage_data.get("completion_tokens", 0) except Exception: logger.error("Failed to extract usage", exc_info=True) diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index d2bf3eb92a..75afe0cdb8 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -122,7 +122,6 @@ class TencentVector(BaseVector): metric_type, params, ) - index_text = vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER) index_metadate = vdb_index.FilterIndex(self.field_metadata, enum.FieldType.Json, enum.IndexType.FILTER) index_sparse_vector = vdb_index.SparseIndex( name="sparse_vector", @@ -130,7 +129,7 @@ class TencentVector(BaseVector): index_type=enum.IndexType.SPARSE_INVERTED, metric_type=enum.MetricType.IP, ) - indexes = [index_id, index_vector, index_text, index_metadate] + indexes = [index_id, index_vector, index_metadate] if self._enable_hybrid_search: indexes.append(index_sparse_vector) try: @@ -149,7 +148,7 @@ class TencentVector(BaseVector): index_metadate = vdb_index.FilterIndex( self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER ) - indexes = [index_id, index_vector, index_text, index_metadate] + indexes = [index_id, index_vector, index_metadate] if self._enable_hybrid_search: indexes.append(index_sparse_vector) self._client.create_collection( diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index cdec92aee7..0b3e5eb424 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -17,6 +17,7 @@ from core.workflow.entities.workflow_execution import ( ) from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from libs.helper import extract_tenant_id from models import ( Account, CreatorUserRole, @@ -67,7 +68,7 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): ) # Extract tenant_id from user - tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id + tenant_id = extract_tenant_id(user) if not tenant_id: raise ValueError("User must have a tenant_id or current_tenant_id") self._tenant_id = tenant_id diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 797cce9354..a5feeb0d7c 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -20,6 +20,7 @@ from core.workflow.entities.workflow_node_execution import ( from core.workflow.nodes.enums import NodeType from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from libs.helper import extract_tenant_id from models import ( Account, CreatorUserRole, @@ -70,7 +71,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) ) # Extract tenant_id from user - tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id + tenant_id = extract_tenant_id(user) if not tenant_id: raise ValueError("User must have a tenant_id or current_tenant_id") self._tenant_id = tenant_id diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index b5225ce548..9bfb402dc8 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -221,15 +221,6 @@ class LLMNode(BaseNode[LLMNodeData]): jinja2_variables=self.node_data.prompt_config.jinja2_variables, ) - process_data = { - "model_mode": model_config.mode, - "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, prompt_messages=prompt_messages - ), - "model_provider": model_config.provider, - "model_name": model_config.model, - } - # handle invoke result generator = self._invoke_llm( node_data_model=self.node_data.model, @@ -253,6 +244,17 @@ class LLMNode(BaseNode[LLMNodeData]): elif isinstance(event, LLMStructuredOutput): structured_output = event + process_data = { + "model_mode": model_config.mode, + "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_config.mode, prompt_messages=prompt_messages + ), + "usage": jsonable_encoder(usage), + "finish_reason": finish_reason, + "model_provider": model_config.provider, + "model_name": model_config.model, + } + outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} if structured_output: outputs["structured_output"] = structured_output.structured_output diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 8d6c2d0a5c..25a534256b 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -253,7 +253,12 @@ class ParameterExtractorNode(BaseNode): status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, process_data=process_data, - outputs={"__is_success": 1 if not error else 0, "__reason": error, **result}, + outputs={ + "__is_success": 1 if not error else 0, + "__reason": error, + "__usage": jsonable_encoder(usage), + **result, + }, metadata={ WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index a518167cc6..74024ed90c 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -145,7 +145,11 @@ class QuestionClassifierNode(LLMNode): "model_provider": model_config.provider, "model_name": model_config.model, } - outputs = {"class_name": category_name, "class_id": category_id} + outputs = { + "class_name": category_name, + "class_id": category_id, + "usage": jsonable_encoder(usage), + } return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index 23cf4c5cab..b62b0b60d6 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -12,6 +12,7 @@ from flask_login import user_loaded_from_request, user_logged_in # type: ignore from configs import dify_config from dify_app import DifyApp +from libs.helper import extract_tenant_id from models import Account, EndUser @@ -24,11 +25,8 @@ def on_user_loaded(_sender, user: Union["Account", "EndUser"]): if user: try: current_span = get_current_span() - if isinstance(user, Account) and user.current_tenant_id: - tenant_id = user.current_tenant_id - elif isinstance(user, EndUser): - tenant_id = user.tenant_id - else: + tenant_id = extract_tenant_id(user) + if not tenant_id: return if current_span: current_span.set_attribute("service.tenant.id", tenant_id) diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 9f1bef3b36..f00ea71c54 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -17,6 +17,7 @@ class EnvironmentVariableField(fields.Raw): "name": value.name, "value": encrypter.obfuscated_token(value.value), "value_type": value.value_type.value, + "description": value.description, } if isinstance(value, Variable): return { @@ -24,6 +25,7 @@ class EnvironmentVariableField(fields.Raw): "name": value.name, "value": value.value, "value_type": value.value_type.value, + "description": value.description, } if isinstance(value, dict): value_type = value.get("value_type") diff --git a/api/libs/helper.py b/api/libs/helper.py index 3f2a630956..48126461a3 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -25,6 +25,31 @@ from extensions.ext_redis import redis_client if TYPE_CHECKING: from models.account import Account + from models.model import EndUser + + +def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None: + """ + Extract tenant_id from Account or EndUser object. + + Args: + user: Account or EndUser object + + Returns: + tenant_id string if available, None otherwise + + Raises: + ValueError: If user is neither Account nor EndUser + """ + from models.account import Account + from models.model import EndUser + + if isinstance(user, Account): + return user.current_tenant_id + elif isinstance(user, EndUser): + return user.tenant_id + else: + raise ValueError(f"Invalid user type: {type(user)}. Expected Account or EndUser.") def run(script): diff --git a/api/models/workflow.py b/api/models/workflow.py index 7f01135af3..77d48bec4f 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -15,6 +15,7 @@ from core.variables import utils as variable_utils from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.nodes.enums import NodeType from factories.variable_factory import TypeMismatchError, build_segment_with_type +from libs.helper import extract_tenant_id from ._workflow_exc import NodeNotFoundError, WorkflowDataError @@ -352,12 +353,7 @@ class Workflow(Base): self._environment_variables = "{}" # Get tenant_id from current_user (Account or EndUser) - if isinstance(current_user, Account): - # Account user - tenant_id = current_user.current_tenant_id - else: - # EndUser - tenant_id = current_user.tenant_id + tenant_id = extract_tenant_id(current_user) if not tenant_id: return [] @@ -384,12 +380,7 @@ class Workflow(Base): return # Get tenant_id from current_user (Account or EndUser) - if isinstance(current_user, Account): - # Account user - tenant_id = current_user.current_tenant_id - else: - # EndUser - tenant_id = current_user.tenant_id + tenant_id = extract_tenant_id(current_user) if not tenant_id: self._environment_variables = "{}" diff --git a/api/services/file_service.py b/api/services/file_service.py index 2d68f30c5a..286535bd18 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -18,6 +18,7 @@ from core.file import helpers as file_helpers from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db from extensions.ext_storage import storage +from libs.helper import extract_tenant_id from models.account import Account from models.enums import CreatorUserRole from models.model import EndUser, UploadFile @@ -61,11 +62,7 @@ class FileService: # generate file key file_uuid = str(uuid.uuid4()) - if isinstance(user, Account): - current_tenant_id = user.current_tenant_id - else: - # end_user - current_tenant_id = user.tenant_id + current_tenant_id = extract_tenant_id(user) file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension diff --git a/api/tests/unit_tests/core/helper/test_url_signer.py b/api/tests/unit_tests/core/helper/test_url_signer.py new file mode 100644 index 0000000000..5af24777de --- /dev/null +++ b/api/tests/unit_tests/core/helper/test_url_signer.py @@ -0,0 +1,194 @@ +from unittest.mock import patch +from urllib.parse import parse_qs, urlparse + +import pytest + +from core.helper.url_signer import SignedUrlParams, UrlSigner + + +class TestUrlSigner: + """Test cases for UrlSigner class""" + + @patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345") + def test_should_generate_signed_url_params(self): + """Test generation of signed URL parameters with all required fields""" + sign_key = "test-sign-key" + prefix = "test-prefix" + + params = UrlSigner.get_signed_url_params(sign_key, prefix) + + # Verify the returned object and required fields + assert isinstance(params, SignedUrlParams) + assert params.sign_key == sign_key + assert params.timestamp is not None + assert params.nonce is not None + assert params.sign is not None + + # Verify nonce format (32 character hex string) + assert len(params.nonce) == 32 + assert all(c in "0123456789abcdef" for c in params.nonce) + + @patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345") + def test_should_generate_complete_signed_url(self): + """Test generation of complete signed URL with query parameters""" + base_url = "https://example.com/api/test" + sign_key = "test-sign-key" + prefix = "test-prefix" + + signed_url = UrlSigner.get_signed_url(base_url, sign_key, prefix) + + # Parse URL and verify structure + parsed = urlparse(signed_url) + assert f"{parsed.scheme}://{parsed.netloc}{parsed.path}" == base_url + + # Verify query parameters + query_params = parse_qs(parsed.query) + assert "timestamp" in query_params + assert "nonce" in query_params + assert "sign" in query_params + + # Verify each parameter has exactly one value + assert len(query_params["timestamp"]) == 1 + assert len(query_params["nonce"]) == 1 + assert len(query_params["sign"]) == 1 + + # Verify parameter values are not empty + assert query_params["timestamp"][0] + assert query_params["nonce"][0] + assert query_params["sign"][0] + + @patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345") + def test_should_verify_valid_signature(self): + """Test verification of valid signature""" + sign_key = "test-sign-key" + prefix = "test-prefix" + + # Generate and verify signature + params = UrlSigner.get_signed_url_params(sign_key, prefix) + + is_valid = UrlSigner.verify( + sign_key=sign_key, timestamp=params.timestamp, nonce=params.nonce, sign=params.sign, prefix=prefix + ) + + assert is_valid is True + + @patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345") + @pytest.mark.parametrize( + ("field", "modifier"), + [ + ("sign_key", lambda _: "wrong-sign-key"), + ("timestamp", lambda t: str(int(t) + 1000)), + ("nonce", lambda _: "different-nonce-123456789012345"), + ("prefix", lambda _: "wrong-prefix"), + ("sign", lambda s: s + "tampered"), + ], + ) + def test_should_reject_invalid_signature_params(self, field, modifier): + """Test signature verification rejects invalid parameters""" + sign_key = "test-sign-key" + prefix = "test-prefix" + + # Generate valid signed parameters + params = UrlSigner.get_signed_url_params(sign_key, prefix) + + # Prepare verification parameters + verify_params = { + "sign_key": sign_key, + "timestamp": params.timestamp, + "nonce": params.nonce, + "sign": params.sign, + "prefix": prefix, + } + + # Modify the specific field + verify_params[field] = modifier(verify_params[field]) + + # Verify should fail + is_valid = UrlSigner.verify(**verify_params) + assert is_valid is False + + @patch("configs.dify_config.SECRET_KEY", None) + def test_should_raise_error_without_secret_key(self): + """Test that signing fails when SECRET_KEY is not configured""" + with pytest.raises(Exception) as exc_info: + UrlSigner.get_signed_url_params("key", "prefix") + + assert "SECRET_KEY is not set" in str(exc_info.value) + + @patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345") + def test_should_generate_unique_signatures(self): + """Test that different inputs produce different signatures""" + params1 = UrlSigner.get_signed_url_params("key1", "prefix1") + params2 = UrlSigner.get_signed_url_params("key2", "prefix2") + + # Different inputs should produce different signatures + assert params1.sign != params2.sign + assert params1.nonce != params2.nonce + + @patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345") + def test_should_handle_special_characters(self): + """Test handling of special characters in parameters""" + special_cases = [ + "test with spaces", + "test/with/slashes", + "test中文字符", + ] + + for sign_key in special_cases: + params = UrlSigner.get_signed_url_params(sign_key, "prefix") + + # Should generate valid signature and verify correctly + is_valid = UrlSigner.verify( + sign_key=sign_key, timestamp=params.timestamp, nonce=params.nonce, sign=params.sign, prefix="prefix" + ) + assert is_valid is True + + @patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345") + def test_should_ensure_nonce_randomness(self): + """Test that nonce is random for each generation - critical for security""" + sign_key = "test-sign-key" + prefix = "test-prefix" + + # Generate multiple nonces + nonces = set() + for _ in range(5): + params = UrlSigner.get_signed_url_params(sign_key, prefix) + nonces.add(params.nonce) + + # All nonces should be unique + assert len(nonces) == 5 + + @patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345") + @patch("time.time", return_value=1234567890) + @patch("os.urandom", return_value=b"\xab\xcd\xef\x12\x34\x56\x78\x90\xab\xcd\xef\x12\x34\x56\x78\x90") + def test_should_produce_consistent_signatures(self, mock_urandom, mock_time): + """Test that same inputs produce same signature - ensures deterministic behavior""" + sign_key = "test-sign-key" + prefix = "test-prefix" + + # Generate signature multiple times with same inputs (time and nonce are mocked) + params1 = UrlSigner.get_signed_url_params(sign_key, prefix) + params2 = UrlSigner.get_signed_url_params(sign_key, prefix) + + # With mocked time and random, should produce identical results + assert params1.timestamp == params2.timestamp + assert params1.nonce == params2.nonce + assert params1.sign == params2.sign + + # Verify the signature is valid + assert UrlSigner.verify( + sign_key=sign_key, timestamp=params1.timestamp, nonce=params1.nonce, sign=params1.sign, prefix=prefix + ) + + @patch("configs.dify_config.SECRET_KEY", "test-secret-key-12345") + def test_should_handle_empty_strings(self): + """Test handling of empty string parameters - common edge case""" + # Empty sign_key and prefix should still work + params = UrlSigner.get_signed_url_params("", "") + assert params.sign is not None + + # Should verify correctly + is_valid = UrlSigner.verify( + sign_key="", timestamp=params.timestamp, nonce=params.nonce, sign=params.sign, prefix="" + ) + assert is_valid is True diff --git a/api/tests/unit_tests/libs/test_helper.py b/api/tests/unit_tests/libs/test_helper.py new file mode 100644 index 0000000000..b7701055f5 --- /dev/null +++ b/api/tests/unit_tests/libs/test_helper.py @@ -0,0 +1,65 @@ +import pytest + +from libs.helper import extract_tenant_id +from models.account import Account +from models.model import EndUser + + +class TestExtractTenantId: + """Test cases for the extract_tenant_id utility function.""" + + def test_extract_tenant_id_from_account_with_tenant(self): + """Test extracting tenant_id from Account with current_tenant_id.""" + # Create a mock Account object + account = Account() + # Mock the current_tenant_id property + account._current_tenant = type("MockTenant", (), {"id": "account-tenant-123"})() + + tenant_id = extract_tenant_id(account) + assert tenant_id == "account-tenant-123" + + def test_extract_tenant_id_from_account_without_tenant(self): + """Test extracting tenant_id from Account without current_tenant_id.""" + # Create a mock Account object + account = Account() + account._current_tenant = None + + tenant_id = extract_tenant_id(account) + assert tenant_id is None + + def test_extract_tenant_id_from_enduser_with_tenant(self): + """Test extracting tenant_id from EndUser with tenant_id.""" + # Create a mock EndUser object + end_user = EndUser() + end_user.tenant_id = "enduser-tenant-456" + + tenant_id = extract_tenant_id(end_user) + assert tenant_id == "enduser-tenant-456" + + def test_extract_tenant_id_from_enduser_without_tenant(self): + """Test extracting tenant_id from EndUser without tenant_id.""" + # Create a mock EndUser object + end_user = EndUser() + end_user.tenant_id = None + + tenant_id = extract_tenant_id(end_user) + assert tenant_id is None + + def test_extract_tenant_id_with_invalid_user_type(self): + """Test extracting tenant_id with invalid user type raises ValueError.""" + invalid_user = "not_a_user_object" + + with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"): + extract_tenant_id(invalid_user) + + def test_extract_tenant_id_with_none_user(self): + """Test extracting tenant_id with None user raises ValueError.""" + with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"): + extract_tenant_id(None) + + def test_extract_tenant_id_with_dict_user(self): + """Test extracting tenant_id with dict user raises ValueError.""" + dict_user = {"id": "123", "tenant_id": "456"} + + with pytest.raises(ValueError, match="Invalid user type.*Expected Account or EndUser"): + extract_tenant_id(dict_user) diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index 69163d48bd..5bc77ad0ef 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -9,6 +9,7 @@ from core.file.models import File from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable from core.variables.segments import IntegerSegment, Segment from factories.variable_factory import build_segment +from models.model import EndUser from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable @@ -43,7 +44,7 @@ def test_environment_variables(): ) # Mock current_user as an EndUser - mock_user = mock.Mock() + mock_user = mock.Mock(spec=EndUser) mock_user.tenant_id = "tenant_id" with ( @@ -90,7 +91,7 @@ def test_update_environment_variables(): ) # Mock current_user as an EndUser - mock_user = mock.Mock() + mock_user = mock.Mock(spec=EndUser) mock_user.tenant_id = "tenant_id" with ( @@ -136,7 +137,7 @@ def test_to_dict(): # Create some EnvironmentVariable instances # Mock current_user as an EndUser - mock_user = mock.Mock() + mock_user = mock.Mock(spec=EndUser) mock_user.tenant_id = "tenant_id" with ( diff --git a/docker/.env.example b/docker/.env.example index e7dbecb413..a403f25cb2 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -826,6 +826,9 @@ MAX_ITERATIONS_NUM=99 # The timeout for the text generation in millisecond TEXT_GENERATION_TIMEOUT_MS=60000 +# Allow rendering unsafe URLs which have "data:" scheme. +ALLOW_UNSAFE_DATA_SCHEME=false + # ------------------------------ # Environment Variables for db Service # ------------------------------ diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index a34f96e945..fd7c78c7e7 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -67,6 +67,7 @@ services: TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} CSP_WHITELIST: ${CSP_WHITELIST:-} ALLOW_EMBED: ${ALLOW_EMBED:-false} + ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false} MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace.dify.ai} MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace.dify.ai} TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index e48b5afd8c..0a95251ff0 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -364,6 +364,7 @@ x-shared-env: &shared-api-worker-env MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10} MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-99} TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} + ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false} POSTGRES_USER: ${POSTGRES_USER:-${DB_USERNAME}} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-${DB_PASSWORD}} POSTGRES_DB: ${POSTGRES_DB:-${DB_DATABASE}} @@ -582,6 +583,7 @@ services: TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} CSP_WHITELIST: ${CSP_WHITELIST:-} ALLOW_EMBED: ${ALLOW_EMBED:-false} + ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false} MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace.dify.ai} MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace.dify.ai} TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-} diff --git a/web/.env.example b/web/.env.example index c30064ffed..37bfc939eb 100644 --- a/web/.env.example +++ b/web/.env.example @@ -32,6 +32,9 @@ NEXT_PUBLIC_CSP_WHITELIST= # Default is not allow to embed into iframe to prevent Clickjacking: https://owasp.org/www-community/attacks/Clickjacking NEXT_PUBLIC_ALLOW_EMBED= +# Allow rendering unsafe URLs which have "data:" scheme. +NEXT_PUBLIC_ALLOW_UNSAFE_DATA_SCHEME=false + # Github Access Token, used for invoking Github API NEXT_PUBLIC_GITHUB_ACCESS_TOKEN= # The maximum number of top-k value for RAG. diff --git a/web/app/components/base/markdown-blocks/utils.ts b/web/app/components/base/markdown-blocks/utils.ts index 4e9e98dbed..d8df76aefc 100644 --- a/web/app/components/base/markdown-blocks/utils.ts +++ b/web/app/components/base/markdown-blocks/utils.ts @@ -1,3 +1,7 @@ +import { ALLOW_UNSAFE_DATA_SCHEME } from '@/config' + export const isValidUrl = (url: string): boolean => { - return ['http:', 'https:', '//', 'mailto:'].some(prefix => url.startsWith(prefix)) + const validPrefixes = ['http:', 'https:', '//', 'mailto:'] + if (ALLOW_UNSAFE_DATA_SCHEME) validPrefixes.push('data:') + return validPrefixes.some(prefix => url.startsWith(prefix)) } diff --git a/web/app/components/base/markdown/markdown-utils.ts b/web/app/components/base/markdown/markdown-utils.ts index 209fcd0b32..0089bef0ac 100644 --- a/web/app/components/base/markdown/markdown-utils.ts +++ b/web/app/components/base/markdown/markdown-utils.ts @@ -4,6 +4,7 @@ * Includes preprocessing for LaTeX and custom "think" tags. */ import { flow } from 'lodash-es' +import { ALLOW_UNSAFE_DATA_SCHEME } from '@/config' export const preprocessLaTeX = (content: string) => { if (typeof content !== 'string') @@ -86,5 +87,8 @@ export const customUrlTransform = (uri: string): string | undefined => { if (PERMITTED_SCHEME_REGEX.test(scheme)) return uri + if (ALLOW_UNSAFE_DATA_SCHEME && scheme === 'data:') + return uri + return undefined } diff --git a/web/app/components/base/tooltip/index.tsx b/web/app/components/base/tooltip/index.tsx index 53f36be5fb..697d6e3d96 100644 --- a/web/app/components/base/tooltip/index.tsx +++ b/web/app/components/base/tooltip/index.tsx @@ -101,7 +101,7 @@ const Tooltip: FC = ({ > {popupContent && (
triggerMethod === 'hover' && setHoverPopup()} diff --git a/web/app/components/workflow/constants.ts b/web/app/components/workflow/constants.ts index 304295cfbf..0ef4dc9dea 100644 --- a/web/app/components/workflow/constants.ts +++ b/web/app/components/workflow/constants.ts @@ -480,6 +480,10 @@ export const LLM_OUTPUT_STRUCT: Var[] = [ variable: 'text', type: VarType.string, }, + { + variable: 'usage', + type: VarType.object, + }, ] export const KNOWLEDGE_RETRIEVAL_OUTPUT_STRUCT: Var[] = [ @@ -501,6 +505,10 @@ export const QUESTION_CLASSIFIER_OUTPUT_STRUCT = [ variable: 'class_name', type: VarType.string, }, + { + variable: 'usage', + type: VarType.object, + }, ] export const HTTP_REQUEST_OUTPUT_STRUCT: Var[] = [ @@ -546,6 +554,10 @@ export const PARAMETER_EXTRACTOR_COMMON_STRUCT: Var[] = [ variable: '__reason', type: VarType.string, }, + { + variable: '__usage', + type: VarType.object, + }, ] export const FILE_STRUCT: Var[] = [ diff --git a/web/app/components/workflow/nodes/_base/components/variable/utils.ts b/web/app/components/workflow/nodes/_base/components/variable/utils.ts index 1058f29119..ac95f54757 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/utils.ts +++ b/web/app/components/workflow/nodes/_base/components/variable/utils.ts @@ -462,6 +462,7 @@ const formatItem = ( return { variable: `env.${env.name}`, type: env.value_type, + description: env.description, } }) as Var[] break @@ -472,7 +473,7 @@ const formatItem = ( return { variable: `conversation.${chatVar.name}`, type: chatVar.value_type, - des: chatVar.description, + description: chatVar.description, } }) as Var[] break diff --git a/web/app/components/workflow/nodes/http/default.ts b/web/app/components/workflow/nodes/http/default.ts index 1bd584eeb9..3f9df0178d 100644 --- a/web/app/components/workflow/nodes/http/default.ts +++ b/web/app/components/workflow/nodes/http/default.ts @@ -22,6 +22,7 @@ const nodeDefault: NodeDefault = { type: BodyType.none, data: [], }, + ssl_verify: true, timeout: { max_connect_timeout: 0, max_read_timeout: 0, diff --git a/web/app/components/workflow/nodes/http/panel.tsx b/web/app/components/workflow/nodes/http/panel.tsx index 9a07c0ad61..b994910ea0 100644 --- a/web/app/components/workflow/nodes/http/panel.tsx +++ b/web/app/components/workflow/nodes/http/panel.tsx @@ -10,6 +10,7 @@ import type { HttpNodeType } from './types' import Timeout from './components/timeout' import CurlPanel from './components/curl-panel' import cn from '@/utils/classnames' +import Switch from '@/app/components/base/switch' import Field from '@/app/components/workflow/nodes/_base/components/field' import Split from '@/app/components/workflow/nodes/_base/components/split' import OutputVars, { VarItem } from '@/app/components/workflow/nodes/_base/components/output-vars' @@ -47,6 +48,7 @@ const Panel: FC> = ({ showCurlPanel, hideCurlPanel, handleCurlImport, + handleSSLVerifyChange, } = useConfig(id, data) // To prevent prompt editor in body not update data. if (!isDataReady) @@ -124,6 +126,18 @@ const Panel: FC> = ({ onChange={setBody} /> + + }> +
{ setInputs(newInputs) }, [inputs, setInputs]) + const handleSSLVerifyChange = useCallback((checked: boolean) => { + const newInputs = produce(inputs, (draft: HttpNodeType) => { + draft.ssl_verify = checked + }) + setInputs(newInputs) + }, [inputs, setInputs]) + return { readOnly, isDataReady, @@ -164,6 +171,8 @@ const useConfig = (id: string, payload: HttpNodeType) => { toggleIsParamKeyValueEdit, // body setBody, + // ssl verify + handleSSLVerifyChange, // authorization isShowAuthorization, showAuthorization, diff --git a/web/app/components/workflow/nodes/llm/panel.tsx b/web/app/components/workflow/nodes/llm/panel.tsx index 2a71dffa11..471d65ef20 100644 --- a/web/app/components/workflow/nodes/llm/panel.tsx +++ b/web/app/components/workflow/nodes/llm/panel.tsx @@ -282,6 +282,11 @@ const Panel: FC> = ({ type='string' description={t(`${i18nPrefix}.outputVars.output`)} /> + {inputs.structured_output_enabled && ( <> diff --git a/web/app/components/workflow/nodes/parameter-extractor/panel.tsx b/web/app/components/workflow/nodes/parameter-extractor/panel.tsx index e86a2e3764..a169217609 100644 --- a/web/app/components/workflow/nodes/parameter-extractor/panel.tsx +++ b/web/app/components/workflow/nodes/parameter-extractor/panel.tsx @@ -190,12 +190,17 @@ const Panel: FC> = ({ + diff --git a/web/app/components/workflow/nodes/question-classifier/panel.tsx b/web/app/components/workflow/nodes/question-classifier/panel.tsx index 8f6f5eb76d..8cf9ec5f7c 100644 --- a/web/app/components/workflow/nodes/question-classifier/panel.tsx +++ b/web/app/components/workflow/nodes/question-classifier/panel.tsx @@ -129,6 +129,11 @@ const Panel: FC> = ({ type='string' description={t(`${i18nPrefix}.outputVars.className`)} /> + diff --git a/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx b/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx index 347c83c155..869317ca6a 100644 --- a/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx +++ b/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx @@ -80,7 +80,7 @@ const ChatVariableModal = ({ const [objectValue, setObjectValue] = React.useState([DEFAULT_OBJECT_VALUE]) const [editorContent, setEditorContent] = React.useState() const [editInJSON, setEditInJSON] = React.useState(false) - const [des, setDes] = React.useState('') + const [description, setDescription] = React.useState('') const editorMinHeight = useMemo(() => { if (type === ChatVarType.ArrayObject) @@ -237,7 +237,7 @@ const ChatVariableModal = ({ name, value_type: type, value: formatValue(value), - description: des, + description, }) onClose() } @@ -247,7 +247,7 @@ const ChatVariableModal = ({ setName(chatVar.name) setType(chatVar.value_type) setValue(chatVar.value) - setDes(chatVar.description) + setDescription(chatVar.description) setObjectValue(getObjectValue()) if (chatVar.value_type === ChatVarType.ArrayObject) { setEditorContent(JSON.stringify(chatVar.value)) @@ -385,9 +385,9 @@ const ChatVariableModal = ({