fix conflict
commit
eb95858e01
@ -0,0 +1,84 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Optional
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class ProviderCredentialsCache(ABC):
|
||||
"""Base class for provider credentials cache"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.cache_key = self._generate_cache_key(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
"""Generate cache key based on subclass implementation"""
|
||||
pass
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""Get cached provider credentials"""
|
||||
cached_credentials = redis_client.get(self.cache_key)
|
||||
if cached_credentials:
|
||||
try:
|
||||
cached_credentials = cached_credentials.decode("utf-8")
|
||||
return dict(json.loads(cached_credentials))
|
||||
except JSONDecodeError:
|
||||
return None
|
||||
return None
|
||||
|
||||
def set(self, config: dict[str, Any]) -> None:
|
||||
"""Cache provider credentials"""
|
||||
redis_client.setex(self.cache_key, 86400, json.dumps(config))
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete cached provider credentials"""
|
||||
redis_client.delete(self.cache_key)
|
||||
|
||||
|
||||
class SingletonProviderCredentialsCache(ProviderCredentialsCache):
|
||||
"""Cache for tool single provider credentials"""
|
||||
|
||||
def __init__(self, tenant_id: str, provider_type: str, provider_identity: str):
|
||||
super().__init__(
|
||||
tenant_id=tenant_id,
|
||||
provider_type=provider_type,
|
||||
provider_identity=provider_identity,
|
||||
)
|
||||
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider_type = kwargs["provider_type"]
|
||||
identity_name = kwargs["provider_identity"]
|
||||
identity_id = f"{provider_type}.{identity_name}"
|
||||
return f"{provider_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
||||
|
||||
|
||||
class ToolProviderCredentialsCache(ProviderCredentialsCache):
|
||||
"""Cache for tool provider credentials"""
|
||||
|
||||
def __init__(self, tenant_id: str, provider: str, credential_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id)
|
||||
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider = kwargs["provider"]
|
||||
credential_id = kwargs["credential_id"]
|
||||
return f"tool_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
|
||||
|
||||
|
||||
class NoOpProviderCredentialCache:
|
||||
"""No-op provider credential cache"""
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""Get cached provider credentials"""
|
||||
return None
|
||||
|
||||
def set(self, config: dict[str, Any]) -> None:
|
||||
"""Cache provider credentials"""
|
||||
pass
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete cached provider credentials"""
|
||||
pass
|
||||
@ -1,51 +0,0 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class ToolProviderCredentialsCacheType(Enum):
|
||||
PROVIDER = "tool_provider"
|
||||
ENDPOINT = "endpoint"
|
||||
|
||||
|
||||
class ToolProviderCredentialsCache:
|
||||
def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):
|
||||
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""
|
||||
Get cached model provider credentials.
|
||||
|
||||
:return:
|
||||
"""
|
||||
cached_provider_credentials = redis_client.get(self.cache_key)
|
||||
if cached_provider_credentials:
|
||||
try:
|
||||
cached_provider_credentials = cached_provider_credentials.decode("utf-8")
|
||||
cached_provider_credentials = json.loads(cached_provider_credentials)
|
||||
except JSONDecodeError:
|
||||
return None
|
||||
|
||||
return dict(cached_provider_credentials)
|
||||
else:
|
||||
return None
|
||||
|
||||
def set(self, credentials: dict) -> None:
|
||||
"""
|
||||
Cache model provider credentials.
|
||||
|
||||
:param credentials: provider credentials
|
||||
:return:
|
||||
"""
|
||||
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
|
||||
|
||||
def delete(self) -> None:
|
||||
"""
|
||||
Delete cached model provider credentials.
|
||||
|
||||
:return:
|
||||
"""
|
||||
redis_client.delete(self.cache_key)
|
||||
@ -0,0 +1,141 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional, Protocol
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import SingletonProviderCredentialsCache
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
|
||||
|
||||
class ProviderConfigCache(Protocol):
|
||||
"""
|
||||
Interface for provider configuration cache operations
|
||||
"""
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""Get cached provider configuration"""
|
||||
...
|
||||
|
||||
def set(self, config: dict[str, Any]) -> None:
|
||||
"""Cache provider configuration"""
|
||||
...
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete cached provider configuration"""
|
||||
...
|
||||
|
||||
|
||||
class ProviderConfigEncrypter:
|
||||
tenant_id: str
|
||||
config: list[BasicProviderConfig]
|
||||
provider_config_cache: ProviderConfigCache
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
config: list[BasicProviderConfig],
|
||||
provider_config_cache: ProviderConfigCache,
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.config = config
|
||||
self.provider_config_cache = provider_config_cache
|
||||
|
||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
deep copy data
|
||||
"""
|
||||
return deepcopy(data)
|
||||
|
||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||
data[field_name] = encrypted
|
||||
|
||||
return data
|
||||
|
||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
if len(data[field_name]) > 6:
|
||||
data[field_name] = (
|
||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
data[field_name] = "*" * len(data[field_name])
|
||||
|
||||
return data
|
||||
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cached_credentials = self.provider_config_cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
try:
|
||||
# if the value is None or empty string, skip decrypt
|
||||
if not data[field_name]:
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.provider_config_cache.set(data)
|
||||
return data
|
||||
|
||||
|
||||
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
|
||||
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
|
||||
|
||||
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
|
||||
cache = SingletonProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_type=controller.provider_type.value,
|
||||
provider_identity=controller.entity.identity.name,
|
||||
)
|
||||
encrypt = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
|
||||
provider_config_cache=cache,
|
||||
)
|
||||
return encrypt, cache
|
||||
@ -0,0 +1,192 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Random import get_random_bytes
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthEncryptionError(Exception):
|
||||
"""OAuth encryption/decryption specific error"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SystemOAuthEncrypter:
|
||||
"""
|
||||
A simple OAuth parameters encrypter using AES-CBC encryption.
|
||||
|
||||
This class provides methods to encrypt and decrypt OAuth parameters
|
||||
using AES-CBC mode with a key derived from the application's SECRET_KEY.
|
||||
"""
|
||||
|
||||
def __init__(self, secret_key: Optional[str] = None):
|
||||
"""
|
||||
Initialize the OAuth encrypter.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
|
||||
Raises:
|
||||
ValueError: If SECRET_KEY is not configured or empty
|
||||
"""
|
||||
secret_key = secret_key or dify_config.SECRET_KEY or ""
|
||||
|
||||
# Generate a fixed 256-bit key using SHA-256
|
||||
self.key = hashlib.sha256(secret_key.encode()).digest()
|
||||
|
||||
def encrypt_oauth_params(self, oauth_params: str) -> str:
|
||||
"""
|
||||
Encrypt OAuth parameters.
|
||||
|
||||
Args:
|
||||
oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
|
||||
Raises:
|
||||
OAuthEncryptionError: If encryption fails
|
||||
ValueError: If oauth_params is invalid
|
||||
"""
|
||||
|
||||
if not oauth_params:
|
||||
raise ValueError("oauth_params cannot be empty")
|
||||
|
||||
try:
|
||||
# Generate random IV (16 bytes)
|
||||
iv = get_random_bytes(16)
|
||||
|
||||
# Create AES cipher (CBC mode)
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
|
||||
# Encrypt data
|
||||
padded_data = pad(oauth_params.encode("utf-8"), AES.block_size)
|
||||
encrypted_data = cipher.encrypt(padded_data)
|
||||
|
||||
# Combine IV and encrypted data
|
||||
combined = iv + encrypted_data
|
||||
|
||||
# Return base64 encoded string
|
||||
return base64.b64encode(combined).decode()
|
||||
|
||||
except Exception as e:
|
||||
raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
|
||||
|
||||
def decrypt_oauth_params(self, encrypted_data: str) -> dict[str, Any]:
|
||||
"""
|
||||
Decrypt OAuth parameters.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted OAuth parameters dictionary
|
||||
|
||||
Raises:
|
||||
OAuthEncryptionError: If decryption fails
|
||||
ValueError: If encrypted_data is invalid
|
||||
"""
|
||||
if not isinstance(encrypted_data, str):
|
||||
raise ValueError("encrypted_data must be a string")
|
||||
|
||||
if not encrypted_data:
|
||||
raise ValueError("encrypted_data cannot be empty")
|
||||
|
||||
try:
|
||||
# Base64 decode
|
||||
combined = base64.b64decode(encrypted_data)
|
||||
|
||||
# Check minimum length (IV + at least one AES block)
|
||||
if len(combined) < 32: # 16 bytes IV + 16 bytes minimum encrypted data
|
||||
raise ValueError("Invalid encrypted data format")
|
||||
|
||||
# Separate IV and encrypted data
|
||||
iv = combined[:16]
|
||||
encrypted_data_bytes = combined[16:]
|
||||
|
||||
# Create AES cipher
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
|
||||
# Decrypt data
|
||||
decrypted_data = cipher.decrypt(encrypted_data_bytes)
|
||||
unpadded_data = unpad(decrypted_data, AES.block_size)
|
||||
|
||||
# Parse JSON
|
||||
params_json = unpadded_data.decode("utf-8")
|
||||
oauth_params = json.loads(params_json)
|
||||
|
||||
if not isinstance(oauth_params, dict):
|
||||
raise ValueError("Decrypted data is not a valid dictionary")
|
||||
|
||||
return oauth_params
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
|
||||
except Exception as e:
|
||||
raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
|
||||
|
||||
|
||||
# Factory function for creating encrypter instances
|
||||
def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAuthEncrypter:
|
||||
"""
|
||||
Create an OAuth encrypter instance.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
|
||||
Returns:
|
||||
SystemOAuthEncrypter instance
|
||||
"""
|
||||
return SystemOAuthEncrypter(secret_key=secret_key)
|
||||
|
||||
|
||||
# Global encrypter instance (for backward compatibility)
|
||||
_oauth_encrypter: Optional[SystemOAuthEncrypter] = None
|
||||
|
||||
|
||||
def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
|
||||
"""
|
||||
Get the global OAuth encrypter instance.
|
||||
|
||||
Returns:
|
||||
SystemOAuthEncrypter instance
|
||||
"""
|
||||
global _oauth_encrypter
|
||||
if _oauth_encrypter is None:
|
||||
_oauth_encrypter = SystemOAuthEncrypter()
|
||||
return _oauth_encrypter
|
||||
|
||||
|
||||
# Convenience functions for backward compatibility
|
||||
def encrypt_system_oauth_params(oauth_params: str) -> str:
|
||||
"""
|
||||
Encrypt OAuth parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
oauth_params: OAuth parameters dictionary
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
"""
|
||||
return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
|
||||
|
||||
|
||||
def decrypt_system_oauth_params(encrypted_data: str) -> dict[str, Any]:
|
||||
"""
|
||||
Decrypt OAuth parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted OAuth parameters dictionary
|
||||
"""
|
||||
return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)
|
||||
@ -0,0 +1,41 @@
|
||||
"""empty message
|
||||
|
||||
Revision ID: 16081485540c
|
||||
Revises: d28f2004b072
|
||||
Create Date: 2025-05-15 16:35:39.113777
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '16081485540c'
|
||||
down_revision = '2adcbe1f5dfb'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tenant_plugin_auto_upgrade_strategies',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('strategy_setting', sa.String(length=16), server_default='fix_only', nullable=False),
|
||||
sa.Column('upgrade_time_of_day', sa.Integer(), nullable=False),
|
||||
sa.Column('upgrade_mode', sa.String(length=16), server_default='exclude', nullable=False),
|
||||
sa.Column('exclude_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
|
||||
sa.Column('include_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tenant_plugin_auto_upgrade_strategy_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', name='unique_tenant_plugin_auto_upgrade_strategy')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('tenant_plugin_auto_upgrade_strategies')
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,62 @@
|
||||
"""tool oauth
|
||||
|
||||
Revision ID: 71f5020c6470
|
||||
Revises: 4474872b0ee6
|
||||
Create Date: 2025-06-24 17:05:43.118647
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '71f5020c6470'
|
||||
down_revision = '58eb7bdb93fe'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tool_oauth_system_clients',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('plugin_id', sa.String(length=512), nullable=False),
|
||||
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'),
|
||||
sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx')
|
||||
)
|
||||
op.create_table('tool_oauth_tenant_clients',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('plugin_id', sa.String(length=512), nullable=False),
|
||||
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
|
||||
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_oauth_tenant_client_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_tenant_client')
|
||||
)
|
||||
|
||||
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('name', sa.String(length=256), server_default=sa.text("'API KEY 1'::character varying"), nullable=False))
|
||||
batch_op.add_column(sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False))
|
||||
batch_op.add_column(sa.Column('credential_type', sa.String(length=32), server_default=sa.text("'api-key'::character varying"), nullable=False))
|
||||
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
|
||||
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'name'])
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
|
||||
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider'])
|
||||
batch_op.drop_column('credential_type')
|
||||
batch_op.drop_column('is_default')
|
||||
batch_op.drop_column('name')
|
||||
|
||||
op.drop_table('tool_oauth_tenant_clients')
|
||||
op.drop_table('tool_oauth_system_clients')
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,49 @@
|
||||
import time
|
||||
|
||||
import click
|
||||
|
||||
import app
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
from tasks.process_tenant_plugin_autoupgrade_check_task import process_tenant_plugin_autoupgrade_check_task
|
||||
|
||||
AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL = 15 * 60 # 15 minutes
|
||||
|
||||
|
||||
@app.celery.task(queue="plugin")
|
||||
def check_upgradable_plugin_task():
|
||||
click.echo(click.style("Start check upgradable plugin.", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
now_seconds_of_day = time.time() % 86400 - 30 # we assume the tz is UTC
|
||||
click.echo(click.style("Now seconds of day: {}".format(now_seconds_of_day), fg="green"))
|
||||
|
||||
strategies = (
|
||||
db.session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.filter(
|
||||
TenantPluginAutoUpgradeStrategy.upgrade_time_of_day >= now_seconds_of_day,
|
||||
TenantPluginAutoUpgradeStrategy.upgrade_time_of_day
|
||||
< now_seconds_of_day + AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL,
|
||||
TenantPluginAutoUpgradeStrategy.strategy_setting
|
||||
!= TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
for strategy in strategies:
|
||||
process_tenant_plugin_autoupgrade_check_task.delay(
|
||||
strategy.tenant_id,
|
||||
strategy.strategy_setting,
|
||||
strategy.upgrade_time_of_day,
|
||||
strategy.upgrade_mode,
|
||||
strategy.exclude_plugins,
|
||||
strategy.include_plugins,
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
click.echo(
|
||||
click.style(
|
||||
"Checked upgradable plugin success latency: {}".format(end_at - start_at),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
@ -0,0 +1,87 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
|
||||
|
||||
class PluginAutoUpgradeService:
|
||||
@staticmethod
|
||||
def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None:
|
||||
with Session(db.engine) as session:
|
||||
return (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def change_strategy(
|
||||
tenant_id: str,
|
||||
strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting,
|
||||
upgrade_time_of_day: int,
|
||||
upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode,
|
||||
exclude_plugins: list[str],
|
||||
include_plugins: list[str],
|
||||
) -> bool:
|
||||
with Session(db.engine) as session:
|
||||
exist_strategy = (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not exist_strategy:
|
||||
strategy = TenantPluginAutoUpgradeStrategy(
|
||||
tenant_id=tenant_id,
|
||||
strategy_setting=strategy_setting,
|
||||
upgrade_time_of_day=upgrade_time_of_day,
|
||||
upgrade_mode=upgrade_mode,
|
||||
exclude_plugins=exclude_plugins,
|
||||
include_plugins=include_plugins,
|
||||
)
|
||||
session.add(strategy)
|
||||
else:
|
||||
exist_strategy.strategy_setting = strategy_setting
|
||||
exist_strategy.upgrade_time_of_day = upgrade_time_of_day
|
||||
exist_strategy.upgrade_mode = upgrade_mode
|
||||
exist_strategy.exclude_plugins = exclude_plugins
|
||||
exist_strategy.include_plugins = include_plugins
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def exclude_plugin(tenant_id: str, plugin_id: str) -> bool:
|
||||
with Session(db.engine) as session:
|
||||
exist_strategy = (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not exist_strategy:
|
||||
# create for this tenant
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant_id,
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
0,
|
||||
TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
|
||||
[plugin_id],
|
||||
[],
|
||||
)
|
||||
return True
|
||||
else:
|
||||
if exist_strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE:
|
||||
if plugin_id not in exist_strategy.exclude_plugins:
|
||||
new_exclude_plugins = exist_strategy.exclude_plugins.copy()
|
||||
new_exclude_plugins.append(plugin_id)
|
||||
exist_strategy.exclude_plugins = new_exclude_plugins
|
||||
elif exist_strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL:
|
||||
if plugin_id in exist_strategy.include_plugins:
|
||||
new_include_plugins = exist_strategy.include_plugins.copy()
|
||||
new_include_plugins.remove(plugin_id)
|
||||
exist_strategy.include_plugins = new_include_plugins
|
||||
elif exist_strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL:
|
||||
exist_strategy.upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
|
||||
exist_strategy.exclude_plugins = [plugin_id]
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
@ -0,0 +1,166 @@
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
|
||||
from core.helper import marketplace
|
||||
from core.helper.marketplace import MarketplacePluginDeclaration
|
||||
from core.plugin.entities.plugin import PluginInstallationSource
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
|
||||
RETRY_TIMES_OF_ONE_PLUGIN_IN_ONE_TENANT = 3
|
||||
|
||||
|
||||
cached_plugin_manifests: dict[str, typing.Union[MarketplacePluginDeclaration, None]] = {}
|
||||
|
||||
|
||||
def marketplace_batch_fetch_plugin_manifests(
|
||||
plugin_ids_plain_list: list[str],
|
||||
) -> list[MarketplacePluginDeclaration]:
|
||||
global cached_plugin_manifests
|
||||
# return marketplace.batch_fetch_plugin_manifests(plugin_ids_plain_list)
|
||||
not_included_plugin_ids = [
|
||||
plugin_id for plugin_id in plugin_ids_plain_list if plugin_id not in cached_plugin_manifests
|
||||
]
|
||||
if not_included_plugin_ids:
|
||||
manifests = marketplace.batch_fetch_plugin_manifests_ignore_deserialization_error(not_included_plugin_ids)
|
||||
for manifest in manifests:
|
||||
cached_plugin_manifests[manifest.plugin_id] = manifest
|
||||
|
||||
if (
|
||||
len(manifests) == 0
|
||||
): # this indicates that the plugin not found in marketplace, should set None in cache to prevent future check
|
||||
for plugin_id in not_included_plugin_ids:
|
||||
cached_plugin_manifests[plugin_id] = None
|
||||
|
||||
result: list[MarketplacePluginDeclaration] = []
|
||||
for plugin_id in plugin_ids_plain_list:
|
||||
final_manifest = cached_plugin_manifests.get(plugin_id)
|
||||
if final_manifest is not None:
|
||||
result.append(final_manifest)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@shared_task(queue="plugin")
|
||||
def process_tenant_plugin_autoupgrade_check_task(
|
||||
tenant_id: str,
|
||||
strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting,
|
||||
upgrade_time_of_day: int,
|
||||
upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode,
|
||||
exclude_plugins: list[str],
|
||||
include_plugins: list[str],
|
||||
):
|
||||
try:
|
||||
manager = PluginInstaller()
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
"Checking upgradable plugin for tenant: {}".format(tenant_id),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
if strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED:
|
||||
return
|
||||
|
||||
# 获取需要检查的插件
|
||||
plugin_ids: list[tuple[str, str, str]] = [] # plugin_id, version, unique_identifier
|
||||
click.echo(click.style("Upgrade mode: {}".format(upgrade_mode), fg="green"))
|
||||
|
||||
if upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL and include_plugins:
|
||||
all_plugins = manager.list_plugins(tenant_id)
|
||||
|
||||
for plugin in all_plugins:
|
||||
if plugin.source == PluginInstallationSource.Marketplace and plugin.plugin_id in include_plugins:
|
||||
plugin_ids.append(
|
||||
(
|
||||
plugin.plugin_id,
|
||||
plugin.version,
|
||||
plugin.plugin_unique_identifier,
|
||||
)
|
||||
)
|
||||
|
||||
elif upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE:
|
||||
# 获取所有插件并移除exclude的
|
||||
all_plugins = manager.list_plugins(tenant_id)
|
||||
plugin_ids = [
|
||||
(plugin.plugin_id, plugin.version, plugin.plugin_unique_identifier)
|
||||
for plugin in all_plugins
|
||||
if plugin.source == PluginInstallationSource.Marketplace and plugin.plugin_id not in exclude_plugins
|
||||
]
|
||||
elif upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL:
|
||||
all_plugins = manager.list_plugins(tenant_id)
|
||||
plugin_ids = [
|
||||
(plugin.plugin_id, plugin.version, plugin.plugin_unique_identifier)
|
||||
for plugin in all_plugins
|
||||
if plugin.source == PluginInstallationSource.Marketplace
|
||||
]
|
||||
|
||||
if not plugin_ids:
|
||||
return
|
||||
|
||||
plugin_ids_plain_list = [plugin_id for plugin_id, _, _ in plugin_ids]
|
||||
|
||||
manifests = marketplace_batch_fetch_plugin_manifests(plugin_ids_plain_list)
|
||||
|
||||
if not manifests:
|
||||
return
|
||||
|
||||
for manifest in manifests:
|
||||
for plugin_id, version, original_unique_identifier in plugin_ids:
|
||||
if manifest.plugin_id != plugin_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
current_version = version
|
||||
latest_version = manifest.latest_version
|
||||
|
||||
def fix_only_checker(latest_version, current_version):
|
||||
latest_version_tuple = tuple(int(val) for val in latest_version.split("."))
|
||||
current_version_tuple = tuple(int(val) for val in current_version.split("."))
|
||||
|
||||
if (
|
||||
latest_version_tuple[0] == current_version_tuple[0]
|
||||
and latest_version_tuple[1] == current_version_tuple[1]
|
||||
):
|
||||
return latest_version_tuple[2] != current_version_tuple[2]
|
||||
return False
|
||||
|
||||
version_checker = {
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: lambda latest_version,
|
||||
current_version: latest_version != current_version,
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY: fix_only_checker,
|
||||
}
|
||||
|
||||
if version_checker[strategy_setting](latest_version, current_version):
|
||||
# 执行升级
|
||||
new_unique_identifier = manifest.latest_package_identifier
|
||||
|
||||
marketplace.record_install_plugin_event(new_unique_identifier)
|
||||
click.echo(
|
||||
click.style(
|
||||
"Upgrade plugin: {} -> {}".format(original_unique_identifier, new_unique_identifier),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
task_start_resp = manager.upgrade_plugin(
|
||||
tenant_id,
|
||||
original_unique_identifier,
|
||||
new_unique_identifier,
|
||||
PluginInstallationSource.Marketplace,
|
||||
{
|
||||
"plugin_unique_identifier": new_unique_identifier,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(click.style("Error when upgrading plugin: {}".format(e), fg="red"))
|
||||
traceback.print_exc()
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
click.echo(click.style("Error when checking upgradable plugin: {}".format(e), fg="red"))
|
||||
traceback.print_exc()
|
||||
return
|
||||
@ -0,0 +1,232 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
from flask_login import LoginManager, UserMixin
|
||||
|
||||
from libs.login import _get_user, current_user, login_required
|
||||
|
||||
|
||||
class MockUser(UserMixin):
|
||||
"""Mock user class for testing."""
|
||||
|
||||
def __init__(self, id: str, is_authenticated: bool = True):
|
||||
self.id = id
|
||||
self._is_authenticated = is_authenticated
|
||||
|
||||
@property
|
||||
def is_authenticated(self):
|
||||
return self._is_authenticated
|
||||
|
||||
|
||||
class TestLoginRequired:
|
||||
"""Test cases for login_required decorator."""
|
||||
|
||||
@pytest.fixture
|
||||
def setup_app(self, app: Flask):
|
||||
"""Set up Flask app with login manager."""
|
||||
# Initialize login manager
|
||||
login_manager = LoginManager()
|
||||
login_manager.init_app(app)
|
||||
|
||||
# Mock unauthorized handler
|
||||
login_manager.unauthorized = MagicMock(return_value="Unauthorized")
|
||||
|
||||
# Add a dummy user loader to prevent exceptions
|
||||
@login_manager.user_loader
|
||||
def load_user(user_id):
|
||||
return None
|
||||
|
||||
return app
|
||||
|
||||
def test_authenticated_user_can_access_protected_view(self, setup_app: Flask):
|
||||
"""Test that authenticated users can access protected views."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
with setup_app.test_request_context():
|
||||
# Mock authenticated user
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
|
||||
def test_unauthenticated_user_cannot_access_protected_view(self, setup_app: Flask):
|
||||
"""Test that unauthenticated users are redirected."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
with setup_app.test_request_context():
|
||||
# Mock unauthenticated user
|
||||
mock_user = MockUser("test_user", is_authenticated=False)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
result = protected_view()
|
||||
assert result == "Unauthorized"
|
||||
setup_app.login_manager.unauthorized.assert_called_once()
|
||||
|
||||
def test_login_disabled_allows_unauthenticated_access(self, setup_app: Flask):
|
||||
"""Test that LOGIN_DISABLED config bypasses authentication."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
with setup_app.test_request_context():
|
||||
# Mock unauthenticated user and LOGIN_DISABLED
|
||||
mock_user = MockUser("test_user", is_authenticated=False)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
with patch("libs.login.dify_config") as mock_config:
|
||||
mock_config.LOGIN_DISABLED = True
|
||||
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
# Ensure unauthorized was not called
|
||||
setup_app.login_manager.unauthorized.assert_not_called()
|
||||
|
||||
def test_options_request_bypasses_authentication(self, setup_app: Flask):
|
||||
"""Test that OPTIONS requests are exempt from authentication."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
with setup_app.test_request_context(method="OPTIONS"):
|
||||
# Mock unauthenticated user
|
||||
mock_user = MockUser("test_user", is_authenticated=False)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
# Ensure unauthorized was not called
|
||||
setup_app.login_manager.unauthorized.assert_not_called()
|
||||
|
||||
def test_flask_2_compatibility(self, setup_app: Flask):
|
||||
"""Test Flask 2.x compatibility with ensure_sync."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
# Mock Flask 2.x ensure_sync
|
||||
setup_app.ensure_sync = MagicMock(return_value=lambda: "Synced content")
|
||||
|
||||
with setup_app.test_request_context():
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
result = protected_view()
|
||||
assert result == "Synced content"
|
||||
setup_app.ensure_sync.assert_called_once()
|
||||
|
||||
def test_flask_1_compatibility(self, setup_app: Flask):
|
||||
"""Test Flask 1.x compatibility without ensure_sync."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
# Remove ensure_sync to simulate Flask 1.x
|
||||
if hasattr(setup_app, "ensure_sync"):
|
||||
delattr(setup_app, "ensure_sync")
|
||||
|
||||
with setup_app.test_request_context():
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
|
||||
|
||||
class TestGetUser:
|
||||
"""Test cases for _get_user function."""
|
||||
|
||||
def test_get_user_returns_user_from_g(self, app: Flask):
|
||||
"""Test that _get_user returns user from g._login_user."""
|
||||
mock_user = MockUser("test_user")
|
||||
|
||||
with app.test_request_context():
|
||||
g._login_user = mock_user
|
||||
user = _get_user()
|
||||
assert user == mock_user
|
||||
assert user.id == "test_user"
|
||||
|
||||
def test_get_user_loads_user_if_not_in_g(self, app: Flask):
|
||||
"""Test that _get_user loads user if not already in g."""
|
||||
mock_user = MockUser("test_user")
|
||||
|
||||
# Mock login manager
|
||||
login_manager = MagicMock()
|
||||
login_manager._load_user = MagicMock()
|
||||
app.login_manager = login_manager
|
||||
|
||||
with app.test_request_context():
|
||||
# Simulate _load_user setting g._login_user
|
||||
def side_effect():
|
||||
g._login_user = mock_user
|
||||
|
||||
login_manager._load_user.side_effect = side_effect
|
||||
|
||||
user = _get_user()
|
||||
assert user == mock_user
|
||||
login_manager._load_user.assert_called_once()
|
||||
|
||||
def test_get_user_returns_none_without_request_context(self, app: Flask):
|
||||
"""Test that _get_user returns None outside request context."""
|
||||
# Outside of request context
|
||||
user = _get_user()
|
||||
assert user is None
|
||||
|
||||
|
||||
class TestCurrentUser:
|
||||
"""Test cases for current_user proxy."""
|
||||
|
||||
def test_current_user_proxy_returns_authenticated_user(self, app: Flask):
|
||||
"""Test that current_user proxy returns authenticated user."""
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
|
||||
with app.test_request_context():
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
assert current_user.id == "test_user"
|
||||
assert current_user.is_authenticated is True
|
||||
|
||||
def test_current_user_proxy_returns_none_when_no_user(self, app: Flask):
|
||||
"""Test that current_user proxy handles None user."""
|
||||
with app.test_request_context():
|
||||
with patch("libs.login._get_user", return_value=None):
|
||||
# When _get_user returns None, accessing attributes should fail
|
||||
# or current_user should evaluate to falsy
|
||||
try:
|
||||
# Try to access an attribute that would exist on a real user
|
||||
_ = current_user.id
|
||||
pytest.fail("Should have raised AttributeError")
|
||||
except AttributeError:
|
||||
# This is expected when current_user is None
|
||||
pass
|
||||
|
||||
def test_current_user_proxy_thread_safety(self, app: Flask):
|
||||
"""Test that current_user proxy is thread-safe."""
|
||||
import threading
|
||||
|
||||
results = {}
|
||||
|
||||
def check_user_in_thread(user_id: str, index: int):
|
||||
with app.test_request_context():
|
||||
mock_user = MockUser(user_id)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
results[index] = current_user.id
|
||||
|
||||
# Create multiple threads with different users
|
||||
threads = []
|
||||
for i in range(5):
|
||||
thread = threading.Thread(target=check_user_in_thread, args=(f"user_{i}", i))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads to complete
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Verify each thread got its own user
|
||||
for i in range(5):
|
||||
assert results[i] == f"user_{i}"
|
||||
@ -0,0 +1,205 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from libs.passport import PassportService
|
||||
|
||||
|
||||
class TestPassportService:
|
||||
"""Test PassportService JWT operations"""
|
||||
|
||||
@pytest.fixture
|
||||
def passport_service(self):
|
||||
"""Create PassportService instance with test secret key"""
|
||||
with patch("libs.passport.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "test-secret-key-for-testing"
|
||||
return PassportService()
|
||||
|
||||
@pytest.fixture
|
||||
def another_passport_service(self):
|
||||
"""Create another PassportService instance with different secret key"""
|
||||
with patch("libs.passport.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "another-secret-key-for-testing"
|
||||
return PassportService()
|
||||
|
||||
# Core functionality tests
|
||||
def test_should_issue_and_verify_token(self, passport_service):
|
||||
"""Test complete JWT lifecycle: issue and verify"""
|
||||
payload = {"user_id": "123", "app_code": "test-app"}
|
||||
token = passport_service.issue(payload)
|
||||
|
||||
# Verify token format
|
||||
assert isinstance(token, str)
|
||||
assert len(token.split(".")) == 3 # JWT format: header.payload.signature
|
||||
|
||||
# Verify token content
|
||||
decoded = passport_service.verify(token)
|
||||
assert decoded == payload
|
||||
|
||||
def test_should_handle_different_payload_types(self, passport_service):
|
||||
"""Test issuing and verifying tokens with different payload types"""
|
||||
test_cases = [
|
||||
{"string": "value"},
|
||||
{"number": 42},
|
||||
{"float": 3.14},
|
||||
{"boolean": True},
|
||||
{"null": None},
|
||||
{"array": [1, 2, 3]},
|
||||
{"nested": {"key": "value"}},
|
||||
{"unicode": "中文测试"},
|
||||
{"emoji": "🔐"},
|
||||
{}, # Empty payload
|
||||
]
|
||||
|
||||
for payload in test_cases:
|
||||
token = passport_service.issue(payload)
|
||||
decoded = passport_service.verify(token)
|
||||
assert decoded == payload
|
||||
|
||||
# Security tests
|
||||
def test_should_reject_modified_token(self, passport_service):
|
||||
"""Test that any modification to token invalidates it"""
|
||||
token = passport_service.issue({"user": "test"})
|
||||
|
||||
# Test multiple modification points
|
||||
test_positions = [0, len(token) // 3, len(token) // 2, len(token) - 1]
|
||||
|
||||
for pos in test_positions:
|
||||
if pos < len(token) and token[pos] != ".":
|
||||
# Change one character
|
||||
tampered = token[:pos] + ("X" if token[pos] != "X" else "Y") + token[pos + 1 :]
|
||||
with pytest.raises(Unauthorized):
|
||||
passport_service.verify(tampered)
|
||||
|
||||
def test_should_reject_token_with_different_secret_key(self, passport_service, another_passport_service):
|
||||
"""Test key isolation - token from one service should not work with another"""
|
||||
payload = {"user_id": "123", "app_code": "test-app"}
|
||||
token = passport_service.issue(payload)
|
||||
|
||||
with pytest.raises(Unauthorized) as exc_info:
|
||||
another_passport_service.verify(token)
|
||||
assert str(exc_info.value) == "401 Unauthorized: Invalid token signature."
|
||||
|
||||
def test_should_use_hs256_algorithm(self, passport_service):
|
||||
"""Test that HS256 algorithm is used for signing"""
|
||||
payload = {"test": "data"}
|
||||
token = passport_service.issue(payload)
|
||||
|
||||
# Decode header without relying on JWT internals
|
||||
# Use jwt.get_unverified_header which is a public API
|
||||
header = jwt.get_unverified_header(token)
|
||||
assert header["alg"] == "HS256"
|
||||
|
||||
def test_should_reject_token_with_wrong_algorithm(self, passport_service):
|
||||
"""Test rejection of token signed with different algorithm"""
|
||||
payload = {"user_id": "123"}
|
||||
|
||||
# Create token with different algorithm
|
||||
with patch("libs.passport.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "test-secret-key-for-testing"
|
||||
# Create token with HS512 instead of HS256
|
||||
wrong_alg_token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS512")
|
||||
|
||||
# Should fail because service expects HS256
|
||||
# InvalidAlgorithmError is now caught by PyJWTError handler
|
||||
with pytest.raises(Unauthorized) as exc_info:
|
||||
passport_service.verify(wrong_alg_token)
|
||||
assert str(exc_info.value) == "401 Unauthorized: Invalid token."
|
||||
|
||||
# Exception handling tests
|
||||
def test_should_handle_invalid_tokens(self, passport_service):
|
||||
"""Test handling of various invalid token formats"""
|
||||
invalid_tokens = [
|
||||
("not.a.token", "Invalid token."),
|
||||
("invalid-jwt-format", "Invalid token."),
|
||||
("xxx.yyy.zzz", "Invalid token."),
|
||||
("a.b", "Invalid token."), # Missing signature
|
||||
("", "Invalid token."), # Empty string
|
||||
(" ", "Invalid token."), # Whitespace
|
||||
(None, "Invalid token."), # None value
|
||||
# Malformed base64
|
||||
("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.INVALID_BASE64!@#$.signature", "Invalid token."),
|
||||
]
|
||||
|
||||
for invalid_token, expected_message in invalid_tokens:
|
||||
with pytest.raises(Unauthorized) as exc_info:
|
||||
passport_service.verify(invalid_token)
|
||||
assert expected_message in str(exc_info.value)
|
||||
|
||||
def test_should_reject_expired_token(self, passport_service):
|
||||
"""Test rejection of expired token"""
|
||||
past_time = datetime.now(UTC) - timedelta(hours=1)
|
||||
payload = {"user_id": "123", "exp": past_time.timestamp()}
|
||||
|
||||
with patch("libs.passport.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "test-secret-key-for-testing"
|
||||
token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS256")
|
||||
|
||||
with pytest.raises(Unauthorized) as exc_info:
|
||||
passport_service.verify(token)
|
||||
assert str(exc_info.value) == "401 Unauthorized: Token has expired."
|
||||
|
||||
# Configuration tests
|
||||
def test_should_handle_empty_secret_key(self):
|
||||
"""Test behavior when SECRET_KEY is empty"""
|
||||
with patch("libs.passport.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = ""
|
||||
service = PassportService()
|
||||
|
||||
# Empty secret key should still work but is insecure
|
||||
payload = {"test": "data"}
|
||||
token = service.issue(payload)
|
||||
decoded = service.verify(token)
|
||||
assert decoded == payload
|
||||
|
||||
def test_should_handle_none_secret_key(self):
|
||||
"""Test behavior when SECRET_KEY is None"""
|
||||
with patch("libs.passport.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = None
|
||||
service = PassportService()
|
||||
|
||||
payload = {"test": "data"}
|
||||
# JWT library will raise TypeError when secret is None
|
||||
with pytest.raises((TypeError, jwt.exceptions.InvalidKeyError)):
|
||||
service.issue(payload)
|
||||
|
||||
# Boundary condition tests
|
||||
def test_should_handle_large_payload(self, passport_service):
|
||||
"""Test handling of large payload"""
|
||||
# Test with 100KB instead of 1MB for faster tests
|
||||
large_data = "x" * (100 * 1024)
|
||||
payload = {"data": large_data}
|
||||
|
||||
token = passport_service.issue(payload)
|
||||
decoded = passport_service.verify(token)
|
||||
|
||||
assert decoded["data"] == large_data
|
||||
|
||||
def test_should_handle_special_characters_in_payload(self, passport_service):
|
||||
"""Test handling of special characters in payload"""
|
||||
special_payloads = [
|
||||
{"special": "!@#$%^&*()"},
|
||||
{"quotes": 'He said "Hello"'},
|
||||
{"backslash": "path\\to\\file"},
|
||||
{"newline": "line1\nline2"},
|
||||
{"unicode": "🔐🔑🛡️"},
|
||||
{"mixed": "Test123!@#中文🔐"},
|
||||
]
|
||||
|
||||
for payload in special_payloads:
|
||||
token = passport_service.issue(payload)
|
||||
decoded = passport_service.verify(token)
|
||||
assert decoded == payload
|
||||
|
||||
def test_should_catch_generic_pyjwt_errors(self, passport_service):
|
||||
"""Test that generic PyJWTError exceptions are caught and converted to Unauthorized"""
|
||||
# Mock jwt.decode to raise a generic PyJWTError
|
||||
with patch("libs.passport.jwt.decode") as mock_decode:
|
||||
mock_decode.side_effect = jwt.exceptions.PyJWTError("Generic JWT error")
|
||||
|
||||
with pytest.raises(Unauthorized) as exc_info:
|
||||
passport_service.verify("some-token")
|
||||
assert str(exc_info.value) == "401 Unauthorized: Invalid token."
|
||||
@ -0,0 +1,59 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
class ServiceDbTestHelper:
|
||||
"""
|
||||
Helper class for service database query tests.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def setup_db_query_filter_by_mock(mock_db, query_results):
|
||||
"""
|
||||
Smart database query mock that responds based on model type and query parameters.
|
||||
|
||||
Args:
|
||||
mock_db: Mock database session
|
||||
query_results: Dict mapping (model_name, filter_key, filter_value) to return value
|
||||
Example: {('Account', 'email', 'test@example.com'): mock_account}
|
||||
"""
|
||||
|
||||
def query_side_effect(model):
|
||||
mock_query = MagicMock()
|
||||
|
||||
def filter_by_side_effect(**kwargs):
|
||||
mock_filter_result = MagicMock()
|
||||
|
||||
def first_side_effect():
|
||||
# Find matching result based on model and filter parameters
|
||||
for (model_name, filter_key, filter_value), result in query_results.items():
|
||||
if model.__name__ == model_name and filter_key in kwargs and kwargs[filter_key] == filter_value:
|
||||
return result
|
||||
return None
|
||||
|
||||
mock_filter_result.first.side_effect = first_side_effect
|
||||
|
||||
# Handle order_by calls for complex queries
|
||||
def order_by_side_effect(*args, **kwargs):
|
||||
mock_order_result = MagicMock()
|
||||
|
||||
def order_first_side_effect():
|
||||
# Look for order_by results in the same query_results dict
|
||||
for (model_name, filter_key, filter_value), result in query_results.items():
|
||||
if (
|
||||
model.__name__ == model_name
|
||||
and filter_key == "order_by"
|
||||
and filter_value == "first_available"
|
||||
):
|
||||
return result
|
||||
return None
|
||||
|
||||
mock_order_result.first.side_effect = order_first_side_effect
|
||||
return mock_order_result
|
||||
|
||||
mock_filter_result.order_by.side_effect = order_by_side_effect
|
||||
return mock_filter_result
|
||||
|
||||
mock_query.filter_by.side_effect = filter_by_side_effect
|
||||
return mock_query
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,16 +1,18 @@
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import Main from '@/app/components/explore/installed-app'
|
||||
|
||||
export type IInstalledAppProps = {
|
||||
params: Promise<{
|
||||
params: {
|
||||
appId: string
|
||||
}>
|
||||
}
|
||||
}
|
||||
|
||||
const InstalledApp: FC<IInstalledAppProps> = async ({ params }) => {
|
||||
// Using Next.js page convention for async server components
|
||||
async function InstalledApp({ params }: IInstalledAppProps) {
|
||||
const appId = (await params).appId
|
||||
return (
|
||||
<Main id={(await params).appId} />
|
||||
<Main id={appId} />
|
||||
)
|
||||
}
|
||||
export default React.memo(InstalledApp)
|
||||
|
||||
export default InstalledApp
|
||||
|
||||
@ -0,0 +1,84 @@
|
||||
'use client'
|
||||
|
||||
import AppUnavailable from '@/app/components/base/app-unavailable'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { removeAccessToken } from '@/app/components/share/utils'
|
||||
import { useWebAppStore } from '@/context/web-app-context'
|
||||
import { useGetUserCanAccessApp } from '@/service/access-control'
|
||||
import { useGetWebAppInfo, useGetWebAppMeta, useGetWebAppParams } from '@/service/use-share'
|
||||
import { usePathname, useRouter, useSearchParams } from 'next/navigation'
|
||||
import React, { useCallback, useEffect } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
const AuthenticatedLayout = ({ children }: { children: React.ReactNode }) => {
|
||||
const { t } = useTranslation()
|
||||
const updateAppInfo = useWebAppStore(s => s.updateAppInfo)
|
||||
const updateAppParams = useWebAppStore(s => s.updateAppParams)
|
||||
const updateWebAppMeta = useWebAppStore(s => s.updateWebAppMeta)
|
||||
const updateUserCanAccessApp = useWebAppStore(s => s.updateUserCanAccessApp)
|
||||
const { isFetching: isFetchingAppParams, data: appParams, error: appParamsError } = useGetWebAppParams()
|
||||
const { isFetching: isFetchingAppInfo, data: appInfo, error: appInfoError } = useGetWebAppInfo()
|
||||
const { isFetching: isFetchingAppMeta, data: appMeta, error: appMetaError } = useGetWebAppMeta()
|
||||
const { data: userCanAccessApp, error: useCanAccessAppError } = useGetUserCanAccessApp({ appId: appInfo?.app_id, isInstalledApp: false })
|
||||
|
||||
useEffect(() => {
|
||||
if (appInfo)
|
||||
updateAppInfo(appInfo)
|
||||
if (appParams)
|
||||
updateAppParams(appParams)
|
||||
if (appMeta)
|
||||
updateWebAppMeta(appMeta)
|
||||
updateUserCanAccessApp(Boolean(userCanAccessApp && userCanAccessApp?.result))
|
||||
}, [appInfo, appMeta, appParams, updateAppInfo, updateAppParams, updateUserCanAccessApp, updateWebAppMeta, userCanAccessApp])
|
||||
|
||||
const router = useRouter()
|
||||
const pathname = usePathname()
|
||||
const searchParams = useSearchParams()
|
||||
const getSigninUrl = useCallback(() => {
|
||||
const params = new URLSearchParams(searchParams)
|
||||
params.delete('message')
|
||||
params.set('redirect_url', pathname)
|
||||
return `/webapp-signin?${params.toString()}`
|
||||
}, [searchParams, pathname])
|
||||
|
||||
const backToHome = useCallback(() => {
|
||||
removeAccessToken()
|
||||
const url = getSigninUrl()
|
||||
router.replace(url)
|
||||
}, [getSigninUrl, router])
|
||||
|
||||
if (appInfoError) {
|
||||
return <div className='flex h-full items-center justify-center'>
|
||||
<AppUnavailable unknownReason={appInfoError.message} />
|
||||
</div>
|
||||
}
|
||||
if (appParamsError) {
|
||||
return <div className='flex h-full items-center justify-center'>
|
||||
<AppUnavailable unknownReason={appParamsError.message} />
|
||||
</div>
|
||||
}
|
||||
if (appMetaError) {
|
||||
return <div className='flex h-full items-center justify-center'>
|
||||
<AppUnavailable unknownReason={appMetaError.message} />
|
||||
</div>
|
||||
}
|
||||
if (useCanAccessAppError) {
|
||||
return <div className='flex h-full items-center justify-center'>
|
||||
<AppUnavailable unknownReason={useCanAccessAppError.message} />
|
||||
</div>
|
||||
}
|
||||
if (userCanAccessApp && !userCanAccessApp.result) {
|
||||
return <div className='flex h-full flex-col items-center justify-center gap-y-2'>
|
||||
<AppUnavailable className='h-auto w-auto' code={403} unknownReason='no permission.' />
|
||||
<span className='system-sm-regular cursor-pointer text-text-tertiary' onClick={backToHome}>{t('common.userProfile.logout')}</span>
|
||||
</div>
|
||||
}
|
||||
if (isFetchingAppInfo || isFetchingAppParams || isFetchingAppMeta) {
|
||||
return <div className='flex h-full items-center justify-center'>
|
||||
<Loading />
|
||||
</div>
|
||||
}
|
||||
return <>{children}</>
|
||||
}
|
||||
|
||||
export default React.memo(AuthenticatedLayout)
|
||||
@ -0,0 +1,80 @@
|
||||
'use client'
|
||||
import type { FC, PropsWithChildren } from 'react'
|
||||
import { useEffect } from 'react'
|
||||
import { useCallback } from 'react'
|
||||
import { useWebAppStore } from '@/context/web-app-context'
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import AppUnavailable from '@/app/components/base/app-unavailable'
|
||||
import { checkOrSetAccessToken, removeAccessToken, setAccessToken } from '@/app/components/share/utils'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { fetchAccessToken } from '@/service/share'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
|
||||
const Splash: FC<PropsWithChildren> = ({ children }) => {
|
||||
const { t } = useTranslation()
|
||||
const shareCode = useWebAppStore(s => s.shareCode)
|
||||
const webAppAccessMode = useWebAppStore(s => s.webAppAccessMode)
|
||||
const searchParams = useSearchParams()
|
||||
const router = useRouter()
|
||||
const redirectUrl = searchParams.get('redirect_url')
|
||||
const tokenFromUrl = searchParams.get('web_sso_token')
|
||||
const message = searchParams.get('message')
|
||||
const code = searchParams.get('code')
|
||||
const getSigninUrl = useCallback(() => {
|
||||
const params = new URLSearchParams(searchParams)
|
||||
params.delete('message')
|
||||
params.delete('code')
|
||||
return `/webapp-signin?${params.toString()}`
|
||||
}, [searchParams])
|
||||
|
||||
const backToHome = useCallback(() => {
|
||||
removeAccessToken()
|
||||
const url = getSigninUrl()
|
||||
router.replace(url)
|
||||
}, [getSigninUrl, router])
|
||||
|
||||
useEffect(() => {
|
||||
(async () => {
|
||||
if (message)
|
||||
return
|
||||
if (shareCode && tokenFromUrl && redirectUrl) {
|
||||
localStorage.setItem('webapp_access_token', tokenFromUrl)
|
||||
const tokenResp = await fetchAccessToken({ appCode: shareCode, webAppAccessToken: tokenFromUrl })
|
||||
await setAccessToken(shareCode, tokenResp.access_token)
|
||||
router.replace(decodeURIComponent(redirectUrl))
|
||||
return
|
||||
}
|
||||
if (shareCode && redirectUrl && localStorage.getItem('webapp_access_token')) {
|
||||
const tokenResp = await fetchAccessToken({ appCode: shareCode, webAppAccessToken: localStorage.getItem('webapp_access_token') })
|
||||
await setAccessToken(shareCode, tokenResp.access_token)
|
||||
router.replace(decodeURIComponent(redirectUrl))
|
||||
return
|
||||
}
|
||||
if (webAppAccessMode === AccessMode.PUBLIC && redirectUrl) {
|
||||
await checkOrSetAccessToken(shareCode)
|
||||
router.replace(decodeURIComponent(redirectUrl))
|
||||
}
|
||||
})()
|
||||
}, [shareCode, redirectUrl, router, tokenFromUrl, message, webAppAccessMode])
|
||||
|
||||
if (message) {
|
||||
return <div className='flex h-full flex-col items-center justify-center gap-y-4'>
|
||||
<AppUnavailable className='h-auto w-auto' code={code || t('share.common.appUnavailable')} unknownReason={message} />
|
||||
<span className='system-sm-regular cursor-pointer text-text-tertiary' onClick={backToHome}>{code === '403' ? t('common.userProfile.logout') : t('share.login.backToHome')}</span>
|
||||
</div>
|
||||
}
|
||||
if (tokenFromUrl) {
|
||||
return <div className='flex h-full items-center justify-center'>
|
||||
<Loading />
|
||||
</div>
|
||||
}
|
||||
if (webAppAccessMode === AccessMode.PUBLIC && redirectUrl) {
|
||||
return <div className='flex h-full items-center justify-center'>
|
||||
<Loading />
|
||||
</div>
|
||||
}
|
||||
return <>{children}</>
|
||||
}
|
||||
|
||||
export default Splash
|
||||
@ -0,0 +1,374 @@
|
||||
import React, { useState } from 'react'
|
||||
import { Trans, useTranslation } from 'react-i18next'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import { RiCloseLine } from '@remixicon/react'
|
||||
import Modal from '@/app/components/base/modal'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Input from '@/app/components/base/input'
|
||||
import {
|
||||
checkEmailExisted,
|
||||
logout,
|
||||
resetEmail,
|
||||
sendVerifyCode,
|
||||
verifyEmail,
|
||||
} from '@/service/common'
|
||||
import { noop } from 'lodash-es'
|
||||
|
||||
type Props = {
|
||||
show: boolean
|
||||
onClose: () => void
|
||||
email: string
|
||||
}
|
||||
|
||||
enum STEP {
|
||||
start = 'start',
|
||||
verifyOrigin = 'verifyOrigin',
|
||||
newEmail = 'newEmail',
|
||||
verifyNew = 'verifyNew',
|
||||
}
|
||||
|
||||
const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useContext(ToastContext)
|
||||
const router = useRouter()
|
||||
const [step, setStep] = useState<STEP>(STEP.start)
|
||||
const [code, setCode] = useState<string>('')
|
||||
const [mail, setMail] = useState<string>('')
|
||||
const [time, setTime] = useState<number>(0)
|
||||
const [stepToken, setStepToken] = useState<string>('')
|
||||
const [newEmailExited, setNewEmailExited] = useState<boolean>(false)
|
||||
|
||||
const startCount = () => {
|
||||
setTime(60)
|
||||
const timer = setInterval(() => {
|
||||
setTime((prev) => {
|
||||
if (prev <= 0) {
|
||||
clearInterval(timer)
|
||||
return 0
|
||||
}
|
||||
return prev - 1
|
||||
})
|
||||
}, 1000)
|
||||
}
|
||||
|
||||
const sendEmail = async (email: string, isOrigin: boolean, token?: string) => {
|
||||
try {
|
||||
const res = await sendVerifyCode({
|
||||
email,
|
||||
phase: isOrigin ? 'old_email' : 'new_email',
|
||||
token,
|
||||
})
|
||||
startCount()
|
||||
if (res.data)
|
||||
setStepToken(res.data)
|
||||
}
|
||||
catch (error) {
|
||||
notify({
|
||||
type: 'error',
|
||||
message: `Error sending verification code: ${error ? (error as any).message : ''}`,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const verifyEmailAddress = async (email: string, code: string, token: string, callback?: () => void) => {
|
||||
try {
|
||||
const res = await verifyEmail({
|
||||
email,
|
||||
code,
|
||||
token,
|
||||
})
|
||||
if (res.is_valid) {
|
||||
setStepToken(res.token)
|
||||
callback?.()
|
||||
}
|
||||
else {
|
||||
notify({
|
||||
type: 'error',
|
||||
message: 'Verifying email failed',
|
||||
})
|
||||
}
|
||||
}
|
||||
catch (error) {
|
||||
notify({
|
||||
type: 'error',
|
||||
message: `Error verifying email: ${error ? (error as any).message : ''}`,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const sendCodeToOriginEmail = async () => {
|
||||
await sendEmail(
|
||||
email,
|
||||
true,
|
||||
)
|
||||
setStep(STEP.verifyOrigin)
|
||||
}
|
||||
|
||||
const handleVerifyOriginEmail = async () => {
|
||||
await verifyEmailAddress(email, code, stepToken, () => setStep(STEP.newEmail))
|
||||
setCode('')
|
||||
}
|
||||
|
||||
const isValidEmail = (email: string): boolean => {
|
||||
const emailRegex = /^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$/
|
||||
return emailRegex.test(email)
|
||||
}
|
||||
|
||||
const checkNewEmailExisted = async (email: string) => {
|
||||
try {
|
||||
await checkEmailExisted({
|
||||
email,
|
||||
})
|
||||
setNewEmailExited(false)
|
||||
}
|
||||
catch (error) {
|
||||
setNewEmailExited(false)
|
||||
if ((error as any)?.code === 'email_already_in_use') {
|
||||
setNewEmailExited(true)
|
||||
}
|
||||
else {
|
||||
notify({
|
||||
type: 'error',
|
||||
message: `Error checking email existence: ${error ? (error as any).message : ''}`,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const handleNewEmailValueChange = (mailAddress: string) => {
|
||||
setMail(mailAddress)
|
||||
if (isValidEmail(mailAddress))
|
||||
checkNewEmailExisted(mailAddress)
|
||||
}
|
||||
|
||||
const sendCodeToNewEmail = async () => {
|
||||
if (!isValidEmail(mail)) {
|
||||
notify({
|
||||
type: 'error',
|
||||
message: 'Invalid email format',
|
||||
})
|
||||
return
|
||||
}
|
||||
await sendEmail(
|
||||
mail,
|
||||
false,
|
||||
stepToken,
|
||||
)
|
||||
setStep(STEP.verifyNew)
|
||||
}
|
||||
|
||||
const handleLogout = async () => {
|
||||
await logout({
|
||||
url: '/logout',
|
||||
params: {},
|
||||
})
|
||||
|
||||
localStorage.removeItem('setup_status')
|
||||
localStorage.removeItem('console_token')
|
||||
localStorage.removeItem('refresh_token')
|
||||
|
||||
router.push('/signin')
|
||||
}
|
||||
|
||||
const updateEmail = async () => {
|
||||
try {
|
||||
await resetEmail({
|
||||
new_email: mail,
|
||||
token: stepToken,
|
||||
})
|
||||
handleLogout()
|
||||
}
|
||||
catch (error) {
|
||||
notify({
|
||||
type: 'error',
|
||||
message: `Error changing email: ${error ? (error as any).message : ''}`,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const submitNewEmail = async () => {
|
||||
await verifyEmailAddress(mail, code, stepToken, () => updateEmail())
|
||||
}
|
||||
|
||||
return (
|
||||
<Modal
|
||||
isShow={show}
|
||||
onClose={noop}
|
||||
className='!w-[420px] !p-6'
|
||||
>
|
||||
<div className='absolute right-5 top-5 cursor-pointer p-1.5' onClick={onClose}>
|
||||
<RiCloseLine className='h-5 w-5 text-text-tertiary' />
|
||||
</div>
|
||||
{step === STEP.start && (
|
||||
<>
|
||||
<div className='title-2xl-semi-bold pb-3 text-text-primary'>{t('common.account.changeEmail.title')}</div>
|
||||
<div className='space-y-0.5 pb-2 pt-1'>
|
||||
<div className='body-md-medium text-text-warning'>{t('common.account.changeEmail.authTip')}</div>
|
||||
<div className='body-md-regular text-text-secondary'>
|
||||
<Trans
|
||||
i18nKey="common.account.changeEmail.content1"
|
||||
components={{ email: <span className='body-md-medium text-text-primary'></span> }}
|
||||
values={{ email }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className='pt-3'></div>
|
||||
<div className='space-y-2'>
|
||||
<Button
|
||||
className='!w-full'
|
||||
variant='primary'
|
||||
onClick={sendCodeToOriginEmail}
|
||||
>
|
||||
{t('common.account.changeEmail.sendVerifyCode')}
|
||||
</Button>
|
||||
<Button
|
||||
className='!w-full'
|
||||
onClick={onClose}
|
||||
>
|
||||
{t('common.operation.cancel')}
|
||||
</Button>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
{step === STEP.verifyOrigin && (
|
||||
<>
|
||||
<div className='title-2xl-semi-bold pb-3 text-text-primary'>{t('common.account.changeEmail.verifyEmail')}</div>
|
||||
<div className='space-y-0.5 pb-2 pt-1'>
|
||||
<div className='body-md-regular text-text-secondary'>
|
||||
<Trans
|
||||
i18nKey="common.account.changeEmail.content2"
|
||||
components={{ email: <span className='body-md-medium text-text-primary'></span> }}
|
||||
values={{ email }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className='pt-3'>
|
||||
<div className='system-sm-medium mb-1 flex h-6 items-center text-text-secondary'>{t('common.account.changeEmail.codeLabel')}</div>
|
||||
<Input
|
||||
className='!w-full'
|
||||
placeholder={t('common.account.changeEmail.codePlaceholder')}
|
||||
value={code}
|
||||
onChange={e => setCode(e.target.value)}
|
||||
maxLength={6}
|
||||
/>
|
||||
</div>
|
||||
<div className='mt-3 space-y-2'>
|
||||
<Button
|
||||
disabled={code.length !== 6}
|
||||
className='!w-full'
|
||||
variant='primary'
|
||||
onClick={handleVerifyOriginEmail}
|
||||
>
|
||||
{t('common.account.changeEmail.continue')}
|
||||
</Button>
|
||||
<Button
|
||||
className='!w-full'
|
||||
onClick={onClose}
|
||||
>
|
||||
{t('common.operation.cancel')}
|
||||
</Button>
|
||||
</div>
|
||||
<div className='system-xs-regular mt-3 flex items-center gap-1 text-text-tertiary'>
|
||||
<span>{t('common.account.changeEmail.resendTip')}</span>
|
||||
{time > 0 && (
|
||||
<span>{t('common.account.changeEmail.resendCount', { count: time })}</span>
|
||||
)}
|
||||
{!time && (
|
||||
<span onClick={sendCodeToOriginEmail} className='system-xs-medium cursor-pointer text-text-accent-secondary'>{t('common.account.changeEmail.resend')}</span>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
{step === STEP.newEmail && (
|
||||
<>
|
||||
<div className='title-2xl-semi-bold pb-3 text-text-primary'>{t('common.account.changeEmail.newEmail')}</div>
|
||||
<div className='space-y-0.5 pb-2 pt-1'>
|
||||
<div className='body-md-regular text-text-secondary'>{t('common.account.changeEmail.content3')}</div>
|
||||
</div>
|
||||
<div className='pt-3'>
|
||||
<div className='system-sm-medium mb-1 flex h-6 items-center text-text-secondary'>{t('common.account.changeEmail.emailLabel')}</div>
|
||||
<Input
|
||||
className='!w-full'
|
||||
placeholder={t('common.account.changeEmail.emailPlaceholder')}
|
||||
value={mail}
|
||||
onChange={e => handleNewEmailValueChange(e.target.value)}
|
||||
destructive={newEmailExited}
|
||||
/>
|
||||
{newEmailExited && (
|
||||
<div className='body-xs-regular mt-1 py-0.5 text-text-destructive'>{t('common.account.changeEmail.existingEmail')}</div>
|
||||
)}
|
||||
</div>
|
||||
<div className='mt-3 space-y-2'>
|
||||
<Button
|
||||
disabled={!mail}
|
||||
className='!w-full'
|
||||
variant='primary'
|
||||
onClick={sendCodeToNewEmail}
|
||||
>
|
||||
{t('common.account.changeEmail.sendVerifyCode')}
|
||||
</Button>
|
||||
<Button
|
||||
className='!w-full'
|
||||
onClick={onClose}
|
||||
>
|
||||
{t('common.operation.cancel')}
|
||||
</Button>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
{step === STEP.verifyNew && (
|
||||
<>
|
||||
<div className='title-2xl-semi-bold pb-3 text-text-primary'>{t('common.account.changeEmail.verifyNew')}</div>
|
||||
<div className='space-y-0.5 pb-2 pt-1'>
|
||||
<div className='body-md-regular text-text-secondary'>
|
||||
<Trans
|
||||
i18nKey="common.account.changeEmail.content4"
|
||||
components={{ email: <span className='body-md-medium text-text-primary'></span> }}
|
||||
values={{ email: mail }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className='pt-3'>
|
||||
<div className='system-sm-medium mb-1 flex h-6 items-center text-text-secondary'>{t('common.account.changeEmail.codeLabel')}</div>
|
||||
<Input
|
||||
className='!w-full'
|
||||
placeholder={t('common.account.changeEmail.codePlaceholder')}
|
||||
value={code}
|
||||
onChange={e => setCode(e.target.value)}
|
||||
maxLength={6}
|
||||
/>
|
||||
</div>
|
||||
<div className='mt-3 space-y-2'>
|
||||
<Button
|
||||
disabled={code.length !== 6}
|
||||
className='!w-full'
|
||||
variant='primary'
|
||||
onClick={submitNewEmail}
|
||||
>
|
||||
{t('common.account.changeEmail.changeTo', { email: mail })}
|
||||
</Button>
|
||||
<Button
|
||||
className='!w-full'
|
||||
onClick={onClose}
|
||||
>
|
||||
{t('common.operation.cancel')}
|
||||
</Button>
|
||||
</div>
|
||||
<div className='system-xs-regular mt-3 flex items-center gap-1 text-text-tertiary'>
|
||||
<span>{t('common.account.changeEmail.resendTip')}</span>
|
||||
{time > 0 && (
|
||||
<span>{t('common.account.changeEmail.resendCount', { count: time })}</span>
|
||||
)}
|
||||
{!time && (
|
||||
<span onClick={sendCodeToNewEmail} className='system-xs-medium cursor-pointer text-text-accent-secondary'>{t('common.account.changeEmail.resend')}</span>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</Modal>
|
||||
)
|
||||
}
|
||||
|
||||
export default EmailChangeModal
|
||||
@ -1,9 +0,0 @@
|
||||
.modal {
|
||||
padding: 24px 32px !important;
|
||||
width: 400px !important;
|
||||
}
|
||||
|
||||
.bg {
|
||||
background: linear-gradient(180deg, rgba(217, 45, 32, 0.05) 0%, rgba(217, 45, 32, 0.00) 24.02%), #F9FAFB;
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue