feat: workflow node support retry
parent
a725b8bb6e
commit
b99f1a09f4
@ -0,0 +1,33 @@
|
||||
"""add retry_index field to node-execution model
|
||||
|
||||
Revision ID: 348cb0a93d53
|
||||
Revises: cf8f4fc45278
|
||||
Create Date: 2024-12-16 01:23:13.093432
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '348cb0a93d53'
|
||||
down_revision = 'cf8f4fc45278'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('retry_index', sa.Integer(), server_default=sa.text('0'), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
|
||||
batch_op.drop_column('retry_index')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,114 @@
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunRetryEvent,
|
||||
)
|
||||
from tests.unit_tests.core.workflow.nodes.test_continue_on_error import ContinueOnErrorTestHelper
|
||||
|
||||
DEFAULT_VALUE_EDGE = [
|
||||
{
|
||||
"id": "start-source-node-target",
|
||||
"source": "start",
|
||||
"target": "node",
|
||||
"sourceHandle": "source",
|
||||
},
|
||||
{
|
||||
"id": "node-source-answer-target",
|
||||
"source": "node",
|
||||
"target": "answer",
|
||||
"sourceHandle": "source",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_retry_default_value_partial_success():
|
||||
"""retry default value node with partial success status"""
|
||||
error_code = """
|
||||
def main() -> dict:
|
||||
return {
|
||||
"result": 1 / 0,
|
||||
}
|
||||
"""
|
||||
|
||||
graph_config = {
|
||||
"edges": DEFAULT_VALUE_EDGE,
|
||||
"nodes": [
|
||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||
ContinueOnErrorTestHelper.get_code_node(
|
||||
error_code,
|
||||
"default-value",
|
||||
[{"key": "result", "type": "number", "value": 132123}],
|
||||
{"retry_config": {"max_retries": 2, "retry_interval": 1, "retry_enabled": True}},
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
|
||||
assert events[-1].outputs == {"answer": "132123"}
|
||||
assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events)
|
||||
assert len(events) == 11
|
||||
|
||||
|
||||
def test_retry_success():
|
||||
"""retry node with success status"""
|
||||
success_code = """
|
||||
count = 0
|
||||
def main():
|
||||
global count
|
||||
count += 1
|
||||
if count == 1:
|
||||
raise Exception("First attempt fails")
|
||||
if count == 2:
|
||||
return {"result": "success"}
|
||||
"""
|
||||
graph_config = {
|
||||
"edges": DEFAULT_VALUE_EDGE,
|
||||
"nodes": [
|
||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||
ContinueOnErrorTestHelper.get_code_node(
|
||||
success_code,
|
||||
None,
|
||||
None,
|
||||
{"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
|
||||
),
|
||||
],
|
||||
}
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
|
||||
assert any(isinstance(e, GraphRunSucceededEvent) for e in events)
|
||||
assert len(events) == 9
|
||||
|
||||
|
||||
def test_retry_failed():
|
||||
"""retry failed with success status"""
|
||||
error_code = """
|
||||
def main() -> dict:
|
||||
return {
|
||||
"result": 1 / 0,
|
||||
}
|
||||
"""
|
||||
|
||||
graph_config = {
|
||||
"edges": DEFAULT_VALUE_EDGE,
|
||||
"nodes": [
|
||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||
ContinueOnErrorTestHelper.get_code_node(
|
||||
error_code,
|
||||
None,
|
||||
None,
|
||||
{"retry_config": {"max_retries": 2, "retry_interval": 1, "retry_enabled": True}},
|
||||
),
|
||||
],
|
||||
}
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
|
||||
assert any(isinstance(e, GraphRunFailedEvent) for e in events)
|
||||
assert len(events) == 8
|
||||
Loading…
Reference in New Issue