Compare commits

..

No commits in common. 'main' and 'feat/human-input' have entirely different histories.

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
npm add -g pnpm@10.13.1 npm add -g pnpm@10.11.1
cd web && pnpm install cd web && pnpm install
pipx install uv pipx install uv
@ -12,4 +12,3 @@ echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f do
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc
source /home/vscode/.bashrc source /home/vscode/.bashrc

@ -1,27 +0,0 @@
name: autofix.ci
on:
workflow_call:
pull_request:
push:
branches: [ "main" ]
permissions:
contents: read
jobs:
autofix:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
# Use uv to ensure we have the same ruff version in CI and locally.
- uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f
- run: |
cd api
uv sync --dev
# Fix lint errors
uv run ruff check --fix-only .
# Format code
uv run ruff format .
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27

@ -28,7 +28,7 @@ jobs:
- name: Check changed files - name: Check changed files
id: changed-files id: changed-files
uses: tj-actions/changed-files@v46 uses: tj-actions/changed-files@v45
with: with:
files: | files: |
api/** api/**
@ -75,7 +75,7 @@ jobs:
- name: Check changed files - name: Check changed files
id: changed-files id: changed-files
uses: tj-actions/changed-files@v46 uses: tj-actions/changed-files@v45
with: with:
files: web/** files: web/**
@ -113,7 +113,7 @@ jobs:
- name: Check changed files - name: Check changed files
id: changed-files id: changed-files
uses: tj-actions/changed-files@v46 uses: tj-actions/changed-files@v45
with: with:
files: | files: |
docker/generate_docker_compose docker/generate_docker_compose
@ -144,7 +144,7 @@ jobs:
- name: Check changed files - name: Check changed files
id: changed-files id: changed-files
uses: tj-actions/changed-files@v46 uses: tj-actions/changed-files@v45
with: with:
files: | files: |
**.sh **.sh
@ -152,15 +152,13 @@ jobs:
**.yml **.yml
**Dockerfile **Dockerfile
dev/** dev/**
.editorconfig
- name: Super-linter - name: Super-linter
uses: super-linter/super-linter/slim@v8 uses: super-linter/super-linter/slim@v7
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
env: env:
BASH_SEVERITY: warning BASH_SEVERITY: warning
DEFAULT_BRANCH: origin/main DEFAULT_BRANCH: main
EDITORCONFIG_FILE_NAME: editorconfig-checker.json
FILTER_REGEX_INCLUDE: pnpm-lock.yaml FILTER_REGEX_INCLUDE: pnpm-lock.yaml
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
IGNORE_GENERATED_FILES: true IGNORE_GENERATED_FILES: true
@ -170,6 +168,16 @@ jobs:
# FIXME: temporarily disabled until api-docker.yaml's run script is fixed for shellcheck # FIXME: temporarily disabled until api-docker.yaml's run script is fixed for shellcheck
# VALIDATE_GITHUB_ACTIONS: true # VALIDATE_GITHUB_ACTIONS: true
VALIDATE_DOCKERFILE_HADOLINT: true VALIDATE_DOCKERFILE_HADOLINT: true
VALIDATE_EDITORCONFIG: true
VALIDATE_XML: true VALIDATE_XML: true
VALIDATE_YAML: true VALIDATE_YAML: true
- name: EditorConfig checks
uses: super-linter/super-linter/slim@v7
env:
DEFAULT_BRANCH: main
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
IGNORE_GENERATED_FILES: true
IGNORE_GITIGNORED_FILES: true
# EditorConfig validation
VALIDATE_EDITORCONFIG: true
EDITORCONFIG_FILE_NAME: editorconfig-checker.json

@ -27,7 +27,7 @@ jobs:
- name: Check changed files - name: Check changed files
id: changed-files id: changed-files
uses: tj-actions/changed-files@v46 uses: tj-actions/changed-files@v45
with: with:
files: web/** files: web/**

@ -1,7 +1,7 @@
![cover-v5-optimized](./images/GitHub_README_if.png) ![cover-v5-optimized](./images/GitHub_README_if.png)
<p align="center"> <p align="center">
📌 <a href="https://dify.ai/blog/introducing-dify-workflow-file-upload-a-demo-on-ai-podcast">Introducing Dify Workflow File Upload: Recreate Google NotebookLM Podcast111</a> 📌 <a href="https://dify.ai/blog/introducing-dify-workflow-file-upload-a-demo-on-ai-podcast">Introducing Dify Workflow File Upload: Recreate Google NotebookLM Podcast</a>
</p> </p>
<p align="center"> <p align="center">

@ -54,7 +54,7 @@ REDIS_CLUSTERS_PASSWORD=
# celery configuration # celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1 CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
CELERY_BACKEND=redis
# PostgreSQL database configuration # PostgreSQL database configuration
DB_USERNAME=postgres DB_USERNAME=postgres
DB_PASSWORD=difyai123456 DB_PASSWORD=difyai123456
@ -142,10 +142,8 @@ WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
# Vector database configuration # Vector database configuration
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`. # support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore, matrixone
VECTOR_STORE=weaviate VECTOR_STORE=weaviate
# Prefix used to create collection name in vector database
VECTOR_INDEX_NAME_PREFIX=Vector_index
# Weaviate configuration # Weaviate configuration
WEAVIATE_ENDPOINT=http://localhost:8080 WEAVIATE_ENDPOINT=http://localhost:8080
@ -471,16 +469,6 @@ APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration # Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1 CELERY_BEAT_SCHEDULER_TIME=1
# Celery schedule tasks configuration
ENABLE_CLEAN_EMBEDDING_CACHE_TASK=false
ENABLE_CLEAN_UNUSED_DATASETS_TASK=false
ENABLE_CREATE_TIDB_SERVERLESS_TASK=false
ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false
ENABLE_CLEAN_MESSAGES=false
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
ENABLE_DATASETS_QUEUE_MONITOR=false
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
# Position configuration # Position configuration
POSITION_TOOL_PINS= POSITION_TOOL_PINS=
POSITION_TOOL_INCLUDES= POSITION_TOOL_INCLUDES=

@ -47,8 +47,6 @@ RUN \
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \ curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
# For Security # For Security
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \ expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
# install fonts to support the use of tools like pypdfium2
fonts-noto-cjk \
# install a package to improve the accuracy of guessing mime type and file extension # install a package to improve the accuracy of guessing mime type and file extension
media-types \ media-types \
# install libmagic to support the use of python-magic guess MIMETYPE # install libmagic to support the use of python-magic guess MIMETYPE

@ -74,12 +74,7 @@
10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service. 10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash ```bash
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion
```
Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal:
```bash
uv run celery -A app.celery beat
``` ```
## Testing ## Testing

@ -50,7 +50,7 @@ def reset_password(email, new_password, password_confirm):
click.echo(click.style("Passwords do not match.", fg="red")) click.echo(click.style("Passwords do not match.", fg="red"))
return return
account = db.session.query(Account).where(Account.email == email).one_or_none() account = db.session.query(Account).filter(Account.email == email).one_or_none()
if not account: if not account:
click.echo(click.style("Account not found for email: {}".format(email), fg="red")) click.echo(click.style("Account not found for email: {}".format(email), fg="red"))
@ -89,7 +89,7 @@ def reset_email(email, new_email, email_confirm):
click.echo(click.style("New emails do not match.", fg="red")) click.echo(click.style("New emails do not match.", fg="red"))
return return
account = db.session.query(Account).where(Account.email == email).one_or_none() account = db.session.query(Account).filter(Account.email == email).one_or_none()
if not account: if not account:
click.echo(click.style("Account not found for email: {}".format(email), fg="red")) click.echo(click.style("Account not found for email: {}".format(email), fg="red"))
@ -136,8 +136,8 @@ def reset_encrypt_key_pair():
tenant.encrypt_public_key = generate_key_pair(tenant.id) tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete() db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete()
db.session.commit() db.session.commit()
click.echo( click.echo(
@ -172,7 +172,7 @@ def migrate_annotation_vector_database():
per_page = 50 per_page = 50
apps = ( apps = (
db.session.query(App) db.session.query(App)
.where(App.status == "normal") .filter(App.status == "normal")
.order_by(App.created_at.desc()) .order_by(App.created_at.desc())
.limit(per_page) .limit(per_page)
.offset((page - 1) * per_page) .offset((page - 1) * per_page)
@ -192,7 +192,7 @@ def migrate_annotation_vector_database():
try: try:
click.echo("Creating app annotation index: {}".format(app.id)) click.echo("Creating app annotation index: {}".format(app.id))
app_annotation_setting = ( app_annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first() db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first()
) )
if not app_annotation_setting: if not app_annotation_setting:
@ -202,13 +202,13 @@ def migrate_annotation_vector_database():
# get dataset_collection_binding info # get dataset_collection_binding info
dataset_collection_binding = ( dataset_collection_binding = (
db.session.query(DatasetCollectionBinding) db.session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) .filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first() .first()
) )
if not dataset_collection_binding: if not dataset_collection_binding:
click.echo("App annotation collection binding not found: {}".format(app.id)) click.echo("App annotation collection binding not found: {}".format(app.id))
continue continue
annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all() annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
dataset = Dataset( dataset = Dataset(
id=app.id, id=app.id,
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
@ -305,7 +305,7 @@ def migrate_knowledge_vector_database():
while True: while True:
try: try:
stmt = ( stmt = (
select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) select(Dataset).filter(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc())
) )
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
@ -332,7 +332,7 @@ def migrate_knowledge_vector_database():
if dataset.collection_binding_id: if dataset.collection_binding_id:
dataset_collection_binding = ( dataset_collection_binding = (
db.session.query(DatasetCollectionBinding) db.session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == dataset.collection_binding_id) .filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none() .one_or_none()
) )
if dataset_collection_binding: if dataset_collection_binding:
@ -367,7 +367,7 @@ def migrate_knowledge_vector_database():
dataset_documents = ( dataset_documents = (
db.session.query(DatasetDocument) db.session.query(DatasetDocument)
.where( .filter(
DatasetDocument.dataset_id == dataset.id, DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == "completed", DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
@ -381,7 +381,7 @@ def migrate_knowledge_vector_database():
for dataset_document in dataset_documents: for dataset_document in dataset_documents:
segments = ( segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where( .filter(
DocumentSegment.document_id == dataset_document.id, DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed", DocumentSegment.status == "completed",
DocumentSegment.enabled == True, DocumentSegment.enabled == True,
@ -468,7 +468,7 @@ def convert_to_agent_apps():
app_id = str(i.id) app_id = str(i.id)
if app_id not in proceeded_app_ids: if app_id not in proceeded_app_ids:
proceeded_app_ids.append(app_id) proceeded_app_ids.append(app_id)
app = db.session.query(App).where(App.id == app_id).first() app = db.session.query(App).filter(App.id == app_id).first()
if app is not None: if app is not None:
apps.append(app) apps.append(app)
@ -483,7 +483,7 @@ def convert_to_agent_apps():
db.session.commit() db.session.commit()
# update conversation mode to agent # update conversation mode to agent
db.session.query(Conversation).where(Conversation.app_id == app.id).update( db.session.query(Conversation).filter(Conversation.app_id == app.id).update(
{Conversation.mode: AppMode.AGENT_CHAT.value} {Conversation.mode: AppMode.AGENT_CHAT.value}
) )
@ -560,7 +560,7 @@ def old_metadata_migration():
try: try:
stmt = ( stmt = (
select(DatasetDocument) select(DatasetDocument)
.where(DatasetDocument.doc_metadata.is_not(None)) .filter(DatasetDocument.doc_metadata.is_not(None))
.order_by(DatasetDocument.created_at.desc()) .order_by(DatasetDocument.created_at.desc())
) )
documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
@ -578,7 +578,7 @@ def old_metadata_migration():
else: else:
dataset_metadata = ( dataset_metadata = (
db.session.query(DatasetMetadata) db.session.query(DatasetMetadata)
.where(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key) .filter(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key)
.first() .first()
) )
if not dataset_metadata: if not dataset_metadata:
@ -602,7 +602,7 @@ def old_metadata_migration():
else: else:
dataset_metadata_binding = ( dataset_metadata_binding = (
db.session.query(DatasetMetadataBinding) # type: ignore db.session.query(DatasetMetadataBinding) # type: ignore
.where( .filter(
DatasetMetadataBinding.dataset_id == document.dataset_id, DatasetMetadataBinding.dataset_id == document.dataset_id,
DatasetMetadataBinding.document_id == document.id, DatasetMetadataBinding.document_id == document.id,
DatasetMetadataBinding.metadata_id == dataset_metadata.id, DatasetMetadataBinding.metadata_id == dataset_metadata.id,
@ -717,7 +717,7 @@ where sites.id is null limit 1000"""
continue continue
try: try:
app = db.session.query(App).where(App.id == app_id).first() app = db.session.query(App).filter(App.id == app_id).first()
if not app: if not app:
print(f"App {app_id} not found") print(f"App {app_id} not found")
continue continue

@ -832,41 +832,6 @@ class CeleryBeatConfig(BaseSettings):
) )
class CeleryScheduleTasksConfig(BaseSettings):
ENABLE_CLEAN_EMBEDDING_CACHE_TASK: bool = Field(
description="Enable clean embedding cache task",
default=False,
)
ENABLE_CLEAN_UNUSED_DATASETS_TASK: bool = Field(
description="Enable clean unused datasets task",
default=False,
)
ENABLE_CREATE_TIDB_SERVERLESS_TASK: bool = Field(
description="Enable create tidb service job task",
default=False,
)
ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK: bool = Field(
description="Enable update tidb service job status task",
default=False,
)
ENABLE_CLEAN_MESSAGES: bool = Field(
description="Enable clean messages task",
default=False,
)
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
description="Enable mail clean document notify task",
default=False,
)
ENABLE_DATASETS_QUEUE_MONITOR: bool = Field(
description="Enable queue monitor task",
default=False,
)
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: bool = Field(
description="Enable check upgradable plugin task",
default=True,
)
class PositionConfig(BaseSettings): class PositionConfig(BaseSettings):
POSITION_PROVIDER_PINS: str = Field( POSITION_PROVIDER_PINS: str = Field(
description="Comma-separated list of pinned model providers", description="Comma-separated list of pinned model providers",
@ -996,6 +961,5 @@ class FeatureConfig(
# hosted services config # hosted services config
HostedServiceConfig, HostedServiceConfig,
CeleryBeatConfig, CeleryBeatConfig,
CeleryScheduleTasksConfig,
): ):
pass pass

@ -85,11 +85,6 @@ class VectorStoreConfig(BaseSettings):
default=False, default=False,
) )
VECTOR_INDEX_NAME_PREFIX: Optional[str] = Field(
description="Prefix used to create collection name in vector database",
default="Vector_index",
)
class KeywordStoreConfig(BaseSettings): class KeywordStoreConfig(BaseSettings):
KEYWORD_STORE: str = Field( KEYWORD_STORE: str = Field(
@ -216,7 +211,7 @@ class DatabaseConfig(BaseSettings):
class CeleryConfig(DatabaseConfig): class CeleryConfig(DatabaseConfig):
CELERY_BACKEND: str = Field( CELERY_BACKEND: str = Field(
description="Backend for Celery task results. Options: 'database', 'redis'.", description="Backend for Celery task results. Options: 'database', 'redis'.",
default="redis", default="database",
) )
CELERY_BROKER_URL: Optional[str] = Field( CELERY_BROKER_URL: Optional[str] = Field(

@ -56,7 +56,7 @@ class InsertExploreAppListApi(Resource):
parser.add_argument("position", type=int, required=True, nullable=False, location="json") parser.add_argument("position", type=int, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none() app = db.session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
if not app: if not app:
raise NotFound(f"App '{args['app_id']}' is not found") raise NotFound(f"App '{args['app_id']}' is not found")
@ -74,7 +74,7 @@ class InsertExploreAppListApi(Resource):
with Session(db.engine) as session: with Session(db.engine) as session:
recommended_app = session.execute( recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]) select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"])
).scalar_one_or_none() ).scalar_one_or_none()
if not recommended_app: if not recommended_app:
@ -117,21 +117,21 @@ class InsertExploreAppApi(Resource):
def delete(self, app_id): def delete(self, app_id):
with Session(db.engine) as session: with Session(db.engine) as session:
recommended_app = session.execute( recommended_app = session.execute(
select(RecommendedApp).where(RecommendedApp.app_id == str(app_id)) select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id))
).scalar_one_or_none() ).scalar_one_or_none()
if not recommended_app: if not recommended_app:
return {"result": "success"}, 204 return {"result": "success"}, 204
with Session(db.engine) as session: with Session(db.engine) as session:
app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none() app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none()
if app: if app:
app.is_public = False app.is_public = False
with Session(db.engine) as session: with Session(db.engine) as session:
installed_apps = session.execute( installed_apps = session.execute(
select(InstalledApp).where( select(InstalledApp).filter(
InstalledApp.app_id == recommended_app.app_id, InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
) )

@ -61,7 +61,7 @@ class BaseApiKeyListResource(Resource):
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
keys = ( keys = (
db.session.query(ApiToken) db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.all() .all()
) )
return {"items": keys} return {"items": keys}
@ -76,7 +76,7 @@ class BaseApiKeyListResource(Resource):
current_key_count = ( current_key_count = (
db.session.query(ApiToken) db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.count() .count()
) )
@ -117,7 +117,7 @@ class BaseApiKeyResource(Resource):
key = ( key = (
db.session.query(ApiToken) db.session.query(ApiToken)
.where( .filter(
getattr(ApiToken, self.resource_id_field) == resource_id, getattr(ApiToken, self.resource_id_field) == resource_id,
ApiToken.type == self.resource_type, ApiToken.type == self.resource_type,
ApiToken.id == api_key_id, ApiToken.id == api_key_id,
@ -128,7 +128,7 @@ class BaseApiKeyResource(Resource):
if key is None: if key is None:
flask_restful.abort(404, message="API key not found") flask_restful.abort(404, message="API key not found")
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
db.session.commit() db.session.commit()
return {"result": "success"}, 204 return {"result": "success"}, 204

@ -1,4 +1,4 @@
from datetime import datetime from datetime import UTC, datetime
import pytz # pip install pytz import pytz # pip install pytz
from flask_login import current_user from flask_login import current_user
@ -19,7 +19,6 @@ from fields.conversation_fields import (
conversation_pagination_fields, conversation_pagination_fields,
conversation_with_summary_pagination_fields, conversation_with_summary_pagination_fields,
) )
from libs.datetime_utils import naive_utc_now
from libs.helper import DatetimeString from libs.helper import DatetimeString
from libs.login import login_required from libs.login import login_required
from models import Conversation, EndUser, Message, MessageAnnotation from models import Conversation, EndUser, Message, MessageAnnotation
@ -49,7 +48,7 @@ class CompletionConversationApi(Resource):
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion") query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion")
if args["keyword"]: if args["keyword"]:
query = query.join(Message, Message.conversation_id == Conversation.id).where( query = query.join(Message, Message.conversation_id == Conversation.id).filter(
or_( or_(
Message.query.ilike("%{}%".format(args["keyword"])), Message.query.ilike("%{}%".format(args["keyword"])),
Message.answer.ilike("%{}%".format(args["keyword"])), Message.answer.ilike("%{}%".format(args["keyword"])),
@ -121,7 +120,7 @@ class CompletionConversationDetailApi(Resource):
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first() .first()
) )
@ -181,7 +180,7 @@ class ChatConversationApi(Resource):
Message.conversation_id == Conversation.id, Message.conversation_id == Conversation.id,
) )
.join(subquery, subquery.c.conversation_id == Conversation.id) .join(subquery, subquery.c.conversation_id == Conversation.id)
.where( .filter(
or_( or_(
Message.query.ilike(keyword_filter), Message.query.ilike(keyword_filter),
Message.answer.ilike(keyword_filter), Message.answer.ilike(keyword_filter),
@ -286,7 +285,7 @@ class ChatConversationDetailApi(Resource):
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first() .first()
) )
@ -308,7 +307,7 @@ api.add_resource(ChatConversationDetailApi, "/apps/<uuid:app_id>/chat-conversati
def _get_conversation(app_model, conversation_id): def _get_conversation(app_model, conversation_id):
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)
.where(Conversation.id == conversation_id, Conversation.app_id == app_model.id) .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first() .first()
) )
@ -316,7 +315,7 @@ def _get_conversation(app_model, conversation_id):
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
if not conversation.read_at: if not conversation.read_at:
conversation.read_at = naive_utc_now() conversation.read_at = datetime.now(UTC).replace(tzinfo=None)
conversation.read_account_id = current_user.id conversation.read_account_id = current_user.id
db.session.commit() db.session.commit()

@ -26,7 +26,7 @@ class AppMCPServerController(Resource):
@get_app_model @get_app_model
@marshal_with(app_server_fields) @marshal_with(app_server_fields)
def get(self, app_model): def get(self, app_model):
server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first() server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == app_model.id).first()
return server return server
@setup_required @setup_required
@ -73,7 +73,7 @@ class AppMCPServerController(Resource):
parser.add_argument("parameters", type=dict, required=True, location="json") parser.add_argument("parameters", type=dict, required=True, location="json")
parser.add_argument("status", type=str, required=False, location="json") parser.add_argument("status", type=str, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first() server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
if not server: if not server:
raise NotFound() raise NotFound()
@ -104,8 +104,8 @@ class AppMCPServerRefreshController(Resource):
raise NotFound() raise NotFound()
server = ( server = (
db.session.query(AppMCPServer) db.session.query(AppMCPServer)
.where(AppMCPServer.id == server_id) .filter(AppMCPServer.id == server_id)
.where(AppMCPServer.tenant_id == current_user.current_tenant_id) .filter(AppMCPServer.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not server: if not server:

@ -5,7 +5,6 @@ from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range from flask_restful.inputs import int_range
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console import api from controllers.console import api
from controllers.console.app.error import ( from controllers.console.app.error import (
CompletionRequestError, CompletionRequestError,
@ -28,7 +27,7 @@ from fields.conversation_fields import annotation_fields, message_detail_fields
from libs.helper import uuid_value from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import login_required from libs.login import login_required
from models.model import AppMode, Conversation, Message, MessageAnnotation from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.annotation_service import AppAnnotationService from services.annotation_service import AppAnnotationService
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
@ -56,7 +55,7 @@ class ChatMessageListApi(Resource):
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)
.where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id) .filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
.first() .first()
) )
@ -66,7 +65,7 @@ class ChatMessageListApi(Resource):
if args["first_id"]: if args["first_id"]:
first_message = ( first_message = (
db.session.query(Message) db.session.query(Message)
.where(Message.conversation_id == conversation.id, Message.id == args["first_id"]) .filter(Message.conversation_id == conversation.id, Message.id == args["first_id"])
.first() .first()
) )
@ -75,7 +74,7 @@ class ChatMessageListApi(Resource):
history_messages = ( history_messages = (
db.session.query(Message) db.session.query(Message)
.where( .filter(
Message.conversation_id == conversation.id, Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at, Message.created_at < first_message.created_at,
Message.id != first_message.id, Message.id != first_message.id,
@ -87,7 +86,7 @@ class ChatMessageListApi(Resource):
else: else:
history_messages = ( history_messages = (
db.session.query(Message) db.session.query(Message)
.where(Message.conversation_id == conversation.id) .filter(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc()) .order_by(Message.created_at.desc())
.limit(args["limit"]) .limit(args["limit"])
.all() .all()
@ -98,7 +97,7 @@ class ChatMessageListApi(Resource):
current_page_first_message = history_messages[-1] current_page_first_message = history_messages[-1]
rest_count = ( rest_count = (
db.session.query(Message) db.session.query(Message)
.where( .filter(
Message.conversation_id == conversation.id, Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at, Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id, Message.id != current_page_first_message.id,
@ -125,17 +124,34 @@ class MessageFeedbackApi(Resource):
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
args = parser.parse_args() args = parser.parse_args()
try: message_id = str(args["message_id"])
MessageService.create_feedback(
app_model=app_model, message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
message_id=str(args["message_id"]),
user=current_user, if not message:
rating=args.get("rating"),
content=None,
)
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
feedback = message.admin_feedback
if not args["rating"] and feedback:
db.session.delete(feedback)
elif args["rating"] and feedback:
feedback.rating = args["rating"]
elif not args["rating"] and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
else:
feedback = MessageFeedback(
app_id=app_model.id,
conversation_id=message.conversation_id,
message_id=message.id,
rating=args["rating"],
from_source="admin",
from_account_id=current_user.id,
)
db.session.add(feedback)
db.session.commit()
return {"result": "success"} return {"result": "success"}
@ -167,7 +183,7 @@ class MessageAnnotationCountApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
def get(self, app_model): def get(self, app_model):
count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count() count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count()
return {"count": count} return {"count": count}
@ -214,7 +230,7 @@ class MessageApi(Resource):
def get(self, app_model, message_id): def get(self, app_model, message_id):
message_id = str(message_id) message_id = str(message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
if not message: if not message:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")

@ -42,7 +42,7 @@ class ModelConfigResource(Resource):
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
# get original app model config # get original app model config
original_app_model_config = ( original_app_model_config = (
db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first() db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
) )
if original_app_model_config is None: if original_app_model_config is None:
raise ValueError("Original app model config not found") raise ValueError("Original app model config not found")

@ -1,3 +1,5 @@
from datetime import UTC, datetime
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@ -8,7 +10,6 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db from extensions.ext_database import db
from fields.app_fields import app_site_fields from fields.app_fields import app_site_fields
from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import login_required
from models import Site from models import Site
@ -49,7 +50,7 @@ class AppSite(Resource):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
site = db.session.query(Site).where(Site.app_id == app_model.id).first() site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
if not site: if not site:
raise NotFound raise NotFound
@ -76,7 +77,7 @@ class AppSite(Resource):
setattr(site, attr_name, value) setattr(site, attr_name, value)
site.updated_by = current_user.id site.updated_by = current_user.id
site.updated_at = naive_utc_now() site.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
return site return site
@ -93,14 +94,14 @@ class AppSiteAccessTokenReset(Resource):
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
site = db.session.query(Site).where(Site.app_id == app_model.id).first() site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
if not site: if not site:
raise NotFound raise NotFound
site.code = Site.generate_code(16) site.code = Site.generate_code(16)
site.updated_by = current_user.id site.updated_by = current_user.id
site.updated_at = naive_utc_now() site.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
return site return site

@ -11,7 +11,7 @@ from models import App, AppMode
def _load_app_model(app_id: str) -> Optional[App]: def _load_app_model(app_id: str) -> Optional[App]:
app_model = ( app_model = (
db.session.query(App) db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first() .first()
) )
return app_model return app_model

@ -1,3 +1,5 @@
import datetime
from flask import request from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
@ -5,7 +7,6 @@ from constants.languages import supported_language
from controllers.console import api from controllers.console import api
from controllers.console.error import AlreadyActivateError from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import StrLen, email, extract_remote_ip, timezone from libs.helper import StrLen, email, extract_remote_ip, timezone
from models.account import AccountStatus from models.account import AccountStatus
from services.account_service import AccountService, RegisterService from services.account_service import AccountService, RegisterService
@ -64,7 +65,7 @@ class ActivateApi(Resource):
account.timezone = args["timezone"] account.timezone = args["timezone"]
account.interface_theme = "light" account.interface_theme = "light"
account.status = AccountStatus.ACTIVE.value account.status = AccountStatus.ACTIVE.value
account.initialized_at = naive_utc_now() account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))

@ -113,9 +113,3 @@ class MemberNotInTenantError(BaseHTTPException):
error_code = "member_not_in_tenant" error_code = "member_not_in_tenant"
description = "The member is not in the workspace." description = "The member is not in the workspace."
code = 400 code = 400
class AccountInFreezeError(BaseHTTPException):
error_code = "account_in_freeze"
description = "This email is temporarily unavailable."
code = 400

@ -1,4 +1,5 @@
import logging import logging
from datetime import UTC, datetime
from typing import Optional from typing import Optional
import requests import requests
@ -12,7 +13,6 @@ from configs import dify_config
from constants.languages import languages from constants.languages import languages
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_remote_ip from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models import Account from models import Account
@ -110,7 +110,7 @@ class OAuthCallback(Resource):
if account.status == AccountStatus.PENDING.value: if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value account.status = AccountStatus.ACTIVE.value
account.initialized_at = naive_utc_now() account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
try: try:

@ -1,3 +1,4 @@
import datetime
import json import json
from flask import request from flask import request
@ -14,7 +15,6 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import login_required
from models import DataSourceOauthBinding, Document from models import DataSourceOauthBinding, Document
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
@ -30,7 +30,7 @@ class DataSourceApi(Resource):
# get workspace data source integrates # get workspace data source integrates
data_source_integrates = ( data_source_integrates = (
db.session.query(DataSourceOauthBinding) db.session.query(DataSourceOauthBinding)
.where( .filter(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.disabled == False,
) )
@ -88,7 +88,7 @@ class DataSourceApi(Resource):
if action == "enable": if action == "enable":
if data_source_binding.disabled: if data_source_binding.disabled:
data_source_binding.disabled = False data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now() data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.add(data_source_binding) db.session.add(data_source_binding)
db.session.commit() db.session.commit()
else: else:
@ -97,7 +97,7 @@ class DataSourceApi(Resource):
if action == "disable": if action == "disable":
if not data_source_binding.disabled: if not data_source_binding.disabled:
data_source_binding.disabled = True data_source_binding.disabled = True
data_source_binding.updated_at = naive_utc_now() data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.add(data_source_binding) db.session.add(data_source_binding)
db.session.commit() db.session.commit()
else: else:
@ -171,7 +171,7 @@ class DataSourceNotionApi(Resource):
page_id = str(page_id) page_id = str(page_id)
with Session(db.engine) as session: with Session(db.engine) as session:
data_source_binding = session.execute( data_source_binding = session.execute(
select(DataSourceOauthBinding).where( select(DataSourceOauthBinding).filter(
db.and_( db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",

@ -211,6 +211,10 @@ class DatasetApi(Resource):
else: else:
data["embedding_available"] = True data["embedding_available"] = True
if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
data.update({"partial_member_list": part_users_list})
return data, 200 return data, 200
@setup_required @setup_required
@ -412,7 +416,7 @@ class DatasetIndexingEstimateApi(Resource):
file_ids = args["info_list"]["file_info_list"]["file_ids"] file_ids = args["info_list"]["file_info_list"]["file_ids"]
file_details = ( file_details = (
db.session.query(UploadFile) db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
.all() .all()
) )
@ -517,14 +521,14 @@ class DatasetIndexingStatusApi(Resource):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
documents = ( documents = (
db.session.query(Document) db.session.query(Document)
.where(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id) .filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
.all() .all()
) )
documents_status = [] documents_status = []
for document in documents: for document in documents:
completed_segments = ( completed_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where( .filter(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id), DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment", DocumentSegment.status != "re_segment",
@ -533,7 +537,7 @@ class DatasetIndexingStatusApi(Resource):
) )
total_segments = ( total_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count() .count()
) )
# Create a dictionary with document attributes and additional fields # Create a dictionary with document attributes and additional fields
@ -568,7 +572,7 @@ class DatasetApiKeyApi(Resource):
def get(self): def get(self):
keys = ( keys = (
db.session.query(ApiToken) db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.all() .all()
) )
return {"items": keys} return {"items": keys}
@ -584,7 +588,7 @@ class DatasetApiKeyApi(Resource):
current_key_count = ( current_key_count = (
db.session.query(ApiToken) db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.count() .count()
) )
@ -620,7 +624,7 @@ class DatasetApiDeleteApi(Resource):
key = ( key = (
db.session.query(ApiToken) db.session.query(ApiToken)
.where( .filter(
ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.tenant_id == current_user.current_tenant_id,
ApiToken.type == self.resource_type, ApiToken.type == self.resource_type,
ApiToken.id == api_key_id, ApiToken.id == api_key_id,
@ -631,7 +635,7 @@ class DatasetApiDeleteApi(Resource):
if key is None: if key is None:
flask_restful.abort(404, message="API key not found") flask_restful.abort(404, message="API key not found")
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
db.session.commit() db.session.commit()
return {"result": "success"}, 204 return {"result": "success"}, 204

@ -1,5 +1,6 @@
import logging import logging
from argparse import ArgumentTypeError from argparse import ArgumentTypeError
from datetime import UTC, datetime
from typing import cast from typing import cast
from flask import request from flask import request
@ -48,7 +49,6 @@ from fields.document_fields import (
document_status_fields, document_status_fields,
document_with_segments_fields, document_with_segments_fields,
) )
from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
@ -124,7 +124,7 @@ class GetProcessRuleApi(Resource):
# get the latest process rule # get the latest process rule
dataset_process_rule = ( dataset_process_rule = (
db.session.query(DatasetProcessRule) db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.dataset_id == document.dataset_id) .filter(DatasetProcessRule.dataset_id == document.dataset_id)
.order_by(DatasetProcessRule.created_at.desc()) .order_by(DatasetProcessRule.created_at.desc())
.limit(1) .limit(1)
.one_or_none() .one_or_none()
@ -176,7 +176,7 @@ class DatasetDocumentListApi(Resource):
if search: if search:
search = f"%{search}%" search = f"%{search}%"
query = query.where(Document.name.like(search)) query = query.filter(Document.name.like(search))
if sort.startswith("-"): if sort.startswith("-"):
sort_logic = desc sort_logic = desc
@ -212,7 +212,7 @@ class DatasetDocumentListApi(Resource):
for document in documents: for document in documents:
completed_segments = ( completed_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where( .filter(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id), DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment", DocumentSegment.status != "re_segment",
@ -221,7 +221,7 @@ class DatasetDocumentListApi(Resource):
) )
total_segments = ( total_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count() .count()
) )
document.completed_segments = completed_segments document.completed_segments = completed_segments
@ -417,7 +417,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
file = ( file = (
db.session.query(UploadFile) db.session.query(UploadFile)
.where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) .filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
.first() .first()
) )
@ -492,7 +492,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
file_id = data_source_info["upload_file_id"] file_id = data_source_info["upload_file_id"]
file_detail = ( file_detail = (
db.session.query(UploadFile) db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id) .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id)
.first() .first()
) )
@ -568,7 +568,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
for document in documents: for document in documents:
completed_segments = ( completed_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where( .filter(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id), DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment", DocumentSegment.status != "re_segment",
@ -577,7 +577,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
) )
total_segments = ( total_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count() .count()
) )
# Create a dictionary with document attributes and additional fields # Create a dictionary with document attributes and additional fields
@ -611,7 +611,7 @@ class DocumentIndexingStatusApi(DocumentResource):
completed_segments = ( completed_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where( .filter(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id), DocumentSegment.document_id == str(document_id),
DocumentSegment.status != "re_segment", DocumentSegment.status != "re_segment",
@ -620,7 +620,7 @@ class DocumentIndexingStatusApi(DocumentResource):
) )
total_segments = ( total_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment") .filter(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
.count() .count()
) )
@ -750,7 +750,7 @@ class DocumentProcessingApi(DocumentResource):
raise InvalidActionError("Document not in indexing state.") raise InvalidActionError("Document not in indexing state.")
document.paused_by = current_user.id document.paused_by = current_user.id
document.paused_at = naive_utc_now() document.paused_at = datetime.now(UTC).replace(tzinfo=None)
document.is_paused = True document.is_paused = True
db.session.commit() db.session.commit()
@ -830,7 +830,7 @@ class DocumentMetadataApi(DocumentResource):
document.doc_metadata[key] = value document.doc_metadata[key] = value
document.doc_type = doc_type document.doc_type = doc_type
document.updated_at = naive_utc_now() document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
return {"result": "success", "message": "Document metadata updated."}, 200 return {"result": "success", "message": "Document metadata updated."}, 200

@ -78,7 +78,7 @@ class DatasetDocumentSegmentListApi(Resource):
query = ( query = (
select(DocumentSegment) select(DocumentSegment)
.where( .filter(
DocumentSegment.document_id == str(document_id), DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id, DocumentSegment.tenant_id == current_user.current_tenant_id,
) )
@ -86,19 +86,19 @@ class DatasetDocumentSegmentListApi(Resource):
) )
if status_list: if status_list:
query = query.where(DocumentSegment.status.in_(status_list)) query = query.filter(DocumentSegment.status.in_(status_list))
if hit_count_gte is not None: if hit_count_gte is not None:
query = query.where(DocumentSegment.hit_count >= hit_count_gte) query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
if keyword: if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
if args["enabled"].lower() != "all": if args["enabled"].lower() != "all":
if args["enabled"].lower() == "true": if args["enabled"].lower() == "true":
query = query.where(DocumentSegment.enabled == True) query = query.filter(DocumentSegment.enabled == True)
elif args["enabled"].lower() == "false": elif args["enabled"].lower() == "false":
query = query.where(DocumentSegment.enabled == False) query = query.filter(DocumentSegment.enabled == False)
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@ -285,7 +285,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -331,7 +331,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -436,7 +436,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -493,7 +493,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -540,7 +540,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -586,7 +586,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -595,7 +595,7 @@ class ChildChunkUpdateApi(Resource):
child_chunk_id = str(child_chunk_id) child_chunk_id = str(child_chunk_id)
child_chunk = ( child_chunk = (
db.session.query(ChildChunk) db.session.query(ChildChunk)
.where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) .filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not child_chunk: if not child_chunk:
@ -635,7 +635,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -644,7 +644,7 @@ class ChildChunkUpdateApi(Resource):
child_chunk_id = str(child_chunk_id) child_chunk_id = str(child_chunk_id)
child_chunk = ( child_chunk = (
db.session.query(ChildChunk) db.session.query(ChildChunk)
.where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) .filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not child_chunk: if not child_chunk:

@ -4,7 +4,7 @@ from controllers.console import api
from controllers.console.datasets.error import WebsiteCrawlError from controllers.console.datasets.error import WebsiteCrawlError
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required from libs.login import login_required
from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService from services.website_service import WebsiteService
class WebsiteCrawlApi(Resource): class WebsiteCrawlApi(Resource):
@ -24,16 +24,10 @@ class WebsiteCrawlApi(Resource):
parser.add_argument("url", type=str, required=True, nullable=True, location="json") parser.add_argument("url", type=str, required=True, nullable=True, location="json")
parser.add_argument("options", type=dict, required=True, nullable=True, location="json") parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
args = parser.parse_args() args = parser.parse_args()
WebsiteService.document_create_args_validate(args)
# Create typed request and validate # crawl url
try:
api_request = WebsiteCrawlApiRequest.from_args(args)
except ValueError as e:
raise WebsiteCrawlError(str(e))
# Crawl URL using typed request
try: try:
result = WebsiteService.crawl_url(api_request) result = WebsiteService.crawl_url(args)
except Exception as e: except Exception as e:
raise WebsiteCrawlError(str(e)) raise WebsiteCrawlError(str(e))
return result, 200 return result, 200
@ -49,16 +43,9 @@ class WebsiteCrawlStatusApi(Resource):
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args" "provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
) )
args = parser.parse_args() args = parser.parse_args()
# get crawl status
# Create typed request and validate
try:
api_request = WebsiteCrawlStatusApiRequest.from_args(args, job_id)
except ValueError as e:
raise WebsiteCrawlError(str(e))
# Get crawl status using typed request
try: try:
result = WebsiteService.get_crawl_status_typed(api_request) result = WebsiteService.get_crawl_status(job_id, args["provider"])
except Exception as e: except Exception as e:
raise WebsiteCrawlError(str(e)) raise WebsiteCrawlError(str(e))
return result, 200 return result, 200

@ -1,4 +1,5 @@
import logging import logging
from datetime import UTC, datetime
from flask_login import current_user from flask_login import current_user
from flask_restful import reqparse from flask_restful import reqparse
@ -26,7 +27,6 @@ from core.errors.error import (
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db from extensions.ext_database import db
from libs import helper from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.helper import uuid_value from libs.helper import uuid_value
from models.model import AppMode from models.model import AppMode
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
@ -51,7 +51,7 @@ class CompletionApi(InstalledAppResource):
streaming = args["response_mode"] == "streaming" streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False args["auto_generate_name"] = False
installed_app.last_used_at = naive_utc_now() installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
try: try:
@ -111,7 +111,7 @@ class ChatApi(InstalledAppResource):
args["auto_generate_name"] = False args["auto_generate_name"] = False
installed_app.last_used_at = naive_utc_now() installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
try: try:

@ -1,4 +1,5 @@
import logging import logging
from datetime import UTC, datetime
from typing import Any from typing import Any
from flask import request from flask import request
@ -12,7 +13,6 @@ from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields from fields.installed_app_fields import installed_app_list_fields
from libs.datetime_utils import naive_utc_now
from libs.login import login_required from libs.login import login_required
from models import App, InstalledApp, RecommendedApp from models import App, InstalledApp, RecommendedApp
from services.account_service import TenantService from services.account_service import TenantService
@ -34,11 +34,11 @@ class InstalledAppsListApi(Resource):
if app_id: if app_id:
installed_apps = ( installed_apps = (
db.session.query(InstalledApp) db.session.query(InstalledApp)
.where(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)) .filter(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
.all() .all()
) )
else: else:
installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all() installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
installed_app_list: list[dict[str, Any]] = [ installed_app_list: list[dict[str, Any]] = [
@ -94,12 +94,12 @@ class InstalledAppsListApi(Resource):
parser.add_argument("app_id", type=str, required=True, help="Invalid app_id") parser.add_argument("app_id", type=str, required=True, help="Invalid app_id")
args = parser.parse_args() args = parser.parse_args()
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first() recommended_app = db.session.query(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]).first()
if recommended_app is None: if recommended_app is None:
raise NotFound("App not found") raise NotFound("App not found")
current_tenant_id = current_user.current_tenant_id current_tenant_id = current_user.current_tenant_id
app = db.session.query(App).where(App.id == args["app_id"]).first() app = db.session.query(App).filter(App.id == args["app_id"]).first()
if app is None: if app is None:
raise NotFound("App not found") raise NotFound("App not found")
@ -109,7 +109,7 @@ class InstalledAppsListApi(Resource):
installed_app = ( installed_app = (
db.session.query(InstalledApp) db.session.query(InstalledApp)
.where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)) .filter(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id))
.first() .first()
) )
@ -122,7 +122,7 @@ class InstalledAppsListApi(Resource):
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
app_owner_tenant_id=app.tenant_id, app_owner_tenant_id=app.tenant_id,
is_pinned=False, is_pinned=False,
last_used_at=naive_utc_now(), last_used_at=datetime.now(UTC).replace(tzinfo=None),
) )
db.session.add(new_installed_app) db.session.add(new_installed_app)
db.session.commit() db.session.commit()

@ -28,7 +28,7 @@ def installed_app_required(view=None):
installed_app = ( installed_app = (
db.session.query(InstalledApp) db.session.query(InstalledApp)
.where( .filter(
InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id
) )
.first() .first()

@ -21,7 +21,7 @@ def plugin_permission_required(
with Session(db.engine) as session: with Session(db.engine) as session:
permission = ( permission = (
session.query(TenantPluginPermission) session.query(TenantPluginPermission)
.where( .filter(
TenantPluginPermission.tenant_id == tenant_id, TenantPluginPermission.tenant_id == tenant_id,
) )
.first() .first()

@ -1,3 +1,5 @@
import datetime
import pytz import pytz
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
@ -9,7 +11,6 @@ from configs import dify_config
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import api from controllers.console import api
from controllers.console.auth.error import ( from controllers.console.auth.error import (
AccountInFreezeError,
EmailAlreadyInUseError, EmailAlreadyInUseError,
EmailChangeLimitError, EmailChangeLimitError,
EmailCodeError, EmailCodeError,
@ -34,7 +35,6 @@ from controllers.console.wraps import (
) )
from extensions.ext_database import db from extensions.ext_database import db
from fields.member_fields import account_fields from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, email, extract_remote_ip, timezone from libs.helper import TimestampField, email, extract_remote_ip, timezone
from libs.login import login_required from libs.login import login_required
from models import AccountIntegrate, InvitationCode from models import AccountIntegrate, InvitationCode
@ -69,7 +69,7 @@ class AccountInitApi(Resource):
# check invitation code # check invitation code
invitation_code = ( invitation_code = (
db.session.query(InvitationCode) db.session.query(InvitationCode)
.where( .filter(
InvitationCode.code == args["invitation_code"], InvitationCode.code == args["invitation_code"],
InvitationCode.status == "unused", InvitationCode.status == "unused",
) )
@ -80,7 +80,7 @@ class AccountInitApi(Resource):
raise InvalidInvitationCodeError() raise InvalidInvitationCodeError()
invitation_code.status = "used" invitation_code.status = "used"
invitation_code.used_at = naive_utc_now() invitation_code.used_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_tenant_id = account.current_tenant_id
invitation_code.used_by_account_id = account.id invitation_code.used_by_account_id = account.id
@ -88,7 +88,7 @@ class AccountInitApi(Resource):
account.timezone = args["timezone"] account.timezone = args["timezone"]
account.interface_theme = "light" account.interface_theme = "light"
account.status = "active" account.status = "active"
account.initialized_at = naive_utc_now() account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
return {"result": "success"} return {"result": "success"}
@ -229,7 +229,7 @@ class AccountIntegrateApi(Resource):
def get(self): def get(self):
account = current_user account = current_user
account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all() account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all()
base_url = request.url_root.rstrip("/") base_url = request.url_root.rstrip("/")
oauth_base_path = "/console/api/oauth/login" oauth_base_path = "/console/api/oauth/login"
@ -480,28 +480,21 @@ class ChangeEmailResetApi(Resource):
parser.add_argument("token", type=str, required=True, nullable=False, location="json") parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
if AccountService.is_account_in_freeze(args["new_email"]):
raise AccountInFreezeError()
if not AccountService.check_email_unique(args["new_email"]):
raise EmailAlreadyInUseError()
reset_data = AccountService.get_change_email_data(args["token"]) reset_data = AccountService.get_change_email_data(args["token"])
if not reset_data: if not reset_data:
raise InvalidTokenError() raise InvalidTokenError()
AccountService.revoke_change_email_token(args["token"]) AccountService.revoke_change_email_token(args["token"])
if not AccountService.check_email_unique(args["new_email"]):
raise EmailAlreadyInUseError()
old_email = reset_data.get("old_email", "") old_email = reset_data.get("old_email", "")
if current_user.email != old_email: if current_user.email != old_email:
raise AccountNotFound() raise AccountNotFound()
updated_account = AccountService.update_account(current_user, email=args["new_email"]) updated_account = AccountService.update_account(current_user, email=args["new_email"])
AccountService.send_change_email_completed_notify_email(
email=args["new_email"],
)
return updated_account return updated_account

@ -108,7 +108,7 @@ class MemberCancelInviteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, member_id): def delete(self, member_id):
member = db.session.query(Account).where(Account.id == str(member_id)).first() member = db.session.query(Account).filter(Account.id == str(member_id)).first()
if member is None: if member is None:
abort(404) abort(404)
else: else:

@ -12,8 +12,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from libs.login import login_required from libs.login import login_required
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission from models.account import TenantPluginPermission
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
from services.plugin.plugin_parameter_service import PluginParameterService from services.plugin.plugin_parameter_service import PluginParameterService
from services.plugin.plugin_permission_service import PluginPermissionService from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService from services.plugin.plugin_service import PluginService
@ -535,114 +534,6 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
return jsonable_encoder({"options": options}) return jsonable_encoder({"options": options})
class PluginChangePreferencesApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
req = reqparse.RequestParser()
req.add_argument("permission", type=dict, required=True, location="json")
req.add_argument("auto_upgrade", type=dict, required=True, location="json")
args = req.parse_args()
tenant_id = user.current_tenant_id
permission = args["permission"]
install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
debug_permission = TenantPluginPermission.DebugPermission(permission.get("debug_permission", "everyone"))
auto_upgrade = args["auto_upgrade"]
strategy_setting = TenantPluginAutoUpgradeStrategy.StrategySetting(
auto_upgrade.get("strategy_setting", "fix_only")
)
upgrade_time_of_day = auto_upgrade.get("upgrade_time_of_day", 0)
upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode(auto_upgrade.get("upgrade_mode", "exclude"))
exclude_plugins = auto_upgrade.get("exclude_plugins", [])
include_plugins = auto_upgrade.get("include_plugins", [])
# set permission
set_permission_result = PluginPermissionService.change_permission(
tenant_id,
install_permission,
debug_permission,
)
if not set_permission_result:
return jsonable_encoder({"success": False, "message": "Failed to set permission"})
# set auto upgrade strategy
set_auto_upgrade_strategy_result = PluginAutoUpgradeService.change_strategy(
tenant_id,
strategy_setting,
upgrade_time_of_day,
upgrade_mode,
exclude_plugins,
include_plugins,
)
if not set_auto_upgrade_strategy_result:
return jsonable_encoder({"success": False, "message": "Failed to set auto upgrade strategy"})
return jsonable_encoder({"success": True})
class PluginFetchPreferencesApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
permission = PluginPermissionService.get_permission(tenant_id)
permission_dict = {
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
"debug_permission": TenantPluginPermission.DebugPermission.EVERYONE,
}
if permission:
permission_dict["install_permission"] = permission.install_permission
permission_dict["debug_permission"] = permission.debug_permission
auto_upgrade = PluginAutoUpgradeService.get_strategy(tenant_id)
auto_upgrade_dict = {
"strategy_setting": TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED,
"upgrade_time_of_day": 0,
"upgrade_mode": TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
"exclude_plugins": [],
"include_plugins": [],
}
if auto_upgrade:
auto_upgrade_dict = {
"strategy_setting": auto_upgrade.strategy_setting,
"upgrade_time_of_day": auto_upgrade.upgrade_time_of_day,
"upgrade_mode": auto_upgrade.upgrade_mode,
"exclude_plugins": auto_upgrade.exclude_plugins,
"include_plugins": auto_upgrade.include_plugins,
}
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
class PluginAutoUpgradeExcludePluginApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
# exclude one single plugin
tenant_id = current_user.current_tenant_id
req = reqparse.RequestParser()
req.add_argument("plugin_id", type=str, required=True, location="json")
args = req.parse_args()
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key") api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
api.add_resource(PluginListApi, "/workspaces/current/plugin/list") api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions") api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions")
@ -669,7 +560,3 @@ api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permissi
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch") api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")
api.add_resource(PluginFetchDynamicSelectOptionsApi, "/workspaces/current/plugin/parameters/dynamic-options") api.add_resource(PluginFetchDynamicSelectOptionsApi, "/workspaces/current/plugin/parameters/dynamic-options")
api.add_resource(PluginFetchPreferencesApi, "/workspaces/current/plugin/preferences/fetch")
api.add_resource(PluginChangePreferencesApi, "/workspaces/current/plugin/preferences/change")
api.add_resource(PluginAutoUpgradeExcludePluginApi, "/workspaces/current/plugin/preferences/autoupgrade/exclude")

@ -29,7 +29,7 @@ from libs.login import login_required
from services.plugin.oauth_service import OAuthProxyService from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.tools.mcp_tools_manage_service import MCPToolManageService from services.tools.mcp_tools_mange_service import MCPToolManageService
from services.tools.tool_labels_service import ToolLabelsService from services.tools.tool_labels_service import ToolLabelsService
from services.tools.tools_manage_service import ToolCommonService from services.tools.tools_manage_service import ToolCommonService
from services.tools.tools_transform_service import ToolTransformService from services.tools.tools_transform_service import ToolTransformService
@ -739,7 +739,7 @@ class ToolOAuthCallback(Resource):
raise Forbidden("no oauth available client config found for this tool provider") raise Forbidden("no oauth available client config found for this tool provider")
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
credentials_response = oauth_handler.get_credentials( credentials = oauth_handler.get_credentials(
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=user_id, user_id=user_id,
plugin_id=plugin_id, plugin_id=plugin_id,
@ -747,10 +747,7 @@ class ToolOAuthCallback(Resource):
redirect_uri=redirect_uri, redirect_uri=redirect_uri,
system_credentials=oauth_client_params, system_credentials=oauth_client_params,
request=request, request=request,
) ).credentials
credentials = credentials_response.credentials
expires_at = credentials_response.expires_at
if not credentials: if not credentials:
raise Exception("the plugin credentials failed") raise Exception("the plugin credentials failed")
@ -761,7 +758,6 @@ class ToolOAuthCallback(Resource):
tenant_id=tenant_id, tenant_id=tenant_id,
provider=provider, provider=provider,
credentials=dict(credentials), credentials=dict(credentials),
expires_at=expires_at,
api_type=CredentialType.OAUTH2, api_type=CredentialType.OAUTH2,
) )
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")

@ -22,7 +22,7 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
user_id = "DEFAULT-USER" user_id = "DEFAULT-USER"
if user_id == "DEFAULT-USER": if user_id == "DEFAULT-USER":
user_model = session.query(EndUser).where(EndUser.session_id == "DEFAULT-USER").first() user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first()
if not user_model: if not user_model:
user_model = EndUser( user_model = EndUser(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -36,7 +36,7 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
else: else:
user_model = AccountService.load_user(user_id) user_model = AccountService.load_user(user_id)
if not user_model: if not user_model:
user_model = session.query(EndUser).where(EndUser.id == user_id).first() user_model = session.query(EndUser).filter(EndUser.id == user_id).first()
if not user_model: if not user_model:
raise ValueError("user not found") raise ValueError("user not found")
except Exception: except Exception:
@ -71,7 +71,7 @@ def get_user_tenant(view: Optional[Callable] = None):
try: try:
tenant_model = ( tenant_model = (
db.session.query(Tenant) db.session.query(Tenant)
.where( .filter(
Tenant.id == tenant_id, Tenant.id == tenant_id,
) )
.first() .first()

@ -55,7 +55,7 @@ def enterprise_inner_api_user_auth(view):
if signature_base64 != token: if signature_base64 != token:
return view(*args, **kwargs) return view(*args, **kwargs)
kwargs["user"] = db.session.query(EndUser).where(EndUser.id == user_id).first() kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first()
return view(*args, **kwargs) return view(*args, **kwargs)

@ -30,7 +30,7 @@ class MCPAppApi(Resource):
request_id = args.get("id") request_id = args.get("id")
server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first() server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first()
if not server: if not server:
return helper.compact_generate_response( return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found") create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found")
@ -41,7 +41,7 @@ class MCPAppApi(Resource):
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active") create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active")
) )
app = db.session.query(App).where(App.id == server.app_id).first() app = db.session.query(App).filter(App.id == server.app_id).first()
if not app: if not app:
return helper.compact_generate_response( return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found") create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found")

@ -1,6 +1,5 @@
import logging import logging
from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
@ -24,7 +23,6 @@ from core.errors.error import (
ProviderTokenNotInitError, ProviderTokenNotInitError,
QuotaExceededError, QuotaExceededError,
) )
from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from libs import helper from libs import helper
from libs.helper import uuid_value from libs.helper import uuid_value
@ -113,10 +111,6 @@ class ChatApi(Resource):
args = parser.parse_args() args = parser.parse_args()
external_trace_id = get_external_trace_id(request)
if external_trace_id:
args["external_trace_id"] = external_trace_id
streaming = args["response_mode"] == "streaming" streaming = args["response_mode"] == "streaming"
try: try:

@ -16,7 +16,7 @@ class AppSiteApi(Resource):
@marshal_with(fields.site_fields) @marshal_with(fields.site_fields)
def get(self, app_model: App): def get(self, app_model: App):
"""Retrieve app site info.""" """Retrieve app site info."""
site = db.session.query(Site).where(Site.app_id == app_model.id).first() site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
if not site: if not site:
raise Forbidden() raise Forbidden()

@ -1,7 +1,6 @@
import logging import logging
from dateutil.parser import isoparse from dateutil.parser import isoparse
from flask import request
from flask_restful import Resource, fields, marshal_with, reqparse from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range from flask_restful.inputs import int_range
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
@ -24,7 +23,6 @@ from core.errors.error import (
ProviderTokenNotInitError, ProviderTokenNotInitError,
QuotaExceededError, QuotaExceededError,
) )
from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
from extensions.ext_database import db from extensions.ext_database import db
@ -92,9 +90,7 @@ class WorkflowRunApi(Resource):
parser.add_argument("files", type=list, required=False, location="json") parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
args = parser.parse_args() args = parser.parse_args()
external_trace_id = get_external_trace_id(request)
if external_trace_id:
args["external_trace_id"] = external_trace_id
streaming = args.get("response_mode") == "streaming" streaming = args.get("response_mode") == "streaming"
try: try:

@ -63,7 +63,7 @@ class DocumentAddByTextApi(DatasetApiResource):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset does not exist.") raise ValueError("Dataset does not exist.")
@ -136,7 +136,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
args = parser.parse_args() args = parser.parse_args()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset does not exist.") raise ValueError("Dataset does not exist.")
@ -206,7 +206,7 @@ class DocumentAddByFileApi(DatasetApiResource):
# get dataset info # get dataset info
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset does not exist.") raise ValueError("Dataset does not exist.")
@ -299,7 +299,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
# get dataset info # get dataset info
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset does not exist.") raise ValueError("Dataset does not exist.")
@ -367,7 +367,7 @@ class DocumentDeleteApi(DatasetApiResource):
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
# get dataset info # get dataset info
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset does not exist.") raise ValueError("Dataset does not exist.")
@ -398,7 +398,7 @@ class DocumentListApi(DatasetApiResource):
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str) search = request.args.get("keyword", default=None, type=str)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -406,7 +406,7 @@ class DocumentListApi(DatasetApiResource):
if search: if search:
search = f"%{search}%" search = f"%{search}%"
query = query.where(Document.name.like(search)) query = query.filter(Document.name.like(search))
query = query.order_by(desc(Document.created_at), desc(Document.position)) query = query.order_by(desc(Document.created_at), desc(Document.position))
@ -430,7 +430,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
batch = str(batch) batch = str(batch)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
# get dataset # get dataset
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
# get documents # get documents
@ -441,7 +441,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
for document in documents: for document in documents:
completed_segments = ( completed_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where( .filter(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id), DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment", DocumentSegment.status != "re_segment",
@ -450,7 +450,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
) )
total_segments = ( total_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count() .count()
) )
# Create a dictionary with document attributes and additional fields # Create a dictionary with document attributes and additional fields

@ -42,7 +42,7 @@ class SegmentApi(DatasetApiResource):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
# check document # check document
@ -89,7 +89,7 @@ class SegmentApi(DatasetApiResource):
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
# check document # check document
@ -146,7 +146,7 @@ class DatasetSegmentApi(DatasetApiResource):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
# check user's model setting # check user's model setting
@ -170,7 +170,7 @@ class DatasetSegmentApi(DatasetApiResource):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
# check user's model setting # check user's model setting
@ -216,7 +216,7 @@ class DatasetSegmentApi(DatasetApiResource):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
# check user's model setting # check user's model setting
@ -246,7 +246,7 @@ class ChildChunkApi(DatasetApiResource):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -296,7 +296,7 @@ class ChildChunkApi(DatasetApiResource):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -343,7 +343,7 @@ class DatasetChildChunkApi(DatasetApiResource):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -382,7 +382,7 @@ class DatasetChildChunkApi(DatasetApiResource):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")

@ -17,7 +17,7 @@ class UploadFileApi(DatasetApiResource):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
# check document # check document
@ -31,7 +31,7 @@ class UploadFileApi(DatasetApiResource):
data_source_info = document.data_source_info_dict data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info: if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"] file_id = data_source_info["upload_file_id"]
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
if not upload_file: if not upload_file:
raise NotFound("UploadFile not found.") raise NotFound("UploadFile not found.")
else: else:

@ -1,6 +1,6 @@
import time import time
from collections.abc import Callable from collections.abc import Callable
from datetime import timedelta from datetime import UTC, datetime, timedelta
from enum import Enum from enum import Enum
from functools import wraps from functools import wraps
from typing import Optional from typing import Optional
@ -15,7 +15,6 @@ from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from libs.login import _get_user from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
from models.dataset import Dataset, RateLimitLog from models.dataset import Dataset, RateLimitLog
@ -44,7 +43,7 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
def decorated_view(*args, **kwargs): def decorated_view(*args, **kwargs):
api_token = validate_and_get_api_token("app") api_token = validate_and_get_api_token("app")
app_model = db.session.query(App).where(App.id == api_token.app_id).first() app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
if not app_model: if not app_model:
raise Forbidden("The app no longer exists.") raise Forbidden("The app no longer exists.")
@ -54,7 +53,7 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
if not app_model.enable_api: if not app_model.enable_api:
raise Forbidden("The app's API service has been disabled.") raise Forbidden("The app's API service has been disabled.")
tenant = db.session.query(Tenant).where(Tenant.id == app_model.tenant_id).first() tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first()
if tenant is None: if tenant is None:
raise ValueError("Tenant does not exist.") raise ValueError("Tenant does not exist.")
if tenant.status == TenantStatus.ARCHIVE: if tenant.status == TenantStatus.ARCHIVE:
@ -62,15 +61,15 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
tenant_account_join = ( tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin) db.session.query(Tenant, TenantAccountJoin)
.where(Tenant.id == api_token.tenant_id) .filter(Tenant.id == api_token.tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id) .filter(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role.in_(["owner"])) .filter(TenantAccountJoin.role.in_(["owner"]))
.where(Tenant.status == TenantStatus.NORMAL) .filter(Tenant.status == TenantStatus.NORMAL)
.one_or_none() .one_or_none()
) # TODO: only owner information is required, so only one is returned. ) # TODO: only owner information is required, so only one is returned.
if tenant_account_join: if tenant_account_join:
tenant, ta = tenant_account_join tenant, ta = tenant_account_join
account = db.session.query(Account).where(Account.id == ta.account_id).first() account = db.session.query(Account).filter(Account.id == ta.account_id).first()
# Login admin # Login admin
if account: if account:
account.current_tenant = tenant account.current_tenant = tenant
@ -213,15 +212,15 @@ def validate_dataset_token(view=None):
api_token = validate_and_get_api_token("dataset") api_token = validate_and_get_api_token("dataset")
tenant_account_join = ( tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin) db.session.query(Tenant, TenantAccountJoin)
.where(Tenant.id == api_token.tenant_id) .filter(Tenant.id == api_token.tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id) .filter(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.role.in_(["owner"])) .filter(TenantAccountJoin.role.in_(["owner"]))
.where(Tenant.status == TenantStatus.NORMAL) .filter(Tenant.status == TenantStatus.NORMAL)
.one_or_none() .one_or_none()
) # TODO: only owner information is required, so only one is returned. ) # TODO: only owner information is required, so only one is returned.
if tenant_account_join: if tenant_account_join:
tenant, ta = tenant_account_join tenant, ta = tenant_account_join
account = db.session.query(Account).where(Account.id == ta.account_id).first() account = db.session.query(Account).filter(Account.id == ta.account_id).first()
# Login admin # Login admin
if account: if account:
account.current_tenant = tenant account.current_tenant = tenant
@ -257,7 +256,7 @@ def validate_and_get_api_token(scope: str | None = None):
if auth_scheme != "bearer": if auth_scheme != "bearer":
raise Unauthorized("Authorization scheme must be 'Bearer'") raise Unauthorized("Authorization scheme must be 'Bearer'")
current_time = naive_utc_now() current_time = datetime.now(UTC).replace(tzinfo=None)
cutoff_time = current_time - timedelta(minutes=1) cutoff_time = current_time - timedelta(minutes=1)
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
update_stmt = ( update_stmt = (
@ -293,7 +292,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
end_user = ( end_user = (
db.session.query(EndUser) db.session.query(EndUser)
.where( .filter(
EndUser.tenant_id == app_model.tenant_id, EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id, EndUser.app_id == app_model.id,
EndUser.session_id == user_id, EndUser.session_id == user_id,
@ -320,7 +319,7 @@ class DatasetApiResource(Resource):
method_decorators = [validate_dataset_token] method_decorators = [validate_dataset_token]
def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset: def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")

@ -3,7 +3,6 @@ from datetime import UTC, datetime, timedelta
from flask import request from flask import request
from flask_restful import Resource from flask_restful import Resource
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config from configs import dify_config
@ -43,17 +42,17 @@ class PassportResource(Resource):
raise WebAppAuthRequiredError() raise WebAppAuthRequiredError()
# get site from db and check if it is normal # get site from db and check if it is normal
site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal")) site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
if not site: if not site:
raise NotFound() raise NotFound()
# get app from db and check if it is normal and enable_site # get app from db and check if it is normal and enable_site
app_model = db.session.scalar(select(App).where(App.id == site.app_id)) app_model = db.session.query(App).filter(App.id == site.app_id).first()
if not app_model or app_model.status != "normal" or not app_model.enable_site: if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound() raise NotFound()
if user_id: if user_id:
end_user = db.session.scalar( end_user = (
select(EndUser).where(EndUser.app_id == app_model.id, EndUser.session_id == user_id) db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
) )
if end_user: if end_user:
@ -122,11 +121,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
if not user_auth_type: if not user_auth_type:
raise Unauthorized("Missing auth_type in the token.") raise Unauthorized("Missing auth_type in the token.")
site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal")) site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
if not site: if not site:
raise NotFound() raise NotFound()
app_model = db.session.scalar(select(App).where(App.id == site.app_id)) app_model = db.session.query(App).filter(App.id == site.app_id).first()
if not app_model or app_model.status != "normal" or not app_model.enable_site: if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound() raise NotFound()
@ -141,14 +140,16 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
end_user = None end_user = None
if end_user_id: if end_user_id:
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id)) end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
if session_id: if session_id:
end_user = db.session.scalar( end_user = (
select(EndUser).where( db.session.query(EndUser)
.filter(
EndUser.session_id == session_id, EndUser.session_id == session_id,
EndUser.tenant_id == app_model.tenant_id, EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id, EndUser.app_id == app_model.id,
) )
.first()
) )
if not end_user: if not end_user:
if not session_id: if not session_id:
@ -186,8 +187,8 @@ def _exchange_for_public_app_token(app_model, site, token_decoded):
user_id = token_decoded.get("user_id") user_id = token_decoded.get("user_id")
end_user = None end_user = None
if user_id: if user_id:
end_user = db.session.scalar( end_user = (
select(EndUser).where(EndUser.app_id == app_model.id, EndUser.session_id == user_id) db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
) )
if not end_user: if not end_user:
@ -223,8 +224,6 @@ def generate_session_id():
""" """
while True: while True:
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
existing_count = db.session.scalar( existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count()
select(func.count()).select_from(EndUser).where(EndUser.session_id == session_id)
)
if existing_count == 0: if existing_count == 0:
return session_id return session_id

@ -57,7 +57,7 @@ class AppSiteApi(WebApiResource):
def get(self, app_model, end_user): def get(self, app_model, end_user):
"""Retrieve app site info.""" """Retrieve app site info."""
# get site # get site
site = db.session.query(Site).where(Site.app_id == app_model.id).first() site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
if not site: if not site:
raise Forbidden() raise Forbidden()

@ -3,7 +3,6 @@ from functools import wraps
from flask import request from flask import request
from flask_restful import Resource from flask_restful import Resource
from sqlalchemy import select
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
@ -49,8 +48,8 @@ def decode_jwt_token():
decoded = PassportService().verify(tk) decoded = PassportService().verify(tk)
app_code = decoded.get("app_code") app_code = decoded.get("app_code")
app_id = decoded.get("app_id") app_id = decoded.get("app_id")
app_model = db.session.scalar(select(App).where(App.id == app_id)) app_model = db.session.query(App).filter(App.id == app_id).first()
site = db.session.scalar(select(Site).where(Site.code == app_code)) site = db.session.query(Site).filter(Site.code == app_code).first()
if not app_model: if not app_model:
raise NotFound() raise NotFound()
if not app_code or not site: if not app_code or not site:
@ -58,7 +57,7 @@ def decode_jwt_token():
if app_model.enable_site is False: if app_model.enable_site is False:
raise BadRequest("Site is disabled.") raise BadRequest("Site is disabled.")
end_user_id = decoded.get("end_user_id") end_user_id = decoded.get("end_user_id")
end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id)) end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
if not end_user: if not end_user:
raise NotFound() raise NotFound()

@ -99,7 +99,7 @@ class BaseAgentRunner(AppRunner):
# get how many agent thoughts have been created # get how many agent thoughts have been created
self.agent_thought_count = ( self.agent_thought_count = (
db.session.query(MessageAgentThought) db.session.query(MessageAgentThought)
.where( .filter(
MessageAgentThought.message_id == self.message.id, MessageAgentThought.message_id == self.message.id,
) )
.count() .count()
@ -336,7 +336,7 @@ class BaseAgentRunner(AppRunner):
Save agent thought Save agent thought
""" """
updated_agent_thought = ( updated_agent_thought = (
db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought.id).first() db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
) )
if not updated_agent_thought: if not updated_agent_thought:
raise ValueError("agent thought not found") raise ValueError("agent thought not found")
@ -496,7 +496,7 @@ class BaseAgentRunner(AppRunner):
return result return result
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all() files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if not files: if not files:
return UserPromptMessage(content=message.query) return UserPromptMessage(content=message.query)
if message.app_model_config: if message.app_model_config:

@ -0,0 +1,48 @@
## Guidelines for Database Connection Management in App Runner and Task Pipeline
Due to the presence of tasks in App Runner that require long execution times, such as LLM generation and external requests, Flask-Sqlalchemy's strategy for database connection pooling is to allocate one connection (transaction) per request. This approach keeps a connection occupied even during non-DB tasks, leading to the inability to acquire new connections during high concurrency requests due to multiple long-running tasks.
Therefore, the database operations in App Runner and Task Pipeline must ensure connections are closed immediately after use, and it's better to pass IDs rather than Model objects to avoid detach errors.
Examples:
1. Creating a new record:
```python
app = App(id=1)
db.session.add(app)
db.session.commit()
db.session.refresh(app) # Retrieve table default values, like created_at, cached in the app object, won't affect after close
# Handle non-long-running tasks or store the content of the App instance in memory (via variable assignment).
db.session.close()
return app.id
```
2. Fetching a record from the table:
```python
app = db.session.query(App).filter(App.id == app_id).first()
created_at = app.created_at
db.session.close()
# Handle tasks (include long-running).
```
3. Updating a table field:
```python
app = db.session.query(App).filter(App.id == app_id).first()
app.updated_at = time.utcnow()
db.session.commit()
db.session.close()
return app_id
```

@ -7,8 +7,7 @@ from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy import select from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
import contexts import contexts
from configs import dify_config from configs import dify_config
@ -18,13 +17,11 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
@ -114,10 +111,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
query = query.replace("\x00", "") query = query.replace("\x00", "")
inputs = args["inputs"] inputs = args["inputs"]
extras = { extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)}
"auto_generate_conversation_name": args.get("auto_generate_name", False),
**extract_external_trace_id_from_args(args),
}
# get conversation # get conversation
conversation = None conversation = None
@ -487,52 +481,21 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
""" """
with preserve_flask_contexts(flask_app, context_vars=context): with preserve_flask_contexts(flask_app, context_vars=context):
# get conversation and message try:
conversation = self._get_conversation(conversation_id) # get conversation and message
message = self._get_message(message_id) conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
with Session(db.engine, expire_on_commit=False) as session:
workflow = session.scalar( # chatbot app
select(Workflow).where( runner = AdvancedChatAppRunner(
Workflow.tenant_id == application_generate_entity.app_config.tenant_id, application_generate_entity=application_generate_entity,
Workflow.app_id == application_generate_entity.app_config.app_id, queue_manager=queue_manager,
Workflow.id == application_generate_entity.app_config.workflow_id, conversation=conversation,
) message=message,
dialogue_count=self._dialogue_count,
variable_loader=variable_loader,
) )
if workflow is None:
raise ValueError("Workflow not found")
# Determine system_user_id based on invocation source
is_external_api_call = application_generate_entity.invoke_from in {
InvokeFrom.WEB_APP,
InvokeFrom.SERVICE_API,
}
if is_external_api_call:
# For external API calls, use end user's session ID
end_user = session.scalar(select(EndUser).where(EndUser.id == application_generate_entity.user_id))
system_user_id = end_user.session_id if end_user else ""
else:
# For internal calls, use the original user ID
system_user_id = application_generate_entity.user_id
app = session.scalar(select(App).where(App.id == application_generate_entity.app_config.app_id))
if app is None:
raise ValueError("App not found")
runner = AdvancedChatAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
dialogue_count=self._dialogue_count,
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
app=app,
)
try:
runner.run() runner.run()
except GenerateTaskStoppedError: except GenerateTaskStoppedError:
pass pass

@ -1,6 +1,6 @@
import logging import logging
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Optional, cast from typing import Any, cast
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -9,19 +9,13 @@ from configs import dify_config
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
AdvancedChatAppGenerateEntity,
AppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent, QueueAnnotationReplyEvent,
QueueStopEvent, QueueStopEvent,
QueueTextChunkEvent, QueueTextChunkEvent,
) )
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.moderation.base import ModerationError from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion from core.variables.variables import VariableUnion
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
@ -29,9 +23,8 @@ from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
from models import Workflow
from models.enums import UserFrom from models.enums import UserFrom
from models.model import App, Conversation, Message, MessageAnnotation from models.model import App, Conversation, EndUser, Message
from models.workflow import ConversationVariable, WorkflowType from models.workflow import ConversationVariable, WorkflowType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -44,38 +37,42 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
def __init__( def __init__(
self, self,
*,
application_generate_entity: AdvancedChatAppGenerateEntity, application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
dialogue_count: int, dialogue_count: int,
variable_loader: VariableLoader, variable_loader: VariableLoader,
workflow: Workflow,
system_user_id: str,
app: App,
) -> None: ) -> None:
super().__init__( super().__init__(queue_manager, variable_loader)
queue_manager=queue_manager,
variable_loader=variable_loader,
app_id=application_generate_entity.app_config.app_id,
)
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
self.conversation = conversation self.conversation = conversation
self.message = message self.message = message
self._dialogue_count = dialogue_count self._dialogue_count = dialogue_count
self._workflow = workflow
self.system_user_id = system_user_id def _get_app_id(self) -> str:
self._app = app return self.application_generate_entity.app_config.app_id
def run(self) -> None: def run(self) -> None:
app_config = self.application_generate_entity.app_config app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config) app_config = cast(AdvancedChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first() app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
user_id: str | None = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
workflow_callbacks: list[WorkflowCallback] = [] workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG: if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback()) workflow_callbacks.append(WorkflowLoggingCallback())
@ -83,14 +80,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if self.application_generate_entity.single_iteration_run: if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested # if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=self._workflow, workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id, node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs), user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
) )
elif self.application_generate_entity.single_loop_run: elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested # if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=self._workflow, workflow=workflow,
node_id=self.application_generate_entity.single_loop_run.node_id, node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs), user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
) )
@ -101,7 +98,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# moderation # moderation
if self.handle_input_moderation( if self.handle_input_moderation(
app_record=self._app, app_record=app_record,
app_generate_entity=self.application_generate_entity, app_generate_entity=self.application_generate_entity,
inputs=inputs, inputs=inputs,
query=query, query=query,
@ -111,7 +108,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# annotation reply # annotation reply
if self.handle_annotation_reply( if self.handle_annotation_reply(
app_record=self._app, app_record=app_record,
message=self.message, message=self.message,
query=query, query=query,
app_generate_entity=self.application_generate_entity, app_generate_entity=self.application_generate_entity,
@ -131,7 +128,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
ConversationVariable.from_variable( ConversationVariable.from_variable(
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
) )
for variable in self._workflow.conversation_variables for variable in workflow.conversation_variables
] ]
session.add_all(db_conversation_variables) session.add_all(db_conversation_variables)
# Convert database entities to variables. # Convert database entities to variables.
@ -144,7 +141,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
query=query, query=query,
files=files, files=files,
conversation_id=self.conversation.id, conversation_id=self.conversation.id,
user_id=self.system_user_id, user_id=user_id,
dialogue_count=self._dialogue_count, dialogue_count=self._dialogue_count,
app_id=app_config.app_id, app_id=app_config.app_id,
workflow_id=app_config.workflow_id, workflow_id=app_config.workflow_id,
@ -155,25 +152,25 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
variable_pool = VariablePool( variable_pool = VariablePool(
system_variables=system_inputs, system_variables=system_inputs,
user_inputs=inputs, user_inputs=inputs,
environment_variables=self._workflow.environment_variables, environment_variables=workflow.environment_variables,
# Based on the definition of `VariableUnion`, # Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables), conversation_variables=cast(list[VariableUnion], conversation_variables),
) )
# init graph # init graph
graph = self._init_graph(graph_config=self._workflow.graph_dict) graph = self._init_graph(graph_config=workflow.graph_dict)
db.session.close() db.session.close()
# RUN WORKFLOW # RUN WORKFLOW
workflow_entry = WorkflowEntry( workflow_entry = WorkflowEntry(
tenant_id=self._workflow.tenant_id, tenant_id=workflow.tenant_id,
app_id=self._workflow.app_id, app_id=workflow.app_id,
workflow_id=self._workflow.id, workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(self._workflow.type), workflow_type=WorkflowType.value_of(workflow.type),
graph=graph, graph=graph,
graph_config=self._workflow.graph_dict, graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id, user_id=self.application_generate_entity.user_id,
user_from=( user_from=(
UserFrom.ACCOUNT UserFrom.ACCOUNT
@ -244,51 +241,3 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._publish_event(QueueTextChunkEvent(text=text)) self._publish_event(QueueTextChunkEvent(text=text))
self._publish_event(QueueStopEvent(stopped_by=stopped_by)) self._publish_event(QueueStopEvent(stopped_by=stopped_by))
def query_app_annotations_to_reply(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> Optional[MessageAnnotation]:
"""
Query app annotations to reply
:param app_record: app record
:param message: message
:param query: query
:param user_id: user id
:param invoke_from: invoke from
:return:
"""
annotation_reply_feature = AnnotationReplyFeature()
return annotation_reply_feature.query(
app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from
)
def moderation_for_inputs(
self,
*,
app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: Mapping[str, Any],
query: str | None = None,
message_id: str,
) -> tuple[bool, Mapping[str, Any], str]:
"""
Process sensitive_word_avoidance.
:param app_id: app id
:param tenant_id: tenant id
:param app_generate_entity: app generate entity
:param inputs: inputs
:param query: query
:param message_id: message id
:return:
"""
moderation_feature = InputModeration()
return moderation_feature.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_generate_entity.app_config,
inputs=dict(inputs),
query=query or "",
message_id=message_id,
trace_manager=app_generate_entity.trace_manager,
)

File diff suppressed because it is too large Load Diff

@ -15,8 +15,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom

@ -45,7 +45,7 @@ class AgentChatAppRunner(AppRunner):
app_config = application_generate_entity.app_config app_config = application_generate_entity.app_config
app_config = cast(AgentChatAppConfig, app_config) app_config = cast(AgentChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first() app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")
@ -183,10 +183,10 @@ class AgentChatAppRunner(AppRunner):
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first() conversation_result = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
if conversation_result is None: if conversation_result is None:
raise ValueError("Conversation not found") raise ValueError("Conversation not found")
message_result = db.session.query(Message).where(Message.id == message.id).first() message_result = db.session.query(Message).filter(Message.id == message.id).first()
if message_result is None: if message_result is None:
raise ValueError("Message not found") raise ValueError("Message not found")
db.session.close() db.session.close()

@ -169,3 +169,7 @@ class AppQueueManager:
raise TypeError( raise TypeError(
"Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed." "Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed."
) )
class GenerateTaskStoppedError(Exception):
pass

@ -118,7 +118,7 @@ class AppRunner:
else: else:
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
model_mode = ModelMode(model_config.mode) model_mode = ModelMode.value_of(model_config.mode)
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]] prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
if model_mode == ModelMode.COMPLETION: if model_mode == ModelMode.COMPLETION:
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template

@ -11,11 +11,10 @@ from configs import dify_config
from constants import UUID_NIL from constants import UUID_NIL
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.chat.app_runner import ChatAppRunner from core.app.apps.chat.app_runner import ChatAppRunner
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom

@ -43,7 +43,7 @@ class ChatAppRunner(AppRunner):
app_config = application_generate_entity.app_config app_config = application_generate_entity.app_config
app_config = cast(ChatAppConfig, app_config) app_config = cast(ChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first() app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")

@ -10,11 +10,10 @@ from pydantic import ValidationError
from configs import dify_config from configs import dify_config
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from core.app.apps.completion.app_runner import CompletionAppRunner from core.app.apps.completion.app_runner import CompletionAppRunner
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
@ -248,7 +247,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
""" """
message = ( message = (
db.session.query(Message) db.session.query(Message)
.where( .filter(
Message.id == message_id, Message.id == message_id,
Message.app_id == app_model.id, Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"), Message.from_source == ("api" if isinstance(user, EndUser) else "console"),

@ -36,7 +36,7 @@ class CompletionAppRunner(AppRunner):
app_config = application_generate_entity.app_config app_config = application_generate_entity.app_config
app_config = cast(CompletionAppConfig, app_config) app_config = cast(CompletionAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first() app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")

@ -1,2 +0,0 @@
class GenerateTaskStoppedError(Exception):
pass

@ -1,12 +1,12 @@
import json import json
import logging import logging
from collections.abc import Generator from collections.abc import Generator
from datetime import UTC, datetime
from typing import Optional, Union, cast from typing import Optional, Union, cast
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity, AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity, AgentChatAppGenerateEntity,
@ -24,7 +24,6 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account from models import Account
from models.enums import CreatorUserRole from models.enums import CreatorUserRole
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
@ -85,7 +84,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
if conversation: if conversation:
app_model_config = ( app_model_config = (
db.session.query(AppModelConfig) db.session.query(AppModelConfig)
.where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) .filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
.first() .first()
) )
@ -151,7 +150,13 @@ class MessageBasedAppGenerator(BaseAppGenerator):
introduction = self._get_conversation_introduction(application_generate_entity) introduction = self._get_conversation_introduction(application_generate_entity)
# get conversation name # get conversation name
query = application_generate_entity.query or "New conversation" if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity):
query = application_generate_entity.query or "New conversation"
else:
query = next(iter(application_generate_entity.inputs.values()), "New conversation")
if isinstance(query, int):
query = str(query)
query = query or "New conversation"
conversation_name = (query[:20] + "") if len(query) > 20 else query conversation_name = (query[:20] + "") if len(query) > 20 else query
if not conversation: if not conversation:
@ -178,7 +183,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
db.session.commit() db.session.commit()
db.session.refresh(conversation) db.session.refresh(conversation)
else: else:
conversation.updated_at = naive_utc_now() conversation.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
message = Message( message = Message(
@ -253,7 +258,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param conversation_id: conversation id :param conversation_id: conversation id
:return: conversation :return: conversation
""" """
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first() conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
if not conversation: if not conversation:
raise ConversationNotExistsError("Conversation not exists") raise ConversationNotExistsError("Conversation not exists")
@ -266,7 +271,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param message_id: message id :param message_id: message id
:return: message :return: message
""" """
message = db.session.query(Message).where(Message.id == message_id).first() message = db.session.query(Message).filter(Message.id == message_id).first()
if message is None: if message is None:
raise MessageNotExistsError("Message not exists") raise MessageNotExistsError("Message not exists")

@ -1,5 +1,4 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
AppQueueEvent, AppQueueEvent,

@ -7,15 +7,13 @@ from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
from sqlalchemy import select from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
import contexts import contexts
from configs import dify_config from configs import dify_config
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow.app_runner import WorkflowAppRunner
@ -23,7 +21,6 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories import DifyCoreRepositoryFactory from core.repositories import DifyCoreRepositoryFactory
@ -125,10 +122,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
) )
inputs: Mapping[str, Any] = args["inputs"] inputs: Mapping[str, Any] = args["inputs"]
extras = {
**extract_external_trace_id_from_args(args),
}
workflow_run_id = str(uuid.uuid4()) workflow_run_id = str(uuid.uuid4())
# init application generate entity # init application generate entity
application_generate_entity = WorkflowAppGenerateEntity( application_generate_entity = WorkflowAppGenerateEntity(
@ -148,7 +141,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
call_depth=call_depth, call_depth=call_depth,
trace_manager=trace_manager, trace_manager=trace_manager,
workflow_execution_id=workflow_run_id, workflow_execution_id=workflow_run_id,
extras=extras,
) )
contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers.set({})
@ -446,44 +438,17 @@ class WorkflowAppGenerator(BaseAppGenerator):
""" """
with preserve_flask_contexts(flask_app, context_vars=context): with preserve_flask_contexts(flask_app, context_vars=context):
with Session(db.engine, expire_on_commit=False) as session: try:
workflow = session.scalar( # workflow app
select(Workflow).where( runner = WorkflowAppRunner(
Workflow.tenant_id == application_generate_entity.app_config.tenant_id, application_generate_entity=application_generate_entity,
Workflow.app_id == application_generate_entity.app_config.app_id, queue_manager=queue_manager,
Workflow.id == application_generate_entity.app_config.workflow_id, workflow_thread_pool_id=workflow_thread_pool_id,
) variable_loader=variable_loader,
) )
if workflow is None:
raise ValueError("Workflow not found")
# Determine system_user_id based on invocation source
is_external_api_call = application_generate_entity.invoke_from in {
InvokeFrom.WEB_APP,
InvokeFrom.SERVICE_API,
}
if is_external_api_call:
# For external API calls, use end user's session ID
end_user = session.scalar(select(EndUser).where(EndUser.id == application_generate_entity.user_id))
system_user_id = end_user.session_id if end_user else ""
else:
# For internal calls, use the original user ID
system_user_id = application_generate_entity.user_id
runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
)
try:
runner.run() runner.run()
except GenerateTaskStoppedError as e: except GenerateTaskStoppedError:
logger.warning(f"Task stopped: {str(e)}")
pass pass
except InvokeAuthorizationError: except InvokeAuthorizationError:
queue_manager.publish_error( queue_manager.publish_error(
@ -499,6 +464,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
except Exception as e: except Exception as e:
logger.exception("Unknown Error when generating") logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.close()
def _handle_response( def _handle_response(
self, self,

@ -1,5 +1,4 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
AppQueueEvent, AppQueueEvent,

@ -14,8 +14,10 @@ from core.workflow.entities.variable_pool import VariablePool
from core.workflow.system_variable import SystemVariable from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.enums import UserFrom from models.enums import UserFrom
from models.workflow import Workflow, WorkflowType from models.model import App, EndUser
from models.workflow import WorkflowType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -27,23 +29,22 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
def __init__( def __init__(
self, self,
*,
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
variable_loader: VariableLoader, variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str] = None,
workflow: Workflow,
system_user_id: str,
) -> None: ) -> None:
super().__init__( """
queue_manager=queue_manager, :param application_generate_entity: application generate entity
variable_loader=variable_loader, :param queue_manager: application queue manager
app_id=application_generate_entity.app_config.app_id, :param workflow_thread_pool_id: workflow thread pool id
) """
super().__init__(queue_manager, variable_loader)
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
self.workflow_thread_pool_id = workflow_thread_pool_id self.workflow_thread_pool_id = workflow_thread_pool_id
self._workflow = workflow
self._sys_user_id = system_user_id def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id
def run(self) -> None: def run(self) -> None:
""" """
@ -52,6 +53,24 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config) app_config = cast(WorkflowAppConfig, app_config)
user_id = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
db.session.close()
workflow_callbacks: list[WorkflowCallback] = [] workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG: if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback()) workflow_callbacks.append(WorkflowLoggingCallback())
@ -60,14 +79,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
if self.application_generate_entity.single_iteration_run: if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested # if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration( graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=self._workflow, workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id, node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs, user_inputs=self.application_generate_entity.single_iteration_run.inputs,
) )
elif self.application_generate_entity.single_loop_run: elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested # if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop( graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=self._workflow, workflow=workflow,
node_id=self.application_generate_entity.single_loop_run.node_id, node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs, user_inputs=self.application_generate_entity.single_loop_run.inputs,
) )
@ -79,7 +98,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
system_inputs = SystemVariable( system_inputs = SystemVariable(
files=files, files=files,
user_id=self._sys_user_id, user_id=user_id,
app_id=app_config.app_id, app_id=app_config.app_id,
workflow_id=app_config.workflow_id, workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id, workflow_execution_id=self.application_generate_entity.workflow_execution_id,
@ -88,21 +107,21 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
variable_pool = VariablePool( variable_pool = VariablePool(
system_variables=system_inputs, system_variables=system_inputs,
user_inputs=inputs, user_inputs=inputs,
environment_variables=self._workflow.environment_variables, environment_variables=workflow.environment_variables,
conversation_variables=[], conversation_variables=[],
) )
# init graph # init graph
graph = self._init_graph(graph_config=self._workflow.graph_dict) graph = self._init_graph(graph_config=workflow.graph_dict)
# RUN WORKFLOW # RUN WORKFLOW
workflow_entry = WorkflowEntry( workflow_entry = WorkflowEntry(
tenant_id=self._workflow.tenant_id, tenant_id=workflow.tenant_id,
app_id=self._workflow.app_id, app_id=workflow.app_id,
workflow_id=self._workflow.id, workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(self._workflow.type), workflow_type=WorkflowType.value_of(workflow.type),
graph=graph, graph=graph,
graph_config=self._workflow.graph_dict, graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id, user_id=self.application_generate_entity.user_id,
user_from=( user_from=(
UserFrom.ACCOUNT UserFrom.ACCOUNT

@ -1,8 +1,7 @@
import logging import logging
import time import time
from collections.abc import Callable, Generator from collections.abc import Generator
from contextlib import contextmanager from typing import Optional, Union
from typing import Any, Optional, Union
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -14,7 +13,6 @@ from core.app.entities.app_invoke_entities import (
WorkflowAppGenerateEntity, WorkflowAppGenerateEntity,
) )
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
MessageQueueMessage,
QueueAgentLogEvent, QueueAgentLogEvent,
QueueErrorEvent, QueueErrorEvent,
QueueIterationCompletedEvent, QueueIterationCompletedEvent,
@ -40,13 +38,11 @@ from core.app.entities.queue_entities import (
QueueWorkflowPartialSuccessEvent, QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent, QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent, QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
) )
from core.app.entities.task_entities import ( from core.app.entities.task_entities import (
ErrorStreamResponse, ErrorStreamResponse,
MessageAudioEndStreamResponse, MessageAudioEndStreamResponse,
MessageAudioStreamResponse, MessageAudioStreamResponse,
PingStreamResponse,
StreamResponse, StreamResponse,
TextChunkStreamResponse, TextChunkStreamResponse,
WorkflowAppBlockingResponse, WorkflowAppBlockingResponse,
@ -58,7 +54,6 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
@ -251,495 +246,315 @@ class WorkflowAppGenerateTaskPipeline:
if tts_publisher: if tts_publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id) yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
@contextmanager def _process_stream_response(
def _database_session(self):
"""Context manager for database sessions."""
with Session(db.engine, expire_on_commit=False) as session:
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
def _ensure_workflow_initialized(self) -> None:
"""Fluent validation for workflow state."""
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState:
"""Fluent validation for graph runtime state."""
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
return graph_runtime_state
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
"""Handle ping events."""
yield self._base_task_pipeline._ping_stream_response()
def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
"""Handle error events."""
err = self._base_task_pipeline._handle_error(event=event)
yield self._base_task_pipeline._error_to_stream_response(err)
def _handle_workflow_started_event(
self, event: QueueWorkflowStartedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle workflow started events."""
# init workflow run
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
self._workflow_run_id = workflow_execution.id_
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
yield start_resp
def _handle_node_retry_event(self, event: QueueNodeRetryEvent, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle node retry events."""
self._ensure_workflow_initialized()
with self._database_session() as session:
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id,
event=event,
)
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
def _handle_node_started_event(
self, event: QueueNodeStartedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle node started events."""
self._ensure_workflow_initialized()
workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
workflow_execution_id=self._workflow_run_id, event=event
)
node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_start_response:
yield node_start_response
def _handle_node_succeeded_event(
self, event: QueueNodeSucceededEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle node succeeded events."""
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
self._save_output_for_event(event, workflow_node_execution.id)
if node_success_response:
yield node_success_response
def _handle_node_failed_events(
self, self,
event: Union[ tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent trace_manager: Optional[TraceQueueManager] = None,
],
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle various node failure events."""
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(
event=event,
)
node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if isinstance(event, QueueNodeExceptionEvent):
self._save_output_for_event(event, workflow_node_execution.id)
if node_failed_response:
yield node_failed_response
def _handle_parallel_branch_started_event(
self, event: QueueParallelBranchRunStartedEvent, **kwargs
) -> Generator[StreamResponse, None, None]: ) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch started events.""" """
self._ensure_workflow_initialized() Process stream response.
:return:
parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response( """
task_id=self._application_generate_entity.task_id, graph_runtime_state = None
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield parallel_start_resp
def _handle_parallel_branch_finished_events( for queue_message in self._base_task_pipeline._queue_manager.listen():
self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs event = queue_message.event
) -> Generator[StreamResponse, None, None]:
"""Handle parallel branch finished events."""
self._ensure_workflow_initialized()
parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response( if isinstance(event, QueuePingEvent):
task_id=self._application_generate_entity.task_id, yield self._base_task_pipeline._ping_stream_response()
workflow_execution_id=self._workflow_run_id, elif isinstance(event, QueueErrorEvent):
event=event, err = self._base_task_pipeline._handle_error(event=event)
) yield self._base_task_pipeline._error_to_stream_response(err)
yield parallel_finish_resp break
elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
# init workflow run
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
self._workflow_run_id = workflow_execution.id_
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
def _handle_iteration_start_event( yield start_resp
self, event: QueueIterationStartEvent, **kwargs elif isinstance(
) -> Generator[StreamResponse, None, None]: event,
"""Handle iteration start events.""" QueueNodeRetryEvent,
self._ensure_workflow_initialized() ):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id,
event=event,
)
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
session.commit()
iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response( if response:
task_id=self._application_generate_entity.task_id, yield response
workflow_execution_id=self._workflow_run_id, elif isinstance(event, QueueNodeStartedEvent):
event=event, if not self._workflow_run_id:
) raise ValueError("workflow run not initialized.")
yield iter_start_resp
def _handle_iteration_next_event( workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
self, event: QueueIterationNextEvent, **kwargs workflow_execution_id=self._workflow_run_id, event=event
) -> Generator[StreamResponse, None, None]: )
"""Handle iteration next events.""" node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response(
self._ensure_workflow_initialized() event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response( if node_start_response:
task_id=self._application_generate_entity.task_id, yield node_start_response
workflow_execution_id=self._workflow_run_id, elif isinstance(event, QueueNodeSucceededEvent):
event=event, workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(
) event=event
yield iter_next_resp )
node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
def _handle_iteration_completed_event( self._save_output_for_event(event, workflow_node_execution.id)
self, event: QueueIterationCompletedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle iteration completed events."""
self._ensure_workflow_initialized()
iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response( if node_success_response:
task_id=self._application_generate_entity.task_id, yield node_success_response
workflow_execution_id=self._workflow_run_id, elif isinstance(
event=event, event,
) QueueNodeFailedEvent
yield iter_finish_resp | QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
):
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(
event=event,
)
node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if isinstance(event, QueueNodeExceptionEvent):
self._save_output_for_event(event, workflow_node_execution.id)
def _handle_loop_start_event(self, event: QueueLoopStartEvent, **kwargs) -> Generator[StreamResponse, None, None]: if node_failed_response:
"""Handle loop start events.""" yield node_failed_response
self._ensure_workflow_initialized()
loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response( elif isinstance(event, QueueParallelBranchRunStartedEvent):
task_id=self._application_generate_entity.task_id, if not self._workflow_run_id:
workflow_execution_id=self._workflow_run_id, raise ValueError("workflow run not initialized.")
event=event,
)
yield loop_start_resp
def _handle_loop_next_event(self, event: QueueLoopNextEvent, **kwargs) -> Generator[StreamResponse, None, None]: parallel_start_resp = (
"""Handle loop next events.""" self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
self._ensure_workflow_initialized() task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
)
loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response( yield parallel_start_resp
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
yield loop_next_resp
def _handle_loop_completed_event( elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
self, event: QueueLoopCompletedEvent, **kwargs if not self._workflow_run_id:
) -> Generator[StreamResponse, None, None]: raise ValueError("workflow run not initialized.")
"""Handle loop completed events."""
self._ensure_workflow_initialized()
loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response( parallel_finish_resp = (
task_id=self._application_generate_entity.task_id, self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
workflow_execution_id=self._workflow_run_id, task_id=self._application_generate_entity.task_id,
event=event, workflow_execution_id=self._workflow_run_id,
) event=event,
yield loop_finish_resp )
)
def _handle_workflow_succeeded_event( yield parallel_finish_resp
self,
event: QueueWorkflowSucceededEvent,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow succeeded events."""
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
outputs=event.outputs,
conversation_id=None,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
# save workflow app log elif isinstance(event, QueueIterationStartEvent):
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id,
task_id=self._application_generate_entity.task_id, workflow_execution_id=self._workflow_run_id,
workflow_execution=workflow_execution, event=event,
) )
yield workflow_finish_resp yield iter_start_resp
def _handle_workflow_partial_success_event( elif isinstance(event, QueueIterationNextEvent):
self, if not self._workflow_run_id:
event: QueueWorkflowPartialSuccessEvent, raise ValueError("workflow run not initialized.")
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
trace_manager: Optional[TraceQueueManager] = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow partial success events."""
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
# save workflow app log iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response(
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( yield iter_next_resp
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
yield workflow_finish_resp elif isinstance(event, QueueIterationCompletedEvent):
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
def _handle_workflow_failed_and_stop_events( iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response(
self, task_id=self._application_generate_entity.task_id,
event: Union[QueueWorkflowFailedEvent, QueueStopEvent], workflow_execution_id=self._workflow_run_id,
*, event=event,
graph_runtime_state: Optional[GraphRuntimeState] = None, )
trace_manager: Optional[TraceQueueManager] = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow failed and stop events."""
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
status=WorkflowExecutionStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowExecutionStatus.STOPPED,
error_message=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
# save workflow app log yield iter_finish_resp
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( elif isinstance(event, QueueLoopStartEvent):
session=session, if not self._workflow_run_id:
task_id=self._application_generate_entity.task_id, raise ValueError("workflow run not initialized.")
workflow_execution=workflow_execution,
)
yield workflow_finish_resp loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
def _handle_text_chunk_event( yield loop_start_resp
self,
event: QueueTextChunkEvent,
*,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle text chunk events."""
delta_text = event.text
if delta_text is None:
return
# only publish tts message at text chunk streaming elif isinstance(event, QueueLoopNextEvent):
if tts_publisher and queue_message: if not self._workflow_run_id:
tts_publisher.publish(queue_message) raise ValueError("workflow run not initialized.")
yield self._text_chunk_to_stream_response(delta_text, from_variable_selector=event.from_variable_selector) loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
event=event,
)
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]: yield loop_next_resp
"""Handle agent log events."""
yield self._workflow_response_converter.handle_agent_log(
task_id=self._application_generate_entity.task_id, event=event
)
def _get_event_handlers(self) -> dict[type, Callable]: elif isinstance(event, QueueLoopCompletedEvent):
"""Get mapping of event types to their handlers using fluent pattern.""" if not self._workflow_run_id:
return { raise ValueError("workflow run not initialized.")
# Basic events
QueuePingEvent: self._handle_ping_event,
QueueErrorEvent: self._handle_error_event,
QueueTextChunkEvent: self._handle_text_chunk_event,
# Workflow events
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
# Node events
QueueNodeRetryEvent: self._handle_node_retry_event,
QueueNodeStartedEvent: self._handle_node_started_event,
QueueNodeSucceededEvent: self._handle_node_succeeded_event,
# Parallel branch events
QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
# Iteration events
QueueIterationStartEvent: self._handle_iteration_start_event,
QueueIterationNextEvent: self._handle_iteration_next_event,
QueueIterationCompletedEvent: self._handle_iteration_completed_event,
# Loop events
QueueLoopStartEvent: self._handle_loop_start_event,
QueueLoopNextEvent: self._handle_loop_next_event,
QueueLoopCompletedEvent: self._handle_loop_completed_event,
# Agent events
QueueAgentLogEvent: self._handle_agent_log_event,
}
def _dispatch_event(
self,
event: Any,
*,
graph_runtime_state: Optional[GraphRuntimeState] = None,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None,
queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
) -> Generator[StreamResponse, None, None]:
"""Dispatch events using elegant pattern matching."""
handlers = self._get_event_handlers()
event_type = type(event)
# Direct handler lookup loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response(
if handler := handlers.get(event_type): task_id=self._application_generate_entity.task_id,
yield from handler( workflow_execution_id=self._workflow_run_id,
event, event=event,
graph_runtime_state=graph_runtime_state, )
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
return
# Handle node failure events with isinstance check yield loop_finish_resp
if isinstance(
event, elif isinstance(event, QueueWorkflowSucceededEvent):
( if not self._workflow_run_id:
QueueNodeFailedEvent, raise ValueError("workflow run not initialized.")
QueueNodeInIterationFailedEvent, if not graph_runtime_state:
QueueNodeInLoopFailedEvent, raise ValueError("graph runtime state not initialized.")
QueueNodeExceptionEvent,
), with Session(db.engine, expire_on_commit=False) as session:
): workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
yield from self._handle_node_failed_events( workflow_run_id=self._workflow_run_id,
event, total_tokens=graph_runtime_state.total_tokens,
graph_runtime_state=graph_runtime_state, total_steps=graph_runtime_state.node_run_steps,
tts_publisher=tts_publisher, outputs=event.outputs,
trace_manager=trace_manager, conversation_id=None,
queue_message=queue_message, trace_manager=trace_manager,
) )
return
# Handle parallel branch finished events with isinstance check # save workflow app log
if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)): self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
yield from self._handle_parallel_branch_finished_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
)
return
# Handle workflow failed and stop events with isinstance check workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)): session=session,
yield from self._handle_workflow_failed_and_stop_events( task_id=self._application_generate_entity.task_id,
event, workflow_execution=workflow_execution,
graph_runtime_state=graph_runtime_state, )
tts_publisher=tts_publisher, session.commit()
trace_manager=trace_manager,
queue_message=queue_message, yield workflow_finish_resp
) elif isinstance(event, QueueWorkflowPartialSuccessEvent):
return if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
with Session(db.engine, expire_on_commit=False) as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
workflow_run_id=self._workflow_run_id,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)
# For unhandled events, we continue (original behavior) # save workflow app log
return self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
def _process_stream_response( workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
self, session=session,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None, task_id=self._application_generate_entity.task_id,
trace_manager: Optional[TraceQueueManager] = None, workflow_execution=workflow_execution,
) -> Generator[StreamResponse, None, None]: )
""" session.commit()
Process stream response using elegant Fluent Python patterns.
Maintains exact same functionality as original 44-if-statement version. yield workflow_finish_resp
""" elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
# Initialize graph runtime state if not self._workflow_run_id:
graph_runtime_state = None raise ValueError("workflow run not initialized.")
if not graph_runtime_state:
for queue_message in self._base_task_pipeline._queue_manager.listen(): raise ValueError("graph runtime state not initialized.")
event = queue_message.event
with Session(db.engine, expire_on_commit=False) as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowExecutionStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowExecutionStatus.STOPPED,
error_message=event.error
if isinstance(event, QueueWorkflowFailedEvent)
else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
)
match event: # save workflow app log
case QueueWorkflowStartedEvent(): self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
graph_runtime_state = event.graph_runtime_state
yield from self._handle_workflow_started_event(event)
case QueueTextChunkEvent(): workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
yield from self._handle_text_chunk_event( session=session,
event, tts_publisher=tts_publisher, queue_message=queue_message task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
) )
session.commit()
case QueueErrorEvent(): yield workflow_finish_resp
yield from self._handle_error_event(event) elif isinstance(event, QueueTextChunkEvent):
break delta_text = event.text
if delta_text is None:
continue
# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(queue_message)
# Handle all other events through elegant dispatch yield self._text_chunk_to_stream_response(
case _: delta_text, from_variable_selector=event.from_variable_selector
if responses := list( )
self._dispatch_event( elif isinstance(event, QueueAgentLogEvent):
event, yield self._workflow_response_converter.handle_agent_log(
graph_runtime_state=graph_runtime_state, task_id=self._application_generate_entity.task_id, event=event
tts_publisher=tts_publisher, )
trace_manager=trace_manager, else:
queue_message=queue_message, continue
)
):
yield from responses
if tts_publisher: if tts_publisher:
tts_publisher.publish(None) tts_publisher.publish(None)

@ -1,7 +1,8 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, cast from typing import Any, Optional, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
AppQueueEvent, AppQueueEvent,
QueueAgentLogEvent, QueueAgentLogEvent,
@ -64,20 +65,18 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.system_variable import SystemVariable from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App
from models.workflow import Workflow from models.workflow import Workflow
class WorkflowBasedAppRunner: class WorkflowBasedAppRunner(AppRunner):
def __init__( def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None:
self, self.queue_manager = queue_manager
*,
queue_manager: AppQueueManager,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
app_id: str,
) -> None:
self._queue_manager = queue_manager
self._variable_loader = variable_loader self._variable_loader = variable_loader
self._app_id = app_id
def _get_app_id(self) -> str:
raise NotImplementedError("not implemented")
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph: def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
""" """
@ -694,5 +693,21 @@ class WorkflowBasedAppRunner:
) )
) )
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
def _publish_event(self, event: AppQueueEvent) -> None: def _publish_event(self, event: AppQueueEvent) -> None:
self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)

@ -26,7 +26,7 @@ class AnnotationReplyFeature:
:return: :return:
""" """
annotation_setting = ( annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first() db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_record.id).first()
) )
if not annotation_setting: if not annotation_setting:

@ -471,7 +471,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:return: :return:
""" """
agent_thought: Optional[MessageAgentThought] = ( agent_thought: Optional[MessageAgentThought] = (
db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first() db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first()
) )
if agent_thought: if agent_thought:

@ -81,7 +81,7 @@ class MessageCycleManager:
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str): def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context(): with flask_app.app_context():
# get conversation and message # get conversation and message
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first() conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
if not conversation: if not conversation:
return return
@ -140,7 +140,7 @@ class MessageCycleManager:
:param event: event :param event: event
:return: :return:
""" """
message_file = db.session.query(MessageFile).where(MessageFile.id == event.message_file_id).first() message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first()
if message_file and message_file.url is not None: if message_file and message_file.url is not None:
# get tool file id # get tool file id

@ -49,7 +49,7 @@ class DatasetIndexToolCallbackHandler:
for document in documents: for document in documents:
if document.metadata is not None: if document.metadata is not None:
document_id = document.metadata["document_id"] document_id = document.metadata["document_id"]
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first() dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
if not dataset_document: if not dataset_document:
_logger.warning( _logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s", "Expected DatasetDocument record to exist, but none was found, document_id=%s",
@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = ( child_chunk = (
db.session.query(ChildChunk) db.session.query(ChildChunk)
.where( .filter(
ChildChunk.index_node_id == document.metadata["doc_id"], ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id, ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id, ChildChunk.document_id == dataset_document.id,
@ -69,18 +69,18 @@ class DatasetIndexToolCallbackHandler:
if child_chunk: if child_chunk:
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.where(DocumentSegment.id == child_chunk.segment_id) .filter(DocumentSegment.id == child_chunk.segment_id)
.update( .update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
) )
) )
else: else:
query = db.session.query(DocumentSegment).where( query = db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == document.metadata["doc_id"] DocumentSegment.index_node_id == document.metadata["doc_id"]
) )
if "dataset_id" in document.metadata: if "dataset_id" in document.metadata:
query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"]) query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment # add hit count to document segment
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)

@ -191,7 +191,7 @@ class ProviderConfiguration(BaseModel):
provider_record = ( provider_record = (
db.session.query(Provider) db.session.query(Provider)
.where( .filter(
Provider.tenant_id == self.tenant_id, Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value, Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_name.in_(provider_names), Provider.provider_name.in_(provider_names),
@ -351,7 +351,7 @@ class ProviderConfiguration(BaseModel):
provider_model_record = ( provider_model_record = (
db.session.query(ProviderModel) db.session.query(ProviderModel)
.where( .filter(
ProviderModel.tenant_id == self.tenant_id, ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name.in_(provider_names), ProviderModel.provider_name.in_(provider_names),
ProviderModel.model_name == model, ProviderModel.model_name == model,
@ -481,7 +481,7 @@ class ProviderConfiguration(BaseModel):
return ( return (
db.session.query(ProviderModelSetting) db.session.query(ProviderModelSetting)
.where( .filter(
ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(provider_names), ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_type == model_type.to_origin_model_type(),
@ -560,7 +560,7 @@ class ProviderConfiguration(BaseModel):
return ( return (
db.session.query(LoadBalancingModelConfig) db.session.query(LoadBalancingModelConfig)
.where( .filter(
LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(provider_names), LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
@ -583,7 +583,7 @@ class ProviderConfiguration(BaseModel):
load_balancing_config_count = ( load_balancing_config_count = (
db.session.query(LoadBalancingModelConfig) db.session.query(LoadBalancingModelConfig)
.where( .filter(
LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(provider_names), LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
@ -627,7 +627,7 @@ class ProviderConfiguration(BaseModel):
model_setting = ( model_setting = (
db.session.query(ProviderModelSetting) db.session.query(ProviderModelSetting)
.where( .filter(
ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(provider_names), ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_type == model_type.to_origin_model_type(),
@ -693,7 +693,7 @@ class ProviderConfiguration(BaseModel):
preferred_model_provider = ( preferred_model_provider = (
db.session.query(TenantPreferredModelProvider) db.session.query(TenantPreferredModelProvider)
.where( .filter(
TenantPreferredModelProvider.tenant_id == self.tenant_id, TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name.in_(provider_names), TenantPreferredModelProvider.provider_name.in_(provider_names),
) )

@ -32,7 +32,7 @@ class ApiExternalDataTool(ExternalDataTool):
# get api_based_extension # get api_based_extension
api_based_extension = ( api_based_extension = (
db.session.query(APIBasedExtension) db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first() .first()
) )
@ -56,7 +56,7 @@ class ApiExternalDataTool(ExternalDataTool):
# get api_based_extension # get api_based_extension
api_based_extension = ( api_based_extension = (
db.session.query(APIBasedExtension) db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id) .filter(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
.first() .first()
) )

@ -7,7 +7,6 @@ from core.model_runtime.entities import (
AudioPromptMessageContent, AudioPromptMessageContent,
DocumentPromptMessageContent, DocumentPromptMessageContent,
ImagePromptMessageContent, ImagePromptMessageContent,
TextPromptMessageContent,
VideoPromptMessageContent, VideoPromptMessageContent,
) )
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
@ -45,44 +44,11 @@ def to_prompt_message_content(
*, *,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None, image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> PromptMessageContentUnionTypes: ) -> PromptMessageContentUnionTypes:
"""
Convert a file to prompt message content.
This function converts files to their appropriate prompt message content types.
For supported file types (IMAGE, AUDIO, VIDEO, DOCUMENT), it creates the
corresponding message content with proper encoding/URL.
For unsupported file types, instead of raising an error, it returns a
TextPromptMessageContent with a descriptive message about the file.
Args:
f: The file to convert
image_detail_config: Optional detail configuration for image files
Returns:
PromptMessageContentUnionTypes: The appropriate message content type
Raises:
ValueError: If file extension or mime_type is missing
"""
if f.extension is None: if f.extension is None:
raise ValueError("Missing file extension") raise ValueError("Missing file extension")
if f.mime_type is None: if f.mime_type is None:
raise ValueError("Missing file mime_type") raise ValueError("Missing file mime_type")
prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = {
FileType.IMAGE: ImagePromptMessageContent,
FileType.AUDIO: AudioPromptMessageContent,
FileType.VIDEO: VideoPromptMessageContent,
FileType.DOCUMENT: DocumentPromptMessageContent,
}
# Check if file type is supported
if f.type not in prompt_class_map:
# For unsupported file types, return a text description
return TextPromptMessageContent(data=f"[Unsupported file type: {f.filename} ({f.type.value})]")
# Process supported file types
params = { params = {
"base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "", "base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "",
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "", "url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
@ -92,7 +58,17 @@ def to_prompt_message_content(
if f.type == FileType.IMAGE: if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
return prompt_class_map[f.type].model_validate(params) prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = {
FileType.IMAGE: ImagePromptMessageContent,
FileType.AUDIO: AudioPromptMessageContent,
FileType.VIDEO: VideoPromptMessageContent,
FileType.DOCUMENT: DocumentPromptMessageContent,
}
try:
return prompt_class_map[f.type].model_validate(params)
except KeyError:
raise ValueError(f"file type {f.type} is not supported")
def download(f: File, /): def download(f: File, /):

@ -15,13 +15,13 @@ def encrypt_token(tenant_id: str, token: str):
from models.account import Tenant from models.account import Tenant
from models.engine import db from models.engine import db
if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()): if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
raise ValueError(f"Tenant with id {tenant_id} not found") raise ValueError(f"Tenant with id {tenant_id} not found")
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode() return base64.b64encode(encrypted_token).decode()
def decrypt_token(tenant_id: str, token: str) -> str: def decrypt_token(tenant_id: str, token: str):
return rsa.decrypt(base64.b64decode(token), tenant_id) return rsa.decrypt(base64.b64decode(token), tenant_id)

@ -25,29 +25,9 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
url = str(marketplace_api_url / "api/v1/plugins/batch") url = str(marketplace_api_url / "api/v1/plugins/batch")
response = requests.post(url, json={"plugin_ids": plugin_ids}) response = requests.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status() response.raise_for_status()
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]] return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
def batch_fetch_plugin_manifests_ignore_deserialization_error(
plugin_ids: list[str],
) -> Sequence[MarketplacePluginDeclaration]:
if len(plugin_ids) == 0:
return []
url = str(marketplace_api_url / "api/v1/plugins/batch")
response = requests.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()
result: list[MarketplacePluginDeclaration] = []
for plugin in response.json()["data"]["plugins"]:
try:
result.append(MarketplacePluginDeclaration(**plugin))
except Exception as e:
pass
return result
def record_install_plugin_event(plugin_unique_identifier: str): def record_install_plugin_event(plugin_unique_identifier: str):
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count") url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier}) response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})

@ -1,42 +0,0 @@
import re
from collections.abc import Mapping
from typing import Any, Optional
def is_valid_trace_id(trace_id: str) -> bool:
"""
Check if the trace_id is valid.
Requirements: 1-128 characters, only letters, numbers, '-', and '_'.
"""
return bool(re.match(r"^[a-zA-Z0-9\-_]{1,128}$", trace_id))
def get_external_trace_id(request: Any) -> Optional[str]:
"""
Retrieve the trace_id from the request.
Priority: header ('X-Trace-Id'), then parameters, then JSON body. Returns None if not provided or invalid.
"""
trace_id = request.headers.get("X-Trace-Id")
if not trace_id:
trace_id = request.args.get("trace_id")
if not trace_id and getattr(request, "is_json", False):
json_data = getattr(request, "json", None)
if json_data:
trace_id = json_data.get("trace_id")
if isinstance(trace_id, str) and is_valid_trace_id(trace_id):
return trace_id
return None
def extract_external_trace_id_from_args(args: Mapping[str, Any]) -> dict:
"""
Extract 'external_trace_id' from args.
Returns a dict suitable for use in extras. Returns an empty dict if not found.
"""
trace_id = args.get("external_trace_id")
if trace_id:
return {"external_trace_id": trace_id}
return {}

@ -59,7 +59,7 @@ class IndexingRunner:
# get the process rule # get the process rule
processing_rule = ( processing_rule = (
db.session.query(DatasetProcessRule) db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first() .first()
) )
if not processing_rule: if not processing_rule:
@ -119,12 +119,12 @@ class IndexingRunner:
db.session.delete(document_segment) db.session.delete(document_segment)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# delete child chunks # delete child chunks
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete() db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit() db.session.commit()
# get the process rule # get the process rule
processing_rule = ( processing_rule = (
db.session.query(DatasetProcessRule) db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first() .first()
) )
if not processing_rule: if not processing_rule:
@ -212,7 +212,7 @@ class IndexingRunner:
# get the process rule # get the process rule
processing_rule = ( processing_rule = (
db.session.query(DatasetProcessRule) db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first() .first()
) )
@ -316,7 +316,7 @@ class IndexingRunner:
# delete image files and related db records # delete image files and related db records
image_upload_file_ids = get_image_upload_file_ids(document.page_content) image_upload_file_ids = get_image_upload_file_ids(document.page_content)
for upload_file_id in image_upload_file_ids: for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
if image_file is None: if image_file is None:
continue continue
try: try:
@ -346,7 +346,7 @@ class IndexingRunner:
raise ValueError("no upload file found") raise ValueError("no upload file found")
file_detail = ( file_detail = (
db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none() db.session.query(UploadFile).filter(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
) )
if file_detail: if file_detail:
@ -599,7 +599,7 @@ class IndexingRunner:
keyword.create(documents) keyword.create(documents)
if dataset.indexing_technique != "high_quality": if dataset.indexing_technique != "high_quality":
document_ids = [document.metadata["doc_id"] for document in documents] document_ids = [document.metadata["doc_id"] for document in documents]
db.session.query(DocumentSegment).where( db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == document_id, DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id, DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id.in_(document_ids), DocumentSegment.index_node_id.in_(document_ids),
@ -630,7 +630,7 @@ class IndexingRunner:
index_processor.load(dataset, chunk_documents, with_keywords=False) index_processor.load(dataset, chunk_documents, with_keywords=False)
document_ids = [document.metadata["doc_id"] for document in chunk_documents] document_ids = [document.metadata["doc_id"] for document in chunk_documents]
db.session.query(DocumentSegment).where( db.session.query(DocumentSegment).filter(
DocumentSegment.document_id == dataset_document.id, DocumentSegment.document_id == dataset_document.id,
DocumentSegment.dataset_id == dataset.id, DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(document_ids), DocumentSegment.index_node_id.in_(document_ids),
@ -672,7 +672,8 @@ class IndexingRunner:
if extra_update_params: if extra_update_params:
update_params.update(extra_update_params) update_params.update(extra_update_params)
db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params) # type: ignore
db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params)
db.session.commit() db.session.commit()
@staticmethod @staticmethod

@ -114,8 +114,7 @@ class LLMGenerator:
), ),
) )
text_content = response.message.get_text_content() questions = output_parser.parse(cast(str, response.message.content))
questions = output_parser.parse(text_content) if text_content else []
except InvokeError: except InvokeError:
questions = [] questions = []
except Exception: except Exception:

@ -15,4 +15,5 @@ class SuggestedQuestionsAfterAnswerOutputParser:
json_obj = json.loads(action_match.group(0).strip()) json_obj = json.loads(action_match.group(0).strip())
else: else:
json_obj = [] json_obj = []
return json_obj return json_obj

@ -8,7 +8,7 @@ from core.mcp.types import (
OAuthTokens, OAuthTokens,
) )
from models.tools import MCPToolProvider from models.tools import MCPToolProvider
from services.tools.mcp_tools_manage_service import MCPToolManageService from services.tools.mcp_tools_mange_service import MCPToolManageService
LATEST_PROTOCOL_VERSION = "1.0" LATEST_PROTOCOL_VERSION = "1.0"

@ -68,17 +68,15 @@ class MCPClient:
} }
parsed_url = urlparse(self.server_url) parsed_url = urlparse(self.server_url)
path = parsed_url.path or "" path = parsed_url.path
method_name = path.rstrip("/").split("/")[-1] if path else "" method_name = path.rstrip("/").split("/")[-1] if path else ""
if method_name in connection_methods: try:
client_factory = connection_methods[method_name] client_factory = connection_methods[method_name]
self.connect_server(client_factory, method_name) self.connect_server(client_factory, method_name)
else: except KeyError:
try: try:
logger.debug(f"Not supported method {method_name} found in URL path, trying default 'mcp' method.")
self.connect_server(sse_client, "sse") self.connect_server(sse_client, "sse")
except MCPConnectionError: except MCPConnectionError:
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
self.connect_server(streamablehttp_client, "mcp") self.connect_server(streamablehttp_client, "mcp")
def connect_server( def connect_server(
@ -93,7 +91,7 @@ class MCPClient:
else {} else {}
) )
self._streams_context = client_factory(url=self.server_url, headers=headers) self._streams_context = client_factory(url=self.server_url, headers=headers)
if not self._streams_context: if self._streams_context is None:
raise MCPConnectionError("Failed to create connection context") raise MCPConnectionError("Failed to create connection context")
# Use exit_stack to manage context managers properly # Use exit_stack to manage context managers properly
@ -143,11 +141,10 @@ class MCPClient:
try: try:
# ExitStack will handle proper cleanup of all managed context managers # ExitStack will handle proper cleanup of all managed context managers
self.exit_stack.close() self.exit_stack.close()
except Exception as e:
logging.exception("Error during cleanup")
raise ValueError(f"Error during cleanup: {e}")
finally:
self._session = None self._session = None
self._session_context = None self._session_context = None
self._streams_context = None self._streams_context = None
self._initialized = False self._initialized = False
except Exception as e:
logging.exception("Error during cleanup")
raise ValueError(f"Error during cleanup: {e}")

@ -28,7 +28,7 @@ class MCPServerStreamableHTTPRequestHandler:
): ):
self.app = app self.app = app
self.request = request self.request = request
mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == self.app.id).first() mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.app.id).first()
if not mcp_server: if not mcp_server:
raise ValueError("MCP server not found") raise ValueError("MCP server not found")
self.mcp_server: AppMCPServer = mcp_server self.mcp_server: AppMCPServer = mcp_server
@ -192,7 +192,7 @@ class MCPServerStreamableHTTPRequestHandler:
def retrieve_end_user(self): def retrieve_end_user(self):
return ( return (
db.session.query(EndUser) db.session.query(EndUser)
.where(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp") .filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.first() .first()
) )

@ -67,7 +67,7 @@ class TokenBufferMemory:
prompt_messages: list[PromptMessage] = [] prompt_messages: list[PromptMessage] = []
for message in messages: for message in messages:
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all() files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if files: if files:
file_extra_config = None file_extra_config = None
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}: if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:

@ -156,23 +156,6 @@ class PromptMessage(ABC, BaseModel):
""" """
return not self.content return not self.content
def get_text_content(self) -> str:
"""
Get text content from prompt message.
:return: Text content as string, empty string if no text content
"""
if isinstance(self.content, str):
return self.content
elif isinstance(self.content, list):
text_parts = []
for item in self.content:
if isinstance(item, TextPromptMessageContent):
text_parts.append(item.data)
return "".join(text_parts)
else:
return ""
@field_validator("content", mode="before") @field_validator("content", mode="before")
@classmethod @classmethod
def validate_content(cls, v): def validate_content(cls, v):

@ -89,7 +89,7 @@ class ApiModeration(Moderation):
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]: def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
extension = ( extension = (
db.session.query(APIBasedExtension) db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first() .first()
) )

@ -101,8 +101,7 @@ class AliyunDataTrace(BaseTraceInstance):
raise ValueError(f"Aliyun get run url failed: {str(e)}") raise ValueError(f"Aliyun get run url failed: {str(e)}")
def workflow_trace(self, trace_info: WorkflowTraceInfo): def workflow_trace(self, trace_info: WorkflowTraceInfo):
external_trace_id = trace_info.metadata.get("external_trace_id") trace_id = convert_to_trace_id(trace_info.workflow_run_id)
trace_id = external_trace_id or convert_to_trace_id(trace_info.workflow_run_id)
workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow") workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow")
self.add_workflow_span(trace_id, workflow_span_id, trace_info) self.add_workflow_span(trace_id, workflow_span_id, trace_info)
@ -120,7 +119,7 @@ class AliyunDataTrace(BaseTraceInstance):
user_id = message_data.from_account_id user_id = message_data.from_account_id
if message_data.from_end_user_id: if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = ( end_user_data: Optional[EndUser] = (
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
) )
if end_user_data is not None: if end_user_data is not None:
user_id = end_user_data.session_id user_id = end_user_data.session_id
@ -244,14 +243,14 @@ class AliyunDataTrace(BaseTraceInstance):
if not app_id: if not app_id:
raise ValueError("No app_id found in trace_info metadata") raise ValueError("No app_id found in trace_info metadata")
app = session.query(App).where(App.id == app_id).first() app = session.query(App).filter(App.id == app_id).first()
if not app: if not app:
raise ValueError(f"App with id {app_id} not found") raise ValueError(f"App with id {app_id} not found")
if not app.created_by: if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)") raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).where(Account.id == app.created_by).first() service_account = session.query(Account).filter(Account.id == app.created_by).first()
if not service_account: if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = ( current_tenant = (

@ -3,7 +3,7 @@ import json
import logging import logging
import os import os
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Optional, Union, cast from typing import Optional, Union, cast
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
from opentelemetry import trace from opentelemetry import trace
@ -142,8 +142,11 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
raise raise
def workflow_trace(self, trace_info: WorkflowTraceInfo): def workflow_trace(self, trace_info: WorkflowTraceInfo):
if trace_info.message_data is None:
return
workflow_metadata = { workflow_metadata = {
"workflow_run_id": trace_info.workflow_run_id or "", "workflow_id": trace_info.workflow_run_id or "",
"message_id": trace_info.message_id or "", "message_id": trace_info.message_id or "",
"workflow_app_log_id": trace_info.workflow_app_log_id or "", "workflow_app_log_id": trace_info.workflow_app_log_id or "",
"status": trace_info.workflow_run_status or "", "status": trace_info.workflow_run_status or "",
@ -153,8 +156,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
} }
workflow_metadata.update(trace_info.metadata) workflow_metadata.update(trace_info.metadata)
external_trace_id = trace_info.metadata.get("external_trace_id") trace_id = uuid_to_trace_id(trace_info.message_id)
trace_id = external_trace_id or uuid_to_trace_id(trace_info.workflow_run_id)
span_id = RandomIdGenerator().generate_span_id() span_id = RandomIdGenerator().generate_span_id()
context = SpanContext( context = SpanContext(
trace_id=trace_id, trace_id=trace_id,
@ -211,7 +213,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
if model: if model:
node_metadata["ls_model_name"] = model node_metadata["ls_model_name"] = model
outputs = json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {} outputs = json.loads(node_execution.outputs).get("usage", {})
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
if usage_data: if usage_data:
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0) node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
@ -234,34 +236,31 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.SESSION_ID: trace_info.conversation_id or "", SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
}, },
start_time=datetime_to_nanos(created_at), start_time=datetime_to_nanos(created_at),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
) )
try: try:
if node_execution.node_type == "llm": if node_execution.node_type == "llm":
llm_attributes: dict[str, Any] = {
SpanAttributes.INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
}
provider = process_data.get("model_provider") provider = process_data.get("model_provider")
model = process_data.get("model_name") model = process_data.get("model_name")
if provider: if provider:
llm_attributes[SpanAttributes.LLM_PROVIDER] = provider node_span.set_attribute(SpanAttributes.LLM_PROVIDER, provider)
if model: if model:
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = model node_span.set_attribute(SpanAttributes.LLM_MODEL_NAME, model)
outputs = (
json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {} outputs = json.loads(node_execution.outputs).get("usage", {})
)
usage_data = ( usage_data = (
process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
) )
if usage_data: if usage_data:
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = usage_data.get("total_tokens", 0) node_span.set_attribute(
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = usage_data.get("prompt_tokens", 0) SpanAttributes.LLM_TOKEN_COUNT_TOTAL, usage_data.get("total_tokens", 0)
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = usage_data.get( )
"completion_tokens", 0 node_span.set_attribute(
SpanAttributes.LLM_TOKEN_COUNT_PROMPT, usage_data.get("prompt_tokens", 0)
)
node_span.set_attribute(
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, usage_data.get("completion_tokens", 0)
) )
llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", [])))
node_span.set_attributes(llm_attributes)
finally: finally:
node_span.end(end_time=datetime_to_nanos(finished_at)) node_span.end(end_time=datetime_to_nanos(finished_at))
finally: finally:
@ -297,7 +296,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
# Add end user data if available # Add end user data if available
if trace_info.message_data.from_end_user_id: if trace_info.message_data.from_end_user_id:
end_user_data: Optional[EndUser] = ( end_user_data: Optional[EndUser] = (
db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first() db.session.query(EndUser).filter(EndUser.id == trace_info.message_data.from_end_user_id).first()
) )
if end_user_data is not None: if end_user_data is not None:
message_metadata["end_user_id"] = end_user_data.session_id message_metadata["end_user_id"] = end_user_data.session_id
@ -353,7 +352,25 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False), SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id, SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
} }
llm_attributes.update(self._construct_llm_attributes(trace_info.inputs))
if isinstance(trace_info.inputs, list):
for i, msg in enumerate(trace_info.inputs):
if isinstance(msg, dict):
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get(
"role", "user"
)
# todo: handle assistant and tool role messages, as they don't always
# have a text field, but may have a tool_calls field instead
# e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
# 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
elif isinstance(trace_info.inputs, dict):
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(trace_info.inputs)
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
elif isinstance(trace_info.inputs, str):
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = trace_info.inputs
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
if trace_info.total_tokens is not None and trace_info.total_tokens > 0: if trace_info.total_tokens is not None and trace_info.total_tokens > 0:
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = trace_info.total_tokens llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = trace_info.total_tokens
if trace_info.message_tokens is not None and trace_info.message_tokens > 0: if trace_info.message_tokens is not None and trace_info.message_tokens > 0:
@ -703,28 +720,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
WorkflowNodeExecutionModel.process_data, WorkflowNodeExecutionModel.process_data,
WorkflowNodeExecutionModel.execution_metadata, WorkflowNodeExecutionModel.execution_metadata,
) )
.where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) .filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.all() .all()
) )
return workflow_nodes return workflow_nodes
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
"""Helper method to construct LLM attributes with passed prompts."""
attributes = {}
if isinstance(prompts, list):
for i, msg in enumerate(prompts):
if isinstance(msg, dict):
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get("role", "user")
# todo: handle assistant and tool role messages, as they don't always
# have a text field, but may have a tool_calls field instead
# e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
# 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
elif isinstance(prompts, dict):
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(prompts)
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
elif isinstance(prompts, str):
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = prompts
attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
return attributes

@ -44,14 +44,14 @@ class BaseTraceInstance(ABC):
""" """
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator # Get the app to find its creator
app = session.query(App).where(App.id == app_id).first() app = session.query(App).filter(App.id == app_id).first()
if not app: if not app:
raise ValueError(f"App with id {app_id} not found") raise ValueError(f"App with id {app_id} not found")
if not app.created_by: if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)") raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).where(Account.id == app.created_by).first() service_account = session.query(Account).filter(Account.id == app.created_by).first()
if not service_account: if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")

@ -67,14 +67,13 @@ class LangFuseDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info) self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo): def workflow_trace(self, trace_info: WorkflowTraceInfo):
external_trace_id = trace_info.metadata.get("external_trace_id") trace_id = trace_info.workflow_run_id
trace_id = external_trace_id or trace_info.workflow_run_id
user_id = trace_info.metadata.get("user_id") user_id = trace_info.metadata.get("user_id")
metadata = trace_info.metadata metadata = trace_info.metadata
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
if trace_info.message_id: if trace_info.message_id:
trace_id = external_trace_id or trace_info.message_id trace_id = trace_info.message_id
name = TraceTaskName.MESSAGE_TRACE.value name = TraceTaskName.MESSAGE_TRACE.value
trace_data = LangfuseTrace( trace_data = LangfuseTrace(
id=trace_id, id=trace_id,
@ -244,7 +243,7 @@ class LangFuseDataTrace(BaseTraceInstance):
user_id = message_data.from_account_id user_id = message_data.from_account_id
if message_data.from_end_user_id: if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = ( end_user_data: Optional[EndUser] = (
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
) )
if end_user_data is not None: if end_user_data is not None:
user_id = end_user_data.session_id user_id = end_user_data.session_id

@ -65,8 +65,7 @@ class LangSmithDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info) self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo): def workflow_trace(self, trace_info: WorkflowTraceInfo):
external_trace_id = trace_info.metadata.get("external_trace_id") trace_id = trace_info.message_id or trace_info.workflow_run_id
trace_id = external_trace_id or trace_info.message_id or trace_info.workflow_run_id
if trace_info.start_time is None: if trace_info.start_time is None:
trace_info.start_time = datetime.now() trace_info.start_time = datetime.now()
message_dotted_order = ( message_dotted_order = (
@ -262,7 +261,7 @@ class LangSmithDataTrace(BaseTraceInstance):
if message_data.from_end_user_id: if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = ( end_user_data: Optional[EndUser] = (
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
) )
if end_user_data is not None: if end_user_data is not None:
end_user_id = end_user_data.session_id end_user_id = end_user_data.session_id

@ -96,8 +96,7 @@ class OpikDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info) self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo): def workflow_trace(self, trace_info: WorkflowTraceInfo):
external_trace_id = trace_info.metadata.get("external_trace_id") dify_trace_id = trace_info.workflow_run_id
dify_trace_id = external_trace_id or trace_info.workflow_run_id
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id) opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
workflow_metadata = wrap_metadata( workflow_metadata = wrap_metadata(
trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id
@ -105,7 +104,7 @@ class OpikDataTrace(BaseTraceInstance):
root_span_id = None root_span_id = None
if trace_info.message_id: if trace_info.message_id:
dify_trace_id = external_trace_id or trace_info.message_id dify_trace_id = trace_info.message_id
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id) opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
trace_data = { trace_data = {
@ -284,7 +283,7 @@ class OpikDataTrace(BaseTraceInstance):
if message_data.from_end_user_id: if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = ( end_user_data: Optional[EndUser] = (
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
) )
if end_user_data is not None: if end_user_data is not None:
end_user_id = end_user_data.session_id end_user_id = end_user_data.session_id

@ -218,7 +218,7 @@ class OpsTraceManager:
""" """
trace_config_data: Optional[TraceAppConfig] = ( trace_config_data: Optional[TraceAppConfig] = (
db.session.query(TraceAppConfig) db.session.query(TraceAppConfig)
.where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first() .first()
) )
@ -226,7 +226,7 @@ class OpsTraceManager:
return None return None
# decrypt_token # decrypt_token
app = db.session.query(App).where(App.id == app_id).first() app = db.session.query(App).filter(App.id == app_id).first()
if not app: if not app:
raise ValueError("App not found") raise ValueError("App not found")
@ -253,7 +253,7 @@ class OpsTraceManager:
if app_id is None: if app_id is None:
return None return None
app: Optional[App] = db.session.query(App).where(App.id == app_id).first() app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
if app is None: if app is None:
return None return None
@ -293,18 +293,18 @@ class OpsTraceManager:
@classmethod @classmethod
def get_app_config_through_message_id(cls, message_id: str): def get_app_config_through_message_id(cls, message_id: str):
app_model_config = None app_model_config = None
message_data = db.session.query(Message).where(Message.id == message_id).first() message_data = db.session.query(Message).filter(Message.id == message_id).first()
if not message_data: if not message_data:
return None return None
conversation_id = message_data.conversation_id conversation_id = message_data.conversation_id
conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first() conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
if not conversation_data: if not conversation_data:
return None return None
if conversation_data.app_model_config_id: if conversation_data.app_model_config_id:
app_model_config = ( app_model_config = (
db.session.query(AppModelConfig) db.session.query(AppModelConfig)
.where(AppModelConfig.id == conversation_data.app_model_config_id) .filter(AppModelConfig.id == conversation_data.app_model_config_id)
.first() .first()
) )
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs: elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
@ -331,7 +331,7 @@ class OpsTraceManager:
if tracing_provider is not None: if tracing_provider is not None:
raise ValueError(f"Invalid tracing provider: {tracing_provider}") raise ValueError(f"Invalid tracing provider: {tracing_provider}")
app_config: Optional[App] = db.session.query(App).where(App.id == app_id).first() app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
if not app_config: if not app_config:
raise ValueError("App not found") raise ValueError("App not found")
app_config.tracing = json.dumps( app_config.tracing = json.dumps(
@ -349,7 +349,7 @@ class OpsTraceManager:
:param app_id: app id :param app_id: app id
:return: :return:
""" """
app: Optional[App] = db.session.query(App).where(App.id == app_id).first() app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
if not app: if not app:
raise ValueError("App not found") raise ValueError("App not found")
if not app.tracing: if not app.tracing:
@ -520,10 +520,6 @@ class TraceTask:
"app_id": workflow_run.app_id, "app_id": workflow_run.app_id,
} }
external_trace_id = self.kwargs.get("external_trace_id")
if external_trace_id:
metadata["external_trace_id"] = external_trace_id
workflow_trace_info = WorkflowTraceInfo( workflow_trace_info = WorkflowTraceInfo(
workflow_data=workflow_run.to_dict(), workflow_data=workflow_run.to_dict(),
conversation_id=conversation_id, conversation_id=conversation_id,

@ -3,8 +3,6 @@ from datetime import datetime
from typing import Optional, Union from typing import Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
from sqlalchemy import select
from extensions.ext_database import db from extensions.ext_database import db
from models.model import Message from models.model import Message
@ -22,7 +20,7 @@ def filter_none_values(data: dict):
def get_message_data(message_id: str): def get_message_data(message_id: str):
return db.session.scalar(select(Message).where(Message.id == message_id)) return db.session.query(Message).filter(Message.id == message_id).first()
@contextmanager @contextmanager

@ -87,8 +87,7 @@ class WeaveDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info) self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo): def workflow_trace(self, trace_info: WorkflowTraceInfo):
external_trace_id = trace_info.metadata.get("external_trace_id") trace_id = trace_info.message_id or trace_info.workflow_run_id
trace_id = external_trace_id or trace_info.message_id or trace_info.workflow_run_id
if trace_info.start_time is None: if trace_info.start_time is None:
trace_info.start_time = datetime.now() trace_info.start_time = datetime.now()
@ -235,7 +234,7 @@ class WeaveDataTrace(BaseTraceInstance):
if message_data.from_end_user_id: if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = ( end_user_data: Optional[EndUser] = (
db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
) )
if end_user_data is not None: if end_user_data is not None:
end_user_id = end_user_data.session_id end_user_id = end_user_data.session_id

@ -193,9 +193,9 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
get the user by user id get the user by user id
""" """
user = db.session.query(EndUser).where(EndUser.id == user_id).first() user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
if not user: if not user:
user = db.session.query(Account).where(Account.id == user_id).first() user = db.session.query(Account).filter(Account.id == user_id).first()
if not user: if not user:
raise ValueError("user not found") raise ValueError("user not found")
@ -208,7 +208,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
get app get app
""" """
try: try:
app = db.session.query(App).where(App.id == app_id).where(App.tenant_id == tenant_id).first() app = db.session.query(App).filter(App.id == app_id).filter(App.tenant_id == tenant_id).first()
except Exception: except Exception:
raise ValueError("app not found") raise ValueError("app not found")

@ -32,13 +32,6 @@ class MarketplacePluginDeclaration(BaseModel):
latest_package_identifier: str = Field( latest_package_identifier: str = Field(
..., description="Unique identifier for the latest package release of the plugin" ..., description="Unique identifier for the latest package release of the plugin"
) )
status: str = Field(..., description="Indicate the status of marketplace plugin, enum from `active` `deleted`")
deprecated_reason: str = Field(
..., description="Not empty when status='deleted', indicates the reason why this plugin is deleted(deprecated)"
)
alternative_plugin_id: str = Field(
..., description="Optional, indicates the alternative plugin for user to switch to"
)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod

@ -182,10 +182,6 @@ class PluginOAuthAuthorizationUrlResponse(BaseModel):
class PluginOAuthCredentialsResponse(BaseModel): class PluginOAuthCredentialsResponse(BaseModel):
metadata: Mapping[str, Any] = Field(
default_factory=dict, description="The metadata of the OAuth, like avatar url, name, etc."
)
expires_at: int = Field(default=-1, description="The expires at time of the credentials. UTC timestamp.")
credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.") credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.")

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save