Merge branch 'langgenius:main' into add-document-status-update

pull/18235/head
GuanMu 1 year ago committed by GitHub
commit 1625148b07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -4,14 +4,10 @@ import platform
import re import re
import urllib.parse import urllib.parse
import warnings import warnings
from collections.abc import Mapping
from typing import Any
from uuid import uuid4 from uuid import uuid4
import httpx import httpx
from constants import DEFAULT_FILE_NUMBER_LIMITS
try: try:
import magic import magic
except ImportError: except ImportError:
@ -31,8 +27,6 @@ except ImportError:
from pydantic import BaseModel from pydantic import BaseModel
from configs import dify_config
class FileInfo(BaseModel): class FileInfo(BaseModel):
filename: str filename: str
@ -89,38 +83,3 @@ def guess_file_info_from_response(response: httpx.Response):
mimetype=mimetype, mimetype=mimetype,
size=int(response.headers.get("Content-Length", -1)), size=int(response.headers.get("Content-Length", -1)),
) )
def get_parameters_from_feature_dict(*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]):
return {
"opening_statement": features_dict.get("opening_statement"),
"suggested_questions": features_dict.get("suggested_questions", []),
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
"user_input_form": user_input_form,
"sensitive_word_avoidance": features_dict.get(
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
),
"file_upload": features_dict.get(
"file_upload",
{
"image": {
"enabled": False,
"number_limits": DEFAULT_FILE_NUMBER_LIMITS,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"],
}
},
),
"system_parameters": {
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
},
}

@ -1,5 +1,4 @@
from datetime import datetime from dateutil.parser import isoparse
from flask_restful import Resource, marshal_with, reqparse # type: ignore from flask_restful import Resource, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range # type: ignore from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -41,10 +40,10 @@ class WorkflowAppLogApi(Resource):
args.status = WorkflowRunStatus(args.status) if args.status else None args.status = WorkflowRunStatus(args.status) if args.status else None
if args.created_at__before: if args.created_at__before:
args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00")) args.created_at__before = isoparse(args.created_at__before)
if args.created_at__after: if args.created_at__after:
args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00")) args.created_at__after = isoparse(args.created_at__after)
# get paginate workflow app logs # get paginate workflow app logs
workflow_app_service = WorkflowAppService() workflow_app_service = WorkflowAppService()

@ -1,10 +1,10 @@
from flask_restful import marshal_with # type: ignore from flask_restful import marshal_with # type: ignore
from controllers.common import fields from controllers.common import fields
from controllers.common import helpers as controller_helpers
from controllers.console import api from controllers.console import api
from controllers.console.app.error import AppUnavailableError from controllers.console.app.error import AppUnavailableError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import AppMode, InstalledApp from models.model import AppMode, InstalledApp
from services.app_service import AppService from services.app_service import AppService
@ -36,9 +36,7 @@ class AppParameterApi(InstalledAppResource):
user_input_form = features_dict.get("user_input_form", []) user_input_form = features_dict.get("user_input_form", [])
return controller_helpers.get_parameters_from_feature_dict( return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
features_dict=features_dict, user_input_form=user_input_form
)
class ExploreAppMetaApi(InstalledAppResource): class ExploreAppMetaApi(InstalledAppResource):

@ -13,6 +13,7 @@ from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocatio
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
from core.plugin.entities.request import ( from core.plugin.entities.request import (
RequestFetchAppInfo,
RequestInvokeApp, RequestInvokeApp,
RequestInvokeEncrypt, RequestInvokeEncrypt,
RequestInvokeLLM, RequestInvokeLLM,
@ -278,6 +279,17 @@ class PluginUploadFileRequestApi(Resource):
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump() return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
class PluginFetchAppInfoApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestFetchAppInfo)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestFetchAppInfo):
return BaseBackwardsInvocationResponse(
data=PluginAppBackwardsInvocation.fetch_app_info(payload.app_id, tenant_model.id)
).model_dump()
api.add_resource(PluginInvokeLLMApi, "/invoke/llm") api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding") api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank") api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
@ -291,3 +303,4 @@ api.add_resource(PluginInvokeAppApi, "/invoke/app")
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt") api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")
api.add_resource(PluginInvokeSummaryApi, "/invoke/summary") api.add_resource(PluginInvokeSummaryApi, "/invoke/summary")
api.add_resource(PluginUploadFileRequestApi, "/upload/file/request") api.add_resource(PluginUploadFileRequestApi, "/upload/file/request")
api.add_resource(PluginFetchAppInfoApi, "/fetch/app/info")

@ -1,10 +1,10 @@
from flask_restful import Resource, marshal_with # type: ignore from flask_restful import Resource, marshal_with # type: ignore
from controllers.common import fields from controllers.common import fields
from controllers.common import helpers as controller_helpers
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token from controllers.service_api.wraps import validate_app_token
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import App, AppMode from models.model import App, AppMode
from services.app_service import AppService from services.app_service import AppService
@ -32,9 +32,7 @@ class AppParameterApi(Resource):
user_input_form = features_dict.get("user_input_form", []) user_input_form = features_dict.get("user_input_form", [])
return controller_helpers.get_parameters_from_feature_dict( return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
features_dict=features_dict, user_input_form=user_input_form
)
class AppMetaApi(Resource): class AppMetaApi(Resource):

@ -1,6 +1,6 @@
import logging import logging
from datetime import datetime
from dateutil.parser import isoparse
from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range # type: ignore from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -140,10 +140,10 @@ class WorkflowAppLogApi(Resource):
args.status = WorkflowRunStatus(args.status) if args.status else None args.status = WorkflowRunStatus(args.status) if args.status else None
if args.created_at__before: if args.created_at__before:
args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00")) args.created_at__before = isoparse(args.created_at__before)
if args.created_at__after: if args.created_at__after:
args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00")) args.created_at__after = isoparse(args.created_at__after)
# get paginate workflow app logs # get paginate workflow app logs
workflow_app_service = WorkflowAppService() workflow_app_service = WorkflowAppService()

