Merge branch 'main' into feat/support-extractor-tools
commit
cae7f7523b
@ -0,0 +1,24 @@
|
|||||||
|
from flask_restful import fields
|
||||||
|
|
||||||
|
parameters__system_parameters = {
|
||||||
|
"image_file_size_limit": fields.Integer,
|
||||||
|
"video_file_size_limit": fields.Integer,
|
||||||
|
"audio_file_size_limit": fields.Integer,
|
||||||
|
"file_size_limit": fields.Integer,
|
||||||
|
"workflow_file_upload_limit": fields.Integer,
|
||||||
|
}
|
||||||
|
|
||||||
|
parameters_fields = {
|
||||||
|
"opening_statement": fields.String,
|
||||||
|
"suggested_questions": fields.Raw,
|
||||||
|
"suggested_questions_after_answer": fields.Raw,
|
||||||
|
"speech_to_text": fields.Raw,
|
||||||
|
"text_to_speech": fields.Raw,
|
||||||
|
"retriever_resource": fields.Raw,
|
||||||
|
"annotation_reply": fields.Raw,
|
||||||
|
"more_like_this": fields.Raw,
|
||||||
|
"user_input_form": fields.Raw,
|
||||||
|
"sensitive_word_avoidance": fields.Raw,
|
||||||
|
"file_upload": fields.Raw,
|
||||||
|
"system_parameters": fields.Nested(parameters__system_parameters),
|
||||||
|
}
|
||||||
@ -0,0 +1,39 @@
|
|||||||
|
model: claude-3-5-haiku-20241022
|
||||||
|
label:
|
||||||
|
en_US: claude-3-5-haiku-20241022
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- vision
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 200000
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
required: true
|
||||||
|
default: 8192
|
||||||
|
min: 1
|
||||||
|
max: 8192
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '1.00'
|
||||||
|
output: '5.00'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
||||||
@ -0,0 +1 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" fill="currentColor" viewBox="0 0 24 24" aria-hidden="true" class="" focusable="false" style="fill:currentColor;height:28px;width:28px"><path d="m3.005 8.858 8.783 12.544h3.904L6.908 8.858zM6.905 15.825 3 21.402h3.907l1.951-2.788zM16.585 2l-6.75 9.64 1.953 2.79L20.492 2zM17.292 7.965v13.437h3.2V3.395z"></path></svg>
|
||||||
|
After Width: | Height: | Size: 356 B |
@ -0,0 +1,37 @@
|
|||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from yarl import URL
|
||||||
|
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageTool,
|
||||||
|
)
|
||||||
|
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||||
|
|
||||||
|
|
||||||
|
class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
prompt_messages: list[PromptMessage],
|
||||||
|
model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
|
stop: Optional[list[str]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[LLMResult, Generator]:
|
||||||
|
self._add_custom_parameters(credentials)
|
||||||
|
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
self._add_custom_parameters(credentials)
|
||||||
|
super().validate_credentials(model, credentials)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _add_custom_parameters(credentials) -> None:
|
||||||
|
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1"
|
||||||
|
credentials["mode"] = LLMMode.CHAT.value
|
||||||
|
credentials["function_calling_type"] = "tool_call"
|
||||||
@ -0,0 +1,25 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class XAIProvider(ModelProvider):
|
||||||
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate provider credentials
|
||||||
|
if validate failed, raise exception
|
||||||
|
|
||||||
|
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_instance = self.get_model_instance(ModelType.LLM)
|
||||||
|
model_instance.validate_credentials(model="grok-beta", credentials=credentials)
|
||||||
|
except CredentialsValidateFailedError as ex:
|
||||||
|
raise ex
|
||||||
|
except Exception as ex:
|
||||||
|
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||||
|
raise ex
|
||||||
@ -0,0 +1,38 @@
|
|||||||
|
provider: x
|
||||||
|
label:
|
||||||
|
en_US: xAI
|
||||||
|
description:
|
||||||
|
en_US: xAI is a company working on building artificial intelligence to accelerate human scientific discovery. We are guided by our mission to advance our collective understanding of the universe.
|
||||||
|
icon_small:
|
||||||
|
en_US: x-ai-logo.svg
|
||||||
|
icon_large:
|
||||||
|
en_US: x-ai-logo.svg
|
||||||
|
help:
|
||||||
|
title:
|
||||||
|
en_US: Get your token from xAI
|
||||||
|
zh_Hans: 从 xAI 获取 token
|
||||||
|
url:
|
||||||
|
en_US: https://x.ai/api
|
||||||
|
supported_model_types:
|
||||||
|
- llm
|
||||||
|
configurate_methods:
|
||||||
|
- predefined-model
|
||||||
|
provider_credential_schema:
|
||||||
|
credential_form_schemas:
|
||||||
|
- variable: api_key
|
||||||
|
label:
|
||||||
|
en_US: API Key
|
||||||
|
type: secret-input
|
||||||
|
required: true
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的 API Key
|
||||||
|
en_US: Enter your API Key
|
||||||
|
- variable: endpoint_url
|
||||||
|
label:
|
||||||
|
en_US: API Base
|
||||||
|
type: text-input
|
||||||
|
required: false
|
||||||
|
default: https://api.x.ai/v1
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的 API Base
|
||||||
|
en_US: Enter your API Base
|
||||||
@ -0,0 +1,16 @@
|
|||||||
|
class CodeNodeError(ValueError):
|
||||||
|
"""Base class for code node errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OutputValidationError(CodeNodeError):
|
||||||
|
"""Raised when there is an output validation error."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DepthLimitError(CodeNodeError):
|
||||||
|
"""Raised when the depth limit is reached."""
|
||||||
|
|
||||||
|
pass
|
||||||
@ -0,0 +1,18 @@
|
|||||||
|
class HttpRequestNodeError(ValueError):
|
||||||
|
"""Custom error for HTTP request node."""
|
||||||
|
|
||||||
|
|
||||||
|
class AuthorizationConfigError(HttpRequestNodeError):
|
||||||
|
"""Raised when authorization config is missing or invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class FileFetchError(HttpRequestNodeError):
|
||||||
|
"""Raised when a file cannot be fetched."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidHttpMethodError(HttpRequestNodeError):
|
||||||
|
"""Raised when an invalid HTTP method is used."""
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseSizeError(HttpRequestNodeError):
|
||||||
|
"""Raised when the response size exceeds the allowed threshold."""
|
||||||
@ -0,0 +1,26 @@
|
|||||||
|
class LLMNodeError(ValueError):
|
||||||
|
"""Base class for LLM Node errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class VariableNotFoundError(LLMNodeError):
|
||||||
|
"""Raised when a required variable is not found."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidContextStructureError(LLMNodeError):
|
||||||
|
"""Raised when the context structure is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidVariableTypeError(LLMNodeError):
|
||||||
|
"""Raised when the variable type is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelNotExistError(LLMNodeError):
|
||||||
|
"""Raised when the specified model does not exist."""
|
||||||
|
|
||||||
|
|
||||||
|
class LLMModeRequiredError(LLMNodeError):
|
||||||
|
"""Raised when LLM mode is required but not provided."""
|
||||||
|
|
||||||
|
|
||||||
|
class NoPromptFoundError(LLMNodeError):
|
||||||
|
"""Raised when no prompt is found in the LLM configuration."""
|
||||||
@ -0,0 +1,50 @@
|
|||||||
|
class ParameterExtractorNodeError(ValueError):
|
||||||
|
"""Base error for ParameterExtractorNode."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidModelTypeError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when the model is not a Large Language Model."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSchemaNotFoundError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when the model schema is not found."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidInvokeResultError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when the invoke result is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidTextContentTypeError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when the text content type is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidNumberOfParametersError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when the number of parameters is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class RequiredParameterMissingError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when a required parameter is missing."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidSelectValueError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when a select value is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidNumberValueError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when a number value is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidBoolValueError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when a bool value is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidStringValueError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when a string value is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidArrayValueError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when an array value is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidModelModeError(ParameterExtractorNodeError):
|
||||||
|
"""Raised when the model mode is invalid."""
|
||||||
@ -0,0 +1,204 @@
|
|||||||
|
import os
|
||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
PromptMessageTool,
|
||||||
|
SystemPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.x.llm.llm import XAILargeLanguageModel
|
||||||
|
|
||||||
|
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||||
|
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||||
|
|
||||||
|
|
||||||
|
def test_predefined_models():
|
||||||
|
model = XAILargeLanguageModel()
|
||||||
|
model_schemas = model.predefined_models()
|
||||||
|
|
||||||
|
assert len(model_schemas) >= 1
|
||||||
|
assert isinstance(model_schemas[0], AIModelEntity)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||||
|
def test_validate_credentials_for_chat_model(setup_openai_mock):
|
||||||
|
model = XAILargeLanguageModel()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
# model name to gpt-3.5-turbo because of mocking
|
||||||
|
model.validate_credentials(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
credentials={"api_key": "invalid_key", "endpoint_url": os.environ.get("XAI_API_BASE"), "mode": "chat"},
|
||||||
|
)
|
||||||
|
|
||||||
|
model.validate_credentials(
|
||||||
|
model="grok-beta",
|
||||||
|
credentials={
|
||||||
|
"api_key": os.environ.get("XAI_API_KEY"),
|
||||||
|
"endpoint_url": os.environ.get("XAI_API_BASE"),
|
||||||
|
"mode": "chat",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||||
|
def test_invoke_chat_model(setup_openai_mock):
|
||||||
|
model = XAILargeLanguageModel()
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model="grok-beta",
|
||||||
|
credentials={
|
||||||
|
"api_key": os.environ.get("XAI_API_KEY"),
|
||||||
|
"endpoint_url": os.environ.get("XAI_API_BASE"),
|
||||||
|
"mode": "chat",
|
||||||
|
},
|
||||||
|
prompt_messages=[
|
||||||
|
SystemPromptMessage(
|
||||||
|
content="You are a helpful AI assistant.",
|
||||||
|
),
|
||||||
|
UserPromptMessage(content="Hello World!"),
|
||||||
|
],
|
||||||
|
model_parameters={
|
||||||
|
"temperature": 0.0,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"presence_penalty": 0.0,
|
||||||
|
"frequency_penalty": 0.0,
|
||||||
|
"max_tokens": 10,
|
||||||
|
},
|
||||||
|
stop=["How"],
|
||||||
|
stream=False,
|
||||||
|
user="foo",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, LLMResult)
|
||||||
|
assert len(result.message.content) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||||
|
def test_invoke_chat_model_with_tools(setup_openai_mock):
|
||||||
|
model = XAILargeLanguageModel()
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model="grok-beta",
|
||||||
|
credentials={
|
||||||
|
"api_key": os.environ.get("XAI_API_KEY"),
|
||||||
|
"endpoint_url": os.environ.get("XAI_API_BASE"),
|
||||||
|
"mode": "chat",
|
||||||
|
},
|
||||||
|
prompt_messages=[
|
||||||
|
SystemPromptMessage(
|
||||||
|
content="You are a helpful AI assistant.",
|
||||||
|
),
|
||||||
|
UserPromptMessage(
|
||||||
|
content="what's the weather today in London?",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||||
|
tools=[
|
||||||
|
PromptMessageTool(
|
||||||
|
name="get_weather",
|
||||||
|
description="Determine weather in my location",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||||
|
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
PromptMessageTool(
|
||||||
|
name="get_stock_price",
|
||||||
|
description="Get the current stock price",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
|
||||||
|
"required": ["symbol"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
stream=False,
|
||||||
|
user="foo",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, LLMResult)
|
||||||
|
assert isinstance(result.message, AssistantPromptMessage)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||||
|
def test_invoke_stream_chat_model(setup_openai_mock):
|
||||||
|
model = XAILargeLanguageModel()
|
||||||
|
|
||||||
|
result = model.invoke(
|
||||||
|
model="grok-beta",
|
||||||
|
credentials={
|
||||||
|
"api_key": os.environ.get("XAI_API_KEY"),
|
||||||
|
"endpoint_url": os.environ.get("XAI_API_BASE"),
|
||||||
|
"mode": "chat",
|
||||||
|
},
|
||||||
|
prompt_messages=[
|
||||||
|
SystemPromptMessage(
|
||||||
|
content="You are a helpful AI assistant.",
|
||||||
|
),
|
||||||
|
UserPromptMessage(content="Hello World!"),
|
||||||
|
],
|
||||||
|
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||||
|
stream=True,
|
||||||
|
user="foo",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, Generator)
|
||||||
|
|
||||||
|
for chunk in result:
|
||||||
|
assert isinstance(chunk, LLMResultChunk)
|
||||||
|
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||||
|
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||||
|
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||||
|
if chunk.delta.finish_reason is not None:
|
||||||
|
assert chunk.delta.usage is not None
|
||||||
|
assert chunk.delta.usage.completion_tokens > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_num_tokens():
|
||||||
|
model = XAILargeLanguageModel()
|
||||||
|
|
||||||
|
num_tokens = model.get_num_tokens(
|
||||||
|
model="grok-beta",
|
||||||
|
credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")},
|
||||||
|
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert num_tokens == 10
|
||||||
|
|
||||||
|
num_tokens = model.get_num_tokens(
|
||||||
|
model="grok-beta",
|
||||||
|
credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")},
|
||||||
|
prompt_messages=[
|
||||||
|
SystemPromptMessage(
|
||||||
|
content="You are a helpful AI assistant.",
|
||||||
|
),
|
||||||
|
UserPromptMessage(content="Hello World!"),
|
||||||
|
],
|
||||||
|
tools=[
|
||||||
|
PromptMessageTool(
|
||||||
|
name="get_weather",
|
||||||
|
description="Determine weather in my location",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||||
|
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert num_tokens == 77
|
||||||
@ -0,0 +1,52 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||||
|
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_inputs_with_zero():
|
||||||
|
base_app_generator = BaseAppGenerator()
|
||||||
|
|
||||||
|
var = VariableEntity(
|
||||||
|
variable="test_var",
|
||||||
|
label="test_var",
|
||||||
|
type=VariableEntityType.NUMBER,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with input 0
|
||||||
|
result = base_app_generator._validate_inputs(
|
||||||
|
variable_entity=var,
|
||||||
|
value=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == 0
|
||||||
|
|
||||||
|
# Test with input "0" (string)
|
||||||
|
result = base_app_generator._validate_inputs(
|
||||||
|
variable_entity=var,
|
||||||
|
value="0",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_input_with_none_for_required_variable():
|
||||||
|
base_app_generator = BaseAppGenerator()
|
||||||
|
|
||||||
|
for var_type in VariableEntityType:
|
||||||
|
var = VariableEntity(
|
||||||
|
variable="test_var",
|
||||||
|
label="test_var",
|
||||||
|
type=var_type,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with input None
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
base_app_generator._validate_inputs(
|
||||||
|
variable_entity=var,
|
||||||
|
value=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert str(exc_info.value) == "test_var is required in input form"
|
||||||
@ -0,0 +1,198 @@
|
|||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.nodes.http_request import (
|
||||||
|
BodyData,
|
||||||
|
HttpRequestNodeAuthorization,
|
||||||
|
HttpRequestNodeBody,
|
||||||
|
HttpRequestNodeData,
|
||||||
|
)
|
||||||
|
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
|
||||||
|
from core.workflow.nodes.http_request.executor import Executor
|
||||||
|
|
||||||
|
|
||||||
|
def test_executor_with_json_body_and_number_variable():
|
||||||
|
# Prepare the variable pool
|
||||||
|
variable_pool = VariablePool(
|
||||||
|
system_variables={},
|
||||||
|
user_inputs={},
|
||||||
|
)
|
||||||
|
variable_pool.add(["pre_node_id", "number"], 42)
|
||||||
|
|
||||||
|
# Prepare the node data
|
||||||
|
node_data = HttpRequestNodeData(
|
||||||
|
title="Test JSON Body with Number Variable",
|
||||||
|
method="post",
|
||||||
|
url="https://api.example.com/data",
|
||||||
|
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||||
|
headers="Content-Type: application/json",
|
||||||
|
params="",
|
||||||
|
body=HttpRequestNodeBody(
|
||||||
|
type="json",
|
||||||
|
data=[
|
||||||
|
BodyData(
|
||||||
|
key="",
|
||||||
|
type="text",
|
||||||
|
value='{"number": {{#pre_node_id.number#}}}',
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the Executor
|
||||||
|
executor = Executor(
|
||||||
|
node_data=node_data,
|
||||||
|
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check the executor's data
|
||||||
|
assert executor.method == "post"
|
||||||
|
assert executor.url == "https://api.example.com/data"
|
||||||
|
assert executor.headers == {"Content-Type": "application/json"}
|
||||||
|
assert executor.params == {}
|
||||||
|
assert executor.json == {"number": 42}
|
||||||
|
assert executor.data is None
|
||||||
|
assert executor.files is None
|
||||||
|
assert executor.content is None
|
||||||
|
|
||||||
|
# Check the raw request (to_log method)
|
||||||
|
raw_request = executor.to_log()
|
||||||
|
assert "POST /data HTTP/1.1" in raw_request
|
||||||
|
assert "Host: api.example.com" in raw_request
|
||||||
|
assert "Content-Type: application/json" in raw_request
|
||||||
|
assert '{"number": 42}' in raw_request
|
||||||
|
|
||||||
|
|
||||||
|
def test_executor_with_json_body_and_object_variable():
|
||||||
|
# Prepare the variable pool
|
||||||
|
variable_pool = VariablePool(
|
||||||
|
system_variables={},
|
||||||
|
user_inputs={},
|
||||||
|
)
|
||||||
|
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
||||||
|
|
||||||
|
# Prepare the node data
|
||||||
|
node_data = HttpRequestNodeData(
|
||||||
|
title="Test JSON Body with Object Variable",
|
||||||
|
method="post",
|
||||||
|
url="https://api.example.com/data",
|
||||||
|
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||||
|
headers="Content-Type: application/json",
|
||||||
|
params="",
|
||||||
|
body=HttpRequestNodeBody(
|
||||||
|
type="json",
|
||||||
|
data=[
|
||||||
|
BodyData(
|
||||||
|
key="",
|
||||||
|
type="text",
|
||||||
|
value="{{#pre_node_id.object#}}",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the Executor
|
||||||
|
executor = Executor(
|
||||||
|
node_data=node_data,
|
||||||
|
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check the executor's data
|
||||||
|
assert executor.method == "post"
|
||||||
|
assert executor.url == "https://api.example.com/data"
|
||||||
|
assert executor.headers == {"Content-Type": "application/json"}
|
||||||
|
assert executor.params == {}
|
||||||
|
assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"}
|
||||||
|
assert executor.data is None
|
||||||
|
assert executor.files is None
|
||||||
|
assert executor.content is None
|
||||||
|
|
||||||
|
# Check the raw request (to_log method)
|
||||||
|
raw_request = executor.to_log()
|
||||||
|
assert "POST /data HTTP/1.1" in raw_request
|
||||||
|
assert "Host: api.example.com" in raw_request
|
||||||
|
assert "Content-Type: application/json" in raw_request
|
||||||
|
assert '"name": "John Doe"' in raw_request
|
||||||
|
assert '"age": 30' in raw_request
|
||||||
|
assert '"email": "john@example.com"' in raw_request
|
||||||
|
|
||||||
|
|
||||||
|
def test_executor_with_json_body_and_nested_object_variable():
|
||||||
|
# Prepare the variable pool
|
||||||
|
variable_pool = VariablePool(
|
||||||
|
system_variables={},
|
||||||
|
user_inputs={},
|
||||||
|
)
|
||||||
|
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
||||||
|
|
||||||
|
# Prepare the node data
|
||||||
|
node_data = HttpRequestNodeData(
|
||||||
|
title="Test JSON Body with Nested Object Variable",
|
||||||
|
method="post",
|
||||||
|
url="https://api.example.com/data",
|
||||||
|
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||||
|
headers="Content-Type: application/json",
|
||||||
|
params="",
|
||||||
|
body=HttpRequestNodeBody(
|
||||||
|
type="json",
|
||||||
|
data=[
|
||||||
|
BodyData(
|
||||||
|
key="",
|
||||||
|
type="text",
|
||||||
|
value='{"object": {{#pre_node_id.object#}}}',
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the Executor
|
||||||
|
executor = Executor(
|
||||||
|
node_data=node_data,
|
||||||
|
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check the executor's data
|
||||||
|
assert executor.method == "post"
|
||||||
|
assert executor.url == "https://api.example.com/data"
|
||||||
|
assert executor.headers == {"Content-Type": "application/json"}
|
||||||
|
assert executor.params == {}
|
||||||
|
assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}}
|
||||||
|
assert executor.data is None
|
||||||
|
assert executor.files is None
|
||||||
|
assert executor.content is None
|
||||||
|
|
||||||
|
# Check the raw request (to_log method)
|
||||||
|
raw_request = executor.to_log()
|
||||||
|
assert "POST /data HTTP/1.1" in raw_request
|
||||||
|
assert "Host: api.example.com" in raw_request
|
||||||
|
assert "Content-Type: application/json" in raw_request
|
||||||
|
assert '"object": {' in raw_request
|
||||||
|
assert '"name": "John Doe"' in raw_request
|
||||||
|
assert '"age": 30' in raw_request
|
||||||
|
assert '"email": "john@example.com"' in raw_request
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_selectors_from_template_with_newline():
|
||||||
|
variable_pool = VariablePool()
|
||||||
|
variable_pool.add(("node_id", "custom_query"), "line1\nline2")
|
||||||
|
node_data = HttpRequestNodeData(
|
||||||
|
title="Test JSON Body with Nested Object Variable",
|
||||||
|
method="post",
|
||||||
|
url="https://api.example.com/data",
|
||||||
|
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||||
|
headers="Content-Type: application/json",
|
||||||
|
params="test: {{#node_id.custom_query#}}",
|
||||||
|
body=HttpRequestNodeBody(
|
||||||
|
type="none",
|
||||||
|
data=[],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
executor = Executor(
|
||||||
|
node_data=node_data,
|
||||||
|
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert executor.params == {"test": "line1\nline2"}
|
||||||
@ -1,99 +0,0 @@
|
|||||||
'use client'
|
|
||||||
import { useCallback, useEffect, useRef } from 'react'
|
|
||||||
import { jwtDecode } from 'jwt-decode'
|
|
||||||
import dayjs from 'dayjs'
|
|
||||||
import utc from 'dayjs/plugin/utc'
|
|
||||||
import { useRouter } from 'next/navigation'
|
|
||||||
import type { CommonResponse } from '@/models/common'
|
|
||||||
import { fetchNewToken } from '@/service/common'
|
|
||||||
import { fetchWithRetry } from '@/utils'
|
|
||||||
|
|
||||||
dayjs.extend(utc)
|
|
||||||
|
|
||||||
const useRefreshToken = () => {
|
|
||||||
const router = useRouter()
|
|
||||||
const timer = useRef<NodeJS.Timeout>()
|
|
||||||
const advanceTime = useRef<number>(5 * 60 * 1000)
|
|
||||||
|
|
||||||
const getExpireTime = useCallback((token: string) => {
|
|
||||||
if (!token)
|
|
||||||
return 0
|
|
||||||
const decoded = jwtDecode(token)
|
|
||||||
return (decoded.exp || 0) * 1000
|
|
||||||
}, [])
|
|
||||||
|
|
||||||
const getCurrentTimeStamp = useCallback(() => {
|
|
||||||
return dayjs.utc().valueOf()
|
|
||||||
}, [])
|
|
||||||
|
|
||||||
const handleError = useCallback(() => {
|
|
||||||
localStorage?.removeItem('is_refreshing')
|
|
||||||
localStorage?.removeItem('console_token')
|
|
||||||
localStorage?.removeItem('refresh_token')
|
|
||||||
router.replace('/signin')
|
|
||||||
}, [])
|
|
||||||
|
|
||||||
const getNewAccessToken = useCallback(async () => {
|
|
||||||
const currentAccessToken = localStorage?.getItem('console_token')
|
|
||||||
const currentRefreshToken = localStorage?.getItem('refresh_token')
|
|
||||||
if (!currentAccessToken || !currentRefreshToken) {
|
|
||||||
handleError()
|
|
||||||
return new Error('No access token or refresh token found')
|
|
||||||
}
|
|
||||||
if (localStorage?.getItem('is_refreshing') === '1') {
|
|
||||||
clearTimeout(timer.current)
|
|
||||||
timer.current = setTimeout(() => {
|
|
||||||
getNewAccessToken()
|
|
||||||
}, 1000)
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
const currentTokenExpireTime = getExpireTime(currentAccessToken)
|
|
||||||
if (getCurrentTimeStamp() + advanceTime.current > currentTokenExpireTime) {
|
|
||||||
localStorage?.setItem('is_refreshing', '1')
|
|
||||||
const [e, res] = await fetchWithRetry(fetchNewToken({
|
|
||||||
body: { refresh_token: currentRefreshToken },
|
|
||||||
}) as Promise<CommonResponse & { data: { access_token: string; refresh_token: string } }>)
|
|
||||||
if (e) {
|
|
||||||
handleError()
|
|
||||||
return e
|
|
||||||
}
|
|
||||||
const { access_token, refresh_token } = res.data
|
|
||||||
localStorage?.setItem('is_refreshing', '0')
|
|
||||||
localStorage?.setItem('console_token', access_token)
|
|
||||||
localStorage?.setItem('refresh_token', refresh_token)
|
|
||||||
const newTokenExpireTime = getExpireTime(access_token)
|
|
||||||
clearTimeout(timer.current)
|
|
||||||
timer.current = setTimeout(() => {
|
|
||||||
getNewAccessToken()
|
|
||||||
}, newTokenExpireTime - advanceTime.current - getCurrentTimeStamp())
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
const newTokenExpireTime = getExpireTime(currentAccessToken)
|
|
||||||
clearTimeout(timer.current)
|
|
||||||
timer.current = setTimeout(() => {
|
|
||||||
getNewAccessToken()
|
|
||||||
}, newTokenExpireTime - advanceTime.current - getCurrentTimeStamp())
|
|
||||||
}
|
|
||||||
return null
|
|
||||||
}, [getExpireTime, getCurrentTimeStamp, handleError])
|
|
||||||
|
|
||||||
const handleVisibilityChange = useCallback(() => {
|
|
||||||
if (document.visibilityState === 'visible')
|
|
||||||
getNewAccessToken()
|
|
||||||
}, [])
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
window.addEventListener('visibilitychange', handleVisibilityChange)
|
|
||||||
return () => {
|
|
||||||
window.removeEventListener('visibilitychange', handleVisibilityChange)
|
|
||||||
clearTimeout(timer.current)
|
|
||||||
localStorage?.removeItem('is_refreshing')
|
|
||||||
}
|
|
||||||
}, [])
|
|
||||||
|
|
||||||
return {
|
|
||||||
getNewAccessToken,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export default useRefreshToken
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue