Merge branch 'main' of github.com:parambharat/dify into tracing-weave

# Conflicts:
#	api/pyproject.toml
#	web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx
pull/14262/head
Bharat Ramanathan 1 year ago
commit 0d160544ae

@ -1,13 +1,13 @@
#!/bin/bash #!/bin/bash
npm add -g pnpm@9.12.2 npm add -g pnpm@10.8.0
cd web && pnpm install cd web && pnpm install
pipx install poetry pipx install uv
echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc echo 'alias start-api="cd /workspaces/dify/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc
source /home/vscode/.bashrc source /home/vscode/.bashrc

@ -1,3 +1,3 @@
#!/bin/bash #!/bin/bash
cd api && poetry install cd api && uv sync

@ -1,36 +0,0 @@
name: Setup Poetry and Python
inputs:
python-version:
description: Python version to use and the Poetry installed with
required: true
default: '3.11'
poetry-version:
description: Poetry version to set up
required: true
default: '2.0.1'
poetry-lockfile:
description: Path to the Poetry lockfile to restore cache from
required: true
default: ''
runs:
using: composite
steps:
- name: Set up Python ${{ inputs.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ inputs.python-version }}
cache: pip
- name: Install Poetry
shell: bash
run: pip install poetry==${{ inputs.poetry-version }}
- name: Restore Poetry cache
if: ${{ inputs.poetry-lockfile != '' }}
uses: actions/setup-python@v5
with:
python-version: ${{ inputs.python-version }}
cache: poetry
cache-dependency-path: ${{ inputs.poetry-lockfile }}

@ -0,0 +1,34 @@
name: Setup UV and Python
inputs:
python-version:
description: Python version to use and the UV installed with
required: true
default: '3.12'
uv-version:
description: UV version to set up
required: true
default: '0.6.14'
uv-lockfile:
description: Path to the UV lockfile to restore cache from
required: true
default: ''
enable-cache:
required: true
default: true
runs:
using: composite
steps:
- name: Set up Python ${{ inputs.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ inputs.python-version }}
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
version: ${{ inputs.uv-version }}
python-version: ${{ inputs.python-version }}
enable-cache: ${{ inputs.enable-cache }}
cache-dependency-glob: ${{ inputs.uv-lockfile }}

@ -17,6 +17,9 @@ jobs:
test: test:
name: API Tests name: API Tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
defaults:
run:
shell: bash
strategy: strategy:
matrix: matrix:
python-version: python-version:
@ -27,35 +30,44 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Setup Poetry and Python ${{ matrix.python-version }} - name: Setup UV and Python
uses: ./.github/actions/setup-poetry uses: ./.github/actions/setup-uv
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
poetry-lockfile: api/poetry.lock uv-lockfile: api/uv.lock
- name: Check Poetry lockfile - name: Check UV lockfile
run: | run: uv lock --project api --check
poetry check -C api --lock
poetry show -C api
- name: Install dependencies - name: Install dependencies
run: poetry install -C api --with dev run: uv sync --project api --dev
- name: Check dependencies in pyproject.toml
run: poetry run -P api bash dev/pytest/pytest_artifacts.sh
- name: Run Unit tests - name: Run Unit tests
run: poetry run -P api bash dev/pytest/pytest_unit_tests.sh run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report >> $GITHUB_STEP_SUMMARY
echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
- name: Run dify config tests - name: Run dify config tests
run: poetry run -P api python dev/pytest/pytest_config_tests.py run: uv run --project api dev/pytest/pytest_config_tests.py
- name: Run mypy - name: MyPy Cache
run: | uses: actions/cache@v4
poetry run -C api python -m mypy --install-types --non-interactive . with:
path: api/.mypy_cache
key: mypy-${{ matrix.python-version }}-${{ runner.os }}-${{ hashFiles('api/uv.lock') }}
- name: Run MyPy Checks
run: dev/mypy-check
- name: Set up dotenvs - name: Set up dotenvs
run: | run: |
@ -75,4 +87,4 @@ jobs:
ssrf_proxy ssrf_proxy
- name: Run Workflow - name: Run Workflow
run: poetry run -P api bash dev/pytest/pytest_workflow.sh run: uv run --project api bash dev/pytest/pytest_workflow.sh

@ -24,13 +24,13 @@ jobs:
fetch-depth: 0 fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Setup Poetry and Python - name: Setup UV and Python
uses: ./.github/actions/setup-poetry uses: ./.github/actions/setup-uv
with: with:
poetry-lockfile: api/poetry.lock uv-lockfile: api/uv.lock
- name: Install dependencies - name: Install dependencies
run: poetry install -C api run: uv sync --project api
- name: Prepare middleware env - name: Prepare middleware env
run: | run: |
@ -54,6 +54,4 @@ jobs:
- name: Run DB Migration - name: Run DB Migration
env: env:
DEBUG: true DEBUG: true
run: | run: uv run --directory api flask upgrade-db
cd api
poetry run python -m flask upgrade-db

@ -42,6 +42,7 @@ jobs:
with: with:
push: false push: false
context: "{{defaultContext}}:${{ matrix.context }}" context: "{{defaultContext}}:${{ matrix.context }}"
file: "${{ matrix.file }}"
platforms: ${{ matrix.platform }} platforms: ${{ matrix.platform }}
cache-from: type=gha cache-from: type=gha
cache-to: type=gha,mode=max cache-to: type=gha,mode=max

