refactor(core/app/apps): Decouple WorkflowBasedAppRunner from AppRunner

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/21739/head
-LAN- 7 months ago
parent 6ed3aa5d4c
commit 7edd48146b
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -1,6 +1,6 @@
import logging import logging
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, cast from typing import Any, Optional, cast
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -9,13 +9,19 @@ from configs import dify_config
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
AppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent, QueueAnnotationReplyEvent,
QueueStopEvent, QueueStopEvent,
QueueTextChunkEvent, QueueTextChunkEvent,
) )
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.moderation.base import ModerationError from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion from core.variables.variables import VariableUnion
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
@ -24,7 +30,7 @@ from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
from models.enums import UserFrom from models.enums import UserFrom
from models.model import App, Conversation, EndUser, Message from models.model import App, Conversation, EndUser, Message, MessageAnnotation
from models.workflow import ConversationVariable, WorkflowType from models.workflow import ConversationVariable, WorkflowType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -241,3 +247,51 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._publish_event(QueueTextChunkEvent(text=text)) self._publish_event(QueueTextChunkEvent(text=text))
self._publish_event(QueueStopEvent(stopped_by=stopped_by)) self._publish_event(QueueStopEvent(stopped_by=stopped_by))
def query_app_annotations_to_reply(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> Optional[MessageAnnotation]:
"""
Query app annotations to reply
:param app_record: app record
:param message: message
:param query: query
:param user_id: user id
:param invoke_from: invoke from
:return:
"""
annotation_reply_feature = AnnotationReplyFeature()
return annotation_reply_feature.query(
app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from
)
def moderation_for_inputs(
self,
*,
app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: Mapping[str, Any],
query: str | None = None,
message_id: str,
) -> tuple[bool, Mapping[str, Any], str]:
"""
Process sensitive_word_avoidance.
:param app_id: app id
:param tenant_id: tenant id
:param app_generate_entity: app generate entity
:param inputs: inputs
:param query: query
:param message_id: message id
:return:
"""
moderation_feature = InputModeration()
return moderation_feature.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_generate_entity.app_config,
inputs=dict(inputs),
query=query or "",
message_id=message_id,
trace_manager=app_generate_entity.trace_manager,
)

@ -2,7 +2,6 @@ from collections.abc import Mapping
from typing import Any, Optional, cast from typing import Any, Optional, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
AppQueueEvent, AppQueueEvent,
QueueAgentLogEvent, QueueAgentLogEvent,
@ -70,7 +69,7 @@ from models.model import App
from models.workflow import Workflow from models.workflow import Workflow
class WorkflowBasedAppRunner(AppRunner): class WorkflowBasedAppRunner:
def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None: def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None:
self.queue_manager = queue_manager self.queue_manager = queue_manager
self._variable_loader = variable_loader self._variable_loader = variable_loader

Loading…
Cancel
Save