refactor: Simplifies ModelMode instantiation

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/22581/head
-LAN- 10 months ago
parent 5feddefca9
commit e7bea600c6
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -118,7 +118,7 @@ class AppRunner:
else:
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
model_mode = ModelMode.value_of(model_config.mode)
model_mode = ModelMode(model_config.mode)
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
if model_mode == ModelMode.COMPLETION:
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template

@ -29,19 +29,6 @@ class ModelMode(enum.StrEnum):
COMPLETION = "completion"
CHAT = "chat"
@classmethod
def value_of(cls, value: str) -> "ModelMode":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid mode value {value}")
prompt_file_contents: dict[str, Any] = {}
@ -65,7 +52,7 @@ class SimplePromptTransform(PromptTransform):
) -> tuple[list[PromptMessage], Optional[list[str]]]:
inputs = {key: str(value) for key, value in inputs.items()}
model_mode = ModelMode.value_of(model_config.mode)
model_mode = ModelMode(model_config.mode)
if model_mode == ModelMode.CHAT:
prompt_messages, stops = self._get_chat_model_prompt_messages(
app_mode=app_mode,

@ -1137,7 +1137,7 @@ class DatasetRetrieval:
def _get_prompt_template(
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str
):
model_mode = ModelMode.value_of(mode)
model_mode = ModelMode(mode)
input_text = query
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]

@ -510,13 +510,8 @@ class KnowledgeRetrievalNode(BaseNode):
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
# get metadata model config
metadata_model_config = node_data.metadata_model_config
if metadata_model_config is None:
raise ValueError("metadata_model_config is required")
# get metadata model instance
# fetch model config
model_instance, model_config = self.get_model_config(metadata_model_config)
# get metadata model instance and fetch model config
model_instance, model_config = self.get_model_config(node_data.metadata_model_config)
# fetch prompt messages
prompt_template = self._get_prompt_template(
node_data=node_data,
@ -707,7 +702,7 @@ class KnowledgeRetrievalNode(BaseNode):
)
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
model_mode = ModelMode.value_of(node_data.metadata_model_config.mode) # type: ignore
model_mode = ModelMode(node_data.metadata_model_config.mode)
input_text = query
prompt_messages: list[LLMNodeChatModelMessage] = []

@ -420,7 +420,7 @@ class ParameterExtractorNode(BaseNode):
"""
Generate prompt engineering prompt.
"""
model_mode = ModelMode.value_of(data.model.mode)
model_mode = ModelMode(data.model.mode)
if model_mode == ModelMode.COMPLETION:
return self._generate_prompt_engineering_completion_prompt(
@ -716,7 +716,7 @@ class ParameterExtractorNode(BaseNode):
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
model_mode = ModelMode.value_of(node_data.model.mode)
model_mode = ModelMode(node_data.model.mode)
input_text = query
memory_str = ""
instruction = variable_pool.convert_template(node_data.instruction or "").text
@ -743,7 +743,7 @@ class ParameterExtractorNode(BaseNode):
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000,
):
model_mode = ModelMode.value_of(node_data.model.mode)
model_mode = ModelMode(node_data.model.mode)
input_text = query
memory_str = ""
instruction = variable_pool.convert_template(node_data.instruction or "").text

@ -334,7 +334,7 @@ class QuestionClassifierNode(BaseNode):
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000,
):
model_mode = ModelMode.value_of(node_data.model.mode)
model_mode = ModelMode(node_data.model.mode)
classes = node_data.classes
categories = []
for class_ in classes:

Loading…
Cancel
Save