feat: refactor provider name generation to use incremental naming & enforce unique constraints

feat/rag-2
Harry 7 months ago
parent 23a5ff410e
commit 7364d051d2

@ -0,0 +1,42 @@
import logging
import re
from collections.abc import Sequence
from typing import Any
from core.tools.entities.tool_entities import CredentialType
logger = logging.getLogger(__name__)
def generate_provider_name(
providers: Sequence[Any], credential_type: CredentialType, fallback_context: str = "provider"
) -> str:
try:
return generate_incremental_name(
[provider.name for provider in providers],
f"{credential_type.get_name()}",
)
except Exception as e:
logger.warning(f"Error generating next provider name for {fallback_context}: {str(e)}")
return f"{credential_type.get_name()} 1"
def generate_incremental_name(
names: Sequence[str],
default_pattern: str,
) -> str:
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
numbers = []
for name in names:
if not name:
continue
match = re.match(pattern, name.strip())
if match:
numbers.append(int(match.group(1)))
if not numbers:
return f"{default_pattern} 1"
max_number = max(numbers)
return f"{default_pattern} {max_number + 1}"

@ -1,35 +0,0 @@
import logging
import re
from collections.abc import Sequence
from typing import Any
from core.tools.entities.tool_entities import CredentialType
logger = logging.getLogger(__name__)
def generate_provider_name(
providers: Sequence[Any],
credential_type: CredentialType,
fallback_context: str = "provider"
) -> str:
try:
default_pattern = f"{credential_type.get_name()}"
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
numbers = []
for provider in providers:
if provider.name:
match = re.match(pattern, provider.name.strip())
if match:
numbers.append(int(match.group(1)))
if not numbers:
return f"{default_pattern} 1"
max_number = max(numbers)
return f"{default_pattern} {max_number + 1}"
except Exception as e:
logger.warning(f"Error generating next provider name for {fallback_context}: {str(e)}")
return f"{credential_type.get_name()} 1"

@ -0,0 +1,33 @@
"""add_pipeline_info_13
Revision ID: fcb46171d891
Revises: 2008609cf2bb
Create Date: 2025-07-18 21:34:31.914500
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'fcb46171d891'
down_revision = '2008609cf2bb'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
batch_op.create_unique_constraint('datasource_provider_unique_name', ['tenant_id', 'plugin_id', 'provider', 'name'])
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('datasource_providers', schema=None) as batch_op:
batch_op.drop_constraint('datasource_provider_unique_name', type_='unique')
# ### end Alembic commands ###

@ -25,6 +25,7 @@ class DatasourceProvider(Base):
__tablename__ = "datasource_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="datasource_provider_pkey"),
db.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"),
db.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
@ -35,6 +36,6 @@ class DatasourceProvider(Base):
auth_type: Mapped[str] = db.Column(db.String(255), nullable=False)
encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
avatar_url: Mapped[str] = db.Column(db.String(255), nullable=True)
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)

@ -5,7 +5,7 @@ from sqlalchemy.orm import Session
from constants import HIDDEN_VALUE
from core.helper import encrypter
from core.helper.provider_name_generator import generate_provider_name
from core.helper.name_generator import generate_incremental_name
from core.model_runtime.entities.provider_entities import FormType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.plugin.entities.plugin import DatasourceProviderID
@ -40,7 +40,10 @@ class DatasourceProviderService:
)
.all()
)
return generate_provider_name(db_providers, credential_type, f"datasource provider {provider_id}")
return generate_incremental_name(
[provider.name for provider in db_providers],
f"{credential_type.get_name()}",
)
def add_datasource_oauth_provider(
self,
@ -57,15 +60,33 @@ class DatasourceProviderService:
with Session(db.engine) as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}"
with redis_client.lock(lock, timeout=20):
db_provider_name = name or self.generate_next_datasource_provider_name(
session=session,
tenant_id=tenant_id,
provider_id=provider_id,
credential_type=credential_type,
)
if session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, name=db_provider_name).count() > 0:
raise ValueError("name is already exists")
db_provider_name = name
if not db_provider_name:
db_provider_name = self.generate_next_datasource_provider_name(
session=session,
tenant_id=tenant_id,
provider_id=provider_id,
credential_type=credential_type,
)
else:
if session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=credential_type.value,
).count() > 0:
db_provider_name = generate_incremental_name(
[
provider.name
for provider in session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
)
],
db_provider_name,
)
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{provider_id}"

@ -8,9 +8,9 @@ from sqlalchemy.orm import Session
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.name_generator import generate_incremental_name
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.helper.provider_name_generator import generate_provider_name
from core.plugin.entities.plugin import ToolProviderID
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
@ -309,8 +309,10 @@ class BuiltinToolManageService:
.order_by(BuiltinToolProvider.created_at.desc())
.all()
)
return generate_provider_name(db_providers, credential_type, f"builtin tool provider {provider}")
return generate_incremental_name(
[provider.name for provider in db_providers],
f"{credential_type.get_name()}",
)
@staticmethod
def get_builtin_tool_provider_credentials(

Loading…
Cancel
Save