feat: plugin OAuth with stateful
parent
366ddb05ae
commit
12c20ec7f6
@ -0,0 +1,67 @@
|
|||||||
|
import secrets
|
||||||
|
import urllib.parse
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from dify_plugin import ToolProvider
|
||||||
|
from dify_plugin.errors.tool import ToolProviderCredentialValidationError
|
||||||
|
from werkzeug import Request
|
||||||
|
|
||||||
|
|
||||||
|
class GithubProvider(ToolProvider):
|
||||||
|
_AUTH_URL = "https://github.com/login/oauth/authorize"
|
||||||
|
_TOKEN_URL = "https://github.com/login/oauth/access_token"
|
||||||
|
_API_USER_URL = "https://api.github.com/user"
|
||||||
|
|
||||||
|
def _oauth_get_authorization_url(self, system_credentials: Mapping[str, Any]) -> str:
|
||||||
|
"""
|
||||||
|
Generate the authorization URL for the Github OAuth.
|
||||||
|
"""
|
||||||
|
state = secrets.token_urlsafe(16)
|
||||||
|
params = {
|
||||||
|
"client_id": system_credentials["client_id"],
|
||||||
|
"redirect_uri": system_credentials["redirect_uri"],
|
||||||
|
"scope": system_credentials.get("scope", "read:user"),
|
||||||
|
"state": state,
|
||||||
|
# Optionally: allow_signup, login, etc.
|
||||||
|
}
|
||||||
|
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||||
|
|
||||||
|
def _oauth_get_credentials(self, system_credentials: Mapping[str, Any], request: Request) -> Mapping[str, Any]:
|
||||||
|
"""
|
||||||
|
Exchange code for access_token.
|
||||||
|
"""
|
||||||
|
code = request.args.get("code")
|
||||||
|
state = request.args.get("state")
|
||||||
|
if not code:
|
||||||
|
raise ValueError("No code provided")
|
||||||
|
# Optionally: validate state here
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"client_id": system_credentials["client_id"],
|
||||||
|
"client_secret": system_credentials["client_secret"],
|
||||||
|
"code": code,
|
||||||
|
"redirect_uri": system_credentials["redirect_uri"],
|
||||||
|
}
|
||||||
|
headers = {"Accept": "application/json"}
|
||||||
|
response = requests.post(self._TOKEN_URL, data=data, headers=headers, timeout=10)
|
||||||
|
response_json = response.json()
|
||||||
|
access_token = response_json.get("access_token")
|
||||||
|
if not access_token:
|
||||||
|
raise ValueError(f"Error in GitHub OAuth: {response_json}")
|
||||||
|
return {"access_token": access_token}
|
||||||
|
|
||||||
|
def _validate_credentials(self, credentials: dict) -> None:
|
||||||
|
try:
|
||||||
|
if "access_token" not in credentials or not credentials.get("access_token"):
|
||||||
|
raise ToolProviderCredentialValidationError("GitHub API Access Token is required.")
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {credentials['access_token']}",
|
||||||
|
"Accept": "application/vnd.github+json",
|
||||||
|
}
|
||||||
|
response = requests.get(self._API_USER_URL, headers=headers, timeout=10)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise ToolProviderCredentialValidationError(response.json().get("message"))
|
||||||
|
except Exception as e:
|
||||||
|
raise ToolProviderCredentialValidationError(str(e))
|
||||||
@ -0,0 +1,66 @@
|
|||||||
|
"""add tool oauth credentials
|
||||||
|
|
||||||
|
Revision ID: 99310d2c25a6
|
||||||
|
Revises: 4474872b0ee6
|
||||||
|
Create Date: 2025-06-18 15:06:15.261915
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '99310d2c25a6'
|
||||||
|
down_revision = '4474872b0ee6'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('tool_oauth_system_clients',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('plugin_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'),
|
||||||
|
sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx')
|
||||||
|
)
|
||||||
|
op.create_table('tool_oauth_user_clients',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('plugin_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='tool_oauth_user_client_pkey'),
|
||||||
|
sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_user_client')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('default', sa.Boolean(), server_default=sa.text('false'), nullable=False))
|
||||||
|
batch_op.alter_column('credential_type',
|
||||||
|
existing_type=sa.VARCHAR(length=255),
|
||||||
|
type_=sa.String(length=32),
|
||||||
|
existing_nullable=False,
|
||||||
|
existing_server_default=sa.text("'api_key'::character varying"))
|
||||||
|
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
|
||||||
|
batch_op.create_unique_constraint('unique_builtin_tool_provider', ['tenant_id', 'provider', 'credential_type'])
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||||
|
batch_op.drop_constraint('unique_builtin_tool_provider', type_='unique')
|
||||||
|
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider'])
|
||||||
|
batch_op.alter_column('credential_type',
|
||||||
|
existing_type=sa.String(length=32),
|
||||||
|
type_=sa.VARCHAR(length=255),
|
||||||
|
existing_nullable=False,
|
||||||
|
existing_server_default=sa.text("'api_key'::character varying"))
|
||||||
|
batch_op.drop_column('default')
|
||||||
|
|
||||||
|
op.drop_table('tool_oauth_user_clients')
|
||||||
|
op.drop_table('tool_oauth_system_clients')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -0,0 +1,39 @@
|
|||||||
|
"""multiple credential
|
||||||
|
|
||||||
|
Revision ID: 222376193a49
|
||||||
|
Revises: 99310d2c25a6
|
||||||
|
Create Date: 2025-06-19 11:33:46.400455
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '222376193a49'
|
||||||
|
down_revision = '99310d2c25a6'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||||
|
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
|
||||||
|
|
||||||
|
with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('owner_type', sa.Text(), nullable=False))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('owner_type')
|
||||||
|
|
||||||
|
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||||
|
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'credential_type'])
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -0,0 +1,33 @@
|
|||||||
|
"""multiple credential
|
||||||
|
|
||||||
|
Revision ID: a9306e69af07
|
||||||
|
Revises: 222376193a49
|
||||||
|
Create Date: 2025-06-19 13:53:41.554159
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'a9306e69af07'
|
||||||
|
down_revision = '222376193a49'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||||
|
batch_op.create_unique_constraint('unique_builtin_tool_provider', ['provider', 'tenant_id', 'default'])
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||||
|
batch_op.drop_constraint('unique_builtin_tool_provider', type_='unique')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -0,0 +1,33 @@
|
|||||||
|
"""multiple credential
|
||||||
|
|
||||||
|
Revision ID: 6835b906335f
|
||||||
|
Revises: e315d2a83984
|
||||||
|
Create Date: 2025-06-19 13:59:58.107955
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '6835b906335f'
|
||||||
|
down_revision = 'e315d2a83984'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||||
|
batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||||
|
batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['provider', 'tenant_id', 'default'])
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -0,0 +1,33 @@
|
|||||||
|
"""multiple credential
|
||||||
|
|
||||||
|
Revision ID: e315d2a83984
|
||||||
|
Revises: a9306e69af07
|
||||||
|
Create Date: 2025-06-19 13:59:13.860523
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'e315d2a83984'
|
||||||
|
down_revision = 'a9306e69af07'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||||
|
batch_op.drop_constraint(batch_op.f('unique_api_tool_provider'), type_='unique')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
|
||||||
|
batch_op.create_unique_constraint(batch_op.f('unique_api_tool_provider'), ['name', 'tenant_id'])
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -0,0 +1,53 @@
|
|||||||
|
"""multiple credential
|
||||||
|
|
||||||
|
Revision ID: 110e30078dd3
|
||||||
|
Revises: 6835b906335f
|
||||||
|
Create Date: 2025-06-19 15:11:42.688478
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '110e30078dd3'
|
||||||
|
down_revision = '6835b906335f'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('tool_oauth_system_clients', schema=None) as batch_op:
|
||||||
|
batch_op.alter_column('plugin_id',
|
||||||
|
existing_type=sa.UUID(),
|
||||||
|
type_=sa.String(length=512),
|
||||||
|
existing_nullable=False)
|
||||||
|
|
||||||
|
with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False))
|
||||||
|
batch_op.alter_column('plugin_id',
|
||||||
|
existing_type=sa.UUID(),
|
||||||
|
type_=sa.String(length=512),
|
||||||
|
existing_nullable=False)
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('tool_oauth_user_clients', schema=None) as batch_op:
|
||||||
|
batch_op.alter_column('plugin_id',
|
||||||
|
existing_type=sa.String(length=512),
|
||||||
|
type_=sa.UUID(),
|
||||||
|
existing_nullable=False)
|
||||||
|
batch_op.drop_column('enabled')
|
||||||
|
|
||||||
|
with op.batch_alter_table('tool_oauth_system_clients', schema=None) as batch_op:
|
||||||
|
batch_op.alter_column('plugin_id',
|
||||||
|
existing_type=sa.String(length=512),
|
||||||
|
type_=sa.UUID(),
|
||||||
|
existing_nullable=False)
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -1,7 +1,62 @@
|
|||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
|
||||||
from core.plugin.impl.base import BasePluginClient
|
from core.plugin.impl.base import BasePluginClient
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthProxyService(BasePluginClient):
|
||||||
|
# Default max age for proxy context parameter in seconds
|
||||||
|
__MAX_AGE__ = 5 * 60 # 5 minutes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_proxy_context(user_id, tenant_id, plugin_id, provider):
|
||||||
|
"""
|
||||||
|
Create a proxy context for an OAuth 2.0 authorization request.
|
||||||
|
|
||||||
|
This parameter is a crucial security measure to prevent Cross-Site Request
|
||||||
|
Forgery (CSRF) attacks. It works by generating a unique nonce and storing it
|
||||||
|
in a distributed cache (Redis) along with the user's session context.
|
||||||
|
|
||||||
|
The returned nonce should be included as the 'proxy_context' parameter in the
|
||||||
|
authorization URL. Upon callback, the `retrieve_proxy_context` method
|
||||||
|
is used to verify the state, ensuring the request's integrity and authenticity,
|
||||||
|
and mitigating replay attacks.
|
||||||
|
"""
|
||||||
|
seconds, microseconds = redis_client.time()
|
||||||
|
context_id = str(uuid.uuid4())
|
||||||
|
data = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"plugin_id": plugin_id,
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"provider": provider,
|
||||||
|
# encode redis time to avoid distribution time skew
|
||||||
|
"timestamp": seconds,
|
||||||
|
}
|
||||||
|
# ignore nonce collision
|
||||||
|
redis_client.setex(
|
||||||
|
f"oauth_proxy_context:{context_id}",
|
||||||
|
OAuthProxyService.__MAX_AGE__,
|
||||||
|
json.dumps(data),
|
||||||
|
)
|
||||||
|
return context_id
|
||||||
|
|
||||||
|
|
||||||
class OAuthService(BasePluginClient):
|
@staticmethod
|
||||||
@classmethod
|
def use_proxy_context(context_id, max_age=__MAX_AGE__):
|
||||||
def get_authorization_url(cls, tenant_id: str, user_id: str, provider_name: str) -> str:
|
"""
|
||||||
return "1234567890"
|
Validate the proxy context parameter.
|
||||||
|
This checks if the context_id is valid and not expired.
|
||||||
|
"""
|
||||||
|
if not context_id:
|
||||||
|
raise ValueError("context_id is required")
|
||||||
|
# get data from redis
|
||||||
|
data = redis_client.getdel(f"oauth_proxy_context:{context_id}")
|
||||||
|
if not data:
|
||||||
|
raise ValueError("context_id is invalid")
|
||||||
|
# check if data is expired
|
||||||
|
seconds, microseconds = redis_client.time()
|
||||||
|
state = json.loads(data)
|
||||||
|
if state.get("timestamp") < seconds - max_age:
|
||||||
|
raise ValueError("context_id is expired")
|
||||||
|
return state
|
||||||
|
|||||||
@ -0,0 +1,27 @@
|
|||||||
|
|
||||||
|
@accessToken=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiYjM4Y2Y5N2MtODNiYS00MWI3LWEyZjMtMzZlOTgzZjE4YmQ5IiwiZXhwIjoxNzUwNDE3NDI0LCJpc3MiOiJTRUxGX0hPU1RFRCIsInN1YiI6IkNvbnNvbGUgQVBJIFBhc3Nwb3J0In0.pPCkISnSmnu3hOCyEVTIJoNeWxtx7E9LNy0cDQUy__Q
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# set default credential
|
||||||
|
POST /console/api/workspaces/current/tool-provider/builtin/langgenius/github/github/set-default
|
||||||
|
Host: 127.0.0.1:5001
|
||||||
|
Content-Type: application/json
|
||||||
|
Authorization: Bearer {{accessToken}}
|
||||||
|
|
||||||
|
{
|
||||||
|
"id": "55fb78d2-0ce6-4496-9488-3b8d9f40818f"
|
||||||
|
}
|
||||||
|
###
|
||||||
|
|
||||||
|
# get oauth url
|
||||||
|
GET /console/api/oauth/plugin/tool?plugin_id=c58a1845-f3a4-4d93-b749-a71e9998b702/github&provider=github
|
||||||
|
Host: 127.0.0.1:5001
|
||||||
|
Authorization: Bearer {{accessToken}}
|
||||||
|
|
||||||
|
###
|
||||||
|
|
||||||
|
# get oauth token
|
||||||
|
GET /console/api/oauth/plugin/tool/callback?state=734072c2-d8ed-4b0b-8ed8-4efd69d15a4f&code=e2d68a6216a3b7d70d2f&state=NQCjFkMKtf32XCMHc8KBdw
|
||||||
|
Host: 127.0.0.1:5001
|
||||||
|
Authorization: Bearer {{accessToken}}
|
||||||
Loading…
Reference in New Issue