feat: regenerate in `Chat`, `agent` and `Chatflow` app (#7661)
parent
b32a7713e0
commit
8c51d06222
@ -1 +1,2 @@
|
||||
HIDDEN_VALUE = "[__HIDDEN__]"
|
||||
UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
@ -0,0 +1,22 @@
|
||||
from constants import UUID_NIL
|
||||
|
||||
|
||||
def extract_thread_messages(messages: list[dict]) -> list[dict]:
|
||||
thread_messages = []
|
||||
next_message = None
|
||||
|
||||
for message in messages:
|
||||
if not message.parent_message_id:
|
||||
# If the message is regenerated and does not have a parent message, it is the start of a new thread
|
||||
thread_messages.append(message)
|
||||
break
|
||||
|
||||
if not next_message:
|
||||
thread_messages.append(message)
|
||||
next_message = message.parent_message_id
|
||||
else:
|
||||
if next_message in {message.id, UUID_NIL}:
|
||||
thread_messages.append(message)
|
||||
next_message = message.parent_message_id
|
||||
|
||||
return thread_messages
|
||||
@ -0,0 +1,36 @@
|
||||
"""add parent_message_id to messages
|
||||
|
||||
Revision ID: d57ba9ebb251
|
||||
Revises: 675b5321501b
|
||||
Create Date: 2024-09-11 10:12:45.826265
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models as models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'd57ba9ebb251'
|
||||
down_revision = '675b5321501b'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('parent_message_id', models.types.StringUUID(), nullable=True))
|
||||
|
||||
# Set parent_message_id for existing messages to uuid_nil() to distinguish them from new messages with actual parent IDs or NULLs
|
||||
op.execute('UPDATE messages SET parent_message_id = uuid_nil() WHERE parent_message_id IS NULL')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
batch_op.drop_column('parent_message_id')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,91 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
|
||||
|
||||
class TestMessage:
|
||||
def __init__(self, id, parent_message_id):
|
||||
self.id = id
|
||||
self.parent_message_id = parent_message_id
|
||||
|
||||
def __getitem__(self, item):
|
||||
return getattr(self, item)
|
||||
|
||||
|
||||
def test_extract_thread_messages_single_message():
|
||||
messages = [TestMessage(str(uuid4()), UUID_NIL)]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 1
|
||||
assert result[0] == messages[0]
|
||||
|
||||
|
||||
def test_extract_thread_messages_linear_thread():
|
||||
id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
TestMessage(id5, id4),
|
||||
TestMessage(id4, id3),
|
||||
TestMessage(id3, id2),
|
||||
TestMessage(id2, id1),
|
||||
TestMessage(id1, UUID_NIL),
|
||||
]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 5
|
||||
assert [msg["id"] for msg in result] == [id5, id4, id3, id2, id1]
|
||||
|
||||
|
||||
def test_extract_thread_messages_branched_thread():
|
||||
id1, id2, id3, id4 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
TestMessage(id4, id2),
|
||||
TestMessage(id3, id2),
|
||||
TestMessage(id2, id1),
|
||||
TestMessage(id1, UUID_NIL),
|
||||
]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert [msg["id"] for msg in result] == [id4, id2, id1]
|
||||
|
||||
|
||||
def test_extract_thread_messages_empty_list():
|
||||
messages = []
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_extract_thread_messages_partially_loaded():
|
||||
id0, id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
TestMessage(id3, id2),
|
||||
TestMessage(id2, id1),
|
||||
TestMessage(id1, id0),
|
||||
]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert [msg["id"] for msg in result] == [id3, id2, id1]
|
||||
|
||||
|
||||
def test_extract_thread_messages_legacy_messages():
|
||||
id1, id2, id3 = str(uuid4()), str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
TestMessage(id3, UUID_NIL),
|
||||
TestMessage(id2, UUID_NIL),
|
||||
TestMessage(id1, UUID_NIL),
|
||||
]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert [msg["id"] for msg in result] == [id3, id2, id1]
|
||||
|
||||
|
||||
def test_extract_thread_messages_mixed_with_legacy_messages():
|
||||
id1, id2, id3, id4, id5 = str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4()), str(uuid4())
|
||||
messages = [
|
||||
TestMessage(id5, id4),
|
||||
TestMessage(id4, id2),
|
||||
TestMessage(id3, id2),
|
||||
TestMessage(id2, UUID_NIL),
|
||||
TestMessage(id1, UUID_NIL),
|
||||
]
|
||||
result = extract_thread_messages(messages)
|
||||
assert len(result) == 4
|
||||
assert [msg["id"] for msg in result] == [id5, id4, id2, id1]
|
||||
@ -1 +1,2 @@
|
||||
export const CONVERSATION_ID_INFO = 'conversationIdInfo'
|
||||
export const UUID_NIL = '00000000-0000-0000-0000-000000000000'
|
||||
|
||||
@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor"><path d="M5.46257 4.43262C7.21556 2.91688 9.5007 2 12 2C17.5228 2 22 6.47715 22 12C22 14.1361 21.3302 16.1158 20.1892 17.7406L17 12H20C20 7.58172 16.4183 4 12 4C9.84982 4 7.89777 4.84827 6.46023 6.22842L5.46257 4.43262ZM18.5374 19.5674C16.7844 21.0831 14.4993 22 12 22C6.47715 22 2 17.5228 2 12C2 9.86386 2.66979 7.88416 3.8108 6.25944L7 12H4C4 16.4183 7.58172 20 12 20C14.1502 20 16.1022 19.1517 17.5398 17.7716L18.5374 19.5674Z"></path></svg>
|
||||
|
After Width: | Height: | Size: 524 B |
@ -0,0 +1,23 @@
|
||||
{
|
||||
"icon": {
|
||||
"type": "element",
|
||||
"isRootNode": true,
|
||||
"name": "svg",
|
||||
"attributes": {
|
||||
"xmlns": "http://www.w3.org/2000/svg",
|
||||
"viewBox": "0 0 24 24",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M5.46257 4.43262C7.21556 2.91688 9.5007 2 12 2C17.5228 2 22 6.47715 22 12C22 14.1361 21.3302 16.1158 20.1892 17.7406L17 12H20C20 7.58172 16.4183 4 12 4C9.84982 4 7.89777 4.84827 6.46023 6.22842L5.46257 4.43262ZM18.5374 19.5674C16.7844 21.0831 14.4993 22 12 22C6.47715 22 2 17.5228 2 12C2 9.86386 2.66979 7.88416 3.8108 6.25944L7 12H4C4 16.4183 7.58172 20 12 20C14.1502 20 16.1022 19.1517 17.5398 17.7716L18.5374 19.5674Z"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "Refresh"
|
||||
}
|
||||
@ -0,0 +1,16 @@
|
||||
// GENERATE BY script
|
||||
// DON NOT EDIT IT MANUALLY
|
||||
|
||||
import * as React from 'react'
|
||||
import data from './Refresh.json'
|
||||
import IconBase from '@/app/components/base/icons/IconBase'
|
||||
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
|
||||
|
||||
const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
|
||||
props,
|
||||
ref,
|
||||
) => <IconBase {...props} ref={ref} data={data as IconData} />)
|
||||
|
||||
Icon.displayName = 'Refresh'
|
||||
|
||||
export default Icon
|
||||
@ -0,0 +1,31 @@
|
||||
'use client'
|
||||
import { t } from 'i18next'
|
||||
import { Refresh } from '../icons/src/vender/line/general'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
|
||||
type Props = {
|
||||
className?: string
|
||||
onClick?: () => void
|
||||
}
|
||||
|
||||
const RegenerateBtn = ({ className, onClick }: Props) => {
|
||||
return (
|
||||
<div className={`${className}`}>
|
||||
<Tooltip
|
||||
popupContent={t('appApi.regenerate') as string}
|
||||
>
|
||||
<div
|
||||
className={'box-border p-0.5 flex items-center justify-center rounded-md bg-white cursor-pointer'}
|
||||
onClick={() => onClick?.()}
|
||||
style={{
|
||||
boxShadow: '0px 4px 8px -2px rgba(16, 24, 40, 0.1), 0px 2px 4px -2px rgba(16, 24, 40, 0.06)',
|
||||
}}
|
||||
>
|
||||
<Refresh className="p-[3.5px] w-6 h-6 text-[#667085] hover:bg-gray-50" />
|
||||
</div>
|
||||
</Tooltip>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default RegenerateBtn
|
||||
Loading…
Reference in New Issue