Merge branch 'langgenius:main' into add-document-status-update

pull/18235/head
GuanMu 1 year ago committed by GitHub
commit 1625148b07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -4,14 +4,10 @@ import platform
import re
import urllib.parse
import warnings
from collections.abc import Mapping
from typing import Any
from uuid import uuid4
import httpx
from constants import DEFAULT_FILE_NUMBER_LIMITS
try:
import magic
except ImportError:
@ -31,8 +27,6 @@ except ImportError:
from pydantic import BaseModel
from configs import dify_config
class FileInfo(BaseModel):
filename: str
@ -89,38 +83,3 @@ def guess_file_info_from_response(response: httpx.Response):
mimetype=mimetype,
size=int(response.headers.get("Content-Length", -1)),
)
def get_parameters_from_feature_dict(*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]):
return {
"opening_statement": features_dict.get("opening_statement"),
"suggested_questions": features_dict.get("suggested_questions", []),
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
"user_input_form": user_input_form,
"sensitive_word_avoidance": features_dict.get(
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
),
"file_upload": features_dict.get(
"file_upload",
{
"image": {
"enabled": False,
"number_limits": DEFAULT_FILE_NUMBER_LIMITS,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"],
}
},
),
"system_parameters": {
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
},
}

@ -1,5 +1,4 @@
from datetime import datetime
from dateutil.parser import isoparse
from flask_restful import Resource, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session
@ -41,10 +40,10 @@ class WorkflowAppLogApi(Resource):
args.status = WorkflowRunStatus(args.status) if args.status else None
if args.created_at__before:
args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00"))
args.created_at__before = isoparse(args.created_at__before)
if args.created_at__after:
args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00"))
args.created_at__after = isoparse(args.created_at__after)
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()

@ -1,10 +1,10 @@
from flask_restful import marshal_with # type: ignore
from controllers.common import fields
from controllers.common import helpers as controller_helpers
from controllers.console import api
from controllers.console.app.error import AppUnavailableError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import AppMode, InstalledApp
from services.app_service import AppService
@ -36,9 +36,7 @@ class AppParameterApi(InstalledAppResource):
user_input_form = features_dict.get("user_input_form", [])
return controller_helpers.get_parameters_from_feature_dict(
features_dict=features_dict, user_input_form=user_input_form
)
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
class ExploreAppMetaApi(InstalledAppResource):

@ -13,6 +13,7 @@ from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocatio
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
from core.plugin.entities.request import (
RequestFetchAppInfo,
RequestInvokeApp,
RequestInvokeEncrypt,
RequestInvokeLLM,
@ -278,6 +279,17 @@ class PluginUploadFileRequestApi(Resource):
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
class PluginFetchAppInfoApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestFetchAppInfo)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestFetchAppInfo):
return BaseBackwardsInvocationResponse(
data=PluginAppBackwardsInvocation.fetch_app_info(payload.app_id, tenant_model.id)
).model_dump()
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
@ -291,3 +303,4 @@ api.add_resource(PluginInvokeAppApi, "/invoke/app")
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")
api.add_resource(PluginInvokeSummaryApi, "/invoke/summary")
api.add_resource(PluginUploadFileRequestApi, "/upload/file/request")
api.add_resource(PluginFetchAppInfoApi, "/fetch/app/info")

@ -1,10 +1,10 @@
from flask_restful import Resource, marshal_with # type: ignore
from controllers.common import fields
from controllers.common import helpers as controller_helpers
from controllers.service_api import api
from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import App, AppMode
from services.app_service import AppService
@ -32,9 +32,7 @@ class AppParameterApi(Resource):
user_input_form = features_dict.get("user_input_form", [])
return controller_helpers.get_parameters_from_feature_dict(
features_dict=features_dict, user_input_form=user_input_form
)
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
class AppMetaApi(Resource):

@ -1,6 +1,6 @@
import logging
from datetime import datetime
from dateutil.parser import isoparse
from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session
@ -140,10 +140,10 @@ class WorkflowAppLogApi(Resource):
args.status = WorkflowRunStatus(args.status) if args.status else None
if args.created_at__before:
args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00"))
args.created_at__before = isoparse(args.created_at__before)
if args.created_at__after:
args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00"))
args.created_at__after = isoparse(args.created_at__after)
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()

@ -122,6 +122,8 @@ class SegmentApi(DatasetApiResource):
tenant_id=current_user.current_tenant_id,
status_list=args["status"],
keyword=args["keyword"],
page=page,
limit=limit,
)
response = {

@ -1,10 +1,10 @@
from flask_restful import marshal_with # type: ignore
from controllers.common import fields
from controllers.common import helpers as controller_helpers
from controllers.web import api
from controllers.web.error import AppUnavailableError
from controllers.web.wraps import WebApiResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import App, AppMode
from services.app_service import AppService
@ -31,9 +31,7 @@ class AppParameterApi(WebApiResource):
user_input_form = features_dict.get("user_input_form", [])
return controller_helpers.get_parameters_from_feature_dict(
features_dict=features_dict, user_input_form=user_input_form
)
return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
class AppMeta(WebApiResource):

@ -52,6 +52,7 @@ class AgentStrategyParameter(PluginParameter):
return cast_parameter_value(self, value)
type: AgentStrategyParameterType = Field(..., description="The type of the parameter")
help: Optional[I18nObject] = None
def init_frontend_parameter(self, value: Any):
return init_frontend_parameter(self, self.type, value)

@ -0,0 +1,45 @@
from collections.abc import Mapping
from typing import Any
from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
def get_parameters_from_feature_dict(
*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]
) -> Mapping[str, Any]:
"""
Mapping from feature dict to webapp parameters
"""
return {
"opening_statement": features_dict.get("opening_statement"),
"suggested_questions": features_dict.get("suggested_questions", []),
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
"user_input_form": user_input_form,
"sensitive_word_avoidance": features_dict.get(
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
),
"file_upload": features_dict.get(
"file_upload",
{
"image": {
"enabled": False,
"number_limits": DEFAULT_FILE_NUMBER_LIMITS,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"],
}
},
),
"system_parameters": {
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
},
}

@ -17,6 +17,7 @@ class BaseAppGenerator:
user_inputs: Optional[Mapping[str, Any]],
variables: Sequence["VariableEntity"],
tenant_id: str,
strict_type_validation: bool = False,
) -> Mapping[str, Any]:
user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values
@ -37,6 +38,7 @@ class BaseAppGenerator:
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
),
strict_type_validation=strict_type_validation,
)
for k, v in user_inputs.items()
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE

@ -92,6 +92,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
mappings=files,
tenant_id=app_model.tenant_id,
config=file_extra_config,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
)
# convert to app config
@ -114,7 +115,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_config=app_config,
file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
user_inputs=inputs,
variables=app_config.variables,
tenant_id=app_model.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
),
files=list(system_files),
user_id=user.id,

