diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 8518d34a8e..4046417076 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -85,5 +85,35 @@ class RuleCodeGenerateApi(Resource): return code_result +class RuleStructuredOutputGenerateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") + parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + + account = current_user + try: + structured_output = LLMGenerator.generate_structured_output( + tenant_id=account.current_tenant_id, + instruction=args["instruction"], + model_config=args["model_config"], + ) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + + return structured_output + + api.add_resource(RuleGenerateApi, "/rule-generate") api.add_resource(RuleCodeGenerateApi, "/rule-code-generate") +api.add_resource(RuleStructuredOutputGenerateApi, "/rule-structured-output-generate") diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index dc0009f36e..d4a33645ab 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -16,7 +16,7 @@ from controllers.console.auth.error import ( PasswordMismatchError, ) from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError -from controllers.console.wraps import setup_required +from controllers.console.wraps import email_password_login_enabled, setup_required from events.tenant_event import tenant_was_created from extensions.ext_database import db from libs.helper import email, extract_remote_ip @@ -30,6 +30,7 @@ from services.feature_service import FeatureService class ForgotPasswordSendEmailApi(Resource): @setup_required + @email_password_login_enabled def post(self): parser = reqparse.RequestParser() parser.add_argument("email", type=email, required=True, location="json") @@ -62,6 +63,7 @@ class ForgotPasswordSendEmailApi(Resource): class ForgotPasswordCheckApi(Resource): @setup_required + @email_password_login_enabled def post(self): parser = reqparse.RequestParser() parser.add_argument("email", type=str, required=True, location="json") @@ -86,12 +88,21 @@ class ForgotPasswordCheckApi(Resource): AccountService.add_forgot_password_error_rate_limit(args["email"]) raise EmailCodeError() + # Verified, revoke the first token + AccountService.revoke_reset_password_token(args["token"]) + + # Refresh token data by generating a new token + _, new_token = AccountService.generate_reset_password_token( + user_email, code=args["code"], additional_data={"phase": "reset"} + ) + AccountService.reset_forgot_password_error_rate_limit(args["email"]) - return {"is_valid": True, "email": token_data.get("email")} + return {"is_valid": True, "email": token_data.get("email"), "token": new_token} class ForgotPasswordResetApi(Resource): @setup_required + @email_password_login_enabled def post(self): parser = reqparse.RequestParser() parser.add_argument("token", type=str, required=True, nullable=False, location="json") @@ -107,6 +118,9 @@ class ForgotPasswordResetApi(Resource): reset_data = AccountService.get_reset_password_data(args["token"]) if not reset_data: raise InvalidTokenError() + # Must use token in reset phase + if reset_data.get("phase", "") != "reset": + raise InvalidTokenError() # Revoke token to prevent reuse AccountService.revoke_reset_password_token(args["token"]) diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 41362e9fa2..16c1dcc441 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -22,7 +22,7 @@ from controllers.console.error import ( EmailSendIpLimitError, NotAllowedCreateWorkspace, ) -from controllers.console.wraps import setup_required +from controllers.console.wraps import email_password_login_enabled, setup_required from events.tenant_event import tenant_was_created from libs.helper import email, extract_remote_ip from libs.password import valid_password @@ -38,6 +38,7 @@ class LoginApi(Resource): """Resource for user login.""" @setup_required + @email_password_login_enabled def post(self): """Authenticate user and login.""" parser = reqparse.RequestParser() @@ -110,6 +111,7 @@ class LogoutApi(Resource): class ResetPasswordSendEmailApi(Resource): @setup_required + @email_password_login_enabled def post(self): parser = reqparse.RequestParser() parser.add_argument("email", type=email, required=True, location="json") diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 6caaae87f4..e5e8038ad7 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -210,3 +210,16 @@ def enterprise_license_required(view): return view(*args, **kwargs) return decorated + + +def email_password_login_enabled(view): + @wraps(view) + def decorated(*args, **kwargs): + features = FeatureService.get_system_features() + if features.enable_email_password_login: + return view(*args, **kwargs) + + # otherwise, return 403 + abort(403) + + return decorated diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 75687f9ae3..d5d2ca60fa 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -10,6 +10,7 @@ from core.llm_generator.prompts import ( GENERATOR_QA_PROMPT, JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE, PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE, + SYSTEM_STRUCTURED_OUTPUT_GENERATE, WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, ) from core.model_manager import ModelManager @@ -340,3 +341,37 @@ class LLMGenerator: answer = cast(str, response.message.content) return answer.strip() + + @classmethod + def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict): + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), + ) + + prompt_messages = [ + SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE), + UserPromptMessage(content=instruction), + ] + model_parameters = model_config.get("model_parameters", {}) + + try: + response = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False + ), + ) + + generated_json_schema = cast(str, response.message.content) + return {"output": generated_json_schema, "error": ""} + + except InvokeError as e: + error = str(e) + return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"} + except Exception as e: + logging.exception(f"Failed to invoke LLM model, model: {model_config.get('name')}") + return {"output": "", "error": f"An unexpected error occurred: {str(e)}"} diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index cf20e60c82..82d22d7f89 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -220,3 +220,110 @@ Here is the task description: {{INPUT_TEXT}} You just need to generate the output """ # noqa: E501 + +SYSTEM_STRUCTURED_OUTPUT_GENERATE = """ +Your task is to convert simple user descriptions into properly formatted JSON Schema definitions. When a user describes data fields they need, generate a complete, valid JSON Schema that accurately represents those fields with appropriate types and requirements. + +## Instructions: + +1. Analyze the user's description of their data needs +2. Identify each property that should be included in the schema +3. Determine the appropriate data type for each property +4. Decide which properties should be required +5. Generate a complete JSON Schema with proper syntax +6. Include appropriate constraints when specified (min/max values, patterns, formats) +7. Provide ONLY the JSON Schema without any additional explanations, comments, or markdown formatting. +8. DO NOT use markdown code blocks (``` or ``` json). Return the raw JSON Schema directly. + +## Examples: + +### Example 1: +**User Input:** I need name and age +**JSON Schema Output:** +{ + "type": "object", + "properties": { + "name": { "type": "string" }, + "age": { "type": "number" } + }, + "required": ["name", "age"] +} + +### Example 2: +**User Input:** I want to store information about books including title, author, publication year and optional page count +**JSON Schema Output:** +{ + "type": "object", + "properties": { + "title": { "type": "string" }, + "author": { "type": "string" }, + "publicationYear": { "type": "integer" }, + "pageCount": { "type": "integer" } + }, + "required": ["title", "author", "publicationYear"] +} + +### Example 3: +**User Input:** Create a schema for user profiles with email, password, and age (must be at least 18) +**JSON Schema Output:** +{ + "type": "object", + "properties": { + "email": { + "type": "string", + "format": "email" + }, + "password": { + "type": "string", + "minLength": 8 + }, + "age": { + "type": "integer", + "minimum": 18 + } + }, + "required": ["email", "password", "age"] +} + +### Example 4: +**User Input:** I need album schema, the ablum has songs, and each song has name, duration, and artist. +**JSON Schema Output:** +{ + "type": "object", + "properties": { + "properties": { + "songs": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "id": { + "type": "string" + }, + "duration": { + "type": "string" + }, + "aritst": { + "type": "string" + } + }, + "required": [ + "name", + "id", + "duration", + "aritst" + ] + } + } + } + }, + "required": [ + "songs" + ] +} + +Now, generate a JSON Schema based on my description +""" # noqa: E501 diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 3225f03fbd..373ef2bbe2 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -2,7 +2,7 @@ from decimal import Decimal from enum import Enum, StrEnum from typing import Any, Optional -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, model_validator from core.model_runtime.entities.common_entities import I18nObject @@ -85,6 +85,7 @@ class ModelFeature(Enum): DOCUMENT = "document" VIDEO = "video" AUDIO = "audio" + STRUCTURED_OUTPUT = "structured-output" class DefaultParameterName(StrEnum): @@ -197,6 +198,19 @@ class AIModelEntity(ProviderModel): parameter_rules: list[ParameterRule] = [] pricing: Optional[PriceConfig] = None + @model_validator(mode="after") + def validate_model(self): + supported_schema_keys = ["json_schema"] + schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None) + if not schema_key: + return self + if self.features is None: + self.features = [ModelFeature.STRUCTURED_OUTPUT] + else: + if ModelFeature.STRUCTURED_OUTPUT not in self.features: + self.features.append(ModelFeature.STRUCTURED_OUTPUT) + return self + class ModelUsage(BaseModel): pass diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index f402da030f..db07e52f3f 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -39,6 +39,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): :param query: str :return: dict """ + # FIXME(-LAN-): Avoid import service into core workflow_service = WorkflowService() node_id = "1919810" node_data = ParameterExtractorNodeData( @@ -89,6 +90,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): :param query: str :return: dict """ + # FIXME(-LAN-): Avoid import service into core workflow_service = WorkflowService() node_id = "1919810" node_data = QuestionClassifierNodeData( diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 70c618a631..edaa8c92fa 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -126,9 +126,7 @@ class WordExtractor(BaseExtractor): db.session.add(upload_file) db.session.commit() - image_map[rel.target_part] = ( - f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/file-preview)" - ) + image_map[rel.target_part] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)" return image_map diff --git a/api/core/repository/workflow_node_execution_repository.py b/api/core/repository/workflow_node_execution_repository.py index 6dea4566de..9bb790cb0f 100644 --- a/api/core/repository/workflow_node_execution_repository.py +++ b/api/core/repository/workflow_node_execution_repository.py @@ -86,3 +86,12 @@ class WorkflowNodeExecutionRepository(Protocol): execution: The WorkflowNodeExecution instance to update """ ... + + def clear(self) -> None: + """ + Clear all WorkflowNodeExecution records based on implementation-specific criteria. + + This method is intended to be used for bulk deletion operations, such as removing + all records associated with a specific app_id and tenant_id in multi-tenant implementations. + """ + ... diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 7c8960fe49..da40cbcdea 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -16,7 +16,7 @@ from core.variables.segments import StringSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey -from core.workflow.nodes.agent.entities import AgentNodeData, ParamsAutoGenerated +from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event.event import RunCompletedEvent @@ -251,7 +251,12 @@ class AgentNode(ToolNode): prompt_message.model_dump(mode="json") for prompt_message in prompt_messages ] value["history_prompt_messages"] = history_prompt_messages - value["entity"] = model_schema.model_dump(mode="json") if model_schema else None + if model_schema: + # remove structured output feature to support old version agent plugin + model_schema = self._remove_unsupported_model_features_for_old_version(model_schema) + value["entity"] = model_schema.model_dump(mode="json") + else: + value["entity"] = None result[parameter_name] = value return result @@ -348,3 +353,10 @@ class AgentNode(ToolNode): ) model_schema = model_type_instance.get_model_schema(model_name, model_credentials) return model_instance, model_schema + + def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity: + if model_schema.features: + for feature in model_schema.features: + if feature.value not in AgentOldVersionModelFeatures: + model_schema.features.remove(feature) + return model_schema diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 87cc7e9824..77e94375bf 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -24,3 +24,18 @@ class AgentNodeData(BaseNodeData): class ParamsAutoGenerated(Enum): CLOSE = 0 OPEN = 1 + + +class AgentOldVersionModelFeatures(Enum): + """ + Enum class for old SDK version llm feature. + """ + + TOOL_CALL = "tool-call" + MULTI_TOOL_CALL = "multi-tool-call" + AGENT_THOUGHT = "agent-thought" + VISION = "vision" + STREAM_TOOL_CALL = "stream-tool-call" + DOCUMENT = "document" + VIDEO = "video" + AUDIO = "audio" diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index bf54fdb80c..486b4b01af 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -65,6 +65,8 @@ class LLMNodeData(BaseNodeData): memory: Optional[MemoryConfig] = None context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) + structured_output: dict | None = None + structured_output_enabled: bool = False @field_validator("prompt_config", mode="before") @classmethod diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index fe0ed3e564..8db7394e54 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -4,6 +4,8 @@ from collections.abc import Generator, Mapping, Sequence from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, Optional, cast +import json_repair + from configs import dify_config from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus @@ -27,7 +29,13 @@ from core.model_runtime.entities.message_entities import ( SystemPromptMessage, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, +) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import ModelProviderID @@ -57,6 +65,12 @@ from core.workflow.nodes.event import ( RunRetrieverResourceEvent, RunStreamChunkEvent, ) +from core.workflow.utils.structured_output.entities import ( + ResponseFormat, + SpecialModelType, + SupportStructuredOutputStatus, +) +from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from models.model import Conversation @@ -92,6 +106,12 @@ class LLMNode(BaseNode[LLMNodeData]): _node_type = NodeType.LLM def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: + def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]: + """Process structured output if enabled""" + if not self.node_data.structured_output_enabled or not self.node_data.structured_output: + return None + return self._parse_structured_output(text) + node_inputs: Optional[dict[str, Any]] = None process_data = None result_text = "" @@ -130,7 +150,6 @@ class LLMNode(BaseNode[LLMNodeData]): if isinstance(event, RunRetrieverResourceEvent): context = event.context yield event - if context: node_inputs["#context#"] = context @@ -192,7 +211,9 @@ class LLMNode(BaseNode[LLMNodeData]): self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) break outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} - + structured_output = process_structured_output(result_text) + if structured_output: + outputs["structured_output"] = structured_output yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -513,7 +534,12 @@ class LLMNode(BaseNode[LLMNodeData]): if not model_schema: raise ModelNotExistError(f"Model {model_name} not exist.") - + support_structured_output = self._check_model_structured_output_support() + if support_structured_output == SupportStructuredOutputStatus.SUPPORTED: + completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules) + elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: + # Set appropriate response format based on model capabilities + self._set_response_format(completion_params, model_schema.parameter_rules) return model_instance, ModelConfigWithCredentialsEntity( provider=provider_name, model=model_name, @@ -724,10 +750,29 @@ class LLMNode(BaseNode[LLMNodeData]): "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) - + support_structured_output = self._check_model_structured_output_support() + if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: + filtered_prompt_messages = self._handle_prompt_based_schema( + prompt_messages=filtered_prompt_messages, + ) stop = model_config.stop return filtered_prompt_messages, stop + def _parse_structured_output(self, result_text: str) -> dict[str, Any] | list[Any]: + structured_output: dict[str, Any] | list[Any] = {} + try: + parsed = json.loads(result_text) + if not isinstance(parsed, (dict | list)): + raise LLMNodeError(f"Failed to parse structured output: {result_text}") + structured_output = parsed + except json.JSONDecodeError as e: + # if the result_text is not a valid json, try to repair it + parsed = json_repair.loads(result_text) + if not isinstance(parsed, (dict | list)): + raise LLMNodeError(f"Failed to parse structured output: {result_text}") + structured_output = parsed + return structured_output + @classmethod def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: provider_model_bundle = model_instance.provider_model_bundle @@ -926,6 +971,166 @@ class LLMNode(BaseNode[LLMNodeData]): return prompt_messages + def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict: + """ + Handle structured output for models with native JSON schema support. + + :param model_parameters: Model parameters to update + :param rules: Model parameter rules + :return: Updated model parameters with JSON schema configuration + """ + # Process schema according to model requirements + schema = self._fetch_structured_output_schema() + schema_json = self._prepare_schema_for_model(schema) + + # Set JSON schema in parameters + model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False) + + # Set appropriate response format if required by the model + for rule in rules: + if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value + + return model_parameters + + def _handle_prompt_based_schema(self, prompt_messages: Sequence[PromptMessage]) -> list[PromptMessage]: + """ + Handle structured output for models without native JSON schema support. + This function modifies the prompt messages to include schema-based output requirements. + + Args: + prompt_messages: Original sequence of prompt messages + + Returns: + list[PromptMessage]: Updated prompt messages with structured output requirements + """ + # Convert schema to string format + schema_str = json.dumps(self._fetch_structured_output_schema(), ensure_ascii=False) + + # Find existing system prompt with schema placeholder + system_prompt = next( + (prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)), + None, + ) + structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str) + # Prepare system prompt content + system_prompt_content = ( + structured_output_prompt + "\n\n" + system_prompt.content + if system_prompt and isinstance(system_prompt.content, str) + else structured_output_prompt + ) + system_prompt = SystemPromptMessage(content=system_prompt_content) + + # Extract content from the last user message + + filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)] + updated_prompt = [system_prompt] + filtered_prompts + + return updated_prompt + + def _set_response_format(self, model_parameters: dict, rules: list) -> None: + """ + Set the appropriate response format parameter based on model rules. + + :param model_parameters: Model parameters to update + :param rules: Model parameter rules + """ + for rule in rules: + if rule.name == "response_format": + if ResponseFormat.JSON.value in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON.value + elif ResponseFormat.JSON_OBJECT.value in rule.options: + model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value + + def _prepare_schema_for_model(self, schema: dict) -> dict: + """ + Prepare JSON schema based on model requirements. + + Different models have different requirements for JSON schema formatting. + This function handles these differences. + + :param schema: The original JSON schema + :return: Processed schema compatible with the current model + """ + + # Deep copy to avoid modifying the original schema + processed_schema = schema.copy() + + # Convert boolean types to string types (common requirement) + convert_boolean_to_string(processed_schema) + + # Apply model-specific transformations + if SpecialModelType.GEMINI in self.node_data.model.name: + remove_additional_properties(processed_schema) + return processed_schema + elif SpecialModelType.OLLAMA in self.node_data.model.provider: + return processed_schema + else: + # Default format with name field + return {"schema": processed_schema, "name": "llm_response"} + + def _fetch_model_schema(self, provider: str) -> AIModelEntity | None: + """ + Fetch model schema + """ + model_name = self.node_data.model.name + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name + ) + model_type_instance = model_instance.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + model_credentials = model_instance.credentials + model_schema = model_type_instance.get_model_schema(model_name, model_credentials) + return model_schema + + def _fetch_structured_output_schema(self) -> dict[str, Any]: + """ + Fetch the structured output schema from the node data. + + Returns: + dict[str, Any]: The structured output schema + """ + if not self.node_data.structured_output: + raise LLMNodeError("Please provide a valid structured output schema") + structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False) + if not structured_output_schema: + raise LLMNodeError("Please provide a valid structured output schema") + + try: + schema = json.loads(structured_output_schema) + if not isinstance(schema, dict): + raise LLMNodeError("structured_output_schema must be a JSON object") + return schema + except json.JSONDecodeError: + raise LLMNodeError("structured_output_schema is not valid JSON format") + + def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus: + """ + Check if the current model supports structured output. + + Returns: + SupportStructuredOutput: The support status of structured output + """ + # Early return if structured output is disabled + if ( + not isinstance(self.node_data, LLMNodeData) + or not self.node_data.structured_output_enabled + or not self.node_data.structured_output + ): + return SupportStructuredOutputStatus.DISABLED + # Get model schema and check if it exists + model_schema = self._fetch_model_schema(self.node_data.model.provider) + if not model_schema: + return SupportStructuredOutputStatus.DISABLED + + # Check if model supports structured output feature + return ( + SupportStructuredOutputStatus.SUPPORTED + if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features) + else SupportStructuredOutputStatus.UNSUPPORTED + ) + def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole): match role: @@ -1064,3 +1269,49 @@ def _handle_completion_template( ) prompt_messages.append(prompt_message) return prompt_messages + + +def remove_additional_properties(schema: dict) -> None: + """ + Remove additionalProperties fields from JSON schema. + Used for models like Gemini that don't support this property. + + :param schema: JSON schema to modify in-place + """ + if not isinstance(schema, dict): + return + + # Remove additionalProperties at current level + schema.pop("additionalProperties", None) + + # Process nested structures recursively + for value in schema.values(): + if isinstance(value, dict): + remove_additional_properties(value) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + remove_additional_properties(item) + + +def convert_boolean_to_string(schema: dict) -> None: + """ + Convert boolean type specifications to string in JSON schema. + + :param schema: JSON schema to modify in-place + """ + if not isinstance(schema, dict): + return + + # Check for boolean type at current level + if schema.get("type") == "boolean": + schema["type"] = "string" + + # Process nested dictionaries and lists recursively + for value in schema.values(): + if isinstance(value, dict): + convert_boolean_to_string(value) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + convert_boolean_to_string(item) diff --git a/api/core/workflow/utils/structured_output/entities.py b/api/core/workflow/utils/structured_output/entities.py new file mode 100644 index 0000000000..7954acbaee --- /dev/null +++ b/api/core/workflow/utils/structured_output/entities.py @@ -0,0 +1,24 @@ +from enum import StrEnum + + +class ResponseFormat(StrEnum): + """Constants for model response formats""" + + JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode. + JSON = "JSON" # model's json mode. some model like claude support this mode. + JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias. + + +class SpecialModelType(StrEnum): + """Constants for identifying model types""" + + GEMINI = "gemini" + OLLAMA = "ollama" + + +class SupportStructuredOutputStatus(StrEnum): + """Constants for structured output support status""" + + SUPPORTED = "supported" + UNSUPPORTED = "unsupported" + DISABLED = "disabled" diff --git a/api/core/workflow/utils/structured_output/prompt.py b/api/core/workflow/utils/structured_output/prompt.py new file mode 100644 index 0000000000..06d9b2056e --- /dev/null +++ b/api/core/workflow/utils/structured_output/prompt.py @@ -0,0 +1,17 @@ +STRUCTURED_OUTPUT_PROMPT = """You’re a helpful AI assistant. You could answer questions and output in JSON format. +constraints: + - You must output in JSON format. + - Do not output boolean value, use string type instead. + - Do not output integer or float value, use number type instead. +eg: + Here is the JSON schema: + {"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"} + + Here is the user's question: + My name is John Doe and I am 30 years old. + + output: + {"name": "John Doe", "age": 30} +Here is the JSON schema: +{{schema}} +""" # noqa: E501 diff --git a/api/models/workflow.py b/api/models/workflow.py index 045fa0aaa0..51f2f4cc9f 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -630,6 +630,7 @@ class WorkflowNodeExecution(Base): @property def created_by_account(self): created_by_role = CreatedByRole(self.created_by_role) + # TODO(-LAN-): Avoid using db.session.get() here. return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None @property @@ -637,6 +638,7 @@ class WorkflowNodeExecution(Base): from models.model import EndUser created_by_role = CreatedByRole(self.created_by_role) + # TODO(-LAN-): Avoid using db.session.get() here. return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None @property diff --git a/api/pyproject.toml b/api/pyproject.toml index 85679a6359..08f9c1e229 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "gunicorn~=23.0.0", "httpx[socks]~=0.27.0", "jieba==0.42.1", + "json-repair>=0.41.1", "langfuse~=2.51.3", "langsmith~=0.1.77", "mailchimp-transactional~=1.0.50", @@ -163,10 +164,7 @@ storage = [ ############################################################ # [ Tools ] dependency group ############################################################ -tools = [ - "cloudscraper~=1.2.71", - "nltk~=3.9.1", -] +tools = ["cloudscraper~=1.2.71", "nltk~=3.9.1"] ############################################################ # [ VDB ] dependency group diff --git a/api/repositories/workflow_node_execution/sqlalchemy_repository.py b/api/repositories/workflow_node_execution/sqlalchemy_repository.py index c9c6e70ff3..0594d816a2 100644 --- a/api/repositories/workflow_node_execution/sqlalchemy_repository.py +++ b/api/repositories/workflow_node_execution/sqlalchemy_repository.py @@ -6,7 +6,7 @@ import logging from collections.abc import Sequence from typing import Optional -from sqlalchemy import UnaryExpression, asc, desc, select +from sqlalchemy import UnaryExpression, asc, delete, desc, select from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -168,3 +168,25 @@ class SQLAlchemyWorkflowNodeExecutionRepository: session.merge(execution) session.commit() + + def clear(self) -> None: + """ + Clear all WorkflowNodeExecution records for the current tenant_id and app_id. + + This method deletes all WorkflowNodeExecution records that match the tenant_id + and app_id (if provided) associated with this repository instance. + """ + with self._session_factory() as session: + stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id) + + if self._app_id: + stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) + + result = session.execute(stmt) + session.commit() + + deleted_count = result.rowcount + logger.info( + f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}" + + (f" and app {self._app_id}" if self._app_id else "") + ) diff --git a/api/services/account_service.py b/api/services/account_service.py index ada8109067..f930ef910b 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -407,10 +407,8 @@ class AccountService: raise PasswordResetRateLimitExceededError() - code = "".join([str(random.randint(0, 9)) for _ in range(6)]) - token = TokenManager.generate_token( - account=account, email=email, token_type="reset_password", additional_data={"code": code} - ) + code, token = cls.generate_reset_password_token(account_email, account) + send_reset_password_mail_task.delay( language=language, to=account_email, @@ -419,6 +417,22 @@ class AccountService: cls.reset_password_rate_limiter.increment_rate_limit(account_email) return token + @classmethod + def generate_reset_password_token( + cls, + email: str, + account: Optional[Account] = None, + code: Optional[str] = None, + additional_data: dict[str, Any] = {}, + ): + if not code: + code = "".join([str(random.randint(0, 9)) for _ in range(6)]) + additional_data["code"] = code + token = TokenManager.generate_token( + account=account, email=email, token_type="reset_password", additional_data=additional_data + ) + return code, token + @classmethod def revoke_reset_password_token(cls, token: str): TokenManager.revoke_token(token, "reset_password") diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 0ddd18ea27..ff3b33eecd 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -2,13 +2,14 @@ import threading from typing import Optional import contexts +from core.repository import RepositoryFactory +from core.repository.workflow_node_execution_repository import OrderConfig from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.model import App from models.workflow import ( WorkflowNodeExecution, - WorkflowNodeExecutionTriggeredFrom, WorkflowRun, ) @@ -127,17 +128,17 @@ class WorkflowRunService: if not workflow_run: return [] - node_executions = ( - db.session.query(WorkflowNodeExecution) - .filter( - WorkflowNodeExecution.tenant_id == app_model.tenant_id, - WorkflowNodeExecution.app_id == app_model.id, - WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - WorkflowNodeExecution.workflow_run_id == run_id, - ) - .order_by(WorkflowNodeExecution.index.desc()) - .all() + # Use the repository to get the node executions + repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": app_model.tenant_id, + "app_id": app_model.id, + "session_factory": db.session.get_bind, + } ) - return node_executions + # Use the repository to get the node executions with ordering + order_config = OrderConfig(order_by=["index"], order_direction="desc") + node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config) + + return list(node_executions) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 992942fc70..b88c7b296d 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -11,6 +11,7 @@ from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.model_runtime.utils.encoders import jsonable_encoder +from core.repository import RepositoryFactory from core.variables import Variable from core.workflow.entities.node_entities import NodeRunResult from core.workflow.errors import WorkflowNodeRunFailedError @@ -282,8 +283,15 @@ class WorkflowService: workflow_node_execution.created_by = account.id workflow_node_execution.workflow_id = draft_workflow.id - db.session.add(workflow_node_execution) - db.session.commit() + # Use the repository to save the workflow node execution + repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": app_model.tenant_id, + "app_id": app_model.id, + "session_factory": db.session.get_bind, + } + ) + repository.save(workflow_node_execution) return workflow_node_execution diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index c3910e2be3..4542b1b923 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -7,6 +7,7 @@ from celery import shared_task # type: ignore from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError +from core.repository import RepositoryFactory from extensions.ext_database import db from models.dataset import AppDatasetJoin from models.model import ( @@ -30,7 +31,7 @@ from models.model import ( ) from models.tools import WorkflowToolProvider from models.web import PinnedConversation, SavedMessage -from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecution, WorkflowRun +from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowRun @shared_task(queue="app_deletion", bind=True, max_retries=3) @@ -187,18 +188,20 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str): def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): - def del_workflow_node_execution(workflow_node_execution_id: str): - db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution_id).delete( - synchronize_session=False - ) - - _delete_records( - """select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""", - {"tenant_id": tenant_id, "app_id": app_id}, - del_workflow_node_execution, - "workflow node execution", + # Create a repository instance for WorkflowNodeExecution + repository = RepositoryFactory.create_workflow_node_execution_repository( + params={ + "tenant_id": tenant_id, + "app_id": app_id, + "session_factory": db.session.get_bind, + } ) + # Use the clear method to delete all records for this tenant_id and app_id + repository.clear() + + logging.info(click.style(f"Deleted workflow node executions for tenant {tenant_id} and app {app_id}", fg="green")) + def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def del_workflow_app_log(workflow_app_log_id: str): diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py index f31adab2a8..36847f8a13 100644 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py @@ -152,3 +152,27 @@ def test_update(repository, session): # Assert session.merge was called session_obj.merge.assert_called_once_with(execution) + + +def test_clear(repository, session, mocker: MockerFixture): + """Test clear method.""" + session_obj, _ = session + # Set up mock + mock_delete = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.delete") + mock_stmt = mocker.MagicMock() + mock_delete.return_value = mock_stmt + mock_stmt.where.return_value = mock_stmt + + # Mock the execute result with rowcount + mock_result = mocker.MagicMock() + mock_result.rowcount = 5 # Simulate 5 records deleted + session_obj.execute.return_value = mock_result + + # Call method + repository.clear() + + # Assert delete was called with correct parameters + mock_delete.assert_called_once_with(WorkflowNodeExecution) + mock_stmt.where.assert_called() + session_obj.execute.assert_called_once_with(mock_stmt) + session_obj.commit.assert_called_once() diff --git a/api/uv.lock b/api/uv.lock index 4ff9c34446..4384e1abb5 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.11, <3.13" resolution-markers = [ "python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy'", @@ -1178,6 +1177,7 @@ dependencies = [ { name = "gunicorn" }, { name = "httpx", extra = ["socks"] }, { name = "jieba" }, + { name = "json-repair" }, { name = "langfuse" }, { name = "langsmith" }, { name = "mailchimp-transactional" }, @@ -1346,6 +1346,7 @@ requires-dist = [ { name = "gunicorn", specifier = "~=23.0.0" }, { name = "httpx", extras = ["socks"], specifier = "~=0.27.0" }, { name = "jieba", specifier = "==0.42.1" }, + { name = "json-repair", specifier = ">=0.41.1" }, { name = "langfuse", specifier = "~=2.51.3" }, { name = "langsmith", specifier = "~=0.1.77" }, { name = "mailchimp-transactional", specifier = "~=1.0.50" }, @@ -2524,6 +2525,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/91/29/df4b9b42f2be0b623cbd5e2140cafcaa2bef0759a00b7b70104dcfe2fb51/joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6", size = 301817 }, ] +[[package]] +name = "json-repair" +version = "0.41.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6d/6a/6c7a75a10da6dc807b582f2449034da1ed74415e8899746bdfff97109012/json_repair-0.41.1.tar.gz", hash = "sha256:bba404b0888c84a6b86ecc02ec43b71b673cfee463baf6da94e079c55b136565", size = 31208 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/5c/abd7495c934d9af5c263c2245ae30cfaa716c3c0cf027b2b8fa686ee7bd4/json_repair-0.41.1-py3-none-any.whl", hash = "sha256:0e181fd43a696887881fe19fed23422a54b3e4c558b6ff27a86a8c3ddde9ae79", size = 21578 }, +] + [[package]] name = "jsonpath-python" version = "1.0.6" @@ -4074,6 +4084,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/cd/ed6e429fb0792ce368f66e83246264dd3a7a045b0b1e63043ed22a063ce5/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7c9e222d0976f68d0cf6409cfea896676ddc1d98485d601e9508f90f60e2b0a2", size = 2144914 }, { url = "https://files.pythonhosted.org/packages/f6/23/b064bd4cfbf2cc5f25afcde0e7c880df5b20798172793137ba4b62d82e72/pycryptodome-3.19.1-cp35-abi3-win32.whl", hash = "sha256:4805e053571140cb37cf153b5c72cd324bb1e3e837cbe590a19f69b6cf85fd03", size = 1713105 }, { url = "https://files.pythonhosted.org/packages/7d/e0/ded1968a5257ab34216a0f8db7433897a2337d59e6d03be113713b346ea2/pycryptodome-3.19.1-cp35-abi3-win_amd64.whl", hash = "sha256:a470237ee71a1efd63f9becebc0ad84b88ec28e6784a2047684b693f458f41b7", size = 1749222 }, + { url = "https://files.pythonhosted.org/packages/1d/e3/0c9679cd66cf5604b1f070bdf4525a0c01a15187be287d8348b2eafb718e/pycryptodome-3.19.1-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:ed932eb6c2b1c4391e166e1a562c9d2f020bfff44a0e1b108f67af38b390ea89", size = 1629005 }, + { url = "https://files.pythonhosted.org/packages/13/75/0d63bf0daafd0580b17202d8a9dd57f28c8487f26146b3e2799b0c5a059c/pycryptodome-3.19.1-pp27-pypy_73-win32.whl", hash = "sha256:81e9d23c0316fc1b45d984a44881b220062336bbdc340aa9218e8d0656587934", size = 1697997 }, ] [[package]] diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index a8f7b755fb..c6d41849ef 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -130,6 +130,7 @@ services: HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128} HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128} SANDBOX_PORT: ${SANDBOX_PORT:-8194} + PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} volumes: - ./volumes/sandbox/dependencies:/dependencies - ./volumes/sandbox/conf:/conf diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 27d6d660d0..1702a5395f 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -60,6 +60,7 @@ services: HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128} HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128} SANDBOX_PORT: ${SANDBOX_PORT:-8194} + PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} volumes: - ./volumes/sandbox/dependencies:/dependencies - ./volumes/sandbox/conf:/conf diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index e01b9f7e9a..def4b77c65 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -603,6 +603,7 @@ services: HTTP_PROXY: ${SANDBOX_HTTP_PROXY:-http://ssrf_proxy:3128} HTTPS_PROXY: ${SANDBOX_HTTPS_PROXY:-http://ssrf_proxy:3128} SANDBOX_PORT: ${SANDBOX_PORT:-8194} + PIP_MIRROR_URL: ${PIP_MIRROR_URL:-} volumes: - ./volumes/sandbox/dependencies:/dependencies - ./volumes/sandbox/conf:/conf diff --git a/web/app/components/app/configuration/config-var/config-select/index.spec.tsx b/web/app/components/app/configuration/config-var/config-select/index.spec.tsx new file mode 100644 index 0000000000..18df318de3 --- /dev/null +++ b/web/app/components/app/configuration/config-var/config-select/index.spec.tsx @@ -0,0 +1,82 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import ConfigSelect from './index' + +jest.mock('react-sortablejs', () => ({ + ReactSortable: ({ children }: { children: React.ReactNode }) =>
{children}
, +})) + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +describe('ConfigSelect Component', () => { + const defaultProps = { + options: ['Option 1', 'Option 2'], + onChange: jest.fn(), + } + + afterEach(() => { + jest.clearAllMocks() + }) + + it('renders all options', () => { + render() + + defaultProps.options.forEach((option) => { + expect(screen.getByDisplayValue(option)).toBeInTheDocument() + }) + }) + + it('renders add button', () => { + render() + + expect(screen.getByText('appDebug.variableConfig.addOption')).toBeInTheDocument() + }) + + it('handles option deletion', () => { + render() + const optionContainer = screen.getByDisplayValue('Option 1').closest('div') + const deleteButton = optionContainer?.querySelector('div[role="button"]') + + if (!deleteButton) return + fireEvent.click(deleteButton) + expect(defaultProps.onChange).toHaveBeenCalledWith(['Option 2']) + }) + + it('handles adding new option', () => { + render() + const addButton = screen.getByText('appDebug.variableConfig.addOption') + + fireEvent.click(addButton) + + expect(defaultProps.onChange).toHaveBeenCalledWith([...defaultProps.options, '']) + }) + + it('applies focus styles on input focus', () => { + render() + const firstInput = screen.getByDisplayValue('Option 1') + + fireEvent.focus(firstInput) + + expect(firstInput.closest('div')).toHaveClass('border-components-input-border-active') + }) + + it('applies delete hover styles', () => { + render() + const optionContainer = screen.getByDisplayValue('Option 1').closest('div') + const deleteButton = optionContainer?.querySelector('div[role="button"]') + + if (!deleteButton) return + fireEvent.mouseEnter(deleteButton) + expect(optionContainer).toHaveClass('border-components-input-border-destructive') + }) + + it('renders empty state correctly', () => { + render() + + expect(screen.queryByRole('textbox')).not.toBeInTheDocument() + expect(screen.getByText('appDebug.variableConfig.addOption')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app/configuration/config-var/config-select/index.tsx b/web/app/components/app/configuration/config-var/config-select/index.tsx index d2dc1662c1..40ddaef78f 100644 --- a/web/app/components/app/configuration/config-var/config-select/index.tsx +++ b/web/app/components/app/configuration/config-var/config-select/index.tsx @@ -51,7 +51,7 @@ const ConfigSelect: FC = ({ { const value = e.target.value @@ -67,6 +67,7 @@ const ConfigSelect: FC = ({ onBlur={() => setFocusID(null)} />
{ onChange(options.filter((_, i) => index !== i)) diff --git a/web/app/components/base/chat/chat/answer/index.tsx b/web/app/components/base/chat/chat/answer/index.tsx index 3722556931..a0a9323729 100644 --- a/web/app/components/base/chat/chat/answer/index.tsx +++ b/web/app/components/base/chat/chat/answer/index.tsx @@ -234,4 +234,6 @@ const Answer: FC = ({ ) } -export default memo(Answer) +export default memo(Answer, (prevProps, nextProps) => + prevProps.responding === false && nextProps.responding === false, +) diff --git a/web/app/components/base/checkbox/assets/indeterminate-icon.tsx b/web/app/components/base/checkbox/assets/indeterminate-icon.tsx new file mode 100644 index 0000000000..56df8db6a4 --- /dev/null +++ b/web/app/components/base/checkbox/assets/indeterminate-icon.tsx @@ -0,0 +1,11 @@ +const IndeterminateIcon = () => { + return ( +
+ + + +
+ ) +} + +export default IndeterminateIcon diff --git a/web/app/components/base/checkbox/assets/mixed.svg b/web/app/components/base/checkbox/assets/mixed.svg deleted file mode 100644 index e16b8fc975..0000000000 --- a/web/app/components/base/checkbox/assets/mixed.svg +++ /dev/null @@ -1,5 +0,0 @@ - - - - - diff --git a/web/app/components/base/checkbox/index.module.css b/web/app/components/base/checkbox/index.module.css deleted file mode 100644 index d675607b46..0000000000 --- a/web/app/components/base/checkbox/index.module.css +++ /dev/null @@ -1,10 +0,0 @@ -.mixed { - background: var(--color-components-checkbox-bg) url(./assets/mixed.svg) center center no-repeat; - background-size: 12px 12px; - border: none; -} - -.checked.disabled { - background-color: #d0d5dd; - border-color: #d0d5dd; -} \ No newline at end of file diff --git a/web/app/components/base/checkbox/index.spec.tsx b/web/app/components/base/checkbox/index.spec.tsx new file mode 100644 index 0000000000..7ef901aef5 --- /dev/null +++ b/web/app/components/base/checkbox/index.spec.tsx @@ -0,0 +1,67 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import Checkbox from './index' + +describe('Checkbox Component', () => { + const mockProps = { + id: 'test', + } + + it('renders unchecked checkbox by default', () => { + render() + const checkbox = screen.getByTestId('checkbox-test') + expect(checkbox).toBeInTheDocument() + expect(checkbox).not.toHaveClass('bg-components-checkbox-bg') + }) + + it('renders checked checkbox when checked prop is true', () => { + render() + const checkbox = screen.getByTestId('checkbox-test') + expect(checkbox).toHaveClass('bg-components-checkbox-bg') + expect(screen.getByTestId('check-icon-test')).toBeInTheDocument() + }) + + it('renders indeterminate state correctly', () => { + render() + expect(screen.getByTestId('indeterminate-icon')).toBeInTheDocument() + }) + + it('handles click events when not disabled', () => { + const onCheck = jest.fn() + render() + const checkbox = screen.getByTestId('checkbox-test') + + fireEvent.click(checkbox) + expect(onCheck).toHaveBeenCalledTimes(1) + }) + + it('does not handle click events when disabled', () => { + const onCheck = jest.fn() + render() + const checkbox = screen.getByTestId('checkbox-test') + + fireEvent.click(checkbox) + expect(onCheck).not.toHaveBeenCalled() + expect(checkbox).toHaveClass('cursor-not-allowed') + }) + + it('applies custom className when provided', () => { + const customClass = 'custom-class' + render() + const checkbox = screen.getByTestId('checkbox-test') + expect(checkbox).toHaveClass(customClass) + }) + + it('applies correct styles for disabled checked state', () => { + render() + const checkbox = screen.getByTestId('checkbox-test') + expect(checkbox).toHaveClass('bg-components-checkbox-bg-disabled-checked') + expect(checkbox).toHaveClass('cursor-not-allowed') + }) + + it('applies correct styles for disabled unchecked state', () => { + render() + const checkbox = screen.getByTestId('checkbox-test') + expect(checkbox).toHaveClass('bg-components-checkbox-bg-disabled') + expect(checkbox).toHaveClass('cursor-not-allowed') + }) +}) diff --git a/web/app/components/base/checkbox/index.tsx b/web/app/components/base/checkbox/index.tsx index b0b0ebca7c..3e47967c62 100644 --- a/web/app/components/base/checkbox/index.tsx +++ b/web/app/components/base/checkbox/index.tsx @@ -1,48 +1,49 @@ import { RiCheckLine } from '@remixicon/react' -import s from './index.module.css' import cn from '@/utils/classnames' +import IndeterminateIcon from './assets/indeterminate-icon' type CheckboxProps = { + id?: string checked?: boolean onCheck?: () => void className?: string disabled?: boolean - mixed?: boolean + indeterminate?: boolean } -const Checkbox = ({ checked, onCheck, className, disabled, mixed }: CheckboxProps) => { - if (!checked) { - return ( -
{ - if (disabled) - return - onCheck?.() - }} - >
- ) - } +const Checkbox = ({ + id, + checked, + onCheck, + className, + disabled, + indeterminate, +}: CheckboxProps) => { + const checkClassName = (checked || indeterminate) + ? 'bg-components-checkbox-bg text-components-checkbox-icon hover:bg-components-checkbox-bg-hover' + : 'border border-components-checkbox-border bg-components-checkbox-bg-unchecked hover:bg-components-checkbox-bg-unchecked-hover hover:border-components-checkbox-border-hover' + const disabledClassName = (checked || indeterminate) + ? 'cursor-not-allowed bg-components-checkbox-bg-disabled-checked text-components-checkbox-icon-disabled hover:bg-components-checkbox-bg-disabled-checked' + : 'cursor-not-allowed border-components-checkbox-border-disabled bg-components-checkbox-bg-disabled hover:border-components-checkbox-border-disabled hover:bg-components-checkbox-bg-disabled' + return (
{ if (disabled) return - onCheck?.() }} + data-testid={`checkbox-${id}`} > - + {!checked && indeterminate && } + {checked && }
) } diff --git a/web/app/components/base/form/components/field/checkbox.tsx b/web/app/components/base/form/components/field/checkbox.tsx new file mode 100644 index 0000000000..855dbd80fe --- /dev/null +++ b/web/app/components/base/form/components/field/checkbox.tsx @@ -0,0 +1,43 @@ +import cn from '@/utils/classnames' +import { useFieldContext } from '../..' +import Checkbox from '../../../checkbox' + +type CheckboxFieldProps = { + label: string; + labelClassName?: string; +} + +const CheckboxField = ({ + label, + labelClassName, +}: CheckboxFieldProps) => { + const field = useFieldContext() + + return ( +
+
+ { + field.handleChange(!field.state.value) + }} + /> +
+ +
+ ) +} + +export default CheckboxField diff --git a/web/app/components/base/form/components/field/number-input.tsx b/web/app/components/base/form/components/field/number-input.tsx new file mode 100644 index 0000000000..fce3143fe1 --- /dev/null +++ b/web/app/components/base/form/components/field/number-input.tsx @@ -0,0 +1,49 @@ +import React from 'react' +import { useFieldContext } from '../..' +import Label from '../label' +import cn from '@/utils/classnames' +import type { InputNumberProps } from '../../../input-number' +import { InputNumber } from '../../../input-number' + +type TextFieldProps = { + label: string + isRequired?: boolean + showOptional?: boolean + tooltip?: string + className?: string + labelClassName?: string +} & Omit + +const NumberInputField = ({ + label, + isRequired, + showOptional, + tooltip, + className, + labelClassName, + ...inputProps +}: TextFieldProps) => { + const field = useFieldContext() + + return ( +
+
+ ) +} + +export default NumberInputField diff --git a/web/app/components/base/form/components/field/options.tsx b/web/app/components/base/form/components/field/options.tsx new file mode 100644 index 0000000000..9ff71e50af --- /dev/null +++ b/web/app/components/base/form/components/field/options.tsx @@ -0,0 +1,34 @@ +import cn from '@/utils/classnames' +import { useFieldContext } from '../..' +import Label from '../label' +import ConfigSelect from '@/app/components/app/configuration/config-var/config-select' + +type OptionsFieldProps = { + label: string; + className?: string; + labelClassName?: string; +} + +const OptionsField = ({ + label, + className, + labelClassName, +}: OptionsFieldProps) => { + const field = useFieldContext() + + return ( +
+
+ ) +} + +export default OptionsField diff --git a/web/app/components/base/form/components/field/select.tsx b/web/app/components/base/form/components/field/select.tsx new file mode 100644 index 0000000000..95af3c0116 --- /dev/null +++ b/web/app/components/base/form/components/field/select.tsx @@ -0,0 +1,51 @@ +import cn from '@/utils/classnames' +import { useFieldContext } from '../..' +import PureSelect from '../../../select/pure' +import Label from '../label' + +type SelectOption = { + value: string + label: string +} + +type SelectFieldProps = { + label: string + options: SelectOption[] + isRequired?: boolean + showOptional?: boolean + tooltip?: string + className?: string + labelClassName?: string +} + +const SelectField = ({ + label, + options, + isRequired, + showOptional, + tooltip, + className, + labelClassName, +}: SelectFieldProps) => { + const field = useFieldContext() + + return ( +
+
+ ) +} + +export default SelectField diff --git a/web/app/components/base/form/components/field/text.tsx b/web/app/components/base/form/components/field/text.tsx new file mode 100644 index 0000000000..b2090291a0 --- /dev/null +++ b/web/app/components/base/form/components/field/text.tsx @@ -0,0 +1,48 @@ +import React from 'react' +import { useFieldContext } from '../..' +import Input, { type InputProps } from '../../../input' +import Label from '../label' +import cn from '@/utils/classnames' + +type TextFieldProps = { + label: string + isRequired?: boolean + showOptional?: boolean + tooltip?: string + className?: string + labelClassName?: string +} & Omit + +const TextField = ({ + label, + isRequired, + showOptional, + tooltip, + className, + labelClassName, + ...inputProps +}: TextFieldProps) => { + const field = useFieldContext() + + return ( +
+
+ ) +} + +export default TextField diff --git a/web/app/components/base/form/components/form/submit-button.tsx b/web/app/components/base/form/components/form/submit-button.tsx new file mode 100644 index 0000000000..494d19b843 --- /dev/null +++ b/web/app/components/base/form/components/form/submit-button.tsx @@ -0,0 +1,25 @@ +import { useStore } from '@tanstack/react-form' +import { useFormContext } from '../..' +import Button, { type ButtonProps } from '../../../button' + +type SubmitButtonProps = Omit + +const SubmitButton = ({ ...buttonProps }: SubmitButtonProps) => { + const form = useFormContext() + + const [isSubmitting, canSubmit] = useStore(form.store, state => [ + state.isSubmitting, + state.canSubmit, + ]) + + return ( +
+ ) +} + +export default Label diff --git a/web/app/components/base/form/form-scenarios/demo/contact-fields.tsx b/web/app/components/base/form/form-scenarios/demo/contact-fields.tsx new file mode 100644 index 0000000000..9ba664fc10 --- /dev/null +++ b/web/app/components/base/form/form-scenarios/demo/contact-fields.tsx @@ -0,0 +1,35 @@ +import { withForm } from '../..' +import { demoFormOpts } from './shared-options' +import { ContactMethods } from './types' + +const ContactFields = withForm({ + ...demoFormOpts, + render: ({ form }) => { + return ( +
+

Contacts

+
+ } + /> + } + /> + ( + + )} + /> +
+
+ ) + }, +}) + +export default ContactFields diff --git a/web/app/components/base/form/form-scenarios/demo/index.tsx b/web/app/components/base/form/form-scenarios/demo/index.tsx new file mode 100644 index 0000000000..f08edee41e --- /dev/null +++ b/web/app/components/base/form/form-scenarios/demo/index.tsx @@ -0,0 +1,68 @@ +import { useStore } from '@tanstack/react-form' +import { useAppForm } from '../..' +import ContactFields from './contact-fields' +import { demoFormOpts } from './shared-options' +import { UserSchema } from './types' + +const DemoForm = () => { + const form = useAppForm({ + ...demoFormOpts, + validators: { + onSubmit: ({ value }) => { + // Validate the entire form + const result = UserSchema.safeParse(value) + if (!result.success) { + const issues = result.error.issues + console.log('Validation errors:', issues) + return issues[0].message + } + return undefined + }, + }, + onSubmit: ({ value }) => { + console.log('Form submitted:', value) + }, + }) + +const name = useStore(form.store, state => state.values.name) + + return ( +
{ + e.preventDefault() + e.stopPropagation() + form.handleSubmit() + }} + > + ( + + )} + /> + ( + + )} + /> + ( + + )} + /> + { + !!name && ( + + ) + } + + Submit + + + ) +} + +export default DemoForm diff --git a/web/app/components/base/form/form-scenarios/demo/shared-options.tsx b/web/app/components/base/form/form-scenarios/demo/shared-options.tsx new file mode 100644 index 0000000000..8b216c8b90 --- /dev/null +++ b/web/app/components/base/form/form-scenarios/demo/shared-options.tsx @@ -0,0 +1,14 @@ +import { formOptions } from '@tanstack/react-form' + +export const demoFormOpts = formOptions({ + defaultValues: { + name: '', + surname: '', + isAcceptingTerms: false, + contact: { + email: '', + phone: '', + preferredContactMethod: 'email', + }, + }, +}) diff --git a/web/app/components/base/form/form-scenarios/demo/types.ts b/web/app/components/base/form/form-scenarios/demo/types.ts new file mode 100644 index 0000000000..c4e626ef63 --- /dev/null +++ b/web/app/components/base/form/form-scenarios/demo/types.ts @@ -0,0 +1,34 @@ +import { z } from 'zod' + +const ContactMethod = z.union([ + z.literal('email'), + z.literal('phone'), + z.literal('whatsapp'), + z.literal('sms'), +]) + +export const ContactMethods = ContactMethod.options.map(({ value }) => ({ + value, + label: value.charAt(0).toUpperCase() + value.slice(1), +})) + +export const UserSchema = z.object({ + name: z + .string() + .regex(/^[A-Z]/, 'Name must start with a capital letter') + .min(3, 'Name must be at least 3 characters long'), + surname: z + .string() + .min(3, 'Surname must be at least 3 characters long') + .regex(/^[A-Z]/, 'Surname must start with a capital letter'), + isAcceptingTerms: z.boolean().refine(val => val, { + message: 'You must accept the terms and conditions', + }), + contact: z.object({ + email: z.string().email('Invalid email address'), + phone: z.string().optional(), + preferredContactMethod: ContactMethod, + }), +}) + +export type User = z.infer diff --git a/web/app/components/base/form/index.tsx b/web/app/components/base/form/index.tsx new file mode 100644 index 0000000000..aeb482ad02 --- /dev/null +++ b/web/app/components/base/form/index.tsx @@ -0,0 +1,25 @@ +import { createFormHook, createFormHookContexts } from '@tanstack/react-form' +import TextField from './components/field/text' +import NumberInputField from './components/field/number-input' +import CheckboxField from './components/field/checkbox' +import SelectField from './components/field/select' +import OptionsField from './components/field/options' +import SubmitButton from './components/form/submit-button' + +export const { fieldContext, useFieldContext, formContext, useFormContext } + = createFormHookContexts() + +export const { useAppForm, withForm } = createFormHook({ + fieldComponents: { + TextField, + NumberInputField, + CheckboxField, + SelectField, + OptionsField, + }, + formComponents: { + SubmitButton, + }, + fieldContext, + formContext, +}) diff --git a/web/app/components/base/icons/assets/vender/solid/general/arrow-down-round-fill.svg b/web/app/components/base/icons/assets/vender/solid/general/arrow-down-round-fill.svg new file mode 100644 index 0000000000..9566fcc0c3 --- /dev/null +++ b/web/app/components/base/icons/assets/vender/solid/general/arrow-down-round-fill.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.json b/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.json new file mode 100644 index 0000000000..4e7da3c801 --- /dev/null +++ b/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.json @@ -0,0 +1,36 @@ +{ + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "width": "16", + "height": "16", + "viewBox": "0 0 16 16", + "fill": "none", + "xmlns": "http://www.w3.org/2000/svg" + }, + "children": [ + { + "type": "element", + "name": "g", + "attributes": { + "id": "arrow-down-round-fill" + }, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "id": "Vector", + "d": "M6.02913 6.23572C5.08582 6.23572 4.56482 7.33027 5.15967 8.06239L7.13093 10.4885C7.57922 11.0403 8.42149 11.0403 8.86986 10.4885L10.8411 8.06239C11.4359 7.33027 10.9149 6.23572 9.97158 6.23572H6.02913Z", + "fill": "currentColor" + }, + "children": [] + } + ] + } + ] + }, + "name": "ArrowDownRoundFill" +} \ No newline at end of file diff --git a/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.tsx b/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.tsx new file mode 100644 index 0000000000..c766a72b94 --- /dev/null +++ b/web/app/components/base/icons/src/vender/solid/general/ArrowDownRoundFill.tsx @@ -0,0 +1,20 @@ +// GENERATE BY script +// DON NOT EDIT IT MANUALLY + +import * as React from 'react' +import data from './ArrowDownRoundFill.json' +import IconBase from '@/app/components/base/icons/IconBase' +import type { IconData } from '@/app/components/base/icons/IconBase' + +const Icon = ( + { + ref, + ...props + }: React.SVGProps & { + ref?: React.RefObject>; + }, +) => + +Icon.displayName = 'ArrowDownRoundFill' + +export default Icon diff --git a/web/app/components/base/icons/src/vender/solid/general/index.ts b/web/app/components/base/icons/src/vender/solid/general/index.ts index 52647905ab..4c4dd9a437 100644 --- a/web/app/components/base/icons/src/vender/solid/general/index.ts +++ b/web/app/components/base/icons/src/vender/solid/general/index.ts @@ -1,4 +1,5 @@ export { default as AnswerTriangle } from './AnswerTriangle' +export { default as ArrowDownRoundFill } from './ArrowDownRoundFill' export { default as CheckCircle } from './CheckCircle' export { default as CheckDone01 } from './CheckDone01' export { default as Download02 } from './Download02' diff --git a/web/app/components/base/input-number/index.spec.tsx b/web/app/components/base/input-number/index.spec.tsx new file mode 100644 index 0000000000..8dfd1184b0 --- /dev/null +++ b/web/app/components/base/input-number/index.spec.tsx @@ -0,0 +1,97 @@ +import { fireEvent, render, screen } from '@testing-library/react' +import { InputNumber } from './index' + +jest.mock('react-i18next', () => ({ + useTranslation: () => ({ + t: (key: string) => key, + }), +})) + +describe('InputNumber Component', () => { + const defaultProps = { + onChange: jest.fn(), + } + + afterEach(() => { + jest.clearAllMocks() + }) + + it('renders input with default values', () => { + render() + const input = screen.getByRole('textbox') + expect(input).toBeInTheDocument() + }) + + it('handles increment button click', () => { + render() + const incrementBtn = screen.getByRole('button', { name: /increment/i }) + + fireEvent.click(incrementBtn) + expect(defaultProps.onChange).toHaveBeenCalledWith(6) + }) + + it('handles decrement button click', () => { + render() + const decrementBtn = screen.getByRole('button', { name: /decrement/i }) + + fireEvent.click(decrementBtn) + expect(defaultProps.onChange).toHaveBeenCalledWith(4) + }) + + it('respects max value constraint', () => { + render() + const incrementBtn = screen.getByRole('button', { name: /increment/i }) + + fireEvent.click(incrementBtn) + expect(defaultProps.onChange).not.toHaveBeenCalled() + }) + + it('respects min value constraint', () => { + render() + const decrementBtn = screen.getByRole('button', { name: /decrement/i }) + + fireEvent.click(decrementBtn) + expect(defaultProps.onChange).not.toHaveBeenCalled() + }) + + it('handles direct input changes', () => { + render() + const input = screen.getByRole('textbox') + + fireEvent.change(input, { target: { value: '42' } }) + expect(defaultProps.onChange).toHaveBeenCalledWith(42) + }) + + it('handles empty input', () => { + render() + const input = screen.getByRole('textbox') + + fireEvent.change(input, { target: { value: '' } }) + expect(defaultProps.onChange).toHaveBeenCalledWith(undefined) + }) + + it('handles invalid input', () => { + render() + const input = screen.getByRole('textbox') + + fireEvent.change(input, { target: { value: 'abc' } }) + expect(defaultProps.onChange).not.toHaveBeenCalled() + }) + + it('displays unit when provided', () => { + const unit = 'px' + render() + expect(screen.getByText(unit)).toBeInTheDocument() + }) + + it('disables controls when disabled prop is true', () => { + render() + const input = screen.getByRole('textbox') + const incrementBtn = screen.getByRole('button', { name: /increment/i }) + const decrementBtn = screen.getByRole('button', { name: /decrement/i }) + + expect(input).toBeDisabled() + expect(incrementBtn).toBeDisabled() + expect(decrementBtn).toBeDisabled() + }) +}) diff --git a/web/app/components/base/input-number/index.tsx b/web/app/components/base/input-number/index.tsx index 5b88fc67f8..98efc94462 100644 --- a/web/app/components/base/input-number/index.tsx +++ b/web/app/components/base/input-number/index.tsx @@ -8,7 +8,7 @@ export type InputNumberProps = { value?: number onChange: (value?: number) => void amount?: number - size?: 'sm' | 'md' + size?: 'regular' | 'large' max?: number min?: number defaultValue?: number @@ -19,14 +19,12 @@ export type InputNumberProps = { } & Omit export const InputNumber: FC = (props) => { - const { unit, className, onChange, amount = 1, value, size = 'md', max, min, defaultValue, wrapClassName, controlWrapClassName, controlClassName, disabled, ...rest } = props + const { unit, className, onChange, amount = 1, value, size = 'regular', max, min, defaultValue, wrapClassName, controlWrapClassName, controlClassName, disabled, ...rest } = props const isValidValue = (v: number) => { - if (max && v > max) + if (typeof max === 'number' && v > max) return false - if (min && v < min) - return false - return true + return !(typeof min === 'number' && v < min) } const inc = () => { @@ -76,29 +74,39 @@ export const InputNumber: FC = (props) => { onChange(parsed) }} unit={unit} + size={size} />
-
diff --git a/web/app/components/base/input/index.tsx b/web/app/components/base/input/index.tsx index 5f059c3b7f..30fd90aff8 100644 --- a/web/app/components/base/input/index.tsx +++ b/web/app/components/base/input/index.tsx @@ -30,7 +30,7 @@ export type InputProps = { wrapperClassName?: string styleCss?: CSSProperties unit?: string -} & React.InputHTMLAttributes & VariantProps +} & Omit, 'size'> & VariantProps const Input = ({ size, diff --git a/web/app/components/base/markdown-blocks/music.tsx b/web/app/components/base/markdown-blocks/music.tsx new file mode 100644 index 0000000000..7edd1713c9 --- /dev/null +++ b/web/app/components/base/markdown-blocks/music.tsx @@ -0,0 +1,37 @@ +import abcjs from 'abcjs' +import { useEffect, useRef } from 'react' +import 'abcjs/abcjs-audio.css' + +const MarkdownMusic = ({ children }: { children: React.ReactNode }) => { + const containerRef = useRef(null) + const controlsRef = useRef(null) + + useEffect(() => { + if (containerRef.current && controlsRef.current) { + if (typeof children === 'string') { + const visualObjs = abcjs.renderAbc(containerRef.current, children, { + add_classes: true, // Add classes to SVG elements for cursor tracking + responsive: 'resize', // Make notation responsive + }) + const synthControl = new abcjs.synth.SynthController() + synthControl.load(controlsRef.current, {}, { displayPlay: true }) + const synth = new abcjs.synth.CreateSynth() + const visualObj = visualObjs[0] + synth.init({ visualObj }).then(() => { + synthControl.setTune(visualObj, false) + }) + containerRef.current.style.overflow = 'auto' + } + } + }, [children]) + + return ( +
+
+
+
+ ) +} +MarkdownMusic.displayName = 'MarkdownMusic' + +export default MarkdownMusic diff --git a/web/app/components/base/markdown.tsx b/web/app/components/base/markdown.tsx index d50c397177..52b880affa 100644 --- a/web/app/components/base/markdown.tsx +++ b/web/app/components/base/markdown.tsx @@ -23,6 +23,7 @@ import VideoGallery from '@/app/components/base/video-gallery' import AudioGallery from '@/app/components/base/audio-gallery' import MarkdownButton from '@/app/components/base/markdown-blocks/button' import MarkdownForm from '@/app/components/base/markdown-blocks/form' +import MarkdownMusic from '@/app/components/base/markdown-blocks/music' import ThinkBlock from '@/app/components/base/markdown-blocks/think-block' import { Theme } from '@/types/app' import useTheme from '@/hooks/use-theme' @@ -51,6 +52,7 @@ const capitalizationLanguageNameMap: Record = { json: 'JSON', latex: 'Latex', svg: 'SVG', + abc: 'ABC', } const getCorrectCapitalizationLanguageName = (language: string) => { if (!language) @@ -137,45 +139,54 @@ const CodeBlock: any = memo(({ inline, className, children, ...props }: any) => const renderCodeContent = useMemo(() => { const content = String(children).replace(/\n$/, '') - if (language === 'mermaid' && isSVG) { - return - } - else if (language === 'echarts') { - return ( -
+ switch (language) { + case 'mermaid': + if (isSVG) + return + break + case 'echarts': + return ( +
+ + + +
+ ) + case 'svg': + if (isSVG) { + return ( + + + + ) + } + break + case 'abc': + return ( - + -
- ) - } - else if (language === 'svg' && isSVG) { - return ( - - - - ) - } - else { - return ( - - {content} - - ) + ) + default: + return ( + + {content} + + ) } - }, [language, match, props, children, chartData, isSVG]) + }, [children, language, isSVG, chartData, props, theme, match]) if (inline || !match) return {children} diff --git a/web/app/components/base/param-item/index.tsx b/web/app/components/base/param-item/index.tsx index 4cae402e3b..03eb5a7c42 100644 --- a/web/app/components/base/param-item/index.tsx +++ b/web/app/components/base/param-item/index.tsx @@ -54,7 +54,7 @@ const ParamItem: FC = ({ className, id, name, noTooltip, tip, step = 0.1, max={max} step={step} amount={step} - size='sm' + size='regular' value={value} onChange={(value) => { onChange(id, value) diff --git a/web/app/components/base/prompt-editor/plugins/history-block/node.tsx b/web/app/components/base/prompt-editor/plugins/history-block/node.tsx index 1a2600d568..1cb33fcc49 100644 --- a/web/app/components/base/prompt-editor/plugins/history-block/node.tsx +++ b/web/app/components/base/prompt-editor/plugins/history-block/node.tsx @@ -14,7 +14,7 @@ export class HistoryBlockNode extends DecoratorNode { } static clone(node: HistoryBlockNode): HistoryBlockNode { - return new HistoryBlockNode(node.__roleName, node.__onEditRole) + return new HistoryBlockNode(node.__roleName, node.__onEditRole, node.__key) } constructor(roleName: RoleName, onEditRole: () => void, key?: NodeKey) { diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx index 2cf4c95b87..2f6c3374a7 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/component.tsx @@ -11,6 +11,7 @@ import { mergeRegister } from '@lexical/utils' import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' import { RiErrorWarningFill, + RiMoreLine, } from '@remixicon/react' import { useSelectOrDelete } from '../../hooks' import type { WorkflowNodesMap } from './node' @@ -27,26 +28,35 @@ import { Line3 } from '@/app/components/base/icons/src/public/common' import { isConversationVar, isENV, isSystemVar } from '@/app/components/workflow/nodes/_base/components/variable/utils' import Tooltip from '@/app/components/base/tooltip' import { isExceptionVariable } from '@/app/components/workflow/utils' +import VarFullPathPanel from '@/app/components/workflow/nodes/_base/components/variable/var-full-path-panel' +import { Type } from '@/app/components/workflow/nodes/llm/types' +import type { ValueSelector } from '@/app/components/workflow/types' type WorkflowVariableBlockComponentProps = { nodeKey: string variables: string[] workflowNodesMap: WorkflowNodesMap + getVarType?: (payload: { + nodeId: string, + valueSelector: ValueSelector, + }) => Type } const WorkflowVariableBlockComponent = ({ nodeKey, variables, workflowNodesMap = {}, + getVarType, }: WorkflowVariableBlockComponentProps) => { const { t } = useTranslation() const [editor] = useLexicalComposerContext() const [ref, isSelected] = useSelectOrDelete(nodeKey, DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND) const variablesLength = variables.length + const isShowAPart = variablesLength > 2 const varName = ( () => { const isSystem = isSystemVar(variables) - const varName = variablesLength >= 3 ? (variables).slice(-2).join('.') : variables[variablesLength - 1] + const varName = variables[variablesLength - 1] return `${isSystem ? 'sys.' : ''}${varName}` } )() @@ -76,7 +86,7 @@ const WorkflowVariableBlockComponent = ({ const Item = (
)} + {isShowAPart && ( +
+ + +
+ )} +
{!isEnv && !isChatVar && } {isEnv && } @@ -126,7 +143,27 @@ const WorkflowVariableBlockComponent = ({ ) } - return Item + if (!node) + return null + + return ( + } + disabled={!isShowAPart} + > +
{Item}
+
+ ) } export default memo(WorkflowVariableBlockComponent) diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx index 05d4505e20..479dce9615 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx @@ -9,7 +9,7 @@ import { } from 'lexical' import { mergeRegister } from '@lexical/utils' import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' -import type { WorkflowVariableBlockType } from '../../types' +import type { GetVarType, WorkflowVariableBlockType } from '../../types' import { $createWorkflowVariableBlockNode, WorkflowVariableBlockNode, @@ -25,11 +25,13 @@ export type WorkflowVariableBlockProps = { getWorkflowNode: (nodeId: string) => Node onInsert?: () => void onDelete?: () => void + getVarType: GetVarType } const WorkflowVariableBlock = memo(({ workflowNodesMap, onInsert, onDelete, + getVarType, }: WorkflowVariableBlockType) => { const [editor] = useLexicalComposerContext() @@ -48,7 +50,7 @@ const WorkflowVariableBlock = memo(({ INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND, (variables: string[]) => { editor.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) - const workflowVariableBlockNode = $createWorkflowVariableBlockNode(variables, workflowNodesMap) + const workflowVariableBlockNode = $createWorkflowVariableBlockNode(variables, workflowNodesMap, getVarType) $insertNodes([workflowVariableBlockNode]) if (onInsert) @@ -69,7 +71,7 @@ const WorkflowVariableBlock = memo(({ COMMAND_PRIORITY_EDITOR, ), ) - }, [editor, onInsert, onDelete, workflowNodesMap]) + }, [editor, onInsert, onDelete, workflowNodesMap, getVarType]) return null }) diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx index 0564e6f16d..dce636d92d 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx @@ -2,34 +2,39 @@ import type { LexicalNode, NodeKey, SerializedLexicalNode } from 'lexical' import { DecoratorNode } from 'lexical' import type { WorkflowVariableBlockType } from '../../types' import WorkflowVariableBlockComponent from './component' +import type { GetVarType } from '../../types' export type WorkflowNodesMap = WorkflowVariableBlockType['workflowNodesMap'] + export type SerializedNode = SerializedLexicalNode & { variables: string[] workflowNodesMap: WorkflowNodesMap + getVarType?: GetVarType } export class WorkflowVariableBlockNode extends DecoratorNode { __variables: string[] __workflowNodesMap: WorkflowNodesMap + __getVarType?: GetVarType static getType(): string { return 'workflow-variable-block' } static clone(node: WorkflowVariableBlockNode): WorkflowVariableBlockNode { - return new WorkflowVariableBlockNode(node.__variables, node.__workflowNodesMap, node.__key) + return new WorkflowVariableBlockNode(node.__variables, node.__workflowNodesMap, node.__getVarType, node.__key) } isInline(): boolean { return true } - constructor(variables: string[], workflowNodesMap: WorkflowNodesMap, key?: NodeKey) { + constructor(variables: string[], workflowNodesMap: WorkflowNodesMap, getVarType: any, key?: NodeKey) { super(key) this.__variables = variables this.__workflowNodesMap = workflowNodesMap + this.__getVarType = getVarType } createDOM(): HTMLElement { @@ -48,12 +53,13 @@ export class WorkflowVariableBlockNode extends DecoratorNode nodeKey={this.getKey()} variables={this.__variables} workflowNodesMap={this.__workflowNodesMap} + getVarType={this.__getVarType!} /> ) } static importJSON(serializedNode: SerializedNode): WorkflowVariableBlockNode { - const node = $createWorkflowVariableBlockNode(serializedNode.variables, serializedNode.workflowNodesMap) + const node = $createWorkflowVariableBlockNode(serializedNode.variables, serializedNode.workflowNodesMap, serializedNode.getVarType) return node } @@ -64,6 +70,7 @@ export class WorkflowVariableBlockNode extends DecoratorNode version: 1, variables: this.getVariables(), workflowNodesMap: this.getWorkflowNodesMap(), + getVarType: this.getVarType(), } } @@ -77,12 +84,17 @@ export class WorkflowVariableBlockNode extends DecoratorNode return self.__workflowNodesMap } + getVarType(): any { + const self = this.getLatest() + return self.__getVarType + } + getTextContent(): string { return `{{#${this.getVariables().join('.')}#}}` } } -export function $createWorkflowVariableBlockNode(variables: string[], workflowNodesMap: WorkflowNodesMap): WorkflowVariableBlockNode { - return new WorkflowVariableBlockNode(variables, workflowNodesMap) +export function $createWorkflowVariableBlockNode(variables: string[], workflowNodesMap: WorkflowNodesMap, getVarType?: GetVarType): WorkflowVariableBlockNode { + return new WorkflowVariableBlockNode(variables, workflowNodesMap, getVarType) } export function $isWorkflowVariableBlockNode( diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/workflow-variable-block-replacement-block.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/workflow-variable-block-replacement-block.tsx index 22ebc5d248..288008bbcc 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/workflow-variable-block-replacement-block.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/workflow-variable-block-replacement-block.tsx @@ -16,6 +16,7 @@ import { VAR_REGEX as REGEX, resetReg } from '@/config' const WorkflowVariableBlockReplacementBlock = ({ workflowNodesMap, + getVarType, onInsert, }: WorkflowVariableBlockType) => { const [editor] = useLexicalComposerContext() @@ -30,8 +31,8 @@ const WorkflowVariableBlockReplacementBlock = ({ onInsert() const nodePathString = textNode.getTextContent().slice(3, -3) - return $applyNodeReplacement($createWorkflowVariableBlockNode(nodePathString.split('.'), workflowNodesMap)) - }, [onInsert, workflowNodesMap]) + return $applyNodeReplacement($createWorkflowVariableBlockNode(nodePathString.split('.'), workflowNodesMap, getVarType)) + }, [onInsert, workflowNodesMap, getVarType]) const getMatch = useCallback((text: string) => { const matchArr = REGEX.exec(text) diff --git a/web/app/components/base/prompt-editor/types.ts b/web/app/components/base/prompt-editor/types.ts index 6d0f307c17..0f09fb2473 100644 --- a/web/app/components/base/prompt-editor/types.ts +++ b/web/app/components/base/prompt-editor/types.ts @@ -1,8 +1,10 @@ +import type { Type } from '../../workflow/nodes/llm/types' import type { Dataset } from './plugins/context-block' import type { RoleName } from './plugins/history-block' import type { Node, NodeOutPutVar, + ValueSelector, } from '@/app/components/workflow/types' export type Option = { @@ -54,12 +56,18 @@ export type ExternalToolBlockType = { onAddExternalTool?: () => void } +export type GetVarType = (payload: { + nodeId: string, + valueSelector: ValueSelector, +}) => Type + export type WorkflowVariableBlockType = { show?: boolean variables?: NodeOutPutVar[] workflowNodesMap?: Record> onInsert?: () => void onDelete?: () => void + getVarType?: GetVarType } export type MenuTextMatch = { diff --git a/web/app/components/base/segmented-control/index.tsx b/web/app/components/base/segmented-control/index.tsx new file mode 100644 index 0000000000..bd921e4243 --- /dev/null +++ b/web/app/components/base/segmented-control/index.tsx @@ -0,0 +1,68 @@ +import React from 'react' +import classNames from '@/utils/classnames' +import type { RemixiconComponentType } from '@remixicon/react' +import Divider from '../divider' + +// Updated generic type to allow enum values +type SegmentedControlProps = { + options: { Icon: RemixiconComponentType, text: string, value: T }[] + value: T + onChange: (value: T) => void + className?: string +} + +export const SegmentedControl = ({ + options, + value, + onChange, + className, +}: SegmentedControlProps): JSX.Element => { + const selectedOptionIndex = options.findIndex(option => option.value === value) + + return ( +
+ {options.map((option, index) => { + const { Icon } = option + const isSelected = index === selectedOptionIndex + const isNextSelected = index === selectedOptionIndex - 1 + const isLast = index === options.length - 1 + return ( + + ) + })} +
+ ) +} + +export default React.memo(SegmentedControl) as typeof SegmentedControl diff --git a/web/app/components/base/textarea/index.tsx b/web/app/components/base/textarea/index.tsx index 0f18bebedf..1e274515f8 100644 --- a/web/app/components/base/textarea/index.tsx +++ b/web/app/components/base/textarea/index.tsx @@ -8,8 +8,9 @@ const textareaVariants = cva( { variants: { size: { - regular: 'px-3 radius-md system-sm-regular', - large: 'px-4 radius-lg system-md-regular', + small: 'py-1 rounded-md system-xs-regular', + regular: 'px-3 rounded-md system-sm-regular', + large: 'px-4 rounded-lg system-md-regular', }, }, defaultVariants: { diff --git a/web/app/components/base/tooltip/index.tsx b/web/app/components/base/tooltip/index.tsx index e9b7ab047a..e6c4de31f1 100644 --- a/web/app/components/base/tooltip/index.tsx +++ b/web/app/components/base/tooltip/index.tsx @@ -10,6 +10,7 @@ export type TooltipProps = { position?: Placement triggerMethod?: 'hover' | 'click' triggerClassName?: string + triggerTestId?: string disabled?: boolean popupContent?: React.ReactNode children?: React.ReactNode @@ -24,6 +25,7 @@ const Tooltip: FC = ({ position = 'top', triggerMethod = 'hover', triggerClassName, + triggerTestId, disabled = false, popupContent, children, @@ -91,7 +93,7 @@ const Tooltip: FC = ({ onMouseLeave={() => triggerMethod === 'hover' && handleLeave(true)} asChild={asChild} > - {children ||
} + {children ||
} = (props) => {
}> = (props) => {
}> = ({ const resetList = useCallback(() => { setSelectedSegmentIds([]) invalidSegmentList() - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []) + }, [invalidSegmentList]) const resetChildList = useCallback(() => { invalidChildSegmentList() - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []) + }, [invalidChildSegmentList]) const onClickCard = (detail: SegmentDetailModel, isEditMode = false) => { setCurrSegment({ segInfo: detail, showModal: true, isEditMode }) @@ -253,7 +251,7 @@ const Completed: FC = ({ const invalidChunkListEnabled = useInvalid(useChunkListEnabledKey) const invalidChunkListDisabled = useInvalid(useChunkListDisabledKey) - const refreshChunkListWithStatusChanged = () => { + const refreshChunkListWithStatusChanged = useCallback(() => { switch (selectedStatus) { case 'all': invalidChunkListDisabled() @@ -262,7 +260,7 @@ const Completed: FC = ({ default: invalidSegmentList() } - } + }, [selectedStatus, invalidChunkListDisabled, invalidChunkListEnabled, invalidSegmentList]) const onChangeSwitch = useCallback(async (enable: boolean, segId?: string) => { const operationApi = enable ? enableSegment : disableSegment @@ -280,8 +278,7 @@ const Completed: FC = ({ notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') }) }, }) - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [datasetId, documentId, selectedSegmentIds, segments]) + }, [datasetId, documentId, selectedSegmentIds, segments, disableSegment, enableSegment, t, notify, refreshChunkListWithStatusChanged]) const { mutateAsync: deleteSegment } = useDeleteSegment() @@ -296,12 +293,11 @@ const Completed: FC = ({ notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') }) }, }) - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [datasetId, documentId, selectedSegmentIds]) + }, [datasetId, documentId, selectedSegmentIds, deleteSegment, resetList, t, notify]) const { mutateAsync: updateSegment } = useUpdateSegment() - const refreshChunkListDataWithDetailChanged = () => { + const refreshChunkListDataWithDetailChanged = useCallback(() => { switch (selectedStatus) { case 'all': invalidChunkListDisabled() @@ -316,7 +312,7 @@ const Completed: FC = ({ invalidChunkListEnabled() break } - } + }, [selectedStatus, invalidChunkListDisabled, invalidChunkListEnabled, invalidChunkListAll]) const handleUpdateSegment = useCallback(async ( segmentId: string, @@ -375,17 +371,18 @@ const Completed: FC = ({ eventEmitter?.emit('update-segment-done') }, }) - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [segments, datasetId, documentId]) + }, [segments, datasetId, documentId, updateSegment, docForm, notify, eventEmitter, onCloseSegmentDetail, refreshChunkListDataWithDetailChanged, t]) useEffect(() => { resetList() + // eslint-disable-next-line react-hooks/exhaustive-deps }, [pathname]) useEffect(() => { if (importStatus === ProcessStatus.COMPLETED) resetList() - }, [importStatus, resetList]) + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [importStatus]) const onCancelBatchOperation = useCallback(() => { setSelectedSegmentIds([]) @@ -430,8 +427,7 @@ const Completed: FC = ({ const count = segmentListData?.total || 0 return `${total} ${t('datasetDocuments.segment.searchResults', { count })}` } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [segmentListData?.total, mode, parentMode, searchValue, selectedStatus]) + }, [segmentListData, mode, parentMode, searchValue, selectedStatus, t]) const toggleFullScreen = useCallback(() => { setFullScreen(!fullScreen) @@ -449,8 +445,7 @@ const Completed: FC = ({ resetList() currentPage !== totalPages && setCurrentPage(totalPages) } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [segmentListData, limit, currentPage]) + }, [segmentListData, limit, currentPage, resetList]) const { mutateAsync: deleteChildSegment } = useDeleteChildSegment() @@ -470,8 +465,7 @@ const Completed: FC = ({ }, }, ) - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [datasetId, documentId, parentMode]) + }, [datasetId, documentId, parentMode, deleteChildSegment, resetList, resetChildList, t, notify]) const handleAddNewChildChunk = useCallback((parentChunkId: string) => { setShowNewChildSegmentModal(true) @@ -490,8 +484,7 @@ const Completed: FC = ({ else { resetChildList() } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [parentMode, currChunkId, segments]) + }, [parentMode, currChunkId, segments, refreshChunkListDataWithDetailChanged, resetChildList]) const viewNewlyAddedChildChunk = useCallback(() => { const totalPages = childChunkListData?.total_pages || 0 @@ -505,8 +498,7 @@ const Completed: FC = ({ resetChildList() currentPage !== totalPages && setCurrentPage(totalPages) } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [childChunkListData, limit, currentPage]) + }, [childChunkListData, limit, currentPage, resetChildList]) const onClickSlice = useCallback((detail: ChildChunkDetail) => { setCurrChildChunk({ childChunkInfo: detail, showModal: true }) @@ -560,8 +552,7 @@ const Completed: FC = ({ eventEmitter?.emit('update-child-segment-done') }, }) - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [segments, childSegments, datasetId, documentId, parentMode]) + }, [segments, datasetId, documentId, parentMode, updateChildSegment, notify, eventEmitter, onCloseChildSegmentDetail, refreshChunkListDataWithDetailChanged, resetChildList, t]) const onClearFilter = useCallback(() => { setInputValue('') @@ -570,6 +561,12 @@ const Completed: FC = ({ setCurrentPage(1) }, []) + const selectDefaultValue = useMemo(() => { + if (selectedStatus === 'all') + return 'all' + return selectedStatus ? 1 : 0 + }, [selectedStatus]) + return ( = ({ @@ -591,7 +588,7 @@ const Completed: FC = ({ = ({ const wordCountText = useMemo(() => { const total = formatNumber(word_count) return `${total} ${t('datasetDocuments.segment.characters', { count: word_count })}` - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [word_count]) + }, [word_count, t]) const labelPrefix = useMemo(() => { return isParentChildMode ? t('datasetDocuments.segment.parentChunk') : t('datasetDocuments.segment.chunk') - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [isParentChildMode]) + }, [isParentChildMode, t]) if (loading) return diff --git a/web/app/components/datasets/documents/detail/completed/segment-detail.tsx b/web/app/components/datasets/documents/detail/completed/segment-detail.tsx index cea3402499..d3575c18ed 100644 --- a/web/app/components/datasets/documents/detail/completed/segment-detail.tsx +++ b/web/app/components/datasets/documents/detail/completed/segment-detail.tsx @@ -86,8 +86,7 @@ const SegmentDetail: FC = ({ const titleText = useMemo(() => { return isEditMode ? t('datasetDocuments.segment.editChunk') : t('datasetDocuments.segment.chunkDetail') - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [isEditMode]) + }, [isEditMode, t]) const isQAModel = useMemo(() => { return docForm === ChunkingMode.qa @@ -98,13 +97,11 @@ const SegmentDetail: FC = ({ const total = formatNumber(isEditMode ? contentLength : segInfo!.word_count as number) const count = isEditMode ? contentLength : segInfo!.word_count as number return `${total} ${t('datasetDocuments.segment.characters', { count })}` - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [isEditMode, question.length, answer.length, segInfo?.word_count, isQAModel]) + }, [isEditMode, question.length, answer.length, isQAModel, segInfo, t]) const labelPrefix = useMemo(() => { return isParentChildMode ? t('datasetDocuments.segment.parentChunk') : t('datasetDocuments.segment.chunk') - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [isParentChildMode]) + }, [isParentChildMode, t]) return (
diff --git a/web/app/components/datasets/documents/detail/completed/segment-list.tsx b/web/app/components/datasets/documents/detail/completed/segment-list.tsx index b2351c1b97..f6076e5813 100644 --- a/web/app/components/datasets/documents/detail/completed/segment-list.tsx +++ b/web/app/components/datasets/documents/detail/completed/segment-list.tsx @@ -42,7 +42,7 @@ const SegmentList = ( embeddingAvailable, onClearFilter, }: ISegmentListProps & { - ref: React.RefObject; + ref: React.LegacyRef }, ) => { const mode = useDocumentContext(s => s.mode) diff --git a/web/app/components/datasets/documents/list.tsx b/web/app/components/datasets/documents/list.tsx index 8ed878fe56..cb349ee01c 100644 --- a/web/app/components/datasets/documents/list.tsx +++ b/web/app/components/datasets/documents/list.tsx @@ -202,7 +202,7 @@ export const OperationAction: FC<{ const isListScene = scene === 'list' const onOperate = async (operationName: OperationName) => { - let opApi = deleteDocument + let opApi switch (operationName) { case 'archive': opApi = archiveDocument @@ -490,7 +490,7 @@ const DocumentList: FC = ({ const handleAction = (actionName: DocumentActionType) => { return async () => { - let opApi = deleteDocument + let opApi switch (actionName) { case DocumentActionType.archive: opApi = archiveDocument @@ -527,7 +527,7 @@ const DocumentList: FC = ({ )} diff --git a/web/app/components/datasets/metadata/edit-metadata-batch/input-combined.tsx b/web/app/components/datasets/metadata/edit-metadata-batch/input-combined.tsx index 25e19506d0..fd7bb89bd3 100644 --- a/web/app/components/datasets/metadata/edit-metadata-batch/input-combined.tsx +++ b/web/app/components/datasets/metadata/edit-metadata-batch/input-combined.tsx @@ -40,7 +40,7 @@ const InputCombined: FC = ({ className={cn(className, 'rounded-l-md')} value={value} onChange={onChange} - size='sm' + size='regular' controlWrapClassName='overflow-hidden' controlClassName='pt-0 pb-0' readOnly={readOnly} diff --git a/web/app/components/header/account-dropdown/workplace-selector/index.tsx b/web/app/components/header/account-dropdown/workplace-selector/index.tsx index a9a886376a..da3f8bae6d 100644 --- a/web/app/components/header/account-dropdown/workplace-selector/index.tsx +++ b/web/app/components/header/account-dropdown/workplace-selector/index.tsx @@ -42,7 +42,7 @@ const WorkplaceSelector = () => { `, )}>
- {currentWorkspace?.name[0]?.toLocaleUpperCase()} + {currentWorkspace?.name[0]?.toLocaleUpperCase()}
{currentWorkspace?.name}
@@ -73,7 +73,7 @@ const WorkplaceSelector = () => { workspaces.map(workspace => (
handleSwitchWorkspace(workspace.id)}>
- {workspace?.name[0]?.toLocaleUpperCase()} + {workspace?.name[0]?.toLocaleUpperCase()}
{workspace.name}
diff --git a/web/app/components/header/account-setting/model-provider-page/declarations.ts b/web/app/components/header/account-setting/model-provider-page/declarations.ts index 39e229cd54..12dd9b3b5b 100644 --- a/web/app/components/header/account-setting/model-provider-page/declarations.ts +++ b/web/app/components/header/account-setting/model-provider-page/declarations.ts @@ -60,6 +60,7 @@ export enum ModelFeatureEnum { video = 'video', document = 'document', audio = 'audio', + StructuredOutput = 'structured-output', } export enum ModelFeatureTextEnum { diff --git a/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx index 025cb87dc1..9019051989 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-icon/index.tsx @@ -23,9 +23,9 @@ const ModelIcon: FC = ({ isDeprecated = false, }) => { const language = useLanguage() - if (provider?.provider.includes('openai') && modelName?.includes('gpt-4o')) + if (provider?.provider && ['openai', 'langgenius/openai/openai'].includes(provider.provider) && modelName?.includes('gpt-4o')) return
- if (provider?.provider.includes('openai') && modelName?.startsWith('gpt-4')) + if (provider?.provider && ['openai', 'langgenius/openai/openai'].includes(provider.provider) && modelName?.startsWith('gpt-4')) return
if (provider?.icon_small) { diff --git a/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx b/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx index 28001bef5e..c5af4ed8a1 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-modal/Form.tsx @@ -376,6 +376,7 @@ function Form< tooltip={tooltip?.[language] || tooltip?.en_US} value={value[variable] || []} onChange={item => handleFormChange(variable, item as any)} + supportCollapse /> {fieldMoreInfo?.(formSchema)} {validating && changeKey === variable && } diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx index 4bb3cbf7d5..3e969d708b 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx @@ -10,6 +10,7 @@ import Slider from '@/app/components/base/slider' import Radio from '@/app/components/base/radio' import { SimpleSelect } from '@/app/components/base/select' import TagInput from '@/app/components/base/tag-input' +import { useTranslation } from 'react-i18next' export type ParameterValue = number | string | string[] | boolean | undefined @@ -27,6 +28,7 @@ const ParameterItem: FC = ({ onSwitch, isInWorkflow, }) => { + const { t } = useTranslation() const language = useLanguage() const [localValue, setLocalValue] = useState(value) const numberInputRef = useRef(null) diff --git a/web/app/components/plugins/plugin-detail-panel/multiple-tool-selector/index.tsx b/web/app/components/plugins/plugin-detail-panel/multiple-tool-selector/index.tsx index fc29feaefc..f243d30aff 100644 --- a/web/app/components/plugins/plugin-detail-panel/multiple-tool-selector/index.tsx +++ b/web/app/components/plugins/plugin-detail-panel/multiple-tool-selector/index.tsx @@ -2,7 +2,6 @@ import React from 'react' import { useTranslation } from 'react-i18next' import { RiAddLine, - RiArrowDropDownLine, RiQuestionLine, } from '@remixicon/react' import ToolSelector from '@/app/components/plugins/plugin-detail-panel/tool-selector' @@ -13,6 +12,7 @@ import type { ToolValue } from '@/app/components/workflow/block-selector/types' import type { Node } from 'reactflow' import type { NodeOutPutVar } from '@/app/components/workflow/types' import cn from '@/utils/classnames' +import { ArrowDownRoundFill } from '@/app/components/base/icons/src/vender/solid/general' type Props = { disabled?: boolean @@ -98,14 +98,12 @@ const MultipleToolSelector = ({ )} {supportCollapse && ( -
- -
+ )}
{value.length > 0 && ( diff --git a/web/app/components/workflow/hooks/use-workflow-variables.ts b/web/app/components/workflow/hooks/use-workflow-variables.ts index a2863671ed..35637bc775 100644 --- a/web/app/components/workflow/hooks/use-workflow-variables.ts +++ b/web/app/components/workflow/hooks/use-workflow-variables.ts @@ -8,6 +8,8 @@ import type { ValueSelector, Var, } from '@/app/components/workflow/types' +import { useIsChatMode } from './use-workflow' +import { useStoreApi } from 'reactflow' export const useWorkflowVariables = () => { const { t } = useTranslation() @@ -75,3 +77,37 @@ export const useWorkflowVariables = () => { getCurrentVariableType, } } + +export const useWorkflowVariableType = () => { + const store = useStoreApi() + const { + getNodes, + } = store.getState() + const { getCurrentVariableType } = useWorkflowVariables() + + const isChatMode = useIsChatMode() + + const getVarType = ({ + nodeId, + valueSelector, + }: { + nodeId: string, + valueSelector: ValueSelector, + }) => { + const node = getNodes().find(n => n.id === nodeId) + const isInIteration = !!node?.data.isInIteration + const iterationNode = isInIteration ? getNodes().find(n => n.id === node.parentId) : null + const availableNodes = [node] + + const type = getCurrentVariableType({ + parentNode: iterationNode, + valueSelector, + availableNodes, + isChatMode, + isConstant: false, + }) + return type + } + + return getVarType +} diff --git a/web/app/components/workflow/nodes/_base/components/agent-strategy.tsx b/web/app/components/workflow/nodes/_base/components/agent-strategy.tsx index be57cbca0f..d67b7af1a4 100644 --- a/web/app/components/workflow/nodes/_base/components/agent-strategy.tsx +++ b/web/app/components/workflow/nodes/_base/components/agent-strategy.tsx @@ -133,7 +133,7 @@ export const AgentStrategy = memo((props: AgentStrategyProps) => { // TODO: maybe empty, handle this onChange={onChange as any} defaultValue={defaultValue} - size='sm' + size='regular' min={def.min} max={def.max} className='w-12' diff --git a/web/app/components/workflow/nodes/_base/components/collapse/field-collapse.tsx b/web/app/components/workflow/nodes/_base/components/collapse/field-collapse.tsx index 4b36125575..2390dfd74e 100644 --- a/web/app/components/workflow/nodes/_base/components/collapse/field-collapse.tsx +++ b/web/app/components/workflow/nodes/_base/components/collapse/field-collapse.tsx @@ -4,10 +4,16 @@ import Collapse from '.' type FieldCollapseProps = { title: string children: ReactNode + collapsed?: boolean + onCollapse?: (collapsed: boolean) => void + operations?: ReactNode } const FieldCollapse = ({ title, children, + collapsed, + onCollapse, + operations, }: FieldCollapseProps) => { return (
@@ -15,6 +21,9 @@ const FieldCollapse = ({ trigger={
{title}
} + operations={operations} + collapsed={collapsed} + onCollapse={onCollapse} >
{children} diff --git a/web/app/components/workflow/nodes/_base/components/collapse/index.tsx b/web/app/components/workflow/nodes/_base/components/collapse/index.tsx index 1f39c1c1c5..16fba88a25 100644 --- a/web/app/components/workflow/nodes/_base/components/collapse/index.tsx +++ b/web/app/components/workflow/nodes/_base/components/collapse/index.tsx @@ -1,15 +1,18 @@ -import { useState } from 'react' -import { RiArrowDropRightLine } from '@remixicon/react' +import type { ReactNode } from 'react' +import { useMemo, useState } from 'react' +import { ArrowDownRoundFill } from '@/app/components/base/icons/src/vender/solid/general' import cn from '@/utils/classnames' export { default as FieldCollapse } from './field-collapse' type CollapseProps = { disabled?: boolean - trigger: React.JSX.Element + trigger: React.JSX.Element | ((collapseIcon: React.JSX.Element | null) => React.JSX.Element) children: React.JSX.Element collapsed?: boolean onCollapse?: (collapsed: boolean) => void + operations?: ReactNode + hideCollapseIcon?: boolean } const Collapse = ({ disabled, @@ -17,34 +20,44 @@ const Collapse = ({ children, collapsed, onCollapse, + operations, + hideCollapseIcon, }: CollapseProps) => { const [collapsedLocal, setCollapsedLocal] = useState(true) const collapsedMerged = collapsed !== undefined ? collapsed : collapsedLocal + const collapseIcon = useMemo(() => { + if (disabled) + return null + return ( + + ) + }, [collapsedMerged, disabled]) return ( <> -
{ - if (!disabled) { - setCollapsedLocal(!collapsedMerged) - onCollapse?.(!collapsedMerged) - } - }} - > -
- { - !disabled && ( - - ) - } +
+
{ + if (!disabled) { + setCollapsedLocal(!collapsedMerged) + onCollapse?.(!collapsedMerged) + } + }} + > + {typeof trigger === 'function' ? trigger(collapseIcon) : trigger} + {!hideCollapseIcon && ( +
+ {collapseIcon} +
+ )}
- {trigger} + {operations}
{ !collapsedMerged && children diff --git a/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx b/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx index b36abbfb00..cfcbae80f3 100644 --- a/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx +++ b/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx @@ -49,20 +49,23 @@ const ErrorHandle = ({ disabled={!error_strategy} collapsed={collapsed} onCollapse={setCollapsed} + hideCollapseIcon trigger={ -
-
-
- {t('workflow.nodes.common.errorHandle.title')} + collapseIcon => ( +
+
+
+ {t('workflow.nodes.common.errorHandle.title')} +
+ + {collapseIcon}
- +
- -
- } + )} > <> { diff --git a/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-type-selector.tsx b/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-type-selector.tsx index 190c748831..d9516dfcf5 100644 --- a/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-type-selector.tsx +++ b/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-type-selector.tsx @@ -50,6 +50,7 @@ const ErrorHandleTypeSelector = ({ > { e.stopPropagation() + e.nativeEvent.stopImmediatePropagation() setOpen(v => !v) }}> + + )} + + + +
+
+
+ +
+
+ ) +} + +export default React.memo(CodeEditor) diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/error-message.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/error-message.tsx new file mode 100644 index 0000000000..2685182f9f --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/error-message.tsx @@ -0,0 +1,27 @@ +import React from 'react' +import type { FC } from 'react' +import { RiErrorWarningFill } from '@remixicon/react' +import classNames from '@/utils/classnames' + +type ErrorMessageProps = { + message: string +} & React.HTMLAttributes + +const ErrorMessage: FC = ({ + message, + className, +}) => { + return ( +
+ +
+ {message} +
+
+ ) +} + +export default React.memo(ErrorMessage) diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/index.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/index.tsx new file mode 100644 index 0000000000..d34836d5b2 --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/index.tsx @@ -0,0 +1,34 @@ +import React, { type FC } from 'react' +import Modal from '../../../../../base/modal' +import type { SchemaRoot } from '../../types' +import JsonSchemaConfig from './json-schema-config' + +type JsonSchemaConfigModalProps = { + isShow: boolean + defaultSchema?: SchemaRoot + onSave: (schema: SchemaRoot) => void + onClose: () => void +} + +const JsonSchemaConfigModal: FC = ({ + isShow, + defaultSchema, + onSave, + onClose, +}) => { + return ( + + + + ) +} + +export default JsonSchemaConfigModal diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-importer.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-importer.tsx new file mode 100644 index 0000000000..643059adbd --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-importer.tsx @@ -0,0 +1,136 @@ +import React, { type FC, useCallback, useEffect, useRef, useState } from 'react' +import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem' +import cn from '@/utils/classnames' +import { useTranslation } from 'react-i18next' +import { RiCloseLine } from '@remixicon/react' +import Button from '@/app/components/base/button' +import { checkJsonDepth } from '../../utils' +import { JSON_SCHEMA_MAX_DEPTH } from '@/config' +import CodeEditor from './code-editor' +import ErrorMessage from './error-message' +import { useVisualEditorStore } from './visual-editor/store' +import { useMittContext } from './visual-editor/context' + +type JsonImporterProps = { + onSubmit: (schema: any) => void + updateBtnWidth: (width: number) => void +} + +const JsonImporter: FC = ({ + onSubmit, + updateBtnWidth, +}) => { + const { t } = useTranslation() + const [open, setOpen] = useState(false) + const [json, setJson] = useState('') + const [parseError, setParseError] = useState(null) + const importBtnRef = useRef(null) + const advancedEditing = useVisualEditorStore(state => state.advancedEditing) + const isAddingNewField = useVisualEditorStore(state => state.isAddingNewField) + const { emit } = useMittContext() + + useEffect(() => { + if (importBtnRef.current) { + const rect = importBtnRef.current.getBoundingClientRect() + updateBtnWidth(rect.width) + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []) + + const handleTrigger = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + if (advancedEditing || isAddingNewField) + emit('quitEditing', {}) + setOpen(!open) + }, [open, advancedEditing, isAddingNewField, emit]) + + const onClose = useCallback(() => { + setOpen(false) + }, []) + + const handleSubmit = useCallback(() => { + try { + const parsedJSON = JSON.parse(json) + if (typeof parsedJSON !== 'object' || Array.isArray(parsedJSON)) { + setParseError(new Error('Root must be an object, not an array or primitive value.')) + return + } + const maxDepth = checkJsonDepth(parsedJSON) + if (maxDepth > JSON_SCHEMA_MAX_DEPTH) { + setParseError({ + type: 'error', + message: `Schema exceeds maximum depth of ${JSON_SCHEMA_MAX_DEPTH}.`, + }) + return + } + onSubmit(parsedJSON) + setParseError(null) + setOpen(false) + } + catch (e: any) { + if (e instanceof Error) + setParseError(e) + else + setParseError(new Error('Invalid JSON')) + } + }, [onSubmit, json]) + + return ( + + + + + +
+ {/* Title */} +
+
+ +
+
+ {t('workflow.nodes.llm.jsonSchema.import')} +
+
+ {/* Content */} +
+ + {parseError && } +
+ {/* Footer */} +
+ + +
+
+
+
+ ) +} + +export default JsonImporter diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-config.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-config.tsx new file mode 100644 index 0000000000..d125e31dae --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-config.tsx @@ -0,0 +1,301 @@ +import React, { type FC, useCallback, useState } from 'react' +import { type SchemaRoot, Type } from '../../types' +import { RiBracesLine, RiCloseLine, RiExternalLinkLine, RiTimelineView } from '@remixicon/react' +import { SegmentedControl } from '../../../../../base/segmented-control' +import JsonSchemaGenerator from './json-schema-generator' +import Divider from '@/app/components/base/divider' +import JsonImporter from './json-importer' +import { useTranslation } from 'react-i18next' +import Button from '@/app/components/base/button' +import VisualEditor from './visual-editor' +import SchemaEditor from './schema-editor' +import { + checkJsonSchemaDepth, + convertBooleanToString, + getValidationErrorMessage, + jsonToSchema, + preValidateSchema, + validateSchemaAgainstDraft7, +} from '../../utils' +import { MittProvider, VisualEditorContextProvider, useMittContext } from './visual-editor/context' +import ErrorMessage from './error-message' +import { useVisualEditorStore } from './visual-editor/store' +import Toast from '@/app/components/base/toast' +import { useGetLanguage } from '@/context/i18n' +import { JSON_SCHEMA_MAX_DEPTH } from '@/config' + +type JsonSchemaConfigProps = { + defaultSchema?: SchemaRoot + onSave: (schema: SchemaRoot) => void + onClose: () => void +} + +enum SchemaView { + VisualEditor = 'visualEditor', + JsonSchema = 'jsonSchema', +} + +const VIEW_TABS = [ + { Icon: RiTimelineView, text: 'Visual Editor', value: SchemaView.VisualEditor }, + { Icon: RiBracesLine, text: 'JSON Schema', value: SchemaView.JsonSchema }, +] + +const DEFAULT_SCHEMA: SchemaRoot = { + type: Type.object, + properties: {}, + required: [], + additionalProperties: false, +} + +const HELP_DOC_URL = { + zh_Hans: 'https://docs.dify.ai/zh-hans/guides/workflow/structured-outputs', + en_US: 'https://docs.dify.ai/guides/workflow/structured-outputs', + ja_JP: 'https://docs.dify.ai/ja-jp/guides/workflow/structured-outputs', +} + +type LocaleKey = keyof typeof HELP_DOC_URL + +const JsonSchemaConfig: FC = ({ + defaultSchema, + onSave, + onClose, +}) => { + const { t } = useTranslation() + const locale = useGetLanguage() as LocaleKey + const [currentTab, setCurrentTab] = useState(SchemaView.VisualEditor) + const [jsonSchema, setJsonSchema] = useState(defaultSchema || DEFAULT_SCHEMA) + const [json, setJson] = useState(JSON.stringify(jsonSchema, null, 2)) + const [btnWidth, setBtnWidth] = useState(0) + const [parseError, setParseError] = useState(null) + const [validationError, setValidationError] = useState('') + const advancedEditing = useVisualEditorStore(state => state.advancedEditing) + const setAdvancedEditing = useVisualEditorStore(state => state.setAdvancedEditing) + const isAddingNewField = useVisualEditorStore(state => state.isAddingNewField) + const setIsAddingNewField = useVisualEditorStore(state => state.setIsAddingNewField) + const setHoveringProperty = useVisualEditorStore(state => state.setHoveringProperty) + const { emit } = useMittContext() + + const updateBtnWidth = useCallback((width: number) => { + setBtnWidth(width + 32) + }, []) + + const handleTabChange = useCallback((value: SchemaView) => { + if (currentTab === value) return + if (currentTab === SchemaView.JsonSchema) { + try { + const schema = JSON.parse(json) + setParseError(null) + const result = preValidateSchema(schema) + if (!result.success) { + setValidationError(result.error.message) + return + } + const schemaDepth = checkJsonSchemaDepth(schema) + if (schemaDepth > JSON_SCHEMA_MAX_DEPTH) { + setValidationError(`Schema exceeds maximum depth of ${JSON_SCHEMA_MAX_DEPTH}.`) + return + } + convertBooleanToString(schema) + const validationErrors = validateSchemaAgainstDraft7(schema) + if (validationErrors.length > 0) { + setValidationError(getValidationErrorMessage(validationErrors)) + return + } + setJsonSchema(schema) + setValidationError('') + } + catch (error) { + setValidationError('') + if (error instanceof Error) + setParseError(error) + else + setParseError(new Error('Invalid JSON')) + return + } + } + else if (currentTab === SchemaView.VisualEditor) { + if (advancedEditing || isAddingNewField) + emit('quitEditing', { callback: (backup: SchemaRoot) => setJson(JSON.stringify(backup || jsonSchema, null, 2)) }) + else + setJson(JSON.stringify(jsonSchema, null, 2)) + } + + setCurrentTab(value) + }, [currentTab, jsonSchema, json, advancedEditing, isAddingNewField, emit]) + + const handleApplySchema = useCallback((schema: SchemaRoot) => { + if (currentTab === SchemaView.VisualEditor) + setJsonSchema(schema) + else if (currentTab === SchemaView.JsonSchema) + setJson(JSON.stringify(schema, null, 2)) + }, [currentTab]) + + const handleSubmit = useCallback((schema: any) => { + const jsonSchema = jsonToSchema(schema) as SchemaRoot + if (currentTab === SchemaView.VisualEditor) + setJsonSchema(jsonSchema) + else if (currentTab === SchemaView.JsonSchema) + setJson(JSON.stringify(jsonSchema, null, 2)) + }, [currentTab]) + + const handleVisualEditorUpdate = useCallback((schema: SchemaRoot) => { + setJsonSchema(schema) + }, []) + + const handleSchemaEditorUpdate = useCallback((schema: string) => { + setJson(schema) + }, []) + + const handleResetDefaults = useCallback(() => { + if (currentTab === SchemaView.VisualEditor) { + setHoveringProperty(null) + advancedEditing && setAdvancedEditing(false) + isAddingNewField && setIsAddingNewField(false) + } + setJsonSchema(DEFAULT_SCHEMA) + setJson(JSON.stringify(DEFAULT_SCHEMA, null, 2)) + }, [currentTab, advancedEditing, isAddingNewField, setAdvancedEditing, setIsAddingNewField, setHoveringProperty]) + + const handleCancel = useCallback(() => { + onClose() + }, [onClose]) + + const handleSave = useCallback(() => { + let schema = jsonSchema + if (currentTab === SchemaView.JsonSchema) { + try { + schema = JSON.parse(json) + setParseError(null) + const result = preValidateSchema(schema) + if (!result.success) { + setValidationError(result.error.message) + return + } + const schemaDepth = checkJsonSchemaDepth(schema) + if (schemaDepth > JSON_SCHEMA_MAX_DEPTH) { + setValidationError(`Schema exceeds maximum depth of ${JSON_SCHEMA_MAX_DEPTH}.`) + return + } + convertBooleanToString(schema) + const validationErrors = validateSchemaAgainstDraft7(schema) + if (validationErrors.length > 0) { + setValidationError(getValidationErrorMessage(validationErrors)) + return + } + setJsonSchema(schema) + setValidationError('') + } + catch (error) { + setValidationError('') + if (error instanceof Error) + setParseError(error) + else + setParseError(new Error('Invalid JSON')) + return + } + } + else if (currentTab === SchemaView.VisualEditor) { + if (advancedEditing || isAddingNewField) { + Toast.notify({ + type: 'warning', + message: t('workflow.nodes.llm.jsonSchema.warningTips.saveSchema'), + }) + return + } + } + onSave(schema) + onClose() + }, [currentTab, jsonSchema, json, onSave, onClose, advancedEditing, isAddingNewField, t]) + + return ( +
+ {/* Header */} +
+
+ {t('workflow.nodes.llm.jsonSchema.title')} +
+
+ +
+
+ {/* Content */} +
+ {/* Tab */} + + options={VIEW_TABS} + value={currentTab} + onChange={handleTabChange} + /> +
+ {/* JSON Schema Generator */} + + + {/* JSON Schema Importer */} + +
+
+
+ {currentTab === SchemaView.VisualEditor && ( + + )} + {currentTab === SchemaView.JsonSchema && ( + + )} + {parseError && } + {validationError && } +
+ {/* Footer */} +
+ + {t('workflow.nodes.llm.jsonSchema.doc')} + + +
+
+ + +
+
+ + +
+
+
+
+ ) +} + +const JsonSchemaConfigWrapper: FC = (props) => { + return ( + + + + + + ) +} + +export default JsonSchemaConfigWrapper diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/index.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/index.tsx new file mode 100644 index 0000000000..5f1f117086 --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/index.tsx @@ -0,0 +1,7 @@ +import SchemaGeneratorLight from './schema-generator-light' +import SchemaGeneratorDark from './schema-generator-dark' + +export { + SchemaGeneratorLight, + SchemaGeneratorDark, +} diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/schema-generator-dark.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/schema-generator-dark.tsx new file mode 100644 index 0000000000..ac4793b1e3 --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/schema-generator-dark.tsx @@ -0,0 +1,15 @@ +const SchemaGeneratorDark = () => { + return ( + + + + + + + + + + ) +} + +export default SchemaGeneratorDark diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/schema-generator-light.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/schema-generator-light.tsx new file mode 100644 index 0000000000..8b898bde68 --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/schema-generator-light.tsx @@ -0,0 +1,15 @@ +const SchemaGeneratorLight = () => { + return ( + + + + + + + + + + ) +} + +export default SchemaGeneratorLight diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/generated-result.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/generated-result.tsx new file mode 100644 index 0000000000..00f57237e5 --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/generated-result.tsx @@ -0,0 +1,121 @@ +import React, { type FC, useCallback, useMemo, useState } from 'react' +import type { SchemaRoot } from '../../../types' +import { RiArrowLeftLine, RiCloseLine, RiSparklingLine } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import Button from '@/app/components/base/button' +import CodeEditor from '../code-editor' +import ErrorMessage from '../error-message' +import { getValidationErrorMessage, validateSchemaAgainstDraft7 } from '../../../utils' +import Loading from '@/app/components/base/loading' + +type GeneratedResultProps = { + schema: SchemaRoot + isGenerating: boolean + onBack: () => void + onRegenerate: () => void + onClose: () => void + onApply: () => void +} + +const GeneratedResult: FC = ({ + schema, + isGenerating, + onBack, + onRegenerate, + onClose, + onApply, +}) => { + const { t } = useTranslation() + const [parseError, setParseError] = useState(null) + const [validationError, setValidationError] = useState('') + + const formatJSON = (json: SchemaRoot) => { + try { + const schema = JSON.stringify(json, null, 2) + setParseError(null) + return schema + } + catch (e) { + if (e instanceof Error) + setParseError(e) + else + setParseError(new Error('Invalid JSON')) + return '' + } + } + + const jsonSchema = useMemo(() => formatJSON(schema), [schema]) + + const handleApply = useCallback(() => { + const validationErrors = validateSchemaAgainstDraft7(schema) + if (validationErrors.length > 0) { + setValidationError(getValidationErrorMessage(validationErrors)) + return + } + onApply() + setValidationError('') + }, [schema, onApply]) + + return ( +
+ { + isGenerating ? ( +
+ +
{t('workflow.nodes.llm.jsonSchema.generating')}
+
+ ) : ( + <> +
+ +
+ {/* Title */} +
+
+ {t('workflow.nodes.llm.jsonSchema.generatedResult')} +
+
+ {t('workflow.nodes.llm.jsonSchema.resultTip')} +
+
+ {/* Content */} +
+ + {parseError && } + {validationError && } +
+ {/* Footer */} +
+ +
+ + +
+
+ + + ) + } +
+ ) +} + +export default React.memo(GeneratedResult) diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/index.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/index.tsx new file mode 100644 index 0000000000..4732499f3a --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/index.tsx @@ -0,0 +1,183 @@ +import React, { type FC, useCallback, useEffect, useState } from 'react' +import type { SchemaRoot } from '../../../types' +import { + PortalToFollowElem, + PortalToFollowElemContent, + PortalToFollowElemTrigger, +} from '@/app/components/base/portal-to-follow-elem' +import useTheme from '@/hooks/use-theme' +import type { CompletionParams, Model } from '@/types/app' +import { ModelModeType } from '@/types/app' +import { Theme } from '@/types/app' +import { SchemaGeneratorDark, SchemaGeneratorLight } from './assets' +import cn from '@/utils/classnames' +import type { ModelInfo } from './prompt-editor' +import PromptEditor from './prompt-editor' +import GeneratedResult from './generated-result' +import { useGenerateStructuredOutputRules } from '@/service/use-common' +import Toast from '@/app/components/base/toast' +import { type FormValue, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { useVisualEditorStore } from '../visual-editor/store' +import { useTranslation } from 'react-i18next' +import { useMittContext } from '../visual-editor/context' + +type JsonSchemaGeneratorProps = { + onApply: (schema: SchemaRoot) => void + crossAxisOffset?: number +} + +enum GeneratorView { + promptEditor = 'promptEditor', + result = 'result', +} + +export const JsonSchemaGenerator: FC = ({ + onApply, + crossAxisOffset, +}) => { + const { t } = useTranslation() + const [open, setOpen] = useState(false) + const [view, setView] = useState(GeneratorView.promptEditor) + const [model, setModel] = useState({ + name: '', + provider: '', + mode: ModelModeType.completion, + completion_params: {} as CompletionParams, + }) + const [instruction, setInstruction] = useState('') + const [schema, setSchema] = useState(null) + const { theme } = useTheme() + const { + defaultModel, + } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) + const advancedEditing = useVisualEditorStore(state => state.advancedEditing) + const isAddingNewField = useVisualEditorStore(state => state.isAddingNewField) + const { emit } = useMittContext() + const SchemaGenerator = theme === Theme.light ? SchemaGeneratorLight : SchemaGeneratorDark + + useEffect(() => { + if (defaultModel) { + setModel(prev => ({ + ...prev, + name: defaultModel.model, + provider: defaultModel.provider.provider, + })) + } + }, [defaultModel]) + + const handleTrigger = useCallback((e: React.MouseEvent) => { + e.stopPropagation() + if (advancedEditing || isAddingNewField) + emit('quitEditing', {}) + setOpen(!open) + }, [open, advancedEditing, isAddingNewField, emit]) + + const onClose = useCallback(() => { + setOpen(false) + }, []) + + const handleModelChange = useCallback((model: ModelInfo) => { + setModel(prev => ({ + ...prev, + provider: model.provider, + name: model.modelId, + mode: model.mode as ModelModeType, + })) + }, []) + + const handleCompletionParamsChange = useCallback((newParams: FormValue) => { + setModel(prev => ({ + ...prev, + completion_params: newParams as CompletionParams, + }), + ) + }, []) + + const { mutateAsync: generateStructuredOutputRules, isPending: isGenerating } = useGenerateStructuredOutputRules() + + const generateSchema = useCallback(async () => { + const { output, error } = await generateStructuredOutputRules({ instruction, model_config: model! }) + if (error) { + Toast.notify({ + type: 'error', + message: error, + }) + setSchema(null) + setView(GeneratorView.promptEditor) + return + } + return output + }, [instruction, model, generateStructuredOutputRules]) + + const handleGenerate = useCallback(async () => { + setView(GeneratorView.result) + const output = await generateSchema() + if (output === undefined) return + setSchema(JSON.parse(output)) + }, [generateSchema]) + + const goBackToPromptEditor = () => { + setView(GeneratorView.promptEditor) + } + + const handleRegenerate = useCallback(async () => { + const output = await generateSchema() + if (output === undefined) return + setSchema(JSON.parse(output)) + }, [generateSchema]) + + const handleApply = () => { + onApply(schema!) + setOpen(false) + } + + return ( + + + + + + {view === GeneratorView.promptEditor && ( + + )} + {view === GeneratorView.result && ( + + )} + + + ) +} + +export default JsonSchemaGenerator diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/prompt-editor.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/prompt-editor.tsx new file mode 100644 index 0000000000..9387813ee5 --- /dev/null +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/prompt-editor.tsx @@ -0,0 +1,108 @@ +import React, { useCallback } from 'react' +import type { FC } from 'react' +import { RiCloseLine, RiSparklingFill } from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import Textarea from '@/app/components/base/textarea' +import Tooltip from '@/app/components/base/tooltip' +import Button from '@/app/components/base/button' +import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations' +import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' +import type { Model } from '@/types/app' + +export type ModelInfo = { + modelId: string + provider: string + mode?: string + features?: string[] +} + +type PromptEditorProps = { + instruction: string + model: Model + onInstructionChange: (instruction: string) => void + onCompletionParamsChange: (newParams: FormValue) => void + onModelChange: (model: ModelInfo) => void + onClose: () => void + onGenerate: () => void +} + +const PromptEditor: FC = ({ + instruction, + model, + onInstructionChange, + onCompletionParamsChange, + onClose, + onGenerate, + onModelChange, +}) => { + const { t } = useTranslation() + + const handleInstructionChange = useCallback((e: React.ChangeEvent) => { + onInstructionChange(e.target.value) + }, [onInstructionChange]) + + return ( +
+
+ +
+ {/* Title */} +
+
+ {t('workflow.nodes.llm.jsonSchema.generateJsonSchema')} +
+
+ {t('workflow.nodes.llm.jsonSchema.generationTip')} +
+
+ {/* Content */} +
+
+ {t('common.modelProvider.model')} +
+ +
+
+
+ {t('workflow.nodes.llm.jsonSchema.instruction')} + +
+
+