feat: Remove tenant_id context variable in favor of current_user

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/20240/head
-LAN- 1 year ago
parent 637d225317
commit 9ab54b7dae
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -11,10 +11,6 @@ if TYPE_CHECKING:
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
tenant_id: ContextVar[str] = ContextVar("tenant_id")
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
""" """
To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with
""" """

@ -158,7 +158,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
trace_manager=trace_manager, trace_manager=trace_manager,
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
@ -240,7 +239,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
node_id=node_id, inputs=args["inputs"] node_id=node_id, inputs=args["inputs"]
), ),
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
@ -316,7 +314,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
extras={"auto_generate_conversation_name": False}, extras={"auto_generate_conversation_name": False},
single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())

@ -135,7 +135,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_run_id=workflow_run_id, workflow_run_id=workflow_run_id,
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
@ -282,7 +281,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
), ),
workflow_run_id=str(uuid.uuid4()), workflow_run_id=str(uuid.uuid4()),
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())
@ -359,7 +357,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
workflow_run_id=str(uuid.uuid4()), workflow_run_id=str(uuid.uuid4()),
) )
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock()) contexts.plugin_tool_providers_lock.set(threading.Lock())

@ -5,7 +5,6 @@ from flask import Response, request
from flask_login import user_loaded_from_request, user_logged_in from flask_login import user_loaded_from_request, user_logged_in
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
import contexts
from configs import dify_config from configs import dify_config
from dify_app import DifyApp from dify_app import DifyApp
from extensions.ext_database import db from extensions.ext_database import db
@ -82,8 +81,8 @@ def on_user_logged_in(_sender, user):
Note: AccountService.load_logged_in_account will populate user.current_tenant_id Note: AccountService.load_logged_in_account will populate user.current_tenant_id
through the load_user method, which calls account.set_tenant_id(). through the load_user method, which calls account.set_tenant_id().
""" """
if user and isinstance(user, Account) and user.current_tenant_id: # tenant_id context variable removed - using current_user.current_tenant_id directly
contexts.tenant_id.set(user.current_tenant_id) pass
@login_manager.unauthorized_handler @login_manager.unauthorized_handler

@ -6,6 +6,8 @@ from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import uuid4 from uuid import uuid4
from flask_login import current_user
from core.variables import utils as variable_utils from core.variables import utils as variable_utils
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from factories.variable_factory import build_segment from factories.variable_factory import build_segment
@ -17,7 +19,6 @@ import sqlalchemy as sa
from sqlalchemy import UniqueConstraint, func from sqlalchemy import UniqueConstraint, func
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
import contexts
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
from core.helper import encrypter from core.helper import encrypter
from core.variables import SecretVariable, Segment, SegmentType, Variable from core.variables import SecretVariable, Segment, SegmentType, Variable
@ -274,7 +275,16 @@ class Workflow(Base):
if self._environment_variables is None: if self._environment_variables is None:
self._environment_variables = "{}" self._environment_variables = "{}"
tenant_id = contexts.tenant_id.get() # Get tenant_id from current_user (Account or EndUser)
if isinstance(current_user, Account):
# Account user
tenant_id = current_user.current_tenant_id
else:
# EndUser
tenant_id = current_user.tenant_id
if not tenant_id:
return []
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables) environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables)
results = [ results = [
@ -297,7 +307,17 @@ class Workflow(Base):
self._environment_variables = "{}" self._environment_variables = "{}"
return return
tenant_id = contexts.tenant_id.get() # Get tenant_id from current_user (Account or EndUser)
if isinstance(current_user, Account):
# Account user
tenant_id = current_user.current_tenant_id
else:
# EndUser
tenant_id = current_user.tenant_id
if not tenant_id:
self._environment_variables = "{}"
return
value = list(value) value = list(value)
if any(var for var in value if not var.id): if any(var for var in value if not var.id):

@ -2,14 +2,13 @@ import json
from unittest import mock from unittest import mock
from uuid import uuid4 from uuid import uuid4
import contexts
from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
from models.workflow import Workflow, WorkflowNodeExecution from models.workflow import Workflow, WorkflowNodeExecution
def test_environment_variables(): def test_environment_variables():
contexts.tenant_id.set("tenant_id") # tenant_id context variable removed - using current_user.current_tenant_id directly
# Create a Workflow instance # Create a Workflow instance
workflow = Workflow( workflow = Workflow(
@ -51,7 +50,7 @@ def test_environment_variables():
def test_update_environment_variables(): def test_update_environment_variables():
contexts.tenant_id.set("tenant_id") # tenant_id context variable removed - using current_user.current_tenant_id directly
# Create a Workflow instance # Create a Workflow instance
workflow = Workflow( workflow = Workflow(
@ -104,7 +103,7 @@ def test_update_environment_variables():
def test_to_dict(): def test_to_dict():
contexts.tenant_id.set("tenant_id") # tenant_id context variable removed - using current_user.current_tenant_id directly
# Create a Workflow instance # Create a Workflow instance
workflow = Workflow( workflow = Workflow(

Loading…
Cancel
Save