diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 4872702a76..3de8f998f4 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -10,7 +10,6 @@ from typing import Any, Optional from flask_login import current_user from sqlalchemy import func, select -from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from configs import dify_config @@ -385,13 +384,12 @@ class DatasetService: external_knowledge_id: External knowledge identifier external_knowledge_api_id: External knowledge API identifier """ - with Session(db.engine) as session: - external_knowledge_binding = ( - session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first() - ) + external_knowledge_binding = ( + db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first() + ) - if not external_knowledge_binding: - raise ValueError("External knowledge binding not found.") + if not external_knowledge_binding: + raise ValueError("External knowledge binding not found.") # Update binding if values have changed if ( @@ -400,7 +398,6 @@ class DatasetService: ): external_knowledge_binding.external_knowledge_id = external_knowledge_id external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id - db.session.add(external_knowledge_binding) @staticmethod def _update_internal_dataset(dataset, data, user): diff --git a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py index 7c40b1e556..7e0745dcd3 100644 --- a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py @@ -118,7 +118,7 @@ class TestDatasetServiceUpdateDataset: @pytest.fixture def mock_external_provider_dependencies(self): """Mock setup for external provider tests.""" - with patch("services.dataset_service.Session") as mock_session: + with patch("sqlalchemy.orm.Session") as mock_session: from extensions.ext_database import db with patch.object(db.__class__, "engine", new_callable=Mock):