Merge branch 'main' of github.com:parambharat/dify into tracing-weave

# Conflicts:
#	api/pyproject.toml
#	web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config-popup.tsx
pull/14262/head
Bharat Ramanathan 1 year ago
commit 0d160544ae

@ -1,13 +1,13 @@
#!/bin/bash #!/bin/bash
npm add -g pnpm@9.12.2 npm add -g pnpm@10.8.0
cd web && pnpm install cd web && pnpm install
pipx install poetry pipx install uv
echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc echo 'alias start-api="cd /workspaces/dify/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc
source /home/vscode/.bashrc source /home/vscode/.bashrc

@ -1,3 +1,3 @@
#!/bin/bash #!/bin/bash
cd api && poetry install cd api && uv sync

@ -1,36 +0,0 @@
name: Setup Poetry and Python
inputs:
python-version:
description: Python version to use and the Poetry installed with
required: true
default: '3.11'
poetry-version:
description: Poetry version to set up
required: true
default: '2.0.1'
poetry-lockfile:
description: Path to the Poetry lockfile to restore cache from
required: true
default: ''
runs:
using: composite
steps:
- name: Set up Python ${{ inputs.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ inputs.python-version }}
cache: pip
- name: Install Poetry
shell: bash
run: pip install poetry==${{ inputs.poetry-version }}
- name: Restore Poetry cache
if: ${{ inputs.poetry-lockfile != '' }}
uses: actions/setup-python@v5
with:
python-version: ${{ inputs.python-version }}
cache: poetry
cache-dependency-path: ${{ inputs.poetry-lockfile }}

@ -0,0 +1,34 @@
name: Setup UV and Python
inputs:
python-version:
description: Python version to use and the UV installed with
required: true
default: '3.12'
uv-version:
description: UV version to set up
required: true
default: '0.6.14'
uv-lockfile:
description: Path to the UV lockfile to restore cache from
required: true
default: ''
enable-cache:
required: true
default: true
runs:
using: composite
steps:
- name: Set up Python ${{ inputs.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ inputs.python-version }}
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
version: ${{ inputs.uv-version }}
python-version: ${{ inputs.python-version }}
enable-cache: ${{ inputs.enable-cache }}
cache-dependency-glob: ${{ inputs.uv-lockfile }}

@ -17,6 +17,9 @@ jobs:
test: test:
name: API Tests name: API Tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
defaults:
run:
shell: bash
strategy: strategy:
matrix: matrix:
python-version: python-version:
@ -27,35 +30,44 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Setup Poetry and Python ${{ matrix.python-version }} - name: Setup UV and Python
uses: ./.github/actions/setup-poetry uses: ./.github/actions/setup-uv
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
poetry-lockfile: api/poetry.lock uv-lockfile: api/uv.lock
- name: Check Poetry lockfile - name: Check UV lockfile
run: | run: uv lock --project api --check
poetry check -C api --lock
poetry show -C api
- name: Install dependencies - name: Install dependencies
run: poetry install -C api --with dev run: uv sync --project api --dev
- name: Check dependencies in pyproject.toml
run: poetry run -P api bash dev/pytest/pytest_artifacts.sh
- name: Run Unit tests - name: Run Unit tests
run: poetry run -P api bash dev/pytest/pytest_unit_tests.sh run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report >> $GITHUB_STEP_SUMMARY
echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
- name: Run dify config tests - name: Run dify config tests
run: poetry run -P api python dev/pytest/pytest_config_tests.py run: uv run --project api dev/pytest/pytest_config_tests.py
- name: Run mypy - name: MyPy Cache
run: | uses: actions/cache@v4
poetry run -C api python -m mypy --install-types --non-interactive . with:
path: api/.mypy_cache
key: mypy-${{ matrix.python-version }}-${{ runner.os }}-${{ hashFiles('api/uv.lock') }}
- name: Run MyPy Checks
run: dev/mypy-check
- name: Set up dotenvs - name: Set up dotenvs
run: | run: |
@ -75,4 +87,4 @@ jobs:
ssrf_proxy ssrf_proxy
- name: Run Workflow - name: Run Workflow
run: poetry run -P api bash dev/pytest/pytest_workflow.sh run: uv run --project api bash dev/pytest/pytest_workflow.sh

@ -24,13 +24,13 @@ jobs:
fetch-depth: 0 fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Setup Poetry and Python - name: Setup UV and Python
uses: ./.github/actions/setup-poetry uses: ./.github/actions/setup-uv
with: with:
poetry-lockfile: api/poetry.lock uv-lockfile: api/uv.lock
- name: Install dependencies - name: Install dependencies
run: poetry install -C api run: uv sync --project api
- name: Prepare middleware env - name: Prepare middleware env
run: | run: |
@ -54,6 +54,4 @@ jobs:
- name: Run DB Migration - name: Run DB Migration
env: env:
DEBUG: true DEBUG: true
run: | run: uv run --directory api flask upgrade-db
cd api
poetry run python -m flask upgrade-db

@ -42,6 +42,7 @@ jobs:
with: with:
push: false push: false
context: "{{defaultContext}}:${{ matrix.context }}" context: "{{defaultContext}}:${{ matrix.context }}"
file: "${{ matrix.file }}"
platforms: ${{ matrix.platform }} platforms: ${{ matrix.platform }}
cache-from: type=gha cache-from: type=gha
cache-to: type=gha,mode=max cache-to: type=gha,mode=max

