Merge branch 'main' into fix/chore-fix
commit
75fe785d88
@ -0,0 +1,51 @@
|
|||||||
|
from enum import StrEnum
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class OpenDALScheme(StrEnum):
|
||||||
|
FS = "fs"
|
||||||
|
S3 = "s3"
|
||||||
|
|
||||||
|
|
||||||
|
class OpenDALStorageConfig(BaseSettings):
|
||||||
|
STORAGE_OPENDAL_SCHEME: str = Field(
|
||||||
|
default=OpenDALScheme.FS.value,
|
||||||
|
description="OpenDAL scheme.",
|
||||||
|
)
|
||||||
|
# FS
|
||||||
|
OPENDAL_FS_ROOT: str = Field(
|
||||||
|
default="storage",
|
||||||
|
description="Root path for local storage.",
|
||||||
|
)
|
||||||
|
# S3
|
||||||
|
OPENDAL_S3_ROOT: str = Field(
|
||||||
|
default="/",
|
||||||
|
description="Root path for S3 storage.",
|
||||||
|
)
|
||||||
|
OPENDAL_S3_BUCKET: str = Field(
|
||||||
|
default="",
|
||||||
|
description="S3 bucket name.",
|
||||||
|
)
|
||||||
|
OPENDAL_S3_ENDPOINT: str = Field(
|
||||||
|
default="https://s3.amazonaws.com",
|
||||||
|
description="S3 endpoint URL.",
|
||||||
|
)
|
||||||
|
OPENDAL_S3_ACCESS_KEY_ID: str = Field(
|
||||||
|
default="",
|
||||||
|
description="S3 access key ID.",
|
||||||
|
)
|
||||||
|
OPENDAL_S3_SECRET_ACCESS_KEY: str = Field(
|
||||||
|
default="",
|
||||||
|
description="S3 secret access key.",
|
||||||
|
)
|
||||||
|
OPENDAL_S3_REGION: str = Field(
|
||||||
|
default="",
|
||||||
|
description="S3 region.",
|
||||||
|
)
|
||||||
|
OPENDAL_S3_SERVER_SIDE_ENCRYPTION: Literal["aws:kms", ""] = Field(
|
||||||
|
default="",
|
||||||
|
description="S3 server-side encryption.",
|
||||||
|
)
|
||||||
@ -0,0 +1,17 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from .apollo import ApolloSettingsSourceInfo
|
||||||
|
from .base import RemoteSettingsSource
|
||||||
|
from .enums import RemoteSettingsSourceName
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteSettingsSourceConfig(ApolloSettingsSourceInfo):
|
||||||
|
REMOTE_SETTINGS_SOURCE_NAME: RemoteSettingsSourceName | str = Field(
|
||||||
|
description="name of remote config source",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["RemoteSettingsSource", "RemoteSettingsSourceConfig", "RemoteSettingsSourceName"]
|
||||||
@ -0,0 +1,55 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from pydantic.fields import FieldInfo
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
from configs.remote_settings_sources.base import RemoteSettingsSource
|
||||||
|
|
||||||
|
from .client import ApolloClient
|
||||||
|
|
||||||
|
|
||||||
|
class ApolloSettingsSourceInfo(BaseSettings):
|
||||||
|
"""
|
||||||
|
Packaging build information
|
||||||
|
"""
|
||||||
|
|
||||||
|
APOLLO_APP_ID: Optional[str] = Field(
|
||||||
|
description="apollo app_id",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
APOLLO_CLUSTER: Optional[str] = Field(
|
||||||
|
description="apollo cluster",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
APOLLO_CONFIG_URL: Optional[str] = Field(
|
||||||
|
description="apollo config url",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
APOLLO_NAMESPACE: Optional[str] = Field(
|
||||||
|
description="apollo namespace",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ApolloSettingsSource(RemoteSettingsSource):
|
||||||
|
def __init__(self, configs: Mapping[str, Any]):
|
||||||
|
self.client = ApolloClient(
|
||||||
|
app_id=configs["APOLLO_APP_ID"],
|
||||||
|
cluster=configs["APOLLO_CLUSTER"],
|
||||||
|
config_url=configs["APOLLO_CONFIG_URL"],
|
||||||
|
start_hot_update=False,
|
||||||
|
_notification_map={configs["APOLLO_NAMESPACE"]: -1},
|
||||||
|
)
|
||||||
|
self.namespace = configs["APOLLO_NAMESPACE"]
|
||||||
|
self.remote_configs = self.client.get_all_dicts(self.namespace)
|
||||||
|
|
||||||
|
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||||
|
if not isinstance(self.remote_configs, dict):
|
||||||
|
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
|
||||||
|
field_value = self.remote_configs.get(field_name)
|
||||||
|
return field_value, field_name, False
|
||||||
@ -0,0 +1,303 @@
|
|||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from .python_3x import http_request, makedirs_wrapper
|
||||||
|
from .utils import (
|
||||||
|
CONFIGURATIONS,
|
||||||
|
NAMESPACE_NAME,
|
||||||
|
NOTIFICATION_ID,
|
||||||
|
get_value_from_dict,
|
||||||
|
init_ip,
|
||||||
|
no_key_cache_key,
|
||||||
|
signature,
|
||||||
|
url_encode_wrapper,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ApolloClient:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config_url,
|
||||||
|
app_id,
|
||||||
|
cluster="default",
|
||||||
|
secret="",
|
||||||
|
start_hot_update=True,
|
||||||
|
change_listener=None,
|
||||||
|
_notification_map=None,
|
||||||
|
):
|
||||||
|
# Core routing parameters
|
||||||
|
self.config_url = config_url
|
||||||
|
self.cluster = cluster
|
||||||
|
self.app_id = app_id
|
||||||
|
|
||||||
|
# Non-core parameters
|
||||||
|
self.ip = init_ip()
|
||||||
|
self.secret = secret
|
||||||
|
|
||||||
|
# Check the parameter variables
|
||||||
|
|
||||||
|
# Private control variables
|
||||||
|
self._cycle_time = 5
|
||||||
|
self._stopping = False
|
||||||
|
self._cache = {}
|
||||||
|
self._no_key = {}
|
||||||
|
self._hash = {}
|
||||||
|
self._pull_timeout = 75
|
||||||
|
self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/"
|
||||||
|
self._long_poll_thread = None
|
||||||
|
self._change_listener = change_listener # "add" "delete" "update"
|
||||||
|
if _notification_map is None:
|
||||||
|
_notification_map = {"application": -1}
|
||||||
|
self._notification_map = _notification_map
|
||||||
|
self.last_release_key = None
|
||||||
|
# Private startup method
|
||||||
|
self._path_checker()
|
||||||
|
if start_hot_update:
|
||||||
|
self._start_hot_update()
|
||||||
|
|
||||||
|
# start the heartbeat thread
|
||||||
|
heartbeat = threading.Thread(target=self._heart_beat)
|
||||||
|
heartbeat.daemon = True
|
||||||
|
heartbeat.start()
|
||||||
|
|
||||||
|
def get_json_from_net(self, namespace="application"):
|
||||||
|
url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format(
|
||||||
|
self.config_url, self.app_id, self.cluster, namespace, "", self.ip
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
|
||||||
|
if code == 200:
|
||||||
|
if not body:
|
||||||
|
logger.error(f"get_json_from_net load configs failed, body is {body}")
|
||||||
|
return None
|
||||||
|
data = json.loads(body)
|
||||||
|
data = data["configurations"]
|
||||||
|
return_data = {CONFIGURATIONS: data}
|
||||||
|
return return_data
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
logger.exception("an error occurred in get_json_from_net")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_value(self, key, default_val=None, namespace="application"):
|
||||||
|
try:
|
||||||
|
# read memory configuration
|
||||||
|
namespace_cache = self._cache.get(namespace)
|
||||||
|
val = get_value_from_dict(namespace_cache, key)
|
||||||
|
if val is not None:
|
||||||
|
return val
|
||||||
|
|
||||||
|
no_key = no_key_cache_key(namespace, key)
|
||||||
|
if no_key in self._no_key:
|
||||||
|
return default_val
|
||||||
|
|
||||||
|
# read the network configuration
|
||||||
|
namespace_data = self.get_json_from_net(namespace)
|
||||||
|
val = get_value_from_dict(namespace_data, key)
|
||||||
|
if val is not None:
|
||||||
|
self._update_cache_and_file(namespace_data, namespace)
|
||||||
|
return val
|
||||||
|
|
||||||
|
# read the file configuration
|
||||||
|
namespace_cache = self._get_local_cache(namespace)
|
||||||
|
val = get_value_from_dict(namespace_cache, key)
|
||||||
|
if val is not None:
|
||||||
|
self._update_cache_and_file(namespace_cache, namespace)
|
||||||
|
return val
|
||||||
|
|
||||||
|
# If all of them are not obtained, the default value is returned
|
||||||
|
# and the local cache is set to None
|
||||||
|
self._set_local_cache_none(namespace, key)
|
||||||
|
return default_val
|
||||||
|
except Exception:
|
||||||
|
logger.exception("get_value has error, [key is %s], [namespace is %s]", key, namespace)
|
||||||
|
return default_val
|
||||||
|
|
||||||
|
# Set the key of a namespace to none, and do not set default val
|
||||||
|
# to ensure the real-time correctness of the function call.
|
||||||
|
# If the user does not have the same default val twice
|
||||||
|
# and the default val is used here, there may be a problem.
|
||||||
|
def _set_local_cache_none(self, namespace, key):
|
||||||
|
no_key = no_key_cache_key(namespace, key)
|
||||||
|
self._no_key[no_key] = key
|
||||||
|
|
||||||
|
def _start_hot_update(self):
|
||||||
|
self._long_poll_thread = threading.Thread(target=self._listener)
|
||||||
|
# When the asynchronous thread is started, the daemon thread will automatically exit
|
||||||
|
# when the main thread is launched.
|
||||||
|
self._long_poll_thread.daemon = True
|
||||||
|
self._long_poll_thread.start()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self._stopping = True
|
||||||
|
logger.info("Stopping listener...")
|
||||||
|
|
||||||
|
# Call the set callback function, and if it is abnormal, try it out
|
||||||
|
def _call_listener(self, namespace, old_kv, new_kv):
|
||||||
|
if self._change_listener is None:
|
||||||
|
return
|
||||||
|
if old_kv is None:
|
||||||
|
old_kv = {}
|
||||||
|
if new_kv is None:
|
||||||
|
new_kv = {}
|
||||||
|
try:
|
||||||
|
for key in old_kv:
|
||||||
|
new_value = new_kv.get(key)
|
||||||
|
old_value = old_kv.get(key)
|
||||||
|
if new_value is None:
|
||||||
|
# If newValue is empty, it means key, and the value is deleted.
|
||||||
|
self._change_listener("delete", namespace, key, old_value)
|
||||||
|
continue
|
||||||
|
if new_value != old_value:
|
||||||
|
self._change_listener("update", namespace, key, new_value)
|
||||||
|
continue
|
||||||
|
for key in new_kv:
|
||||||
|
new_value = new_kv.get(key)
|
||||||
|
old_value = old_kv.get(key)
|
||||||
|
if old_value is None:
|
||||||
|
self._change_listener("add", namespace, key, new_value)
|
||||||
|
except BaseException as e:
|
||||||
|
logger.warning(str(e))
|
||||||
|
|
||||||
|
def _path_checker(self):
|
||||||
|
if not os.path.isdir(self._cache_file_path):
|
||||||
|
makedirs_wrapper(self._cache_file_path)
|
||||||
|
|
||||||
|
# update the local cache and file cache
|
||||||
|
def _update_cache_and_file(self, namespace_data, namespace="application"):
|
||||||
|
# update the local cache
|
||||||
|
self._cache[namespace] = namespace_data
|
||||||
|
# update the file cache
|
||||||
|
new_string = json.dumps(namespace_data)
|
||||||
|
new_hash = hashlib.md5(new_string.encode("utf-8")).hexdigest()
|
||||||
|
if self._hash.get(namespace) == new_hash:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
file_path = Path(self._cache_file_path) / f"{self.app_id}_configuration_{namespace}.txt"
|
||||||
|
file_path.write_text(new_string)
|
||||||
|
self._hash[namespace] = new_hash
|
||||||
|
|
||||||
|
# get the configuration from the local file
|
||||||
|
def _get_local_cache(self, namespace="application"):
|
||||||
|
cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt")
|
||||||
|
if os.path.isfile(cache_file_path):
|
||||||
|
with open(cache_file_path) as f:
|
||||||
|
result = json.loads(f.readline())
|
||||||
|
return result
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _long_poll(self):
|
||||||
|
notifications = []
|
||||||
|
for key in self._cache:
|
||||||
|
namespace_data = self._cache[key]
|
||||||
|
notification_id = -1
|
||||||
|
if NOTIFICATION_ID in namespace_data:
|
||||||
|
notification_id = self._cache[key][NOTIFICATION_ID]
|
||||||
|
notifications.append({NAMESPACE_NAME: key, NOTIFICATION_ID: notification_id})
|
||||||
|
try:
|
||||||
|
# if the length is 0 it is returned directly
|
||||||
|
if len(notifications) == 0:
|
||||||
|
return
|
||||||
|
url = "{}/notifications/v2".format(self.config_url)
|
||||||
|
params = {
|
||||||
|
"appId": self.app_id,
|
||||||
|
"cluster": self.cluster,
|
||||||
|
"notifications": json.dumps(notifications, ensure_ascii=False),
|
||||||
|
}
|
||||||
|
param_str = url_encode_wrapper(params)
|
||||||
|
url = url + "?" + param_str
|
||||||
|
code, body = http_request(url, self._pull_timeout, headers=self._sign_headers(url))
|
||||||
|
http_code = code
|
||||||
|
if http_code == 304:
|
||||||
|
logger.debug("No change, loop...")
|
||||||
|
return
|
||||||
|
if http_code == 200:
|
||||||
|
if not body:
|
||||||
|
logger.error(f"_long_poll load configs failed,body is {body}")
|
||||||
|
return
|
||||||
|
data = json.loads(body)
|
||||||
|
for entry in data:
|
||||||
|
namespace = entry[NAMESPACE_NAME]
|
||||||
|
n_id = entry[NOTIFICATION_ID]
|
||||||
|
logger.info("%s has changes: notificationId=%d", namespace, n_id)
|
||||||
|
self._get_net_and_set_local(namespace, n_id, call_change=True)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
logger.warning("Sleep...")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(str(e))
|
||||||
|
|
||||||
|
def _get_net_and_set_local(self, namespace, n_id, call_change=False):
|
||||||
|
namespace_data = self.get_json_from_net(namespace)
|
||||||
|
if not namespace_data:
|
||||||
|
return
|
||||||
|
namespace_data[NOTIFICATION_ID] = n_id
|
||||||
|
old_namespace = self._cache.get(namespace)
|
||||||
|
self._update_cache_and_file(namespace_data, namespace)
|
||||||
|
if self._change_listener is not None and call_change and old_namespace:
|
||||||
|
old_kv = old_namespace.get(CONFIGURATIONS)
|
||||||
|
new_kv = namespace_data.get(CONFIGURATIONS)
|
||||||
|
self._call_listener(namespace, old_kv, new_kv)
|
||||||
|
|
||||||
|
def _listener(self):
|
||||||
|
logger.info("start long_poll")
|
||||||
|
while not self._stopping:
|
||||||
|
self._long_poll()
|
||||||
|
time.sleep(self._cycle_time)
|
||||||
|
logger.info("stopped, long_poll")
|
||||||
|
|
||||||
|
# add the need for endorsement to the header
|
||||||
|
def _sign_headers(self, url):
|
||||||
|
headers = {}
|
||||||
|
if self.secret == "":
|
||||||
|
return headers
|
||||||
|
uri = url[len(self.config_url) : len(url)]
|
||||||
|
time_unix_now = str(int(round(time.time() * 1000)))
|
||||||
|
headers["Authorization"] = "Apollo " + self.app_id + ":" + signature(time_unix_now, uri, self.secret)
|
||||||
|
headers["Timestamp"] = time_unix_now
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def _heart_beat(self):
|
||||||
|
while not self._stopping:
|
||||||
|
for namespace in self._notification_map:
|
||||||
|
self._do_heart_beat(namespace)
|
||||||
|
time.sleep(60 * 10) # 10分钟
|
||||||
|
|
||||||
|
def _do_heart_beat(self, namespace):
|
||||||
|
url = "{}/configs/{}/{}/{}?ip={}".format(self.config_url, self.app_id, self.cluster, namespace, self.ip)
|
||||||
|
try:
|
||||||
|
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
|
||||||
|
if code == 200:
|
||||||
|
if not body:
|
||||||
|
logger.error(f"_do_heart_beat load configs failed,body is {body}")
|
||||||
|
return None
|
||||||
|
data = json.loads(body)
|
||||||
|
if self.last_release_key == data["releaseKey"]:
|
||||||
|
return None
|
||||||
|
self.last_release_key = data["releaseKey"]
|
||||||
|
data = data["configurations"]
|
||||||
|
self._update_cache_and_file(data, namespace)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
logger.exception("an error occurred in _do_heart_beat")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_all_dicts(self, namespace):
|
||||||
|
namespace_data = self._cache.get(namespace)
|
||||||
|
if namespace_data is None:
|
||||||
|
net_namespace_data = self.get_json_from_net(namespace)
|
||||||
|
if not net_namespace_data:
|
||||||
|
return namespace_data
|
||||||
|
namespace_data = net_namespace_data.get(CONFIGURATIONS)
|
||||||
|
if namespace_data:
|
||||||
|
self._update_cache_and_file(namespace_data, namespace)
|
||||||
|
return namespace_data
|
||||||
@ -0,0 +1,41 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import ssl
|
||||||
|
import urllib.request
|
||||||
|
from urllib import parse
|
||||||
|
from urllib.error import HTTPError
|
||||||
|
|
||||||
|
# Create an SSL context that allows for a lower level of security
|
||||||
|
ssl_context = ssl.create_default_context()
|
||||||
|
ssl_context.set_ciphers("HIGH:!DH:!aNULL")
|
||||||
|
ssl_context.check_hostname = False
|
||||||
|
ssl_context.verify_mode = ssl.CERT_NONE
|
||||||
|
|
||||||
|
# Create an opener object and pass in a custom SSL context
|
||||||
|
opener = urllib.request.build_opener(urllib.request.HTTPSHandler(context=ssl_context))
|
||||||
|
|
||||||
|
urllib.request.install_opener(opener)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def http_request(url, timeout, headers={}):
|
||||||
|
try:
|
||||||
|
request = urllib.request.Request(url, headers=headers)
|
||||||
|
res = urllib.request.urlopen(request, timeout=timeout)
|
||||||
|
body = res.read().decode("utf-8")
|
||||||
|
return res.code, body
|
||||||
|
except HTTPError as e:
|
||||||
|
if e.code == 304:
|
||||||
|
logger.warning("http_request error,code is 304, maybe you should check secret")
|
||||||
|
return 304, None
|
||||||
|
logger.warning("http_request error,code is %d, msg is %s", e.code, e.msg)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def url_encode(params):
|
||||||
|
return parse.urlencode(params)
|
||||||
|
|
||||||
|
|
||||||
|
def makedirs_wrapper(path):
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
@ -0,0 +1,51 @@
|
|||||||
|
import hashlib
|
||||||
|
import socket
|
||||||
|
|
||||||
|
from .python_3x import url_encode
|
||||||
|
|
||||||
|
# define constants
|
||||||
|
CONFIGURATIONS = "configurations"
|
||||||
|
NOTIFICATION_ID = "notificationId"
|
||||||
|
NAMESPACE_NAME = "namespaceName"
|
||||||
|
|
||||||
|
|
||||||
|
# add timestamps uris and keys
|
||||||
|
def signature(timestamp, uri, secret):
|
||||||
|
import base64
|
||||||
|
import hmac
|
||||||
|
|
||||||
|
string_to_sign = "" + timestamp + "\n" + uri
|
||||||
|
hmac_code = hmac.new(secret.encode(), string_to_sign.encode(), hashlib.sha1).digest()
|
||||||
|
return base64.b64encode(hmac_code).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def url_encode_wrapper(params):
|
||||||
|
return url_encode(params)
|
||||||
|
|
||||||
|
|
||||||
|
def no_key_cache_key(namespace, key):
|
||||||
|
return "{}{}{}".format(namespace, len(namespace), key)
|
||||||
|
|
||||||
|
|
||||||
|
# Returns whether the obtained value is obtained, and None if it does not
|
||||||
|
def get_value_from_dict(namespace_cache, key):
|
||||||
|
if namespace_cache:
|
||||||
|
kv_data = namespace_cache.get(CONFIGURATIONS)
|
||||||
|
if kv_data is None:
|
||||||
|
return None
|
||||||
|
if key in kv_data:
|
||||||
|
return kv_data[key]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def init_ip():
|
||||||
|
ip = ""
|
||||||
|
s = None
|
||||||
|
try:
|
||||||
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
s.connect(("8.8.8.8", 53))
|
||||||
|
ip = s.getsockname()[0]
|
||||||
|
finally:
|
||||||
|
if s:
|
||||||
|
s.close()
|
||||||
|
return ip
|
||||||
@ -0,0 +1,15 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic.fields import FieldInfo
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteSettingsSource:
|
||||||
|
def __init__(self, configs: Mapping[str, Any]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
|
||||||
|
return value
|
||||||
@ -0,0 +1,5 @@
|
|||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteSettingsSourceName(StrEnum):
|
||||||
|
APOLLO = "apollo"
|
||||||
@ -0,0 +1,39 @@
|
|||||||
|
model: gemini-2.0-flash-exp
|
||||||
|
label:
|
||||||
|
en_US: Gemini 2.0 Flash Exp
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- vision
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
- document
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 1048576
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
|
- name: max_output_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 8192
|
||||||
|
min: 1
|
||||||
|
max: 8192
|
||||||
|
- name: json_schema
|
||||||
|
use_template: json_schema
|
||||||
|
pricing:
|
||||||
|
input: '0.00'
|
||||||
|
output: '0.00'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
||||||
@ -0,0 +1,39 @@
|
|||||||
|
model: gemini-2.0-flash-exp
|
||||||
|
label:
|
||||||
|
en_US: Gemini 2.0 Flash Exp
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- vision
|
||||||
|
- tool-call
|
||||||
|
- stream-tool-call
|
||||||
|
- document
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 1048576
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
|
- name: max_output_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 8192
|
||||||
|
min: 1
|
||||||
|
max: 8192
|
||||||
|
- name: json_schema
|
||||||
|
use_template: json_schema
|
||||||
|
pricing:
|
||||||
|
input: '0.00'
|
||||||
|
output: '0.00'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
||||||
@ -0,0 +1,10 @@
|
|||||||
|
class BaseNodeError(Exception):
|
||||||
|
"""Base class for node errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultValueTypeError(BaseNodeError):
|
||||||
|
"""Raised when the default value type is invalid."""
|
||||||
|
|
||||||
|
pass
|
||||||
@ -1,62 +0,0 @@
|
|||||||
import os
|
|
||||||
import shutil
|
|
||||||
from collections.abc import Generator
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from flask import current_app
|
|
||||||
|
|
||||||
from configs import dify_config
|
|
||||||
from extensions.storage.base_storage import BaseStorage
|
|
||||||
|
|
||||||
|
|
||||||
class LocalFsStorage(BaseStorage):
|
|
||||||
"""Implementation for local filesystem storage."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
folder = dify_config.STORAGE_LOCAL_PATH
|
|
||||||
if not os.path.isabs(folder):
|
|
||||||
folder = os.path.join(current_app.root_path, folder)
|
|
||||||
self.folder = folder
|
|
||||||
|
|
||||||
def _build_filepath(self, filename: str) -> str:
|
|
||||||
"""Build the full file path based on the folder and filename."""
|
|
||||||
if not self.folder or self.folder.endswith("/"):
|
|
||||||
return self.folder + filename
|
|
||||||
else:
|
|
||||||
return self.folder + "/" + filename
|
|
||||||
|
|
||||||
def save(self, filename, data):
|
|
||||||
filepath = self._build_filepath(filename)
|
|
||||||
folder = os.path.dirname(filepath)
|
|
||||||
os.makedirs(folder, exist_ok=True)
|
|
||||||
Path(os.path.join(os.getcwd(), filepath)).write_bytes(data)
|
|
||||||
|
|
||||||
def load_once(self, filename: str) -> bytes:
|
|
||||||
filepath = self._build_filepath(filename)
|
|
||||||
if not os.path.exists(filepath):
|
|
||||||
raise FileNotFoundError("File not found")
|
|
||||||
return Path(filepath).read_bytes()
|
|
||||||
|
|
||||||
def load_stream(self, filename: str) -> Generator:
|
|
||||||
filepath = self._build_filepath(filename)
|
|
||||||
if not os.path.exists(filepath):
|
|
||||||
raise FileNotFoundError("File not found")
|
|
||||||
with open(filepath, "rb") as f:
|
|
||||||
while chunk := f.read(4096): # Read in chunks of 4KB
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
def download(self, filename, target_filepath):
|
|
||||||
filepath = self._build_filepath(filename)
|
|
||||||
if not os.path.exists(filepath):
|
|
||||||
raise FileNotFoundError("File not found")
|
|
||||||
shutil.copyfile(filepath, target_filepath)
|
|
||||||
|
|
||||||
def exists(self, filename):
|
|
||||||
filepath = self._build_filepath(filename)
|
|
||||||
return os.path.exists(filepath)
|
|
||||||
|
|
||||||
def delete(self, filename):
|
|
||||||
filepath = self._build_filepath(filename)
|
|
||||||
if os.path.exists(filepath):
|
|
||||||
os.remove(filepath)
|
|
||||||
@ -0,0 +1,72 @@
|
|||||||
|
from collections.abc import Generator
|
||||||
|
from pathlib import Path
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import opendal
|
||||||
|
|
||||||
|
from configs.middleware.storage.opendal_storage_config import OpenDALScheme
|
||||||
|
from extensions.storage.base_storage import BaseStorage
|
||||||
|
|
||||||
|
S3_R2_HOSTNAME = "r2.cloudflarestorage.com"
|
||||||
|
S3_R2_COMPATIBLE_KWARGS = {
|
||||||
|
"delete_max_size": "700",
|
||||||
|
"disable_stat_with_override": "true",
|
||||||
|
"region": "auto",
|
||||||
|
}
|
||||||
|
S3_SSE_WITH_AWS_MANAGED_IAM_KWARGS = {
|
||||||
|
"server_side_encryption": "aws:kms",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def is_r2_endpoint(endpoint: str) -> bool:
|
||||||
|
if not endpoint:
|
||||||
|
return False
|
||||||
|
|
||||||
|
parsed_url = urlparse(endpoint)
|
||||||
|
return bool(parsed_url.hostname and parsed_url.hostname.endswith(S3_R2_HOSTNAME))
|
||||||
|
|
||||||
|
|
||||||
|
class OpenDALStorage(BaseStorage):
|
||||||
|
def __init__(self, scheme: OpenDALScheme, **kwargs):
|
||||||
|
if scheme == OpenDALScheme.FS:
|
||||||
|
Path(kwargs["root"]).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.op = opendal.Operator(scheme=scheme, **kwargs)
|
||||||
|
|
||||||
|
def save(self, filename: str, data: bytes) -> None:
|
||||||
|
self.op.write(path=filename, bs=data)
|
||||||
|
|
||||||
|
def load_once(self, filename: str) -> bytes:
|
||||||
|
if not self.exists(filename):
|
||||||
|
raise FileNotFoundError("File not found")
|
||||||
|
|
||||||
|
return self.op.read(path=filename)
|
||||||
|
|
||||||
|
def load_stream(self, filename: str) -> Generator:
|
||||||
|
if not self.exists(filename):
|
||||||
|
raise FileNotFoundError("File not found")
|
||||||
|
|
||||||
|
batch_size = 4096
|
||||||
|
file = self.op.open(path=filename, mode="rb")
|
||||||
|
while chunk := file.read(batch_size):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
def download(self, filename: str, target_filepath: str):
|
||||||
|
if not self.exists(filename):
|
||||||
|
raise FileNotFoundError("File not found")
|
||||||
|
|
||||||
|
with Path(target_filepath).open("wb") as f:
|
||||||
|
f.write(self.op.read(path=filename))
|
||||||
|
|
||||||
|
def exists(self, filename: str) -> bool:
|
||||||
|
# FIXME this is a workaround for opendal python-binding do not have a exists method and no better
|
||||||
|
# error handler here when opendal python-binding has a exists method, we should use it
|
||||||
|
# more https://github.com/apache/opendal/blob/main/bindings/python/src/operator.rs
|
||||||
|
try:
|
||||||
|
return self.op.stat(path=filename).mode.is_file()
|
||||||
|
except Exception as e:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def delete(self, filename: str):
|
||||||
|
if self.exists(filename):
|
||||||
|
self.op.delete(path=filename)
|
||||||
@ -0,0 +1,33 @@
|
|||||||
|
"""add exceptions_count field to WorkflowRun model
|
||||||
|
|
||||||
|
Revision ID: cf8f4fc45278
|
||||||
|
Revises: 01d6889832f7
|
||||||
|
Create Date: 2024-11-28 05:53:21.576178
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'cf8f4fc45278'
|
||||||
|
down_revision = '01d6889832f7'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('exceptions_count')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,20 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from extensions.storage.opendal_storage import is_r2_endpoint
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("endpoint", "expected"),
|
||||||
|
[
|
||||||
|
("https://bucket.r2.cloudflarestorage.com", True),
|
||||||
|
("https://custom-domain.r2.cloudflarestorage.com/", True),
|
||||||
|
("https://bucket.r2.cloudflarestorage.com/path", True),
|
||||||
|
("https://s3.amazonaws.com", False),
|
||||||
|
("https://storage.googleapis.com", False),
|
||||||
|
("http://localhost:9000", False),
|
||||||
|
("invalid-url", False),
|
||||||
|
("", False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_is_r2_endpoint(endpoint: str, expected: bool):
|
||||||
|
assert is_r2_endpoint(endpoint) == expected
|
||||||
@ -0,0 +1,502 @@
|
|||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.workflow.enums import SystemVariableKey
|
||||||
|
from core.workflow.graph_engine.entities.event import (
|
||||||
|
GraphRunPartialSucceededEvent,
|
||||||
|
GraphRunSucceededEvent,
|
||||||
|
NodeRunExceptionEvent,
|
||||||
|
NodeRunStreamChunkEvent,
|
||||||
|
)
|
||||||
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
|
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||||
|
from models.enums import UserFrom
|
||||||
|
from models.workflow import WorkflowType
|
||||||
|
|
||||||
|
|
||||||
|
class ContinueOnErrorTestHelper:
|
||||||
|
@staticmethod
|
||||||
|
def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: dict | None = None):
|
||||||
|
"""Helper method to create a code node configuration"""
|
||||||
|
node = {
|
||||||
|
"id": "node",
|
||||||
|
"data": {
|
||||||
|
"outputs": {"result": {"type": "number"}},
|
||||||
|
"error_strategy": error_strategy,
|
||||||
|
"title": "code",
|
||||||
|
"variables": [],
|
||||||
|
"code_language": "python3",
|
||||||
|
"code": "\n".join([line[4:] for line in code.split("\n")]),
|
||||||
|
"type": "code",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if default_value:
|
||||||
|
node["data"]["default_value"] = default_value
|
||||||
|
return node
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_http_node(
|
||||||
|
error_strategy: str = "fail-branch", default_value: dict | None = None, authorization_success: bool = False
|
||||||
|
):
|
||||||
|
"""Helper method to create a http node configuration"""
|
||||||
|
authorization = (
|
||||||
|
{
|
||||||
|
"type": "api-key",
|
||||||
|
"config": {
|
||||||
|
"type": "basic",
|
||||||
|
"api_key": "ak-xxx",
|
||||||
|
"header": "api-key",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if authorization_success
|
||||||
|
else {
|
||||||
|
"type": "api-key",
|
||||||
|
# missing config field
|
||||||
|
}
|
||||||
|
)
|
||||||
|
node = {
|
||||||
|
"id": "node",
|
||||||
|
"data": {
|
||||||
|
"title": "http",
|
||||||
|
"desc": "",
|
||||||
|
"method": "get",
|
||||||
|
"url": "http://example.com",
|
||||||
|
"authorization": authorization,
|
||||||
|
"headers": "X-Header:123",
|
||||||
|
"params": "A:b",
|
||||||
|
"body": None,
|
||||||
|
"type": "http-request",
|
||||||
|
"error_strategy": error_strategy,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if default_value:
|
||||||
|
node["data"]["default_value"] = default_value
|
||||||
|
return node
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_error_status_code_http_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
|
||||||
|
"""Helper method to create a http node configuration"""
|
||||||
|
node = {
|
||||||
|
"id": "node",
|
||||||
|
"data": {
|
||||||
|
"type": "http-request",
|
||||||
|
"title": "HTTP Request",
|
||||||
|
"desc": "",
|
||||||
|
"variables": [],
|
||||||
|
"method": "get",
|
||||||
|
"url": "https://api.github.com/issues",
|
||||||
|
"authorization": {"type": "no-auth", "config": None},
|
||||||
|
"headers": "",
|
||||||
|
"params": "",
|
||||||
|
"body": {"type": "none", "data": []},
|
||||||
|
"timeout": {"max_connect_timeout": 0, "max_read_timeout": 0, "max_write_timeout": 0},
|
||||||
|
"error_strategy": error_strategy,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if default_value:
|
||||||
|
node["data"]["default_value"] = default_value
|
||||||
|
return node
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tool_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
|
||||||
|
"""Helper method to create a tool node configuration"""
|
||||||
|
node = {
|
||||||
|
"id": "node",
|
||||||
|
"data": {
|
||||||
|
"title": "a",
|
||||||
|
"desc": "a",
|
||||||
|
"provider_id": "maths",
|
||||||
|
"provider_type": "builtin",
|
||||||
|
"provider_name": "maths",
|
||||||
|
"tool_name": "eval_expression",
|
||||||
|
"tool_label": "eval_expression",
|
||||||
|
"tool_configurations": {},
|
||||||
|
"tool_parameters": {
|
||||||
|
"expression": {
|
||||||
|
"type": "variable",
|
||||||
|
"value": ["1", "123", "args1"],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "tool",
|
||||||
|
"error_strategy": error_strategy,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if default_value:
|
||||||
|
node["data"]["default_value"] = default_value
|
||||||
|
return node
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_llm_node(error_strategy: str = "fail-branch", default_value: dict | None = None):
|
||||||
|
"""Helper method to create a llm node configuration"""
|
||||||
|
node = {
|
||||||
|
"id": "node",
|
||||||
|
"data": {
|
||||||
|
"title": "123",
|
||||||
|
"type": "llm",
|
||||||
|
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}},
|
||||||
|
"prompt_template": [
|
||||||
|
{"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."},
|
||||||
|
{"role": "user", "text": "{{#sys.query#}}"},
|
||||||
|
],
|
||||||
|
"memory": None,
|
||||||
|
"context": {"enabled": False},
|
||||||
|
"vision": {"enabled": False},
|
||||||
|
"error_strategy": error_strategy,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if default_value:
|
||||||
|
node["data"]["default_value"] = default_value
|
||||||
|
return node
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None):
|
||||||
|
"""Helper method to create a graph engine instance for testing"""
|
||||||
|
graph = Graph.init(graph_config=graph_config)
|
||||||
|
variable_pool = {
|
||||||
|
"system_variables": {
|
||||||
|
SystemVariableKey.QUERY: "clear",
|
||||||
|
SystemVariableKey.FILES: [],
|
||||||
|
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||||
|
SystemVariableKey.USER_ID: "aaa",
|
||||||
|
},
|
||||||
|
"user_inputs": user_inputs or {"uid": "takato"},
|
||||||
|
}
|
||||||
|
|
||||||
|
return GraphEngine(
|
||||||
|
tenant_id="111",
|
||||||
|
app_id="222",
|
||||||
|
workflow_type=WorkflowType.CHAT,
|
||||||
|
workflow_id="333",
|
||||||
|
graph_config=graph_config,
|
||||||
|
user_id="444",
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
invoke_from=InvokeFrom.WEB_APP,
|
||||||
|
call_depth=0,
|
||||||
|
graph=graph,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
max_execution_steps=500,
|
||||||
|
max_execution_time=1200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_VALUE_EDGE = [
|
||||||
|
{
|
||||||
|
"id": "start-source-node-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "node",
|
||||||
|
"sourceHandle": "source",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "node-source-answer-target",
|
||||||
|
"source": "node",
|
||||||
|
"target": "answer",
|
||||||
|
"sourceHandle": "source",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
FAIL_BRANCH_EDGES = [
|
||||||
|
{
|
||||||
|
"id": "start-source-node-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "node",
|
||||||
|
"sourceHandle": "source",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "node-true-success-target",
|
||||||
|
"source": "node",
|
||||||
|
"target": "success",
|
||||||
|
"sourceHandle": "source",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "node-false-error-target",
|
||||||
|
"source": "node",
|
||||||
|
"target": "error",
|
||||||
|
"sourceHandle": "fail-branch",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_code_default_value_continue_on_error():
|
||||||
|
error_code = """
|
||||||
|
def main() -> dict:
|
||||||
|
return {
|
||||||
|
"result": 1 / 0,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph_config = {
|
||||||
|
"edges": DEFAULT_VALUE_EDGE,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||||
|
ContinueOnErrorTestHelper.get_code_node(
|
||||||
|
error_code, "default-value", [{"key": "result", "type": "number", "value": 132123}]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||||
|
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "132123"} for e in events)
|
||||||
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_code_fail_branch_continue_on_error():
|
||||||
|
error_code = """
|
||||||
|
def main() -> dict:
|
||||||
|
return {
|
||||||
|
"result": 1 / 0,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph_config = {
|
||||||
|
"edges": FAIL_BRANCH_EDGES,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{
|
||||||
|
"data": {"title": "success", "type": "answer", "answer": "node node run successfully"},
|
||||||
|
"id": "success",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"title": "error", "type": "answer", "answer": "node node run failed"},
|
||||||
|
"id": "error",
|
||||||
|
},
|
||||||
|
ContinueOnErrorTestHelper.get_code_node(error_code),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||||
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||||
|
assert any(
|
||||||
|
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "node node run failed"} for e in events
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_http_node_default_value_continue_on_error():
|
||||||
|
"""Test HTTP node with default value error strategy"""
|
||||||
|
graph_config = {
|
||||||
|
"edges": DEFAULT_VALUE_EDGE,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.response#}}"}, "id": "answer"},
|
||||||
|
ContinueOnErrorTestHelper.get_http_node(
|
||||||
|
"default-value", [{"key": "response", "type": "string", "value": "http node got error response"}]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
|
||||||
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||||
|
assert any(
|
||||||
|
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http node got error response"}
|
||||||
|
for e in events
|
||||||
|
)
|
||||||
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_http_node_fail_branch_continue_on_error():
|
||||||
|
"""Test HTTP node with fail-branch error strategy"""
|
||||||
|
graph_config = {
|
||||||
|
"edges": FAIL_BRANCH_EDGES,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{
|
||||||
|
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
|
||||||
|
"id": "success",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"title": "error", "type": "answer", "answer": "HTTP request failed"},
|
||||||
|
"id": "error",
|
||||||
|
},
|
||||||
|
ContinueOnErrorTestHelper.get_http_node(),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
|
||||||
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||||
|
assert any(
|
||||||
|
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "HTTP request failed"} for e in events
|
||||||
|
)
|
||||||
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_node_default_value_continue_on_error():
|
||||||
|
"""Test tool node with default value error strategy"""
|
||||||
|
graph_config = {
|
||||||
|
"edges": DEFAULT_VALUE_EDGE,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||||
|
ContinueOnErrorTestHelper.get_tool_node(
|
||||||
|
"default-value", [{"key": "result", "type": "string", "value": "default tool result"}]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
|
||||||
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||||
|
assert any(
|
||||||
|
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default tool result"} for e in events
|
||||||
|
)
|
||||||
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_node_fail_branch_continue_on_error():
|
||||||
|
"""Test HTTP node with fail-branch error strategy"""
|
||||||
|
graph_config = {
|
||||||
|
"edges": FAIL_BRANCH_EDGES,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{
|
||||||
|
"data": {"title": "success", "type": "answer", "answer": "tool execute successful"},
|
||||||
|
"id": "success",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"title": "error", "type": "answer", "answer": "tool execute failed"},
|
||||||
|
"id": "error",
|
||||||
|
},
|
||||||
|
ContinueOnErrorTestHelper.get_tool_node(),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
|
||||||
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||||
|
assert any(
|
||||||
|
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "tool execute failed"} for e in events
|
||||||
|
)
|
||||||
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_node_default_value_continue_on_error():
|
||||||
|
"""Test LLM node with default value error strategy"""
|
||||||
|
graph_config = {
|
||||||
|
"edges": DEFAULT_VALUE_EDGE,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.answer#}}"}, "id": "answer"},
|
||||||
|
ContinueOnErrorTestHelper.get_llm_node(
|
||||||
|
"default-value", [{"key": "answer", "type": "string", "value": "default LLM response"}]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
|
||||||
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||||
|
assert any(
|
||||||
|
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default LLM response"} for e in events
|
||||||
|
)
|
||||||
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_node_fail_branch_continue_on_error():
|
||||||
|
"""Test LLM node with fail-branch error strategy"""
|
||||||
|
graph_config = {
|
||||||
|
"edges": FAIL_BRANCH_EDGES,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{
|
||||||
|
"data": {"title": "success", "type": "answer", "answer": "LLM request successful"},
|
||||||
|
"id": "success",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"title": "error", "type": "answer", "answer": "LLM request failed"},
|
||||||
|
"id": "error",
|
||||||
|
},
|
||||||
|
ContinueOnErrorTestHelper.get_llm_node(),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
|
||||||
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||||
|
assert any(
|
||||||
|
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "LLM request failed"} for e in events
|
||||||
|
)
|
||||||
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_status_code_error_http_node_fail_branch_continue_on_error():
|
||||||
|
"""Test HTTP node with fail-branch error strategy"""
|
||||||
|
graph_config = {
|
||||||
|
"edges": FAIL_BRANCH_EDGES,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{
|
||||||
|
"data": {"title": "success", "type": "answer", "answer": "http execute successful"},
|
||||||
|
"id": "success",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"title": "error", "type": "answer", "answer": "http execute failed"},
|
||||||
|
"id": "error",
|
||||||
|
},
|
||||||
|
ContinueOnErrorTestHelper.get_error_status_code_http_node(),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
|
||||||
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||||
|
assert any(
|
||||||
|
isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http execute failed"} for e in events
|
||||||
|
)
|
||||||
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_variable_pool_error_type_variable():
|
||||||
|
graph_config = {
|
||||||
|
"edges": FAIL_BRANCH_EDGES,
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{
|
||||||
|
"data": {"title": "success", "type": "answer", "answer": "http execute successful"},
|
||||||
|
"id": "success",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {"title": "error", "type": "answer", "answer": "http execute failed"},
|
||||||
|
"id": "error",
|
||||||
|
},
|
||||||
|
ContinueOnErrorTestHelper.get_error_status_code_http_node(),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
list(graph_engine.run())
|
||||||
|
error_message = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_message"])
|
||||||
|
error_type = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_type"])
|
||||||
|
assert error_message != None
|
||||||
|
assert error_type.value == "HTTPResponseCodeError"
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_node_in_fail_branch_continue_on_error():
|
||||||
|
"""Test HTTP node with fail-branch error strategy"""
|
||||||
|
graph_config = {
|
||||||
|
"edges": FAIL_BRANCH_EDGES[:-1],
|
||||||
|
"nodes": [
|
||||||
|
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||||
|
{
|
||||||
|
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
|
||||||
|
"id": "success",
|
||||||
|
},
|
||||||
|
ContinueOnErrorTestHelper.get_http_node(),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||||
|
events = list(graph_engine.run())
|
||||||
|
|
||||||
|
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
|
||||||
|
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events)
|
||||||
|
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0
|
||||||
@ -1,18 +0,0 @@
|
|||||||
from collections.abc import Generator
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from extensions.storage.local_fs_storage import LocalFsStorage
|
|
||||||
from tests.unit_tests.oss.__mock.base import (
|
|
||||||
BaseStorageTest,
|
|
||||||
get_example_folder,
|
|
||||||
)
|
|
||||||
from tests.unit_tests.oss.__mock.local import setup_local_fs_mock
|
|
||||||
|
|
||||||
|
|
||||||
class TestLocalFS(BaseStorageTest):
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def setup_method(self, setup_local_fs_mock):
|
|
||||||
"""Executed before each test method."""
|
|
||||||
self.storage = LocalFsStorage()
|
|
||||||
self.storage.folder = get_example_folder()
|
|
||||||
@ -0,0 +1,88 @@
|
|||||||
|
import os
|
||||||
|
from collections.abc import Generator
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from configs.middleware.storage.opendal_storage_config import OpenDALScheme
|
||||||
|
from extensions.storage.opendal_storage import OpenDALStorage
|
||||||
|
from tests.unit_tests.oss.__mock.base import (
|
||||||
|
get_example_data,
|
||||||
|
get_example_filename,
|
||||||
|
get_example_filepath,
|
||||||
|
get_opendal_bucket,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenDAL:
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_method(self, *args, **kwargs):
|
||||||
|
"""Executed before each test method."""
|
||||||
|
self.storage = OpenDALStorage(
|
||||||
|
scheme=OpenDALScheme.FS,
|
||||||
|
root=get_opendal_bucket(),
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class", autouse=True)
|
||||||
|
def teardown_class(self, request):
|
||||||
|
"""Clean up after all tests in the class."""
|
||||||
|
|
||||||
|
def cleanup():
|
||||||
|
folder = Path(get_opendal_bucket())
|
||||||
|
if folder.exists() and folder.is_dir():
|
||||||
|
for item in folder.iterdir():
|
||||||
|
if item.is_file():
|
||||||
|
item.unlink()
|
||||||
|
elif item.is_dir():
|
||||||
|
item.rmdir()
|
||||||
|
folder.rmdir()
|
||||||
|
|
||||||
|
return cleanup()
|
||||||
|
|
||||||
|
def test_save_and_exists(self):
|
||||||
|
"""Test saving data and checking existence."""
|
||||||
|
filename = get_example_filename()
|
||||||
|
data = get_example_data()
|
||||||
|
|
||||||
|
assert not self.storage.exists(filename)
|
||||||
|
self.storage.save(filename, data)
|
||||||
|
assert self.storage.exists(filename)
|
||||||
|
|
||||||
|
def test_load_once(self):
|
||||||
|
"""Test loading data once."""
|
||||||
|
filename = get_example_filename()
|
||||||
|
data = get_example_data()
|
||||||
|
|
||||||
|
self.storage.save(filename, data)
|
||||||
|
loaded_data = self.storage.load_once(filename)
|
||||||
|
assert loaded_data == data
|
||||||
|
|
||||||
|
def test_load_stream(self):
|
||||||
|
"""Test loading data as a stream."""
|
||||||
|
filename = get_example_filename()
|
||||||
|
data = get_example_data()
|
||||||
|
|
||||||
|
self.storage.save(filename, data)
|
||||||
|
generator = self.storage.load_stream(filename)
|
||||||
|
assert isinstance(generator, Generator)
|
||||||
|
assert next(generator) == data
|
||||||
|
|
||||||
|
def test_download(self):
|
||||||
|
"""Test downloading data to a file."""
|
||||||
|
filename = get_example_filename()
|
||||||
|
filepath = str(Path(get_opendal_bucket()) / filename)
|
||||||
|
data = get_example_data()
|
||||||
|
|
||||||
|
self.storage.save(filename, data)
|
||||||
|
self.storage.download(filename, filepath)
|
||||||
|
|
||||||
|
def test_delete(self):
|
||||||
|
"""Test deleting a file."""
|
||||||
|
filename = get_example_filename()
|
||||||
|
data = get_example_data()
|
||||||
|
|
||||||
|
self.storage.save(filename, data)
|
||||||
|
assert self.storage.exists(filename)
|
||||||
|
|
||||||
|
self.storage.delete(filename)
|
||||||
|
assert not self.storage.exists(filename)
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue