Merge branch 'main' into feat/support-extractor-tools
commit
cae7f7523b
@ -0,0 +1,24 @@
|
||||
from flask_restful import fields
|
||||
|
||||
parameters__system_parameters = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
"workflow_file_upload_limit": fields.Integer,
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"suggested_questions_after_answer": fields.Raw,
|
||||
"speech_to_text": fields.Raw,
|
||||
"text_to_speech": fields.Raw,
|
||||
"retriever_resource": fields.Raw,
|
||||
"annotation_reply": fields.Raw,
|
||||
"more_like_this": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"sensitive_word_avoidance": fields.Raw,
|
||||
"file_upload": fields.Raw,
|
||||
"system_parameters": fields.Nested(parameters__system_parameters),
|
||||
}
|
||||
@ -0,0 +1,39 @@
|
||||
model: claude-3-5-haiku-20241022
|
||||
label:
|
||||
en_US: claude-3-5-haiku-20241022
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '1.00'
|
||||
output: '5.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" fill="currentColor" viewBox="0 0 24 24" aria-hidden="true" class="" focusable="false" style="fill:currentColor;height:28px;width:28px"><path d="m3.005 8.858 8.783 12.544h3.904L6.908 8.858zM6.905 15.825 3 21.402h3.907l1.951-2.788zM16.585 2l-6.75 9.64 1.953 2.79L20.492 2zM17.292 7.965v13.437h3.2V3.395z"></path></svg>
|
||||
|
After Width: | Height: | Size: 356 B |
@ -0,0 +1,37 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials) -> None:
|
||||
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1"
|
||||
credentials["mode"] = LLMMode.CHAT.value
|
||||
credentials["function_calling_type"] = "tool_call"
|
||||
@ -0,0 +1,25 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class XAIProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
model_instance.validate_credentials(model="grok-beta", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
@ -0,0 +1,38 @@
|
||||
provider: x
|
||||
label:
|
||||
en_US: xAI
|
||||
description:
|
||||
en_US: xAI is a company working on building artificial intelligence to accelerate human scientific discovery. We are guided by our mission to advance our collective understanding of the universe.
|
||||
icon_small:
|
||||
en_US: x-ai-logo.svg
|
||||
icon_large:
|
||||
en_US: x-ai-logo.svg
|
||||
help:
|
||||
title:
|
||||
en_US: Get your token from xAI
|
||||
zh_Hans: 从 xAI 获取 token
|
||||
url:
|
||||
en_US: https://x.ai/api
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: endpoint_url
|
||||
label:
|
||||
en_US: API Base
|
||||
type: text-input
|
||||
required: false
|
||||
default: https://api.x.ai/v1
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Base
|
||||
en_US: Enter your API Base
|
||||
@ -0,0 +1,16 @@
|
||||
class CodeNodeError(ValueError):
|
||||
"""Base class for code node errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OutputValidationError(CodeNodeError):
|
||||
"""Raised when there is an output validation error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DepthLimitError(CodeNodeError):
|
||||
"""Raised when the depth limit is reached."""
|
||||
|
||||
pass
|
||||
@ -0,0 +1,18 @@
|
||||
class HttpRequestNodeError(ValueError):
|
||||
"""Custom error for HTTP request node."""
|
||||
|
||||
|
||||
class AuthorizationConfigError(HttpRequestNodeError):
|
||||
"""Raised when authorization config is missing or invalid."""
|
||||
|
||||
|
||||
class FileFetchError(HttpRequestNodeError):
|
||||
"""Raised when a file cannot be fetched."""
|
||||
|
||||
|
||||
class InvalidHttpMethodError(HttpRequestNodeError):
|
||||
"""Raised when an invalid HTTP method is used."""
|
||||
|
||||
|
||||
class ResponseSizeError(HttpRequestNodeError):
|
||||
"""Raised when the response size exceeds the allowed threshold."""
|
||||
@ -0,0 +1,26 @@
|
||||
class LLMNodeError(ValueError):
|
||||
"""Base class for LLM Node errors."""
|
||||
|
||||
|
||||
class VariableNotFoundError(LLMNodeError):
|
||||
"""Raised when a required variable is not found."""
|
||||
|
||||
|
||||
class InvalidContextStructureError(LLMNodeError):
|
||||
"""Raised when the context structure is invalid."""
|
||||
|
||||
|
||||
class InvalidVariableTypeError(LLMNodeError):
|
||||
"""Raised when the variable type is invalid."""
|
||||
|
||||
|
||||
class ModelNotExistError(LLMNodeError):
|
||||
"""Raised when the specified model does not exist."""
|
||||
|
||||
|
||||
class LLMModeRequiredError(LLMNodeError):
|
||||
"""Raised when LLM mode is required but not provided."""
|
||||
|
||||
|
||||
class NoPromptFoundError(LLMNodeError):
|
||||
"""Raised when no prompt is found in the LLM configuration."""
|
||||
@ -0,0 +1,50 @@
|
||||
class ParameterExtractorNodeError(ValueError):
|
||||
"""Base error for ParameterExtractorNode."""
|
||||
|
||||
|
||||
class InvalidModelTypeError(ParameterExtractorNodeError):
|
||||
"""Raised when the model is not a Large Language Model."""
|
||||
|
||||
|
||||
class ModelSchemaNotFoundError(ParameterExtractorNodeError):
|
||||
"""Raised when the model schema is not found."""
|
||||
|
||||
|
||||
class InvalidInvokeResultError(ParameterExtractorNodeError):
|
||||
"""Raised when the invoke result is invalid."""
|
||||
|
||||
|
||||
class InvalidTextContentTypeError(ParameterExtractorNodeError):
|
||||
"""Raised when the text content type is invalid."""
|
||||
|
||||
|
||||
class InvalidNumberOfParametersError(ParameterExtractorNodeError):
|
||||
"""Raised when the number of parameters is invalid."""
|
||||
|
||||
|
||||
class RequiredParameterMissingError(ParameterExtractorNodeError):
|
||||
"""Raised when a required parameter is missing."""
|
||||
|
||||
|
||||
class InvalidSelectValueError(ParameterExtractorNodeError):
|
||||
"""Raised when a select value is invalid."""
|
||||
|
||||
|
||||
class InvalidNumberValueError(ParameterExtractorNodeError):
|
||||
"""Raised when a number value is invalid."""
|
||||
|
||||
|
||||
class InvalidBoolValueError(ParameterExtractorNodeError):
|
||||
"""Raised when a bool value is invalid."""
|
||||
|
||||
|
||||
class InvalidStringValueError(ParameterExtractorNodeError):
|
||||
"""Raised when a string value is invalid."""
|
||||
|
||||
|
||||
class InvalidArrayValueError(ParameterExtractorNodeError):
|
||||
"""Raised when an array value is invalid."""
|
||||
|
||||
|
||||
class InvalidModelModeError(ParameterExtractorNodeError):
|
||||
"""Raised when the model mode is invalid."""
|
||||
@ -0,0 +1,204 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.x.llm.llm import XAILargeLanguageModel
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
def test_predefined_models():
|
||||
model = XAILargeLanguageModel()
|
||||
model_schemas = model.predefined_models()
|
||||
|
||||
assert len(model_schemas) >= 1
|
||||
assert isinstance(model_schemas[0], AIModelEntity)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_validate_credentials_for_chat_model(setup_openai_mock):
|
||||
model = XAILargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
# model name to gpt-3.5-turbo because of mocking
|
||||
model.validate_credentials(
|
||||
model="gpt-3.5-turbo",
|
||||
credentials={"api_key": "invalid_key", "endpoint_url": os.environ.get("XAI_API_BASE"), "mode": "chat"},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="grok-beta",
|
||||
credentials={
|
||||
"api_key": os.environ.get("XAI_API_KEY"),
|
||||
"endpoint_url": os.environ.get("XAI_API_BASE"),
|
||||
"mode": "chat",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_chat_model(setup_openai_mock):
|
||||
model = XAILargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="grok-beta",
|
||||
credentials={
|
||||
"api_key": os.environ.get("XAI_API_KEY"),
|
||||
"endpoint_url": os.environ.get("XAI_API_BASE"),
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
"temperature": 0.0,
|
||||
"top_p": 1.0,
|
||||
"presence_penalty": 0.0,
|
||||
"frequency_penalty": 0.0,
|
||||
"max_tokens": 10,
|
||||
},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="foo",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert len(result.message.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_chat_model_with_tools(setup_openai_mock):
|
||||
model = XAILargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="grok-beta",
|
||||
credentials={
|
||||
"api_key": os.environ.get("XAI_API_KEY"),
|
||||
"endpoint_url": os.environ.get("XAI_API_BASE"),
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content="what's the weather today in London?",
|
||||
),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name="get_weather",
|
||||
description="Determine weather in my location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
),
|
||||
PromptMessageTool(
|
||||
name="get_stock_price",
|
||||
description="Get the current stock price",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
|
||||
"required": ["symbol"],
|
||||
},
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
user="foo",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert isinstance(result.message, AssistantPromptMessage)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_stream_chat_model(setup_openai_mock):
|
||||
model = XAILargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="grok-beta",
|
||||
credentials={
|
||||
"api_key": os.environ.get("XAI_API_KEY"),
|
||||
"endpoint_url": os.environ.get("XAI_API_BASE"),
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=True,
|
||||
user="foo",
|
||||
)
|
||||
|
||||
assert isinstance(result, Generator)
|
||||
|
||||
for chunk in result:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
if chunk.delta.finish_reason is not None:
|
||||
assert chunk.delta.usage is not None
|
||||
assert chunk.delta.usage.completion_tokens > 0
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = XAILargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="grok-beta",
|
||||
credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
)
|
||||
|
||||
assert num_tokens == 10
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="grok-beta",
|
||||
credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name="get_weather",
|
||||
description="Determine weather in my location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 77
|
||||
@ -0,0 +1,52 @@
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
|
||||
|
||||
def test_validate_inputs_with_zero():
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
var = VariableEntity(
|
||||
variable="test_var",
|
||||
label="test_var",
|
||||
type=VariableEntityType.NUMBER,
|
||||
required=True,
|
||||
)
|
||||
|
||||
# Test with input 0
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var,
|
||||
value=0,
|
||||
)
|
||||
|
||||
assert result == 0
|
||||
|
||||
# Test with input "0" (string)
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var,
|
||||
value="0",
|
||||
)
|
||||
|
||||
assert result == 0
|
||||
|
||||
|
||||
def test_validate_input_with_none_for_required_variable():
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
for var_type in VariableEntityType:
|
||||
var = VariableEntity(
|
||||
variable="test_var",
|
||||
label="test_var",
|
||||
type=var_type,
|
||||
required=True,
|
||||
)
|
||||
|
||||
# Test with input None
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
base_app_generator._validate_inputs(
|
||||
variable_entity=var,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert str(exc_info.value) == "test_var is required in input form"
|
||||
@ -0,0 +1,198 @@
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.http_request import (
|
||||
BodyData,
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeBody,
|
||||
HttpRequestNodeData,
|
||||
)
|
||||
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_number_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "number"], 42)
|
||||
|
||||
# Prepare the node data
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Number Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
value='{"number": {{#pre_node_id.number#}}}',
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize the Executor
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Check the executor's data
|
||||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == {}
|
||||
assert executor.json == {"number": 42}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
assert executor.content is None
|
||||
|
||||
# Check the raw request (to_log method)
|
||||
raw_request = executor.to_log()
|
||||
assert "POST /data HTTP/1.1" in raw_request
|
||||
assert "Host: api.example.com" in raw_request
|
||||
assert "Content-Type: application/json" in raw_request
|
||||
assert '{"number": 42}' in raw_request
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_object_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
||||
|
||||
# Prepare the node data
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Object Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
value="{{#pre_node_id.object#}}",
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize the Executor
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Check the executor's data
|
||||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == {}
|
||||
assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
assert executor.content is None
|
||||
|
||||
# Check the raw request (to_log method)
|
||||
raw_request = executor.to_log()
|
||||
assert "POST /data HTTP/1.1" in raw_request
|
||||
assert "Host: api.example.com" in raw_request
|
||||
assert "Content-Type: application/json" in raw_request
|
||||
assert '"name": "John Doe"' in raw_request
|
||||
assert '"age": 30' in raw_request
|
||||
assert '"email": "john@example.com"' in raw_request
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_nested_object_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
||||
|
||||
# Prepare the node data
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Nested Object Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
value='{"object": {{#pre_node_id.object#}}}',
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize the Executor
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Check the executor's data
|
||||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == {}
|
||||
assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
assert executor.content is None
|
||||
|
||||
# Check the raw request (to_log method)
|
||||
raw_request = executor.to_log()
|
||||
assert "POST /data HTTP/1.1" in raw_request
|
||||
assert "Host: api.example.com" in raw_request
|
||||
assert "Content-Type: application/json" in raw_request
|
||||
assert '"object": {' in raw_request
|
||||
assert '"name": "John Doe"' in raw_request
|
||||
assert '"age": 30' in raw_request
|
||||
assert '"email": "john@example.com"' in raw_request
|
||||
|
||||
|
||||
def test_extract_selectors_from_template_with_newline():
|
||||
variable_pool = VariablePool()
|
||||
variable_pool.add(("node_id", "custom_query"), "line1\nline2")
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Nested Object Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="test: {{#node_id.custom_query#}}",
|
||||
body=HttpRequestNodeBody(
|
||||
type="none",
|
||||
data=[],
|
||||
),
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
assert executor.params == {"test": "line1\nline2"}
|
||||
@ -1,99 +0,0 @@
|
||||
'use client'
|
||||
import { useCallback, useEffect, useRef } from 'react'
|
||||
import { jwtDecode } from 'jwt-decode'
|
||||
import dayjs from 'dayjs'
|
||||
import utc from 'dayjs/plugin/utc'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import type { CommonResponse } from '@/models/common'
|
||||
import { fetchNewToken } from '@/service/common'
|
||||
import { fetchWithRetry } from '@/utils'
|
||||
|
||||
dayjs.extend(utc)
|
||||
|
||||
const useRefreshToken = () => {
|
||||
const router = useRouter()
|
||||
const timer = useRef<NodeJS.Timeout>()
|
||||
const advanceTime = useRef<number>(5 * 60 * 1000)
|
||||
|
||||
const getExpireTime = useCallback((token: string) => {
|
||||
if (!token)
|
||||
return 0
|
||||
const decoded = jwtDecode(token)
|
||||
return (decoded.exp || 0) * 1000
|
||||
}, [])
|
||||
|
||||
const getCurrentTimeStamp = useCallback(() => {
|
||||
return dayjs.utc().valueOf()
|
||||
}, [])
|
||||
|
||||
const handleError = useCallback(() => {
|
||||
localStorage?.removeItem('is_refreshing')
|
||||
localStorage?.removeItem('console_token')
|
||||
localStorage?.removeItem('refresh_token')
|
||||
router.replace('/signin')
|
||||
}, [])
|
||||
|
||||
const getNewAccessToken = useCallback(async () => {
|
||||
const currentAccessToken = localStorage?.getItem('console_token')
|
||||
const currentRefreshToken = localStorage?.getItem('refresh_token')
|
||||
if (!currentAccessToken || !currentRefreshToken) {
|
||||
handleError()
|
||||
return new Error('No access token or refresh token found')
|
||||
}
|
||||
if (localStorage?.getItem('is_refreshing') === '1') {
|
||||
clearTimeout(timer.current)
|
||||
timer.current = setTimeout(() => {
|
||||
getNewAccessToken()
|
||||
}, 1000)
|
||||
return null
|
||||
}
|
||||
const currentTokenExpireTime = getExpireTime(currentAccessToken)
|
||||
if (getCurrentTimeStamp() + advanceTime.current > currentTokenExpireTime) {
|
||||
localStorage?.setItem('is_refreshing', '1')
|
||||
const [e, res] = await fetchWithRetry(fetchNewToken({
|
||||
body: { refresh_token: currentRefreshToken },
|
||||
}) as Promise<CommonResponse & { data: { access_token: string; refresh_token: string } }>)
|
||||
if (e) {
|
||||
handleError()
|
||||
return e
|
||||
}
|
||||
const { access_token, refresh_token } = res.data
|
||||
localStorage?.setItem('is_refreshing', '0')
|
||||
localStorage?.setItem('console_token', access_token)
|
||||
localStorage?.setItem('refresh_token', refresh_token)
|
||||
const newTokenExpireTime = getExpireTime(access_token)
|
||||
clearTimeout(timer.current)
|
||||
timer.current = setTimeout(() => {
|
||||
getNewAccessToken()
|
||||
}, newTokenExpireTime - advanceTime.current - getCurrentTimeStamp())
|
||||
}
|
||||
else {
|
||||
const newTokenExpireTime = getExpireTime(currentAccessToken)
|
||||
clearTimeout(timer.current)
|
||||
timer.current = setTimeout(() => {
|
||||
getNewAccessToken()
|
||||
}, newTokenExpireTime - advanceTime.current - getCurrentTimeStamp())
|
||||
}
|
||||
return null
|
||||
}, [getExpireTime, getCurrentTimeStamp, handleError])
|
||||
|
||||
const handleVisibilityChange = useCallback(() => {
|
||||
if (document.visibilityState === 'visible')
|
||||
getNewAccessToken()
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
window.addEventListener('visibilitychange', handleVisibilityChange)
|
||||
return () => {
|
||||
window.removeEventListener('visibilitychange', handleVisibilityChange)
|
||||
clearTimeout(timer.current)
|
||||
localStorage?.removeItem('is_refreshing')
|
||||
}
|
||||
}, [])
|
||||
|
||||
return {
|
||||
getNewAccessToken,
|
||||
}
|
||||
}
|
||||
|
||||
export default useRefreshToken
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue