feat: implement OAuth credentials refresh mechanism and update expires_at handling

pull/22744/head
Yeuoly 7 months ago
parent 5d986c2cdf
commit 7fa952b1a2

@ -739,7 +739,7 @@ class ToolOAuthCallback(Resource):
raise Forbidden("no oauth available client config found for this tool provider") raise Forbidden("no oauth available client config found for this tool provider")
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
credentials = oauth_handler.get_credentials( credentials_response = oauth_handler.get_credentials(
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=user_id, user_id=user_id,
plugin_id=plugin_id, plugin_id=plugin_id,
@ -747,7 +747,10 @@ class ToolOAuthCallback(Resource):
redirect_uri=redirect_uri, redirect_uri=redirect_uri,
system_credentials=oauth_client_params, system_credentials=oauth_client_params,
request=request, request=request,
).credentials )
credentials = credentials_response.credentials
expires_at = credentials_response.expires_at
if not credentials: if not credentials:
raise Exception("the plugin credentials failed") raise Exception("the plugin credentials failed")
@ -758,8 +761,8 @@ class ToolOAuthCallback(Resource):
tenant_id=tenant_id, tenant_id=tenant_id,
provider=provider, provider=provider,
credentials=dict(credentials), credentials=dict(credentials),
expires_at=expires_at,
api_type=CredentialType.OAUTH2, api_type=CredentialType.OAUTH2,
expires_at=credentials.get("expires_at"),
) )
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")

@ -187,6 +187,9 @@ class PluginOAuthCredentialsResponse(BaseModel):
) )
expires_at: int | None = Field(description="The expires at time of the credentials. UTC timestamp.") expires_at: int | None = Field(description="The expires at time of the credentials. UTC timestamp.")
credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.") credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.")
expires_at: int = Field(
default=-1, description="The timestamp of the credentials expires at, -1 means never expires."
)
class PluginListResponse(BaseModel): class PluginListResponse(BaseModel):

@ -84,6 +84,41 @@ class OAuthHandler(BasePluginClient):
except Exception as e: except Exception as e:
raise ValueError(f"Error getting credentials: {e}") raise ValueError(f"Error getting credentials: {e}")
def refresh_credentials(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
redirect_uri: str,
system_credentials: Mapping[str, Any],
credentials: Mapping[str, Any],
) -> PluginOAuthCredentialsResponse:
try:
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/refresh_credentials",
PluginOAuthCredentialsResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"redirect_uri": redirect_uri,
"system_credentials": system_credentials,
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for refresh credentials request.")
except Exception as e:
raise ValueError(f"Error refreshing credentials: {e}")
def _convert_request_to_raw_data(self, request: Request) -> bytes: def _convert_request_to_raw_data(self, request: Request) -> bytes:
""" """
Convert a Request object to raw HTTP data. Convert a Request object to raw HTTP data.

@ -1,16 +1,19 @@
import json import json
import logging import logging
import mimetypes import mimetypes
from collections.abc import Generator import time
from collections.abc import Generator, Mapping
from os import listdir, path from os import listdir, path
from threading import Lock from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
from pydantic import TypeAdapter
from yarl import URL from yarl import URL
import contexts import contexts
from core.helper.provider_cache import ToolProviderCredentialsCache from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler
from core.plugin.impl.tool import PluginToolManager from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime from core.tools.__base.tool_runtime import ToolRuntime
@ -21,6 +24,7 @@ from core.tools.plugin_tool.tool import PluginTool
from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.tools.mcp_tools_manage_service import MCPToolManageService from services.tools.mcp_tools_manage_service import MCPToolManageService
if TYPE_CHECKING: if TYPE_CHECKING:
@ -244,12 +248,44 @@ class ToolManager:
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
), ),
) )
# decrypt the credentials
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
# check if the credentials is expired
if builtin_provider.expires_at != -1 and builtin_provider.expires_at < int(time.time()):
# refresh the credentials
tool_provider = ToolProviderID(provider_id)
provider_name = tool_provider.provider_name
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_name}/tool/callback"
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_name)
oauth_handler = OAuthHandler()
# refresh the credentials
refreshed_credentials = oauth_handler.refresh_credentials(
tenant_id=tenant_id,
user_id=builtin_provider.user_id,
plugin_id=tool_provider.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
# update the credentials
builtin_provider.encrypted_credentials = (
TypeAdapter(dict[str, Any])
.dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
.decode("utf-8")
)
builtin_provider.expires_at = refreshed_credentials.expires_at
db.session.commit()
decrypted_credentials = refreshed_credentials.credentials
return cast( return cast(
BuiltinTool, BuiltinTool,
builtin_tool.fork_tool_runtime( builtin_tool.fork_tool_runtime(
runtime=ToolRuntime( runtime=ToolRuntime(
tenant_id=tenant_id, tenant_id=tenant_id,
credentials=encrypter.decrypt(builtin_provider.credentials), credentials=dict(decrypted_credentials),
credential_type=CredentialType.of(builtin_provider.credential_type), credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={}, runtime_parameters={},
invoke_from=invoke_from, invoke_from=invoke_from,

@ -20,7 +20,7 @@ depends_on = None
def upgrade(): def upgrade():
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('expires_at', sa.Integer(), server_default=sa.text('2147483647'), nullable=False)) batch_op.add_column(sa.Column('expires_at', sa.BigInteger(), server_default=sa.text('2147483647'), nullable=False))
# ### end Alembic commands ### # ### end Alembic commands ###

@ -93,7 +93,7 @@ class BuiltinToolProvider(Base):
credential_type: Mapped[str] = mapped_column( credential_type: Mapped[str] = mapped_column(
db.String(32), nullable=False, server_default=db.text("'api-key'::character varying") db.String(32), nullable=False, server_default=db.text("'api-key'::character varying")
) )
expires_at: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("2147483647")) expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1"))
@property @property
def credentials(self) -> dict: def credentials(self) -> dict:

@ -213,8 +213,8 @@ class BuiltinToolManageService:
tenant_id: str, tenant_id: str,
provider: str, provider: str,
credentials: dict, credentials: dict,
expires_at: int = -1,
name: str | None = None, name: str | None = None,
expires_at: int | None = None,
): ):
""" """
add builtin tool provider add builtin tool provider

Loading…
Cancel
Save