feat(models): Add type hints to WorkflowRun and fix a type error

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/20067/head
-LAN- 1 year ago
parent 997b46bfaa
commit d1599fc4af
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -504,6 +504,13 @@ class WorkflowCycleManager:
else: else:
raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}") raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}")
# Handle the case where finished_at is None by using current time as default
finished_at_timestamp = (
int(workflow_run.finished_at.timestamp())
if workflow_run.finished_at
else int(datetime.now(UTC).timestamp())
)
return WorkflowFinishStreamResponse( return WorkflowFinishStreamResponse(
task_id=task_id, task_id=task_id,
workflow_run_id=workflow_run.id, workflow_run_id=workflow_run.id,
@ -519,7 +526,7 @@ class WorkflowCycleManager:
total_steps=workflow_run.total_steps, total_steps=workflow_run.total_steps,
created_by=created_by, created_by=created_by,
created_at=int(workflow_run.created_at.timestamp()), created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp()), finished_at=finished_at_timestamp,
files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)), files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)),
exceptions_count=workflow_run.exceptions_count, exceptions_count=workflow_run.exceptions_count,
), ),

@ -425,14 +425,14 @@ class WorkflowRun(Base):
status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
error: Mapped[Optional[str]] = mapped_column(db.Text) error: Mapped[Optional[str]] = mapped_column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=sa.text("0")) elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0"))
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
total_steps = db.Column(db.Integer, server_default=db.text("0")) total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"))
created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user
created_by = db.Column(StringUUID, nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
finished_at = db.Column(db.DateTime) finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
exceptions_count = db.Column(db.Integer, server_default=db.text("0")) exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"))
@property @property
def created_by_account(self): def created_by_account(self):
@ -447,7 +447,7 @@ class WorkflowRun(Base):
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
@property @property
def graph_dict(self): def graph_dict(self) -> Mapping[str, Any]:
return json.loads(self.graph) if self.graph else {} return json.loads(self.graph) if self.graph else {}
@property @property
@ -749,12 +749,12 @@ class WorkflowAppLog(Base):
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID) tenant_id: Mapped[str] = mapped_column(StringUUID)
app_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id = db.Column(StringUUID, nullable=False) workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
workflow_run_id: Mapped[str] = mapped_column(StringUUID) workflow_run_id: Mapped[str] = mapped_column(StringUUID)
created_from = db.Column(db.String(255), nullable=False) created_from: Mapped[str] = mapped_column(db.String(255), nullable=False)
created_by_role = db.Column(db.String(255), nullable=False) created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False)
created_by = db.Column(StringUUID, nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property @property
def workflow_run(self): def workflow_run(self):
@ -779,9 +779,11 @@ class ConversationVariable(Base):
id: Mapped[str] = mapped_column(StringUUID, primary_key=True) id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True, index=True) conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True, index=True)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
data = mapped_column(db.Text, nullable=False) data: Mapped[str] = mapped_column(db.Text, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True) created_at: Mapped[datetime] = mapped_column(
updated_at = mapped_column( db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True
)
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
) )
@ -829,14 +831,14 @@ class WorkflowDraftVariable(Base):
# id is the unique identifier of a draft variable. # id is the unique identifier of a draft variable.
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
created_at = mapped_column( created_at: Mapped[datetime] = mapped_column(
db.DateTime, db.DateTime,
nullable=False, nullable=False,
default=_naive_utc_datetime, default=_naive_utc_datetime,
server_default=func.current_timestamp(), server_default=func.current_timestamp(),
) )
updated_at = mapped_column( updated_at: Mapped[datetime] = mapped_column(
db.DateTime, db.DateTime,
nullable=False, nullable=False,
default=_naive_utc_datetime, default=_naive_utc_datetime,

Loading…
Cancel
Save