feat: Allow using file variables directly in the LLM node and support more file types. (#10679)
Co-authored-by: Joel <iamjoel007@gmail.com>pull/11014/head
parent
535c72cad7
commit
c5f7d650b5
@ -1,125 +1,484 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from configs import dify_config
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||||
|
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||||
|
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
|
||||||
from core.file import File, FileTransferMethod, FileType
|
from core.file import File, FileTransferMethod, FileType
|
||||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
ImagePromptMessageContent,
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageRole,
|
||||||
|
SystemPromptMessage,
|
||||||
|
TextPromptMessageContent,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel
|
||||||
|
from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity
|
||||||
|
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||||
|
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||||
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
|
||||||
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
|
||||||
from core.workflow.nodes.end import EndStreamParam
|
from core.workflow.nodes.end import EndStreamParam
|
||||||
from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions
|
from core.workflow.nodes.llm.entities import (
|
||||||
|
ContextConfig,
|
||||||
|
LLMNodeChatModelMessage,
|
||||||
|
LLMNodeData,
|
||||||
|
ModelConfig,
|
||||||
|
VisionConfig,
|
||||||
|
VisionConfigOptions,
|
||||||
|
)
|
||||||
from core.workflow.nodes.llm.node import LLMNode
|
from core.workflow.nodes.llm.node import LLMNode
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
|
from models.provider import ProviderType
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
|
from tests.unit_tests.core.workflow.nodes.llm.test_scenarios import LLMNodeTestScenario
|
||||||
|
|
||||||
|
|
||||||
class TestLLMNode:
|
class MockTokenBufferMemory:
|
||||||
@pytest.fixture
|
def __init__(self, history_messages=None):
|
||||||
def llm_node(self):
|
self.history_messages = history_messages or []
|
||||||
data = LLMNodeData(
|
|
||||||
title="Test LLM",
|
def get_history_prompt_messages(
|
||||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
self, max_token_limit: int = 2000, message_limit: Optional[int] = None
|
||||||
prompt_template=[],
|
) -> Sequence[PromptMessage]:
|
||||||
memory=None,
|
if message_limit is not None:
|
||||||
context=ContextConfig(enabled=False),
|
return self.history_messages[-message_limit * 2 :]
|
||||||
vision=VisionConfig(
|
return self.history_messages
|
||||||
enabled=True,
|
|
||||||
configs=VisionConfigOptions(
|
|
||||||
variable_selector=["sys", "files"],
|
@pytest.fixture
|
||||||
detail=ImagePromptMessageContent.DETAIL.HIGH,
|
def llm_node():
|
||||||
),
|
data = LLMNodeData(
|
||||||
),
|
title="Test LLM",
|
||||||
)
|
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||||
variable_pool = VariablePool(
|
prompt_template=[],
|
||||||
system_variables={},
|
memory=None,
|
||||||
user_inputs={},
|
context=ContextConfig(enabled=False),
|
||||||
)
|
vision=VisionConfig(
|
||||||
node = LLMNode(
|
enabled=True,
|
||||||
id="1",
|
configs=VisionConfigOptions(
|
||||||
config={
|
variable_selector=["sys", "files"],
|
||||||
"id": "1",
|
detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||||
"data": data.model_dump(),
|
|
||||||
},
|
|
||||||
graph_init_params=GraphInitParams(
|
|
||||||
tenant_id="1",
|
|
||||||
app_id="1",
|
|
||||||
workflow_type=WorkflowType.WORKFLOW,
|
|
||||||
workflow_id="1",
|
|
||||||
graph_config={},
|
|
||||||
user_id="1",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.SERVICE_API,
|
|
||||||
call_depth=0,
|
|
||||||
),
|
),
|
||||||
graph=Graph(
|
),
|
||||||
root_node_id="1",
|
)
|
||||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
variable_pool = VariablePool(
|
||||||
answer_dependencies={},
|
system_variables={},
|
||||||
answer_generate_route={},
|
user_inputs={},
|
||||||
),
|
)
|
||||||
end_stream_param=EndStreamParam(
|
node = LLMNode(
|
||||||
end_dependencies={},
|
id="1",
|
||||||
end_stream_variable_selector_mapping={},
|
config={
|
||||||
),
|
"id": "1",
|
||||||
|
"data": data.model_dump(),
|
||||||
|
},
|
||||||
|
graph_init_params=GraphInitParams(
|
||||||
|
tenant_id="1",
|
||||||
|
app_id="1",
|
||||||
|
workflow_type=WorkflowType.WORKFLOW,
|
||||||
|
workflow_id="1",
|
||||||
|
graph_config={},
|
||||||
|
user_id="1",
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
|
call_depth=0,
|
||||||
|
),
|
||||||
|
graph=Graph(
|
||||||
|
root_node_id="1",
|
||||||
|
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||||
|
answer_dependencies={},
|
||||||
|
answer_generate_route={},
|
||||||
),
|
),
|
||||||
graph_runtime_state=GraphRuntimeState(
|
end_stream_param=EndStreamParam(
|
||||||
variable_pool=variable_pool,
|
end_dependencies={},
|
||||||
start_at=0,
|
end_stream_variable_selector_mapping={},
|
||||||
),
|
),
|
||||||
)
|
),
|
||||||
return node
|
graph_runtime_state=GraphRuntimeState(
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
start_at=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_config():
|
||||||
|
# Create actual provider and model type instances
|
||||||
|
model_provider_factory = ModelProviderFactory()
|
||||||
|
provider_instance = model_provider_factory.get_provider_instance("openai")
|
||||||
|
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
||||||
|
|
||||||
|
# Create a ProviderModelBundle
|
||||||
|
provider_model_bundle = ProviderModelBundle(
|
||||||
|
configuration=ProviderConfiguration(
|
||||||
|
tenant_id="1",
|
||||||
|
provider=provider_instance.get_provider_schema(),
|
||||||
|
preferred_provider_type=ProviderType.CUSTOM,
|
||||||
|
using_provider_type=ProviderType.CUSTOM,
|
||||||
|
system_configuration=SystemConfiguration(enabled=False),
|
||||||
|
custom_configuration=CustomConfiguration(provider=None),
|
||||||
|
model_settings=[],
|
||||||
|
),
|
||||||
|
provider_instance=provider_instance,
|
||||||
|
model_type_instance=model_type_instance,
|
||||||
|
)
|
||||||
|
|
||||||
def test_fetch_files_with_file_segment(self, llm_node):
|
# Create and return a ModelConfigWithCredentialsEntity
|
||||||
file = File(
|
return ModelConfigWithCredentialsEntity(
|
||||||
|
provider="openai",
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
model_schema=AIModelEntity(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
label=I18nObject(en_US="GPT-3.5 Turbo"),
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_properties={},
|
||||||
|
),
|
||||||
|
mode="chat",
|
||||||
|
credentials={},
|
||||||
|
parameters={},
|
||||||
|
provider_model_bundle=provider_model_bundle,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_files_with_file_segment(llm_node):
|
||||||
|
file = File(
|
||||||
|
id="1",
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test.jpg",
|
||||||
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
|
related_id="1",
|
||||||
|
)
|
||||||
|
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
|
||||||
|
|
||||||
|
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||||
|
assert result == [file]
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_files_with_array_file_segment(llm_node):
|
||||||
|
files = [
|
||||||
|
File(
|
||||||
id="1",
|
id="1",
|
||||||
tenant_id="test",
|
tenant_id="test",
|
||||||
type=FileType.IMAGE,
|
type=FileType.IMAGE,
|
||||||
filename="test.jpg",
|
filename="test1.jpg",
|
||||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
related_id="1",
|
related_id="1",
|
||||||
|
),
|
||||||
|
File(
|
||||||
|
id="2",
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test2.jpg",
|
||||||
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
|
related_id="2",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
|
||||||
|
|
||||||
|
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||||
|
assert result == files
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_files_with_none_segment(llm_node):
|
||||||
|
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
|
||||||
|
|
||||||
|
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_files_with_array_any_segment(llm_node):
|
||||||
|
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
|
||||||
|
|
||||||
|
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_files_with_non_existent_variable(llm_node):
|
||||||
|
result = llm_node._fetch_files(selector=["sys", "files"])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
|
||||||
|
prompt_template = []
|
||||||
|
llm_node.node_data.prompt_template = prompt_template
|
||||||
|
|
||||||
|
fake_vision_detail = faker.random_element(
|
||||||
|
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
|
||||||
|
)
|
||||||
|
fake_remote_url = faker.url()
|
||||||
|
files = [
|
||||||
|
File(
|
||||||
|
id="1",
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test1.jpg",
|
||||||
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
|
remote_url=fake_remote_url,
|
||||||
)
|
)
|
||||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
|
]
|
||||||
|
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
fake_query = faker.sentence()
|
||||||
assert result == [file]
|
|
||||||
|
prompt_messages, _ = llm_node._fetch_prompt_messages(
|
||||||
def test_fetch_files_with_array_file_segment(self, llm_node):
|
user_query=fake_query,
|
||||||
files = [
|
user_files=files,
|
||||||
File(
|
context=None,
|
||||||
id="1",
|
memory=None,
|
||||||
tenant_id="test",
|
model_config=model_config,
|
||||||
type=FileType.IMAGE,
|
prompt_template=prompt_template,
|
||||||
filename="test1.jpg",
|
memory_config=None,
|
||||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
vision_enabled=False,
|
||||||
related_id="1",
|
vision_detail=fake_vision_detail,
|
||||||
),
|
variable_pool=llm_node.graph_runtime_state.variable_pool,
|
||||||
File(
|
jinja2_variables=[],
|
||||||
id="2",
|
)
|
||||||
tenant_id="test",
|
|
||||||
type=FileType.IMAGE,
|
assert prompt_messages == [UserPromptMessage(content=fake_query)]
|
||||||
filename="test2.jpg",
|
|
||||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
|
||||||
related_id="2",
|
def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
|
||||||
),
|
# Setup dify config
|
||||||
]
|
dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url"
|
||||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
|
dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url"
|
||||||
|
|
||||||
|
# Generate fake values for prompt template
|
||||||
|
fake_assistant_prompt = faker.sentence()
|
||||||
|
fake_query = faker.sentence()
|
||||||
|
fake_context = faker.sentence()
|
||||||
|
fake_window_size = faker.random_int(min=1, max=3)
|
||||||
|
fake_vision_detail = faker.random_element(
|
||||||
|
[ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW]
|
||||||
|
)
|
||||||
|
fake_remote_url = faker.url()
|
||||||
|
|
||||||
|
# Setup mock memory with history messages
|
||||||
|
mock_history = [
|
||||||
|
UserPromptMessage(content=faker.sentence()),
|
||||||
|
AssistantPromptMessage(content=faker.sentence()),
|
||||||
|
UserPromptMessage(content=faker.sentence()),
|
||||||
|
AssistantPromptMessage(content=faker.sentence()),
|
||||||
|
UserPromptMessage(content=faker.sentence()),
|
||||||
|
AssistantPromptMessage(content=faker.sentence()),
|
||||||
|
]
|
||||||
|
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
# Setup memory configuration
|
||||||
assert result == files
|
memory_config = MemoryConfig(
|
||||||
|
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
|
||||||
|
window=MemoryConfig.WindowConfig(enabled=True, size=fake_window_size),
|
||||||
|
query_prompt_template=None,
|
||||||
|
)
|
||||||
|
|
||||||
def test_fetch_files_with_none_segment(self, llm_node):
|
memory = MockTokenBufferMemory(history_messages=mock_history)
|
||||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment())
|
|
||||||
|
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
# Test scenarios covering different file input combinations
|
||||||
assert result == []
|
test_scenarios = [
|
||||||
|
LLMNodeTestScenario(
|
||||||
|
description="No files",
|
||||||
|
user_query=fake_query,
|
||||||
|
user_files=[],
|
||||||
|
features=[],
|
||||||
|
vision_enabled=False,
|
||||||
|
vision_detail=None,
|
||||||
|
window_size=fake_window_size,
|
||||||
|
prompt_template=[
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text=fake_context,
|
||||||
|
role=PromptMessageRole.SYSTEM,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text="{#context#}",
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text=fake_assistant_prompt,
|
||||||
|
role=PromptMessageRole.ASSISTANT,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
expected_messages=[
|
||||||
|
SystemPromptMessage(content=fake_context),
|
||||||
|
UserPromptMessage(content=fake_context),
|
||||||
|
AssistantPromptMessage(content=fake_assistant_prompt),
|
||||||
|
]
|
||||||
|
+ mock_history[fake_window_size * -2 :]
|
||||||
|
+ [
|
||||||
|
UserPromptMessage(content=fake_query),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
LLMNodeTestScenario(
|
||||||
|
description="User files",
|
||||||
|
user_query=fake_query,
|
||||||
|
user_files=[
|
||||||
|
File(
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test1.jpg",
|
||||||
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
|
remote_url=fake_remote_url,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
vision_enabled=True,
|
||||||
|
vision_detail=fake_vision_detail,
|
||||||
|
features=[ModelFeature.VISION],
|
||||||
|
window_size=fake_window_size,
|
||||||
|
prompt_template=[
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text=fake_context,
|
||||||
|
role=PromptMessageRole.SYSTEM,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text="{#context#}",
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text=fake_assistant_prompt,
|
||||||
|
role=PromptMessageRole.ASSISTANT,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
expected_messages=[
|
||||||
|
SystemPromptMessage(content=fake_context),
|
||||||
|
UserPromptMessage(content=fake_context),
|
||||||
|
AssistantPromptMessage(content=fake_assistant_prompt),
|
||||||
|
]
|
||||||
|
+ mock_history[fake_window_size * -2 :]
|
||||||
|
+ [
|
||||||
|
UserPromptMessage(
|
||||||
|
content=[
|
||||||
|
TextPromptMessageContent(data=fake_query),
|
||||||
|
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
LLMNodeTestScenario(
|
||||||
|
description="Prompt template with variable selector of File",
|
||||||
|
user_query=fake_query,
|
||||||
|
user_files=[],
|
||||||
|
vision_enabled=False,
|
||||||
|
vision_detail=fake_vision_detail,
|
||||||
|
features=[ModelFeature.VISION],
|
||||||
|
window_size=fake_window_size,
|
||||||
|
prompt_template=[
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text="{{#input.image#}}",
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
expected_messages=[
|
||||||
|
UserPromptMessage(
|
||||||
|
content=[
|
||||||
|
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
+ mock_history[fake_window_size * -2 :]
|
||||||
|
+ [UserPromptMessage(content=fake_query)],
|
||||||
|
file_variables={
|
||||||
|
"input.image": File(
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test1.jpg",
|
||||||
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
|
remote_url=fake_remote_url,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
LLMNodeTestScenario(
|
||||||
|
description="Prompt template with variable selector of File without vision feature",
|
||||||
|
user_query=fake_query,
|
||||||
|
user_files=[],
|
||||||
|
vision_enabled=True,
|
||||||
|
vision_detail=fake_vision_detail,
|
||||||
|
features=[],
|
||||||
|
window_size=fake_window_size,
|
||||||
|
prompt_template=[
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text="{{#input.image#}}",
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)],
|
||||||
|
file_variables={
|
||||||
|
"input.image": File(
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
filename="test1.jpg",
|
||||||
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
|
remote_url=fake_remote_url,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
LLMNodeTestScenario(
|
||||||
|
description="Prompt template with variable selector of File with video file and vision feature",
|
||||||
|
user_query=fake_query,
|
||||||
|
user_files=[],
|
||||||
|
vision_enabled=True,
|
||||||
|
vision_detail=fake_vision_detail,
|
||||||
|
features=[ModelFeature.VISION],
|
||||||
|
window_size=fake_window_size,
|
||||||
|
prompt_template=[
|
||||||
|
LLMNodeChatModelMessage(
|
||||||
|
text="{{#input.image#}}",
|
||||||
|
role=PromptMessageRole.USER,
|
||||||
|
edition_type="basic",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
expected_messages=mock_history[fake_window_size * -2 :] + [UserPromptMessage(content=fake_query)],
|
||||||
|
file_variables={
|
||||||
|
"input.image": File(
|
||||||
|
tenant_id="test",
|
||||||
|
type=FileType.VIDEO,
|
||||||
|
filename="test1.mp4",
|
||||||
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
|
remote_url=fake_remote_url,
|
||||||
|
extension="mp4",
|
||||||
|
)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
def test_fetch_files_with_array_any_segment(self, llm_node):
|
for scenario in test_scenarios:
|
||||||
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
|
model_config.model_schema.features = scenario.features
|
||||||
|
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
for k, v in scenario.file_variables.items():
|
||||||
assert result == []
|
selector = k.split(".")
|
||||||
|
llm_node.graph_runtime_state.variable_pool.add(selector, v)
|
||||||
|
|
||||||
|
# Call the method under test
|
||||||
|
prompt_messages, _ = llm_node._fetch_prompt_messages(
|
||||||
|
user_query=scenario.user_query,
|
||||||
|
user_files=scenario.user_files,
|
||||||
|
context=fake_context,
|
||||||
|
memory=memory,
|
||||||
|
model_config=model_config,
|
||||||
|
prompt_template=scenario.prompt_template,
|
||||||
|
memory_config=memory_config,
|
||||||
|
vision_enabled=scenario.vision_enabled,
|
||||||
|
vision_detail=scenario.vision_detail,
|
||||||
|
variable_pool=llm_node.graph_runtime_state.variable_pool,
|
||||||
|
jinja2_variables=[],
|
||||||
|
)
|
||||||
|
|
||||||
def test_fetch_files_with_non_existent_variable(self, llm_node):
|
# Verify the result
|
||||||
result = llm_node._fetch_files(selector=["sys", "files"])
|
assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}"
|
||||||
assert result == []
|
assert (
|
||||||
|
prompt_messages == scenario.expected_messages
|
||||||
|
), f"Message content mismatch in scenario: {scenario.description}"
|
||||||
|
|||||||
@ -0,0 +1,25 @@
|
|||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.file import File
|
||||||
|
from core.model_runtime.entities.message_entities import PromptMessage
|
||||||
|
from core.model_runtime.entities.model_entities import ModelFeature
|
||||||
|
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage
|
||||||
|
|
||||||
|
|
||||||
|
class LLMNodeTestScenario(BaseModel):
|
||||||
|
"""Test scenario for LLM node testing."""
|
||||||
|
|
||||||
|
description: str = Field(..., description="Description of the test scenario")
|
||||||
|
user_query: str = Field(..., description="User query input")
|
||||||
|
user_files: Sequence[File] = Field(default_factory=list, description="List of user files")
|
||||||
|
vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
|
||||||
|
vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
|
||||||
|
features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features")
|
||||||
|
window_size: int = Field(..., description="Window size for memory")
|
||||||
|
prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages")
|
||||||
|
file_variables: Mapping[str, File | Sequence[File]] = Field(
|
||||||
|
default_factory=dict, description="List of file variables"
|
||||||
|
)
|
||||||
|
expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing")
|
||||||
Loading…
Reference in New Issue