|
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Optional, cast
|
|
|
|
|
|
|
|
|
|
from sqlalchemy import Float, and_, func, or_, text
|
|
|
|
|
from sqlalchemy import cast as sqlalchemy_cast
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
|
|
|
|
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
|
|
|
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
|
|
|
|
@ -85,30 +86,31 @@ class KnowledgeRetrievalNode(LLMNode):
|
|
|
|
|
return NodeRunResult(
|
|
|
|
|
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required."
|
|
|
|
|
)
|
|
|
|
|
# TODO(-LAN-): Move this check outside.
|
|
|
|
|
# check rate limit
|
|
|
|
|
if self.tenant_id:
|
|
|
|
|
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
|
|
|
|
|
if knowledge_rate_limit.enabled:
|
|
|
|
|
current_time = int(time.time() * 1000)
|
|
|
|
|
key = f"rate_limit_{self.tenant_id}"
|
|
|
|
|
redis_client.zadd(key, {current_time: current_time})
|
|
|
|
|
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
|
|
|
|
request_count = redis_client.zcard(key)
|
|
|
|
|
if request_count > knowledge_rate_limit.limit:
|
|
|
|
|
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
|
|
|
|
|
if knowledge_rate_limit.enabled:
|
|
|
|
|
current_time = int(time.time() * 1000)
|
|
|
|
|
key = f"rate_limit_{self.tenant_id}"
|
|
|
|
|
redis_client.zadd(key, {current_time: current_time})
|
|
|
|
|
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
|
|
|
|
request_count = redis_client.zcard(key)
|
|
|
|
|
if request_count > knowledge_rate_limit.limit:
|
|
|
|
|
with Session(db.engine) as session:
|
|
|
|
|
# add ratelimit record
|
|
|
|
|
rate_limit_log = RateLimitLog(
|
|
|
|
|
tenant_id=self.tenant_id,
|
|
|
|
|
subscription_plan=knowledge_rate_limit.subscription_plan,
|
|
|
|
|
operation="knowledge",
|
|
|
|
|
)
|
|
|
|
|
db.session.add(rate_limit_log)
|
|
|
|
|
db.session.commit()
|
|
|
|
|
return NodeRunResult(
|
|
|
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
|
|
|
inputs=variables,
|
|
|
|
|
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
|
|
|
|
|
error_type="RateLimitExceeded",
|
|
|
|
|
)
|
|
|
|
|
session.add(rate_limit_log)
|
|
|
|
|
session.commit()
|
|
|
|
|
return NodeRunResult(
|
|
|
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
|
|
|
inputs=variables,
|
|
|
|
|
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
|
|
|
|
|
error_type="RateLimitExceeded",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# retrieve knowledge
|
|
|
|
|
try:
|
|
|
|
|
@ -173,7 +175,9 @@ class KnowledgeRetrievalNode(LLMNode):
|
|
|
|
|
dataset_retrieval = DatasetRetrieval()
|
|
|
|
|
if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value:
|
|
|
|
|
# fetch model config
|
|
|
|
|
model_instance, model_config = self._fetch_model_config(node_data.single_retrieval_config.model) # type: ignore
|
|
|
|
|
if node_data.single_retrieval_config is None:
|
|
|
|
|
raise ValueError("single_retrieval_config is required")
|
|
|
|
|
model_instance, model_config = self.get_model_config(node_data.single_retrieval_config.model)
|
|
|
|
|
# check model is support tool calling
|
|
|
|
|
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
|
|
|
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
|
|
|
|
@ -424,7 +428,7 @@ class KnowledgeRetrievalNode(LLMNode):
|
|
|
|
|
raise ValueError("metadata_model_config is required")
|
|
|
|
|
# get metadata model instance
|
|
|
|
|
# fetch model config
|
|
|
|
|
model_instance, model_config = self._fetch_model_config(node_data.metadata_model_config) # type: ignore
|
|
|
|
|
model_instance, model_config = self.get_model_config(metadata_model_config)
|
|
|
|
|
# fetch prompt messages
|
|
|
|
|
prompt_template = self._get_prompt_template(
|
|
|
|
|
node_data=node_data,
|
|
|
|
|
@ -550,14 +554,7 @@ class KnowledgeRetrievalNode(LLMNode):
|
|
|
|
|
variable_mapping[node_id + ".query"] = node_data.query_variable_selector
|
|
|
|
|
return variable_mapping
|
|
|
|
|
|
|
|
|
|
def _fetch_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: # type: ignore
|
|
|
|
|
"""
|
|
|
|
|
Fetch model config
|
|
|
|
|
:param model: model
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
if model is None:
|
|
|
|
|
raise ValueError("model is required")
|
|
|
|
|
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
|
|
|
|
model_name = model.name
|
|
|
|
|
provider_name = model.provider
|
|
|
|
|
|
|
|
|
|
|