@ -18,7 +18,6 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Check changed files - name: Check changed files
@ -29,24 +28,27 @@ jobs:
api/** api/**
.github/workflows/style.yml .github/workflows/style.yml
- name: Setup Poetry and Python - name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
uses: ./.github/actions/setup-poetry uses: ./.github/actions/setup-uv
with:
uv-lockfile: api/uv.lock
enable-cache: false
- name: Install dependencies - name: Install dependencies
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: poetry install -C api --only lint run: uv sync --project api --dev
- name: Ruff check - name: Ruff check
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: | run: |
poetry run -C api ruff --version uv run --directory api ruff --version
poetry run -C api ruff check ./ uv run --directory api ruff check ./
poetry run -C api ruff format --check ./ uv run --directory api ruff format --check ./
- name: Dotenv check - name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: poetry run -P api dotenv-linter ./api/.env.example ./web/.env.example run: uv run --project api dotenv-linter ./api/.env.example ./web/.env.example
- name: Lint hints - name: Lint hints
if: failure() if: failure()
@ -63,7 +65,6 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Check changed files - name: Check changed files
@ -82,7 +83,7 @@ jobs:
uses: actions/setup-node@v4 uses: actions/setup-node@v4
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
with: with:
node-version: 20 node-version: 22
cache: pnpm cache: pnpm
cache-dependency-path: ./web/package.json cache-dependency-path: ./web/package.json
@ -102,7 +103,6 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Check changed files - name: Check changed files
@ -133,7 +133,6 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Check changed files - name: Check changed files
@ -153,6 +152,7 @@ jobs:
env: env:
BASH_SEVERITY: warning BASH_SEVERITY: warning
DEFAULT_BRANCH: main DEFAULT_BRANCH: main
FILTER_REGEX_INCLUDE: pnpm-lock.yaml
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
IGNORE_GENERATED_FILES: true IGNORE_GENERATED_FILES: true
IGNORE_GITIGNORED_FILES: true IGNORE_GITIGNORED_FILES: true

@ -18,7 +18,7 @@ jobs:
strategy: strategy:
matrix: matrix:
node-version: [16, 18, 20] node-version: [16, 18, 20, 22]
defaults: defaults:
run: run:
@ -27,7 +27,6 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Use Node.js ${{ matrix.node-version }} - name: Use Node.js ${{ matrix.node-version }}

@ -33,7 +33,7 @@ jobs:
- name: Set up Node.js - name: Set up Node.js
if: env.FILES_CHANGED == 'true' if: env.FILES_CHANGED == 'true'
uses: actions/setup-node@v2 uses: actions/setup-node@v4
with: with:
node-version: 'lts/*' node-version: 'lts/*'

@ -8,7 +8,7 @@ on:
- api/core/rag/datasource/** - api/core/rag/datasource/**
- docker/** - docker/**
- .github/workflows/vdb-tests.yml - .github/workflows/vdb-tests.yml
- api/poetry.lock - api/uv.lock
- api/pyproject.toml - api/pyproject.toml
concurrency: concurrency:
@ -29,22 +29,19 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Setup Poetry and Python ${{ matrix.python-version }} - name: Setup UV and Python
uses: ./.github/actions/setup-poetry uses: ./.github/actions/setup-uv
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
poetry-lockfile: api/poetry.lock uv-lockfile: api/uv.lock
- name: Check Poetry lockfile - name: Check UV lockfile
run: | run: uv lock --project api --check
poetry check -C api --lock
poetry show -C api
- name: Install dependencies - name: Install dependencies
run: poetry install -C api --with dev run: uv sync --project api --dev
- name: Set up dotenvs - name: Set up dotenvs
run: | run: |
@ -80,7 +77,7 @@ jobs:
elasticsearch elasticsearch
- name: Check TiDB Ready - name: Check TiDB Ready
run: poetry run -P api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
- name: Test Vector Stores - name: Test Vector Stores
run: poetry run -P api bash dev/pytest/pytest_vdb.sh run: uv run --project api bash dev/pytest/pytest_vdb.sh

@ -23,7 +23,6 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0
persist-credentials: false persist-credentials: false
- name: Check changed files - name: Check changed files
@ -31,7 +30,9 @@ jobs:
uses: tj-actions/changed-files@v45 uses: tj-actions/changed-files@v45
with: with:
files: web/** files: web/**
- name: Install pnpm - name: Install pnpm
if: steps.changed-files.outputs.any_changed == 'true'
uses: pnpm/action-setup@v4 uses: pnpm/action-setup@v4
with: with:
version: 10 version: 10
@ -41,7 +42,7 @@ jobs:
uses: actions/setup-node@v4 uses: actions/setup-node@v4
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
with: with:
node-version: 20 node-version: 22
cache: pnpm cache: pnpm
cache-dependency-path: ./web/package.json cache-dependency-path: ./web/package.json

1
.gitignore vendored

@ -46,6 +46,7 @@ htmlcov/
.cache .cache
nosetests.xml nosetests.xml
coverage.xml coverage.xml
coverage.json
*.cover *.cover
*.py,cover *.py,cover
.hypothesis/ .hypothesis/

@ -254,8 +254,6 @@ docker compose up -d
- [Discord](https://discord.gg/FngNHpbcY7)。👉:分享您的应用程序并与社区交流。 - [Discord](https://discord.gg/FngNHpbcY7)。👉:分享您的应用程序并与社区交流。
- [X(Twitter)](https://twitter.com/dify_ai)。👉:分享您的应用程序并与社区交流。 - [X(Twitter)](https://twitter.com/dify_ai)。👉:分享您的应用程序并与社区交流。
- [商业许可](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)。👉:有关商业用途许可 Dify.AI 的商业咨询。 - [商业许可](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)。👉:有关商业用途许可 Dify.AI 的商业咨询。
- [微信]() 👉:扫描下方二维码,添加微信好友,备注 Dify我们将邀请您加入 Dify 社区。
<img src="./images/wechat.png" alt="wechat" width="100"/>
## 安全问题 ## 安全问题

@ -165,6 +165,7 @@ MILVUS_URI=http://127.0.0.1:19530
MILVUS_TOKEN= MILVUS_TOKEN=
MILVUS_USER=root MILVUS_USER=root
MILVUS_PASSWORD=Milvus MILVUS_PASSWORD=Milvus
MILVUS_ANALYZER_PARAMS=
# MyScale configuration # MyScale configuration
MYSCALE_HOST=127.0.0.1 MYSCALE_HOST=127.0.0.1
@ -189,6 +190,7 @@ TENCENT_VECTOR_DB_USERNAME=dify
TENCENT_VECTOR_DB_DATABASE=dify TENCENT_VECTOR_DB_DATABASE=dify
TENCENT_VECTOR_DB_SHARD=1 TENCENT_VECTOR_DB_SHARD=1
TENCENT_VECTOR_DB_REPLICAS=2 TENCENT_VECTOR_DB_REPLICAS=2
TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH=false
# ElasticSearch configuration # ElasticSearch configuration
ELASTICSEARCH_HOST=127.0.0.1 ELASTICSEARCH_HOST=127.0.0.1
@ -325,6 +327,7 @@ UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
MULTIMODAL_SEND_FORMAT=base64 MULTIMODAL_SEND_FORMAT=base64
PROMPT_GENERATION_MAX_TOKENS=512 PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024 CODE_GENERATION_MAX_TOKENS=1024
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
# Mail configuration, support: resend, smtp # Mail configuration, support: resend, smtp
MAIL_TYPE= MAIL_TYPE=
@ -421,6 +424,12 @@ WORKFLOW_CALL_MAX_DEPTH=5
WORKFLOW_PARALLEL_DEPTH_LIMIT=3 WORKFLOW_PARALLEL_DEPTH_LIMIT=3
MAX_VARIABLE_SIZE=204800 MAX_VARIABLE_SIZE=204800
# Workflow storage configuration
# Options: rdbms, hybrid
# rdbms: Use only the relational database (default)
# hybrid: Save new data to object storage, read from both object storage and RDBMS
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
# App configuration # App configuration
APP_MAX_EXECUTION_TIME=1200 APP_MAX_EXECUTION_TIME=1200
APP_MAX_ACTIVE_REQUESTS=0 APP_MAX_ACTIVE_REQUESTS=0
@ -461,3 +470,16 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false
MAX_SUBMIT_COUNT=100 MAX_SUBMIT_COUNT=100
# Lockout duration in seconds # Lockout duration in seconds
LOGIN_LOCKOUT_DURATION=86400 LOGIN_LOCKOUT_DURATION=86400
# Enable OpenTelemetry
ENABLE_OTEL=false
OTLP_BASE_ENDPOINT=http://localhost:4318
OTLP_API_KEY=
OTEL_EXPORTER_TYPE=otlp
OTEL_SAMPLING_RATE=0.1
OTEL_BATCH_EXPORT_SCHEDULE_DELAY=5000
OTEL_MAX_QUEUE_SIZE=2048
OTEL_MAX_EXPORT_BATCH_SIZE=512
OTEL_METRIC_EXPORT_INTERVAL=60000
OTEL_BATCH_EXPORT_TIMEOUT=10000
OTEL_METRIC_EXPORT_TIMEOUT=30000

@ -3,20 +3,11 @@ FROM python:3.12-slim-bookworm AS base
WORKDIR /app/api WORKDIR /app/api
# Install Poetry # Install uv
ENV POETRY_VERSION=2.0.1 ENV UV_VERSION=0.6.14
# if you located in China, you can use aliyun mirror to speed up RUN pip install --no-cache-dir uv==${UV_VERSION}
# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/
RUN pip install --no-cache-dir poetry==${POETRY_VERSION}
# Configure Poetry
ENV POETRY_CACHE_DIR=/tmp/poetry_cache
ENV POETRY_NO_INTERACTION=1
ENV POETRY_VIRTUALENVS_IN_PROJECT=true
ENV POETRY_VIRTUALENVS_CREATE=true
ENV POETRY_REQUESTS_TIMEOUT=15
FROM base AS packages FROM base AS packages
@ -27,8 +18,8 @@ RUN apt-get update \
&& apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev && apt-get install -y --no-install-recommends gcc g++ libc-dev libffi-dev libgmp-dev libmpfr-dev libmpc-dev
# Install Python dependencies # Install Python dependencies
COPY pyproject.toml poetry.lock ./ COPY pyproject.toml uv.lock ./
RUN poetry install --sync --no-cache --no-root RUN uv sync --locked
# production stage # production stage
FROM base AS production FROM base AS production

@ -3,7 +3,10 @@
## Usage ## Usage
> [!IMPORTANT] > [!IMPORTANT]
> In the v0.6.12 release, we deprecated `pip` as the package management tool for Dify API Backend service and replaced it with `poetry`. >
> In the v1.3.0 release, `poetry` has been replaced with
> [`uv`](https://docs.astral.sh/uv/) as the package manager
> for Dify API backend service.
1. Start the docker-compose stack 1. Start the docker-compose stack
@ -37,19 +40,19 @@
4. Create environment. 4. Create environment.
Dify API service uses [Poetry](https://python-poetry.org/docs/) to manage dependencies. First, you need to add the poetry shell plugin, if you don't have it already, in order to run in a virtual environment. [Note: Poetry shell is no longer a native command so you need to install the poetry plugin beforehand] Dify API service uses [UV](https://docs.astral.sh/uv/) to manage dependencies.
First, you need to add the uv package manager, if you don't have it already.
```bash ```bash
poetry self add poetry-plugin-shell pip install uv
# Or on macOS
brew install uv
``` ```
Then, You can execute `poetry shell` to activate the environment.
5. Install dependencies 5. Install dependencies
```bash ```bash
poetry env use 3.12 uv sync --dev
poetry install
``` ```
6. Run migrate 6. Run migrate
@ -57,21 +60,21 @@
Before the first launch, migrate the database to the latest version. Before the first launch, migrate the database to the latest version.
```bash ```bash
poetry run python -m flask db upgrade uv run flask db upgrade
``` ```
7. Start backend 7. Start backend
```bash ```bash
poetry run python -m flask run --host 0.0.0.0 --port=5001 --debug uv run flask run --host 0.0.0.0 --port=5001 --debug
``` ```
8. Start Dify [web](../web) service. 8. Start Dify [web](../web) service.
9. Setup your application by visiting `http://localhost:3000`... 9. Setup your application by visiting `http://localhost:3000`.
10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. 10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash ```bash
poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion
``` ```
## Testing ## Testing
@ -79,11 +82,11 @@
1. Install dependencies for both the backend and the test environment 1. Install dependencies for both the backend and the test environment
```bash ```bash
poetry install -C api --with dev uv sync --dev
``` ```
2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml` 2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
```bash ```bash
poetry run -P api bash dev/pytest/pytest_all_tests.sh uv run -P api bash dev/pytest/pytest_all_tests.sh
``` ```

@ -51,8 +51,10 @@ def initialize_extensions(app: DifyApp):
ext_login, ext_login,
ext_mail, ext_mail,
ext_migrate, ext_migrate,
ext_otel,
ext_proxy_fix, ext_proxy_fix,
ext_redis, ext_redis,
ext_repositories,
ext_sentry, ext_sentry,
ext_set_secretkey, ext_set_secretkey,
ext_storage, ext_storage,
@ -73,6 +75,7 @@ def initialize_extensions(app: DifyApp):
ext_migrate, ext_migrate,
ext_redis, ext_redis,
ext_storage, ext_storage,
ext_repositories,
ext_celery, ext_celery,
ext_login, ext_login,
ext_mail, ext_mail,
@ -81,6 +84,7 @@ def initialize_extensions(app: DifyApp):
ext_proxy_fix, ext_proxy_fix,
ext_blueprints, ext_blueprints,
ext_commands, ext_commands,
ext_otel,
] ]
for ext in extensions: for ext in extensions:
short_name = ext.__name__.split(".")[-1] short_name = ext.__name__.split(".")[-1]

@ -9,6 +9,7 @@ from .enterprise import EnterpriseFeatureConfig
from .extra import ExtraServiceConfig from .extra import ExtraServiceConfig
from .feature import FeatureConfig from .feature import FeatureConfig
from .middleware import MiddlewareConfig from .middleware import MiddlewareConfig
from .observability import ObservabilityConfig
from .packaging import PackagingInfo from .packaging import PackagingInfo
from .remote_settings_sources import RemoteSettingsSource, RemoteSettingsSourceConfig, RemoteSettingsSourceName from .remote_settings_sources import RemoteSettingsSource, RemoteSettingsSourceConfig, RemoteSettingsSourceName
from .remote_settings_sources.apollo import ApolloSettingsSource from .remote_settings_sources.apollo import ApolloSettingsSource
@ -59,6 +60,8 @@ class DifyConfig(
MiddlewareConfig, MiddlewareConfig,
# Extra service configs # Extra service configs
ExtraServiceConfig, ExtraServiceConfig,
# Observability configs
ObservabilityConfig,
# Remote source configs # Remote source configs
RemoteSettingsSourceConfig, RemoteSettingsSourceConfig,
# Enterprise feature configs # Enterprise feature configs

@ -12,7 +12,7 @@ from pydantic import (
) )
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
from configs.feature.hosted_service import HostedServiceConfig from .hosted_service import HostedServiceConfig
class SecurityConfig(BaseSettings): class SecurityConfig(BaseSettings):
@ -442,7 +442,7 @@ class LoggingConfig(BaseSettings):
class ModelLoadBalanceConfig(BaseSettings): class ModelLoadBalanceConfig(BaseSettings):
""" """
Configuration for model load balancing Configuration for model load balancing and token counting
""" """
MODEL_LB_ENABLED: bool = Field( MODEL_LB_ENABLED: bool = Field(
@ -450,6 +450,11 @@ class ModelLoadBalanceConfig(BaseSettings):
default=False, default=False,
) )
PLUGIN_BASED_TOKEN_COUNTING_ENABLED: bool = Field(
description="Enable or disable plugin based token counting. If disabled, token counting will return 0.",
default=False,
)
class BillingConfig(BaseSettings): class BillingConfig(BaseSettings):
""" """
@ -514,6 +519,11 @@ class WorkflowNodeExecutionConfig(BaseSettings):
default=100, default=100,
) )
WORKFLOW_NODE_EXECUTION_STORAGE: str = Field(
default="rdbms",
description="Storage backend for WorkflowNodeExecution. Options: 'rdbms', 'hybrid'",
)
class AuthConfig(BaseSettings): class AuthConfig(BaseSettings):
""" """
@ -848,6 +858,11 @@ class AccountConfig(BaseSettings):
default=5, default=5,
) )
EDUCATION_ENABLED: bool = Field(
description="whether to enable education identity",
default=False,
)
class FeatureConfig( class FeatureConfig(
# place the configs in alphabet order # place the configs in alphabet order

@ -39,3 +39,8 @@ class MilvusConfig(BaseSettings):
"older versions", "older versions",
default=True, default=True,
) )
MILVUS_ANALYZER_PARAMS: Optional[str] = Field(
description='Milvus text analyzer parameters, e.g., {"type": "chinese"} for Chinese segmentation support.',
default=None,
)

@ -48,3 +48,8 @@ class TencentVectorDBConfig(BaseSettings):
description="Name of the specific Tencent Vector Database to connect to", description="Name of the specific Tencent Vector Database to connect to",
default=None, default=None,
) )
TENCENT_VECTOR_DB_ENABLE_HYBRID_SEARCH: bool = Field(
description="Enable hybrid search features",
default=False,
)

@ -0,0 +1,9 @@
from configs.observability.otel.otel_config import OTelConfig
class ObservabilityConfig(OTelConfig):
"""
Observability configuration settings
"""
pass

@ -0,0 +1,44 @@
from pydantic import Field
from pydantic_settings import BaseSettings
class OTelConfig(BaseSettings):
"""
OpenTelemetry configuration settings
"""
ENABLE_OTEL: bool = Field(
description="Whether to enable OpenTelemetry",
default=False,
)
OTLP_BASE_ENDPOINT: str = Field(
description="OTLP base endpoint",
default="http://localhost:4318",
)
OTLP_API_KEY: str = Field(
description="OTLP API key",
default="",
)
OTEL_EXPORTER_TYPE: str = Field(
description="OTEL exporter type",
default="otlp",
)
OTEL_SAMPLING_RATE: float = Field(default=0.1, description="Sampling rate for traces (0.0 to 1.0)")
OTEL_BATCH_EXPORT_SCHEDULE_DELAY: int = Field(
default=5000, description="Batch export schedule delay in milliseconds"
)
OTEL_MAX_QUEUE_SIZE: int = Field(default=2048, description="Maximum queue size for the batch span processor")
OTEL_MAX_EXPORT_BATCH_SIZE: int = Field(default=512, description="Maximum export batch size")
OTEL_METRIC_EXPORT_INTERVAL: int = Field(default=60000, description="Metric export interval in milliseconds")
OTEL_BATCH_EXPORT_TIMEOUT: int = Field(default=10000, description="Batch export timeout in milliseconds")
OTEL_METRIC_EXPORT_TIMEOUT: int = Field(default=30000, description="Metric export timeout in milliseconds")

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field( CURRENT_VERSION: str = Field(
description="Dify version", description="Dify version",
default="1.1.3", default="1.2.0",
) )
COMMIT_SHA: str = Field( COMMIT_SHA: str = Field(

@ -270,7 +270,7 @@ class ApolloClient:
while not self._stopping: while not self._stopping:
for namespace in self._notification_map: for namespace in self._notification_map:
self._do_heart_beat(namespace) self._do_heart_beat(namespace)
time.sleep(60 * 10) # 10分钟 time.sleep(60 * 10) # 10 minutes
def _do_heart_beat(self, namespace): def _do_heart_beat(self, namespace):
url = "{}/configs/{}/{}/{}?ip={}".format(self.config_url, self.app_id, self.cluster, namespace, self.ip) url = "{}/configs/{}/{}/{}?ip={}".format(self.config_url, self.app_id, self.cluster, namespace, self.ip)

@ -3,6 +3,8 @@ from configs import dify_config
HIDDEN_VALUE = "[__HIDDEN__]" HIDDEN_VALUE = "[__HIDDEN__]"
UUID_NIL = "00000000-0000-0000-0000-000000000000" UUID_NIL = "00000000-0000-0000-0000-000000000000"
DEFAULT_FILE_NUMBER_LIMITS = 3
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])

@ -4,8 +4,6 @@ import platform
import re import re
import urllib.parse import urllib.parse
import warnings import warnings
from collections.abc import Mapping
from typing import Any
from uuid import uuid4 from uuid import uuid4
import httpx import httpx
@ -29,8 +27,6 @@ except ImportError:
from pydantic import BaseModel from pydantic import BaseModel
from configs import dify_config
class FileInfo(BaseModel): class FileInfo(BaseModel):
filename: str filename: str
@ -87,38 +83,3 @@ def guess_file_info_from_response(response: httpx.Response):
mimetype=mimetype, mimetype=mimetype,
size=int(response.headers.get("Content-Length", -1)), size=int(response.headers.get("Content-Length", -1)),
) )
def get_parameters_from_feature_dict(*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]):
return {
"opening_statement": features_dict.get("opening_statement"),
"suggested_questions": features_dict.get("suggested_questions", []),
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
"user_input_form": user_input_form,
"sensitive_word_avoidance": features_dict.get(
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
),
"file_upload": features_dict.get(
"file_upload",
{
"image": {
"enabled": False,
"number_limits": 3,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"],
}
},
),
"system_parameters": {
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
},
}

@ -89,7 +89,7 @@ class AnnotationReplyActionStatusApi(Resource):
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id)) app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
cache_result = redis_client.get(app_annotation_job_key) cache_result = redis_client.get(app_annotation_job_key)
if cache_result is None: if cache_result is None:
raise ValueError("The job is not exist.") raise ValueError("The job does not exist.")
job_status = cache_result.decode() job_status = cache_result.decode()
error_msg = "" error_msg = ""
@ -226,7 +226,7 @@ class AnnotationBatchImportStatusApi(Resource):
indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id))
cache_result = redis_client.get(indexing_cache_key) cache_result = redis_client.get(indexing_cache_key)
if cache_result is None: if cache_result is None:
raise ValueError("The job is not exist.") raise ValueError("The job does not exist.")
job_status = cache_result.decode() job_status = cache_result.decode()
error_msg = "" error_msg = ""
if job_status == "error": if job_status == "error":

@ -8,6 +8,7 @@ from werkzeug.exceptions import Forbidden
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_resource_check,
setup_required, setup_required,
) )
from extensions.ext_database import db from extensions.ext_database import db
@ -23,6 +24,7 @@ class AppImportApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(app_import_fields) @marshal_with(app_import_fields)
@cloud_edition_billing_resource_check("apps")
def post(self): def post(self):
# Check user role first # Check user role first
if not current_user.is_editor: if not current_user.is_editor:

@ -1,5 +1,4 @@
from datetime import datetime from dateutil.parser import isoparse
from flask_restful import Resource, marshal_with, reqparse # type: ignore from flask_restful import Resource, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range # type: ignore from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -41,10 +40,10 @@ class WorkflowAppLogApi(Resource):
args.status = WorkflowRunStatus(args.status) if args.status else None args.status = WorkflowRunStatus(args.status) if args.status else None
if args.created_at__before: if args.created_at__before:
args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00")) args.created_at__before = isoparse(args.created_at__before)
if args.created_at__after: if args.created_at__after:
args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00")) args.created_at__after = isoparse(args.created_at__after)
# get paginate workflow app logs # get paginate workflow app logs
workflow_app_service = WorkflowAppService() workflow_app_service = WorkflowAppService()

@ -74,7 +74,9 @@ class OAuthDataSourceBinding(Resource):
if not oauth_provider: if not oauth_provider:
return {"error": "Invalid provider"}, 400 return {"error": "Invalid provider"}, 400
if "code" in request.args: if "code" in request.args:
code = request.args.get("code") code = request.args.get("code", "")
if not code:
return {"error": "Invalid code"}, 400
try: try:
oauth_provider.get_access_token(code) oauth_provider.get_access_token(code)
except requests.exceptions.HTTPError as e: except requests.exceptions.HTTPError as e:

@ -99,44 +99,57 @@ class ForgotPasswordResetApi(Resource):
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
new_password = args["new_password"] # Validate passwords match
password_confirm = args["password_confirm"] if args["new_password"] != args["password_confirm"]:
if str(new_password).strip() != str(password_confirm).strip():
raise PasswordMismatchError() raise PasswordMismatchError()
token = args["token"] # Validate token and get reset data
reset_data = AccountService.get_reset_password_data(token) reset_data = AccountService.get_reset_password_data(args["token"])
if not reset_data:
if reset_data is None:
raise InvalidTokenError() raise InvalidTokenError()
AccountService.revoke_reset_password_token(token) # Revoke token to prevent reuse
AccountService.revoke_reset_password_token(args["token"])
# Generate secure salt and hash password
salt = secrets.token_bytes(16) salt = secrets.token_bytes(16)
base64_salt = base64.b64encode(salt).decode() password_hashed = hash_password(args["new_password"], salt)
password_hashed = hash_password(new_password, salt) email = reset_data.get("email", "")
base64_password_hashed = base64.b64encode(password_hashed).decode()
with Session(db.engine) as session: with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=reset_data.get("email"))).scalar_one_or_none() account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
if account: if account:
account.password = base64_password_hashed self._update_existing_account(account, password_hashed, salt, session)
account.password_salt = base64_salt else:
db.session.commit() self._create_new_account(email, args["password_confirm"])
tenant = TenantService.get_join_tenants(account)
if not tenant and not FeatureService.get_system_features().is_allow_create_workspace: return {"result": "success"}
def _update_existing_account(self, account, password_hashed, salt, session):
# Update existing account credentials
account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode()
session.commit()
# Create workspace if needed
if (
not TenantService.get_join_tenants(account)
and FeatureService.get_system_features().is_allow_create_workspace
):
tenant = TenantService.create_tenant(f"{account.name}'s Workspace") tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner") TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant account.current_tenant = tenant
tenant_was_created.send(tenant) tenant_was_created.send(tenant)
else:
def _create_new_account(self, email, password):
# Create new account if allowed
try: try:
account = AccountService.create_account_and_tenant( AccountService.create_account_and_tenant(
email=reset_data.get("email", ""), email=email,
name=reset_data.get("email", ""), name=email,
password=password_confirm, password=password,
interface_language=languages[0], interface_language=languages[0],
) )
except WorkSpaceNotAllowedCreateError: except WorkSpaceNotAllowedCreateError:
@ -144,8 +157,6 @@ class ForgotPasswordResetApi(Resource):
except AccountRegisterError: except AccountRegisterError:
raise AccountInFreezeError() raise AccountInFreezeError()
return {"result": "success"}
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password") api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity") api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")

@ -641,7 +641,6 @@ class DatasetRetrievalSettingApi(Resource):
VectorType.RELYT VectorType.RELYT
| VectorType.TIDB_VECTOR | VectorType.TIDB_VECTOR
| VectorType.CHROMA | VectorType.CHROMA
| VectorType.TENCENT
| VectorType.PGVECTO_RS | VectorType.PGVECTO_RS
| VectorType.BAIDU | VectorType.BAIDU
| VectorType.VIKINGDB | VectorType.VIKINGDB
@ -665,6 +664,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.OPENGAUSS | VectorType.OPENGAUSS
| VectorType.OCEANBASE | VectorType.OCEANBASE
| VectorType.TABLESTORE | VectorType.TABLESTORE
| VectorType.TENCENT
): ):
return { return {
"retrieval_method": [ "retrieval_method": [
@ -688,7 +688,6 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.RELYT | VectorType.RELYT
| VectorType.TIDB_VECTOR | VectorType.TIDB_VECTOR
| VectorType.CHROMA | VectorType.CHROMA
| VectorType.TENCENT
| VectorType.PGVECTO_RS | VectorType.PGVECTO_RS
| VectorType.BAIDU | VectorType.BAIDU
| VectorType.VIKINGDB | VectorType.VIKINGDB
@ -710,6 +709,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.OPENGAUSS | VectorType.OPENGAUSS
| VectorType.OCEANBASE | VectorType.OCEANBASE
| VectorType.TABLESTORE | VectorType.TABLESTORE
| VectorType.TENCENT
): ):
return { return {
"retrieval_method": [ "retrieval_method": [

@ -398,7 +398,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
indexing_cache_key = "segment_batch_import_{}".format(job_id) indexing_cache_key = "segment_batch_import_{}".format(job_id)
cache_result = redis_client.get(indexing_cache_key) cache_result = redis_client.get(indexing_cache_key)
if cache_result is None: if cache_result is None:
raise ValueError("The job is not exist.") raise ValueError("The job does not exist.")
return {"job_id": job_id, "job_status": cache_result.decode()}, 200 return {"job_id": job_id, "job_status": cache_result.decode()}, 200

@ -21,12 +21,6 @@ def _validate_name(name):
return name return name
def _validate_description_length(description):
if description and len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class ExternalApiTemplateListApi(Resource): class ExternalApiTemplateListApi(Resource):
@setup_required @setup_required
@login_required @login_required

@ -14,18 +14,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
from services.metadata_service import MetadataService from services.metadata_service import MetadataService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class DatasetMetadataCreateApi(Resource): class DatasetMetadataCreateApi(Resource):
@setup_required @setup_required
@login_required @login_required

@ -14,7 +14,12 @@ class WebsiteCrawlApi(Resource):
def post(self): def post(self):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument( parser.add_argument(
"provider", type=str, choices=["firecrawl", "jinareader"], required=True, nullable=True, location="json" "provider",
type=str,
choices=["firecrawl", "watercrawl", "jinareader"],
required=True,
nullable=True,
location="json",
) )
parser.add_argument("url", type=str, required=True, nullable=True, location="json") parser.add_argument("url", type=str, required=True, nullable=True, location="json")
parser.add_argument("options", type=dict, required=True, nullable=True, location="json") parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
@ -34,7 +39,9 @@ class WebsiteCrawlStatusApi(Resource):
@account_initialization_required @account_initialization_required
def get(self, job_id: str): def get(self, job_id: str):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, choices=["firecrawl", "jinareader"], required=True, location="args") parser.add_argument(
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
)
args = parser.parse_args() args = parser.parse_args()
# get crawl status # get crawl status
try: try:

@ -103,6 +103,18 @@ class AccountInFreezeError(BaseHTTPException):
) )
class EducationVerifyLimitError(BaseHTTPException):
error_code = "education_verify_limit"
description = "Rate limit exceeded"
code = 429
class EducationActivateLimitError(BaseHTTPException):
error_code = "education_activate_limit"
description = "Rate limit exceeded"
code = 429
class CompilanceRateLimitError(BaseHTTPException): class CompilanceRateLimitError(BaseHTTPException):
error_code = "compilance_rate_limit" error_code = "compilance_rate_limit"
description = "Rate limit exceeded for downloading compliance report." description = "Rate limit exceeded for downloading compliance report."

@ -1,10 +1,10 @@
from flask_restful import marshal_with # type: ignore from flask_restful import marshal_with # type: ignore
from controllers.common import fields from controllers.common import fields
from controllers.common import helpers as controller_helpers
from controllers.console import api from controllers.console import api
from controllers.console.app.error import AppUnavailableError from controllers.console.app.error import AppUnavailableError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import AppMode, InstalledApp from models.model import AppMode, InstalledApp
from services.app_service import AppService from services.app_service import AppService
@ -36,9 +36,7 @@ class AppParameterApi(InstalledAppResource):
user_input_form = features_dict.get("user_input_form", []) user_input_form = features_dict.get("user_input_form", [])
return controller_helpers.get_parameters_from_feature_dict( return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
features_dict=features_dict, user_input_form=user_input_form
)
class ExploreAppMetaApi(InstalledAppResource): class ExploreAppMetaApi(InstalledAppResource):

@ -15,7 +15,13 @@ from controllers.console.workspace.error import (
InvalidInvitationCodeError, InvalidInvitationCodeError,
RepeatPasswordNotMatchError, RepeatPasswordNotMatchError,
) )
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_enabled,
enterprise_license_required,
only_edition_cloud,
setup_required,
)
from extensions.ext_database import db from extensions.ext_database import db
from fields.member_fields import account_fields from fields.member_fields import account_fields
from libs.helper import TimestampField, timezone from libs.helper import TimestampField, timezone
@ -280,8 +286,6 @@ class AccountDeleteApi(Resource):
class AccountDeleteUpdateFeedbackApi(Resource): class AccountDeleteUpdateFeedbackApi(Resource):
@setup_required @setup_required
def post(self): def post(self):
account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json") parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("feedback", type=str, required=True, location="json") parser.add_argument("feedback", type=str, required=True, location="json")
@ -292,6 +296,79 @@ class AccountDeleteUpdateFeedbackApi(Resource):
return {"result": "success"} return {"result": "success"}
class EducationVerifyApi(Resource):
verify_fields = {
"token": fields.String,
}
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
@marshal_with(verify_fields)
def get(self):
account = current_user
return BillingService.EducationIdentity.verify(account.id, account.email)
class EducationApi(Resource):
status_fields = {
"result": fields.Boolean,
}
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
def post(self):
account = current_user
parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, location="json")
parser.add_argument("institution", type=str, required=True, location="json")
parser.add_argument("role", type=str, required=True, location="json")
args = parser.parse_args()
return BillingService.EducationIdentity.activate(account, args["token"], args["institution"], args["role"])
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
@marshal_with(status_fields)
def get(self):
account = current_user
return BillingService.EducationIdentity.is_active(account.id)
class EducationAutoCompleteApi(Resource):
data_fields = {
"data": fields.List(fields.String),
"curr_page": fields.Integer,
"has_next": fields.Boolean,
}
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
@marshal_with(data_fields)
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("keywords", type=str, required=True, location="args")
parser.add_argument("page", type=int, required=False, location="args", default=0)
parser.add_argument("limit", type=int, required=False, location="args", default=20)
args = parser.parse_args()
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
# Register API resources # Register API resources
api.add_resource(AccountInitApi, "/account/init") api.add_resource(AccountInitApi, "/account/init")
api.add_resource(AccountProfileApi, "/account/profile") api.add_resource(AccountProfileApi, "/account/profile")
@ -305,5 +382,8 @@ api.add_resource(AccountIntegrateApi, "/account/integrates")
api.add_resource(AccountDeleteVerifyApi, "/account/delete/verify") api.add_resource(AccountDeleteVerifyApi, "/account/delete/verify")
api.add_resource(AccountDeleteApi, "/account/delete") api.add_resource(AccountDeleteApi, "/account/delete")
api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback") api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback")
api.add_resource(EducationVerifyApi, "/account/education/verify")
api.add_resource(EducationApi, "/account/education")
api.add_resource(EducationAutoCompleteApi, "/account/education/autocomplete")
# api.add_resource(AccountEmailApi, '/account/email') # api.add_resource(AccountEmailApi, '/account/email')
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify') # api.add_resource(AccountEmailVerifyApi, '/account/email-verify')

@ -49,6 +49,23 @@ class PluginListApi(Resource):
return jsonable_encoder({"plugins": plugins}) return jsonable_encoder({"plugins": plugins})
class PluginListLatestVersionsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
req = reqparse.RequestParser()
req.add_argument("plugin_ids", type=list, required=True, location="json")
args = req.parse_args()
try:
versions = PluginService.list_latest_versions(args["plugin_ids"])
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"versions": versions})
class PluginListInstallationsFromIdsApi(Resource): class PluginListInstallationsFromIdsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -232,11 +249,36 @@ class PluginInstallFromMarketplaceApi(Resource):
return jsonable_encoder(response) return jsonable_encoder(response)
class PluginFetchMarketplacePkgApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(install_required=True)
def get(self):
tenant_id = current_user.current_tenant_id
parser = reqparse.RequestParser()
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
args = parser.parse_args()
try:
return jsonable_encoder(
{
"manifest": PluginService.fetch_marketplace_pkg(
tenant_id,
args["plugin_unique_identifier"],
)
}
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
class PluginFetchManifestApi(Resource): class PluginFetchManifestApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@plugin_permission_required(debug_required=True) @plugin_permission_required(install_required=True)
def get(self): def get(self):
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
@ -260,7 +302,7 @@ class PluginFetchInstallTasksApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@plugin_permission_required(debug_required=True) @plugin_permission_required(install_required=True)
def get(self): def get(self):
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
@ -281,7 +323,7 @@ class PluginFetchInstallTaskApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@plugin_permission_required(debug_required=True) @plugin_permission_required(install_required=True)
def get(self, task_id: str): def get(self, task_id: str):
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
@ -295,7 +337,7 @@ class PluginDeleteInstallTaskApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@plugin_permission_required(debug_required=True) @plugin_permission_required(install_required=True)
def post(self, task_id: str): def post(self, task_id: str):
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
@ -309,7 +351,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@plugin_permission_required(debug_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
@ -323,7 +365,7 @@ class PluginDeleteInstallTaskItemApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@plugin_permission_required(debug_required=True) @plugin_permission_required(install_required=True)
def post(self, task_id: str, identifier: str): def post(self, task_id: str, identifier: str):
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
@ -337,7 +379,7 @@ class PluginUpgradeFromMarketplaceApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@plugin_permission_required(debug_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
@ -360,7 +402,7 @@ class PluginUpgradeFromGithubApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@plugin_permission_required(debug_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
@ -391,7 +433,7 @@ class PluginUninstallApi(Resource):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@plugin_permission_required(debug_required=True) @plugin_permission_required(install_required=True)
def post(self): def post(self):
req = reqparse.RequestParser() req = reqparse.RequestParser()
req.add_argument("plugin_installation_id", type=str, required=True, location="json") req.add_argument("plugin_installation_id", type=str, required=True, location="json")
@ -453,6 +495,7 @@ class PluginFetchPermissionApi(Resource):
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key") api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
api.add_resource(PluginListApi, "/workspaces/current/plugin/list") api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions")
api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids") api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids")
api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon") api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon")
api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg") api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg")
@ -470,6 +513,7 @@ api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<
api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all") api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all")
api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>") api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall") api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")
api.add_resource(PluginFetchMarketplacePkgApi, "/workspaces/current/plugin/marketplace/pkg")
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change") api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch") api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")

@ -216,6 +216,23 @@ class WebappLogoWorkspaceApi(Resource):
return {"id": upload_file.id}, 201 return {"id": upload_file.id}, 201
class WorkspaceInfoApi(Resource):
@setup_required
@login_required
@account_initialization_required
# Change workspace name
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args()
tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404()
tenant.name = args["name"]
db.session.commit()
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants
api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants
api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info
@ -223,3 +240,4 @@ api.add_resource(TenantApi, "/info", endpoint="info") # Deprecated
api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant
api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config") api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config")
api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload") api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload")
api.add_resource(WorkspaceInfoApi, "/workspaces/info") # POST for changing workspace info

@ -54,6 +54,17 @@ def only_edition_self_hosted(view):
return decorated return decorated
def cloud_edition_billing_enabled(view):
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled:
abort(403, "Billing feature is not enabled.")
return view(*args, **kwargs)
return decorated
def cloud_edition_billing_resource_check(resource: str): def cloud_edition_billing_resource_check(resource: str):
def interceptor(view): def interceptor(view):
@wraps(view) @wraps(view)

@ -13,6 +13,7 @@ from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocatio
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
from core.plugin.entities.request import ( from core.plugin.entities.request import (
RequestFetchAppInfo,
RequestInvokeApp, RequestInvokeApp,
RequestInvokeEncrypt, RequestInvokeEncrypt,
RequestInvokeLLM, RequestInvokeLLM,
@ -278,6 +279,17 @@ class PluginUploadFileRequestApi(Resource):
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump() return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
class PluginFetchAppInfoApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestFetchAppInfo)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestFetchAppInfo):
return BaseBackwardsInvocationResponse(
data=PluginAppBackwardsInvocation.fetch_app_info(payload.app_id, tenant_model.id)
).model_dump()
api.add_resource(PluginInvokeLLMApi, "/invoke/llm") api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding") api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank") api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
@ -291,3 +303,4 @@ api.add_resource(PluginInvokeAppApi, "/invoke/app")
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt") api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")
api.add_resource(PluginInvokeSummaryApi, "/invoke/summary") api.add_resource(PluginInvokeSummaryApi, "/invoke/summary")
api.add_resource(PluginUploadFileRequestApi, "/upload/file/request") api.add_resource(PluginUploadFileRequestApi, "/upload/file/request")
api.add_resource(PluginFetchAppInfoApi, "/fetch/app/info")

@ -6,5 +6,6 @@ bp = Blueprint("service_api", __name__, url_prefix="/v1")
api = ExternalApi(bp) api = ExternalApi(bp)
from . import index from . import index
from .app import app, audio, completion, conversation, file, message, workflow from .app import annotation, app, audio, completion, conversation, file, message, workflow
from .dataset import dataset, document, hit_testing, metadata, segment, upload_file from .dataset import dataset, document, hit_testing, metadata, segment, upload_file
from .workspace import models

@ -0,0 +1,107 @@
from flask import request
from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
from controllers.service_api import api
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from extensions.ext_redis import redis_client
from fields.annotation_fields import (
annotation_fields,
)
from libs.login import current_user
from models.model import App, EndUser
from services.annotation_service import AppAnnotationService
class AnnotationReplyActionApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
def post(self, app_model: App, end_user: EndUser, action):
parser = reqparse.RequestParser()
parser.add_argument("score_threshold", required=True, type=float, location="json")
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
parser.add_argument("embedding_model_name", required=True, type=str, location="json")
args = parser.parse_args()
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_model.id)
else:
raise ValueError("Unsupported annotation reply action")
return result, 200
class AnnotationReplyActionStatusApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, job_id, action):
job_id = str(job_id)
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
cache_result = redis_client.get(app_annotation_job_key)
if cache_result is None:
raise ValueError("The job does not exist.")
job_status = cache_result.decode()
error_msg = ""
if job_status == "error":
app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id))
error_msg = redis_client.get(app_annotation_error_key).decode()
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
class AnnotationListApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default="", type=str)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword)
response = {
"data": marshal(annotation_list, annotation_fields),
"has_more": len(annotation_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response, 200
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(annotation_fields)
def post(self, app_model: App, end_user: EndUser):
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
return annotation
class AnnotationUpdateDeleteApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(annotation_fields)
def post(self, app_model: App, end_user: EndUser, annotation_id):
if not current_user.is_editor:
raise Forbidden()
annotation_id = str(annotation_id)
parser = reqparse.RequestParser()
parser.add_argument("question", required=True, type=str, location="json")
parser.add_argument("answer", required=True, type=str, location="json")
args = parser.parse_args()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
return annotation
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def delete(self, app_model: App, end_user: EndUser, annotation_id):
if not current_user.is_editor:
raise Forbidden()
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
return {"result": "success"}, 200
api.add_resource(AnnotationReplyActionApi, "/apps/annotation-reply/<string:action>")
api.add_resource(AnnotationReplyActionStatusApi, "/apps/annotation-reply/<string:action>/status/<uuid:job_id>")
api.add_resource(AnnotationListApi, "/apps/annotations")
api.add_resource(AnnotationUpdateDeleteApi, "/apps/annotations/<uuid:annotation_id>")

@ -1,10 +1,10 @@
from flask_restful import Resource, marshal_with # type: ignore from flask_restful import Resource, marshal_with # type: ignore
from controllers.common import fields from controllers.common import fields
from controllers.common import helpers as controller_helpers
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.app.error import AppUnavailableError
from controllers.service_api.wraps import validate_app_token from controllers.service_api.wraps import validate_app_token
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import App, AppMode from models.model import App, AppMode
from services.app_service import AppService from services.app_service import AppService
@ -32,9 +32,7 @@ class AppParameterApi(Resource):
user_input_form = features_dict.get("user_input_form", []) user_input_form = features_dict.get("user_input_form", [])
return controller_helpers.get_parameters_from_feature_dict( return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
features_dict=features_dict, user_input_form=user_input_form
)
class AppMetaApi(Resource): class AppMetaApi(Resource):

@ -1,3 +1,4 @@
import json
import logging import logging
from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
@ -10,7 +11,7 @@ from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import message_file_fields from fields.conversation_fields import message_file_fields
from fields.message_fields import agent_thought_fields, feedback_fields, retriever_resource_fields from fields.message_fields import agent_thought_fields, feedback_fields
from fields.raws import FilesContainedField from fields.raws import FilesContainedField
from libs.helper import TimestampField, uuid_value from libs.helper import TimestampField, uuid_value
from models.model import App, AppMode, EndUser from models.model import App, AppMode, EndUser
@ -28,7 +29,11 @@ class MessageListApi(Resource):
"answer": fields.String(attribute="re_sign_file_url_answer"), "answer": fields.String(attribute="re_sign_file_url_answer"),
"message_files": fields.List(fields.Nested(message_file_fields)), "message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), "retriever_resources": fields.Raw(
attribute=lambda obj: json.loads(obj.message_metadata).get("retriever_resources", [])
if obj.message_metadata
else []
),
"created_at": TimestampField, "created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"status": fields.String, "status": fields.String,

@ -1,6 +1,6 @@
import logging import logging
from datetime import datetime
from dateutil.parser import isoparse
from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range # type: ignore from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -140,10 +140,10 @@ class WorkflowAppLogApi(Resource):
args.status = WorkflowRunStatus(args.status) if args.status else None args.status = WorkflowRunStatus(args.status) if args.status else None
if args.created_at__before: if args.created_at__before:
args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00")) args.created_at__before = isoparse(args.created_at__before)
if args.created_at__after: if args.created_at__after:
args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00")) args.created_at__after = isoparse(args.created_at__after)
# get paginate workflow app logs # get paginate workflow app logs
workflow_app_service = WorkflowAppService() workflow_app_service = WorkflowAppService()

@ -1,6 +1,6 @@
from flask import request from flask import request
from flask_restful import marshal, reqparse # type: ignore from flask_restful import marshal, reqparse # type: ignore
from werkzeug.exceptions import NotFound from werkzeug.exceptions import Forbidden, NotFound
import services.dataset_service import services.dataset_service
from controllers.service_api import api from controllers.service_api import api
@ -12,7 +12,8 @@ from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields from fields.dataset_fields import dataset_detail_fields
from libs.login import current_user from libs.login import current_user
from models.dataset import Dataset, DatasetPermissionEnum from models.dataset import Dataset, DatasetPermissionEnum
from services.dataset_service import DatasetService from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
def _validate_name(name): def _validate_name(name):
@ -21,6 +22,12 @@ def _validate_name(name):
return name return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class DatasetListApi(DatasetApiResource): class DatasetListApi(DatasetApiResource):
"""Resource for datasets.""" """Resource for datasets."""
@ -114,8 +121,11 @@ class DatasetListApi(DatasetApiResource):
nullable=True, nullable=True,
required=False, required=False,
) )
args = parser.parse_args() parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
try: try:
dataset = DatasetService.create_empty_dataset( dataset = DatasetService.create_empty_dataset(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -127,6 +137,11 @@ class DatasetListApi(DatasetApiResource):
provider=args["provider"], provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"], external_knowledge_api_id=args["external_knowledge_api_id"],
external_knowledge_id=args["external_knowledge_id"], external_knowledge_id=args["external_knowledge_id"],
embedding_model_provider=args["embedding_model_provider"],
embedding_model_name=args["embedding_model"],
retrieval_model=RetrievalModel(**args["retrieval_model"])
if args["retrieval_model"] is not None
else None,
) )
except services.errors.dataset.DatasetNameDuplicateError: except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError() raise DatasetNameDuplicateError()
@ -137,6 +152,145 @@ class DatasetListApi(DatasetApiResource):
class DatasetApi(DatasetApiResource): class DatasetApi(DatasetApiResource):
"""Resource for dataset.""" """Resource for dataset."""
def get(self, _, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
model_names = []
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
if data["indexing_technique"] == "high_quality":
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
if item_model in model_names:
data["embedding_available"] = True
else:
data["embedding_available"] = False
else:
data["embedding_available"] = True
if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})
return data, 200
def patch(self, _, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
parser.add_argument(
"indexing_technique",
type=str,
location="json",
choices=Dataset.INDEXING_TECHNIQUE_LIST,
nullable=True,
help="Invalid indexing technique.",
)
parser.add_argument(
"permission",
type=str,
location="json",
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
help="Invalid permission.",
)
parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
parser.add_argument(
"embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
)
parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
parser.add_argument(
"external_retrieval_model",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid external retrieval model.",
)
parser.add_argument(
"external_knowledge_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge id.",
)
parser.add_argument(
"external_knowledge_api_id",
type=str,
required=False,
nullable=True,
location="json",
help="Invalid external knowledge api id.",
)
args = parser.parse_args()
data = request.get_json()
# check embedding model setting
if data.get("indexing_technique") == "high_quality":
DatasetService.check_embedding_model_setting(
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
)
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
current_user, dataset, data.get("permission"), data.get("partial_member_list")
)
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
if dataset is None:
raise NotFound("Dataset not found.")
result_data = marshal(dataset, dataset_detail_fields)
tenant_id = current_user.current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members":
DatasetPermissionService.update_partial_member_list(
tenant_id, dataset_id_str, data.get("partial_member_list")
)
# clear partial member list when permission is only_me or all_team_members
elif (
data.get("permission") == DatasetPermissionEnum.ONLY_ME
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
result_data.update({"partial_member_list": partial_member_list})
return result_data, 200
def delete(self, _, dataset_id): def delete(self, _, dataset_id):
""" """
Deletes a dataset given its ID. Deletes a dataset given its ID.
@ -158,6 +312,7 @@ class DatasetApi(DatasetApiResource):
try: try:
if DatasetService.delete_dataset(dataset_id_str, current_user): if DatasetService.delete_dataset(dataset_id_str, current_user):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
return {"result": "success"}, 204 return {"result": "success"}, 204
else: else:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")

@ -49,7 +49,9 @@ class DocumentAddByTextApi(DatasetApiResource):
parser.add_argument( parser.add_argument(
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
) )
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") parser.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -57,7 +59,7 @@ class DocumentAddByTextApi(DatasetApiResource):
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset is not exist.") raise ValueError("Dataset does not exist.")
if not dataset.indexing_technique and not args["indexing_technique"]: if not dataset.indexing_technique and not args["indexing_technique"]:
raise ValueError("indexing_technique is required.") raise ValueError("indexing_technique is required.")
@ -114,7 +116,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset is not exist.") raise ValueError("Dataset does not exist.")
# indexing_technique is already set in dataset since this is an update # indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique args["indexing_technique"] = dataset.indexing_technique
@ -172,7 +174,7 @@ class DocumentAddByFileApi(DatasetApiResource):
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset is not exist.") raise ValueError("Dataset does not exist.")
if not dataset.indexing_technique and not args.get("indexing_technique"): if not dataset.indexing_technique and not args.get("indexing_technique"):
raise ValueError("indexing_technique is required.") raise ValueError("indexing_technique is required.")
@ -239,7 +241,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset is not exist.") raise ValueError("Dataset does not exist.")
# indexing_technique is already set in dataset since this is an update # indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique args["indexing_technique"] = dataset.indexing_technique
@ -303,7 +305,7 @@ class DocumentDeleteApi(DatasetApiResource):
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset is not exist.") raise ValueError("Dataset does not exist.")
document = DocumentService.get_document(dataset.id, document_id) document = DocumentService.get_document(dataset.id, document_id)
@ -341,7 +343,7 @@ class DocumentListApi(DatasetApiResource):
search = f"%{search}%" search = f"%{search}%"
query = query.filter(Document.name.like(search)) query = query.filter(Document.name.like(search))
query = query.order_by(desc(Document.created_at)) query = query.order_by(desc(Document.created_at), desc(Document.position))
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items documents = paginated_documents.items

@ -13,18 +13,6 @@ from services.entities.knowledge_entities.knowledge_entities import (
from services.metadata_service import MetadataService from services.metadata_service import MetadataService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class DatasetMetadataCreateServiceApi(DatasetApiResource): class DatasetMetadataCreateServiceApi(DatasetApiResource):
def post(self, tenant_id, dataset_id): def post(self, tenant_id, dataset_id):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()

@ -117,14 +117,13 @@ class SegmentApi(DatasetApiResource):
parser.add_argument("keyword", type=str, default=None, location="args") parser.add_argument("keyword", type=str, default=None, location="args")
args = parser.parse_args() args = parser.parse_args()
status_list = args["status"]
keyword = args["keyword"]
segments, total = SegmentService.get_segments( segments, total = SegmentService.get_segments(
document_id=document_id, document_id=document_id,
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
status_list=args["status"], status_list=args["status"],
keyword=args["keyword"], keyword=args["keyword"],
page=page,
limit=limit,
) )
response = { response = {

@ -0,0 +1,21 @@
from flask_login import current_user # type: ignore
from flask_restful import Resource # type: ignore
from controllers.service_api import api
from controllers.service_api.wraps import validate_dataset_token
from core.model_runtime.utils.encoders import jsonable_encoder
from services.model_provider_service import ModelProviderService
class ModelProviderAvailableModelApi(Resource):
@validate_dataset_token
def get(self, _, model_type):
tenant_id = current_user.current_tenant_id
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type)
return jsonable_encoder({"data": models})
api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>")

@ -59,6 +59,27 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
if tenant.status == TenantStatus.ARCHIVE: if tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("The workspace's status is archived.") raise Forbidden("The workspace's status is archived.")
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == api_token.tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.role.in_(["owner"]))
.filter(Tenant.status == TenantStatus.NORMAL)
.one_or_none()
) # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
account = db.session.query(Account).filter(Account.id == ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
else:
raise Unauthorized("Tenant owner account does not exist.")
else:
raise Unauthorized("Tenant does not exist.")
kwargs["app_model"] = app_model kwargs["app_model"] = app_model
if fetch_user_arg: if fetch_user_arg:

@ -1,10 +1,10 @@
from flask_restful import marshal_with # type: ignore from flask_restful import marshal_with # type: ignore
from controllers.common import fields from controllers.common import fields
from controllers.common import helpers as controller_helpers
from controllers.web import api from controllers.web import api
from controllers.web.error import AppUnavailableError from controllers.web.error import AppUnavailableError
from controllers.web.wraps import WebApiResource from controllers.web.wraps import WebApiResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from models.model import App, AppMode from models.model import App, AppMode
from services.app_service import AppService from services.app_service import AppService
@ -31,9 +31,7 @@ class AppParameterApi(WebApiResource):
user_input_form = features_dict.get("user_input_form", []) user_input_form = features_dict.get("user_input_form", [])
return controller_helpers.get_parameters_from_feature_dict( return get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
features_dict=features_dict, user_input_form=user_input_form
)
class AppMeta(WebApiResource): class AppMeta(WebApiResource):

@ -46,6 +46,7 @@ class MessageListApi(WebApiResource):
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"created_at": TimestampField, "created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String, "status": fields.String,
"error": fields.String, "error": fields.String,
} }

@ -191,7 +191,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# action is final answer, return final answer directly # action is final answer, return final answer directly
try: try:
if isinstance(scratchpad.action.action_input, dict): if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps(scratchpad.action.action_input) final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False)
elif isinstance(scratchpad.action.action_input, str): elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input final_answer = scratchpad.action.action_input
else: else:

@ -52,6 +52,7 @@ class AgentStrategyParameter(PluginParameter):
return cast_parameter_value(self, value) return cast_parameter_value(self, value)
type: AgentStrategyParameterType = Field(..., description="The type of the parameter") type: AgentStrategyParameterType = Field(..., description="The type of the parameter")
help: Optional[I18nObject] = None
def init_frontend_parameter(self, value: Any): def init_frontend_parameter(self, value: Any):
return init_frontend_parameter(self, self.type, value) return init_frontend_parameter(self, self.type, value)
@ -70,11 +71,20 @@ class AgentStrategyIdentity(ToolIdentity):
pass pass
class AgentFeature(enum.StrEnum):
"""
Agent Feature, used to describe the features of the agent strategy.
"""
HISTORY_MESSAGES = "history-messages"
class AgentStrategyEntity(BaseModel): class AgentStrategyEntity(BaseModel):
identity: AgentStrategyIdentity identity: AgentStrategyIdentity
parameters: list[AgentStrategyParameter] = Field(default_factory=list) parameters: list[AgentStrategyParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The description of the agent strategy") description: I18nObject = Field(..., description="The description of the agent strategy")
output_schema: Optional[dict] = None output_schema: Optional[dict] = None
features: Optional[list[AgentFeature]] = None
# pydantic configs # pydantic configs
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())

@ -0,0 +1,45 @@
from collections.abc import Mapping
from typing import Any
from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
def get_parameters_from_feature_dict(
*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]
) -> Mapping[str, Any]:
"""
Mapping from feature dict to webapp parameters
"""
return {
"opening_statement": features_dict.get("opening_statement"),
"suggested_questions": features_dict.get("suggested_questions", []),
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
"user_input_form": user_input_form,
"sensitive_word_avoidance": features_dict.get(
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
),
"file_upload": features_dict.get(
"file_upload",
{
"image": {
"enabled": False,
"number_limits": DEFAULT_FILE_NUMBER_LIMITS,
"detail": "high",
"transfer_methods": ["remote_url", "local_file"],
}
},
),
"system_parameters": {
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
},
}

@ -1,6 +1,7 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any from typing import Any
from constants import DEFAULT_FILE_NUMBER_LIMITS
from core.file import FileUploadConfig from core.file import FileUploadConfig
@ -18,7 +19,7 @@ class FileUploadConfigManager:
if file_upload_dict.get("enabled"): if file_upload_dict.get("enabled"):
transform_methods = file_upload_dict.get("allowed_file_upload_methods", []) transform_methods = file_upload_dict.get("allowed_file_upload_methods", [])
file_upload_dict["image_config"] = { file_upload_dict["image_config"] = {
"number_limits": file_upload_dict.get("number_limits", 1), "number_limits": file_upload_dict.get("number_limits", DEFAULT_FILE_NUMBER_LIMITS),
"transfer_methods": transform_methods, "transfer_methods": transform_methods,
} }

@ -320,10 +320,9 @@ class AdvancedChatAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id session=session, workflow_run_id=self._workflow_run_id
) )
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event workflow_run=workflow_run, event=event
) )
node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
@ -341,11 +340,10 @@ class AdvancedChatAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id session=session, workflow_run_id=self._workflow_run_id
) )
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event workflow_run=workflow_run, event=event
) )
node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response( node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
@ -363,11 +361,10 @@ class AdvancedChatAppGenerateTaskPipeline:
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
session=session, event=event event=event
) )
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
@ -383,18 +380,15 @@ class AdvancedChatAppGenerateTaskPipeline:
| QueueNodeInLoopFailedEvent | QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent, | QueueNodeExceptionEvent,
): ):
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
session=session, event=event event=event
) )
node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
) )
session.commit()
if node_finish_resp: if node_finish_resp:
yield node_finish_resp yield node_finish_resp

@ -53,20 +53,6 @@ class AgentChatAppRunner(AppRunner):
query = application_generate_entity.query query = application_generate_entity.query
files = application_generate_entity.files files = application_generate_entity.files
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=dict(inputs),
files=list(files),
query=query,
)
memory = None memory = None
if application_generate_entity.conversation_id: if application_generate_entity.conversation_id:
# get memory of conversation (read-only) # get memory of conversation (read-only)

@ -17,6 +17,7 @@ class BaseAppGenerator:
user_inputs: Optional[Mapping[str, Any]], user_inputs: Optional[Mapping[str, Any]],
variables: Sequence["VariableEntity"], variables: Sequence["VariableEntity"],
tenant_id: str, tenant_id: str,
strict_type_validation: bool = False,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
user_inputs = user_inputs or {} user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values # Filter input variables from form configuration, handle required fields, default values, and option values
@ -37,6 +38,7 @@ class BaseAppGenerator:
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
), ),
strict_type_validation=strict_type_validation,
) )
for k, v in user_inputs.items() for k, v in user_inputs.items()
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE

@ -61,20 +61,6 @@ class ChatAppRunner(AppRunner):
) )
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query,
)
memory = None memory = None
if application_generate_entity.conversation_id: if application_generate_entity.conversation_id:
# get memory of conversation (read-only) # get memory of conversation (read-only)

@ -54,20 +54,6 @@ class CompletionAppRunner(AppRunner):
) )
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=application_generate_entity.model_conf,
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query,
)
# organize all inputs and template to prompt messages # organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional) # Include: prompt template, inputs, query(optional), files(optional)
prompt_messages, stop = self.organize_prompt_messages( prompt_messages, stop = self.organize_prompt_messages(

@ -153,6 +153,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
query = application_generate_entity.query or "New conversation" query = application_generate_entity.query or "New conversation"
else: else:
query = next(iter(application_generate_entity.inputs.values()), "New conversation") query = next(iter(application_generate_entity.inputs.values()), "New conversation")
query = query or "New conversation"
conversation_name = (query[:20] + "") if len(query) > 20 else query conversation_name = (query[:20] + "") if len(query) > 20 else query
if not conversation: if not conversation:

@ -92,6 +92,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
mappings=files, mappings=files,
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
config=file_extra_config, config=file_extra_config,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
) )
# convert to app config # convert to app config
@ -114,7 +115,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_config=app_config, app_config=app_config,
file_upload_config=file_extra_config, file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs( inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id user_inputs=inputs,
variables=app_config.variables,
tenant_id=app_model.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
), ),
files=list(system_files), files=list(system_files),
user_id=user.id, user_id=user.id,

@ -279,10 +279,9 @@ class WorkflowAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id session=session, workflow_run_id=self._workflow_run_id
) )
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event workflow_run=workflow_run, event=event
) )
response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
@ -300,10 +299,9 @@ class WorkflowAppGenerateTaskPipeline:
session=session, workflow_run_id=self._workflow_run_id session=session, workflow_run_id=self._workflow_run_id
) )
workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event workflow_run=workflow_run, event=event
) )
node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response( node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
@ -313,17 +311,14 @@ class WorkflowAppGenerateTaskPipeline:
if node_start_response: if node_start_response:
yield node_start_response yield node_start_response
elif isinstance(event, QueueNodeSucceededEvent): elif isinstance(event, QueueNodeSucceededEvent):
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success(
session=session, event=event event=event
) )
node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
) )
session.commit()
if node_success_response: if node_success_response:
yield node_success_response yield node_success_response
@ -334,18 +329,14 @@ class WorkflowAppGenerateTaskPipeline:
| QueueNodeInLoopFailedEvent | QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent, | QueueNodeExceptionEvent,
): ):
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
session=session,
event=event, event=event,
) )
node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response(
session=session,
event=event, event=event,
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution, workflow_node_execution=workflow_node_execution,
) )
session.commit()
if node_failed_response: if node_failed_response:
yield node_failed_response yield node_failed_response
@ -627,6 +618,7 @@ class WorkflowAppGenerateTaskPipeline:
workflow_app_log.created_by = self._user_id workflow_app_log.created_by = self._user_id
session.add(workflow_app_log) session.add(workflow_app_log)
session.commit()
def _text_chunk_to_stream_response( def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: Optional[list[str]] = None self, text: str, from_variable_selector: Optional[list[str]] = None

@ -6,7 +6,7 @@ from typing import Any, Optional, Union, cast
from uuid import uuid4 from uuid import uuid4
from sqlalchemy import func, select from sqlalchemy import func, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session, sessionmaker
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
@ -49,12 +49,14 @@ from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.repository import RepositoryFactory
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
from models.model import EndUser from models.model import EndUser
@ -80,6 +82,21 @@ class WorkflowCycleManage:
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables self._workflow_system_variables = workflow_system_variables
# Initialize the session factory and repository
# We use the global db engine instead of the session passed to methods
# Disable expire_on_commit to avoid the need for merging objects
self._session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
self._workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
params={
"tenant_id": self._application_generate_entity.app_config.tenant_id,
"app_id": self._application_generate_entity.app_config.app_id,
"session_factory": self._session_factory,
}
)
# We'll still keep the cache for backward compatibility and performance
# but use the repository for database operations
def _handle_workflow_run_start( def _handle_workflow_run_start(
self, self,
*, *,
@ -254,19 +271,15 @@ class WorkflowCycleManage:
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count workflow_run.exceptions_count = exceptions_count
stmt = select(WorkflowNodeExecution.node_execution_id).where( # Use the instance repository to find running executions for a workflow run
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions(
WorkflowNodeExecution.app_id == workflow_run.app_id, workflow_run_id=workflow_run.id
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
) )
ids = session.scalars(stmt).all()
# Use self._get_workflow_node_execution here to make sure the cache is updated # Update the cache with the retrieved executions
running_workflow_node_executions = [ for execution in running_workflow_node_executions:
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id if execution.node_execution_id:
] self._workflow_node_executions[execution.node_execution_id] = execution
for workflow_node_execution in running_workflow_node_executions: for workflow_node_execution in running_workflow_node_executions:
now = datetime.now(UTC).replace(tzinfo=None) now = datetime.now(UTC).replace(tzinfo=None)
@ -288,7 +301,7 @@ class WorkflowCycleManage:
return workflow_run return workflow_run
def _handle_node_execution_start( def _handle_node_execution_start(
self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeStartedEvent self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
) -> WorkflowNodeExecution: ) -> WorkflowNodeExecution:
workflow_node_execution = WorkflowNodeExecution() workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.id = str(uuid4()) workflow_node_execution.id = str(uuid4())
@ -315,17 +328,14 @@ class WorkflowCycleManage:
) )
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
session.add(workflow_node_execution) # Use the instance repository to save the workflow node execution
self._workflow_node_execution_repository.save(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_success( def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
self, *, session: Session, event: QueueNodeSucceededEvent workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
) -> WorkflowNodeExecution:
workflow_node_execution = self._get_workflow_node_execution(
session=session, node_execution_id=event.node_execution_id
)
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data) process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs) outputs = WorkflowEntry.handle_special_values(event.outputs)
@ -344,13 +354,13 @@ class WorkflowCycleManage:
workflow_node_execution.finished_at = finished_at workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution = session.merge(workflow_node_execution) # Use the instance repository to update the workflow node execution
self._workflow_node_execution_repository.update(workflow_node_execution)
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_failed( def _handle_workflow_node_execution_failed(
self, self,
*, *,
session: Session,
event: QueueNodeFailedEvent event: QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent | QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent | QueueNodeInLoopFailedEvent
@ -361,9 +371,7 @@ class WorkflowCycleManage:
:param event: queue node failed event :param event: queue node failed event
:return: :return:
""" """
workflow_node_execution = self._get_workflow_node_execution( workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id)
session=session, node_execution_id=event.node_execution_id
)
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data) process_data = WorkflowEntry.handle_special_values(event.process_data)
@ -387,14 +395,14 @@ class WorkflowCycleManage:
workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_retried( def _handle_workflow_node_execution_retried(
self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeRetryEvent self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution: ) -> WorkflowNodeExecution:
""" """
Workflow node execution failed Workflow node execution failed
:param workflow_run: workflow run
:param event: queue node failed event :param event: queue node failed event
:return: :return:
""" """
@ -439,15 +447,12 @@ class WorkflowCycleManage:
workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution.index = event.node_run_index workflow_node_execution.index = event.node_run_index
session.add(workflow_node_execution) # Use the instance repository to save the workflow node execution
self._workflow_node_execution_repository.save(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution return workflow_node_execution
#################################################
# to stream responses #
#################################################
def _workflow_start_to_stream_response( def _workflow_start_to_stream_response(
self, self,
*, *,
@ -455,7 +460,6 @@ class WorkflowCycleManage:
task_id: str, task_id: str,
workflow_run: WorkflowRun, workflow_run: WorkflowRun,
) -> WorkflowStartStreamResponse: ) -> WorkflowStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return WorkflowStartStreamResponse( return WorkflowStartStreamResponse(
task_id=task_id, task_id=task_id,
@ -521,14 +525,10 @@ class WorkflowCycleManage:
def _workflow_node_start_to_stream_response( def _workflow_node_start_to_stream_response(
self, self,
*, *,
session: Session,
event: QueueNodeStartedEvent, event: QueueNodeStartedEvent,
task_id: str, task_id: str,
workflow_node_execution: WorkflowNodeExecution, workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeStartStreamResponse]: ) -> Optional[NodeStartStreamResponse]:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None return None
if not workflow_node_execution.workflow_run_id: if not workflow_node_execution.workflow_run_id:
@ -571,7 +571,6 @@ class WorkflowCycleManage:
def _workflow_node_finish_to_stream_response( def _workflow_node_finish_to_stream_response(
self, self,
*, *,
session: Session,
event: QueueNodeSucceededEvent event: QueueNodeSucceededEvent
| QueueNodeFailedEvent | QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent | QueueNodeInIterationFailedEvent
@ -580,8 +579,6 @@ class WorkflowCycleManage:
task_id: str, task_id: str,
workflow_node_execution: WorkflowNodeExecution, workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]: ) -> Optional[NodeFinishStreamResponse]:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None return None
if not workflow_node_execution.workflow_run_id: if not workflow_node_execution.workflow_run_id:
@ -621,13 +618,10 @@ class WorkflowCycleManage:
def _workflow_node_retry_to_stream_response( def _workflow_node_retry_to_stream_response(
self, self,
*, *,
session: Session,
event: QueueNodeRetryEvent, event: QueueNodeRetryEvent,
task_id: str, task_id: str,
workflow_node_execution: WorkflowNodeExecution, workflow_node_execution: WorkflowNodeExecution,
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None return None
if not workflow_node_execution.workflow_run_id: if not workflow_node_execution.workflow_run_id:
@ -668,7 +662,6 @@ class WorkflowCycleManage:
def _workflow_parallel_branch_start_to_stream_response( def _workflow_parallel_branch_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse: ) -> ParallelBranchStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return ParallelBranchStartStreamResponse( return ParallelBranchStartStreamResponse(
task_id=task_id, task_id=task_id,
@ -692,7 +685,6 @@ class WorkflowCycleManage:
workflow_run: WorkflowRun, workflow_run: WorkflowRun,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
) -> ParallelBranchFinishedStreamResponse: ) -> ParallelBranchFinishedStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return ParallelBranchFinishedStreamResponse( return ParallelBranchFinishedStreamResponse(
task_id=task_id, task_id=task_id,
@ -713,7 +705,6 @@ class WorkflowCycleManage:
def _workflow_iteration_start_to_stream_response( def _workflow_iteration_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
) -> IterationNodeStartStreamResponse: ) -> IterationNodeStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return IterationNodeStartStreamResponse( return IterationNodeStartStreamResponse(
task_id=task_id, task_id=task_id,
@ -735,7 +726,6 @@ class WorkflowCycleManage:
def _workflow_iteration_next_to_stream_response( def _workflow_iteration_next_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
) -> IterationNodeNextStreamResponse: ) -> IterationNodeNextStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return IterationNodeNextStreamResponse( return IterationNodeNextStreamResponse(
task_id=task_id, task_id=task_id,
@ -759,7 +749,6 @@ class WorkflowCycleManage:
def _workflow_iteration_completed_to_stream_response( def _workflow_iteration_completed_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
) -> IterationNodeCompletedStreamResponse: ) -> IterationNodeCompletedStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return IterationNodeCompletedStreamResponse( return IterationNodeCompletedStreamResponse(
task_id=task_id, task_id=task_id,
@ -790,7 +779,6 @@ class WorkflowCycleManage:
def _workflow_loop_start_to_stream_response( def _workflow_loop_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent
) -> LoopNodeStartStreamResponse: ) -> LoopNodeStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return LoopNodeStartStreamResponse( return LoopNodeStartStreamResponse(
task_id=task_id, task_id=task_id,
@ -812,7 +800,6 @@ class WorkflowCycleManage:
def _workflow_loop_next_to_stream_response( def _workflow_loop_next_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent
) -> LoopNodeNextStreamResponse: ) -> LoopNodeNextStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return LoopNodeNextStreamResponse( return LoopNodeNextStreamResponse(
task_id=task_id, task_id=task_id,
@ -836,7 +823,6 @@ class WorkflowCycleManage:
def _workflow_loop_completed_to_stream_response( def _workflow_loop_completed_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent
) -> LoopNodeCompletedStreamResponse: ) -> LoopNodeCompletedStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session _ = session
return LoopNodeCompletedStreamResponse( return LoopNodeCompletedStreamResponse(
task_id=task_id, task_id=task_id,
@ -934,11 +920,22 @@ class WorkflowCycleManage:
return workflow_run return workflow_run
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution: def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
if node_execution_id not in self._workflow_node_executions: # First check the cache for performance
if node_execution_id in self._workflow_node_executions:
cached_execution = self._workflow_node_executions[node_execution_id]
# No need to merge with session since expire_on_commit=False
return cached_execution
# If not in cache, use the instance repository to get by node_execution_id
execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id)
if not execution:
raise ValueError(f"Workflow node execution not found: {node_execution_id}") raise ValueError(f"Workflow node execution not found: {node_execution_id}")
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
return session.merge(cached_workflow_node_execution) # Update cache
self._workflow_node_executions[node_execution_id] = execution
return execution
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
""" """

@ -6,7 +6,6 @@ from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from models.model import DatasetRetrieverResource
class DatasetIndexToolCallbackHandler: class DatasetIndexToolCallbackHandler:
@ -71,29 +70,6 @@ class DatasetIndexToolCallbackHandler:
def return_retriever_resource_info(self, resource: list): def return_retriever_resource_info(self, resource: list):
"""Handle return_retriever_resource_info.""" """Handle return_retriever_resource_info."""
if resource and len(resource) > 0:
for item in resource:
dataset_retriever_resource = DatasetRetrieverResource(
message_id=self._message_id,
position=item.get("position") or 0,
dataset_id=item.get("dataset_id"),
dataset_name=item.get("dataset_name"),
document_id=item.get("document_id"),
document_name=item.get("document_name"),
data_source_type=item.get("data_source_type"),
segment_id=item.get("segment_id"),
score=item.get("score") if "score" in item else None,
hit_count=item.get("hit_count") if "hit_count" in item else None,
word_count=item.get("word_count") if "word_count" in item else None,
segment_position=item.get("segment_position") if "segment_position" in item else None,
index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None,
content=item.get("content"),
retriever_from=item.get("retriever_from"),
created_by=self._user_id,
)
db.session.add(dataset_retriever_resource)
db.session.commit()
self._queue_manager.publish( self._queue_manager.publish(
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
) )

@ -146,6 +146,7 @@ class BasicProviderConfig(BaseModel):
BOOLEAN = CommonParameterType.BOOLEAN.value BOOLEAN = CommonParameterType.BOOLEAN.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
@classmethod @classmethod
def value_of(cls, value: str) -> "ProviderConfig.Type": def value_of(cls, value: str) -> "ProviderConfig.Type":

@ -48,21 +48,26 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT, write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
) )
if "ssl_verify" not in kwargs:
kwargs["ssl_verify"] = HTTP_REQUEST_NODE_SSL_VERIFY
ssl_verify = kwargs.pop("ssl_verify")
retries = 0 retries = 0
while retries <= max_retries: while retries <= max_retries:
try: try:
if dify_config.SSRF_PROXY_ALL_URL: if dify_config.SSRF_PROXY_ALL_URL:
with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client: with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL, verify=ssl_verify) as client:
response = client.request(method=method, url=url, **kwargs) response = client.request(method=method, url=url, **kwargs)
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL: elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxy_mounts = { proxy_mounts = {
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL), "http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=ssl_verify),
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL), "https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=ssl_verify),
} }
with httpx.Client(mounts=proxy_mounts, verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client: with httpx.Client(mounts=proxy_mounts, verify=ssl_verify) as client:
response = client.request(method=method, url=url, **kwargs) response = client.request(method=method, url=url, **kwargs)
else: else:
with httpx.Client(verify=HTTP_REQUEST_NODE_SSL_VERIFY) as client: with httpx.Client(verify=ssl_verify) as client:
response = client.request(method=method, url=url, **kwargs) response = client.request(method=method, url=url, **kwargs)
if response.status_code not in STATUS_FORCELIST: if response.status_code not in STATUS_FORCELIST:

@ -44,6 +44,7 @@ class TokenBufferMemory:
Message.created_at, Message.created_at,
Message.workflow_run_id, Message.workflow_run_id,
Message.parent_message_id, Message.parent_message_id,
Message.answer_tokens,
) )
.filter( .filter(
Message.conversation_id == self.conversation.id, Message.conversation_id == self.conversation.id,
@ -63,7 +64,7 @@ class TokenBufferMemory:
thread_messages = extract_thread_messages(messages) thread_messages = extract_thread_messages(messages)
# for newly created message, its answer is temporarily empty, we don't need to add it to memory # for newly created message, its answer is temporarily empty, we don't need to add it to memory
if thread_messages and not thread_messages[0].answer: if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
thread_messages.pop(0) thread_messages.pop(0)
messages = list(reversed(thread_messages)) messages = list(reversed(thread_messages))

@ -177,7 +177,7 @@ class ModelInstance:
) )
def get_llm_num_tokens( def get_llm_num_tokens(
self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None self, prompt_messages: Sequence[PromptMessage], tools: Optional[Sequence[PromptMessageTool]] = None
) -> int: ) -> int:
""" """
Get number of tokens for llm Get number of tokens for llm

@ -10,7 +10,7 @@
- 支持 5 种模型类型的能力调用 - 支持 5 种模型类型的能力调用
- `LLM` - LLM 文本补全、对话,预计算 tokens 能力 - `LLM` - LLM 文本补全、对话,预计算 tokens 能力
- `Text Embedding Model` - 文本 Embedding ,预计算 tokens 能力 - `Text Embedding Model` - 文本 Embedding预计算 tokens 能力
- `Rerank Model` - 分段 Rerank 能力 - `Rerank Model` - 分段 Rerank 能力
- `Speech-to-text Model` - 语音转文本能力 - `Speech-to-text Model` - 语音转文本能力
- `Text-to-speech Model` - 文本转语音能力 - `Text-to-speech Model` - 文本转语音能力
@ -57,11 +57,11 @@ Model Runtime 分三层:
提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。 提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。
对于供应商/模型凭据,有两种情况 对于供应商/模型凭据,有两种情况
- 如OpenAI这类中心化供应商需要定义如**api_key**这类的鉴权凭据 - 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据
- 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。 - 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据,就像下面这样,当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。
![Alt text](docs/zh_Hans/images/index/image.png) ![Alt text](docs/zh_Hans/images/index/image.png)
当配置好凭据后就可以通过DifyRuntime的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。 当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。
- 最底层为模型层 - 最底层为模型层
@ -69,9 +69,9 @@ Model Runtime 分三层:
在这里我们需要先区分模型参数与模型凭据。 在这里我们需要先区分模型参数与模型凭据。
- 模型参数(**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等这些参数是由用户在前端页面上进行调整的因此需要在后端定义参数的规则以便前端页面进行展示和调整。在DifyRuntime中他们的参数名一般为**model_parameters: dict[str, any]**。 - 模型参数 (**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等,这些参数是由用户在前端页面上进行调整的,因此需要在后端定义参数的规则,以便前端页面进行展示和调整。在 DifyRuntime 中,他们的参数名一般为**model_parameters: dict[str, any]**。
- 模型凭据(**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在DifyRuntime中他们的参数名一般为**credentials: dict[str, any]**Provider层的credentials会直接被传递到这一层不需要再单独定义。 - 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。
## 下一步 ## 下一步
@ -81,7 +81,7 @@ Model Runtime 分三层:
![Alt text](docs/zh_Hans/images/index/image-1.png) ![Alt text](docs/zh_Hans/images/index/image-1.png)
### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#增加模型) ### [为已存在的供应商新增模型 👈🏻](./docs/zh_Hans/provider_scale_out.md#增加模型)
当添加后对应供应商的模型列表中将会出现一个新的预定义模型供用户选择如GPT-3.5 GPT-4 ChatGLM3-6b等而对于支持自定义模型的供应商则不需要新增模型。 当添加后,对应供应商的模型列表中将会出现一个新的预定义模型供用户选择,如 GPT-3.5 GPT-4 ChatGLM3-6b 等,而对于支持自定义模型的供应商,则不需要新增模型。
![Alt text](docs/zh_Hans/images/index/image-2.png) ![Alt text](docs/zh_Hans/images/index/image-2.png)

@ -58,7 +58,7 @@ class Callback(ABC):
chunk: LLMResultChunk, chunk: LLMResultChunk,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
@ -88,7 +88,7 @@ class Callback(ABC):
result: LLMResult, result: LLMResult,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,

@ -74,7 +74,7 @@ class LoggingCallback(Callback):
chunk: LLMResultChunk, chunk: LLMResultChunk,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
@ -104,7 +104,7 @@ class LoggingCallback(Callback):
result: LLMResult, result: LLMResult,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,

@ -102,12 +102,12 @@ provider_credential_schema:
```yaml ```yaml
- variable: server_url - variable: server_url
label: label:
zh_Hans: 服务器URL zh_Hans: 服务器 URL
en_US: Server url en_US: Server url
type: text-input type: text-input
required: true required: true
placeholder: placeholder:
zh_Hans: 在此输入Xinference的服务器地址如 https://example.com/xxx zh_Hans: 在此输入 Xinference 的服务器地址,如 https://example.com/xxx
en_US: Enter the url of your Xinference, for example https://example.com/xxx en_US: Enter the url of your Xinference, for example https://example.com/xxx
``` ```
@ -116,12 +116,12 @@ provider_credential_schema:
```yaml ```yaml
- variable: model_uid - variable: model_uid
label: label:
zh_Hans: 模型UID zh_Hans: 模型 UID
en_US: Model uid en_US: Model uid
type: text-input type: text-input
required: true required: true
placeholder: placeholder:
zh_Hans: 在此输入您的Model UID zh_Hans: 在此输入您的 Model UID
en_US: Enter the model uid en_US: Enter the model uid
``` ```
@ -192,7 +192,7 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr
``` ```
Sometimes, you might not want to return 0 directly. In such cases, you can use `self._get_num_tokens_by_gpt2(text: str)` to get pre-computed tokens. This method is provided by the `AIModel` base class, and it uses GPT2's Tokenizer for calculation. However, it should be noted that this is only a substitute and may not be fully accurate. Sometimes, you might not want to return 0 directly. In such cases, you can use `self._get_num_tokens_by_gpt2(text: str)` to get pre-computed tokens and ensure environment variable `PLUGIN_BASED_TOKEN_COUNTING_ENABLED` is set to `true`, This method is provided by the `AIModel` base class, and it uses GPT2's Tokenizer for calculation. However, it should be noted that this is only a substitute and may not be fully accurate.
- Model Credentials Validation - Model Credentials Validation

@ -367,7 +367,7 @@ Inherit the `__base.text2speech_model.Text2SpeechModel` base class and implement
- Returns - Returns
Text converted speech stream Text converted speech stream.
### Moderation ### Moderation

@ -6,14 +6,14 @@
需要注意的是,对于自定义模型,每一个模型的接入都需要填写一个完整的供应商凭据。 需要注意的是,对于自定义模型,每一个模型的接入都需要填写一个完整的供应商凭据。
而不同于预定义模型自定义供应商接入时永远会拥有如下两个参数不需要在供应商yaml中定义。 而不同于预定义模型,自定义供应商接入时永远会拥有如下两个参数,不需要在供应商 yaml 中定义。
![Alt text](images/index/image-3.png) ![Alt text](images/index/image-3.png)
在前文中,我们已经知道了供应商无需实现`validate_provider_credential`Runtime会自行根据用户在此选择的模型类型和模型名称调用对应的模型层的`validate_credentials`来进行验证。 在前文中,我们已经知道了供应商无需实现`validate_provider_credential`Runtime 会自行根据用户在此选择的模型类型和模型名称调用对应的模型层的`validate_credentials`来进行验证。
### 编写供应商yaml ### 编写供应商 yaml
我们首先要确定,接入的这个供应商支持哪些类型的模型。 我们首先要确定,接入的这个供应商支持哪些类型的模型。
@ -26,7 +26,7 @@
- `tts` 文字转语音 - `tts` 文字转语音
- `moderation` 审查 - `moderation` 审查
`Xinference`支持`LLM`和`Text Embedding`和Rerank那么我们开始编写`xinference.yaml`。 `Xinference`支持`LLM`和`Text Embedding`和 Rerank那么我们开始编写`xinference.yaml`。
```yaml ```yaml
provider: xinference #确定供应商标识 provider: xinference #确定供应商标识
@ -42,17 +42,17 @@ help: # 帮助
zh_Hans: 如何部署 Xinference zh_Hans: 如何部署 Xinference
url: url:
en_US: https://github.com/xorbitsai/inference en_US: https://github.com/xorbitsai/inference
supported_model_types: # 支持的模型类型Xinference同时支持LLM/Text Embedding/Rerank supported_model_types: # 支持的模型类型Xinference 同时支持 LLM/Text Embedding/Rerank
- llm - llm
- text-embedding - text-embedding
- rerank - rerank
configurate_methods: # 因为Xinference为本地部署的供应商并且没有预定义模型需要用什么模型需要根据Xinference的文档自己部署所以这里只支持自定义模型 configurate_methods: # 因为 Xinference 为本地部署的供应商,并且没有预定义模型,需要用什么模型需要根据 Xinference 的文档自己部署,所以这里只支持自定义模型
- customizable-model - customizable-model
provider_credential_schema: provider_credential_schema:
credential_form_schemas: credential_form_schemas:
``` ```
随后我们需要思考在Xinference中定义一个模型需要哪些凭据 随后,我们需要思考在 Xinference 中定义一个模型需要哪些凭据
- 它支持三种不同的模型,因此,我们需要有`model_type`来指定这个模型的类型,它有三种类型,所以我们这么编写 - 它支持三种不同的模型,因此,我们需要有`model_type`来指定这个模型的类型,它有三种类型,所以我们这么编写
```yaml ```yaml
@ -88,28 +88,28 @@ provider_credential_schema:
zh_Hans: 填写模型名称 zh_Hans: 填写模型名称
en_US: Input model name en_US: Input model name
``` ```
- 填写Xinference本地部署的地址 - 填写 Xinference 本地部署的地址
```yaml ```yaml
- variable: server_url - variable: server_url
label: label:
zh_Hans: 服务器URL zh_Hans: 服务器 URL
en_US: Server url en_US: Server url
type: text-input type: text-input
required: true required: true
placeholder: placeholder:
zh_Hans: 在此输入Xinference的服务器地址如 https://example.com/xxx zh_Hans: 在此输入 Xinference 的服务器地址,如 https://example.com/xxx
en_US: Enter the url of your Xinference, for example https://example.com/xxx en_US: Enter the url of your Xinference, for example https://example.com/xxx
``` ```
- 每个模型都有唯一的model_uid因此需要在这里定义 - 每个模型都有唯一的 model_uid因此需要在这里定义
```yaml ```yaml
- variable: model_uid - variable: model_uid
label: label:
zh_Hans: 模型UID zh_Hans: 模型 UID
en_US: Model uid en_US: Model uid
type: text-input type: text-input
required: true required: true
placeholder: placeholder:
zh_Hans: 在此输入您的Model UID zh_Hans: 在此输入您的 Model UID
en_US: Enter the model uid en_US: Enter the model uid
``` ```
现在,我们就完成了供应商的基础定义。 现在,我们就完成了供应商的基础定义。
@ -145,7 +145,7 @@ provider_credential_schema:
""" """
``` ```
在实现时需要注意使用两个函数来返回数据分别用于处理同步返回和流式返回因为Python会将函数中包含 `yield` 关键字的函数识别为生成器函数,返回的数据类型固定为 `Generator`,因此同步和流式返回需要分别实现,就像下面这样(注意下面例子使用了简化参数,实际实现时需要按照上面的参数列表进行实现): 在实现时,需要注意使用两个函数来返回数据,分别用于处理同步返回和流式返回,因为 Python 会将函数中包含 `yield` 关键字的函数识别为生成器函数,返回的数据类型固定为 `Generator`,因此同步和流式返回需要分别实现,就像下面这样(注意下面例子使用了简化参数,实际实现时需要按照上面的参数列表进行实现):
```python ```python
def _invoke(self, stream: bool, **kwargs) \ def _invoke(self, stream: bool, **kwargs) \
@ -179,7 +179,7 @@ provider_credential_schema:
""" """
``` ```
有时候也许你不需要直接返回0所以你可以使用`self._get_num_tokens_by_gpt2(text: str)`来获取预计算的tokens这个方法位于`AIModel`基类中它会使用GPT2的Tokenizer进行计算但是只能作为替代方法并不完全准确。 有时候,也许你不需要直接返回 0所以你可以使用`self._get_num_tokens_by_gpt2(text: str)`来获取预计算的 tokens并确保环境变量`PLUGIN_BASED_TOKEN_COUNTING_ENABLED`设置为`true`这个方法位于`AIModel`基类中,它会使用 GPT2 Tokenizer 进行计算,但是只能作为替代方法,并不完全准确。
- 模型凭据校验 - 模型凭据校验
@ -196,13 +196,13 @@ provider_credential_schema:
""" """
``` ```
- 模型参数Schema - 模型参数 Schema
与自定义类型不同由于没有在yaml文件中定义一个模型支持哪些参数因此我们需要动态时间模型参数的Schema。 与自定义类型不同,由于没有在 yaml 文件中定义一个模型支持哪些参数,因此,我们需要动态时间模型参数的 Schema。
如Xinference支持`max_tokens` `temperature` `top_p` 这三个模型参数。 Xinference 支持`max_tokens` `temperature` `top_p` 这三个模型参数。
但是有的供应商根据不同的模型支持不同的参数,如供应商`OpenLLM`支持`top_k`,但是并不是这个供应商提供的所有模型都支持`top_k`我们这里举例A模型支持`top_k`B模型不支持`top_k`那么我们需要在这里动态生成模型参数的Schema如下所示 但是有的供应商根据不同的模型支持不同的参数,如供应商`OpenLLM`支持`top_k`,但是并不是这个供应商提供的所有模型都支持`top_k`,我们这里举例 A 模型支持`top_k`B 模型不支持`top_k`,那么我们需要在这里动态生成模型参数的 Schema如下所示
```python ```python
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:

@ -687,7 +687,7 @@ class LLMUsage(ModelUsage):
total_tokens: int # 总使用 token 数 total_tokens: int # 总使用 token 数
total_price: Decimal # 总费用 total_price: Decimal # 总费用
currency: str # 货币单位 currency: str # 货币单位
latency: float # 请求耗时(s) latency: float # 请求耗时 (s)
``` ```
--- ---
@ -717,7 +717,7 @@ class EmbeddingUsage(ModelUsage):
price_unit: Decimal # 价格单位,即单价基于多少 tokens price_unit: Decimal # 价格单位,即单价基于多少 tokens
total_price: Decimal # 总费用 total_price: Decimal # 总费用
currency: str # 货币单位 currency: str # 货币单位
latency: float # 请求耗时(s) latency: float # 请求耗时 (s)
``` ```
--- ---

@ -95,7 +95,7 @@ pricing: # 价格信息
""" """
``` ```
在实现时需要注意使用两个函数来返回数据分别用于处理同步返回和流式返回因为Python会将函数中包含 `yield` 关键字的函数识别为生成器函数,返回的数据类型固定为 `Generator`,因此同步和流式返回需要分别实现,就像下面这样(注意下面例子使用了简化参数,实际实现时需要按照上面的参数列表进行实现): 在实现时,需要注意使用两个函数来返回数据,分别用于处理同步返回和流式返回,因为 Python 会将函数中包含 `yield` 关键字的函数识别为生成器函数,返回的数据类型固定为 `Generator`,因此同步和流式返回需要分别实现,就像下面这样(注意下面例子使用了简化参数,实际实现时需要按照上面的参数列表进行实现):
```python ```python
def _invoke(self, stream: bool, **kwargs) \ def _invoke(self, stream: bool, **kwargs) \

@ -8,13 +8,13 @@
- `customizable-model` 自定义模型 - `customizable-model` 自定义模型
用户需要新增每个模型的凭据配置如Xinference它同时支持 LLM 和 Text Embedding但是每个模型都有唯一的**model_uid**,如果想要将两者同时接入,就需要为每个模型配置一个**model_uid**。 用户需要新增每个模型的凭据配置,如 Xinference它同时支持 LLM 和 Text Embedding但是每个模型都有唯一的**model_uid**,如果想要将两者同时接入,就需要为每个模型配置一个**model_uid**。
- `fetch-from-remote` 从远程获取 - `fetch-from-remote` 从远程获取
`predefined-model` 配置方式一致,只需要配置统一的供应商凭据即可,模型通过凭据信息从供应商获取。 `predefined-model` 配置方式一致,只需要配置统一的供应商凭据即可,模型通过凭据信息从供应商获取。
如OpenAI我们可以基于gpt-turbo-3.5来Fine Tune多个模型而他们都位于同一个**api_key**下,当配置为 `fetch-from-remote` 时,开发者只需要配置统一的**api_key**即可让DifyRuntime获取到开发者所有的微调模型并接入Dify。 OpenAI我们可以基于 gpt-turbo-3.5 Fine Tune 多个模型,而他们都位于同一个**api_key**下,当配置为 `fetch-from-remote` 时,开发者只需要配置统一的**api_key**即可让 DifyRuntime 获取到开发者所有的微调模型并接入 Dify。
这三种配置方式**支持共存**,即存在供应商支持 `predefined-model` + `customizable-model``predefined-model` + `fetch-from-remote` 等,也就是配置了供应商统一凭据可以使用预定义模型和从远程获取的模型,若新增了模型,则可以在此基础上额外使用自定义的模型。 这三种配置方式**支持共存**,即存在供应商支持 `predefined-model` + `customizable-model``predefined-model` + `fetch-from-remote` 等,也就是配置了供应商统一凭据可以使用预定义模型和从远程获取的模型,若新增了模型,则可以在此基础上额外使用自定义的模型。
@ -23,16 +23,16 @@
### 介绍 ### 介绍
#### 名词解释 #### 名词解释
- `module`: 一个`module`即为一个Python Package或者通俗一点称为一个文件夹里面包含了一个`__init__.py`文件,以及其他的`.py`文件。 - `module`: 一个`module`即为一个 Python Package或者通俗一点称为一个文件夹里面包含了一个`__init__.py`文件,以及其他的`.py`文件。
#### 步骤 #### 步骤
新增一个供应商主要分为几步,这里简单列出,帮助大家有一个大概的认识,具体的步骤会在下面详细介绍。 新增一个供应商主要分为几步,这里简单列出,帮助大家有一个大概的认识,具体的步骤会在下面详细介绍。
- 创建供应商yaml文件根据[ProviderSchema](./schema.md#provider)编写 - 创建供应商 yaml 文件,根据[ProviderSchema](./schema.md#provider)编写
- 创建供应商代码,实现一个`class`。 - 创建供应商代码,实现一个`class`。
- 根据模型类型,在供应商`module`下创建对应的模型类型 `module`,如`llm`或`text_embedding`。 - 根据模型类型,在供应商`module`下创建对应的模型类型 `module`,如`llm`或`text_embedding`。
- 根据模型类型,在对应的模型`module`下创建同名的代码文件,如`llm.py`,并实现一个`class`。 - 根据模型类型,在对应的模型`module`下创建同名的代码文件,如`llm.py`,并实现一个`class`。
- 如果有预定义模型根据模型名称创建同名的yaml文件在模型`module`下,如`claude-2.1.yaml`,根据[AIModelEntity](./schema.md#aimodelentity)编写。 - 如果有预定义模型,根据模型名称创建同名的 yaml 文件在模型`module`下,如`claude-2.1.yaml`,根据[AIModelEntity](./schema.md#aimodelentity)编写。
- 编写测试代码,确保功能可用。 - 编写测试代码,确保功能可用。
### 开始吧 ### 开始吧
@ -121,11 +121,11 @@ model_credential_schema:
#### 实现供应商代码 #### 实现供应商代码
我们需要在`model_providers`下创建一个同名的python文件如`anthropic.py`,并实现一个`class`,继承`__base.provider.Provider`基类,如`AnthropicProvider`。 我们需要在`model_providers`下创建一个同名的 python 文件,如`anthropic.py`,并实现一个`class`,继承`__base.provider.Provider`基类,如`AnthropicProvider`。
##### 自定义模型供应商 ##### 自定义模型供应商
当供应商为Xinference等自定义模型供应商时可跳过该步骤仅创建一个空的`XinferenceProvider`类即可,并实现一个空的`validate_provider_credentials`方法,该方法并不会被实际使用,仅用作避免抽象类无法实例化。 当供应商为 Xinference 等自定义模型供应商时,可跳过该步骤,仅创建一个空的`XinferenceProvider`类即可,并实现一个空的`validate_provider_credentials`方法,该方法并不会被实际使用,仅用作避免抽象类无法实例化。
```python ```python
class XinferenceProvider(Provider): class XinferenceProvider(Provider):
@ -155,7 +155,7 @@ def validate_provider_credentials(self, credentials: dict) -> None:
#### 增加模型 #### 增加模型
#### [增加预定义模型 👈🏻](./predefined_model_scale_out.md) #### [增加预定义模型 👈🏻](./predefined_model_scale_out.md)
对于预定义模型我们可以通过简单定义一个yaml并通过实现调用代码来接入。 对于预定义模型,我们可以通过简单定义一个 yaml并通过实现调用代码来接入。
#### [增加自定义模型 👈🏻](./customizable_model_scale_out.md) #### [增加自定义模型 👈🏻](./customizable_model_scale_out.md)
对于自定义模型,我们只需要实现调用代码即可接入,但是它需要处理的参数可能会更加复杂。 对于自定义模型,我们只需要实现调用代码即可接入,但是它需要处理的参数可能会更加复杂。

@ -29,7 +29,7 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
"help": { "help": {
"en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options" "en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options"
" are considered.", " are considered.",
"zh_Hans": "通过核心采样控制多样性0.5表示考虑了一半的所有可能性加权选项。", "zh_Hans": "通过核心采样控制多样性0.5 表示考虑了一半的所有可能性加权选项。",
}, },
"required": False, "required": False,
"default": 1.0, "default": 1.0,
@ -111,7 +111,7 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
"help": { "help": {
"en_US": "Set a response format, ensure the output from llm is a valid code block as possible," "en_US": "Set a response format, ensure the output from llm is a valid code block as possible,"
" such as JSON, XML, etc.", " such as JSON, XML, etc.",
"zh_Hans": "设置一个返回格式,确保llm的输出尽可能是有效的代码块如JSON、XML等", "zh_Hans": "设置一个返回格式,确保 llm 的输出尽可能是有效的代码块,如 JSON、XML ",
}, },
"required": False, "required": False,
"options": ["JSON", "XML"], "options": ["JSON", "XML"],
@ -123,7 +123,7 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
"type": "text", "type": "text",
"help": { "help": {
"en_US": "Set a response json schema will ensure LLM to adhere it.", "en_US": "Set a response json schema will ensure LLM to adhere it.",
"zh_Hans": "设置返回的json schemallm将按照它返回", "zh_Hans": "设置返回的 json schemallm 将按照它返回",
}, },
"required": False, "required": False,
}, },

@ -1,8 +1,9 @@
from collections.abc import Sequence
from decimal import Decimal from decimal import Decimal
from enum import StrEnum from enum import StrEnum
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel, Field
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage
from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo
@ -107,7 +108,7 @@ class LLMResult(BaseModel):
id: Optional[str] = None id: Optional[str] = None
model: str model: str
prompt_messages: list[PromptMessage] prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
message: AssistantPromptMessage message: AssistantPromptMessage
usage: LLMUsage usage: LLMUsage
system_fingerprint: Optional[str] = None system_fingerprint: Optional[str] = None
@ -130,7 +131,7 @@ class LLMResultChunk(BaseModel):
""" """
model: str model: str
prompt_messages: list[PromptMessage] prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
system_fingerprint: Optional[str] = None system_fingerprint: Optional[str] = None
delta: LLMResultChunkDelta delta: LLMResultChunkDelta

@ -1,5 +1,6 @@
import logging import logging
import time import time
import uuid
from collections.abc import Generator, Sequence from collections.abc import Generator, Sequence
from typing import Optional, Union from typing import Optional, Union
@ -24,6 +25,58 @@ from core.plugin.manager.model import PluginModelManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _gen_tool_call_id() -> str:
return f"chatcmpl-tool-{str(uuid.uuid4().hex)}"
def _increase_tool_call(
new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall]
):
"""
Merge incremental tool call updates into existing tool calls.
:param new_tool_calls: List of new tool call deltas to be merged.
:param existing_tools_calls: List of existing tool calls to be modified IN-PLACE.
"""
def get_tool_call(tool_call_id: str):
"""
Get or create a tool call by ID
:param tool_call_id: tool call ID
:return: existing or new tool call
"""
if not tool_call_id:
return existing_tools_calls[-1]
_tool_call = next((_tool_call for _tool_call in existing_tools_calls if _tool_call.id == tool_call_id), None)
if _tool_call is None:
_tool_call = AssistantPromptMessage.ToolCall(
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""),
)
existing_tools_calls.append(_tool_call)
return _tool_call
for new_tool_call in new_tool_calls:
# generate ID for tool calls with function name but no ID to track them
if new_tool_call.function.name and not new_tool_call.id:
new_tool_call.id = _gen_tool_call_id()
# get tool call
tool_call = get_tool_call(new_tool_call.id)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
class LargeLanguageModel(AIModel): class LargeLanguageModel(AIModel):
""" """
Model class for large language model. Model class for large language model.
@ -45,7 +98,7 @@ class LargeLanguageModel(AIModel):
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
) -> Union[LLMResult, Generator]: ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
""" """
Invoke large language model Invoke large language model
@ -109,44 +162,13 @@ class LargeLanguageModel(AIModel):
system_fingerprint = None system_fingerprint = None
tools_calls: list[AssistantPromptMessage.ToolCall] = [] tools_calls: list[AssistantPromptMessage.ToolCall] = []
def increase_tool_call(new_tool_calls: list[AssistantPromptMessage.ToolCall]):
def get_tool_call(tool_name: str):
if not tool_name:
return tools_calls[-1]
tool_call = next(
(tool_call for tool_call in tools_calls if tool_call.function.name == tool_name), None
)
if tool_call is None:
tool_call = AssistantPromptMessage.ToolCall(
id="",
type="",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tool_name, arguments=""),
)
tools_calls.append(tool_call)
return tool_call
for new_tool_call in new_tool_calls:
# get tool call
tool_call = get_tool_call(new_tool_call.function.name)
# update tool call
if new_tool_call.id:
tool_call.id = new_tool_call.id
if new_tool_call.type:
tool_call.type = new_tool_call.type
if new_tool_call.function.name:
tool_call.function.name = new_tool_call.function.name
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
for chunk in result: for chunk in result:
if isinstance(chunk.delta.message.content, str): if isinstance(chunk.delta.message.content, str):
content += chunk.delta.message.content content += chunk.delta.message.content
elif isinstance(chunk.delta.message.content, list): elif isinstance(chunk.delta.message.content, list):
content_list.extend(chunk.delta.message.content) content_list.extend(chunk.delta.message.content)
if chunk.delta.message.tool_calls: if chunk.delta.message.tool_calls:
increase_tool_call(chunk.delta.message.tool_calls) _increase_tool_call(chunk.delta.message.tool_calls, tools_calls)
usage = chunk.delta.usage or LLMUsage.empty_usage() usage = chunk.delta.usage or LLMUsage.empty_usage()
system_fingerprint = chunk.system_fingerprint system_fingerprint = chunk.system_fingerprint
@ -205,22 +227,26 @@ class LargeLanguageModel(AIModel):
user=user, user=user,
callbacks=callbacks, callbacks=callbacks,
) )
# Following https://github.com/langgenius/dify/issues/17799,
# we removed the prompt_messages from the chunk on the plugin daemon side.
# To ensure compatibility, we add the prompt_messages back here.
result.prompt_messages = prompt_messages
return result return result
raise NotImplementedError("unsupported invoke result type", type(result))
def _invoke_result_generator( def _invoke_result_generator(
self, self,
model: str, model: str,
result: Generator, result: Generator,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
) -> Generator: ) -> Generator[LLMResultChunk, None, None]:
""" """
Invoke result generator Invoke result generator
@ -235,6 +261,10 @@ class LargeLanguageModel(AIModel):
try: try:
for chunk in result: for chunk in result:
# Following https://github.com/langgenius/dify/issues/17799,
# we removed the prompt_messages from the chunk on the plugin daemon side.
# To ensure compatibility, we add the prompt_messages back here.
chunk.prompt_messages = prompt_messages
yield chunk yield chunk
self._trigger_new_chunk_callbacks( self._trigger_new_chunk_callbacks(
@ -295,6 +325,7 @@ class LargeLanguageModel(AIModel):
:param tools: tools for tool calling :param tools: tools for tool calling
:return: :return:
""" """
if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
plugin_model_manager = PluginModelManager() plugin_model_manager = PluginModelManager()
return plugin_model_manager.get_llm_num_tokens( return plugin_model_manager.get_llm_num_tokens(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
@ -307,6 +338,7 @@ class LargeLanguageModel(AIModel):
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
tools=tools, tools=tools,
) )
return 0
def _calc_response_usage( def _calc_response_usage(
self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int
@ -401,7 +433,7 @@ class LargeLanguageModel(AIModel):
chunk: LLMResultChunk, chunk: LLMResultChunk,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
@ -448,7 +480,7 @@ class LargeLanguageModel(AIModel):
model: str, model: str,
result: LLMResult, result: LLMResult,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,

@ -5,6 +5,7 @@ from datetime import datetime, timedelta
from typing import Optional from typing import Optional
from langfuse import Langfuse # type: ignore from langfuse import Langfuse # type: ignore
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.config_entity import LangfuseConfig
@ -28,9 +29,9 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
UnitEnum, UnitEnum,
) )
from core.ops.utils import filter_none_values from core.ops.utils import filter_none_values
from core.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.model import EndUser from models.model import EndUser
from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -110,36 +111,18 @@ class LangFuseDataTrace(BaseTraceInstance):
) )
self.add_trace(langfuse_trace_data=trace_data) self.add_trace(langfuse_trace_data=trace_data)
# through workflow_run_id get all_nodes_execution # through workflow_run_id get all_nodes_execution using repository
workflow_nodes_execution_id_records = ( session_factory = sessionmaker(bind=db.engine)
db.session.query(WorkflowNodeExecution.id) workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) params={"tenant_id": trace_info.tenant_id, "session_factory": session_factory},
.all()
) )
for node_execution_id_record in workflow_nodes_execution_id_records: # Get all executions for this workflow run
node_execution = ( workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
db.session.query( workflow_run_id=trace_info.workflow_run_id
WorkflowNodeExecution.id,
WorkflowNodeExecution.tenant_id,
WorkflowNodeExecution.app_id,
WorkflowNodeExecution.title,
WorkflowNodeExecution.node_type,
WorkflowNodeExecution.status,
WorkflowNodeExecution.inputs,
WorkflowNodeExecution.outputs,
WorkflowNodeExecution.created_at,
WorkflowNodeExecution.elapsed_time,
WorkflowNodeExecution.process_data,
WorkflowNodeExecution.execution_metadata,
) )
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
.first()
)
if not node_execution:
continue
for node_execution in workflow_node_executions:
node_execution_id = node_execution.id node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id tenant_id = node_execution.tenant_id
app_id = node_execution.app_id app_id = node_execution.app_id
@ -213,9 +196,24 @@ class LangFuseDataTrace(BaseTraceInstance):
if process_data and process_data.get("model_mode") == "chat": if process_data and process_data.get("model_mode") == "chat":
total_token = metadata.get("total_tokens", 0) total_token = metadata.get("total_tokens", 0)
prompt_tokens = 0
completion_tokens = 0
try:
if outputs.get("usage"):
prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = outputs.get("usage", {}).get("completion_tokens", 0)
else:
prompt_tokens = process_data.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = process_data.get("usage", {}).get("completion_tokens", 0)
except Exception:
logger.error("Failed to extract usage", exc_info=True)
# add generation # add generation
generation_usage = GenerationUsage( generation_usage = GenerationUsage(
input=prompt_tokens,
output=completion_tokens,
total=total_token, total=total_token,
unit=UnitEnum.TOKENS,
) )
node_generation_data = LangfuseGeneration( node_generation_data = LangfuseGeneration(

@ -7,6 +7,7 @@ from typing import Optional, cast
from langsmith import Client from langsmith import Client
from langsmith.schemas import RunBase from langsmith.schemas import RunBase
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangSmithConfig from core.ops.entities.config_entity import LangSmithConfig
@ -27,9 +28,9 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
LangSmithRunUpdateModel, LangSmithRunUpdateModel,
) )
from core.ops.utils import filter_none_values, generate_dotted_order from core.ops.utils import filter_none_values, generate_dotted_order
from core.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.model import EndUser, MessageFile from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -134,36 +135,22 @@ class LangSmithDataTrace(BaseTraceInstance):
self.add_run(langsmith_run) self.add_run(langsmith_run)
# through workflow_run_id get all_nodes_execution # through workflow_run_id get all_nodes_execution using repository
workflow_nodes_execution_id_records = ( session_factory = sessionmaker(bind=db.engine)
db.session.query(WorkflowNodeExecution.id) workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) params={
.all() "tenant_id": trace_info.tenant_id,
"app_id": trace_info.metadata.get("app_id"),
"session_factory": session_factory,
},
) )
for node_execution_id_record in workflow_nodes_execution_id_records: # Get all executions for this workflow run
node_execution = ( workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
db.session.query( workflow_run_id=trace_info.workflow_run_id
WorkflowNodeExecution.id,
WorkflowNodeExecution.tenant_id,
WorkflowNodeExecution.app_id,
WorkflowNodeExecution.title,
WorkflowNodeExecution.node_type,
WorkflowNodeExecution.status,
WorkflowNodeExecution.inputs,
WorkflowNodeExecution.outputs,
WorkflowNodeExecution.created_at,
WorkflowNodeExecution.elapsed_time,
WorkflowNodeExecution.process_data,
WorkflowNodeExecution.execution_metadata,
)
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
.first()
) )
if not node_execution: for node_execution in workflow_node_executions:
continue
node_execution_id = node_execution.id node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id tenant_id = node_execution.tenant_id
app_id = node_execution.app_id app_id = node_execution.app_id
@ -199,6 +186,7 @@ class LangSmithDataTrace(BaseTraceInstance):
) )
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
if process_data and process_data.get("model_mode") == "chat": if process_data and process_data.get("model_mode") == "chat":
run_type = LangSmithRunType.llm run_type = LangSmithRunType.llm
metadata.update( metadata.update(
@ -212,9 +200,23 @@ class LangSmithDataTrace(BaseTraceInstance):
else: else:
run_type = LangSmithRunType.tool run_type = LangSmithRunType.tool
prompt_tokens = 0
completion_tokens = 0
try:
if outputs.get("usage"):
prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = outputs.get("usage", {}).get("completion_tokens", 0)
else:
prompt_tokens = process_data.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = process_data.get("usage", {}).get("completion_tokens", 0)
except Exception:
logger.error("Failed to extract usage", exc_info=True)
node_dotted_order = generate_dotted_order(node_execution_id, created_at, workflow_dotted_order) node_dotted_order = generate_dotted_order(node_execution_id, created_at, workflow_dotted_order)
langsmith_run = LangSmithRunModel( langsmith_run = LangSmithRunModel(
total_tokens=node_total_tokens, total_tokens=node_total_tokens,
input_tokens=prompt_tokens,
output_tokens=completion_tokens,
name=node_type, name=node_type,
inputs=inputs, inputs=inputs,
run_type=run_type, run_type=run_type,

@ -7,6 +7,7 @@ from typing import Optional, cast
from opik import Opik, Trace from opik import Opik, Trace
from opik.id_helpers import uuid4_to_uuid7 from opik.id_helpers import uuid4_to_uuid7
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import OpikConfig from core.ops.entities.config_entity import OpikConfig
@ -21,9 +22,9 @@ from core.ops.entities.trace_entity import (
TraceTaskName, TraceTaskName,
WorkflowTraceInfo, WorkflowTraceInfo,
) )
from core.repository.repository_factory import RepositoryFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.model import EndUser, MessageFile from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecution
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -147,36 +148,22 @@ class OpikDataTrace(BaseTraceInstance):
} }
self.add_trace(trace_data) self.add_trace(trace_data)
# through workflow_run_id get all_nodes_execution # through workflow_run_id get all_nodes_execution using repository
workflow_nodes_execution_id_records = ( session_factory = sessionmaker(bind=db.engine)
db.session.query(WorkflowNodeExecution.id) workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) params={
.all() "tenant_id": trace_info.tenant_id,
"app_id": trace_info.metadata.get("app_id"),
"session_factory": session_factory,
},
) )
for node_execution_id_record in workflow_nodes_execution_id_records: # Get all executions for this workflow run
node_execution = ( workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
db.session.query( workflow_run_id=trace_info.workflow_run_id
WorkflowNodeExecution.id,
WorkflowNodeExecution.tenant_id,
WorkflowNodeExecution.app_id,
WorkflowNodeExecution.title,
WorkflowNodeExecution.node_type,
WorkflowNodeExecution.status,
WorkflowNodeExecution.inputs,
WorkflowNodeExecution.outputs,
WorkflowNodeExecution.created_at,
WorkflowNodeExecution.elapsed_time,
WorkflowNodeExecution.process_data,
WorkflowNodeExecution.execution_metadata,
)
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
.first()
) )
if not node_execution: for node_execution in workflow_node_executions:
continue
node_execution_id = node_execution.id node_execution_id = node_execution.id
tenant_id = node_execution.tenant_id tenant_id = node_execution.tenant_id
app_id = node_execution.app_id app_id = node_execution.app_id

@ -460,7 +460,7 @@ class TraceTask:
"version": workflow_run_version, "version": workflow_run_version,
"total_tokens": total_tokens, "total_tokens": total_tokens,
"file_list": file_list, "file_list": file_list,
"triggered_form": workflow_run.triggered_from, "triggered_from": workflow_run.triggered_from,
"user_id": user_id, "user_id": user_id,
} }

@ -2,6 +2,7 @@ from collections.abc import Generator, Mapping
from typing import Optional, Union from typing import Optional, Union
from controllers.service_api.wraps import create_or_update_end_user_for_user_id from controllers.service_api.wraps import create_or_update_end_user_for_user_id
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator
from core.app.apps.chat.app_generator import ChatAppGenerator from core.app.apps.chat.app_generator import ChatAppGenerator
@ -15,6 +16,34 @@ from models.model import App, AppMode, EndUser
class PluginAppBackwardsInvocation(BaseBackwardsInvocation): class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
@classmethod
def fetch_app_info(cls, app_id: str, tenant_id: str) -> Mapping:
"""
Fetch app info
"""
app = cls._get_app(app_id, tenant_id)
"""Retrieve app parameters."""
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app.workflow
if workflow is None:
raise ValueError("unexpected app type")
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app.app_model_config
if app_model_config is None:
raise ValueError("unexpected app type")
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
return {
"data": get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form),
}
@classmethod @classmethod
def invoke_app( def invoke_app(
cls, cls,

@ -131,7 +131,7 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
raise ValueError("The selector must be a dictionary.") raise ValueError("The selector must be a dictionary.")
return value return value
case PluginParameterType.TOOLS_SELECTOR: case PluginParameterType.TOOLS_SELECTOR:
if not isinstance(value, list): if value and not isinstance(value, list):
raise ValueError("The tools selector must be a list.") raise ValueError("The tools selector must be a list.")
return value return value
case _: case _:
@ -147,7 +147,7 @@ def init_frontend_parameter(rule: PluginParameter, type: enum.StrEnum, value: An
init frontend parameter by rule init frontend parameter by rule
""" """
parameter_value = value parameter_value = value
if not parameter_value and parameter_value != 0 and type != PluginParameterType.TOOLS_SELECTOR: if not parameter_value and parameter_value != 0:
# get default value # get default value
parameter_value = rule.default parameter_value = rule.default
if not parameter_value and rule.required: if not parameter_value and rule.required:

@ -70,6 +70,9 @@ class PluginDeclaration(BaseModel):
models: Optional[list[str]] = Field(default_factory=list) models: Optional[list[str]] = Field(default_factory=list)
endpoints: Optional[list[str]] = Field(default_factory=list) endpoints: Optional[list[str]] = Field(default_factory=list)
class Meta(BaseModel):
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
version: str = Field(..., pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$") version: str = Field(..., pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
author: Optional[str] = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$") author: Optional[str] = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$")
name: str = Field(..., pattern=r"^[a-z0-9_-]{1,128}$") name: str = Field(..., pattern=r"^[a-z0-9_-]{1,128}$")
@ -86,6 +89,7 @@ class PluginDeclaration(BaseModel):
model: Optional[ProviderEntity] = None model: Optional[ProviderEntity] = None
endpoint: Optional[EndpointProviderDeclaration] = None endpoint: Optional[EndpointProviderDeclaration] = None
agent_strategy: Optional[AgentStrategyProviderEntity] = None agent_strategy: Optional[AgentStrategyProviderEntity] = None
meta: Meta
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@ -120,8 +124,6 @@ class PluginEntity(PluginInstallation):
name: str name: str
installation_id: str installation_id: str
version: str version: str
latest_version: Optional[str] = None
latest_unique_identifier: Optional[str] = None
@model_validator(mode="after") @model_validator(mode="after")
def set_plugin_id(self): def set_plugin_id(self):

@ -204,3 +204,11 @@ class RequestRequestUploadFile(BaseModel):
filename: str filename: str
mimetype: str mimetype: str
class RequestFetchAppInfo(BaseModel):
"""
Request to fetch app info
"""
app_id: str

@ -82,7 +82,7 @@ class BasePluginManager:
Make a stream request to the plugin daemon inner API Make a stream request to the plugin daemon inner API
""" """
response = self._request(method, path, headers, data, params, files, stream=True) response = self._request(method, path, headers, data, params, files, stream=True)
for line in response.iter_lines(): for line in response.iter_lines(chunk_size=1024 * 8):
line = line.decode("utf-8").strip() line = line.decode("utf-8").strip()
if line.startswith("data:"): if line.startswith("data:"):
line = line[5:].strip() line = line[5:].strip()
@ -168,16 +168,18 @@ class BasePluginManager:
Make a stream request to the plugin daemon inner API and yield the response as a model. Make a stream request to the plugin daemon inner API and yield the response as a model.
""" """
for line in self._stream_request(method, path, params, headers, data, files): for line in self._stream_request(method, path, params, headers, data, files):
line_data = None
try: try:
line_data = json.loads(line) rep = PluginDaemonBasicResponse[type].model_validate_json(line) # type: ignore
rep = PluginDaemonBasicResponse[type](**line_data) # type: ignore except (ValueError, TypeError):
except Exception:
# TODO modify this when line_data has code and message # TODO modify this when line_data has code and message
if line_data and "error" in line_data: try:
raise ValueError(line_data["error"]) line_data = json.loads(line)
else: except (ValueError, TypeError):
raise ValueError(line) raise ValueError(line)
# If the dictionary contains the `error` key, use its value as the argument
# for `ValueError`.
# Otherwise, use the `line` to provide better contextual information about the error.
raise ValueError(line_data.get("error", line))
if rep.code != 0: if rep.code != 0:
if rep.code == -500: if rep.code == -500:

@ -110,7 +110,62 @@ class PluginToolManager(BasePluginManager):
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
) )
return response
class FileChunk:
"""
Only used for internal processing.
"""
bytes_written: int
total_length: int
data: bytearray
def __init__(self, total_length: int):
self.bytes_written = 0
self.total_length = total_length
self.data = bytearray(total_length)
files: dict[str, FileChunk] = {}
for resp in response:
if resp.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
assert isinstance(resp.message, ToolInvokeMessage.BlobChunkMessage)
# Get blob chunk information
chunk_id = resp.message.id
total_length = resp.message.total_length
blob_data = resp.message.blob
is_end = resp.message.end
# Initialize buffer for this file if it doesn't exist
if chunk_id not in files:
files[chunk_id] = FileChunk(total_length)
# If this is the final chunk, yield a complete blob message
if is_end:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB,
message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data),
meta=resp.meta,
)
else:
# Check if file is too large (30MB limit)
if files[chunk_id].bytes_written + len(blob_data) > 30 * 1024 * 1024:
# Delete the file if it's too large
del files[chunk_id]
# Skip yielding this message
raise ValueError("File is too large which reached the limit of 30MB")
# Check if single chunk is too large (8KB limit)
if len(blob_data) > 8192:
# Skip yielding this message
raise ValueError("File chunk is too large which reached the limit of 8KB")
# Append the blob data to the buffer
files[chunk_id].data[
files[chunk_id].bytes_written : files[chunk_id].bytes_written + len(blob_data)
] = blob_data
files[chunk_id].bytes_written += len(blob_data)
else:
yield resp
def validate_provider_credentials( def validate_provider_credentials(
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any] self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]

@ -28,7 +28,7 @@ BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = {
}, },
"conversation_histories_role": {"user_prefix": "用户", "assistant_prefix": "助手"}, "conversation_histories_role": {"user_prefix": "用户", "assistant_prefix": "助手"},
}, },
"stop": ["用户:"], "stop": ["用户"],
} }
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = { BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = {
@ -41,5 +41,5 @@ BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = {
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = { BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}}, "completion_prompt_config": {"prompt": {"text": "{{#pre_prompt#}}"}},
"stop": ["用户:"], "stop": ["用户"],
} }

@ -124,6 +124,15 @@ class ProviderManager:
# Get All preferred provider types of the workspace # Get All preferred provider types of the workspace
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id) provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
# Ensure that both the original provider name and its ModelProviderID string representation
# are present in the dictionary to handle cases where either form might be used
for provider_name in list(provider_name_to_preferred_model_provider_records_dict.keys()):
provider_id = ModelProviderID(provider_name)
if str(provider_id) not in provider_name_to_preferred_model_provider_records_dict:
# Add the ModelProviderID string representation if it's not already present
provider_name_to_preferred_model_provider_records_dict[str(provider_id)] = (
provider_name_to_preferred_model_provider_records_dict[provider_name]
)
# Get All provider model settings # Get All provider model settings
provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id) provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id)
@ -497,8 +506,8 @@ class ProviderManager:
@staticmethod @staticmethod
def _init_trial_provider_records( def _init_trial_provider_records(
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list] tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
) -> dict[str, list]: ) -> dict[str, list[Provider]]:
""" """
Initialize trial provider records if not exists. Initialize trial provider records if not exists.
@ -532,7 +541,7 @@ class ProviderManager:
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
try: try:
# FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic # FIXME ignore the type errork, onyl TrialHostingQuota has limit need to change the logic
provider_record = Provider( new_provider_record = Provider(
tenant_id=tenant_id, tenant_id=tenant_id,
# TODO: Use provider name with prefix after the data migration. # TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name, provider_name=ModelProviderID(provider_name).provider_name,
@ -542,11 +551,12 @@ class ProviderManager:
quota_used=0, quota_used=0,
is_valid=True, is_valid=True,
) )
db.session.add(provider_record) db.session.add(new_provider_record)
db.session.commit() db.session.commit()
provider_name_to_provider_records_dict[provider_name].append(new_provider_record)
except IntegrityError: except IntegrityError:
db.session.rollback() db.session.rollback()
provider_record = ( existed_provider_record = (
db.session.query(Provider) db.session.query(Provider)
.filter( .filter(
Provider.tenant_id == tenant_id, Provider.tenant_id == tenant_id,
@ -556,11 +566,14 @@ class ProviderManager:
) )
.first() .first()
) )
if provider_record and not provider_record.is_valid: if not existed_provider_record:
provider_record.is_valid = True continue
if not existed_provider_record.is_valid:
existed_provider_record.is_valid = True
db.session.commit() db.session.commit()
provider_name_to_provider_records_dict[provider_name].append(provider_record) provider_name_to_provider_records_dict[provider_name].append(existed_provider_record)
return provider_name_to_provider_records_dict return provider_name_to_provider_records_dict

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save