feat: Add caching mechanism for plugin model schemas (#14898)

pull/14262/head
Yeuoly 1 year ago committed by Bharat Ramanathan
parent e49a13d20a
commit 696fd9b344

@ -5,6 +5,7 @@ from typing import TYPE_CHECKING
from contexts.wrapper import RecyclableContextVar from contexts.wrapper import RecyclableContextVar
if TYPE_CHECKING: if TYPE_CHECKING:
from core.model_runtime.entities.model_entities import AIModelEntity
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.provider import PluginToolProviderController
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
@ -20,11 +21,19 @@ To avoid race-conditions caused by gunicorn thread recycling, using RecyclableCo
plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar( plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar(
ContextVar("plugin_tool_providers") ContextVar("plugin_tool_providers")
) )
plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock")) plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))
plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar( plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
ContextVar("plugin_model_providers") ContextVar("plugin_model_providers")
) )
plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("plugin_model_providers_lock") ContextVar("plugin_model_providers_lock")
) )
plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock"))
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
ContextVar("plugin_model_schemas")
)

@ -1,8 +1,11 @@
import decimal import decimal
import hashlib
from threading import Lock
from typing import Optional from typing import Optional
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
import contexts
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.model_entities import (
@ -139,15 +142,35 @@ class AIModel(BaseModel):
:return: model schema :return: model schema
""" """
plugin_model_manager = PluginModelManager() plugin_model_manager = PluginModelManager()
return plugin_model_manager.get_model_schema( cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
tenant_id=self.tenant_id, # sort credentials
user_id="unknown", sorted_credentials = sorted(credentials.items()) if credentials else []
plugin_id=self.plugin_id, cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
provider=self.provider_name,
model_type=self.model_type.value, try:
model=model, contexts.plugin_model_schemas.get()
credentials=credentials or {}, except LookupError:
) contexts.plugin_model_schemas.set({})
contexts.plugin_model_schema_lock.set(Lock())
with contexts.plugin_model_schema_lock.get():
if cache_key in contexts.plugin_model_schemas.get():
return contexts.plugin_model_schemas.get()[cache_key]
schema = plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=self.plugin_id,
provider=self.provider_name,
model_type=self.model_type.value,
model=model,
credentials=credentials or {},
)
if schema:
contexts.plugin_model_schemas.get()[cache_key] = schema
return schema
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]: def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
""" """

@ -1,3 +1,4 @@
import hashlib
import logging import logging
import os import os
from collections.abc import Sequence from collections.abc import Sequence
@ -206,17 +207,35 @@ class ModelProviderFactory:
Get model schema Get model schema
""" """
plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider)
model_schema = self.plugin_model_manager.get_model_schema( cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}"
tenant_id=self.tenant_id, # sort credentials
user_id="unknown", sorted_credentials = sorted(credentials.items()) if credentials else []
plugin_id=plugin_id, cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials])
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials,
)
return model_schema try:
contexts.plugin_model_schemas.get()
except LookupError:
contexts.plugin_model_schemas.set({})
contexts.plugin_model_schema_lock.set(Lock())
with contexts.plugin_model_schema_lock.get():
if cache_key in contexts.plugin_model_schemas.get():
return contexts.plugin_model_schemas.get()[cache_key]
schema = self.plugin_model_manager.get_model_schema(
tenant_id=self.tenant_id,
user_id="unknown",
plugin_id=plugin_id,
provider=provider_name,
model_type=model_type.value,
model=model,
credentials=credentials or {},
)
if schema:
contexts.plugin_model_schemas.get()[cache_key] = schema
return schema
def get_models( def get_models(
self, self,

Loading…
Cancel
Save