Merge branch 'main' into feat/dependency-relations

# Conflicts:
#	web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsx
pull/21998/head
Minamiyama 9 months ago
commit 02de417615

@ -6,6 +6,7 @@ on:
- "main"
- "deploy/dev"
- "deploy/enterprise"
- "build/**"
tags:
- "*"

@ -5,17 +5,17 @@
SECRET_KEY=
# Console API base URL
CONSOLE_API_URL=http://127.0.0.1:5001
CONSOLE_WEB_URL=http://127.0.0.1:3000
CONSOLE_API_URL=http://localhost:5001
CONSOLE_WEB_URL=http://localhost:3000
# Service API base URL
SERVICE_API_URL=http://127.0.0.1:5001
SERVICE_API_URL=http://localhost:5001
# Web APP base URL
APP_WEB_URL=http://127.0.0.1:3000
APP_WEB_URL=http://localhost:3000
# Files URL
FILES_URL=http://127.0.0.1:5001
FILES_URL=http://localhost:5001
# INTERNAL_FILES_URL is used for plugin daemon communication within Docker network.
# Set this to the internal Docker service URL for proper plugin file access.
@ -138,8 +138,8 @@ SUPABASE_API_KEY=your-access-key
SUPABASE_URL=your-server-url
# CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
# Vector database configuration
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore, matrixone

@ -2,19 +2,22 @@ import base64
import json
import logging
import secrets
from typing import Optional
from typing import Any, Optional
import click
from flask import current_app
from pydantic import TypeAdapter
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from configs import dify_config
from constants.languages import languages
from core.plugin.entities.plugin import ToolProviderID
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.models.document import Document
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
from events.app_event import app_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@ -27,6 +30,7 @@ from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, D
from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
from models.provider import Provider, ProviderModel
from models.tools import ToolOAuthSystemClient
from services.account_service import AccountService, RegisterService, TenantService
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
from services.plugin.data_migration import PluginDataMigration
@ -1155,3 +1159,49 @@ def remove_orphaned_files_on_storage(force: bool):
click.echo(click.style(f"Removed {removed_files} orphaned files without errors.", fg="green"))
else:
click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow"))
@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.")
@click.option("--provider", prompt=True, help="Provider name")
@click.option("--client-params", prompt=True, help="Client Params")
def setup_system_tool_oauth_client(provider, client_params):
"""
Setup system tool oauth client
"""
provider_id = ToolProviderID(provider)
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
try:
# json validate
click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
click.echo(click.style("Client params validated successfully.", fg="green"))
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
return
deleted_count = (
db.session.query(ToolOAuthSystemClient)
.filter_by(
provider=provider_name,
plugin_id=plugin_id,
)
.delete()
)
if deleted_count > 0:
click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
oauth_client = ToolOAuthSystemClient(
provider=provider_name,
plugin_id=plugin_id,
encrypted_oauth_params=oauth_client_params,
)
db.session.add(oauth_client)
db.session.commit()
click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))

@ -1,6 +1,7 @@
from configs import dify_config
HIDDEN_VALUE = "[__HIDDEN__]"
UNKNOWN_VALUE = "[__UNKNOWN__]"
UUID_NIL = "00000000-0000-0000-0000-000000000000"
DEFAULT_FILE_NUMBER_LIMITS = 3

@ -27,19 +27,19 @@ class InvalidTokenError(BaseHTTPException):
class PasswordResetRateLimitExceededError(BaseHTTPException):
error_code = "password_reset_rate_limit_exceeded"
description = "Too many password reset emails have been sent. Please try again in 1 minutes."
description = "Too many password reset emails have been sent. Please try again in 1 minute."
code = 429
class EmailChangeRateLimitExceededError(BaseHTTPException):
error_code = "email_change_rate_limit_exceeded"
description = "Too many email change emails have been sent. Please try again in 1 minutes."
description = "Too many email change emails have been sent. Please try again in 1 minute."
code = 429
class OwnerTransferRateLimitExceededError(BaseHTTPException):
error_code = "owner_transfer_rate_limit_exceeded"
description = "Too many owner tansfer emails have been sent. Please try again in 1 minutes."
description = "Too many owner transfer emails have been sent. Please try again in 1 minute."
code = 429

@ -211,10 +211,6 @@ class DatasetApi(Resource):
else:
data["embedding_available"] = True
if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})
return data, 200
@setup_required

@ -264,11 +264,9 @@ class OwnerTransfer(Resource):
transfer_token_data = AccountService.get_owner_transfer_data(args["token"])
if not transfer_token_data:
print(transfer_token_data, "transfer_token_data")
raise InvalidTokenError()
if transfer_token_data.get("email") != current_user.email:
print(transfer_token_data.get("email"), current_user.email)
raise InvalidEmailError()
AccountService.revoke_owner_transfer_token(args["token"])

@ -1,23 +1,32 @@
import io
from urllib.parse import urlparse
from flask import redirect, send_file
from flask import make_response, redirect, request, send_file
from flask_login import current_user
from flask_restful import Resource, reqparse
from sqlalchemy.orm import Session
from flask_restful import (
Resource,
reqparse,
)
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
setup_required,
)
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from libs.helper import alphanumeric, uuid_value
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler
from core.tools.entities.tool_entities import CredentialType
from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import login_required
from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.tools.mcp_tools_mange_service import MCPToolManageService
@ -89,7 +98,7 @@ class ToolBuiltinProviderInfoApi(Resource):
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider))
return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
class ToolBuiltinProviderDeleteApi(Resource):
@ -98,17 +107,47 @@ class ToolBuiltinProviderDeleteApi(Resource):
@account_initialization_required
def post(self, provider):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
user_id = user.id
tenant_id = user.current_tenant_id
req = reqparse.RequestParser()
req.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
args = req.parse_args()
return BuiltinToolManageService.delete_builtin_tool_provider(
user_id,
tenant_id,
provider,
args["credential_id"],
)
class ToolBuiltinProviderAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json")
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
if args["type"] not in CredentialType.values():
raise ValueError(f"Invalid credential type: {args['type']}")
return BuiltinToolManageService.add_builtin_tool_provider(
user_id=user_id,
tenant_id=tenant_id,
provider=provider,
credentials=args["credentials"],
name=args["name"],
api_type=CredentialType.of(args["type"]),
)
@ -126,19 +165,20 @@ class ToolBuiltinProviderUpdateApi(Resource):
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
with Session(db.engine) as session:
result = BuiltinToolManageService.update_builtin_tool_provider(
session=session,
user_id=user_id,
tenant_id=tenant_id,
provider_name=provider,
credentials=args["credentials"],
)
session.commit()
result = BuiltinToolManageService.update_builtin_tool_provider(
user_id=user_id,
tenant_id=tenant_id,
provider=provider,
credential_id=args["credential_id"],
credentials=args.get("credentials", None),
name=args.get("name", ""),
)
return result
@ -149,9 +189,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
def get(self, provider):
tenant_id = current_user.current_tenant_id
return BuiltinToolManageService.get_builtin_tool_provider_credentials(
tenant_id=tenant_id,
provider_name=provider,
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credentials(
tenant_id=tenant_id,
provider_name=provider,
)
)
@ -344,12 +386,15 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
def get(self, provider, credential_type):
user = current_user
tenant_id = user.current_tenant_id
return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id)
return jsonable_encoder(
BuiltinToolManageService.list_builtin_provider_credentials_schema(
provider, CredentialType.of(credential_type), tenant_id
)
)
class ToolApiProviderSchemaApi(Resource):
@ -586,15 +631,12 @@ class ToolApiListApi(Resource):
@account_initialization_required
def get(self):
user = current_user
user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(
[
provider.to_dict()
for provider in ApiToolManageService.list_api_tools(
user_id,
tenant_id,
)
]
@ -631,6 +673,179 @@ class ToolLabelsApi(Resource):
return jsonable_encoder(ToolLabelsService.list_tool_labels())
class ToolPluginOAuthApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
tool_provider = ToolProviderID(provider)
plugin_id = tool_provider.plugin_id
provider_name = tool_provider.provider_name
# todo check permission
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
tenant_id = user.current_tenant_id
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider")
oauth_handler = OAuthHandler()
context_id = OAuthProxyService.create_proxy_context(
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
)
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id,
user_id=user.id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
)
response = make_response(jsonable_encoder(authorization_url_response))
response.set_cookie(
"context_id",
context_id,
httponly=True,
samesite="Lax",
max_age=OAuthProxyService.__MAX_AGE__,
)
return response
class ToolOAuthCallback(Resource):
@setup_required
def get(self, provider):
context_id = request.cookies.get("context_id")
if not context_id:
raise Forbidden("context_id not found")
context = OAuthProxyService.use_proxy_context(context_id)
if context is None:
raise Forbidden("Invalid context_id")
tool_provider = ToolProviderID(provider)
plugin_id = tool_provider.plugin_id
provider_name = tool_provider.provider_name
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
oauth_handler = OAuthHandler()
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider)
if oauth_client_params is None:
raise Forbidden("no oauth available client config found for this tool provider")
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
credentials = oauth_handler.get_credentials(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
request=request,
).credentials
if not credentials:
raise Exception("the plugin credentials failed")
# add credentials to database
BuiltinToolManageService.add_builtin_tool_provider(
user_id=user_id,
tenant_id=tenant_id,
provider=provider,
credentials=dict(credentials),
api_type=CredentialType.OAUTH2,
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
class ToolBuiltinProviderSetDefaultApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
return BuiltinToolManageService.set_default_provider(
tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
)
class ToolOAuthCustomClient(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider):
parser = reqparse.RequestParser()
parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
args = parser.parse_args()
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
return BuiltinToolManageService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider=provider,
client_params=args.get("client_params", {}),
enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
)
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
return jsonable_encoder(
BuiltinToolManageService.get_custom_oauth_client_params(
tenant_id=current_user.current_tenant_id, provider=provider
)
)
@setup_required
@login_required
@account_initialization_required
def delete(self, provider):
return jsonable_encoder(
BuiltinToolManageService.delete_custom_oauth_client_params(
tenant_id=current_user.current_tenant_id, provider=provider
)
)
class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
tenant_id=current_user.current_tenant_id, provider_name=provider
)
)
class ToolBuiltinProviderGetCredentialInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
tenant_id = current_user.current_tenant_id
return jsonable_encoder(
BuiltinToolManageService.get_builtin_tool_provider_credential_info(
tenant_id=tenant_id,
provider=provider,
)
)
class ToolProviderMCPApi(Resource):
@setup_required
@login_required
@ -794,17 +1009,33 @@ class ToolMCPCallbackApi(Resource):
# tool provider
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
# tool oauth
api.add_resource(ToolPluginOAuthApi, "/oauth/plugin/<path:provider>/tool/authorization-url")
api.add_resource(ToolOAuthCallback, "/oauth/plugin/<path:provider>/tool/callback")
api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin/<path:provider>/oauth/custom-client")
# builtin tool provider
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/info")
api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/builtin/<path:provider>/add")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
api.add_resource(
ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin/<path:provider>/default-credential"
)
api.add_resource(
ToolBuiltinProviderGetCredentialInfoApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credential/info"
)
api.add_resource(
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
)
api.add_resource(
ToolBuiltinProviderCredentialsSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema",
"/workspaces/current/tool-provider/builtin/<path:provider>/credential/schema/<path:credential_type>",
)
api.add_resource(
ToolBuiltinProviderGetOauthClientSchemaApi,
"/workspaces/current/tool-provider/builtin/<path:provider>/oauth/client-schema",
)
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")

@ -175,6 +175,7 @@ class PluginInvokeToolApi(Resource):
provider=payload.provider,
tool_name=payload.tool,
tool_parameters=payload.tool_parameters,
credential_id=payload.credential_id,
),
)

@ -16,6 +16,7 @@ class AgentToolEntity(BaseModel):
tool_name: str
tool_parameters: dict[str, Any] = Field(default_factory=dict)
plugin_unique_identifier: str | None = None
credential_id: str | None = None
class AgentPromptEntity(BaseModel):

@ -4,6 +4,7 @@ from typing import Any, Optional
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyParameter
from core.plugin.entities.request import InvokeCredentials
class BaseAgentStrategy(ABC):
@ -18,11 +19,12 @@ class BaseAgentStrategy(ABC):
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
credentials: Optional[InvokeCredentials] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent strategy.
"""
yield from self._invoke(params, user_id, conversation_id, app_id, message_id)
yield from self._invoke(params, user_id, conversation_id, app_id, message_id, credentials)
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
"""
@ -38,5 +40,6 @@ class BaseAgentStrategy(ABC):
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
credentials: Optional[InvokeCredentials] = None,
) -> Generator[AgentInvokeMessage, None, None]:
pass

@ -4,6 +4,7 @@ from typing import Any, Optional
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
from core.agent.strategy.base import BaseAgentStrategy
from core.plugin.entities.request import InvokeCredentials, PluginInvokeContext
from core.plugin.impl.agent import PluginAgentClient
from core.plugin.utils.converter import convert_parameters_to_plugin_format
@ -40,6 +41,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
credentials: Optional[InvokeCredentials] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent strategy.
@ -58,4 +60,5 @@ class PluginAgentStrategy(BaseAgentStrategy):
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
context=PluginInvokeContext(credentials=credentials or InvokeCredentials()),
)

@ -39,6 +39,7 @@ class AgentConfigManager:
"provider_id": tool["provider_id"],
"tool_name": tool["tool_name"],
"tool_parameters": tool.get("tool_parameters", {}),
"credential_id": tool.get("credential_id", None),
}
agent_tools.append(AgentToolEntity(**agent_tool_properties))

@ -0,0 +1,84 @@
import json
from abc import ABC, abstractmethod
from json import JSONDecodeError
from typing import Any, Optional
from extensions.ext_redis import redis_client
class ProviderCredentialsCache(ABC):
"""Base class for provider credentials cache"""
def __init__(self, **kwargs):
self.cache_key = self._generate_cache_key(**kwargs)
@abstractmethod
def _generate_cache_key(self, **kwargs) -> str:
"""Generate cache key based on subclass implementation"""
pass
def get(self) -> Optional[dict]:
"""Get cached provider credentials"""
cached_credentials = redis_client.get(self.cache_key)
if cached_credentials:
try:
cached_credentials = cached_credentials.decode("utf-8")
return dict(json.loads(cached_credentials))
except JSONDecodeError:
return None
return None
def set(self, config: dict[str, Any]) -> None:
"""Cache provider credentials"""
redis_client.setex(self.cache_key, 86400, json.dumps(config))
def delete(self) -> None:
"""Delete cached provider credentials"""
redis_client.delete(self.cache_key)
class SingletonProviderCredentialsCache(ProviderCredentialsCache):
"""Cache for tool single provider credentials"""
def __init__(self, tenant_id: str, provider_type: str, provider_identity: str):
super().__init__(
tenant_id=tenant_id,
provider_type=provider_type,
provider_identity=provider_identity,
)
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider_type = kwargs["provider_type"]
identity_name = kwargs["provider_identity"]
identity_id = f"{provider_type}.{identity_name}"
return f"{provider_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
class ToolProviderCredentialsCache(ProviderCredentialsCache):
"""Cache for tool provider credentials"""
def __init__(self, tenant_id: str, provider: str, credential_id: str):
super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id)
def _generate_cache_key(self, **kwargs) -> str:
tenant_id = kwargs["tenant_id"]
provider = kwargs["provider"]
credential_id = kwargs["credential_id"]
return f"tool_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
class NoOpProviderCredentialCache:
"""No-op provider credential cache"""
def get(self) -> Optional[dict]:
"""Get cached provider credentials"""
return None
def set(self, config: dict[str, Any]) -> None:
"""Cache provider credentials"""
pass
def delete(self) -> None:
"""Delete cached provider credentials"""
pass

@ -1,51 +0,0 @@
import json
from enum import Enum
from json import JSONDecodeError
from typing import Optional
from extensions.ext_redis import redis_client
class ToolProviderCredentialsCacheType(Enum):
PROVIDER = "tool_provider"
ENDPOINT = "endpoint"
class ToolProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def get(self) -> Optional[dict]:
"""
Get cached model provider credentials.
:return:
"""
cached_provider_credentials = redis_client.get(self.cache_key)
if cached_provider_credentials:
try:
cached_provider_credentials = cached_provider_credentials.decode("utf-8")
cached_provider_credentials = json.loads(cached_provider_credentials)
except JSONDecodeError:
return None
return dict(cached_provider_credentials)
else:
return None
def set(self, credentials: dict) -> None:
"""
Cache model provider credentials.
:param credentials: provider credentials
:return:
"""
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
def delete(self) -> None:
"""
Delete cached model provider credentials.
:return:
"""
redis_client.delete(self.cache_key)

@ -148,9 +148,11 @@ class LLMGenerator:
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
try:

@ -1,16 +1,20 @@
from core.helper.provider_cache import SingletonProviderCredentialsCache
from core.plugin.entities.request import RequestInvokeEncrypt
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import create_provider_encrypter
from models.account import Tenant
class PluginEncrypter:
@classmethod
def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict:
encrypter = ProviderConfigEncrypter(
encrypter, cache = create_provider_encrypter(
tenant_id=tenant.id,
config=payload.config,
provider_type=payload.namespace,
provider_identity=payload.identity,
cache=SingletonProviderCredentialsCache(
tenant_id=tenant.id,
provider_type=payload.namespace,
provider_identity=payload.identity,
),
)
if payload.opt == "encrypt":
@ -22,7 +26,7 @@ class PluginEncrypter:
"data": encrypter.decrypt(payload.data),
}
elif payload.opt == "clear":
encrypter.delete_tool_credentials_cache()
cache.delete()
return {
"data": {},
}

