merge main

pull/21369/head
jZonG 11 months ago
commit 37c6bdec3d

@ -84,10 +84,8 @@ jobs:
elasticsearch elasticsearch
oceanbase oceanbase
- name: Check VDB Ready (TiDB, Oceanbase) - name: Check VDB Ready (TiDB)
run: | run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
uv run --project api python api/tests/integration_tests/vdb/oceanbase/check_oceanbase_ready.py
- name: Test Vector Stores - name: Test Vector Stores
run: uv run --project api bash dev/pytest/pytest_vdb.sh run: uv run --project api bash dev/pytest/pytest_vdb.sh

4
.gitignore vendored

@ -179,6 +179,7 @@ docker/volumes/pgvecto_rs/data/*
docker/volumes/couchbase/* docker/volumes/couchbase/*
docker/volumes/oceanbase/* docker/volumes/oceanbase/*
docker/volumes/plugin_daemon/* docker/volumes/plugin_daemon/*
docker/volumes/matrixone/*
!docker/volumes/oceanbase/init.d !docker/volumes/oceanbase/init.d
docker/nginx/conf.d/default.conf docker/nginx/conf.d/default.conf
@ -210,3 +211,6 @@ mise.toml
# Next.js build output # Next.js build output
.next/ .next/
# AI Assistant
.roo/

@ -137,7 +137,7 @@ WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Vector database configuration # Vector database configuration
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore # support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore, matrixone
VECTOR_STORE=weaviate VECTOR_STORE=weaviate
# Weaviate configuration # Weaviate configuration
@ -294,6 +294,13 @@ VIKINGDB_SCHEMA=http
VIKINGDB_CONNECTION_TIMEOUT=30 VIKINGDB_CONNECTION_TIMEOUT=30
VIKINGDB_SOCKET_TIMEOUT=30 VIKINGDB_SOCKET_TIMEOUT=30
# Matrixone configration
MATRIXONE_HOST=127.0.0.1
MATRIXONE_PORT=6001
MATRIXONE_USER=dump
MATRIXONE_PASSWORD=111
MATRIXONE_DATABASE=dify
# Lindorm configuration # Lindorm configuration
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070 LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
LINDORM_USERNAME=admin LINDORM_USERNAME=admin
@ -332,9 +339,11 @@ PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024 CODE_GENERATION_MAX_TOKENS=1024
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
# Mail configuration, support: resend, smtp # Mail configuration, support: resend, smtp, sendgrid
MAIL_TYPE= MAIL_TYPE=
# If using SendGrid, use the 'from' field for authentication if necessary.
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai> MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
# resend configuration
RESEND_API_KEY= RESEND_API_KEY=
RESEND_API_URL=https://api.resend.com RESEND_API_URL=https://api.resend.com
# smtp configuration # smtp configuration
@ -344,7 +353,8 @@ SMTP_USERNAME=123
SMTP_PASSWORD=abc SMTP_PASSWORD=abc
SMTP_USE_TLS=true SMTP_USE_TLS=true
SMTP_OPPORTUNISTIC_TLS=false SMTP_OPPORTUNISTIC_TLS=false
# Sendgid configuration
SENDGRID_API_KEY=
# Sentry configuration # Sentry configuration
SENTRY_DSN= SENTRY_DSN=

@ -281,6 +281,7 @@ def migrate_knowledge_vector_database():
VectorType.ELASTICSEARCH, VectorType.ELASTICSEARCH,
VectorType.OPENGAUSS, VectorType.OPENGAUSS,
VectorType.TABLESTORE, VectorType.TABLESTORE,
VectorType.MATRIXONE,
} }
lower_collection_vector_types = { lower_collection_vector_types = {
VectorType.ANALYTICDB, VectorType.ANALYTICDB,

@ -609,7 +609,7 @@ class MailConfig(BaseSettings):
""" """
MAIL_TYPE: Optional[str] = Field( MAIL_TYPE: Optional[str] = Field(
description="Email service provider type ('smtp' or 'resend'), default to None.", description="Email service provider type ('smtp' or 'resend' or 'sendGrid), default to None.",
default=None, default=None,
) )
@ -663,6 +663,11 @@ class MailConfig(BaseSettings):
default=50, default=50,
) )
SENDGRID_API_KEY: Optional[str] = Field(
description="API key for SendGrid service",
default=None,
)
class RagEtlConfig(BaseSettings): class RagEtlConfig(BaseSettings):
""" """

@ -24,6 +24,7 @@ from .vdb.couchbase_config import CouchbaseConfig
from .vdb.elasticsearch_config import ElasticsearchConfig from .vdb.elasticsearch_config import ElasticsearchConfig
from .vdb.huawei_cloud_config import HuaweiCloudConfig from .vdb.huawei_cloud_config import HuaweiCloudConfig
from .vdb.lindorm_config import LindormConfig from .vdb.lindorm_config import LindormConfig
from .vdb.matrixone_config import MatrixoneConfig
from .vdb.milvus_config import MilvusConfig from .vdb.milvus_config import MilvusConfig
from .vdb.myscale_config import MyScaleConfig from .vdb.myscale_config import MyScaleConfig
from .vdb.oceanbase_config import OceanBaseVectorConfig from .vdb.oceanbase_config import OceanBaseVectorConfig
@ -323,5 +324,6 @@ class MiddlewareConfig(
OpenGaussConfig, OpenGaussConfig,
TableStoreConfig, TableStoreConfig,
DatasetQueueMonitorConfig, DatasetQueueMonitorConfig,
MatrixoneConfig,
): ):
pass pass

@ -0,0 +1,14 @@
from pydantic import BaseModel, Field
class MatrixoneConfig(BaseModel):
"""Matrixone vector database configuration."""
MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server")
MATRIXONE_PORT: int = Field(default=6001, description="Port number of the Matrixone server")
MATRIXONE_USER: str = Field(default="dump", description="Username for authenticating with Matrixone")
MATRIXONE_PASSWORD: str = Field(default="111", description="Password for authenticating with Matrixone")
MATRIXONE_DATABASE: str = Field(default="dify", description="Name of the Matrixone database to connect to")
MATRIXONE_METRIC: str = Field(
default="l2", description="Distance metric type for vector similarity search (cosine or l2)"
)

@ -208,7 +208,7 @@ class AnnotationBatchImportApi(Resource):
if len(request.files) > 1: if len(request.files) > 1:
raise TooManyFilesError() raise TooManyFilesError()
# check file type # check file type
if not file.filename or not file.filename.endswith(".csv"): if not file.filename or not file.filename.lower().endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed") raise ValueError("Invalid file type. Only CSV files are allowed")
return AppAnnotationService.batch_import_app_annotations(app_id, file) return AppAnnotationService.batch_import_app_annotations(app_id, file)

@ -34,6 +34,20 @@ class WorkflowAppLogApi(Resource):
parser.add_argument( parser.add_argument(
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp" "created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
) )
parser.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
parser.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args() args = parser.parse_args()
@ -57,6 +71,8 @@ class WorkflowAppLogApi(Resource):
created_at_after=args.created_at__after, created_at_after=args.created_at__after,
page=args.page, page=args.page,
limit=args.limit, limit=args.limit,
created_by_end_user_session_id=args.created_by_end_user_session_id,
created_by_account=args.created_by_account,
) )
return workflow_app_log_pagination return workflow_app_log_pagination

@ -686,6 +686,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.TABLESTORE | VectorType.TABLESTORE
| VectorType.HUAWEI_CLOUD | VectorType.HUAWEI_CLOUD
| VectorType.TENCENT | VectorType.TENCENT
| VectorType.MATRIXONE
): ):
return { return {
"retrieval_method": [ "retrieval_method": [
@ -733,6 +734,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.TABLESTORE | VectorType.TABLESTORE
| VectorType.TENCENT | VectorType.TENCENT
| VectorType.HUAWEI_CLOUD | VectorType.HUAWEI_CLOUD
| VectorType.MATRIXONE
): ):
return { return {
"retrieval_method": [ "retrieval_method": [

@ -374,7 +374,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
if len(request.files) > 1: if len(request.files) > 1:
raise TooManyFilesError() raise TooManyFilesError()
# check file type # check file type
if not file.filename or not file.filename.endswith(".csv"): if not file.filename or not file.filename.lower().endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed") raise ValueError("Invalid file type. Only CSV files are allowed")
try: try:

@ -135,6 +135,20 @@ class WorkflowAppLogApi(Resource):
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
parser.add_argument("created_at__before", type=str, location="args") parser.add_argument("created_at__before", type=str, location="args")
parser.add_argument("created_at__after", type=str, location="args") parser.add_argument("created_at__after", type=str, location="args")
parser.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
parser.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args() args = parser.parse_args()
@ -158,6 +172,8 @@ class WorkflowAppLogApi(Resource):
created_at_after=args.created_at__after, created_at_after=args.created_at__after,
page=args.page, page=args.page,
limit=args.limit, limit=args.limit,
created_by_end_user_session_id=args.created_by_end_user_session_id,
created_by_account=args.created_by_account,
) )
return workflow_app_log_pagination return workflow_app_log_pagination

@ -5,7 +5,11 @@ from werkzeug.exceptions import Forbidden, NotFound
import services.dataset_service import services.dataset_service
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError
from controllers.service_api.wraps import DatasetApiResource, validate_dataset_token from controllers.service_api.wraps import (
DatasetApiResource,
cloud_edition_billing_rate_limit_check,
validate_dataset_token,
)
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
@ -70,6 +74,7 @@ class DatasetListApi(DatasetApiResource):
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
return response, 200 return response, 200
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id): def post(self, tenant_id):
"""Resource for creating datasets.""" """Resource for creating datasets."""
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -193,6 +198,7 @@ class DatasetApi(DatasetApiResource):
return data, 200 return data, 200
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, _, dataset_id): def patch(self, _, dataset_id):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
@ -293,6 +299,7 @@ class DatasetApi(DatasetApiResource):
return result_data, 200 return result_data, 200
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, _, dataset_id): def delete(self, _, dataset_id):
""" """
Deletes a dataset given its ID. Deletes a dataset given its ID.

@ -19,7 +19,11 @@ from controllers.service_api.dataset.error import (
ArchivedDocumentImmutableError, ArchivedDocumentImmutableError,
DocumentIndexingError, DocumentIndexingError,
) )
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check from controllers.service_api.wraps import (
DatasetApiResource,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
)
from core.errors.error import ProviderTokenNotInitError from core.errors.error import ProviderTokenNotInitError
from extensions.ext_database import db from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields from fields.document_fields import document_fields, document_status_fields
@ -35,6 +39,7 @@ class DocumentAddByTextApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_resource_check("documents", "dataset") @cloud_edition_billing_resource_check("documents", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id): def post(self, tenant_id, dataset_id):
"""Create document by text.""" """Create document by text."""
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -99,6 +104,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
"""Resource for update documents.""" """Resource for update documents."""
@cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id): def post(self, tenant_id, dataset_id, document_id):
"""Update document by text.""" """Update document by text."""
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -158,6 +164,7 @@ class DocumentAddByFileApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_resource_check("documents", "dataset") @cloud_edition_billing_resource_check("documents", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id): def post(self, tenant_id, dataset_id):
"""Create document by upload file.""" """Create document by upload file."""
args = {} args = {}
@ -232,6 +239,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
"""Resource for update documents.""" """Resource for update documents."""
@cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id): def post(self, tenant_id, dataset_id, document_id):
"""Update document by upload file.""" """Update document by upload file."""
args = {} args = {}
@ -302,6 +310,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
class DocumentDeleteApi(DatasetApiResource): class DocumentDeleteApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id, document_id): def delete(self, tenant_id, dataset_id, document_id):
"""Delete document.""" """Delete document."""
document_id = str(document_id) document_id = str(document_id)

@ -1,9 +1,10 @@
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.wraps import DatasetApiResource from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id): def post(self, tenant_id, dataset_id):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)

@ -3,7 +3,7 @@ from flask_restful import marshal, reqparse
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.wraps import DatasetApiResource from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
from fields.dataset_fields import dataset_metadata_fields from fields.dataset_fields import dataset_metadata_fields
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import ( from services.entities.knowledge_entities.knowledge_entities import (
@ -14,6 +14,7 @@ from services.metadata_service import MetadataService
class DatasetMetadataCreateServiceApi(DatasetApiResource): class DatasetMetadataCreateServiceApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id): def post(self, tenant_id, dataset_id):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("type", type=str, required=True, nullable=True, location="json") parser.add_argument("type", type=str, required=True, nullable=True, location="json")
@ -39,6 +40,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
class DatasetMetadataServiceApi(DatasetApiResource): class DatasetMetadataServiceApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id, dataset_id, metadata_id): def patch(self, tenant_id, dataset_id, metadata_id):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=True, location="json") parser.add_argument("name", type=str, required=True, nullable=True, location="json")
@ -54,6 +56,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
return marshal(metadata, dataset_metadata_fields), 200 return marshal(metadata, dataset_metadata_fields), 200
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id, metadata_id): def delete(self, tenant_id, dataset_id, metadata_id):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id) metadata_id_str = str(metadata_id)
@ -73,6 +76,7 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, action): def post(self, tenant_id, dataset_id, action):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
@ -88,6 +92,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
class DocumentMetadataEditServiceApi(DatasetApiResource): class DocumentMetadataEditServiceApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id): def post(self, tenant_id, dataset_id):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)

@ -8,6 +8,7 @@ from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import ( from controllers.service_api.wraps import (
DatasetApiResource, DatasetApiResource,
cloud_edition_billing_knowledge_limit_check, cloud_edition_billing_knowledge_limit_check,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
) )
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
@ -35,6 +36,7 @@ class SegmentApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id): def post(self, tenant_id, dataset_id, document_id):
"""Create single segment.""" """Create single segment."""
# check dataset # check dataset
@ -139,6 +141,7 @@ class SegmentApi(DatasetApiResource):
class DatasetSegmentApi(DatasetApiResource): class DatasetSegmentApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id, document_id, segment_id): def delete(self, tenant_id, dataset_id, document_id, segment_id):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -162,6 +165,7 @@ class DatasetSegmentApi(DatasetApiResource):
return 204 return 204
@cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id, segment_id): def post(self, tenant_id, dataset_id, document_id, segment_id):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -236,6 +240,7 @@ class ChildChunkApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id, segment_id): def post(self, tenant_id, dataset_id, document_id, segment_id):
"""Create child chunk.""" """Create child chunk."""
# check dataset # check dataset
@ -332,6 +337,7 @@ class DatasetChildChunkApi(DatasetApiResource):
"""Resource for updating child chunks.""" """Resource for updating child chunks."""
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id): def delete(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id):
"""Delete child chunk.""" """Delete child chunk."""
# check dataset # check dataset
@ -370,6 +376,7 @@ class DatasetChildChunkApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id): def patch(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id):
"""Update child chunk.""" """Update child chunk."""
# check dataset # check dataset

@ -163,7 +163,7 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
) )
db.session.add(end_user) db.session.add(end_user)
db.session.commit() db.session.commit()
exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24) exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
exp = int(exp_dt.timestamp()) exp = int(exp_dt.timestamp())
payload = { payload = {
"iss": site.id, "iss": site.id,

@ -138,14 +138,11 @@ class DatasetConfigManager:
if not config.get("dataset_configs"): if not config.get("dataset_configs"):
config["dataset_configs"] = {"retrieval_model": "single"} config["dataset_configs"] = {"retrieval_model": "single"}
if not config["dataset_configs"].get("datasets"):
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
if not isinstance(config["dataset_configs"], dict): if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type") raise ValueError("dataset_configs must be of object type")
if not isinstance(config["dataset_configs"], dict): if not config["dataset_configs"].get("datasets"):
raise ValueError("dataset_configs must be of object type") config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get( need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
"datasets", {} "datasets", {}

@ -5,7 +5,7 @@ import uuid
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from typing import Any, Literal, Optional, Union, overload from typing import Any, Literal, Optional, Union, overload
from flask import Flask, copy_current_request_context, current_app, has_request_context from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
@ -31,6 +31,7 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
@ -366,6 +367,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param user: account or end user :param user: account or end user
:param invoke_from: invoke from source :param invoke_from: invoke from source
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
:param workflow_execution_repository: repository for workflow execution
:param workflow_node_execution_repository: repository for workflow node execution :param workflow_node_execution_repository: repository for workflow node execution
:param conversation: conversation :param conversation: conversation
:param stream: is stream :param stream: is stream
@ -399,21 +401,18 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# new thread with request context and contextvars # new thread with request context and contextvars
context = contextvars.copy_context() context = contextvars.copy_context()
@copy_current_request_context worker_thread = threading.Thread(
def worker_with_context(): target=self._generate_worker,
# Run the worker within the copied context kwargs={
return context.run( "flask_app": current_app._get_current_object(), # type: ignore
self._generate_worker, "application_generate_entity": application_generate_entity,
flask_app=current_app._get_current_object(), # type: ignore "queue_manager": queue_manager,
application_generate_entity=application_generate_entity, "conversation_id": conversation.id,
queue_manager=queue_manager, "message_id": message.id,
conversation_id=conversation.id, "context": context,
message_id=message.id, },
context=context,
) )
worker_thread = threading.Thread(target=worker_with_context)
worker_thread.start() worker_thread.start()
# return response or stream generator # return response or stream generator
@ -449,24 +448,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param message_id: message ID :param message_id: message ID
:return: :return:
""" """
for var, val in context.items():
var.set(val)
# FIXME(-LAN-): Save current user before entering new app context with preserve_flask_contexts(flask_app, context_vars=context):
from flask import g
saved_user = None
if has_request_context() and hasattr(g, "_login_user"):
saved_user = g._login_user
with flask_app.app_context():
try: try:
# Restore user in new app context
if saved_user is not None:
from flask import g
g._login_user = saved_user
# get conversation and message # get conversation and message
conversation = self._get_conversation(conversation_id) conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id) message = self._get_message(message_id)

@ -5,7 +5,7 @@ import uuid
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload from typing import Any, Literal, Union, overload
from flask import Flask, copy_current_request_context, current_app, has_request_context from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
from configs import dify_config from configs import dify_config
@ -23,6 +23,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, EndUser from models import Account, App, EndUser
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
@ -182,21 +183,18 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
# new thread with request context and contextvars # new thread with request context and contextvars
context = contextvars.copy_context() context = contextvars.copy_context()
@copy_current_request_context worker_thread = threading.Thread(
def worker_with_context(): target=self._generate_worker,
# Run the worker within the copied context kwargs={
return context.run( "flask_app": current_app._get_current_object(), # type: ignore
self._generate_worker, "context": context,
flask_app=current_app._get_current_object(), # type: ignore "application_generate_entity": application_generate_entity,
context=context, "queue_manager": queue_manager,
application_generate_entity=application_generate_entity, "conversation_id": conversation.id,
queue_manager=queue_manager, "message_id": message.id,
conversation_id=conversation.id, },
message_id=message.id,
) )
worker_thread = threading.Thread(target=worker_with_context)
worker_thread.start() worker_thread.start()
# return response or stream generator # return response or stream generator
@ -229,24 +227,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
:param message_id: message ID :param message_id: message ID
:return: :return:
""" """
for var, val in context.items():
var.set(val)
# FIXME(-LAN-): Save current user before entering new app context with preserve_flask_contexts(flask_app, context_vars=context):
from flask import g
saved_user = None
if has_request_context() and hasattr(g, "_login_user"):
saved_user = g._login_user
with flask_app.app_context():
try: try:
# Restore user in new app context
if saved_user is not None:
from flask import g
g._login_user = saved_user
# get conversation and message # get conversation and message
conversation = self._get_conversation(conversation_id) conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id) message = self._get_message(message_id)

@ -5,7 +5,7 @@ import uuid
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any, Literal, Optional, Union, overload from typing import Any, Literal, Optional, Union, overload
from flask import Flask, copy_current_request_context, current_app, has_request_context from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
@ -29,6 +29,7 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
@ -194,6 +195,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param user: account or end user :param user: account or end user
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
:param invoke_from: invoke from source :param invoke_from: invoke from source
:param workflow_execution_repository: repository for workflow execution
:param workflow_node_execution_repository: repository for workflow node execution :param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream :param streaming: is stream
:param workflow_thread_pool_id: workflow thread pool id :param workflow_thread_pool_id: workflow thread pool id
@ -209,20 +211,17 @@ class WorkflowAppGenerator(BaseAppGenerator):
# new thread with request context and contextvars # new thread with request context and contextvars
context = contextvars.copy_context() context = contextvars.copy_context()
@copy_current_request_context worker_thread = threading.Thread(
def worker_with_context(): target=self._generate_worker,
# Run the worker within the copied context kwargs={
return context.run( "flask_app": current_app._get_current_object(), # type: ignore
self._generate_worker, "application_generate_entity": application_generate_entity,
flask_app=current_app._get_current_object(), # type: ignore "queue_manager": queue_manager,
application_generate_entity=application_generate_entity, "context": context,
queue_manager=queue_manager, "workflow_thread_pool_id": workflow_thread_pool_id,
context=context, },
workflow_thread_pool_id=workflow_thread_pool_id,
) )
worker_thread = threading.Thread(target=worker_with_context)
worker_thread.start() worker_thread.start()
# return response or stream generator # return response or stream generator
@ -408,24 +407,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param workflow_thread_pool_id: workflow thread pool id :param workflow_thread_pool_id: workflow thread pool id
:return: :return:
""" """
for var, val in context.items():
var.set(val)
# FIXME(-LAN-): Save current user before entering new app context with preserve_flask_contexts(flask_app, context_vars=context):
from flask import g
saved_user = None
if has_request_context() and hasattr(g, "_login_user"):
saved_user = g._login_user
with flask_app.app_context():
try: try:
# Restore user in new app context
if saved_user is not None:
from flask import g
g._login_user = saved_user
# workflow app # workflow app
runner = WorkflowAppRunner( runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,

@ -542,8 +542,6 @@ class LBModelManager:
return config return config
return None
def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None: def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None:
""" """
Cooldown model load balancing config Cooldown model load balancing config

@ -251,7 +251,7 @@ class OpsTraceManager:
provider_config_map[tracing_provider]["trace_instance"], provider_config_map[tracing_provider]["trace_instance"],
provider_config_map[tracing_provider]["config_class"], provider_config_map[tracing_provider]["config_class"],
) )
decrypt_trace_config_key = str(decrypt_trace_config) decrypt_trace_config_key = json.dumps(decrypt_trace_config, sort_keys=True)
tracing_instance = cls.ops_trace_instances_cache.get(decrypt_trace_config_key) tracing_instance = cls.ops_trace_instances_cache.get(decrypt_trace_config_key)
if tracing_instance is None: if tracing_instance is None:
# create new tracing_instance and update the cache if it absent # create new tracing_instance and update the cache if it absent

@ -156,9 +156,23 @@ class PluginInstallTaskStartResponse(BaseModel):
task_id: str = Field(description="The ID of the install task.") task_id: str = Field(description="The ID of the install task.")
class PluginUploadResponse(BaseModel): class PluginVerification(BaseModel):
"""
Verification of the plugin.
"""
class AuthorizedCategory(StrEnum):
Langgenius = "langgenius"
Partner = "partner"
Community = "community"
authorized_category: AuthorizedCategory = Field(description="The authorized category of the plugin.")
class PluginDecodeResponse(BaseModel):
unique_identifier: str = Field(description="The unique identifier of the plugin.") unique_identifier: str = Field(description="The unique identifier of the plugin.")
manifest: PluginDeclaration manifest: PluginDeclaration
verification: Optional[PluginVerification] = Field(default=None, description="Basic verification information")
class PluginOAuthAuthorizationUrlResponse(BaseModel): class PluginOAuthAuthorizationUrlResponse(BaseModel):

@ -10,10 +10,10 @@ from core.plugin.entities.plugin import (
PluginInstallationSource, PluginInstallationSource,
) )
from core.plugin.entities.plugin_daemon import ( from core.plugin.entities.plugin_daemon import (
PluginDecodeResponse,
PluginInstallTask, PluginInstallTask,
PluginInstallTaskStartResponse, PluginInstallTaskStartResponse,
PluginListResponse, PluginListResponse,
PluginUploadResponse,
) )
from core.plugin.impl.base import BasePluginClient from core.plugin.impl.base import BasePluginClient
@ -53,7 +53,7 @@ class PluginInstaller(BasePluginClient):
tenant_id: str, tenant_id: str,
pkg: bytes, pkg: bytes,
verify_signature: bool = False, verify_signature: bool = False,
) -> PluginUploadResponse: ) -> PluginDecodeResponse:
""" """
Upload a plugin package and return the plugin unique identifier. Upload a plugin package and return the plugin unique identifier.
""" """
@ -68,7 +68,7 @@ class PluginInstaller(BasePluginClient):
return self._request_with_plugin_daemon_response( return self._request_with_plugin_daemon_response(
"POST", "POST",
f"plugin/{tenant_id}/management/install/upload/package", f"plugin/{tenant_id}/management/install/upload/package",
PluginUploadResponse, PluginDecodeResponse,
files=body, files=body,
data=data, data=data,
) )
@ -176,6 +176,18 @@ class PluginInstaller(BasePluginClient):
params={"plugin_unique_identifier": plugin_unique_identifier}, params={"plugin_unique_identifier": plugin_unique_identifier},
) )
def decode_plugin_from_identifier(self, tenant_id: str, plugin_unique_identifier: str) -> PluginDecodeResponse:
"""
Decode a plugin from an identifier.
"""
return self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/decode/from_identifier",
PluginDecodeResponse,
data={"plugin_unique_identifier": plugin_unique_identifier},
headers={"Content-Type": "application/json"},
)
def fetch_plugin_installation_by_ids( def fetch_plugin_installation_by_ids(
self, tenant_id: str, plugin_ids: Sequence[str] self, tenant_id: str, plugin_ids: Sequence[str]
) -> Sequence[PluginInstallation]: ) -> Sequence[PluginInstallation]:

@ -0,0 +1,233 @@
import json
import logging
import uuid
from functools import wraps
from typing import Any, Optional
from mo_vector.client import MoVectorClient # type: ignore
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class MatrixoneConfig(BaseModel):
host: str = "localhost"
port: int = 6001
user: str = "dump"
password: str = "111"
database: str = "dify"
metric: str = "l2"
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config host is required")
if not values["port"]:
raise ValueError("config port is required")
if not values["user"]:
raise ValueError("config user is required")
if not values["password"]:
raise ValueError("config password is required")
if not values["database"]:
raise ValueError("config database is required")
return values
def ensure_client(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if self.client is None:
self.client = self._get_client(None, False)
return func(self, *args, **kwargs)
return wrapper
class MatrixoneVector(BaseVector):
"""
Matrixone vector storage implementation.
"""
def __init__(self, collection_name: str, config: MatrixoneConfig):
super().__init__(collection_name)
self.config = config
self.collection_name = collection_name.lower()
self.client = None
@property
def collection_name(self):
return self._collection_name
@collection_name.setter
def collection_name(self, value):
self._collection_name = value
def get_type(self) -> str:
return VectorType.MATRIXONE
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if self.client is None:
self.client = self._get_client(len(embeddings[0]), True)
return self.add_texts(texts, embeddings)
def _get_client(self, dimension: Optional[int] = None, create_table: bool = False) -> MoVectorClient:
"""
Create a new client for the collection.
The collection will be created if it doesn't exist.
"""
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
client = MoVectorClient(
connection_string=f"mysql+pymysql://{self.config.user}:{self.config.password}@{self.config.host}:{self.config.port}/{self.config.database}",
table_name=self.collection_name,
vector_dimension=dimension,
create_table=create_table,
)
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return client
try:
client.create_full_text_index()
except Exception as e:
logger.exception("Failed to create full text index")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
return client
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
if self.client is None:
self.client = self._get_client(len(embeddings[0]), True)
assert self.client is not None
ids = []
for _, doc in enumerate(documents):
if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
ids.append(doc_id)
self.client.insert(
texts=[doc.page_content for doc in documents],
embeddings=embeddings,
metadatas=[doc.metadata for doc in documents],
ids=ids,
)
return ids
@ensure_client
def text_exists(self, id: str) -> bool:
assert self.client is not None
result = self.client.get(ids=[id])
return len(result) > 0
@ensure_client
def delete_by_ids(self, ids: list[str]) -> None:
assert self.client is not None
if not ids:
return
self.client.delete(ids=ids)
@ensure_client
def get_ids_by_metadata_field(self, key: str, value: str):
assert self.client is not None
results = self.client.query_by_metadata(filter={key: value})
return [result.id for result in results]
@ensure_client
def delete_by_metadata_field(self, key: str, value: str) -> None:
assert self.client is not None
self.client.delete(filter={key: value})
@ensure_client
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
assert self.client is not None
top_k = kwargs.get("top_k", 5)
document_ids_filter = kwargs.get("document_ids_filter")
filter = None
if document_ids_filter:
filter = {"document_id": {"$in": document_ids_filter}}
results = self.client.query(
query_vector=query_vector,
k=top_k,
filter=filter,
)
docs = []
# TODO: add the score threshold to the query
for result in results:
metadata = result.metadata
docs.append(
Document(
page_content=result.document,
metadata=metadata,
)
)
return docs
@ensure_client
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
assert self.client is not None
top_k = kwargs.get("top_k", 5)
document_ids_filter = kwargs.get("document_ids_filter")
filter = None
if document_ids_filter:
filter = {"document_id": {"$in": document_ids_filter}}
score_threshold = float(kwargs.get("score_threshold", 0.0))
results = self.client.full_text_query(
keywords=[query],
k=top_k,
filter=filter,
)
docs = []
for result in results:
metadata = result.metadata
if isinstance(metadata, str):
import json
metadata = json.loads(metadata)
score = 1 - result.distance
if score >= score_threshold:
metadata["score"] = score
docs.append(
Document(
page_content=result.document,
metadata=metadata,
)
)
return docs
@ensure_client
def delete(self) -> None:
assert self.client is not None
self.client.delete()
class MatrixoneVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MATRIXONE, collection_name))
config = MatrixoneConfig(
host=dify_config.MATRIXONE_HOST or "localhost",
port=dify_config.MATRIXONE_PORT or 6001,
user=dify_config.MATRIXONE_USER or "dump",
password=dify_config.MATRIXONE_PASSWORD or "111",
database=dify_config.MATRIXONE_DATABASE or "dify",
metric=dify_config.MATRIXONE_METRIC or "l2",
)
return MatrixoneVector(collection_name=collection_name, config=config)

@ -164,6 +164,10 @@ class Vector:
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVectorFactory from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVectorFactory
return HuaweiCloudVectorFactory return HuaweiCloudVectorFactory
case VectorType.MATRIXONE:
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory
return MatrixoneVectorFactory
case _: case _:
raise ValueError(f"Vector store {vector_type} is not supported.") raise ValueError(f"Vector store {vector_type} is not supported.")

@ -29,3 +29,4 @@ class VectorType(StrEnum):
OPENGAUSS = "opengauss" OPENGAUSS = "opengauss"
TABLESTORE = "tablestore" TABLESTORE = "tablestore"
HUAWEI_CLOUD = "huawei_cloud" HUAWEI_CLOUD = "huawei_cloud"
MATRIXONE = "matrixone"

@ -41,6 +41,13 @@ class WeaviateVector(BaseVector):
weaviate.connect.connection.has_grpc = False weaviate.connect.connection.has_grpc = False
# Fix to minimize the performance impact of the deprecation check in weaviate-client 3.24.0,
# by changing the connection timeout to pypi.org from 1 second to 0.001 seconds.
# TODO: This can be removed once weaviate-client is updated to 3.26.7 or higher,
# which does not contain the deprecation check.
if hasattr(weaviate.connect.connection, "PYPI_TIMEOUT"):
weaviate.connect.connection.PYPI_TIMEOUT = 0.001
try: try:
client = weaviate.Client( client = weaviate.Client(
url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None

@ -22,6 +22,7 @@ class FirecrawlApp:
"formats": ["markdown"], "formats": ["markdown"],
"onlyMainContent": True, "onlyMainContent": True,
"timeout": 30000, "timeout": 30000,
"integration": "dify",
} }
if params: if params:
json_data.update(params) json_data.update(params)
@ -39,7 +40,7 @@ class FirecrawlApp:
def crawl_url(self, url, params=None) -> str: def crawl_url(self, url, params=None) -> str:
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/crawl-post # Documentation: https://docs.firecrawl.dev/api-reference/endpoint/crawl-post
headers = self._prepare_headers() headers = self._prepare_headers()
json_data = {"url": url} json_data = {"url": url, "integration": "dify"}
if params: if params:
json_data.update(params) json_data.update(params)
response = self._post_request(f"{self.base_url}/v1/crawl", json_data, headers) response = self._post_request(f"{self.base_url}/v1/crawl", json_data, headers)
@ -49,7 +50,6 @@ class FirecrawlApp:
return cast(str, job_id) return cast(str, job_id)
else: else:
self._handle_error(response, "start crawl job") self._handle_error(response, "start crawl job")
# FIXME: unreachable code for mypy
return "" # unreachable return "" # unreachable
def check_crawl_status(self, job_id) -> dict[str, Any]: def check_crawl_status(self, job_id) -> dict[str, Any]:
@ -82,7 +82,6 @@ class FirecrawlApp:
) )
else: else:
self._handle_error(response, "check crawl status") self._handle_error(response, "check crawl status")
# FIXME: unreachable code for mypy
return {} # unreachable return {} # unreachable
def _format_crawl_status_response( def _format_crawl_status_response(
@ -126,4 +125,31 @@ class FirecrawlApp:
def _handle_error(self, response, action) -> None: def _handle_error(self, response, action) -> None:
error_message = response.json().get("error", "Unknown error occurred") error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return]
def search(self, query: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/search
headers = self._prepare_headers()
json_data = {
"query": query,
"limit": 5,
"lang": "en",
"country": "us",
"timeout": 60000,
"ignoreInvalidURLs": False,
"scrapeOptions": {},
"integration": "dify",
}
if params:
json_data.update(params)
response = self._post_request(f"{self.base_url}/v1/search", json_data, headers)
if response.status_code == 200:
response_data = response.json()
if not response_data.get("success"):
raise Exception(f"Search failed. Error: {response_data.get('warning', 'Unknown error')}")
return cast(dict[str, Any], response_data)
elif response.status_code in {402, 409, 500, 429, 408}:
self._handle_error(response, "perform search")
return {} # Avoid additional exception after handling error
else:
raise Exception(f"Failed to perform search. Status code: {response.status_code}")

@ -79,6 +79,16 @@ class NotionExtractor(BaseExtractor):
def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]: def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]:
"""Get all the pages from a Notion database.""" """Get all the pages from a Notion database."""
assert self._notion_access_token is not None, "Notion access token is required" assert self._notion_access_token is not None, "Notion access token is required"
database_content = []
next_cursor = None
has_more = True
while has_more:
current_query = query_dict.copy()
if next_cursor:
current_query["start_cursor"] = next_cursor
res = requests.post( res = requests.post(
DATABASE_URL_TMPL.format(database_id=database_id), DATABASE_URL_TMPL.format(database_id=database_id),
headers={ headers={
@ -86,15 +96,15 @@ class NotionExtractor(BaseExtractor):
"Content-Type": "application/json", "Content-Type": "application/json",
"Notion-Version": "2022-06-28", "Notion-Version": "2022-06-28",
}, },
json=query_dict, json=current_query,
) )
data = res.json() response_data = res.json()
database_content = [] if "results" not in response_data or response_data["results"] is None:
if "results" not in data or data["results"] is None: break
return []
for result in data["results"]: for result in response_data["results"]:
properties = result["properties"] properties = result["properties"]
data = {} data = {}
value: Any value: Any
@ -129,6 +139,12 @@ class NotionExtractor(BaseExtractor):
row_content = row_content + f"{key}:{value}\n" row_content = row_content + f"{key}:{value}\n"
database_content.append(row_content) database_content.append(row_content)
has_more = response_data.get("has_more", False)
next_cursor = response_data.get("next_cursor")
if not database_content:
return []
return [Document(page_content="\n".join(database_content))] return [Document(page_content="\n".join(database_content))]
def _get_notion_block_data(self, page_id: str) -> list[str]: def _get_notion_block_data(self, page_id: str) -> list[str]:

@ -104,7 +104,7 @@ class QAIndexProcessor(BaseIndexProcessor):
def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]:
# check file type # check file type
if not file.filename or not file.filename.endswith(".csv"): if not file.filename or not file.filename.lower().endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed") raise ValueError("Invalid file type. Only CSV files are allowed")
try: try:

@ -496,6 +496,8 @@ class DatasetRetrieval:
all_documents = self.calculate_keyword_score(query, all_documents, top_k) all_documents = self.calculate_keyword_score(query, all_documents, top_k)
elif index_type == "high_quality": elif index_type == "high_quality":
all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold) all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
else:
all_documents = all_documents[:top_k] if top_k else all_documents
self._on_query(query, dataset_ids, app_id, user_from, user_id) self._on_query(query, dataset_ids, app_id, user_from, user_id)

@ -9,7 +9,7 @@ from copy import copy, deepcopy
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any, Optional, cast from typing import Any, Optional, cast
from flask import Flask, current_app, has_request_context from flask import Flask, current_app
from configs import dify_config from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
@ -53,6 +53,7 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from libs.flask_utils import preserve_flask_contexts
from models.enums import UserFrom from models.enums import UserFrom
from models.workflow import WorkflowType from models.workflow import WorkflowType
@ -537,24 +538,9 @@ class GraphEngine:
""" """
Run parallel nodes Run parallel nodes
""" """
for var, val in context.items():
var.set(val)
# FIXME(-LAN-): Save current user before entering new app context with preserve_flask_contexts(flask_app, context_vars=context):
from flask import g
saved_user = None
if has_request_context() and hasattr(g, "_login_user"):
saved_user = g._login_user
with flask_app.app_context():
try: try:
# Restore user in new app context
if saved_user is not None:
from flask import g
g._login_user = saved_user
q.put( q.put(
ParallelBranchRunStartedEvent( ParallelBranchRunStartedEvent(
parallel_id=parallel_id, parallel_id=parallel_id,
@ -653,26 +639,19 @@ class GraphEngine:
retry_start_at = datetime.now(UTC).replace(tzinfo=None) retry_start_at = datetime.now(UTC).replace(tzinfo=None)
# yield control to other threads # yield control to other threads
time.sleep(0.001) time.sleep(0.001)
generator = node_instance.run() event_stream = node_instance.run()
for item in generator: for event in event_stream:
if isinstance(item, GraphEngineEvent): if isinstance(event, GraphEngineEvent):
if isinstance(item, BaseIterationEvent):
# add parallel info to iteration event # add parallel info to iteration event
item.parallel_id = parallel_id if isinstance(event, BaseIterationEvent | BaseLoopEvent):
item.parallel_start_node_id = parallel_start_node_id event.parallel_id = parallel_id
item.parent_parallel_id = parent_parallel_id event.parallel_start_node_id = parallel_start_node_id
item.parent_parallel_start_node_id = parent_parallel_start_node_id event.parent_parallel_id = parent_parallel_id
elif isinstance(item, BaseLoopEvent): event.parent_parallel_start_node_id = parent_parallel_start_node_id
# add parallel info to loop event yield event
item.parallel_id = parallel_id
item.parallel_start_node_id = parallel_start_node_id
item.parent_parallel_id = parent_parallel_id
item.parent_parallel_start_node_id = parent_parallel_start_node_id
yield item
else: else:
if isinstance(item, RunCompletedEvent): if isinstance(event, RunCompletedEvent):
run_result = item.run_result run_result = event.run_result
if run_result.status == WorkflowNodeExecutionStatus.FAILED: if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if ( if (
retries == max_retries retries == max_retries
@ -708,7 +687,7 @@ class GraphEngine:
# if run failed, handle error # if run failed, handle error
run_result = self._handle_continue_on_error( run_result = self._handle_continue_on_error(
node_instance, node_instance,
item.run_result, event.run_result,
self.graph_runtime_state.variable_pool, self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions, handle_exceptions=handle_exceptions,
) )
@ -811,28 +790,28 @@ class GraphEngine:
should_continue_retry = False should_continue_retry = False
break break
elif isinstance(item, RunStreamChunkEvent): elif isinstance(event, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent( yield NodeRunStreamChunkEvent(
id=node_instance.id, id=node_instance.id,
node_id=node_instance.node_id, node_id=node_instance.node_id,
node_type=node_instance.node_type, node_type=node_instance.node_type,
node_data=node_instance.node_data, node_data=node_instance.node_data,
chunk_content=item.chunk_content, chunk_content=event.chunk_content,
from_variable_selector=item.from_variable_selector, from_variable_selector=event.from_variable_selector,
route_node_state=route_node_state, route_node_state=route_node_state,
parallel_id=parallel_id, parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
) )
elif isinstance(item, RunRetrieverResourceEvent): elif isinstance(event, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent( yield NodeRunRetrieverResourceEvent(
id=node_instance.id, id=node_instance.id,
node_id=node_instance.node_id, node_id=node_instance.node_id,
node_type=node_instance.node_type, node_type=node_instance.node_type,
node_data=node_instance.node_data, node_data=node_instance.node_data,
retriever_resources=item.retriever_resources, retriever_resources=event.retriever_resources,
context=item.context, context=event.context,
route_node_state=route_node_state, route_node_state=route_node_state,
parallel_id=parallel_id, parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,

@ -214,7 +214,7 @@ class AgentNode(ToolNode):
) )
if tool_runtime.entity.description: if tool_runtime.entity.description:
tool_runtime.entity.description.llm = ( tool_runtime.entity.description.llm = (
extra.get("descrption", "") or tool_runtime.entity.description.llm extra.get("description", "") or tool_runtime.entity.description.llm
) )
for tool_runtime_params in tool_runtime.entity.parameters: for tool_runtime_params in tool_runtime.entity.parameters:
tool_runtime_params.form = ( tool_runtime_params.form = (

@ -57,7 +57,6 @@ class StreamProcessor(ABC):
# The branch_identify parameter is added to ensure that # The branch_identify parameter is added to ensure that
# only nodes in the correct logical branch are included. # only nodes in the correct logical branch are included.
reachable_node_ids.append(edge.target_node_id)
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle) ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle)
reachable_node_ids.extend(ids) reachable_node_ids.extend(ids)
else: else:
@ -74,6 +73,8 @@ class StreamProcessor(ABC):
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]: def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]:
if node_id not in self.rest_node_ids:
self.rest_node_ids.append(node_id)
node_ids = [] node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []): for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id: if edge.target_node_id == self.graph.root_node_id:

@ -7,7 +7,7 @@ from datetime import UTC, datetime
from queue import Empty, Queue from queue import Empty, Queue
from typing import TYPE_CHECKING, Any, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast
from flask import Flask, current_app, has_request_context from flask import Flask, current_app
from configs import dify_config from configs import dify_config
from core.variables import ArrayVariable, IntegerVariable, NoneVariable from core.variables import ArrayVariable, IntegerVariable, NoneVariable
@ -37,6 +37,7 @@ from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from libs.flask_utils import preserve_flask_contexts
from .exc import ( from .exc import (
InvalidIteratorValueError, InvalidIteratorValueError,
@ -583,23 +584,8 @@ class IterationNode(BaseNode[IterationNodeData]):
""" """
run single iteration in parallel mode run single iteration in parallel mode
""" """
for var, val in context.items():
var.set(val)
# FIXME(-LAN-): Save current user before entering new app context
from flask import g
saved_user = None
if has_request_context() and hasattr(g, "_login_user"):
saved_user = g._login_user
with flask_app.app_context():
# Restore user in new app context
if saved_user is not None:
from flask import g
g._login_user = saved_user
with preserve_flask_contexts(flask_app, context_vars=context):
parallel_mode_run_id = uuid.uuid4().hex parallel_mode_run_id = uuid.uuid4().hex
graph_engine_copy = graph_engine.create_copy() graph_engine_copy = graph_engine.create_copy()
variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool

@ -8,4 +8,5 @@ EMPTY_VALUE_MAPPING = {
SegmentType.ARRAY_STRING: [], SegmentType.ARRAY_STRING: [],
SegmentType.ARRAY_NUMBER: [], SegmentType.ARRAY_NUMBER: [],
SegmentType.ARRAY_OBJECT: [], SegmentType.ARRAY_OBJECT: [],
SegmentType.ARRAY_FILE: [],
} }

@ -1,5 +1,6 @@
from typing import Any from typing import Any
from core.file import File
from core.variables import SegmentType from core.variables import SegmentType
from .enums import Operation from .enums import Operation
@ -85,6 +86,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
return isinstance(value, int | float) return isinstance(value, int | float)
case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND: case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND:
return isinstance(value, dict) return isinstance(value, dict)
case SegmentType.ARRAY_FILE if operation == Operation.APPEND:
return isinstance(value, File)
# Array & Extend / Overwrite # Array & Extend / Overwrite
case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}: case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
@ -95,6 +98,8 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
return isinstance(value, list) and all(isinstance(item, int | float) for item in value) return isinstance(value, list) and all(isinstance(item, int | float) for item in value)
case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}: case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, dict) for item in value) return isinstance(value, list) and all(isinstance(item, dict) for item in value)
case SegmentType.ARRAY_FILE if operation in {Operation.EXTEND, Operation.OVER_WRITE}:
return isinstance(value, list) and all(isinstance(item, File) for item in value)
case _: case _:
return False return False

@ -54,6 +54,15 @@ class Mail:
use_tls=dify_config.SMTP_USE_TLS, use_tls=dify_config.SMTP_USE_TLS,
opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS, opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS,
) )
case "sendgrid":
from libs.sendgrid import SendGridClient
if not dify_config.SENDGRID_API_KEY:
raise ValueError("SENDGRID_API_KEY is required for SendGrid mail type")
self._client = SendGridClient(
sendgrid_api_key=dify_config.SENDGRID_API_KEY, _from=dify_config.MAIL_DEFAULT_SEND_FROM or ""
)
case _: case _:
raise ValueError("Unsupported mail type {}".format(mail_type)) raise ValueError("Unsupported mail type {}".format(mail_type))

@ -101,6 +101,8 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
result = ArrayNumberVariable.model_validate(mapping) result = ArrayNumberVariable.model_validate(mapping)
case SegmentType.ARRAY_OBJECT if isinstance(value, list): case SegmentType.ARRAY_OBJECT if isinstance(value, list):
result = ArrayObjectVariable.model_validate(mapping) result = ArrayObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_FILE if isinstance(value, list):
result = ArrayFileVariable.model_validate(mapping)
case _: case _:
raise VariableError(f"not supported value type {value_type}") raise VariableError(f"not supported value type {value_type}")
if result.size > dify_config.MAX_VARIABLE_SIZE: if result.size > dify_config.MAX_VARIABLE_SIZE:

@ -0,0 +1,65 @@
import contextvars
from collections.abc import Iterator
from contextlib import contextmanager
from typing import TypeVar
from flask import Flask, g, has_request_context
T = TypeVar("T")
@contextmanager
def preserve_flask_contexts(
flask_app: Flask,
context_vars: contextvars.Context,
) -> Iterator[None]:
"""
A context manager that handles:
1. flask-login's UserProxy copy
2. ContextVars copy
3. flask_app.app_context()
This context manager ensures that the Flask application context is properly set up,
the current user is preserved across context boundaries, and any provided context variables
are set within the new context.
Note:
This manager aims to allow use current_user cross thread and app context,
but it's not the recommend use, it's better to pass user directly in parameters.
Args:
flask_app: The Flask application instance
context_vars: contextvars.Context object containing context variables to be set in the new context
Yields:
None
Example:
```python
with preserve_flask_contexts(flask_app, context_vars=context_vars):
# Code that needs Flask app context and context variables
# Current user will be preserved if available
```
"""
# Set context variables if provided
if context_vars:
for var, val in context_vars.items():
var.set(val)
# Save current user before entering new app context
saved_user = None
if has_request_context() and hasattr(g, "_login_user"):
saved_user = g._login_user
# Enter Flask app context
with flask_app.app_context():
try:
# Restore user in new app context if it was saved
if saved_user is not None:
g._login_user = saved_user
# Yield control back to the caller
yield
finally:
# Any cleanup can be added here if needed
pass

@ -0,0 +1,42 @@
import logging
import sendgrid # type: ignore
from python_http_client.exceptions import ForbiddenError, UnauthorizedError
from sendgrid.helpers.mail import Content, Email, Mail, To # type: ignore
class SendGridClient:
def __init__(self, sendgrid_api_key: str, _from: str):
self.sendgrid_api_key = sendgrid_api_key
self._from = _from
def send(self, mail: dict):
logging.debug("Sending email with SendGrid")
try:
_to = mail["to"]
if not _to:
raise ValueError("SendGridClient: Cannot send email: recipient address is missing.")
sg = sendgrid.SendGridAPIClient(api_key=self.sendgrid_api_key)
from_email = Email(self._from)
to_email = To(_to)
subject = mail["subject"]
content = Content("text/html", mail["html"])
mail = Mail(from_email, to_email, subject, content)
mail_json = mail.get() # type: ignore
response = sg.client.mail.send.post(request_body=mail_json)
logging.debug(response.status_code)
logging.debug(response.body)
logging.debug(response.headers)
except TimeoutError as e:
logging.exception("SendGridClient Timeout occurred while sending email")
raise
except (UnauthorizedError, ForbiddenError) as e:
logging.exception("SendGridClient Authentication failed. Verify that your credentials and the 'from")
raise
except Exception as e:
logging.exception(f"SendGridClient Unexpected error occurred while sending email to {_to}")
raise

@ -10,7 +10,6 @@ from core.plugin.entities.plugin import GenericProviderID
from core.tools.entities.tool_entities import ToolProviderType from core.tools.entities.tool_entities import ToolProviderType
from core.tools.signature import sign_tool_file from core.tools.signature import sign_tool_file
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
from services.plugin.plugin_service import PluginService
if TYPE_CHECKING: if TYPE_CHECKING:
from models.workflow import Workflow from models.workflow import Workflow
@ -169,6 +168,7 @@ class App(Base):
@property @property
def deleted_tools(self) -> list: def deleted_tools(self) -> list:
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from services.plugin.plugin_service import PluginService
# get agent mode tools # get agent mode tools
app_model_config = self.app_model_config app_model_config = self.app_model_config

@ -18,4 +18,3 @@ ignore_missing_imports=True
[mypy-flask_restful.inputs] [mypy-flask_restful.inputs]
ignore_missing_imports=True ignore_missing_imports=True

@ -81,6 +81,7 @@ dependencies = [
"weave~=0.51.0", "weave~=0.51.0",
"yarl~=1.18.3", "yarl~=1.18.3",
"webvtt-py~=0.5.1", "webvtt-py~=0.5.1",
"sendgrid~=6.12.3",
] ]
# Before adding new dependency, consider place it in # Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group. # alphabet order (a-z) and suitable group.
@ -202,4 +203,5 @@ vdb = [
"volcengine-compat~=1.0.0", "volcengine-compat~=1.0.0",
"weaviate-client~=3.24.0", "weaviate-client~=3.24.0",
"xinference-client~=1.2.2", "xinference-client~=1.2.2",
"mo-vector~=0.1.13",
] ]

@ -421,7 +421,7 @@ class AppDslService:
# Set icon type # Set icon type
icon_type_value = icon_type or app_data.get("icon_type") icon_type_value = icon_type or app_data.get("icon_type")
if icon_type_value in ["emoji", "link"]: if icon_type_value in ["emoji", "link", "image"]:
icon_type = icon_type_value icon_type = icon_type_value
else: else:
icon_type = "emoji" icon_type = "emoji"

@ -101,7 +101,7 @@ class WeightModel(BaseModel):
class RetrievalModel(BaseModel): class RetrievalModel(BaseModel):
search_method: Literal["hybrid_search", "semantic_search", "full_text_search"] search_method: Literal["hybrid_search", "semantic_search", "full_text_search", "keyword_search"]
reranking_enable: bool reranking_enable: bool
reranking_model: Optional[RerankingModel] = None reranking_model: Optional[RerankingModel] = None
reranking_mode: Optional[str] = None reranking_mode: Optional[str] = None

@ -0,0 +1,5 @@
from services.errors.base import BaseServiceError
class PluginInstallationForbiddenError(BaseServiceError):
pass

@ -88,6 +88,26 @@ class WebAppAuthModel(BaseModel):
allow_email_password_login: bool = False allow_email_password_login: bool = False
class PluginInstallationScope(StrEnum):
NONE = "none"
OFFICIAL_ONLY = "official_only"
OFFICIAL_AND_SPECIFIC_PARTNERS = "official_and_specific_partners"
ALL = "all"
class PluginInstallationPermissionModel(BaseModel):
# Plugin installation scope possible values:
# none: prohibit all plugin installations
# official_only: allow only Dify official plugins
# official_and_specific_partners: allow official and specific partner plugins
# all: allow installation of all plugins
plugin_installation_scope: PluginInstallationScope = PluginInstallationScope.ALL
# If True, restrict plugin installation to the marketplace only
# Equivalent to ForceEnablePluginVerification
restrict_to_marketplace_only: bool = False
class FeatureModel(BaseModel): class FeatureModel(BaseModel):
billing: BillingModel = BillingModel() billing: BillingModel = BillingModel()
education: EducationModel = EducationModel() education: EducationModel = EducationModel()
@ -128,6 +148,7 @@ class SystemFeatureModel(BaseModel):
license: LicenseModel = LicenseModel() license: LicenseModel = LicenseModel()
branding: BrandingModel = BrandingModel() branding: BrandingModel = BrandingModel()
webapp_auth: WebAppAuthModel = WebAppAuthModel() webapp_auth: WebAppAuthModel = WebAppAuthModel()
plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel()
class FeatureService: class FeatureService:
@ -291,3 +312,12 @@ class FeatureService:
features.license.workspaces.enabled = license_info["workspaces"]["enabled"] features.license.workspaces.enabled = license_info["workspaces"]["enabled"]
features.license.workspaces.limit = license_info["workspaces"]["limit"] features.license.workspaces.limit = license_info["workspaces"]["limit"]
features.license.workspaces.size = license_info["workspaces"]["used"] features.license.workspaces.size = license_info["workspaces"]["used"]
if "PluginInstallationPermission" in enterprise_info:
plugin_installation_info = enterprise_info["PluginInstallationPermission"]
features.plugin_installation_permission.plugin_installation_scope = plugin_installation_info[
"pluginInstallationScope"
]
features.plugin_installation_permission.restrict_to_marketplace_only = plugin_installation_info[
"restrictToMarketplaceOnly"
]

@ -3,7 +3,7 @@ import logging
import click import click
from core.entities import DEFAULT_PLUGIN_ID from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID
from models.engine import db from models.engine import db
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -12,17 +12,17 @@ logger = logging.getLogger(__name__)
class PluginDataMigration: class PluginDataMigration:
@classmethod @classmethod
def migrate(cls) -> None: def migrate(cls) -> None:
cls.migrate_db_records("providers", "provider_name") # large table cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table
cls.migrate_db_records("provider_models", "provider_name") cls.migrate_db_records("provider_models", "provider_name", ModelProviderID)
cls.migrate_db_records("provider_orders", "provider_name") cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID)
cls.migrate_db_records("tenant_default_models", "provider_name") cls.migrate_db_records("tenant_default_models", "provider_name", ModelProviderID)
cls.migrate_db_records("tenant_preferred_model_providers", "provider_name") cls.migrate_db_records("tenant_preferred_model_providers", "provider_name", ModelProviderID)
cls.migrate_db_records("provider_model_settings", "provider_name") cls.migrate_db_records("provider_model_settings", "provider_name", ModelProviderID)
cls.migrate_db_records("load_balancing_model_configs", "provider_name") cls.migrate_db_records("load_balancing_model_configs", "provider_name", ModelProviderID)
cls.migrate_datasets() cls.migrate_datasets()
cls.migrate_db_records("embeddings", "provider_name") # large table cls.migrate_db_records("embeddings", "provider_name", ModelProviderID) # large table
cls.migrate_db_records("dataset_collection_bindings", "provider_name") cls.migrate_db_records("dataset_collection_bindings", "provider_name", ModelProviderID)
cls.migrate_db_records("tool_builtin_providers", "provider") cls.migrate_db_records("tool_builtin_providers", "provider_name", ToolProviderID)
@classmethod @classmethod
def migrate_datasets(cls) -> None: def migrate_datasets(cls) -> None:
@ -66,9 +66,10 @@ limit 1000"""
fg="white", fg="white",
) )
) )
retrieval_model["reranking_model"]["reranking_provider_name"] = ( # update google to langgenius/gemini/google etc.
f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}" retrieval_model["reranking_model"]["reranking_provider_name"] = ModelProviderID(
) retrieval_model["reranking_model"]["reranking_provider_name"]
).to_string()
retrieval_model_changed = True retrieval_model_changed = True
click.echo( click.echo(
@ -86,9 +87,11 @@ limit 1000"""
update_retrieval_model_sql = ", retrieval_model = :retrieval_model" update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
params["retrieval_model"] = json.dumps(retrieval_model) params["retrieval_model"] = json.dumps(retrieval_model)
params["provider_name"] = ModelProviderID(provider_name).to_string()
sql = f"""update {table_name} sql = f"""update {table_name}
set {provider_column_name} = set {provider_column_name} =
concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name}) :provider_name
{update_retrieval_model_sql} {update_retrieval_model_sql}
where id = :record_id""" where id = :record_id"""
conn.execute(db.text(sql), params) conn.execute(db.text(sql), params)
@ -122,7 +125,9 @@ limit 1000"""
) )
@classmethod @classmethod
def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None: def migrate_db_records(
cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]
) -> None:
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white")) click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
processed_count = 0 processed_count = 0
@ -166,7 +171,8 @@ limit 1000"""
) )
try: try:
updated_value = f"{DEFAULT_PLUGIN_ID}/{provider_name}/{provider_name}" # update jina to langgenius/jina_tool/jina etc.
updated_value = provider_cls(provider_name).to_string()
batch_updates.append((updated_value, record_id)) batch_updates.append((updated_value, record_id))
except Exception as e: except Exception as e:
failed_ids.append(record_id) failed_ids.append(record_id)

@ -17,11 +17,18 @@ from core.plugin.entities.plugin import (
PluginInstallation, PluginInstallation,
PluginInstallationSource, PluginInstallationSource,
) )
from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginListResponse, PluginUploadResponse from core.plugin.entities.plugin_daemon import (
PluginDecodeResponse,
PluginInstallTask,
PluginListResponse,
PluginVerification,
)
from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.debugging import PluginDebuggingClient from core.plugin.impl.debugging import PluginDebuggingClient
from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.plugin import PluginInstaller
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from services.errors.plugin import PluginInstallationForbiddenError
from services.feature_service import FeatureService, PluginInstallationScope
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -86,6 +93,42 @@ class PluginService:
logger.exception("failed to fetch latest plugin version") logger.exception("failed to fetch latest plugin version")
return result return result
@staticmethod
def _check_marketplace_only_permission():
"""
Check if the marketplace only permission is enabled
"""
features = FeatureService.get_system_features()
if features.plugin_installation_permission.restrict_to_marketplace_only:
raise PluginInstallationForbiddenError("Plugin installation is restricted to marketplace only")
@staticmethod
def _check_plugin_installation_scope(plugin_verification: Optional[PluginVerification]):
"""
Check the plugin installation scope
"""
features = FeatureService.get_system_features()
match features.plugin_installation_permission.plugin_installation_scope:
case PluginInstallationScope.OFFICIAL_ONLY:
if (
plugin_verification is None
or plugin_verification.authorized_category != PluginVerification.AuthorizedCategory.Langgenius
):
raise PluginInstallationForbiddenError("Plugin installation is restricted to official only")
case PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS:
if plugin_verification is None or plugin_verification.authorized_category not in [
PluginVerification.AuthorizedCategory.Langgenius,
PluginVerification.AuthorizedCategory.Partner,
]:
raise PluginInstallationForbiddenError(
"Plugin installation is restricted to official and specific partners"
)
case PluginInstallationScope.NONE:
raise PluginInstallationForbiddenError("Installing plugins is not allowed")
case PluginInstallationScope.ALL:
pass
@staticmethod @staticmethod
def get_debugging_key(tenant_id: str) -> str: def get_debugging_key(tenant_id: str) -> str:
""" """
@ -208,6 +251,8 @@ class PluginService:
# check if plugin pkg is already downloaded # check if plugin pkg is already downloaded
manager = PluginInstaller() manager = PluginInstaller()
features = FeatureService.get_system_features()
try: try:
manager.fetch_plugin_manifest(tenant_id, new_plugin_unique_identifier) manager.fetch_plugin_manifest(tenant_id, new_plugin_unique_identifier)
# already downloaded, skip, and record install event # already downloaded, skip, and record install event
@ -215,7 +260,14 @@ class PluginService:
except Exception: except Exception:
# plugin not installed, download and upload pkg # plugin not installed, download and upload pkg
pkg = download_plugin_pkg(new_plugin_unique_identifier) pkg = download_plugin_pkg(new_plugin_unique_identifier)
manager.upload_pkg(tenant_id, pkg, verify_signature=False) response = manager.upload_pkg(
tenant_id,
pkg,
verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only,
)
# check if the plugin is available to install
PluginService._check_plugin_installation_scope(response.verification)
return manager.upgrade_plugin( return manager.upgrade_plugin(
tenant_id, tenant_id,
@ -239,6 +291,7 @@ class PluginService:
""" """
Upgrade plugin with github Upgrade plugin with github
""" """
PluginService._check_marketplace_only_permission()
manager = PluginInstaller() manager = PluginInstaller()
return manager.upgrade_plugin( return manager.upgrade_plugin(
tenant_id, tenant_id,
@ -253,33 +306,43 @@ class PluginService:
) )
@staticmethod @staticmethod
def upload_pkg(tenant_id: str, pkg: bytes, verify_signature: bool = False) -> PluginUploadResponse: def upload_pkg(tenant_id: str, pkg: bytes, verify_signature: bool = False) -> PluginDecodeResponse:
""" """
Upload plugin package files Upload plugin package files
returns: plugin_unique_identifier returns: plugin_unique_identifier
""" """
PluginService._check_marketplace_only_permission()
manager = PluginInstaller() manager = PluginInstaller()
return manager.upload_pkg(tenant_id, pkg, verify_signature) features = FeatureService.get_system_features()
response = manager.upload_pkg(
tenant_id,
pkg,
verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only,
)
return response
@staticmethod @staticmethod
def upload_pkg_from_github( def upload_pkg_from_github(
tenant_id: str, repo: str, version: str, package: str, verify_signature: bool = False tenant_id: str, repo: str, version: str, package: str, verify_signature: bool = False
) -> PluginUploadResponse: ) -> PluginDecodeResponse:
""" """
Install plugin from github release package files, Install plugin from github release package files,
returns plugin_unique_identifier returns plugin_unique_identifier
""" """
PluginService._check_marketplace_only_permission()
pkg = download_with_size_limit( pkg = download_with_size_limit(
f"https://github.com/{repo}/releases/download/{version}/{package}", dify_config.PLUGIN_MAX_PACKAGE_SIZE f"https://github.com/{repo}/releases/download/{version}/{package}", dify_config.PLUGIN_MAX_PACKAGE_SIZE
) )
features = FeatureService.get_system_features()
manager = PluginInstaller() manager = PluginInstaller()
return manager.upload_pkg( response = manager.upload_pkg(
tenant_id, tenant_id,
pkg, pkg,
verify_signature, verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only,
) )
return response
@staticmethod @staticmethod
def upload_bundle( def upload_bundle(
@ -289,11 +352,15 @@ class PluginService:
Upload a plugin bundle and return the dependencies. Upload a plugin bundle and return the dependencies.
""" """
manager = PluginInstaller() manager = PluginInstaller()
PluginService._check_marketplace_only_permission()
return manager.upload_bundle(tenant_id, bundle, verify_signature) return manager.upload_bundle(tenant_id, bundle, verify_signature)
@staticmethod @staticmethod
def install_from_local_pkg(tenant_id: str, plugin_unique_identifiers: Sequence[str]): def install_from_local_pkg(tenant_id: str, plugin_unique_identifiers: Sequence[str]):
PluginService._check_marketplace_only_permission()
manager = PluginInstaller() manager = PluginInstaller()
return manager.install_from_identifiers( return manager.install_from_identifiers(
tenant_id, tenant_id,
plugin_unique_identifiers, plugin_unique_identifiers,
@ -307,6 +374,8 @@ class PluginService:
Install plugin from github release package files, Install plugin from github release package files,
returns plugin_unique_identifier returns plugin_unique_identifier
""" """
PluginService._check_marketplace_only_permission()
manager = PluginInstaller() manager = PluginInstaller()
return manager.install_from_identifiers( return manager.install_from_identifiers(
tenant_id, tenant_id,
@ -322,28 +391,33 @@ class PluginService:
) )
@staticmethod @staticmethod
def fetch_marketplace_pkg( def fetch_marketplace_pkg(tenant_id: str, plugin_unique_identifier: str) -> PluginDeclaration:
tenant_id: str, plugin_unique_identifier: str, verify_signature: bool = False
) -> PluginDeclaration:
""" """
Fetch marketplace package Fetch marketplace package
""" """
if not dify_config.MARKETPLACE_ENABLED: if not dify_config.MARKETPLACE_ENABLED:
raise ValueError("marketplace is not enabled") raise ValueError("marketplace is not enabled")
features = FeatureService.get_system_features()
manager = PluginInstaller() manager = PluginInstaller()
try: try:
declaration = manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier) declaration = manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
except Exception: except Exception:
pkg = download_plugin_pkg(plugin_unique_identifier) pkg = download_plugin_pkg(plugin_unique_identifier)
declaration = manager.upload_pkg(tenant_id, pkg, verify_signature).manifest response = manager.upload_pkg(
tenant_id,
pkg,
verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only,
)
# check if the plugin is available to install
PluginService._check_plugin_installation_scope(response.verification)
declaration = response.manifest
return declaration return declaration
@staticmethod @staticmethod
def install_from_marketplace_pkg( def install_from_marketplace_pkg(tenant_id: str, plugin_unique_identifiers: Sequence[str]):
tenant_id: str, plugin_unique_identifiers: Sequence[str], verify_signature: bool = False
):
""" """
Install plugin from marketplace package files, Install plugin from marketplace package files,
returns installation task id returns installation task id
@ -353,15 +427,26 @@ class PluginService:
manager = PluginInstaller() manager = PluginInstaller()
features = FeatureService.get_system_features()
# check if already downloaded # check if already downloaded
for plugin_unique_identifier in plugin_unique_identifiers: for plugin_unique_identifier in plugin_unique_identifiers:
try: try:
manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier) manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
plugin_decode_response = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier)
# check if the plugin is available to install
PluginService._check_plugin_installation_scope(plugin_decode_response.verification)
# already downloaded, skip # already downloaded, skip
except Exception: except Exception:
# plugin not installed, download and upload pkg # plugin not installed, download and upload pkg
pkg = download_plugin_pkg(plugin_unique_identifier) pkg = download_plugin_pkg(plugin_unique_identifier)
manager.upload_pkg(tenant_id, pkg, verify_signature) response = manager.upload_pkg(
tenant_id,
pkg,
verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only,
)
# check if the plugin is available to install
PluginService._check_plugin_installation_scope(response.verification)
return manager.install_from_identifiers( return manager.install_from_identifiers(
tenant_id, tenant_id,

@ -5,7 +5,7 @@ from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
from models import App, EndUser, WorkflowAppLog, WorkflowRun from models import Account, App, EndUser, WorkflowAppLog, WorkflowRun
from models.enums import CreatorUserRole from models.enums import CreatorUserRole
@ -21,6 +21,8 @@ class WorkflowAppService:
created_at_after: datetime | None = None, created_at_after: datetime | None = None,
page: int = 1, page: int = 1,
limit: int = 20, limit: int = 20,
created_by_end_user_session_id: str | None = None,
created_by_account: str | None = None,
) -> dict: ) -> dict:
""" """
Get paginate workflow app logs using SQLAlchemy 2.0 style Get paginate workflow app logs using SQLAlchemy 2.0 style
@ -32,6 +34,8 @@ class WorkflowAppService:
:param created_at_after: filter logs created after this timestamp :param created_at_after: filter logs created after this timestamp
:param page: page number :param page: page number
:param limit: items per page :param limit: items per page
:param created_by_end_user_session_id: filter by end user session id
:param created_by_account: filter by account email
:return: Pagination object :return: Pagination object
""" """
# Build base statement using SQLAlchemy 2.0 style # Build base statement using SQLAlchemy 2.0 style
@ -71,6 +75,26 @@ class WorkflowAppService:
if created_at_after: if created_at_after:
stmt = stmt.where(WorkflowAppLog.created_at >= created_at_after) stmt = stmt.where(WorkflowAppLog.created_at >= created_at_after)
# Filter by end user session id or account email
if created_by_end_user_session_id:
stmt = stmt.join(
EndUser,
and_(
WorkflowAppLog.created_by == EndUser.id,
WorkflowAppLog.created_by_role == CreatorUserRole.END_USER,
EndUser.session_id == created_by_end_user_session_id,
),
)
if created_by_account:
stmt = stmt.join(
Account,
and_(
WorkflowAppLog.created_by == Account.id,
WorkflowAppLog.created_by_role == CreatorUserRole.ACCOUNT,
Account.email == created_by_account,
),
)
stmt = stmt.order_by(WorkflowAppLog.created_at.desc()) stmt = stmt.order_by(WorkflowAppLog.created_at.desc())
# Get total count using the same filters # Get total count using the same filters

@ -0,0 +1,25 @@
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneConfig, MatrixoneVector
from tests.integration_tests.vdb.test_vector_store import (
AbstractVectorTest,
get_example_text,
setup_mock_redis,
)
class MatrixoneVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = MatrixoneVector(
collection_name=self.collection_name,
config=MatrixoneConfig(
host="localhost", port=6001, user="dump", password="111", database="dify", metric="l2"
),
)
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
assert len(ids) == 1
def test_matrixone_vector(setup_mock_redis):
MatrixoneVectorTest().run_all_tests()

@ -1,49 +0,0 @@
import time
import pymysql
def check_oceanbase_ready() -> bool:
try:
connection = pymysql.connect(
host="localhost",
port=2881,
user="root",
password="difyai123456",
)
affected_rows = connection.query("SELECT 1")
return affected_rows == 1
except Exception as e:
print(f"Oceanbase is not ready. Exception: {e}")
return False
finally:
if connection:
connection.close()
def main():
max_attempts = 50
retry_interval_seconds = 2
is_oceanbase_ready = False
for attempt in range(max_attempts):
try:
is_oceanbase_ready = check_oceanbase_ready()
except Exception as e:
print(f"Oceanbase is not ready. Exception: {e}")
is_oceanbase_ready = False
if is_oceanbase_ready:
break
else:
print(f"Attempt {attempt + 1} failed, retry in {retry_interval_seconds} seconds...")
time.sleep(retry_interval_seconds)
if is_oceanbase_ready:
print("Oceanbase is ready.")
else:
print(f"Oceanbase is not ready after {max_attempts} attempting checks.")
exit(1)
if __name__ == "__main__":
main()

@ -0,0 +1,124 @@
import contextvars
import threading
from typing import Optional
import pytest
from flask import Flask
from flask_login import LoginManager, UserMixin, current_user, login_user
from libs.flask_utils import preserve_flask_contexts
class User(UserMixin):
"""Simple User class for testing."""
def __init__(self, id: str):
self.id = id
def get_id(self) -> str:
return self.id
@pytest.fixture
def login_app(app: Flask) -> Flask:
"""Set up a Flask app with flask-login."""
# Set a secret key for the app
app.config["SECRET_KEY"] = "test-secret-key"
login_manager = LoginManager()
login_manager.init_app(app)
@login_manager.user_loader
def load_user(user_id: str) -> Optional[User]:
if user_id == "test_user":
return User("test_user")
return None
return app
@pytest.fixture
def test_user() -> User:
"""Create a test user."""
return User("test_user")
def test_current_user_not_accessible_across_threads(login_app: Flask, test_user: User):
"""
Test that current_user is not accessible in a different thread without preserve_flask_contexts.
This test demonstrates that without the preserve_flask_contexts, we cannot access
current_user in a different thread, even with app_context.
"""
# Log in the user in the main thread
with login_app.test_request_context():
login_user(test_user)
assert current_user.is_authenticated
assert current_user.id == "test_user"
# Store the result of the thread execution
result = {"user_accessible": True, "error": None}
# Define a function to run in a separate thread
def check_user_in_thread():
try:
# Try to access current_user in a different thread with app_context
with login_app.app_context():
# This should fail because current_user is not accessible across threads
# without preserve_flask_contexts
result["user_accessible"] = current_user.is_authenticated
except Exception as e:
result["error"] = str(e) # type: ignore
# Run the function in a separate thread
thread = threading.Thread(target=check_user_in_thread)
thread.start()
thread.join()
# Verify that we got an error or current_user is not authenticated
assert result["error"] is not None or (result["user_accessible"] is not None and not result["user_accessible"])
def test_current_user_accessible_with_preserve_flask_contexts(login_app: Flask, test_user: User):
"""
Test that current_user is accessible in a different thread with preserve_flask_contexts.
This test demonstrates that with the preserve_flask_contexts, we can access
current_user in a different thread.
"""
# Log in the user in the main thread
with login_app.test_request_context():
login_user(test_user)
assert current_user.is_authenticated
assert current_user.id == "test_user"
# Save the context variables
context_vars = contextvars.copy_context()
# Store the result of the thread execution
result = {"user_accessible": False, "user_id": None, "error": None}
# Define a function to run in a separate thread
def check_user_in_thread_with_manager():
try:
# Use preserve_flask_contexts to access current_user in a different thread
with preserve_flask_contexts(login_app, context_vars):
from flask_login import current_user
if current_user:
result["user_accessible"] = True
result["user_id"] = current_user.id
else:
result["user_accessible"] = False
except Exception as e:
result["error"] = str(e) # type: ignore
# Run the function in a separate thread
thread = threading.Thread(target=check_user_in_thread_with_manager)
thread.start()
thread.join()
# Verify that current_user is accessible and has the correct ID
assert result["error"] is None
assert result["user_accessible"] is True
assert result["user_id"] == "test_user"

File diff suppressed because it is too large Load Diff

@ -399,7 +399,7 @@ SUPABASE_URL=your-server-url
# ------------------------------ # ------------------------------
# The type of vector store to use. # The type of vector store to use.
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`. # Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
VECTOR_STORE=weaviate VECTOR_STORE=weaviate
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
@ -490,6 +490,13 @@ TIDB_VECTOR_USER=
TIDB_VECTOR_PASSWORD= TIDB_VECTOR_PASSWORD=
TIDB_VECTOR_DATABASE=dify TIDB_VECTOR_DATABASE=dify
# Matrixone vector configurations.
MATRIXONE_HOST=matrixone
MATRIXONE_PORT=6001
MATRIXONE_USER=dump
MATRIXONE_PASSWORD=111
MATRIXONE_DATABASE=dify
# Tidb on qdrant configuration, only available when VECTOR_STORE is `tidb_on_qdrant` # Tidb on qdrant configuration, only available when VECTOR_STORE is `tidb_on_qdrant`
TIDB_ON_QDRANT_URL=http://127.0.0.1 TIDB_ON_QDRANT_URL=http://127.0.0.1
TIDB_ON_QDRANT_API_KEY=dify TIDB_ON_QDRANT_API_KEY=dify
@ -719,10 +726,11 @@ NOTION_INTERNAL_SECRET=
# Mail related configuration # Mail related configuration
# ------------------------------ # ------------------------------
# Mail type, support: resend, smtp # Mail type, support: resend, smtp, sendgrid
MAIL_TYPE=resend MAIL_TYPE=resend
# Default send from email address, if not specified # Default send from email address, if not specified
# If using SendGrid, use the 'from' field for authentication if necessary.
MAIL_DEFAULT_SEND_FROM= MAIL_DEFAULT_SEND_FROM=
# API-Key for the Resend email provider, used when MAIL_TYPE is `resend`. # API-Key for the Resend email provider, used when MAIL_TYPE is `resend`.
@ -738,6 +746,9 @@ SMTP_PASSWORD=
SMTP_USE_TLS=true SMTP_USE_TLS=true
SMTP_OPPORTUNISTIC_TLS=false SMTP_OPPORTUNISTIC_TLS=false
# Sendgid configuration
SENDGRID_API_KEY=
# ------------------------------ # ------------------------------
# Others Configuration # Others Configuration
# ------------------------------ # ------------------------------
@ -815,7 +826,8 @@ TEXT_GENERATION_TIMEOUT_MS=60000
# Environment Variables for db Service # Environment Variables for db Service
# ------------------------------ # ------------------------------
PGUSER=${DB_USERNAME} # The name of the default postgres user.
POSTGRES_USER=${DB_USERNAME}
# The password for the default postgres user. # The password for the default postgres user.
POSTGRES_PASSWORD=${DB_PASSWORD} POSTGRES_PASSWORD=${DB_PASSWORD}
# The name of the default postgres database. # The name of the default postgres database.
@ -1067,7 +1079,7 @@ PLUGIN_MEDIA_CACHE_PATH=assets
# Plugin oss bucket # Plugin oss bucket
PLUGIN_STORAGE_OSS_BUCKET= PLUGIN_STORAGE_OSS_BUCKET=
# Plugin oss s3 credentials # Plugin oss s3 credentials
PLUGIN_S3_USE_AWS= PLUGIN_S3_USE_AWS=false
PLUGIN_S3_USE_AWS_MANAGED_IAM=false PLUGIN_S3_USE_AWS_MANAGED_IAM=false
PLUGIN_S3_ENDPOINT= PLUGIN_S3_ENDPOINT=
PLUGIN_S3_USE_PATH_STYLE=false PLUGIN_S3_USE_PATH_STYLE=false

@ -84,7 +84,7 @@ services:
image: postgres:15-alpine image: postgres:15-alpine
restart: always restart: always
environment: environment:
PGUSER: ${PGUSER:-postgres} POSTGRES_USER: ${POSTGRES_USER:-postgres}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456}
POSTGRES_DB: ${POSTGRES_DB:-dify} POSTGRES_DB: ${POSTGRES_DB:-dify}
PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata} PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata}
@ -451,6 +451,14 @@ services:
OB_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OB_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai}
OB_SERVER_IP: 127.0.0.1 OB_SERVER_IP: 127.0.0.1
MODE: mini MODE: mini
ports:
- "${OCEANBASE_VECTOR_PORT:-2881}:2881"
healthcheck:
test: [ 'CMD-SHELL', 'obclient -h127.0.0.1 -P2881 -uroot@test -p$${OB_TENANT_PASSWORD} -e "SELECT 1;"' ]
interval: 10s
retries: 30
start_period: 30s
timeout: 10s
# Oracle vector database # Oracle vector database
oracle: oracle:
@ -609,6 +617,18 @@ services:
ports: ports:
- ${MYSCALE_PORT:-8123}:${MYSCALE_PORT:-8123} - ${MYSCALE_PORT:-8123}:${MYSCALE_PORT:-8123}
# Matrixone vector store.
matrixone:
hostname: matrixone
image: matrixorigin/matrixone:2.1.1
profiles:
- matrixone
restart: always
volumes:
- ./volumes/matrixone/data:/mo-data
ports:
- ${MATRIXONE_PORT:-6001}:${MATRIXONE_PORT:-6001}
# https://www.elastic.co/guide/en/elasticsearch/reference/current/settings.html # https://www.elastic.co/guide/en/elasticsearch/reference/current/settings.html
# https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html#docker-prod-prerequisites # https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html#docker-prod-prerequisites
elasticsearch: elasticsearch:

@ -104,7 +104,7 @@ services:
PLUGIN_PACKAGE_CACHE_PATH: ${PLUGIN_PACKAGE_CACHE_PATH:-plugin_packages} PLUGIN_PACKAGE_CACHE_PATH: ${PLUGIN_PACKAGE_CACHE_PATH:-plugin_packages}
PLUGIN_MEDIA_CACHE_PATH: ${PLUGIN_MEDIA_CACHE_PATH:-assets} PLUGIN_MEDIA_CACHE_PATH: ${PLUGIN_MEDIA_CACHE_PATH:-assets}
PLUGIN_STORAGE_OSS_BUCKET: ${PLUGIN_STORAGE_OSS_BUCKET:-} PLUGIN_STORAGE_OSS_BUCKET: ${PLUGIN_STORAGE_OSS_BUCKET:-}
S3_USE_AWS: ${PLUGIN_S3_USE_AWS:-} S3_USE_AWS: ${PLUGIN_S3_USE_AWS:-false}
S3_USE_AWS_MANAGED_IAM: ${PLUGIN_S3_USE_AWS_MANAGED_IAM:-false} S3_USE_AWS_MANAGED_IAM: ${PLUGIN_S3_USE_AWS_MANAGED_IAM:-false}
S3_ENDPOINT: ${PLUGIN_S3_ENDPOINT:-} S3_ENDPOINT: ${PLUGIN_S3_ENDPOINT:-}
S3_USE_PATH_STYLE: ${PLUGIN_S3_USE_PATH_STYLE:-false} S3_USE_PATH_STYLE: ${PLUGIN_S3_USE_PATH_STYLE:-false}

@ -195,6 +195,11 @@ x-shared-env: &shared-api-worker-env
TIDB_VECTOR_USER: ${TIDB_VECTOR_USER:-} TIDB_VECTOR_USER: ${TIDB_VECTOR_USER:-}
TIDB_VECTOR_PASSWORD: ${TIDB_VECTOR_PASSWORD:-} TIDB_VECTOR_PASSWORD: ${TIDB_VECTOR_PASSWORD:-}
TIDB_VECTOR_DATABASE: ${TIDB_VECTOR_DATABASE:-dify} TIDB_VECTOR_DATABASE: ${TIDB_VECTOR_DATABASE:-dify}
MATRIXONE_HOST: ${MATRIXONE_HOST:-matrixone}
MATRIXONE_PORT: ${MATRIXONE_PORT:-6001}
MATRIXONE_USER: ${MATRIXONE_USER:-dump}
MATRIXONE_PASSWORD: ${MATRIXONE_PASSWORD:-111}
MATRIXONE_DATABASE: ${MATRIXONE_DATABASE:-dify}
TIDB_ON_QDRANT_URL: ${TIDB_ON_QDRANT_URL:-http://127.0.0.1} TIDB_ON_QDRANT_URL: ${TIDB_ON_QDRANT_URL:-http://127.0.0.1}
TIDB_ON_QDRANT_API_KEY: ${TIDB_ON_QDRANT_API_KEY:-dify} TIDB_ON_QDRANT_API_KEY: ${TIDB_ON_QDRANT_API_KEY:-dify}
TIDB_ON_QDRANT_CLIENT_TIMEOUT: ${TIDB_ON_QDRANT_CLIENT_TIMEOUT:-20} TIDB_ON_QDRANT_CLIENT_TIMEOUT: ${TIDB_ON_QDRANT_CLIENT_TIMEOUT:-20}
@ -322,6 +327,7 @@ x-shared-env: &shared-api-worker-env
SMTP_PASSWORD: ${SMTP_PASSWORD:-} SMTP_PASSWORD: ${SMTP_PASSWORD:-}
SMTP_USE_TLS: ${SMTP_USE_TLS:-true} SMTP_USE_TLS: ${SMTP_USE_TLS:-true}
SMTP_OPPORTUNISTIC_TLS: ${SMTP_OPPORTUNISTIC_TLS:-false} SMTP_OPPORTUNISTIC_TLS: ${SMTP_OPPORTUNISTIC_TLS:-false}
SENDGRID_API_KEY: ${SENDGRID_API_KEY:-}
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000} INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000}
INVITE_EXPIRY_HOURS: ${INVITE_EXPIRY_HOURS:-72} INVITE_EXPIRY_HOURS: ${INVITE_EXPIRY_HOURS:-72}
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: ${RESET_PASSWORD_TOKEN_EXPIRY_MINUTES:-5} RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: ${RESET_PASSWORD_TOKEN_EXPIRY_MINUTES:-5}
@ -356,7 +362,7 @@ x-shared-env: &shared-api-worker-env
MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10} MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10}
MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-99} MAX_ITERATIONS_NUM: ${MAX_ITERATIONS_NUM:-99}
TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000}
PGUSER: ${PGUSER:-${DB_USERNAME}} POSTGRES_USER: ${POSTGRES_USER:-${DB_USERNAME}}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-${DB_PASSWORD}} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-${DB_PASSWORD}}
POSTGRES_DB: ${POSTGRES_DB:-${DB_DATABASE}} POSTGRES_DB: ${POSTGRES_DB:-${DB_DATABASE}}
PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata} PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata}
@ -467,7 +473,7 @@ x-shared-env: &shared-api-worker-env
PLUGIN_PACKAGE_CACHE_PATH: ${PLUGIN_PACKAGE_CACHE_PATH:-plugin_packages} PLUGIN_PACKAGE_CACHE_PATH: ${PLUGIN_PACKAGE_CACHE_PATH:-plugin_packages}
PLUGIN_MEDIA_CACHE_PATH: ${PLUGIN_MEDIA_CACHE_PATH:-assets} PLUGIN_MEDIA_CACHE_PATH: ${PLUGIN_MEDIA_CACHE_PATH:-assets}
PLUGIN_STORAGE_OSS_BUCKET: ${PLUGIN_STORAGE_OSS_BUCKET:-} PLUGIN_STORAGE_OSS_BUCKET: ${PLUGIN_STORAGE_OSS_BUCKET:-}
PLUGIN_S3_USE_AWS: ${PLUGIN_S3_USE_AWS:-} PLUGIN_S3_USE_AWS: ${PLUGIN_S3_USE_AWS:-false}
PLUGIN_S3_USE_AWS_MANAGED_IAM: ${PLUGIN_S3_USE_AWS_MANAGED_IAM:-false} PLUGIN_S3_USE_AWS_MANAGED_IAM: ${PLUGIN_S3_USE_AWS_MANAGED_IAM:-false}
PLUGIN_S3_ENDPOINT: ${PLUGIN_S3_ENDPOINT:-} PLUGIN_S3_ENDPOINT: ${PLUGIN_S3_ENDPOINT:-}
PLUGIN_S3_USE_PATH_STYLE: ${PLUGIN_S3_USE_PATH_STYLE:-false} PLUGIN_S3_USE_PATH_STYLE: ${PLUGIN_S3_USE_PATH_STYLE:-false}
@ -591,7 +597,7 @@ services:
image: postgres:15-alpine image: postgres:15-alpine
restart: always restart: always
environment: environment:
PGUSER: ${PGUSER:-postgres} POSTGRES_USER: ${POSTGRES_USER:-postgres}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456} POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-difyai123456}
POSTGRES_DB: ${POSTGRES_DB:-dify} POSTGRES_DB: ${POSTGRES_DB:-dify}
PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata} PGDATA: ${PGDATA:-/var/lib/postgresql/data/pgdata}
@ -958,6 +964,14 @@ services:
OB_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OB_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai}
OB_SERVER_IP: 127.0.0.1 OB_SERVER_IP: 127.0.0.1
MODE: mini MODE: mini
ports:
- "${OCEANBASE_VECTOR_PORT:-2881}:2881"
healthcheck:
test: [ 'CMD-SHELL', 'obclient -h127.0.0.1 -P2881 -uroot@test -p$${OB_TENANT_PASSWORD} -e "SELECT 1;"' ]
interval: 10s
retries: 30
start_period: 30s
timeout: 10s
# Oracle vector database # Oracle vector database
oracle: oracle:
@ -1116,6 +1130,18 @@ services:
ports: ports:
- ${MYSCALE_PORT:-8123}:${MYSCALE_PORT:-8123} - ${MYSCALE_PORT:-8123}:${MYSCALE_PORT:-8123}
# Matrixone vector store.
matrixone:
hostname: matrixone
image: matrixorigin/matrixone:2.1.1
profiles:
- matrixone
restart: always
volumes:
- ./volumes/matrixone/data:/mo-data
ports:
- ${MATRIXONE_PORT:-6001}:${MATRIXONE_PORT:-6001}
# https://www.elastic.co/guide/en/elasticsearch/reference/current/settings.html # https://www.elastic.co/guide/en/elasticsearch/reference/current/settings.html
# https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html#docker-prod-prerequisites # https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html#docker-prod-prerequisites
elasticsearch: elasticsearch:

@ -1,7 +1,7 @@
# ------------------------------ # ------------------------------
# Environment Variables for db Service # Environment Variables for db Service
# ------------------------------ # ------------------------------
PGUSER=postgres POSTGRES_USER=postgres
# The password for the default postgres user. # The password for the default postgres user.
POSTGRES_PASSWORD=difyai123456 POSTGRES_PASSWORD=difyai123456
# The name of the default postgres database. # The name of the default postgres database.
@ -133,7 +133,7 @@ PLUGIN_MEDIA_CACHE_PATH=assets
PLUGIN_STORAGE_OSS_BUCKET= PLUGIN_STORAGE_OSS_BUCKET=
# Plugin oss s3 credentials # Plugin oss s3 credentials
PLUGIN_S3_USE_AWS_MANAGED_IAM=false PLUGIN_S3_USE_AWS_MANAGED_IAM=false
PLUGIN_S3_USE_AWS= PLUGIN_S3_USE_AWS=false
PLUGIN_S3_ENDPOINT= PLUGIN_S3_ENDPOINT=
PLUGIN_S3_USE_PATH_STYLE=false PLUGIN_S3_USE_PATH_STYLE=false
PLUGIN_AWS_ACCESS_KEY= PLUGIN_AWS_ACCESS_KEY=

@ -15,7 +15,7 @@ const Overview = async (props: IDevelopProps) => {
} = params } = params
return ( return (
<div className="h-full overflow-scroll bg-chatbot-bg px-4 py-6 sm:px-12"> <div className="h-full overflow-y-auto bg-chatbot-bg px-4 py-6 sm:px-12">
<ApikeyInfoPanel /> <ApikeyInfoPanel />
<ChartView <ChartView
appId={appId} appId={appId}

@ -9,6 +9,7 @@ import { useTranslation } from 'react-i18next'
import { useDebounceFn } from 'ahooks' import { useDebounceFn } from 'ahooks'
import { import {
RiApps2Line, RiApps2Line,
RiDragDropLine,
RiExchange2Line, RiExchange2Line,
RiFile4Line, RiFile4Line,
RiMessage3Line, RiMessage3Line,
@ -16,7 +17,8 @@ import {
} from '@remixicon/react' } from '@remixicon/react'
import AppCard from './AppCard' import AppCard from './AppCard'
import NewAppCard from './NewAppCard' import NewAppCard from './NewAppCard'
import useAppsQueryState from './hooks/useAppsQueryState' import useAppsQueryState from './hooks/use-apps-query-state'
import { useDSLDragDrop } from './hooks/use-dsl-drag-drop'
import type { AppListResponse } from '@/models/app' import type { AppListResponse } from '@/models/app'
import { fetchAppList } from '@/service/apps' import { fetchAppList } from '@/service/apps'
import { useAppContext } from '@/context/app-context' import { useAppContext } from '@/context/app-context'
@ -29,6 +31,7 @@ import { useStore as useTagStore } from '@/app/components/base/tag-management/st
import TagManagementModal from '@/app/components/base/tag-management' import TagManagementModal from '@/app/components/base/tag-management'
import TagFilter from '@/app/components/base/tag-management/filter' import TagFilter from '@/app/components/base/tag-management/filter'
import CheckboxWithLabel from '@/app/components/datasets/create/website/base/checkbox-with-label' import CheckboxWithLabel from '@/app/components/datasets/create/website/base/checkbox-with-label'
import CreateFromDSLModal from '@/app/components/app/create-from-dsl-modal'
const getKey = ( const getKey = (
pageIndex: number, pageIndex: number,
@ -67,6 +70,9 @@ const Apps = () => {
const [tagFilterValue, setTagFilterValue] = useState<string[]>(tagIDs) const [tagFilterValue, setTagFilterValue] = useState<string[]>(tagIDs)
const [searchKeywords, setSearchKeywords] = useState(keywords) const [searchKeywords, setSearchKeywords] = useState(keywords)
const newAppCardRef = useRef<HTMLDivElement>(null) const newAppCardRef = useRef<HTMLDivElement>(null)
const containerRef = useRef<HTMLDivElement>(null)
const [showCreateFromDSLModal, setShowCreateFromDSLModal] = useState(false)
const [droppedDSLFile, setDroppedDSLFile] = useState<File | undefined>()
const setKeywords = useCallback((keywords: string) => { const setKeywords = useCallback((keywords: string) => {
setQuery(prev => ({ ...prev, keywords })) setQuery(prev => ({ ...prev, keywords }))
}, [setQuery]) }, [setQuery])
@ -74,6 +80,17 @@ const Apps = () => {
setQuery(prev => ({ ...prev, tagIDs })) setQuery(prev => ({ ...prev, tagIDs }))
}, [setQuery]) }, [setQuery])
const handleDSLFileDropped = useCallback((file: File) => {
setDroppedDSLFile(file)
setShowCreateFromDSLModal(true)
}, [])
const { dragging } = useDSLDragDrop({
onDSLFileDropped: handleDSLFileDropped,
containerRef,
enabled: isCurrentWorkspaceEditor,
})
const { data, isLoading, error, setSize, mutate } = useSWRInfinite( const { data, isLoading, error, setSize, mutate } = useSWRInfinite(
(pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, isCreatedByMe, tagIDs, searchKeywords), (pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, activeTab, isCreatedByMe, tagIDs, searchKeywords),
fetchAppList, fetchAppList,
@ -151,6 +168,12 @@ const Apps = () => {
return ( return (
<> <>
<div ref={containerRef} className='relative flex h-0 shrink-0 grow flex-col overflow-y-auto bg-background-body'>
{dragging && (
<div className="absolute inset-0 z-50 m-0.5 rounded-2xl border-2 border-dashed border-components-dropzone-border-accent bg-[rgba(21,90,239,0.14)] p-2">
</div>
)}
<div className='sticky top-0 z-10 flex flex-wrap items-center justify-between gap-y-2 bg-background-body px-12 pb-2 pt-4 leading-[56px]'> <div className='sticky top-0 z-10 flex flex-wrap items-center justify-between gap-y-2 bg-background-body px-12 pb-2 pt-4 leading-[56px]'>
<TabSliderNew <TabSliderNew
value={activeTab} value={activeTab}
@ -188,11 +211,39 @@ const Apps = () => {
&& <NewAppCard ref={newAppCardRef} className='z-10' onSuccess={mutate} />} && <NewAppCard ref={newAppCardRef} className='z-10' onSuccess={mutate} />}
<NoAppsFound /> <NoAppsFound />
</div>} </div>}
{isCurrentWorkspaceEditor && (
<div
className={`flex items-center justify-center gap-2 py-4 ${dragging ? 'text-text-accent' : 'text-text-quaternary'}`}
role="region"
aria-label={t('app.newApp.dropDSLToCreateApp')}
>
<RiDragDropLine className="h-4 w-4" />
<span className="system-xs-regular">{t('app.newApp.dropDSLToCreateApp')}</span>
</div>
)}
<CheckModal /> <CheckModal />
<div ref={anchorRef} className='h-0'> </div> <div ref={anchorRef} className='h-0'> </div>
{showTagManagementModal && ( {showTagManagementModal && (
<TagManagementModal type='app' show={showTagManagementModal} /> <TagManagementModal type='app' show={showTagManagementModal} />
)} )}
</div>
{showCreateFromDSLModal && (
<CreateFromDSLModal
show={showCreateFromDSLModal}
onClose={() => {
setShowCreateFromDSLModal(false)
setDroppedDSLFile(undefined)
}}
onSuccess={() => {
setShowCreateFromDSLModal(false)
setDroppedDSLFile(undefined)
mutate()
}}
droppedFile={droppedDSLFile}
/>
)}
</> </>
) )
} }

@ -0,0 +1,72 @@
import { useEffect, useState } from 'react'
type DSLDragDropHookProps = {
onDSLFileDropped: (file: File) => void
containerRef: React.RefObject<HTMLDivElement>
enabled?: boolean
}
export const useDSLDragDrop = ({ onDSLFileDropped, containerRef, enabled = true }: DSLDragDropHookProps) => {
const [dragging, setDragging] = useState(false)
const handleDragEnter = (e: DragEvent) => {
e.preventDefault()
e.stopPropagation()
if (e.dataTransfer?.types.includes('Files'))
setDragging(true)
}
const handleDragOver = (e: DragEvent) => {
e.preventDefault()
e.stopPropagation()
}
const handleDragLeave = (e: DragEvent) => {
e.preventDefault()
e.stopPropagation()
if (e.relatedTarget === null || !containerRef.current?.contains(e.relatedTarget as Node))
setDragging(false)
}
const handleDrop = (e: DragEvent) => {
e.preventDefault()
e.stopPropagation()
setDragging(false)
if (!e.dataTransfer)
return
const files = [...e.dataTransfer.files]
if (files.length === 0)
return
const file = files[0]
if (file.name.toLowerCase().endsWith('.yaml') || file.name.toLowerCase().endsWith('.yml'))
onDSLFileDropped(file)
}
useEffect(() => {
if (!enabled)
return
const current = containerRef.current
if (current) {
current.addEventListener('dragenter', handleDragEnter)
current.addEventListener('dragover', handleDragOver)
current.addEventListener('dragleave', handleDragLeave)
current.addEventListener('drop', handleDrop)
}
return () => {
if (current) {
current.removeEventListener('dragenter', handleDragEnter)
current.removeEventListener('dragover', handleDragOver)
current.removeEventListener('dragleave', handleDragLeave)
current.removeEventListener('drop', handleDrop)
}
}
}, [containerRef, enabled])
return {
dragging: enabled ? dragging : false,
}
}

@ -8,15 +8,15 @@ import { useRouter } from 'next/navigation'
import { useEffect } from 'react' import { useEffect } from 'react'
export default function DatasetsLayout({ children }: { children: React.ReactNode }) { export default function DatasetsLayout({ children }: { children: React.ReactNode }) {
const { isCurrentWorkspaceEditor } = useAppContext() const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator } = useAppContext()
const router = useRouter() const router = useRouter()
useEffect(() => { useEffect(() => {
if (!isCurrentWorkspaceEditor) if (!isCurrentWorkspaceEditor && !isCurrentWorkspaceDatasetOperator)
router.replace('/apps') router.replace('/apps')
}, [isCurrentWorkspaceEditor, router]) }, [isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, router])
if (!isCurrentWorkspaceEditor) if (!isCurrentWorkspaceEditor && !isCurrentWorkspaceDatasetOperator)
return <Loading type='app' /> return <Loading type='app' />
return ( return (
<ExternalKnowledgeApiProvider> <ExternalKnowledgeApiProvider>

@ -54,7 +54,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Property> </Property>
<Property name='indexing_technique' type='string' key='indexing_technique'> <Property name='indexing_technique' type='string' key='indexing_technique'>
Index mode Index mode
- <code>high_quality</code> High quality: embedding using embedding model, built as vector database index - <code>high_quality</code> High quality: Embedding using embedding model, built as vector database index
- <code>economy</code> Economy: Build using inverted index of keyword table index - <code>economy</code> Economy: Build using inverted index of keyword table index
</Property> </Property>
<Property name='doc_form' type='string' key='doc_form'> <Property name='doc_form' type='string' key='doc_form'>

@ -55,7 +55,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
<Property name='indexing_technique' type='string' key='indexing_technique'> <Property name='indexing_technique' type='string' key='indexing_technique'>
索引方式 索引方式
- <code>high_quality</code> 高质量:使用 - <code>high_quality</code> 高质量:使用
ding 模型进行嵌入,构建为向量数据库索引 Embedding 模型进行嵌入,构建为向量数据库索引
- <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建 - <code>economy</code> 经济:使用 keyword table index 的倒排索引进行构建
</Property> </Property>
<Property name='doc_form' type='string' key='doc_form'> <Property name='doc_form' type='string' key='doc_form'>

@ -314,10 +314,10 @@ const AppPublisher = ({
{!isAppAccessSet && <p className='system-xs-regular mt-1 text-text-warning'>{t('app.publishApp.notSetDesc')}</p>} {!isAppAccessSet && <p className='system-xs-regular mt-1 text-text-warning'>{t('app.publishApp.notSetDesc')}</p>}
</div>} </div>}
<div className='flex flex-col gap-y-1 border-t-[0.5px] border-t-divider-regular p-4 pt-3'> <div className='flex flex-col gap-y-1 border-t-[0.5px] border-t-divider-regular p-4 pt-3'>
<Tooltip triggerClassName='flex' disabled={!systemFeatures.webapp_auth.enabled || userCanAccessApp?.result} popupContent={t('app.noAccessPermission')} asChild={false}> <Tooltip triggerClassName='flex' disabled={!systemFeatures.webapp_auth.enabled || appDetail?.access_mode === AccessMode.EXTERNAL_MEMBERS || userCanAccessApp?.result} popupContent={t('app.noAccessPermission')} asChild={false}>
<SuggestedAction <SuggestedAction
className='flex-1' className='flex-1'
disabled={!publishedAt || (systemFeatures.webapp_auth.enabled && !userCanAccessApp?.result)} disabled={!publishedAt || (systemFeatures.webapp_auth.enabled && appDetail?.access_mode !== AccessMode.EXTERNAL_MEMBERS && !userCanAccessApp?.result)}
link={appURL} link={appURL}
icon={<RiPlayCircleLine className='h-4 w-4' />} icon={<RiPlayCircleLine className='h-4 w-4' />}
> >
@ -326,10 +326,10 @@ const AppPublisher = ({
</Tooltip> </Tooltip>
{appDetail?.mode === 'workflow' || appDetail?.mode === 'completion' {appDetail?.mode === 'workflow' || appDetail?.mode === 'completion'
? ( ? (
<Tooltip triggerClassName='flex' disabled={!systemFeatures.webapp_auth.enabled || userCanAccessApp?.result} popupContent={t('app.noAccessPermission')} asChild={false}> <Tooltip triggerClassName='flex' disabled={!systemFeatures.webapp_auth.enabled || appDetail.access_mode === AccessMode.EXTERNAL_MEMBERS || userCanAccessApp?.result} popupContent={t('app.noAccessPermission')} asChild={false}>
<SuggestedAction <SuggestedAction
className='flex-1' className='flex-1'
disabled={!publishedAt || (systemFeatures.webapp_auth.enabled && !userCanAccessApp?.result)} disabled={!publishedAt || (systemFeatures.webapp_auth.enabled && appDetail.access_mode !== AccessMode.EXTERNAL_MEMBERS && !userCanAccessApp?.result)}
link={`${appURL}${appURL.includes('?') ? '&' : '?'}mode=batch`} link={`${appURL}${appURL.includes('?') ? '&' : '?'}mode=batch`}
icon={<RiPlayList2Line className='h-4 w-4' />} icon={<RiPlayList2Line className='h-4 w-4' />}
> >

@ -156,12 +156,11 @@ const Debug: FC<IDebug> = ({
} }
let hasEmptyInput = '' let hasEmptyInput = ''
const requiredVars = modelConfig.configs.prompt_variables.filter(({ key, name, required, type }) => { const requiredVars = modelConfig.configs.prompt_variables.filter(({ key, name, required, type }) => {
if (type !== 'string' && type !== 'paragraph' && type !== 'select') if (type !== 'string' && type !== 'paragraph' && type !== 'select' && type !== 'number')
return false return false
const res = (!key || !key.trim()) || (!name || !name.trim()) || (required || required === undefined || required === null) const res = (!key || !key.trim()) || (!name || !name.trim()) || (required || required === undefined || required === null)
return res return res
}) // compatible with old version }) // compatible with old version
// debugger
requiredVars.forEach(({ key, name }) => { requiredVars.forEach(({ key, name }) => {
if (hasEmptyInput) if (hasEmptyInput)
return return

@ -20,6 +20,7 @@ import type {
import { useToastContext } from '@/app/components/base/toast' import { useToastContext } from '@/app/components/base/toast'
import AppIcon from '@/app/components/base/app-icon' import AppIcon from '@/app/components/base/app-icon'
import { noop } from 'lodash-es' import { noop } from 'lodash-es'
import { useDocLink } from '@/context/i18n'
const systemTypes = ['api'] const systemTypes = ['api']
type ExternalDataToolModalProps = { type ExternalDataToolModalProps = {
@ -40,6 +41,7 @@ const ExternalDataToolModal: FC<ExternalDataToolModalProps> = ({
onValidateBeforeSave, onValidateBeforeSave,
}) => { }) => {
const { t } = useTranslation() const { t } = useTranslation()
const docLink = useDocLink()
const { notify } = useToastContext() const { notify } = useToastContext()
const { locale } = useContext(I18n) const { locale } = useContext(I18n)
const [localeData, setLocaleData] = useState(data.type ? data : { ...data, type: 'api' }) const [localeData, setLocaleData] = useState(data.type ? data : { ...data, type: 'api' })
@ -243,7 +245,7 @@ const ExternalDataToolModal: FC<ExternalDataToolModalProps> = ({
<div className='flex h-9 items-center justify-between text-sm font-medium text-gray-900'> <div className='flex h-9 items-center justify-between text-sm font-medium text-gray-900'>
{t('common.apiBasedExtension.selector.title')} {t('common.apiBasedExtension.selector.title')}
<a <a
href={t('common.apiBasedExtension.linkUrl') || '/'} href={docLink('/guides/extension/api-based-extension/README')}
target='_blank' rel='noopener noreferrer' target='_blank' rel='noopener noreferrer'
className='group flex items-center text-xs font-normal text-gray-500 hover:text-primary-600' className='group flex items-center text-xs font-normal text-gray-500 hover:text-primary-600'
> >

@ -1,7 +1,7 @@
'use client' 'use client'
import type { MouseEventHandler } from 'react' import type { MouseEventHandler } from 'react'
import { useMemo, useRef, useState } from 'react' import { useEffect, useMemo, useRef, useState } from 'react'
import { useRouter } from 'next/navigation' import { useRouter } from 'next/navigation'
import { useContext } from 'use-context-selector' import { useContext } from 'use-context-selector'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
@ -35,6 +35,7 @@ type CreateFromDSLModalProps = {
onClose: () => void onClose: () => void
activeTab?: string activeTab?: string
dslUrl?: string dslUrl?: string
droppedFile?: File
} }
export enum CreateFromDSLModalTab { export enum CreateFromDSLModalTab {
@ -42,11 +43,11 @@ export enum CreateFromDSLModalTab {
FROM_URL = 'from-url', FROM_URL = 'from-url',
} }
const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDSLModalTab.FROM_FILE, dslUrl = '' }: CreateFromDSLModalProps) => { const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDSLModalTab.FROM_FILE, dslUrl = '', droppedFile }: CreateFromDSLModalProps) => {
const { push } = useRouter() const { push } = useRouter()
const { t } = useTranslation() const { t } = useTranslation()
const { notify } = useContext(ToastContext) const { notify } = useContext(ToastContext)
const [currentFile, setDSLFile] = useState<File>() const [currentFile, setDSLFile] = useState<File | undefined>(droppedFile)
const [fileContent, setFileContent] = useState<string>() const [fileContent, setFileContent] = useState<string>()
const [currentTab, setCurrentTab] = useState(activeTab) const [currentTab, setCurrentTab] = useState(activeTab)
const [dslUrlValue, setDslUrlValue] = useState(dslUrl) const [dslUrlValue, setDslUrlValue] = useState(dslUrl)
@ -78,6 +79,11 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
const isCreatingRef = useRef(false) const isCreatingRef = useRef(false)
useEffect(() => {
if (droppedFile)
handleFile(droppedFile)
}, [droppedFile])
const onCreate: MouseEventHandler = async () => { const onCreate: MouseEventHandler = async () => {
if (currentTab === CreateFromDSLModalTab.FROM_FILE && !currentFile) if (currentTab === CreateFromDSLModalTab.FROM_FILE && !currentFile)
return return

@ -50,6 +50,10 @@ const OPTION_MAP = {
// user_id: 'YOU CAN DEFINE USER ID HERE', // user_id: 'YOU CAN DEFINE USER ID HERE',
// conversation_id: 'YOU CAN DEFINE CONVERSATION ID HERE, IT MUST BE A VALID UUID', // conversation_id: 'YOU CAN DEFINE CONVERSATION ID HERE, IT MUST BE A VALID UUID',
}, },
userVariables: {
// avatar_url: 'YOU CAN DEFINE USER AVATAR URL HERE',
// name: 'YOU CAN DEFINE USER NAME HERE',
},
} }
</script> </script>
<script <script

@ -25,6 +25,7 @@ import SuggestedQuestions from '@/app/components/base/chat/chat/answer/suggested
import { Markdown } from '@/app/components/base/markdown' import { Markdown } from '@/app/components/base/markdown'
import cn from '@/utils/classnames' import cn from '@/utils/classnames'
import type { FileEntity } from '../../file-uploader/types' import type { FileEntity } from '../../file-uploader/types'
import Avatar from '../../avatar'
const ChatWrapper = () => { const ChatWrapper = () => {
const { const {
@ -49,6 +50,7 @@ const ChatWrapper = () => {
setClearChatList, setClearChatList,
setIsResponding, setIsResponding,
allInputsHidden, allInputsHidden,
initUserVariables,
} = useEmbeddedChatbotContext() } = useEmbeddedChatbotContext()
const appConfig = useMemo(() => { const appConfig = useMemo(() => {
const config = appParams || {} const config = appParams || {}
@ -261,6 +263,14 @@ const ChatWrapper = () => {
switchSibling={siblingMessageId => setTargetMessageId(siblingMessageId)} switchSibling={siblingMessageId => setTargetMessageId(siblingMessageId)}
inputDisabled={inputDisabled} inputDisabled={inputDisabled}
isMobile={isMobile} isMobile={isMobile}
questionIcon={
initUserVariables?.avatar_url
? <Avatar
avatar={initUserVariables.avatar_url}
name={initUserVariables.name || 'user'}
size={40}
/> : undefined
}
/> />
) )
} }

@ -52,6 +52,10 @@ export type EmbeddedChatbotContextValue = {
currentConversationInputs: Record<string, any> | null, currentConversationInputs: Record<string, any> | null,
setCurrentConversationInputs: (v: Record<string, any>) => void, setCurrentConversationInputs: (v: Record<string, any>) => void,
allInputsHidden: boolean allInputsHidden: boolean
initUserVariables?: {
name?: string
avatar_url?: string
}
} }
export const EmbeddedChatbotContext = createContext<EmbeddedChatbotContextValue>({ export const EmbeddedChatbotContext = createContext<EmbeddedChatbotContextValue>({
@ -81,5 +85,6 @@ export const EmbeddedChatbotContext = createContext<EmbeddedChatbotContextValue>
currentConversationInputs: {}, currentConversationInputs: {},
setCurrentConversationInputs: noop, setCurrentConversationInputs: noop,
allInputsHidden: false, allInputsHidden: false,
initUserVariables: {},
}) })
export const useEmbeddedChatbotContext = () => useContext(EmbeddedChatbotContext) export const useEmbeddedChatbotContext = () => useContext(EmbeddedChatbotContext)

@ -15,7 +15,7 @@ import type {
Feedback, Feedback,
} from '../types' } from '../types'
import { CONVERSATION_ID_INFO } from '../constants' import { CONVERSATION_ID_INFO } from '../constants'
import { buildChatItemTree, getProcessedInputsFromUrlParams, getProcessedSystemVariablesFromUrlParams } from '../utils' import { buildChatItemTree, getProcessedInputsFromUrlParams, getProcessedSystemVariablesFromUrlParams, getProcessedUserVariablesFromUrlParams } from '../utils'
import { getProcessedFilesFromResponse } from '../../file-uploader/utils' import { getProcessedFilesFromResponse } from '../../file-uploader/utils'
import { import {
fetchAppInfo, fetchAppInfo,
@ -169,6 +169,7 @@ export const useEmbeddedChatbot = () => {
const newConversationInputsRef = useRef<Record<string, any>>({}) const newConversationInputsRef = useRef<Record<string, any>>({})
const [newConversationInputs, setNewConversationInputs] = useState<Record<string, any>>({}) const [newConversationInputs, setNewConversationInputs] = useState<Record<string, any>>({})
const [initInputs, setInitInputs] = useState<Record<string, any>>({}) const [initInputs, setInitInputs] = useState<Record<string, any>>({})
const [initUserVariables, setInitUserVariables] = useState<Record<string, any>>({})
const handleNewConversationInputsChange = useCallback((newInputs: Record<string, any>) => { const handleNewConversationInputsChange = useCallback((newInputs: Record<string, any>) => {
newConversationInputsRef.current = newInputs newConversationInputsRef.current = newInputs
setNewConversationInputs(newInputs) setNewConversationInputs(newInputs)
@ -237,7 +238,9 @@ export const useEmbeddedChatbot = () => {
// init inputs from url params // init inputs from url params
(async () => { (async () => {
const inputs = await getProcessedInputsFromUrlParams() const inputs = await getProcessedInputsFromUrlParams()
const userVariables = await getProcessedUserVariablesFromUrlParams()
setInitInputs(inputs) setInitInputs(inputs)
setInitUserVariables(userVariables)
})() })()
}, []) }, [])
useEffect(() => { useEffect(() => {
@ -418,5 +421,6 @@ export const useEmbeddedChatbot = () => {
currentConversationInputs, currentConversationInputs,
setCurrentConversationInputs, setCurrentConversationInputs,
allInputsHidden, allInputsHidden,
initUserVariables,
} }
} }

@ -195,6 +195,7 @@ const EmbeddedChatbotWrapper = () => {
currentConversationInputs, currentConversationInputs,
setCurrentConversationInputs, setCurrentConversationInputs,
allInputsHidden, allInputsHidden,
initUserVariables,
} = useEmbeddedChatbot() } = useEmbeddedChatbot()
return <EmbeddedChatbotContext.Provider value={{ return <EmbeddedChatbotContext.Provider value={{
@ -233,6 +234,7 @@ const EmbeddedChatbotWrapper = () => {
currentConversationInputs, currentConversationInputs,
setCurrentConversationInputs, setCurrentConversationInputs,
allInputsHidden, allInputsHidden,
initUserVariables,
}}> }}>
<Chatbot /> <Chatbot />
</EmbeddedChatbotContext.Provider> </EmbeddedChatbotContext.Provider>

@ -32,7 +32,8 @@ async function getProcessedInputsFromUrlParams(): Promise<Record<string, any>> {
const entriesArray = Array.from(urlParams.entries()) const entriesArray = Array.from(urlParams.entries())
await Promise.all( await Promise.all(
entriesArray.map(async ([key, value]) => { entriesArray.map(async ([key, value]) => {
if (!key.startsWith('sys.')) const prefixArray = ['sys.', 'user.']
if (!prefixArray.some(prefix => key.startsWith(prefix)))
inputs[key] = await decodeBase64AndDecompress(decodeURIComponent(value)) inputs[key] = await decodeBase64AndDecompress(decodeURIComponent(value))
}), }),
) )
@ -52,6 +53,19 @@ async function getProcessedSystemVariablesFromUrlParams(): Promise<Record<string
return systemVariables return systemVariables
} }
async function getProcessedUserVariablesFromUrlParams(): Promise<Record<string, any>> {
const urlParams = new URLSearchParams(window.location.search)
const userVariables: Record<string, any> = {}
const entriesArray = Array.from(urlParams.entries())
await Promise.all(
entriesArray.map(async ([key, value]) => {
if (key.startsWith('user.'))
userVariables[key.slice(5)] = await decodeBase64AndDecompress(decodeURIComponent(value))
}),
)
return userVariables
}
function isValidGeneratedAnswer(item?: ChatItem | ChatItemInTree): boolean { function isValidGeneratedAnswer(item?: ChatItem | ChatItemInTree): boolean {
return !!item && item.isAnswer && !item.id.startsWith('answer-placeholder-') && !item.isOpeningStatement return !!item && item.isAnswer && !item.id.startsWith('answer-placeholder-') && !item.isOpeningStatement
} }
@ -198,6 +212,7 @@ export {
getRawInputsFromUrlParams, getRawInputsFromUrlParams,
getProcessedInputsFromUrlParams, getProcessedInputsFromUrlParams,
getProcessedSystemVariablesFromUrlParams, getProcessedSystemVariablesFromUrlParams,
getProcessedUserVariablesFromUrlParams,
isValidGeneratedAnswer, isValidGeneratedAnswer,
getLastAnswer, getLastAnswer,
buildChatItemTree, buildChatItemTree,

@ -25,6 +25,7 @@ import { useModalContext } from '@/context/modal-context'
import { CustomConfigurationStatusEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { CustomConfigurationStatusEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import cn from '@/utils/classnames' import cn from '@/utils/classnames'
import { noop } from 'lodash-es' import { noop } from 'lodash-es'
import { useDocLink } from '@/context/i18n'
const systemTypes = ['openai_moderation', 'keywords', 'api'] const systemTypes = ['openai_moderation', 'keywords', 'api']
@ -46,6 +47,7 @@ const ModerationSettingModal: FC<ModerationSettingModalProps> = ({
onSave, onSave,
}) => { }) => {
const { t } = useTranslation() const { t } = useTranslation()
const docLink = useDocLink()
const { notify } = useToastContext() const { notify } = useToastContext()
const { locale } = useContext(I18n) const { locale } = useContext(I18n)
const { data: modelProviders, isLoading, mutate } = useSWR('/workspaces/current/model-providers', fetchModelProviders) const { data: modelProviders, isLoading, mutate } = useSWR('/workspaces/current/model-providers', fetchModelProviders)
@ -316,7 +318,7 @@ const ModerationSettingModal: FC<ModerationSettingModalProps> = ({
<div className='flex h-9 items-center justify-between'> <div className='flex h-9 items-center justify-between'>
<div className='text-sm font-medium text-text-primary'>{t('common.apiBasedExtension.selector.title')}</div> <div className='text-sm font-medium text-text-primary'>{t('common.apiBasedExtension.selector.title')}</div>
<a <a
href={t('common.apiBasedExtension.linkUrl') || '/'} href={docLink('/guides/extension/api-based-extension/README')}
target='_blank' rel='noopener noreferrer' target='_blank' rel='noopener noreferrer'
className='group flex items-center text-xs text-text-tertiary hover:text-primary-600' className='group flex items-center text-xs text-text-tertiary hover:text-primary-600'
> >

@ -1,4 +1,4 @@
import { memo, useEffect, useMemo, useRef, useState } from 'react' import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'
import ReactEcharts from 'echarts-for-react' import ReactEcharts from 'echarts-for-react'
import SyntaxHighlighter from 'react-syntax-highlighter' import SyntaxHighlighter from 'react-syntax-highlighter'
import { import {
@ -62,6 +62,17 @@ const getCorrectCapitalizationLanguageName = (language: string) => {
// visit https://reactjs.org/docs/error-decoder.html?invariant=185 for the full message // visit https://reactjs.org/docs/error-decoder.html?invariant=185 for the full message
// or use the non-minified dev environment for full errors and additional helpful warnings. // or use the non-minified dev environment for full errors and additional helpful warnings.
// Define ECharts event parameter types
interface EChartsEventParams {
type: string;
seriesIndex?: number;
dataIndex?: number;
name?: string;
value?: any;
currentIndex?: number; // Added for timeline events
[key: string]: any;
}
const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any) => { const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any) => {
const { theme } = useTheme() const { theme } = useTheme()
const [isSVG, setIsSVG] = useState(true) const [isSVG, setIsSVG] = useState(true)
@ -70,6 +81,11 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any
const echartsRef = useRef<any>(null) const echartsRef = useRef<any>(null)
const contentRef = useRef<string>('') const contentRef = useRef<string>('')
const processedRef = useRef<boolean>(false) // Track if content was successfully processed const processedRef = useRef<boolean>(false) // Track if content was successfully processed
const instanceIdRef = useRef<string>(`chart-${Date.now()}-${Math.random().toString(36).substring(2, 9)}`) // Unique ID for logging
const isInitialRenderRef = useRef<boolean>(true) // Track if this is initial render
const chartInstanceRef = useRef<any>(null) // Direct reference to ECharts instance
const resizeTimerRef = useRef<NodeJS.Timeout | null>(null) // For debounce handling
const finishedEventCountRef = useRef<number>(0) // Track finished event trigger count
const match = /language-(\w+)/.exec(className || '') const match = /language-(\w+)/.exec(className || '')
const language = match?.[1] const language = match?.[1]
const languageShowName = getCorrectCapitalizationLanguageName(language || '') const languageShowName = getCorrectCapitalizationLanguageName(language || '')
@ -85,36 +101,64 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any
width: 'auto', width: 'auto',
}) as any, []) }) as any, [])
const echartsOnEvents = useMemo(() => ({ // Debounce resize operations
finished: () => { const debouncedResize = useCallback(() => {
const instance = echartsRef.current?.getEchartsInstance?.() if (resizeTimerRef.current)
if (instance) clearTimeout(resizeTimerRef.current)
instance.resize()
resizeTimerRef.current = setTimeout(() => {
if (chartInstanceRef.current)
chartInstanceRef.current.resize()
resizeTimerRef.current = null
}, 200)
}, [])
// Handle ECharts instance initialization
const handleChartReady = useCallback((instance: any) => {
chartInstanceRef.current = instance
// Force resize to ensure timeline displays correctly
setTimeout(() => {
if (chartInstanceRef.current)
chartInstanceRef.current.resize()
}, 200)
}, [])
// Store event handlers in useMemo to avoid recreating them
const echartsEvents = useMemo(() => ({
finished: (params: EChartsEventParams) => {
// Limit finished event frequency to avoid infinite loops
finishedEventCountRef.current++
if (finishedEventCountRef.current > 3) {
// Stop processing after 3 times to avoid infinite loops
return
}
if (chartInstanceRef.current) {
// Use debounced resize
debouncedResize()
}
}, },
}), [echartsRef]) // echartsRef is stable, so this effectively runs once. }), [debouncedResize])
// Handle container resize for echarts // Handle container resize for echarts
useEffect(() => { useEffect(() => {
if (language !== 'echarts' || !echartsRef.current) return if (language !== 'echarts' || !chartInstanceRef.current) return
const handleResize = () => { const handleResize = () => {
// This gets the echarts instance from the component if (chartInstanceRef.current)
const instance = echartsRef.current?.getEchartsInstance?.() // Use debounced resize
if (instance) debouncedResize()
instance.resize()
} }
window.addEventListener('resize', handleResize) window.addEventListener('resize', handleResize)
// Also manually trigger resize after a short delay to ensure proper sizing
const resizeTimer = setTimeout(handleResize, 200)
return () => { return () => {
window.removeEventListener('resize', handleResize) window.removeEventListener('resize', handleResize)
clearTimeout(resizeTimer) if (resizeTimerRef.current)
clearTimeout(resizeTimerRef.current)
} }
}, [language, echartsRef.current]) }, [language, debouncedResize])
// Process chart data when content changes // Process chart data when content changes
useEffect(() => { useEffect(() => {
// Only process echarts content // Only process echarts content
@ -222,6 +266,7 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any
} }
}, [language, children]) }, [language, children])
// Cache rendered content to avoid unnecessary re-renders
const renderCodeContent = useMemo(() => { const renderCodeContent = useMemo(() => {
const content = String(children).replace(/\n$/, '') const content = String(children).replace(/\n$/, '')
switch (language) { switch (language) {
@ -274,6 +319,9 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any
// Success state: show the chart // Success state: show the chart
if (chartState === 'success' && finalChartOption) { if (chartState === 'success' && finalChartOption) {
// Reset finished event counter
finishedEventCountRef.current = 0
return ( return (
<div style={{ <div style={{
minWidth: '300px', minWidth: '300px',
@ -286,13 +334,20 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any
}}> }}>
<ErrorBoundary> <ErrorBoundary>
<ReactEcharts <ReactEcharts
ref={echartsRef} ref={(e) => {
if (e && isInitialRenderRef.current) {
echartsRef.current = e
isInitialRenderRef.current = false
}
}}
option={finalChartOption} option={finalChartOption}
style={echartsStyle} style={echartsStyle}
theme={isDarkMode ? 'dark' : undefined} theme={isDarkMode ? 'dark' : undefined}
opts={echartsOpts} opts={echartsOpts}
notMerge={true} notMerge={false}
onEvents={echartsOnEvents} lazyUpdate={false}
onEvents={echartsEvents}
onChartReady={handleChartReady}
/> />
</ErrorBoundary> </ErrorBoundary>
</div> </div>
@ -363,7 +418,7 @@ const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any
</SyntaxHighlighter> </SyntaxHighlighter>
) )
} }
}, [children, language, isSVG, finalChartOption, props, theme, match, chartState, isDarkMode, echartsStyle, echartsOpts, echartsOnEvents]) }, [children, language, isSVG, finalChartOption, props, theme, match, chartState, isDarkMode, echartsStyle, echartsOpts, handleChartReady, echartsEvents])
if (inline || !match) if (inline || !match)
return <code {...props} className={className}>{children}</code> return <code {...props} className={className}>{children}</code>

@ -28,7 +28,7 @@ export const preprocessLaTeX = (content: string) => {
} }
export const preprocessThinkTag = (content: string) => { export const preprocessThinkTag = (content: string) => {
const thinkOpenTagRegex = /<think>\n/g const thinkOpenTagRegex = /(<think>\n)+/g
const thinkCloseTagRegex = /\n<\/think>/g const thinkCloseTagRegex = /\n<\/think>/g
return flow([ return flow([
(str: string) => str.replace(thinkOpenTagRegex, '<details data-think=true>\n'), (str: string) => str.replace(thinkOpenTagRegex, '<details data-think=true>\n'),

@ -165,6 +165,7 @@ const ComponentPicker = ({
isSupportFileVar={isSupportFileVar} isSupportFileVar={isSupportFileVar}
onClose={handleClose} onClose={handleClose}
onBlur={handleClose} onBlur={handleClose}
autoFocus={false}
/> />
</div> </div>
) )

@ -533,6 +533,12 @@ Workflow applications offers non-session support and is ideal for translation, a
<Property name='limit' type='int' key='limit'> <Property name='limit' type='int' key='limit'>
How many chat history messages to return in one request, default is 20. How many chat history messages to return in one request, default is 20.
</Property> </Property>
<Property name='created_by_end_user_session_id' type='str' key='created_by_end_user_session_id'>
Created by which endUser, for example, `abc-123`.
</Property>
<Property name='created_by_account' type='str' key='created_by_account'>
Created by which email account, for example, lizb@test.com.
</Property>
</Properties> </Properties>
### Response ### Response

@ -534,6 +534,12 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
<Property name='limit' type='int' key='limit'> <Property name='limit' type='int' key='limit'>
1回のリクエストで返すチャット履歴メッセージの数、デフォルトは20。 1回のリクエストで返すチャット履歴メッセージの数、デフォルトは20。
</Property> </Property>
<Property name='created_by_end_user_session_id' type='str' key='created_by_end_user_session_id'>
どのendUserによって作成されたか、例えば、`abc-123`。
</Property>
<Property name='created_by_account' type='str' key='created_by_account'>
どのメールアカウントによって作成されたか、例えば、lizb@test.com。
</Property>
</Properties> </Properties>
### 応答 ### 応答

@ -522,6 +522,12 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等
<Property name='limit' type='int' key='limit'> <Property name='limit' type='int' key='limit'>
每页条数, 默认20. 每页条数, 默认20.
</Property> </Property>
<Property name='created_by_end_user_session_id' type='str' key='created_by_end_user_session_id'>
由哪个endUser创建例如`abc-123`.
</Property>
<Property name='created_by_account' type='str' key='created_by_account'>
由哪个邮箱账户创建例如lizb@test.com.
</Property>
</Properties> </Properties>
### Response ### Response

@ -31,22 +31,22 @@ const WorkplaceSelector = () => {
} }
return ( return (
<Menu as="div" className="relative h-full w-full"> <Menu as="div" className="min-w-0">
{ {
({ open }) => ( ({ open }) => (
<> <>
<MenuButton className={cn( <MenuButton className={cn(
` `
group flex w-full cursor-pointer items-center group flex w-full cursor-pointer items-center
gap-1.5 p-0.5 hover:bg-state-base-hover ${open && 'bg-state-base-hover'} rounded-[10px] p-0.5 hover:bg-state-base-hover ${open && 'bg-state-base-hover'} rounded-[10px]
`, `,
)}> )}>
<div className='flex h-6 w-6 items-center justify-center rounded-md bg-components-icon-bg-blue-solid text-[13px]'> <div className='mr-1.5 flex h-6 w-6 shrink-0 items-center justify-center rounded-md bg-components-icon-bg-blue-solid text-[13px] max-[800px]:mr-0'>
<span className='h-6 bg-gradient-to-r from-components-avatar-shape-fill-stop-0 to-components-avatar-shape-fill-stop-100 bg-clip-text align-middle font-semibold uppercase leading-6 text-shadow-shadow-1 opacity-90'>{currentWorkspace?.name[0]?.toLocaleUpperCase()}</span> <span className='h-6 bg-gradient-to-r from-components-avatar-shape-fill-stop-0 to-components-avatar-shape-fill-stop-100 bg-clip-text align-middle font-semibold uppercase leading-6 text-shadow-shadow-1 opacity-90'>{currentWorkspace?.name[0]?.toLocaleUpperCase()}</span>
</div> </div>
<div className='flex flex-row'> <div className='flex min-w-0 items-center'>
<div className={'system-sm-medium max-w-[160px] truncate text-text-secondary'}>{currentWorkspace?.name}</div> <div className={'system-sm-medium min-w-0 max-w-[149px] truncate text-text-secondary max-[800px]:hidden'}>{currentWorkspace?.name}</div>
<RiArrowDownSLine className='h-4 w-4 text-text-secondary' /> <RiArrowDownSLine className='h-4 w-4 shrink-0 text-text-secondary' />
</div> </div>
</MenuButton> </MenuButton>
<Transition <Transition
@ -59,10 +59,11 @@ const WorkplaceSelector = () => {
leaveTo="transform opacity-0 scale-95" leaveTo="transform opacity-0 scale-95"
> >
<MenuItems <MenuItems
anchor="bottom start"
className={cn( className={cn(
` `
shadows-shadow-lg absolute left-[-15px] mt-1 flex max-h-[400px] w-[280px] flex-col items-start overflow-y-auto rounded-xl shadows-shadow-lg absolute left-[-15px] z-[1000] mt-1 flex max-h-[400px] w-[280px] flex-col items-start overflow-y-auto
bg-components-panel-bg-blur backdrop-blur-[5px] rounded-xl bg-components-panel-bg-blur backdrop-blur-[5px]
`, `,
)} )}
> >
@ -73,7 +74,7 @@ const WorkplaceSelector = () => {
{ {
workspaces.map(workspace => ( workspaces.map(workspace => (
<div className='flex items-center gap-2 self-stretch rounded-lg py-1 pl-3 pr-2 hover:bg-state-base-hover' key={workspace.id} onClick={() => handleSwitchWorkspace(workspace.id)}> <div className='flex items-center gap-2 self-stretch rounded-lg py-1 pl-3 pr-2 hover:bg-state-base-hover' key={workspace.id} onClick={() => handleSwitchWorkspace(workspace.id)}>
<div className='flex h-6 w-6 items-center justify-center rounded-md bg-components-icon-bg-blue-solid text-[13px]'> <div className='flex h-6 w-6 shrink-0 items-center justify-center rounded-md bg-components-icon-bg-blue-solid text-[13px]'>
<span className='h-6 bg-gradient-to-r from-components-avatar-shape-fill-stop-0 to-components-avatar-shape-fill-stop-100 bg-clip-text align-middle font-semibold uppercase leading-6 text-shadow-shadow-1 opacity-90'>{workspace?.name[0]?.toLocaleUpperCase()}</span> <span className='h-6 bg-gradient-to-r from-components-avatar-shape-fill-stop-0 to-components-avatar-shape-fill-stop-100 bg-clip-text align-middle font-semibold uppercase leading-6 text-shadow-shadow-1 opacity-90'>{workspace?.name[0]?.toLocaleUpperCase()}</span>
</div> </div>
<div className='system-md-regular line-clamp-1 grow cursor-pointer overflow-hidden text-ellipsis text-text-secondary'>{workspace.name}</div> <div className='system-md-regular line-clamp-1 grow cursor-pointer overflow-hidden text-ellipsis text-text-secondary'>{workspace.name}</div>

@ -3,9 +3,11 @@ import {
RiExternalLinkLine, RiExternalLinkLine,
RiPuzzle2Line, RiPuzzle2Line,
} from '@remixicon/react' } from '@remixicon/react'
import { useDocLink } from '@/context/i18n'
const Empty = () => { const Empty = () => {
const { t } = useTranslation() const { t } = useTranslation()
const docLink = useDocLink()
return ( return (
<div className='mb-2 rounded-xl bg-background-section p-6'> <div className='mb-2 rounded-xl bg-background-section p-6'>
@ -15,7 +17,7 @@ const Empty = () => {
<div className='system-sm-medium mb-1 text-text-secondary'>{t('common.apiBasedExtension.title')}</div> <div className='system-sm-medium mb-1 text-text-secondary'>{t('common.apiBasedExtension.title')}</div>
<a <a
className='system-xs-regular flex items-center text-text-accent' className='system-xs-regular flex items-center text-text-accent'
href={t('common.apiBasedExtension.linkUrl') || '/'} href={docLink('/guides/extension/api-based-extension/README')}
target='_blank' rel='noopener noreferrer' target='_blank' rel='noopener noreferrer'
> >
{t('common.apiBasedExtension.link')} {t('common.apiBasedExtension.link')}

@ -1,6 +1,7 @@
import type { FC } from 'react' import type { FC } from 'react'
import { useState } from 'react' import { useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { useDocLink } from '@/context/i18n'
import Modal from '@/app/components/base/modal' import Modal from '@/app/components/base/modal'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import { BookOpen01 } from '@/app/components/base/icons/src/vender/line/education' import { BookOpen01 } from '@/app/components/base/icons/src/vender/line/education'
@ -29,6 +30,7 @@ const ApiBasedExtensionModal: FC<ApiBasedExtensionModalProps> = ({
onSave, onSave,
}) => { }) => {
const { t } = useTranslation() const { t } = useTranslation()
const docLink = useDocLink()
const [localeData, setLocaleData] = useState(data) const [localeData, setLocaleData] = useState(data)
const [loading, setLoading] = useState(false) const [loading, setLoading] = useState(false)
const { notify } = useToastContext() const { notify } = useToastContext()
@ -100,7 +102,7 @@ const ApiBasedExtensionModal: FC<ApiBasedExtensionModalProps> = ({
<div className='flex h-9 items-center justify-between text-sm font-medium text-text-primary'> <div className='flex h-9 items-center justify-between text-sm font-medium text-text-primary'>
{t('common.apiBasedExtension.modal.apiEndpoint.title')} {t('common.apiBasedExtension.modal.apiEndpoint.title')}
<a <a
href={t('common.apiBasedExtension.linkUrl') || '/'} href={docLink('/guides/extension/api-based-extension/README')}
target='_blank' rel='noopener noreferrer' target='_blank' rel='noopener noreferrer'
className='group flex items-center text-xs font-normal text-text-accent' className='group flex items-center text-xs font-normal text-text-accent'
> >

@ -17,9 +17,9 @@ import Loading from '@/app/components/base/loading'
import ProviderCard from '@/app/components/plugins/provider-card' import ProviderCard from '@/app/components/plugins/provider-card'
import List from '@/app/components/plugins/marketplace/list' import List from '@/app/components/plugins/marketplace/list'
import type { Plugin } from '@/app/components/plugins/types' import type { Plugin } from '@/app/components/plugins/types'
import { MARKETPLACE_URL_PREFIX } from '@/config'
import cn from '@/utils/classnames' import cn from '@/utils/classnames'
import { getLocaleOnClient } from '@/i18n' import { getLocaleOnClient } from '@/i18n'
import { getMarketplaceUrl } from '@/utils/var'
type InstallFromMarketplaceProps = { type InstallFromMarketplaceProps = {
providers: ModelProvider[] providers: ModelProvider[]
@ -55,7 +55,7 @@ const InstallFromMarketplace = ({
</div> </div>
<div className='mb-2 flex items-center pt-2'> <div className='mb-2 flex items-center pt-2'>
<span className='system-sm-regular pr-1 text-text-tertiary'>{t('common.modelProvider.discoverMore')}</span> <span className='system-sm-regular pr-1 text-text-tertiary'>{t('common.modelProvider.discoverMore')}</span>
<Link target="_blank" href={`${MARKETPLACE_URL_PREFIX}${theme ? `?theme=${theme}` : ''}`} className='system-sm-medium inline-flex items-center text-text-accent'> <Link target="_blank" href={getMarketplaceUrl('', { theme })} className='system-sm-medium inline-flex items-center text-text-accent'>
{t('plugin.marketplace.difyMarketplace')} {t('plugin.marketplace.difyMarketplace')}
<RiArrowRightUpLine className='h-4 w-4' /> <RiArrowRightUpLine className='h-4 w-4' />
</Link> </Link>

@ -96,7 +96,7 @@ const AppNav = () => {
link, link,
} }
}) })
setNavItems(navItems) setNavItems(navItems as any)
} }
}, [appsData, isCurrentWorkspaceEditor, setNavItems]) }, [appsData, isCurrentWorkspaceEditor, setNavItems])
@ -122,7 +122,7 @@ const AppNav = () => {
text={t('common.menus.apps')} text={t('common.menus.apps')}
activeSegment={['apps', 'app']} activeSegment={['apps', 'app']}
link='/apps' link='/apps'
curNav={appDetail} curNav={appDetail as any}
navs={navItems} navs={navItems}
createText={t('common.menus.newApp')} createText={t('common.menus.newApp')}
onCreate={openModal} onCreate={openModal}

@ -48,7 +48,7 @@ const DatasetNav = () => {
text={t('common.menus.datasets')} text={t('common.menus.datasets')}
activeSegment='datasets' activeSegment='datasets'
link='/datasets' link='/datasets'
curNav={currentDataset as Omit<NavItem, 'link'>} curNav={currentDataset as any}
navs={datasetItems.map(dataset => ({ navs={datasetItems.map(dataset => ({
id: dataset.id, id: dataset.id,
name: dataset.name, name: dataset.name,
@ -59,6 +59,7 @@ const DatasetNav = () => {
createText={t('common.menus.newDataset')} createText={t('common.menus.newDataset')}
onCreate={() => router.push(`${basePath}/datasets/create`)} onCreate={() => router.push(`${basePath}/datasets/create`)}
onLoadmore={handleLoadmore} onLoadmore={handleLoadmore}
isApp={false}
/> />
) )
} }

@ -20,22 +20,22 @@ const EnvNav = () => {
return ( return (
<div className={` <div className={`
mr-4 flex h-[22px] items-center rounded-md border px-2 text-xs font-medium mr-1 flex h-[22px] items-center rounded-md border px-2 text-xs font-medium
${headerEnvClassName[langeniusVersionInfo.current_env]} ${headerEnvClassName[langeniusVersionInfo.current_env]}
`}> `}>
{ {
langeniusVersionInfo.current_env === 'TESTING' && ( langeniusVersionInfo.current_env === 'TESTING' && (
<> <>
<Beaker02 className='mr-1 h-3 w-3' /> <Beaker02 className='h-3 w-3' />
{t('common.environment.testing')} <div className='ml-1 max-[1280px]:hidden'>{t('common.environment.testing')}</div>
</> </>
) )
} }
{ {
langeniusVersionInfo.current_env === 'DEVELOPMENT' && ( langeniusVersionInfo.current_env === 'DEVELOPMENT' && (
<> <>
<TerminalSquare className='mr-1 h-3 w-3' /> <TerminalSquare className='h-3 w-3' />
{t('common.environment.development')} <div className='ml-1 max-[1280px]:hidden'>{t('common.environment.development')}</div>
</> </>
) )
} }

@ -27,10 +27,12 @@ const ExploreNav = ({
)}> )}>
{ {
activated activated
? <RiPlanetFill className='mr-2 h-4 w-4' /> ? <RiPlanetFill className='h-4 w-4' />
: <RiPlanetLine className='mr-2 h-4 w-4' /> : <RiPlanetLine className='h-4 w-4' />
} }
<div className='ml-2 max-[1024px]:hidden'>
{t('common.menus.explore')} {t('common.menus.explore')}
</div>
</Link> </Link>
) )
} }

@ -1,9 +1,6 @@
'use client' 'use client'
import { useCallback, useEffect } from 'react' import { useCallback } from 'react'
import Link from 'next/link' import Link from 'next/link'
import { useBoolean } from 'ahooks'
import { useSelectedLayoutSegment } from 'next/navigation'
import { Bars3Icon } from '@heroicons/react/20/solid'
import AccountDropdown from './account-dropdown' import AccountDropdown from './account-dropdown'
import AppNav from './app-nav' import AppNav from './app-nav'
import DatasetNav from './dataset-nav' import DatasetNav from './dataset-nav'
@ -24,17 +21,15 @@ import { Plan } from '../billing/type'
import { useGlobalPublicStore } from '@/context/global-public-context' import { useGlobalPublicStore } from '@/context/global-public-context'
const navClassName = ` const navClassName = `
flex items-center relative mr-0 sm:mr-3 px-3 h-8 rounded-xl flex items-center relative px-3 h-8 rounded-xl
font-medium text-sm font-medium text-sm
cursor-pointer cursor-pointer
` `
const Header = () => { const Header = () => {
const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator } = useAppContext() const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator } = useAppContext()
const selectedSegment = useSelectedLayoutSegment()
const media = useBreakpoints() const media = useBreakpoints()
const isMobile = media === MediaType.mobile const isMobile = media === MediaType.mobile
const [isShowNavMenu, { toggle, setFalse: hideNavMenu }] = useBoolean(false)
const { enableBilling, plan } = useProviderContext() const { enableBilling, plan } = useProviderContext()
const { setShowPricingModal, setShowAccountSettingModal } = useModalContext() const { setShowPricingModal, setShowAccountSettingModal } = useModalContext()
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
@ -46,23 +41,12 @@ const Header = () => {
setShowAccountSettingModal({ payload: 'billing' }) setShowAccountSettingModal({ payload: 'billing' })
}, [isFreePlan, setShowAccountSettingModal, setShowPricingModal]) }, [isFreePlan, setShowAccountSettingModal, setShowPricingModal])
useEffect(() => { if (isMobile) {
hideNavMenu()
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [selectedSegment])
return ( return (
<div className='relative flex flex-1 items-center justify-between bg-background-body'> <div className=''>
<div className='flex items-center justify-between px-2'>
<div className='flex items-center'> <div className='flex items-center'>
{isMobile && <div <Link href="/apps" className='flex h-8 shrink-0 items-center justify-center px-0.5'>
className='flex h-8 w-8 cursor-pointer items-center justify-center'
onClick={toggle}
>
<Bars3Icon className="h-4 w-4 text-gray-500" />
</div>}
{
!isMobile
&& <div className='flex shrink-0 items-center gap-1.5 self-stretch pl-3'>
<Link href="/apps" className='flex h-8 shrink-0 items-center justify-center gap-2 px-0.5'>
{systemFeatures.branding.enabled && systemFeatures.branding.workspace_logo {systemFeatures.branding.enabled && systemFeatures.branding.workspace_logo
? <img ? <img
src={systemFeatures.branding.workspace_logo} src={systemFeatures.branding.workspace_logo}
@ -71,19 +55,33 @@ const Header = () => {
/> />
: <DifyLogo />} : <DifyLogo />}
</Link> </Link>
<div className='font-light text-divider-deep'>/</div> <div className='mx-1.5 shrink-0 font-light text-divider-deep'>/</div>
<div className='flex items-center gap-0.5'>
<WorkspaceProvider> <WorkspaceProvider>
<WorkplaceSelector /> <WorkplaceSelector />
</WorkspaceProvider> </WorkspaceProvider>
{enableBilling ? <PlanBadge allowHover sandboxAsUpgrade plan={plan.type} onClick={handlePlanClick} /> : <LicenseNav />} {enableBilling ? <PlanBadge allowHover sandboxAsUpgrade plan={plan.type} onClick={handlePlanClick} /> : <LicenseNav />}
</div> </div>
<div className='flex items-center'>
<div className='mr-2'>
<PluginsNav />
</div>
<AccountDropdown />
</div>
</div>
<div className='my-1 flex items-center justify-center space-x-1'>
{!isCurrentWorkspaceDatasetOperator && <ExploreNav className={navClassName} />}
{!isCurrentWorkspaceDatasetOperator && <AppNav />}
{(isCurrentWorkspaceEditor || isCurrentWorkspaceDatasetOperator) && <DatasetNav />}
{!isCurrentWorkspaceDatasetOperator && <ToolsNav className={navClassName} />}
</div> </div>
</div>
)
} }
</div >
{isMobile && ( return (
<div className='flex'> <div className='flex h-[60px] items-center'>
<Link href="/apps" className='mr-4 flex items-center'> <div className='flex min-w-0 flex-[1] items-center pl-3 pr-2 min-[1280px]:pr-3'>
<Link href="/apps" className='flex h-8 shrink-0 items-center justify-center px-0.5'>
{systemFeatures.branding.enabled && systemFeatures.branding.workspace_logo {systemFeatures.branding.enabled && systemFeatures.branding.workspace_logo
? <img ? <img
src={systemFeatures.branding.workspace_logo} src={systemFeatures.branding.workspace_logo}
@ -92,38 +90,26 @@ const Header = () => {
/> />
: <DifyLogo />} : <DifyLogo />}
</Link> </Link>
<div className='font-light text-divider-deep'>/</div> <div className='mx-1.5 shrink-0 font-light text-divider-deep'>/</div>
<WorkspaceProvider>
<WorkplaceSelector />
</WorkspaceProvider>
{enableBilling ? <PlanBadge allowHover sandboxAsUpgrade plan={plan.type} onClick={handlePlanClick} /> : <LicenseNav />} {enableBilling ? <PlanBadge allowHover sandboxAsUpgrade plan={plan.type} onClick={handlePlanClick} /> : <LicenseNav />}
</div > </div>
)} <div className='flex items-center space-x-2'>
{
!isMobile && (
<div className='absolute left-1/2 top-1/2 flex -translate-x-1/2 -translate-y-1/2 items-center'>
{!isCurrentWorkspaceDatasetOperator && <ExploreNav className={navClassName} />} {!isCurrentWorkspaceDatasetOperator && <ExploreNav className={navClassName} />}
{!isCurrentWorkspaceDatasetOperator && <AppNav />} {!isCurrentWorkspaceDatasetOperator && <AppNav />}
{(isCurrentWorkspaceEditor || isCurrentWorkspaceDatasetOperator) && <DatasetNav />} {(isCurrentWorkspaceEditor || isCurrentWorkspaceDatasetOperator) && <DatasetNav />}
{!isCurrentWorkspaceDatasetOperator && <ToolsNav className={navClassName} />} {!isCurrentWorkspaceDatasetOperator && <ToolsNav className={navClassName} />}
</div> </div>
) <div className='flex min-w-0 flex-[1] items-center justify-end pl-2 pr-3 min-[1280px]:pl-3'>
}
<div className='flex shrink-0 items-center pr-3'>
<EnvNav /> <EnvNav />
<div className='mr-2'> <div className='mr-2'>
<PluginsNav /> <PluginsNav />
</div> </div>
<AccountDropdown /> <AccountDropdown />
</div> </div>
{
(isMobile && isShowNavMenu) && (
<div className='flex w-full flex-col gap-y-1 p-2'>
{!isCurrentWorkspaceDatasetOperator && <ExploreNav className={navClassName} />}
{!isCurrentWorkspaceDatasetOperator && <AppNav />}
{(isCurrentWorkspaceEditor || isCurrentWorkspaceDatasetOperator) && <DatasetNav />}
{!isCurrentWorkspaceDatasetOperator && <ToolsNav className={navClassName} />}
</div> </div>
) )
}
</div >
)
} }
export default Header export default Header

@ -46,7 +46,7 @@ const Nav = ({
return ( return (
<div className={` <div className={`
mr-0 flex h-8 shrink-0 items-center rounded-xl px-0.5 text-sm font-medium sm:mr-3 flex h-8 max-w-[670px] shrink-0 items-center rounded-xl px-0.5 text-sm font-medium max-[1024px]:max-w-[400px]
${isActivated && 'bg-components-main-nav-nav-button-bg-active font-semibold shadow-md'} ${isActivated && 'bg-components-main-nav-nav-button-bg-active font-semibold shadow-md'}
${!curNav && !isActivated && 'hover:bg-components-main-nav-nav-button-bg-hover'} ${!curNav && !isActivated && 'hover:bg-components-main-nav-nav-button-bg-hover'}
`}> `}>
@ -61,7 +61,7 @@ const Nav = ({
onMouseEnter={() => setHovered(true)} onMouseEnter={() => setHovered(true)}
onMouseLeave={() => setHovered(false)} onMouseLeave={() => setHovered(false)}
> >
<div className='mr-2'> <div>
{ {
(hovered && curNav) (hovered && curNav)
? <ArrowNarrowLeft className='h-4 w-4' /> ? <ArrowNarrowLeft className='h-4 w-4' />
@ -70,8 +70,10 @@ const Nav = ({
: icon : icon
} }
</div> </div>
<div className='ml-2 max-[1024px]:hidden'>
{text} {text}
</div> </div>
</div>
</Link> </Link>
{ {
curNav && isActivated && ( curNav && isActivated && (

@ -53,15 +53,14 @@ const NavSelector = ({ curNav, navs, createText, isApp, onCreate, onLoadmore }:
}, 50), []) }, 50), [])
return ( return (
<div className=""> <Menu as="div" className="relative">
<Menu as="div" className="relative inline-block text-left">
{({ open }) => ( {({ open }) => (
<> <>
<MenuButton className={cn( <MenuButton className={cn(
'hover:hover:bg-components-main-nav-nav-button-bg-active-hover group inline-flex h-7 w-full items-center justify-center rounded-[10px] pl-2 pr-2.5 text-[14px] font-semibold text-components-main-nav-nav-button-text-active', 'hover:hover:bg-components-main-nav-nav-button-bg-active-hover group inline-flex h-7 w-full items-center justify-center rounded-[10px] pl-2 pr-2.5 text-[14px] font-semibold text-components-main-nav-nav-button-text-active',
open && 'bg-components-main-nav-nav-button-bg-active', open && 'bg-components-main-nav-nav-button-bg-active',
)}> )}>
<div className='max-w-[180px] truncate' title={curNav?.name}>{curNav?.name}</div> <div className='max-w-[157px] truncate' title={curNav?.name}>{curNav?.name}</div>
<RiArrowDownSLine <RiArrowDownSLine
className={cn('ml-1 h-3 w-3 shrink-0 opacity-50 group-hover:opacity-100', open && '!opacity-100')} className={cn('ml-1 h-3 w-3 shrink-0 opacity-50 group-hover:opacity-100', open && '!opacity-100')}
aria-hidden="true" aria-hidden="true"
@ -182,7 +181,6 @@ const NavSelector = ({ curNav, navs, createText, isApp, onCreate, onLoadmore }:
</> </>
)} )}
</Menu> </Menu>
</div>
) )
} }

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save