@ -18,7 +18,6 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Check changed files - name: Check changed files
@ -29,24 +28,27 @@ jobs:
api/** api/**
.github/workflows/style.yml .github/workflows/style.yml
- name: Setup Poetry and Python - name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
uses: ./.github/actions/setup-poetry uses: ./.github/actions/setup-uv
with:
uv-lockfile: api/uv.lock
enable-cache: false
- name: Install dependencies - name: Install dependencies
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: poetry install -C api --only lint run: uv sync --project api --dev
- name: Ruff check - name: Ruff check
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: | run: |
poetry run -C api ruff --version uv run --directory api ruff --version
poetry run -C api ruff check ./ uv run --directory api ruff check ./
poetry run -C api ruff format --check ./ uv run --directory api ruff format --check ./
- name: Dotenv check - name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: poetry run -P api dotenv-linter ./api/.env.example ./web/.env.example run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
- name: Lint hints - name: Lint hints
if: failure() if: failure()
@ -63,7 +65,6 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Check changed files - name: Check changed files
@ -82,7 +83,7 @@ jobs:
uses: actions/setup-node@v4 uses: actions/setup-node@v4
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
with: with:
node-version: 20 node-version: 22
cache: pnpm cache: pnpm
cache-dependency-path: ./web/package.json cache-dependency-path: ./web/package.json
@ -102,7 +103,6 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Check changed files - name: Check changed files
@ -133,7 +133,6 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Check changed files - name: Check changed files
@ -153,6 +152,7 @@ jobs:
env: env:
BASH_SEVERITY: warning BASH_SEVERITY: warning
DEFAULT_BRANCH: main DEFAULT_BRANCH: main
FILTER_REGEX_INCLUDE: pnpm-lock.yaml
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
IGNORE_GENERATED_FILES: true IGNORE_GENERATED_FILES: true
IGNORE_GITIGNORED_FILES: true IGNORE_GITIGNORED_FILES: true

@ -18,7 +18,7 @@ jobs:
strategy: strategy:
matrix: matrix:
node-version: [16, 18, 20] node-version: [16, 18, 20, 22]
defaults: defaults:
run: run:
@ -27,7 +27,6 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Use Node.js ${{ matrix.node-version }} - name: Use Node.js ${{ matrix.node-version }}

@ -33,7 +33,7 @@ jobs:
- name: Set up Node.js - name: Set up Node.js
if: env.FILES_CHANGED == 'true' if: env.FILES_CHANGED == 'true'
uses: actions/setup-node@v2 uses: actions/setup-node@v4
with: with:
node-version: 'lts/*' node-version: 'lts/*'

@ -8,7 +8,7 @@ on:
- api/core/rag/datasource/** - api/core/rag/datasource/**
- docker/** - docker/**
- .github/workflows/vdb-tests.yml - .github/workflows/vdb-tests.yml
- api/poetry.lock - api/uv.lock
- api/pyproject.toml - api/pyproject.toml
concurrency: concurrency:
@ -29,22 +29,19 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Setup Poetry and Python ${{ matrix.python-version }} - name: Setup UV and Python
uses: ./.github/actions/setup-poetry uses: ./.github/actions/setup-uv
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
poetry-lockfile: api/poetry.lock uv-lockfile: api/uv.lock
- name: Check Poetry lockfile - name: Check UV lockfile
run: | run: uv lock --project api --check
poetry check -C api --lock
poetry show -C api
- name: Install dependencies - name: Install dependencies
run: poetry install -C api --with dev run: uv sync --project api --dev
- name: Set up dotenvs - name: Set up dotenvs
run: | run: |
@ -80,7 +77,7 @@ jobs:
elasticsearch elasticsearch
- name: Check TiDB Ready - name: Check TiDB Ready
run: poetry run -P api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
- name: Test Vector Stores - name: Test Vector Stores
run: poetry run -P api bash dev/pytest/pytest_vdb.sh run: uv run --project api bash dev/pytest/pytest_vdb.sh

@ -23,7 +23,6 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Check changed files - name: Check changed files
@ -31,7 +30,9 @@ jobs:
uses: tj-actions/changed-files@v45 uses: tj-actions/changed-files@v45
with: with:
files: web/** files: web/**
- name: Install pnpm - name: Install pnpm
if: steps.changed-files.outputs.any_changed == 'true'
uses: pnpm/action-setup@v4 uses: pnpm/action-setup@v4
with: with:
version: 10 version: 10
@ -41,7 +42,7 @@ jobs:
uses: actions/setup-node@v4 uses: actions/setup-node@v4
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
with: with:
node-version: 20 node-version: 22
cache: pnpm cache: pnpm
cache-dependency-path: ./web/package.json cache-dependency-path: ./web/package.json

1
.gitignore vendored

@ -46,6 +46,7 @@ htmlcov/
.cache .cache
nosetests.xml nosetests.xml
coverage.xml coverage.xml
coverage.json
*.cover *.cover
*.py,cover *.py,cover
.hypothesis/ .hypothesis/

@ -254,8 +254,6 @@ docker compose up -d
- [Discord](https://discord.gg/FngNHpbcY7)。👉:分享您的应用程序并与社区交流。 - [Discord](https://discord.gg/FngNHpbcY7)。👉:分享您的应用程序并与社区交流。
- [X(Twitter)](https://twitter.com/dify_ai)。👉:分享您的应用程序并与社区交流。 - [X(Twitter)](https://twitter.com/dify_ai)。👉:分享您的应用程序并与社区交流。
- [商业许可](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)。👉:有关商业用途许可 Dify.AI 的商业咨询。 - [商业许可](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)。👉:有关商业用途许可 Dify.AI 的商业咨询。
- [微信]() 👉:扫描下方二维码,添加微信好友,备注 Dify我们将邀请您加入 Dify 社区。
<img src="./images/wechat.png" alt="wechat" width="100"/>
## 安全问题 ## 安全问题

@ -165,6 +165,7 @@ MILVUS_URI=http://127.0.0.1:19530
MILVUS_TOKEN= MILVUS_TOKEN=
MILVUS_USER=root MILVUS_USER=root
MILVUS_PASSWORD=Milvus MILVUS_PASSWORD=Milvus
MILVUS_ANALYZER_PARAMS=
# MyScale configuration # MyScale configuration
MYSCALE_HOST=127.0.0.1 MYSCALE_HOST=127.0.0.1
@ -189,6 +190,7 @@ TENCENT_VECTOR_DB_USERNAME=dify
TENCENT_VECTOR_DB_DATABASE=dify TENCENT_VECTOR_DB_DATABASE=dify
TENCENT_VECTOR_DB_SHARD=1 TENCENT_VECTOR_DB_SHARD=1
TENCENT_VECTOR_DB_REPLICAS=2 TENCENT_VECTOR_DB_REPLICAS=2
TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH=false
# ElasticSearch configuration # ElasticSearch configuration
ELASTICSEARCH_HOST=127.0.0.1 ELASTICSEARCH_HOST=127.0.0.1
@ -325,6 +327,7 @@ UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
MULTIMODAL_SEND_FORMAT=base64 MULTIMODAL_SEND_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512 PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024 CODE_GENERATION_MAX_TOKENS=1024
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
# Mail configuration, support: resend, smtp # Mail configuration, support: resend, smtp
MAIL_TYPE= MAIL_TYPE=
@ -421,6 +424,12 @@ WORKFLOW_CALL_MAX_DEPTH=5
WORKFLOW_PARALLEL_DEPTH_LIMIT=3 WORKFLOW_PARALLEL_DEPTH_LIMIT=3
MAX_VARIABLE_SIZE=204800 MAX_VARIABLE_SIZE=204800
# Workflow storage configuration
# Options: rdbms, hybrid
# rdbms: Use only the relational database (default)
# hybrid: Save new data to object storage, read from both object storage and RDBMS
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
# App configuration # App configuration
APP_MAX_EXECUTION_TIME=1200 APP_MAX_EXECUTION_TIME=1200
APP_MAX_ACTIVE_REQUESTS=0 APP_MAX_ACTIVE_REQUESTS=0
@ -461,3 +470,16 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false
MAX_SUBMIT_COUNT=100 MAX_SUBMIT_COUNT=100
# Lockout duration in seconds # Lockout duration in seconds
LOGIN_LOCKOUT_DURATION=86400 LOGIN_LOCKOUT_DURATION=86400
# Enable OpenTelemetry
ENABLE_OTEL=false
OTLP_BASE_ENDPOINT=http://localhost:4318
OTLP_API_KEY=
OTEL_EXPORTER_TYPE=otlp
OTEL_SAMPLING_RATE=0.1
OTEL_BATCH_EXPORT_SCHEDULE_DELAY=5000
OTEL_MAX_QUEUE_SIZE=2048
OTEL_MAX_EXPORT_BATCH_SIZE=512
OTEL_METRIC_EXPORT_INTERVAL=60000
OTEL_BATCH_EXPORT_TIMEOUT=10000
OTEL_METRIC_EXPORT_TIMEOUT=30000

@ -3,20 +3,11 @@ FROM python:3.12-slim-bookworm AS base
WORKDIR /app/api WORKDIR /app/api
# Install Poetry # Install uv
ENV POETRY_VERSION=2.0.1 ENV UV_VERSION=0.6.14
# if you located in China, you can use aliyun mirror to speed up RUN pip install --no-cache-dir uv==${UV_VERSION}
# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/
RUN pip install --no-cache-dir poetry==${POETRY_VERSION}
# Configure Poetry
ENV POETRY_CACHE_DIR=/tmp/poetry_cache
ENV POETRY_NO_INTERACTION=1
ENV POETRY_VIRTUALENVS_IN_PROJECT=true
ENV POETRY_VIRTUALENVS_CREATE=true
ENV POETRY_REQUESTS_TIMEOUT=15
FROM base AS packages FROM base AS packages
@ -27,8 +18,8 @@ RUN apt-get update \
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev && apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
# Install Python dependencies # Install Python dependencies
COPY pyproject.toml poetry.lock ./ COPY pyproject.toml uv.lock ./
RUN poetry install --sync --no-cache --no-root RUN uv sync --locked
# production stage # production stage
FROM base AS production FROM base AS production

@ -3,7 +3,10 @@
## Usage ## Usage
> [!IMPORTANT] > [!IMPORTANT]
> In the v0.6.12 release, we deprecated `pip` as the package management tool for Dify API Backend service and replaced it with `poetry`. >
> In the v1.3.0 release, `poetry` has been replaced with
> [`uv`](https://docs.astral.sh/uv/) as the package manager
> for Dify API backend service.
1. Start the docker-compose stack 1. Start the docker-compose stack
@ -37,19 +40,19 @@
4. Create environment. 4. Create environment.
Dify API service uses [Poetry](https://python-poetry.org/docs/) to manage dependencies. First, you need to add the poetry shell plugin, if you don't have it already, in order to run in a virtual environment. [Note: Poetry shell is no longer a native command so you need to install the poetry plugin beforehand] Dify API service uses [UV](https://docs.astral.sh/uv/) to manage dependencies.
First, you need to add the uv package manager, if you don't have it already.
```bash ```bash
poetry self add poetry-plugin-shell pip install uv
# Or on macOS
brew install uv
``` ```
Then, You can execute `poetry shell` to activate the environment.
5. Install dependencies 5. Install dependencies
```bash ```bash
poetry env use 3.12 uv sync --dev
poetry install
``` ```
6. Run migrate 6. Run migrate
@ -57,21 +60,21 @@
Before the first launch, migrate the database to the latest version. Before the first launch, migrate the database to the latest version.
```bash ```bash
poetry run python -m flask db upgrade uv run flask db upgrade
``` ```
7. Start backend 7. Start backend
```bash ```bash
poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug uv run flask run --host 0.0.0.0 --port=5001 --debug
``` ```
8. Start Dify [web](../web) service. 8. Start Dify [web](../web) service.
9. Setup your application by visiting `http://localhost:3000`... 9. Setup your application by visiting `http://localhost:3000`.
10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. 10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash ```bash
poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion
``` ```
## Testing ## Testing
@ -79,11 +82,11 @@
1. Install dependencies for both the backend and the test environment 1. Install dependencies for both the backend and the test environment
```bash ```bash
poetry install -C api --with dev uv sync --dev
``` ```
2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml` 2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
```bash ```bash
poetry run -P api bash dev/pytest/pytest_all_tests.sh uv run -P api bash dev/pytest/pytest_all_tests.sh
``` ```

@ -51,8 +51,10 @@ def initialize_extensions(app: DifyApp):
ext_login, ext_login,
ext_mail, ext_mail,
ext_migrate, ext_migrate,
ext_otel,
ext_proxy_fix, ext_proxy_fix,
ext_redis, ext_redis,
ext_repositories,
ext_sentry, ext_sentry,
ext_set_secretkey, ext_set_secretkey,
ext_storage, ext_storage,
@ -73,6 +75,7 @@ def initialize_extensions(app: DifyApp):
ext_migrate, ext_migrate,
ext_redis, ext_redis,
ext_storage, ext_storage,
ext_repositories,
ext_celery, ext_celery,
ext_login, ext_login,
ext_mail, ext_mail,
@ -81,6 +84,7 @@ def initialize_extensions(app: DifyApp):
ext_proxy_fix, ext_proxy_fix,
ext_blueprints, ext_blueprints,
ext_commands, ext_commands,
ext_otel,
] ]
for ext in extensions: for ext in extensions:
short_name = ext.__name__.split(".")[-1] short_name = ext.__name__.split(".")[-1]

@ -9,6 +9,7 @@ from .enterprise import EnterpriseFeatureConfig
from .extra import ExtraServiceConfig from .extra import ExtraServiceConfig
from .feature import FeatureConfig from .feature import FeatureConfig
from .middleware import MiddlewareConfig from .middleware import MiddlewareConfig
from .observability import ObservabilityConfig
from .packaging import PackagingInfo from .packaging import PackagingInfo
from .remote_settings_sources import RemoteSettingsSource, RemoteSettingsSourceConfig, RemoteSettingsSourceName from .remote_settings_sources import RemoteSettingsSource, RemoteSettingsSourceConfig, RemoteSettingsSourceName
from .remote_settings_sources.apollo import ApolloSettingsSource from .remote_settings_sources.apollo import ApolloSettingsSource
@ -59,6 +60,8 @@ class DifyConfig(
MiddlewareConfig, MiddlewareConfig,
# Extra service configs # Extra service configs
ExtraServiceConfig, ExtraServiceConfig,
# Observability configs
ObservabilityConfig,
# Remote source configs # Remote source configs
RemoteSettingsSourceConfig, RemoteSettingsSourceConfig,
# Enterprise feature configs # Enterprise feature configs

@ -12,7 +12,7 @@ from pydantic import (
) )
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
from configs.feature.hosted_service import HostedServiceConfig from .hosted_service import HostedServiceConfig
class SecurityConfig(BaseSettings): class SecurityConfig(BaseSettings):
@ -442,7 +442,7 @@ class LoggingConfig(BaseSettings):
class ModelLoadBalanceConfig(BaseSettings): class ModelLoadBalanceConfig(BaseSettings):
""" """
Configuration for model load balancing Configuration for model load balancing and token counting
""" """
MODEL_LB_ENABLED: bool = Field( MODEL_LB_ENABLED: bool = Field(
@ -450,6 +450,11 @@ class ModelLoadBalanceConfig(BaseSettings):
default=False, default=False,
) )
PLUGIN_BASED_TOKEN_COUNTING_ENABLED: bool = Field(
description="Enable or disable plugin based token counting. If disabled, token counting will return 0.",
default=False,
)
class BillingConfig(BaseSettings): class BillingConfig(BaseSettings):
""" """
@ -514,6 +519,11 @@ class WorkflowNodeExecutionConfig(BaseSettings):
default=100, default=100,
) )
WORKFLOW_NODE_EXECUTION_STORAGE: str = Field(
default="rdbms",
description="Storage backend for WorkflowNodeExecution. Options: 'rdbms', 'hybrid'",
)
class AuthConfig(BaseSettings): class AuthConfig(BaseSettings):
""" """
@ -848,6 +858,11 @@ class AccountConfig(BaseSettings):
default=5, default=5,
) )
EDUCATION_ENABLED: bool = Field(
description="whether to enable education identity",
default=False,
)
class FeatureConfig( class FeatureConfig(
# place the configs in alphabet order # place the configs in alphabet order

@ -39,3 +39,8 @@ class MilvusConfig(BaseSettings):
"older versions", "older versions",
default=True, default=True,
) )
MILVUS_ANALYZER_PARAMS: Optional[str] = Field(
description='Milvus text analyzer parameters, e.g., {"type": "chinese"} for Chinese segmentation support.',
default=None,
)

@ -48,3 +48,8 @@ class TencentVectorDBConfig(BaseSettings):
description="Name of the specific Tencent Vector Database to connect to", description="Name of the specific Tencent Vector Database to connect to",
default=None, default=None,
) )
TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH: bool = Field(
description="Enable hybrid search features",
default=False,
)

@ -0,0 +1,9 @@
from configs.observability.otel.otel_config import OTelConfig
class ObservabilityConfig(OTelConfig):
"""
Observability configuration settings
"""
pass

@ -0,0 +1,44 @@
from pydantic import Field
from pydantic_settings import BaseSettings
class OTelConfig(BaseSettings):
"""
OpenTelemetry configuration settings
"""
ENABLE_OTEL: bool = Field(
description="Whether to enable OpenTelemetry",
default=False,
)
OTLP_BASE_ENDPOINT: str = Field(
description="OTLP base endpoint",
default="http://localhost:4318",
)
OTLP_API_KEY: str = Field(
description="OTLP API key",
default="",
)
OTEL_EXPORTER_TYPE: str = Field(
description="OTEL exporter type",
default="otlp",
)
OTEL_SAMPLING_RATE: float = Field(default=0.1, description="Sampling rate for traces (0.0 to 1.0)")
OTEL_BATCH_EXPORT_SCHEDULE_DELAY: int = Field(
default=5000, description="Batch export schedule delay in milliseconds"
)
OTEL_MAX_QUEUE_SIZE: int = Field(default=2048, description="Maximum queue size for the batch span processor")
OTEL_MAX_EXPORT_BATCH_SIZE: int = Field(default=512, description="Maximum export batch size")
OTEL_METRIC_EXPORT_INTERVAL: int = Field(default=60000, description="Metric export interval in milliseconds")
OTEL_BATCH_EXPORT_TIMEOUT: int = Field(default=10000, description="Batch export timeout in milliseconds")
OTEL_METRIC_EXPORT_TIMEOUT: int = Field(default=30000, description="Metric export timeout in milliseconds")

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field( CURRENT_VERSION: str = Field(
description="Dify version", description="Dify version",
default="1.1.3", default="1.2.0",
) )
COMMIT_SHA: str = Field( COMMIT_SHA: str = Field(

@ -270,7 +270,7 @@ class ApolloClient:
while not self._stopping: while not self._stopping:
for namespace in self._notification_map: for namespace in self._notification_map:
self._do_heart_beat(namespace) self._do_heart_beat(namespace)
time.sleep(60 * 10) # 10分钟 time.sleep(60 * 10) # 10 minutes
def _do_heart_beat(self, namespace): def _do_heart_beat(self, namespace):
url = "{}/configs/{}/{}/{}?ip={}".format(self.config_url, self.app_id, self.cluster, namespace, self.ip) url = "{}/configs/{}/{}/{}?ip={}".format(self.config_url, self.app_id, self.cluster, namespace, self.ip)

@ -3,6 +3,8 @@ from configs import dify_config
HIDDEN_VALUE = "[__HIDDEN__]" HIDDEN_VALUE = "[__HIDDEN__]"
UUID_NIL = "00000000-0000-0000-0000-000000000000" UUID_NIL = "00000000-0000-0000-0000-000000000000"
DEFAULT_FILE_NUMBER_LIMITS = 3
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])

@ -4,8 +4,6 @@ import platform
import re import re
import urllib.parse import urllib.parse
import warnings import warnings
from collections.abc import Mapping
from typing import Any
from uuid import uuid4 from uuid import uuid4
import httpx import httpx
@ -29,8 +27,6 @@ except ImportError:
from pydantic import BaseModel from pydantic import BaseModel
from configs import dify_config
class FileInfo(BaseModel): class FileInfo(BaseModel):
filename: str filename: str
@ -87,38 +83,3 @@ def guess_file_info_from_response(response: httpx.Response):
mimetype=mimetype, mimetype=mimetype,
size=int(response.headers.get("Content-Length", -1)), size=int(response.headers.get("Content-Length", -1)),
) )
def get_parameters_from_feature_dict(*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]):
return {
"opening_statement": features_dict.get("opening_statement"),
"suggested_questions": features_dict.get("suggested_questions", []),
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
"user_input_form": user_input_form,
"sensitive_word_avoidance": features_dict.get(
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
),
"file_upload": features_dict.get(
"file_upload",
{
"image": {
"enabled": False,
"number_limits": 3,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"],
}
},
),
"system_parameters": {
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
},
}

@ -89,7 +89,7 @@ class AnnotationReplyActionStatusApi(Resource):
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id)) app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
cache_result = redis_client.get(app_annotation_job_key) cache_result = redis_client.get(app_annotation_job_key)
if cache_result is None: if cache_result is None:
raise ValueError("The job is not exist.") raise ValueError("The job does not exist.")
job_status = cache_result.decode() job_status = cache_result.decode()
error_msg = "" error_msg = ""
@ -226,7 +226,7 @@ class AnnotationBatchImportStatusApi(Resource):
indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id))
cache_result = redis_client.get(indexing_cache_key) cache_result = redis_client.get(indexing_cache_key)
if cache_result is None: if cache_result is None:
raise ValueError("The job is not exist.") raise ValueError("The job does not exist.")
job_status = cache_result.decode() job_status = cache_result.decode()
error_msg = "" error_msg = ""
if job_status == "error": if job_status == "error":

@ -8,6 +8,7 @@ from werkzeug.exceptions import Forbidden
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_resource_check,
setup_required, setup_required,
) )
from extensions.ext_database import db from extensions.ext_database import db
@ -23,6 +24,7 @@ class AppImportApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(app_import_fields) @marshal_with(app_import_fields)
@cloud_edition_billing_resource_check("apps")
def post(self): def post(self):
# Check user role first # Check user role first
if not current_user.is_editor: if not current_user.is_editor:

@ -1,5 +1,4 @@
from datetime import datetime from dateutil.parser import isoparse
from flask_restful import Resource, marshal_with, reqparse # type: ignore from flask_restful import Resource, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range # type: ignore from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -41,10 +40,10 @@ class WorkflowAppLogApi(Resource):
args.status = WorkflowRunStatus(args.status) if args.status else None args.status = WorkflowRunStatus(args.status) if args.status else None
if args.created_at__before: if args.created_at__before:
args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00")) args.created_at__before = isoparse(args.created_at__before)
if args.created_at__after: if args.created_at__after:
args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00")) args.created_at__after = isoparse(args.created_at__after)
# get paginate workflow app logs # get paginate workflow app logs
workflow_app_service = WorkflowAppService() workflow_app_service = WorkflowAppService()

@ -74,7 +74,9 @@ class OAuthDataSourceBinding(Resource):
if not oauth_provider: if not oauth_provider:
return {"error": "Invalid provider"}, 400 return {"error": "Invalid provider"}, 400
if "code" in request.args: if "code" in request.args:
code = request.args.get("code") code = request.args.get("code", "")
if not code:
return {"error": "Invalid code"}, 400
try: try:
oauth_provider.get_access_token(code) oauth_provider.get_access_token(code)
except requests.exceptions.HTTPError as e: except requests.exceptions.HTTPError as e:

@ -99,44 +99,57 @@ class ForgotPasswordResetApi(Resource):
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
new_password = args["new_password"] # Validate passwords match
password_confirm = args["password_confirm"] if args["new_password"] != args["password_confirm"]:
if str(new_password).strip() != str(password_confirm).strip():
raise PasswordMismatchError() raise PasswordMismatchError()
token = args["token"] # Validate token and get reset data
reset_data = AccountService.get_reset_password_data(token) reset_data = AccountService.get_reset_password_data(args["token"])
if not reset_data:
if reset_data is None:
raise InvalidTokenError() raise InvalidTokenError()
AccountService.revoke_reset_password_token(token) # Revoke token to prevent reuse
AccountService.revoke_reset_password_token(args["token"])
# Generate secure salt and hash password
salt = secrets.token_bytes(16) salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode() password_hashed = hash_password(args["new_password"], salt)
password_hashed = hash_password(new_password, salt) email = reset_data.get("email", "")
base64_password_hashed = base64.b64encode(password_hashed).decode()
with Session(db.engine) as session: with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=reset_data.get("email"))).scalar_one_or_none() account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account: if account:
account.password = base64_password_hashed self._update_existing_account(account, password_hashed, salt, session)
account.password_salt = base64_salt else:
db.session.commit() self._create_new_account(email, args["password_confirm"])
tenant = TenantService.get_join_tenants(account)
if not tenant and not FeatureService.get_system_features().is_allow_create_workspace: return {"result": "success"}
def _update_existing_account(self, account, password_hashed, salt, session):
# Update existing account credentials
account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode()
session.commit()
# Create workspace if needed
if (
not TenantService.get_join_tenants(account)
and FeatureService.get_system_features().is_allow_create_workspace
):
tenant = TenantService.create_tenant(f"{account.name}'s Workspace") tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner") TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant account.current_tenant = tenant
tenant_was_created.send(tenant) tenant_was_created.send(tenant)
else:
def _create_new_account(self, email, password):
# Create new account if allowed
try: try:
account = AccountService.create_account_and_tenant( AccountService.create_account_and_tenant(
email=reset_data.get("email", ""), email=email,
name=reset_data.get("email", ""), name=email,
password=password_confirm, password=password,
interface_language=languages[0], interface_language=languages[0],
) )
except WorkSpaceNotAllowedCreateError: except WorkSpaceNotAllowedCreateError:
@ -144,8 +157,6 @@ class ForgotPasswordResetApi(Resource):
except AccountRegisterError: except AccountRegisterError:
raise AccountInFreezeError() raise AccountInFreezeError()
return {"result": "success"}
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password") api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity") api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")

@ -641,7 +641,6 @@ class DatasetRetrievalSettingApi(Resource):
VectorType.RELYT VectorType.RELYT
| VectorType.TIDB_VECTOR | VectorType.TIDB_VECTOR
| VectorType.CHROMA | VectorType.CHROMA
| VectorType.TENCENT
| VectorType.PGVECTO_RS | VectorType.PGVECTO_RS
| VectorType.BAIDU | VectorType.BAIDU
| VectorType.VIKINGDB | VectorType.VIKINGDB
@ -665,6 +664,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.OPENGAUSS | VectorType.OPENGAUSS
| VectorType.OCEANBASE | VectorType.OCEANBASE
| VectorType.TABLESTORE | VectorType.TABLESTORE
| VectorType.TENCENT
): ):
return { return {
"retrieval_method": [ "retrieval_method": [
@ -688,7 +688,6 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.RELYT | VectorType.RELYT
| VectorType.TIDB_VECTOR | VectorType.TIDB_VECTOR
| VectorType.CHROMA | VectorType.CHROMA
| VectorType.TENCENT
| VectorType.PGVECTO_RS | VectorType.PGVECTO_RS
| VectorType.BAIDU | VectorType.BAIDU
| VectorType.VIKINGDB | VectorType.VIKINGDB
@ -710,6 +709,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.OPENGAUSS | VectorType.OPENGAUSS
| VectorType.OCEANBASE | VectorType.OCEANBASE
| VectorType.TABLESTORE | VectorType.TABLESTORE
| VectorType.TENCENT
): ):
return { return {
"retrieval_method": [ "retrieval_method": [

@ -398,7 +398,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
indexing_cache_key = "segment_batch_import_{}".format(job_id) indexing_cache_key = "segment_batch_import_{}".format(job_id)
cache_result = redis_client.get(indexing_cache_key) cache_result = redis_client.get(indexing_cache_key)
if cache_result is None: if cache_result is None:
raise ValueError("The job is not exist.") raise ValueError("The job does not exist.")
return {"job_id": job_id, "job_status": cache_result.decode()}, 200 return {"job_id": job_id, "job_status": cache_result.decode()}, 200

@ -21,12 +21,6 @@ def _validate_name(name):
return name return name
def _validate_description_length(description):
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class ExternalApiTemplateListApi(Resource): class ExternalApiTemplateListApi(Resource):
@setup_required @setup_required
@login_required @login_required

@ -14,18 +14,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
from services.metadata_service import MetadataService from services.metadata_service import MetadataService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class DatasetMetadataCreateApi(Resource): class DatasetMetadataCreateApi(Resource):
@setup_required @setup_required
@login_required @login_required

@ -14,7 +14,12 @@ class WebsiteCrawlApi(Resource):
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument( parser.add_argument(
"provider", type=str, choices=["firecrawl", "jinareader"], required=True, nullable=True, location="json" "provider",
type=str,
choices=["firecrawl", "watercrawl", "jinareader"],
required=True,
nullable=True,
location="json",
) )
parser.add_argument("url", type=str, required=True, nullable=True, location="json") parser.add_argument("url", type=str, required=True, nullable=True, location="json")
parser.add_argument("options", type=dict, required=True, nullable=True, location="json") parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
@ -34,7 +39,9 @@ class WebsiteCrawlStatusApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, job_id: str): def get(self, job_id: str):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, choices=["firecrawl", "jinareader"], required=True, location="args") parser.add_argument(
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
)
args = parser.parse_args() args = parser.parse_args()
# get crawl status # get crawl status
try: try:

@ -103,6 +103,18 @@ class AccountInFreezeError(BaseHTTPException):
) )
class EducationVerifyLimitError(BaseHTTPException):
error_code = "education_verify_limit"
description = "Rate limit exceeded"
code = 429
class EducationActivateLimitError(BaseHTTPException):
error_code = "education_activate_limit"
description = "Rate limit exceeded"
code = 429
class CompilanceRateLimitError(BaseHTTPException): class CompilanceRateLimitError(BaseHTTPException):
error_code = "compilance_rate_limit" error_code = "compilance_rate_limit"
description = "Rate limit exceeded for downloading compliance report." description = "Rate limit exceeded for downloading compliance report."

@ -1,10 +1,10 @@
from flask_restful import marshal_with # type: ignore from flask_restful import marshal_with # type: ignore
from controllers.common import fields from controllers.common import fields
from controllers.common import helpers as controller_helpers
from controllers.console import api from controllers.console import api
from controllers.console.app.error import AppUnavailableError from controllers.console.app.error import AppUnavailableError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import AppMode, InstalledApp from models.model import AppMode, InstalledApp
from services.app_service import AppService from services.app_service import AppService
@ -36,9 +36,7 @@ class AppParameterApi(InstalledAppResource):
user_input_form = features_dict.get("user_input_form", []) user_input_form = features_dict.get("user_input_form", [])
return controller_helpers.get_parameters_from_feature_dict( return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
features_dict=features_dict, user_input_form=user_input_form
)
class ExploreAppMetaApi(InstalledAppResource): class ExploreAppMetaApi(InstalledAppResource):

@ -15,7 +15,13 @@ from controllers.console.workspace.error import (
InvalidInvitationCodeError, InvalidInvitationCodeError,
RepeatPasswordNotMatchError, RepeatPasswordNotMatchError,
) )
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_enabled,
enterprise_license_required,
only_edition_cloud,
setup_required,
)
from extensions.ext_database import db from extensions.ext_database import db
from fields.member_fields import account_fields from fields.member_fields import account_fields
from libs.helper import TimestampField, timezone from libs.helper import TimestampField, timezone
@ -280,8 +286,6 @@ class AccountDeleteApi(Resource):
class AccountDeleteUpdateFeedbackApi(Resource): class AccountDeleteUpdateFeedbackApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json") parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("feedback", type=str, required=True, location="json") parser.add_argument("feedback", type=str, required=True, location="json")
@ -292,6 +296,79 @@ class AccountDeleteUpdateFeedbackApi(Resource):
return {"result": "success"} return {"result": "success"}
class EducationVerifyApi(Resource):
verify_fields = {
"token": fields.String,
}
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
@marshal_with(verify_fields)
def get(self):
account = current_user
return BillingService.EducationIdentity.verify(account.id, account.email)
class EducationApi(Resource):
status_fields = {
"result": fields.Boolean,
}
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
def post(self):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, location="json")
parser.add_argument("institution", type=str, required=True, location="json")
parser.add_argument("role", type=str, required=True, location="json")
args = parser.parse_args()
return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"])
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
@marshal_with(status_fields)
def get(self):
account = current_user
return BillingService.EducationIdentity.is_active(account.id)
class EducationAutoCompleteApi(Resource):
data_fields = {
"data": fields.List(fields.String),
"curr_page": fields.Integer,
"has_next": fields.Boolean,
}
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
@marshal_with(data_fields)
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("keywords", type=str, required=True, location="args")
parser.add_argument("page", type=int, required=False, location="args", default=0)
parser.add_argument("limit", type=int, required=False, location="args", default=20)
args = parser.parse_args()
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
# Register API resources # Register API resources
api.add_resource(AccountInitApi, "/account/init") api.add_resource(AccountInitApi, "/account/init")
api.add_resource(AccountProfileApi, "/account/profile") api.add_resource(AccountProfileApi, "/account/profile")
@ -305,5 +382,8 @@ api.add_resource(AccountIntegrateApi, "/account/integrates")
api.add_resource(AccountDeleteVerifyApi, "/account/delete/verify") api.add_resource(AccountDeleteVerifyApi, "/account/delete/verify")
api.add_resource(AccountDeleteApi, "/account/delete") api.add_resource(AccountDeleteApi, "/account/delete")
api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback") api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback")
api.add_resource(EducationVerifyApi, "/account/education/verify")
api.add_resource(EducationApi, "/account/education")
api.add_resource(EducationAutoCompleteApi, "/account/education/autocomplete")
# api.add_resource(AccountEmailApi, '/account/email') # api.add_resource(AccountEmailApi, '/account/email')
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify') # api.add_resource(AccountEmailVerifyApi, '/account/email-verify')

@ -49,6 +49,23 @@ class PluginListApi(Resource):
return jsonable_encoder({"plugins": plugins}) return jsonable_encoder({"plugins": plugins})
class PluginListLatestVersionsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
req = reqparse.RequestParser()
req.add_argument("plugin_ids", type=list, required=True, location="json")
args = req.parse_args()
try:
versions = PluginService.list_latest_versions(args["plugin_ids"])
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"versions": versions})
class PluginListInstallationsFromIdsApi(Resource): class PluginListInstallationsFromIdsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -232,11 +249,36 @@ class PluginInstallFromMarketplaceApi(Resource):
return jsonable_encoder(response) return jsonable_encoder(response)
class PluginFetchMarketplacePkgApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
args = parser.parse_args()
try:
return jsonable_encoder(
{
"manifest": PluginService.fetch_marketplace_pkg(
tenant_id,
args["plugin_unique_identifier"],
)
}
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginFetchManifestApi(Resource): class PluginFetchManifestApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@plugin_permission_required(debug_required=True) @plugin_permission_required(install_required=True)
def get(self): def get(self):
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
@ -260,7 +302,7 @@ class PluginFetchInstallTasksApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@plugin_permission_required(debug_required=True) @plugin_permission_required(install_required=True)
def get(self): def get(self):
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
@ -281,7 +323,7 @@ class PluginFetchInstallTaskApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@plugin_permission_required(debug_required=True) @plugin_permission_required(install_required=True)
def get(self, task_id: str): def get(self, task_id: str):
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
@ -295,7 +337,7 @@ class PluginDeleteInstallTaskApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@plugin_permission_required(debug_required=True) @plugin_permission_required(install_required=True)
def post(self, task_id: str): def post(self, task_id: str):
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
@ -309,7 +351,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@plugin_permission_required(debug_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
@ -323,7 +365,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@plugin_permission_required(debug_required=True) @plugin_permission_required(install_required=True)
def post(self, task_id: str, identifier: str): def post(self, task_id: str, identifier: str):
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
@ -337,7 +379,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@plugin_permission_required(debug_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
@ -360,7 +402,7 @@ class PluginUpgradeFromGithubApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@plugin_permission_required(debug_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
@ -391,7 +433,7 @@ class PluginUninstallApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@plugin_permission_required(debug_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
req = reqparse.RequestParser() req = reqparse.RequestParser()
req.add_argument("plugin_installation_id", type=str, required=True, location="json") req.add_argument("plugin_installation_id", type=str, required=True, location="json")
@ -453,6 +495,7 @@ class PluginFetchPermissionApi(Resource):
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key") api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
api.add_resource(PluginListApi, "/workspaces/current/plugin/list") api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions")
api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids") api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids")
api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon") api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon")
api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg") api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg")
@ -470,6 +513,7 @@ api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<
api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all") api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all")
api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>") api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall") api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")
api.add_resource(PluginFetchMarketplacePkgApi, "/workspaces/current/plugin/marketplace/pkg")
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change") api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch") api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")

@ -216,6 +216,23 @@ class WebappLogoWorkspaceApi(Resource):
return {"id": upload_file.id}, 201 return {"id": upload_file.id}, 201
class WorkspaceInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
# Change workspace name
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404()
tenant.name = args["name"]
db.session.commit()
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants
api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants
api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info
@ -223,3 +240,4 @@ api.add_resource(TenantApi, "/info", endpoint="info") # Deprecated
api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant
api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config") api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config")
api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload") api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload")
api.add_resource(WorkspaceInfoApi, "/workspaces/info") # POST for changing workspace info

@ -54,6 +54,17 @@ def only_edition_self_hosted(view):
return decorated return decorated
def cloud_edition_billing_enabled(view):
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled:
abort(403, "Billing feature is not enabled.")
return view(*args, **kwargs)
return decorated
def cloud_edition_billing_resource_check(resource: str): def cloud_edition_billing_resource_check(resource: str):
def interceptor(view): def interceptor(view):
@wraps(view) @wraps(view)

@ -13,6 +13,7 @@ from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocatio
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
from core.plugin.entities.request import ( from core.plugin.entities.request import (
RequestFetchAppInfo,
RequestInvokeApp, RequestInvokeApp,
RequestInvokeEncrypt, RequestInvokeEncrypt,
RequestInvokeLLM, RequestInvokeLLM,
@ -278,6 +279,17 @@ class PluginUploadFileRequestApi(Resource):
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump() return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
class PluginFetchAppInfoApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestFetchAppInfo)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestFetchAppInfo):
return BaseBackwardsInvocationResponse(
data=PluginAppBackwardsInvocation.fetch_app_info(payload.app_id, tenant_model.id)
).model_dump()
api.add_resource(PluginInvokeLLMApi, "/invoke/llm") api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding") api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank") api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
@ -291,3 +303,4 @@ api.add_resource(PluginInvokeAppApi, "/invoke/app")
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt") api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")
api.add_resource(PluginInvokeSummaryApi, "/invoke/summary") api.add_resource(PluginInvokeSummaryApi, "/invoke/summary")
api.add_resource(PluginUploadFileRequestApi, "/upload/file/request") api.add_resource(PluginUploadFileRequestApi, "/upload/file/request")
api.add_resource(PluginFetchAppInfoApi, "/fetch/app/info")

@ -6,5 +6,6 @@ bp = Blueprint("service_api", __name__, url_prefix="/v1")
api = ExternalApi(bp) api = ExternalApi(bp)
from . import index from . import index
from .app import app, audio, completion, conversation, file, message, workflow from .app import annotation, app, audio, completion, conversation, file, message, workflow
from .dataset import dataset, document, hit_testing, metadata, segment, upload_file from .dataset import dataset, document, hit_testing, metadata, segment, upload_file
from .workspace import models

@ -0,0 +1,107 @@
from flask import request
from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
from controllers.service_api import api
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from extensions.ext_redis import redis_client
from fields.annotation_fields import (
annotation_fields,
)
from libs.login import current_user
from models.model import App, EndUser
from services.annotation_service import AppAnnotationService
class AnnotationReplyActionApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser, action):
parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json")
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
parser.add_argument("embedding_model_name", required=True, type=str, location="json")
args = parser.parse_args()
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_model.id)
else:
raise ValueError("Unsupported annotation reply action")
return result, 200
class AnnotationReplyActionStatusApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, job_id, action):
job_id = str(job_id)
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
cache_result = redis_client.get(app_annotation_job_key)
if cache_result is None:
raise ValueError("The job does not exist.")
job_status = cache_result.decode()
error_msg = ""
if job_status == "error":
app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id))
error_msg = redis_client.get(app_annotation_error_key).decode()
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
class AnnotationListApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default="", type=str)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword)
response = {
"data": marshal(annotation_list, annotation_fields),
"has_more": len(annotation_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response, 200
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(annotation_fields)
def post(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
return annotation
class AnnotationUpdateDeleteApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(annotation_fields)
def post(self, app_model: App, end_user: EndUser, annotation_id):
if not current_user.is_editor:
raise Forbidden()
annotation_id = str(annotation_id)
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
return annotation
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def delete(self, app_model: App, end_user: EndUser, annotation_id):
if not current_user.is_editor:
raise Forbidden()
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
return {"result": "success"}, 200
api.add_resource(AnnotationReplyActionApi, "/apps/annotation-reply/<string:action>")
api.add_resource(AnnotationReplyActionStatusApi, "/apps/annotation-reply/<string:action>/status/<uuid:job_id>")
api.add_resource(AnnotationListApi, "/apps/annotations")
api.add_resource(AnnotationUpdateDeleteApi, "/apps/annotations/<uuid:annotation_id>")

@ -1,10 +1,10 @@
from flask_restful import Resource, marshal_with # type: ignore from flask_restful import Resource, marshal_with # type: ignore
from controllers.common import fields from controllers.common import fields
from controllers.common import helpers as controller_helpers
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token from controllers.service_api.wraps import validate_app_token
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import App, AppMode from models.model import App, AppMode
from services.app_service import AppService from services.app_service import AppService
@ -32,9 +32,7 @@ class AppParameterApi(Resource):
user_input_form = features_dict.get("user_input_form", []) user_input_form = features_dict.get("user_input_form", [])
return controller_helpers.get_parameters_from_feature_dict( return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
features_dict=features_dict, user_input_form=user_input_form
)
class AppMetaApi(Resource): class AppMetaApi(Resource):

@ -1,3 +1,4 @@
import json
import logging import logging
from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
@ -10,7 +11,7 @@ from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import message_file_fields from fields.conversation_fields import message_file_fields
from fields.message_fields import agent_thought_fields, feedback_fields, retriever_resource_fields from fields.message_fields import agent_thought_fields, feedback_fields
from fields.raws import FilesContainedField from fields.raws import FilesContainedField
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser
@ -28,7 +29,11 @@ class MessageListApi(Resource):
"answer": fields.String(attribute="re_sign_file_url_answer"), "answer": fields.String(attribute="re_sign_file_url_answer"),
"message_files": fields.List(fields.Nested(message_file_fields)), "message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), "retriever_resources": fields.Raw(
attribute=lambda obj: json.loads(obj.message_metadata).get("retriever_resources", [])
if obj.message_metadata
else []
),
"created_at": TimestampField, "created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"status": fields.String, "status": fields.String,

@ -1,6 +1,6 @@
import logging import logging
from datetime import datetime
from dateutil.parser import isoparse
from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range # type: ignore from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -140,10 +140,10 @@ class WorkflowAppLogApi(Resource):
args.status = WorkflowRunStatus(args.status) if args.status else None args.status = WorkflowRunStatus(args.status) if args.status else None
if args.created_at__before: if args.created_at__before:
args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00")) args.created_at__before = isoparse(args.created_at__before)
if args.created_at__after: if args.created_at__after:
args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00")) args.created_at__after = isoparse(args.created_at__after)
# get paginate workflow app logs # get paginate workflow app logs
workflow_app_service = WorkflowAppService() workflow_app_service = WorkflowAppService()

@ -1,6 +1,6 @@
from flask import request from flask import request
from flask_restful import marshal, reqparse # type: ignore from flask_restful import marshal, reqparse # type: ignore
from werkzeug.exceptions import NotFound from werkzeug.exceptions import Forbidden, NotFound
import services.dataset_service import services.dataset_service
from controllers.service_api import api from controllers.service_api import api
@ -12,7 +12,8 @@ from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields from fields.dataset_fields import dataset_detail_fields
from libs.login import current_user from libs.login import current_user
from models.dataset import Dataset, DatasetPermissionEnum from models.dataset import Dataset, DatasetPermissionEnum
from services.dataset_service import DatasetService from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
def _validate_name(name): def _validate_name(name):
@ -21,6 +22,12 @@ def _validate_name(name):
return name return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class DatasetListApi(DatasetApiResource): class DatasetListApi(DatasetApiResource):
"""Resource for datasets.""" """Resource for datasets."""
@ -114,8 +121,11 @@ class DatasetListApi(DatasetApiResource):
nullable=True, nullable=True,
required=False, required=False,
) )
args = parser.parse_args() parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
try: try:
dataset = DatasetService.create_empty_dataset( dataset = DatasetService.create_empty_dataset(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -127,6 +137,11 @@ class DatasetListApi(DatasetApiResource):
provider=args["provider"], provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"], external_knowledge_api_id=args["external_knowledge_api_id"],
external_knowledge_id=args["external_knowledge_id"], external_knowledge_id=args["external_knowledge_id"],
embedding_model_provider=args["embedding_model_provider"],
embedding_model_name=args["embedding_model"],
retrieval_model=RetrievalModel(**args["retrieval_model"])
if args["retrieval_model"] is not None
else None,
) )
except services.errors.dataset.DatasetNameDuplicateError: except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError() raise DatasetNameDuplicateError()
@ -137,6 +152,145 @@ class DatasetListApi(DatasetApiResource):
class DatasetApi(DatasetApiResource): class DatasetApi(DatasetApiResource):
"""Resource for dataset.""" """Resource for dataset."""
def get(self, _, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
model_names = []
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
if data["indexing_technique"] == "high_quality":
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
if item_model in model_names:
data["embedding_available"] = True
else:
data["embedding_available"] = False
else:
data["embedding_available"] = True
if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})
return data, 200
def patch(self, _, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
parser.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
parser.add_argument(
"permission",
type=str,
location="json",
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
)
parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
parser.add_argument(
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
)
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
parser.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)
parser.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)
parser.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
)
args = parser.parse_args()
data = request.get_json()
# check embedding model setting
if data.get("indexing_technique") == "high_quality":
DatasetService.check_embedding_model_setting(
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
)
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
current_user, dataset, data.get("permission"), data.get("partial_member_list")
)
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
if dataset is None:
raise NotFound("Dataset not found.")
result_data = marshal(dataset, dataset_detail_fields)
tenant_id = current_user.current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members":
DatasetPermissionService.update_partial_member_list(
tenant_id, dataset_id_str, data.get("partial_member_list")
)
# clear partial member list when permission is only_me or all_team_members
elif (
data.get("permission") == DatasetPermissionEnum.ONLY_ME
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
result_data.update({"partial_member_list": partial_member_list})
return result_data, 200
def delete(self, _, dataset_id): def delete(self, _, dataset_id):
""" """
Deletes a dataset given its ID. Deletes a dataset given its ID.
@ -158,6 +312,7 @@ class DatasetApi(DatasetApiResource):
try: try:
if DatasetService.delete_dataset(dataset_id_str, current_user): if DatasetService.delete_dataset(dataset_id_str, current_user):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
return {"result": "success"}, 204 return {"result": "success"}, 204
else: else:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")

@ -49,7 +49,9 @@ class DocumentAddByTextApi(DatasetApiResource):
parser.add_argument( parser.add_argument(
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
) )
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -57,7 +59,7 @@ class DocumentAddByTextApi(DatasetApiResource):
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset is not exist.") raise ValueError("Dataset does not exist.")
if not dataset.indexing_technique and not args["indexing_technique"]: if not dataset.indexing_technique and not args["indexing_technique"]:
raise ValueError("indexing_technique is required.") raise ValueError("indexing_technique is required.")
@ -114,7 +116,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset is not exist.") raise ValueError("Dataset does not exist.")
# indexing_technique is already set in dataset since this is an update # indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique args["indexing_technique"] = dataset.indexing_technique
@ -172,7 +174,7 @@ class DocumentAddByFileApi(DatasetApiResource):
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset is not exist.") raise ValueError("Dataset does not exist.")
if not dataset.indexing_technique and not args.get("indexing_technique"): if not dataset.indexing_technique and not args.get("indexing_technique"):
raise ValueError("indexing_technique is required.") raise ValueError("indexing_technique is required.")
@ -239,7 +241,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset is not exist.") raise ValueError("Dataset does not exist.")
# indexing_technique is already set in dataset since this is an update # indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique args["indexing_technique"] = dataset.indexing_technique
@ -303,7 +305,7 @@ class DocumentDeleteApi(DatasetApiResource):
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset is not exist.") raise ValueError("Dataset does not exist.")
document = DocumentService.get_document(dataset.id, document_id) document = DocumentService.get_document(dataset.id, document_id)
@ -341,7 +343,7 @@ class DocumentListApi(DatasetApiResource):
search = f"%{search}%" search = f"%{search}%"
query = query.filter(Document.name.like(search)) query = query.filter(Document.name.like(search))
query = query.order_by(desc(Document.created_at)) query = query.order_by(desc(Document.created_at), desc(Document.position))
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items documents = paginated_documents.items

@ -13,18 +13,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
from services.metadata_service import MetadataService from services.metadata_service import MetadataService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class DatasetMetadataCreateServiceApi(DatasetApiResource): class DatasetMetadataCreateServiceApi(DatasetApiResource):
def post(self, tenant_id, dataset_id): def post(self, tenant_id, dataset_id):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()

@ -117,14 +117,13 @@ class SegmentApi(DatasetApiResource):
parser.add_argument("keyword", type=str, default=None, location="args") parser.add_argument("keyword", type=str, default=None, location="args")
args = parser.parse_args() args = parser.parse_args()
status_list = args["status"]
keyword = args["keyword"]
segments, total = SegmentService.get_segments( segments, total = SegmentService.get_segments(
document_id=document_id, document_id=document_id,
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
status_list=args["status"], status_list=args["status"],
keyword=args["keyword"], keyword=args["keyword"],
page=page,
limit=limit,
) )
response = { response = {

@ -0,0 +1,21 @@
from flask_login import current_user # type: ignore
from flask_restful import Resource # type: ignore
from controllers.service_api import api
from controllers.service_api.wraps import validate_dataset_token
from core.model_runtime.utils.encoders import jsonable_encoder
from services.model_provider_service import ModelProviderService
class ModelProviderAvailableModelApi(Resource):
@validate_dataset_token
def get(self, _, model_type):
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
return jsonable_encoder({"data": models})
api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")

@ -59,6 +59,27 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
if tenant.status == TenantStatus.ARCHIVE: if tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("The workspace's status is archived.") raise Forbidden("The workspace's status is archived.")
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == api_token.tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.role.in_(["owner"]))
.filter(Tenant.status == TenantStatus.NORMAL)
.one_or_none()
) # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.query(Account).filter(Account.id == ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
else:
raise Unauthorized("Tenant owner account does not exist.")
else:
raise Unauthorized("Tenant does not exist.")
kwargs["app_model"] = app_model kwargs["app_model"] = app_model
if fetch_user_arg: if fetch_user_arg:

@ -1,10 +1,10 @@
from flask_restful import marshal_with # type: ignore from flask_restful import marshal_with # type: ignore
from controllers.common import fields from controllers.common import fields
from controllers.common import helpers as controller_helpers
from controllers.web import api from controllers.web import api
from controllers.web.error import AppUnavailableError from controllers.web.error import AppUnavailableError
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import App, AppMode from models.model import App, AppMode
from services.app_service import AppService from services.app_service import AppService
@ -31,9 +31,7 @@ class AppParameterApi(WebApiResource):
user_input_form = features_dict.get("user_input_form", []) user_input_form = features_dict.get("user_input_form", [])
return controller_helpers.get_parameters_from_feature_dict( return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
features_dict=features_dict, user_input_form=user_input_form
)
class AppMeta(WebApiResource): class AppMeta(WebApiResource):

@ -46,6 +46,7 @@ class MessageListApi(WebApiResource):
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"created_at": TimestampField, "created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String, "status": fields.String,
"error": fields.String, "error": fields.String,
} }

@ -191,7 +191,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# action is final answer, return final answer directly # action is final answer, return final answer directly
try: try:
if isinstance(scratchpad.action.action_input, dict): if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps(scratchpad.action.action_input) final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False)
elif isinstance(scratchpad.action.action_input, str): elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input final_answer = scratchpad.action.action_input
else: else:

@ -52,6 +52,7 @@ class AgentStrategyParameter(PluginParameter):
return cast_parameter_value(self, value) return cast_parameter_value(self, value)
type: AgentStrategyParameterType = Field(..., description="The type of the parameter") type: AgentStrategyParameterType = Field(..., description="The type of the parameter")
help: Optional[I18nObject] = None
def init_frontend_parameter(self, value: Any): def init_frontend_parameter(self, value: Any):
return init_frontend_parameter(self, self.type, value) return init_frontend_parameter(self, self.type, value)
@ -70,11 +71,20 @@ class AgentStrategyIdentity(ToolIdentity):
pass pass
class AgentFeature(enum.StrEnum):
"""
Agent Feature, used to describe the features of the agent strategy.
"""
HISTORY_MESSAGES = "history-messages"
class AgentStrategyEntity(BaseModel): class AgentStrategyEntity(BaseModel):
identity: AgentStrategyIdentity identity: AgentStrategyIdentity
parameters: list[AgentStrategyParameter] = Field(default_factory=list) parameters: list[AgentStrategyParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The description of the agent strategy") description: I18nObject = Field(..., description="The description of the agent strategy")
output_schema: Optional[dict] = None output_schema: Optional[dict] = None
features: Optional[list[AgentFeature]] = None
# pydantic configs # pydantic configs
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())

@ -0,0 +1,45 @@
from collections.abc import Mapping
from typing import Any
from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
def get_parameters_from_feature_dict(
*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]
) -> Mapping[str, Any]:
"""
Mapping from feature dict to webapp parameters
"""
return {
"opening_statement": features_dict.get("opening_statement"),
"suggested_questions": features_dict.get("suggested_questions", []),
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
"user_input_form": user_input_form,
"sensitive_word_avoidance": features_dict.get(
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
),
"file_upload": features_dict.get(
"file_upload",
{
"image": {
"enabled": False,
"number_limits": DEFAULT_FILE_NUMBER_LIMITS,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"],
}
},
),
"system_parameters": {
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
},
}

@ -1,6 +1,7 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any from typing import Any
from constants import DEFAULT_FILE_NUMBER_LIMITS
from core.file import FileUploadConfig from core.file import FileUploadConfig
@ -18,7 +19,7 @@ class FileUploadConfigManager:
if file_upload_dict.get("enabled"): if file_upload_dict.get("enabled"):
transform_methods = file_upload_dict.get("allowed_file_upload_methods", []) transform_methods = file_upload_dict.get("allowed_file_upload_methods", [])
file_upload_dict["image_config"] = { file_upload_dict["image_config"] = {
"number_limits": file_upload_dict.get("number_limits", 1), "number_limits": file_upload_dict.get("number_limits", DEFAULT_FILE_NUMBER_LIMITS),
"transfer_methods": transform_methods, "transfer_methods": transform_methods,
} }

@ -320,10 +320,9 @@ class AdvancedChatAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id session=session, workflow_run_id=self._workflow_run_id
) )
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event workflow_run=workflow_run, event=event
) )
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
@ -341,11 +340,10 @@ class AdvancedChatAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id session=session, workflow_run_id=self._workflow_run_id
) )
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event workflow_run=workflow_run, event=event
) )
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response( node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
@ -363,11 +361,10 @@ class AdvancedChatAppGenerateTaskPipeline:
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
session=session, event=event event=event
) )
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
@ -383,18 +380,15 @@ class AdvancedChatAppGenerateTaskPipeline:
| QueueNodeInLoopFailedEvent | QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent, | QueueNodeExceptionEvent,
): ):
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
session=session, event=event event=event
) )
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
) )
session.commit()
if node_finish_resp: if node_finish_resp:
yield node_finish_resp yield node_finish_resp

@ -53,20 +53,6 @@ class AgentChatAppRunner(AppRunner):
query = application_generate_entity.query query = application_generate_entity.query
files = application_generate_entity.files files = application_generate_entity.files
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=dict(inputs),
files=list(files),
query=query,
)
memory = None memory = None
if application_generate_entity.conversation_id: if application_generate_entity.conversation_id:
# get memory of conversation (read-only) # get memory of conversation (read-only)

@ -17,6 +17,7 @@ class BaseAppGenerator:
user_inputs: Optional[Mapping[str, Any]], user_inputs: Optional[Mapping[str, Any]],
variables: Sequence["VariableEntity"], variables: Sequence["VariableEntity"],
tenant_id: str, tenant_id: str,
strict_type_validation: bool = False,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
user_inputs = user_inputs or {} user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values # Filter input variables from form configuration, handle required fields, default values, and option values
@ -37,6 +38,7 @@ class BaseAppGenerator:
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
), ),
strict_type_validation=strict_type_validation,
) )
for k, v in user_inputs.items() for k, v in user_inputs.items()
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE

@ -61,20 +61,6 @@ class ChatAppRunner(AppRunner):
) )
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query,
)
memory = None memory = None
if application_generate_entity.conversation_id: if application_generate_entity.conversation_id:
# get memory of conversation (read-only) # get memory of conversation (read-only)

@ -54,20 +54,6 @@ class CompletionAppRunner(AppRunner):
) )
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query,
)
# organize all inputs and template to prompt messages # organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional) # Include: prompt template, inputs, query(optional), files(optional)
prompt_messages, stop = self.organize_prompt_messages( prompt_messages, stop = self.organize_prompt_messages(

@ -153,6 +153,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
query = application_generate_entity.query or "New conversation" query = application_generate_entity.query or "New conversation"
else: else:
query = next(iter(application_generate_entity.inputs.values()), "New conversation") query = next(iter(application_generate_entity.inputs.values()), "New conversation")
query = query or "New conversation"
conversation_name = (query[:20] + "") if len(query) > 20 else query conversation_name = (query[:20] + "") if len(query) > 20 else query
if not conversation: if not conversation:

@ -92,6 +92,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
mappings=files, mappings=files,
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
config=file_extra_config, config=file_extra_config,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
) )
# convert to app config # convert to app config
@ -114,7 +115,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_config=app_config, app_config=app_config,
file_upload_config=file_extra_config, file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs( inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id user_inputs=inputs,
variables=app_config.variables,
tenant_id=app_model.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
), ),
files=list(system_files), files=list(system_files),
user_id=user.id, user_id=user.id,

@ -279,10 +279,9 @@ class WorkflowAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id session=session, workflow_run_id=self._workflow_run_id
) )
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event workflow_run=workflow_run, event=event
) )
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
@ -300,10 +299,9 @@ class WorkflowAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id session=session, workflow_run_id=self._workflow_run_id
) )
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event workflow_run=workflow_run, event=event
) )
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response( node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
@ -313,17 +311,14 @@ class WorkflowAppGenerateTaskPipeline:
if node_start_response: if node_start_response:
yield node_start_response yield node_start_response
elif isinstance(event, QueueNodeSucceededEvent): elif isinstance(event, QueueNodeSucceededEvent):
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
session=session, event=event event=event
) )
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
) )
session.commit()
if node_success_response: if node_success_response:
yield node_success_response yield node_success_response
@ -334,18 +329,14 @@ class WorkflowAppGenerateTaskPipeline:
| QueueNodeInLoopFailedEvent | QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent, | QueueNodeExceptionEvent,
): ):
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
session=session,
event=event, event=event,
) )
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
) )
session.commit()
if node_failed_response: if node_failed_response:
yield node_failed_response yield node_failed_response
@ -627,6 +618,7 @@ class WorkflowAppGenerateTaskPipeline:
workflow_app_log.created_by = self._user_id workflow_app_log.created_by = self._user_id
session.add(workflow_app_log) session.add(workflow_app_log)
session.commit()
def _text_chunk_to_stream_response( def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: Optional[list[str]] = None self, text: str, from_variable_selector: Optional[list[str]] = None

@ -6,7 +6,7 @@ from typing import Any, Optional, Union, cast
from uuid import uuid4 from uuid import uuid4
from sqlalchemy import func, select from sqlalchemy import func, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session, sessionmaker
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
@ -49,12 +49,14 @@ from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.repository import RepositoryFactory
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
from models.model import EndUser from models.model import EndUser
@ -80,6 +82,21 @@ class WorkflowCycleManage:
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables self._workflow_system_variables = workflow_system_variables
# Initialize the session factory and repository
# We use the global db engine instead of the session passed to methods
# Disable expire_on_commit to avoid the need for merging objects
self._session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": self._application_generate_entity.app_config.tenant_id,
"app_id": self._application_generate_entity.app_config.app_id,
"session_factory": self._session_factory,
}
)
# We'll still keep the cache for backward compatibility and performance
# but use the repository for database operations
def _handle_workflow_run_start( def _handle_workflow_run_start(
self, self,
*, *,
@ -254,19 +271,15 @@ class WorkflowCycleManage:
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count workflow_run.exceptions_count = exceptions_count
stmt = select(WorkflowNodeExecution.node_execution_id).where( # Use the instance repository to find running executions for a workflow run
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions(
WorkflowNodeExecution.app_id == workflow_run.app_id, workflow_run_id=workflow_run.id
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
) )
ids = session.scalars(stmt).all()
# Use self._get_workflow_node_execution here to make sure the cache is updated # Update the cache with the retrieved executions
running_workflow_node_executions = [ for execution in running_workflow_node_executions:
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id if execution.node_execution_id:
] self._workflow_node_executions[execution.node_execution_id] = execution
for workflow_node_execution in running_workflow_node_executions: for workflow_node_execution in running_workflow_node_executions:
now = datetime.now(UTC).replace(tzinfo=None) now = datetime.now(UTC).replace(tzinfo=None)
@ -288,7 +301,7 @@ class WorkflowCycleManage:
return workflow_run return workflow_run
def _handle_node_execution_start( def _handle_node_execution_start(
self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeStartedEvent self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
) -> WorkflowNodeExecution: ) -> WorkflowNodeExecution:
workflow_node_execution = WorkflowNodeExecution() workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.id = str(uuid4()) workflow_node_execution.id = str(uuid4())
@ -315,17 +328,14 @@ class WorkflowCycleManage:
) )
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
session.add(workflow_node_execution) # Use the instance repository to save the workflow node execution
self._workflow_node_execution_repository.save(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_success( def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
self, *, session: Session, event: QueueNodeSucceededEvent workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
) -> WorkflowNodeExecution:
workflow_node_execution = self._get_workflow_node_execution(
session=session, node_execution_id=event.node_execution_id
)
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data) process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs) outputs = WorkflowEntry.handle_special_values(event.outputs)
@ -344,13 +354,13 @@ class WorkflowCycleManage:
workflow_node_execution.finished_at = finished_at workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution = session.merge(workflow_node_execution) # Use the instance repository to update the workflow node execution
self._workflow_node_execution_repository.update(workflow_node_execution)
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_failed( def _handle_workflow_node_execution_failed(
self, self,
*, *,
session: Session,
event: QueueNodeFailedEvent event: QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent | QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent | QueueNodeInLoopFailedEvent
@ -361,9 +371,7 @@ class WorkflowCycleManage:
:param event: queue node failed event :param event: queue node failed event
:return: :return:
""" """
workflow_node_execution = self._get_workflow_node_execution( workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
session=session, node_execution_id=event.node_execution_id
)
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data) process_data = WorkflowEntry.handle_special_values(event.process_data)
@ -387,14 +395,14 @@ class WorkflowCycleManage:
workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_retried( def _handle_workflow_node_execution_retried(
self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeRetryEvent self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution: ) -> WorkflowNodeExecution:
""" """
Workflow node execution failed Workflow node execution failed
:param workflow_run: workflow run
:param event: queue node failed event :param event: queue node failed event
:return: :return:
""" """
@ -439,15 +447,12 @@ class WorkflowCycleManage:
workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution.index = event.node_run_index workflow_node_execution.index = event.node_run_index
session.add(workflow_node_execution) # Use the instance repository to save the workflow node execution
self._workflow_node_execution_repository.save(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution return workflow_node_execution
#################################################
# to stream responses #
#################################################
def _workflow_start_to_stream_response( def _workflow_start_to_stream_response(
self, self,
*, *,
@ -455,7 +460,6 @@ class WorkflowCycleManage:
task_id: str, task_id: str,
workflow_run: WorkflowRun, workflow_run: WorkflowRun,
) -> WorkflowStartStreamResponse: ) -> WorkflowStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return WorkflowStartStreamResponse( return WorkflowStartStreamResponse(
task_id=task_id, task_id=task_id,
@ -521,14 +525,10 @@ class WorkflowCycleManage:
def _workflow_node_start_to_stream_response( def _workflow_node_start_to_stream_response(
self, self,
*, *,
session: Session,
event: QueueNodeStartedEvent, event: QueueNodeStartedEvent,
task_id: str, task_id: str,
workflow_node_execution: WorkflowNodeExecution, workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeStartStreamResponse]: ) -> Optional[NodeStartStreamResponse]:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None return None
if not workflow_node_execution.workflow_run_id: if not workflow_node_execution.workflow_run_id:
@ -571,7 +571,6 @@ class WorkflowCycleManage:
def _workflow_node_finish_to_stream_response( def _workflow_node_finish_to_stream_response(
self, self,
*, *,
session: Session,
event: QueueNodeSucceededEvent event: QueueNodeSucceededEvent
| QueueNodeFailedEvent | QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent | QueueNodeInIterationFailedEvent
@ -580,8 +579,6 @@ class WorkflowCycleManage:
task_id: str, task_id: str,
workflow_node_execution: WorkflowNodeExecution, workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]: ) -> Optional[NodeFinishStreamResponse]:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None return None
if not workflow_node_execution.workflow_run_id: if not workflow_node_execution.workflow_run_id:
@ -621,13 +618,10 @@ class WorkflowCycleManage:
def _workflow_node_retry_to_stream_response( def _workflow_node_retry_to_stream_response(
self, self,
*, *,
session: Session,
event: QueueNodeRetryEvent, event: QueueNodeRetryEvent,
task_id: str, task_id: str,
workflow_node_execution: WorkflowNodeExecution, workflow_node_execution: WorkflowNodeExecution,
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None return None
if not workflow_node_execution.workflow_run_id: if not workflow_node_execution.workflow_run_id:
@ -668,7 +662,6 @@ class WorkflowCycleManage:
def _workflow_parallel_branch_start_to_stream_response( def _workflow_parallel_branch_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse: ) -> ParallelBranchStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return ParallelBranchStartStreamResponse( return ParallelBranchStartStreamResponse(
task_id=task_id, task_id=task_id,
@ -692,7 +685,6 @@ class WorkflowCycleManage:
workflow_run: WorkflowRun, workflow_run: WorkflowRun,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
) -> ParallelBranchFinishedStreamResponse: ) -> ParallelBranchFinishedStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return ParallelBranchFinishedStreamResponse( return ParallelBranchFinishedStreamResponse(
task_id=task_id, task_id=task_id,
@ -713,7 +705,6 @@ class WorkflowCycleManage:
def _workflow_iteration_start_to_stream_response( def _workflow_iteration_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
) -> IterationNodeStartStreamResponse: ) -> IterationNodeStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return IterationNodeStartStreamResponse( return IterationNodeStartStreamResponse(
task_id=task_id, task_id=task_id,
@ -735,7 +726,6 @@ class WorkflowCycleManage:
def _workflow_iteration_next_to_stream_response( def _workflow_iteration_next_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
) -> IterationNodeNextStreamResponse: ) -> IterationNodeNextStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return IterationNodeNextStreamResponse( return IterationNodeNextStreamResponse(
task_id=task_id, task_id=task_id,
@ -759,7 +749,6 @@ class WorkflowCycleManage:
def _workflow_iteration_completed_to_stream_response( def _workflow_iteration_completed_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
) -> IterationNodeCompletedStreamResponse: ) -> IterationNodeCompletedStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return IterationNodeCompletedStreamResponse( return IterationNodeCompletedStreamResponse(
task_id=task_id, task_id=task_id,
@ -790,7 +779,6 @@ class WorkflowCycleManage:
def _workflow_loop_start_to_stream_response( def _workflow_loop_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent
) -> LoopNodeStartStreamResponse: ) -> LoopNodeStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return LoopNodeStartStreamResponse( return LoopNodeStartStreamResponse(
task_id=task_id, task_id=task_id,
@ -812,7 +800,6 @@ class WorkflowCycleManage:
def _workflow_loop_next_to_stream_response( def _workflow_loop_next_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent
) -> LoopNodeNextStreamResponse: ) -> LoopNodeNextStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return LoopNodeNextStreamResponse( return LoopNodeNextStreamResponse(
task_id=task_id, task_id=task_id,
@ -836,7 +823,6 @@ class WorkflowCycleManage:
def _workflow_loop_completed_to_stream_response( def _workflow_loop_completed_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent
) -> LoopNodeCompletedStreamResponse: ) -> LoopNodeCompletedStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return LoopNodeCompletedStreamResponse( return LoopNodeCompletedStreamResponse(
task_id=task_id, task_id=task_id,
@ -934,11 +920,22 @@ class WorkflowCycleManage:
return workflow_run return workflow_run
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
if node_execution_id not in self._workflow_node_executions: # First check the cache for performance
if node_execution_id in self._workflow_node_executions:
cached_execution = self._workflow_node_executions[node_execution_id]
# No need to merge with session since expire_on_commit=False
return cached_execution
# If not in cache, use the instance repository to get by node_execution_id
execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id)
if not execution:
raise ValueError(f"Workflow node execution not found: {node_execution_id}") raise ValueError(f"Workflow node execution not found: {node_execution_id}")
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
return session.merge(cached_workflow_node_execution) # Update cache
self._workflow_node_executions[node_execution_id] = execution
return execution
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
""" """

@ -6,7 +6,6 @@ from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from models.model import DatasetRetrieverResource
class DatasetIndexToolCallbackHandler: class DatasetIndexToolCallbackHandler:
@ -71,29 +70,6 @@ class DatasetIndexToolCallbackHandler:
def return_retriever_resource_info(self, resource: list): def return_retriever_resource_info(self, resource: list):
"""Handle return_retriever_resource_info.""" """Handle return_retriever_resource_info."""
if resource and len(resource) > 0:
for item in resource:
dataset_retriever_resource = DatasetRetrieverResource(
message_id=self._message_id,
position=item.get("position") or 0,
dataset_id=item.get("dataset_id"),
dataset_name=item.get("dataset_name"),
document_id=item.get("document_id"),
document_name=item.get("document_name"),
data_source_type=item.get("data_source_type"),
segment_id=item.get("segment_id"),
score=item.get("score") if "score" in item else None,
hit_count=item.get("hit_count") if "hit_count" in item else None,
word_count=item.get("word_count") if "word_count" in item else None,
segment_position=item.get("segment_position") if "segment_position" in item else None,
index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None,
content=item.get("content"),
retriever_from=item.get("retriever_from"),
created_by=self._user_id,
)
db.session.add(dataset_retriever_resource)
db.session.commit()
self._queue_manager.publish( self._queue_manager.publish(
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
) )

@ -146,6 +146,7 @@ class BasicProviderConfig(BaseModel):
BOOLEAN = CommonParameterType.BOOLEAN.value BOOLEAN = CommonParameterType.BOOLEAN.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
@classmethod @classmethod
def value_of(cls, value: str) -> "ProviderConfig.Type": def value_of(cls, value: str) -> "ProviderConfig.Type":

@ -48,21 +48,26 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT, write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
) )
if "ssl_verify" not in kwargs:
kwargs["ssl_verify"] = HTTP_REQUEST_NODE_SSL_VERIFY
ssl_verify = kwargs.pop("ssl_verify")
retries = 0 retries = 0
while retries <= max_retries: while retries <= max_retries:
try: try:
if dify_config.SSRF_PROXY_ALL_URL: if dify_config.SSRF_PROXY_ALL_URL:
with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client: with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=ssl_verify) as client:
response = client.request(method=method, url=url, **kwargs) response = client.request(method=method, url=url, **kwargs)
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxy_mounts = { proxy_mounts = {
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL), "http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=ssl_verify),
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL), "https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=ssl_verify),
} }
with httpx.Client(mounts=proxy_mounts, verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client: with httpx.Client(mounts=proxy_mounts, verify=ssl_verify) as client:
response = client.request(method=method, url=url, **kwargs) response = client.request(method=method, url=url, **kwargs)
else: else:
with httpx.Client(verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client: with httpx.Client(verify=ssl_verify) as client:
response = client.request(method=method, url=url, **kwargs) response = client.request(method=method, url=url, **kwargs)
if response.status_code not in STATUS_FORCELIST: if response.status_code not in STATUS_FORCELIST:

@ -44,6 +44,7 @@ class TokenBufferMemory:
Message.created_at, Message.created_at,
Message.workflow_run_id, Message.workflow_run_id,
Message.parent_message_id, Message.parent_message_id,
Message.answer_tokens,
) )
.filter( .filter(
Message.conversation_id == self.conversation.id, Message.conversation_id == self.conversation.id,
@ -63,7 +64,7 @@ class TokenBufferMemory:
thread_messages = extract_thread_messages(messages) thread_messages = extract_thread_messages(messages)
# for newly created message, its answer is temporarily empty, we don't need to add it to memory # for newly created message, its answer is temporarily empty, we don't need to add it to memory
if thread_messages and not thread_messages[0].answer: if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
thread_messages.pop(0) thread_messages.pop(0)
messages = list(reversed(thread_messages)) messages = list(reversed(thread_messages))

@ -177,7 +177,7 @@ class ModelInstance:
) )
def get_llm_num_tokens( def get_llm_num_tokens(
self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None self, prompt_messages: Sequence[PromptMessage], tools: Optional[Sequence[PromptMessageTool]] = None
) -> int: ) -> int:
""" """
Get number of tokens for llm Get number of tokens for llm

@ -58,7 +58,7 @@ class Callback(ABC):
chunk: LLMResultChunk, chunk: LLMResultChunk,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
@ -88,7 +88,7 @@ class Callback(ABC):
result: LLMResult, result: LLMResult,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,

@ -74,7 +74,7 @@ class LoggingCallback(Callback):
chunk: LLMResultChunk, chunk: LLMResultChunk,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
@ -104,7 +104,7 @@ class LoggingCallback(Callback):
result: LLMResult, result: LLMResult,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,

@ -192,7 +192,7 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr
``` ```
Sometimes, you might not want to return 0 directly. In such cases, you can use `self._get_num_tokens_by_gpt2(text: str)` to get pre-computed tokens. This method is provided by the `AIModel` base class, and it uses GPT2's Tokenizer for calculation. However, it should be noted that this is only a substitute and may not be fully accurate. Sometimes, you might not want to return 0 directly. In such cases, you can use `self._get_num_tokens_by_gpt2(text: str)` to get pre-computed tokens and ensure environment variable `PLUGIN_BASED_TOKEN_COUNTING_ENABLED` is set to `true`, This method is provided by the `AIModel` base class, and it uses GPT2's Tokenizer for calculation. However, it should be noted that this is only a substitute and may not be fully accurate.
- Model Credentials Validation - Model Credentials Validation

@ -367,7 +367,7 @@ Inherit the `__base.text2speech_model.Text2SpeechModel` base class and implement
- Returns - Returns
Text converted speech stream Text converted speech stream.
### Moderation ### Moderation

@ -179,7 +179,7 @@ provider_credential_schema:
""" """
``` ```
有时候也许你不需要直接返回0所以你可以使用`self._get_num_tokens_by_gpt2(text: str)`来获取预计算的tokens这个方法位于`AIModel`基类中它会使用GPT2的Tokenizer进行计算但是只能作为替代方法并不完全准确。 有时候,也许你不需要直接返回 0所以你可以使用`self._get_num_tokens_by_gpt2(text: str)`来获取预计算的 tokens并确保环境变量`PLUGIN_BASED_TOKEN_COUNTING_ENABLED`设置为`true`这个方法位于`AIModel`基类中,它会使用 GPT2 Tokenizer 进行计算,但是只能作为替代方法,并不完全准确。
- 模型凭据校验 - 模型凭据校验

@ -1,8 +1,9 @@
from collections.abc import Sequence
from decimal import Decimal from decimal import Decimal
from enum import StrEnum from enum import StrEnum
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel, Field
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo
@ -107,7 +108,7 @@ class LLMResult(BaseModel):
id: Optional[str] = None id: Optional[str] = None
model: str model: str
prompt_messages: list[PromptMessage] prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
message: AssistantPromptMessage message: AssistantPromptMessage
usage: LLMUsage usage: LLMUsage
system_fingerprint: Optional[str] = None system_fingerprint: Optional[str] = None
@ -130,7 +131,7 @@ class LLMResultChunk(BaseModel):
""" """
model: str model: str
prompt_messages: list[PromptMessage] prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
system_fingerprint: Optional[str] = None system_fingerprint: Optional[str] = None
delta: LLMResultChunkDelta delta: LLMResultChunkDelta

@ -1,5 +1,6 @@
import logging import logging
import time import time
import uuid
from collections.abc import Generator, Sequence from collections.abc import Generator, Sequence
from typing import Optional, Union from typing import Optional, Union
@ -24,6 +25,58 @@ from core.plugin.manager.model import PluginModelManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _gen_tool_call_id() -> str:
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
def _increase_tool_call(
new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall]
):
"""
Merge incremental tool call updates into existing tool calls.
:param new_tool_calls: List of new tool call deltas to be merged.
:param existing_tools_calls: List of existing tool calls to be modified IN-PLACE.
"""
def get_tool_call(tool_call_id: str):
"""
Get or create a tool call by ID
:param tool_call_id: tool call ID
:return: existing or new tool call
"""
if not tool_call_id:
return existing_tools_calls[-1]
_tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None)
if _tool_call is None:
_tool_call = AssistantPromptMessage.ToolCall(
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
)
existing_tools_calls.append(_tool_call)
return _tool_call
for new_tool_call in new_tool_calls:
# generate ID for tool calls with function name but no ID to track them
if new_tool_call.function.name and not new_tool_call.id:
new_tool_call.id = _gen_tool_call_id()
# get tool call
tool_call = get_tool_call(new_tool_call.id)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
class LargeLanguageModel(AIModel): class LargeLanguageModel(AIModel):
""" """
Model class for large language model. Model class for large language model.
@ -45,7 +98,7 @@ class LargeLanguageModel(AIModel):
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
) -> Union[LLMResult, Generator]: ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
""" """
Invoke large language model Invoke large language model
@ -109,44 +162,13 @@ class LargeLanguageModel(AIModel):
system_fingerprint = None system_fingerprint = None
tools_calls: list[AssistantPromptMessage.ToolCall] = [] tools_calls: list[AssistantPromptMessage.ToolCall] = []
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
def get_tool_call(tool_name: str):
if not tool_name:
return tools_calls[-1]
tool_call = next(
(tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None
)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id="",
type="",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""),
)
tools_calls.append(tool_call)
return tool_call
for new_tool_call in new_tool_calls:
# get tool call
tool_call = get_tool_call(new_tool_call.function.name)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
for chunk in result: for chunk in result:
if isinstance(chunk.delta.message.content, str): if isinstance(chunk.delta.message.content, str):
content += chunk.delta.message.content content += chunk.delta.message.content
elif isinstance(chunk.delta.message.content, list): elif isinstance(chunk.delta.message.content, list):
content_list.extend(chunk.delta.message.content) content_list.extend(chunk.delta.message.content)
if chunk.delta.message.tool_calls: if chunk.delta.message.tool_calls:
increase_tool_call(chunk.delta.message.tool_calls) _increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
usage = chunk.delta.usage or LLMUsage.empty_usage() usage = chunk.delta.usage or LLMUsage.empty_usage()
system_fingerprint = chunk.system_fingerprint system_fingerprint = chunk.system_fingerprint
@ -205,22 +227,26 @@ class LargeLanguageModel(AIModel):
user=user, user=user,
callbacks=callbacks, callbacks=callbacks,
) )
# Following https://github.com/langgenius/dify/issues/17799,
# we removed the prompt_messages from the chunk on the plugin daemon side.
# To ensure compatibility, we add the prompt_messages back here.
result.prompt_messages = prompt_messages
return result return result
raise NotImplementedError("unsupported invoke result type", type(result))
def _invoke_result_generator( def _invoke_result_generator(
self, self,
model: str, model: str,
result: Generator, result: Generator,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
) -> Generator: ) -> Generator[LLMResultChunk, None, None]:
""" """
Invoke result generator Invoke result generator
@ -235,6 +261,10 @@ class LargeLanguageModel(AIModel):
try: try:
for chunk in result: for chunk in result:
# Following https://github.com/langgenius/dify/issues/17799,
# we removed the prompt_messages from the chunk on the plugin daemon side.
# To ensure compatibility, we add the prompt_messages back here.
chunk.prompt_messages = prompt_messages
yield chunk yield chunk
self._trigger_new_chunk_callbacks( self._trigger_new_chunk_callbacks(
@ -295,6 +325,7 @@ class LargeLanguageModel(AIModel):
:param tools: tools for tool calling :param tools: tools for tool calling
:return: :return:
""" """
if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
plugin_model_manager = PluginModelManager() plugin_model_manager = PluginModelManager()
return plugin_model_manager.get_llm_num_tokens( return plugin_model_manager.get_llm_num_tokens(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
@ -307,6 +338,7 @@ class LargeLanguageModel(AIModel):
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
tools=tools, tools=tools,
) )
return 0
def _calc_response_usage( def _calc_response_usage(
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int
@ -401,7 +433,7 @@ class LargeLanguageModel(AIModel):
chunk: LLMResultChunk, chunk: LLMResultChunk,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
@ -448,7 +480,7 @@ class LargeLanguageModel(AIModel):
model: str, model: str,
result: LLMResult, result: LLMResult,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,

@ -5,6 +5,7 @@ from datetime import datetime, timedelta
from typing import Optional from typing import Optional
from langfuse import Langfuse # type: ignore from langfuse import Langfuse # type: ignore
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.config_entity import LangfuseConfig
@ -28,9 +29,9 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
UnitEnum, UnitEnum,
) )
from core.ops.utils import filter_none_values from core.ops.utils import filter_none_values
from core.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.model import EndUser from models.model import EndUser
from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -110,36 +111,18 @@ class LangFuseDataTrace(BaseTraceInstance):
) )
self.add_trace(langfuse_trace_data=trace_data) self.add_trace(langfuse_trace_data=trace_data)
# through workflow_run_id get all_nodes_execution # through workflow_run_id get all_nodes_execution using repository
workflow_nodes_execution_id_records = ( session_factory = sessionmaker(bind=db.engine)
db.session.query(WorkflowNodeExecution.id) workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) params={"tenant_id": trace_info.tenant_id, "session_factory": session_factory},
.all()
) )
for node_execution_id_record in workflow_nodes_execution_id_records: # Get all executions for this workflow run
node_execution = ( workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
db.session.query( workflow_run_id=trace_info.workflow_run_id
WorkflowNodeExecution.id,
WorkflowNodeExecution.tenant_id,
WorkflowNodeExecution.app_id,
WorkflowNodeExecution.title,
WorkflowNodeExecution.node_type,
WorkflowNodeExecution.status,
WorkflowNodeExecution.inputs,
WorkflowNodeExecution.outputs,
WorkflowNodeExecution.created_at,
WorkflowNodeExecution.elapsed_time,
WorkflowNodeExecution.process_data,
WorkflowNodeExecution.execution_metadata,
) )
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
.first()
)
if not node_execution:
continue
for node_execution in workflow_node_executions:
node_execution_id = node_execution.id node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id tenant_id = node_execution.tenant_id
app_id = node_execution.app_id app_id = node_execution.app_id
@ -213,9 +196,24 @@ class LangFuseDataTrace(BaseTraceInstance):
if process_data and process_data.get("model_mode") == "chat": if process_data and process_data.get("model_mode") == "chat":
total_token = metadata.get("total_tokens", 0) total_token = metadata.get("total_tokens", 0)
prompt_tokens = 0
completion_tokens = 0
try:
if outputs.get("usage"):
prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = outputs.get("usage", {}).get("completion_tokens", 0)
else:
prompt_tokens = process_data.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = process_data.get("usage", {}).get("completion_tokens", 0)
except Exception:
logger.error("Failed to extract usage", exc_info=True)
# add generation # add generation
generation_usage = GenerationUsage( generation_usage = GenerationUsage(
input=prompt_tokens,
output=completion_tokens,
total=total_token, total=total_token,
unit=UnitEnum.TOKENS,
) )
node_generation_data = LangfuseGeneration( node_generation_data = LangfuseGeneration(

@ -7,6 +7,7 @@ from typing import Optional, cast
from langsmith import Client from langsmith import Client
from langsmith.schemas import RunBase from langsmith.schemas import RunBase
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangSmithConfig from core.ops.entities.config_entity import LangSmithConfig
@ -27,9 +28,9 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
LangSmithRunUpdateModel, LangSmithRunUpdateModel,
) )
from core.ops.utils import filter_none_values, generate_dotted_order from core.ops.utils import filter_none_values, generate_dotted_order
from core.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.model import EndUser, MessageFile from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -134,36 +135,22 @@ class LangSmithDataTrace(BaseTraceInstance):
self.add_run(langsmith_run) self.add_run(langsmith_run)
# through workflow_run_id get all_nodes_execution # through workflow_run_id get all_nodes_execution using repository
workflow_nodes_execution_id_records = ( session_factory = sessionmaker(bind=db.engine)
db.session.query(WorkflowNodeExecution.id) workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) params={
.all() "tenant_id": trace_info.tenant_id,
"app_id": trace_info.metadata.get("app_id"),
"session_factory": session_factory,
},
) )
for node_execution_id_record in workflow_nodes_execution_id_records: # Get all executions for this workflow run
node_execution = ( workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
db.session.query( workflow_run_id=trace_info.workflow_run_id
WorkflowNodeExecution.id,
WorkflowNodeExecution.tenant_id,
WorkflowNodeExecution.app_id,
WorkflowNodeExecution.title,
WorkflowNodeExecution.node_type,
WorkflowNodeExecution.status,
WorkflowNodeExecution.inputs,
WorkflowNodeExecution.outputs,
WorkflowNodeExecution.created_at,
WorkflowNodeExecution.elapsed_time,
WorkflowNodeExecution.process_data,
WorkflowNodeExecution.execution_metadata,
)
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
.first()
) )
if not node_execution: for node_execution in workflow_node_executions:
continue
node_execution_id = node_execution.id node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id tenant_id = node_execution.tenant_id
app_id = node_execution.app_id app_id = node_execution.app_id
@ -199,6 +186,7 @@ class LangSmithDataTrace(BaseTraceInstance):
) )
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
if process_data and process_data.get("model_mode") == "chat": if process_data and process_data.get("model_mode") == "chat":
run_type = LangSmithRunType.llm run_type = LangSmithRunType.llm
metadata.update( metadata.update(
@ -212,9 +200,23 @@ class LangSmithDataTrace(BaseTraceInstance):
else: else:
run_type = LangSmithRunType.tool run_type = LangSmithRunType.tool
prompt_tokens = 0
completion_tokens = 0
try:
if outputs.get("usage"):
prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = outputs.get("usage", {}).get("completion_tokens", 0)
else:
prompt_tokens = process_data.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = process_data.get("usage", {}).get("completion_tokens", 0)
except Exception:
logger.error("Failed to extract usage", exc_info=True)
node_dotted_order = generate_dotted_order(node_execution_id, created_at, workflow_dotted_order) node_dotted_order = generate_dotted_order(node_execution_id, created_at, workflow_dotted_order)
langsmith_run = LangSmithRunModel( langsmith_run = LangSmithRunModel(
total_tokens=node_total_tokens, total_tokens=node_total_tokens,
input_tokens=prompt_tokens,
output_tokens=completion_tokens,
name=node_type, name=node_type,
inputs=inputs, inputs=inputs,
run_type=run_type, run_type=run_type,

@ -7,6 +7,7 @@ from typing import Optional, cast
from opik import Opik, Trace from opik import Opik, Trace
from opik.id_helpers import uuid4_to_uuid7 from opik.id_helpers import uuid4_to_uuid7
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import OpikConfig from core.ops.entities.config_entity import OpikConfig
@ -21,9 +22,9 @@ from core.ops.entities.trace_entity import (
TraceTaskName, TraceTaskName,
WorkflowTraceInfo, WorkflowTraceInfo,
) )
from core.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.model import EndUser, MessageFile from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -147,36 +148,22 @@ class OpikDataTrace(BaseTraceInstance):
} }
self.add_trace(trace_data) self.add_trace(trace_data)
# through workflow_run_id get all_nodes_execution # through workflow_run_id get all_nodes_execution using repository
workflow_nodes_execution_id_records = ( session_factory = sessionmaker(bind=db.engine)
db.session.query(WorkflowNodeExecution.id) workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) params={
.all() "tenant_id": trace_info.tenant_id,
"app_id": trace_info.metadata.get("app_id"),
"session_factory": session_factory,
},
) )
for node_execution_id_record in workflow_nodes_execution_id_records: # Get all executions for this workflow run
node_execution = ( workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
db.session.query( workflow_run_id=trace_info.workflow_run_id
WorkflowNodeExecution.id,
WorkflowNodeExecution.tenant_id,
WorkflowNodeExecution.app_id,
WorkflowNodeExecution.title,
WorkflowNodeExecution.node_type,
WorkflowNodeExecution.status,
WorkflowNodeExecution.inputs,
WorkflowNodeExecution.outputs,
WorkflowNodeExecution.created_at,
WorkflowNodeExecution.elapsed_time,
WorkflowNodeExecution.process_data,
WorkflowNodeExecution.execution_metadata,
)
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
.first()
) )
if not node_execution: for node_execution in workflow_node_executions:
continue
node_execution_id = node_execution.id node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id tenant_id = node_execution.tenant_id
app_id = node_execution.app_id app_id = node_execution.app_id

@ -460,7 +460,7 @@ class TraceTask:
"version": workflow_run_version, "version": workflow_run_version,
"total_tokens": total_tokens, "total_tokens": total_tokens,
"file_list": file_list, "file_list": file_list,
"triggered_form": workflow_run.triggered_from, "triggered_from": workflow_run.triggered_from,
"user_id": user_id, "user_id": user_id,
} }

@ -2,6 +2,7 @@ from collections.abc import Generator, Mapping
from typing import Optional, Union from typing import Optional, Union
from controllers.service_api.wraps import create_or_update_end_user_for_user_id from controllers.service_api.wraps import create_or_update_end_user_for_user_id
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
from core.app.apps.chat.app_generator import ChatAppGenerator from core.app.apps.chat.app_generator import ChatAppGenerator
@ -15,6 +16,34 @@ from models.model import App, AppMode, EndUser
class PluginAppBackwardsInvocation(BaseBackwardsInvocation): class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
@classmethod
def fetch_app_info(cls, app_id: str, tenant_id: str) -> Mapping:
"""
Fetch app info
"""
app = cls._get_app(app_id, tenant_id)
"""Retrieve app parameters."""
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app.workflow
if workflow is None:
raise ValueError("unexpected app type")
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app.app_model_config
if app_model_config is None:
raise ValueError("unexpected app type")
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
return {
"data": get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form),
}
@classmethod @classmethod
def invoke_app( def invoke_app(
cls, cls,

@ -131,7 +131,7 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
raise ValueError("The selector must be a dictionary.") raise ValueError("The selector must be a dictionary.")
return value return value
case PluginParameterType.TOOLS_SELECTOR: case PluginParameterType.TOOLS_SELECTOR:
if not isinstance(value, list): if value and not isinstance(value, list):
raise ValueError("The tools selector must be a list.") raise ValueError("The tools selector must be a list.")
return value return value
case _: case _:
@ -147,7 +147,7 @@ def init_frontend_parameter(rule: PluginParameter, type: enum.StrEnum, value: An
init frontend parameter by rule init frontend parameter by rule
""" """
parameter_value = value parameter_value = value
if not parameter_value and parameter_value != 0 and type != PluginParameterType.TOOLS_SELECTOR: if not parameter_value and parameter_value != 0:
# get default value # get default value
parameter_value = rule.default parameter_value = rule.default
if not parameter_value and rule.required: if not parameter_value and rule.required:

@ -70,6 +70,9 @@ class PluginDeclaration(BaseModel):
models: Optional[list[str]] = Field(default_factory=list) models: Optional[list[str]] = Field(default_factory=list)
endpoints: Optional[list[str]] = Field(default_factory=list) endpoints: Optional[list[str]] = Field(default_factory=list)
class Meta(BaseModel):
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
version: str = Field(..., pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$") version: str = Field(..., pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
author: Optional[str] = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$") author: Optional[str] = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$")
name: str = Field(..., pattern=r"^[a-z0-9_-]{1,128}$") name: str = Field(..., pattern=r"^[a-z0-9_-]{1,128}$")
@ -86,6 +89,7 @@ class PluginDeclaration(BaseModel):
model: Optional[ProviderEntity] = None model: Optional[ProviderEntity] = None
endpoint: Optional[EndpointProviderDeclaration] = None endpoint: Optional[EndpointProviderDeclaration] = None
agent_strategy: Optional[AgentStrategyProviderEntity] = None agent_strategy: Optional[AgentStrategyProviderEntity] = None
meta: Meta
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@ -120,8 +124,6 @@ class PluginEntity(PluginInstallation):
name: str name: str
installation_id: str installation_id: str
version: str version: str
latest_version: Optional[str] = None
latest_unique_identifier: Optional[str] = None
@model_validator(mode="after") @model_validator(mode="after")
def set_plugin_id(self): def set_plugin_id(self):

@ -204,3 +204,11 @@ class RequestRequestUploadFile(BaseModel):
filename: str filename: str
mimetype: str mimetype: str
class RequestFetchAppInfo(BaseModel):
"""
Request to fetch app info
"""
app_id: str

@ -82,7 +82,7 @@ class BasePluginManager:
Make a stream request to the plugin daemon inner API Make a stream request to the plugin daemon inner API
""" """
response = self._request(method, path, headers, data, params, files, stream=True) response = self._request(method, path, headers, data, params, files, stream=True)
for line in response.iter_lines(): for line in response.iter_lines(chunk_size=1024 * 8):
line = line.decode("utf-8").strip() line = line.decode("utf-8").strip()
if line.startswith("data:"): if line.startswith("data:"):
line = line[5:].strip() line = line[5:].strip()
@ -168,16 +168,18 @@ class BasePluginManager:
Make a stream request to the plugin daemon inner API and yield the response as a model. Make a stream request to the plugin daemon inner API and yield the response as a model.
""" """
for line in self._stream_request(method, path, params, headers, data, files): for line in self._stream_request(method, path, params, headers, data, files):
line_data = None
try: try:
line_data = json.loads(line) rep = PluginDaemonBasicResponse[type].model_validate_json(line) # type: ignore
rep = PluginDaemonBasicResponse[type](**line_data) # type: ignore except (ValueError, TypeError):
except Exception:
# TODO modify this when line_data has code and message # TODO modify this when line_data has code and message
if line_data and "error" in line_data: try:
raise ValueError(line_data["error"]) line_data = json.loads(line)
else: except (ValueError, TypeError):
raise ValueError(line) raise ValueError(line)
# If the dictionary contains the `error` key, use its value as the argument
# for `ValueError`.
# Otherwise, use the `line` to provide better contextual information about the error.
raise ValueError(line_data.get("error", line))
if rep.code != 0: if rep.code != 0:
if rep.code == -500: if rep.code == -500:

@ -110,7 +110,62 @@ class PluginToolManager(BasePluginManager):
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
) )
return response
class FileChunk:
"""
Only used for internal processing.
"""
bytes_written: int
total_length: int
data: bytearray
def __init__(self, total_length: int):
self.bytes_written = 0
self.total_length = total_length
self.data = bytearray(total_length)
files: dict[str, FileChunk] = {}
for resp in response:
if resp.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
assert isinstance(resp.message, ToolInvokeMessage.BlobChunkMessage)
# Get blob chunk information
chunk_id = resp.message.id
total_length = resp.message.total_length
blob_data = resp.message.blob
is_end = resp.message.end
# Initialize buffer for this file if it doesn't exist
if chunk_id not in files:
files[chunk_id] = FileChunk(total_length)
# If this is the final chunk, yield a complete blob message
if is_end:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB,
message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data),
meta=resp.meta,
)
else:
# Check if file is too large (30MB limit)
if files[chunk_id].bytes_written + len(blob_data) > 30 * 1024 * 1024:
# Delete the file if it's too large
del files[chunk_id]
# Skip yielding this message
raise ValueError("File is too large which reached the limit of 30MB")
# Check if single chunk is too large (8KB limit)
if len(blob_data) > 8192:
# Skip yielding this message
raise ValueError("File chunk is too large which reached the limit of 8KB")
# Append the blob data to the buffer
files[chunk_id].data[
files[chunk_id].bytes_written : files[chunk_id].bytes_written + len(blob_data)
] = blob_data
files[chunk_id].bytes_written += len(blob_data)
else:
yield resp
def validate_provider_credentials( def validate_provider_credentials(
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any] self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]

@ -28,7 +28,7 @@ BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = {
}, },
"conversation_histories_role": {"user_prefix": "用户", "assistant_prefix": "助手"}, "conversation_histories_role": {"user_prefix": "用户", "assistant_prefix": "助手"},
}, },
"stop": ["用户:"], "stop": ["用户"],
} }
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = { BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = {
@ -41,5 +41,5 @@ BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = {
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = { BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}}, "completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}},
"stop": ["用户:"], "stop": ["用户"],
} }

@ -124,6 +124,15 @@ class ProviderManager:
# Get All preferred provider types of the workspace # Get All preferred provider types of the workspace
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id) provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
# Ensure that both the original provider name and its ModelProviderID string representation
# are present in the dictionary to handle cases where either form might be used
for provider_name in list(provider_name_to_preferred_model_provider_records_dict.keys()):
provider_id = ModelProviderID(provider_name)
if str(provider_id) not in provider_name_to_preferred_model_provider_records_dict:
# Add the ModelProviderID string representation if it's not already present
provider_name_to_preferred_model_provider_records_dict[str(provider_id)] = (
provider_name_to_preferred_model_provider_records_dict[provider_name]
)
# Get All provider model settings # Get All provider model settings
provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id) provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id)
@ -497,8 +506,8 @@ class ProviderManager:
@staticmethod @staticmethod
def _init_trial_provider_records( def _init_trial_provider_records(
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list] tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
) -> dict[str, list]: ) -> dict[str, list[Provider]]:
""" """
Initialize trial provider records if not exists. Initialize trial provider records if not exists.
@ -532,7 +541,7 @@ class ProviderManager:
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
try: try:
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic # FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
provider_record = Provider( new_provider_record = Provider(
tenant_id=tenant_id, tenant_id=tenant_id,
# TODO: Use provider name with prefix after the data migration. # TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name, provider_name=ModelProviderID(provider_name).provider_name,
@ -542,11 +551,12 @@ class ProviderManager:
quota_used=0, quota_used=0,
is_valid=True, is_valid=True,
) )
db.session.add(provider_record) db.session.add(new_provider_record)
db.session.commit() db.session.commit()
provider_name_to_provider_records_dict[provider_name].append(new_provider_record)
except IntegrityError: except IntegrityError:
db.session.rollback() db.session.rollback()
provider_record = ( existed_provider_record = (
db.session.query(Provider) db.session.query(Provider)
.filter( .filter(
Provider.tenant_id == tenant_id, Provider.tenant_id == tenant_id,
@ -556,11 +566,14 @@ class ProviderManager:
) )
.first() .first()
) )
if provider_record and not provider_record.is_valid: if not existed_provider_record:
provider_record.is_valid = True continue
if not existed_provider_record.is_valid:
existed_provider_record.is_valid = True
db.session.commit() db.session.commit()
provider_name_to_provider_records_dict[provider_name].append(provider_record) provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
return provider_name_to_provider_records_dict return provider_name_to_provider_records_dict

@ -27,9 +27,26 @@ class CleanProcessor:
pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)" pattern = r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)"
text = re.sub(pattern, "", text) text = re.sub(pattern, "", text)
# Remove URL # Remove URL but keep Markdown image URLs
pattern = r"https?://[^\s]+" # First, temporarily replace Markdown image URLs with a placeholder
text = re.sub(pattern, "", text) markdown_image_pattern = r"!\[.*?\]\((https?://[^\s)]+)\)"
placeholders: list[str] = []
def replace_with_placeholder(match, placeholders=placeholders):
url = match.group(1)
placeholder = f"__MARKDOWN_IMAGE_URL_{len(placeholders)}__"
placeholders.append(url)
return f"![image]({placeholder})"
text = re.sub(markdown_image_pattern, replace_with_placeholder, text)
# Now remove all remaining URLs
url_pattern = r"https?://[^\s)]+"
text = re.sub(url_pattern, "", text)
# Finally, restore the Markdown image URLs
for i, url in enumerate(placeholders):
text = text.replace(f"__MARKDOWN_IMAGE_URL_{i}__", url)
return text return text
def filter_string(self, text): def filter_string(self, text):

@ -46,7 +46,7 @@ class RetrievalService:
if not query: if not query:
return [] return []
dataset = cls._get_dataset(dataset_id) dataset = cls._get_dataset(dataset_id)
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: if not dataset:
return [] return []
all_documents: list[Document] = [] all_documents: list[Document] = []

@ -139,6 +139,7 @@ class AnalyticdbVectorBySql:
) )
if embedding_dimension is not None: if embedding_dimension is not None:
index_name = f"{self._collection_name}_embedding_idx" index_name = f"{self._collection_name}_embedding_idx"
try:
cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN") cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN")
cur.execute( cur.execute(
f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) " f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) "
@ -146,6 +147,9 @@ class AnalyticdbVectorBySql:
f"pq_enable=0, external_storage=0)" f"pq_enable=0, external_storage=0)"
) )
cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)") cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)")
except Exception as e:
if "already exists" not in str(e):
raise e
redis_client.set(collection_exist_cache_key, 1, ex=3600) redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
@ -177,9 +181,11 @@ class AnalyticdbVectorBySql:
return cur.fetchone() is not None return cur.fetchone() is not None
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
with self._get_cursor() as cur: with self._get_cursor() as cur:
try: try:
cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id IN %s", (tuple(ids),)) cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id = ANY(%s)", (ids,))
except Exception as e: except Exception as e:
if "does not exist" not in str(e): if "does not exist" not in str(e):
raise e raise e
@ -240,7 +246,7 @@ class AnalyticdbVectorBySql:
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
FROM {self.table_name} FROM {self.table_name}
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause} WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
ORDER BY score DESC ORDER BY score DESC, id DESC
LIMIT {top_k}""", LIMIT {top_k}""",
(f"'{query}'", f"'{query}'"), (f"'{query}'", f"'{query}'"),
) )

@ -1,10 +1,13 @@
import copy import copy
import json import json
import logging import logging
import time
from typing import Any, Optional from typing import Any, Optional
from opensearchpy import OpenSearch from opensearchpy import OpenSearch, helpers
from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from tenacity import retry, stop_after_attempt, wait_exponential
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
@ -77,10 +80,43 @@ class LindormVectorStore(BaseVector):
def refresh(self): def refresh(self):
self._client.indices.refresh(index=self._collection_name) self._client.indices.refresh(index=self._collection_name)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): def add_texts(
actions = [] self,
documents: list[Document],
embeddings: list[list[float]],
batch_size: int = 64,
timeout: int = 60,
**kwargs,
):
logger.info(f"Total documents to add: {len(documents)}")
uuids = self._get_uuids(documents) uuids = self._get_uuids(documents)
for i in range(len(documents)):
total_docs = len(documents)
num_batches = (total_docs + batch_size - 1) // batch_size
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
)
def _bulk_with_retry(actions):
try:
response = self._client.bulk(actions, timeout=timeout)
if response["errors"]:
error_items = [item for item in response["items"] if "error" in item["index"]]
error_msg = f"Bulk indexing had {len(error_items)} errors"
logger.exception(error_msg)
raise Exception(error_msg)
return response
except Exception:
logger.exception("Bulk indexing error")
raise
for batch_num in range(num_batches):
start_idx = batch_num * batch_size
end_idx = min((batch_num + 1) * batch_size, total_docs)
actions = []
for i in range(start_idx, end_idx):
action_header = { action_header = {
"index": { "index": {
"_index": self.collection_name.lower(), "_index": self.collection_name.lower(),
@ -89,19 +125,29 @@ class LindormVectorStore(BaseVector):
} }
action_values: dict[str, Any] = { action_values: dict[str, Any] = {
Field.CONTENT_KEY.value: documents[i].page_content, Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata, Field.METADATA_KEY.value: documents[i].metadata,
} }
if self._using_ugc: if self._using_ugc:
action_header["index"]["routing"] = self._routing action_header["index"]["routing"] = self._routing
if self._routing_field is not None: if self._routing_field is not None:
action_values[self._routing_field] = self._routing action_values[self._routing_field] = self._routing
actions.append(action_header) actions.append(action_header)
actions.append(action_values) actions.append(action_values)
response = self._client.bulk(actions)
if response["errors"]: # logger.info(f"Processing batch {batch_num + 1}/{num_batches} (documents {start_idx + 1} to {end_idx})")
for item in response["items"]:
print(f"{item['index']['status']}: {item['index']['error']['type']}") try:
_bulk_with_retry(actions)
# logger.info(f"Successfully processed batch {batch_num + 1}")
# simple latency to avoid too many requests in a short time
if batch_num < num_batches - 1:
time.sleep(0.5)
except Exception:
logger.exception(f"Failed to process batch {batch_num + 1}")
raise
def get_ids_by_metadata_field(self, key: str, value: str): def get_ids_by_metadata_field(self, key: str, value: str):
query: dict[str, Any] = { query: dict[str, Any] = {
@ -121,19 +167,51 @@ class LindormVectorStore(BaseVector):
self.delete_by_ids(ids) self.delete_by_ids(ids)
def delete_by_ids(self, ids: list[str]) -> None: def delete_by_ids(self, ids: list[str]) -> None:
params = {} """Delete documents by their IDs in batch.
if self._using_ugc:
params["routing"] = self._routing Args:
ids: List of document IDs to delete
"""
if not ids:
return
params = {"routing": self._routing} if self._using_ugc else {}
# 1. First check if collection exists
if not self._client.indices.exists(index=self._collection_name):
logger.warning(f"Collection {self._collection_name} does not exist")
return
# 2. Batch process deletions
actions = []
for id in ids: for id in ids:
if self._client.exists(index=self._collection_name, id=id, params=params): if self._client.exists(index=self._collection_name, id=id, params=params):
params = {} actions.append(
if self._using_ugc: {
params["routing"] = self._routing "_op_type": "delete",
self._client.delete(index=self._collection_name, id=id, params=params) "_index": self._collection_name,
self.refresh() "_id": id,
**params, # Include routing if using UGC
}
)
else: else:
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.") logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")
# 3. Perform bulk deletion if there are valid documents to delete
if actions:
try:
helpers.bulk(self._client, actions)
except BulkIndexError as e:
for error in e.errors:
delete_error = error.get("delete", {})
status = delete_error.get("status")
doc_id = delete_error.get("_id")
if status == 404:
logger.warning(f"Document not found for deletion: {doc_id}")
else:
logger.exception(f"Error deleting document: {error}")
def delete(self) -> None: def delete(self) -> None:
if self._using_ugc: if self._using_ugc:
routing_filter_query = { routing_filter_query = {
@ -169,7 +247,7 @@ class LindormVectorStore(BaseVector):
document_ids_filter = kwargs.get("document_ids_filter") document_ids_filter = kwargs.get("document_ids_filter")
filters = [] filters = []
if document_ids_filter: if document_ids_filter:
filters.append({"terms": {"metadata.document_id": document_ids_filter}}) filters.append({"terms": {"metadata.document_id.keyword": document_ids_filter}})
query = default_vector_search_query(query_vector=query_vector, k=top_k, filters=filters, **kwargs) query = default_vector_search_query(query_vector=query_vector, k=top_k, filters=filters, **kwargs)
try: try:
@ -212,7 +290,7 @@ class LindormVectorStore(BaseVector):
filters = kwargs.get("filter", []) filters = kwargs.get("filter", [])
document_ids_filter = kwargs.get("document_ids_filter") document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter: if document_ids_filter:
filters.append({"terms": {"metadata.document_id": document_ids_filter}}) filters.append({"terms": {"metadata.document_id.keyword": document_ids_filter}})
routing = self._routing routing = self._routing
full_text_query = default_text_search_query( full_text_query = default_text_search_query(
query_text=query, query_text=query,
@ -226,6 +304,7 @@ class LindormVectorStore(BaseVector):
routing=routing, routing=routing,
routing_field=self._routing_field, routing_field=self._routing_field,
) )
response = self._client.search(index=self._collection_name, body=full_text_query) response = self._client.search(index=self._collection_name, body=full_text_query)
docs = [] docs = []
for hit in response["hits"]["hits"]: for hit in response["hits"]["hits"]:
@ -435,7 +514,7 @@ def default_vector_search_query(
**kwargs, **kwargs,
) -> dict: ) -> dict:
if filters is not None: if filters is not None:
filter_type = "post_filter" if filter_type is None else filter_type filter_type = "pre_filter" if filter_type is None else filter_type
if not isinstance(filters, list): if not isinstance(filters, list):
raise RuntimeError(f"unexpected filter with {type(filters)}") raise RuntimeError(f"unexpected filter with {type(filters)}")
final_ext: dict[str, Any] = {"lvector": {}} final_ext: dict[str, Any] = {"lvector": {}}

@ -32,6 +32,7 @@ class MilvusConfig(BaseModel):
batch_size: int = 100 # Batch size for operations batch_size: int = 100 # Batch size for operations
database: str = "default" # Database name database: str = "default" # Database name
enable_hybrid_search: bool = False # Flag to enable hybrid search enable_hybrid_search: bool = False # Flag to enable hybrid search
analyzer_params: Optional[str] = None # Analyzer params
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@ -58,6 +59,7 @@ class MilvusConfig(BaseModel):
"user": self.user, "user": self.user,
"password": self.password, "password": self.password,
"db_name": self.database, "db_name": self.database,
"analyzer_params": self.analyzer_params,
} }
@ -300,14 +302,19 @@ class MilvusVector(BaseVector):
# Create the text field, enable_analyzer will be set True to support milvus automatically # Create the text field, enable_analyzer will be set True to support milvus automatically
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md # transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
fields.append( content_field_kwargs: dict[str, Any] = {
FieldSchema( "max_length": 65_535,
Field.CONTENT_KEY.value, "enable_analyzer": self._hybrid_search_enabled,
DataType.VARCHAR, }
max_length=65_535, if (
enable_analyzer=self._hybrid_search_enabled, self._hybrid_search_enabled
) and self._client_config.analyzer_params is not None
) and self._client_config.analyzer_params.strip()
):
content_field_kwargs["analyzer_params"] = self._client_config.analyzer_params
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, **content_field_kwargs))
# Create the primary key field # Create the primary key field
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True)) fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
# Create the vector field, supports binary or float vectors # Create the vector field, supports binary or float vectors
@ -383,5 +390,6 @@ class MilvusVectorFactory(AbstractVectorFactory):
password=dify_config.MILVUS_PASSWORD or "", password=dify_config.MILVUS_PASSWORD or "",
database=dify_config.MILVUS_DATABASE or "", database=dify_config.MILVUS_DATABASE or "",
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False, enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
analyzer_params=dify_config.MILVUS_ANALYZER_PARAMS or "",
), ),
) )

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save