@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any
from typing import Any, Optional
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
@ -23,6 +23,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
provider: str,
tool_name: str,
tool_parameters: dict[str, Any],
credential_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
invoke tool
@ -30,7 +31,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
# get tool runtime
try:
tool_runtime = ToolManager.get_tool_runtime_from_plugin(
tool_type, tenant_id, provider, tool_name, tool_parameters
tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id
)
response = ToolEngine.generic_invoke(
tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1

@ -27,6 +27,20 @@ from core.workflow.nodes.question_classifier.entities import (
)
class InvokeCredentials(BaseModel):
tool_credentials: dict[str, str] = Field(
default_factory=dict,
description="Map of tool provider to credential id, used to store the credential id for the tool provider.",
)
class PluginInvokeContext(BaseModel):
credentials: Optional[InvokeCredentials] = Field(
default_factory=InvokeCredentials,
description="Credentials context for the plugin invocation or backward invocation.",
)
class RequestInvokeTool(BaseModel):
"""
Request to invoke a tool
@ -36,6 +50,7 @@ class RequestInvokeTool(BaseModel):
provider: str
tool: str
tool_parameters: dict
credential_id: Optional[str] = None
class BaseRequestInvokeModel(BaseModel):

@ -6,6 +6,7 @@ from core.plugin.entities.plugin import GenericProviderID
from core.plugin.entities.plugin_daemon import (
PluginAgentProviderEntity,
)
from core.plugin.entities.request import PluginInvokeContext
from core.plugin.impl.base import BasePluginClient
@ -83,6 +84,7 @@ class PluginAgentClient(BasePluginClient):
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
context: Optional[PluginInvokeContext] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent with the given tenant, user, plugin, provider, name and parameters.
@ -99,6 +101,7 @@ class PluginAgentClient(BasePluginClient):
"conversation_id": conversation_id,
"app_id": app_id,
"message_id": message_id,
"context": context.model_dump() if context else {},
"data": {
"agent_strategy_provider": agent_provider_id.provider_name,
"agent_strategy": agent_strategy,

@ -15,27 +15,32 @@ class OAuthHandler(BasePluginClient):
user_id: str,
plugin_id: str,
provider: str,
redirect_uri: str,
system_credentials: Mapping[str, Any],
) -> PluginOAuthAuthorizationUrlResponse:
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
PluginOAuthAuthorizationUrlResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"system_credentials": system_credentials,
try:
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
PluginOAuthAuthorizationUrlResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"redirect_uri": redirect_uri,
"system_credentials": system_credentials,
},
},
},
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for authorization URL request.")
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for authorization URL request.")
except Exception as e:
raise ValueError(f"Error getting authorization URL: {e}")
def get_credentials(
self,
@ -43,6 +48,7 @@ class OAuthHandler(BasePluginClient):
user_id: str,
plugin_id: str,
provider: str,
redirect_uri: str,
system_credentials: Mapping[str, Any],
request: Request,
) -> PluginOAuthCredentialsResponse:
@ -50,30 +56,33 @@ class OAuthHandler(BasePluginClient):
Get credentials from the given request.
"""
# encode request to raw http request
raw_request_bytes = self._convert_request_to_raw_data(request)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
PluginOAuthCredentialsResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"system_credentials": system_credentials,
# for json serialization
"raw_http_request": binascii.hexlify(raw_request_bytes).decode(),
try:
# encode request to raw http request
raw_request_bytes = self._convert_request_to_raw_data(request)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
PluginOAuthCredentialsResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"redirect_uri": redirect_uri,
"system_credentials": system_credentials,
# for json serialization
"raw_http_request": binascii.hexlify(raw_request_bytes).decode(),
},
},
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
},
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for authorization URL request.")
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for authorization URL request.")
except Exception as e:
raise ValueError(f"Error getting credentials: {e}")
def _convert_request_to_raw_data(self, request: Request) -> bytes:
"""

@ -6,7 +6,7 @@ from pydantic import BaseModel
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.impl.base import BasePluginClient
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
class PluginToolManager(BasePluginClient):
@ -78,6 +78,7 @@ class PluginToolManager(BasePluginClient):
tool_provider: str,
tool_name: str,
credentials: dict[str, Any],
credential_type: CredentialType,
tool_parameters: dict[str, Any],
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
@ -102,6 +103,7 @@ class PluginToolManager(BasePluginClient):
"provider": tool_provider_id.provider_name,
"tool": tool_name,
"credentials": credentials,
"credential_type": credential_type,
"tool_parameters": tool_parameters,
},
},

@ -238,9 +238,11 @@ class WordExtractor(BaseExtractor):
paragraph_content = []
for run in paragraph.runs:
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
# Process drawing type images
drawing_elements = run.element.findall(
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing"
)
has_drawing = False
for drawing in drawing_elements:
blip_elements = drawing.findall(
".//{http://schemas.openxmlformats.org/drawingml/2006/main}blip"
@ -252,6 +254,34 @@ class WordExtractor(BaseExtractor):
if embed_id:
image_part = doc.part.related_parts.get(embed_id)
if image_part in image_map:
has_drawing = True
paragraph_content.append(image_map[image_part])
# Process pict type images
shape_elements = run.element.findall(
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
)
for shape in shape_elements:
# Find image data in VML
shape_image = shape.find(
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}binData"
)
if shape_image is not None and shape_image.text:
image_id = shape_image.get(
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
)
if image_id and image_id in doc.part.rels:
image_part = doc.part.rels[image_id].target_part
if image_part in image_map and not has_drawing:
paragraph_content.append(image_map[image_part])
# Find imagedata element in VML
image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
if image_data is not None:
image_id = image_data.get("id") or image_data.get(
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
)
if image_id and image_id in doc.part.rels:
image_part = doc.part.rels[image_id].target_part
if image_part in image_map and not has_drawing:
paragraph_content.append(image_map[image_part])
if run.text.strip():
paragraph_content.append(run.text.strip())

@ -4,7 +4,7 @@ from openai import BaseModel
from pydantic import Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import ToolInvokeFrom
from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom
class ToolRuntime(BaseModel):
@ -17,6 +17,7 @@ class ToolRuntime(BaseModel):
invoke_from: Optional[InvokeFrom] = None
tool_invoke_from: Optional[ToolInvokeFrom] = None
credentials: dict[str, Any] = Field(default_factory=dict)
credential_type: CredentialType = Field(default=CredentialType.API_KEY)
runtime_parameters: dict[str, Any] = Field(default_factory=dict)

@ -7,7 +7,13 @@ from core.helper.module_import_helper import load_single_subclass_from_source
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType
from core.tools.entities.tool_entities import (
CredentialType,
OAuthSchema,
ToolEntity,
ToolProviderEntity,
ToolProviderType,
)
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
from core.tools.errors import (
ToolProviderNotFoundError,
@ -39,10 +45,18 @@ class BuiltinToolProviderController(ToolProviderController):
credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {})
credentials_schema.append(credential_dict)
oauth_schema = None
if provider_yaml.get("oauth_schema", None) is not None:
oauth_schema = OAuthSchema(
client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []),
credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []),
)
super().__init__(
entity=ToolProviderEntity(
identity=provider_yaml["identity"],
credentials_schema=credentials_schema,
oauth_schema=oauth_schema,
),
)
@ -97,10 +111,39 @@ class BuiltinToolProviderController(ToolProviderController):
:return: the credentials schema
"""
if not self.entity.credentials_schema:
return []
return self.get_credentials_schema_by_type(CredentialType.API_KEY.value)
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
"""
returns the credentials schema of the provider
return self.entity.credentials_schema.copy()
:param credential_type: the type of the credential
:return: the credentials schema of the provider
"""
if credential_type == CredentialType.OAUTH2.value:
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
if credential_type == CredentialType.API_KEY.value:
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
raise ValueError(f"Invalid credential type: {credential_type}")
def get_oauth_client_schema(self) -> list[ProviderConfig]:
"""
returns the oauth client schema of the provider
:return: the oauth client schema
"""
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
def get_supported_credential_types(self) -> list[str]:
"""
returns the credential support type of the provider
"""
types = []
if self.entity.credentials_schema is not None and len(self.entity.credentials_schema) > 0:
types.append(CredentialType.API_KEY.value)
if self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) > 0:
types.append(CredentialType.OAUTH2.value)
return types
def get_tools(self) -> list[BuiltinTool]:
"""
@ -123,7 +166,11 @@ class BuiltinToolProviderController(ToolProviderController):
:return: whether the provider needs credentials
"""
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
return (
self.entity.credentials_schema is not None
and len(self.entity.credentials_schema) != 0
or (self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) != 0)
)
@property
def provider_type(self) -> ToolProviderType:

@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, field_validator
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool import ToolParameter
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.entities.tool_entities import CredentialType, ToolProviderType
class ToolApiEntity(BaseModel):
@ -87,3 +87,22 @@ class ToolProviderApiEntity(BaseModel):
def optional_field(self, key: str, value: Any) -> dict:
"""Return dict with key-value if value is truthy, empty dict otherwise."""
return {key: value} if value else {}
class ToolProviderCredentialApiEntity(BaseModel):
id: str = Field(description="The unique id of the credential")
name: str = Field(description="The name of the credential")
provider: str = Field(description="The provider of the credential")
credential_type: CredentialType = Field(description="The type of the credential")
is_default: bool = Field(
default=False, description="Whether the credential is the default credential for the provider in the workspace"
)
credentials: dict = Field(description="The credentials of the provider")
class ToolProviderCredentialInfoApiEntity(BaseModel):
supported_credential_types: list[str] = Field(description="The supported credential types of the provider")
is_oauth_custom_client_enabled: bool = Field(
default=False, description="Whether the OAuth custom client is enabled for the provider"
)
credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider")

@ -370,10 +370,18 @@ class ToolEntity(BaseModel):
return v or []
class OAuthSchema(BaseModel):
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
credentials_schema: list[ProviderConfig] = Field(
default_factory=list, description="The schema of the OAuth credentials"
)
class ToolProviderEntity(BaseModel):
identity: ToolProviderIdentity
plugin_id: Optional[str] = None
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
oauth_schema: Optional[OAuthSchema] = None
class ToolProviderEntityWithPlugin(ToolProviderEntity):
@ -453,6 +461,7 @@ class ToolSelector(BaseModel):
options: Optional[list[PluginParameterOption]] = None
provider_id: str = Field(..., description="The id of the provider")
credential_id: Optional[str] = Field(default=None, description="The id of the credential")
tool_name: str = Field(..., description="The name of the tool")
tool_description: str = Field(..., description="The description of the tool")
tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
@ -460,3 +469,36 @@ class ToolSelector(BaseModel):
def to_plugin_parameter(self) -> dict[str, Any]:
return self.model_dump()
class CredentialType(enum.StrEnum):
API_KEY = "api-key"
OAUTH2 = "oauth2"
def get_name(self):
if self == CredentialType.API_KEY:
return "API KEY"
elif self == CredentialType.OAUTH2:
return "AUTH"
else:
return self.value.replace("-", " ").upper()
def is_editable(self):
return self == CredentialType.API_KEY
def is_validate_allowed(self):
return self == CredentialType.API_KEY
@classmethod
def values(cls):
return [item.value for item in cls]
@classmethod
def of(cls, credential_type: str) -> "CredentialType":
type_name = credential_type.lower()
if type_name == "api-key":
return cls.API_KEY
elif type_name == "oauth2":
return cls.OAUTH2
else:
raise ValueError(f"Invalid credential type: {credential_type}")

@ -44,6 +44,7 @@ class PluginTool(Tool):
tool_provider=self.entity.identity.provider,
tool_name=self.entity.identity.name,
credentials=self.runtime.credentials,
credential_type=self.runtime.credential_type,
tool_parameters=tool_parameters,
conversation_id=conversation_id,
app_id=app_id,

@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
from yarl import URL
import contexts
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
@ -17,6 +18,7 @@ from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.workflow.entities.variable_pool import VariablePool
from services.tools.mcp_tools_mange_service import MCPToolManageService
@ -24,7 +26,6 @@ from services.tools.mcp_tools_mange_service import MCPToolManageService
if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity
from configs import dify_config
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
@ -41,16 +42,17 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
CredentialType,
ToolInvokeFrom,
ToolParameter,
ToolProviderType,
)
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError
from core.tools.errors import ToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import (
ProviderConfigEncrypter,
ToolParameterConfigurationManager,
)
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
@ -68,8 +70,11 @@ class ToolManager:
@classmethod
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
"""
get the hardcoded provider
"""
if len(cls._hardcoded_providers) == 0:
# init the builtin providers
cls.load_hardcoded_providers_cache()
@ -113,7 +118,12 @@ class ToolManager:
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(Lock())
plugin_tool_providers = contexts.plugin_tool_providers.get()
if provider in plugin_tool_providers:
return plugin_tool_providers[provider]
with contexts.plugin_tool_providers_lock.get():
# double check
plugin_tool_providers = contexts.plugin_tool_providers.get()
if provider in plugin_tool_providers:
return plugin_tool_providers[provider]
@ -131,25 +141,7 @@ class ToolManager:
)
plugin_tool_providers[provider] = controller
return controller
@classmethod
def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
"""
get the builtin tool
:param provider: the name of the provider
:param tool_name: the name of the tool
:param tenant_id: the id of the tenant
:return: the provider, the tool
"""
provider_controller = cls.get_builtin_provider(provider, tenant_id)
tool = provider_controller.get_tool(tool_name)
if tool is None:
raise ToolNotFoundError(f"tool {tool_name} not found")
return tool
return controller
@classmethod
def get_tool_runtime(
@ -160,6 +152,7 @@ class ToolManager:
tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
credential_id: Optional[str] = None,
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
"""
get the tool runtime
@ -170,6 +163,7 @@ class ToolManager:
:param tenant_id: the tenant id
:param invoke_from: invoke from
:param tool_invoke_from: the tool invoke from
:param credential_id: the credential id
:return: the tool
"""
@ -193,49 +187,70 @@ class ToolManager:
)
),
)
builtin_provider = None
if isinstance(provider_controller, PluginToolProviderController):
provider_id_entity = ToolProviderID(provider_id)
# get credentials
builtin_provider: BuiltinToolProvider | None = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
)
.first()
)
# get specific credentials
if is_valid_uuid(credential_id):
try:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
except Exception as e:
builtin_provider = None
logger.info(f"Error getting builtin provider {credential_id}:{e}", exc_info=True)
# if the provider has been deleted, raise an error
if builtin_provider is None:
raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
# fallback to the default provider
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
# use the default provider
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
else:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
# decrypt the credentials
credentials = builtin_provider.credentials
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
],
cache=ToolProviderCredentialsCache(
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
),
)
decrypted_credentials = tool_configuration.decrypt(credentials)
return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=decrypted_credentials,
credentials=encrypter.decrypt(builtin_provider.credentials),
credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
@ -245,22 +260,16 @@ class ToolManager:
elif provider_type == ToolProviderType.API:
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
# decrypt the credentials
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
provider_type=api_provider.provider_type.value,
provider_identity=api_provider.entity.identity.name,
controller=api_provider,
)
decrypted_credentials = tool_configuration.decrypt(credentials)
return cast(
ApiTool,
api_provider.get_tool(tool_name).fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=decrypted_credentials,
credentials=encrypter.decrypt(credentials),
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
@ -320,6 +329,7 @@ class ToolManager:
tenant_id=tenant_id,
invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.AGENT,
credential_id=agent_tool.credential_id,
)
runtime_parameters = {}
parameters = tool_entity.get_merged_runtime_parameters()
@ -362,6 +372,7 @@ class ToolManager:
tenant_id=tenant_id,
invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
credential_id=workflow_tool.credential_id,
)
parameters = tool_runtime.get_merged_runtime_parameters()
@ -391,6 +402,7 @@ class ToolManager:
provider: str,
tool_name: str,
tool_parameters: dict[str, Any],
credential_id: Optional[str] = None,
) -> Tool:
"""
get tool runtime from plugin
@ -402,6 +414,7 @@ class ToolManager:
tenant_id=tenant_id,
invoke_from=InvokeFrom.SERVICE_API,
tool_invoke_from=ToolInvokeFrom.PLUGIN,
credential_id=credential_id,
)
runtime_parameters = {}
parameters = tool_entity.get_merged_runtime_parameters()
@ -551,6 +564,22 @@ class ToolManager:
return cls._builtin_tools_labels[tool_name]
@classmethod
def list_default_builtin_providers(cls, tenant_id: str) -> list[BuiltinToolProvider]:
"""
list all the builtin providers
"""
# according to multi credentials, select the one with is_default=True first, then created_at oldest
# for compatibility with old version
sql = """
SELECT DISTINCT ON (tenant_id, provider) id
FROM tool_builtin_providers
WHERE tenant_id = :tenant_id
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
"""
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
@classmethod
def list_providers_from_api(
cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
@ -565,21 +594,13 @@ class ToolManager:
with db.session.no_autoflush:
if "builtin" in filters:
# get builtin providers
builtin_providers = cls.list_builtin_providers(tenant_id)
# get db builtin providers
db_builtin_providers: list[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
)
# rewrite db_builtin_providers
for db_provider in db_builtin_providers:
tool_provider_id = str(ToolProviderID(db_provider.provider))
db_provider.provider = tool_provider_id
def find_db_builtin_provider(provider):
return next((x for x in db_builtin_providers if x.provider == provider), None)
# key: provider name, value: provider
db_builtin_providers = {
str(ToolProviderID(provider.provider)): provider
for provider in cls.list_default_builtin_providers(tenant_id)
}
# append builtin providers
for provider in builtin_providers:
@ -591,10 +612,9 @@ class ToolManager:
name_func=lambda x: x.identity.name,
):
continue
user_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider,
db_provider=find_db_builtin_provider(provider.entity.identity.name),
db_provider=db_builtin_providers.get(provider.entity.identity.name),
decrypt_credentials=False,
)
@ -604,7 +624,6 @@ class ToolManager:
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
# get db api providers
if "api" in filters:
db_api_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
@ -764,15 +783,12 @@ class ToolManager:
auth_type,
)
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
provider_type=controller.provider_type.value,
provider_identity=controller.entity.identity.name,
controller=controller,
)
decrypted_credentials = tool_configuration.decrypt(credentials)
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials))
try:
icon = json.loads(provider_obj.icon)

@ -1,12 +1,8 @@
from copy import deepcopy
from typing import Any
from pydantic import BaseModel
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolParameter,
@ -14,110 +10,6 @@ from core.tools.entities.tool_entities import (
)
class ProviderConfigEncrypter(BaseModel):
tenant_id: str
config: list[BasicProviderConfig]
provider_type: str
provider_identity: str
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
"""
deep copy data
"""
return deepcopy(data)
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
data[field_name] = encrypted
return data
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = (
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
)
else:
data[field_name] = "*" * len(data[field_name])
return data
def decrypt(self, data: dict[str, str], use_cache: bool = True) -> dict[str, str]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
if use_cache:
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f"{self.provider_type}.{self.provider_identity}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cached_credentials = cache.get()
if cached_credentials:
return cached_credentials
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
try:
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
except Exception:
pass
if use_cache:
cache.set(data)
return data
def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f"{self.provider_type}.{self.provider_identity}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cache.delete()
class ToolParameterConfigurationManager:
"""
Tool parameter configuration manager

@ -0,0 +1,142 @@
from copy import deepcopy
from typing import Any, Optional, Protocol
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
from core.helper.provider_cache import SingletonProviderCredentialsCache
from core.tools.__base.tool_provider import ToolProviderController
class ProviderConfigCache(Protocol):
"""
Interface for provider configuration cache operations
"""
def get(self) -> Optional[dict]:
"""Get cached provider configuration"""
...
def set(self, config: dict[str, Any]) -> None:
"""Cache provider configuration"""
...
def delete(self) -> None:
"""Delete cached provider configuration"""
...
class ProviderConfigEncrypter:
tenant_id: str
config: list[BasicProviderConfig]
provider_config_cache: ProviderConfigCache
def __init__(
self,
tenant_id: str,
config: list[BasicProviderConfig],
provider_config_cache: ProviderConfigCache,
):
self.tenant_id = tenant_id
self.config = config
self.provider_config_cache = provider_config_cache
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
"""
deep copy data
"""
return deepcopy(data)
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
data[field_name] = encrypted
return data
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = (
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
)
else:
data[field_name] = "*" * len(data[field_name])
return data
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
cached_credentials = self.provider_config_cache.get()
if cached_credentials:
return cached_credentials
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
try:
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
except Exception:
pass
self.provider_config_cache.set(data)
return data
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
cache = SingletonProviderCredentialsCache(
tenant_id=tenant_id,
provider_type=controller.provider_type.value,
provider_identity=controller.entity.identity.name,
)
encrypt = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
provider_config_cache=cache,
)
return encrypt, cache

@ -0,0 +1,187 @@
import base64
import hashlib
import logging
from collections.abc import Mapping
from typing import Any, Optional
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad, unpad
from pydantic import TypeAdapter
from configs import dify_config
logger = logging.getLogger(__name__)
class OAuthEncryptionError(Exception):
"""OAuth encryption/decryption specific error"""
pass
class SystemOAuthEncrypter:
"""
A simple OAuth parameters encrypter using AES-CBC encryption.
This class provides methods to encrypt and decrypt OAuth parameters
using AES-CBC mode with a key derived from the application's SECRET_KEY.
"""
def __init__(self, secret_key: Optional[str] = None):
"""
Initialize the OAuth encrypter.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
Raises:
ValueError: If SECRET_KEY is not configured or empty
"""
secret_key = secret_key or dify_config.SECRET_KEY or ""
# Generate a fixed 256-bit key using SHA-256
self.key = hashlib.sha256(secret_key.encode()).digest()
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters.
Args:
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
Returns:
Base64-encoded encrypted string
Raises:
OAuthEncryptionError: If encryption fails
ValueError: If oauth_params is invalid
"""
try:
# Generate random IV (16 bytes)
iv = get_random_bytes(16)
# Create AES cipher (CBC mode)
cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Encrypt data
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
encrypted_data = cipher.encrypt(padded_data)
# Combine IV and encrypted data
combined = iv + encrypted_data
# Return base64 encoded string
return base64.b64encode(combined).decode()
except Exception as e:
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
Raises:
OAuthEncryptionError: If decryption fails
ValueError: If encrypted_data is invalid
"""
if not isinstance(encrypted_data, str):
raise ValueError("encrypted_data must be a string")
if not encrypted_data:
raise ValueError("encrypted_data cannot be empty")
try:
# Base64 decode
combined = base64.b64decode(encrypted_data)
# Check minimum length (IV + at least one AES block)
if len(combined) < 32: # 16 bytes IV + 16 bytes minimum encrypted data
raise ValueError("Invalid encrypted data format")
# Separate IV and encrypted data
iv = combined[:16]
encrypted_data_bytes = combined[16:]
# Create AES cipher
cipher = AES.new(self.key, AES.MODE_CBC, iv)
# Decrypt data
decrypted_data = cipher.decrypt(encrypted_data_bytes)
unpadded_data = unpad(decrypted_data, AES.block_size)
# Parse JSON
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
if not isinstance(oauth_params, dict):
raise ValueError("Decrypted data is not a valid dictionary")
return oauth_params
except Exception as e:
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
# Factory function for creating encrypter instances
def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAuthEncrypter:
"""
Create an OAuth encrypter instance.
Args:
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
Returns:
SystemOAuthEncrypter instance
"""
return SystemOAuthEncrypter(secret_key=secret_key)
# Global encrypter instance (for backward compatibility)
_oauth_encrypter: Optional[SystemOAuthEncrypter] = None
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
"""
Get the global OAuth encrypter instance.
Returns:
SystemOAuthEncrypter instance
"""
global _oauth_encrypter
if _oauth_encrypter is None:
_oauth_encrypter = SystemOAuthEncrypter()
return _oauth_encrypter
# Convenience functions for backward compatibility
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
"""
Encrypt OAuth parameters using the global encrypter.
Args:
oauth_params: OAuth parameters dictionary
Returns:
Base64-encoded encrypted string
"""
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
"""
Decrypt OAuth parameters using the global encrypter.
Args:
encrypted_data: Base64-encoded encrypted string
Returns:
Decrypted OAuth parameters dictionary
"""
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)

@ -1,7 +1,9 @@
import uuid
def is_valid_uuid(uuid_str: str) -> bool:
def is_valid_uuid(uuid_str: str | None) -> bool:
if uuid_str is None or len(uuid_str) == 0:
return False
try:
uuid.UUID(uuid_str)
return True

@ -91,8 +91,6 @@ class SegmentType(StrEnum):
return SegmentType.OBJECT
elif isinstance(value, File):
return SegmentType.FILE
elif isinstance(value, str):
return SegmentType.STRING
else:
return None

@ -152,7 +152,6 @@ class VariablePool(BaseModel):
self.variable_dictionary[selector[0]] = {}
return
key, hash_key = self._selector_to_keys(selector)
hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[key].pop(hash_key, None)
def convert_template(self, template: str, /):

@ -4,6 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from packaging.version import Version
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -13,10 +14,16 @@ from core.agent.strategy.plugin import PluginAgentStrategy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.plugin.entities.request import InvokeCredentials
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.impl.plugin import PluginInstaller
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
from core.tools.entities.tool_entities import (
ToolIdentity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
from core.tools.tool_manager import ToolManager
from core.variables.segments import StringSegment
from core.workflow.entities.node_entities import NodeRunResult
@ -84,6 +91,7 @@ class AgentNode(ToolNode):
for_log=True,
strategy=strategy,
)
credentials = self._generate_credentials(parameters=parameters)
# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
@ -94,6 +102,7 @@ class AgentNode(ToolNode):
user_id=self.user_id,
app_id=self.app_id,
conversation_id=conversation_id.text if conversation_id else None,
credentials=credentials,
)
except Exception as e:
yield RunCompletedEvent(
@ -246,6 +255,7 @@ class AgentNode(ToolNode):
tool_name=tool.get("tool_name", ""),
tool_parameters=parameters,
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
credential_id=tool.get("credential_id", None),
)
extra = tool.get("extra", {})
@ -276,6 +286,7 @@ class AgentNode(ToolNode):
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters,
"credential_id": tool.get("credential_id", None),
"provider_type": provider_type.value,
}
)
@ -305,6 +316,27 @@ class AgentNode(ToolNode):
return result
def _generate_credentials(
self,
parameters: dict[str, Any],
) -> InvokeCredentials:
"""
Generate credentials based on the given agent parameters.
"""
credentials = InvokeCredentials()
# generate credentials for tools selector
credentials.tool_credentials = {}
for tool in parameters.get("tools", []):
if tool.get("credential_id"):
try:
identity = ToolIdentity.model_validate(tool.get("identity", {}))
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
except ValidationError:
continue
return credentials
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

@ -1,4 +1,5 @@
from collections.abc import Mapping, Sequence
from decimal import Decimal
from typing import Any, Optional
from configs import dify_config
@ -114,8 +115,10 @@ class CodeNode(BaseNode[CodeNodeData]):
)
if isinstance(value, float):
decimal_value = Decimal(str(value)).normalize()
precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator]
# raise error if precision is too high
if len(str(value).split(".")[1]) > dify_config.CODE_MAX_PRECISION:
if precision > dify_config.CODE_MAX_PRECISION:
raise OutputValidationError(
f"Output variable `{variable}` has too high precision,"
f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."

@ -849,7 +849,7 @@ class LLMNode(BaseNode[LLMNodeData]):
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=jinja2_variables,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_message_content_with_role(
@ -1021,20 +1021,20 @@ def _combine_message_content_with_role(
def _render_jinja2_message(
*,
template: str,
jinjia2_variables: Sequence[VariableSelector],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
):
if not template:
return ""
jinjia2_inputs = {}
for jinja2_variable in jinjia2_variables:
jinja2_inputs = {}
for jinja2_variable in jinja2_variables:
variable = variable_pool.get(jinja2_variable.value_selector)
jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
code_execute_resp = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2,
code=template,
inputs=jinjia2_inputs,
inputs=jinja2_inputs,
)
result_text = code_execute_resp["result"]
return result_text
@ -1130,7 +1130,7 @@ def _handle_completion_template(
if template.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=template.jinja2_text or "",
jinjia2_variables=jinja2_variables,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
else:

@ -14,6 +14,7 @@ class ToolEntity(BaseModel):
tool_name: str
tool_label: str # redundancy
tool_configurations: dict[str, Any]
credential_id: str | None = None
plugin_unique_identifier: str | None = None # redundancy
@field_validator("tool_configurations", mode="before")

@ -20,6 +20,7 @@ def handle(sender, **kwargs):
provider_id=tool_entity.provider_id,
tool_name=tool_entity.tool_name,
tenant_id=app.tenant_id,
credential_id=tool_entity.credential_id,
)
manager = ToolParameterConfigurationManager(
tenant_id=app.tenant_id,

@ -18,6 +18,7 @@ def init_app(app: DifyApp):
reset_email,
reset_encrypt_key_pair,
reset_password,
setup_system_tool_oauth_client,
upgrade_db,
vdb_migrate,
)
@ -40,6 +41,7 @@ def init_app(app: DifyApp):
clear_free_plan_tenant_expired_logs,
clear_orphaned_file_records,
remove_orphaned_files_on_storage,
setup_system_tool_oauth_client,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

@ -205,7 +205,7 @@ def init_app(app: DifyApp):
metric_endpoint = dify_config.OTLP_METRIC_ENDPOINT
if not metric_endpoint:
metric_endpoint = dify_config.OTLP_BASE_ENDPOINT + "/v1/traces"
metric_endpoint = dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics"
metric_exporter = HTTPMetricExporter(
endpoint=metric_endpoint,
headers=headers,

@ -0,0 +1,41 @@
"""empty message
Revision ID: 16081485540c
Revises: d28f2004b072
Create Date: 2025-05-15 16:35:39.113777
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '16081485540c'
down_revision = '2adcbe1f5dfb'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tenant_plugin_auto_upgrade_strategies',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('strategy_setting', sa.String(length=16), server_default='fix_only', nullable=False),
sa.Column('upgrade_time_of_day', sa.Integer(), nullable=False),
sa.Column('upgrade_mode', sa.String(length=16), server_default='exclude', nullable=False),
sa.Column('exclude_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
sa.Column('include_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tenant_plugin_auto_upgrade_strategy_pkey'),
sa.UniqueConstraint('tenant_id', name='unique_tenant_plugin_auto_upgrade_strategy')
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('tenant_plugin_auto_upgrade_strategies')
# ### end Alembic commands ###

@ -12,7 +12,7 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '4474872b0ee6'
down_revision = '2adcbe1f5dfb'
down_revision = '16081485540c'
branch_labels = None
depends_on = None

@ -0,0 +1,62 @@
"""tool oauth
Revision ID: 71f5020c6470
Revises: 4474872b0ee6
Create Date: 2025-06-24 17:05:43.118647
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '71f5020c6470'
down_revision = '1c9ba48be8e4'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tool_oauth_system_clients',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('plugin_id', sa.String(length=512), nullable=False),
sa.Column('provider', sa.String(length=255), nullable=False),
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'),
sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx')
)
op.create_table('tool_oauth_tenant_clients',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('plugin_id', sa.String(length=512), nullable=False),
sa.Column('provider', sa.String(length=255), nullable=False),
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_oauth_tenant_client_pkey'),
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_tenant_client')
)
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('name', sa.String(length=256), server_default=sa.text("'API KEY 1'::character varying"), nullable=False))
batch_op.add_column(sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False))
batch_op.add_column(sa.Column('credential_type', sa.String(length=32), server_default=sa.text("'api-key'::character varying"), nullable=False))
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'name'])
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider'])
batch_op.drop_column('credential_type')
batch_op.drop_column('is_default')
batch_op.drop_column('name')
op.drop_table('tool_oauth_tenant_clients')
op.drop_table('tool_oauth_system_clients')
# ### end Alembic commands ###

@ -21,6 +21,43 @@ from .model import Account, App, Tenant
from .types import StringUUID
# system level tool oauth client params (client_id, client_secret, etc.)
class ToolOAuthSystemClient(Base):
__tablename__ = "tool_oauth_system_clients"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
# oauth params of the tool provider
encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
# tenant level tool oauth client params (client_id, client_secret, etc.)
class ToolOAuthTenantClient(Base):
__tablename__ = "tool_oauth_tenant_clients"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"),
db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# tenant id
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
# oauth params of the tool provider
encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
@property
def oauth_params(self) -> dict:
return cast(dict, json.loads(self.encrypted_oauth_params or "{}"))
class BuiltinToolProvider(Base):
"""
This table stores the tool provider information for built-in tools for each tenant.
@ -29,12 +66,14 @@ class BuiltinToolProvider(Base):
__tablename__ = "tool_builtin_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
# one tenant can only have one tool provider with the same name
db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_tool_provider"),
db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"),
)
# id of the tool provider
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
name: Mapped[str] = mapped_column(
db.String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying")
)
# id of the tenant
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
# who created this tool provider
@ -49,6 +88,11 @@ class BuiltinToolProvider(Base):
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)
is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
# credential type, e.g., "api-key", "oauth2"
credential_type: Mapped[str] = mapped_column(
db.String(32), nullable=False, server_default=db.text("'api-key'::character varying")
)
@property
def credentials(self) -> dict:
@ -68,7 +112,7 @@ class ApiToolProvider(Base):
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# name of the api provider
name = db.Column(db.String(255), nullable=False)
name = db.Column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying"))
# icon
icon = db.Column(db.String(255), nullable=False)
# original schema
@ -281,18 +325,19 @@ class MCPToolProvider(Base):
@property
def decrypted_credentials(self) -> dict:
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import create_provider_encrypter
provider_controller = MCPToolProviderController._from_db(self)
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.provider_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
cache=NoOpProviderCredentialCache(),
)
return tool_configuration.decrypt(self.credentials, use_cache=False)
return encrypter.decrypt(self.credentials) # type: ignore
class ToolModelInvoke(Base):

@ -575,13 +575,26 @@ class AppDslService:
raise ValueError("Missing draft workflow configuration, please check.")
workflow_dict = workflow.to_dict(include_secret=include_secret)
# TODO: refactor: we need a better way to filter workspace related data from nodes
for node in workflow_dict.get("graph", {}).get("nodes", []):
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
node_data = node.get("data", {})
if not node_data:
continue
data_type = node_data.get("type", "")
if data_type == NodeType.KNOWLEDGE_RETRIEVAL.value:
dataset_ids = node_data.get("dataset_ids", [])
node_data["dataset_ids"] = [
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id)
for dataset_id in dataset_ids
]
# filter credential id from tool node
if not include_secret and data_type == NodeType.TOOL.value:
node_data.pop("credential_id", None)
# filter credential id from agent node
if not include_secret and data_type == NodeType.AGENT.value:
for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
tool.pop("credential_id", None)
export_data["workflow"] = workflow_dict
dependencies = cls._extract_dependencies_from_workflow(workflow)
export_data["dependencies"] = [
@ -602,7 +615,15 @@ class AppDslService:
if not app_model_config:
raise ValueError("Missing app configuration, please check.")
export_data["model_config"] = app_model_config.to_dict()
model_config = app_model_config.to_dict()
# TODO: refactor: we need a better way to filter workspace related data from model config
# filter credential id from model config
for tool in model_config.get("agent_mode", {}).get("tools", []):
tool.pop("credential_id", None)
export_data["model_config"] = model_config
dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict())
export_data["dependencies"] = [
jsonable_encoder(d.model_dump())

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from core.plugin.entities.parameters import PluginParameterOption
from core.plugin.impl.dynamic_select import DynamicSelectClient
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import create_tool_provider_encrypter
from extensions.ext_database import db
from models.tools import BuiltinToolProvider
@ -38,11 +38,9 @@ class PluginParameterService:
case "tool":
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
controller=provider_controller,
)
# check if credentials are required
@ -63,7 +61,7 @@ class PluginParameterService:
if db_record is None:
raise ValueError(f"Builtin provider {provider} not found when fetching credentials")
credentials = tool_configuration.decrypt(db_record.credentials)
credentials = encrypter.decrypt(db_record.credentials)
case _:
raise ValueError(f"Invalid provider type: {provider_type}")

@ -196,6 +196,17 @@ class PluginService:
manager = PluginInstaller()
return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
@staticmethod
def is_plugin_verified(tenant_id: str, plugin_unique_identifier: str) -> bool:
"""
Check if the plugin is verified
"""
manager = PluginInstaller()
try:
return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier).verified
except Exception:
return False
@staticmethod
def fetch_install_tasks(tenant_id: str, page: int, page_size: int) -> Sequence[PluginInstallTask]:
"""

@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import (
)
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import create_tool_provider_encrypter
from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db
from models.tools import ApiToolProvider
@ -164,15 +164,11 @@ class ApiToolManageService:
provider_controller.load_bundled_tools(tool_bundles)
# encrypt credentials
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
controller=provider_controller,
)
encrypted_credentials = tool_configuration.encrypt(credentials)
db_provider.credentials_str = json.dumps(encrypted_credentials)
db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials))
db.session.add(db_provider)
db.session.commit()
@ -297,28 +293,26 @@ class ApiToolManageService:
provider_controller.load_bundled_tools(tool_bundles)
# get original credentials if exists
tool_configuration = ProviderConfigEncrypter(
encrypter, cache = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
controller=provider_controller,
)
original_credentials = tool_configuration.decrypt(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
original_credentials = encrypter.decrypt(provider.credentials)
masked_credentials = encrypter.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name]
credentials = tool_configuration.encrypt(credentials)
credentials = encrypter.encrypt(credentials)
provider.credentials_str = json.dumps(credentials)
db.session.add(provider)
db.session.commit()
# delete cache
tool_configuration.delete_tool_credentials_cache()
cache.delete()
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
@ -416,15 +410,13 @@ class ApiToolManageService:
# decrypt credentials
if db_provider.id:
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
controller=provider_controller,
)
decrypted_credentials = tool_configuration.decrypt(credentials)
decrypted_credentials = encrypter.decrypt(credentials)
# check if the credential has changed, save the original credential
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
masked_credentials = encrypter.mask_tool_credentials(decrypted_credentials)
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = decrypted_credentials[name]
@ -446,7 +438,7 @@ class ApiToolManageService:
return {"result": result or "empty response"}
@staticmethod
def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]:
"""
list api tools
"""
@ -474,7 +466,7 @@ class ApiToolManageService:
for tool in tools or []:
user_provider.tools.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
tenant_id=tenant_id, tool=tool, labels=labels
)
)

