|
|
|
|
@ -1,19 +1,22 @@
|
|
|
|
|
import logging
|
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
|
|
from flask_login import current_user
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
|
|
|
|
from constants import HIDDEN_VALUE
|
|
|
|
|
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
|
|
|
|
from core.helper import encrypter
|
|
|
|
|
from core.helper.name_generator import generate_incremental_name
|
|
|
|
|
from core.helper.provider_cache import NoOpProviderCredentialCache
|
|
|
|
|
from core.model_runtime.entities.provider_entities import FormType
|
|
|
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
|
|
|
from core.plugin.entities.plugin import DatasourceProviderID
|
|
|
|
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
|
|
|
|
from core.tools.entities.tool_entities import CredentialType
|
|
|
|
|
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
|
|
|
|
from extensions.ext_database import db
|
|
|
|
|
from extensions.ext_redis import redis_client
|
|
|
|
|
from models.oauth import DatasourceProvider
|
|
|
|
|
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
@ -26,6 +29,165 @@ class DatasourceProviderService:
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
self.provider_manager = PluginDatasourceManager()
|
|
|
|
|
|
|
|
|
|
def setup_oauth_custom_client_params(
|
|
|
|
|
self,
|
|
|
|
|
tenant_id: str,
|
|
|
|
|
datasource_provider_id: DatasourceProviderID,
|
|
|
|
|
client_params: dict | None,
|
|
|
|
|
enabled: bool | None,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
setup oauth custom client params
|
|
|
|
|
"""
|
|
|
|
|
if client_params is None and enabled is None:
|
|
|
|
|
return
|
|
|
|
|
provider_controller = PluginDatasourceManager()
|
|
|
|
|
datasource_provider = provider_controller.fetch_datasource_provider(
|
|
|
|
|
tenant_id=tenant_id, provider_id=str(datasource_provider_id)
|
|
|
|
|
)
|
|
|
|
|
if not datasource_provider.declaration.oauth_schema:
|
|
|
|
|
raise ValueError("Datasource provider oauth schema not found")
|
|
|
|
|
with Session(db.engine) as session:
|
|
|
|
|
tenant_oauth_client_params = (
|
|
|
|
|
session.query(DatasourceOauthTenantParamConfig)
|
|
|
|
|
.filter_by(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
provider=datasource_provider_id.provider_name,
|
|
|
|
|
plugin_id=datasource_provider_id.plugin_id,
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not tenant_oauth_client_params:
|
|
|
|
|
tenant_oauth_client_params = DatasourceOauthTenantParamConfig(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
provider=datasource_provider_id.provider_name,
|
|
|
|
|
plugin_id=datasource_provider_id.plugin_id,
|
|
|
|
|
client_params={},
|
|
|
|
|
enabled=False,
|
|
|
|
|
)
|
|
|
|
|
session.add(tenant_oauth_client_params)
|
|
|
|
|
|
|
|
|
|
if client_params is not None:
|
|
|
|
|
client_schema = datasource_provider.declaration.oauth_schema.client_schema
|
|
|
|
|
encrypter, _ = create_provider_encrypter(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
config=[x.to_basic_provider_config() for x in client_schema],
|
|
|
|
|
cache=NoOpProviderCredentialCache(),
|
|
|
|
|
)
|
|
|
|
|
original_params = (
|
|
|
|
|
encrypter.decrypt(tenant_oauth_client_params.client_params) if tenant_oauth_client_params else {}
|
|
|
|
|
)
|
|
|
|
|
new_params: dict = {
|
|
|
|
|
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
|
|
|
|
|
for key, value in client_params.items()
|
|
|
|
|
}
|
|
|
|
|
tenant_oauth_client_params.client_params = encrypter.encrypt(new_params)
|
|
|
|
|
|
|
|
|
|
if enabled is not None:
|
|
|
|
|
tenant_oauth_client_params.enabled = enabled
|
|
|
|
|
session.commit()
|
|
|
|
|
|
|
|
|
|
def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool:
|
|
|
|
|
"""
|
|
|
|
|
check if system oauth params exist
|
|
|
|
|
"""
|
|
|
|
|
with Session(db.engine).no_autoflush as session:
|
|
|
|
|
return (
|
|
|
|
|
session.query(DatasourceOauthParamConfig)
|
|
|
|
|
.filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id)
|
|
|
|
|
.first()
|
|
|
|
|
is not None
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def is_tenant_oauth_params_enabled(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> bool:
|
|
|
|
|
"""
|
|
|
|
|
check if tenant oauth params is enabled
|
|
|
|
|
"""
|
|
|
|
|
with Session(db.engine).no_autoflush as session:
|
|
|
|
|
return (
|
|
|
|
|
session.query(DatasourceOauthTenantParamConfig)
|
|
|
|
|
.filter_by(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
provider=datasource_provider_id.provider_name,
|
|
|
|
|
plugin_id=datasource_provider_id.plugin_id,
|
|
|
|
|
enabled=True,
|
|
|
|
|
)
|
|
|
|
|
.count()
|
|
|
|
|
> 0
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def get_tenant_oauth_client(
|
|
|
|
|
self, tenant_id: str, datasource_provider_id: DatasourceProviderID
|
|
|
|
|
) -> dict[str, Any] | None:
|
|
|
|
|
"""
|
|
|
|
|
get tenant oauth client
|
|
|
|
|
"""
|
|
|
|
|
with Session(db.engine).no_autoflush as session:
|
|
|
|
|
tenant_oauth_client_params = (
|
|
|
|
|
session.query(DatasourceOauthTenantParamConfig)
|
|
|
|
|
.filter_by(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
provider=datasource_provider_id.provider_name,
|
|
|
|
|
plugin_id=datasource_provider_id.plugin_id,
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
if tenant_oauth_client_params:
|
|
|
|
|
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
|
|
|
|
|
return encrypter.decrypt(tenant_oauth_client_params.client_params)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def get_oauth_encrypter(
|
|
|
|
|
self, tenant_id: str, datasource_provider_id: DatasourceProviderID
|
|
|
|
|
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
|
|
|
|
"""
|
|
|
|
|
get oauth encrypter
|
|
|
|
|
"""
|
|
|
|
|
datasource_provider = self.provider_manager.fetch_datasource_provider(
|
|
|
|
|
tenant_id=tenant_id, provider_id=str(datasource_provider_id)
|
|
|
|
|
)
|
|
|
|
|
if not datasource_provider.declaration.oauth_schema:
|
|
|
|
|
raise ValueError("Datasource provider oauth schema not found")
|
|
|
|
|
|
|
|
|
|
client_schema = datasource_provider.declaration.oauth_schema.client_schema
|
|
|
|
|
return create_provider_encrypter(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
config=[x.to_basic_provider_config() for x in client_schema],
|
|
|
|
|
cache=NoOpProviderCredentialCache(),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def get_oauth_client(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> dict[str, Any] | None:
|
|
|
|
|
"""
|
|
|
|
|
get oauth client
|
|
|
|
|
"""
|
|
|
|
|
provider = datasource_provider_id.provider_name
|
|
|
|
|
plugin_id = datasource_provider_id.plugin_id
|
|
|
|
|
with Session(db.engine).no_autoflush as session:
|
|
|
|
|
# get tenant oauth client params
|
|
|
|
|
tenant_oauth_client_params = (
|
|
|
|
|
session.query(DatasourceOauthTenantParamConfig)
|
|
|
|
|
.filter_by(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
provider=provider,
|
|
|
|
|
plugin_id=plugin_id,
|
|
|
|
|
enabled=True,
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
if tenant_oauth_client_params:
|
|
|
|
|
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
|
|
|
|
|
return encrypter.decrypt(tenant_oauth_client_params.client_params)
|
|
|
|
|
|
|
|
|
|
# fallback to system oauth client params
|
|
|
|
|
oauth_client_params = (
|
|
|
|
|
session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
|
|
|
|
|
)
|
|
|
|
|
if oauth_client_params:
|
|
|
|
|
return oauth_client_params.system_credentials
|
|
|
|
|
|
|
|
|
|
raise ValueError(f"Please configure oauth client params(system/tenant) for {plugin_id}/{provider}")
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def generate_next_datasource_provider_name(
|
|
|
|
|
session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType
|
|
|
|
|
@ -69,24 +231,29 @@ class DatasourceProviderService:
|
|
|
|
|
credential_type=credential_type,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
if session.query(DatasourceProvider).filter_by(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
name=db_provider_name,
|
|
|
|
|
provider=provider_id.provider_name,
|
|
|
|
|
plugin_id=provider_id.plugin_id,
|
|
|
|
|
auth_type=credential_type.value,
|
|
|
|
|
).count() > 0:
|
|
|
|
|
if (
|
|
|
|
|
session.query(DatasourceProvider)
|
|
|
|
|
.filter_by(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
name=db_provider_name,
|
|
|
|
|
provider=provider_id.provider_name,
|
|
|
|
|
plugin_id=provider_id.plugin_id,
|
|
|
|
|
auth_type=credential_type.value,
|
|
|
|
|
)
|
|
|
|
|
.count()
|
|
|
|
|
> 0
|
|
|
|
|
):
|
|
|
|
|
db_provider_name = generate_incremental_name(
|
|
|
|
|
[
|
|
|
|
|
provider.name
|
|
|
|
|
for provider in session.query(DatasourceProvider).filter_by(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
provider=provider_id.provider_name,
|
|
|
|
|
plugin_id=provider_id.plugin_id,
|
|
|
|
|
)
|
|
|
|
|
],
|
|
|
|
|
db_provider_name,
|
|
|
|
|
)
|
|
|
|
|
[
|
|
|
|
|
provider.name
|
|
|
|
|
for provider in session.query(DatasourceProvider).filter_by(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
provider=provider_id.provider_name,
|
|
|
|
|
plugin_id=provider_id.plugin_id,
|
|
|
|
|
)
|
|
|
|
|
],
|
|
|
|
|
db_provider_name,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
provider_credential_secret_variables = self.extract_secret_variables(
|
|
|
|
|
tenant_id=tenant_id, provider_id=f"{provider_id}"
|
|
|
|
|
@ -103,7 +270,7 @@ class DatasourceProviderService:
|
|
|
|
|
plugin_id=provider_id.plugin_id,
|
|
|
|
|
auth_type=credential_type.value,
|
|
|
|
|
encrypted_credentials=credentials,
|
|
|
|
|
avatar_url=avatar_url,
|
|
|
|
|
avatar_url=avatar_url or "default",
|
|
|
|
|
)
|
|
|
|
|
session.add(datasource_provider)
|
|
|
|
|
session.commit()
|
|
|
|
|
@ -222,6 +389,7 @@ class DatasourceProviderService:
|
|
|
|
|
"credential": copy_credentials,
|
|
|
|
|
"type": datasource_provider.auth_type,
|
|
|
|
|
"name": datasource_provider.name,
|
|
|
|
|
"avatar_url": datasource_provider.avatar_url,
|
|
|
|
|
"id": datasource_provider.id,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
@ -239,6 +407,7 @@ class DatasourceProviderService:
|
|
|
|
|
datasources = manager.fetch_installed_datasource_providers(tenant_id)
|
|
|
|
|
datasource_credentials = []
|
|
|
|
|
for datasource in datasources:
|
|
|
|
|
datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
|
|
|
|
|
credentials = self.get_datasource_credentials(
|
|
|
|
|
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
|
|
|
|
|
)
|
|
|
|
|
@ -302,6 +471,11 @@ class DatasourceProviderService:
|
|
|
|
|
}
|
|
|
|
|
for credential in datasource.declaration.oauth_schema.credentials_schema or []
|
|
|
|
|
],
|
|
|
|
|
"oauth_custom_client_params": self.get_tenant_oauth_client(tenant_id, datasource_provider_id),
|
|
|
|
|
"is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled(
|
|
|
|
|
tenant_id, datasource_provider_id
|
|
|
|
|
),
|
|
|
|
|
"is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
|
|
|
|
|
}
|
|
|
|
|
if datasource.declaration.oauth_schema
|
|
|
|
|
else None,
|
|
|
|
|
|