From fb9e4a422735f3e0562058fe60cf57564f4bec6b Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 11 Jul 2025 16:05:11 +0800 Subject: [PATCH 1/2] feat(oauth): migrations --- .../versions/2025_05_15_1635-16081485540c_.py | 41 +++++++++++++++++++ ...c_merge_tool_oauth_and_remove_sequence_.py | 25 ----------- ...025_07_04_1705-71f5020c6470_tool_oauth.py} | 12 ++---- 3 files changed, 45 insertions(+), 33 deletions(-) create mode 100644 api/migrations/versions/2025_05_15_1635-16081485540c_.py delete mode 100644 api/migrations/versions/2025_06_25_1101-46d46b3f389c_merge_tool_oauth_and_remove_sequence_.py rename api/migrations/versions/{2025_06_24_1705-71f5020c6470_tool_oauth.py => 2025_07_04_1705-71f5020c6470_tool_oauth.py} (85%) diff --git a/api/migrations/versions/2025_05_15_1635-16081485540c_.py b/api/migrations/versions/2025_05_15_1635-16081485540c_.py new file mode 100644 index 0000000000..70ed771391 --- /dev/null +++ b/api/migrations/versions/2025_05_15_1635-16081485540c_.py @@ -0,0 +1,41 @@ +"""empty message + +Revision ID: 16081485540c +Revises: d28f2004b072 +Create Date: 2025-05-15 16:35:39.113777 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '16081485540c' +down_revision = '58eb7bdb93fe' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tenant_plugin_auto_upgrade_strategies', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('strategy_setting', sa.String(length=16), server_default='fix_only', nullable=False), + sa.Column('upgrade_time_of_day', sa.Integer(), nullable=False), + sa.Column('upgrade_mode', sa.String(length=16), server_default='exclude', nullable=False), + sa.Column('exclude_plugins', sa.ARRAY(sa.String(length=255)), nullable=False), + sa.Column('include_plugins', sa.ARRAY(sa.String(length=255)), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tenant_plugin_auto_upgrade_strategy_pkey'), + sa.UniqueConstraint('tenant_id', name='unique_tenant_plugin_auto_upgrade_strategy') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('tenant_plugin_auto_upgrade_strategies') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_25_1101-46d46b3f389c_merge_tool_oauth_and_remove_sequence_.py b/api/migrations/versions/2025_06_25_1101-46d46b3f389c_merge_tool_oauth_and_remove_sequence_.py deleted file mode 100644 index a3c51e7e75..0000000000 --- a/api/migrations/versions/2025_06_25_1101-46d46b3f389c_merge_tool_oauth_and_remove_sequence_.py +++ /dev/null @@ -1,25 +0,0 @@ -"""merge tool oauth and remove sequence number branches - -Revision ID: 46d46b3f389c -Revises: 0ab65e1cc7fa, 71f5020c6470 -Create Date: 2025-06-25 11:01:55.215896 - -""" -from alembic import op -import models as models -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = '46d46b3f389c' -down_revision = ('0ab65e1cc7fa', '71f5020c6470') -branch_labels = None -depends_on = None - - -def upgrade(): - pass - - -def downgrade(): - pass diff --git a/api/migrations/versions/2025_06_24_1705-71f5020c6470_tool_oauth.py b/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py similarity index 85% rename from api/migrations/versions/2025_06_24_1705-71f5020c6470_tool_oauth.py rename to api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py index ffb4fffe56..32cc08ab1a 100644 --- a/api/migrations/versions/2025_06_24_1705-71f5020c6470_tool_oauth.py +++ b/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py @@ -12,7 +12,7 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. revision = '71f5020c6470' -down_revision = '4474872b0ee6' +down_revision = '16081485540c' branch_labels = None depends_on = None @@ -37,29 +37,25 @@ def upgrade(): sa.PrimaryKeyConstraint('id', name='tool_oauth_tenant_client_pkey'), sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_tenant_client') ) - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.drop_constraint(batch_op.f('unique_api_tool_provider'), type_='unique') with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: batch_op.add_column(sa.Column('name', sa.String(length=256), server_default=sa.text("'API KEY 1'::character varying"), nullable=False)) batch_op.add_column(sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False)) - batch_op.add_column(sa.Column('credential_type', sa.String(length=32), server_default=sa.text("'api_key'::character varying"), nullable=False)) + batch_op.add_column(sa.Column('credential_type', sa.String(length=32), server_default=sa.text("'api-key'::character varying"), nullable=False)) batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique') - + batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'name']) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique') batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider']) batch_op.drop_column('credential_type') batch_op.drop_column('is_default') batch_op.drop_column('name') - with op.batch_alter_table('tool_api_providers', schema=None) as batch_op: - batch_op.create_unique_constraint(batch_op.f('unique_api_tool_provider'), ['name', 'tenant_id']) - op.drop_table('tool_oauth_tenant_clients') op.drop_table('tool_oauth_system_clients') # ### end Alembic commands ### From adc39f7b0db283e514d930be9a29c0345eb5a906 Mon Sep 17 00:00:00 2001 From: Harry Date: Fri, 11 Jul 2025 16:28:40 +0800 Subject: [PATCH 2/2] feat(oauth): enhance OAuth client management and validation --- ...2025_07_04_1705-71f5020c6470_tool_oauth.py | 1 + api/models/tools.py | 18 +++++---- .../tools/builtin_tools_manage_service.py | 40 ++++++++++++++++--- api/services/tools/mcp_tools_mange_service.py | 7 ++-- 4 files changed, 50 insertions(+), 16 deletions(-) diff --git a/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py b/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py index 32cc08ab1a..ad73563246 100644 --- a/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py +++ b/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py @@ -44,6 +44,7 @@ def upgrade(): batch_op.add_column(sa.Column('credential_type', sa.String(length=32), server_default=sa.text("'api-key'::character varying"), nullable=False)) batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique') batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'name']) + # ### end Alembic commands ### diff --git a/api/models/tools.py b/api/models/tools.py index 05a4920a9c..34bc97d006 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -109,7 +109,10 @@ class ApiToolProvider(Base): """ __tablename__ = "tool_api_providers" - __table_args__ = (db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"),) + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_api_provider_pkey"), + db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"), + ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the api provider @@ -326,18 +329,17 @@ class MCPToolProvider(Base): @property def decrypted_credentials(self) -> dict: + from core.helper.provider_cache import NoOpProviderCredentialCache from core.tools.mcp_tool.provider import MCPToolProviderController - from core.tools.utils.configuration import ProviderConfigEncrypter + from core.tools.utils.encryption import create_provider_encrypter provider_controller = MCPToolProviderController._from_db(self) - tool_configuration = ProviderConfigEncrypter( + return create_provider_encrypter( tenant_id=self.tenant_id, - config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.provider_id, - ) - return tool_configuration.decrypt(self.credentials, use_cache=False) + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], + cache=NoOpProviderCredentialCache(), + )[0].decrypt(self.credentials) class ToolModelInvoke(Base): diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index fea74ba492..66157fb6b6 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -43,12 +43,22 @@ class BuiltinToolManageService: get builtin tool provider oauth client schema """ provider = ToolManager.get_builtin_provider(provider_name, tenant_id) - return { + + is_oauth_custom_client_enabled = BuiltinToolManageService.is_oauth_custom_client_enabled( + tenant_id, provider_name + ) + is_system_oauth_params_exists = BuiltinToolManageService.is_oauth_system_client_exists(provider_name) + result = { "schema": provider.get_oauth_client_schema(), - "is_oauth_custom_client_enabled": BuiltinToolManageService.is_oauth_custom_client_enabled( - tenant_id, provider_name - ), + "is_oauth_custom_client_enabled": is_oauth_custom_client_enabled, + "is_system_oauth_params_exists": is_system_oauth_params_exists, } + if is_oauth_custom_client_enabled: + result["client_params"] = BuiltinToolManageService.get_oauth_client(tenant_id, provider_name) + result["redirect_uri"] = ( + f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_name}/tool/callback" + ) + return result @staticmethod def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]: @@ -415,6 +425,20 @@ class BuiltinToolManageService: session.commit() return {"result": "success"} + @staticmethod + def is_oauth_system_client_exists(provider_name: str) -> bool: + """ + check if oauth system client exists + """ + tool_provider = ToolProviderID(provider_name) + with Session(db.engine).no_autoflush as session: + system_client: ToolOAuthSystemClient | None = ( + session.query(ToolOAuthSystemClient) + .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name) + .first() + ) + return system_client is not None + @staticmethod def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool: """ @@ -685,4 +709,10 @@ class BuiltinToolManageService: config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], cache=NoOpProviderCredentialCache(), ) - return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params)) + + return { + "oauth_params": encrypter.mask_tool_credentials( + encrypter.decrypt(custom_oauth_client_params.oauth_params) + ), + "enabled": custom_oauth_client_params.enabled, + } diff --git a/api/services/tools/mcp_tools_mange_service.py b/api/services/tools/mcp_tools_mange_service.py index 7c23abda4b..fda6da5983 100644 --- a/api/services/tools/mcp_tools_mange_service.py +++ b/api/services/tools/mcp_tools_mange_service.py @@ -7,13 +7,14 @@ from sqlalchemy import or_ from sqlalchemy.exc import IntegrityError from core.helper import encrypter +from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.error import MCPAuthError, MCPError from core.mcp.mcp_client import MCPClient from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType from core.tools.mcp_tool.provider import MCPToolProviderController -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.encryption import ProviderConfigEncrypter from extensions.ext_database import db from models.tools import MCPToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -69,6 +70,7 @@ class MCPToolManageService: MCPToolProvider.server_url_hash == server_url_hash, MCPToolProvider.server_identifier == server_identifier, ), + MCPToolProvider.tenant_id == tenant_id, ) .first() ) @@ -197,8 +199,7 @@ class MCPToolManageService: tool_configuration = ProviderConfigEncrypter( tenant_id=mcp_provider.tenant_id, config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.provider_id, + provider_config_cache=NoOpProviderCredentialCache(), ) credentials = tool_configuration.encrypt(credentials) mcp_provider.updated_at = datetime.now()