@ -2,6 +2,7 @@ from collections.abc import Generator, Mapping
from typing import Optional, Union
from controllers.service_api.wraps import create_or_update_end_user_for_user_id
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
from core.app.apps.chat.app_generator import ChatAppGenerator
@ -15,6 +16,34 @@ from models.model import App, AppMode, EndUser
class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
@classmethod
def fetch_app_info(cls, app_id: str, tenant_id: str) -> Mapping:
"""
Fetch app info
"""
app = cls._get_app(app_id, tenant_id)
"""Retrieve app parameters."""
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app.workflow
if workflow is None:
raise ValueError("unexpected app type")
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app.app_model_config
if app_model_config is None:
raise ValueError("unexpected app type")
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
return {
"data": get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form),
}
@classmethod
def invoke_app(
cls,

@ -204,3 +204,11 @@ class RequestRequestUploadFile(BaseModel):
filename: str
mimetype: str
class RequestFetchAppInfo(BaseModel):
"""
Request to fetch app info
"""
app_id: str

@ -155,9 +155,28 @@ class AnswerStreamProcessor(StreamProcessor):
for answer_node_id, route_position in self.route_position.items():
if answer_node_id not in self.rest_node_ids:
continue
# exclude current node id
# Remove current node id from answer dependencies to support stream output if it is a success branch
answer_dependencies = self.generate_routes.answer_dependencies
if event.node_id in answer_dependencies[answer_node_id]:
edge_mapping = self.graph.edge_mapping.get(event.node_id)
success_edge = (
next(
(
edge
for edge in edge_mapping
if edge.run_condition
and edge.run_condition.type == "branch_identify"
and edge.run_condition.branch_identify == "success-branch"
),
None,
)
if edge_mapping
else None
)
if (
event.node_id in answer_dependencies[answer_node_id]
and success_edge
and success_edge.target_node_id == answer_node_id
):
answer_dependencies[answer_node_id].remove(event.node_id)
answer_dependencies_ids = answer_dependencies.get(answer_node_id, [])
# all depends on answer node id not in rest node ids

@ -52,6 +52,7 @@ def build_from_mapping(
mapping: Mapping[str, Any],
tenant_id: str,
config: FileUploadConfig | None = None,
strict_type_validation: bool = False,
) -> File:
transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
@ -69,6 +70,7 @@ def build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
transfer_method=transfer_method,
strict_type_validation=strict_type_validation,
)
if config and not _is_file_valid_with_config(
@ -87,12 +89,14 @@ def build_from_mappings(
mappings: Sequence[Mapping[str, Any]],
config: FileUploadConfig | None = None,
tenant_id: str,
strict_type_validation: bool = False,
) -> Sequence[File]:
files = [
build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
config=config,
strict_type_validation=strict_type_validation,
)
for mapping in mappings
]
@ -116,6 +120,7 @@ def _build_from_local_file(
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
upload_file_id = mapping.get("upload_file_id")
if not upload_file_id:
@ -134,10 +139,16 @@ def _build_from_local_file(
if row is None:
raise ValueError("Invalid upload file")
file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
if file_type.value != mapping.get("type", "custom"):
detected_file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
specified_type = mapping.get("type", "custom")
if strict_type_validation and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
file_type = (
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
)
return File(
id=mapping.get("id"),
filename=row.name,
@ -158,6 +169,7 @@ def _build_from_remote_url(
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
upload_file_id = mapping.get("upload_file_id")
if upload_file_id:
@ -174,10 +186,21 @@ def _build_from_remote_url(
if upload_file is None:
raise ValueError("Invalid upload file")
file_type = _standardize_file_type(extension="." + upload_file.extension, mime_type=upload_file.mime_type)
if file_type.value != mapping.get("type", "custom"):
detected_file_type = _standardize_file_type(
extension="." + upload_file.extension, mime_type=upload_file.mime_type
)
specified_type = mapping.get("type")
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
file_type = (
FileType(specified_type)
if specified_type and specified_type != FileType.CUSTOM.value
else detected_file_type
)
return File(
id=mapping.get("id"),
filename=upload_file.name,
@ -237,6 +260,7 @@ def _build_from_tool_file(
mapping: Mapping[str, Any],
tenant_id: str,
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
tool_file = (
db.session.query(ToolFile)
@ -252,7 +276,16 @@ def _build_from_tool_file(
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
detected_file_type = _standardize_file_type(extension="." + extension, mime_type=tool_file.mimetype)
specified_type = mapping.get("type")
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
file_type = (
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
)
return File(
id=mapping.get("id"),

@ -2120,7 +2120,7 @@ class SegmentService:
dataset_id=dataset.id,
document_id=document.id,
segment_id=segment.id,
position=max_position + 1,
position=max_position + 1 if max_position else 1,
index_node_id=index_node_id,
index_node_hash=index_node_hash,
content=content,
@ -2270,7 +2270,13 @@ class SegmentService:
@classmethod
def get_segments(
cls, document_id: str, tenant_id: str, status_list: list[str] | None = None, keyword: str | None = None
cls,
document_id: str,
tenant_id: str,
status_list: list[str] | None = None,
keyword: str | None = None,
page: int = 1,
limit: int = 20,
):
"""Get segments for a document with optional filtering."""
query = DocumentSegment.query.filter(
@ -2283,10 +2289,11 @@ class SegmentService:
if keyword:
query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%"))
segments = query.order_by(DocumentSegment.position.asc()).all()
total = len(segments)
paginated_segments = query.order_by(DocumentSegment.position.asc()).paginate(
page=page, per_page=limit, max_per_page=100, error_out=False
)
return segments, total
return paginated_segments.items, paginated_segments.total
@classmethod
def update_segment_by_id(

@ -1,14 +1,20 @@
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunStreamChunkEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode
from models.enums import UserFrom
from models.workflow import WorkflowType
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
class ContinueOnErrorTestHelper:
@ -492,10 +498,7 @@ def test_no_node_in_fail_branch_continue_on_error():
"edges": FAIL_BRANCH_EDGES[:-1],
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
"id": "success",
},
{"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, "id": "success"},
ContinueOnErrorTestHelper.get_http_node(),
],
}
@ -506,3 +509,47 @@ def test_no_node_in_fail_branch_continue_on_error():
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0
def test_stream_output_with_fail_branch_continue_on_error():
"""Test stream output with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "LLM request successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "{{#node.text#}}"},
"id": "error",
},
ContinueOnErrorTestHelper.get_llm_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
def llm_generator(self):
contents = ["hi", "bye", "good morning"]
yield RunStreamChunkEvent(chunk_content=contents[0], from_variable_selector=[self.node_id, "text"])
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={},
process_data={},
outputs={},
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: 1,
NodeRunMetadataKey.TOTAL_PRICE: 1,
NodeRunMetadataKey.CURRENCY: "USD",
},
)
)
with patch.object(LLMNode, "_run", new=llm_generator):
events = list(graph_engine.run())
assert sum(isinstance(e, NodeRunStreamChunkEvent) for e in events) == 1
assert all(not isinstance(e, NodeRunFailedEvent | NodeRunExceptionEvent) for e in events)

@ -0,0 +1,198 @@
import uuid
from unittest.mock import MagicMock, patch
import pytest
from httpx import Response
from factories.file_factory import (
File,
FileTransferMethod,
FileType,
FileUploadConfig,
build_from_mapping,
)
from models import ToolFile, UploadFile
# Test Data
TEST_TENANT_ID = "test_tenant_id"
TEST_UPLOAD_FILE_ID = str(uuid.uuid4())
TEST_TOOL_FILE_ID = str(uuid.uuid4())
TEST_REMOTE_URL = "http://example.com/test.jpg"
# Test Config
TEST_CONFIG = FileUploadConfig(
allowed_file_types=["image", "document"],
allowed_file_extensions=[".jpg", ".pdf"],
allowed_file_upload_methods=[FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE],
number_limits=10,
)
# Fixtures
@pytest.fixture
def mock_upload_file():
mock = MagicMock(spec=UploadFile)
mock.id = TEST_UPLOAD_FILE_ID
mock.tenant_id = TEST_TENANT_ID
mock.name = "test.jpg"
mock.extension = "jpg"
mock.mime_type = "image/jpeg"
mock.source_url = TEST_REMOTE_URL
mock.size = 1024
mock.key = "test_key"
with patch("factories.file_factory.db.session.scalar", return_value=mock) as m:
yield m
@pytest.fixture
def mock_tool_file():
mock = MagicMock(spec=ToolFile)
mock.id = TEST_TOOL_FILE_ID
mock.tenant_id = TEST_TENANT_ID
mock.name = "tool_file.pdf"
mock.file_key = "tool_file.pdf"
mock.mimetype = "application/pdf"
mock.original_url = "http://example.com/tool.pdf"
mock.size = 2048
with patch("factories.file_factory.db.session.query") as mock_query:
mock_query.return_value.filter.return_value.first.return_value = mock
yield mock
@pytest.fixture
def mock_http_head():
def _mock_response(filename, size, content_type):
return Response(
status_code=200,
headers={
"Content-Disposition": f'attachment; filename="{filename}"',
"Content-Length": str(size),
"Content-Type": content_type,
},
)
with patch("factories.file_factory.ssrf_proxy.head") as mock_head:
mock_head.return_value = _mock_response("remote_test.jpg", 2048, "image/jpeg")
yield mock_head
# Helper functions
def local_file_mapping(file_type="image"):
return {
"transfer_method": "local_file",
"upload_file_id": TEST_UPLOAD_FILE_ID,
"type": file_type,
}
def tool_file_mapping(file_type="document"):
return {
"transfer_method": "tool_file",
"tool_file_id": TEST_TOOL_FILE_ID,
"type": file_type,
}
# Tests
def test_build_from_mapping_backward_compatibility(mock_upload_file):
mapping = local_file_mapping(file_type="image")
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
assert isinstance(file, File)
assert file.transfer_method == FileTransferMethod.LOCAL_FILE
assert file.type == FileType.IMAGE
assert file.related_id == TEST_UPLOAD_FILE_ID
@pytest.mark.parametrize(
("file_type", "should_pass", "expected_error"),
[
("image", True, None),
("document", False, "Detected file type does not match"),
],
)
def test_build_from_local_file_strict_validation(mock_upload_file, file_type, should_pass, expected_error):
mapping = local_file_mapping(file_type=file_type)
if should_pass:
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
assert file.type == FileType(file_type)
else:
with pytest.raises(ValueError, match=expected_error):
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
@pytest.mark.parametrize(
("file_type", "should_pass", "expected_error"),
[
("document", True, None),
("image", False, "Detected file type does not match"),
],
)
def test_build_from_tool_file_strict_validation(mock_tool_file, file_type, should_pass, expected_error):
"""Strict type validation for tool_file."""
mapping = tool_file_mapping(file_type=file_type)
if should_pass:
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
assert file.type == FileType(file_type)
else:
with pytest.raises(ValueError, match=expected_error):
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
def test_build_from_remote_url(mock_http_head):
mapping = {
"transfer_method": "remote_url",
"url": TEST_REMOTE_URL,
"type": "image",
}
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
assert file.transfer_method == FileTransferMethod.REMOTE_URL
assert file.type == FileType.IMAGE
assert file.filename == "remote_test.jpg"
assert file.size == 2048
def test_tool_file_not_found():
"""Test ToolFile not found in database."""
with patch("factories.file_factory.db.session.query") as mock_query:
mock_query.return_value.filter.return_value.first.return_value = None
mapping = tool_file_mapping()
with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"):
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
def test_local_file_not_found():
"""Test UploadFile not found in database."""
with patch("factories.file_factory.db.session.scalar", return_value=None):
mapping = local_file_mapping()
with pytest.raises(ValueError, match="Invalid upload file"):
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
def test_build_without_type_specification(mock_upload_file):
"""Test the situation where no file type is specified"""
mapping = {
"transfer_method": "local_file",
"upload_file_id": TEST_UPLOAD_FILE_ID,
# leave out the type
}
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
# It should automatically infer the type as "image" based on the file extension
assert file.type == FileType.IMAGE
@pytest.mark.parametrize(
("file_type", "should_pass", "expected_error"),
[
("image", True, None),
("video", False, "File validation failed"),
],
)
def test_file_validation_with_config(mock_upload_file, file_type, should_pass, expected_error):
"""Test the validation of files and configurations"""
mapping = local_file_mapping(file_type=file_type)
if should_pass:
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=TEST_CONFIG)
assert file is not None
else:
with pytest.raises(ValueError, match=expected_error):
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=TEST_CONFIG)

@ -16,7 +16,7 @@ import type {
Feedback,
} from '../types'
import { CONVERSATION_ID_INFO } from '../constants'
import { buildChatItemTree } from '../utils'
import { buildChatItemTree, getProcessedSystemVariablesFromUrlParams } from '../utils'
import { addFileInfos, sortAgentSorts } from '../../../tools/utils'
import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils'
import {
@ -106,6 +106,13 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
}, [isInstalledApp, installedAppInfo, appInfo])
const appId = useMemo(() => appData?.app_id, [appData])
const [userId, setUserId] = useState<string>()
useEffect(() => {
getProcessedSystemVariablesFromUrlParams().then(({ user_id }) => {
setUserId(user_id)
})
}, [])
useEffect(() => {
if (appData?.site.default_language)
changeLanguage(appData.site.default_language)
@ -124,18 +131,24 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
setSidebarCollapseState(localState === 'collapsed')
}
}, [appId])
const [conversationIdInfo, setConversationIdInfo] = useLocalStorageState<Record<string, string>>(CONVERSATION_ID_INFO, {
const [conversationIdInfo, setConversationIdInfo] = useLocalStorageState<Record<string, Record<string, string>>>(CONVERSATION_ID_INFO, {
defaultValue: {},
})
const currentConversationId = useMemo(() => conversationIdInfo?.[appId || ''] || '', [appId, conversationIdInfo])
const currentConversationId = useMemo(() => conversationIdInfo?.[appId || '']?.[userId || 'DEFAULT'] || '', [appId, conversationIdInfo, userId])
const handleConversationIdInfoChange = useCallback((changeConversationId: string) => {
if (appId) {
let prevValue = conversationIdInfo?.[appId || '']
if (typeof prevValue === 'string')
prevValue = {}
setConversationIdInfo({
...conversationIdInfo,
[appId || '']: changeConversationId,
[appId || '']: {
...prevValue,
[userId || 'DEFAULT']: changeConversationId,
},
})
}
}, [appId, conversationIdInfo, setConversationIdInfo])
}, [appId, conversationIdInfo, setConversationIdInfo, userId])
const [newConversationId, setNewConversationId] = useState('')
const chatShouldReloadKey = useMemo(() => {

@ -2,8 +2,6 @@ import type { FC } from 'react'
import { memo } from 'react'
import type { ChatItem } from '../../types'
import { useChatContext } from '../context'
import Button from '@/app/components/base/button'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
type SuggestedQuestionsProps = {
item: ChatItem
@ -12,9 +10,6 @@ const SuggestedQuestions: FC<SuggestedQuestionsProps> = ({
item,
}) => {
const { onSend } = useChatContext()
const media = useBreakpoints()
const isMobile = media === MediaType.mobile
const klassName = `mr-1 mt-1 ${isMobile ? 'block overflow-hidden text-ellipsis' : ''} max-w-full shrink-0 last:mr-0`
const {
isOpeningStatement,
@ -27,14 +22,13 @@ const SuggestedQuestions: FC<SuggestedQuestionsProps> = ({
return (
<div className='flex flex-wrap'>
{suggestedQuestions.filter(q => !!q && q.trim()).map((question, index) => (
<Button
<div
key={index}
variant='secondary-accent'
className={klassName}
className='system-sm-medium mr-1 mt-1 inline-flex max-w-full shrink-0 cursor-pointer flex-wrap rounded-lg border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg px-3.5 py-2 text-components-button-secondary-accent-text shadow-xs last:mr-0 hover:border-components-button-secondary-border-hover hover:bg-components-button-secondary-bg-hover'
onClick={() => onSend?.(question)}
>
{question}
</Button>),
</div>),
)}
</div>
)

@ -184,7 +184,7 @@ const ChatWrapper = () => {
return null
if (welcomeMessage.suggestedQuestions && welcomeMessage.suggestedQuestions?.length > 0) {
return (
<div className='flex h-[50vh] items-center justify-center px-4 py-12'>
<div className={cn('flex items-center justify-center px-4 py-12', isMobile ? 'min-h-[30vh] py-0' : 'h-[50vh]')}>
<div className='flex max-w-[720px] grow gap-4'>
<AppIcon
size='xl'
@ -202,7 +202,7 @@ const ChatWrapper = () => {
)
}
return (
<div className={cn('flex h-[50vh] flex-col items-center justify-center gap-3 py-12')}>
<div className={cn('flex h-[50vh] flex-col items-center justify-center gap-3 py-12', isMobile ? 'min-h-[30vh] py-0' : 'h-[50vh]')}>
<AppIcon
size='xl'
iconType={appData?.site.icon_type}

@ -15,7 +15,7 @@ import type {
Feedback,
} from '../types'
import { CONVERSATION_ID_INFO } from '../constants'
import { buildChatItemTree, getProcessedInputsFromUrlParams } from '../utils'
import { buildChatItemTree, getProcessedInputsFromUrlParams, getProcessedSystemVariablesFromUrlParams } from '../utils'
import { getProcessedFilesFromResponse } from '../../file-uploader/utils'
import {
fetchAppInfo,
@ -72,23 +72,36 @@ export const useEmbeddedChatbot = () => {
}, [appInfo])
const appId = useMemo(() => appData?.app_id, [appData])
const [userId, setUserId] = useState<string>()
useEffect(() => {
getProcessedSystemVariablesFromUrlParams().then(({ user_id }) => {
setUserId(user_id)
})
}, [])
useEffect(() => {
if (appInfo?.site.default_language)
changeLanguage(appInfo.site.default_language)
}, [appInfo])
const [conversationIdInfo, setConversationIdInfo] = useLocalStorageState<Record<string, string>>(CONVERSATION_ID_INFO, {
const [conversationIdInfo, setConversationIdInfo] = useLocalStorageState<Record<string, Record<string, string>>>(CONVERSATION_ID_INFO, {
defaultValue: {},
})
const currentConversationId = useMemo(() => conversationIdInfo?.[appId || ''] || '', [appId, conversationIdInfo])
const currentConversationId = useMemo(() => conversationIdInfo?.[appId || '']?.[userId || 'DEFAULT'] || '', [appId, conversationIdInfo, userId])
const handleConversationIdInfoChange = useCallback((changeConversationId: string) => {
if (appId) {
let prevValue = conversationIdInfo?.[appId || '']
if (typeof prevValue === 'string')
prevValue = {}
setConversationIdInfo({
...conversationIdInfo,
[appId || '']: changeConversationId,
[appId || '']: {
...prevValue,
[userId || 'DEFAULT']: changeConversationId,
},
})
}
}, [appId, conversationIdInfo, setConversationIdInfo])
}, [appId, conversationIdInfo, setConversationIdInfo, userId])
const [newConversationId, setNewConversationId] = useState('')
const chatShouldReloadKey = useMemo(() => {

@ -7,9 +7,11 @@ import SecretKeyButton from '@/app/components/develop/secret-key/secret-key-butt
type ApiServerProps = {
apiBaseUrl: string
appId?: string
}
const ApiServer: FC<ApiServerProps> = ({
apiBaseUrl,
appId,
}) => {
const { t } = useTranslation()
@ -25,7 +27,7 @@ const ApiServer: FC<ApiServerProps> = ({
{t('appApi.ok')}
</div>
<SecretKeyButton
className='!h-8 shrink-0'
className='!h-8 shrink-0' appId={appId}
/>
</div>
)

@ -23,7 +23,7 @@ const DevelopMain = ({ appId }: IDevelopMainProps) => {
<div className='relative flex h-full flex-col overflow-hidden'>
<div className='flex shrink-0 items-center justify-between border-b border-solid border-b-divider-regular px-6 py-2'>
<div className='text-lg font-medium text-text-primary'></div>
<ApiServer apiBaseUrl={appDetail.api_base_url} />
<ApiServer apiBaseUrl={appDetail.api_base_url} appId={appId} />
</div>
<div className='grow overflow-auto px-4 py-4 sm:px-10'>
<Doc appDetail={appDetail} />

@ -406,8 +406,7 @@ export type VersionProps = {
export type StrategyParamItem = {
name: string
label: Record<Locale, string>
human_description: Record<Locale, string>
llm_description: string
help: Record<Locale, string>
placeholder: Record<Locale, string>
type: string
scope: string

@ -2,29 +2,44 @@ import { CONVERSATION_ID_INFO } from '../base/chat/constants'
import { fetchAccessToken } from '@/service/share'
import { getProcessedSystemVariablesFromUrlParams } from '../base/chat/utils'
export const isTokenV1 = (token: Record<string, any>) => {
return !token.version
}
export const getInitialTokenV2 = (): Record<string, any> => ({
version: 2,
})
export const checkOrSetAccessToken = async () => {
const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
let accessTokenJson = { [sharedToken]: '' }
const userId = (await getProcessedSystemVariablesFromUrlParams()).user_id
const accessToken = localStorage.getItem('token') || JSON.stringify(getInitialTokenV2())
let accessTokenJson = getInitialTokenV2()
try {
accessTokenJson = JSON.parse(accessToken)
if (isTokenV1(accessTokenJson))
accessTokenJson = getInitialTokenV2()
}
catch {
}
if (!accessTokenJson[sharedToken]) {
const sysUserId = (await getProcessedSystemVariablesFromUrlParams()).user_id
const res = await fetchAccessToken(sharedToken, sysUserId)
accessTokenJson[sharedToken] = res.access_token
if (!accessTokenJson[sharedToken]?.[userId || 'DEFAULT']) {
const res = await fetchAccessToken(sharedToken, userId)
accessTokenJson[sharedToken] = {
...accessTokenJson[sharedToken],
[userId || 'DEFAULT']: res.access_token,
}
localStorage.setItem('token', JSON.stringify(accessTokenJson))
}
}
export const setAccessToken = async (sharedToken: string, token: string) => {
const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
let accessTokenJson = { [sharedToken]: '' }
export const setAccessToken = async (sharedToken: string, token: string, user_id?: string) => {
const accessToken = localStorage.getItem('token') || JSON.stringify(getInitialTokenV2())
let accessTokenJson = getInitialTokenV2()
try {
accessTokenJson = JSON.parse(accessToken)
if (isTokenV1(accessTokenJson))
accessTokenJson = getInitialTokenV2()
}
catch {
@ -32,17 +47,22 @@ export const setAccessToken = async (sharedToken: string, token: string) => {
localStorage.removeItem(CONVERSATION_ID_INFO)
accessTokenJson[sharedToken] = token
accessTokenJson[sharedToken] = {
...accessTokenJson[sharedToken],
[user_id || 'DEFAULT']: token,
}
localStorage.setItem('token', JSON.stringify(accessTokenJson))
}
export const removeAccessToken = () => {
const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
let accessTokenJson = { [sharedToken]: '' }
const accessToken = localStorage.getItem('token') || JSON.stringify(getInitialTokenV2())
let accessTokenJson = getInitialTokenV2()
try {
accessTokenJson = JSON.parse(accessToken)
if (isTokenV1(accessTokenJson))
accessTokenJson = getInitialTokenV2()
}
catch {

@ -65,7 +65,7 @@ export const AgentStrategy = memo((props: AgentStrategyProps) => {
switch (schema.type) {
case FormTypeEnum.textInput: {
const def = schema as CredentialFormSchemaTextInput
const value = props.value[schema.variable]
const value = props.value[schema.variable] || schema.default
const onChange = (value: string) => {
props.onChange({ ...props.value, [schema.variable]: value })
}

@ -27,6 +27,7 @@ export function strategyParamToCredientialForm(param: StrategyParamItem): Creden
variable: param.name,
show_on: [],
type: toType(param.type),
tooltip: param.help,
}
}
@ -53,6 +54,7 @@ const AgentPanel: FC<NodePanelProps<AgentNodeType>> = (props) => {
outputSchema,
handleMemoryChange,
} = useConfig(props.id, props.data)
console.log('currentStrategy', currentStrategy)
const { t } = useTranslation()
const nodeInfo = useMemo(() => {
if (!runResult)

@ -287,9 +287,9 @@ const handleStream = (
const baseFetch = base
export const upload = (options: any, isPublicAPI?: boolean, url?: string, searchParams?: string): Promise<any> => {
export const upload = async (options: any, isPublicAPI?: boolean, url?: string, searchParams?: string): Promise<any> => {
const urlPrefix = isPublicAPI ? PUBLIC_API_PREFIX : API_PREFIX
const token = getAccessToken(isPublicAPI)
const token = await getAccessToken(isPublicAPI)
const defaultOptions = {
method: 'POST',
url: (url ? `${urlPrefix}${url}` : `${urlPrefix}/files/upload`) + (searchParams || ''),
@ -324,7 +324,7 @@ export const upload = (options: any, isPublicAPI?: boolean, url?: string, search
})
}
export const ssePost = (
export const ssePost = async (
url: string,
fetchOptions: FetchOptionType,
otherOptions: IOtherOptions,
@ -385,7 +385,7 @@ export const ssePost = (
if (body)
options.body = JSON.stringify(body)
const accessToken = getAccessToken(isPublicAPI)
const accessToken = await getAccessToken(isPublicAPI)
; (options.headers as Headers).set('Authorization', `Bearer ${accessToken}`)
globalThis.fetch(urlWithPrefix, options as RequestInit)

@ -3,6 +3,8 @@ import ky from 'ky'
import type { IOtherOptions } from './base'
import Toast from '@/app/components/base/toast'
import { API_PREFIX, MARKETPLACE_API_PREFIX, PUBLIC_API_PREFIX } from '@/config'
import { getInitialTokenV2, isTokenV1 } from '@/app/components/share/utils'
import { getProcessedSystemVariablesFromUrlParams } from '@/app/components/base/chat/utils'
const TIME_OUT = 100000
@ -67,44 +69,34 @@ const beforeErrorToast = (otherOptions: IOtherOptions): BeforeErrorHook => {
}
}
export const getPublicToken = () => {
let token = ''
const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
let accessTokenJson = { [sharedToken]: '' }
try {
accessTokenJson = JSON.parse(accessToken)
}
catch { }
token = accessTokenJson[sharedToken]
return token || ''
}
export function getAccessToken(isPublicAPI?: boolean) {
export async function getAccessToken(isPublicAPI?: boolean) {
if (isPublicAPI) {
const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
let accessTokenJson = { [sharedToken]: '' }
const userId = (await getProcessedSystemVariablesFromUrlParams()).user_id
const accessToken = localStorage.getItem('token') || JSON.stringify({ version: 2 })
let accessTokenJson: Record<string, any> = { version: 2 }
try {
accessTokenJson = JSON.parse(accessToken)
if (isTokenV1(accessTokenJson))
accessTokenJson = getInitialTokenV2()
}
catch {
}
return accessTokenJson[sharedToken]
return accessTokenJson[sharedToken]?.[userId || 'DEFAULT']
}
else {
return localStorage.getItem('console_token') || ''
}
}
const beforeRequestPublicAuthorization: BeforeRequestHook = (request) => {
const token = getAccessToken(true)
const beforeRequestPublicAuthorization: BeforeRequestHook = async (request) => {
const token = await getAccessToken(true)
request.headers.set('Authorization', `Bearer ${token}`)
}
const beforeRequestAuthorization: BeforeRequestHook = (request) => {
const accessToken = getAccessToken()
const beforeRequestAuthorization: BeforeRequestHook = async (request) => {
const accessToken = await getAccessToken()
request.headers.set('Authorization', `Bearer ${accessToken}`)
}

Loading…
Cancel
Save