@ -1,28 +1,84 @@
import json
import logging
import re
from collections.abc import Mapping
from pathlib import Path
from typing import Any, Optional
from sqlalchemy.orm import Session
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
from core.tools.entities.api_entities import (
ToolApiEntity,
ToolProviderApiEntity,
ToolProviderCredentialApiEntity,
ToolProviderCredentialInfoApiEntity,
)
from core.tools.entities.tool_entities import CredentialType
from core.tools.errors import ToolProviderNotFoundError
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import create_provider_encrypter
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from extensions.ext_database import db
from models.tools import BuiltinToolProvider
from extensions.ext_redis import redis_client
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
from services.plugin.plugin_service import PluginService
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
class BuiltinToolManageService:
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
@staticmethod
def delete_custom_oauth_client_params(tenant_id: str, provider: str):
"""
delete custom oauth client params
"""
tool_provider = ToolProviderID(provider)
with Session(db.engine) as session:
session.query(ToolOAuthTenantClient).filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
).delete()
session.commit()
return {"result": "success"}
@staticmethod
def get_builtin_tool_provider_oauth_client_schema(tenant_id: str, provider_name: str):
"""
get builtin tool provider oauth client schema
"""
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified(
tenant_id, provider.plugin_unique_identifier
)
is_oauth_custom_client_enabled = BuiltinToolManageService.is_oauth_custom_client_enabled(
tenant_id, provider_name
)
is_system_oauth_params_exists = verified and BuiltinToolManageService.is_oauth_system_client_exists(
provider_name
)
result = {
"schema": provider.get_oauth_client_schema(),
"is_oauth_custom_client_enabled": is_oauth_custom_client_enabled,
"is_system_oauth_params_exists": is_system_oauth_params_exists,
"client_params": BuiltinToolManageService.get_custom_oauth_client_params(tenant_id, provider_name),
"redirect_uri": f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_name}/tool/callback",
}
return result
@staticmethod
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
"""
@ -36,27 +92,11 @@ class BuiltinToolManageService:
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tools = provider_controller.get_tools()
tool_provider_configurations = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# check if user has added the provider
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
credentials = {}
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt(credentials)
result: list[ToolApiEntity] = []
for tool in tools or []:
result.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tool=tool,
credentials=credentials,
tenant_id=tenant_id,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
@ -65,25 +105,15 @@ class BuiltinToolManageService:
return result
@staticmethod
def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str):
def get_builtin_tool_provider_info(tenant_id: str, provider: str):
"""
get builtin tool provider info
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tool_provider_configurations = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# check if user has added the provider
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
credentials = {}
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt(credentials)
builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id)
if builtin_provider is None:
raise ValueError(f"you have not added provider {provider}")
entity = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
@ -92,127 +122,406 @@ class BuiltinToolManageService:
)
entity.original_credentials = {}
return entity
@staticmethod
def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
def list_builtin_provider_credentials_schema(provider_name: str, credential_type: CredentialType, tenant_id: str):
"""
list builtin provider credentials schema
:param credential_type: credential type
:param provider_name: the name of the provider
:param tenant_id: the id of the tenant
:return: the list of tool providers
"""
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
return jsonable_encoder(provider.get_credentials_schema())
return provider.get_credentials_schema_by_type(credential_type)
@staticmethod
def update_builtin_tool_provider(
session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict
user_id: str,
tenant_id: str,
provider: str,
credential_id: str,
credentials: dict | None = None,
name: str | None = None,
):
"""
update builtin tool provider
"""
# get if the provider exists
provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
with Session(db.engine) as session:
# get if the provider exists
db_provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
if db_provider is None:
raise ValueError(f"you have not added provider {provider}")
try:
if CredentialType.of(db_provider.credential_type).is_editable() and credentials:
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider} does not need credentials")
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
tenant_id, db_provider, provider, provider_controller
)
original_credentials = encrypter.decrypt(db_provider.credentials)
new_credentials: dict = {
key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
if CredentialType.of(db_provider.credential_type).is_validate_allowed():
provider_controller.validate_credentials(user_id, new_credentials)
# encrypt credentials
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials))
cache.delete()
# update name if provided
if name and name != db_provider.name:
# check if the name is already used
if (
session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
.count()
> 0
):
raise ValueError(f"the credential name '{name}' is already used")
db_provider.name = name
session.commit()
except Exception as e:
session.rollback()
raise ValueError(str(e))
return {"result": "success"}
@staticmethod
def add_builtin_tool_provider(
user_id: str,
api_type: CredentialType,
tenant_id: str,
provider: str,
credentials: dict,
name: str | None = None,
):
"""
add builtin tool provider
"""
try:
# get provider
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials")
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
with Session(db.engine) as session:
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
with redis_client.lock(lock, timeout=20):
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider} does not need credentials")
provider_count = (
session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
)
# check if the provider count is reached the limit
if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
raise ValueError(f"you have reached the maximum number of providers for {provider}")
# validate credentials if allowed
if CredentialType.of(api_type).is_validate_allowed():
provider_controller.validate_credentials(user_id, credentials)
# generate name if not provided
if name is None or name == "":
name = BuiltinToolManageService.generate_builtin_tool_provider_name(
session=session, tenant_id=tenant_id, provider=provider, credential_type=api_type
)
else:
# check if the name is already used
if (
session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider, name=name)
.count()
> 0
):
raise ValueError(f"the credential name '{name}' is already used")
# create encrypter
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(api_type)
],
cache=NoOpProviderCredentialCache(),
)
db_provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider,
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
credential_type=api_type.value,
name=name,
)
# get original credentials if exists
if provider is not None:
original_credentials = tool_configuration.decrypt(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name]
# validate credentials
provider_controller.validate_credentials(user_id, credentials)
# encrypt credentials
credentials = tool_configuration.encrypt(credentials)
except (
PluginDaemonClientSideError,
ToolProviderNotFoundError,
ToolNotFoundError,
ToolProviderCredentialValidationError,
) as e:
session.add(db_provider)
session.commit()
except Exception as e:
session.rollback()
raise ValueError(str(e))
return {"result": "success"}
if provider is None:
# create provider
provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider_name,
encrypted_credentials=json.dumps(credentials),
@staticmethod
def create_tool_encrypter(
tenant_id: str,
db_provider: BuiltinToolProvider,
provider: str,
provider_controller: BuiltinToolProviderController,
):
encrypter, cache = create_provider_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(db_provider.credential_type)
],
cache=ToolProviderCredentialsCache(tenant_id=tenant_id, provider=provider, credential_id=db_provider.id),
)
return encrypter, cache
@staticmethod
def generate_builtin_tool_provider_name(
session: Session, tenant_id: str, provider: str, credential_type: CredentialType
) -> str:
try:
db_providers = (
session.query(BuiltinToolProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider,
credential_type=credential_type.value,
)
.order_by(BuiltinToolProvider.created_at.desc())
.all()
)
db.session.add(provider)
else:
provider.encrypted_credentials = json.dumps(credentials)
# Get the default name pattern
default_pattern = f"{credential_type.get_name()}"
# delete cache
tool_configuration.delete_tool_credentials_cache()
# Find all names that match the default pattern: "{default_pattern} {number}"
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
numbers = []
db.session.commit()
return {"result": "success"}
for db_provider in db_providers:
if db_provider.name:
match = re.match(pattern, db_provider.name.strip())
if match:
numbers.append(int(match.group(1)))
# If no default pattern names found, start with 1
if not numbers:
return f"{default_pattern} 1"
# Find the next number
max_number = max(numbers)
return f"{default_pattern} {max_number + 1}"
except Exception as e:
logger.warning(f"Error generating next provider name for {provider}: {str(e)}")
# fallback
return f"{credential_type.get_name()} 1"
@staticmethod
def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str):
def get_builtin_tool_provider_credentials(
tenant_id: str, provider_name: str
) -> list[ToolProviderCredentialApiEntity]:
"""
get builtin tool provider credentials
"""
provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
with db.session.no_autoflush:
providers = (
db.session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, provider=provider_name)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.all()
)
if provider_obj is None:
return {}
if len(providers) == 0:
return []
default_provider = providers[0]
default_provider.is_default = True
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
credentials: list[ToolProviderCredentialApiEntity] = []
encrypters = {}
for provider in providers:
credential_type = provider.credential_type
if credential_type not in encrypters:
encrypters[credential_type] = BuiltinToolManageService.create_tool_encrypter(
tenant_id, provider, provider.provider, provider_controller
)[0]
encrypter = encrypters[credential_type]
decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials))
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
provider=provider,
credentials=decrypt_credential,
)
credentials.append(credential_entity)
return credentials
provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
@staticmethod
def get_builtin_tool_provider_credential_info(tenant_id: str, provider: str) -> ToolProviderCredentialInfoApiEntity:
"""
get builtin tool provider credential info
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
supported_credential_types = provider_controller.get_supported_credential_types()
credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider)
credential_info = ToolProviderCredentialInfoApiEntity(
supported_credential_types=supported_credential_types,
is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider),
credentials=credentials,
)
credentials = tool_configuration.decrypt(provider_obj.credentials)
credentials = tool_configuration.mask_tool_credentials(credentials)
return credentials
return credential_info
@staticmethod
def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str):
def delete_builtin_tool_provider(tenant_id: str, provider: str, credential_id: str):
"""
delete tool provider
"""
provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
with Session(db.engine) as session:
db_provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
if provider_obj is None:
raise ValueError(f"you have not added provider {provider_name}")
if db_provider is None:
raise ValueError(f"you have not added provider {provider}")
db.session.delete(provider_obj)
db.session.commit()
session.delete(db_provider)
session.commit()
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
tool_configuration = ProviderConfigEncrypter(
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
_, cache = BuiltinToolManageService.create_tool_encrypter(
tenant_id, db_provider, provider, provider_controller
)
cache.delete()
return {"result": "success"}
@staticmethod
def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str):
"""
set default provider
"""
with Session(db.engine) as session:
# get provider
target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first()
if target_provider is None:
raise ValueError("provider not found")
# clear default provider
session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
).update({"is_default": False})
# set new default provider
target_provider.is_default = True
session.commit()
return {"result": "success"}
@staticmethod
def is_oauth_system_client_exists(provider_name: str) -> bool:
"""
check if oauth system client exists
"""
tool_provider = ToolProviderID(provider_name)
with Session(db.engine).no_autoflush as session:
system_client: ToolOAuthSystemClient | None = (
session.query(ToolOAuthSystemClient)
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
.first()
)
return system_client is not None
@staticmethod
def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool:
"""
check if oauth custom client is enabled
"""
tool_provider = ToolProviderID(provider)
with Session(db.engine).no_autoflush as session:
user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
enabled=True,
)
.first()
)
return user_client is not None and user_client.enabled
@staticmethod
def get_oauth_client(tenant_id: str, provider: str) -> Mapping[str, Any] | None:
"""
get builtin tool provider
"""
tool_provider = ToolProviderID(provider)
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
tool_configuration.delete_tool_credentials_cache()
with Session(db.engine).no_autoflush as session:
user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
provider=tool_provider.provider_name,
plugin_id=tool_provider.plugin_id,
enabled=True,
)
.first()
)
oauth_params: Mapping[str, Any] | None = None
if user_client:
oauth_params = encrypter.decrypt(user_client.oauth_params)
return oauth_params
# only verified provider can use custom oauth client
is_verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified(
tenant_id, provider.plugin_unique_identifier
)
if not is_verified:
return oauth_params
return {"result": "success"}
system_client: ToolOAuthSystemClient | None = (
session.query(ToolOAuthSystemClient)
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
.first()
)
if system_client:
try:
oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
except Exception as e:
raise ValueError(f"Error decrypting system oauth params: {e}")
return oauth_params
@staticmethod
def get_builtin_tool_provider_icon(provider: str):
@ -234,9 +543,7 @@ class BuiltinToolManageService:
with db.session.no_autoflush:
# get all user added providers
db_providers: list[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
)
db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
# rewrite db_providers
for db_provider in db_providers:
@ -275,7 +582,6 @@ class BuiltinToolManageService:
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id,
tool=tool,
credentials=user_builtin_provider.original_credentials,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
)
@ -287,43 +593,153 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result)
@staticmethod
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
try:
full_provider_name = provider_name
provider_id_entity = ToolProviderID(provider_name)
provider_name = provider_id_entity.provider_name
if provider_id_entity.organization != "langgenius":
provider_obj = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == full_provider_name,
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
"""
This method is used to fetch the builtin provider from the database
1.if the default provider exists, return the default provider
2.if the default provider does not exist, return the oldest provider
"""
with Session(db.engine) as session:
try:
full_provider_name = provider_name
provider_id_entity = ToolProviderID(provider_name)
provider_name = provider_id_entity.provider_name
if provider_id_entity.organization != "langgenius":
provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == full_provider_name,
)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
)
.first()
)
else:
provider_obj = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name)
| (BuiltinToolProvider.provider == full_provider_name),
else:
provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name)
| (BuiltinToolProvider.provider == full_provider_name),
)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
)
if provider is None:
return None
provider.provider = ToolProviderID(provider.provider).to_string()
return provider
except Exception:
# it's an old provider without organization
return (
session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
)
if provider_obj is None:
return None
@staticmethod
def save_custom_oauth_client_params(
tenant_id: str,
provider: str,
client_params: Optional[dict] = None,
enable_oauth_custom_client: Optional[bool] = None,
):
"""
setup oauth custom client
"""
if client_params is None and enable_oauth_custom_client is None:
return {"result": "success"}
tool_provider = ToolProviderID(provider)
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller:
raise ToolProviderNotFoundError(f"Provider {provider} not found")
provider_obj.provider = ToolProviderID(provider_obj.provider).to_string()
return provider_obj
except Exception:
# it's an old provider without organization
return (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name),
if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
with Session(db.engine) as session:
custom_client_params = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
)
.first()
)
# if the record does not exist, create a basic record
if custom_client_params is None:
custom_client_params = ToolOAuthTenantClient(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
)
session.add(custom_client_params)
if client_params is not None:
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
original_params = encrypter.decrypt(custom_client_params.oauth_params)
new_params: dict = {
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
for key, value in client_params.items()
}
custom_client_params.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params))
if enable_oauth_custom_client is not None:
custom_client_params.enabled = enable_oauth_custom_client
session.commit()
return {"result": "success"}
@staticmethod
def get_custom_oauth_client_params(tenant_id: str, provider: str):
"""
get custom oauth client params
"""
with Session(db.engine) as session:
tool_provider = ToolProviderID(provider)
custom_oauth_client_params: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
)
.first()
)
if custom_oauth_client_params is None:
return {}
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller:
raise ToolProviderNotFoundError(f"Provider {provider} not found")
if not isinstance(provider_controller, BuiltinToolProviderController):
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))

@ -7,13 +7,14 @@ from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import ProviderConfigEncrypter
from extensions.ext_database import db
from models.tools import MCPToolProvider
from services.tools.tools_transform_service import ToolTransformService
@ -69,6 +70,7 @@ class MCPToolManageService:
MCPToolProvider.server_url_hash == server_url_hash,
MCPToolProvider.server_identifier == server_identifier,
),
MCPToolProvider.tenant_id == tenant_id,
)
.first()
)
@ -197,8 +199,7 @@ class MCPToolManageService:
tool_configuration = ProviderConfigEncrypter(
tenant_id=mcp_provider.tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.provider_id,
provider_config_cache=NoOpProviderCredentialCache(),
)
credentials = tool_configuration.encrypt(credentials)
mcp_provider.updated_at = datetime.now()

@ -5,21 +5,23 @@ from typing import Any, Optional, Union, cast
from yarl import URL
from configs import dify_config
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.mcp.types import Tool as MCPTool
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
CredentialType,
ToolParameter,
ToolProviderType,
)
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
@ -119,7 +121,12 @@ class ToolTransformService:
result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
# get credentials schema
schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()}
schema = {
x.to_basic_provider_config().name: x
for x in provider_controller.get_credentials_schema_by_type(
CredentialType.of(db_provider.credential_type) if db_provider else CredentialType.API_KEY
)
}
for name, value in schema.items():
if result.masked_credentials:
@ -136,15 +143,23 @@ class ToolTransformService:
credentials = db_provider.credentials
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_provider_encrypter(
tenant_id=db_provider.tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(
CredentialType.of(db_provider.credential_type)
)
],
cache=ToolProviderCredentialsCache(
tenant_id=db_provider.tenant_id,
provider=db_provider.provider,
credential_id=db_provider.id,
),
)
# decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt(data=credentials)
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
decrypted_credentials = encrypter.decrypt(data=credentials)
masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials
result.original_credentials = decrypted_credentials
@ -287,16 +302,14 @@ class ToolTransformService:
if decrypt_credentials:
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_tool_provider_encrypter(
tenant_id=db_provider.tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
controller=provider_controller,
)
# decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt(data=credentials)
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
decrypted_credentials = encrypter.decrypt(data=credentials)
masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials
@ -306,7 +319,6 @@ class ToolTransformService:
def convert_tool_entity_to_api_entity(
tool: Union[ApiToolBundle, WorkflowTool, Tool],
tenant_id: str,
credentials: dict | None = None,
labels: list[str] | None = None,
) -> ToolApiEntity:
"""
@ -316,7 +328,7 @@ class ToolTransformService:
# fork tool runtime
tool = tool.fork_tool_runtime(
runtime=ToolRuntime(
credentials=credentials or {},
credentials={},
tenant_id=tenant_id,
)
)
@ -357,6 +369,19 @@ class ToolTransformService:
labels=labels or [],
)
@staticmethod
def convert_builtin_provider_to_credential_entity(
provider: BuiltinToolProvider, credentials: dict
) -> ToolProviderCredentialApiEntity:
return ToolProviderCredentialApiEntity(
id=provider.id,
name=provider.name,
provider=provider.provider,
credential_type=CredentialType.of(provider.credential_type),
is_default=provider.is_default,
credentials=credentials,
)
@staticmethod
def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
"""

@ -111,7 +111,6 @@
<p class="title">验证您的邮箱变更请求</p>
<div class="description">
<p class="content1">我们收到了一个变更您 Dify 账户关联邮箱地址的请求。</p>
<p class="content2">我们收到了一个变更您 Dify 账户关联邮箱地址的请求。</p>
<p class="content3">此验证码仅在接下来的5分钟内有效</p>
</div>
<div class="code-content">

@ -143,7 +143,7 @@
<div class="warning">Please note:</div>
<ul class="warningList">
<li>The ownership transfer will take effect immediately once confirmed and cannot be undone.</li>
<li>Youll become a admin member, and the new owner will have full control of the workspace.</li>
<li>Youll become an admin member, and the new owner will have full control of the workspace.</li>
</ul>
<p class="tips">If you didnt make this request, please ignore this email or contact support immediately.</p>
</div>

@ -108,7 +108,6 @@
<p class="title">验证您的邮箱变更请求</p>
<div class="description">
<p class="content1">我们收到了一个变更您 Dify 账户关联邮箱地址的请求。</p>
<p class="content2">我们收到了一个变更您 Dify 账户关联邮箱地址的请求。</p>
<p class="content3">此验证码仅在接下来的5分钟内有效</p>
</div>
<div class="code-content">

@ -140,7 +140,7 @@
<div class="warning">Please note:</div>
<ul class="warningList">
<li>The ownership transfer will take effect immediately once confirmed and cannot be undone.</li>
<li>Youll become a admin member, and the new owner will have full control of the workspace.</li>
<li>Youll become an admin member, and the new owner will have full control of the workspace.</li>
</ul>
<p class="tips">If you didnt make this request, please ignore this email or contact support immediately.</p>
</div>

@ -354,3 +354,35 @@ def test_execute_code_output_object_list():
# validate
with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs)
def test_execute_code_scientific_notation():
code = """
def main() -> dict:
return {
"result": -8.0E-5
}
"""
code = "\n".join([line[4:] for line in code.split("\n")])
code_config = {
"id": "code",
"data": {
"outputs": {
"result": {
"type": "number",
},
},
"title": "123",
"variables": [],
"answer": "123",
"code_language": "python3",
"code": code,
},
}
node = init_code_node(code_config)
# execute node
result = node._run()
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED

@ -0,0 +1,496 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.oauth import (
OAuthCallback,
OAuthLogin,
_generate_account,
_get_account_by_openid_or_email,
get_oauth_providers,
)
from libs.oauth import OAuthUserInfo
from models.account import AccountStatus
from services.errors.account import AccountNotFoundError
class TestGetOAuthProviders:
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.mark.parametrize(
("github_config", "google_config", "expected_github", "expected_google"),
[
# Both providers configured
(
{"id": "github_id", "secret": "github_secret"},
{"id": "google_id", "secret": "google_secret"},
True,
True,
),
# Only GitHub configured
({"id": "github_id", "secret": "github_secret"}, {"id": None, "secret": None}, True, False),
# Only Google configured
({"id": None, "secret": None}, {"id": "google_id", "secret": "google_secret"}, False, True),
# No providers configured
({"id": None, "secret": None}, {"id": None, "secret": None}, False, False),
],
)
@patch("controllers.console.auth.oauth.dify_config")
def test_should_configure_oauth_providers_correctly(
self, mock_config, app, github_config, google_config, expected_github, expected_google
):
mock_config.GITHUB_CLIENT_ID = github_config["id"]
mock_config.GITHUB_CLIENT_SECRET = github_config["secret"]
mock_config.GOOGLE_CLIENT_ID = google_config["id"]
mock_config.GOOGLE_CLIENT_SECRET = google_config["secret"]
mock_config.CONSOLE_API_URL = "http://localhost"
with app.app_context():
providers = get_oauth_providers()
assert (providers["github"] is not None) == expected_github
assert (providers["google"] is not None) == expected_google
class TestOAuthLogin:
@pytest.fixture
def resource(self):
return OAuthLogin()
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def mock_oauth_provider(self):
provider = MagicMock()
provider.get_authorization_url.return_value = "https://github.com/login/oauth/authorize?..."
return provider
@pytest.mark.parametrize(
("invite_token", "expected_token"),
[
(None, None),
("test_invite_token", "test_invite_token"),
("", None),
],
)
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth.redirect")
def test_should_handle_oauth_login_with_various_tokens(
self,
mock_redirect,
mock_get_providers,
resource,
app,
mock_oauth_provider,
invite_token,
expected_token,
):
mock_get_providers.return_value = {"github": mock_oauth_provider, "google": None}
query_string = f"invite_token={invite_token}" if invite_token else ""
with app.test_request_context(f"/auth/oauth/github?{query_string}"):
resource.get("github")
mock_oauth_provider.get_authorization_url.assert_called_once_with(invite_token=expected_token)
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...")
@pytest.mark.parametrize(
("provider", "expected_error"),
[
("invalid_provider", "Invalid provider"),
("github", "Invalid provider"), # When GitHub is not configured
("google", "Invalid provider"), # When Google is not configured
],
)
@patch("controllers.console.auth.oauth.get_oauth_providers")
def test_should_return_error_for_invalid_providers(
self, mock_get_providers, resource, app, provider, expected_error
):
mock_get_providers.return_value = {"github": None, "google": None}
with app.test_request_context(f"/auth/oauth/{provider}"):
response, status_code = resource.get(provider)
assert status_code == 400
assert response["error"] == expected_error
class TestOAuthCallback:
@pytest.fixture
def resource(self):
return OAuthCallback()
@pytest.fixture
def app(self):
app = Flask(__name__)
app.config["TESTING"] = True
return app
@pytest.fixture
def oauth_setup(self):
"""Common OAuth setup for callback tests"""
oauth_provider = MagicMock()
oauth_provider.get_access_token.return_value = "access_token"
oauth_provider.get_user_info.return_value = OAuthUserInfo(id="123", name="Test User", email="test@example.com")
account = MagicMock()
account.status = AccountStatus.ACTIVE.value
token_pair = MagicMock()
token_pair.access_token = "jwt_access_token"
token_pair.refresh_token = "jwt_refresh_token"
return {"provider": oauth_provider, "account": account, "token_pair": token_pair}
@patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth._generate_account")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.redirect")
def test_should_handle_successful_oauth_callback(
self,
mock_redirect,
mock_tenant_service,
mock_account_service,
mock_generate_account,
mock_get_providers,
mock_config,
resource,
app,
oauth_setup,
):
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
mock_generate_account.return_value = oauth_setup["account"]
mock_account_service.login.return_value = oauth_setup["token_pair"]
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
resource.get("github")
oauth_setup["provider"].get_access_token.assert_called_once_with("test_code")
oauth_setup["provider"].get_user_info.assert_called_once_with("access_token")
mock_redirect.assert_called_once_with(
"http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token"
)
@pytest.mark.parametrize(
("exception", "expected_error"),
[
(Exception("OAuth error"), "OAuth process failed"),
(ValueError("Invalid token"), "OAuth process failed"),
(KeyError("Missing key"), "OAuth process failed"),
],
)
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.get_oauth_providers")
def test_should_handle_oauth_exceptions(
self, mock_get_providers, mock_db, resource, app, exception, expected_error
):
# Mock database session
mock_db.session = MagicMock()
mock_db.session.rollback = MagicMock()
# Import the real requests module to create a proper exception
import requests
request_exception = requests.exceptions.RequestException("OAuth error")
request_exception.response = MagicMock()
request_exception.response.text = str(exception)
mock_oauth_provider = MagicMock()
mock_oauth_provider.get_access_token.side_effect = request_exception
mock_get_providers.return_value = {"github": mock_oauth_provider}
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
response, status_code = resource.get("github")
assert status_code == 400
assert response["error"] == expected_error
@pytest.mark.parametrize(
("account_status", "expected_redirect"),
[
(AccountStatus.BANNED.value, "http://localhost:3000/signin?message=Account is banned."),
# CLOSED status: Currently NOT handled, will proceed to login (security issue)
# This documents actual behavior. See test_defensive_check_for_closed_account_status for details
(
AccountStatus.CLOSED.value,
"http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token",
),
],
)
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth._generate_account")
@patch("controllers.console.auth.oauth.redirect")
def test_should_redirect_based_on_account_status(
self,
mock_redirect,
mock_generate_account,
mock_get_providers,
mock_config,
mock_db,
mock_tenant_service,
mock_account_service,
resource,
app,
oauth_setup,
account_status,
expected_redirect,
):
# Mock database session
mock_db.session = MagicMock()
mock_db.session.rollback = MagicMock()
mock_db.session.commit = MagicMock()
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
account = MagicMock()
account.status = account_status
account.id = "123"
mock_generate_account.return_value = account
# Mock login for CLOSED status
mock_token_pair = MagicMock()
mock_token_pair.access_token = "jwt_access_token"
mock_token_pair.refresh_token = "jwt_refresh_token"
mock_account_service.login.return_value = mock_token_pair
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
resource.get("github")
mock_redirect.assert_called_once_with(expected_redirect)
@patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth._generate_account")
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.AccountService")
def test_should_activate_pending_account(
self,
mock_account_service,
mock_tenant_service,
mock_db,
mock_generate_account,
mock_get_providers,
mock_config,
resource,
app,
oauth_setup,
):
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
mock_account = MagicMock()
mock_account.status = AccountStatus.PENDING.value
mock_generate_account.return_value = mock_account
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
resource.get("github")
assert mock_account.status == AccountStatus.ACTIVE.value
assert mock_account.initialized_at is not None
mock_db.session.commit.assert_called_once()
@patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth._generate_account")
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.redirect")
def test_defensive_check_for_closed_account_status(
self,
mock_redirect,
mock_account_service,
mock_tenant_service,
mock_db,
mock_generate_account,
mock_get_providers,
mock_config,
resource,
app,
oauth_setup,
):
"""Defensive test for CLOSED account status handling in OAuth callback.
This is a defensive test documenting expected security behavior for CLOSED accounts.
Current behavior: CLOSED status is NOT checked, allowing closed accounts to login.
Expected behavior: CLOSED accounts should be rejected like BANNED accounts.
Context:
- AccountStatus.CLOSED is defined in the enum but never used in production
- The close_account() method exists but is never called
- Account deletion uses external service instead of status change
- All authentication services (OAuth, password, email) don't check CLOSED status
TODO: If CLOSED status is implemented in the future:
1. Update OAuth callback to check for CLOSED status
2. Add similar checks to all authentication services for consistency
3. Update this test to verify the rejection behavior
Security consideration: Until properly implemented, CLOSED status provides no protection.
"""
# Setup
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
# Create account with CLOSED status
closed_account = MagicMock()
closed_account.status = AccountStatus.CLOSED.value
closed_account.id = "123"
closed_account.name = "Closed Account"
mock_generate_account.return_value = closed_account
# Mock successful login (current behavior)
mock_token_pair = MagicMock()
mock_token_pair.access_token = "jwt_access_token"
mock_token_pair.refresh_token = "jwt_refresh_token"
mock_account_service.login.return_value = mock_token_pair
# Execute OAuth callback
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
resource.get("github")
# Verify current behavior: login succeeds (this is NOT ideal)
mock_redirect.assert_called_once_with(
"http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token"
)
mock_account_service.login.assert_called_once()
# Document expected behavior in comments:
# Expected: mock_redirect.assert_called_once_with(
# "http://localhost:3000/signin?message=Account is closed."
# )
# Expected: mock_account_service.login.assert_not_called()
class TestAccountGeneration:
@pytest.fixture
def user_info(self):
return OAuthUserInfo(id="123", name="Test User", email="test@example.com")
@pytest.fixture
def mock_account(self):
account = MagicMock()
account.name = "Test User"
return account
@patch("controllers.console.auth.oauth.db")
@patch("controllers.console.auth.oauth.Account")
@patch("controllers.console.auth.oauth.Session")
@patch("controllers.console.auth.oauth.select")
def test_should_get_account_by_openid_or_email(
self, mock_select, mock_session, mock_account_model, mock_db, user_info, mock_account
):
# Mock db.engine for Session creation
mock_db.engine = MagicMock()
# Test OpenID found
mock_account_model.get_by_openid.return_value = mock_account
result = _get_account_by_openid_or_email("github", user_info)
assert result == mock_account
mock_account_model.get_by_openid.assert_called_once_with("github", "123")
# Test fallback to email
mock_account_model.get_by_openid.return_value = None
mock_session_instance = MagicMock()
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
mock_session.return_value.__enter__.return_value = mock_session_instance
result = _get_account_by_openid_or_email("github", user_info)
assert result == mock_account
@pytest.mark.parametrize(
("allow_register", "existing_account", "should_create"),
[
(True, None, True), # New account creation allowed
(True, "existing", False), # Existing account
(False, None, False), # Registration not allowed
],
)
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
@patch("controllers.console.auth.oauth.FeatureService")
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.db")
def test_should_handle_account_generation_scenarios(
self,
mock_db,
mock_tenant_service,
mock_account_service,
mock_register_service,
mock_feature_service,
mock_get_account,
app,
user_info,
mock_account,
allow_register,
existing_account,
should_create,
):
mock_get_account.return_value = mock_account if existing_account else None
mock_feature_service.get_system_features.return_value.is_allow_register = allow_register
mock_register_service.register.return_value = mock_account
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
if not allow_register and not existing_account:
with pytest.raises(AccountNotFoundError):
_generate_account("github", user_info)
else:
result = _generate_account("github", user_info)
assert result == mock_account
if should_create:
mock_register_service.register.assert_called_once_with(
email="test@example.com", name="Test User", password=None, open_id="123", provider="github"
)
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
@patch("controllers.console.auth.oauth.TenantService")
@patch("controllers.console.auth.oauth.FeatureService")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.tenant_was_created")
def test_should_create_workspace_for_account_without_tenant(
self,
mock_event,
mock_account_service,
mock_feature_service,
mock_tenant_service,
mock_get_account,
app,
user_info,
mock_account,
):
mock_get_account.return_value = mock_account
mock_tenant_service.get_join_tenants.return_value = []
mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True
mock_new_tenant = MagicMock()
mock_tenant_service.create_tenant.return_value = mock_new_tenant
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
result = _generate_account("github", user_info)
assert result == mock_account
mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace")
mock_tenant_service.create_tenant_member.assert_called_once_with(
mock_new_tenant, mock_account, role="owner"
)
mock_event.send.assert_called_once_with(mock_new_tenant)

@ -376,7 +376,7 @@ class TestSegmentDumpAndLoad:
f"get_segment_discriminator failed for serialized form of type {type(variable)}"
)
def test_invlaid_value_for_discriminator(self):
def test_invalid_value_for_discriminator(self):
# Test invalid cases
assert get_segment_discriminator({"value_type": "invalid"}) is None
assert get_segment_discriminator({}) is None

@ -0,0 +1,249 @@
import urllib.parse
from unittest.mock import MagicMock, patch
import pytest
import requests
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
class BaseOAuthTest:
"""Base class for OAuth provider tests with common fixtures"""
@pytest.fixture
def oauth_config(self):
return {
"client_id": "test_client_id",
"client_secret": "test_client_secret",
"redirect_uri": "http://localhost/callback",
}
@pytest.fixture
def mock_response(self):
response = MagicMock()
response.json.return_value = {}
return response
def parse_auth_url(self, url):
"""Helper to parse authorization URL"""
parsed = urllib.parse.urlparse(url)
params = urllib.parse.parse_qs(parsed.query)
return parsed, params
class TestGitHubOAuth(BaseOAuthTest):
@pytest.fixture
def oauth(self, oauth_config):
return GitHubOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
@pytest.mark.parametrize(
("invite_token", "expected_state"),
[
(None, None),
("test_invite_token", "test_invite_token"),
("", None),
],
)
def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
url = oauth.get_authorization_url(invite_token)
parsed, params = self.parse_auth_url(url)
assert parsed.scheme == "https"
assert parsed.netloc == "github.com"
assert parsed.path == "/login/oauth/authorize"
assert params["client_id"][0] == oauth_config["client_id"]
assert params["redirect_uri"][0] == oauth_config["redirect_uri"]
assert params["scope"][0] == "user:email"
if expected_state:
assert params["state"][0] == expected_state
else:
assert "state" not in params
@pytest.mark.parametrize(
("response_data", "expected_token", "should_raise"),
[
({"access_token": "test_token"}, "test_token", False),
({"error": "invalid_grant"}, None, True),
({}, None, True),
],
)
@patch("requests.post")
def test_should_retrieve_access_token(
self, mock_post, oauth, mock_response, response_data, expected_token, should_raise
):
mock_response.json.return_value = response_data
mock_post.return_value = mock_response
if should_raise:
with pytest.raises(ValueError) as exc_info:
oauth.get_access_token("test_code")
assert "Error in GitHub OAuth" in str(exc_info.value)
else:
token = oauth.get_access_token("test_code")
assert token == expected_token
@pytest.mark.parametrize(
("user_data", "email_data", "expected_email"),
[
# User with primary email
(
{"id": 12345, "login": "testuser", "name": "Test User"},
[
{"email": "secondary@example.com", "primary": False},
{"email": "primary@example.com", "primary": True},
],
"primary@example.com",
),
# User with no emails - fallback to noreply
({"id": 12345, "login": "testuser", "name": "Test User"}, [], "12345+testuser@users.noreply.github.com"),
# User with only secondary email - fallback to noreply
(
{"id": 12345, "login": "testuser", "name": "Test User"},
[{"email": "secondary@example.com", "primary": False}],
"12345+testuser@users.noreply.github.com",
),
],
)
@patch("requests.get")
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email):
user_response = MagicMock()
user_response.json.return_value = user_data
email_response = MagicMock()
email_response.json.return_value = email_data
mock_get.side_effect = [user_response, email_response]
user_info = oauth.get_user_info("test_token")
assert user_info.id == str(user_data["id"])
assert user_info.name == user_data["name"]
assert user_info.email == expected_email
@patch("requests.get")
def test_should_handle_network_errors(self, mock_get, oauth):
mock_get.side_effect = requests.exceptions.RequestException("Network error")
with pytest.raises(requests.exceptions.RequestException):
oauth.get_raw_user_info("test_token")
class TestGoogleOAuth(BaseOAuthTest):
@pytest.fixture
def oauth(self, oauth_config):
return GoogleOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
@pytest.mark.parametrize(
("invite_token", "expected_state"),
[
(None, None),
("test_invite_token", "test_invite_token"),
("", None),
],
)
def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
url = oauth.get_authorization_url(invite_token)
parsed, params = self.parse_auth_url(url)
assert parsed.scheme == "https"
assert parsed.netloc == "accounts.google.com"
assert parsed.path == "/o/oauth2/v2/auth"
assert params["client_id"][0] == oauth_config["client_id"]
assert params["redirect_uri"][0] == oauth_config["redirect_uri"]
assert params["response_type"][0] == "code"
assert params["scope"][0] == "openid email"
if expected_state:
assert params["state"][0] == expected_state
else:
assert "state" not in params
@pytest.mark.parametrize(
("response_data", "expected_token", "should_raise"),
[
({"access_token": "test_token"}, "test_token", False),
({"error": "invalid_grant"}, None, True),
({}, None, True),
],
)
@patch("requests.post")
def test_should_retrieve_access_token(
self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise
):
mock_response.json.return_value = response_data
mock_post.return_value = mock_response
if should_raise:
with pytest.raises(ValueError) as exc_info:
oauth.get_access_token("test_code")
assert "Error in Google OAuth" in str(exc_info.value)
else:
token = oauth.get_access_token("test_code")
assert token == expected_token
mock_post.assert_called_once_with(
oauth._TOKEN_URL,
data={
"client_id": oauth_config["client_id"],
"client_secret": oauth_config["client_secret"],
"code": "test_code",
"grant_type": "authorization_code",
"redirect_uri": oauth_config["redirect_uri"],
},
headers={"Accept": "application/json"},
)
@pytest.mark.parametrize(
("user_data", "expected_name"),
[
({"sub": "123", "email": "test@example.com", "email_verified": True}, ""),
({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string
],
)
@patch("requests.get")
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name):
mock_response.json.return_value = user_data
mock_get.return_value = mock_response
user_info = oauth.get_user_info("test_token")
assert user_info.id == user_data["sub"]
assert user_info.name == expected_name
assert user_info.email == user_data["email"]
mock_get.assert_called_once_with(oauth._USER_INFO_URL, headers={"Authorization": "Bearer test_token"})
@pytest.mark.parametrize(
"exception_type",
[
requests.exceptions.HTTPError,
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
],
)
@patch("requests.get")
def test_should_handle_http_errors(self, mock_get, oauth, exception_type):
mock_response = MagicMock()
mock_response.raise_for_status.side_effect = exception_type("Error")
mock_get.return_value = mock_response
with pytest.raises(exception_type):
oauth.get_raw_user_info("invalid_token")
class TestOAuthUserInfo:
@pytest.mark.parametrize(
"user_data",
[
{"id": "123", "name": "Test User", "email": "test@example.com"},
{"id": "456", "name": "", "email": "user@domain.com"},
{"id": "789", "name": "Another User", "email": "another@test.org"},
],
)
def test_should_create_user_info_dataclass(self, user_data):
user_info = OAuthUserInfo(**user_data)
assert user_info.id == user_data["id"]
assert user_info.name == user_data["name"]
assert user_info.email == user_data["email"]

@ -0,0 +1,619 @@
import base64
import hashlib
from unittest.mock import patch
import pytest
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
from Crypto.Util.Padding import pad
from core.tools.utils.system_oauth_encryption import (
OAuthEncryptionError,
SystemOAuthEncrypter,
create_system_oauth_encrypter,
decrypt_system_oauth_params,
encrypt_system_oauth_params,
get_system_oauth_encrypter,
)
class TestSystemOAuthEncrypter:
"""Test cases for SystemOAuthEncrypter class"""
def test_init_with_secret_key(self):
"""Test initialization with provided secret key"""
secret_key = "test_secret_key"
encrypter = SystemOAuthEncrypter(secret_key=secret_key)
expected_key = hashlib.sha256(secret_key.encode()).digest()
assert encrypter.key == expected_key
def test_init_with_none_secret_key(self):
"""Test initialization with None secret key falls back to config"""
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "config_secret"
encrypter = SystemOAuthEncrypter(secret_key=None)
expected_key = hashlib.sha256(b"config_secret").digest()
assert encrypter.key == expected_key
def test_init_with_empty_secret_key(self):
"""Test initialization with empty secret key"""
encrypter = SystemOAuthEncrypter(secret_key="")
expected_key = hashlib.sha256(b"").digest()
assert encrypter.key == expected_key
def test_init_without_secret_key_uses_config(self):
"""Test initialization without secret key uses config"""
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "default_secret"
encrypter = SystemOAuthEncrypter()
expected_key = hashlib.sha256(b"default_secret").digest()
assert encrypter.key == expected_key
def test_encrypt_oauth_params_basic(self):
"""Test basic OAuth parameters encryption"""
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
# Should be valid base64
try:
base64.b64decode(encrypted)
except Exception:
pytest.fail("Encrypted result is not valid base64")
def test_encrypt_oauth_params_empty_dict(self):
"""Test encryption with empty dictionary"""
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_encrypt_oauth_params_complex_data(self):
"""Test encryption with complex data structures"""
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {
"client_id": "test_id",
"client_secret": "test_secret",
"scopes": ["read", "write", "admin"],
"metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True},
"numeric_value": 42,
"boolean_value": False,
"null_value": None,
}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_encrypt_oauth_params_unicode_data(self):
"""Test encryption with unicode data"""
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {"client_id": "test_id", "client_secret": "test_secret", "description": "This is a test case 🚀"}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_encrypt_oauth_params_large_data(self):
"""Test encryption with large data"""
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {
"client_id": "test_id",
"large_data": "x" * 10000, # 10KB of data
}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_encrypt_oauth_params_invalid_input(self):
"""Test encryption with invalid input types"""
encrypter = SystemOAuthEncrypter("test_secret")
with pytest.raises(Exception): # noqa: B017
encrypter.encrypt_oauth_params(None) # type: ignore
with pytest.raises(Exception): # noqa: B017
encrypter.encrypt_oauth_params("not_a_dict") # type: ignore
def test_decrypt_oauth_params_basic(self):
"""Test basic OAuth parameters decryption"""
encrypter = SystemOAuthEncrypter("test_secret")
original_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == original_params
def test_decrypt_oauth_params_empty_dict(self):
"""Test decryption of empty dictionary"""
encrypter = SystemOAuthEncrypter("test_secret")
original_params = {}
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == original_params
def test_decrypt_oauth_params_complex_data(self):
"""Test decryption with complex data structures"""
encrypter = SystemOAuthEncrypter("test_secret")
original_params = {
"client_id": "test_id",
"client_secret": "test_secret",
"scopes": ["read", "write", "admin"],
"metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True},
"numeric_value": 42,
"boolean_value": False,
"null_value": None,
}
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == original_params
def test_decrypt_oauth_params_unicode_data(self):
"""Test decryption with unicode data"""
encrypter = SystemOAuthEncrypter("test_secret")
original_params = {
"client_id": "test_id",
"client_secret": "test_secret",
"description": "This is a test case 🚀",
}
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == original_params
def test_decrypt_oauth_params_large_data(self):
"""Test decryption with large data"""
encrypter = SystemOAuthEncrypter("test_secret")
original_params = {
"client_id": "test_id",
"large_data": "x" * 10000, # 10KB of data
}
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == original_params
def test_decrypt_oauth_params_invalid_base64(self):
"""Test decryption with invalid base64 data"""
encrypter = SystemOAuthEncrypter("test_secret")
with pytest.raises(OAuthEncryptionError):
encrypter.decrypt_oauth_params("invalid_base64!")
def test_decrypt_oauth_params_empty_string(self):
"""Test decryption with empty string"""
encrypter = SystemOAuthEncrypter("test_secret")
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params("")
assert "encrypted_data cannot be empty" in str(exc_info.value)
def test_decrypt_oauth_params_non_string_input(self):
"""Test decryption with non-string input"""
encrypter = SystemOAuthEncrypter("test_secret")
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params(123) # type: ignore
assert "encrypted_data must be a string" in str(exc_info.value)
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params(None) # type: ignore
assert "encrypted_data must be a string" in str(exc_info.value)
def test_decrypt_oauth_params_too_short_data(self):
"""Test decryption with too short encrypted data"""
encrypter = SystemOAuthEncrypter("test_secret")
# Create data that's too short (less than 32 bytes)
short_data = base64.b64encode(b"short").decode()
with pytest.raises(OAuthEncryptionError) as exc_info:
encrypter.decrypt_oauth_params(short_data)
assert "Invalid encrypted data format" in str(exc_info.value)
def test_decrypt_oauth_params_corrupted_data(self):
"""Test decryption with corrupted data"""
encrypter = SystemOAuthEncrypter("test_secret")
# Create corrupted data (valid base64 but invalid encrypted content)
corrupted_data = base64.b64encode(b"x" * 48).decode() # 48 bytes of garbage
with pytest.raises(OAuthEncryptionError):
encrypter.decrypt_oauth_params(corrupted_data)
def test_decrypt_oauth_params_wrong_key(self):
"""Test decryption with wrong key"""
encrypter1 = SystemOAuthEncrypter("secret1")
encrypter2 = SystemOAuthEncrypter("secret2")
original_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypter1.encrypt_oauth_params(original_params)
with pytest.raises(OAuthEncryptionError):
encrypter2.decrypt_oauth_params(encrypted)
def test_encryption_decryption_consistency(self):
"""Test that encryption and decryption are consistent"""
encrypter = SystemOAuthEncrypter("test_secret")
test_cases = [
{},
{"simple": "value"},
{"client_id": "id", "client_secret": "secret"},
{"complex": {"nested": {"deep": "value"}}},
{"unicode": "test 🚀"},
{"numbers": 42, "boolean": True, "null": None},
{"array": [1, 2, 3, "four", {"five": 5}]},
]
for original_params in test_cases:
encrypted = encrypter.encrypt_oauth_params(original_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == original_params, f"Failed for case: {original_params}"
def test_encryption_randomness(self):
"""Test that encryption produces different results for same input"""
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted1 = encrypter.encrypt_oauth_params(oauth_params)
encrypted2 = encrypter.encrypt_oauth_params(oauth_params)
# Should be different due to random IV
assert encrypted1 != encrypted2
# But should decrypt to same result
decrypted1 = encrypter.decrypt_oauth_params(encrypted1)
decrypted2 = encrypter.decrypt_oauth_params(encrypted2)
assert decrypted1 == decrypted2 == oauth_params
def test_different_secret_keys_produce_different_results(self):
"""Test that different secret keys produce different encrypted results"""
encrypter1 = SystemOAuthEncrypter("secret1")
encrypter2 = SystemOAuthEncrypter("secret2")
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted1 = encrypter1.encrypt_oauth_params(oauth_params)
encrypted2 = encrypter2.encrypt_oauth_params(oauth_params)
# Should produce different encrypted results
assert encrypted1 != encrypted2
# But each should decrypt correctly with its own key
decrypted1 = encrypter1.decrypt_oauth_params(encrypted1)
decrypted2 = encrypter2.decrypt_oauth_params(encrypted2)
assert decrypted1 == decrypted2 == oauth_params
@patch("core.tools.utils.system_oauth_encryption.get_random_bytes")
def test_encrypt_oauth_params_crypto_error(self, mock_get_random_bytes):
"""Test encryption when crypto operation fails"""
mock_get_random_bytes.side_effect = Exception("Crypto error")
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {"client_id": "test_id"}
with pytest.raises(OAuthEncryptionError) as exc_info:
encrypter.encrypt_oauth_params(oauth_params)
assert "Encryption failed" in str(exc_info.value)
@patch("core.tools.utils.system_oauth_encryption.TypeAdapter")
def test_encrypt_oauth_params_serialization_error(self, mock_type_adapter):
"""Test encryption when JSON serialization fails"""
mock_type_adapter.return_value.dump_json.side_effect = Exception("Serialization error")
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {"client_id": "test_id"}
with pytest.raises(OAuthEncryptionError) as exc_info:
encrypter.encrypt_oauth_params(oauth_params)
assert "Encryption failed" in str(exc_info.value)
def test_decrypt_oauth_params_invalid_json(self):
"""Test decryption with invalid JSON data"""
encrypter = SystemOAuthEncrypter("test_secret")
# Create valid encrypted data but with invalid JSON content
iv = get_random_bytes(16)
cipher = AES.new(encrypter.key, AES.MODE_CBC, iv)
invalid_json = b"invalid json content"
padded_data = pad(invalid_json, AES.block_size)
encrypted_data = cipher.encrypt(padded_data)
combined = iv + encrypted_data
encoded = base64.b64encode(combined).decode()
with pytest.raises(OAuthEncryptionError):
encrypter.decrypt_oauth_params(encoded)
def test_key_derivation_consistency(self):
"""Test that key derivation is consistent"""
secret_key = "test_secret"
encrypter1 = SystemOAuthEncrypter(secret_key)
encrypter2 = SystemOAuthEncrypter(secret_key)
assert encrypter1.key == encrypter2.key
# Keys should be 32 bytes (256 bits)
assert len(encrypter1.key) == 32
class TestFactoryFunctions:
"""Test cases for factory functions"""
def test_create_system_oauth_encrypter_with_secret(self):
"""Test factory function with secret key"""
secret_key = "test_secret"
encrypter = create_system_oauth_encrypter(secret_key)
assert isinstance(encrypter, SystemOAuthEncrypter)
expected_key = hashlib.sha256(secret_key.encode()).digest()
assert encrypter.key == expected_key
def test_create_system_oauth_encrypter_without_secret(self):
"""Test factory function without secret key"""
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "config_secret"
encrypter = create_system_oauth_encrypter()
assert isinstance(encrypter, SystemOAuthEncrypter)
expected_key = hashlib.sha256(b"config_secret").digest()
assert encrypter.key == expected_key
def test_create_system_oauth_encrypter_with_none_secret(self):
"""Test factory function with None secret key"""
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "config_secret"
encrypter = create_system_oauth_encrypter(None)
assert isinstance(encrypter, SystemOAuthEncrypter)
expected_key = hashlib.sha256(b"config_secret").digest()
assert encrypter.key == expected_key
class TestGlobalEncrypterInstance:
"""Test cases for global encrypter instance"""
def test_get_system_oauth_encrypter_singleton(self):
"""Test that get_system_oauth_encrypter returns singleton instance"""
# Clear the global instance first
import core.tools.utils.system_oauth_encryption
core.tools.utils.system_oauth_encryption._oauth_encrypter = None
encrypter1 = get_system_oauth_encrypter()
encrypter2 = get_system_oauth_encrypter()
assert encrypter1 is encrypter2
assert isinstance(encrypter1, SystemOAuthEncrypter)
def test_get_system_oauth_encrypter_uses_config(self):
"""Test that global encrypter uses config"""
# Clear the global instance first
import core.tools.utils.system_oauth_encryption
core.tools.utils.system_oauth_encryption._oauth_encrypter = None
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
mock_config.SECRET_KEY = "global_secret"
encrypter = get_system_oauth_encrypter()
expected_key = hashlib.sha256(b"global_secret").digest()
assert encrypter.key == expected_key
class TestConvenienceFunctions:
"""Test cases for convenience functions"""
def test_encrypt_system_oauth_params(self):
"""Test encrypt_system_oauth_params convenience function"""
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypt_system_oauth_params(oauth_params)
assert isinstance(encrypted, str)
assert len(encrypted) > 0
def test_decrypt_system_oauth_params(self):
"""Test decrypt_system_oauth_params convenience function"""
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
encrypted = encrypt_system_oauth_params(oauth_params)
decrypted = decrypt_system_oauth_params(encrypted)
assert decrypted == oauth_params
def test_convenience_functions_consistency(self):
"""Test that convenience functions work consistently"""
test_cases = [
{},
{"simple": "value"},
{"client_id": "id", "client_secret": "secret"},
{"complex": {"nested": {"deep": "value"}}},
{"unicode": "test 🚀"},
{"numbers": 42, "boolean": True, "null": None},
]
for original_params in test_cases:
encrypted = encrypt_system_oauth_params(original_params)
decrypted = decrypt_system_oauth_params(encrypted)
assert decrypted == original_params, f"Failed for case: {original_params}"
def test_convenience_functions_with_errors(self):
"""Test convenience functions with error conditions"""
# Test encryption with invalid input
with pytest.raises(Exception): # noqa: B017
encrypt_system_oauth_params(None) # type: ignore
# Test decryption with invalid input
with pytest.raises(ValueError):
decrypt_system_oauth_params("")
with pytest.raises(ValueError):
decrypt_system_oauth_params(None) # type: ignore
class TestErrorHandling:
"""Test cases for error handling"""
def test_oauth_encryption_error_inheritance(self):
"""Test that OAuthEncryptionError is a proper exception"""
error = OAuthEncryptionError("Test error")
assert isinstance(error, Exception)
assert str(error) == "Test error"
def test_oauth_encryption_error_with_cause(self):
"""Test OAuthEncryptionError with cause"""
original_error = ValueError("Original error")
error = OAuthEncryptionError("Wrapper error")
error.__cause__ = original_error
assert isinstance(error, Exception)
assert str(error) == "Wrapper error"
assert error.__cause__ is original_error
def test_error_messages_are_informative(self):
"""Test that error messages are informative"""
encrypter = SystemOAuthEncrypter("test_secret")
# Test empty string error
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params("")
assert "encrypted_data cannot be empty" in str(exc_info.value)
# Test non-string error
with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params(123) # type: ignore
assert "encrypted_data must be a string" in str(exc_info.value)
# Test invalid format error
short_data = base64.b64encode(b"short").decode()
with pytest.raises(OAuthEncryptionError) as exc_info:
encrypter.decrypt_oauth_params(short_data)
assert "Invalid encrypted data format" in str(exc_info.value)
class TestEdgeCases:
"""Test cases for edge cases and boundary conditions"""
def test_very_long_secret_key(self):
"""Test with very long secret key"""
long_secret = "x" * 10000
encrypter = SystemOAuthEncrypter(long_secret)
# Key should still be 32 bytes due to SHA-256
assert len(encrypter.key) == 32
# Should still work normally
oauth_params = {"client_id": "test_id"}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params
def test_special_characters_in_secret_key(self):
"""Test with special characters in secret key"""
special_secret = "!@#$%^&*()_+-=[]{}|;':\",./<>?`~test🚀"
encrypter = SystemOAuthEncrypter(special_secret)
oauth_params = {"client_id": "test_id"}
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params
def test_empty_values_in_oauth_params(self):
"""Test with empty values in oauth params"""
oauth_params = {
"client_id": "",
"client_secret": "",
"empty_dict": {},
"empty_list": [],
"empty_string": "",
"zero": 0,
"false": False,
"none": None,
}
encrypter = SystemOAuthEncrypter("test_secret")
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params
def test_deeply_nested_oauth_params(self):
"""Test with deeply nested oauth params"""
oauth_params = {"level1": {"level2": {"level3": {"level4": {"level5": {"deep_value": "found"}}}}}}
encrypter = SystemOAuthEncrypter("test_secret")
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params
def test_oauth_params_with_all_json_types(self):
"""Test with all JSON-supported data types"""
oauth_params = {
"string": "test_string",
"integer": 42,
"float": 3.14159,
"boolean_true": True,
"boolean_false": False,
"null_value": None,
"empty_string": "",
"array": [1, "two", 3.0, True, False, None],
"object": {"nested_string": "nested_value", "nested_number": 123, "nested_bool": True},
}
encrypter = SystemOAuthEncrypter("test_secret")
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params
class TestPerformance:
"""Test cases for performance considerations"""
def test_large_oauth_params(self):
"""Test with large oauth params"""
large_value = "x" * 100000 # 100KB
oauth_params = {"client_id": "test_id", "large_data": large_value}
encrypter = SystemOAuthEncrypter("test_secret")
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params
def test_many_fields_oauth_params(self):
"""Test with many fields in oauth params"""
oauth_params = {f"field_{i}": f"value_{i}" for i in range(1000)}
encrypter = SystemOAuthEncrypter("test_secret")
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params
def test_repeated_encryption_decryption(self):
"""Test repeated encryption and decryption operations"""
encrypter = SystemOAuthEncrypter("test_secret")
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
# Test multiple rounds of encryption/decryption
for i in range(100):
encrypted = encrypter.encrypt_oauth_params(oauth_params)
decrypted = encrypter.decrypt_oauth_params(encrypted)
assert decrypted == oauth_params

@ -183,3 +183,42 @@ rename_conversation_response.raise_for_status()
print('[rename result]')
print(rename_conversation_response.json())
```
* Using the Workflow Client
```python
import json
import requests
from dify_client import WorkflowClient
api_key = "your_api_key"
# Initialize Workflow Client
client = WorkflowClient(api_key)
# Prepare parameters for Workflow Client
user_id = "your_user_id"
context = "previous user interaction / metadata"
user_prompt = "What is the capital of France?"
inputs = {
"context": context,
"user_prompt": user_prompt,
# Add other input fields expected by your workflow (e.g., additional context, task parameters)
}
# Set response mode (default: streaming)
response_mode = "blocking"
# Run the workflow
response = client.run(inputs=inputs, response_mode=response_mode, user=user_id)
response.raise_for_status()
# Parse result
result = json.loads(response.text)
answer = result.get("data").get("outputs")
print(answer["answer"])
```

@ -1 +1 @@
from dify_client.client import ChatClient, CompletionClient, DifyClient
from dify_client.client import ChatClient, CompletionClient, WorkflowClient, KnowledgeBaseClient, DifyClient

@ -308,7 +308,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
const EditTimeText = useMemo(() => {
const timeText = formatTime({
date: (app.updated_at || app.created_at) * 1000,
dateFormat: 'MM/DD/YYYY h:mm',
dateFormat: `${t('datasetDocuments.segment.dateTimeFormat')}`,
})
return `${t('datasetDocuments.segment.editedAt')} ${timeText}`
// eslint-disable-next-line react-hooks/exhaustive-deps

@ -113,8 +113,8 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
}
const isValidEmail = (email: string): boolean => {
const emailRegex = /^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$/
return emailRegex.test(email)
const rfc5322emailRegex = /^[a-zA-Z0-9.!#$%&'*+/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$/
return rfc5322emailRegex.test(email) && email.length <= 254
}
const checkNewEmailExisted = async (email: string) => {

@ -17,8 +17,8 @@ import PromptEditorHeightResizeWrap from './prompt-editor-height-resize-wrap'
import cn from '@/utils/classnames'
import type { PromptRole, PromptVariable } from '@/models/debug'
import {
Clipboard,
ClipboardCheck,
Copy,
CopyCheck,
} from '@/app/components/base/icons/src/vender/line/files'
import Button from '@/app/components/base/button'
import Tooltip from '@/app/components/base/tooltip'
@ -188,13 +188,13 @@ const AdvancedPromptInput: FC<Props> = ({
)}
{!isCopied
? (
<Clipboard className='h-6 w-6 cursor-pointer p-1 text-text-tertiary' onClick={() => {
<Copy className='h-6 w-6 cursor-pointer p-1 text-text-tertiary' onClick={() => {
copy(value)
setIsCopied(true)
}} />
)
: (
<ClipboardCheck className='h-6 w-6 p-1 text-text-tertiary' />
<CopyCheck className='h-6 w-6 p-1 text-text-tertiary' />
)}
</div>
</div>

@ -18,7 +18,6 @@ import AppIcon from '@/app/components/base/app-icon'
import Button from '@/app/components/base/button'
import Indicator from '@/app/components/header/indicator'
import Switch from '@/app/components/base/switch'
import Toast from '@/app/components/base/toast'
import ConfigContext from '@/context/debug-configuration'
import type { AgentTool } from '@/types/app'
import { type Collection, CollectionType } from '@/app/components/tools/types'
@ -26,8 +25,6 @@ import { MAX_TOOLS_NUM } from '@/config'
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
import Tooltip from '@/app/components/base/tooltip'
import { DefaultToolIcon } from '@/app/components/base/icons/src/public/other'
import ConfigCredential from '@/app/components/tools/setting/build-in/config-credentials'
import { updateBuiltInToolCredential } from '@/service/tools'
import cn from '@/utils/classnames'
import ToolPicker from '@/app/components/workflow/block-selector/tool-picker'
import type { ToolDefaultValue, ToolValue } from '@/app/components/workflow/block-selector/types'
@ -57,13 +54,7 @@ const AgentTools: FC = () => {
const formattingChangedDispatcher = useFormattingChangedDispatcher()
const [currentTool, setCurrentTool] = useState<AgentToolWithMoreInfo>(null)
const currentCollection = useMemo(() => {
if (!currentTool) return null
const collection = collectionList.find(collection => canFindTool(collection.id, currentTool?.provider_id) && collection.type === currentTool?.provider_type)
return collection
}, [currentTool, collectionList])
const [isShowSettingTool, setIsShowSettingTool] = useState(false)
const [isShowSettingAuth, setShowSettingAuth] = useState(false)
const tools = (modelConfig?.agentConfig?.tools as AgentTool[] || []).map((item) => {
const collection = collectionList.find(
collection =>
@ -100,17 +91,6 @@ const AgentTools: FC = () => {
formattingChangedDispatcher()
}
const handleToolAuthSetting = (value: AgentToolWithMoreInfo) => {
const newModelConfig = produce(modelConfig, (draft) => {
const tool = (draft.agentConfig.tools).find((item: any) => item.provider_id === value?.collection?.id && item.tool_name === value?.tool_name)
if (tool)
(tool as AgentTool).notAuthor = false
})
setModelConfig(newModelConfig)
setIsShowSettingTool(false)
formattingChangedDispatcher()
}
const [isDeleting, setIsDeleting] = useState<number>(-1)
const getToolValue = (tool: ToolDefaultValue) => {
return {
@ -144,6 +124,20 @@ const AgentTools: FC = () => {
return item.provider_name
}
const handleAuthorizationItemClick = useCallback((credentialId: string) => {
const newModelConfig = produce(modelConfig, (draft) => {
const tool = (draft.agentConfig.tools).find((item: any) => item.provider_id === currentTool?.provider_id)
if (tool)
(tool as AgentTool).credential_id = credentialId
})
setCurrentTool({
...currentTool,
credential_id: credentialId,
} as any)
setModelConfig(newModelConfig)
formattingChangedDispatcher()
}, [currentTool, modelConfig, setModelConfig, formattingChangedDispatcher])
return (
<>
<Panel
@ -299,7 +293,7 @@ const AgentTools: FC = () => {
{item.notAuthor && (
<Button variant='secondary' size='small' onClick={() => {
setCurrentTool(item)
setShowSettingAuth(true)
setIsShowSettingTool(true)
}}>
{t('tools.notAuthorized')}
<Indicator className='ml-2' color='orange' />
@ -319,21 +313,8 @@ const AgentTools: FC = () => {
isModel={currentTool?.collection?.type === CollectionType.model}
onSave={handleToolSettingChange}
onHide={() => setIsShowSettingTool(false)}
/>
)}
{isShowSettingAuth && (
<ConfigCredential
collection={currentCollection as any}
onCancel={() => setShowSettingAuth(false)}
onSaved={async (value) => {
await updateBuiltInToolCredential((currentCollection as any).name, value)
Toast.notify({
type: 'success',
message: t('common.api.actionSuccess'),
})
handleToolAuthSetting(currentTool)
setShowSettingAuth(false)
}}
credentialId={currentTool?.credential_id}
onAuthorizationItemClick={handleAuthorizationItemClick}
/>
)}
</>

@ -14,7 +14,6 @@ import Icon from '@/app/components/plugins/card/base/card-icon'
import OrgInfo from '@/app/components/plugins/card/base/org-info'
import Description from '@/app/components/plugins/card/base/description'
import TabSlider from '@/app/components/base/tab-slider-plain'
import Button from '@/app/components/base/button'
import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form'
import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema'
@ -25,6 +24,10 @@ import I18n from '@/context/i18n'
import { getLanguage } from '@/i18n/language'
import cn from '@/utils/classnames'
import type { ToolWithProvider } from '@/app/components/workflow/types'
import {
AuthCategory,
PluginAuthInAgent,
} from '@/app/components/plugins/plugin-auth'
type Props = {
showBackButton?: boolean
@ -36,6 +39,8 @@ type Props = {
readonly?: boolean
onHide: () => void
onSave?: (value: Record<string, any>) => void
credentialId?: string
onAuthorizationItemClick?: (id: string) => void
}
const SettingBuiltInTool: FC<Props> = ({
@ -48,6 +53,8 @@ const SettingBuiltInTool: FC<Props> = ({
readonly,
onHide,
onSave,
credentialId,
onAuthorizationItemClick,
}) => {
const { locale } = useContext(I18n)
const language = getLanguage(locale)
@ -197,8 +204,20 @@ const SettingBuiltInTool: FC<Props> = ({
</div>
<div className='system-md-semibold mt-1 text-text-primary'>{currTool?.label[language]}</div>
{!!currTool?.description[language] && (
<Description className='mt-3' text={currTool.description[language]} descriptionLineRows={2}></Description>
<Description className='mb-2 mt-3 h-auto' text={currTool.description[language]} descriptionLineRows={2}></Description>
)}
{
collection.allow_delete && collection.type === CollectionType.builtIn && (
<PluginAuthInAgent
pluginPayload={{
provider: collection.name,
category: AuthCategory.tool,
}}
credentialId={credentialId}
onAuthorizationItemClick={onAuthorizationItemClick}
/>
)
}
</div>
{/* form */}
<div className='h-full'>

@ -6,8 +6,8 @@ import { useContext } from 'use-context-selector'
import { useTranslation } from 'react-i18next'
import cn from '@/utils/classnames'
import {
Clipboard,
ClipboardCheck,
Copy,
CopyCheck,
} from '@/app/components/base/icons/src/vender/line/files'
import PromptEditor from '@/app/components/base/prompt-editor'
import type { ExternalDataTool } from '@/models/common'
@ -81,13 +81,13 @@ const Editor: FC<Props> = ({
<div className={cn(s.optionWrap, 'items-center space-x-1')}>
{!isCopied
? (
<Clipboard className='h-6 w-6 cursor-pointer p-1 text-gray-500' onClick={() => {
<Copy className='h-6 w-6 cursor-pointer p-1 text-gray-500' onClick={() => {
copy(value)
setIsCopied(true)
}} />
)
: (
<ClipboardCheck className='h-6 w-6 p-1 text-gray-500' />
<CopyCheck className='h-6 w-6 p-1 text-gray-500' />
)}
</div>
</div>

@ -1,8 +1,8 @@
import { useEffect, useRef } from 'react'
import cn from '@/utils/classnames'
type AutoHeightTextareaProps =
& React.DetailedHTMLProps<React.TextareaHTMLAttributes<HTMLTextAreaElement>, HTMLTextAreaElement>
type AutoHeightTextareaProps
= & React.DetailedHTMLProps<React.TextareaHTMLAttributes<HTMLTextAreaElement>, HTMLTextAreaElement>
& { outerClassName?: string }
const AutoHeightTextarea = (

@ -117,7 +117,7 @@ const Question: FC<QuestionProps> = ({
</div>
<div
ref={contentRef}
className='w-full rounded-2xl bg-background-gradient-bg-fill-chat-bubble-bg-3 px-4 py-3 text-sm text-text-primary'
className='bg-background-gradient-bg-fill-chat-bubble-bg-3 w-full rounded-2xl px-4 py-3 text-sm text-text-primary'
style={theme?.chatBubbleColorStyle ? CssTransform(theme.chatBubbleColorStyle) : {}}
>
{

@ -5,8 +5,8 @@ import { debounce } from 'lodash-es'
import copy from 'copy-to-clipboard'
import Tooltip from '../tooltip'
import {
Clipboard,
ClipboardCheck,
Copy,
CopyCheck,
} from '@/app/components/base/icons/src/vender/line/files'
type Props = {
@ -39,10 +39,10 @@ export const CopyIcon = ({ content }: Props) => {
<div onMouseLeave={onMouseLeave}>
{!isCopied
? (
<Clipboard className='mx-1 h-3.5 w-3.5 cursor-pointer text-text-tertiary' onClick={onClickCopy} />
<Copy className='mx-1 h-3.5 w-3.5 cursor-pointer text-text-tertiary' onClick={onClickCopy} />
)
: (
<ClipboardCheck className='mx-1 h-3.5 w-3.5 text-text-tertiary' />
<CopyCheck className='mx-1 h-3.5 w-3.5 text-text-tertiary' />
)
}
</div>

@ -0,0 +1,177 @@
import {
isValidElement,
memo,
useMemo,
} from 'react'
import type { AnyFieldApi } from '@tanstack/react-form'
import { useStore } from '@tanstack/react-form'
import cn from '@/utils/classnames'
import Input from '@/app/components/base/input'
import PureSelect from '@/app/components/base/select/pure'
import type { FormSchema } from '@/app/components/base/form/types'
import { FormTypeEnum } from '@/app/components/base/form/types'
import { useRenderI18nObject } from '@/hooks/use-i18n'
export type BaseFieldProps = {
fieldClassName?: string
labelClassName?: string
inputContainerClassName?: string
inputClassName?: string
formSchema: FormSchema
field: AnyFieldApi
disabled?: boolean
}
const BaseField = ({
fieldClassName,
labelClassName,
inputContainerClassName,
inputClassName,
formSchema,
field,
disabled,
}: BaseFieldProps) => {
const renderI18nObject = useRenderI18nObject()
const {
label,
required,
placeholder,
options,
labelClassName: formLabelClassName,
show_on = [],
} = formSchema
const memorizedLabel = useMemo(() => {
if (isValidElement(label))
return label
if (typeof label === 'string')
return label
if (typeof label === 'object' && label !== null)
return renderI18nObject(label as Record<string, string>)
}, [label, renderI18nObject])
const memorizedPlaceholder = useMemo(() => {
if (typeof placeholder === 'string')
return placeholder
if (typeof placeholder === 'object' && placeholder !== null)
return renderI18nObject(placeholder as Record<string, string>)
}, [placeholder, renderI18nObject])
const memorizedOptions = useMemo(() => {
return options?.map((option) => {
return {
label: typeof option.label === 'string' ? option.label : renderI18nObject(option.label),
value: option.value,
}
}) || []
}, [options, renderI18nObject])
const value = useStore(field.form.store, s => s.values[field.name])
const values = useStore(field.form.store, (s) => {
return show_on.reduce((acc, condition) => {
acc[condition.variable] = s.values[condition.variable]
return acc
}, {} as Record<string, any>)
})
const show = useMemo(() => {
return show_on.every((condition) => {
const conditionValue = values[condition.variable]
return conditionValue === condition.value
})
}, [values, show_on])
if (!show)
return null
return (
<div className={cn(fieldClassName)}>
<div className={cn(labelClassName, formLabelClassName)}>
{memorizedLabel}
{
required && !isValidElement(label) && (
<span className='ml-1 text-text-destructive-secondary'>*</span>
)
}
</div>
<div className={cn(inputContainerClassName)}>
{
formSchema.type === FormTypeEnum.textInput && (
<Input
id={field.name}
name={field.name}
className={cn(inputClassName)}
value={value || ''}
onChange={e => field.handleChange(e.target.value)}
onBlur={field.handleBlur}
disabled={disabled}
placeholder={memorizedPlaceholder}
/>
)
}
{
formSchema.type === FormTypeEnum.secretInput && (
<Input
id={field.name}
name={field.name}
type='password'
className={cn(inputClassName)}
value={value || ''}
onChange={e => field.handleChange(e.target.value)}
onBlur={field.handleBlur}
disabled={disabled}
placeholder={memorizedPlaceholder}
/>
)
}
{
formSchema.type === FormTypeEnum.textNumber && (
<Input
id={field.name}
name={field.name}
type='number'
className={cn(inputClassName)}
value={value || ''}
onChange={e => field.handleChange(e.target.value)}
onBlur={field.handleBlur}
disabled={disabled}
placeholder={memorizedPlaceholder}
/>
)
}
{
formSchema.type === FormTypeEnum.select && (
<PureSelect
value={value}
onChange={v => field.handleChange(v)}
disabled={disabled}
placeholder={memorizedPlaceholder}
options={memorizedOptions}
triggerPopupSameWidth
/>
)
}
{
formSchema.type === FormTypeEnum.radio && (
<div className='flex items-center space-x-2'>
{
memorizedOptions.map(option => (
<div
key={option.value}
className={cn(
'system-sm-regular hover:bg-components-option-card-option-hover-bg hover:border-components-option-card-option-hover-border flex h-8 grow cursor-pointer items-center justify-center rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg p-2 text-text-secondary',
value === option.value && 'border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary shadow-xs',
)}
onClick={() => field.handleChange(option.value)}
>
{option.label}
</div>
))
}
</div>
)
}
</div>
</div>
)
}
export default memo(BaseField)

@ -0,0 +1,115 @@
import {
memo,
useCallback,
useImperativeHandle,
} from 'react'
import type {
AnyFieldApi,
AnyFormApi,
} from '@tanstack/react-form'
import { useForm } from '@tanstack/react-form'
import type {
FormRef,
FormSchema,
} from '@/app/components/base/form/types'
import {
BaseField,
} from '.'
import type {
BaseFieldProps,
} from '.'
import cn from '@/utils/classnames'
import {
useGetFormValues,
useGetValidators,
} from '@/app/components/base/form/hooks'
export type BaseFormProps = {
formSchemas?: FormSchema[]
defaultValues?: Record<string, any>
formClassName?: string
ref?: FormRef
disabled?: boolean
formFromProps?: AnyFormApi
} & Pick<BaseFieldProps, 'fieldClassName' | 'labelClassName' | 'inputContainerClassName' | 'inputClassName'>
const BaseForm = ({
formSchemas = [],
defaultValues,
formClassName,
fieldClassName,
labelClassName,
inputContainerClassName,
inputClassName,
ref,
disabled,
formFromProps,
}: BaseFormProps) => {
const formFromHook = useForm({
defaultValues,
})
const form: any = formFromProps || formFromHook
const { getFormValues } = useGetFormValues(form, formSchemas)
const { getValidators } = useGetValidators()
useImperativeHandle(ref, () => {
return {
getForm() {
return form
},
getFormValues: (option) => {
return getFormValues(option)
},
}
}, [form, getFormValues])
const renderField = useCallback((field: AnyFieldApi) => {
const formSchema = formSchemas?.find(schema => schema.name === field.name)
if (formSchema) {
return (
<BaseField
field={field}
formSchema={formSchema}
fieldClassName={fieldClassName}
labelClassName={labelClassName}
inputContainerClassName={inputContainerClassName}
inputClassName={inputClassName}
disabled={disabled}
/>
)
}
return null
}, [formSchemas, fieldClassName, labelClassName, inputContainerClassName, inputClassName, disabled])
const renderFieldWrapper = useCallback((formSchema: FormSchema) => {
const validators = getValidators(formSchema)
const {
name,
} = formSchema
return (
<form.Field
key={name}
name={name}
validators={validators}
>
{renderField}
</form.Field>
)
}, [renderField, form, getValidators])
if (!formSchemas?.length)
return null
return (
<form
className={cn(formClassName)}
>
{formSchemas.map(renderFieldWrapper)}
</form>
)
}
export default memo(BaseForm)

@ -0,0 +1,2 @@
export { default as BaseForm, type BaseFormProps } from './base-form'
export { default as BaseField, type BaseFieldProps } from './base-field'

@ -0,0 +1,23 @@
import { memo } from 'react'
import { BaseForm } from '../../components/base'
import type { BaseFormProps } from '../../components/base'
const AuthForm = ({
formSchemas = [],
defaultValues,
ref,
formFromProps,
}: BaseFormProps) => {
return (
<BaseForm
ref={ref}
formSchemas={formSchemas}
defaultValues={defaultValues}
formClassName='space-y-4'
labelClassName='h-6 flex items-center mb-1 system-sm-medium text-text-secondary'
formFromProps={formFromProps}
/>
)
}
export default memo(AuthForm)

@ -0,0 +1,3 @@
export * from './use-check-validated'
export * from './use-get-form-values'
export * from './use-get-validators'

@ -0,0 +1,48 @@
import { useCallback } from 'react'
import type { AnyFormApi } from '@tanstack/react-form'
import { useToastContext } from '@/app/components/base/toast'
import type { FormSchema } from '@/app/components/base/form/types'
export const useCheckValidated = (form: AnyFormApi, FormSchemas: FormSchema[]) => {
const { notify } = useToastContext()
const checkValidated = useCallback(() => {
const allError = form?.getAllErrors()
const values = form.state.values
if (allError) {
const fields = allError.fields
const errorArray = Object.keys(fields).reduce((acc: string[], key: string) => {
const currentSchema = FormSchemas.find(schema => schema.name === key)
const { show_on = [] } = currentSchema || {}
const showOnValues = show_on.reduce((acc, condition) => {
acc[condition.variable] = values[condition.variable]
return acc
}, {} as Record<string, any>)
const show = show_on?.every((condition) => {
const conditionValue = showOnValues[condition.variable]
return conditionValue === condition.value
})
const errors: any[] = show ? fields[key].errors : []
return [...acc, ...errors]
}, [] as string[])
if (errorArray.length) {
notify({
type: 'error',
message: errorArray[0],
})
return false
}
return true
}
return true
}, [form, notify, FormSchemas])
return {
checkValidated,
}
}

@ -0,0 +1,44 @@
import { useCallback } from 'react'
import type { AnyFormApi } from '@tanstack/react-form'
import { useCheckValidated } from './use-check-validated'
import type {
FormSchema,
GetValuesOptions,
} from '../types'
import { getTransformedValuesWhenSecretInputPristine } from '../utils'
export const useGetFormValues = (form: AnyFormApi, formSchemas: FormSchema[]) => {
const { checkValidated } = useCheckValidated(form, formSchemas)
const getFormValues = useCallback((
{
needCheckValidatedValues,
needTransformWhenSecretFieldIsPristine,
}: GetValuesOptions,
) => {
const values = form?.store.state.values || {}
if (!needCheckValidatedValues) {
return {
values,
isCheckValidated: false,
}
}
if (checkValidated()) {
return {
values: needTransformWhenSecretFieldIsPristine ? getTransformedValuesWhenSecretInputPristine(formSchemas, form) : values,
isCheckValidated: true,
}
}
else {
return {
values: {},
isCheckValidated: false,
}
}
}, [form, checkValidated, formSchemas])
return {
getFormValues,
}
}

@ -0,0 +1,36 @@
import { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import type { FormSchema } from '../types'
export const useGetValidators = () => {
const { t } = useTranslation()
const getValidators = useCallback((formSchema: FormSchema) => {
const {
name,
validators,
required,
} = formSchema
let mergedValidators = validators
if (required && !validators) {
mergedValidators = {
onMount: ({ value }: any) => {
if (!value)
return t('common.errorMsg.fieldRequired', { field: name })
},
onChange: ({ value }: any) => {
if (!value)
return t('common.errorMsg.fieldRequired', { field: name })
},
onBlur: ({ value }: any) => {
if (!value)
return t('common.errorMsg.fieldRequired', { field: name })
},
}
}
return mergedValidators
}, [t])
return {
getValidators,
}
}

@ -0,0 +1,76 @@
import type {
ForwardedRef,
ReactNode,
} from 'react'
import type {
AnyFormApi,
FieldValidators,
} from '@tanstack/react-form'
export type TypeWithI18N<T = string> = {
en_US: T
zh_Hans: T
[key: string]: T
}
export type FormShowOnObject = {
variable: string
value: string
}
export enum FormTypeEnum {
textInput = 'text-input',
textNumber = 'number-input',
secretInput = 'secret-input',
select = 'select',
radio = 'radio',
boolean = 'boolean',
files = 'files',
file = 'file',
modelSelector = 'model-selector',
toolSelector = 'tool-selector',
multiToolSelector = 'array[tools]',
appSelector = 'app-selector',
dynamicSelect = 'dynamic-select',
}
export type FormOption = {
label: TypeWithI18N | string
value: string
show_on?: FormShowOnObject[]
icon?: string
}
export type AnyValidators = FieldValidators<any, any, any, any, any, any, any, any, any, any>
export type FormSchema = {
type: FormTypeEnum
name: string
label: string | ReactNode | TypeWithI18N
required: boolean
default?: any
tooltip?: string | TypeWithI18N
show_on?: FormShowOnObject[]
url?: string
scope?: string
help?: string | TypeWithI18N
placeholder?: string | TypeWithI18N
options?: FormOption[]
labelClassName?: string
validators?: AnyValidators
}
export type FormValues = Record<string, any>
export type GetValuesOptions = {
needTransformWhenSecretFieldIsPristine?: boolean
needCheckValidatedValues?: boolean
}
export type FormRefObject = {
getForm: () => AnyFormApi
getFormValues: (obj: GetValuesOptions) => {
values: Record<string, any>
isCheckValidated: boolean
}
}
export type FormRef = ForwardedRef<FormRefObject>

@ -0,0 +1 @@
export * from './secret-input'

@ -0,0 +1,29 @@
import type { AnyFormApi } from '@tanstack/react-form'
import type { FormSchema } from '@/app/components/base/form/types'
import { FormTypeEnum } from '@/app/components/base/form/types'
export const transformFormSchemasSecretInput = (isPristineSecretInputNames: string[], values: Record<string, any>) => {
const transformedValues: Record<string, any> = { ...values }
isPristineSecretInputNames.forEach((name) => {
if (transformedValues[name])
transformedValues[name] = '[__HIDDEN__]'
})
return transformedValues
}
export const getTransformedValuesWhenSecretInputPristine = (formSchemas: FormSchema[], form: AnyFormApi) => {
const values = form?.store.state.values || {}
const isPristineSecretInputNames: string[] = []
for (let i = 0; i < formSchemas.length; i++) {
const schema = formSchemas[i]
if (schema.type === FormTypeEnum.secretInput) {
const fieldMeta = form?.getFieldMeta(schema.name)
if (fieldMeta?.isPristine)
isPristineSecretInputNames.push(schema.name)
}
}
return transformFormSchemasSecretInput(isPristineSecretInputNames, values)
}

@ -1,3 +0,0 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M16 4C16.93 4 17.395 4 17.7765 4.10222C18.8117 4.37962 19.6204 5.18827 19.8978 6.22354C20 6.60504 20 7.07003 20 8V17.2C20 18.8802 20 19.7202 19.673 20.362C19.3854 20.9265 18.9265 21.3854 18.362 21.673C17.7202 22 16.8802 22 15.2 22H8.8C7.11984 22 6.27976 22 5.63803 21.673C5.07354 21.3854 4.6146 20.9265 4.32698 20.362C4 19.7202 4 18.8802 4 17.2V8C4 7.07003 4 6.60504 4.10222 6.22354C4.37962 5.18827 5.18827 4.37962 6.22354 4.10222C6.60504 4 7.07003 4 8 4M9 15L11 17L15.5 12.5M9.6 6H14.4C14.9601 6 15.2401 6 15.454 5.89101C15.6422 5.79513 15.7951 5.64215 15.891 5.45399C16 5.24008 16 4.96005 16 4.4V3.6C16 3.03995 16 2.75992 15.891 2.54601C15.7951 2.35785 15.6422 2.20487 15.454 2.10899C15.2401 2 14.9601 2 14.4 2H9.6C9.03995 2 8.75992 2 8.54601 2.10899C8.35785 2.20487 8.20487 2.35785 8.10899 2.54601C8 2.75992 8 3.03995 8 3.6V4.4C8 4.96005 8 5.24008 8.10899 5.45399C8.20487 5.64215 8.35785 5.79513 8.54601 5.89101C8.75992 6 9.03995 6 9.6 6Z" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 1.1 KiB

@ -1,3 +0,0 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M16 4C16.93 4 17.395 4 17.7765 4.10222C18.8117 4.37962 19.6204 5.18827 19.8978 6.22354C20 6.60504 20 7.07003 20 8V17.2C20 18.8802 20 19.7202 19.673 20.362C19.3854 20.9265 18.9265 21.3854 18.362 21.673C17.7202 22 16.8802 22 15.2 22H8.8C7.11984 22 6.27976 22 5.63803 21.673C5.07354 21.3854 4.6146 20.9265 4.32698 20.362C4 19.7202 4 18.8802 4 17.2V8C4 7.07003 4 6.60504 4.10222 6.22354C4.37962 5.18827 5.18827 4.37962 6.22354 4.10222C6.60504 4 7.07003 4 8 4M9.6 6H14.4C14.9601 6 15.2401 6 15.454 5.89101C15.6422 5.79513 15.7951 5.64215 15.891 5.45399C16 5.24008 16 4.96005 16 4.4V3.6C16 3.03995 16 2.75992 15.891 2.54601C15.7951 2.35785 15.6422 2.20487 15.454 2.10899C15.2401 2 14.9601 2 14.4 2H9.6C9.03995 2 8.75992 2 8.54601 2.10899C8.35785 2.20487 8.20487 2.35785 8.10899 2.54601C8 2.75992 8 3.03995 8 3.6V4.4C8 4.96005 8 5.24008 8.10899 5.45399C8.20487 5.64215 8.35785 5.79513 8.54601 5.89101C8.75992 6 9.03995 6 9.6 6Z" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

Before

Width:  |  Height:  |  Size: 1.1 KiB

@ -0,0 +1,3 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M10.6665 2.66683C11.2865 2.66683 11.5965 2.66683 11.8508 2.73498C12.541 2.91991 13.0801 3.45901 13.265 4.14919C13.3332 4.40352 13.3332 4.71352 13.3332 5.3335V11.4668C13.3332 12.5869 13.3332 13.147 13.1152 13.5748C12.9234 13.9511 12.6175 14.2571 12.2412 14.4488C11.8133 14.6668 11.2533 14.6668 10.1332 14.6668H5.8665C4.7464 14.6668 4.18635 14.6668 3.75852 14.4488C3.3822 14.2571 3.07624 13.9511 2.88449 13.5748C2.6665 13.147 2.6665 12.5869 2.6665 11.4668V5.3335C2.6665 4.71352 2.6665 4.40352 2.73465 4.14919C2.91959 3.45901 3.45868 2.91991 4.14887 2.73498C4.4032 2.66683 4.71319 2.66683 5.33317 2.66683M5.99984 10.0002L7.33317 11.3335L10.3332 8.3335M6.39984 4.00016H9.59984C9.9732 4.00016 10.1599 4.00016 10.3025 3.9275C10.4279 3.86359 10.5299 3.7616 10.5938 3.63616C10.6665 3.49355 10.6665 3.30686 10.6665 2.9335V2.40016C10.6665 2.02679 10.6665 1.84011 10.5938 1.6975C10.5299 1.57206 10.4279 1.47007 10.3025 1.40616C10.1599 1.3335 9.97321 1.3335 9.59984 1.3335H6.39984C6.02647 1.3335 5.83978 1.3335 5.69718 1.40616C5.57174 1.47007 5.46975 1.57206 5.40583 1.6975C5.33317 1.84011 5.33317 2.02679 5.33317 2.40016V2.9335C5.33317 3.30686 5.33317 3.49355 5.40583 3.63616C5.46975 3.7616 5.57174 3.86359 5.69718 3.9275C5.83978 4.00016 6.02647 4.00016 6.39984 4.00016Z" stroke="#1D2939" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

After

Width:  |  Height:  |  Size: 1.4 KiB

@ -0,0 +1,3 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M10.6665 2.66634H11.9998C12.3535 2.66634 12.6926 2.80682 12.9426 3.05687C13.1927 3.30691 13.3332 3.64605 13.3332 3.99967V13.333C13.3332 13.6866 13.1927 14.0258 12.9426 14.2758C12.6926 14.5259 12.3535 14.6663 11.9998 14.6663H3.99984C3.64622 14.6663 3.30708 14.5259 3.05703 14.2758C2.80698 14.0258 2.6665 13.6866 2.6665 13.333V3.99967C2.6665 3.64605 2.80698 3.30691 3.05703 3.05687C3.30708 2.80682 3.64622 2.66634 3.99984 2.66634H5.33317M5.99984 1.33301H9.99984C10.368 1.33301 10.6665 1.63148 10.6665 1.99967V3.33301C10.6665 3.7012 10.368 3.99967 9.99984 3.99967H5.99984C5.63165 3.99967 5.33317 3.7012 5.33317 3.33301V1.99967C5.33317 1.63148 5.63165 1.33301 5.99984 1.33301Z" stroke="#667085" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

After

Width:  |  Height:  |  Size: 875 B

@ -1,29 +0,0 @@
{
"icon": {
"type": "element",
"isRootNode": true,
"name": "svg",
"attributes": {
"width": "24",
"height": "24",
"viewBox": "0 0 24 24",
"fill": "none",
"xmlns": "http://www.w3.org/2000/svg"
},
"children": [
{
"type": "element",
"name": "path",
"attributes": {
"d": "M16 4C16.93 4 17.395 4 17.7765 4.10222C18.8117 4.37962 19.6204 5.18827 19.8978 6.22354C20 6.60504 20 7.07003 20 8V17.2C20 18.8802 20 19.7202 19.673 20.362C19.3854 20.9265 18.9265 21.3854 18.362 21.673C17.7202 22 16.8802 22 15.2 22H8.8C7.11984 22 6.27976 22 5.63803 21.673C5.07354 21.3854 4.6146 20.9265 4.32698 20.362C4 19.7202 4 18.8802 4 17.2V8C4 7.07003 4 6.60504 4.10222 6.22354C4.37962 5.18827 5.18827 4.37962 6.22354 4.10222C6.60504 4 7.07003 4 8 4M9.6 6H14.4C14.9601 6 15.2401 6 15.454 5.89101C15.6422 5.79513 15.7951 5.64215 15.891 5.45399C16 5.24008 16 4.96005 16 4.4V3.6C16 3.03995 16 2.75992 15.891 2.54601C15.7951 2.35785 15.6422 2.20487 15.454 2.10899C15.2401 2 14.9601 2 14.4 2H9.6C9.03995 2 8.75992 2 8.54601 2.10899C8.35785 2.20487 8.20487 2.35785 8.10899 2.54601C8 2.75992 8 3.03995 8 3.6V4.4C8 4.96005 8 5.24008 8.10899 5.45399C8.20487 5.64215 8.35785 5.79513 8.54601 5.89101C8.75992 6 9.03995 6 9.6 6Z",
"stroke": "currentColor",
"stroke-width": "2",
"stroke-linecap": "round",
"stroke-linejoin": "round"
},
"children": []
}
]
},
"name": "Clipboard"
}

@ -1,29 +0,0 @@
{
"icon": {
"type": "element",
"isRootNode": true,
"name": "svg",
"attributes": {
"width": "24",
"height": "24",
"viewBox": "0 0 24 24",
"fill": "none",
"xmlns": "http://www.w3.org/2000/svg"
},
"children": [
{
"type": "element",
"name": "path",
"attributes": {
"d": "M16 4C16.93 4 17.395 4 17.7765 4.10222C18.8117 4.37962 19.6204 5.18827 19.8978 6.22354C20 6.60504 20 7.07003 20 8V17.2C20 18.8802 20 19.7202 19.673 20.362C19.3854 20.9265 18.9265 21.3854 18.362 21.673C17.7202 22 16.8802 22 15.2 22H8.8C7.11984 22 6.27976 22 5.63803 21.673C5.07354 21.3854 4.6146 20.9265 4.32698 20.362C4 19.7202 4 18.8802 4 17.2V8C4 7.07003 4 6.60504 4.10222 6.22354C4.37962 5.18827 5.18827 4.37962 6.22354 4.10222C6.60504 4 7.07003 4 8 4M9 15L11 17L15.5 12.5M9.6 6H14.4C14.9601 6 15.2401 6 15.454 5.89101C15.6422 5.79513 15.7951 5.64215 15.891 5.45399C16 5.24008 16 4.96005 16 4.4V3.6C16 3.03995 16 2.75992 15.891 2.54601C15.7951 2.35785 15.6422 2.20487 15.454 2.10899C15.2401 2 14.9601 2 14.4 2H9.6C9.03995 2 8.75992 2 8.54601 2.10899C8.35785 2.20487 8.20487 2.35785 8.10899 2.54601C8 2.75992 8 3.03995 8 3.6V4.4C8 4.96005 8 5.24008 8.10899 5.45399C8.20487 5.64215 8.35785 5.79513 8.54601 5.89101C8.75992 6 9.03995 6 9.6 6Z",
"stroke": "currentColor",
"stroke-width": "2",
"stroke-linecap": "round",
"stroke-linejoin": "round"
},
"children": []
}
]
},
"name": "ClipboardCheck"
}

@ -0,0 +1,29 @@
{
"icon": {
"type": "element",
"isRootNode": true,
"name": "svg",
"attributes": {
"width": "16",
"height": "16",
"viewBox": "0 0 16 16",
"fill": "none",
"xmlns": "http://www.w3.org/2000/svg"
},
"children": [
{
"type": "element",
"name": "path",
"attributes": {
"d": "M10.6665 2.66634H11.9998C12.3535 2.66634 12.6926 2.80682 12.9426 3.05687C13.1927 3.30691 13.3332 3.64605 13.3332 3.99967V13.333C13.3332 13.6866 13.1927 14.0258 12.9426 14.2758C12.6926 14.5259 12.3535 14.6663 11.9998 14.6663H3.99984C3.64622 14.6663 3.30708 14.5259 3.05703 14.2758C2.80698 14.0258 2.6665 13.6866 2.6665 13.333V3.99967C2.6665 3.64605 2.80698 3.30691 3.05703 3.05687C3.30708 2.80682 3.64622 2.66634 3.99984 2.66634H5.33317M5.99984 1.33301H9.99984C10.368 1.33301 10.6665 1.63148 10.6665 1.99967V3.33301C10.6665 3.7012 10.368 3.99967 9.99984 3.99967H5.99984C5.63165 3.99967 5.33317 3.7012 5.33317 3.33301V1.99967C5.33317 1.63148 5.63165 1.33301 5.99984 1.33301Z",
"stroke": "currentColor",
"stroke-width": "1.5",
"stroke-linecap": "round",
"stroke-linejoin": "round"
},
"children": []
}
]
},
"name": "Copy"
}

@ -2,7 +2,7 @@
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import data from './Clipboard.json'
import data from './Copy.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconData } from '@/app/components/base/icons/IconBase'
@ -15,6 +15,6 @@ const Icon = (
},
) => <IconBase {...props} ref={ref} data={data as IconData} />
Icon.displayName = 'Clipboard'
Icon.displayName = 'Copy'
export default Icon

@ -0,0 +1,29 @@
{
"icon": {
"type": "element",
"isRootNode": true,
"name": "svg",
"attributes": {
"width": "16",
"height": "16",
"viewBox": "0 0 16 16",
"fill": "none",
"xmlns": "http://www.w3.org/2000/svg"
},
"children": [
{
"type": "element",
"name": "path",
"attributes": {
"d": "M10.6665 2.66683C11.2865 2.66683 11.5965 2.66683 11.8508 2.73498C12.541 2.91991 13.0801 3.45901 13.265 4.14919C13.3332 4.40352 13.3332 4.71352 13.3332 5.3335V11.4668C13.3332 12.5869 13.3332 13.147 13.1152 13.5748C12.9234 13.9511 12.6175 14.2571 12.2412 14.4488C11.8133 14.6668 11.2533 14.6668 10.1332 14.6668H5.8665C4.7464 14.6668 4.18635 14.6668 3.75852 14.4488C3.3822 14.2571 3.07624 13.9511 2.88449 13.5748C2.6665 13.147 2.6665 12.5869 2.6665 11.4668V5.3335C2.6665 4.71352 2.6665 4.40352 2.73465 4.14919C2.91959 3.45901 3.45868 2.91991 4.14887 2.73498C4.4032 2.66683 4.71319 2.66683 5.33317 2.66683M5.99984 10.0002L7.33317 11.3335L10.3332 8.3335M6.39984 4.00016H9.59984C9.9732 4.00016 10.1599 4.00016 10.3025 3.9275C10.4279 3.86359 10.5299 3.7616 10.5938 3.63616C10.6665 3.49355 10.6665 3.30686 10.6665 2.9335V2.40016C10.6665 2.02679 10.6665 1.84011 10.5938 1.6975C10.5299 1.57206 10.4279 1.47007 10.3025 1.40616C10.1599 1.3335 9.97321 1.3335 9.59984 1.3335H6.39984C6.02647 1.3335 5.83978 1.3335 5.69718 1.40616C5.57174 1.47007 5.46975 1.57206 5.40583 1.6975C5.33317 1.84011 5.33317 2.02679 5.33317 2.40016V2.9335C5.33317 3.30686 5.33317 3.49355 5.40583 3.63616C5.46975 3.7616 5.57174 3.86359 5.69718 3.9275C5.83978 4.00016 6.02647 4.00016 6.39984 4.00016Z",
"stroke": "currentColor",
"stroke-width": "1.5",
"stroke-linecap": "round",
"stroke-linejoin": "round"
},
"children": []
}
]
},
"name": "CopyCheck"
}

@ -2,7 +2,7 @@
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import data from './ClipboardCheck.json'
import data from './CopyCheck.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconData } from '@/app/components/base/icons/IconBase'
@ -15,6 +15,6 @@ const Icon = (
},
) => <IconBase {...props} ref={ref} data={data as IconData} />
Icon.displayName = 'ClipboardCheck'
Icon.displayName = 'CopyCheck'
export default Icon

@ -1,5 +1,5 @@
export { default as ClipboardCheck } from './ClipboardCheck'
export { default as Clipboard } from './Clipboard'
export { default as CopyCheck } from './CopyCheck'
export { default as Copy } from './Copy'
export { default as File02 } from './File02'
export { default as FileArrow01 } from './FileArrow01'
export { default as FileCheck02 } from './FileCheck02'

@ -0,0 +1,127 @@
import { memo } from 'react'
import { useTranslation } from 'react-i18next'
import { RiCloseLine } from '@remixicon/react'
import {
PortalToFollowElem,
PortalToFollowElemContent,
} from '@/app/components/base/portal-to-follow-elem'
import Button from '@/app/components/base/button'
import type { ButtonProps } from '@/app/components/base/button'
import cn from '@/utils/classnames'
type ModalProps = {
onClose?: () => void
size?: 'sm' | 'md'
title: string
subTitle?: string
children?: React.ReactNode
confirmButtonText?: string
onConfirm?: () => void
cancelButtonText?: string
onCancel?: () => void
showExtraButton?: boolean
extraButtonText?: string
extraButtonVariant?: ButtonProps['variant']
onExtraButtonClick?: () => void
footerSlot?: React.ReactNode
bottomSlot?: React.ReactNode
disabled?: boolean
}
const Modal = ({
onClose,
size = 'sm',
title,
subTitle,
children,
confirmButtonText,
onConfirm,
cancelButtonText,
onCancel,
showExtraButton,
extraButtonVariant = 'warning',
extraButtonText,
onExtraButtonClick,
footerSlot,
bottomSlot,
disabled,
}: ModalProps) => {
const { t } = useTranslation()
return (
<PortalToFollowElem open>
<PortalToFollowElemContent
className='z-[9998] flex h-full w-full items-center justify-center bg-background-overlay'
onClick={onClose}
>
<div
className={cn(
'max-h-[80%] w-[480px] overflow-y-auto rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-xs',
size === 'sm' && 'w-[480px',
size === 'md' && 'w-[640px]',
)}
onClick={e => e.stopPropagation()}
>
<div className='title-2xl-semi-bold relative p-6 pb-3 pr-14 text-text-primary'>
{title}
{
subTitle && (
<div className='system-xs-regular mt-1 text-text-tertiary'>
{subTitle}
</div>
)
}
<div
className='absolute right-5 top-5 flex h-8 w-8 cursor-pointer items-center justify-center rounded-lg'
onClick={onClose}
>
<RiCloseLine className='h-5 w-5 text-text-tertiary' />
</div>
</div>
{
children && (
<div className='px-6 py-3'>{children}</div>
)
}
<div className='flex justify-between p-6 pt-5'>
<div>
{footerSlot}
</div>
<div className='flex items-center'>
{
showExtraButton && (
<>
<Button
variant={extraButtonVariant}
onClick={onExtraButtonClick}
disabled={disabled}
>
{extraButtonText || t('common.operation.remove')}
</Button>
<div className='mx-3 h-4 w-[1px] bg-divider-regular'></div>
</>
)
}
<Button
onClick={onCancel}
disabled={disabled}
>
{cancelButtonText || t('common.operation.cancel')}
</Button>
<Button
className='ml-2'
variant='primary'
onClick={onConfirm}
disabled={disabled}
>
{confirmButtonText || t('common.operation.save')}
</Button>
</div>
</div>
{bottomSlot}
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
)
}
export default memo(Modal)

@ -39,6 +39,9 @@ type PureSelectProps = {
itemClassName?: string
title?: string
},
placeholder?: string
disabled?: boolean
triggerPopupSameWidth?: boolean
}
const PureSelect = ({
options,
@ -47,6 +50,9 @@ const PureSelect = ({
containerProps,
triggerProps,
popupProps,
placeholder,
disabled,
triggerPopupSameWidth,
}: PureSelectProps) => {
const { t } = useTranslation()
const {
@ -74,7 +80,7 @@ const PureSelect = ({
}, [onOpenChange])
const selectedOption = options.find(option => option.value === value)
const triggerText = selectedOption?.label || t('common.placeholder.select')
const triggerText = selectedOption?.label || placeholder || t('common.placeholder.select')
return (
<PortalToFollowElem
@ -82,6 +88,7 @@ const PureSelect = ({
offset={offset || 4}
open={mergedOpen}
onOpenChange={handleOpenChange}
triggerPopupSameWidth={triggerPopupSameWidth}
>
<PortalToFollowElemTrigger
onClick={() => handleOpenChange(!mergedOpen)}
@ -135,6 +142,7 @@ const PureSelect = ({
)}
title={option.label}
onClick={() => {
if (disabled) return
onChange?.(option.value)
handleOpenChange(false)
}}

@ -66,7 +66,7 @@ const ChildSegmentDetail: FC<IChildSegmentDetailProps> = ({
const EditTimeText = useMemo(() => {
const timeText = formatTime({
date: (childChunkInfo?.updated_at ?? 0) * 1000,
dateFormat: 'MM/DD/YYYY h:mm:ss',
dateFormat: `${t('datasetDocuments.segment.dateTimeFormat')}`,
})
return `${t('datasetDocuments.segment.editedAt')} ${timeText}`
// eslint-disable-next-line react-hooks/exhaustive-deps
@ -74,7 +74,7 @@ const ChildSegmentDetail: FC<IChildSegmentDetailProps> = ({
return (
<div className={'flex h-full flex-col'}>
<div className={classNames('flex items-center justify-between', fullScreen ? 'py-3 pr-4 pl-6 border border-divider-subtle' : 'pt-3 pr-3 pl-4')}>
<div className={classNames('flex items-center justify-between', fullScreen ? 'border border-divider-subtle py-3 pl-6 pr-4' : 'pl-4 pr-3 pt-3')}>
<div className='flex flex-col'>
<div className='system-xl-semibold text-text-primary'>{t('datasetDocuments.segment.editChildChunk')}</div>
<div className='flex items-center gap-x-2'>
@ -107,8 +107,8 @@ const ChildSegmentDetail: FC<IChildSegmentDetailProps> = ({
</div>
</div>
</div>
<div className={classNames('flex grow w-full', fullScreen ? 'flex-row justify-center px-6 pt-6' : 'py-3 px-4')}>
<div className={classNames('break-all overflow-hidden whitespace-pre-line h-full', fullScreen ? 'w-1/2' : 'w-full')}>
<div className={classNames('flex w-full grow', fullScreen ? 'flex-row justify-center px-6 pt-6' : 'px-4 py-3')}>
<div className={classNames('h-full overflow-hidden whitespace-pre-line break-all', fullScreen ? 'w-1/2' : 'w-full')}>
<ChunkContent
docForm={docForm}
question={content}

@ -6,7 +6,7 @@ import {
RiClipboardLine,
} from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import { ClipboardCheck } from '../../base/icons/src/vender/line/files'
import { CopyCheck } from '../../base/icons/src/vender/line/files'
import Tooltip from '../../base/tooltip'
import cn from '@/utils/classnames'
import ActionButton from '@/app/components/base/action-button'
@ -44,7 +44,7 @@ const KeyValueItem: FC<Props> = ({
}
}, [isCopied])
const CopyIcon = isCopied ? ClipboardCheck : RiClipboardLine
const CopyIcon = isCopied ? CopyCheck : RiClipboardLine
return (
<div className='flex items-center gap-1'>

@ -0,0 +1,50 @@
import {
memo,
useState,
} from 'react'
import Button from '@/app/components/base/button'
import type { ButtonProps } from '@/app/components/base/button'
import ApiKeyModal from './api-key-modal'
import type { PluginPayload } from '../types'
export type AddApiKeyButtonProps = {
pluginPayload: PluginPayload
buttonVariant?: ButtonProps['variant']
buttonText?: string
disabled?: boolean
onUpdate?: () => void
}
const AddApiKeyButton = ({
pluginPayload,
buttonVariant = 'secondary-accent',
buttonText = 'use api key',
disabled,
onUpdate,
}: AddApiKeyButtonProps) => {
const [isApiKeyModalOpen, setIsApiKeyModalOpen] = useState(false)
return (
<>
<Button
className='w-full'
variant={buttonVariant}
onClick={() => setIsApiKeyModalOpen(true)}
disabled={disabled}
>
{buttonText}
</Button>
{
isApiKeyModalOpen && (
<ApiKeyModal
pluginPayload={pluginPayload}
onClose={() => setIsApiKeyModalOpen(false)}
onUpdate={onUpdate}
/>
)
}
</>
)
}
export default memo(AddApiKeyButton)

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save