From 39d3f58082f977eb4ba9c632fde18d9d824e1d94 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 2 Jul 2025 11:33:00 +0800 Subject: [PATCH] r2 --- .../datasets/rag_pipeline/datasource_auth.py | 2 ++ ..._15_1558-b35c3db83d09_add_pipeline_info.py | 2 +- ...2_1132-15e40b74a6d2_add_pipeline_info_9.py | 33 +++++++++++++++++++ api/models/oauth.py | 1 + api/services/datasource_provider_service.py | 13 +++++++- .../customized/customized_retrieval.py | 1 + 6 files changed, 50 insertions(+), 2 deletions(-) create mode 100644 api/migrations/versions/2025_07_02_1132-15e40b74a6d2_add_pipeline_info_9.py diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 21a7b998f0..7f7b6a7867 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -96,6 +96,7 @@ class DatasourceAuth(Resource): parser = reqparse.RequestParser() parser.add_argument("provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, required=True, nullable=False, location="json") parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json") parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() @@ -108,6 +109,7 @@ class DatasourceAuth(Resource): provider=args["provider"], plugin_id=args["plugin_id"], credentials=args["credentials"], + name=args["name"], ) except CredentialsValidateFailedError as ex: raise ValueError(str(ex)) diff --git a/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py b/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py index 503842b797..961589a87e 100644 --- a/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py +++ b/api/migrations/versions/2025_05_15_1558-b35c3db83d09_add_pipeline_info.py @@ -12,7 +12,7 @@ from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = 'b35c3db83d09' -down_revision = '4474872b0ee6' +down_revision = '0ab65e1cc7fa' branch_labels = None depends_on = None diff --git a/api/migrations/versions/2025_07_02_1132-15e40b74a6d2_add_pipeline_info_9.py b/api/migrations/versions/2025_07_02_1132-15e40b74a6d2_add_pipeline_info_9.py new file mode 100644 index 0000000000..82c5991775 --- /dev/null +++ b/api/migrations/versions/2025_07_02_1132-15e40b74a6d2_add_pipeline_info_9.py @@ -0,0 +1,33 @@ +"""add_pipeline_info_9 + +Revision ID: 15e40b74a6d2 +Revises: a1025f709c06 +Create Date: 2025-07-02 11:32:44.125790 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '15e40b74a6d2' +down_revision = 'a1025f709c06' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasource_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('name', sa.String(length=255), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('datasource_providers', schema=None) as batch_op: + batch_op.drop_column('name') + + # ### end Alembic commands ### diff --git a/api/models/oauth.py b/api/models/oauth.py index b1b09e5d45..84bc29931e 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -29,6 +29,7 @@ class DatasourceProvider(Base): ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) + name: Mapped[str] = db.Column(db.String(255), nullable=False) provider: Mapped[str] = db.Column(db.String(255), nullable=False) plugin_id: Mapped[str] = db.Column(db.TEXT, nullable=False) auth_type: Mapped[str] = db.Column(db.String(255), nullable=False) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index fa01fe0afe..bca0081417 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -22,7 +22,7 @@ class DatasourceProviderService: self.provider_manager = PluginDatasourceManager() def datasource_provider_credentials_validate( - self, tenant_id: str, provider: str, plugin_id: str, credentials: dict + self, tenant_id: str, provider: str, plugin_id: str, credentials: dict, name: str ) -> None: """ validate datasource provider credentials. @@ -31,6 +31,15 @@ class DatasourceProviderService: :param provider: :param credentials: """ + # check name is exist + datasource_provider = ( + db.session.query(DatasourceProvider) + .filter_by(tenant_id=tenant_id, name=name) + .first() + ) + if datasource_provider: + raise ValueError("Authorization name is already exists") + credential_valid = self.provider_manager.validate_provider_credentials( tenant_id=tenant_id, user_id=current_user.id, @@ -55,6 +64,7 @@ class DatasourceProviderService: credentials[key] = encrypter.encrypt_token(tenant_id, value) datasource_provider = DatasourceProvider( tenant_id=tenant_id, + name=name, provider=provider, plugin_id=plugin_id, auth_type="api_key", @@ -120,6 +130,7 @@ class DatasourceProviderService: { "credentials": copy_credentials, "type": datasource_provider.auth_type, + "name": datasource_provider.name, } ) diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index 7280408889..3380d23ec4 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -38,6 +38,7 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): pipeline_customized_templates = ( db.session.query(PipelineCustomizedTemplate) .filter(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language) + .order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc()) .all() ) recommended_pipelines_results = []