chore: add redis for builtin plugin service

pull/20118/head
G.Wood-Sun 1 year ago
parent 916c415b4b
commit 813d156382

@ -16,6 +16,7 @@ from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter from core.tools.utils.configuration import ProviderConfigEncrypter
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.tools import BuiltinToolProvider from models.tools import BuiltinToolProvider
from services.tools.tools_transform_service import ToolTransformService from services.tools.tools_transform_service import ToolTransformService
@ -23,6 +24,16 @@ logger = logging.getLogger(__name__)
class BuiltinToolManageService: class BuiltinToolManageService:
REDIS_KEY_PREFIX = "builtin_tools:user:{user_id}:tenant:{tenant_id}"
REDIS_TTL = 60 * 10 # Cache expires in 10 minutes
@staticmethod
def get_redis_key(user_id: str, tenant_id: str) -> str:
"""
Generate the Redis key for caching.
"""
return BuiltinToolManageService.REDIS_KEY_PREFIX.format(user_id=user_id, tenant_id=tenant_id)
@staticmethod @staticmethod
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]: def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
""" """
@ -166,6 +177,8 @@ class BuiltinToolManageService:
tool_configuration.delete_tool_credentials_cache() tool_configuration.delete_tool_credentials_cache()
db.session.commit() db.session.commit()
redis_key = BuiltinToolManageService.get_redis_key(user_id, tenant_id)
redis_client.delete(redis_key)
return {"result": "success"} return {"result": "success"}
@staticmethod @staticmethod
@ -201,6 +214,8 @@ class BuiltinToolManageService:
db.session.delete(provider_obj) db.session.delete(provider_obj)
db.session.commit() db.session.commit()
redis_key = BuiltinToolManageService.get_redis_key(user_id, tenant_id)
redis_client.delete(redis_key)
# delete cache # delete cache
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
@ -229,62 +244,80 @@ class BuiltinToolManageService:
""" """
list builtin tools list builtin tools
""" """
# get all builtin providers redis_key = BuiltinToolManageService.get_redis_key(user_id, tenant_id)
provider_controllers = ToolManager.list_builtin_providers(tenant_id) result: list[ToolProviderApiEntity] = []
with db.session.no_autoflush:
# get all user added providers
db_providers: list[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
)
# rewrite db_providers
for db_provider in db_providers:
db_provider.provider = str(ToolProviderID(db_provider.provider))
# find provider
def find_provider(provider):
return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
result: list[ToolProviderApiEntity] = [] try:
# Try to get from Redis cache
for provider_controller in provider_controllers: cached_data = redis_client.get(redis_key)
if cached_data:
try: try:
# handle include, exclude # Deserialize cached data directly into a list of dictionaries
if is_filtered( deserialized_data = json.loads(cached_data)
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore return [ToolProviderApiEntity(**item) for item in deserialized_data]
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore except (json.JSONDecodeError, TypeError) as e:
data=provider_controller, raise ValueError(f"Failed to deserialize cached data: {e}")
name_func=lambda x: x.identity.name,
): # get all builtin providers
continue provider_controllers = ToolManager.list_builtin_providers(tenant_id)
# convert provider controller to user provider with db.session.no_autoflush:
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( # get all user added providers
provider_controller=provider_controller, db_providers: list[BuiltinToolProvider] = (
db_provider=find_provider(provider_controller.entity.identity.name), db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
decrypt_credentials=True, )
)
# add icon # rewrite db_providers
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider) for db_provider in db_providers:
db_provider.provider = str(ToolProviderID(db_provider.provider))
tools = provider_controller.get_tools()
for tool in tools or []: # find provider
user_builtin_provider.tools.append( def find_provider(provider):
ToolTransformService.convert_tool_entity_to_api_entity( return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
tenant_id=tenant_id,
tool=tool, for provider_controller in provider_controllers:
credentials=user_builtin_provider.original_credentials, try:
labels=ToolLabelManager.get_tool_labels(provider_controller), # handle include, exclude
) if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
data=provider_controller,
name_func=lambda x: x.identity.name,
):
continue
# convert provider controller to user provider
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
db_provider=find_provider(provider_controller.entity.identity.name),
decrypt_credentials=True,
) )
result.append(user_builtin_provider) # add icon
except Exception as e: ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
raise e
tools = provider_controller.get_tools()
for tool in tools or []:
user_builtin_provider.tools.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id,
tool=tool,
credentials=user_builtin_provider.original_credentials,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
)
result.append(user_builtin_provider)
except Exception as e:
raise e
result = BuiltinToolProviderSort.sort(result)
return BuiltinToolProviderSort.sort(result) # Serialize and store in Redis
serialized_result = json.dumps([jsonable_encoder(item) for item in result])
redis_client.setex(redis_key, BuiltinToolManageService.REDIS_TTL, serialized_result)
except Exception as e:
raise ValueError(str(e))
return result
@staticmethod @staticmethod
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:

Loading…
Cancel
Save