Merge branch 'main' into feat/knowledge-dark-mode

pull/13379/head
twwu 1 year ago
commit 49674507c6

@ -8,7 +8,7 @@ inputs:
poetry-version: poetry-version:
description: Poetry version to set up description: Poetry version to set up
required: true required: true
default: '1.8.4' default: '2.0.1'
poetry-lockfile: poetry-lockfile:
description: Path to the Poetry lockfile to restore cache from description: Path to the Poetry lockfile to restore cache from
required: true required: true

@ -42,25 +42,23 @@ jobs:
run: poetry install -C api --with dev run: poetry install -C api --with dev
- name: Check dependencies in pyproject.toml - name: Check dependencies in pyproject.toml
run: poetry run -C api bash dev/pytest/pytest_artifacts.sh run: poetry run -P api bash dev/pytest/pytest_artifacts.sh
- name: Run Unit tests - name: Run Unit tests
run: poetry run -C api bash dev/pytest/pytest_unit_tests.sh run: poetry run -P api bash dev/pytest/pytest_unit_tests.sh
- name: Run ModelRuntime - name: Run ModelRuntime
run: poetry run -C api bash dev/pytest/pytest_model_runtime.sh run: poetry run -P api bash dev/pytest/pytest_model_runtime.sh
- name: Run dify config tests - name: Run dify config tests
run: poetry run -C api python dev/pytest/pytest_config_tests.py run: poetry run -P api python dev/pytest/pytest_config_tests.py
- name: Run Tool - name: Run Tool
run: poetry run -C api bash dev/pytest/pytest_tools.sh run: poetry run -P api bash dev/pytest/pytest_tools.sh
- name: Run mypy - name: Run mypy
run: | run: |
pushd api poetry run -C api python -m mypy --install-types --non-interactive .
poetry run python -m mypy --install-types --non-interactive .
popd
- name: Set up dotenvs - name: Set up dotenvs
run: | run: |
@ -80,4 +78,4 @@ jobs:
ssrf_proxy ssrf_proxy
- name: Run Workflow - name: Run Workflow
run: poetry run -C api bash dev/pytest/pytest_workflow.sh run: poetry run -P api bash dev/pytest/pytest_workflow.sh

@ -38,12 +38,12 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: | run: |
poetry run -C api ruff --version poetry run -C api ruff --version
poetry run -C api ruff check ./api poetry run -C api ruff check ./
poetry run -C api ruff format --check ./api poetry run -C api ruff format --check ./
- name: Dotenv check - name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example run: poetry run -P api dotenv-linter ./api/.env.example ./web/.env.example
- name: Lint hints - name: Lint hints
if: failure() if: failure()

@ -70,4 +70,4 @@ jobs:
tidb tidb
- name: Test Vector Stores - name: Test Vector Stores
run: poetry run -C api bash dev/pytest/pytest_vdb.sh run: poetry run -P api bash dev/pytest/pytest_vdb.sh