@ -122,6 +122,8 @@ class SegmentApi(DatasetApiResource):
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
status_list=args["status"], status_list=args["status"],
keyword=args["keyword"], keyword=args["keyword"],
page=page,
limit=limit,
) )
response = { response = {

@ -1,10 +1,10 @@
from flask_restful import marshal_with # type: ignore from flask_restful import marshal_with # type: ignore
from controllers.common import fields from controllers.common import fields
from controllers.common import helpers as controller_helpers
from controllers.web import api from controllers.web import api
from controllers.web.error import AppUnavailableError from controllers.web.error import AppUnavailableError
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import App, AppMode from models.model import App, AppMode
from services.app_service import AppService from services.app_service import AppService
@ -31,9 +31,7 @@ class AppParameterApi(WebApiResource):
user_input_form = features_dict.get("user_input_form", []) user_input_form = features_dict.get("user_input_form", [])
return controller_helpers.get_parameters_from_feature_dict( return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
features_dict=features_dict, user_input_form=user_input_form
)
class AppMeta(WebApiResource): class AppMeta(WebApiResource):

@ -52,6 +52,7 @@ class AgentStrategyParameter(PluginParameter):
return cast_parameter_value(self, value) return cast_parameter_value(self, value)
type: AgentStrategyParameterType = Field(..., description="The type of the parameter") type: AgentStrategyParameterType = Field(..., description="The type of the parameter")
help: Optional[I18nObject] = None
def init_frontend_parameter(self, value: Any): def init_frontend_parameter(self, value: Any):
return init_frontend_parameter(self, self.type, value) return init_frontend_parameter(self, self.type, value)

@ -0,0 +1,45 @@
from collections.abc import Mapping
from typing import Any
from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
def get_parameters_from_feature_dict(
*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]
) -> Mapping[str, Any]:
"""
Mapping from feature dict to webapp parameters
"""
return {
"opening_statement": features_dict.get("opening_statement"),
"suggested_questions": features_dict.get("suggested_questions", []),
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
"user_input_form": user_input_form,
"sensitive_word_avoidance": features_dict.get(
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
),
"file_upload": features_dict.get(
"file_upload",
{
"image": {
"enabled": False,
"number_limits": DEFAULT_FILE_NUMBER_LIMITS,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"],
}
},
),
"system_parameters": {
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
},
}

@ -17,6 +17,7 @@ class BaseAppGenerator:
user_inputs: Optional[Mapping[str, Any]], user_inputs: Optional[Mapping[str, Any]],
variables: Sequence["VariableEntity"], variables: Sequence["VariableEntity"],
tenant_id: str, tenant_id: str,
strict_type_validation: bool = False,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
user_inputs = user_inputs or {} user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values # Filter input variables from form configuration, handle required fields, default values, and option values
@ -37,6 +38,7 @@ class BaseAppGenerator:
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
), ),
strict_type_validation=strict_type_validation,
) )
for k, v in user_inputs.items() for k, v in user_inputs.items()
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE

@ -92,6 +92,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
mappings=files, mappings=files,
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
config=file_extra_config, config=file_extra_config,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
) )
# convert to app config # convert to app config
@ -114,7 +115,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_config=app_config, app_config=app_config,
file_upload_config=file_extra_config, file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs( inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id user_inputs=inputs,
variables=app_config.variables,
tenant_id=app_model.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
), ),
files=list(system_files), files=list(system_files),
user_id=user.id, user_id=user.id,

