feat: Add draft hash check in workflow (#4251)

pull/4260/head
takatost 2 years ago committed by GitHub
parent a1ab87107b
commit 8f3042e5b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -91,3 +91,9 @@ class DraftWorkflowNotExist(BaseHTTPException):
error_code = 'draft_workflow_not_exist' error_code = 'draft_workflow_not_exist'
description = "Draft workflow need to be initialized." description = "Draft workflow need to be initialized."
code = 400 code = 400
class DraftWorkflowNotSync(BaseHTTPException):
error_code = 'draft_workflow_not_sync'
description = "Workflow graph might have been modified, please refresh and resubmit."
code = 400

@ -7,7 +7,7 @@ from werkzeug.exceptions import InternalServerError, NotFound
import services import services
from controllers.console import api from controllers.console import api
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
@ -20,6 +20,7 @@ from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required from libs.login import current_user, login_required
from models.model import App, AppMode from models.model import App, AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -59,6 +60,7 @@ class DraftWorkflowApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument('graph', type=dict, required=True, nullable=False, location='json') parser.add_argument('graph', type=dict, required=True, nullable=False, location='json')
parser.add_argument('features', type=dict, required=True, nullable=False, location='json') parser.add_argument('features', type=dict, required=True, nullable=False, location='json')
parser.add_argument('hash', type=str, required=False, location='json')
args = parser.parse_args() args = parser.parse_args()
elif 'text/plain' in content_type: elif 'text/plain' in content_type:
try: try:
@ -71,7 +73,8 @@ class DraftWorkflowApi(Resource):
args = { args = {
'graph': data.get('graph'), 'graph': data.get('graph'),
'features': data.get('features') 'features': data.get('features'),
'hash': data.get('hash')
} }
except json.JSONDecodeError: except json.JSONDecodeError:
return {'message': 'Invalid JSON data'}, 400 return {'message': 'Invalid JSON data'}, 400
@ -79,15 +82,21 @@ class DraftWorkflowApi(Resource):
abort(415) abort(415)
workflow_service = WorkflowService() workflow_service = WorkflowService()
workflow = workflow_service.sync_draft_workflow(
app_model=app_model, try:
graph=args.get('graph'), workflow = workflow_service.sync_draft_workflow(
features=args.get('features'), app_model=app_model,
account=current_user graph=args.get('graph'),
) features=args.get('features'),
unique_hash=args.get('hash'),
account=current_user
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()
return { return {
"result": "success", "result": "success",
"hash": workflow.unique_hash,
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at) "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at)
} }

@ -7,6 +7,7 @@ workflow_fields = {
'id': fields.String, 'id': fields.String,
'graph': fields.Raw(attribute='graph_dict'), 'graph': fields.Raw(attribute='graph_dict'),
'features': fields.Raw(attribute='features_dict'), 'features': fields.Raw(attribute='features_dict'),
'hash': fields.String(attribute='unique_hash'),
'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'), 'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'),
'created_at': TimestampField, 'created_at': TimestampField,
'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True),

@ -4,6 +4,7 @@ from typing import Optional, Union
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from extensions.ext_database import db from extensions.ext_database import db
from libs import helper
from models import StringUUID from models import StringUUID
from models.account import Account from models.account import Account
@ -156,6 +157,21 @@ class Workflow(db.Model):
return variables return variables
@property
def unique_hash(self) -> str:
"""
Get hash of workflow.
:return: hash
"""
entity = {
'graph': self.graph_dict,
'features': self.features_dict
}
return helper.generate_text_hash(json.dumps(entity, sort_keys=True))
class WorkflowRunTriggeredFrom(Enum): class WorkflowRunTriggeredFrom(Enum):
""" """
Workflow Run Triggered From Enum Workflow Run Triggered From Enum

@ -196,6 +196,7 @@ class AppService:
app_model=app, app_model=app,
graph=workflow.get('graph'), graph=workflow.get('graph'),
features=workflow.get('features'), features=workflow.get('features'),
unique_hash=None,
account=account account=account
) )
workflow_service.publish_workflow( workflow_service.publish_workflow(

@ -1,2 +1,6 @@
class MoreLikeThisDisabledError(Exception): class MoreLikeThisDisabledError(Exception):
pass pass
class WorkflowHashNotEqualError(Exception):
pass

@ -21,6 +21,7 @@ from models.workflow import (
WorkflowNodeExecutionTriggeredFrom, WorkflowNodeExecutionTriggeredFrom,
WorkflowType, WorkflowType,
) )
from services.errors.app import WorkflowHashNotEqualError
from services.workflow.workflow_converter import WorkflowConverter from services.workflow.workflow_converter import WorkflowConverter
@ -63,13 +64,20 @@ class WorkflowService:
def sync_draft_workflow(self, app_model: App, def sync_draft_workflow(self, app_model: App,
graph: dict, graph: dict,
features: dict, features: dict,
unique_hash: Optional[str],
account: Account) -> Workflow: account: Account) -> Workflow:
""" """
Sync draft workflow Sync draft workflow
@throws WorkflowHashNotEqualError
""" """
# fetch draft workflow by app_model # fetch draft workflow by app_model
workflow = self.get_draft_workflow(app_model=app_model) workflow = self.get_draft_workflow(app_model=app_model)
if workflow:
# validate unique hash
if workflow.unique_hash != unique_hash:
raise WorkflowHashNotEqualError()
# validate features structure # validate features structure
self.validate_features_structure( self.validate_features_structure(
app_model=app_model, app_model=app_model,

Loading…
Cancel
Save