test: refactor thinking tags removal tests

- Remove redundant test cases
pull/21897/head
kimtaewoong 11 months ago
parent 89e3027a9b
commit ebd27801a1

@ -3,6 +3,8 @@ import uuid
from collections.abc import Sequence from collections.abc import Sequence
from typing import Optional from typing import Optional
from unittest import mock from unittest import mock
import os
import importlib
import pytest import pytest
@ -35,7 +37,7 @@ from core.workflow.nodes.llm.entities import (
VisionConfigOptions, VisionConfigOptions,
) )
from core.workflow.nodes.llm.file_saver import LLMFileSaver from core.workflow.nodes.llm.file_saver import LLMFileSaver
from core.workflow.nodes.llm.node import LLMNode from core.workflow.nodes.llm.node import LLMNode, LLM_NODE_THINKING_TAGS_ENABLED
from models.enums import UserFrom from models.enums import UserFrom
from models.provider import ProviderType from models.provider import ProviderType
from models.workflow import WorkflowType from models.workflow import WorkflowType
@ -703,14 +705,6 @@ with multiple lines
result = llm_node._remove_thinking_tags(input_text) result = llm_node._remove_thinking_tags(input_text)
assert result == expected assert result == expected
def test_mixed_case_removal(self, llm_node):
"""Test mixed case thinking tag removal."""
input_text = "<Think>Mixed case thinking</Think>Response"
expected = "Response"
result = llm_node._remove_thinking_tags(input_text)
assert result == expected
def test_no_thinking_tags(self, llm_node): def test_no_thinking_tags(self, llm_node):
"""Test text without thinking tags remains unchanged.""" """Test text without thinking tags remains unchanged."""
input_text = "Hello, this is a normal response without thinking tags." input_text = "Hello, this is a normal response without thinking tags."
@ -743,14 +737,6 @@ with multiple lines
result = llm_node._remove_thinking_tags(input_text) result = llm_node._remove_thinking_tags(input_text)
assert result == expected assert result == expected
def test_whitespace_after_tag(self, llm_node):
"""Test whitespace removal after thinking tags."""
input_text = "<think>Thinking</think> \n \t Final response"
expected = "Final response"
result = llm_node._remove_thinking_tags(input_text)
assert result == expected
def test_none_input(self, llm_node): def test_none_input(self, llm_node):
"""Test None input handling.""" """Test None input handling."""
result = llm_node._remove_thinking_tags(None) result = llm_node._remove_thinking_tags(None)
@ -761,59 +747,21 @@ with multiple lines
result = llm_node._remove_thinking_tags(123) result = llm_node._remove_thinking_tags(123)
assert result == 123 assert result == 123
def test_complex_real_world_example(self, llm_node):
"""Test with a complex real-world example from DeepSeek-R1."""
input_text = """<think>
Okay, let me try to figure out what the user is asking here. The message is just "gdgd".
That's pretty short and doesn't make much sense on its own. I need to consider different
possibilities.
First, maybe it's a typo or a shorthand. "GDGD" could be an acronym. Let me think about
common acronyms. "GDGD" might stand for "Good Good Good Good" but that seems unlikely.
</think>It looks like your message might be incomplete or unclear. Could you please provide
more context or rephrase your question? I'm here to help!"""
expected = (
"It looks like your message might be incomplete or unclear. Could you please "
"provide more context or rephrase your question? I'm here to help!"
)
result = llm_node._remove_thinking_tags(input_text)
assert result == expected
def test_multiple_whitespace_tags(self, llm_node):
"""Test multiple thinking tags with various whitespace."""
input_text = "<think>First</think> \n<think>Second</think> Final"
expected = "Final"
result = llm_node._remove_thinking_tags(input_text)
assert result == expected
@mock.patch.dict("os.environ", {"LLM_NODE_THINKING_TAGS_ENABLED": "true"}) @mock.patch.dict("os.environ", {"LLM_NODE_THINKING_TAGS_ENABLED": "true"})
def test_environment_variable_enabled(self): def test_environment_variable_enabled(self):
"""Test that environment variable is properly read when enabled.""" """Test that environment variable is properly read when enabled."""
from core.workflow.nodes.llm.node import LLM_NODE_THINKING_TAGS_ENABLED importlib.reload(core.workflow.nodes.llm.node)
assert LLM_NODE_THINKING_TAGS_ENABLED is True assert LLM_NODE_THINKING_TAGS_ENABLED is True
@mock.patch.dict("os.environ", {"LLM_NODE_THINKING_TAGS_ENABLED": "false"}) @mock.patch.dict("os.environ", {"LLM_NODE_THINKING_TAGS_ENABLED": "false"})
def test_environment_variable_disabled(self): def test_environment_variable_disabled(self):
"""Test that environment variable is properly read when disabled.""" """Test that environment variable is properly read when disabled."""
# Need to reimport to get the updated value
import importlib
import core.workflow.nodes.llm.node
importlib.reload(core.workflow.nodes.llm.node) importlib.reload(core.workflow.nodes.llm.node)
from core.workflow.nodes.llm.node import LLM_NODE_THINKING_TAGS_ENABLED
assert LLM_NODE_THINKING_TAGS_ENABLED is False assert LLM_NODE_THINKING_TAGS_ENABLED is False
def test_environment_variable_default(self): def test_environment_variable_default(self):
"""Test that environment variable defaults to True.""" """Test that environment variable defaults to True."""
from core.workflow.nodes.llm.node import LLM_NODE_THINKING_TAGS_ENABLED with mock.patch.dict("os.environ"):
os.environ.pop("LLM_NODE_THINKING_TAGS_ENABLED", None)
# Default should be True for backward compatibility importlib.reload(core.workflow.nodes.llm.node)
assert LLM_NODE_THINKING_TAGS_ENABLED is True assert LLM_NODE_THINKING_TAGS_ENABLED is True

Loading…
Cancel
Save