@ -2,6 +2,7 @@ from collections.abc import Generator, Mapping
from typing import Optional, Union from typing import Optional, Union
from controllers.service_api.wraps import create_or_update_end_user_for_user_id from controllers.service_api.wraps import create_or_update_end_user_for_user_id
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
from core.app.apps.chat.app_generator import ChatAppGenerator from core.app.apps.chat.app_generator import ChatAppGenerator
@ -15,6 +16,34 @@ from models.model import App, AppMode, EndUser
class PluginAppBackwardsInvocation(BaseBackwardsInvocation): class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
@classmethod
def fetch_app_info(cls, app_id: str, tenant_id: str) -> Mapping:
"""
Fetch app info
"""
app = cls._get_app(app_id, tenant_id)
"""Retrieve app parameters."""
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app.workflow
if workflow is None:
raise ValueError("unexpected app type")
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app.app_model_config
if app_model_config is None:
raise ValueError("unexpected app type")
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
return {
"data": get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form),
}
@classmethod @classmethod
def invoke_app( def invoke_app(
cls, cls,

@ -204,3 +204,11 @@ class RequestRequestUploadFile(BaseModel):
filename: str filename: str
mimetype: str mimetype: str
class RequestFetchAppInfo(BaseModel):
"""
Request to fetch app info
"""
app_id: str

@ -155,9 +155,28 @@ class AnswerStreamProcessor(StreamProcessor):
for answer_node_id, route_position in self.route_position.items(): for answer_node_id, route_position in self.route_position.items():
if answer_node_id not in self.rest_node_ids: if answer_node_id not in self.rest_node_ids:
continue continue
# exclude current node id # Remove current node id from answer dependencies to support stream output if it is a success branch
answer_dependencies = self.generate_routes.answer_dependencies answer_dependencies = self.generate_routes.answer_dependencies
if event.node_id in answer_dependencies[answer_node_id]: edge_mapping = self.graph.edge_mapping.get(event.node_id)
success_edge = (
next(
(
edge
for edge in edge_mapping
if edge.run_condition
and edge.run_condition.type == "branch_identify"
and edge.run_condition.branch_identify == "success-branch"
),
None,
)
if edge_mapping
else None
)
if (
event.node_id in answer_dependencies[answer_node_id]
and success_edge
and success_edge.target_node_id == answer_node_id
):
answer_dependencies[answer_node_id].remove(event.node_id) answer_dependencies[answer_node_id].remove(event.node_id)
answer_dependencies_ids = answer_dependencies.get(answer_node_id, []) answer_dependencies_ids = answer_dependencies.get(answer_node_id, [])
# all depends on answer node id not in rest node ids # all depends on answer node id not in rest node ids

@ -52,6 +52,7 @@ def build_from_mapping(
mapping: Mapping[str, Any], mapping: Mapping[str, Any],
tenant_id: str, tenant_id: str,
config: FileUploadConfig | None = None, config: FileUploadConfig | None = None,
strict_type_validation: bool = False,
) -> File: ) -> File:
transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method")) transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
@ -69,6 +70,7 @@ def build_from_mapping(
mapping=mapping, mapping=mapping,
tenant_id=tenant_id, tenant_id=tenant_id,
transfer_method=transfer_method, transfer_method=transfer_method,
strict_type_validation=strict_type_validation,
) )
if config and not _is_file_valid_with_config( if config and not _is_file_valid_with_config(
@ -87,12 +89,14 @@ def build_from_mappings(
mappings: Sequence[Mapping[str, Any]], mappings: Sequence[Mapping[str, Any]],
config: FileUploadConfig | None = None, config: FileUploadConfig | None = None,
tenant_id: str, tenant_id: str,
strict_type_validation: bool = False,
) -> Sequence[File]: ) -> Sequence[File]:
files = [ files = [
build_from_mapping( build_from_mapping(
mapping=mapping, mapping=mapping,
tenant_id=tenant_id, tenant_id=tenant_id,
config=config, config=config,
strict_type_validation=strict_type_validation,
) )
for mapping in mappings for mapping in mappings
] ]
@ -116,6 +120,7 @@ def _build_from_local_file(
mapping: Mapping[str, Any], mapping: Mapping[str, Any],
tenant_id: str, tenant_id: str,
transfer_method: FileTransferMethod, transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File: ) -> File:
upload_file_id = mapping.get("upload_file_id") upload_file_id = mapping.get("upload_file_id")
if not upload_file_id: if not upload_file_id:
@ -134,10 +139,16 @@ def _build_from_local_file(
if row is None: if row is None:
raise ValueError("Invalid upload file") raise ValueError("Invalid upload file")
file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) detected_file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
if file_type.value != mapping.get("type", "custom"): specified_type = mapping.get("type", "custom")
if strict_type_validation and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.") raise ValueError("Detected file type does not match the specified type. Please verify the file.")
file_type = (
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
)
return File( return File(
id=mapping.get("id"), id=mapping.get("id"),
filename=row.name, filename=row.name,
@ -158,6 +169,7 @@ def _build_from_remote_url(
mapping: Mapping[str, Any], mapping: Mapping[str, Any],
tenant_id: str, tenant_id: str,
transfer_method: FileTransferMethod, transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File: ) -> File:
upload_file_id = mapping.get("upload_file_id") upload_file_id = mapping.get("upload_file_id")
if upload_file_id: if upload_file_id:
@ -174,10 +186,21 @@ def _build_from_remote_url(
if upload_file is None: if upload_file is None:
raise ValueError("Invalid upload file") raise ValueError("Invalid upload file")
file_type = _standardize_file_type(extension="." + upload_file.extension, mime_type=upload_file.mime_type) detected_file_type = _standardize_file_type(
if file_type.value != mapping.get("type", "custom"): extension="." + upload_file.extension, mime_type=upload_file.mime_type
)
specified_type = mapping.get("type")
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.") raise ValueError("Detected file type does not match the specified type. Please verify the file.")
file_type = (
FileType(specified_type)
if specified_type and specified_type != FileType.CUSTOM.value
else detected_file_type
)
return File( return File(
id=mapping.get("id"), id=mapping.get("id"),
filename=upload_file.name, filename=upload_file.name,
@ -237,6 +260,7 @@ def _build_from_tool_file(
mapping: Mapping[str, Any], mapping: Mapping[str, Any],
tenant_id: str, tenant_id: str,
transfer_method: FileTransferMethod, transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File: ) -> File:
tool_file = ( tool_file = (
db.session.query(ToolFile) db.session.query(ToolFile)
@ -252,7 +276,16 @@ def _build_from_tool_file(
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype) detected_file_type = _standardize_file_type(extension="." + extension, mime_type=tool_file.mimetype)
specified_type = mapping.get("type")
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
file_type = (
FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
)
return File( return File(
id=mapping.get("id"), id=mapping.get("id"),

@ -2120,7 +2120,7 @@ class SegmentService:
dataset_id=dataset.id, dataset_id=dataset.id,
document_id=document.id, document_id=document.id,
segment_id=segment.id, segment_id=segment.id,
position=max_position + 1, position=max_position + 1 if max_position else 1,
index_node_id=index_node_id, index_node_id=index_node_id,
index_node_hash=index_node_hash, index_node_hash=index_node_hash,
content=content, content=content,
@ -2270,7 +2270,13 @@ class SegmentService:
@classmethod @classmethod
def get_segments( def get_segments(
cls, document_id: str, tenant_id: str, status_list: list[str] | None = None, keyword: str | None = None cls,
document_id: str,
tenant_id: str,
status_list: list[str] | None = None,
keyword: str | None = None,
page: int = 1,
limit: int = 20,
): ):
"""Get segments for a document with optional filtering.""" """Get segments for a document with optional filtering."""
query = DocumentSegment.query.filter( query = DocumentSegment.query.filter(
@ -2283,10 +2289,11 @@ class SegmentService:
if keyword: if keyword:
query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%")) query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%"))
segments = query.order_by(DocumentSegment.position.asc()).all() paginated_segments = query.order_by(DocumentSegment.position.asc()).paginate(
total = len(segments) page=page, per_page=limit, max_per_page=100, error_out=False
)
return segments, total return paginated_segments.items, paginated_segments.total
@classmethod @classmethod
def update_segment_by_id( def update_segment_by_id(

@ -1,14 +1,20 @@
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent, GraphRunPartialSucceededEvent,
NodeRunExceptionEvent, NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunStreamChunkEvent, NodeRunStreamChunkEvent,
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.nodes.llm.node import LLMNode
from models.enums import UserFrom from models.enums import UserFrom
from models.workflow import WorkflowType from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
class ContinueOnErrorTestHelper: class ContinueOnErrorTestHelper:
@ -492,10 +498,7 @@ def test_no_node_in_fail_branch_continue_on_error():
"edges": FAIL_BRANCH_EDGES[:-1], "edges": FAIL_BRANCH_EDGES[:-1],
"nodes": [ "nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{ {"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, "id": "success"},
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
"id": "success",
},
ContinueOnErrorTestHelper.get_http_node(), ContinueOnErrorTestHelper.get_http_node(),
], ],
} }
@ -506,3 +509,47 @@ def test_no_node_in_fail_branch_continue_on_error():
assert any(isinstance(e, NodeRunExceptionEvent) for e in events) assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events) assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events)
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0 assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0
def test_stream_output_with_fail_branch_continue_on_error():
"""Test stream output with fail-branch error strategy"""
graph_config = {
"edges": FAIL_BRANCH_EDGES,
"nodes": [
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
{
"data": {"title": "success", "type": "answer", "answer": "LLM request successful"},
"id": "success",
},
{
"data": {"title": "error", "type": "answer", "answer": "{{#node.text#}}"},
"id": "error",
},
ContinueOnErrorTestHelper.get_llm_node(),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
def llm_generator(self):
contents = ["hi", "bye", "good morning"]
yield RunStreamChunkEvent(chunk_content=contents[0], from_variable_selector=[self.node_id, "text"])
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={},
process_data={},
outputs={},
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: 1,
NodeRunMetadataKey.TOTAL_PRICE: 1,
NodeRunMetadataKey.CURRENCY: "USD",
},
)
)
with patch.object(LLMNode, "_run", new=llm_generator):
events = list(graph_engine.run())
assert sum(isinstance(e, NodeRunStreamChunkEvent) for e in events) == 1
assert all(not isinstance(e, NodeRunFailedEvent | NodeRunExceptionEvent) for e in events)

@ -0,0 +1,198 @@
import uuid
from unittest.mock import MagicMock, patch
import pytest
from httpx import Response
from factories.file_factory import (
File,
FileTransferMethod,
FileType,
FileUploadConfig,
build_from_mapping,
)
from models import ToolFile, UploadFile
# Test Data
TEST_TENANT_ID = "test_tenant_id"
TEST_UPLOAD_FILE_ID = str(uuid.uuid4())
TEST_TOOL_FILE_ID = str(uuid.uuid4())
TEST_REMOTE_URL = "http://example.com/test.jpg"
# Test Config
TEST_CONFIG = FileUploadConfig(
allowed_file_types=["image", "document"],
allowed_file_extensions=[".jpg", ".pdf"],
allowed_file_upload_methods=[FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE],
number_limits=10,
)
# Fixtures
@pytest.fixture
def mock_upload_file():
mock = MagicMock(spec=UploadFile)
mock.id = TEST_UPLOAD_FILE_ID
mock.tenant_id = TEST_TENANT_ID
mock.name = "test.jpg"
mock.extension = "jpg"
mock.mime_type = "image/jpeg"
mock.source_url = TEST_REMOTE_URL
mock.size = 1024
mock.key = "test_key"
with patch("factories.file_factory.db.session.scalar", return_value=mock) as m:
yield m
@pytest.fixture
def mock_tool_file():
mock = MagicMock(spec=ToolFile)
mock.id = TEST_TOOL_FILE_ID
mock.tenant_id = TEST_TENANT_ID
mock.name = "tool_file.pdf"
mock.file_key = "tool_file.pdf"
mock.mimetype = "application/pdf"
mock.original_url = "http://example.com/tool.pdf"
mock.size = 2048
with patch("factories.file_factory.db.session.query") as mock_query:
mock_query.return_value.filter.return_value.first.return_value = mock
yield mock
@pytest.fixture
def mock_http_head():
def _mock_response(filename, size, content_type):
return Response(
status_code=200,
headers={
"Content-Disposition": f'attachment; filename="{filename}"',
"Content-Length": str(size),
"Content-Type": content_type,
},
)
with patch("factories.file_factory.ssrf_proxy.head") as mock_head:
mock_head.return_value = _mock_response("remote_test.jpg", 2048, "image/jpeg")
yield mock_head
# Helper functions
def local_file_mapping(file_type="image"):
return {
"transfer_method": "local_file",
"upload_file_id": TEST_UPLOAD_FILE_ID,
"type": file_type,
}
def tool_file_mapping(file_type="document"):
return {
"transfer_method": "tool_file",
"tool_file_id": TEST_TOOL_FILE_ID,
"type": file_type,
}
# Tests
def test_build_from_mapping_backward_compatibility(mock_upload_file):
mapping = local_file_mapping(file_type="image")
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
assert isinstance(file, File)
assert file.transfer_method == FileTransferMethod.LOCAL_FILE
assert file.type == FileType.IMAGE
assert file.related_id == TEST_UPLOAD_FILE_ID
@pytest.mark.parametrize(
("file_type", "should_pass", "expected_error"),
[
("image", True, None),
("document", False, "Detected file type does not match"),
],
)
def test_build_from_local_file_strict_validation(mock_upload_file, file_type, should_pass, expected_error):
mapping = local_file_mapping(file_type=file_type)
if should_pass:
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
assert file.type == FileType(file_type)
else:
with pytest.raises(ValueError, match=expected_error):
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
@pytest.mark.parametrize(
("file_type", "should_pass", "expected_error"),
[
("document", True, None),
("image", False, "Detected file type does not match"),
],
)
def test_build_from_tool_file_strict_validation(mock_tool_file, file_type, should_pass, expected_error):
"""Strict type validation for tool_file."""
mapping = tool_file_mapping(file_type=file_type)
if should_pass:
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
assert file.type == FileType(file_type)
else:
with pytest.raises(ValueError, match=expected_error):
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
def test_build_from_remote_url(mock_http_head):
mapping = {
"transfer_method": "remote_url",
"url": TEST_REMOTE_URL,
"type": "image",
}
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
assert file.transfer_method == FileTransferMethod.REMOTE_URL
assert file.type == FileType.IMAGE
assert file.filename == "remote_test.jpg"
assert file.size == 2048
def test_tool_file_not_found():
"""Test ToolFile not found in database."""
with patch("factories.file_factory.db.session.query") as mock_query:
mock_query.return_value.filter.return_value.first.return_value = None
mapping = tool_file_mapping()
with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"):
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
def test_local_file_not_found():
"""Test UploadFile not found in database."""
with patch("factories.file_factory.db.session.scalar", return_value=None):
mapping = local_file_mapping()
with pytest.raises(ValueError, match="Invalid upload file"):
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
def test_build_without_type_specification(mock_upload_file):
"""Test the situation where no file type is specified"""
mapping = {
"transfer_method": "local_file",
"upload_file_id": TEST_UPLOAD_FILE_ID,
# leave out the type
}
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
# It should automatically infer the type as "image" based on the file extension
assert file.type == FileType.IMAGE
@pytest.mark.parametrize(
("file_type", "should_pass", "expected_error"),
[
("image", True, None),
("video", False, "File validation failed"),
],
)
def test_file_validation_with_config(mock_upload_file, file_type, should_pass, expected_error):
"""Test the validation of files and configurations"""
mapping = local_file_mapping(file_type=file_type)
if should_pass:
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=TEST_CONFIG)
assert file is not None
else:
with pytest.raises(ValueError, match=expected_error):
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=TEST_CONFIG)

@ -16,7 +16,7 @@ import type {
Feedback, Feedback,
} from '../types' } from '../types'
import { CONVERSATION_ID_INFO } from '../constants' import { CONVERSATION_ID_INFO } from '../constants'
import { buildChatItemTree } from '../utils' import { buildChatItemTree, getProcessedSystemVariablesFromUrlParams } from '../utils'
import { addFileInfos, sortAgentSorts } from '../../../tools/utils' import { addFileInfos, sortAgentSorts } from '../../../tools/utils'
import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils' import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils'
import { import {
@ -106,6 +106,13 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
}, [isInstalledApp, installedAppInfo, appInfo]) }, [isInstalledApp, installedAppInfo, appInfo])
const appId = useMemo(() => appData?.app_id, [appData]) const appId = useMemo(() => appData?.app_id, [appData])
const [userId, setUserId] = useState<string>()
useEffect(() => {
getProcessedSystemVariablesFromUrlParams().then(({ user_id }) => {
setUserId(user_id)
})
}, [])
useEffect(() => { useEffect(() => {
if (appData?.site.default_language) if (appData?.site.default_language)
changeLanguage(appData.site.default_language) changeLanguage(appData.site.default_language)
@ -124,18 +131,24 @@ export const useChatWithHistory = (installedAppInfo?: InstalledApp) => {
setSidebarCollapseState(localState === 'collapsed') setSidebarCollapseState(localState === 'collapsed')
} }
}, [appId]) }, [appId])
const [conversationIdInfo, setConversationIdInfo] = useLocalStorageState<Record<string, string>>(CONVERSATION_ID_INFO, { const [conversationIdInfo, setConversationIdInfo] = useLocalStorageState<Record<string, Record<string, string>>>(CONVERSATION_ID_INFO, {
defaultValue: {}, defaultValue: {},
}) })
const currentConversationId = useMemo(() => conversationIdInfo?.[appId || ''] || '', [appId, conversationIdInfo]) const currentConversationId = useMemo(() => conversationIdInfo?.[appId || '']?.[userId || 'DEFAULT'] || '', [appId, conversationIdInfo, userId])
const handleConversationIdInfoChange = useCallback((changeConversationId: string) => { const handleConversationIdInfoChange = useCallback((changeConversationId: string) => {
if (appId) { if (appId) {
let prevValue = conversationIdInfo?.[appId || '']
if (typeof prevValue === 'string')
prevValue = {}
setConversationIdInfo({ setConversationIdInfo({
...conversationIdInfo, ...conversationIdInfo,
[appId || '']: changeConversationId, [appId || '']: {
...prevValue,
[userId || 'DEFAULT']: changeConversationId,
},
}) })
} }
}, [appId, conversationIdInfo, setConversationIdInfo]) }, [appId, conversationIdInfo, setConversationIdInfo, userId])
const [newConversationId, setNewConversationId] = useState('') const [newConversationId, setNewConversationId] = useState('')
const chatShouldReloadKey = useMemo(() => { const chatShouldReloadKey = useMemo(() => {

@ -2,8 +2,6 @@ import type { FC } from 'react'
import { memo } from 'react' import { memo } from 'react'
import type { ChatItem } from '../../types' import type { ChatItem } from '../../types'
import { useChatContext } from '../context' import { useChatContext } from '../context'
import Button from '@/app/components/base/button'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
type SuggestedQuestionsProps = { type SuggestedQuestionsProps = {
item: ChatItem item: ChatItem
@ -12,9 +10,6 @@ const SuggestedQuestions: FC<SuggestedQuestionsProps> = ({
item, item,
}) => { }) => {
const { onSend } = useChatContext() const { onSend } = useChatContext()
const media = useBreakpoints()
const isMobile = media === MediaType.mobile
const klassName = `mr-1 mt-1 ${isMobile ? 'block overflow-hidden text-ellipsis' : ''} max-w-full shrink-0 last:mr-0`
const { const {
isOpeningStatement, isOpeningStatement,
@ -27,14 +22,13 @@ const SuggestedQuestions: FC<SuggestedQuestionsProps> = ({
return ( return (
<div className='flex flex-wrap'> <div className='flex flex-wrap'>
{suggestedQuestions.filter(q => !!q && q.trim()).map((question, index) => ( {suggestedQuestions.filter(q => !!q && q.trim()).map((question, index) => (
<Button <div
key={index} key={index}
variant='secondary-accent' className='system-sm-medium mr-1 mt-1 inline-flex max-w-full shrink-0 cursor-pointer flex-wrap rounded-lg border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg px-3.5 py-2 text-components-button-secondary-accent-text shadow-xs last:mr-0 hover:border-components-button-secondary-border-hover hover:bg-components-button-secondary-bg-hover'
className={klassName}
onClick={() => onSend?.(question)} onClick={() => onSend?.(question)}
> >
{question} {question}
</Button>), </div>),
)} )}
</div> </div>
) )

@ -184,7 +184,7 @@ const ChatWrapper = () => {
return null return null
if (welcomeMessage.suggestedQuestions && welcomeMessage.suggestedQuestions?.length > 0) { if (welcomeMessage.suggestedQuestions && welcomeMessage.suggestedQuestions?.length > 0) {
return ( return (
<div className='flex h-[50vh] items-center justify-center px-4 py-12'> <div className={cn('flex items-center justify-center px-4 py-12', isMobile ? 'min-h-[30vh] py-0' : 'h-[50vh]')}>
<div className='flex max-w-[720px] grow gap-4'> <div className='flex max-w-[720px] grow gap-4'>
<AppIcon <AppIcon
size='xl' size='xl'
@ -202,7 +202,7 @@ const ChatWrapper = () => {
) )
} }
return ( return (
<div className={cn('flex h-[50vh] flex-col items-center justify-center gap-3 py-12')}> <div className={cn('flex h-[50vh] flex-col items-center justify-center gap-3 py-12', isMobile ? 'min-h-[30vh] py-0' : 'h-[50vh]')}>
<AppIcon <AppIcon
size='xl' size='xl'
iconType={appData?.site.icon_type} iconType={appData?.site.icon_type}

@ -15,7 +15,7 @@ import type {
Feedback, Feedback,
} from '../types' } from '../types'
import { CONVERSATION_ID_INFO } from '../constants' import { CONVERSATION_ID_INFO } from '../constants'
import { buildChatItemTree, getProcessedInputsFromUrlParams } from '../utils' import { buildChatItemTree, getProcessedInputsFromUrlParams, getProcessedSystemVariablesFromUrlParams } from '../utils'
import { getProcessedFilesFromResponse } from '../../file-uploader/utils' import { getProcessedFilesFromResponse } from '../../file-uploader/utils'
import { import {
fetchAppInfo, fetchAppInfo,
@ -72,23 +72,36 @@ export const useEmbeddedChatbot = () => {
}, [appInfo]) }, [appInfo])
const appId = useMemo(() => appData?.app_id, [appData]) const appId = useMemo(() => appData?.app_id, [appData])
const [userId, setUserId] = useState<string>()
useEffect(() => {
getProcessedSystemVariablesFromUrlParams().then(({ user_id }) => {
setUserId(user_id)
})
}, [])
useEffect(() => { useEffect(() => {
if (appInfo?.site.default_language) if (appInfo?.site.default_language)
changeLanguage(appInfo.site.default_language) changeLanguage(appInfo.site.default_language)
}, [appInfo]) }, [appInfo])
const [conversationIdInfo, setConversationIdInfo] = useLocalStorageState<Record<string, string>>(CONVERSATION_ID_INFO, { const [conversationIdInfo, setConversationIdInfo] = useLocalStorageState<Record<string, Record<string, string>>>(CONVERSATION_ID_INFO, {
defaultValue: {}, defaultValue: {},
}) })
const currentConversationId = useMemo(() => conversationIdInfo?.[appId || ''] || '', [appId, conversationIdInfo]) const currentConversationId = useMemo(() => conversationIdInfo?.[appId || '']?.[userId || 'DEFAULT'] || '', [appId, conversationIdInfo, userId])
const handleConversationIdInfoChange = useCallback((changeConversationId: string) => { const handleConversationIdInfoChange = useCallback((changeConversationId: string) => {
if (appId) { if (appId) {
let prevValue = conversationIdInfo?.[appId || '']
if (typeof prevValue === 'string')
prevValue = {}
setConversationIdInfo({ setConversationIdInfo({
...conversationIdInfo, ...conversationIdInfo,
[appId || '']: changeConversationId, [appId || '']: {
...prevValue,
[userId || 'DEFAULT']: changeConversationId,
},
}) })
} }
}, [appId, conversationIdInfo, setConversationIdInfo]) }, [appId, conversationIdInfo, setConversationIdInfo, userId])
const [newConversationId, setNewConversationId] = useState('') const [newConversationId, setNewConversationId] = useState('')
const chatShouldReloadKey = useMemo(() => { const chatShouldReloadKey = useMemo(() => {

@ -7,9 +7,11 @@ import SecretKeyButton from '@/app/components/develop/secret-key/secret-key-butt
type ApiServerProps = { type ApiServerProps = {
apiBaseUrl: string apiBaseUrl: string
appId?: string
} }
const ApiServer: FC<ApiServerProps> = ({ const ApiServer: FC<ApiServerProps> = ({
apiBaseUrl, apiBaseUrl,
appId,
}) => { }) => {
const { t } = useTranslation() const { t } = useTranslation()
@ -25,7 +27,7 @@ const ApiServer: FC<ApiServerProps> = ({
{t('appApi.ok')} {t('appApi.ok')}
</div> </div>
<SecretKeyButton <SecretKeyButton
className='!h-8 shrink-0' className='!h-8 shrink-0' appId={appId}
/> />
</div> </div>
) )

@ -23,7 +23,7 @@ const DevelopMain = ({ appId }: IDevelopMainProps) => {
<div className='relative flex h-full flex-col overflow-hidden'> <div className='relative flex h-full flex-col overflow-hidden'>
<div className='flex shrink-0 items-center justify-between border-b border-solid border-b-divider-regular px-6 py-2'> <div className='flex shrink-0 items-center justify-between border-b border-solid border-b-divider-regular px-6 py-2'>
<div className='text-lg font-medium text-text-primary'></div> <div className='text-lg font-medium text-text-primary'></div>
<ApiServer apiBaseUrl={appDetail.api_base_url} /> <ApiServer apiBaseUrl={appDetail.api_base_url} appId={appId} />
</div> </div>
<div className='grow overflow-auto px-4 py-4 sm:px-10'> <div className='grow overflow-auto px-4 py-4 sm:px-10'>
<Doc appDetail={appDetail} /> <Doc appDetail={appDetail} />

@ -406,8 +406,7 @@ export type VersionProps = {
export type StrategyParamItem = { export type StrategyParamItem = {
name: string name: string
label: Record<Locale, string> label: Record<Locale, string>
human_description: Record<Locale, string> help: Record<Locale, string>
llm_description: string
placeholder: Record<Locale, string> placeholder: Record<Locale, string>
type: string type: string
scope: string scope: string

@ -2,29 +2,44 @@ import { CONVERSATION_ID_INFO } from '../base/chat/constants'
import { fetchAccessToken } from '@/service/share' import { fetchAccessToken } from '@/service/share'
import { getProcessedSystemVariablesFromUrlParams } from '../base/chat/utils' import { getProcessedSystemVariablesFromUrlParams } from '../base/chat/utils'
export const isTokenV1 = (token: Record<string, any>) => {
return !token.version
}
export const getInitialTokenV2 = (): Record<string, any> => ({
version: 2,
})
export const checkOrSetAccessToken = async () => { export const checkOrSetAccessToken = async () => {
const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0] const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' }) const userId = (await getProcessedSystemVariablesFromUrlParams()).user_id
let accessTokenJson = { [sharedToken]: '' } const accessToken = localStorage.getItem('token') || JSON.stringify(getInitialTokenV2())
let accessTokenJson = getInitialTokenV2()
try { try {
accessTokenJson = JSON.parse(accessToken) accessTokenJson = JSON.parse(accessToken)
if (isTokenV1(accessTokenJson))
accessTokenJson = getInitialTokenV2()
} }
catch { catch {
} }
if (!accessTokenJson[sharedToken]) { if (!accessTokenJson[sharedToken]?.[userId || 'DEFAULT']) {
const sysUserId = (await getProcessedSystemVariablesFromUrlParams()).user_id const res = await fetchAccessToken(sharedToken, userId)
const res = await fetchAccessToken(sharedToken, sysUserId) accessTokenJson[sharedToken] = {
accessTokenJson[sharedToken] = res.access_token ...accessTokenJson[sharedToken],
[userId || 'DEFAULT']: res.access_token,
}
localStorage.setItem('token', JSON.stringify(accessTokenJson)) localStorage.setItem('token', JSON.stringify(accessTokenJson))
} }
} }
export const setAccessToken = async (sharedToken: string, token: string) => { export const setAccessToken = async (sharedToken: string, token: string, user_id?: string) => {
const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' }) const accessToken = localStorage.getItem('token') || JSON.stringify(getInitialTokenV2())
let accessTokenJson = { [sharedToken]: '' } let accessTokenJson = getInitialTokenV2()
try { try {
accessTokenJson = JSON.parse(accessToken) accessTokenJson = JSON.parse(accessToken)
if (isTokenV1(accessTokenJson))
accessTokenJson = getInitialTokenV2()
} }
catch { catch {
@ -32,17 +47,22 @@ export const setAccessToken = async (sharedToken: string, token: string) => {
localStorage.removeItem(CONVERSATION_ID_INFO) localStorage.removeItem(CONVERSATION_ID_INFO)
accessTokenJson[sharedToken] = token accessTokenJson[sharedToken] = {
...accessTokenJson[sharedToken],
[user_id || 'DEFAULT']: token,
}
localStorage.setItem('token', JSON.stringify(accessTokenJson)) localStorage.setItem('token', JSON.stringify(accessTokenJson))
} }
export const removeAccessToken = () => { export const removeAccessToken = () => {
const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0] const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' }) const accessToken = localStorage.getItem('token') || JSON.stringify(getInitialTokenV2())
let accessTokenJson = { [sharedToken]: '' } let accessTokenJson = getInitialTokenV2()
try { try {
accessTokenJson = JSON.parse(accessToken) accessTokenJson = JSON.parse(accessToken)
if (isTokenV1(accessTokenJson))
accessTokenJson = getInitialTokenV2()
} }
catch { catch {

@ -65,7 +65,7 @@ export const AgentStrategy = memo((props: AgentStrategyProps) => {
switch (schema.type) { switch (schema.type) {
case FormTypeEnum.textInput: { case FormTypeEnum.textInput: {
const def = schema as CredentialFormSchemaTextInput const def = schema as CredentialFormSchemaTextInput
const value = props.value[schema.variable] const value = props.value[schema.variable] || schema.default
const onChange = (value: string) => { const onChange = (value: string) => {
props.onChange({ ...props.value, [schema.variable]: value }) props.onChange({ ...props.value, [schema.variable]: value })
} }

@ -27,6 +27,7 @@ export function strategyParamToCredientialForm(param: StrategyParamItem): Creden
variable: param.name, variable: param.name,
show_on: [], show_on: [],
type: toType(param.type), type: toType(param.type),
tooltip: param.help,
} }
} }
@ -53,6 +54,7 @@ const AgentPanel: FC<NodePanelProps<AgentNodeType>> = (props) => {
outputSchema, outputSchema,
handleMemoryChange, handleMemoryChange,
} = useConfig(props.id, props.data) } = useConfig(props.id, props.data)
console.log('currentStrategy', currentStrategy)
const { t } = useTranslation() const { t } = useTranslation()
const nodeInfo = useMemo(() => { const nodeInfo = useMemo(() => {
if (!runResult) if (!runResult)

@ -287,9 +287,9 @@ const handleStream = (
const baseFetch = base const baseFetch = base
export const upload = (options: any, isPublicAPI?: boolean, url?: string, searchParams?: string): Promise<any> => { export const upload = async (options: any, isPublicAPI?: boolean, url?: string, searchParams?: string): Promise<any> => {
const urlPrefix = isPublicAPI ? PUBLIC_API_PREFIX : API_PREFIX const urlPrefix = isPublicAPI ? PUBLIC_API_PREFIX : API_PREFIX
const token = getAccessToken(isPublicAPI) const token = await getAccessToken(isPublicAPI)
const defaultOptions = { const defaultOptions = {
method: 'POST', method: 'POST',
url: (url ? `${urlPrefix}${url}` : `${urlPrefix}/files/upload`) + (searchParams || ''), url: (url ? `${urlPrefix}${url}` : `${urlPrefix}/files/upload`) + (searchParams || ''),
@ -324,7 +324,7 @@ export const upload = (options: any, isPublicAPI?: boolean, url?: string, search
}) })
} }
export const ssePost = ( export const ssePost = async (
url: string, url: string,
fetchOptions: FetchOptionType, fetchOptions: FetchOptionType,
otherOptions: IOtherOptions, otherOptions: IOtherOptions,
@ -385,7 +385,7 @@ export const ssePost = (
if (body) if (body)
options.body = JSON.stringify(body) options.body = JSON.stringify(body)
const accessToken = getAccessToken(isPublicAPI) const accessToken = await getAccessToken(isPublicAPI)
; (options.headers as Headers).set('Authorization', `Bearer ${accessToken}`) ; (options.headers as Headers).set('Authorization', `Bearer ${accessToken}`)
globalThis.fetch(urlWithPrefix, options as RequestInit) globalThis.fetch(urlWithPrefix, options as RequestInit)

@ -3,6 +3,8 @@ import ky from 'ky'
import type { IOtherOptions } from './base' import type { IOtherOptions } from './base'
import Toast from '@/app/components/base/toast' import Toast from '@/app/components/base/toast'
import { API_PREFIX, MARKETPLACE_API_PREFIX, PUBLIC_API_PREFIX } from '@/config' import { API_PREFIX, MARKETPLACE_API_PREFIX, PUBLIC_API_PREFIX } from '@/config'
import { getInitialTokenV2, isTokenV1 } from '@/app/components/share/utils'
import { getProcessedSystemVariablesFromUrlParams } from '@/app/components/base/chat/utils'
const TIME_OUT = 100000 const TIME_OUT = 100000
@ -67,44 +69,34 @@ const beforeErrorToast = (otherOptions: IOtherOptions): BeforeErrorHook => {
} }
} }
export const getPublicToken = () => { export async function getAccessToken(isPublicAPI?: boolean) {
let token = ''
const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
let accessTokenJson = { [sharedToken]: '' }
try {
accessTokenJson = JSON.parse(accessToken)
}
catch { }
token = accessTokenJson[sharedToken]
return token || ''
}
export function getAccessToken(isPublicAPI?: boolean) {
if (isPublicAPI) { if (isPublicAPI) {
const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0] const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' }) const userId = (await getProcessedSystemVariablesFromUrlParams()).user_id
let accessTokenJson = { [sharedToken]: '' } const accessToken = localStorage.getItem('token') || JSON.stringify({ version: 2 })
let accessTokenJson: Record<string, any> = { version: 2 }
try { try {
accessTokenJson = JSON.parse(accessToken) accessTokenJson = JSON.parse(accessToken)
if (isTokenV1(accessTokenJson))
accessTokenJson = getInitialTokenV2()
} }
catch { catch {
} }
return accessTokenJson[sharedToken] return accessTokenJson[sharedToken]?.[userId || 'DEFAULT']
} }
else { else {
return localStorage.getItem('console_token') || '' return localStorage.getItem('console_token') || ''
} }
} }
const beforeRequestPublicAuthorization: BeforeRequestHook = (request) => { const beforeRequestPublicAuthorization: BeforeRequestHook = async (request) => {
const token = getAccessToken(true) const token = await getAccessToken(true)
request.headers.set('Authorization', `Bearer ${token}`) request.headers.set('Authorization', `Bearer ${token}`)
} }
const beforeRequestAuthorization: BeforeRequestHook = (request) => { const beforeRequestAuthorization: BeforeRequestHook = async (request) => {
const accessToken = getAccessToken() const accessToken = await getAccessToken()
request.headers.set('Authorization', `Bearer ${accessToken}`) request.headers.set('Authorization', `Bearer ${accessToken}`)
} }

Loading…
Cancel
Save