feat: Remove tenant_id context variable in favor of current_user

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/20240/head
-LAN- 1 year ago
parent 637d225317
commit 9ab54b7dae
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -11,10 +11,6 @@ if TYPE_CHECKING:
from core.workflow.entities.variable_pool import VariablePool
tenant_id: ContextVar[str] = ContextVar("tenant_id")
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
"""
To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with
"""

@ -158,7 +158,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
trace_manager=trace_manager,
workflow_run_id=workflow_run_id,
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
@ -240,7 +239,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
node_id=node_id, inputs=args["inputs"]
),
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
@ -316,7 +314,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
extras={"auto_generate_conversation_name": False},
single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())

@ -135,7 +135,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_run_id=workflow_run_id,
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
@ -282,7 +281,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
),
workflow_run_id=str(uuid.uuid4()),
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
@ -359,7 +357,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
workflow_run_id=str(uuid.uuid4()),
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())

@ -5,7 +5,6 @@ from flask import Response, request
from flask_login import user_loaded_from_request, user_logged_in
from werkzeug.exceptions import NotFound, Unauthorized
import contexts
from configs import dify_config
from dify_app import DifyApp
from extensions.ext_database import db
@ -82,8 +81,8 @@ def on_user_logged_in(_sender, user):
Note: AccountService.load_logged_in_account will populate user.current_tenant_id
through the load_user method, which calls account.set_tenant_id().
"""
if user and isinstance(user, Account) and user.current_tenant_id:
contexts.tenant_id.set(user.current_tenant_id)
# tenant_id context variable removed - using current_user.current_tenant_id directly
pass
@login_manager.unauthorized_handler

@ -6,6 +6,8 @@ from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import uuid4
from flask_login import current_user
from core.variables import utils as variable_utils
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from factories.variable_factory import build_segment
@ -17,7 +19,6 @@ import sqlalchemy as sa
from sqlalchemy import UniqueConstraint, func
from sqlalchemy.orm import Mapped, mapped_column
import contexts
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
from core.helper import encrypter
from core.variables import SecretVariable, Segment, SegmentType, Variable
@ -274,7 +275,16 @@ class Workflow(Base):
if self._environment_variables is None:
self._environment_variables = "{}"
tenant_id = contexts.tenant_id.get()
# Get tenant_id from current_user (Account or EndUser)
if isinstance(current_user, Account):
# Account user
tenant_id = current_user.current_tenant_id
else:
# EndUser
tenant_id = current_user.tenant_id
if not tenant_id:
return []
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables)
results = [
@ -297,7 +307,17 @@ class Workflow(Base):
self._environment_variables = "{}"
return
tenant_id = contexts.tenant_id.get()
# Get tenant_id from current_user (Account or EndUser)
if isinstance(current_user, Account):
# Account user
tenant_id = current_user.current_tenant_id
else:
# EndUser
tenant_id = current_user.tenant_id
if not tenant_id:
self._environment_variables = "{}"
return
value = list(value)
if any(var for var in value if not var.id):

@ -2,14 +2,13 @@ import json
from unittest import mock
from uuid import uuid4
import contexts
from constants import HIDDEN_VALUE
from core.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable
from models.workflow import Workflow, WorkflowNodeExecution
def test_environment_variables():
contexts.tenant_id.set("tenant_id")
# tenant_id context variable removed - using current_user.current_tenant_id directly
# Create a Workflow instance
workflow = Workflow(
@ -51,7 +50,7 @@ def test_environment_variables():
def test_update_environment_variables():
contexts.tenant_id.set("tenant_id")
# tenant_id context variable removed - using current_user.current_tenant_id directly
# Create a Workflow instance
workflow = Workflow(
@ -104,7 +103,7 @@ def test_update_environment_variables():
def test_to_dict():
contexts.tenant_id.set("tenant_id")
# tenant_id context variable removed - using current_user.current_tenant_id directly
# Create a Workflow instance
workflow = Workflow(

Loading…
Cancel
Save