fix(api): conversation variable should be editable.

pull/20699/head
QuantumGhost 11 months ago
parent 8f8465cd9f
commit b46a56c272

@ -1206,6 +1206,7 @@ class WorkflowDraftVariable(Base):
description=description, description=description,
node_execution_id=None, node_execution_id=None,
) )
variable.editable = True
return variable return variable
@classmethod @classmethod

@ -217,9 +217,6 @@ class WorkflowDraftVariableService:
return variable return variable
def _reset_conv_var(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None: def _reset_conv_var(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None:
# If a variable does not allow updating, it makes no sence to resetting it.
if not variable.editable:
return variable
conv_var_by_name = {i.name: i for i in workflow.conversation_variables} conv_var_by_name = {i.name: i for i in workflow.conversation_variables}
conv_var = conv_var_by_name.get(variable.name) conv_var = conv_var_by_name.get(variable.name)
@ -238,6 +235,9 @@ class WorkflowDraftVariableService:
return variable return variable
def _reset_node_var(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None: def _reset_node_var(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None:
# If a variable does not allow updating, it makes no sence to resetting it.
if not variable.editable:
return variable
# No execution record for this variable, delete the variable instead. # No execution record for this variable, delete the variable instead.
if variable.node_execution_id is None: if variable.node_execution_id is None:
self._session.delete(instance=variable) self._session.delete(instance=variable)

@ -26,6 +26,7 @@ from core.variables.segments import (
ArrayNumberSegment, ArrayNumberSegment,
ArrayObjectSegment, ArrayObjectSegment,
ArrayStringSegment, ArrayStringSegment,
FileSegment,
FloatSegment, FloatSegment,
IntegerSegment, IntegerSegment,
NoneSegment, NoneSegment,
@ -551,6 +552,25 @@ class TestBuildSegmentWithType:
assert result.value == test_obj assert result.value == test_obj
assert result.value_type == SegmentType.OBJECT assert result.value_type == SegmentType.OBJECT
def test_file_type(self):
"""Test building a file segment with correct type."""
test_file = File(
id="test_file_id",
tenant_id="test_tenant_id",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://test.example.com/test-file.png",
filename="test-file",
extension=".png",
mime_type="image/png",
size=1000,
storage_key="test_storage_key",
)
result = build_segment_with_type(SegmentType.FILE, test_file)
assert isinstance(result, FileSegment)
assert result.value == test_file
assert result.value_type == SegmentType.FILE
def test_none_type(self): def test_none_type(self):
"""Test building a none segment with None value.""" """Test building a none segment with None value."""
result = build_segment_with_type(SegmentType.NONE, None) result = build_segment_with_type(SegmentType.NONE, None)
@ -811,7 +831,6 @@ class TestBuildSegmentValueErrors:
self.ValueErrorTestCase( self.ValueErrorTestCase(
name="generic_object", description="generic object (unsupported type)", test_value=object() name="generic_object", description="generic object (unsupported type)", test_value=object()
), ),
self.ValueErrorTestCase(name="nested_list", description="nested list (unsupported type)", test_value=[[1]]),
] ]
def test_build_segment_unsupported_types(self): def test_build_segment_unsupported_types(self):
@ -822,8 +841,9 @@ class TestBuildSegmentValueErrors:
# Use test value directly # Use test value directly
test_value = test_case.test_value test_value = test_case.test_value
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info: # noqa: PT012
variable_factory.build_segment(test_value) segment = variable_factory.build_segment(test_value)
pytest.fail(f"Test case {index} ({test_case.name}) should raise ValueError but not, result={segment}")
error_message = str(exc_info.value) error_message = str(exc_info.value)
assert "not supported value" in error_message, ( assert "not supported value" in error_message, (

Loading…
Cancel
Save