Merge branch 'main' into feat/21072-default-value-select-field

pull/21192/head
antonko 11 months ago
commit 3b92b42284

@ -47,15 +47,17 @@ jobs:
- name: Run Unit tests
run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
- name: Coverage Summary
run: |
set -x
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report >> $GITHUB_STEP_SUMMARY
echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
- name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py
@ -83,9 +85,15 @@ jobs:
compose-file: |
docker/docker-compose.middleware.yaml
services: |
db
redis
sandbox
ssrf_proxy
- name: setup test config
run: |
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
- name: Run Workflow
run: uv run --project api bash dev/pytest/pytest_workflow.sh

@ -84,6 +84,12 @@ jobs:
elasticsearch
oceanbase
- name: setup test config
run: |
echo $(pwd)
ls -lah .
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
- name: Check VDB Ready (TiDB)
run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py

2
.gitignore vendored

@ -179,6 +179,7 @@ docker/volumes/pgvecto_rs/data/*
docker/volumes/couchbase/*
docker/volumes/oceanbase/*
docker/volumes/plugin_daemon/*
docker/volumes/matrixone/*
!docker/volumes/oceanbase/init.d
docker/nginx/conf.d/default.conf
@ -213,3 +214,4 @@ mise.toml
# AI Assistant
.roo/
api/.env.backup

@ -226,6 +226,15 @@ Deploy Dify to AWS with [CDK](https://aws.amazon.com/cdk/)
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Using Alibaba Cloud Computing Nest
Quickly deploy Dify to Alibaba cloud with [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Using Alibaba Cloud Data Management
One-Click deploy Dify to Alibaba Cloud with [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
## Contributing
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).

@ -209,6 +209,14 @@ docker compose up -d
- [AWS CDK بواسطة @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### استخدام Alibaba Cloud للنشر
[بسرعة نشر Dify إلى سحابة علي بابا مع عش الحوسبة السحابية علي بابا](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### استخدام Alibaba Cloud Data Management للنشر
انشر Dify على علي بابا كلاود بنقرة واحدة باستخدام [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
## المساهمة
لأولئك الذين يرغبون في المساهمة، انظر إلى [دليل المساهمة](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) لدينا.

@ -225,6 +225,15 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Alibaba Cloud ব্যবহার করে ডিপ্লয়
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Alibaba Cloud Data Management ব্যবহার করে ডিপ্লয়
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
## Contributing
যারা কোড অবদান রাখতে চান, তাদের জন্য আমাদের [অবদান নির্দেশিকা] দেখুন (https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)।

@ -221,6 +221,15 @@ docker compose up -d
##### AWS
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### 使用 阿里云计算巢 部署
使用 [阿里云计算巢](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) 将 Dify 一键部署到 阿里云
#### 使用 阿里云数据管理DMS 部署
使用 [阿里云数据管理DMS](https://help.aliyun.com/zh/dms/dify-in-invitational-preview) 将 Dify 一键部署到 阿里云
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date)

@ -221,6 +221,15 @@ Bereitstellung von Dify auf AWS mit [CDK](https://aws.amazon.com/cdk/)
##### AWS
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Alibaba Cloud
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Alibaba Cloud Data Management
Ein-Klick-Bereitstellung von Dify in der Alibaba Cloud mit [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
## Contributing
Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren.

@ -221,6 +221,15 @@ Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/)
##### AWS
- [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Alibaba Cloud
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Alibaba Cloud Data Management
Despliega Dify en Alibaba Cloud con un solo clic con [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
## Contribuir
Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).

@ -219,6 +219,15 @@ Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/)
##### AWS
- [AWS CDK par @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Alibaba Cloud
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Alibaba Cloud Data Management
Déployez Dify en un clic sur Alibaba Cloud avec [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
## Contribuer
Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).

@ -155,7 +155,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ
[こちら](https://dify.ai)のDify Cloudサービスを利用して、セットアップ不要で試すことができます。サンドボックスプランには、200回のGPT-4呼び出しが無料で含まれています。
- **Dify Community Editionのセルフホスティング</br>**
この[スタートガイド](#quick-start)を使用して、ローカル環境でDifyを簡単に実行できます。
この[スタートガイド](#クイックスタート)を使用して、ローカル環境でDifyを簡単に実行できます。
詳しくは[ドキュメント](https://docs.dify.ai)をご覧ください。
- **企業/組織向けのDify</br>**
@ -220,6 +220,13 @@ docker compose up -d
##### AWS
- [@KevinZhaoによるAWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Alibaba Cloud
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Alibaba Cloud Data Management
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) を利用して、DifyをAlibaba Cloudへワンクリックでデプロイできます
## 貢献
コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)を参照してください。

@ -219,6 +219,15 @@ wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo
##### AWS
- [AWS CDK qachlot @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Alibaba Cloud
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Alibaba Cloud Data Management
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
## Contributing
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).

@ -213,6 +213,15 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했
##### AWS
- [KevinZhao의 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Alibaba Cloud
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Alibaba Cloud Data Management
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)를 통해 원클릭으로 Dify를 Alibaba Cloud에 배포할 수 있습니다
## 기여
코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요.

@ -218,6 +218,15 @@ Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/)
##### AWS
- [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Alibaba Cloud
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Alibaba Cloud Data Management
Implante o Dify na Alibaba Cloud com um clique usando o [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
## Contribuindo
Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).

@ -219,6 +219,15 @@ Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/)
##### AWS
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Alibaba Cloud
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Alibaba Cloud Data Management
Z enim klikom namestite Dify na Alibaba Cloud z [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
## Prispevam
Za tiste, ki bi radi prispevali kodo, si oglejte naš vodnik za prispevke . Hkrati vas prosimo, da podprete Dify tako, da ga delite na družbenih medijih ter na dogodkih in konferencah.

@ -212,6 +212,15 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter
##### AWS
- [AWS CDK tarafından @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Alibaba Cloud
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Alibaba Cloud Data Management
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) kullanarak Dify'ı tek tıkla Alibaba Cloud'a dağıtın
## Katkıda Bulunma
Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakabilirsiniz.

@ -224,6 +224,15 @@ Dify 的所有功能都提供相應的 API因此您可以輕鬆地將 Dify
- [由 @KevinZhao 提供的 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### 使用 阿里云计算巢進行部署
[阿里云](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### 使用 阿里雲數據管理DMS 進行部署
透過 [阿里雲數據管理DMS](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/),一鍵將 Dify 部署至阿里雲
## 貢獻
對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。

@ -214,6 +214,16 @@ Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/)
##### AWS
- [AWS CDK bởi @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Alibaba Cloud
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Alibaba Cloud Data Management
Triển khai Dify lên Alibaba Cloud chỉ với một cú nhấp chuột bằng [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
## Đóng góp
Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi.

@ -137,7 +137,7 @@ WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Vector database configuration
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore, matrixone
VECTOR_STORE=weaviate
# Weaviate configuration
@ -294,6 +294,13 @@ VIKINGDB_SCHEMA=http
VIKINGDB_CONNECTION_TIMEOUT=30
VIKINGDB_SOCKET_TIMEOUT=30
# Matrixone configration
MATRIXONE_HOST=127.0.0.1
MATRIXONE_PORT=6001
MATRIXONE_USER=dump
MATRIXONE_PASSWORD=111
MATRIXONE_DATABASE=dify
# Lindorm configuration
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
LINDORM_USERNAME=admin
@ -332,9 +339,11 @@ PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
# Mail configuration, support: resend, smtp
# Mail configuration, support: resend, smtp, sendgrid
MAIL_TYPE=
# If using SendGrid, use the 'from' field for authentication if necessary.
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
# resend configuration
RESEND_API_KEY=
RESEND_API_URL=https://api.resend.com
# smtp configuration
@ -344,7 +353,8 @@ SMTP_USERNAME=123
SMTP_PASSWORD=abc
SMTP_USE_TLS=true
SMTP_OPPORTUNISTIC_TLS=false
# Sendgid configuration
SENDGRID_API_KEY=
# Sentry configuration
SENTRY_DSN=

@ -1,6 +1,4 @@
exclude = [
"migrations/*",
]
exclude = ["migrations/*"]
line-length = 120
[format]
@ -77,6 +75,7 @@ ignore = [
"SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false
"UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/
]
[lint.per-file-ignores]

@ -281,6 +281,7 @@ def migrate_knowledge_vector_database():
VectorType.ELASTICSEARCH,
VectorType.OPENGAUSS,
VectorType.TABLESTORE,
VectorType.MATRIXONE,
}
lower_collection_vector_types = {
VectorType.ANALYTICDB,

@ -1,8 +1,11 @@
import logging
from pathlib import Path
from typing import Any
from pydantic.fields import FieldInfo
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict, TomlConfigSettingsSource
from libs.file_utils import search_file_upwards
from .deploy import DeploymentConfig
from .enterprise import EnterpriseFeatureConfig
@ -99,4 +102,12 @@ class DifyConfig(
RemoteSettingsSourceFactory(settings_cls),
dotenv_settings,
file_secret_settings,
TomlConfigSettingsSource(
settings_cls=settings_cls,
toml_file=search_file_upwards(
base_dir_path=Path(__file__).parent,
target_file_name="pyproject.toml",
max_search_parent_depth=2,
),
),
)

@ -609,7 +609,7 @@ class MailConfig(BaseSettings):
"""
MAIL_TYPE: Optional[str] = Field(
description="Email service provider type ('smtp' or 'resend'), default to None.",
description="Email service provider type ('smtp' or 'resend' or 'sendGrid), default to None.",
default=None,
)
@ -663,6 +663,11 @@ class MailConfig(BaseSettings):
default=50,
)
SENDGRID_API_KEY: Optional[str] = Field(
description="API key for SendGrid service",
default=None,
)
class RagEtlConfig(BaseSettings):
"""

@ -24,6 +24,7 @@ from .vdb.couchbase_config import CouchbaseConfig
from .vdb.elasticsearch_config import ElasticsearchConfig
from .vdb.huawei_cloud_config import HuaweiCloudConfig
from .vdb.lindorm_config import LindormConfig
from .vdb.matrixone_config import MatrixoneConfig
from .vdb.milvus_config import MilvusConfig
from .vdb.myscale_config import MyScaleConfig
from .vdb.oceanbase_config import OceanBaseVectorConfig
@ -222,6 +223,10 @@ class CeleryConfig(DatabaseConfig):
default=None,
)
CELERY_SENTINEL_PASSWORD: Optional[str] = Field(
description="Password of the Redis Sentinel master.",
default=None,
)
CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
description="Timeout for Redis Sentinel socket operations in seconds.",
default=0.1,
@ -323,5 +328,6 @@ class MiddlewareConfig(
OpenGaussConfig,
TableStoreConfig,
DatasetQueueMonitorConfig,
MatrixoneConfig,
):
pass

@ -0,0 +1,14 @@
from pydantic import BaseModel, Field
class MatrixoneConfig(BaseModel):
"""Matrixone vector database configuration."""
MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server")
MATRIXONE_PORT: int = Field(default=6001, description="Port number of the Matrixone server")
MATRIXONE_USER: str = Field(default="dump", description="Username for authenticating with Matrixone")
MATRIXONE_PASSWORD: str = Field(default="111", description="Password for authenticating with Matrixone")
MATRIXONE_DATABASE: str = Field(default="dify", description="Name of the Matrixone database to connect to")
MATRIXONE_METRIC: str = Field(
default="l2", description="Distance metric type for vector similarity search (cosine or l2)"
)

@ -1,17 +1,13 @@
from pydantic import Field
from pydantic_settings import BaseSettings
from configs.packaging.pyproject import PyProjectConfig, PyProjectTomlConfig
class PackagingInfo(BaseSettings):
class PackagingInfo(PyProjectTomlConfig):
"""
Packaging build information
"""
CURRENT_VERSION: str = Field(
description="Dify version",
default="1.4.3",
)
COMMIT_SHA: str = Field(
description="SHA-1 checksum of the git commit used to build the app",
default="",

@ -0,0 +1,17 @@
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings
class PyProjectConfig(BaseModel):
version: str = Field(description="Dify version", default="")
class PyProjectTomlConfig(BaseSettings):
"""
configs in api/pyproject.toml
"""
project: PyProjectConfig = Field(
description="configs in the project section of pyproject.toml",
default=PyProjectConfig(),
)

@ -63,6 +63,7 @@ from .app import (
statistic,
workflow,
workflow_app_log,
workflow_draft_variable,
workflow_run,
workflow_statistic,
)

@ -56,8 +56,7 @@ class InsertExploreAppListApi(Resource):
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
args = parser.parse_args()
with Session(db.engine) as session:
app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
app = db.session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
if not app:
raise NotFound(f"App '{args['app_id']}' is not found")

@ -17,6 +17,8 @@ from libs.login import login_required
from models import Account
from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
class AppImportApi(Resource):
@ -60,7 +62,9 @@ class AppImportApi(Resource):
app_id=args.get("app_id"),
)
session.commit()
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
# update web app setting as private
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:

@ -90,23 +90,11 @@ class ChatMessageTextApi(Resource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
if text_to_speech is None:
raise ValueError("TTS is not enabled")
voice = args.get("voice") or text_to_speech.get("voice")
else:
try:
if app_model.app_model_config is None:
raise ValueError("AppModelConfig not found")
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
except Exception:
voice = None
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)
voice = args.get("voice", None)
response = AudioService.transcript_tts(
app_model=app_model, text=text, voice=voice, message_id=message_id, is_draft=True
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")

@ -1,5 +1,6 @@
import json
import logging
from collections.abc import Sequence
from typing import cast
from flask import abort, request
@ -18,10 +19,12 @@ from controllers.console.app.error import (
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from extensions.ext_database import db
from factories import variable_factory
from factories import file_factory, variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
@ -30,6 +33,7 @@ from libs.login import current_user, login_required
from models import App
from models.account import Account
from models.model import AppMode
from models.workflow import Workflow
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError
@ -38,6 +42,24 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
logger = logging.getLogger(__name__)
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
# at the controller level rather than in the workflow logic. This would improve separation
# of concerns and make the code more maintainable.
def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence[File]:
files = files or []
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
file_objs: Sequence[File] = []
if file_extra_config is None:
return file_objs
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=workflow.tenant_id,
config=file_extra_config,
)
return file_objs
class DraftWorkflowApi(Resource):
@setup_required
@login_required
@ -402,15 +424,30 @@ class DraftWorkflowNodeRunApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("query", type=str, required=False, location="json", default="")
parser.add_argument("files", type=list, location="json", default=[])
args = parser.parse_args()
inputs = args.get("inputs")
if inputs == None:
user_inputs = args.get("inputs")
if user_inputs is None:
raise ValueError("missing inputs")
workflow_srv = WorkflowService()
# fetch draft workflow by app_model
draft_workflow = workflow_srv.get_draft_workflow(app_model=app_model)
if not draft_workflow:
raise ValueError("Workflow not initialized")
files = _parse_file(draft_workflow, args.get("files"))
workflow_service = WorkflowService()
workflow_node_execution = workflow_service.run_draft_workflow_node(
app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user
app_model=app_model,
draft_workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,
account=current_user,
query=args.get("query", ""),
files=files,
)
return workflow_node_execution
@ -731,6 +768,27 @@ class WorkflowByIdApi(Resource):
return None, 204
class DraftWorkflowNodeLastRunApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_node_execution_fields)
def get(self, app_model: App, node_id: str):
srv = WorkflowService()
workflow = srv.get_draft_workflow(app_model)
if not workflow:
raise NotFound("Workflow not found")
node_exec = srv.get_node_last_run(
app_model=app_model,
workflow=workflow,
node_id=node_id,
)
if node_exec is None:
raise NotFound("last run not found")
return node_exec
api.add_resource(
DraftWorkflowApi,
"/apps/<uuid:app_id>/workflows/draft",
@ -795,3 +853,7 @@ api.add_resource(
WorkflowByIdApi,
"/apps/<uuid:app_id>/workflows/<string:workflow_id>",
)
api.add_resource(
DraftWorkflowNodeLastRunApi,
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/last-run",
)

@ -34,6 +34,20 @@ class WorkflowAppLogApi(Resource):
parser.add_argument(
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
)
parser.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
parser.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()
@ -57,6 +71,8 @@ class WorkflowAppLogApi(Resource):
created_at_after=args.created_at__after,
page=args.page,
limit=args.limit,
created_by_end_user_session_id=args.created_by_end_user_session_id,
created_by_account=args.created_by_account,
)
return workflow_app_log_pagination

@ -0,0 +1,421 @@
import logging
from typing import Any, NoReturn
from flask import Response
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.app.error import (
DraftWorkflowNotExist,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models import App, AppMode, db
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
if isinstance(value, FileSegment):
return value.value.model_dump()
elif isinstance(value, ArrayFileSegment):
return [i.model_dump() for i in value.value]
elif isinstance(value, SegmentGroup):
return [_convert_values_to_json_serializable_object(i) for i in value.value]
else:
return value.value
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
value = variable.get_value()
# create a copy of the value to avoid affecting the model cache.
value = value.model_copy(deep=True)
# Refresh the url signature before returning it to client.
if isinstance(value, FileSegment):
file = value.value
file.remote_url = file.generate_url()
elif isinstance(value, ArrayFileSegment):
files = value.value
for file in files:
file.remote_url = file.generate_url()
return _convert_values_to_json_serializable_object(value)
def _create_pagination_parser():
parser = reqparse.RequestParser()
parser.add_argument(
"page",
type=inputs.int_range(1, 100_000),
required=False,
default=1,
location="args",
help="the page of data requested",
)
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
return parser
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
"id": fields.String,
"type": fields.String(attribute=lambda model: model.get_variable_type()),
"name": fields.String,
"description": fields.String,
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
"value_type": fields.String,
"edited": fields.Boolean(attribute=lambda model: model.edited),
"visible": fields.Boolean,
}
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
value=fields.Raw(attribute=_serialize_var_value),
)
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
"id": fields.String,
"type": fields.String(attribute=lambda _: "env"),
"name": fields.String,
"description": fields.String,
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
"value_type": fields.String,
"edited": fields.Boolean(attribute=lambda model: model.edited),
"visible": fields.Boolean,
}
_WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)),
}
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
return var_list.variables
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
"total": fields.Raw(),
}
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
}
def _api_prerequisite(f):
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
- Dify has been property setup.
- The request user has logged in and initialized.
- The requested app is a workflow or a chat flow.
- The request user has the edit permission for the app.
"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def wrapper(*args, **kwargs):
if not current_user.is_editor:
raise Forbidden()
return f(*args, **kwargs)
return wrapper
class WorkflowVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
def get(self, app_model: App):
"""
Get draft workflow
"""
parser = _create_pagination_parser()
args = parser.parse_args()
# fetch draft workflow by app_model
workflow_service = WorkflowService()
workflow_exist = workflow_service.is_workflow_exist(app_model=app_model)
if not workflow_exist:
raise DraftWorkflowNotExist()
# fetch draft workflow by app_model
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
workflow_vars = draft_var_srv.list_variables_without_values(
app_id=app_model.id,
page=args.page,
limit=args.limit,
)
return workflow_vars
@_api_prerequisite
def delete(self, app_model: App):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
draft_var_srv.delete_workflow_variables(app_model.id)
db.session.commit()
return Response("", 204)
def validate_node_id(node_id: str) -> NoReturn | None:
if node_id in [
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
]:
# NOTE(QuantumGhost): While we store the system and conversation variables as node variables
# with specific `node_id` in database, we still want to make the API separated. By disallowing
# accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
# we mitigate the risk that user of the API depending on the implementation detail of the API.
#
# ref: [Hyrum's Law](https://www.hyrumslaw.com/)
raise InvalidArgumentError(
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
)
return None
class NodeVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, app_model: App, node_id: str):
validate_node_id(node_id)
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
node_vars = draft_var_srv.list_node_variables(app_model.id, node_id)
return node_vars
@_api_prerequisite
def delete(self, app_model: App, node_id: str):
validate_node_id(node_id)
srv = WorkflowDraftVariableService(db.session())
srv.delete_node_variables(app_model.id, node_id)
db.session.commit()
return Response("", 204)
class VariableApi(Resource):
_PATCH_NAME_FIELD = "name"
_PATCH_VALUE_FIELD = "value"
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
def get(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
return variable
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
def patch(self, app_model: App, variable_id: str):
# Request payload for file types:
#
# Local File:
#
# {
# "type": "image",
# "transfer_method": "local_file",
# "url": "",
# "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190"
# }
#
# Remote File:
#
#
# {
# "type": "image",
# "transfer_method": "remote_url",
# "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=",
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# }
parser = reqparse.RequestParser()
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
# Parse 'value' field as-is to maintain its original data structure
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
args = parser.parse_args(strict=True)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
new_name = args.get(self._PATCH_NAME_FIELD, None)
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
if new_name is None and raw_value is None:
return variable
new_value = None
if raw_value is not None:
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()
return variable
@_api_prerequisite
def delete(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
draft_var_srv.delete_variable(variable)
db.session.commit()
return Response("", 204)
class VariableResetApi(Resource):
@_api_prerequisite
def put(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
workflow_srv = WorkflowService()
draft_workflow = workflow_srv.get_draft_workflow(app_model)
if draft_workflow is None:
raise NotFoundError(
f"Draft workflow not found, app_id={app_model.id}",
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
db.session.commit()
if resetted is None:
return Response("", 204)
else:
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
if node_id == CONVERSATION_VARIABLE_NODE_ID:
draft_vars = draft_var_srv.list_conversation_variables(app_model.id)
elif node_id == SYSTEM_VARIABLE_NODE_ID:
draft_vars = draft_var_srv.list_system_variables(app_model.id)
else:
draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id)
return draft_vars
class ConversationVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, app_model: App):
# NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
# so their IDs can be returned to the caller.
workflow_srv = WorkflowService()
draft_workflow = workflow_srv.get_draft_workflow(app_model)
if draft_workflow is None:
raise NotFoundError(description=f"draft workflow not found, id={app_model.id}")
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
db.session.commit()
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
class SystemVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, app_model: App):
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
class EnvironmentVariableCollectionApi(Resource):
@_api_prerequisite
def get(self, app_model: App):
"""
Get draft workflow
"""
# fetch draft workflow by app_model
workflow_service = WorkflowService()
workflow = workflow_service.get_draft_workflow(app_model=app_model)
if workflow is None:
raise DraftWorkflowNotExist()
env_vars = workflow.environment_variables
env_vars_list = []
for v in env_vars:
env_vars_list.append(
{
"id": v.id,
"type": "env",
"name": v.name,
"description": v.description,
"selector": v.selector,
"value_type": v.value_type.value,
"value": v.value,
# Do not track edited for env vars.
"edited": False,
"visible": True,
"editable": True,
}
)
return {"items": env_vars_list}
api.add_resource(
WorkflowVariableCollectionApi,
"/apps/<uuid:app_id>/workflows/draft/variables",
)
api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
api.add_resource(VariableResetApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables")
api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables")
api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables")

@ -8,6 +8,15 @@ from libs.login import current_user
from models import App, AppMode
def _load_app_model(app_id: str) -> Optional[App]:
app_model = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
return app_model
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func):
@wraps(view_func)
@ -20,11 +29,7 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[
del kwargs["app_id"]
app_model = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
app_model = _load_app_model(app_id)
if not app_model:
raise AppNotFoundError()

@ -41,7 +41,7 @@ class OAuthDataSource(Resource):
if not internal_secret:
return ({"error": "Internal secret is not set"},)
oauth_provider.save_internal_access_token(internal_secret)
return {"data": ""}
return {"data": "internal"}
else:
auth_url = oauth_provider.get_authorization_url()
return {"data": auth_url}, 200

@ -686,6 +686,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.TABLESTORE
| VectorType.HUAWEI_CLOUD
| VectorType.TENCENT
| VectorType.MATRIXONE
):
return {
"retrieval_method": [
@ -733,6 +734,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.TABLESTORE
| VectorType.TENCENT
| VectorType.HUAWEI_CLOUD
| VectorType.MATRIXONE
):
return {
"retrieval_method": [

@ -5,7 +5,7 @@ from typing import cast
from flask import request
from flask_login import current_user
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
from flask_restful import Resource, marshal, marshal_with, reqparse
from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
@ -43,7 +43,6 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.document_fields import (
dataset_and_document_fields,
document_fields,
@ -54,8 +53,6 @@ from libs.login import login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from tasks.add_document_to_index_task import add_document_to_index_task
from tasks.remove_document_from_index_task import remove_document_from_index_task
class DocumentResource(Resource):
@ -242,12 +239,10 @@ class DatasetDocumentListApi(Resource):
return response
documents_and_batch_fields = {"documents": fields.List(fields.Nested(document_fields)), "batch": fields.String}
@setup_required
@login_required
@account_initialization_required
@marshal_with(documents_and_batch_fields)
@marshal_with(dataset_and_document_fields)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
@ -293,6 +288,8 @@ class DatasetDocumentListApi(Resource):
try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user)
dataset = DatasetService.get_dataset(dataset_id)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
@ -300,7 +297,7 @@ class DatasetDocumentListApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
return {"documents": documents, "batch": batch}
return {"dataset": dataset, "documents": documents, "batch": batch}
@setup_required
@login_required
@ -862,77 +859,16 @@ class DocumentStatusApi(DocumentResource):
DatasetService.check_dataset_permission(dataset, current_user)
document_ids = request.args.getlist("document_id")
for document_id in document_ids:
document = self.get_document(dataset_id, document_id)
indexing_cache_key = "document_{}_indexing".format(document.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise InvalidActionError(f"Document:{document.name} is being indexed, please try again later")
if action == "enable":
if document.enabled:
continue
document.enabled = True
document.disabled_at = None
document.disabled_by = None
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
add_document_to_index_task.delay(document_id)
elif action == "disable":
if not document.completed_at or document.indexing_status != "completed":
raise InvalidActionError(f"Document: {document.name} is not completed.")
if not document.enabled:
continue
document.enabled = False
document.disabled_at = datetime.now(UTC).replace(tzinfo=None)
document.disabled_by = current_user.id
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
remove_document_from_index_task.delay(document_id)
elif action == "archive":
if document.archived:
continue
document.archived = True
document.archived_at = datetime.now(UTC).replace(tzinfo=None)
document.archived_by = current_user.id
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
if document.enabled:
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
remove_document_from_index_task.delay(document_id)
elif action == "un_archive":
if not document.archived:
continue
document.archived = False
document.archived_at = None
document.archived_by = None
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
add_document_to_index_task.delay(document_id)
try:
DocumentService.batch_update_document_status(dataset, document_ids, action, current_user)
except services.errors.document.DocumentIndexingError as e:
raise InvalidActionError(str(e))
except ValueError as e:
raise InvalidActionError(str(e))
except NotFound as e:
raise NotFound(str(e))
else:
raise InvalidActionError()
return {"result": "success"}, 200

@ -18,7 +18,6 @@ from controllers.console.app.error import (
from controllers.console.explore.wraps import InstalledAppResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from models.model import AppMode
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@ -79,19 +78,9 @@ class ChatTextApi(InstalledAppResource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") or text_to_speech.get("voice")
else:
try:
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
except Exception:
voice = None
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)
voice = args.get("voice", None)
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")

@ -18,7 +18,7 @@ class VersionApi(Resource):
check_update_url = dify_config.CHECK_UPDATE_URL
result = {
"version": dify_config.CURRENT_VERSION,
"version": dify_config.project.version,
"release_date": "",
"release_notes": "",
"can_auto_update": False,

@ -15,7 +15,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden()
tenant_id = current_user.current_tenant_id
@ -64,7 +64,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str, config_id: str):
if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden()
tenant_id = current_user.current_tenant_id

@ -85,6 +85,7 @@ class MemberInviteEmailApi(Resource):
return {
"result": "success",
"invitation_results": invitation_results,
"tenant_id": str(current_user.current_tenant.id),
}, 201
@ -110,7 +111,7 @@ class MemberCancelInviteApi(Resource):
except Exception as e:
raise ValueError(str(e))
return {"result": "success"}, 204
return {"result": "success", "tenant_id": str(current_user.current_tenant.id)}, 200
class MemberUpdateRoleApi(Resource):

@ -13,6 +13,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginDaemonClientSideError
from libs.login import login_required
from models.account import TenantPluginPermission
from services.plugin.plugin_parameter_service import PluginParameterService
from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
@ -497,6 +498,42 @@ class PluginFetchPermissionApi(Resource):
)
class PluginFetchDynamicSelectOptionsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
# check if the user is admin or owner
if not current_user.is_admin_or_owner:
raise Forbidden()
tenant_id = current_user.current_tenant_id
user_id = current_user.id
parser = reqparse.RequestParser()
parser.add_argument("plugin_id", type=str, required=True, location="args")
parser.add_argument("provider", type=str, required=True, location="args")
parser.add_argument("action", type=str, required=True, location="args")
parser.add_argument("parameter", type=str, required=True, location="args")
parser.add_argument("provider_type", type=str, required=True, location="args")
args = parser.parse_args()
try:
options = PluginParameterService.get_dynamic_select_options(
tenant_id,
user_id,
args["plugin_id"],
args["provider"],
args["action"],
args["parameter"],
args["provider_type"],
)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return jsonable_encoder({"options": options})
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions")
@ -521,3 +558,5 @@ api.add_resource(PluginFetchMarketplacePkgApi, "/workspaces/current/plugin/marke
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")
api.add_resource(PluginFetchDynamicSelectOptionsApi, "/workspaces/current/plugin/parameters/dynamic-options")

@ -87,7 +87,5 @@ class PluginUploadFileApi(Resource):
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return tool_file, 201
api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin")

@ -17,6 +17,7 @@ from core.plugin.entities.request import (
RequestInvokeApp,
RequestInvokeEncrypt,
RequestInvokeLLM,
RequestInvokeLLMWithStructuredOutput,
RequestInvokeModeration,
RequestInvokeParameterExtractorNode,
RequestInvokeQuestionClassifierNode,
@ -47,6 +48,21 @@ class PluginInvokeLLMApi(Resource):
return length_prefixed_response(0xF, generator())
class PluginInvokeLLMWithStructuredOutputApi(Resource):
@setup_required
@plugin_inner_api_only
@get_user_tenant
@plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput)
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLMWithStructuredOutput):
def generator():
response = PluginModelBackwardsInvocation.invoke_llm_with_structured_output(
user_model.id, tenant_model, payload
)
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
return length_prefixed_response(0xF, generator())
class PluginInvokeTextEmbeddingApi(Resource):
@setup_required
@plugin_inner_api_only
@ -291,6 +307,7 @@ class PluginFetchAppInfoApi(Resource):
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
api.add_resource(PluginInvokeLLMWithStructuredOutputApi, "/invoke/llm/structured-output")
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
api.add_resource(PluginInvokeTTSApi, "/invoke/tts")

@ -29,7 +29,19 @@ class EnterpriseWorkspace(Resource):
tenant_was_created.send(tenant)
return {"message": "enterprise workspace created."}
resp = {
"id": tenant.id,
"name": tenant.name,
"plan": tenant.plan,
"status": tenant.status,
"created_at": tenant.created_at.isoformat() + "Z" if tenant.created_at else None,
"updated_at": tenant.updated_at.isoformat() + "Z" if tenant.updated_at else None,
}
return {
"message": "enterprise workspace created.",
"tenant": resp,
}
class EnterpriseWorkspaceNoOwnerEmail(Resource):

@ -20,7 +20,7 @@ from controllers.service_api.app.error import (
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from models.model import App, AppMode, EndUser
from models.model import App, EndUser
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@ -78,20 +78,9 @@ class TextApi(Resource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {})
voice = args.get("voice") or text_to_speech.get("voice")
else:
try:
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
except Exception:
voice = None
voice = args.get("voice", None)
response = AudioService.transcript_tts(
app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
)
return response

@ -135,6 +135,20 @@ class WorkflowAppLogApi(Resource):
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
parser.add_argument("created_at__before", type=str, location="args")
parser.add_argument("created_at__after", type=str, location="args")
parser.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
parser.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()
@ -158,6 +172,8 @@ class WorkflowAppLogApi(Resource):
created_at_after=args.created_at__after,
page=args.page,
limit=args.limit,
created_by_end_user_session_id=args.created_by_end_user_session_id,
created_by_account=args.created_by_account,
)
return workflow_app_log_pagination

@ -4,7 +4,7 @@ from werkzeug.exceptions import Forbidden, NotFound
import services.dataset_service
from controllers.service_api import api
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
from controllers.service_api.wraps import (
DatasetApiResource,
cloud_edition_billing_rate_limit_check,
@ -17,7 +17,7 @@ from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import tag_fields
from libs.login import current_user
from models.dataset import Dataset, DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.tag_service import TagService
@ -133,6 +133,22 @@ class DatasetListApi(DatasetApiResource):
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
if args.get("embedding_model_provider"):
DatasetService.check_embedding_model_setting(
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
)
if (
args.get("retrieval_model")
and args.get("retrieval_model").get("reranking_model")
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
):
DatasetService.check_reranking_model_setting(
tenant_id,
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
)
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=tenant_id,
@ -265,10 +281,20 @@ class DatasetApi(DatasetApiResource):
data = request.get_json()
# check embedding model setting
if data.get("indexing_technique") == "high_quality":
if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"):
DatasetService.check_embedding_model_setting(
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
)
if (
data.get("retrieval_model")
and data.get("retrieval_model").get("reranking_model")
and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
):
DatasetService.check_reranking_model_setting(
dataset.tenant_id,
data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
data.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
)
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
@ -329,6 +355,56 @@ class DatasetApi(DatasetApiResource):
raise DatasetInUseError()
class DocumentStatusApi(DatasetApiResource):
"""Resource for batch document status operations."""
def patch(self, tenant_id, dataset_id, action):
"""
Batch update document status.
Args:
tenant_id: tenant id
dataset_id: dataset id
action: action to perform (enable, disable, archive, un_archive)
Returns:
dict: A dictionary with a key 'result' and a value 'success'
int: HTTP status code 200 indicating that the operation was successful.
Raises:
NotFound: If the dataset with the given ID does not exist.
Forbidden: If the user does not have permission.
InvalidActionError: If the action is invalid or cannot be performed.
"""
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
# Check user's permission
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# Check dataset model setting
DatasetService.check_dataset_model_setting(dataset)
# Get document IDs from request body
data = request.get_json()
document_ids = data.get("document_ids", [])
try:
DocumentService.batch_update_document_status(dataset, document_ids, action, current_user)
except services.errors.document.DocumentIndexingError as e:
raise InvalidActionError(str(e))
except ValueError as e:
raise InvalidActionError(str(e))
return {"result": "success"}, 200
class DatasetTagsApi(DatasetApiResource):
@validate_dataset_token
@marshal_with(tag_fields)
@ -457,6 +533,7 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
api.add_resource(DatasetListApi, "/datasets")
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>")
api.add_resource(DatasetTagsApi, "/datasets/tags")
api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding")
api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding")

@ -3,7 +3,7 @@ import json
from flask import request
from flask_restful import marshal, reqparse
from sqlalchemy import desc, select
from werkzeug.exceptions import NotFound
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.errors import FilenameNotExistsError
@ -18,6 +18,7 @@ from controllers.service_api.app.error import (
from controllers.service_api.dataset.error import (
ArchivedDocumentImmutableError,
DocumentIndexingError,
InvalidMetadataError,
)
from controllers.service_api.wraps import (
DatasetApiResource,
@ -29,7 +30,7 @@ from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment
from services.dataset_service import DocumentService
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from services.file_service import FileService
@ -59,6 +60,7 @@ class DocumentAddByTextApi(DatasetApiResource):
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@ -74,6 +76,21 @@ class DocumentAddByTextApi(DatasetApiResource):
if text is None or name is None:
raise ValueError("Both 'text' and 'name' must be non-null values.")
if args.get("embedding_model_provider"):
DatasetService.check_embedding_model_setting(
tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
)
if (
args.get("retrieval_model")
and args.get("retrieval_model").get("reranking_model")
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
):
DatasetService.check_reranking_model_setting(
tenant_id,
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
)
upload_file = FileService.upload_text(text=str(text), text_name=str(name))
data_source = {
"type": "upload_file",
@ -124,6 +141,17 @@ class DocumentUpdateByTextApi(DatasetApiResource):
if not dataset:
raise ValueError("Dataset does not exist.")
if (
args.get("retrieval_model")
and args.get("retrieval_model").get("reranking_model")
and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
):
DatasetService.check_reranking_model_setting(
tenant_id,
args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
)
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
@ -183,11 +211,29 @@ class DocumentAddByFileApi(DatasetApiResource):
if not dataset:
raise ValueError("Dataset does not exist.")
if dataset.provider == "external":
raise ValueError("External datasets are not supported.")
indexing_technique = args.get("indexing_technique") or dataset.indexing_technique
if not indexing_technique:
raise ValueError("indexing_technique is required.")
args["indexing_technique"] = indexing_technique
if "embedding_model_provider" in args:
DatasetService.check_embedding_model_setting(
tenant_id, args["embedding_model_provider"], args["embedding_model"]
)
if (
"retrieval_model" in args
and args["retrieval_model"].get("reranking_model")
and args["retrieval_model"].get("reranking_model").get("reranking_provider_name")
):
DatasetService.check_reranking_model_setting(
tenant_id,
args["retrieval_model"].get("reranking_model").get("reranking_provider_name"),
args["retrieval_model"].get("reranking_model").get("reranking_model_name"),
)
# save file info
file = request.files["file"]
# check file
@ -258,6 +304,9 @@ class DocumentUpdateByFileApi(DatasetApiResource):
if not dataset:
raise ValueError("Dataset does not exist.")
if dataset.provider == "external":
raise ValueError("External datasets are not supported.")
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
@ -424,6 +473,101 @@ class DocumentIndexingStatusApi(DatasetApiResource):
return data
class DocumentDetailApi(DatasetApiResource):
METADATA_CHOICES = {"all", "only", "without"}
def get(self, tenant_id, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = self.get_dataset(dataset_id, tenant_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
if document.tenant_id != str(tenant_id):
raise Forbidden("No permission.")
metadata = request.args.get("metadata", "all")
if metadata not in self.METADATA_CHOICES:
raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
if metadata == "only":
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict()
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
"position": document.position,
"data_source_type": document.data_source_type,
"data_source_info": data_source_info,
"dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,
"name": document.name,
"created_from": document.created_from,
"created_by": document.created_by,
"created_at": document.created_at.timestamp(),
"tokens": document.tokens,
"indexing_status": document.indexing_status,
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
"updated_at": int(document.updated_at.timestamp()) if document.updated_at else None,
"indexing_latency": document.indexing_latency,
"error": document.error,
"enabled": document.enabled,
"disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None,
"disabled_by": document.disabled_by,
"archived": document.archived,
"segment_count": document.segment_count,
"average_segment_length": document.average_segment_length,
"hit_count": document.hit_count,
"display_status": document.display_status,
"doc_form": document.doc_form,
"doc_language": document.doc_language,
}
else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict()
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
"position": document.position,
"data_source_type": document.data_source_type,
"data_source_info": data_source_info,
"dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,
"name": document.name,
"created_from": document.created_from,
"created_by": document.created_by,
"created_at": document.created_at.timestamp(),
"tokens": document.tokens,
"indexing_status": document.indexing_status,
"completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
"updated_at": int(document.updated_at.timestamp()) if document.updated_at else None,
"indexing_latency": document.indexing_latency,
"error": document.error,
"enabled": document.enabled,
"disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None,
"disabled_by": document.disabled_by,
"archived": document.archived,
"doc_type": document.doc_type,
"doc_metadata": document.doc_metadata_details,
"segment_count": document.segment_count,
"average_segment_length": document.average_segment_length,
"hit_count": document.hit_count,
"display_status": document.display_status,
"doc_form": document.doc_form,
"doc_language": document.doc_language,
}
return response
api.add_resource(
DocumentAddByTextApi,
"/datasets/<uuid:dataset_id>/document/create_by_text",
@ -447,3 +591,4 @@ api.add_resource(
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(DocumentListApi, "/datasets/<uuid:dataset_id>/documents")
api.add_resource(DocumentIndexingStatusApi, "/datasets/<uuid:dataset_id>/documents/<string:batch>/indexing-status")
api.add_resource(DocumentDetailApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")

@ -9,7 +9,7 @@ class IndexApi(Resource):
return {
"welcome": "Dify OpenAPI",
"api_version": "v1",
"server_version": dify_config.CURRENT_VERSION,
"server_version": dify_config.project.version,
}

@ -11,13 +11,13 @@ from flask_restful import Resource
from pydantic import BaseModel
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, Unauthorized
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
from models.dataset import RateLimitLog
from models.dataset import Dataset, RateLimitLog
from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService
@ -317,3 +317,11 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
class DatasetApiResource(Resource):
method_decorators = [validate_dataset_token]
def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first()
if not dataset:
raise NotFound("Dataset not found.")
return dataset

@ -19,7 +19,7 @@ from controllers.web.error import (
from controllers.web.wraps import WebApiResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from models.model import App, AppMode
from models.model import App
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@ -77,21 +77,9 @@ class TextApi(WebApiResource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {})
voice = args.get("voice") or text_to_speech.get("voice")
else:
try:
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
except Exception:
voice = None
voice = args.get("voice", None)
response = AudioService.transcript_tts(
app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
)
return response

@ -139,3 +139,13 @@ class InvokeRateLimitError(BaseHTTPException):
error_code = "rate_limit_error"
description = "Rate Limit Error"
code = 429
class NotFoundError(BaseHTTPException):
error_code = "not_found"
code = 404
class InvalidArgumentError(BaseHTTPException):
error_code = "invalid_param"
code = 400

@ -104,6 +104,7 @@ class VariableEntity(BaseModel):
Variable Entity.
"""
# `variable` records the name of the variable in user inputs.
variable: str
label: str
description: str = ""

@ -27,15 +27,22 @@ from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaverFactory,
)
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.errors.message import MessageNotExistsError
from services.workflow_draft_variable_service import (
DraftVarLoader,
WorkflowDraftVariableService,
)
logger = logging.getLogger(__name__)
@ -116,6 +123,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
)
# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config:
@ -261,6 +273,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
return self._generate(
workflow=workflow,
@ -271,6 +290,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=None,
stream=streaming,
variable_loader=var_loader,
)
def single_loop_generate(
@ -336,6 +356,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
return self._generate(
workflow=workflow,
@ -346,6 +373,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=None,
stream=streaming,
variable_loader=var_loader,
)
def _generate(
@ -359,6 +387,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
conversation: Optional[Conversation] = None,
stream: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
"""
Generate App response.
@ -410,6 +439,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
"conversation_id": conversation.id,
"message_id": message.id,
"context": context,
"variable_loader": variable_loader,
},
)
@ -426,6 +456,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream,
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from),
)
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
@ -438,6 +469,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation_id: str,
message_id: str,
context: contextvars.Context,
variable_loader: VariableLoader,
) -> None:
"""
Generate worker in a new thread.
@ -454,8 +486,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError("Message not exists")
# chatbot app
runner = AdvancedChatAppRunner(
@ -464,6 +494,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
dialogue_count=self._dialogue_count,
variable_loader=variable_loader,
)
runner.run()
@ -497,6 +528,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
@ -523,6 +555,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream,
draft_var_saver_factory=draft_var_saver_factory,
)
try:

@ -19,6 +19,7 @@ from core.moderation.base import ModerationError
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.enums import UserFrom
@ -40,14 +41,17 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation: Conversation,
message: Message,
dialogue_count: int,
variable_loader: VariableLoader,
) -> None:
super().__init__(queue_manager)
super().__init__(queue_manager, variable_loader)
self.application_generate_entity = application_generate_entity
self.conversation = conversation
self.message = message
self._dialogue_count = dialogue_count
def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id
def run(self) -> None:
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)

@ -64,6 +64,7 @@ from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, W
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
@ -94,6 +95,7 @@ class AdvancedChatAppGenerateTaskPipeline:
dialogue_count: int,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
) -> None:
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
@ -153,6 +155,7 @@ class AdvancedChatAppGenerateTaskPipeline:
self._conversation_name_generate_thread: Thread | None = None
self._recorded_files: list[Mapping[str, Any]] = []
self._workflow_run_id: str = ""
self._draft_var_saver_factory = draft_var_saver_factory
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
@ -371,6 +374,7 @@ class AdvancedChatAppGenerateTaskPipeline:
workflow_node_execution=workflow_node_execution,
)
session.commit()
self._save_output_for_event(event, workflow_node_execution.id)
if node_finish_resp:
yield node_finish_resp
@ -390,6 +394,8 @@ class AdvancedChatAppGenerateTaskPipeline:
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if isinstance(event, QueueNodeExceptionEvent):
self._save_output_for_event(event, workflow_node_execution.id)
if node_finish_resp:
yield node_finish_resp
@ -759,3 +765,15 @@ class AdvancedChatAppGenerateTaskPipeline:
if not message:
raise ValueError(f"Message not found: {self._message_id}")
return message
def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str):
with Session(db.engine) as session, session.begin():
saver = self._draft_var_saver_factory(
session=session,
app_id=self._application_generate_entity.app_config.app_id,
node_id=event.node_id,
node_type=event.node_type,
node_execution_id=node_execution_id,
enclosing_node_id=event.in_loop_id or event.in_iteration_id,
)
saver.save(event.process_data, event.outputs)

@ -26,7 +26,6 @@ from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, EndUser
from services.conversation_service import ConversationService
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
@ -124,6 +123,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
override_model_config_dict["retriever_resource"] = {"enabled": True}
# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args.get("files") or []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
@ -233,8 +237,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError("Message not exists")
# chatbot app
runner = AgentChatAppRunner()

@ -1,10 +1,20 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union, final
from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileUploadConfig
from core.workflow.nodes.enums import NodeType
from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaver,
DraftVariableSaverFactory,
NoopDraftVariableSaver,
)
from factories import file_factory
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
if TYPE_CHECKING:
from core.app.app_config.entities import VariableEntity
@ -159,3 +169,38 @@ class BaseAppGenerator:
yield f"event: {message}\n\n"
return gen()
@final
@staticmethod
def _get_draft_var_saver_factory(invoke_from: InvokeFrom) -> DraftVariableSaverFactory:
if invoke_from == InvokeFrom.DEBUGGER:
def draft_var_saver_factory(
session: Session,
app_id: str,
node_id: str,
node_type: NodeType,
node_execution_id: str,
enclosing_node_id: str | None = None,
) -> DraftVariableSaver:
return DraftVariableSaverImpl(
session=session,
app_id=app_id,
node_id=node_id,
node_type=node_type,
node_execution_id=node_execution_id,
enclosing_node_id=enclosing_node_id,
)
else:
def draft_var_saver_factory(
session: Session,
app_id: str,
node_id: str,
node_type: NodeType,
node_execution_id: str,
enclosing_node_id: str | None = None,
) -> DraftVariableSaver:
return NoopDraftVariableSaver()
return draft_var_saver_factory

@ -25,7 +25,6 @@ from factories import file_factory
from models.account import Account
from models.model import App, EndUser
from services.conversation_service import ConversationService
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
@ -115,6 +114,11 @@ class ChatAppGenerator(MessageBasedAppGenerator):
override_model_config_dict["retriever_resource"] = {"enabled": True}
# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
@ -219,8 +223,6 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError("Message not exists")
# chatbot app
runner = ChatAppRunner()

@ -44,10 +44,12 @@ from core.app.entities.task_entities import (
)
from core.file import FILE_MODEL_IDENTITY, File
from core.tools.tool_manager import ToolManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models import (
Account,
CreatorUserRole,
@ -125,7 +127,7 @@ class WorkflowResponseConverter:
id=workflow_execution.id_,
workflow_id=workflow_execution.workflow_id,
status=workflow_execution.status,
outputs=workflow_execution.outputs,
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(workflow_execution.outputs),
error=workflow_execution.error_message,
elapsed_time=workflow_execution.elapsed_time,
total_tokens=workflow_execution.total_tokens,
@ -202,6 +204,8 @@ class WorkflowResponseConverter:
if not workflow_node_execution.finished_at:
return None
json_converter = WorkflowRuntimeTypeConverter()
return NodeFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_execution_id,
@ -214,7 +218,7 @@ class WorkflowResponseConverter:
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs,
process_data=workflow_node_execution.process_data,
outputs=workflow_node_execution.outputs,
outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
@ -245,6 +249,8 @@ class WorkflowResponseConverter:
if not workflow_node_execution.finished_at:
return None
json_converter = WorkflowRuntimeTypeConverter()
return NodeRetryStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_execution_id,
@ -257,7 +263,7 @@ class WorkflowResponseConverter:
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs,
process_data=workflow_node_execution.process_data,
outputs=workflow_node_execution.outputs,
outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
@ -376,6 +382,7 @@ class WorkflowResponseConverter:
workflow_execution_id: str,
event: QueueIterationCompletedEvent,
) -> IterationNodeCompletedStreamResponse:
json_converter = WorkflowRuntimeTypeConverter()
return IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
@ -384,7 +391,7 @@ class WorkflowResponseConverter:
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
outputs=event.outputs,
outputs=json_converter.to_json_encodable(event.outputs),
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
@ -463,7 +470,7 @@ class WorkflowResponseConverter:
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
outputs=event.outputs,
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs),
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
@ -500,7 +507,8 @@ class WorkflowResponseConverter:
# Convert to tuple to match Sequence type
return tuple(flattened_files)
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
@classmethod
def _fetch_files_from_variable_value(cls, value: Union[dict, list, Segment]) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from variable value
:param value: variable value
@ -509,20 +517,30 @@ class WorkflowResponseConverter:
if not value:
return []
files = []
if isinstance(value, list):
files: list[Mapping[str, Any]] = []
if isinstance(value, FileSegment):
files.append(value.value.to_dict())
elif isinstance(value, ArrayFileSegment):
files.extend([i.to_dict() for i in value.value])
elif isinstance(value, File):
files.append(value.to_dict())
elif isinstance(value, list):
for item in value:
file = self._get_file_var_from_value(item)
file = cls._get_file_var_from_value(item)
if file:
files.append(file)
elif isinstance(value, dict):
file = self._get_file_var_from_value(value)
elif isinstance(
value,
dict,
):
file = cls._get_file_var_from_value(value)
if file:
files.append(file)
return files
def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None:
@classmethod
def _get_file_var_from_value(cls, value: Union[dict, list]) -> Mapping[str, Any] | None:
"""
Get file var from value
:param value: variable value

@ -101,6 +101,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
)
# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
@ -196,8 +201,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
try:
# get message
message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError()
# chatbot app
runner = CompletionAppRunner()

@ -29,6 +29,7 @@ from models.enums import CreatorUserRole
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
@ -251,7 +252,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
return introduction or ""
def _get_conversation(self, conversation_id: str):
def _get_conversation(self, conversation_id: str) -> Conversation:
"""
Get conversation by conversation id
:param conversation_id: conversation id
@ -260,11 +261,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
if not conversation:
raise ConversationNotExistsError()
raise ConversationNotExistsError("Conversation not exists")
return conversation
def _get_message(self, message_id: str) -> Optional[Message]:
def _get_message(self, message_id: str) -> Message:
"""
Get message by message id
:param message_id: message id
@ -272,4 +273,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
"""
message = db.session.query(Message).filter(Message.id == message_id).first()
if message is None:
raise MessageNotExistsError("Message not exists")
return message

@ -25,13 +25,16 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
logger = logging.getLogger(__name__)
@ -94,6 +97,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
files: Sequence[Mapping[str, Any]] = args.get("files") or []
# parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
system_files = file_factory.build_from_mappings(
mappings=files,
@ -186,6 +194,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
@ -211,6 +220,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
# new thread with request context and contextvars
context = contextvars.copy_context()
# release database connection, because the following new thread operations may take a long time
db.session.close()
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
@ -219,11 +231,16 @@ class WorkflowAppGenerator(BaseAppGenerator):
"queue_manager": queue_manager,
"context": context,
"workflow_thread_pool_id": workflow_thread_pool_id,
"variable_loader": variable_loader,
},
)
worker_thread.start()
draft_var_saver_factory = self._get_draft_var_saver_factory(
invoke_from,
)
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
@ -232,6 +249,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
draft_var_saver_factory=draft_var_saver_factory,
stream=streaming,
)
@ -303,6 +321,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
return self._generate(
app_model=app_model,
@ -313,6 +338,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
)
def single_loop_generate(
@ -379,7 +405,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
return self._generate(
app_model=app_model,
workflow=workflow,
@ -389,6 +421,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
)
def _generate_worker(
@ -397,6 +430,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None,
) -> None:
"""
@ -415,6 +449,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
variable_loader=variable_loader,
)
runner.run()
@ -445,6 +480,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@ -465,6 +501,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
draft_var_saver_factory=draft_var_saver_factory,
stream=stream,
)

@ -12,6 +12,7 @@ from core.app.entities.app_invoke_entities import (
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.enums import UserFrom
@ -30,6 +31,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
self,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None,
) -> None:
"""
@ -37,10 +39,13 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
:param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id
"""
super().__init__(queue_manager, variable_loader)
self.application_generate_entity = application_generate_entity
self.queue_manager = queue_manager
self.workflow_thread_pool_id = workflow_thread_pool_id
def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id
def run(self) -> None:
"""
Run application

@ -56,6 +56,7 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import SystemVariableKey
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
@ -87,6 +88,7 @@ class WorkflowAppGenerateTaskPipeline:
stream: bool,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
) -> None:
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
@ -131,6 +133,8 @@ class WorkflowAppGenerateTaskPipeline:
self._application_generate_entity = application_generate_entity
self._workflow_features_dict = workflow.features_dict
self._workflow_run_id = ""
self._invoke_from = queue_manager._invoke_from
self._draft_var_saver_factory = draft_var_saver_factory
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@ -322,6 +326,8 @@ class WorkflowAppGenerateTaskPipeline:
workflow_node_execution=workflow_node_execution,
)
self._save_output_for_event(event, workflow_node_execution.id)
if node_success_response:
yield node_success_response
elif isinstance(
@ -339,6 +345,8 @@ class WorkflowAppGenerateTaskPipeline:
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if isinstance(event, QueueNodeExceptionEvent):
self._save_output_for_event(event, workflow_node_execution.id)
if node_failed_response:
yield node_failed_response
@ -593,3 +601,15 @@ class WorkflowAppGenerateTaskPipeline:
)
return response
def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str):
with Session(db.engine) as session, session.begin():
saver = self._draft_var_saver_factory(
session=session,
app_id=self._application_generate_entity.app_config.app_id,
node_id=event.node_id,
node_type=event.node_type,
node_execution_id=node_execution_id,
enclosing_node_id=event.in_loop_id or event.in_iteration_id,
)
saver.save(event.process_data, event.outputs)

@ -62,6 +62,7 @@ from core.workflow.graph_engine.entities.event import (
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes import NodeType
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App
@ -69,8 +70,12 @@ from models.workflow import Workflow
class WorkflowBasedAppRunner(AppRunner):
def __init__(self, queue_manager: AppQueueManager):
def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None:
self.queue_manager = queue_manager
self._variable_loader = variable_loader
def _get_app_id(self) -> str:
raise NotImplementedError("not implemented")
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
"""
@ -173,6 +178,13 @@ class WorkflowBasedAppRunner(AppRunner):
except NotImplementedError:
variable_mapping = {}
load_into_variable_pool(
variable_loader=self._variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
user_inputs=user_inputs,
)
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
@ -262,6 +274,12 @@ class WorkflowBasedAppRunner(AppRunner):
)
except NotImplementedError:
variable_mapping = {}
load_into_variable_pool(
self._variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
user_inputs=user_inputs,
)
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
@ -376,6 +394,7 @@ class WorkflowBasedAppRunner(AppRunner):
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, NodeRunFailedEvent):
self._publish_event(
QueueNodeFailedEvent(
@ -438,6 +457,7 @@ class WorkflowBasedAppRunner(AppRunner):
in_loop_id=event.in_loop_id,
)
)
elif isinstance(event, NodeInIterationFailedEvent):
self._publish_event(
QueueNodeInIterationFailedEvent(

@ -17,9 +17,24 @@ class InvokeFrom(Enum):
Invoke From.
"""
# SERVICE_API indicates that this invocation is from an API call to Dify app.
#
# Description of service api in Dify docs:
# https://docs.dify.ai/en/guides/application-publishing/developing-with-apis
SERVICE_API = "service-api"
# WEB_APP indicates that this invocation is from
# the web app of the workflow (or chatflow).
#
# Description of web app in Dify docs:
# https://docs.dify.ai/en/guides/application-publishing/launch-your-webapp-quickly/README
WEB_APP = "web-app"
# EXPLORE indicates that this invocation is from
# the workflow (or chatflow) explore page.
EXPLORE = "explore"
# DEBUGGER indicates that this invocation is from
# the workflow (or chatflow) edit page.
DEBUGGER = "debugger"
@classmethod

@ -19,6 +19,7 @@ from core.app.entities.task_entities import (
from core.errors.error import QuotaExceededError
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.moderation.output_moderation import ModerationRule, OutputModeration
from models.enums import MessageStatus
from models.model import Message
logger = logging.getLogger(__name__)
@ -62,7 +63,7 @@ class BasedGenerateTaskPipeline:
return err
err_desc = self._error_to_desc(err)
message.status = "error"
message.status = MessageStatus.ERROR
message.error = err_desc
return err

@ -395,6 +395,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message.provider_response_latency = time.perf_counter() - self._start_at
message.total_price = usage.total_price
message.currency = usage.currency
self._task_state.llm_result.usage.latency = message.provider_response_latency
message.message_metadata = self._task_state.metadata.model_dump_json()
if trace_manager:

@ -15,6 +15,11 @@ class CommonParameterType(StrEnum):
MODEL_SELECTOR = "model-selector"
TOOLS_SELECTOR = "array[tools]"
# Dynamic select parameter
# Once you are not sure about the available options until authorization is done
# eg: Select a Slack channel from a Slack workspace
DYNAMIC_SELECT = "dynamic-select"
# TOOL_SELECTOR = "tool-selector"

@ -1 +1,11 @@
from typing import Any
# TODO(QuantumGhost): Refactor variable type identification. Instead of directly
# comparing `dify_model_identity` with constants throughout the codebase, extract
# this logic into a dedicated function. This would encapsulate the implementation
# details of how different variable types are identified.
FILE_MODEL_IDENTITY = "__dify__file__"
def maybe_file_object(o: Any) -> bool:
return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY

@ -51,7 +51,7 @@ class File(BaseModel):
# It should be set to `ToolFile.id` when `transfer_method` is `tool_file`.
related_id: Optional[str] = None
filename: Optional[str] = None
extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
extension: Optional[str] = Field(default=None, description="File extension, should contain dot")
mime_type: Optional[str] = None
size: int = -1

@ -1,67 +0,0 @@
import base64
import logging
import time
from typing import Optional
from configs import dify_config
from constants import IMAGE_EXTENSIONS
from core.helper.url_signer import UrlSigner
from extensions.ext_storage import storage
class UploadFileParser:
@classmethod
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
if not upload_file:
return None
if upload_file.extension not in IMAGE_EXTENSIONS:
return None
if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url:
return cls.get_signed_temp_image_url(upload_file.id)
else:
# get image file base64
try:
data = storage.load(upload_file.key)
except FileNotFoundError:
logging.exception(f"File not found: {upload_file.key}")
return None
encoded_string = base64.b64encode(data).decode("utf-8")
return f"data:{upload_file.mime_type};base64,{encoded_string}"
@classmethod
def get_signed_temp_image_url(cls, upload_file_id) -> str:
"""
get signed url from upload file
:param upload_file_id: the id of UploadFile object
:return:
"""
base_url = dify_config.FILES_URL
image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
return UrlSigner.get_signed_url(url=image_preview_url, sign_key=upload_file_id, prefix="image-preview")
@classmethod
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
:param upload_file_id: file id
:param timestamp: timestamp
:param nonce: nonce
:param sign: signature
:return:
"""
result = UrlSigner.verify(
sign_key=upload_file_id, timestamp=timestamp, nonce=nonce, sign=sign, prefix="image-preview"
)
# verify signature
if not result:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT

@ -28,7 +28,7 @@ class TemplateTransformer(ABC):
def extract_result_str_from_response(cls, response: str):
result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL)
if not result:
raise ValueError("Failed to parse result")
raise ValueError(f"Failed to parse result: no result tag found in response. Response: {response[:200]}...")
return result.group(1)
@classmethod
@ -38,16 +38,53 @@ class TemplateTransformer(ABC):
:param response: response
:return:
"""
try:
result = json.loads(cls.extract_result_str_from_response(response))
except json.JSONDecodeError:
raise ValueError("failed to parse response")
result_str = cls.extract_result_str_from_response(response)
result = json.loads(result_str)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse JSON response: {str(e)}. Response content: {result_str[:200]}...")
except ValueError as e:
# Re-raise ValueError from extract_result_str_from_response
raise e
except Exception as e:
raise ValueError(f"Unexpected error during response transformation: {str(e)}")
# Check if the result contains an error
if isinstance(result, dict) and "error" in result:
raise ValueError(f"JavaScript execution error: {result['error']}")
if not isinstance(result, dict):
raise ValueError("result must be a dict")
raise ValueError(f"Result must be a dict, got {type(result).__name__}")
if not all(isinstance(k, str) for k in result):
raise ValueError("result keys must be strings")
raise ValueError("Result keys must be strings")
# Post-process the result to convert scientific notation strings back to numbers
result = cls._post_process_result(result)
return result
@classmethod
def _post_process_result(cls, result: dict[Any, Any]) -> dict[Any, Any]:
"""
Post-process the result to convert scientific notation strings back to numbers
"""
def convert_scientific_notation(value):
if isinstance(value, str):
# Check if the string looks like scientific notation
if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE):
try:
return float(value)
except ValueError:
pass
elif isinstance(value, dict):
return {k: convert_scientific_notation(v) for k, v in value.items()}
elif isinstance(value, list):
return [convert_scientific_notation(v) for v in value]
return value
return convert_scientific_notation(result) # type: ignore[no-any-return]
@classmethod
@abstractmethod
def get_runner_script(cls) -> str:

@ -1,22 +0,0 @@
from collections import OrderedDict
from typing import Any
class LRUCache:
def __init__(self, capacity: int):
self.cache: OrderedDict[Any, Any] = OrderedDict()
self.capacity = capacity
def get(self, key: Any) -> Any:
if key not in self.cache:
return None
else:
self.cache.move_to_end(key) # move the key to the end of the OrderedDict
return self.cache[key]
def put(self, key: Any, value: Any) -> None:
if key in self.cache:
self.cache.move_to_end(key)
self.cache[key] = value
if len(self.cache) > self.capacity:
self.cache.popitem(last=False) # pop the first item

@ -317,8 +317,9 @@ class IndexingRunner:
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
if image_file is None:
continue
try:
if image_file:
storage.delete(image_file.key)
except Exception:
logging.exception(
@ -534,7 +535,7 @@ class IndexingRunner:
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
tokens = 0
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
# create keyword index
create_keyword_thread = threading.Thread(
target=self._process_keyword_index,
@ -572,7 +573,7 @@ class IndexingRunner:
for future in futures:
tokens += future.result()
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
create_keyword_thread.join()
indexing_end_at = time.perf_counter()

@ -0,0 +1,380 @@
import json
from collections.abc import Generator, Mapping, Sequence
from copy import deepcopy
from enum import StrEnum
from typing import Any, Literal, Optional, cast, overload
import json_repair
from pydantic import TypeAdapter, ValidationError
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT
from core.model_manager import ModelInstance
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
)
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
)
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule
class ResponseFormat(StrEnum):
"""Constants for model response formats"""
JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode.
JSON = "JSON" # model's json mode. some model like claude support this mode.
JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias.
class SpecialModelType(StrEnum):
"""Constants for identifying model types"""
GEMINI = "gemini"
OLLAMA = "ollama"
@overload
def invoke_llm_with_structured_output(
provider: str,
model_schema: AIModelEntity,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Optional[Mapping] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stream: Literal[True] = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
@overload
def invoke_llm_with_structured_output(
provider: str,
model_schema: AIModelEntity,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Optional[Mapping] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stream: Literal[False] = False,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> LLMResultWithStructuredOutput: ...
@overload
def invoke_llm_with_structured_output(
provider: str,
model_schema: AIModelEntity,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Optional[Mapping] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
def invoke_llm_with_structured_output(
provider: str,
model_schema: AIModelEntity,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
json_schema: Mapping[str, Any],
model_parameters: Optional[Mapping] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
"""
Invoke large language model with structured output
1. This method invokes model_instance.invoke_llm with json_schema
2. Try to parse the result as structured output
:param prompt_messages: prompt messages
:param json_schema: json schema
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:return: full response or stream response chunk generator result
"""
# handle native json schema
model_parameters_with_json_schema: dict[str, Any] = {
**(model_parameters or {}),
}
if model_schema.support_structure_output:
model_parameters = _handle_native_json_schema(
provider, model_schema, json_schema, model_parameters_with_json_schema, model_schema.parameter_rules
)
else:
# Set appropriate response format based on model capabilities
_set_response_format(model_parameters_with_json_schema, model_schema.parameter_rules)
# handle prompt based schema
prompt_messages = _handle_prompt_based_schema(
prompt_messages=prompt_messages,
structured_output_schema=json_schema,
)
llm_result = model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters=model_parameters_with_json_schema,
tools=tools,
stop=stop,
stream=stream,
user=user,
callbacks=callbacks,
)
if isinstance(llm_result, LLMResult):
if not isinstance(llm_result.message.content, str):
raise OutputParserError(
f"Failed to parse structured output, LLM result is not a string: {llm_result.message.content}"
)
return LLMResultWithStructuredOutput(
structured_output=_parse_structured_output(llm_result.message.content),
model=llm_result.model,
message=llm_result.message,
usage=llm_result.usage,
system_fingerprint=llm_result.system_fingerprint,
prompt_messages=llm_result.prompt_messages,
)
else:
def generator() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
result_text: str = ""
prompt_messages: Sequence[PromptMessage] = []
system_fingerprint: Optional[str] = None
for event in llm_result:
if isinstance(event, LLMResultChunk):
prompt_messages = event.prompt_messages
system_fingerprint = event.system_fingerprint
if isinstance(event.delta.message.content, str):
result_text += event.delta.message.content
elif isinstance(event.delta.message.content, list):
for item in event.delta.message.content:
if isinstance(item, TextPromptMessageContent):
result_text += item.data
yield LLMResultChunkWithStructuredOutput(
model=model_schema.model,
prompt_messages=prompt_messages,
system_fingerprint=system_fingerprint,
delta=event.delta,
)
yield LLMResultChunkWithStructuredOutput(
structured_output=_parse_structured_output(result_text),
model=model_schema.model,
prompt_messages=prompt_messages,
system_fingerprint=system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=""),
usage=None,
finish_reason=None,
),
)
return generator()
def _handle_native_json_schema(
provider: str,
model_schema: AIModelEntity,
structured_output_schema: Mapping,
model_parameters: dict,
rules: list[ParameterRule],
) -> dict:
"""
Handle structured output for models with native JSON schema support.
:param model_parameters: Model parameters to update
:param rules: Model parameter rules
:return: Updated model parameters with JSON schema configuration
"""
# Process schema according to model requirements
schema_json = _prepare_schema_for_model(provider, model_schema, structured_output_schema)
# Set JSON schema in parameters
model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
# Set appropriate response format if required by the model
for rule in rules:
if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
return model_parameters
def _set_response_format(model_parameters: dict, rules: list) -> None:
"""
Set the appropriate response format parameter based on model rules.
:param model_parameters: Model parameters to update
:param rules: Model parameter rules
"""
for rule in rules:
if rule.name == "response_format":
if ResponseFormat.JSON.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON.value
elif ResponseFormat.JSON_OBJECT.value in rule.options:
model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
def _handle_prompt_based_schema(
prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping
) -> list[PromptMessage]:
"""
Handle structured output for models without native JSON schema support.
This function modifies the prompt messages to include schema-based output requirements.
Args:
prompt_messages: Original sequence of prompt messages
Returns:
list[PromptMessage]: Updated prompt messages with structured output requirements
"""
# Convert schema to string format
schema_str = json.dumps(structured_output_schema, ensure_ascii=False)
# Find existing system prompt with schema placeholder
system_prompt = next(
(prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)),
None,
)
structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str)
# Prepare system prompt content
system_prompt_content = (
structured_output_prompt + "\n\n" + system_prompt.content
if system_prompt and isinstance(system_prompt.content, str)
else structured_output_prompt
)
system_prompt = SystemPromptMessage(content=system_prompt_content)
# Extract content from the last user message
filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)]
updated_prompt = [system_prompt] + filtered_prompts
return updated_prompt
def _parse_structured_output(result_text: str) -> Mapping[str, Any]:
structured_output: Mapping[str, Any] = {}
parsed: Mapping[str, Any] = {}
try:
parsed = TypeAdapter(Mapping).validate_json(result_text)
if not isinstance(parsed, dict):
raise OutputParserError(f"Failed to parse structured output: {result_text}")
structured_output = parsed
except ValidationError:
# if the result_text is not a valid json, try to repair it
temp_parsed = json_repair.loads(result_text)
if not isinstance(temp_parsed, dict):
# handle reasoning model like deepseek-r1 got '<think>\n\n</think>\n' prefix
if isinstance(temp_parsed, list):
temp_parsed = next((item for item in temp_parsed if isinstance(item, dict)), {})
else:
raise OutputParserError(f"Failed to parse structured output: {result_text}")
structured_output = cast(dict, temp_parsed)
return structured_output
def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping) -> dict:
"""
Prepare JSON schema based on model requirements.
Different models have different requirements for JSON schema formatting.
This function handles these differences.
:param schema: The original JSON schema
:return: Processed schema compatible with the current model
"""
# Deep copy to avoid modifying the original schema
processed_schema = dict(deepcopy(schema))
# Convert boolean types to string types (common requirement)
convert_boolean_to_string(processed_schema)
# Apply model-specific transformations
if SpecialModelType.GEMINI in model_schema.model:
remove_additional_properties(processed_schema)
return processed_schema
elif SpecialModelType.OLLAMA in provider:
return processed_schema
else:
# Default format with name field
return {"schema": processed_schema, "name": "llm_response"}
def remove_additional_properties(schema: dict) -> None:
"""
Remove additionalProperties fields from JSON schema.
Used for models like Gemini that don't support this property.
:param schema: JSON schema to modify in-place
"""
if not isinstance(schema, dict):
return
# Remove additionalProperties at current level
schema.pop("additionalProperties", None)
# Process nested structures recursively
for value in schema.values():
if isinstance(value, dict):
remove_additional_properties(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
remove_additional_properties(item)
def convert_boolean_to_string(schema: dict) -> None:
"""
Convert boolean type specifications to string in JSON schema.
:param schema: JSON schema to modify in-place
"""
if not isinstance(schema, dict):
return
# Check for boolean type at current level
if schema.get("type") == "boolean":
schema["type"] = "string"
# Process nested dictionaries and lists recursively
for value in schema.values():
if isinstance(value, dict):
convert_boolean_to_string(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
convert_boolean_to_string(item)

@ -291,3 +291,21 @@ Your task is to convert simple user descriptions into properly formatted JSON Sc
Now, generate a JSON Schema based on my description
""" # noqa: E501
STRUCTURED_OUTPUT_PROMPT = """Youre a helpful AI assistant. You could answer questions and output in JSON format.
constraints:
- You must output in JSON format.
- Do not output boolean value, use string type instead.
- Do not output integer or float value, use number type instead.
eg:
Here is the JSON schema:
{"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"}
Here is the user's question:
My name is John Doe and I am 30 years old.
output:
{"name": "John Doe", "age": 30}
Here is the JSON schema:
{{schema}}
""" # noqa: E501

@ -1,7 +1,7 @@
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from decimal import Decimal
from enum import StrEnum
from typing import Optional
from typing import Any, Optional
from pydantic import BaseModel, Field
@ -53,6 +53,37 @@ class LLMUsage(ModelUsage):
latency=0.0,
)
@classmethod
def from_metadata(cls, metadata: dict) -> "LLMUsage":
"""
Create LLMUsage instance from metadata dictionary with default values.
Args:
metadata: Dictionary containing usage metadata
Returns:
LLMUsage instance with values from metadata or defaults
"""
total_tokens = metadata.get("total_tokens", 0)
completion_tokens = metadata.get("completion_tokens", 0)
if total_tokens > 0 and completion_tokens == 0:
completion_tokens = total_tokens
return cls(
prompt_tokens=metadata.get("prompt_tokens", 0),
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))),
completion_unit_price=Decimal(str(metadata.get("completion_unit_price", 0))),
total_price=Decimal(str(metadata.get("total_price", 0))),
currency=metadata.get("currency", "USD"),
prompt_price_unit=Decimal(str(metadata.get("prompt_price_unit", 0))),
completion_price_unit=Decimal(str(metadata.get("completion_price_unit", 0))),
prompt_price=Decimal(str(metadata.get("prompt_price", 0))),
completion_price=Decimal(str(metadata.get("completion_price", 0))),
latency=metadata.get("latency", 0.0),
)
def plus(self, other: "LLMUsage") -> "LLMUsage":
"""
Add two LLMUsage instances together.
@ -101,6 +132,20 @@ class LLMResult(BaseModel):
system_fingerprint: Optional[str] = None
class LLMStructuredOutput(BaseModel):
"""
Model class for llm structured output.
"""
structured_output: Optional[Mapping[str, Any]] = None
class LLMResultWithStructuredOutput(LLMResult, LLMStructuredOutput):
"""
Model class for llm result with structured output.
"""
class LLMResultChunkDelta(BaseModel):
"""
Model class for llm result chunk delta.
@ -123,6 +168,12 @@ class LLMResultChunk(BaseModel):
delta: LLMResultChunkDelta
class LLMResultChunkWithStructuredOutput(LLMResultChunk, LLMStructuredOutput):
"""
Model class for llm result chunk with structured output.
"""
class NumTokensResult(PriceInfo):
"""
Model class for number of tokens result.

@ -0,0 +1,487 @@
import json
import logging
from collections.abc import Sequence
from typing import Optional
from urllib.parse import urljoin
from opentelemetry.trace import Status, StatusCode
from sqlalchemy.orm import Session, sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
TraceClient,
convert_datetime_to_nanoseconds,
convert_to_span_id,
convert_to_trace_id,
generate_span_id,
)
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
from core.ops.aliyun_trace.entities.semconv import (
GEN_AI_COMPLETION,
GEN_AI_FRAMEWORK,
GEN_AI_MODEL_NAME,
GEN_AI_PROMPT,
GEN_AI_PROMPT_TEMPLATE_TEMPLATE,
GEN_AI_PROMPT_TEMPLATE_VARIABLE,
GEN_AI_RESPONSE_FINISH_REASON,
GEN_AI_SESSION_ID,
GEN_AI_SPAN_KIND,
GEN_AI_SYSTEM,
GEN_AI_USAGE_INPUT_TOKENS,
GEN_AI_USAGE_OUTPUT_TOKENS,
GEN_AI_USAGE_TOTAL_TOKENS,
GEN_AI_USER_ID,
INPUT_VALUE,
OUTPUT_VALUE,
RETRIEVAL_DOCUMENT,
RETRIEVAL_QUERY,
TOOL_DESCRIPTION,
TOOL_NAME,
TOOL_PARAMETERS,
GenAISpanKind,
)
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import AliyunConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.rag.models.document import Document
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.nodes import NodeType
from models import Account, App, EndUser, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom, db
logger = logging.getLogger(__name__)
class AliyunDataTrace(BaseTraceInstance):
def __init__(
self,
aliyun_config: AliyunConfig,
):
super().__init__(aliyun_config)
base_url = aliyun_config.endpoint.rstrip("/")
endpoint = urljoin(base_url, f"adapt_{aliyun_config.license_key}/api/otlp/traces")
self.trace_client = TraceClient(service_name=aliyun_config.app_name, endpoint=endpoint)
def trace(self, trace_info: BaseTraceInfo):
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
pass
if isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
pass
def api_check(self):
return self.trace_client.api_check()
def get_project_url(self):
try:
return self.trace_client.get_project_url()
except Exception as e:
logger.info(f"Aliyun get run url failed: {str(e)}", exc_info=True)
raise ValueError(f"Aliyun get run url failed: {str(e)}")
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = convert_to_trace_id(trace_info.workflow_run_id)
workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow")
self.add_workflow_span(trace_id, workflow_span_id, trace_info)
workflow_node_executions = self.get_workflow_node_executions(trace_info)
for node_execution in workflow_node_executions:
node_span = self.build_workflow_node_span(node_execution, trace_id, trace_info, workflow_span_id)
self.trace_client.add_span(node_span)
def message_trace(self, trace_info: MessageTraceInfo):
message_data = trace_info.message_data
if message_data is None:
return
message_id = trace_info.message_id
user_id = message_data.from_account_id
if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
user_id = end_user_data.session_id
status: Status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
trace_id = convert_to_trace_id(message_id)
message_span_id = convert_to_span_id(message_id, "message")
message_span = SpanData(
trace_id=trace_id,
parent_span_id=None,
span_id=message_span_id,
name="message",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
OUTPUT_VALUE: str(trace_info.outputs),
},
status=status,
)
self.trace_client.add_span(message_span)
app_model_config = getattr(trace_info.message_data, "app_model_config", {})
pre_prompt = getattr(app_model_config, "pre_prompt", "")
inputs_data = getattr(trace_info.message_data, "inputs", {})
llm_span = SpanData(
trace_id=trace_id,
parent_span_id=message_span_id,
span_id=convert_to_span_id(message_id, "llm"),
name="llm",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name", ""),
GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider", ""),
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens),
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens),
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens),
GEN_AI_PROMPT_TEMPLATE_VARIABLE: json.dumps(inputs_data, ensure_ascii=False),
GEN_AI_PROMPT_TEMPLATE_TEMPLATE: pre_prompt,
GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False),
GEN_AI_COMPLETION: str(trace_info.outputs),
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
OUTPUT_VALUE: str(trace_info.outputs),
},
status=status,
)
self.trace_client.add_span(llm_span)
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
message_id = trace_info.message_id
documents_data = extract_retrieval_documents(trace_info.documents)
dataset_retrieval_span = SpanData(
trace_id=convert_to_trace_id(message_id),
parent_span_id=convert_to_span_id(message_id, "message"),
span_id=generate_span_id(),
name="dataset_retrieval",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
GEN_AI_FRAMEWORK: "dify",
RETRIEVAL_QUERY: str(trace_info.inputs),
RETRIEVAL_DOCUMENT: json.dumps(documents_data, ensure_ascii=False),
INPUT_VALUE: str(trace_info.inputs),
OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False),
},
)
self.trace_client.add_span(dataset_retrieval_span)
def tool_trace(self, trace_info: ToolTraceInfo):
if trace_info.message_data is None:
return
message_id = trace_info.message_id
status: Status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
tool_span = SpanData(
trace_id=convert_to_trace_id(message_id),
parent_span_id=convert_to_span_id(message_id, "message"),
span_id=generate_span_id(),
name=trace_info.tool_name,
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
GEN_AI_FRAMEWORK: "dify",
TOOL_NAME: trace_info.tool_name,
TOOL_DESCRIPTION: json.dumps(trace_info.tool_config, ensure_ascii=False),
TOOL_PARAMETERS: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
OUTPUT_VALUE: str(trace_info.tool_outputs),
},
status=status,
)
self.trace_client.add_span(tool_span)
def get_workflow_node_executions(self, trace_info: WorkflowTraceInfo) -> Sequence[WorkflowNodeExecution]:
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app = session.query(App).filter(App.id == app_id).first()
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).filter(Account.id == app.created_by).first()
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = (
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
)
if not current_tenant:
raise ValueError(f"Current tenant not found for account {service_account.id}")
service_account.set_tenant_id(current_tenant.tenant_id)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=service_account,
app_id=trace_info.metadata.get("app_id"),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Get all executions for this workflow run
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
workflow_run_id=trace_info.workflow_run_id
)
return workflow_node_executions
def build_workflow_node_span(
self, node_execution: WorkflowNodeExecution, trace_id: int, trace_info: WorkflowTraceInfo, workflow_span_id: int
):
try:
if node_execution.node_type == NodeType.LLM:
node_span = self.build_workflow_llm_span(trace_id, workflow_span_id, trace_info, node_execution)
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
node_span = self.build_workflow_retrieval_span(trace_id, workflow_span_id, trace_info, node_execution)
elif node_execution.node_type == NodeType.TOOL:
node_span = self.build_workflow_tool_span(trace_id, workflow_span_id, trace_info, node_execution)
else:
node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution)
return node_span
except Exception:
return None
def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status:
span_status: Status = Status(StatusCode.UNSET)
if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
span_status = Status(StatusCode.OK)
elif node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
span_status = Status(StatusCode.ERROR, str(node_execution.error))
return span_status
def build_workflow_task_span(
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
},
status=self.get_workflow_node_status(node_execution),
)
def build_workflow_tool_span(
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
tool_des = {}
if node_execution.metadata:
tool_des = node_execution.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
GEN_AI_FRAMEWORK: "dify",
TOOL_NAME: node_execution.title,
TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
TOOL_PARAMETERS: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
INPUT_VALUE: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
},
status=self.get_workflow_node_status(node_execution),
)
def build_workflow_retrieval_span(
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
input_value = ""
if node_execution.inputs:
input_value = str(node_execution.inputs.get("query", ""))
output_value = ""
if node_execution.outputs:
output_value = json.dumps(node_execution.outputs.get("result", []), ensure_ascii=False)
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
GEN_AI_FRAMEWORK: "dify",
RETRIEVAL_QUERY: input_value,
RETRIEVAL_DOCUMENT: output_value,
INPUT_VALUE: input_value,
OUTPUT_VALUE: output_value,
},
status=self.get_workflow_node_status(node_execution),
)
def build_workflow_llm_span(
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
process_data = node_execution.process_data or {}
outputs = node_execution.outputs or {}
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
GEN_AI_SYSTEM: process_data.get("model_provider", ""),
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
GEN_AI_COMPLETION: str(outputs.get("text", "")),
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""),
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
OUTPUT_VALUE: str(outputs.get("text", "")),
},
status=self.get_workflow_node_status(node_execution),
)
def add_workflow_span(self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo):
message_span_id = None
if trace_info.message_id:
message_span_id = convert_to_span_id(trace_info.message_id, "message")
user_id = trace_info.metadata.get("user_id")
status: Status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
if message_span_id: # chatflow
message_span = SpanData(
trace_id=trace_id,
parent_span_id=None,
span_id=message_span_id,
name="message",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: trace_info.workflow_run_inputs.get("sys.query", ""),
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
},
status=status,
)
self.trace_client.add_span(message_span)
workflow_span = SpanData(
trace_id=trace_id,
parent_span_id=message_span_id,
span_id=workflow_span_id,
name="workflow",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
},
status=status,
)
self.trace_client.add_span(workflow_span)
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
message_id = trace_info.message_id
status: Status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
suggested_question_span = SpanData(
trace_id=convert_to_trace_id(message_id),
parent_span_id=convert_to_span_id(message_id, "message"),
span_id=convert_to_span_id(message_id, "suggested_question"),
name="suggested_question",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name", ""),
GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider", ""),
GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False),
GEN_AI_COMPLETION: json.dumps(trace_info.suggested_question, ensure_ascii=False),
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
},
status=status,
)
self.trace_client.add_span(suggested_question_span)
def extract_retrieval_documents(documents: list[Document]):
documents_data = []
for document in documents:
document_data = {
"content": document.page_content,
"metadata": {
"dataset_id": document.metadata.get("dataset_id"),
"doc_id": document.metadata.get("doc_id"),
"document_id": document.metadata.get("document_id"),
},
"score": document.metadata.get("score"),
}
documents_data.append(document_data)
return documents_data

@ -0,0 +1,200 @@
import hashlib
import logging
import random
import socket
import threading
import uuid
from collections import deque
from collections.abc import Sequence
from datetime import datetime
from typing import Optional
import requests
from opentelemetry import trace as trace_api
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.util.instrumentation import InstrumentationScope
from opentelemetry.semconv.resource import ResourceAttributes
from configs import dify_config
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
INVALID_SPAN_ID = 0x0000000000000000
INVALID_TRACE_ID = 0x00000000000000000000000000000000
logger = logging.getLogger(__name__)
class TraceClient:
def __init__(
self,
service_name: str,
endpoint: str,
max_queue_size: int = 1000,
schedule_delay_sec: int = 5,
max_export_batch_size: int = 50,
):
self.endpoint = endpoint
self.resource = Resource(
attributes={
ResourceAttributes.SERVICE_NAME: service_name,
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
ResourceAttributes.HOST_NAME: socket.gethostname(),
}
)
self.span_builder = SpanBuilder(self.resource)
self.exporter = OTLPSpanExporter(endpoint=endpoint)
self.max_queue_size = max_queue_size
self.schedule_delay_sec = schedule_delay_sec
self.max_export_batch_size = max_export_batch_size
self.queue: deque = deque(maxlen=max_queue_size)
self.condition = threading.Condition(threading.Lock())
self.done = False
self.worker_thread = threading.Thread(target=self._worker, daemon=True)
self.worker_thread.start()
self._spans_dropped = False
def export(self, spans: Sequence[ReadableSpan]):
self.exporter.export(spans)
def api_check(self):
try:
response = requests.head(self.endpoint, timeout=5)
if response.status_code == 405:
return True
else:
logger.debug(f"AliyunTrace API check failed: Unexpected status code: {response.status_code}")
return False
except requests.exceptions.RequestException as e:
logger.debug(f"AliyunTrace API check failed: {str(e)}")
raise ValueError(f"AliyunTrace API check failed: {str(e)}")
def get_project_url(self):
return "https://arms.console.aliyun.com/#/llm"
def add_span(self, span_data: SpanData):
if span_data is None:
return
span: ReadableSpan = self.span_builder.build_span(span_data)
with self.condition:
if len(self.queue) == self.max_queue_size:
if not self._spans_dropped:
logger.warning("Queue is full, likely spans will be dropped.")
self._spans_dropped = True
self.queue.appendleft(span)
if len(self.queue) >= self.max_export_batch_size:
self.condition.notify()
def _worker(self):
while not self.done:
with self.condition:
if len(self.queue) < self.max_export_batch_size and not self.done:
self.condition.wait(timeout=self.schedule_delay_sec)
self._export_batch()
def _export_batch(self):
spans_to_export: list[ReadableSpan] = []
with self.condition:
while len(spans_to_export) < self.max_export_batch_size and self.queue:
spans_to_export.append(self.queue.pop())
if spans_to_export:
try:
self.exporter.export(spans_to_export)
except Exception as e:
logger.debug(f"Error exporting spans: {e}")
def shutdown(self):
with self.condition:
self.done = True
self.condition.notify_all()
self.worker_thread.join()
self._export_batch()
self.exporter.shutdown()
class SpanBuilder:
def __init__(self, resource):
self.resource = resource
self.instrumentation_scope = InstrumentationScope(
__name__,
"",
None,
None,
)
def build_span(self, span_data: SpanData) -> ReadableSpan:
span_context = trace_api.SpanContext(
trace_id=span_data.trace_id,
span_id=span_data.span_id,
is_remote=False,
trace_flags=trace_api.TraceFlags(trace_api.TraceFlags.SAMPLED),
trace_state=None,
)
parent_span_context = None
if span_data.parent_span_id is not None:
parent_span_context = trace_api.SpanContext(
trace_id=span_data.trace_id,
span_id=span_data.parent_span_id,
is_remote=False,
trace_flags=trace_api.TraceFlags(trace_api.TraceFlags.SAMPLED),
trace_state=None,
)
span = ReadableSpan(
name=span_data.name,
context=span_context,
parent=parent_span_context,
resource=self.resource,
attributes=span_data.attributes,
events=span_data.events,
links=span_data.links,
kind=trace_api.SpanKind.INTERNAL,
status=span_data.status,
start_time=span_data.start_time,
end_time=span_data.end_time,
instrumentation_scope=self.instrumentation_scope,
)
return span
def generate_span_id() -> int:
span_id = random.getrandbits(64)
while span_id == INVALID_SPAN_ID:
span_id = random.getrandbits(64)
return span_id
def convert_to_trace_id(uuid_v4: Optional[str]) -> int:
try:
uuid_obj = uuid.UUID(uuid_v4)
return uuid_obj.int
except Exception as e:
raise ValueError(f"Invalid UUID input: {e}")
def convert_to_span_id(uuid_v4: Optional[str], span_type: str) -> int:
try:
uuid_obj = uuid.UUID(uuid_v4)
except Exception as e:
raise ValueError(f"Invalid UUID input: {e}")
combined_key = f"{uuid_obj.hex}-{span_type}"
hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest()
span_id = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
return span_id
def convert_datetime_to_nanoseconds(start_time_a: Optional[datetime]) -> Optional[int]:
if start_time_a is None:
return None
timestamp_in_seconds = start_time_a.timestamp()
timestamp_in_nanoseconds = int(timestamp_in_seconds * 1e9)
return timestamp_in_nanoseconds

@ -0,0 +1,21 @@
from collections.abc import Sequence
from typing import Optional
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event, Status, StatusCode
from pydantic import BaseModel, Field
class SpanData(BaseModel):
model_config = {"arbitrary_types_allowed": True}
trace_id: int = Field(..., description="The unique identifier for the trace.")
parent_span_id: Optional[int] = Field(None, description="The ID of the parent span, if any.")
span_id: int = Field(..., description="The unique identifier for this span.")
name: str = Field(..., description="The name of the span.")
attributes: dict[str, str] = Field(default_factory=dict, description="Attributes associated with the span.")
events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.")
links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.")
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
start_time: Optional[int] = Field(..., description="The start time of the span in nanoseconds.")
end_time: Optional[int] = Field(..., description="The end time of the span in nanoseconds.")

@ -0,0 +1,64 @@
from enum import Enum
# public
GEN_AI_SESSION_ID = "gen_ai.session.id"
GEN_AI_USER_ID = "gen_ai.user.id"
GEN_AI_USER_NAME = "gen_ai.user.name"
GEN_AI_SPAN_KIND = "gen_ai.span.kind"
GEN_AI_FRAMEWORK = "gen_ai.framework"
# Chain
INPUT_VALUE = "input.value"
OUTPUT_VALUE = "output.value"
# Retriever
RETRIEVAL_QUERY = "retrieval.query"
RETRIEVAL_DOCUMENT = "retrieval.document"
# LLM
GEN_AI_MODEL_NAME = "gen_ai.model_name"
GEN_AI_SYSTEM = "gen_ai.system"
GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
GEN_AI_PROMPT_TEMPLATE_TEMPLATE = "gen_ai.prompt_template.template"
GEN_AI_PROMPT_TEMPLATE_VARIABLE = "gen_ai.prompt_template.variable"
GEN_AI_PROMPT = "gen_ai.prompt"
GEN_AI_COMPLETION = "gen_ai.completion"
GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
# Tool
TOOL_NAME = "tool.name"
TOOL_DESCRIPTION = "tool.description"
TOOL_PARAMETERS = "tool.parameters"
class GenAISpanKind(Enum):
CHAIN = "CHAIN"
RETRIEVER = "RETRIEVER"
RERANKER = "RERANKER"
LLM = "LLM"
EMBEDDING = "EMBEDDING"
TOOL = "TOOL"
AGENT = "AGENT"
TASK = "TASK"

@ -0,0 +1,726 @@
import hashlib
import json
import logging
import os
from datetime import datetime, timedelta
from typing import Optional, Union, cast
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
from opentelemetry.trace import SpanContext, TraceFlags, TraceState
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
TraceTaskName,
WorkflowTraceInfo,
)
from extensions.ext_database import db
from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecutionModel
logger = logging.getLogger(__name__)
def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[trace_sdk.Tracer, SimpleSpanProcessor]:
"""Configure OpenTelemetry tracer with OTLP exporter for Arize/Phoenix."""
try:
# Choose the appropriate exporter based on config type
exporter: Union[GrpcOTLPSpanExporter, HttpOTLPSpanExporter]
if isinstance(arize_phoenix_config, ArizeConfig):
arize_endpoint = f"{arize_phoenix_config.endpoint}/v1"
arize_headers = {
"api_key": arize_phoenix_config.api_key or "",
"space_id": arize_phoenix_config.space_id or "",
"authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
}
exporter = GrpcOTLPSpanExporter(
endpoint=arize_endpoint,
headers=arize_headers,
timeout=30,
)
else:
phoenix_endpoint = f"{arize_phoenix_config.endpoint}/v1/traces"
phoenix_headers = {
"api_key": arize_phoenix_config.api_key or "",
"authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
}
exporter = HttpOTLPSpanExporter(
endpoint=phoenix_endpoint,
headers=phoenix_headers,
timeout=30,
)
attributes = {
"openinference.project.name": arize_phoenix_config.project or "",
"model_id": arize_phoenix_config.project or "",
}
resource = Resource(attributes=attributes)
provider = trace_sdk.TracerProvider(resource=resource)
processor = SimpleSpanProcessor(
exporter,
)
provider.add_span_processor(processor)
# Create a named tracer instead of setting the global provider
tracer_name = f"arize_phoenix_tracer_{arize_phoenix_config.project}"
logger.info(f"[Arize/Phoenix] Created tracer with name: {tracer_name}")
return cast(trace_sdk.Tracer, provider.get_tracer(tracer_name)), processor
except Exception as e:
logger.error(f"[Arize/Phoenix] Failed to setup the tracer: {str(e)}", exc_info=True)
raise
def datetime_to_nanos(dt: Optional[datetime]) -> int:
"""Convert datetime to nanoseconds since epoch. If None, use current time."""
if dt is None:
dt = datetime.now()
return int(dt.timestamp() * 1_000_000_000)
def uuid_to_trace_id(string: Optional[str]) -> int:
"""Convert UUID string to a valid trace ID (16-byte integer)."""
if string is None:
string = ""
hash_object = hashlib.sha256(string.encode())
# Take the first 16 bytes (128 bits) of the hash
digest = hash_object.digest()[:16]
# Convert to integer (128 bits)
return int.from_bytes(digest, byteorder="big")
class ArizePhoenixDataTrace(BaseTraceInstance):
def __init__(
self,
arize_phoenix_config: ArizeConfig | PhoenixConfig,
):
super().__init__(arize_phoenix_config)
import logging
logging.basicConfig()
logging.getLogger().setLevel(logging.DEBUG)
self.arize_phoenix_config = arize_phoenix_config
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
self.project = arize_phoenix_config.project
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
def trace(self, trace_info: BaseTraceInfo):
logger.info(f"[Arize/Phoenix] Trace: {trace_info}")
try:
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
self.moderation_trace(trace_info)
if isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
self.generate_name_trace(trace_info)
except Exception as e:
logger.error(f"[Arize/Phoenix] Error in the trace: {str(e)}", exc_info=True)
raise
def workflow_trace(self, trace_info: WorkflowTraceInfo):
if trace_info.message_data is None:
return
workflow_metadata = {
"workflow_id": trace_info.workflow_run_id or "",
"message_id": trace_info.message_id or "",
"workflow_app_log_id": trace_info.workflow_app_log_id or "",
"status": trace_info.workflow_run_status or "",
"status_message": trace_info.error or "",
"level": "ERROR" if trace_info.error else "DEFAULT",
"total_tokens": trace_info.total_tokens or 0,
}
workflow_metadata.update(trace_info.metadata)
trace_id = uuid_to_trace_id(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
workflow_span = self.tracer.start_span(
name=TraceTaskName.WORKFLOW_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(workflow_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
# Process workflow nodes
for node_execution in self._get_workflow_nodes(trace_info.workflow_run_id):
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
node_metadata = {
"node_id": node_execution.id,
"node_type": node_execution.node_type,
"node_status": node_execution.status,
"tenant_id": node_execution.tenant_id,
"app_id": node_execution.app_id,
"app_name": node_execution.title,
"status": node_execution.status,
"level": "ERROR" if node_execution.status != "succeeded" else "DEFAULT",
}
if node_execution.execution_metadata:
node_metadata.update(json.loads(node_execution.execution_metadata))
# Determine the correct span kind based on node type
span_kind = OpenInferenceSpanKindValues.CHAIN.value
if node_execution.node_type == "llm":
span_kind = OpenInferenceSpanKindValues.LLM.value
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider:
node_metadata["ls_provider"] = provider
if model:
node_metadata["ls_model_name"] = model
outputs = json.loads(node_execution.outputs).get("usage", {})
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
if usage_data:
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0)
elif node_execution.node_type == "dataset_retrieval":
span_kind = OpenInferenceSpanKindValues.RETRIEVER.value
elif node_execution.node_type == "tool":
span_kind = OpenInferenceSpanKindValues.TOOL.value
else:
span_kind = OpenInferenceSpanKindValues.CHAIN.value
node_span = self.tracer.start_span(
name=node_execution.node_type,
attributes={
SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}",
SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}",
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind,
SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
start_time=datetime_to_nanos(created_at),
)
try:
if node_execution.node_type == "llm":
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider:
node_span.set_attribute(SpanAttributes.LLM_PROVIDER, provider)
if model:
node_span.set_attribute(SpanAttributes.LLM_MODEL_NAME, model)
outputs = json.loads(node_execution.outputs).get("usage", {})
usage_data = (
process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
)
if usage_data:
node_span.set_attribute(
SpanAttributes.LLM_TOKEN_COUNT_TOTAL, usage_data.get("total_tokens", 0)
)
node_span.set_attribute(
SpanAttributes.LLM_TOKEN_COUNT_PROMPT, usage_data.get("prompt_tokens", 0)
)
node_span.set_attribute(
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, usage_data.get("completion_tokens", 0)
)
finally:
node_span.end(end_time=datetime_to_nanos(finished_at))
finally:
workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time))
def message_trace(self, trace_info: MessageTraceInfo):
if trace_info.message_data is None:
return
file_list = cast(list[str], trace_info.file_list) or []
message_file_data: Optional[MessageFile] = trace_info.message_file_data
if message_file_data is not None:
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
file_list.append(file_url)
message_metadata = {
"message_id": trace_info.message_id or "",
"conversation_mode": str(trace_info.conversation_mode or ""),
"user_id": trace_info.message_data.from_account_id or "",
"file_list": json.dumps(file_list),
"status": trace_info.message_data.status or "",
"status_message": trace_info.error or "",
"level": "ERROR" if trace_info.error else "DEFAULT",
"total_tokens": trace_info.total_tokens or 0,
"prompt_tokens": trace_info.message_tokens or 0,
"completion_tokens": trace_info.answer_tokens or 0,
"ls_provider": trace_info.message_data.model_provider or "",
"ls_model_name": trace_info.message_data.model_id or "",
}
message_metadata.update(trace_info.metadata)
# Add end user data if available
if trace_info.message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == trace_info.message_data.from_end_user_id).first()
)
if end_user_data is not None:
message_metadata["end_user_id"] = end_user_data.session_id
attributes = {
SpanAttributes.INPUT_VALUE: trace_info.message_data.query,
SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer,
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
}
trace_id = uuid_to_trace_id(trace_info.message_id)
message_span_id = RandomIdGenerator().generate_span_id()
span_context = SpanContext(
trace_id=trace_id,
span_id=message_span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
message_span = self.tracer.start_span(
name=TraceTaskName.MESSAGE_TRACE.value,
attributes=attributes,
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
)
try:
if trace_info.error:
message_span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
# Convert outputs to string based on type
if isinstance(trace_info.outputs, dict | list):
outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False)
elif isinstance(trace_info.outputs, str):
outputs_str = trace_info.outputs
else:
outputs_str = str(trace_info.outputs)
llm_attributes = {
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value,
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: outputs_str,
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
}
if isinstance(trace_info.inputs, list):
for i, msg in enumerate(trace_info.inputs):
if isinstance(msg, dict):
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get(
"role", "user"
)
# todo: handle assistant and tool role messages, as they don't always
# have a text field, but may have a tool_calls field instead
# e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
# 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
elif isinstance(trace_info.inputs, dict):
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(trace_info.inputs)
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
elif isinstance(trace_info.inputs, str):
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = trace_info.inputs
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
if trace_info.total_tokens is not None and trace_info.total_tokens > 0:
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = trace_info.total_tokens
if trace_info.message_tokens is not None and trace_info.message_tokens > 0:
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = trace_info.message_tokens
if trace_info.answer_tokens is not None and trace_info.answer_tokens > 0:
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = trace_info.answer_tokens
if trace_info.message_data.model_id is not None:
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = trace_info.message_data.model_id
if trace_info.message_data.model_provider is not None:
llm_attributes[SpanAttributes.LLM_PROVIDER] = trace_info.message_data.model_provider
if trace_info.message_data and trace_info.message_data.message_metadata:
metadata_dict = json.loads(trace_info.message_data.message_metadata)
if model_params := metadata_dict.get("model_parameters"):
llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params)
llm_span = self.tracer.start_span(
name="llm",
attributes=llm_attributes,
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
)
try:
if trace_info.error:
llm_span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
finally:
llm_span.end(end_time=datetime_to_nanos(trace_info.end_time))
finally:
message_span.end(end_time=datetime_to_nanos(trace_info.end_time))
def moderation_trace(self, trace_info: ModerationTraceInfo):
if trace_info.message_data is None:
return
metadata = {
"message_id": trace_info.message_id,
"tool_name": "moderation",
"status": trace_info.message_data.status,
"status_message": trace_info.message_data.error or "",
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
}
metadata.update(trace_info.metadata)
trace_id = uuid_to_trace_id(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
span = self.tracer.start_span(
name=TraceTaskName.MODERATION_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(
{
"action": trace_info.action,
"flagged": trace_info.flagged,
"preset_response": trace_info.preset_response,
"inputs": trace_info.inputs,
},
ensure_ascii=False,
),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
},
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
if trace_info.message_data.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.message_data.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.message_data.error,
},
)
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
if trace_info.message_data is None:
return
start_time = trace_info.start_time or trace_info.message_data.created_at
end_time = trace_info.end_time or trace_info.message_data.updated_at
metadata = {
"message_id": trace_info.message_id,
"tool_name": "suggested_question",
"status": trace_info.status,
"status_message": trace_info.error or "",
"level": "ERROR" if trace_info.error else "DEFAULT",
"total_tokens": trace_info.total_tokens,
"ls_provider": trace_info.model_provider or "",
"ls_model_name": trace_info.model_id or "",
}
metadata.update(trace_info.metadata)
trace_id = uuid_to_trace_id(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
span = self.tracer.start_span(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
},
start_time=datetime_to_nanos(start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
if trace_info.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
finally:
span.end(end_time=datetime_to_nanos(end_time))
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
start_time = trace_info.start_time or trace_info.message_data.created_at
end_time = trace_info.end_time or trace_info.message_data.updated_at
metadata = {
"message_id": trace_info.message_id,
"tool_name": "dataset_retrieval",
"status": trace_info.message_data.status,
"status_message": trace_info.message_data.error or "",
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
"ls_provider": trace_info.message_data.model_provider or "",
"ls_model_name": trace_info.message_data.model_id or "",
}
metadata.update(trace_info.metadata)
trace_id = uuid_to_trace_id(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
span = self.tracer.start_span(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps({"documents": trace_info.documents}, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.RETRIEVER.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
"start_time": start_time.isoformat() if start_time else "",
"end_time": end_time.isoformat() if end_time else "",
},
start_time=datetime_to_nanos(start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
if trace_info.message_data.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.message_data.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.message_data.error,
},
)
finally:
span.end(end_time=datetime_to_nanos(end_time))
def tool_trace(self, trace_info: ToolTraceInfo):
if trace_info.message_data is None:
logger.warning("[Arize/Phoenix] Message data is None, skipping tool trace.")
return
metadata = {
"message_id": trace_info.message_id,
"tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False),
}
trace_id = uuid_to_trace_id(trace_info.message_id)
tool_span_id = RandomIdGenerator().generate_span_id()
logger.info(f"[Arize/Phoenix] Creating tool trace with trace_id: {trace_id}, span_id: {tool_span_id}")
# Create span context with the same trace_id as the parent
# todo: Create with the appropriate parent span context, so that the tool span is
# a child of the appropriate span (e.g. message span)
span_context = SpanContext(
trace_id=trace_id,
span_id=tool_span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
tool_params_str = (
json.dumps(trace_info.tool_parameters, ensure_ascii=False)
if isinstance(trace_info.tool_parameters, dict)
else str(trace_info.tool_parameters)
)
span = self.tracer.start_span(
name=trace_info.tool_name,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs,
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
SpanAttributes.TOOL_NAME: trace_info.tool_name,
SpanAttributes.TOOL_PARAMETERS: tool_params_str,
},
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
)
try:
if trace_info.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
if trace_info.message_data is None:
return
metadata = {
"project_name": self.project,
"message_id": trace_info.message_id,
"status": trace_info.message_data.status,
"status_message": trace_info.message_data.error or "",
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
}
metadata.update(trace_info.metadata)
trace_id = uuid_to_trace_id(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
span = self.tracer.start_span(
name=TraceTaskName.GENERATE_NAME_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.outputs, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
"start_time": trace_info.start_time.isoformat() if trace_info.start_time else "",
"end_time": trace_info.end_time.isoformat() if trace_info.end_time else "",
},
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
if trace_info.message_data.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.message_data.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.message_data.error,
},
)
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
def api_check(self):
try:
with self.tracer.start_span("api_check") as span:
span.set_attribute("test", "true")
return True
except Exception as e:
logger.info(f"[Arize/Phoenix] API check failed: {str(e)}", exc_info=True)
raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}")
def get_project_url(self):
try:
if self.arize_phoenix_config.endpoint == "https://otlp.arize.com":
return "https://app.arize.com/"
else:
return f"{self.arize_phoenix_config.endpoint}/projects/"
except Exception as e:
logger.info(f"[Arize/Phoenix] Get run url failed: {str(e)}", exc_info=True)
raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
def _get_workflow_nodes(self, workflow_run_id: str):
"""Helper method to get workflow nodes"""
workflow_nodes = (
db.session.query(
WorkflowNodeExecutionModel.id,
WorkflowNodeExecutionModel.tenant_id,
WorkflowNodeExecutionModel.app_id,
WorkflowNodeExecutionModel.title,
WorkflowNodeExecutionModel.node_type,
WorkflowNodeExecutionModel.status,
WorkflowNodeExecutionModel.inputs,
WorkflowNodeExecutionModel.outputs,
WorkflowNodeExecutionModel.created_at,
WorkflowNodeExecutionModel.elapsed_time,
WorkflowNodeExecutionModel.process_data,
WorkflowNodeExecutionModel.execution_metadata,
)
.filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.all()
)
return workflow_nodes

@ -2,20 +2,92 @@ from enum import StrEnum
from pydantic import BaseModel, ValidationInfo, field_validator
from core.ops.utils import validate_project_name, validate_url, validate_url_with_path
class TracingProviderEnum(StrEnum):
ARIZE = "arize"
PHOENIX = "phoenix"
LANGFUSE = "langfuse"
LANGSMITH = "langsmith"
OPIK = "opik"
WEAVE = "weave"
ALIYUN = "aliyun"
class BaseTracingConfig(BaseModel):
"""
Base model class for tracing
Base model class for tracing configurations
"""
@classmethod
def validate_endpoint_url(cls, v: str, default_url: str) -> str:
"""
Common endpoint URL validation logic
Args:
v: URL value to validate
default_url: Default URL to use if input is None or empty
Returns:
Validated and normalized URL
"""
return validate_url(v, default_url)
@classmethod
def validate_project_field(cls, v: str, default_name: str) -> str:
"""
Common project name validation logic
Args:
v: Project name to validate
default_name: Default name to use if input is None or empty
Returns:
Validated project name
"""
return validate_project_name(v, default_name)
...
class ArizeConfig(BaseTracingConfig):
"""
Model class for Arize tracing config.
"""
api_key: str | None = None
space_id: str | None = None
project: str | None = None
endpoint: str = "https://otlp.arize.com"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "default")
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://otlp.arize.com")
class PhoenixConfig(BaseTracingConfig):
"""
Model class for Phoenix tracing config.
"""
api_key: str | None = None
project: str | None = None
endpoint: str = "https://app.phoenix.arize.com"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "default")
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://app.phoenix.arize.com")
class LangfuseConfig(BaseTracingConfig):
@ -29,13 +101,8 @@ class LangfuseConfig(BaseTracingConfig):
@field_validator("host")
@classmethod
def set_value(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "https://api.langfuse.com"
if not v.startswith("https://") and not v.startswith("http://"):
raise ValueError("host must start with https:// or http://")
return v
def host_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://api.langfuse.com")
class LangSmithConfig(BaseTracingConfig):
@ -49,13 +116,9 @@ class LangSmithConfig(BaseTracingConfig):
@field_validator("endpoint")
@classmethod
def set_value(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "https://api.smith.langchain.com"
if not v.startswith("https://"):
raise ValueError("endpoint must start with https://")
return v
def endpoint_validator(cls, v, info: ValidationInfo):
# LangSmith only allows HTTPS
return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",))
class OpikConfig(BaseTracingConfig):
@ -71,22 +134,12 @@ class OpikConfig(BaseTracingConfig):
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "Default Project"
return v
return cls.validate_project_field(v, "Default Project")
@field_validator("url")
@classmethod
def url_validator(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "https://www.comet.com/opik/api/"
if not v.startswith(("https://", "http://")):
raise ValueError("url must start with https:// or http://")
if not v.endswith("/api/"):
raise ValueError("url should ends with /api/")
return v
return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/")
class WeaveConfig(BaseTracingConfig):
@ -102,22 +155,44 @@ class WeaveConfig(BaseTracingConfig):
@field_validator("endpoint")
@classmethod
def set_value(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "https://trace.wandb.ai"
if not v.startswith("https://"):
raise ValueError("endpoint must start with https://")
def endpoint_validator(cls, v, info: ValidationInfo):
# Weave only allows HTTPS for endpoint
return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",))
@field_validator("host")
@classmethod
def host_validator(cls, v, info: ValidationInfo):
if v is not None and v.strip() != "":
return validate_url(v, v, allowed_schemes=("https", "http"))
return v
@field_validator("host")
class AliyunConfig(BaseTracingConfig):
"""
Model class for Aliyun tracing config.
"""
app_name: str = "dify_app"
license_key: str
endpoint: str
@field_validator("app_name")
@classmethod
def app_name_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "dify_app")
@field_validator("license_key")
@classmethod
def validate_host(cls, v, info: ValidationInfo):
if v is not None and v != "":
if not v.startswith(("https://", "http://")):
raise ValueError("host must start with https:// or http://")
def license_key_validator(cls, v, info: ValidationInfo):
if not v or v.strip() == "":
raise ValueError("License key cannot be empty")
return v
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"

@ -32,6 +32,7 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
from models.enums import MessageStatus
logger = logging.getLogger(__name__)
@ -83,6 +84,7 @@ class LangFuseDataTrace(BaseTraceInstance):
metadata=metadata,
session_id=trace_info.conversation_id,
tags=["message", "workflow"],
version=trace_info.workflow_run_version,
)
self.add_trace(langfuse_trace_data=trace_data)
workflow_span_data = LangfuseSpan(
@ -108,6 +110,7 @@ class LangFuseDataTrace(BaseTraceInstance):
metadata=metadata,
session_id=trace_info.conversation_id,
tags=["workflow"],
version=trace_info.workflow_run_version,
)
self.add_trace(langfuse_trace_data=trace_data)
@ -172,48 +175,15 @@ class LangFuseDataTrace(BaseTraceInstance):
}
)
# add span
if trace_info.message_id:
span_data = LangfuseSpan(
id=node_execution_id,
name=node_type,
input=inputs,
output=outputs,
trace_id=trace_id,
start_time=created_at,
end_time=finished_at,
metadata=metadata,
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
status_message=trace_info.error or "",
parent_observation_id=trace_info.workflow_run_id,
)
else:
span_data = LangfuseSpan(
id=node_execution_id,
name=node_type,
input=inputs,
output=outputs,
trace_id=trace_id,
start_time=created_at,
end_time=finished_at,
metadata=metadata,
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
status_message=trace_info.error or "",
)
self.add_span(langfuse_span_data=span_data)
# add generation span
if process_data and process_data.get("model_mode") == "chat":
total_token = metadata.get("total_tokens", 0)
prompt_tokens = 0
completion_tokens = 0
try:
if outputs.get("usage"):
prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = outputs.get("usage", {}).get("completion_tokens", 0)
else:
prompt_tokens = process_data.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = process_data.get("usage", {}).get("completion_tokens", 0)
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
prompt_tokens = usage_data.get("prompt_tokens", 0)
completion_tokens = usage_data.get("completion_tokens", 0)
except Exception:
logger.error("Failed to extract usage", exc_info=True)
@ -226,10 +196,10 @@ class LangFuseDataTrace(BaseTraceInstance):
)
node_generation_data = LangfuseGeneration(
name="llm",
id=node_execution_id,
name=node_name,
trace_id=trace_id,
model=process_data.get("model_name"),
parent_observation_id=node_execution_id,
start_time=created_at,
end_time=finished_at,
input=inputs,
@ -237,11 +207,30 @@ class LangFuseDataTrace(BaseTraceInstance):
metadata=metadata,
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
status_message=trace_info.error or "",
parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None,
usage=generation_usage,
)
self.add_generation(langfuse_generation_data=node_generation_data)
# add normal span
else:
span_data = LangfuseSpan(
id=node_execution_id,
name=node_name,
input=inputs,
output=outputs,
trace_id=trace_id,
start_time=created_at,
end_time=finished_at,
metadata=metadata,
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
status_message=trace_info.error or "",
parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None,
)
self.add_span(langfuse_span_data=span_data)
def message_trace(self, trace_info: MessageTraceInfo, **kwargs):
# get message file data
file_list = trace_info.file_list
@ -284,7 +273,7 @@ class LangFuseDataTrace(BaseTraceInstance):
)
self.add_trace(langfuse_trace_data=trace_data)
# start add span
# add generation
generation_usage = GenerationUsage(
input=trace_info.message_tokens,
output=trace_info.answer_tokens,
@ -302,7 +291,7 @@ class LangFuseDataTrace(BaseTraceInstance):
input=trace_info.inputs,
output=message_data.answer,
metadata=metadata,
level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
level=(LevelEnum.DEFAULT if message_data.status != MessageStatus.ERROR else LevelEnum.ERROR),
status_message=message_data.error or "",
usage=generation_usage,
)
@ -348,7 +337,7 @@ class LangFuseDataTrace(BaseTraceInstance):
start_time=trace_info.start_time,
end_time=trace_info.end_time,
metadata=trace_info.metadata,
level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
level=(LevelEnum.DEFAULT if message_data.status != MessageStatus.ERROR else LevelEnum.ERROR),
status_message=message_data.error or "",
usage=generation_usage,
)

@ -206,12 +206,9 @@ class LangSmithDataTrace(BaseTraceInstance):
prompt_tokens = 0
completion_tokens = 0
try:
if outputs.get("usage"):
prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = outputs.get("usage", {}).get("completion_tokens", 0)
else:
prompt_tokens = process_data.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = process_data.get("usage", {}).get("completion_tokens", 0)
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
prompt_tokens = usage_data.get("prompt_tokens", 0)
completion_tokens = usage_data.get("completion_tokens", 0)
except Exception:
logger.error("Failed to extract usage", exc_info=True)

@ -222,10 +222,10 @@ class OpikDataTrace(BaseTraceInstance):
)
try:
if outputs.get("usage"):
total_tokens = outputs["usage"].get("total_tokens", 0)
prompt_tokens = outputs["usage"].get("prompt_tokens", 0)
completion_tokens = outputs["usage"].get("completion_tokens", 0)
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
total_tokens = usage_data.get("total_tokens", 0)
prompt_tokens = usage_data.get("prompt_tokens", 0)
completion_tokens = usage_data.get("completion_tokens", 0)
except Exception:
logger.error("Failed to extract usage", exc_info=True)

@ -84,6 +84,36 @@ class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
"other_keys": ["project", "entity", "endpoint", "host"],
"trace_instance": WeaveDataTrace,
}
case TracingProviderEnum.ARIZE:
from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
from core.ops.entities.config_entity import ArizeConfig
return {
"config_class": ArizeConfig,
"secret_keys": ["api_key", "space_id"],
"other_keys": ["project", "endpoint"],
"trace_instance": ArizePhoenixDataTrace,
}
case TracingProviderEnum.PHOENIX:
from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
from core.ops.entities.config_entity import PhoenixConfig
return {
"config_class": PhoenixConfig,
"secret_keys": ["api_key"],
"other_keys": ["project", "endpoint"],
"trace_instance": ArizePhoenixDataTrace,
}
case TracingProviderEnum.ALIYUN:
from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
from core.ops.entities.config_entity import AliyunConfig
return {
"config_class": AliyunConfig,
"secret_keys": ["license_key"],
"other_keys": ["endpoint", "app_name"],
"trace_instance": AliyunDataTrace,
}
case _:
raise KeyError(f"Unsupported tracing provider: {provider}")

@ -1,6 +1,7 @@
from contextlib import contextmanager
from datetime import datetime
from typing import Optional, Union
from urllib.parse import urlparse
from extensions.ext_database import db
from models.model import Message
@ -60,3 +61,83 @@ def generate_dotted_order(
return current_segment
return f"{parent_dotted_order}.{current_segment}"
def validate_url(url: str, default_url: str, allowed_schemes: tuple = ("https", "http")) -> str:
"""
Validate and normalize URL with proper error handling
Args:
url: The URL to validate
default_url: Default URL to use if input is None or empty
allowed_schemes: Tuple of allowed URL schemes (default: https, http)
Returns:
Normalized URL string
Raises:
ValueError: If URL format is invalid or scheme not allowed
"""
if not url or url.strip() == "":
return default_url
# Parse URL to validate format
parsed = urlparse(url)
# Check if scheme is allowed
if parsed.scheme not in allowed_schemes:
raise ValueError(f"URL scheme must be one of: {', '.join(allowed_schemes)}")
# Reconstruct URL with only scheme, netloc (removing path, query, fragment)
normalized_url = f"{parsed.scheme}://{parsed.netloc}"
return normalized_url
def validate_url_with_path(url: str, default_url: str, required_suffix: str | None = None) -> str:
"""
Validate URL that may include path components
Args:
url: The URL to validate
default_url: Default URL to use if input is None or empty
required_suffix: Optional suffix that URL must end with
Returns:
Validated URL string
Raises:
ValueError: If URL format is invalid or doesn't match required suffix
"""
if not url or url.strip() == "":
return default_url
# Parse URL to validate format
parsed = urlparse(url)
# Check if scheme is allowed
if parsed.scheme not in ("https", "http"):
raise ValueError("URL must start with https:// or http://")
# Check required suffix if specified
if required_suffix and not url.endswith(required_suffix):
raise ValueError(f"URL should end with {required_suffix}")
return url
def validate_project_name(project: str, default_name: str) -> str:
"""
Validate and normalize project name
Args:
project: Project name to validate
default_name: Default name to use if input is None or empty
Returns:
Normalized project name
"""
if not project or project.strip() == "":
return default_name
return project.strip()

@ -2,8 +2,15 @@ import tempfile
from binascii import hexlify, unhexlify
from collections.abc import Generator
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
)
from core.model_runtime.entities.message_entities import (
PromptMessage,
SystemPromptMessage,
@ -12,6 +19,7 @@ from core.model_runtime.entities.message_entities import (
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from core.plugin.entities.request import (
RequestInvokeLLM,
RequestInvokeLLMWithStructuredOutput,
RequestInvokeModeration,
RequestInvokeRerank,
RequestInvokeSpeech2Text,
@ -81,6 +89,72 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
return handle_non_streaming(response)
@classmethod
def invoke_llm_with_structured_output(
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLMWithStructuredOutput
):
"""
invoke llm with structured output
"""
model_instance = ModelManager().get_model_instance(
tenant_id=tenant.id,
provider=payload.provider,
model_type=payload.model_type,
model=payload.model,
)
model_schema = model_instance.model_type_instance.get_model_schema(payload.model, model_instance.credentials)
if not model_schema:
raise ValueError(f"Model schema not found for {payload.model}")
response = invoke_llm_with_structured_output(
provider=payload.provider,
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=payload.prompt_messages,
json_schema=payload.structured_output_schema,
tools=payload.tools,
stop=payload.stop,
stream=True if payload.stream is None else payload.stream,
user=user_id,
model_parameters=payload.completion_params,
)
if isinstance(response, Generator):
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
for chunk in response:
if chunk.delta.usage:
llm_utils.deduct_llm_quota(
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
)
chunk.prompt_messages = []
yield chunk
return handle()
else:
if response.usage:
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle_non_streaming(
response: LLMResultWithStructuredOutput,
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
yield LLMResultChunkWithStructuredOutput(
model=response.model,
prompt_messages=[],
system_fingerprint=response.system_fingerprint,
structured_output=response.structured_output,
delta=LLMResultChunkDelta(
index=0,
message=response.message,
usage=response.usage,
finish_reason="",
),
)
return handle_non_streaming(response)
@classmethod
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
"""

@ -10,6 +10,9 @@ from core.tools.entities.common_entities import I18nObject
class PluginParameterOption(BaseModel):
value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option")
icon: Optional[str] = Field(
default=None, description="The icon of the option, can be a url or a base64 encoded image"
)
@field_validator("value", mode="before")
@classmethod
@ -35,6 +38,7 @@ class PluginParameterType(enum.StrEnum):
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT.value
# deprecated, should not use.
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value

@ -1,4 +1,4 @@
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import StrEnum
from typing import Any, Generic, Optional, TypeVar
@ -9,6 +9,7 @@ from core.agent.plugin_entities import AgentProviderEntityWithPlugin
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.entities.provider_entities import ProviderEntity
from core.plugin.entities.base import BasePluginEntity
from core.plugin.entities.parameters import PluginParameterOption
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
@ -186,3 +187,7 @@ class PluginOAuthCredentialsResponse(BaseModel):
class PluginListResponse(BaseModel):
list: list[PluginEntity]
total: int
class PluginDynamicSelectOptionsResponse(BaseModel):
options: Sequence[PluginParameterOption] = Field(description="The options of the dynamic select.")

@ -82,6 +82,16 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
return v
class RequestInvokeLLMWithStructuredOutput(RequestInvokeLLM):
"""
Request to invoke LLM with structured output
"""
structured_output_schema: dict[str, Any] = Field(
default_factory=dict, description="The schema of the structured output in JSON schema format"
)
class RequestInvokeTextEmbedding(BaseRequestInvokeModel):
"""
Request to invoke text embedding

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save