Merge branch 'main' into feat/dependency-relations
# Conflicts: # web/app/components/workflow/nodes/_base/components/workflow-panel/index.tsxpull/21998/head
commit
02de417615
@ -0,0 +1,84 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Optional
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class ProviderCredentialsCache(ABC):
|
||||
"""Base class for provider credentials cache"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.cache_key = self._generate_cache_key(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
"""Generate cache key based on subclass implementation"""
|
||||
pass
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""Get cached provider credentials"""
|
||||
cached_credentials = redis_client.get(self.cache_key)
|
||||
if cached_credentials:
|
||||
try:
|
||||
cached_credentials = cached_credentials.decode("utf-8")
|
||||
return dict(json.loads(cached_credentials))
|
||||
except JSONDecodeError:
|
||||
return None
|
||||
return None
|
||||
|
||||
def set(self, config: dict[str, Any]) -> None:
|
||||
"""Cache provider credentials"""
|
||||
redis_client.setex(self.cache_key, 86400, json.dumps(config))
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete cached provider credentials"""
|
||||
redis_client.delete(self.cache_key)
|
||||
|
||||
|
||||
class SingletonProviderCredentialsCache(ProviderCredentialsCache):
|
||||
"""Cache for tool single provider credentials"""
|
||||
|
||||
def __init__(self, tenant_id: str, provider_type: str, provider_identity: str):
|
||||
super().__init__(
|
||||
tenant_id=tenant_id,
|
||||
provider_type=provider_type,
|
||||
provider_identity=provider_identity,
|
||||
)
|
||||
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider_type = kwargs["provider_type"]
|
||||
identity_name = kwargs["provider_identity"]
|
||||
identity_id = f"{provider_type}.{identity_name}"
|
||||
return f"{provider_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
||||
|
||||
|
||||
class ToolProviderCredentialsCache(ProviderCredentialsCache):
|
||||
"""Cache for tool provider credentials"""
|
||||
|
||||
def __init__(self, tenant_id: str, provider: str, credential_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id)
|
||||
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider = kwargs["provider"]
|
||||
credential_id = kwargs["credential_id"]
|
||||
return f"tool_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
|
||||
|
||||
|
||||
class NoOpProviderCredentialCache:
|
||||
"""No-op provider credential cache"""
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""Get cached provider credentials"""
|
||||
return None
|
||||
|
||||
def set(self, config: dict[str, Any]) -> None:
|
||||
"""Cache provider credentials"""
|
||||
pass
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete cached provider credentials"""
|
||||
pass
|
||||
@ -1,51 +0,0 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class ToolProviderCredentialsCacheType(Enum):
|
||||
PROVIDER = "tool_provider"
|
||||
ENDPOINT = "endpoint"
|
||||
|
||||
|
||||
class ToolProviderCredentialsCache:
|
||||
def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):
|
||||
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""
|
||||
Get cached model provider credentials.
|
||||
|
||||
:return:
|
||||
"""
|
||||
cached_provider_credentials = redis_client.get(self.cache_key)
|
||||
if cached_provider_credentials:
|
||||
try:
|
||||
cached_provider_credentials = cached_provider_credentials.decode("utf-8")
|
||||
cached_provider_credentials = json.loads(cached_provider_credentials)
|
||||
except JSONDecodeError:
|
||||
return None
|
||||
|
||||
return dict(cached_provider_credentials)
|
||||
else:
|
||||
return None
|
||||
|
||||
def set(self, credentials: dict) -> None:
|
||||
"""
|
||||
Cache model provider credentials.
|
||||
|
||||
:param credentials: provider credentials
|
||||
:return:
|
||||
"""
|
||||
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
|
||||
|
||||
def delete(self) -> None:
|
||||
"""
|
||||
Delete cached model provider credentials.
|
||||
|
||||
:return:
|
||||
"""
|
||||
redis_client.delete(self.cache_key)
|
||||
@ -0,0 +1,142 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional, Protocol
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import SingletonProviderCredentialsCache
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
|
||||
|
||||
class ProviderConfigCache(Protocol):
|
||||
"""
|
||||
Interface for provider configuration cache operations
|
||||
"""
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""Get cached provider configuration"""
|
||||
...
|
||||
|
||||
def set(self, config: dict[str, Any]) -> None:
|
||||
"""Cache provider configuration"""
|
||||
...
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete cached provider configuration"""
|
||||
...
|
||||
|
||||
|
||||
class ProviderConfigEncrypter:
|
||||
tenant_id: str
|
||||
config: list[BasicProviderConfig]
|
||||
provider_config_cache: ProviderConfigCache
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
config: list[BasicProviderConfig],
|
||||
provider_config_cache: ProviderConfigCache,
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.config = config
|
||||
self.provider_config_cache = provider_config_cache
|
||||
|
||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
deep copy data
|
||||
"""
|
||||
return deepcopy(data)
|
||||
|
||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||
data[field_name] = encrypted
|
||||
|
||||
return data
|
||||
|
||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
if len(data[field_name]) > 6:
|
||||
data[field_name] = (
|
||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
data[field_name] = "*" * len(data[field_name])
|
||||
|
||||
return data
|
||||
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cached_credentials = self.provider_config_cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
try:
|
||||
# if the value is None or empty string, skip decrypt
|
||||
if not data[field_name]:
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.provider_config_cache.set(data)
|
||||
return data
|
||||
|
||||
|
||||
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
|
||||
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
|
||||
|
||||
|
||||
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
|
||||
cache = SingletonProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_type=controller.provider_type.value,
|
||||
provider_identity=controller.entity.identity.name,
|
||||
)
|
||||
encrypt = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
|
||||
provider_config_cache=cache,
|
||||
)
|
||||
return encrypt, cache
|
||||
@ -0,0 +1,187 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Random import get_random_bytes
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthEncryptionError(Exception):
|
||||
"""OAuth encryption/decryption specific error"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SystemOAuthEncrypter:
|
||||
"""
|
||||
A simple OAuth parameters encrypter using AES-CBC encryption.
|
||||
|
||||
This class provides methods to encrypt and decrypt OAuth parameters
|
||||
using AES-CBC mode with a key derived from the application's SECRET_KEY.
|
||||
"""
|
||||
|
||||
def __init__(self, secret_key: Optional[str] = None):
|
||||
"""
|
||||
Initialize the OAuth encrypter.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
|
||||
Raises:
|
||||
ValueError: If SECRET_KEY is not configured or empty
|
||||
"""
|
||||
secret_key = secret_key or dify_config.SECRET_KEY or ""
|
||||
|
||||
# Generate a fixed 256-bit key using SHA-256
|
||||
self.key = hashlib.sha256(secret_key.encode()).digest()
|
||||
|
||||
def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt OAuth parameters.
|
||||
|
||||
Args:
|
||||
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
|
||||
Raises:
|
||||
OAuthEncryptionError: If encryption fails
|
||||
ValueError: If oauth_params is invalid
|
||||
"""
|
||||
|
||||
try:
|
||||
# Generate random IV (16 bytes)
|
||||
iv = get_random_bytes(16)
|
||||
|
||||
# Create AES cipher (CBC mode)
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
|
||||
# Encrypt data
|
||||
padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
|
||||
encrypted_data = cipher.encrypt(padded_data)
|
||||
|
||||
# Combine IV and encrypted data
|
||||
combined = iv + encrypted_data
|
||||
|
||||
# Return base64 encoded string
|
||||
return base64.b64encode(combined).decode()
|
||||
|
||||
except Exception as e:
|
||||
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
|
||||
|
||||
def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt OAuth parameters.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted OAuth parameters dictionary
|
||||
|
||||
Raises:
|
||||
OAuthEncryptionError: If decryption fails
|
||||
ValueError: If encrypted_data is invalid
|
||||
"""
|
||||
if not isinstance(encrypted_data, str):
|
||||
raise ValueError("encrypted_data must be a string")
|
||||
|
||||
if not encrypted_data:
|
||||
raise ValueError("encrypted_data cannot be empty")
|
||||
|
||||
try:
|
||||
# Base64 decode
|
||||
combined = base64.b64decode(encrypted_data)
|
||||
|
||||
# Check minimum length (IV + at least one AES block)
|
||||
if len(combined) < 32: # 16 bytes IV + 16 bytes minimum encrypted data
|
||||
raise ValueError("Invalid encrypted data format")
|
||||
|
||||
# Separate IV and encrypted data
|
||||
iv = combined[:16]
|
||||
encrypted_data_bytes = combined[16:]
|
||||
|
||||
# Create AES cipher
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
|
||||
# Decrypt data
|
||||
decrypted_data = cipher.decrypt(encrypted_data_bytes)
|
||||
unpadded_data = unpad(decrypted_data, AES.block_size)
|
||||
|
||||
# Parse JSON
|
||||
oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
|
||||
|
||||
if not isinstance(oauth_params, dict):
|
||||
raise ValueError("Decrypted data is not a valid dictionary")
|
||||
|
||||
return oauth_params
|
||||
|
||||
except Exception as e:
|
||||
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
|
||||
|
||||
|
||||
# Factory function for creating encrypter instances
|
||||
def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAuthEncrypter:
|
||||
"""
|
||||
Create an OAuth encrypter instance.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
|
||||
Returns:
|
||||
SystemOAuthEncrypter instance
|
||||
"""
|
||||
return SystemOAuthEncrypter(secret_key=secret_key)
|
||||
|
||||
|
||||
# Global encrypter instance (for backward compatibility)
|
||||
_oauth_encrypter: Optional[SystemOAuthEncrypter] = None
|
||||
|
||||
|
||||
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
|
||||
"""
|
||||
Get the global OAuth encrypter instance.
|
||||
|
||||
Returns:
|
||||
SystemOAuthEncrypter instance
|
||||
"""
|
||||
global _oauth_encrypter
|
||||
if _oauth_encrypter is None:
|
||||
_oauth_encrypter = SystemOAuthEncrypter()
|
||||
return _oauth_encrypter
|
||||
|
||||
|
||||
# Convenience functions for backward compatibility
|
||||
def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt OAuth parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
oauth_params: OAuth parameters dictionary
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
"""
|
||||
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
|
||||
|
||||
|
||||
def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt OAuth parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted OAuth parameters dictionary
|
||||
"""
|
||||
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)
|
||||
@ -0,0 +1,41 @@
|
||||
"""empty message
|
||||
|
||||
Revision ID: 16081485540c
|
||||
Revises: d28f2004b072
|
||||
Create Date: 2025-05-15 16:35:39.113777
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '16081485540c'
|
||||
down_revision = '2adcbe1f5dfb'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tenant_plugin_auto_upgrade_strategies',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('strategy_setting', sa.String(length=16), server_default='fix_only', nullable=False),
|
||||
sa.Column('upgrade_time_of_day', sa.Integer(), nullable=False),
|
||||
sa.Column('upgrade_mode', sa.String(length=16), server_default='exclude', nullable=False),
|
||||
sa.Column('exclude_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
|
||||
sa.Column('include_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tenant_plugin_auto_upgrade_strategy_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', name='unique_tenant_plugin_auto_upgrade_strategy')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('tenant_plugin_auto_upgrade_strategies')
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,62 @@
|
||||
"""tool oauth
|
||||
|
||||
Revision ID: 71f5020c6470
|
||||
Revises: 4474872b0ee6
|
||||
Create Date: 2025-06-24 17:05:43.118647
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '71f5020c6470'
|
||||
down_revision = '1c9ba48be8e4'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tool_oauth_system_clients',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('plugin_id', sa.String(length=512), nullable=False),
|
||||
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'),
|
||||
sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx')
|
||||
)
|
||||
op.create_table('tool_oauth_tenant_clients',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('plugin_id', sa.String(length=512), nullable=False),
|
||||
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
|
||||
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_oauth_tenant_client_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_tenant_client')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('name', sa.String(length=256), server_default=sa.text("'API KEY 1'::character varying"), nullable=False))
|
||||
batch_op.add_column(sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False))
|
||||
batch_op.add_column(sa.Column('credential_type', sa.String(length=32), server_default=sa.text("'api-key'::character varying"), nullable=False))
|
||||
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
|
||||
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'name'])
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
|
||||
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider'])
|
||||
batch_op.drop_column('credential_type')
|
||||
batch_op.drop_column('is_default')
|
||||
batch_op.drop_column('name')
|
||||
|
||||
op.drop_table('tool_oauth_tenant_clients')
|
||||
op.drop_table('tool_oauth_system_clients')
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,496 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.oauth import (
|
||||
OAuthCallback,
|
||||
OAuthLogin,
|
||||
_generate_account,
|
||||
_get_account_by_openid_or_email,
|
||||
get_oauth_providers,
|
||||
)
|
||||
from libs.oauth import OAuthUserInfo
|
||||
from models.account import AccountStatus
|
||||
from services.errors.account import AccountNotFoundError
|
||||
|
||||
|
||||
class TestGetOAuthProviders:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("github_config", "google_config", "expected_github", "expected_google"),
|
||||
[
|
||||
# Both providers configured
|
||||
(
|
||||
{"id": "github_id", "secret": "github_secret"},
|
||||
{"id": "google_id", "secret": "google_secret"},
|
||||
True,
|
||||
True,
|
||||
),
|
||||
# Only GitHub configured
|
||||
({"id": "github_id", "secret": "github_secret"}, {"id": None, "secret": None}, True, False),
|
||||
# Only Google configured
|
||||
({"id": None, "secret": None}, {"id": "google_id", "secret": "google_secret"}, False, True),
|
||||
# No providers configured
|
||||
({"id": None, "secret": None}, {"id": None, "secret": None}, False, False),
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
def test_should_configure_oauth_providers_correctly(
|
||||
self, mock_config, app, github_config, google_config, expected_github, expected_google
|
||||
):
|
||||
mock_config.GITHUB_CLIENT_ID = github_config["id"]
|
||||
mock_config.GITHUB_CLIENT_SECRET = github_config["secret"]
|
||||
mock_config.GOOGLE_CLIENT_ID = google_config["id"]
|
||||
mock_config.GOOGLE_CLIENT_SECRET = google_config["secret"]
|
||||
mock_config.CONSOLE_API_URL = "http://localhost"
|
||||
|
||||
with app.app_context():
|
||||
providers = get_oauth_providers()
|
||||
|
||||
assert (providers["github"] is not None) == expected_github
|
||||
assert (providers["google"] is not None) == expected_google
|
||||
|
||||
|
||||
class TestOAuthLogin:
|
||||
@pytest.fixture
|
||||
def resource(self):
|
||||
return OAuthLogin()
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_oauth_provider(self):
|
||||
provider = MagicMock()
|
||||
provider.get_authorization_url.return_value = "https://github.com/login/oauth/authorize?..."
|
||||
return provider
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("invite_token", "expected_token"),
|
||||
[
|
||||
(None, None),
|
||||
("test_invite_token", "test_invite_token"),
|
||||
("", None),
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth.redirect")
|
||||
def test_should_handle_oauth_login_with_various_tokens(
|
||||
self,
|
||||
mock_redirect,
|
||||
mock_get_providers,
|
||||
resource,
|
||||
app,
|
||||
mock_oauth_provider,
|
||||
invite_token,
|
||||
expected_token,
|
||||
):
|
||||
mock_get_providers.return_value = {"github": mock_oauth_provider, "google": None}
|
||||
|
||||
query_string = f"invite_token={invite_token}" if invite_token else ""
|
||||
with app.test_request_context(f"/auth/oauth/github?{query_string}"):
|
||||
resource.get("github")
|
||||
|
||||
mock_oauth_provider.get_authorization_url.assert_called_once_with(invite_token=expected_token)
|
||||
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("provider", "expected_error"),
|
||||
[
|
||||
("invalid_provider", "Invalid provider"),
|
||||
("github", "Invalid provider"), # When GitHub is not configured
|
||||
("google", "Invalid provider"), # When Google is not configured
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
def test_should_return_error_for_invalid_providers(
|
||||
self, mock_get_providers, resource, app, provider, expected_error
|
||||
):
|
||||
mock_get_providers.return_value = {"github": None, "google": None}
|
||||
|
||||
with app.test_request_context(f"/auth/oauth/{provider}"):
|
||||
response, status_code = resource.get(provider)
|
||||
|
||||
assert status_code == 400
|
||||
assert response["error"] == expected_error
|
||||
|
||||
|
||||
class TestOAuthCallback:
|
||||
@pytest.fixture
|
||||
def resource(self):
|
||||
return OAuthCallback()
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def oauth_setup(self):
|
||||
"""Common OAuth setup for callback tests"""
|
||||
oauth_provider = MagicMock()
|
||||
oauth_provider.get_access_token.return_value = "access_token"
|
||||
oauth_provider.get_user_info.return_value = OAuthUserInfo(id="123", name="Test User", email="test@example.com")
|
||||
|
||||
account = MagicMock()
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
|
||||
token_pair = MagicMock()
|
||||
token_pair.access_token = "jwt_access_token"
|
||||
token_pair.refresh_token = "jwt_refresh_token"
|
||||
|
||||
return {"provider": oauth_provider, "account": account, "token_pair": token_pair}
|
||||
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth._generate_account")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.redirect")
|
||||
def test_should_handle_successful_oauth_callback(
|
||||
self,
|
||||
mock_redirect,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_generate_account,
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
resource,
|
||||
app,
|
||||
oauth_setup,
|
||||
):
|
||||
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||
mock_generate_account.return_value = oauth_setup["account"]
|
||||
mock_account_service.login.return_value = oauth_setup["token_pair"]
|
||||
|
||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
||||
resource.get("github")
|
||||
|
||||
oauth_setup["provider"].get_access_token.assert_called_once_with("test_code")
|
||||
oauth_setup["provider"].get_user_info.assert_called_once_with("access_token")
|
||||
mock_redirect.assert_called_once_with(
|
||||
"http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exception", "expected_error"),
|
||||
[
|
||||
(Exception("OAuth error"), "OAuth process failed"),
|
||||
(ValueError("Invalid token"), "OAuth process failed"),
|
||||
(KeyError("Missing key"), "OAuth process failed"),
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
def test_should_handle_oauth_exceptions(
|
||||
self, mock_get_providers, mock_db, resource, app, exception, expected_error
|
||||
):
|
||||
# Mock database session
|
||||
mock_db.session = MagicMock()
|
||||
mock_db.session.rollback = MagicMock()
|
||||
|
||||
# Import the real requests module to create a proper exception
|
||||
import requests
|
||||
|
||||
request_exception = requests.exceptions.RequestException("OAuth error")
|
||||
request_exception.response = MagicMock()
|
||||
request_exception.response.text = str(exception)
|
||||
|
||||
mock_oauth_provider = MagicMock()
|
||||
mock_oauth_provider.get_access_token.side_effect = request_exception
|
||||
mock_get_providers.return_value = {"github": mock_oauth_provider}
|
||||
|
||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
||||
response, status_code = resource.get("github")
|
||||
|
||||
assert status_code == 400
|
||||
assert response["error"] == expected_error
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("account_status", "expected_redirect"),
|
||||
[
|
||||
(AccountStatus.BANNED.value, "http://localhost:3000/signin?message=Account is banned."),
|
||||
# CLOSED status: Currently NOT handled, will proceed to login (security issue)
|
||||
# This documents actual behavior. See test_defensive_check_for_closed_account_status for details
|
||||
(
|
||||
AccountStatus.CLOSED.value,
|
||||
"http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token",
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth._generate_account")
|
||||
@patch("controllers.console.auth.oauth.redirect")
|
||||
def test_should_redirect_based_on_account_status(
|
||||
self,
|
||||
mock_redirect,
|
||||
mock_generate_account,
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
mock_db,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
resource,
|
||||
app,
|
||||
oauth_setup,
|
||||
account_status,
|
||||
expected_redirect,
|
||||
):
|
||||
# Mock database session
|
||||
mock_db.session = MagicMock()
|
||||
mock_db.session.rollback = MagicMock()
|
||||
mock_db.session.commit = MagicMock()
|
||||
|
||||
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||
|
||||
account = MagicMock()
|
||||
account.status = account_status
|
||||
account.id = "123"
|
||||
mock_generate_account.return_value = account
|
||||
|
||||
# Mock login for CLOSED status
|
||||
mock_token_pair = MagicMock()
|
||||
mock_token_pair.access_token = "jwt_access_token"
|
||||
mock_token_pair.refresh_token = "jwt_refresh_token"
|
||||
mock_account_service.login.return_value = mock_token_pair
|
||||
|
||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
||||
resource.get("github")
|
||||
|
||||
mock_redirect.assert_called_once_with(expected_redirect)
|
||||
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth._generate_account")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
def test_should_activate_pending_account(
|
||||
self,
|
||||
mock_account_service,
|
||||
mock_tenant_service,
|
||||
mock_db,
|
||||
mock_generate_account,
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
resource,
|
||||
app,
|
||||
oauth_setup,
|
||||
):
|
||||
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||
|
||||
mock_account = MagicMock()
|
||||
mock_account.status = AccountStatus.PENDING.value
|
||||
mock_generate_account.return_value = mock_account
|
||||
|
||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
||||
resource.get("github")
|
||||
|
||||
assert mock_account.status == AccountStatus.ACTIVE.value
|
||||
assert mock_account.initialized_at is not None
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth._generate_account")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.redirect")
|
||||
def test_defensive_check_for_closed_account_status(
|
||||
self,
|
||||
mock_redirect,
|
||||
mock_account_service,
|
||||
mock_tenant_service,
|
||||
mock_db,
|
||||
mock_generate_account,
|
||||
mock_get_providers,
|
||||
mock_config,
|
||||
resource,
|
||||
app,
|
||||
oauth_setup,
|
||||
):
|
||||
"""Defensive test for CLOSED account status handling in OAuth callback.
|
||||
|
||||
This is a defensive test documenting expected security behavior for CLOSED accounts.
|
||||
|
||||
Current behavior: CLOSED status is NOT checked, allowing closed accounts to login.
|
||||
Expected behavior: CLOSED accounts should be rejected like BANNED accounts.
|
||||
|
||||
Context:
|
||||
- AccountStatus.CLOSED is defined in the enum but never used in production
|
||||
- The close_account() method exists but is never called
|
||||
- Account deletion uses external service instead of status change
|
||||
- All authentication services (OAuth, password, email) don't check CLOSED status
|
||||
|
||||
TODO: If CLOSED status is implemented in the future:
|
||||
1. Update OAuth callback to check for CLOSED status
|
||||
2. Add similar checks to all authentication services for consistency
|
||||
3. Update this test to verify the rejection behavior
|
||||
|
||||
Security consideration: Until properly implemented, CLOSED status provides no protection.
|
||||
"""
|
||||
# Setup
|
||||
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
|
||||
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
|
||||
|
||||
# Create account with CLOSED status
|
||||
closed_account = MagicMock()
|
||||
closed_account.status = AccountStatus.CLOSED.value
|
||||
closed_account.id = "123"
|
||||
closed_account.name = "Closed Account"
|
||||
mock_generate_account.return_value = closed_account
|
||||
|
||||
# Mock successful login (current behavior)
|
||||
mock_token_pair = MagicMock()
|
||||
mock_token_pair.access_token = "jwt_access_token"
|
||||
mock_token_pair.refresh_token = "jwt_refresh_token"
|
||||
mock_account_service.login.return_value = mock_token_pair
|
||||
|
||||
# Execute OAuth callback
|
||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
||||
resource.get("github")
|
||||
|
||||
# Verify current behavior: login succeeds (this is NOT ideal)
|
||||
mock_redirect.assert_called_once_with(
|
||||
"http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token"
|
||||
)
|
||||
mock_account_service.login.assert_called_once()
|
||||
|
||||
# Document expected behavior in comments:
|
||||
# Expected: mock_redirect.assert_called_once_with(
|
||||
# "http://localhost:3000/signin?message=Account is closed."
|
||||
# )
|
||||
# Expected: mock_account_service.login.assert_not_called()
|
||||
|
||||
|
||||
class TestAccountGeneration:
|
||||
@pytest.fixture
|
||||
def user_info(self):
|
||||
return OAuthUserInfo(id="123", name="Test User", email="test@example.com")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_account(self):
|
||||
account = MagicMock()
|
||||
account.name = "Test User"
|
||||
return account
|
||||
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
@patch("controllers.console.auth.oauth.Account")
|
||||
@patch("controllers.console.auth.oauth.Session")
|
||||
@patch("controllers.console.auth.oauth.select")
|
||||
def test_should_get_account_by_openid_or_email(
|
||||
self, mock_select, mock_session, mock_account_model, mock_db, user_info, mock_account
|
||||
):
|
||||
# Mock db.engine for Session creation
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Test OpenID found
|
||||
mock_account_model.get_by_openid.return_value = mock_account
|
||||
result = _get_account_by_openid_or_email("github", user_info)
|
||||
assert result == mock_account
|
||||
mock_account_model.get_by_openid.assert_called_once_with("github", "123")
|
||||
|
||||
# Test fallback to email
|
||||
mock_account_model.get_by_openid.return_value = None
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.execute.return_value.scalar_one_or_none.return_value = mock_account
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
|
||||
result = _get_account_by_openid_or_email("github", user_info)
|
||||
assert result == mock_account
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("allow_register", "existing_account", "should_create"),
|
||||
[
|
||||
(True, None, True), # New account creation allowed
|
||||
(True, "existing", False), # Existing account
|
||||
(False, None, False), # Registration not allowed
|
||||
],
|
||||
)
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
|
||||
@patch("controllers.console.auth.oauth.FeatureService")
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.db")
|
||||
def test_should_handle_account_generation_scenarios(
|
||||
self,
|
||||
mock_db,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_register_service,
|
||||
mock_feature_service,
|
||||
mock_get_account,
|
||||
app,
|
||||
user_info,
|
||||
mock_account,
|
||||
allow_register,
|
||||
existing_account,
|
||||
should_create,
|
||||
):
|
||||
mock_get_account.return_value = mock_account if existing_account else None
|
||||
mock_feature_service.get_system_features.return_value.is_allow_register = allow_register
|
||||
mock_register_service.register.return_value = mock_account
|
||||
|
||||
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
|
||||
if not allow_register and not existing_account:
|
||||
with pytest.raises(AccountNotFoundError):
|
||||
_generate_account("github", user_info)
|
||||
else:
|
||||
result = _generate_account("github", user_info)
|
||||
assert result == mock_account
|
||||
|
||||
if should_create:
|
||||
mock_register_service.register.assert_called_once_with(
|
||||
email="test@example.com", name="Test User", password=None, open_id="123", provider="github"
|
||||
)
|
||||
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
@patch("controllers.console.auth.oauth.FeatureService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.tenant_was_created")
|
||||
def test_should_create_workspace_for_account_without_tenant(
|
||||
self,
|
||||
mock_event,
|
||||
mock_account_service,
|
||||
mock_feature_service,
|
||||
mock_tenant_service,
|
||||
mock_get_account,
|
||||
app,
|
||||
user_info,
|
||||
mock_account,
|
||||
):
|
||||
mock_get_account.return_value = mock_account
|
||||
mock_tenant_service.get_join_tenants.return_value = []
|
||||
mock_feature_service.get_system_features.return_value.is_allow_create_workspace = True
|
||||
|
||||
mock_new_tenant = MagicMock()
|
||||
mock_tenant_service.create_tenant.return_value = mock_new_tenant
|
||||
|
||||
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
|
||||
result = _generate_account("github", user_info)
|
||||
|
||||
assert result == mock_account
|
||||
mock_tenant_service.create_tenant.assert_called_once_with("Test User's Workspace")
|
||||
mock_tenant_service.create_tenant_member.assert_called_once_with(
|
||||
mock_new_tenant, mock_account, role="owner"
|
||||
)
|
||||
mock_event.send.assert_called_once_with(mock_new_tenant)
|
||||
@ -0,0 +1,249 @@
|
||||
import urllib.parse
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
|
||||
|
||||
class BaseOAuthTest:
|
||||
"""Base class for OAuth provider tests with common fixtures"""
|
||||
|
||||
@pytest.fixture
|
||||
def oauth_config(self):
|
||||
return {
|
||||
"client_id": "test_client_id",
|
||||
"client_secret": "test_client_secret",
|
||||
"redirect_uri": "http://localhost/callback",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response(self):
|
||||
response = MagicMock()
|
||||
response.json.return_value = {}
|
||||
return response
|
||||
|
||||
def parse_auth_url(self, url):
|
||||
"""Helper to parse authorization URL"""
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
params = urllib.parse.parse_qs(parsed.query)
|
||||
return parsed, params
|
||||
|
||||
|
||||
class TestGitHubOAuth(BaseOAuthTest):
|
||||
@pytest.fixture
|
||||
def oauth(self, oauth_config):
|
||||
return GitHubOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("invite_token", "expected_state"),
|
||||
[
|
||||
(None, None),
|
||||
("test_invite_token", "test_invite_token"),
|
||||
("", None),
|
||||
],
|
||||
)
|
||||
def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
|
||||
url = oauth.get_authorization_url(invite_token)
|
||||
parsed, params = self.parse_auth_url(url)
|
||||
|
||||
assert parsed.scheme == "https"
|
||||
assert parsed.netloc == "github.com"
|
||||
assert parsed.path == "/login/oauth/authorize"
|
||||
assert params["client_id"][0] == oauth_config["client_id"]
|
||||
assert params["redirect_uri"][0] == oauth_config["redirect_uri"]
|
||||
assert params["scope"][0] == "user:email"
|
||||
|
||||
if expected_state:
|
||||
assert params["state"][0] == expected_state
|
||||
else:
|
||||
assert "state" not in params
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("response_data", "expected_token", "should_raise"),
|
||||
[
|
||||
({"access_token": "test_token"}, "test_token", False),
|
||||
({"error": "invalid_grant"}, None, True),
|
||||
({}, None, True),
|
||||
],
|
||||
)
|
||||
@patch("requests.post")
|
||||
def test_should_retrieve_access_token(
|
||||
self, mock_post, oauth, mock_response, response_data, expected_token, should_raise
|
||||
):
|
||||
mock_response.json.return_value = response_data
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
if should_raise:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
oauth.get_access_token("test_code")
|
||||
assert "Error in GitHub OAuth" in str(exc_info.value)
|
||||
else:
|
||||
token = oauth.get_access_token("test_code")
|
||||
assert token == expected_token
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("user_data", "email_data", "expected_email"),
|
||||
[
|
||||
# User with primary email
|
||||
(
|
||||
{"id": 12345, "login": "testuser", "name": "Test User"},
|
||||
[
|
||||
{"email": "secondary@example.com", "primary": False},
|
||||
{"email": "primary@example.com", "primary": True},
|
||||
],
|
||||
"primary@example.com",
|
||||
),
|
||||
# User with no emails - fallback to noreply
|
||||
({"id": 12345, "login": "testuser", "name": "Test User"}, [], "12345+testuser@users.noreply.github.com"),
|
||||
# User with only secondary email - fallback to noreply
|
||||
(
|
||||
{"id": 12345, "login": "testuser", "name": "Test User"},
|
||||
[{"email": "secondary@example.com", "primary": False}],
|
||||
"12345+testuser@users.noreply.github.com",
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("requests.get")
|
||||
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email):
|
||||
user_response = MagicMock()
|
||||
user_response.json.return_value = user_data
|
||||
|
||||
email_response = MagicMock()
|
||||
email_response.json.return_value = email_data
|
||||
|
||||
mock_get.side_effect = [user_response, email_response]
|
||||
|
||||
user_info = oauth.get_user_info("test_token")
|
||||
|
||||
assert user_info.id == str(user_data["id"])
|
||||
assert user_info.name == user_data["name"]
|
||||
assert user_info.email == expected_email
|
||||
|
||||
@patch("requests.get")
|
||||
def test_should_handle_network_errors(self, mock_get, oauth):
|
||||
mock_get.side_effect = requests.exceptions.RequestException("Network error")
|
||||
|
||||
with pytest.raises(requests.exceptions.RequestException):
|
||||
oauth.get_raw_user_info("test_token")
|
||||
|
||||
|
||||
class TestGoogleOAuth(BaseOAuthTest):
|
||||
@pytest.fixture
|
||||
def oauth(self, oauth_config):
|
||||
return GoogleOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("invite_token", "expected_state"),
|
||||
[
|
||||
(None, None),
|
||||
("test_invite_token", "test_invite_token"),
|
||||
("", None),
|
||||
],
|
||||
)
|
||||
def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
|
||||
url = oauth.get_authorization_url(invite_token)
|
||||
parsed, params = self.parse_auth_url(url)
|
||||
|
||||
assert parsed.scheme == "https"
|
||||
assert parsed.netloc == "accounts.google.com"
|
||||
assert parsed.path == "/o/oauth2/v2/auth"
|
||||
assert params["client_id"][0] == oauth_config["client_id"]
|
||||
assert params["redirect_uri"][0] == oauth_config["redirect_uri"]
|
||||
assert params["response_type"][0] == "code"
|
||||
assert params["scope"][0] == "openid email"
|
||||
|
||||
if expected_state:
|
||||
assert params["state"][0] == expected_state
|
||||
else:
|
||||
assert "state" not in params
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("response_data", "expected_token", "should_raise"),
|
||||
[
|
||||
({"access_token": "test_token"}, "test_token", False),
|
||||
({"error": "invalid_grant"}, None, True),
|
||||
({}, None, True),
|
||||
],
|
||||
)
|
||||
@patch("requests.post")
|
||||
def test_should_retrieve_access_token(
|
||||
self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise
|
||||
):
|
||||
mock_response.json.return_value = response_data
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
if should_raise:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
oauth.get_access_token("test_code")
|
||||
assert "Error in Google OAuth" in str(exc_info.value)
|
||||
else:
|
||||
token = oauth.get_access_token("test_code")
|
||||
assert token == expected_token
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
oauth._TOKEN_URL,
|
||||
data={
|
||||
"client_id": oauth_config["client_id"],
|
||||
"client_secret": oauth_config["client_secret"],
|
||||
"code": "test_code",
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": oauth_config["redirect_uri"],
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("user_data", "expected_name"),
|
||||
[
|
||||
({"sub": "123", "email": "test@example.com", "email_verified": True}, ""),
|
||||
({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string
|
||||
],
|
||||
)
|
||||
@patch("requests.get")
|
||||
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name):
|
||||
mock_response.json.return_value = user_data
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
user_info = oauth.get_user_info("test_token")
|
||||
|
||||
assert user_info.id == user_data["sub"]
|
||||
assert user_info.name == expected_name
|
||||
assert user_info.email == user_data["email"]
|
||||
|
||||
mock_get.assert_called_once_with(oauth._USER_INFO_URL, headers={"Authorization": "Bearer test_token"})
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"exception_type",
|
||||
[
|
||||
requests.exceptions.HTTPError,
|
||||
requests.exceptions.ConnectionError,
|
||||
requests.exceptions.Timeout,
|
||||
],
|
||||
)
|
||||
@patch("requests.get")
|
||||
def test_should_handle_http_errors(self, mock_get, oauth, exception_type):
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.side_effect = exception_type("Error")
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with pytest.raises(exception_type):
|
||||
oauth.get_raw_user_info("invalid_token")
|
||||
|
||||
|
||||
class TestOAuthUserInfo:
|
||||
@pytest.mark.parametrize(
|
||||
"user_data",
|
||||
[
|
||||
{"id": "123", "name": "Test User", "email": "test@example.com"},
|
||||
{"id": "456", "name": "", "email": "user@domain.com"},
|
||||
{"id": "789", "name": "Another User", "email": "another@test.org"},
|
||||
],
|
||||
)
|
||||
def test_should_create_user_info_dataclass(self, user_data):
|
||||
user_info = OAuthUserInfo(**user_data)
|
||||
|
||||
assert user_info.id == user_data["id"]
|
||||
assert user_info.name == user_data["name"]
|
||||
assert user_info.email == user_data["email"]
|
||||
@ -0,0 +1,619 @@
|
||||
import base64
|
||||
import hashlib
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Random import get_random_bytes
|
||||
from Crypto.Util.Padding import pad
|
||||
|
||||
from core.tools.utils.system_oauth_encryption import (
|
||||
OAuthEncryptionError,
|
||||
SystemOAuthEncrypter,
|
||||
create_system_oauth_encrypter,
|
||||
decrypt_system_oauth_params,
|
||||
encrypt_system_oauth_params,
|
||||
get_system_oauth_encrypter,
|
||||
)
|
||||
|
||||
|
||||
class TestSystemOAuthEncrypter:
|
||||
"""Test cases for SystemOAuthEncrypter class"""
|
||||
|
||||
def test_init_with_secret_key(self):
|
||||
"""Test initialization with provided secret key"""
|
||||
secret_key = "test_secret_key"
|
||||
encrypter = SystemOAuthEncrypter(secret_key=secret_key)
|
||||
expected_key = hashlib.sha256(secret_key.encode()).digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
def test_init_with_none_secret_key(self):
|
||||
"""Test initialization with None secret key falls back to config"""
|
||||
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "config_secret"
|
||||
encrypter = SystemOAuthEncrypter(secret_key=None)
|
||||
expected_key = hashlib.sha256(b"config_secret").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
def test_init_with_empty_secret_key(self):
|
||||
"""Test initialization with empty secret key"""
|
||||
encrypter = SystemOAuthEncrypter(secret_key="")
|
||||
expected_key = hashlib.sha256(b"").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
def test_init_without_secret_key_uses_config(self):
|
||||
"""Test initialization without secret key uses config"""
|
||||
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "default_secret"
|
||||
encrypter = SystemOAuthEncrypter()
|
||||
expected_key = hashlib.sha256(b"default_secret").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
def test_encrypt_oauth_params_basic(self):
|
||||
"""Test basic OAuth parameters encryption"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
# Should be valid base64
|
||||
try:
|
||||
base64.b64decode(encrypted)
|
||||
except Exception:
|
||||
pytest.fail("Encrypted result is not valid base64")
|
||||
|
||||
def test_encrypt_oauth_params_empty_dict(self):
|
||||
"""Test encryption with empty dictionary"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
oauth_params = {}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
|
||||
def test_encrypt_oauth_params_complex_data(self):
|
||||
"""Test encryption with complex data structures"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
oauth_params = {
|
||||
"client_id": "test_id",
|
||||
"client_secret": "test_secret",
|
||||
"scopes": ["read", "write", "admin"],
|
||||
"metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True},
|
||||
"numeric_value": 42,
|
||||
"boolean_value": False,
|
||||
"null_value": None,
|
||||
}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
|
||||
def test_encrypt_oauth_params_unicode_data(self):
|
||||
"""Test encryption with unicode data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
oauth_params = {"client_id": "test_id", "client_secret": "test_secret", "description": "This is a test case 🚀"}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
|
||||
def test_encrypt_oauth_params_large_data(self):
|
||||
"""Test encryption with large data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
oauth_params = {
|
||||
"client_id": "test_id",
|
||||
"large_data": "x" * 10000, # 10KB of data
|
||||
}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
|
||||
def test_encrypt_oauth_params_invalid_input(self):
|
||||
"""Test encryption with invalid input types"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
encrypter.encrypt_oauth_params(None) # type: ignore
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
encrypter.encrypt_oauth_params("not_a_dict") # type: ignore
|
||||
|
||||
def test_decrypt_oauth_params_basic(self):
|
||||
"""Test basic OAuth parameters decryption"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
original_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
|
||||
assert decrypted == original_params
|
||||
|
||||
def test_decrypt_oauth_params_empty_dict(self):
|
||||
"""Test decryption of empty dictionary"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
original_params = {}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
|
||||
assert decrypted == original_params
|
||||
|
||||
def test_decrypt_oauth_params_complex_data(self):
|
||||
"""Test decryption with complex data structures"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
original_params = {
|
||||
"client_id": "test_id",
|
||||
"client_secret": "test_secret",
|
||||
"scopes": ["read", "write", "admin"],
|
||||
"metadata": {"issuer": "test_issuer", "expires_in": 3600, "is_active": True},
|
||||
"numeric_value": 42,
|
||||
"boolean_value": False,
|
||||
"null_value": None,
|
||||
}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
|
||||
assert decrypted == original_params
|
||||
|
||||
def test_decrypt_oauth_params_unicode_data(self):
|
||||
"""Test decryption with unicode data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
original_params = {
|
||||
"client_id": "test_id",
|
||||
"client_secret": "test_secret",
|
||||
"description": "This is a test case 🚀",
|
||||
}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
|
||||
assert decrypted == original_params
|
||||
|
||||
def test_decrypt_oauth_params_large_data(self):
|
||||
"""Test decryption with large data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
original_params = {
|
||||
"client_id": "test_id",
|
||||
"large_data": "x" * 10000, # 10KB of data
|
||||
}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
|
||||
assert decrypted == original_params
|
||||
|
||||
def test_decrypt_oauth_params_invalid_base64(self):
|
||||
"""Test decryption with invalid base64 data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
|
||||
with pytest.raises(OAuthEncryptionError):
|
||||
encrypter.decrypt_oauth_params("invalid_base64!")
|
||||
|
||||
def test_decrypt_oauth_params_empty_string(self):
|
||||
"""Test decryption with empty string"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_oauth_params("")
|
||||
|
||||
assert "encrypted_data cannot be empty" in str(exc_info.value)
|
||||
|
||||
def test_decrypt_oauth_params_non_string_input(self):
|
||||
"""Test decryption with non-string input"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_oauth_params(123) # type: ignore
|
||||
|
||||
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_oauth_params(None) # type: ignore
|
||||
|
||||
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||
|
||||
def test_decrypt_oauth_params_too_short_data(self):
|
||||
"""Test decryption with too short encrypted data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
|
||||
# Create data that's too short (less than 32 bytes)
|
||||
short_data = base64.b64encode(b"short").decode()
|
||||
|
||||
with pytest.raises(OAuthEncryptionError) as exc_info:
|
||||
encrypter.decrypt_oauth_params(short_data)
|
||||
|
||||
assert "Invalid encrypted data format" in str(exc_info.value)
|
||||
|
||||
def test_decrypt_oauth_params_corrupted_data(self):
|
||||
"""Test decryption with corrupted data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
|
||||
# Create corrupted data (valid base64 but invalid encrypted content)
|
||||
corrupted_data = base64.b64encode(b"x" * 48).decode() # 48 bytes of garbage
|
||||
|
||||
with pytest.raises(OAuthEncryptionError):
|
||||
encrypter.decrypt_oauth_params(corrupted_data)
|
||||
|
||||
def test_decrypt_oauth_params_wrong_key(self):
|
||||
"""Test decryption with wrong key"""
|
||||
encrypter1 = SystemOAuthEncrypter("secret1")
|
||||
encrypter2 = SystemOAuthEncrypter("secret2")
|
||||
|
||||
original_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
encrypted = encrypter1.encrypt_oauth_params(original_params)
|
||||
|
||||
with pytest.raises(OAuthEncryptionError):
|
||||
encrypter2.decrypt_oauth_params(encrypted)
|
||||
|
||||
def test_encryption_decryption_consistency(self):
|
||||
"""Test that encryption and decryption are consistent"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
|
||||
test_cases = [
|
||||
{},
|
||||
{"simple": "value"},
|
||||
{"client_id": "id", "client_secret": "secret"},
|
||||
{"complex": {"nested": {"deep": "value"}}},
|
||||
{"unicode": "test 🚀"},
|
||||
{"numbers": 42, "boolean": True, "null": None},
|
||||
{"array": [1, 2, 3, "four", {"five": 5}]},
|
||||
]
|
||||
|
||||
for original_params in test_cases:
|
||||
encrypted = encrypter.encrypt_oauth_params(original_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
assert decrypted == original_params, f"Failed for case: {original_params}"
|
||||
|
||||
def test_encryption_randomness(self):
|
||||
"""Test that encryption produces different results for same input"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted1 = encrypter.encrypt_oauth_params(oauth_params)
|
||||
encrypted2 = encrypter.encrypt_oauth_params(oauth_params)
|
||||
|
||||
# Should be different due to random IV
|
||||
assert encrypted1 != encrypted2
|
||||
|
||||
# But should decrypt to same result
|
||||
decrypted1 = encrypter.decrypt_oauth_params(encrypted1)
|
||||
decrypted2 = encrypter.decrypt_oauth_params(encrypted2)
|
||||
assert decrypted1 == decrypted2 == oauth_params
|
||||
|
||||
def test_different_secret_keys_produce_different_results(self):
|
||||
"""Test that different secret keys produce different encrypted results"""
|
||||
encrypter1 = SystemOAuthEncrypter("secret1")
|
||||
encrypter2 = SystemOAuthEncrypter("secret2")
|
||||
|
||||
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted1 = encrypter1.encrypt_oauth_params(oauth_params)
|
||||
encrypted2 = encrypter2.encrypt_oauth_params(oauth_params)
|
||||
|
||||
# Should produce different encrypted results
|
||||
assert encrypted1 != encrypted2
|
||||
|
||||
# But each should decrypt correctly with its own key
|
||||
decrypted1 = encrypter1.decrypt_oauth_params(encrypted1)
|
||||
decrypted2 = encrypter2.decrypt_oauth_params(encrypted2)
|
||||
assert decrypted1 == decrypted2 == oauth_params
|
||||
|
||||
@patch("core.tools.utils.system_oauth_encryption.get_random_bytes")
|
||||
def test_encrypt_oauth_params_crypto_error(self, mock_get_random_bytes):
|
||||
"""Test encryption when crypto operation fails"""
|
||||
mock_get_random_bytes.side_effect = Exception("Crypto error")
|
||||
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
oauth_params = {"client_id": "test_id"}
|
||||
|
||||
with pytest.raises(OAuthEncryptionError) as exc_info:
|
||||
encrypter.encrypt_oauth_params(oauth_params)
|
||||
|
||||
assert "Encryption failed" in str(exc_info.value)
|
||||
|
||||
@patch("core.tools.utils.system_oauth_encryption.TypeAdapter")
|
||||
def test_encrypt_oauth_params_serialization_error(self, mock_type_adapter):
|
||||
"""Test encryption when JSON serialization fails"""
|
||||
mock_type_adapter.return_value.dump_json.side_effect = Exception("Serialization error")
|
||||
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
oauth_params = {"client_id": "test_id"}
|
||||
|
||||
with pytest.raises(OAuthEncryptionError) as exc_info:
|
||||
encrypter.encrypt_oauth_params(oauth_params)
|
||||
|
||||
assert "Encryption failed" in str(exc_info.value)
|
||||
|
||||
def test_decrypt_oauth_params_invalid_json(self):
|
||||
"""Test decryption with invalid JSON data"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
|
||||
# Create valid encrypted data but with invalid JSON content
|
||||
iv = get_random_bytes(16)
|
||||
cipher = AES.new(encrypter.key, AES.MODE_CBC, iv)
|
||||
invalid_json = b"invalid json content"
|
||||
padded_data = pad(invalid_json, AES.block_size)
|
||||
encrypted_data = cipher.encrypt(padded_data)
|
||||
combined = iv + encrypted_data
|
||||
encoded = base64.b64encode(combined).decode()
|
||||
|
||||
with pytest.raises(OAuthEncryptionError):
|
||||
encrypter.decrypt_oauth_params(encoded)
|
||||
|
||||
def test_key_derivation_consistency(self):
|
||||
"""Test that key derivation is consistent"""
|
||||
secret_key = "test_secret"
|
||||
encrypter1 = SystemOAuthEncrypter(secret_key)
|
||||
encrypter2 = SystemOAuthEncrypter(secret_key)
|
||||
|
||||
assert encrypter1.key == encrypter2.key
|
||||
|
||||
# Keys should be 32 bytes (256 bits)
|
||||
assert len(encrypter1.key) == 32
|
||||
|
||||
|
||||
class TestFactoryFunctions:
|
||||
"""Test cases for factory functions"""
|
||||
|
||||
def test_create_system_oauth_encrypter_with_secret(self):
|
||||
"""Test factory function with secret key"""
|
||||
secret_key = "test_secret"
|
||||
encrypter = create_system_oauth_encrypter(secret_key)
|
||||
|
||||
assert isinstance(encrypter, SystemOAuthEncrypter)
|
||||
expected_key = hashlib.sha256(secret_key.encode()).digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
def test_create_system_oauth_encrypter_without_secret(self):
|
||||
"""Test factory function without secret key"""
|
||||
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "config_secret"
|
||||
encrypter = create_system_oauth_encrypter()
|
||||
|
||||
assert isinstance(encrypter, SystemOAuthEncrypter)
|
||||
expected_key = hashlib.sha256(b"config_secret").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
def test_create_system_oauth_encrypter_with_none_secret(self):
|
||||
"""Test factory function with None secret key"""
|
||||
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "config_secret"
|
||||
encrypter = create_system_oauth_encrypter(None)
|
||||
|
||||
assert isinstance(encrypter, SystemOAuthEncrypter)
|
||||
expected_key = hashlib.sha256(b"config_secret").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
|
||||
class TestGlobalEncrypterInstance:
|
||||
"""Test cases for global encrypter instance"""
|
||||
|
||||
def test_get_system_oauth_encrypter_singleton(self):
|
||||
"""Test that get_system_oauth_encrypter returns singleton instance"""
|
||||
# Clear the global instance first
|
||||
import core.tools.utils.system_oauth_encryption
|
||||
|
||||
core.tools.utils.system_oauth_encryption._oauth_encrypter = None
|
||||
|
||||
encrypter1 = get_system_oauth_encrypter()
|
||||
encrypter2 = get_system_oauth_encrypter()
|
||||
|
||||
assert encrypter1 is encrypter2
|
||||
assert isinstance(encrypter1, SystemOAuthEncrypter)
|
||||
|
||||
def test_get_system_oauth_encrypter_uses_config(self):
|
||||
"""Test that global encrypter uses config"""
|
||||
# Clear the global instance first
|
||||
import core.tools.utils.system_oauth_encryption
|
||||
|
||||
core.tools.utils.system_oauth_encryption._oauth_encrypter = None
|
||||
|
||||
with patch("core.tools.utils.system_oauth_encryption.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "global_secret"
|
||||
encrypter = get_system_oauth_encrypter()
|
||||
|
||||
expected_key = hashlib.sha256(b"global_secret").digest()
|
||||
assert encrypter.key == expected_key
|
||||
|
||||
|
||||
class TestConvenienceFunctions:
|
||||
"""Test cases for convenience functions"""
|
||||
|
||||
def test_encrypt_system_oauth_params(self):
|
||||
"""Test encrypt_system_oauth_params convenience function"""
|
||||
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted = encrypt_system_oauth_params(oauth_params)
|
||||
|
||||
assert isinstance(encrypted, str)
|
||||
assert len(encrypted) > 0
|
||||
|
||||
def test_decrypt_system_oauth_params(self):
|
||||
"""Test decrypt_system_oauth_params convenience function"""
|
||||
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
encrypted = encrypt_system_oauth_params(oauth_params)
|
||||
decrypted = decrypt_system_oauth_params(encrypted)
|
||||
|
||||
assert decrypted == oauth_params
|
||||
|
||||
def test_convenience_functions_consistency(self):
|
||||
"""Test that convenience functions work consistently"""
|
||||
test_cases = [
|
||||
{},
|
||||
{"simple": "value"},
|
||||
{"client_id": "id", "client_secret": "secret"},
|
||||
{"complex": {"nested": {"deep": "value"}}},
|
||||
{"unicode": "test 🚀"},
|
||||
{"numbers": 42, "boolean": True, "null": None},
|
||||
]
|
||||
|
||||
for original_params in test_cases:
|
||||
encrypted = encrypt_system_oauth_params(original_params)
|
||||
decrypted = decrypt_system_oauth_params(encrypted)
|
||||
assert decrypted == original_params, f"Failed for case: {original_params}"
|
||||
|
||||
def test_convenience_functions_with_errors(self):
|
||||
"""Test convenience functions with error conditions"""
|
||||
# Test encryption with invalid input
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
encrypt_system_oauth_params(None) # type: ignore
|
||||
|
||||
# Test decryption with invalid input
|
||||
with pytest.raises(ValueError):
|
||||
decrypt_system_oauth_params("")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
decrypt_system_oauth_params(None) # type: ignore
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test cases for error handling"""
|
||||
|
||||
def test_oauth_encryption_error_inheritance(self):
|
||||
"""Test that OAuthEncryptionError is a proper exception"""
|
||||
error = OAuthEncryptionError("Test error")
|
||||
assert isinstance(error, Exception)
|
||||
assert str(error) == "Test error"
|
||||
|
||||
def test_oauth_encryption_error_with_cause(self):
|
||||
"""Test OAuthEncryptionError with cause"""
|
||||
original_error = ValueError("Original error")
|
||||
error = OAuthEncryptionError("Wrapper error")
|
||||
error.__cause__ = original_error
|
||||
|
||||
assert isinstance(error, Exception)
|
||||
assert str(error) == "Wrapper error"
|
||||
assert error.__cause__ is original_error
|
||||
|
||||
def test_error_messages_are_informative(self):
|
||||
"""Test that error messages are informative"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
|
||||
# Test empty string error
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_oauth_params("")
|
||||
assert "encrypted_data cannot be empty" in str(exc_info.value)
|
||||
|
||||
# Test non-string error
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypter.decrypt_oauth_params(123) # type: ignore
|
||||
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||
|
||||
# Test invalid format error
|
||||
short_data = base64.b64encode(b"short").decode()
|
||||
with pytest.raises(OAuthEncryptionError) as exc_info:
|
||||
encrypter.decrypt_oauth_params(short_data)
|
||||
assert "Invalid encrypted data format" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test cases for edge cases and boundary conditions"""
|
||||
|
||||
def test_very_long_secret_key(self):
|
||||
"""Test with very long secret key"""
|
||||
long_secret = "x" * 10000
|
||||
encrypter = SystemOAuthEncrypter(long_secret)
|
||||
|
||||
# Key should still be 32 bytes due to SHA-256
|
||||
assert len(encrypter.key) == 32
|
||||
|
||||
# Should still work normally
|
||||
oauth_params = {"client_id": "test_id"}
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
|
||||
def test_special_characters_in_secret_key(self):
|
||||
"""Test with special characters in secret key"""
|
||||
special_secret = "!@#$%^&*()_+-=[]{}|;':\",./<>?`~test🚀"
|
||||
encrypter = SystemOAuthEncrypter(special_secret)
|
||||
|
||||
oauth_params = {"client_id": "test_id"}
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
|
||||
def test_empty_values_in_oauth_params(self):
|
||||
"""Test with empty values in oauth params"""
|
||||
oauth_params = {
|
||||
"client_id": "",
|
||||
"client_secret": "",
|
||||
"empty_dict": {},
|
||||
"empty_list": [],
|
||||
"empty_string": "",
|
||||
"zero": 0,
|
||||
"false": False,
|
||||
"none": None,
|
||||
}
|
||||
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
|
||||
def test_deeply_nested_oauth_params(self):
|
||||
"""Test with deeply nested oauth params"""
|
||||
oauth_params = {"level1": {"level2": {"level3": {"level4": {"level5": {"deep_value": "found"}}}}}}
|
||||
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
|
||||
def test_oauth_params_with_all_json_types(self):
|
||||
"""Test with all JSON-supported data types"""
|
||||
oauth_params = {
|
||||
"string": "test_string",
|
||||
"integer": 42,
|
||||
"float": 3.14159,
|
||||
"boolean_true": True,
|
||||
"boolean_false": False,
|
||||
"null_value": None,
|
||||
"empty_string": "",
|
||||
"array": [1, "two", 3.0, True, False, None],
|
||||
"object": {"nested_string": "nested_value", "nested_number": 123, "nested_bool": True},
|
||||
}
|
||||
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
|
||||
|
||||
class TestPerformance:
|
||||
"""Test cases for performance considerations"""
|
||||
|
||||
def test_large_oauth_params(self):
|
||||
"""Test with large oauth params"""
|
||||
large_value = "x" * 100000 # 100KB
|
||||
oauth_params = {"client_id": "test_id", "large_data": large_value}
|
||||
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
|
||||
def test_many_fields_oauth_params(self):
|
||||
"""Test with many fields in oauth params"""
|
||||
oauth_params = {f"field_{i}": f"value_{i}" for i in range(1000)}
|
||||
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
|
||||
def test_repeated_encryption_decryption(self):
|
||||
"""Test repeated encryption and decryption operations"""
|
||||
encrypter = SystemOAuthEncrypter("test_secret")
|
||||
oauth_params = {"client_id": "test_id", "client_secret": "test_secret"}
|
||||
|
||||
# Test multiple rounds of encryption/decryption
|
||||
for i in range(100):
|
||||
encrypted = encrypter.encrypt_oauth_params(oauth_params)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
assert decrypted == oauth_params
|
||||
@ -1 +1 @@
|
||||
from dify_client.client import ChatClient, CompletionClient, DifyClient
|
||||
from dify_client.client import ChatClient, CompletionClient, WorkflowClient, KnowledgeBaseClient, DifyClient
|
||||
|
||||
@ -0,0 +1,177 @@
|
||||
import {
|
||||
isValidElement,
|
||||
memo,
|
||||
useMemo,
|
||||
} from 'react'
|
||||
import type { AnyFieldApi } from '@tanstack/react-form'
|
||||
import { useStore } from '@tanstack/react-form'
|
||||
import cn from '@/utils/classnames'
|
||||
import Input from '@/app/components/base/input'
|
||||
import PureSelect from '@/app/components/base/select/pure'
|
||||
import type { FormSchema } from '@/app/components/base/form/types'
|
||||
import { FormTypeEnum } from '@/app/components/base/form/types'
|
||||
import { useRenderI18nObject } from '@/hooks/use-i18n'
|
||||
|
||||
export type BaseFieldProps = {
|
||||
fieldClassName?: string
|
||||
labelClassName?: string
|
||||
inputContainerClassName?: string
|
||||
inputClassName?: string
|
||||
formSchema: FormSchema
|
||||
field: AnyFieldApi
|
||||
disabled?: boolean
|
||||
}
|
||||
const BaseField = ({
|
||||
fieldClassName,
|
||||
labelClassName,
|
||||
inputContainerClassName,
|
||||
inputClassName,
|
||||
formSchema,
|
||||
field,
|
||||
disabled,
|
||||
}: BaseFieldProps) => {
|
||||
const renderI18nObject = useRenderI18nObject()
|
||||
const {
|
||||
label,
|
||||
required,
|
||||
placeholder,
|
||||
options,
|
||||
labelClassName: formLabelClassName,
|
||||
show_on = [],
|
||||
} = formSchema
|
||||
|
||||
const memorizedLabel = useMemo(() => {
|
||||
if (isValidElement(label))
|
||||
return label
|
||||
|
||||
if (typeof label === 'string')
|
||||
return label
|
||||
|
||||
if (typeof label === 'object' && label !== null)
|
||||
return renderI18nObject(label as Record<string, string>)
|
||||
}, [label, renderI18nObject])
|
||||
const memorizedPlaceholder = useMemo(() => {
|
||||
if (typeof placeholder === 'string')
|
||||
return placeholder
|
||||
|
||||
if (typeof placeholder === 'object' && placeholder !== null)
|
||||
return renderI18nObject(placeholder as Record<string, string>)
|
||||
}, [placeholder, renderI18nObject])
|
||||
const memorizedOptions = useMemo(() => {
|
||||
return options?.map((option) => {
|
||||
return {
|
||||
label: typeof option.label === 'string' ? option.label : renderI18nObject(option.label),
|
||||
value: option.value,
|
||||
}
|
||||
}) || []
|
||||
}, [options, renderI18nObject])
|
||||
const value = useStore(field.form.store, s => s.values[field.name])
|
||||
const values = useStore(field.form.store, (s) => {
|
||||
return show_on.reduce((acc, condition) => {
|
||||
acc[condition.variable] = s.values[condition.variable]
|
||||
return acc
|
||||
}, {} as Record<string, any>)
|
||||
})
|
||||
const show = useMemo(() => {
|
||||
return show_on.every((condition) => {
|
||||
const conditionValue = values[condition.variable]
|
||||
return conditionValue === condition.value
|
||||
})
|
||||
}, [values, show_on])
|
||||
|
||||
if (!show)
|
||||
return null
|
||||
|
||||
return (
|
||||
<div className={cn(fieldClassName)}>
|
||||
<div className={cn(labelClassName, formLabelClassName)}>
|
||||
{memorizedLabel}
|
||||
{
|
||||
required && !isValidElement(label) && (
|
||||
<span className='ml-1 text-text-destructive-secondary'>*</span>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
<div className={cn(inputContainerClassName)}>
|
||||
{
|
||||
formSchema.type === FormTypeEnum.textInput && (
|
||||
<Input
|
||||
id={field.name}
|
||||
name={field.name}
|
||||
className={cn(inputClassName)}
|
||||
value={value || ''}
|
||||
onChange={e => field.handleChange(e.target.value)}
|
||||
onBlur={field.handleBlur}
|
||||
disabled={disabled}
|
||||
placeholder={memorizedPlaceholder}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
formSchema.type === FormTypeEnum.secretInput && (
|
||||
<Input
|
||||
id={field.name}
|
||||
name={field.name}
|
||||
type='password'
|
||||
className={cn(inputClassName)}
|
||||
value={value || ''}
|
||||
onChange={e => field.handleChange(e.target.value)}
|
||||
onBlur={field.handleBlur}
|
||||
disabled={disabled}
|
||||
placeholder={memorizedPlaceholder}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
formSchema.type === FormTypeEnum.textNumber && (
|
||||
<Input
|
||||
id={field.name}
|
||||
name={field.name}
|
||||
type='number'
|
||||
className={cn(inputClassName)}
|
||||
value={value || ''}
|
||||
onChange={e => field.handleChange(e.target.value)}
|
||||
onBlur={field.handleBlur}
|
||||
disabled={disabled}
|
||||
placeholder={memorizedPlaceholder}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
formSchema.type === FormTypeEnum.select && (
|
||||
<PureSelect
|
||||
value={value}
|
||||
onChange={v => field.handleChange(v)}
|
||||
disabled={disabled}
|
||||
placeholder={memorizedPlaceholder}
|
||||
options={memorizedOptions}
|
||||
triggerPopupSameWidth
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
formSchema.type === FormTypeEnum.radio && (
|
||||
<div className='flex items-center space-x-2'>
|
||||
{
|
||||
memorizedOptions.map(option => (
|
||||
<div
|
||||
key={option.value}
|
||||
className={cn(
|
||||
'system-sm-regular hover:bg-components-option-card-option-hover-bg hover:border-components-option-card-option-hover-border flex h-8 grow cursor-pointer items-center justify-center rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg p-2 text-text-secondary',
|
||||
value === option.value && 'border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary shadow-xs',
|
||||
)}
|
||||
onClick={() => field.handleChange(option.value)}
|
||||
>
|
||||
{option.label}
|
||||
</div>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(BaseField)
|
||||
@ -0,0 +1,115 @@
|
||||
import {
|
||||
memo,
|
||||
useCallback,
|
||||
useImperativeHandle,
|
||||
} from 'react'
|
||||
import type {
|
||||
AnyFieldApi,
|
||||
AnyFormApi,
|
||||
} from '@tanstack/react-form'
|
||||
import { useForm } from '@tanstack/react-form'
|
||||
import type {
|
||||
FormRef,
|
||||
FormSchema,
|
||||
} from '@/app/components/base/form/types'
|
||||
import {
|
||||
BaseField,
|
||||
} from '.'
|
||||
import type {
|
||||
BaseFieldProps,
|
||||
} from '.'
|
||||
import cn from '@/utils/classnames'
|
||||
import {
|
||||
useGetFormValues,
|
||||
useGetValidators,
|
||||
} from '@/app/components/base/form/hooks'
|
||||
|
||||
export type BaseFormProps = {
|
||||
formSchemas?: FormSchema[]
|
||||
defaultValues?: Record<string, any>
|
||||
formClassName?: string
|
||||
ref?: FormRef
|
||||
disabled?: boolean
|
||||
formFromProps?: AnyFormApi
|
||||
} & Pick<BaseFieldProps, 'fieldClassName' | 'labelClassName' | 'inputContainerClassName' | 'inputClassName'>
|
||||
|
||||
const BaseForm = ({
|
||||
formSchemas = [],
|
||||
defaultValues,
|
||||
formClassName,
|
||||
fieldClassName,
|
||||
labelClassName,
|
||||
inputContainerClassName,
|
||||
inputClassName,
|
||||
ref,
|
||||
disabled,
|
||||
formFromProps,
|
||||
}: BaseFormProps) => {
|
||||
const formFromHook = useForm({
|
||||
defaultValues,
|
||||
})
|
||||
const form: any = formFromProps || formFromHook
|
||||
const { getFormValues } = useGetFormValues(form, formSchemas)
|
||||
const { getValidators } = useGetValidators()
|
||||
|
||||
useImperativeHandle(ref, () => {
|
||||
return {
|
||||
getForm() {
|
||||
return form
|
||||
},
|
||||
getFormValues: (option) => {
|
||||
return getFormValues(option)
|
||||
},
|
||||
}
|
||||
}, [form, getFormValues])
|
||||
|
||||
const renderField = useCallback((field: AnyFieldApi) => {
|
||||
const formSchema = formSchemas?.find(schema => schema.name === field.name)
|
||||
|
||||
if (formSchema) {
|
||||
return (
|
||||
<BaseField
|
||||
field={field}
|
||||
formSchema={formSchema}
|
||||
fieldClassName={fieldClassName}
|
||||
labelClassName={labelClassName}
|
||||
inputContainerClassName={inputContainerClassName}
|
||||
inputClassName={inputClassName}
|
||||
disabled={disabled}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
return null
|
||||
}, [formSchemas, fieldClassName, labelClassName, inputContainerClassName, inputClassName, disabled])
|
||||
|
||||
const renderFieldWrapper = useCallback((formSchema: FormSchema) => {
|
||||
const validators = getValidators(formSchema)
|
||||
const {
|
||||
name,
|
||||
} = formSchema
|
||||
|
||||
return (
|
||||
<form.Field
|
||||
key={name}
|
||||
name={name}
|
||||
validators={validators}
|
||||
>
|
||||
{renderField}
|
||||
</form.Field>
|
||||
)
|
||||
}, [renderField, form, getValidators])
|
||||
|
||||
if (!formSchemas?.length)
|
||||
return null
|
||||
|
||||
return (
|
||||
<form
|
||||
className={cn(formClassName)}
|
||||
>
|
||||
{formSchemas.map(renderFieldWrapper)}
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(BaseForm)
|
||||
@ -0,0 +1,2 @@
|
||||
export { default as BaseForm, type BaseFormProps } from './base-form'
|
||||
export { default as BaseField, type BaseFieldProps } from './base-field'
|
||||
@ -0,0 +1,23 @@
|
||||
import { memo } from 'react'
|
||||
import { BaseForm } from '../../components/base'
|
||||
import type { BaseFormProps } from '../../components/base'
|
||||
|
||||
const AuthForm = ({
|
||||
formSchemas = [],
|
||||
defaultValues,
|
||||
ref,
|
||||
formFromProps,
|
||||
}: BaseFormProps) => {
|
||||
return (
|
||||
<BaseForm
|
||||
ref={ref}
|
||||
formSchemas={formSchemas}
|
||||
defaultValues={defaultValues}
|
||||
formClassName='space-y-4'
|
||||
labelClassName='h-6 flex items-center mb-1 system-sm-medium text-text-secondary'
|
||||
formFromProps={formFromProps}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(AuthForm)
|
||||
@ -0,0 +1,3 @@
|
||||
export * from './use-check-validated'
|
||||
export * from './use-get-form-values'
|
||||
export * from './use-get-validators'
|
||||
@ -0,0 +1,48 @@
|
||||
import { useCallback } from 'react'
|
||||
import type { AnyFormApi } from '@tanstack/react-form'
|
||||
import { useToastContext } from '@/app/components/base/toast'
|
||||
import type { FormSchema } from '@/app/components/base/form/types'
|
||||
|
||||
export const useCheckValidated = (form: AnyFormApi, FormSchemas: FormSchema[]) => {
|
||||
const { notify } = useToastContext()
|
||||
|
||||
const checkValidated = useCallback(() => {
|
||||
const allError = form?.getAllErrors()
|
||||
const values = form.state.values
|
||||
|
||||
if (allError) {
|
||||
const fields = allError.fields
|
||||
const errorArray = Object.keys(fields).reduce((acc: string[], key: string) => {
|
||||
const currentSchema = FormSchemas.find(schema => schema.name === key)
|
||||
const { show_on = [] } = currentSchema || {}
|
||||
const showOnValues = show_on.reduce((acc, condition) => {
|
||||
acc[condition.variable] = values[condition.variable]
|
||||
return acc
|
||||
}, {} as Record<string, any>)
|
||||
const show = show_on?.every((condition) => {
|
||||
const conditionValue = showOnValues[condition.variable]
|
||||
return conditionValue === condition.value
|
||||
})
|
||||
const errors: any[] = show ? fields[key].errors : []
|
||||
|
||||
return [...acc, ...errors]
|
||||
}, [] as string[])
|
||||
|
||||
if (errorArray.length) {
|
||||
notify({
|
||||
type: 'error',
|
||||
message: errorArray[0],
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return true
|
||||
}, [form, notify, FormSchemas])
|
||||
|
||||
return {
|
||||
checkValidated,
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,44 @@
|
||||
import { useCallback } from 'react'
|
||||
import type { AnyFormApi } from '@tanstack/react-form'
|
||||
import { useCheckValidated } from './use-check-validated'
|
||||
import type {
|
||||
FormSchema,
|
||||
GetValuesOptions,
|
||||
} from '../types'
|
||||
import { getTransformedValuesWhenSecretInputPristine } from '../utils'
|
||||
|
||||
export const useGetFormValues = (form: AnyFormApi, formSchemas: FormSchema[]) => {
|
||||
const { checkValidated } = useCheckValidated(form, formSchemas)
|
||||
|
||||
const getFormValues = useCallback((
|
||||
{
|
||||
needCheckValidatedValues,
|
||||
needTransformWhenSecretFieldIsPristine,
|
||||
}: GetValuesOptions,
|
||||
) => {
|
||||
const values = form?.store.state.values || {}
|
||||
if (!needCheckValidatedValues) {
|
||||
return {
|
||||
values,
|
||||
isCheckValidated: false,
|
||||
}
|
||||
}
|
||||
|
||||
if (checkValidated()) {
|
||||
return {
|
||||
values: needTransformWhenSecretFieldIsPristine ? getTransformedValuesWhenSecretInputPristine(formSchemas, form) : values,
|
||||
isCheckValidated: true,
|
||||
}
|
||||
}
|
||||
else {
|
||||
return {
|
||||
values: {},
|
||||
isCheckValidated: false,
|
||||
}
|
||||
}
|
||||
}, [form, checkValidated, formSchemas])
|
||||
|
||||
return {
|
||||
getFormValues,
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,36 @@
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import type { FormSchema } from '../types'
|
||||
|
||||
export const useGetValidators = () => {
|
||||
const { t } = useTranslation()
|
||||
const getValidators = useCallback((formSchema: FormSchema) => {
|
||||
const {
|
||||
name,
|
||||
validators,
|
||||
required,
|
||||
} = formSchema
|
||||
let mergedValidators = validators
|
||||
if (required && !validators) {
|
||||
mergedValidators = {
|
||||
onMount: ({ value }: any) => {
|
||||
if (!value)
|
||||
return t('common.errorMsg.fieldRequired', { field: name })
|
||||
},
|
||||
onChange: ({ value }: any) => {
|
||||
if (!value)
|
||||
return t('common.errorMsg.fieldRequired', { field: name })
|
||||
},
|
||||
onBlur: ({ value }: any) => {
|
||||
if (!value)
|
||||
return t('common.errorMsg.fieldRequired', { field: name })
|
||||
},
|
||||
}
|
||||
}
|
||||
return mergedValidators
|
||||
}, [t])
|
||||
|
||||
return {
|
||||
getValidators,
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,76 @@
|
||||
import type {
|
||||
ForwardedRef,
|
||||
ReactNode,
|
||||
} from 'react'
|
||||
import type {
|
||||
AnyFormApi,
|
||||
FieldValidators,
|
||||
} from '@tanstack/react-form'
|
||||
|
||||
export type TypeWithI18N<T = string> = {
|
||||
en_US: T
|
||||
zh_Hans: T
|
||||
[key: string]: T
|
||||
}
|
||||
|
||||
export type FormShowOnObject = {
|
||||
variable: string
|
||||
value: string
|
||||
}
|
||||
|
||||
export enum FormTypeEnum {
|
||||
textInput = 'text-input',
|
||||
textNumber = 'number-input',
|
||||
secretInput = 'secret-input',
|
||||
select = 'select',
|
||||
radio = 'radio',
|
||||
boolean = 'boolean',
|
||||
files = 'files',
|
||||
file = 'file',
|
||||
modelSelector = 'model-selector',
|
||||
toolSelector = 'tool-selector',
|
||||
multiToolSelector = 'array[tools]',
|
||||
appSelector = 'app-selector',
|
||||
dynamicSelect = 'dynamic-select',
|
||||
}
|
||||
|
||||
export type FormOption = {
|
||||
label: TypeWithI18N | string
|
||||
value: string
|
||||
show_on?: FormShowOnObject[]
|
||||
icon?: string
|
||||
}
|
||||
|
||||
export type AnyValidators = FieldValidators<any, any, any, any, any, any, any, any, any, any>
|
||||
|
||||
export type FormSchema = {
|
||||
type: FormTypeEnum
|
||||
name: string
|
||||
label: string | ReactNode | TypeWithI18N
|
||||
required: boolean
|
||||
default?: any
|
||||
tooltip?: string | TypeWithI18N
|
||||
show_on?: FormShowOnObject[]
|
||||
url?: string
|
||||
scope?: string
|
||||
help?: string | TypeWithI18N
|
||||
placeholder?: string | TypeWithI18N
|
||||
options?: FormOption[]
|
||||
labelClassName?: string
|
||||
validators?: AnyValidators
|
||||
}
|
||||
|
||||
export type FormValues = Record<string, any>
|
||||
|
||||
export type GetValuesOptions = {
|
||||
needTransformWhenSecretFieldIsPristine?: boolean
|
||||
needCheckValidatedValues?: boolean
|
||||
}
|
||||
export type FormRefObject = {
|
||||
getForm: () => AnyFormApi
|
||||
getFormValues: (obj: GetValuesOptions) => {
|
||||
values: Record<string, any>
|
||||
isCheckValidated: boolean
|
||||
}
|
||||
}
|
||||
export type FormRef = ForwardedRef<FormRefObject>
|
||||
@ -0,0 +1 @@
|
||||
export * from './secret-input'
|
||||
@ -0,0 +1,29 @@
|
||||
import type { AnyFormApi } from '@tanstack/react-form'
|
||||
import type { FormSchema } from '@/app/components/base/form/types'
|
||||
import { FormTypeEnum } from '@/app/components/base/form/types'
|
||||
|
||||
export const transformFormSchemasSecretInput = (isPristineSecretInputNames: string[], values: Record<string, any>) => {
|
||||
const transformedValues: Record<string, any> = { ...values }
|
||||
|
||||
isPristineSecretInputNames.forEach((name) => {
|
||||
if (transformedValues[name])
|
||||
transformedValues[name] = '[__HIDDEN__]'
|
||||
})
|
||||
|
||||
return transformedValues
|
||||
}
|
||||
|
||||
export const getTransformedValuesWhenSecretInputPristine = (formSchemas: FormSchema[], form: AnyFormApi) => {
|
||||
const values = form?.store.state.values || {}
|
||||
const isPristineSecretInputNames: string[] = []
|
||||
for (let i = 0; i < formSchemas.length; i++) {
|
||||
const schema = formSchemas[i]
|
||||
if (schema.type === FormTypeEnum.secretInput) {
|
||||
const fieldMeta = form?.getFieldMeta(schema.name)
|
||||
if (fieldMeta?.isPristine)
|
||||
isPristineSecretInputNames.push(schema.name)
|
||||
}
|
||||
}
|
||||
|
||||
return transformFormSchemasSecretInput(isPristineSecretInputNames, values)
|
||||
}
|
||||
@ -1,3 +0,0 @@
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M16 4C16.93 4 17.395 4 17.7765 4.10222C18.8117 4.37962 19.6204 5.18827 19.8978 6.22354C20 6.60504 20 7.07003 20 8V17.2C20 18.8802 20 19.7202 19.673 20.362C19.3854 20.9265 18.9265 21.3854 18.362 21.673C17.7202 22 16.8802 22 15.2 22H8.8C7.11984 22 6.27976 22 5.63803 21.673C5.07354 21.3854 4.6146 20.9265 4.32698 20.362C4 19.7202 4 18.8802 4 17.2V8C4 7.07003 4 6.60504 4.10222 6.22354C4.37962 5.18827 5.18827 4.37962 6.22354 4.10222C6.60504 4 7.07003 4 8 4M9 15L11 17L15.5 12.5M9.6 6H14.4C14.9601 6 15.2401 6 15.454 5.89101C15.6422 5.79513 15.7951 5.64215 15.891 5.45399C16 5.24008 16 4.96005 16 4.4V3.6C16 3.03995 16 2.75992 15.891 2.54601C15.7951 2.35785 15.6422 2.20487 15.454 2.10899C15.2401 2 14.9601 2 14.4 2H9.6C9.03995 2 8.75992 2 8.54601 2.10899C8.35785 2.20487 8.20487 2.35785 8.10899 2.54601C8 2.75992 8 3.03995 8 3.6V4.4C8 4.96005 8 5.24008 8.10899 5.45399C8.20487 5.64215 8.35785 5.79513 8.54601 5.89101C8.75992 6 9.03995 6 9.6 6Z" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 1.1 KiB |
@ -1,3 +0,0 @@
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M16 4C16.93 4 17.395 4 17.7765 4.10222C18.8117 4.37962 19.6204 5.18827 19.8978 6.22354C20 6.60504 20 7.07003 20 8V17.2C20 18.8802 20 19.7202 19.673 20.362C19.3854 20.9265 18.9265 21.3854 18.362 21.673C17.7202 22 16.8802 22 15.2 22H8.8C7.11984 22 6.27976 22 5.63803 21.673C5.07354 21.3854 4.6146 20.9265 4.32698 20.362C4 19.7202 4 18.8802 4 17.2V8C4 7.07003 4 6.60504 4.10222 6.22354C4.37962 5.18827 5.18827 4.37962 6.22354 4.10222C6.60504 4 7.07003 4 8 4M9.6 6H14.4C14.9601 6 15.2401 6 15.454 5.89101C15.6422 5.79513 15.7951 5.64215 15.891 5.45399C16 5.24008 16 4.96005 16 4.4V3.6C16 3.03995 16 2.75992 15.891 2.54601C15.7951 2.35785 15.6422 2.20487 15.454 2.10899C15.2401 2 14.9601 2 14.4 2H9.6C9.03995 2 8.75992 2 8.54601 2.10899C8.35785 2.20487 8.20487 2.35785 8.10899 2.54601C8 2.75992 8 3.03995 8 3.6V4.4C8 4.96005 8 5.24008 8.10899 5.45399C8.20487 5.64215 8.35785 5.79513 8.54601 5.89101C8.75992 6 9.03995 6 9.6 6Z" stroke="black" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 1.1 KiB |
@ -0,0 +1,3 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M10.6665 2.66683C11.2865 2.66683 11.5965 2.66683 11.8508 2.73498C12.541 2.91991 13.0801 3.45901 13.265 4.14919C13.3332 4.40352 13.3332 4.71352 13.3332 5.3335V11.4668C13.3332 12.5869 13.3332 13.147 13.1152 13.5748C12.9234 13.9511 12.6175 14.2571 12.2412 14.4488C11.8133 14.6668 11.2533 14.6668 10.1332 14.6668H5.8665C4.7464 14.6668 4.18635 14.6668 3.75852 14.4488C3.3822 14.2571 3.07624 13.9511 2.88449 13.5748C2.6665 13.147 2.6665 12.5869 2.6665 11.4668V5.3335C2.6665 4.71352 2.6665 4.40352 2.73465 4.14919C2.91959 3.45901 3.45868 2.91991 4.14887 2.73498C4.4032 2.66683 4.71319 2.66683 5.33317 2.66683M5.99984 10.0002L7.33317 11.3335L10.3332 8.3335M6.39984 4.00016H9.59984C9.9732 4.00016 10.1599 4.00016 10.3025 3.9275C10.4279 3.86359 10.5299 3.7616 10.5938 3.63616C10.6665 3.49355 10.6665 3.30686 10.6665 2.9335V2.40016C10.6665 2.02679 10.6665 1.84011 10.5938 1.6975C10.5299 1.57206 10.4279 1.47007 10.3025 1.40616C10.1599 1.3335 9.97321 1.3335 9.59984 1.3335H6.39984C6.02647 1.3335 5.83978 1.3335 5.69718 1.40616C5.57174 1.47007 5.46975 1.57206 5.40583 1.6975C5.33317 1.84011 5.33317 2.02679 5.33317 2.40016V2.9335C5.33317 3.30686 5.33317 3.49355 5.40583 3.63616C5.46975 3.7616 5.57174 3.86359 5.69718 3.9275C5.83978 4.00016 6.02647 4.00016 6.39984 4.00016Z" stroke="#1D2939" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.4 KiB |
@ -0,0 +1,3 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M10.6665 2.66634H11.9998C12.3535 2.66634 12.6926 2.80682 12.9426 3.05687C13.1927 3.30691 13.3332 3.64605 13.3332 3.99967V13.333C13.3332 13.6866 13.1927 14.0258 12.9426 14.2758C12.6926 14.5259 12.3535 14.6663 11.9998 14.6663H3.99984C3.64622 14.6663 3.30708 14.5259 3.05703 14.2758C2.80698 14.0258 2.6665 13.6866 2.6665 13.333V3.99967C2.6665 3.64605 2.80698 3.30691 3.05703 3.05687C3.30708 2.80682 3.64622 2.66634 3.99984 2.66634H5.33317M5.99984 1.33301H9.99984C10.368 1.33301 10.6665 1.63148 10.6665 1.99967V3.33301C10.6665 3.7012 10.368 3.99967 9.99984 3.99967H5.99984C5.63165 3.99967 5.33317 3.7012 5.33317 3.33301V1.99967C5.33317 1.63148 5.63165 1.33301 5.99984 1.33301Z" stroke="#667085" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 875 B |
@ -1,29 +0,0 @@
|
||||
{
|
||||
"icon": {
|
||||
"type": "element",
|
||||
"isRootNode": true,
|
||||
"name": "svg",
|
||||
"attributes": {
|
||||
"width": "24",
|
||||
"height": "24",
|
||||
"viewBox": "0 0 24 24",
|
||||
"fill": "none",
|
||||
"xmlns": "http://www.w3.org/2000/svg"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M16 4C16.93 4 17.395 4 17.7765 4.10222C18.8117 4.37962 19.6204 5.18827 19.8978 6.22354C20 6.60504 20 7.07003 20 8V17.2C20 18.8802 20 19.7202 19.673 20.362C19.3854 20.9265 18.9265 21.3854 18.362 21.673C17.7202 22 16.8802 22 15.2 22H8.8C7.11984 22 6.27976 22 5.63803 21.673C5.07354 21.3854 4.6146 20.9265 4.32698 20.362C4 19.7202 4 18.8802 4 17.2V8C4 7.07003 4 6.60504 4.10222 6.22354C4.37962 5.18827 5.18827 4.37962 6.22354 4.10222C6.60504 4 7.07003 4 8 4M9.6 6H14.4C14.9601 6 15.2401 6 15.454 5.89101C15.6422 5.79513 15.7951 5.64215 15.891 5.45399C16 5.24008 16 4.96005 16 4.4V3.6C16 3.03995 16 2.75992 15.891 2.54601C15.7951 2.35785 15.6422 2.20487 15.454 2.10899C15.2401 2 14.9601 2 14.4 2H9.6C9.03995 2 8.75992 2 8.54601 2.10899C8.35785 2.20487 8.20487 2.35785 8.10899 2.54601C8 2.75992 8 3.03995 8 3.6V4.4C8 4.96005 8 5.24008 8.10899 5.45399C8.20487 5.64215 8.35785 5.79513 8.54601 5.89101C8.75992 6 9.03995 6 9.6 6Z",
|
||||
"stroke": "currentColor",
|
||||
"stroke-width": "2",
|
||||
"stroke-linecap": "round",
|
||||
"stroke-linejoin": "round"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "Clipboard"
|
||||
}
|
||||
@ -1,29 +0,0 @@
|
||||
{
|
||||
"icon": {
|
||||
"type": "element",
|
||||
"isRootNode": true,
|
||||
"name": "svg",
|
||||
"attributes": {
|
||||
"width": "24",
|
||||
"height": "24",
|
||||
"viewBox": "0 0 24 24",
|
||||
"fill": "none",
|
||||
"xmlns": "http://www.w3.org/2000/svg"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M16 4C16.93 4 17.395 4 17.7765 4.10222C18.8117 4.37962 19.6204 5.18827 19.8978 6.22354C20 6.60504 20 7.07003 20 8V17.2C20 18.8802 20 19.7202 19.673 20.362C19.3854 20.9265 18.9265 21.3854 18.362 21.673C17.7202 22 16.8802 22 15.2 22H8.8C7.11984 22 6.27976 22 5.63803 21.673C5.07354 21.3854 4.6146 20.9265 4.32698 20.362C4 19.7202 4 18.8802 4 17.2V8C4 7.07003 4 6.60504 4.10222 6.22354C4.37962 5.18827 5.18827 4.37962 6.22354 4.10222C6.60504 4 7.07003 4 8 4M9 15L11 17L15.5 12.5M9.6 6H14.4C14.9601 6 15.2401 6 15.454 5.89101C15.6422 5.79513 15.7951 5.64215 15.891 5.45399C16 5.24008 16 4.96005 16 4.4V3.6C16 3.03995 16 2.75992 15.891 2.54601C15.7951 2.35785 15.6422 2.20487 15.454 2.10899C15.2401 2 14.9601 2 14.4 2H9.6C9.03995 2 8.75992 2 8.54601 2.10899C8.35785 2.20487 8.20487 2.35785 8.10899 2.54601C8 2.75992 8 3.03995 8 3.6V4.4C8 4.96005 8 5.24008 8.10899 5.45399C8.20487 5.64215 8.35785 5.79513 8.54601 5.89101C8.75992 6 9.03995 6 9.6 6Z",
|
||||
"stroke": "currentColor",
|
||||
"stroke-width": "2",
|
||||
"stroke-linecap": "round",
|
||||
"stroke-linejoin": "round"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "ClipboardCheck"
|
||||
}
|
||||
@ -0,0 +1,29 @@
|
||||
{
|
||||
"icon": {
|
||||
"type": "element",
|
||||
"isRootNode": true,
|
||||
"name": "svg",
|
||||
"attributes": {
|
||||
"width": "16",
|
||||
"height": "16",
|
||||
"viewBox": "0 0 16 16",
|
||||
"fill": "none",
|
||||
"xmlns": "http://www.w3.org/2000/svg"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M10.6665 2.66634H11.9998C12.3535 2.66634 12.6926 2.80682 12.9426 3.05687C13.1927 3.30691 13.3332 3.64605 13.3332 3.99967V13.333C13.3332 13.6866 13.1927 14.0258 12.9426 14.2758C12.6926 14.5259 12.3535 14.6663 11.9998 14.6663H3.99984C3.64622 14.6663 3.30708 14.5259 3.05703 14.2758C2.80698 14.0258 2.6665 13.6866 2.6665 13.333V3.99967C2.6665 3.64605 2.80698 3.30691 3.05703 3.05687C3.30708 2.80682 3.64622 2.66634 3.99984 2.66634H5.33317M5.99984 1.33301H9.99984C10.368 1.33301 10.6665 1.63148 10.6665 1.99967V3.33301C10.6665 3.7012 10.368 3.99967 9.99984 3.99967H5.99984C5.63165 3.99967 5.33317 3.7012 5.33317 3.33301V1.99967C5.33317 1.63148 5.63165 1.33301 5.99984 1.33301Z",
|
||||
"stroke": "currentColor",
|
||||
"stroke-width": "1.5",
|
||||
"stroke-linecap": "round",
|
||||
"stroke-linejoin": "round"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "Copy"
|
||||
}
|
||||
@ -0,0 +1,29 @@
|
||||
{
|
||||
"icon": {
|
||||
"type": "element",
|
||||
"isRootNode": true,
|
||||
"name": "svg",
|
||||
"attributes": {
|
||||
"width": "16",
|
||||
"height": "16",
|
||||
"viewBox": "0 0 16 16",
|
||||
"fill": "none",
|
||||
"xmlns": "http://www.w3.org/2000/svg"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M10.6665 2.66683C11.2865 2.66683 11.5965 2.66683 11.8508 2.73498C12.541 2.91991 13.0801 3.45901 13.265 4.14919C13.3332 4.40352 13.3332 4.71352 13.3332 5.3335V11.4668C13.3332 12.5869 13.3332 13.147 13.1152 13.5748C12.9234 13.9511 12.6175 14.2571 12.2412 14.4488C11.8133 14.6668 11.2533 14.6668 10.1332 14.6668H5.8665C4.7464 14.6668 4.18635 14.6668 3.75852 14.4488C3.3822 14.2571 3.07624 13.9511 2.88449 13.5748C2.6665 13.147 2.6665 12.5869 2.6665 11.4668V5.3335C2.6665 4.71352 2.6665 4.40352 2.73465 4.14919C2.91959 3.45901 3.45868 2.91991 4.14887 2.73498C4.4032 2.66683 4.71319 2.66683 5.33317 2.66683M5.99984 10.0002L7.33317 11.3335L10.3332 8.3335M6.39984 4.00016H9.59984C9.9732 4.00016 10.1599 4.00016 10.3025 3.9275C10.4279 3.86359 10.5299 3.7616 10.5938 3.63616C10.6665 3.49355 10.6665 3.30686 10.6665 2.9335V2.40016C10.6665 2.02679 10.6665 1.84011 10.5938 1.6975C10.5299 1.57206 10.4279 1.47007 10.3025 1.40616C10.1599 1.3335 9.97321 1.3335 9.59984 1.3335H6.39984C6.02647 1.3335 5.83978 1.3335 5.69718 1.40616C5.57174 1.47007 5.46975 1.57206 5.40583 1.6975C5.33317 1.84011 5.33317 2.02679 5.33317 2.40016V2.9335C5.33317 3.30686 5.33317 3.49355 5.40583 3.63616C5.46975 3.7616 5.57174 3.86359 5.69718 3.9275C5.83978 4.00016 6.02647 4.00016 6.39984 4.00016Z",
|
||||
"stroke": "currentColor",
|
||||
"stroke-width": "1.5",
|
||||
"stroke-linecap": "round",
|
||||
"stroke-linejoin": "round"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "CopyCheck"
|
||||
}
|
||||
@ -0,0 +1,127 @@
|
||||
import { memo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { RiCloseLine } from '@remixicon/react'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
import Button from '@/app/components/base/button'
|
||||
import type { ButtonProps } from '@/app/components/base/button'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
type ModalProps = {
|
||||
onClose?: () => void
|
||||
size?: 'sm' | 'md'
|
||||
title: string
|
||||
subTitle?: string
|
||||
children?: React.ReactNode
|
||||
confirmButtonText?: string
|
||||
onConfirm?: () => void
|
||||
cancelButtonText?: string
|
||||
onCancel?: () => void
|
||||
showExtraButton?: boolean
|
||||
extraButtonText?: string
|
||||
extraButtonVariant?: ButtonProps['variant']
|
||||
onExtraButtonClick?: () => void
|
||||
footerSlot?: React.ReactNode
|
||||
bottomSlot?: React.ReactNode
|
||||
disabled?: boolean
|
||||
}
|
||||
const Modal = ({
|
||||
onClose,
|
||||
size = 'sm',
|
||||
title,
|
||||
subTitle,
|
||||
children,
|
||||
confirmButtonText,
|
||||
onConfirm,
|
||||
cancelButtonText,
|
||||
onCancel,
|
||||
showExtraButton,
|
||||
extraButtonVariant = 'warning',
|
||||
extraButtonText,
|
||||
onExtraButtonClick,
|
||||
footerSlot,
|
||||
bottomSlot,
|
||||
disabled,
|
||||
}: ModalProps) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<PortalToFollowElem open>
|
||||
<PortalToFollowElemContent
|
||||
className='z-[9998] flex h-full w-full items-center justify-center bg-background-overlay'
|
||||
onClick={onClose}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
'max-h-[80%] w-[480px] overflow-y-auto rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-xs',
|
||||
size === 'sm' && 'w-[480px',
|
||||
size === 'md' && 'w-[640px]',
|
||||
)}
|
||||
onClick={e => e.stopPropagation()}
|
||||
>
|
||||
<div className='title-2xl-semi-bold relative p-6 pb-3 pr-14 text-text-primary'>
|
||||
{title}
|
||||
{
|
||||
subTitle && (
|
||||
<div className='system-xs-regular mt-1 text-text-tertiary'>
|
||||
{subTitle}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
<div
|
||||
className='absolute right-5 top-5 flex h-8 w-8 cursor-pointer items-center justify-center rounded-lg'
|
||||
onClick={onClose}
|
||||
>
|
||||
<RiCloseLine className='h-5 w-5 text-text-tertiary' />
|
||||
</div>
|
||||
</div>
|
||||
{
|
||||
children && (
|
||||
<div className='px-6 py-3'>{children}</div>
|
||||
)
|
||||
}
|
||||
<div className='flex justify-between p-6 pt-5'>
|
||||
<div>
|
||||
{footerSlot}
|
||||
</div>
|
||||
<div className='flex items-center'>
|
||||
{
|
||||
showExtraButton && (
|
||||
<>
|
||||
<Button
|
||||
variant={extraButtonVariant}
|
||||
onClick={onExtraButtonClick}
|
||||
disabled={disabled}
|
||||
>
|
||||
{extraButtonText || t('common.operation.remove')}
|
||||
</Button>
|
||||
<div className='mx-3 h-4 w-[1px] bg-divider-regular'></div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
<Button
|
||||
onClick={onCancel}
|
||||
disabled={disabled}
|
||||
>
|
||||
{cancelButtonText || t('common.operation.cancel')}
|
||||
</Button>
|
||||
<Button
|
||||
className='ml-2'
|
||||
variant='primary'
|
||||
onClick={onConfirm}
|
||||
disabled={disabled}
|
||||
>
|
||||
{confirmButtonText || t('common.operation.save')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
{bottomSlot}
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
</PortalToFollowElem>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(Modal)
|
||||
@ -0,0 +1,50 @@
|
||||
import {
|
||||
memo,
|
||||
useState,
|
||||
} from 'react'
|
||||
import Button from '@/app/components/base/button'
|
||||
import type { ButtonProps } from '@/app/components/base/button'
|
||||
import ApiKeyModal from './api-key-modal'
|
||||
import type { PluginPayload } from '../types'
|
||||
|
||||
export type AddApiKeyButtonProps = {
|
||||
pluginPayload: PluginPayload
|
||||
buttonVariant?: ButtonProps['variant']
|
||||
buttonText?: string
|
||||
disabled?: boolean
|
||||
onUpdate?: () => void
|
||||
}
|
||||
const AddApiKeyButton = ({
|
||||
pluginPayload,
|
||||
buttonVariant = 'secondary-accent',
|
||||
buttonText = 'use api key',
|
||||
disabled,
|
||||
onUpdate,
|
||||
}: AddApiKeyButtonProps) => {
|
||||
const [isApiKeyModalOpen, setIsApiKeyModalOpen] = useState(false)
|
||||
|
||||
return (
|
||||
<>
|
||||
<Button
|
||||
className='w-full'
|
||||
variant={buttonVariant}
|
||||
onClick={() => setIsApiKeyModalOpen(true)}
|
||||
disabled={disabled}
|
||||
>
|
||||
{buttonText}
|
||||
</Button>
|
||||
{
|
||||
isApiKeyModalOpen && (
|
||||
<ApiKeyModal
|
||||
pluginPayload={pluginPayload}
|
||||
onClose={() => setIsApiKeyModalOpen(false)}
|
||||
onUpdate={onUpdate}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</>
|
||||
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(AddApiKeyButton)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue