feat/datasource
jyong 11 months ago
parent ef0e41de07
commit 678d6ffe2b

@ -8,7 +8,6 @@ from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config
from controllers.console import api from controllers.console import api
from core.plugin.impl.oauth import OAuthHandler
from libs.login import login_required from libs.login import login_required
from libs.oauth_data_source import NotionOAuth from libs.oauth_data_source import NotionOAuth
@ -110,29 +109,9 @@ class OAuthDataSourceSync(Resource):
return {"result": "success"}, 200 return {"result": "success"}, 200
class DatasourcePluginOauthApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, datasource_type, datasource_name):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
# get all builtin providers
oauth_handler = OAuthHandler()
providers = oauth_handler.get_authorization_url(
current_user.current_tenant.id,
current_user.id,
datasource_type,
datasource_name,
system_credentials={}
)
return providers
api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>") api.add_resource(OAuthDataSource, "/oauth/data-source/<string:provider>")
api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>") api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/<string:provider>")
api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>") api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/<string:provider>")
api.add_resource(OAuthDataSourceSync, "/oauth/data-source/<string:provider>/<uuid:binding_id>/sync") api.add_resource(OAuthDataSourceSync, "/oauth/data-source/<string:provider>/<uuid:binding_id>/sync")
api.add_resource(DatasourcePluginOauthApi, "/oauth/plugin/datasource/<string:datasoruce_type>/<string:datasource_name>")

@ -2,26 +2,30 @@ import logging
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Optional from typing import Optional
from flask_login import current_user
import requests import requests
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_restful import Resource from flask_restful import Resource
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import Unauthorized, Forbidden, NotFound
from configs import dify_config from configs import dify_config
from constants.languages import languages from constants.languages import languages
from controllers.console.wraps import account_initialization_required, setup_required
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import extract_remote_ip from libs.helper import extract_remote_ip
from libs.login import login_required
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models import Account from models import Account
from models.account import AccountStatus from models.account import AccountStatus
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from services.account_service import AccountService, RegisterService, TenantService from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountNotFoundError, AccountRegisterError from services.errors.account import AccountNotFoundError, AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
from services.feature_service import FeatureService from services.feature_service import FeatureService
from core.plugin.impl.oauth import OAuthHandler
from .. import api from .. import api
@ -181,5 +185,64 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
return account return account
class PluginOauthApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider, plugin_id):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
# get all plugin oauth configs
plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by(
provider=provider,
plugin_id=plugin_id
).first()
if not plugin_oauth_config:
raise NotFound()
oauth_handler = OAuthHandler()
response = oauth_handler.get_authorization_url(
current_user.current_tenant.id,
current_user.id,
plugin_id,
provider,
system_credentials=plugin_oauth_config.system_credentials
)
return response.model_dump()
class PluginOauthCallback(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider, plugin_id):
oauth_handler = OAuthHandler()
plugin_oauth_config = db.session.query(DatasourceOauthParamConfig).filter_by(
provider=provider,
plugin_id=plugin_id
).first()
if not plugin_oauth_config:
raise NotFound()
credentials = oauth_handler.get_credentials(
current_user.current_tenant.id,
current_user.id,
plugin_id,
provider,
system_credentials=plugin_oauth_config.system_credentials,
request=request
)
datasource_provider = DatasourceProvider(
datasource_name=plugin_oauth_config.datasource_name,
plugin_id=plugin_id,
provider=provider,
auth_type="oauth",
encrypted_credentials=credentials
)
db.session.add(datasource_provider)
db.session.commit()
return redirect(f"{dify_config.CONSOLE_WEB_URL}")
api.add_resource(OAuthLogin, "/oauth/login/<provider>") api.add_resource(OAuthLogin, "/oauth/login/<provider>")
api.add_resource(OAuthCallback, "/oauth/authorize/<provider>") api.add_resource(OAuthCallback, "/oauth/authorize/<provider>")
api.add_resource(PluginOauthApi, "/oauth/plugin/provider/<string:provider>/plugin/<string:plugin_id>")
api.add_resource(PluginOauthCallback, "/oauth/plugin/callback/<string:provider>/plugin/<string:plugin_id>")

@ -0,0 +1,56 @@
from typing import cast
from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from core.plugin.impl.datasource import PluginDatasourceManager
from extensions.ext_database import db
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
from libs.login import login_required
from models import Account
from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class DatasourcePluginOauthApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, datasource_type, datasource_name):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
# get all builtin providers
manager = PluginDatasourceManager()
providers = manager.get_provider_oauth_url()
return providers
# Import Rag Pipeline
api.add_resource(
DatasourcePluginOauthApi,
"/datasource/<string:datasoruce_type>/<string:datasource_name>/oauth",
)
api.add_resource(
RagPipelineImportConfirmApi,
"/rag/pipelines/imports/<string:import_id>/confirm",
)
api.add_resource(
RagPipelineImportCheckDependenciesApi,
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
)
api.add_resource(
RagPipelineExportApi,
"/rag/pipelines/<string:pipeline_id>/exports",
)

@ -279,7 +279,7 @@ class PublishedRagPipelineRunApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json") parser.add_argument("datasource_type", type=str, required=True, location="json")
parser.add_argument("datasource_info", type=list, required=True, location="json") parser.add_argument("datasource_info_list", type=list, required=True, location="json")
parser.add_argument("start_node_id", type=str, required=True, location="json") parser.add_argument("start_node_id", type=str, required=True, location="json")
parser.add_argument("is_preview", type=bool, required=True, location="json", default=False) parser.add_argument("is_preview", type=bool, required=True, location="json", default=False)
args = parser.parse_args() args = parser.parse_args()

@ -104,7 +104,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id, SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id,
SystemVariableKey.DATASOURCE_TYPE: self.application_generate_entity.datasource_type, SystemVariableKey.DATASOURCE_TYPE: self.application_generate_entity.datasource_type,
SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info, SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info,
SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from, SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from.value,
} }
variable_pool = VariablePool( variable_pool = VariablePool(
@ -199,4 +199,4 @@ class PipelineRunner(WorkflowBasedAppRunner):
if not graph: if not graph:
raise ValueError("graph not found in workflow") raise ValueError("graph not found in workflow")
return graph return graph

@ -0,0 +1,47 @@
from datetime import datetime
from json import JSONDecodeError
from typing import Any, cast
from sqlalchemy import func
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped
from configs import dify_config
from extensions.ext_storage import storage
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
from .account import Account
from .base import Base
from .engine import db
from .model import App, Tag, TagBinding, UploadFile
from .types import StringUUID
class DatasourceOauthParamConfig(Base): # type: ignore[name-defined]
__tablename__ = "datasource_oauth_params"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="datasource_oauth_config_pkey"),
db.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
datasource_name: Mapped[str] = db.Column(db.String(255), nullable=False)
plugin_id: Mapped[str] = db.Column(StringUUID, nullable=False)
provider: Mapped[str] = db.Column(db.String(255), nullable=False)
system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
class DatasourceProvider(Base):
__tablename__ = "datasource_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="datasource_provider_pkey"),
db.UniqueConstraint("plugin_id", "provider", name="datasource_provider_plugin_id_provider_idx"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
datasource_name: Mapped[str] = db.Column(db.String(255), nullable=False)
plugin_id: Mapped[str] = db.Column(StringUUID, nullable=False)
provider: Mapped[str] = db.Column(db.String(255), nullable=False)
auth_type: Mapped[str] = db.Column(db.String(255), nullable=False)
encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
Loading…
Cancel
Save