@ -53,10 +53,12 @@ ignore = [
"FURB152", # math-constant "FURB152", # math-constant
"UP007", # non-pep604-annotation "UP007", # non-pep604-annotation
"UP032", # f-string "UP032", # f-string
"UP045", # non-pep604-annotation-optional
"B005", # strip-with-multi-characters "B005", # strip-with-multi-characters
"B006", # mutable-argument-default "B006", # mutable-argument-default
"B007", # unused-loop-control-variable "B007", # unused-loop-control-variable
"B026", # star-arg-unpacking-after-keyword-arg "B026", # star-arg-unpacking-after-keyword-arg
"B903", # class-as-data-structure
"B904", # raise-without-from-inside-except "B904", # raise-without-from-inside-except
"B905", # zip-without-explicit-strict "B905", # zip-without-explicit-strict
"N806", # non-lowercase-variable-in-function "N806", # non-lowercase-variable-in-function

@ -4,7 +4,7 @@ FROM python:3.12-slim-bookworm AS base
WORKDIR /app/api WORKDIR /app/api
# Install Poetry # Install Poetry
ENV POETRY_VERSION=1.8.4 ENV POETRY_VERSION=2.0.1
# if you located in China, you can use aliyun mirror to speed up # if you located in China, you can use aliyun mirror to speed up
# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/ # RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/

@ -79,5 +79,5 @@
2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml` 2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
```bash ```bash
poetry run -C api bash dev/pytest/pytest_all_tests.sh poetry run -P api bash dev/pytest/pytest_all_tests.sh
``` ```

@ -146,7 +146,7 @@ class EndpointConfig(BaseSettings):
) )
CONSOLE_WEB_URL: str = Field( CONSOLE_WEB_URL: str = Field(
description="Base URL for the console web interface," "used for frontend references and CORS configuration", description="Base URL for the console web interface,used for frontend references and CORS configuration",
default="", default="",
) )

@ -181,7 +181,7 @@ class HostedFetchAppTemplateConfig(BaseSettings):
""" """
HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field( HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field(
description="Mode for fetching app templates: remote, db, or builtin" " default to remote,", description="Mode for fetching app templates: remote, db, or builtin default to remote,",
default="remote", default="remote",
) )

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field( CURRENT_VERSION: str = Field(
description="Dify version", description="Dify version",
default="0.15.0", default="0.15.1",
) )
COMMIT_SHA: str = Field( COMMIT_SHA: str = Field(

@ -56,7 +56,7 @@ class InsertExploreAppListApi(Resource):
app = App.query.filter(App.id == args["app_id"]).first() app = App.query.filter(App.id == args["app_id"]).first()
if not app: if not app:
raise NotFound(f'App \'{args["app_id"]}\' is not found') raise NotFound(f"App '{args['app_id']}' is not found")
site = app.site site = app.site
if not site: if not site:

@ -22,7 +22,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required from libs.login import login_required
from models.model import AppMode from models import App, AppMode
from services.audio_service import AudioService from services.audio_service import AudioService
from services.errors.audio import ( from services.errors.audio import (
AudioTooLargeServiceError, AudioTooLargeServiceError,
@ -79,7 +79,7 @@ class ChatMessageTextApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
def post(self, app_model): def post(self, app_model: App):
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
try: try:
@ -98,9 +98,13 @@ class ChatMessageTextApi(Resource):
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech") text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
if text_to_speech is None:
raise ValueError("TTS is not enabled")
voice = args.get("voice") or text_to_speech.get("voice") voice = args.get("voice") or text_to_speech.get("voice")
else: else:
try: try:
if app_model.app_model_config is None:
raise ValueError("AppModelConfig not found")
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice") voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
except Exception: except Exception:
voice = None voice = None

@ -10,12 +10,7 @@ from controllers.console import api
from controllers.console.apikey import api_key_fields, api_key_list from controllers.console.apikey import api_key_fields, api_key_list
from controllers.console.app.error import ProviderNotInitializeError from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
from controllers.console.wraps import ( from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
account_initialization_required,
cloud_edition_billing_rate_limit_check,
enterprise_license_required,
setup_required,
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.indexing_runner import IndexingRunner from core.indexing_runner import IndexingRunner
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
@ -98,7 +93,6 @@ class DatasetListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument( parser.add_argument(
@ -213,7 +207,6 @@ class DatasetApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id): def patch(self, dataset_id):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
@ -317,7 +310,6 @@ class DatasetApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id): def delete(self, dataset_id):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
@ -465,7 +457,7 @@ class DatasetIndexingEstimateApi(Resource):
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider." "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -627,8 +619,7 @@ class DatasetRetrievalSettingApi(Resource):
vector_type = dify_config.VECTOR_STORE vector_type = dify_config.VECTOR_STORE
match vector_type: match vector_type:
case ( case (
VectorType.MILVUS VectorType.RELYT
| VectorType.RELYT
| VectorType.PGVECTOR | VectorType.PGVECTOR
| VectorType.TIDB_VECTOR | VectorType.TIDB_VECTOR
| VectorType.CHROMA | VectorType.CHROMA
@ -653,6 +644,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.TIDB_ON_QDRANT | VectorType.TIDB_ON_QDRANT
| VectorType.LINDORM | VectorType.LINDORM
| VectorType.COUCHBASE | VectorType.COUCHBASE
| VectorType.MILVUS
): ):
return { return {
"retrieval_method": [ "retrieval_method": [

@ -27,7 +27,6 @@ from controllers.console.datasets.error import (
) )
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
setup_required, setup_required,
) )
@ -231,7 +230,6 @@ class DatasetDocumentListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(documents_and_batch_fields) @marshal_with(documents_and_batch_fields)
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id): def post(self, dataset_id):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -287,7 +285,6 @@ class DatasetDocumentListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id): def delete(self, dataset_id):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -311,7 +308,6 @@ class DatasetInitApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(dataset_and_document_fields) @marshal_with(dataset_and_document_fields)
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self): def post(self):
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor: if not current_user.is_editor:
@ -354,8 +350,7 @@ class DatasetInitApi(Resource):
) )
except InvokeAuthorizationError: except InvokeAuthorizationError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"in the Settings -> Model Provider."
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -530,8 +525,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
return response.model_dump(), 200 return response.model_dump(), 200
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"in the Settings -> Model Provider."
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -684,7 +678,6 @@ class DocumentProcessingApi(DocumentResource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action): def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) document_id = str(document_id)
@ -721,7 +714,6 @@ class DocumentDeleteApi(DocumentResource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id): def delete(self, dataset_id, document_id):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id) document_id = str(document_id)
@ -790,7 +782,6 @@ class DocumentStatusApi(DocumentResource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, action): def patch(self, dataset_id, action):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -886,7 +877,6 @@ class DocumentPauseApi(DocumentResource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id): def patch(self, dataset_id, document_id):
"""pause document.""" """pause document."""
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -919,7 +909,6 @@ class DocumentRecoverApi(DocumentResource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id): def patch(self, dataset_id, document_id):
"""recover document.""" """recover document."""
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -949,7 +938,6 @@ class DocumentRetryApi(DocumentResource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id): def post(self, dataset_id):
"""retry document.""" """retry document."""

@ -19,7 +19,6 @@ from controllers.console.datasets.error import (
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_knowledge_limit_check, cloud_edition_billing_knowledge_limit_check,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
setup_required, setup_required,
) )
@ -107,7 +106,6 @@ class DatasetDocumentSegmentListApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id): def delete(self, dataset_id, document_id):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -139,7 +137,6 @@ class DatasetDocumentSegmentApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, action): def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
@ -171,8 +168,7 @@ class DatasetDocumentSegmentApi(Resource):
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"in the Settings -> Model Provider."
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -195,7 +191,6 @@ class DatasetDocumentSegmentAddApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -221,8 +216,7 @@ class DatasetDocumentSegmentAddApi(Resource):
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"in the Settings -> Model Provider."
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -246,7 +240,6 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id): def patch(self, dataset_id, document_id, segment_id):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -272,8 +265,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"in the Settings -> Model Provider."
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -307,7 +299,6 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id): def delete(self, dataset_id, document_id, segment_id):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -345,7 +336,6 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id): def post(self, dataset_id, document_id):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -412,7 +402,6 @@ class ChildChunkAddApi(Resource):
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id, document_id, segment_id): def post(self, dataset_id, document_id, segment_id):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -445,8 +434,7 @@ class ChildChunkAddApi(Resource):
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"in the Settings -> Model Provider."
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -511,7 +499,6 @@ class ChildChunkAddApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id): def patch(self, dataset_id, document_id, segment_id):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -555,7 +542,6 @@ class ChildChunkUpdateApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id, document_id, segment_id, child_chunk_id): def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -600,7 +586,6 @@ class ChildChunkUpdateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id, document_id, segment_id, child_chunk_id): def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)

@ -2,11 +2,7 @@ from flask_restful import Resource # type: ignore
from controllers.console import api from controllers.console import api
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.console.wraps import ( from controllers.console.wraps import account_initialization_required, setup_required
account_initialization_required,
cloud_edition_billing_rate_limit_check,
setup_required,
)
from libs.login import login_required from libs.login import login_required
@ -14,7 +10,6 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id): def post(self, dataset_id):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)

@ -1,6 +1,5 @@
import json import json
import os import os
import time
from functools import wraps from functools import wraps
from flask import abort, request from flask import abort, request
@ -8,7 +7,6 @@ from flask_login import current_user # type: ignore
from configs import dify_config from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError from controllers.console.workspace.error import AccountNotInitializedError
from extensions.ext_redis import redis_client
from models.model import DifySetup from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus from services.feature_service import FeatureService, LicenseStatus
from services.operation_service import OperationService from services.operation_service import OperationService
@ -68,9 +66,7 @@ def cloud_edition_billing_resource_check(resource: str):
elif resource == "apps" and 0 < apps.limit <= apps.size: elif resource == "apps" and 0 < apps.limit <= apps.size:
abort(403, "The number of apps has reached the limit of your subscription.") abort(403, "The number of apps has reached the limit of your subscription.")
elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
abort( abort(403, "The capacity of the vector space has reached the limit of your subscription.")
403, "The capacity of the knowledge storage space has reached the limit of your subscription."
)
elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
# The api of file upload is used in the multiple places, # The api of file upload is used in the multiple places,
# so we need to check the source of the request from datasets # so we need to check the source of the request from datasets
@ -115,33 +111,6 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
return interceptor return interceptor
def cloud_edition_billing_rate_limit_check(resource: str):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{current_user.current_tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
abort(
403, "Sorry, you have reached the knowledge base request rate limit of your subscription."
)
return view(*args, **kwargs)
return decorated
return interceptor
def cloud_utm_record(view): def cloud_utm_record(view):
@wraps(view) @wraps(view)
def decorated(*args, **kwargs): def decorated(*args, **kwargs):

@ -53,8 +53,7 @@ class SegmentApi(DatasetApiResource):
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"in the Settings -> Model Provider."
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -95,8 +94,7 @@ class SegmentApi(DatasetApiResource):
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"in the Settings -> Model Provider."
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -175,8 +173,7 @@ class DatasetSegmentApi(DatasetApiResource):
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"in the Settings -> Model Provider."
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)

@ -1,4 +1,3 @@
import time
from collections.abc import Callable from collections.abc import Callable
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from enum import Enum from enum import Enum
@ -14,7 +13,6 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, Unauthorized from werkzeug.exceptions import Forbidden, Unauthorized
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.login import _get_user from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
from models.model import ApiToken, App, EndUser from models.model import ApiToken, App, EndUser
@ -141,35 +139,6 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
return interceptor return interceptor
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
def interceptor(view):
@wraps(view)
def decorated(*args, **kwargs):
api_token = validate_and_get_api_token(api_token_type)
if resource == "knowledge":
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(api_token.tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{api_token.tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
raise Forbidden(
"Sorry, you have reached the knowledge base request rate limit of your subscription."
)
return view(*args, **kwargs)
return decorated
return interceptor
def validate_dataset_token(view=None): def validate_dataset_token(view=None):
def decorator(view): def decorator(view):
@wraps(view) @wraps(view)
@ -226,7 +195,11 @@ def validate_and_get_api_token(scope: str | None = None):
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
update_stmt = ( update_stmt = (
update(ApiToken) update(ApiToken)
.where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope) .where(
ApiToken.token == auth_token,
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)),
ApiToken.type == scope,
)
.values(last_used_at=current_time) .values(last_used_at=current_time)
.returning(ApiToken) .returning(ApiToken)
) )

@ -172,7 +172,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
self.save_agent_thought( self.save_agent_thought(
agent_thought=agent_thought, agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else "", tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {}, tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
tool_invoke_meta={}, tool_invoke_meta={},
thought=scratchpad.thought or "", thought=scratchpad.thought or "",

@ -167,8 +167,7 @@ class AppQueueManager:
else: else:
if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"): if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
raise TypeError( raise TypeError(
"Critical Error: Passing SQLAlchemy Model instances " "Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed."
"that cause thread safety issues is not allowed."
) )

@ -89,6 +89,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
Conversation.id == conversation_id, Conversation.id == conversation_id,
Conversation.app_id == app_model.id, Conversation.app_id == app_model.id,
Conversation.status == "normal", Conversation.status == "normal",
Conversation.is_deleted.is_(False),
] ]
if isinstance(user, Account): if isinstance(user, Account):

@ -145,7 +145,7 @@ class MessageCycleManage:
# get extension # get extension
if "." in message_file.url: if "." in message_file.url:
extension = f'.{message_file.url.split(".")[-1]}' extension = f".{message_file.url.split('.')[-1]}"
if len(extension) > 10: if len(extension) > 10:
extension = ".bin" extension = ".bin"
else: else:

@ -62,8 +62,9 @@ class ApiExternalDataTool(ExternalDataTool):
if not api_based_extension: if not api_based_extension:
raise ValueError( raise ValueError(
"[External data tool] API query failed, variable: {}, " "[External data tool] API query failed, variable: {}, error: api_based_extension_id is invalid".format(
"error: api_based_extension_id is invalid".format(self.variable) self.variable
)
) )
# decrypt api_key # decrypt api_key

@ -90,7 +90,7 @@ class File(BaseModel):
def markdown(self) -> str: def markdown(self) -> str:
url = self.generate_url() url = self.generate_url()
if self.type == FileType.IMAGE: if self.type == FileType.IMAGE:
text = f'![{self.filename or ""}]({url})' text = f"![{self.filename or ''}]({url})"
else: else:
text = f"[{self.filename or url}]({url})" text = f"[{self.filename or url}]({url})"

@ -131,7 +131,7 @@ JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = (
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"Please help me predict the three most likely questions that human would ask, " "Please help me predict the three most likely questions that human would ask, "
"and keeping each question under 20 characters.\n" "and keeping each question under 20 characters.\n"
"MAKE SURE your output is the SAME language as the Assistant's latest response" "MAKE SURE your output is the SAME language as the Assistant's latest response. "
"The output must be an array in JSON format following the specified schema:\n" "The output must be an array in JSON format following the specified schema:\n"
'["question1","question2","question3"]\n' '["question1","question2","question3"]\n'
) )

@ -108,7 +108,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel):
ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model) ai_model_entity = self._get_ai_model_entity(base_model_name=base_model_name, model=model)
if not ai_model_entity: if not ai_model_entity:
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid') raise CredentialsValidateFailedError(f"Base Model Name {credentials['base_model_name']} is invalid")
try: try:
client = AzureOpenAI(**self._to_credential_kwargs(credentials)) client = AzureOpenAI(**self._to_credential_kwargs(credentials))

@ -130,7 +130,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
raise CredentialsValidateFailedError("Base Model Name is required") raise CredentialsValidateFailedError("Base Model Name is required")
if not self._get_ai_model_entity(credentials["base_model_name"], model): if not self._get_ai_model_entity(credentials["base_model_name"], model):
raise CredentialsValidateFailedError(f'Base Model Name {credentials["base_model_name"]} is invalid') raise CredentialsValidateFailedError(f"Base Model Name {credentials['base_model_name']} is invalid")
try: try:
credentials_kwargs = self._to_credential_kwargs(credentials) credentials_kwargs = self._to_credential_kwargs(credentials)

@ -70,7 +70,7 @@ class BedrockRerankModel(RerankModel):
rerankingConfiguration = { rerankingConfiguration = {
"type": "BEDROCK_RERANKING_MODEL", "type": "BEDROCK_RERANKING_MODEL",
"bedrockRerankingConfiguration": { "bedrockRerankingConfiguration": {
"numberOfResults": top_n, "numberOfResults": min(top_n, len(text_sources)),
"modelConfiguration": { "modelConfiguration": {
"modelArn": model_package_arn, "modelArn": model_package_arn,
}, },

@ -1,2 +1,3 @@
- deepseek-chat - deepseek-chat
- deepseek-coder - deepseek-coder
- deepseek-reasoner

@ -10,7 +10,7 @@ features:
- stream-tool-call - stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 128000 context_size: 64000
parameter_rules: parameter_rules:
- name: temperature - name: temperature
use_template: temperature use_template: temperature

@ -10,7 +10,7 @@ features:
- stream-tool-call - stream-tool-call
model_properties: model_properties:
mode: chat mode: chat
context_size: 128000 context_size: 64000
parameter_rules: parameter_rules:
- name: temperature - name: temperature
use_template: temperature use_template: temperature

@ -0,0 +1,21 @@
model: deepseek-reasoner
label:
zh_Hans: deepseek-reasoner
en_US: deepseek-reasoner
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 64000
parameter_rules:
- name: max_tokens
use_template: max_tokens
min: 1
max: 8192
default: 4096
pricing:
input: "4"
output: "16"
unit: "0.000001"
currency: RMB

@ -24,9 +24,6 @@ class DeepseekLargeLanguageModel(OAIAPICompatLargeLanguageModel):
user: Optional[str] = None, user: Optional[str] = None,
) -> Union[LLMResult, Generator]: ) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials) self._add_custom_parameters(credentials)
# {"response_format": "xx"} need convert to {"response_format": {"type": "xx"}}
if "response_format" in model_parameters:
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: dict) -> None:

@ -1,5 +1,6 @@
- gemini-2.0-flash-exp - gemini-2.0-flash-exp
- gemini-2.0-flash-thinking-exp-1219 - gemini-2.0-flash-thinking-exp-1219
- gemini-2.0-flash-thinking-exp-01-21
- gemini-1.5-pro - gemini-1.5-pro
- gemini-1.5-pro-latest - gemini-1.5-pro-latest
- gemini-1.5-pro-001 - gemini-1.5-pro-001

@ -0,0 +1,39 @@
model: gemini-2.0-flash-thinking-exp-01-21
label:
en_US: Gemini 2.0 Flash Thinking Exp 01-21
model_type: llm
features:
- agent-thought
- vision
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

@ -162,9 +162,9 @@ class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel
@staticmethod @staticmethod
def _check_endpoint_url_model_repository_name(credentials: dict, model_name: str): def _check_endpoint_url_model_repository_name(credentials: dict, model_name: str):
try: try:
url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}' url = f"{HUGGINGFACE_ENDPOINT_API}{credentials['huggingface_namespace']}"
headers = { headers = {
"Authorization": f'Bearer {credentials["huggingfacehub_api_token"]}', "Authorization": f"Bearer {credentials['huggingfacehub_api_token']}",
"Content-Type": "application/json", "Content-Type": "application/json",
} }

@ -34,6 +34,7 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage
class MinimaxLargeLanguageModel(LargeLanguageModel): class MinimaxLargeLanguageModel(LargeLanguageModel):
model_apis = { model_apis = {
"minimax-text-01": MinimaxChatCompletionPro,
"abab7-chat-preview": MinimaxChatCompletionPro, "abab7-chat-preview": MinimaxChatCompletionPro,
"abab6.5t-chat": MinimaxChatCompletionPro, "abab6.5t-chat": MinimaxChatCompletionPro,
"abab6.5s-chat": MinimaxChatCompletionPro, "abab6.5s-chat": MinimaxChatCompletionPro,

@ -0,0 +1,46 @@
model: minimax-text-01
label:
en_US: Minimax-Text-01
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 1000192
parameter_rules:
- name: temperature
use_template: temperature
min: 0.01
max: 1
default: 0.1
- name: top_p
use_template: top_p
min: 0.01
max: 1
default: 0.95
- name: max_tokens
use_template: max_tokens
required: true
default: 2048
min: 1
max: 1000192
- name: mask_sensitive_info
type: boolean
default: true
label:
zh_Hans: 隐私保护
en_US: Moderate
help:
zh_Hans: 对输出中易涉及隐私问题的文本信息进行打码目前包括但不限于邮箱、域名、链接、证件号、家庭住址等默认true即开启打码
en_US: Mask the sensitive info of the generated content, such as email/domain/link/address/phone/id..
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
pricing:
input: '0.001'
output: '0.008'
unit: '0.001'
currency: RMB

@ -44,9 +44,6 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
self._add_custom_parameters(credentials) self._add_custom_parameters(credentials)
self._add_function_call(model, credentials) self._add_function_call(model, credentials)
user = user[:32] if user else None user = user[:32] if user else None
# {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}}
if "response_format" in model_parameters:
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: dict) -> None:

@ -1,5 +1,6 @@
import json import json
import logging import logging
import re
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional, Union, cast from typing import Any, Optional, Union, cast
@ -621,11 +622,19 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
# o1 compatibility # o1 compatibility
block_as_stream = False
if model.startswith("o1"): if model.startswith("o1"):
if "max_tokens" in model_parameters: if "max_tokens" in model_parameters:
model_parameters["max_completion_tokens"] = model_parameters["max_tokens"] model_parameters["max_completion_tokens"] = model_parameters["max_tokens"]
del model_parameters["max_tokens"] del model_parameters["max_tokens"]
if re.match(r"^o1(-\d{4}-\d{2}-\d{2})?$", model):
if stream:
block_as_stream = True
stream = False
if "stream_options" in extra_model_kwargs:
del extra_model_kwargs["stream_options"]
if "stop" in extra_model_kwargs: if "stop" in extra_model_kwargs:
del extra_model_kwargs["stop"] del extra_model_kwargs["stop"]
@ -642,7 +651,45 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
if stream: if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
if block_as_stream:
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
return block_result
def _handle_chat_block_as_stream_response(
self,
block_result: LLMResult,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
) -> Generator[LLMResultChunk, None, None]:
"""
Handle llm chat response
:param model: model name
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return: llm response chunk generator
"""
text = block_result.message.content
text = cast(str, text)
if stop:
text = self.enforce_stop_tokens(text, stop)
yield LLMResultChunk(
model=block_result.model,
prompt_messages=prompt_messages,
system_fingerprint=block_result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=block_result.message,
finish_reason="stop",
usage=block_result.usage,
),
)
def _handle_chat_generate_response( def _handle_chat_generate_response(
self, self,

@ -29,9 +29,6 @@ class SiliconflowLargeLanguageModel(OAIAPICompatLargeLanguageModel):
user: Optional[str] = None, user: Optional[str] = None,
) -> Union[LLMResult, Generator]: ) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials) self._add_custom_parameters(credentials)
# {"response_format": "json_object"} need convert to {"response_format": {"type": "json_object"}}
if "response_format" in model_parameters:
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: dict) -> None:

@ -21,7 +21,7 @@ class SparkLLMClient:
domain = api_domain domain = api_domain
model_api_configs = { model_api_configs = {
"spark-lite": {"version": "v1.1", "chat_domain": "general"}, "spark-lite": {"version": "v1.1", "chat_domain": "lite"},
"spark-pro": {"version": "v3.1", "chat_domain": "generalv3"}, "spark-pro": {"version": "v3.1", "chat_domain": "generalv3"},
"spark-pro-128k": {"version": "pro-128k", "chat_domain": "pro-128k"}, "spark-pro-128k": {"version": "pro-128k", "chat_domain": "pro-128k"},
"spark-max": {"version": "v3.5", "chat_domain": "generalv3.5"}, "spark-max": {"version": "v3.5", "chat_domain": "generalv3.5"},

@ -257,8 +257,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
for index, response in enumerate(responses): for index, response in enumerate(responses):
if response.status_code not in {200, HTTPStatus.OK}: if response.status_code not in {200, HTTPStatus.OK}:
raise ServiceUnavailableError( raise ServiceUnavailableError(
f"Failed to invoke model {model}, status code: {response.status_code}, " f"Failed to invoke model {model}, status code: {response.status_code}, message: {response.message}"
f"message: {response.message}"
) )
resp_finish_reason = response.output.choices[0].finish_reason resp_finish_reason = response.output.choices[0].finish_reason

@ -146,7 +146,7 @@ class TritonInferenceAILargeLanguageModel(LargeLanguageModel):
elif credentials["completion_type"] == "completion": elif credentials["completion_type"] == "completion":
completion_type = LLMMode.COMPLETION.value completion_type = LLMMode.COMPLETION.value
else: else:
raise ValueError(f'completion_type {credentials["completion_type"]} is not supported') raise ValueError(f"completion_type {credentials['completion_type']} is not supported")
entity = AIModelEntity( entity = AIModelEntity(
model=model, model=model,

@ -41,15 +41,15 @@ class BaiduAccessToken:
resp = response.json() resp = response.json()
if "error" in resp: if "error" in resp:
if resp["error"] == "invalid_client": if resp["error"] == "invalid_client":
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}') raise InvalidAPIKeyError(f"Invalid API key or secret key: {resp['error_description']}")
elif resp["error"] == "unknown_error": elif resp["error"] == "unknown_error":
raise InternalServerError(f'Internal server error: {resp["error_description"]}') raise InternalServerError(f"Internal server error: {resp['error_description']}")
elif resp["error"] == "invalid_request": elif resp["error"] == "invalid_request":
raise BadRequestError(f'Bad request: {resp["error_description"]}') raise BadRequestError(f"Bad request: {resp['error_description']}")
elif resp["error"] == "rate_limit_exceeded": elif resp["error"] == "rate_limit_exceeded":
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}') raise RateLimitReachedError(f"Rate limit reached: {resp['error_description']}")
else: else:
raise Exception(f'Unknown error: {resp["error_description"]}') raise Exception(f"Unknown error: {resp['error_description']}")
return resp["access_token"] return resp["access_token"]

@ -406,7 +406,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
elif credentials["completion_type"] == "completion": elif credentials["completion_type"] == "completion":
completion_type = LLMMode.COMPLETION.value completion_type = LLMMode.COMPLETION.value
else: else:
raise ValueError(f'completion_type {credentials["completion_type"]} is not supported') raise ValueError(f"completion_type {credentials['completion_type']} is not supported")
else: else:
extra_args = XinferenceHelper.get_xinference_extra_parameter( extra_args = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials["server_url"], server_url=credentials["server_url"],
@ -472,7 +472,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
api_key = credentials.get("api_key") or "abc" api_key = credentials.get("api_key") or "abc"
client = OpenAI( client = OpenAI(
base_url=f'{credentials["server_url"]}/v1', base_url=f"{credentials['server_url']}/v1",
api_key=api_key, api_key=api_key,
max_retries=int(credentials.get("max_retries") or DEFAULT_MAX_RETRIES), max_retries=int(credentials.get("max_retries") or DEFAULT_MAX_RETRIES),
timeout=int(credentials.get("invoke_timeout") or DEFAULT_INVOKE_TIMEOUT), timeout=int(credentials.get("invoke_timeout") or DEFAULT_INVOKE_TIMEOUT),

@ -6,6 +6,7 @@ from pydantic import BaseModel, ValidationInfo, field_validator
class TracingProviderEnum(Enum): class TracingProviderEnum(Enum):
LANGFUSE = "langfuse" LANGFUSE = "langfuse"
LANGSMITH = "langsmith" LANGSMITH = "langsmith"
OPIK = "opik"
class BaseTracingConfig(BaseModel): class BaseTracingConfig(BaseModel):
@ -56,5 +57,36 @@ class LangSmithConfig(BaseTracingConfig):
return v return v
class OpikConfig(BaseTracingConfig):
"""
Model class for Opik tracing config.
"""
api_key: str | None = None
project: str | None = None
workspace: str | None = None
url: str = "https://www.comet.com/opik/api/"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "Default Project"
return v
@field_validator("url")
@classmethod
def url_validator(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "https://www.comet.com/opik/api/"
if not v.startswith(("https://", "http://")):
raise ValueError("url must start with https:// or http://")
if not v.endswith("/api/"):
raise ValueError("url should ends with /api/")
return v
OPS_FILE_PATH = "ops_trace/" OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE" OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"

@ -0,0 +1,469 @@
import json
import logging
import os
import uuid
from datetime import datetime, timedelta
from typing import Optional, cast
from opik import Opik, Trace
from opik.id_helpers import uuid4_to_uuid7
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import OpikConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
TraceTaskName,
WorkflowTraceInfo,
)
from extensions.ext_database import db
from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__)
def wrap_dict(key_name, data):
"""Make sure that the input data is a dict"""
if not isinstance(data, dict):
return {key_name: data}
return data
def wrap_metadata(metadata, **kwargs):
"""Add common metatada to all Traces and Spans"""
metadata["created_from"] = "dify"
metadata.update(kwargs)
return metadata
def prepare_opik_uuid(user_datetime: Optional[datetime], user_uuid: Optional[str]):
"""Opik needs UUIDv7 while Dify uses UUIDv4 for identifier of most
messages and objects. The type-hints of BaseTraceInfo indicates that
objects start_time and message_id could be null which means we cannot map
it to a UUIDv7. Given that we have no way to identify that object
uniquely, generate a new random one UUIDv7 in that case.
"""
if user_datetime is None:
user_datetime = datetime.now()
if user_uuid is None:
user_uuid = str(uuid.uuid4())
return uuid4_to_uuid7(user_datetime, user_uuid)
class OpikDataTrace(BaseTraceInstance):
def __init__(
self,
opik_config: OpikConfig,
):
super().__init__(opik_config)
self.opik_client = Opik(
project_name=opik_config.project,
workspace=opik_config.workspace,
host=opik_config.url,
api_key=opik_config.api_key,
)
self.project = opik_config.project
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
def trace(self, trace_info: BaseTraceInfo):
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
self.moderation_trace(trace_info)
if isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo):
dify_trace_id = trace_info.workflow_run_id
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
workflow_metadata = wrap_metadata(
trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id
)
root_span_id = None
if trace_info.message_id:
dify_trace_id = trace_info.message_id
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
trace_data = {
"id": opik_trace_id,
"name": TraceTaskName.MESSAGE_TRACE.value,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": workflow_metadata,
"input": wrap_dict("input", trace_info.workflow_run_inputs),
"output": wrap_dict("output", trace_info.workflow_run_outputs),
"tags": ["message", "workflow"],
"project_name": self.project,
}
self.add_trace(trace_data)
root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id)
span_data = {
"id": root_span_id,
"parent_span_id": None,
"trace_id": opik_trace_id,
"name": TraceTaskName.WORKFLOW_TRACE.value,
"input": wrap_dict("input", trace_info.workflow_run_inputs),
"output": wrap_dict("output", trace_info.workflow_run_outputs),
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": workflow_metadata,
"tags": ["workflow"],
"project_name": self.project,
}
self.add_span(span_data)
else:
trace_data = {
"id": opik_trace_id,
"name": TraceTaskName.MESSAGE_TRACE.value,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": workflow_metadata,
"input": wrap_dict("input", trace_info.workflow_run_inputs),
"output": wrap_dict("output", trace_info.workflow_run_outputs),
"tags": ["workflow"],
"project_name": self.project,
}
self.add_trace(trace_data)
# through workflow_run_id get all_nodes_execution
workflow_nodes_execution_id_records = (
db.session.query(WorkflowNodeExecution.id)
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
.all()
)
for node_execution_id_record in workflow_nodes_execution_id_records:
node_execution = (
db.session.query(
WorkflowNodeExecution.id,
WorkflowNodeExecution.tenant_id,
WorkflowNodeExecution.app_id,
WorkflowNodeExecution.title,
WorkflowNodeExecution.node_type,
WorkflowNodeExecution.status,
WorkflowNodeExecution.inputs,
WorkflowNodeExecution.outputs,
WorkflowNodeExecution.created_at,
WorkflowNodeExecution.elapsed_time,
WorkflowNodeExecution.process_data,
WorkflowNodeExecution.execution_metadata,
)
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
.first()
)
if not node_execution:
continue
node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id
app_id = node_execution.app_id
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == "llm":
inputs = (
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
)
else:
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
execution_metadata = (
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
)
metadata = execution_metadata.copy()
metadata.update(
{
"workflow_run_id": trace_info.workflow_run_id,
"node_execution_id": node_execution_id,
"tenant_id": tenant_id,
"app_id": app_id,
"app_name": node_name,
"node_type": node_type,
"status": status,
}
)
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
provider = None
model = None
total_tokens = 0
completion_tokens = 0
prompt_tokens = 0
if process_data and process_data.get("model_mode") == "chat":
run_type = "llm"
provider = process_data.get("model_provider", None)
model = process_data.get("model_name", "")
metadata.update(
{
"ls_provider": provider,
"ls_model_name": model,
}
)
try:
if outputs.get("usage"):
total_tokens = outputs["usage"].get("total_tokens", 0)
prompt_tokens = outputs["usage"].get("prompt_tokens", 0)
completion_tokens = outputs["usage"].get("completion_tokens", 0)
except Exception:
logger.error("Failed to extract usage", exc_info=True)
else:
run_type = "tool"
parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
if not total_tokens:
total_tokens = execution_metadata.get("total_tokens", 0)
span_data = {
"trace_id": opik_trace_id,
"id": prepare_opik_uuid(created_at, node_execution_id),
"parent_span_id": prepare_opik_uuid(trace_info.start_time, parent_span_id),
"name": node_type,
"type": run_type,
"start_time": created_at,
"end_time": finished_at,
"metadata": wrap_metadata(metadata),
"input": wrap_dict("input", inputs),
"output": wrap_dict("output", outputs),
"tags": ["node_execution"],
"project_name": self.project,
"usage": {
"total_tokens": total_tokens,
"completion_tokens": completion_tokens,
"prompt_tokens": prompt_tokens,
},
"model": model,
"provider": provider,
}
self.add_span(span_data)
def message_trace(self, trace_info: MessageTraceInfo):
# get message file data
file_list = cast(list[str], trace_info.file_list) or []
message_file_data: Optional[MessageFile] = trace_info.message_file_data
if message_file_data is not None:
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
file_list.append(file_url)
message_data = trace_info.message_data
if message_data is None:
return
metadata = trace_info.metadata
message_id = trace_info.message_id
user_id = message_data.from_account_id
metadata["user_id"] = user_id
metadata["file_list"] = file_list
if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
end_user_id = end_user_data.session_id
metadata["end_user_id"] = end_user_id
trace_data = {
"id": prepare_opik_uuid(trace_info.start_time, message_id),
"name": TraceTaskName.MESSAGE_TRACE.value,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(metadata),
"input": trace_info.inputs,
"output": message_data.answer,
"tags": ["message", str(trace_info.conversation_mode)],
"project_name": self.project,
}
trace = self.add_trace(trace_data)
span_data = {
"trace_id": trace.id,
"name": "llm",
"type": "llm",
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(metadata),
"input": {"input": trace_info.inputs},
"output": {"output": message_data.answer},
"tags": ["llm", str(trace_info.conversation_mode)],
"usage": {
"completion_tokens": trace_info.answer_tokens,
"prompt_tokens": trace_info.message_tokens,
"total_tokens": trace_info.total_tokens,
},
"project_name": self.project,
}
self.add_span(span_data)
def moderation_trace(self, trace_info: ModerationTraceInfo):
if trace_info.message_data is None:
return
start_time = trace_info.start_time or trace_info.message_data.created_at
span_data = {
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
"name": TraceTaskName.MODERATION_TRACE.value,
"type": "tool",
"start_time": start_time,
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
"metadata": wrap_metadata(trace_info.metadata),
"input": wrap_dict("input", trace_info.inputs),
"output": {
"action": trace_info.action,
"flagged": trace_info.flagged,
"preset_response": trace_info.preset_response,
"inputs": trace_info.inputs,
},
"tags": ["moderation"],
}
self.add_span(span_data)
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
message_data = trace_info.message_data
if message_data is None:
return
start_time = trace_info.start_time or message_data.created_at
span_data = {
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
"name": TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
"type": "tool",
"start_time": start_time,
"end_time": trace_info.end_time or message_data.updated_at,
"metadata": wrap_metadata(trace_info.metadata),
"input": wrap_dict("input", trace_info.inputs),
"output": wrap_dict("output", trace_info.suggested_question),
"tags": ["suggested_question"],
}
self.add_span(span_data)
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
start_time = trace_info.start_time or trace_info.message_data.created_at
span_data = {
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
"name": TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
"type": "tool",
"start_time": start_time,
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
"metadata": wrap_metadata(trace_info.metadata),
"input": wrap_dict("input", trace_info.inputs),
"output": {"documents": trace_info.documents},
"tags": ["dataset_retrieval"],
}
self.add_span(span_data)
def tool_trace(self, trace_info: ToolTraceInfo):
span_data = {
"trace_id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id),
"name": trace_info.tool_name,
"type": "tool",
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(trace_info.metadata),
"input": wrap_dict("input", trace_info.tool_inputs),
"output": wrap_dict("output", trace_info.tool_outputs),
"tags": ["tool", trace_info.tool_name],
}
self.add_span(span_data)
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
trace_data = {
"id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id),
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(trace_info.metadata),
"input": trace_info.inputs,
"output": trace_info.outputs,
"tags": ["generate_name"],
"project_name": self.project,
}
trace = self.add_trace(trace_data)
span_data = {
"trace_id": trace.id,
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
"start_time": trace_info.start_time,
"end_time": trace_info.end_time,
"metadata": wrap_metadata(trace_info.metadata),
"input": wrap_dict("input", trace_info.inputs),
"output": wrap_dict("output", trace_info.outputs),
"tags": ["generate_name"],
}
self.add_span(span_data)
def add_trace(self, opik_trace_data: dict) -> Trace:
try:
trace = self.opik_client.trace(**opik_trace_data)
logger.debug("Opik Trace created successfully")
return trace
except Exception as e:
raise ValueError(f"Opik Failed to create trace: {str(e)}")
def add_span(self, opik_span_data: dict):
try:
self.opik_client.span(**opik_span_data)
logger.debug("Opik Span created successfully")
except Exception as e:
raise ValueError(f"Opik Failed to create span: {str(e)}")
def api_check(self):
try:
self.opik_client.auth_check()
return True
except Exception as e:
logger.info(f"Opik API check failed: {str(e)}", exc_info=True)
raise ValueError(f"Opik API check failed: {str(e)}")
def get_project_url(self):
try:
return self.opik_client.get_project_url(project_name=self.project)
except Exception as e:
logger.info(f"Opik get run url failed: {str(e)}", exc_info=True)
raise ValueError(f"Opik get run url failed: {str(e)}")

@ -17,6 +17,7 @@ from core.ops.entities.config_entity import (
OPS_FILE_PATH, OPS_FILE_PATH,
LangfuseConfig, LangfuseConfig,
LangSmithConfig, LangSmithConfig,
OpikConfig,
TracingProviderEnum, TracingProviderEnum,
) )
from core.ops.entities.trace_entity import ( from core.ops.entities.trace_entity import (
@ -32,6 +33,7 @@ from core.ops.entities.trace_entity import (
) )
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
from core.ops.opik_trace.opik_trace import OpikDataTrace
from core.ops.utils import get_message_data from core.ops.utils import get_message_data
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_storage import storage from extensions.ext_storage import storage
@ -52,6 +54,12 @@ provider_config_map: dict[str, dict[str, Any]] = {
"other_keys": ["project", "endpoint"], "other_keys": ["project", "endpoint"],
"trace_instance": LangSmithDataTrace, "trace_instance": LangSmithDataTrace,
}, },
TracingProviderEnum.OPIK.value: {
"config_class": OpikConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "url", "workspace"],
"trace_instance": OpikDataTrace,
},
} }

@ -22,7 +22,12 @@ from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.helper.position_helper import is_filtered from core.helper.position_helper import is_filtered
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity from core.model_runtime.entities.provider_entities import (
ConfigurateMethod,
CredentialFormSchema,
FormType,
ProviderEntity,
)
from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers import model_provider_factory
from extensions import ext_hosting_provider from extensions import ext_hosting_provider
from extensions.ext_database import db from extensions.ext_database import db
@ -835,6 +840,13 @@ class ProviderManager:
:return: :return:
""" """
# Get provider model credential secret variables # Get provider model credential secret variables
if ConfigurateMethod.PREDEFINED_MODEL in provider_entity.configurate_methods:
model_credential_secret_variables = self._extract_secret_variables(
provider_entity.provider_credential_schema.credential_form_schemas
if provider_entity.provider_credential_schema
else []
)
else:
model_credential_secret_variables = self._extract_secret_variables( model_credential_secret_variables = self._extract_secret_variables(
provider_entity.model_credential_schema.credential_form_schemas provider_entity.model_credential_schema.credential_form_schemas
if provider_entity.model_credential_schema if provider_entity.model_credential_schema

@ -31,7 +31,7 @@ class FirecrawlApp:
"markdown": data.get("markdown"), "markdown": data.get("markdown"),
} }
else: else:
raise Exception(f'Failed to scrape URL. Error: {response_data["error"]}') raise Exception(f"Failed to scrape URL. Error: {response_data['error']}")
elif response.status_code in {402, 409, 500}: elif response.status_code in {402, 409, 500}:
error_message = response.json().get("error", "Unknown error occurred") error_message = response.json().get("error", "Unknown error occurred")

@ -358,8 +358,7 @@ class NotionExtractor(BaseExtractor):
if not data_source_binding: if not data_source_binding:
raise Exception( raise Exception(
f"No notion data source binding found for tenant {tenant_id} " f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}"
f"and notion workspace {notion_workspace_id}"
) )
return cast(str, data_source_binding.access_token) return cast(str, data_source_binding.access_token)

@ -127,7 +127,7 @@ class AIPPTGenerateToolAdapter:
response = response.json() response = response.json()
if response.get("code") != 0: if response.get("code") != 0:
raise Exception(f'Failed to create task: {response.get("msg")}') raise Exception(f"Failed to create task: {response.get('msg')}")
return response.get("data", {}).get("id") return response.get("data", {}).get("id")
@ -222,7 +222,7 @@ class AIPPTGenerateToolAdapter:
elif model == "wenxin": elif model == "wenxin":
response = response.json() response = response.json()
if response.get("code") != 0: if response.get("code") != 0:
raise Exception(f'Failed to generate content: {response.get("msg")}') raise Exception(f"Failed to generate content: {response.get('msg')}")
return response.get("data", "") return response.get("data", "")
@ -254,7 +254,7 @@ class AIPPTGenerateToolAdapter:
response = response.json() response = response.json()
if response.get("code") != 0: if response.get("code") != 0:
raise Exception(f'Failed to generate ppt: {response.get("msg")}') raise Exception(f"Failed to generate ppt: {response.get('msg')}")
id = response.get("data", {}).get("id") id = response.get("data", {}).get("id")
cover_url = response.get("data", {}).get("cover_url") cover_url = response.get("data", {}).get("cover_url")
@ -270,7 +270,7 @@ class AIPPTGenerateToolAdapter:
response = response.json() response = response.json()
if response.get("code") != 0: if response.get("code") != 0:
raise Exception(f'Failed to generate ppt: {response.get("msg")}') raise Exception(f"Failed to generate ppt: {response.get('msg')}")
export_code = response.get("data") export_code = response.get("data")
if not export_code: if not export_code:
@ -290,7 +290,7 @@ class AIPPTGenerateToolAdapter:
response = response.json() response = response.json()
if response.get("code") != 0: if response.get("code") != 0:
raise Exception(f'Failed to generate ppt: {response.get("msg")}') raise Exception(f"Failed to generate ppt: {response.get('msg')}")
if response.get("msg") == "导出中": if response.get("msg") == "导出中":
current_iteration += 1 current_iteration += 1
@ -343,7 +343,7 @@ class AIPPTGenerateToolAdapter:
raise Exception(f"Failed to connect to aippt: {response.text}") raise Exception(f"Failed to connect to aippt: {response.text}")
response = response.json() response = response.json()
if response.get("code") != 0: if response.get("code") != 0:
raise Exception(f'Failed to connect to aippt: {response.get("msg")}') raise Exception(f"Failed to connect to aippt: {response.get('msg')}")
token = response.get("data", {}).get("token") token = response.get("data", {}).get("token")
expire = response.get("data", {}).get("time_expire") expire = response.get("data", {}).get("time_expire")
@ -379,7 +379,7 @@ class AIPPTGenerateToolAdapter:
if cls._style_cache[key]["expire"] < now: if cls._style_cache[key]["expire"] < now:
del cls._style_cache[key] del cls._style_cache[key]
key = f'{credentials["aippt_access_key"]}#@#{user_id}' key = f"{credentials['aippt_access_key']}#@#{user_id}"
if key in cls._style_cache: if key in cls._style_cache:
return cls._style_cache[key]["colors"], cls._style_cache[key]["styles"] return cls._style_cache[key]["colors"], cls._style_cache[key]["styles"]
@ -396,11 +396,11 @@ class AIPPTGenerateToolAdapter:
response = response.json() response = response.json()
if response.get("code") != 0: if response.get("code") != 0:
raise Exception(f'Failed to connect to aippt: {response.get("msg")}') raise Exception(f"Failed to connect to aippt: {response.get('msg')}")
colors = [ colors = [
{ {
"id": f'id-{item.get("id")}', "id": f"id-{item.get('id')}",
"name": item.get("name"), "name": item.get("name"),
"en_name": item.get("en_name", item.get("name")), "en_name": item.get("en_name", item.get("name")),
} }
@ -408,7 +408,7 @@ class AIPPTGenerateToolAdapter:
] ]
styles = [ styles = [
{ {
"id": f'id-{item.get("id")}', "id": f"id-{item.get('id')}",
"name": item.get("title"), "name": item.get("title"),
} }
for item in response.get("data", {}).get("suit_style") or [] for item in response.get("data", {}).get("suit_style") or []
@ -454,7 +454,7 @@ class AIPPTGenerateToolAdapter:
response = response.json() response = response.json()
if response.get("code") != 0: if response.get("code") != 0:
raise Exception(f'Failed to connect to aippt: {response.get("msg")}') raise Exception(f"Failed to connect to aippt: {response.get('msg')}")
if len(response.get("data", {}).get("list") or []) > 0: if len(response.get("data", {}).get("list") or []) > 0:
return response.get("data", {}).get("list")[0].get("id") return response.get("data", {}).get("list")[0].get("id")

@ -229,8 +229,7 @@ class NovaReelTool(BuiltinTool):
if async_mode: if async_mode:
return self.create_text_message( return self.create_text_message(
f"Video generation started.\nInvocation ARN: {invocation_arn}\n" f"Video generation started.\nInvocation ARN: {invocation_arn}\nVideo will be available at: {video_uri}"
f"Video will be available at: {video_uri}"
) )
return self._wait_for_completion(bedrock, s3_client, invocation_arn) return self._wait_for_completion(bedrock, s3_client, invocation_arn)

@ -65,7 +65,7 @@ class BaiduFieldTranslateTool(BuiltinTool, BaiduTranslateToolBase):
if "trans_result" in result: if "trans_result" in result:
result_text = result["trans_result"][0]["dst"] result_text = result["trans_result"][0]["dst"]
else: else:
result_text = f'{result["error_code"]}: {result["error_msg"]}' result_text = f"{result['error_code']}: {result['error_msg']}"
return self.create_text_message(str(result_text)) return self.create_text_message(str(result_text))
except requests.RequestException as e: except requests.RequestException as e:

@ -52,7 +52,7 @@ class BaiduLanguageTool(BuiltinTool, BaiduTranslateToolBase):
result_text = "" result_text = ""
if result["error_code"] != 0: if result["error_code"] != 0:
result_text = f'{result["error_code"]}: {result["error_msg"]}' result_text = f"{result['error_code']}: {result['error_msg']}"
else: else:
result_text = result["data"]["src"] result_text = result["data"]["src"]
result_text = self.mapping_result(description_language, result_text) result_text = self.mapping_result(description_language, result_text)

@ -58,7 +58,7 @@ class BaiduTranslateTool(BuiltinTool, BaiduTranslateToolBase):
if "trans_result" in result: if "trans_result" in result:
result_text = result["trans_result"][0]["dst"] result_text = result["trans_result"][0]["dst"]
else: else:
result_text = f'{result["error_code"]}: {result["error_msg"]}' result_text = f"{result['error_code']}: {result['error_msg']}"
return self.create_text_message(str(result_text)) return self.create_text_message(str(result_text))
except requests.RequestException as e: except requests.RequestException as e:

@ -30,7 +30,7 @@ class BingSearchTool(BuiltinTool):
headers = {"Ocp-Apim-Subscription-Key": subscription_key, "Accept-Language": accept_language} headers = {"Ocp-Apim-Subscription-Key": subscription_key, "Accept-Language": accept_language}
query = quote(query) query = quote(query)
server_url = f'{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={",".join(filters)}' server_url = f"{server_url}?q={query}&mkt={market_code}&count={limit}&responseFilter={','.join(filters)}"
response = get(server_url, headers=headers) response = get(server_url, headers=headers)
if response.status_code != 200: if response.status_code != 200:
@ -47,23 +47,23 @@ class BingSearchTool(BuiltinTool):
results = [] results = []
if search_results: if search_results:
for result in search_results: for result in search_results:
url = f': {result["url"]}' if "url" in result else "" url = f": {result['url']}" if "url" in result else ""
results.append(self.create_text_message(text=f'{result["name"]}{url}')) results.append(self.create_text_message(text=f"{result['name']}{url}"))
if entities: if entities:
for entity in entities: for entity in entities:
url = f': {entity["url"]}' if "url" in entity else "" url = f": {entity['url']}" if "url" in entity else ""
results.append(self.create_text_message(text=f'{entity.get("name", "")}{url}')) results.append(self.create_text_message(text=f"{entity.get('name', '')}{url}"))
if news: if news:
for news_item in news: for news_item in news:
url = f': {news_item["url"]}' if "url" in news_item else "" url = f": {news_item['url']}" if "url" in news_item else ""
results.append(self.create_text_message(text=f'{news_item.get("name", "")}{url}')) results.append(self.create_text_message(text=f"{news_item.get('name', '')}{url}"))
if related_searches: if related_searches:
for related in related_searches: for related in related_searches:
url = f': {related["displayText"]}' if "displayText" in related else "" url = f": {related['displayText']}" if "displayText" in related else ""
results.append(self.create_text_message(text=f'{related.get("displayText", "")}{url}')) results.append(self.create_text_message(text=f"{related.get('displayText', '')}{url}"))
return results return results
elif result_type == "json": elif result_type == "json":
@ -106,29 +106,29 @@ class BingSearchTool(BuiltinTool):
text = "" text = ""
if search_results: if search_results:
for i, result in enumerate(search_results): for i, result in enumerate(search_results):
text += f'{i + 1}: {result.get("name", "")} - {result.get("snippet", "")}\n' text += f"{i + 1}: {result.get('name', '')} - {result.get('snippet', '')}\n"
if computation and "expression" in computation and "value" in computation: if computation and "expression" in computation and "value" in computation:
text += "\nComputation:\n" text += "\nComputation:\n"
text += f'{computation["expression"]} = {computation["value"]}\n' text += f"{computation['expression']} = {computation['value']}\n"
if entities: if entities:
text += "\nEntities:\n" text += "\nEntities:\n"
for entity in entities: for entity in entities:
url = f'- {entity["url"]}' if "url" in entity else "" url = f"- {entity['url']}" if "url" in entity else ""
text += f'{entity.get("name", "")}{url}\n' text += f"{entity.get('name', '')}{url}\n"
if news: if news:
text += "\nNews:\n" text += "\nNews:\n"
for news_item in news: for news_item in news:
url = f'- {news_item["url"]}' if "url" in news_item else "" url = f"- {news_item['url']}" if "url" in news_item else ""
text += f'{news_item.get("name", "")}{url}\n' text += f"{news_item.get('name', '')}{url}\n"
if related_searches: if related_searches:
text += "\n\nRelated Searches:\n" text += "\n\nRelated Searches:\n"
for related in related_searches: for related in related_searches:
url = f'- {related["webSearchUrl"]}' if "webSearchUrl" in related else "" url = f"- {related['webSearchUrl']}" if "webSearchUrl" in related else ""
text += f'{related.get("displayText", "")}{url}\n' text += f"{related.get('displayText', '')}{url}\n"
return self.create_text_message(text=self.summary(user_id=user_id, content=text)) return self.create_text_message(text=self.summary(user_id=user_id, content=text))

@ -83,5 +83,5 @@ class DIDApp:
if status["status"] == "done": if status["status"] == "done":
return status return status
elif status["status"] == "error" or status["status"] == "rejected": elif status["status"] == "error" or status["status"] == "rejected":
raise HTTPError(f'Talks {id} failed: {status["status"]} {status.get("error", {}).get("description")}') raise HTTPError(f"Talks {id} failed: {status['status']} {status.get('error', {}).get('description')}")
time.sleep(poll_interval) time.sleep(poll_interval)

@ -20,33 +20,33 @@ class SendEmailToolParameters(BaseModel):
encrypt_method: str encrypt_method: str
def send_mail(parmas: SendEmailToolParameters): def send_mail(params: SendEmailToolParameters):
timeout = 60 timeout = 60
msg = MIMEMultipart("alternative") msg = MIMEMultipart("alternative")
msg["From"] = parmas.email_account msg["From"] = params.email_account
msg["To"] = parmas.sender_to msg["To"] = params.sender_to
msg["Subject"] = parmas.subject msg["Subject"] = params.subject
msg.attach(MIMEText(parmas.email_content, "plain")) msg.attach(MIMEText(params.email_content, "plain"))
msg.attach(MIMEText(parmas.email_content, "html")) msg.attach(MIMEText(params.email_content, "html"))
ctx = ssl.create_default_context() ctx = ssl.create_default_context()
if parmas.encrypt_method.upper() == "SSL": if params.encrypt_method.upper() == "SSL":
try: try:
with smtplib.SMTP_SSL(parmas.smtp_server, parmas.smtp_port, context=ctx, timeout=timeout) as server: with smtplib.SMTP_SSL(params.smtp_server, params.smtp_port, context=ctx, timeout=timeout) as server:
server.login(parmas.email_account, parmas.email_password) server.login(params.email_account, params.email_password)
server.sendmail(parmas.email_account, parmas.sender_to, msg.as_string()) server.sendmail(params.email_account, params.sender_to, msg.as_string())
return True return True
except Exception as e: except Exception as e:
logging.exception("send email failed") logging.exception("send email failed")
return False return False
else: # NONE or TLS else: # NONE or TLS
try: try:
with smtplib.SMTP(parmas.smtp_server, parmas.smtp_port, timeout=timeout) as server: with smtplib.SMTP(params.smtp_server, params.smtp_port, timeout=timeout) as server:
if parmas.encrypt_method.upper() == "TLS": if params.encrypt_method.upper() == "TLS":
server.starttls(context=ctx) server.starttls(context=ctx)
server.login(parmas.email_account, parmas.email_password) server.login(params.email_account, params.email_password)
server.sendmail(parmas.email_account, parmas.sender_to, msg.as_string()) server.sendmail(params.email_account, params.sender_to, msg.as_string())
return True return True
except Exception as e: except Exception as e:
logging.exception("send email failed") logging.exception("send email failed")

@ -74,7 +74,7 @@ class FirecrawlApp:
if response is None: if response is None:
raise HTTPError("Failed to initiate crawl after multiple retries") raise HTTPError("Failed to initiate crawl after multiple retries")
elif response.get("success") == False: elif response.get("success") == False:
raise HTTPError(f'Failed to crawl: {response.get("error")}') raise HTTPError(f"Failed to crawl: {response.get('error')}")
job_id: str = response["id"] job_id: str = response["id"]
if wait: if wait:
return self._monitor_job_status(job_id=job_id, poll_interval=poll_interval) return self._monitor_job_status(job_id=job_id, poll_interval=poll_interval)
@ -100,7 +100,7 @@ class FirecrawlApp:
if status["status"] == "completed": if status["status"] == "completed":
return status return status
elif status["status"] == "failed": elif status["status"] == "failed":
raise HTTPError(f'Job {job_id} failed: {status["error"]}') raise HTTPError(f"Job {job_id} failed: {status['error']}")
time.sleep(poll_interval) time.sleep(poll_interval)

@ -37,8 +37,9 @@ class GaodeRepositoriesTool(BuiltinTool):
CityCode = City_data["districts"][0]["adcode"] CityCode = City_data["districts"][0]["adcode"]
weatherInfo_response = s.request( weatherInfo_response = s.request(
method="GET", method="GET",
url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json" url="{url}/weather/weatherInfo?city={citycode}&extensions=all&key={apikey}&output=json".format(
"".format(url=api_domain, citycode=CityCode, apikey=self.runtime.credentials.get("api_key")), url=api_domain, citycode=CityCode, apikey=self.runtime.credentials.get("api_key")
),
) )
weatherInfo_data = weatherInfo_response.json() weatherInfo_data = weatherInfo_response.json()
if weatherInfo_response.status_code == 200 and weatherInfo_data.get("info") == "OK": if weatherInfo_response.status_code == 200 and weatherInfo_data.get("info") == "OK":

@ -11,19 +11,21 @@ class GitlabFilesTool(BuiltinTool):
def _invoke( def _invoke(
self, user_id: str, tool_parameters: dict[str, Any] self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
project = tool_parameters.get("project", "")
repository = tool_parameters.get("repository", "") repository = tool_parameters.get("repository", "")
project = tool_parameters.get("project", "")
branch = tool_parameters.get("branch", "") branch = tool_parameters.get("branch", "")
path = tool_parameters.get("path", "") path = tool_parameters.get("path", "")
file_path = tool_parameters.get("file_path", "")
if not project and not repository: if not repository and not project:
return self.create_text_message("Either project or repository is required") return self.create_text_message("Either repository or project is required")
if not branch: if not branch:
return self.create_text_message("Branch is required") return self.create_text_message("Branch is required")
if not path: if not path and not file_path:
return self.create_text_message("Path is required") return self.create_text_message("Either path or file_path is required")
access_token = self.runtime.credentials.get("access_tokens") access_token = self.runtime.credentials.get("access_tokens")
headers = {"PRIVATE-TOKEN": access_token}
site_url = self.runtime.credentials.get("site_url") site_url = self.runtime.credentials.get("site_url")
if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"): if "access_tokens" not in self.runtime.credentials or not self.runtime.credentials.get("access_tokens"):
@ -31,33 +33,45 @@ class GitlabFilesTool(BuiltinTool):
if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"): if "site_url" not in self.runtime.credentials or not self.runtime.credentials.get("site_url"):
site_url = "https://gitlab.com" site_url = "https://gitlab.com"
# Get file content
if repository: if repository:
result = self.fetch_files(site_url, access_token, repository, branch, path, is_repository=True) # URL encode the repository path
identifier = urllib.parse.quote(repository, safe="")
else: else:
result = self.fetch_files(site_url, access_token, project, branch, path, is_repository=False) identifier = self.get_project_id(site_url, access_token, project)
if not identifier:
raise Exception(f"Project '{project}' not found.)")
return [self.create_json_message(item) for item in result] # Get file content
if path:
results = self.fetch_files(site_url, headers, identifier, branch, path)
return [self.create_json_message(item) for item in results]
else:
result = self.fetch_file(site_url, headers, identifier, branch, file_path)
return [self.create_json_message(result)]
@staticmethod
def fetch_file(
site_url: str,
headers: dict[str, str],
identifier: str,
branch: str,
path: str,
) -> dict[str, Any]:
encoded_file_path = urllib.parse.quote(path, safe="")
file_url = f"{site_url}/api/v4/projects/{identifier}/repository/files/{encoded_file_path}/raw?ref={branch}"
file_response = requests.get(file_url, headers=headers)
file_response.raise_for_status()
file_content = file_response.text
return {"path": path, "branch": branch, "content": file_content}
def fetch_files( def fetch_files(
self, site_url: str, access_token: str, identifier: str, branch: str, path: str, is_repository: bool self, site_url: str, headers: dict[str, str], identifier: str, branch: str, path: str
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
domain = site_url
headers = {"PRIVATE-TOKEN": access_token}
results = [] results = []
try: try:
if is_repository: tree_url = f"{site_url}/api/v4/projects/{identifier}/repository/tree?path={path}&ref={branch}"
# URL encode the repository path
encoded_identifier = urllib.parse.quote(identifier, safe="")
tree_url = f"{domain}/api/v4/projects/{encoded_identifier}/repository/tree?path={path}&ref={branch}"
else:
# Get project ID from project name
project_id = self.get_project_id(site_url, access_token, identifier)
if not project_id:
return self.create_text_message(f"Project '{identifier}' not found.")
tree_url = f"{domain}/api/v4/projects/{project_id}/repository/tree?path={path}&ref={branch}"
response = requests.get(tree_url, headers=headers) response = requests.get(tree_url, headers=headers)
response.raise_for_status() response.raise_for_status()
items = response.json() items = response.json()
@ -65,26 +79,10 @@ class GitlabFilesTool(BuiltinTool):
for item in items: for item in items:
item_path = item["path"] item_path = item["path"]
if item["type"] == "tree": # It's a directory if item["type"] == "tree": # It's a directory
results.extend( results.extend(self.fetch_files(site_url, headers, identifier, branch, item_path))
self.fetch_files(site_url, access_token, identifier, branch, item_path, is_repository)
)
else: # It's a file else: # It's a file
encoded_item_path = urllib.parse.quote(item_path, safe="") result = self.fetch_file(site_url, headers, identifier, branch, item_path)
if is_repository: results.append(result)
file_url = (
f"{domain}/api/v4/projects/{encoded_identifier}/repository/files"
f"/{encoded_item_path}/raw?ref={branch}"
)
else:
file_url = (
f"{domain}/api/v4/projects/{project_id}/repository/files"
f"{encoded_item_path}/raw?ref={branch}"
)
file_response = requests.get(file_url, headers=headers)
file_response.raise_for_status()
file_content = file_response.text
results.append({"path": item_path, "branch": branch, "content": file_content})
except requests.RequestException as e: except requests.RequestException as e:
print(f"Error fetching data from GitLab: {e}") print(f"Error fetching data from GitLab: {e}")

@ -29,7 +29,7 @@ parameters:
zh_Hans: 项目 zh_Hans: 项目
human_description: human_description:
en_US: project en_US: project
zh_Hans: 项目 zh_Hans: 项目(和仓库路径二选一,都填写以仓库路径优先)
llm_description: Project for GitLab llm_description: Project for GitLab
form: llm form: llm
- name: branch - name: branch
@ -45,12 +45,21 @@ parameters:
form: llm form: llm
- name: path - name: path
type: string type: string
required: true
label: label:
en_US: path en_US: path
zh_Hans: 文件路径 zh_Hans: 文件
human_description: human_description:
en_US: path en_US: path
zh_Hans: 文件夹
llm_description: Dir path for GitLab
form: llm
- name: file_path
type: string
label:
en_US: file_path
zh_Hans: 文件路径 zh_Hans: 文件路径
human_description:
en_US: file_path
zh_Hans: 文件路径(和文件夹二选一,都填写以文件夹优先)
llm_description: File path for GitLab llm_description: File path for GitLab
form: llm form: llm

@ -110,7 +110,7 @@ class ListWorksheetRecordsTool(BuiltinTool):
result["rows"].append(self.get_row_field_value(row, schema)) result["rows"].append(self.get_row_field_value(row, schema))
return self.create_text_message(json.dumps(result, ensure_ascii=False)) return self.create_text_message(json.dumps(result, ensure_ascii=False))
else: else:
result_text = f"Found {result['total']} rows in worksheet \"{worksheet_name}\"." result_text = f'Found {result["total"]} rows in worksheet "{worksheet_name}".'
if result["total"] > 0: if result["total"] > 0:
result_text += ( result_text += (
f" The following are {min(limit, result['total'])}" f" The following are {min(limit, result['total'])}"

@ -28,4 +28,4 @@ class BaseStabilityAuthorization:
""" """
This method is responsible for generating the authorization headers. This method is responsible for generating the authorization headers.
""" """
return {"Authorization": f'Bearer {credentials.get("api_key", "")}'} return {"Authorization": f"Bearer {credentials.get('api_key', '')}"}

@ -38,7 +38,7 @@ class VannaProvider(BuiltinToolProviderController):
tool_parameters={ tool_parameters={
"model": "chinook", "model": "chinook",
"db_type": "SQLite", "db_type": "SQLite",
"url": f'{self._get_protocol_and_main_domain(credentials["base_url"])}/Chinook.sqlite', "url": f"{self._get_protocol_and_main_domain(credentials['base_url'])}/Chinook.sqlite",
"query": "What are the top 10 customers by sales?", "query": "What are the top 10 customers by sales?",
}, },
) )

@ -43,7 +43,7 @@ class SerplyApi:
def parse_results(res: dict) -> str: def parse_results(res: dict) -> str:
"""Process response from Serply Job Search.""" """Process response from Serply Job Search."""
jobs = res.get("jobs", []) jobs = res.get("jobs", [])
if not jobs: if not res or "jobs" not in res:
raise ValueError(f"Got error from Serply: {res}") raise ValueError(f"Got error from Serply: {res}")
string = [] string = []

@ -43,7 +43,7 @@ class SerplyApi:
def parse_results(res: dict) -> str: def parse_results(res: dict) -> str:
"""Process response from Serply News Search.""" """Process response from Serply News Search."""
news = res.get("entries", []) news = res.get("entries", [])
if not news: if not res or "entries" not in res:
raise ValueError(f"Got error from Serply: {res}") raise ValueError(f"Got error from Serply: {res}")
string = [] string = []

@ -43,7 +43,7 @@ class SerplyApi:
def parse_results(res: dict) -> str: def parse_results(res: dict) -> str:
"""Process response from Serply News Search.""" """Process response from Serply News Search."""
articles = res.get("articles", []) articles = res.get("articles", [])
if not articles: if not res or "articles" not in res:
raise ValueError(f"Got error from Serply: {res}") raise ValueError(f"Got error from Serply: {res}")
string = [] string = []

@ -42,7 +42,7 @@ class SerplyApi:
def parse_results(res: dict) -> str: def parse_results(res: dict) -> str:
"""Process response from Serply Web Search.""" """Process response from Serply Web Search."""
results = res.get("results", []) results = res.get("results", [])
if not results: if not res or "results" not in res:
raise ValueError(f"Got error from Serply: {res}") raise ValueError(f"Got error from Serply: {res}")
string = [] string = []

@ -84,9 +84,9 @@ class ApiTool(Tool):
if "api_key_header_prefix" in credentials: if "api_key_header_prefix" in credentials:
api_key_header_prefix = credentials["api_key_header_prefix"] api_key_header_prefix = credentials["api_key_header_prefix"]
if api_key_header_prefix == "basic" and credentials["api_key_value"]: if api_key_header_prefix == "basic" and credentials["api_key_value"]:
credentials["api_key_value"] = f'Basic {credentials["api_key_value"]}' credentials["api_key_value"] = f"Basic {credentials['api_key_value']}"
elif api_key_header_prefix == "bearer" and credentials["api_key_value"]: elif api_key_header_prefix == "bearer" and credentials["api_key_value"]:
credentials["api_key_value"] = f'Bearer {credentials["api_key_value"]}' credentials["api_key_value"] = f"Bearer {credentials['api_key_value']}"
elif api_key_header_prefix == "custom": elif api_key_header_prefix == "custom":
pass pass

@ -29,7 +29,7 @@ class ToolFileMessageTransformer:
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_url=message.message user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_url=message.message
) )
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}' url = f"/files/tools/{file.id}{guess_extension(file.mimetype) or '.png'}"
result.append( result.append(
ToolInvokeMessage( ToolInvokeMessage(
@ -122,4 +122,4 @@ class ToolFileMessageTransformer:
@classmethod @classmethod
def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str: def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str:
return f'/files/tools/{tool_file_id}{extension or ".bin"}' return f"/files/tools/{tool_file_id}{extension or '.bin'}"

@ -5,6 +5,7 @@ from json import loads as json_loads
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from typing import Optional from typing import Optional
from flask import request
from requests import get from requests import get
from yaml import YAMLError, safe_load # type: ignore from yaml import YAMLError, safe_load # type: ignore
@ -29,6 +30,10 @@ class ApiBasedToolSchemaParser:
raise ToolProviderNotFoundError("No server found in the openapi yaml.") raise ToolProviderNotFoundError("No server found in the openapi yaml.")
server_url = openapi["servers"][0]["url"] server_url = openapi["servers"][0]["url"]
request_env = request.headers.get("X-Request-Env")
if request_env:
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
server_url = matched_servers[0] if matched_servers else server_url
# list all interfaces # list all interfaces
interfaces = [] interfaces = []
@ -144,7 +149,7 @@ class ApiBasedToolSchemaParser:
if not path: if not path:
path = str(uuid.uuid4()) path = str(uuid.uuid4())
interface["operation"]["operationId"] = f'{path}_{interface["method"]}' interface["operation"]["operationId"] = f"{path}_{interface['method']}"
bundles.append( bundles.append(
ApiToolBundle( ApiToolBundle(

@ -134,6 +134,10 @@ class ArrayStringSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_STRING value_type: SegmentType = SegmentType.ARRAY_STRING
value: Sequence[str] value: Sequence[str]
@property
def text(self) -> str:
return json.dumps(self.value)
class ArrayNumberSegment(ArraySegment): class ArrayNumberSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_NUMBER value_type: SegmentType = SegmentType.ARRAY_NUMBER

@ -1,6 +1,7 @@
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Generator from collections.abc import Generator
from typing import Optional
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
@ -48,10 +49,15 @@ class StreamProcessor(ABC):
# we remove the node maybe shortcut the answer node, so comment this code for now # we remove the node maybe shortcut the answer node, so comment this code for now
# there is not effect on the answer node and the workflow, when we have a better solution # there is not effect on the answer node and the workflow, when we have a better solution
# we can open this code. Issues: #11542 #9560 #10638 #10564 # we can open this code. Issues: #11542 #9560 #10638 #10564
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id) # ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id)
if "answer" in ids: # if "answer" in ids:
continue # continue
else: # else:
# reachable_node_ids.extend(ids)
# The branch_identify parameter is added to ensure that
# only nodes in the correct logical branch are included.
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle)
reachable_node_ids.extend(ids) reachable_node_ids.extend(ids)
else: else:
unreachable_first_node_ids.append(edge.target_node_id) unreachable_first_node_ids.append(edge.target_node_id)
@ -59,14 +65,19 @@ class StreamProcessor(ABC):
for node_id in unreachable_first_node_ids: for node_id in unreachable_first_node_ids:
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]: def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]:
node_ids = [] node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []): for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id: if edge.target_node_id == self.graph.root_node_id:
continue continue
# Only follow edges that match the branch_identify or have no run_condition
if edge.run_condition and edge.run_condition.branch_identify:
if not branch_identify or edge.run_condition.branch_identify != branch_identify:
continue
node_ids.append(edge.target_node_id) node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify))
return node_ids return node_ids
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None: def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:

@ -253,9 +253,9 @@ class Executor:
) )
if executor_response.size > threshold_size: if executor_response.size > threshold_size:
raise ResponseSizeError( raise ResponseSizeError(
f'{"File" if executor_response.is_file else "Text"} size is too large,' f"{'File' if executor_response.is_file else 'Text'} size is too large,"
f' max size is {threshold_size / 1024 / 1024:.2f} MB,' f" max size is {threshold_size / 1024 / 1024:.2f} MB,"
f' but current size is {executor_response.readable_size}.' f" but current size is {executor_response.readable_size}."
) )
return executor_response return executor_response
@ -338,7 +338,7 @@ class Executor:
if self.auth.config and self.auth.config.header: if self.auth.config and self.auth.config.header:
authorization_header = self.auth.config.header authorization_header = self.auth.config.header
if k.lower() == authorization_header.lower(): if k.lower() == authorization_header.lower():
raw += f'{k}: {"*" * len(v)}\r\n' raw += f"{k}: {'*' * len(v)}\r\n"
continue continue
raw += f"{k}: {v}\r\n" raw += f"{k}: {v}\r\n"

@ -1,5 +1,4 @@
import logging import logging
import time
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, cast from typing import Any, cast
@ -20,10 +19,8 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Document from models.dataset import Dataset, Document
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
from services.feature_service import FeatureService
from .entities import KnowledgeRetrievalNodeData from .entities import KnowledgeRetrievalNodeData
from .exc import ( from .exc import (
@ -64,23 +61,6 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required." status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required."
) )
# check rate limit
if self.tenant_id:
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{self.tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
error_type="RateLimitExceeded",
)
# retrieve knowledge # retrieve knowledge
try: try:
results = self._fetch_dataset_retriever(node_data=self.node_data, query=query) results = self._fetch_dataset_retriever(node_data=self.node_data, query=query)

@ -1,4 +1,5 @@
import json import json
from collections.abc import Sequence
from typing import Any, cast from typing import Any, cast
from core.variables import SegmentType, Variable from core.variables import SegmentType, Variable
@ -31,7 +32,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
inputs = self.node_data.model_dump() inputs = self.node_data.model_dump()
process_data: dict[str, Any] = {} process_data: dict[str, Any] = {}
# NOTE: This node has no outputs # NOTE: This node has no outputs
updated_variables: list[Variable] = [] updated_variable_selectors: list[Sequence[str]] = []
try: try:
for item in self.node_data.items: for item in self.node_data.items:
@ -98,7 +99,8 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
value=item.value, value=item.value,
) )
variable = variable.model_copy(update={"value": updated_value}) variable = variable.model_copy(update={"value": updated_value})
updated_variables.append(variable) self.graph_runtime_state.variable_pool.add(variable.selector, variable)
updated_variable_selectors.append(variable.selector)
except VariableOperatorNodeError as e: except VariableOperatorNodeError as e:
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
@ -107,9 +109,15 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
error=str(e), error=str(e),
) )
# The `updated_variable_selectors` is a list contains list[str] which not hashable,
# remove the duplicated items first.
updated_variable_selectors = list(set(map(tuple, updated_variable_selectors)))
# Update variables # Update variables
for variable in updated_variables: for selector in updated_variable_selectors:
self.graph_runtime_state.variable_pool.add(variable.selector, variable) variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable):
raise VariableNotFoundError(variable_selector=selector)
process_data[variable.name] = variable.value process_data[variable.name] = variable.value
if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID: if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID:

@ -26,7 +26,7 @@ def handle(sender, **kwargs):
tool_runtime=tool_runtime, tool_runtime=tool_runtime,
provider_name=tool_entity.provider_name, provider_name=tool_entity.provider_name,
provider_type=tool_entity.provider_type, provider_type=tool_entity.provider_type,
identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}', identity_id=f"WORKFLOW.{app.id}.{node_data.get('id')}",
) )
manager.delete_tool_parameters_cache() manager.delete_tool_parameters_cache()
except: except:

@ -1,6 +1,6 @@
from flask_restful import fields # type: ignore from flask_restful import fields # type: ignore
from libs.helper import TimestampField from libs.helper import AvatarUrlField, TimestampField
simple_account_fields = {"id": fields.String, "name": fields.String, "email": fields.String} simple_account_fields = {"id": fields.String, "name": fields.String, "email": fields.String}
@ -8,6 +8,7 @@ account_fields = {
"id": fields.String, "id": fields.String,
"name": fields.String, "name": fields.String,
"avatar": fields.String, "avatar": fields.String,
"avatar_url": AvatarUrlField,
"email": fields.String, "email": fields.String,
"is_password_set": fields.Boolean, "is_password_set": fields.Boolean,
"interface_language": fields.String, "interface_language": fields.String,
@ -22,6 +23,7 @@ account_with_role_fields = {
"id": fields.String, "id": fields.String,
"name": fields.String, "name": fields.String,
"avatar": fields.String, "avatar": fields.String,
"avatar_url": AvatarUrlField,
"email": fields.String, "email": fields.String,
"last_login_at": TimestampField, "last_login_at": TimestampField,
"last_active_at": TimestampField, "last_active_at": TimestampField,

@ -41,6 +41,18 @@ class AppIconUrlField(fields.Raw):
return None return None
class AvatarUrlField(fields.Raw):
def output(self, key, obj):
if obj is None:
return None
from models.account import Account
if isinstance(obj, Account) and obj.avatar is not None:
return file_helpers.get_signed_file_url(obj.avatar)
return None
class TimestampField(fields.Raw): class TimestampField(fields.Raw):
def format(self, value) -> int: def format(self, value) -> int:
return int(value.timestamp()) return int(value.timestamp())

@ -13,6 +13,7 @@ from typing import Any, cast
from sqlalchemy import func from sqlalchemy import func
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped
from configs import dify_config from configs import dify_config
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
@ -515,7 +516,7 @@ class DocumentSegment(db.Model): # type: ignore[name-defined]
tenant_id = db.Column(StringUUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False) dataset_id = db.Column(StringUUID, nullable=False)
document_id = db.Column(StringUUID, nullable=False) document_id = db.Column(StringUUID, nullable=False)
position = db.Column(db.Integer, nullable=False) position: Mapped[int]
content = db.Column(db.Text, nullable=False) content = db.Column(db.Text, nullable=False)
answer = db.Column(db.Text, nullable=True) answer = db.Column(db.Text, nullable=True)
word_count = db.Column(db.Integer, nullable=False) word_count = db.Column(db.Integer, nullable=False)

1076
api/poetry.lock generated

File diff suppressed because it is too large Load Diff

@ -1,9 +1,10 @@
[project] [project]
name = "dify-api" name = "dify-api"
requires-python = ">=3.11,<3.13" requires-python = ">=3.11,<3.13"
dynamic = [ "dependencies" ]
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core>=2.0.0"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.poetry] [tool.poetry]
@ -59,6 +60,7 @@ numpy = "~1.26.4"
oci = "~2.135.1" oci = "~2.135.1"
openai = "~1.52.0" openai = "~1.52.0"
openpyxl = "~3.1.5" openpyxl = "~3.1.5"
opik = "~1.3.4"
pandas = { version = "~2.2.2", extras = ["performance", "excel"] } pandas = { version = "~2.2.2", extras = ["performance", "excel"] }
pandas-stubs = "~2.2.3.241009" pandas-stubs = "~2.2.3.241009"
psycogreen = "~1.0.2" psycogreen = "~1.0.2"
@ -190,4 +192,4 @@ pytest-mock = "~3.14.0"
optional = true optional = true
[tool.poetry.group.lint.dependencies] [tool.poetry.group.lint.dependencies]
dotenv-linter = "~0.5.0" dotenv-linter = "~0.5.0"
ruff = "~0.8.1" ruff = "~0.9.2"

@ -1,7 +1,7 @@
import logging import logging
import uuid import uuid
from enum import StrEnum from enum import StrEnum
from typing import Optional, cast from typing import Optional
from urllib.parse import urlparse from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
@ -139,15 +139,6 @@ class AppDslService:
status=ImportStatus.FAILED, status=ImportStatus.FAILED,
error="Empty content from url", error="Empty content from url",
) )
try:
content = cast(bytes, content).decode("utf-8")
except UnicodeDecodeError as e:
return Import(
id=import_id,
status=ImportStatus.FAILED,
error=f"Error decoding content: {e}",
)
except Exception as e: except Exception as e:
return Import( return Import(
id=import_id, id=import_id,

@ -82,7 +82,7 @@ class AudioService:
from app import app from app import app
from extensions.ext_database import db from extensions.ext_database import db
def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None):
with app.app_context(): with app.app_context():
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app_model.workflow workflow = app_model.workflow
@ -95,6 +95,8 @@ class AudioService:
voice = features_dict["text_to_speech"].get("voice") if voice is None else voice voice = features_dict["text_to_speech"].get("voice") if voice is None else voice
else: else:
if app_model.app_model_config is None:
raise ValueError("AppModelConfig not found")
text_to_speech_dict = app_model.app_model_config.text_to_speech_dict text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
if not text_to_speech_dict.get("enabled"): if not text_to_speech_dict.get("enabled"):

@ -19,14 +19,6 @@ class BillingService:
billing_info = cls._send_request("GET", "/subscription/info", params=params) billing_info = cls._send_request("GET", "/subscription/info", params=params)
return billing_info return billing_info
@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str):
params = {"tenant_id": tenant_id}
knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params)
return knowledge_rate_limit.get("limit", 10)
@classmethod @classmethod
def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""): def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""):
params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id} params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id}

@ -4,6 +4,7 @@ import logging
import random import random
import time import time
import uuid import uuid
from collections import Counter
from typing import Any, Optional from typing import Any, Optional
from flask_login import current_user # type: ignore from flask_login import current_user # type: ignore
@ -221,8 +222,7 @@ class DatasetService:
) )
except LLMBadRequestError: except LLMBadRequestError:
raise ValueError( raise ValueError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
"in the Settings -> Model Provider."
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ValueError(f"The dataset in unavailable, due to: {ex.description}") raise ValueError(f"The dataset in unavailable, due to: {ex.description}")
@ -859,7 +859,7 @@ class DocumentService:
position = DocumentService.get_documents_position(dataset.id) position = DocumentService.get_documents_position(dataset.id)
document_ids = [] document_ids = []
duplicate_document_ids = [] duplicate_document_ids = []
if knowledge_config.data_source.info_list.data_source_type == "upload_file": if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
for file_id in upload_file_list: for file_id in upload_file_list:
file = ( file = (
@ -901,7 +901,7 @@ class DocumentService:
document = DocumentService.build_document( document = DocumentService.build_document(
dataset, dataset,
dataset_process_rule.id, # type: ignore dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, knowledge_config.data_source.info_list.data_source_type, # type: ignore
knowledge_config.doc_form, knowledge_config.doc_form,
knowledge_config.doc_language, knowledge_config.doc_language,
data_source_info, data_source_info,
@ -916,8 +916,8 @@ class DocumentService:
document_ids.append(document.id) document_ids.append(document.id)
documents.append(document) documents.append(document)
position += 1 position += 1
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
notion_info_list = knowledge_config.data_source.info_list.notion_info_list notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
if not notion_info_list: if not notion_info_list:
raise ValueError("No notion info list found.") raise ValueError("No notion info list found.")
exist_page_ids = [] exist_page_ids = []
@ -956,7 +956,7 @@ class DocumentService:
document = DocumentService.build_document( document = DocumentService.build_document(
dataset, dataset,
dataset_process_rule.id, # type: ignore dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, knowledge_config.data_source.info_list.data_source_type, # type: ignore
knowledge_config.doc_form, knowledge_config.doc_form,
knowledge_config.doc_language, knowledge_config.doc_language,
data_source_info, data_source_info,
@ -976,8 +976,8 @@ class DocumentService:
# delete not selected documents # delete not selected documents
if len(exist_document) > 0: if len(exist_document) > 0:
clean_notion_document_task.delay(list(exist_document.values()), dataset.id) clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
website_info = knowledge_config.data_source.info_list.website_info_list website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
if not website_info: if not website_info:
raise ValueError("No website info list found.") raise ValueError("No website info list found.")
urls = website_info.urls urls = website_info.urls
@ -996,7 +996,7 @@ class DocumentService:
document = DocumentService.build_document( document = DocumentService.build_document(
dataset, dataset,
dataset_process_rule.id, # type: ignore dataset_process_rule.id, # type: ignore
knowledge_config.data_source.info_list.data_source_type, knowledge_config.data_source.info_list.data_source_type, # type: ignore
knowledge_config.doc_form, knowledge_config.doc_form,
knowledge_config.doc_language, knowledge_config.doc_language,
data_source_info, data_source_info,
@ -1195,20 +1195,20 @@ class DocumentService:
if features.billing.enabled: if features.billing.enabled:
count = 0 count = 0
if knowledge_config.data_source.info_list.data_source_type == "upload_file": if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
upload_file_list = ( upload_file_list = (
knowledge_config.data_source.info_list.file_info_list.file_ids knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
if knowledge_config.data_source.info_list.file_info_list if knowledge_config.data_source.info_list.file_info_list # type: ignore
else [] else []
) )
count = len(upload_file_list) count = len(upload_file_list)
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
notion_info_list = knowledge_config.data_source.info_list.notion_info_list notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
if notion_info_list: if notion_info_list:
for notion_info in notion_info_list: for notion_info in notion_info_list:
count = count + len(notion_info.pages) count = count + len(notion_info.pages)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
website_info = knowledge_config.data_source.info_list.website_info_list website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
if website_info: if website_info:
count = len(website_info.urls) count = len(website_info.urls)
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
@ -1239,7 +1239,7 @@ class DocumentService:
dataset = Dataset( dataset = Dataset(
tenant_id=tenant_id, tenant_id=tenant_id,
name="", name="",
data_source_type=knowledge_config.data_source.info_list.data_source_type, data_source_type=knowledge_config.data_source.info_list.data_source_type, # type: ignore
indexing_technique=knowledge_config.indexing_technique, indexing_technique=knowledge_config.indexing_technique,
created_by=account.id, created_by=account.id,
embedding_model=knowledge_config.embedding_model, embedding_model=knowledge_config.embedding_model,
@ -1611,8 +1611,11 @@ class SegmentService:
segment.answer = args.answer segment.answer = args.answer
segment.word_count += len(args.answer) if args.answer else 0 segment.word_count += len(args.answer) if args.answer else 0
word_count_change = segment.word_count - word_count_change word_count_change = segment.word_count - word_count_change
keyword_changed = False
if args.keywords: if args.keywords:
if Counter(segment.keywords) != Counter(args.keywords):
segment.keywords = args.keywords segment.keywords = args.keywords
keyword_changed = True
segment.enabled = True segment.enabled = True
segment.disabled_at = None segment.disabled_at = None
segment.disabled_by = None segment.disabled_by = None
@ -1623,13 +1626,6 @@ class SegmentService:
document.word_count = max(0, document.word_count + word_count_change) document.word_count = max(0, document.word_count + word_count_change)
db.session.add(document) db.session.add(document)
# update segment index task # update segment index task
if args.enabled:
VectorService.create_segments_vector(
[args.keywords] if args.keywords else None,
[segment],
dataset,
document.doc_form,
)
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
# regenerate child chunks # regenerate child chunks
# get embedding model instance # get embedding model instance
@ -1662,6 +1658,14 @@ class SegmentService:
VectorService.generate_child_chunks( VectorService.generate_child_chunks(
segment, document, dataset, embedding_model_instance, processing_rule, True segment, document, dataset, embedding_model_instance, processing_rule, True
) )
elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX):
if args.enabled or keyword_changed:
VectorService.create_segments_vector(
[args.keywords] if args.keywords else None,
[segment],
dataset,
document.doc_form,
)
else: else:
segment_hash = helper.generate_text_hash(content) segment_hash = helper.generate_text_hash(content)
tokens = 0 tokens = 0

@ -97,7 +97,7 @@ class KnowledgeConfig(BaseModel):
original_document_id: Optional[str] = None original_document_id: Optional[str] = None
duplicate: bool = True duplicate: bool = True
indexing_technique: Literal["high_quality", "economy"] indexing_technique: Literal["high_quality", "economy"]
data_source: DataSource data_source: Optional[DataSource] = None
process_rule: Optional[ProcessRule] = None process_rule: Optional[ProcessRule] = None
retrieval_model: Optional[RetrievalModel] = None retrieval_model: Optional[RetrievalModel] = None
doc_form: str = "text_model" doc_form: str = "text_model"

@ -155,7 +155,7 @@ class ExternalDatasetService:
if custom_parameters: if custom_parameters:
for parameter in custom_parameters: for parameter in custom_parameters:
if parameter.get("required", False) and not process_parameter.get(parameter.get("name")): if parameter.get("required", False) and not process_parameter.get(parameter.get("name")):
raise ValueError(f'{parameter.get("name")} is required') raise ValueError(f"{parameter.get('name')} is required")
@staticmethod @staticmethod
def process_external_api( def process_external_api(

@ -41,7 +41,6 @@ class FeatureModel(BaseModel):
members: LimitationModel = LimitationModel(size=0, limit=1) members: LimitationModel = LimitationModel(size=0, limit=1)
apps: LimitationModel = LimitationModel(size=0, limit=10) apps: LimitationModel = LimitationModel(size=0, limit=10)
vector_space: LimitationModel = LimitationModel(size=0, limit=5) vector_space: LimitationModel = LimitationModel(size=0, limit=5)
knowledge_rate_limit: int = 10
annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10) annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10)
documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50) documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50)
docs_processing: str = "standard" docs_processing: str = "standard"
@ -53,11 +52,6 @@ class FeatureModel(BaseModel):
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
class KnowledgeRateLimitModel(BaseModel):
enabled: bool = False
limit: int = 10
class SystemFeatureModel(BaseModel): class SystemFeatureModel(BaseModel):
sso_enforced_for_signin: bool = False sso_enforced_for_signin: bool = False
sso_enforced_for_signin_protocol: str = "" sso_enforced_for_signin_protocol: str = ""
@ -85,14 +79,6 @@ class FeatureService:
return features return features
@classmethod
def get_knowledge_rate_limit(cls, tenant_id: str):
knowledge_rate_limit = KnowledgeRateLimitModel()
if dify_config.BILLING_ENABLED and tenant_id:
knowledge_rate_limit.enabled = True
knowledge_rate_limit.limit = BillingService.get_knowledge_rate_limit(tenant_id)
return knowledge_rate_limit
@classmethod @classmethod
def get_system_features(cls) -> SystemFeatureModel: def get_system_features(cls) -> SystemFeatureModel:
system_features = SystemFeatureModel() system_features = SystemFeatureModel()
@ -158,9 +144,6 @@ class FeatureService:
if "model_load_balancing_enabled" in billing_info: if "model_load_balancing_enabled" in billing_info:
features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"] features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"]
if "knowledge_rate_limit" in billing_info:
features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"]
@classmethod @classmethod
def _fulfill_params_from_enterprise(cls, features): def _fulfill_params_from_enterprise(cls, features):
enterprise_info = EnterpriseService.get_info() enterprise_info = EnterpriseService.get_info()

@ -59,6 +59,15 @@ class OpsService:
except Exception: except Exception:
new_decrypt_tracing_config.update({"project_url": "https://smith.langchain.com/"}) new_decrypt_tracing_config.update({"project_url": "https://smith.langchain.com/"})
if tracing_provider == "opik" and (
"project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
):
try:
project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
new_decrypt_tracing_config.update({"project_url": project_url})
except Exception:
new_decrypt_tracing_config.update({"project_url": "https://www.comet.com/opik/"})
trace_config_data.tracing_config = new_decrypt_tracing_config trace_config_data.tracing_config = new_decrypt_tracing_config
return trace_config_data.to_dict() return trace_config_data.to_dict()
@ -92,7 +101,7 @@ class OpsService:
if tracing_provider == "langfuse": if tracing_provider == "langfuse":
project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider) project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider)
project_url = "{host}/project/{key}".format(host=tracing_config.get("host"), key=project_key) project_url = "{host}/project/{key}".format(host=tracing_config.get("host"), key=project_key)
elif tracing_provider == "langsmith": elif tracing_provider in ("langsmith", "opik"):
project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider) project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
else: else:
project_url = None project_url = None

@ -5,7 +5,8 @@ import uuid
import click import click
from celery import shared_task # type: ignore from celery import shared_task # type: ignore
from sqlalchemy import func from sqlalchemy import func, select
from sqlalchemy.orm import Session
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
@ -18,7 +19,12 @@ from services.vector_service import VectorService
@shared_task(queue="dataset") @shared_task(queue="dataset")
def batch_create_segment_to_index_task( def batch_create_segment_to_index_task(
job_id: str, content: list, dataset_id: str, document_id: str, tenant_id: str, user_id: str job_id: str,
content: list,
dataset_id: str,
document_id: str,
tenant_id: str,
user_id: str,
): ):
""" """
Async batch create segment to index Async batch create segment to index
@ -37,15 +43,20 @@ def batch_create_segment_to_index_task(
indexing_cache_key = "segment_batch_import_{}".format(job_id) indexing_cache_key = "segment_batch_import_{}".format(job_id)
try: try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() with Session(db.engine) as session:
dataset = session.get(Dataset, dataset_id)
if not dataset: if not dataset:
raise ValueError("Dataset not exist.") raise ValueError("Dataset not exist.")
dataset_document = db.session.query(Document).filter(Document.id == document_id).first() dataset_document = session.get(Document, document_id)
if not dataset_document: if not dataset_document:
raise ValueError("Document not exist.") raise ValueError("Document not exist.")
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": if (
not dataset_document.enabled
or dataset_document.archived
or dataset_document.indexing_status != "completed"
):
raise ValueError("Document is not available.") raise ValueError("Document is not available.")
document_segments = [] document_segments = []
embedding_model = None embedding_model = None
@ -58,25 +69,24 @@ def batch_create_segment_to_index_task(
model=dataset.embedding_model, model=dataset.embedding_model,
) )
word_count_change = 0 word_count_change = 0
segments_to_insert: list[str] = [] # Explicitly type hint the list as List[str] segments_to_insert: list[str] = []
max_position_stmt = select(func.max(DocumentSegment.position)).where(
DocumentSegment.document_id == dataset_document.id
)
max_position = session.scalar(max_position_stmt) or 1
for segment in content: for segment in content:
content_str = segment["content"] content_str = segment["content"]
doc_id = str(uuid.uuid4()) doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content_str) segment_hash = helper.generate_text_hash(content_str)
# calc embedding use tokens # calc embedding use tokens
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content_str]) if embedding_model else 0 tokens = embedding_model.get_text_embedding_num_tokens(texts=[content_str]) if embedding_model else 0
max_position = (
db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == dataset_document.id)
.scalar()
)
segment_document = DocumentSegment( segment_document = DocumentSegment(
tenant_id=tenant_id, tenant_id=tenant_id,
dataset_id=dataset_id, dataset_id=dataset_id,
document_id=document_id, document_id=document_id,
index_node_id=doc_id, index_node_id=doc_id,
index_node_hash=segment_hash, index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1, position=max_position,
content=content_str, content=content_str,
word_count=len(content_str), word_count=len(content_str),
tokens=tokens, tokens=tokens,
@ -85,23 +95,28 @@ def batch_create_segment_to_index_task(
status="completed", status="completed",
completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
) )
max_position += 1
if dataset_document.doc_form == "qa_model": if dataset_document.doc_form == "qa_model":
segment_document.answer = segment["answer"] segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"]) segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count word_count_change += segment_document.word_count
db.session.add(segment_document) session.add(segment_document)
document_segments.append(segment_document) document_segments.append(segment_document)
segments_to_insert.append(str(segment)) # Cast to string if needed segments_to_insert.append(str(segment)) # Cast to string if needed
# update document word count # update document word count
dataset_document.word_count += word_count_change dataset_document.word_count += word_count_change
db.session.add(dataset_document) session.add(dataset_document)
# add index to db # add index to db
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
db.session.commit() session.commit()
redis_client.setex(indexing_cache_key, 600, "completed") redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(
click.style("Segment batch created job: {} latency: {}".format(job_id, end_at - start_at), fg="green") click.style(
"Segment batch created job: {} latency: {}".format(job_id, end_at - start_at),
fg="green",
)
) )
except Exception as e: except Exception as e:
logging.exception("Segments batch created index failed") logging.exception("Segments batch created index failed")

@ -44,6 +44,6 @@ def test_duplicated_dependency_crossing_groups() -> None:
dependency_names = list(dependencies.keys()) dependency_names = list(dependencies.keys())
all_dependency_names.extend(dependency_names) all_dependency_names.extend(dependency_names)
expected_all_dependency_names = set(all_dependency_names) expected_all_dependency_names = set(all_dependency_names)
assert sorted(expected_all_dependency_names) == sorted( assert sorted(expected_all_dependency_names) == sorted(all_dependency_names), (
all_dependency_names "Duplicated dependencies crossing groups are found"
), "Duplicated dependencies crossing groups are found" )

@ -89,9 +89,9 @@ class TestOpenSearchVector:
print("Actual document ID:", hits_by_vector[0].metadata["document_id"] if hits_by_vector else "No hits") print("Actual document ID:", hits_by_vector[0].metadata["document_id"] if hits_by_vector else "No hits")
assert len(hits_by_vector) > 0, f"Expected at least one hit, got {len(hits_by_vector)}" assert len(hits_by_vector) > 0, f"Expected at least one hit, got {len(hits_by_vector)}"
assert ( assert hits_by_vector[0].metadata["document_id"] == self.example_doc_id, (
hits_by_vector[0].metadata["document_id"] == self.example_doc_id f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}"
), f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}" )
def test_get_ids_by_metadata_field(self): def test_get_ids_by_metadata_field(self):
mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}} mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}}

@ -438,9 +438,9 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
# Verify the result # Verify the result
assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}" assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}"
assert ( assert prompt_messages == scenario.expected_messages, (
prompt_messages == scenario.expected_messages f"Message content mismatch in scenario: {scenario.description}"
), f"Message content mismatch in scenario: {scenario.description}" )
def test_handle_list_messages_basic(llm_node): def test_handle_list_messages_basic(llm_node):

@ -401,8 +401,7 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var
prompt_template = PromptTemplateEntity( prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED, prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\n" prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\nHuman: hi\nAssistant: ",
"Human: hi\nAssistant: ",
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"), role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"),
), ),
) )

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save