refactor: update datasource entity structure and parameter handling

- Renamed and split parameters in DatasourceEntity into first_step_parameters and second_step_parameters.
- Updated validation methods for new parameter structure.
- Adjusted datasource_node to reference first_step_parameters.
- Cleaned up unused imports and improved type hints in workflow.py.
feat/datasource
Yeuoly 1 year ago
parent 5fa2aca2c8
commit 3bfc602561

@ -22,8 +22,8 @@ class DatasourceProviderType(enum.StrEnum):
""" """
ONLINE_DOCUMENT = "online_document" ONLINE_DOCUMENT = "online_document"
LOCAL_FILE = "local_file"
WEBSITE = "website" WEBSITE = "website"
ONLINE_DRIVE = "online_drive"
@classmethod @classmethod
def value_of(cls, value: str) -> "DatasourceProviderType": def value_of(cls, value: str) -> "DatasourceProviderType":
@ -125,14 +125,21 @@ class DatasourceDescription(BaseModel):
class DatasourceEntity(BaseModel): class DatasourceEntity(BaseModel):
identity: DatasourceIdentity identity: DatasourceIdentity
parameters: list[DatasourceParameter] = Field(default_factory=list)
description: Optional[DatasourceDescription] = None description: Optional[DatasourceDescription] = None
output_schema: Optional[dict] = None first_step_parameters: list[DatasourceParameter] = Field(default_factory=list)
second_step_parameters: list[DatasourceParameter] = Field(default_factory=list)
first_step_output_schema: Optional[dict] = None
second_step_output_schema: Optional[dict] = None
has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters") has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")
@field_validator("parameters", mode="before") @field_validator("first_step_parameters", mode="before")
@classmethod @classmethod
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]: def set_first_step_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]:
return v or []
@field_validator("second_step_parameters", mode="before")
@classmethod
def set_second_step_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]:
return v or [] return v or []

@ -64,7 +64,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]):
return return
# get parameters # get parameters
datasource_parameters = datasource_runtime.entity.parameters datasource_parameters = datasource_runtime.entity.first_step_parameters
parameters = self._generate_parameters( parameters = self._generate_parameters(
datasource_parameters=datasource_parameters, datasource_parameters=datasource_parameters,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=self.graph_runtime_state.variable_pool,

@ -39,7 +39,6 @@ from core.variables.variables import (
from core.workflow.constants import ( from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID, CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID,
PIPELINE_VARIABLE_NODE_ID,
) )
@ -123,6 +122,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
result = result.model_copy(update={"selector": selector}) result = result.model_copy(update={"selector": selector})
return cast(Variable, result) return cast(Variable, result)
def build_segment(value: Any, /) -> Segment: def build_segment(value: Any, /) -> Segment:
if value is None: if value is None:
return NoneSegment() return NoneSegment()

@ -3,7 +3,7 @@ import logging
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from datetime import UTC, datetime from datetime import UTC, datetime
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, List, Optional, Self, Union from typing import TYPE_CHECKING, Any, Optional, Self, Union
from uuid import uuid4 from uuid import uuid4
from core.variables import utils as variable_utils from core.variables import utils as variable_utils
@ -366,11 +366,11 @@ class Workflow(Base):
self._rag_pipeline_variables = "{}" self._rag_pipeline_variables = "{}"
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables) variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
results = [v for v in variables_dict.values()] results = list(variables_dict.values())
return results return results
@rag_pipeline_variables.setter @rag_pipeline_variables.setter
def rag_pipeline_variables(self, values: List[dict]) -> None: def rag_pipeline_variables(self, values: list[dict]) -> None:
self._rag_pipeline_variables = json.dumps( self._rag_pipeline_variables = json.dumps(
{item["variable"]: item for item in values}, {item["variable"]: item for item in values},
ensure_ascii=False, ensure_ascii=False,

Loading…
Cancel
Save