feat: plugin OAuth with stateful
parent
366ddb05ae
commit
12c20ec7f6
@ -0,0 +1,67 @@
|
||||
import secrets
|
||||
import urllib.parse
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from dify_plugin import ToolProvider
|
||||
from dify_plugin.errors.tool import ToolProviderCredentialValidationError
|
||||
from werkzeug import Request
|
||||
|
||||
|
||||
class GithubProvider(ToolProvider):
|
||||
_AUTH_URL = "https://github.com/login/oauth/authorize"
|
||||
_TOKEN_URL = "https://github.com/login/oauth/access_token"
|
||||
_API_USER_URL = "https://api.github.com/user"
|
||||
|
||||
def _oauth_get_authorization_url(self, system_credentials: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Generate the authorization URL for the Github OAuth.
|
||||
"""
|
||||
state = secrets.token_urlsafe(16)
|
||||
params = {
|
||||
"client_id": system_credentials["client_id"],
|
||||
"redirect_uri": system_credentials["redirect_uri"],
|
||||
"scope": system_credentials.get("scope", "read:user"),
|
||||
"state": state,
|
||||
# Optionally: allow_signup, login, etc.
|
||||
}
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def _oauth_get_credentials(self, system_credentials: Mapping[str, Any], request: Request) -> Mapping[str, Any]:
|
||||
"""
|
||||
Exchange code for access_token.
|
||||
"""
|
||||
code = request.args.get("code")
|
||||
state = request.args.get("state")
|
||||
if not code:
|
||||
raise ValueError("No code provided")
|
||||
# Optionally: validate state here
|
||||
|
||||
data = {
|
||||
"client_id": system_credentials["client_id"],
|
||||
"client_secret": system_credentials["client_secret"],
|
||||
"code": code,
|
||||
"redirect_uri": system_credentials["redirect_uri"],
|
||||
}
|
||||
headers = {"Accept": "application/json"}
|
||||
response = requests.post(self._TOKEN_URL, data=data, headers=headers, timeout=10)
|
||||
response_json = response.json()
|
||||
access_token = response_json.get("access_token")
|
||||
if not access_token:
|
||||
raise ValueError(f"Error in GitHub OAuth: {response_json}")
|
||||
return {"access_token": access_token}
|
||||
|
||||
def _validate_credentials(self, credentials: dict) -> None:
|
||||
try:
|
||||
if "access_token" not in credentials or not credentials.get("access_token"):
|
||||
raise ToolProviderCredentialValidationError("GitHub API Access Token is required.")
|
||||
headers = {
|
||||
"Authorization": f"Bearer {credentials['access_token']}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
}
|
||||
response = requests.get(self._API_USER_URL, headers=headers, timeout=10)
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderCredentialValidationError(response.json().get("message"))
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
@ -0,0 +1,66 @@
|
||||
"""add tool oauth credentials
|
||||
|
||||
Revision ID: 99310d2c25a6
|
||||
Revises: 4474872b0ee6
|
||||
Create Date: 2025-06-18 15:06:15.261915
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '99310d2c25a6'
|
||||
down_revision = '4474872b0ee6'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tool_oauth_system_clients',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('plugin_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'),
|
||||
sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx')
|
||||
)
|
||||
op.create_table('tool_oauth_user_clients',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('plugin_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tool_oauth_user_client_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_user_client')
|
||||
)
|
||||
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('default', sa.Boolean(), server_default=sa.text('false'), nullable=False))
|
||||
batch_op.alter_column('credential_type',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
type_=sa.String(length=32),
|
||||
existing_nullable=False,
|
||||
existing_server_default=sa.text("'api_key'::character varying"))
|
||||
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
|
||||
batch_op.create_unique_constraint('unique_builtin_tool_provider', ['tenant_id', 'provider', 'credential_type'])
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||
batch_op.drop_constraint('unique_builtin_tool_provider', type_='unique')
|
||||
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider'])
|
||||
batch_op.alter_column('credential_type',
|
||||
existing_type=sa.String(length=32),
|
||||
type_=sa.VARCHAR(length=255),
|
||||
existing_nullable=False,
|
||||
existing_server_default=sa.text("'api_key'::character varying"))
|
||||
batch_op.drop_column('default')
|
||||
|
||||
op.drop_table('tool_oauth_user_clients')
|
||||
op.drop_table('tool_oauth_system_clients')
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,39 @@
|
||||
"""multiple credential
|
||||
|
||||
Revision ID: 222376193a49
|
||||
Revises: 99310d2c25a6
|
||||
Create Date: 2025-06-19 11:33:46.400455
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '222376193a49'
|
||||
down_revision = '99310d2c25a6'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
|
||||
|
||||
with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('owner_type', sa.Text(), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op:
|
||||
batch_op.drop_column('owner_type')
|
||||
|
||||
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'credential_type'])
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,33 @@
|
||||
"""multiple credential
|
||||
|
||||
Revision ID: a9306e69af07
|
||||
Revises: 222376193a49
|
||||
Create Date: 2025-06-19 13:53:41.554159
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'a9306e69af07'
|
||||
down_revision = '222376193a49'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||
batch_op.create_unique_constraint('unique_builtin_tool_provider', ['provider', 'tenant_id', 'default'])
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||
batch_op.drop_constraint('unique_builtin_tool_provider', type_='unique')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,33 @@
|
||||
"""multiple credential
|
||||
|
||||
Revision ID: 6835b906335f
|
||||
Revises: e315d2a83984
|
||||
Create Date: 2025-06-19 13:59:58.107955
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '6835b906335f'
|
||||
down_revision = 'e315d2a83984'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['provider', 'tenant_id', 'default'])
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,33 @@
|
||||
"""multiple credential
|
||||
|
||||
Revision ID: e315d2a83984
|
||||
Revises: a9306e69af07
|
||||
Create Date: 2025-06-19 13:59:13.860523
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'e315d2a83984'
|
||||
down_revision = 'a9306e69af07'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.drop_constraint(batch_op.f('unique_api_tool_provider'), type_='unique')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||
batch_op.create_unique_constraint(batch_op.f('unique_api_tool_provider'), ['name', 'tenant_id'])
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,53 @@
|
||||
"""multiple credential
|
||||
|
||||
Revision ID: 110e30078dd3
|
||||
Revises: 6835b906335f
|
||||
Create Date: 2025-06-19 15:11:42.688478
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '110e30078dd3'
|
||||
down_revision = '6835b906335f'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_oauth_system_clients', schema=None) as batch_op:
|
||||
batch_op.alter_column('plugin_id',
|
||||
existing_type=sa.UUID(),
|
||||
type_=sa.String(length=512),
|
||||
existing_nullable=False)
|
||||
|
||||
with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False))
|
||||
batch_op.alter_column('plugin_id',
|
||||
existing_type=sa.UUID(),
|
||||
type_=sa.String(length=512),
|
||||
existing_nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op:
|
||||
batch_op.alter_column('plugin_id',
|
||||
existing_type=sa.String(length=512),
|
||||
type_=sa.UUID(),
|
||||
existing_nullable=False)
|
||||
batch_op.drop_column('enabled')
|
||||
|
||||
with op.batch_alter_table('tool_oauth_system_clients', schema=None) as batch_op:
|
||||
batch_op.alter_column('plugin_id',
|
||||
existing_type=sa.String(length=512),
|
||||
type_=sa.UUID(),
|
||||
existing_nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@ -1,7 +1,62 @@
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from core.plugin.impl.base import BasePluginClient
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class OAuthProxyService(BasePluginClient):
|
||||
# Default max age for proxy context parameter in seconds
|
||||
__MAX_AGE__ = 5 * 60 # 5 minutes
|
||||
|
||||
@staticmethod
|
||||
def create_proxy_context(user_id, tenant_id, plugin_id, provider):
|
||||
"""
|
||||
Create a proxy context for an OAuth 2.0 authorization request.
|
||||
|
||||
This parameter is a crucial security measure to prevent Cross-Site Request
|
||||
Forgery (CSRF) attacks. It works by generating a unique nonce and storing it
|
||||
in a distributed cache (Redis) along with the user's session context.
|
||||
|
||||
The returned nonce should be included as the 'proxy_context' parameter in the
|
||||
authorization URL. Upon callback, the `retrieve_proxy_context` method
|
||||
is used to verify the state, ensuring the request's integrity and authenticity,
|
||||
and mitigating replay attacks.
|
||||
"""
|
||||
seconds, microseconds = redis_client.time()
|
||||
context_id = str(uuid.uuid4())
|
||||
data = {
|
||||
"user_id": user_id,
|
||||
"plugin_id": plugin_id,
|
||||
"tenant_id": tenant_id,
|
||||
"provider": provider,
|
||||
# encode redis time to avoid distribution time skew
|
||||
"timestamp": seconds,
|
||||
}
|
||||
# ignore nonce collision
|
||||
redis_client.setex(
|
||||
f"oauth_proxy_context:{context_id}",
|
||||
OAuthProxyService.__MAX_AGE__,
|
||||
json.dumps(data),
|
||||
)
|
||||
return context_id
|
||||
|
||||
|
||||
class OAuthService(BasePluginClient):
|
||||
@classmethod
|
||||
def get_authorization_url(cls, tenant_id: str, user_id: str, provider_name: str) -> str:
|
||||
return "1234567890"
|
||||
@staticmethod
|
||||
def use_proxy_context(context_id, max_age=__MAX_AGE__):
|
||||
"""
|
||||
Validate the proxy context parameter.
|
||||
This checks if the context_id is valid and not expired.
|
||||
"""
|
||||
if not context_id:
|
||||
raise ValueError("context_id is required")
|
||||
# get data from redis
|
||||
data = redis_client.getdel(f"oauth_proxy_context:{context_id}")
|
||||
if not data:
|
||||
raise ValueError("context_id is invalid")
|
||||
# check if data is expired
|
||||
seconds, microseconds = redis_client.time()
|
||||
state = json.loads(data)
|
||||
if state.get("timestamp") < seconds - max_age:
|
||||
raise ValueError("context_id is expired")
|
||||
return state
|
||||
|
||||
@ -0,0 +1,27 @@
|
||||
|
||||
@accessToken=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiYjM4Y2Y5N2MtODNiYS00MWI3LWEyZjMtMzZlOTgzZjE4YmQ5IiwiZXhwIjoxNzUwNDE3NDI0LCJpc3MiOiJTRUxGX0hPU1RFRCIsInN1YiI6IkNvbnNvbGUgQVBJIFBhc3Nwb3J0In0.pPCkISnSmnu3hOCyEVTIJoNeWxtx7E9LNy0cDQUy__Q
|
||||
|
||||
|
||||
|
||||
# set default credential
|
||||
POST /console/api/workspaces/current/tool-provider/builtin/langgenius/github/github/set-default
|
||||
Host: 127.0.0.1:5001
|
||||
Content-Type: application/json
|
||||
Authorization: Bearer {{accessToken}}
|
||||
|
||||
{
|
||||
"id": "55fb78d2-0ce6-4496-9488-3b8d9f40818f"
|
||||
}
|
||||
###
|
||||
|
||||
# get oauth url
|
||||
GET /console/api/oauth/plugin/tool?plugin_id=c58a1845-f3a4-4d93-b749-a71e9998b702/github&provider=github
|
||||
Host: 127.0.0.1:5001
|
||||
Authorization: Bearer {{accessToken}}
|
||||
|
||||
###
|
||||
|
||||
# get oauth token
|
||||
GET /console/api/oauth/plugin/tool/callback?state=734072c2-d8ed-4b0b-8ed8-4efd69d15a4f&code=e2d68a6216a3b7d70d2f&state=NQCjFkMKtf32XCMHc8KBdw
|
||||
Host: 127.0.0.1:5001
|
||||
Authorization: Bearer {{accessToken}}
|
||||
Loading…
Reference in New Issue