diff --git a/api/tests/unit_tests/core/prompt/test_extract_thread_messages.py b/api/tests/unit_tests/core/prompt/test_extract_thread_messages.py index ba3c1eb5e0..e3e500e310 100644 --- a/api/tests/unit_tests/core/prompt/test_extract_thread_messages.py +++ b/api/tests/unit_tests/core/prompt/test_extract_thread_messages.py @@ -4,7 +4,7 @@ from constants import UUID_NIL from core.prompt.utils.extract_thread_messages import extract_thread_messages -class TestMessage: +class MockMessage: def __init__(self, id, parent_message_id): self.id = id self.parent_message_id = parent_message_id @@ -14,7 +14,7 @@ class TestMessage: def test_extract_thread_messages_single_message(): - messages = [TestMessage(str(uuid4()), UUID_NIL)] + messages = [MockMessage(str(uuid4()), UUID_NIL)] result = extract_thread_messages(messages) assert len(result) == 1 assert result[0] == messages[0] @@ -23,11 +23,11 @@ def test_extract_thread_messages_single_message(): def test_extract_thread_messages_linear_thread(): id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()) messages = [ - TestMessage(id5, id4), - TestMessage(id4, id3), - TestMessage(id3, id2), - TestMessage(id2, id1), - TestMessage(id1, UUID_NIL), + MockMessage(id5, id4), + MockMessage(id4, id3), + MockMessage(id3, id2), + MockMessage(id2, id1), + MockMessage(id1, UUID_NIL), ] result = extract_thread_messages(messages) assert len(result) == 5 @@ -37,10 +37,10 @@ def test_extract_thread_messages_linear_thread(): def test_extract_thread_messages_branched_thread(): id1, id2, id3, id4 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()) messages = [ - TestMessage(id4, id2), - TestMessage(id3, id2), - TestMessage(id2, id1), - TestMessage(id1, UUID_NIL), + MockMessage(id4, id2), + MockMessage(id3, id2), + MockMessage(id2, id1), + MockMessage(id1, UUID_NIL), ] result = extract_thread_messages(messages) assert len(result) == 3 @@ -56,9 +56,9 @@ def test_extract_thread_messages_empty_list(): def test_extract_thread_messages_partially_loaded(): id0, id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()) messages = [ - TestMessage(id3, id2), - TestMessage(id2, id1), - TestMessage(id1, id0), + MockMessage(id3, id2), + MockMessage(id2, id1), + MockMessage(id1, id0), ] result = extract_thread_messages(messages) assert len(result) == 3 @@ -68,9 +68,9 @@ def test_extract_thread_messages_partially_loaded(): def test_extract_thread_messages_legacy_messages(): id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4()) messages = [ - TestMessage(id3, UUID_NIL), - TestMessage(id2, UUID_NIL), - TestMessage(id1, UUID_NIL), + MockMessage(id3, UUID_NIL), + MockMessage(id2, UUID_NIL), + MockMessage(id1, UUID_NIL), ] result = extract_thread_messages(messages) assert len(result) == 3 @@ -80,11 +80,11 @@ def test_extract_thread_messages_legacy_messages(): def test_extract_thread_messages_mixed_with_legacy_messages(): id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()) messages = [ - TestMessage(id5, id4), - TestMessage(id4, id2), - TestMessage(id3, id2), - TestMessage(id2, UUID_NIL), - TestMessage(id1, UUID_NIL), + MockMessage(id5, id4), + MockMessage(id4, id2), + MockMessage(id3, id2), + MockMessage(id2, UUID_NIL), + MockMessage(id1, UUID_NIL), ] result = extract_thread_messages(messages) assert len(result) == 4