Merge branch 'fix/chore-fix' into dev/plugin-deploy

pull/12372/head
Yeuoly 1 year ago
commit 69b61ef57b
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61

@ -1,5 +1,5 @@
FROM mcr.microsoft.com/devcontainers/python:3.10 FROM mcr.microsoft.com/devcontainers/python:3.12
# [Optional] Uncomment this section to install additional OS packages. # [Optional] Uncomment this section to install additional OS packages.
# RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ # RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
# && apt-get -y install --no-install-recommends <your-package-list-here> # && apt-get -y install --no-install-recommends <your-package-list-here>

@ -1,7 +1,7 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the // For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/anaconda // README at: https://github.com/devcontainers/templates/tree/main/src/anaconda
{ {
"name": "Python 3.10", "name": "Python 3.12",
"build": { "build": {
"context": "..", "context": "..",
"dockerfile": "Dockerfile" "dockerfile": "Dockerfile"

@ -4,7 +4,7 @@ inputs:
python-version: python-version:
description: Python version to use and the Poetry installed with description: Python version to use and the Poetry installed with
required: true required: true
default: '3.10' default: '3.11'
poetry-version: poetry-version:
description: Poetry version to set up description: Poetry version to set up
required: true required: true

@ -20,7 +20,6 @@ jobs:
strategy: strategy:
matrix: matrix:
python-version: python-version:
- "3.10"
- "3.11" - "3.11"
- "3.12" - "3.12"

@ -20,7 +20,6 @@ jobs:
strategy: strategy:
matrix: matrix:
python-version: python-version:
- "3.10"
- "3.11" - "3.11"
- "3.12" - "3.12"

@ -1,6 +1,8 @@
# CONTRIBUTING
So you're looking to contribute to Dify - that's awesome, we can't wait to see what you do. As a startup with limited headcount and funding, we have grand ambitions to design the most intuitive workflow for building and managing LLM applications. Any help from the community counts, truly. So you're looking to contribute to Dify - that's awesome, we can't wait to see what you do. As a startup with limited headcount and funding, we have grand ambitions to design the most intuitive workflow for building and managing LLM applications. Any help from the community counts, truly.
We need to be nimble and ship fast given where we are, but we also want to make sure that contributors like you get as smooth an experience at contributing as possible. We've assembled this contribution guide for that purpose, aiming at getting you familiarized with the codebase & how we work with contributors, so you could quickly jump to the fun part. We need to be nimble and ship fast given where we are, but we also want to make sure that contributors like you get as smooth an experience at contributing as possible. We've assembled this contribution guide for that purpose, aiming at getting you familiarized with the codebase & how we work with contributors, so you could quickly jump to the fun part.
This guide, like Dify itself, is a constant work in progress. We highly appreciate your understanding if at times it lags behind the actual project, and welcome any feedback for us to improve. This guide, like Dify itself, is a constant work in progress. We highly appreciate your understanding if at times it lags behind the actual project, and welcome any feedback for us to improve.
@ -10,14 +12,12 @@ In terms of licensing, please take a minute to read our short [License and Contr
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types: [Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
### Feature requests: ### Feature requests
* If you're opening a new feature request, we'd like you to explain what the proposed feature achieves, and include as much context as possible. [@perzeusss](https://github.com/perzeuss) has made a solid [Feature Request Copilot](https://udify.app/chat/MK2kVSnw1gakVwMX) that helps you draft out your needs. Feel free to give it a try. * If you're opening a new feature request, we'd like you to explain what the proposed feature achieves, and include as much context as possible. [@perzeusss](https://github.com/perzeuss) has made a solid [Feature Request Copilot](https://udify.app/chat/MK2kVSnw1gakVwMX) that helps you draft out your needs. Feel free to give it a try.
* If you want to pick one up from the existing issues, simply drop a comment below it saying so. * If you want to pick one up from the existing issues, simply drop a comment below it saying so.
A team member working in the related direction will be looped in. If all looks good, they will give the go-ahead for you to start coding. We ask that you hold off working on the feature until then, so none of your work goes to waste should we propose changes. A team member working in the related direction will be looped in. If all looks good, they will give the go-ahead for you to start coding. We ask that you hold off working on the feature until then, so none of your work goes to waste should we propose changes.
Depending on whichever area the proposed feature falls under, you might talk to different team members. Here's rundown of the areas each our team members are working on at the moment: Depending on whichever area the proposed feature falls under, you might talk to different team members. Here's rundown of the areas each our team members are working on at the moment:
@ -40,7 +40,7 @@ In terms of licensing, please take a minute to read our short [License and Contr
| Non-core features and minor enhancements | Low Priority | | Non-core features and minor enhancements | Low Priority |
| Valuable but not immediate | Future-Feature | | Valuable but not immediate | Future-Feature |
### Anything else (e.g. bug report, performance optimization, typo correction): ### Anything else (e.g. bug report, performance optimization, typo correction)
* Start coding right away. * Start coding right away.
@ -52,7 +52,6 @@ In terms of licensing, please take a minute to read our short [License and Contr
| Non-critical bugs, performance boosts | Medium Priority | | Non-critical bugs, performance boosts | Medium Priority |
| Minor fixes (typos, confusing but working UI) | Low Priority | | Minor fixes (typos, confusing but working UI) | Low Priority |
## Installing ## Installing
Here are the steps to set up Dify for development: Here are the steps to set up Dify for development:
@ -63,7 +62,7 @@ Here are the steps to set up Dify for development:
Clone the forked repository from your terminal: Clone the forked repository from your terminal:
``` ```shell
git clone git@github.com:<github_username>/dify.git git clone git@github.com:<github_username>/dify.git
``` ```
@ -71,11 +70,11 @@ git clone git@github.com:<github_username>/dify.git
Dify requires the following dependencies to build, make sure they're installed on your system: Dify requires the following dependencies to build, make sure they're installed on your system:
- [Docker](https://www.docker.com/) * [Docker](https://www.docker.com/)
- [Docker Compose](https://docs.docker.com/compose/install/) * [Docker Compose](https://docs.docker.com/compose/install/)
- [Node.js v18.x (LTS)](http://nodejs.org) * [Node.js v18.x (LTS)](http://nodejs.org)
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/) * [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
- [Python](https://www.python.org/) version 3.10.x * [Python](https://www.python.org/) version 3.11.x or 3.12.x
### 4. Installations ### 4. Installations
@ -85,7 +84,7 @@ Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/install-faq) fo
### 5. Visit dify in your browser ### 5. Visit dify in your browser
To validate your set up, head over to [http://localhost:3000](http://localhost:3000) (the default, or your self-configured URL and port) in your browser. You should now see Dify up and running. To validate your set up, head over to [http://localhost:3000](http://localhost:3000) (the default, or your self-configured URL and port) in your browser. You should now see Dify up and running.
## Developing ## Developing
@ -97,9 +96,9 @@ To help you quickly navigate where your contribution fits, a brief, annotated ou
### Backend ### Backend
Difys backend is written in Python using [Flask](https://flask.palletsprojects.com/en/3.0.x/). It uses [SQLAlchemy](https://www.sqlalchemy.org/) for ORM and [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) for task queueing. Authorization logic goes via Flask-login. Difys backend is written in Python using [Flask](https://flask.palletsprojects.com/en/3.0.x/). It uses [SQLAlchemy](https://www.sqlalchemy.org/) for ORM and [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) for task queueing. Authorization logic goes via Flask-login.
``` ```text
[api/] [api/]
├── constants // Constant settings used throughout code base. ├── constants // Constant settings used throughout code base.
├── controllers // API route definitions and request handling logic. ├── controllers // API route definitions and request handling logic.
@ -121,7 +120,7 @@ Difys backend is written in Python using [Flask](https://flask.palletsproject
The website is bootstrapped on [Next.js](https://nextjs.org/) boilerplate in Typescript and uses [Tailwind CSS](https://tailwindcss.com/) for styling. [React-i18next](https://react.i18next.com/) is used for internationalization. The website is bootstrapped on [Next.js](https://nextjs.org/) boilerplate in Typescript and uses [Tailwind CSS](https://tailwindcss.com/) for styling. [React-i18next](https://react.i18next.com/) is used for internationalization.
``` ```text
[web/] [web/]
├── app // layouts, pages, and components ├── app // layouts, pages, and components
│ ├── (commonLayout) // common layout used throughout the app │ ├── (commonLayout) // common layout used throughout the app
@ -149,10 +148,10 @@ The website is bootstrapped on [Next.js](https://nextjs.org/) boilerplate in Typ
## Submitting your PR ## Submitting your PR
At last, time to open a pull request (PR) to our repo. For major features, we first merge them into the `deploy/dev` branch for testing, before they go into the `main` branch. If you run into issues like merge conflicts or don't know how to open a pull request, check out [GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests). At last, time to open a pull request (PR) to our repo. For major features, we first merge them into the `deploy/dev` branch for testing, before they go into the `main` branch. If you run into issues like merge conflicts or don't know how to open a pull request, check out [GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests).
And that's it! Once your PR is merged, you will be featured as a contributor in our [README](https://github.com/langgenius/dify/blob/main/README.md). And that's it! Once your PR is merged, you will be featured as a contributor in our [README](https://github.com/langgenius/dify/blob/main/README.md).
## Getting Help ## Getting Help
If you ever get stuck or got a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat. If you ever get stuck or got a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.

@ -71,7 +71,7 @@ Dify 依赖以下工具和库:
- [Docker Compose](https://docs.docker.com/compose/install/) - [Docker Compose](https://docs.docker.com/compose/install/)
- [Node.js v18.x (LTS)](http://nodejs.org) - [Node.js v18.x (LTS)](http://nodejs.org)
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/) - [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
- [Python](https://www.python.org/) version 3.10.x - [Python](https://www.python.org/) version 3.11.x or 3.12.x
### 4. 安装 ### 4. 安装

@ -74,7 +74,7 @@ Dify を構築するには次の依存関係が必要です。それらがシス
- [Docker Compose](https://docs.docker.com/compose/install/) - [Docker Compose](https://docs.docker.com/compose/install/)
- [Node.js v18.x (LTS)](http://nodejs.org) - [Node.js v18.x (LTS)](http://nodejs.org)
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/) - [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
- [Python](https://www.python.org/) version 3.10.x - [Python](https://www.python.org/) version 3.11.x or 3.12.x
### 4. インストール ### 4. インストール

@ -73,7 +73,7 @@ Dify yêu cầu các phụ thuộc sau để build, hãy đảm bảo chúng đ
- [Docker Compose](https://docs.docker.com/compose/install/) - [Docker Compose](https://docs.docker.com/compose/install/)
- [Node.js v18.x (LTS)](http://nodejs.org) - [Node.js v18.x (LTS)](http://nodejs.org)
- [npm](https://www.npmjs.com/) phiên bản 8.x.x hoặc [Yarn](https://yarnpkg.com/) - [npm](https://www.npmjs.com/) phiên bản 8.x.x hoặc [Yarn](https://yarnpkg.com/)
- [Python](https://www.python.org/) phiên bản 3.10.x - [Python](https://www.python.org/) phiên bản 3.11.x hoặc 3.12.x
### 4. Cài đặt ### 4. Cài đặt
@ -153,4 +153,4 @@ Và thế là xong! Khi PR của bạn được merge, bạn sẽ được giớ
## Nhận trợ giúp ## Nhận trợ giúp
Nếu bạn gặp khó khăn hoặc có câu hỏi cấp bách trong quá trình đóng góp, hãy đặt câu hỏi của bạn trong vấn đề GitHub liên quan, hoặc tham gia [Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi để trò chuyện nhanh chóng. Nếu bạn gặp khó khăn hoặc có câu hỏi cấp bách trong quá trình đóng góp, hãy đặt câu hỏi của bạn trong vấn đề GitHub liên quan, hoặc tham gia [Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi để trò chuyện nhanh chóng.

@ -1,5 +1,5 @@
# base image # base image
FROM python:3.10-slim-bookworm AS base FROM python:3.12-slim-bookworm AS base
WORKDIR /app/api WORKDIR /app/api

@ -42,7 +42,7 @@
5. Install dependencies 5. Install dependencies
```bash ```bash
poetry env use 3.10 poetry env use 3.12
poetry install poetry install
``` ```
@ -81,5 +81,3 @@
```bash ```bash
poetry run -C api bash dev/pytest/pytest_all_tests.sh poetry run -C api bash dev/pytest/pytest_all_tests.sh
``` ```

@ -1,6 +1,11 @@
import os import os
import sys import sys
python_version = sys.version_info
if not ((3, 11) <= python_version < (3, 13)):
print(f"Python 3.11 or 3.12 is required, current version is {python_version.major}.{python_version.minor}")
raise SystemExit(1)
from configs import dify_config from configs import dify_config
if not dify_config.DEBUG: if not dify_config.DEBUG:
@ -30,9 +35,6 @@ from models import account, dataset, model, source, task, tool, tools, web # no
# DO NOT REMOVE ABOVE # DO NOT REMOVE ABOVE
if sys.version_info[:2] == (3, 10):
print("Warning: Python 3.10 will not be supported in the next version.")
warnings.simplefilter("ignore", ResourceWarning) warnings.simplefilter("ignore", ResourceWarning)

@ -27,7 +27,6 @@ class DifyConfig(
# read from dotenv format config file # read from dotenv format config file
env_file=".env", env_file=".env",
env_file_encoding="utf-8", env_file_encoding="utf-8",
frozen=True,
# ignore extra attributes # ignore extra attributes
extra="ignore", extra="ignore",
) )

@ -2,6 +2,7 @@ from flask import Blueprint
from libs.external_api import ExternalApi from libs.external_api import ExternalApi
from .app.app_import import AppImportApi, AppImportConfirmApi
from .files import FileApi, FilePreviewApi, FileSupportTypeApi from .files import FileApi, FilePreviewApi, FileSupportTypeApi
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
@ -17,6 +18,10 @@ api.add_resource(FileSupportTypeApi, "/files/support-type")
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>") api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
api.add_resource(RemoteFileUploadApi, "/remote-files/upload") api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
# Import App
api.add_resource(AppImportApi, "/apps/imports")
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
# Import other controllers # Import other controllers
from . import admin, apikey, extension, feature, ping, setup, version from . import admin, apikey, extension, feature, ping, setup, version

@ -1,7 +1,10 @@
import uuid import uuid
from typing import cast
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, inputs, marshal, marshal_with, reqparse from flask_restful import Resource, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, abort from werkzeug.exceptions import BadRequest, Forbidden, abort
from controllers.console import api from controllers.console import api
@ -12,15 +15,16 @@ from controllers.console.wraps import (
enterprise_license_required, enterprise_license_required,
setup_required, setup_required,
) )
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import OpsTraceManager from core.ops.ops_trace_manager import OpsTraceManager
from extensions.ext_database import db
from fields.app_fields import ( from fields.app_fields import (
app_detail_fields, app_detail_fields,
app_detail_fields_with_site, app_detail_fields_with_site,
app_pagination_fields, app_pagination_fields,
) )
from libs.login import login_required from libs.login import login_required
from services.app_dsl_service import AppDslService from models import Account, App
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService from services.app_service import AppService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
@ -93,99 +97,6 @@ class AppListApi(Resource):
return app, 201 return app, 201
class AppImportDependenciesCheckApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("apps")
def post(self):
"""Check dependencies"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
leaked_dependencies = AppDslService.check_dependencies(
tenant_id=current_user.current_tenant_id, data=args["data"], account=current_user
)
return jsonable_encoder({"leaked": leaked_dependencies}), 200
class AppImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_detail_fields_with_site)
@cloud_edition_billing_resource_check("apps")
def post(self):
"""Import app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args()
app = AppDslService.import_and_create_new_app(
tenant_id=current_user.current_tenant_id, data=args["data"], args=args, account=current_user
)
return app, 201
class AppImportFromUrlApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_detail_fields_with_site)
@cloud_edition_billing_resource_check("apps")
def post(self):
"""Import app from url"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("url", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args()
app = AppDslService.import_and_create_new_app_from_url(
tenant_id=current_user.current_tenant_id, url=args["url"], args=args, account=current_user
)
return app, 201
class AppImportFromUrlDependenciesCheckApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("url", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
leaked_dependencies = AppDslService.check_dependencies_from_url(
tenant_id=current_user.current_tenant_id, url=args["url"], account=current_user
)
return jsonable_encoder({"leaked": leaked_dependencies}), 200
class AppApi(Resource): class AppApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -263,10 +174,24 @@ class AppCopyApi(Resource):
parser.add_argument("icon_background", type=str, location="json") parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args() args = parser.parse_args()
data = AppDslService.export_dsl(app_model=app_model, include_secret=True) with Session(db.engine) as session:
app = AppDslService.import_and_create_new_app( import_service = AppDslService(session)
tenant_id=current_user.current_tenant_id, data=data, args=args, account=current_user yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
) account = cast(Account, current_user)
result = import_service.import_app(
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
yaml_content=yaml_content,
name=args.get("name"),
description=args.get("description"),
icon_type=args.get("icon_type"),
icon=args.get("icon"),
icon_background=args.get("icon_background"),
)
session.commit()
stmt = select(App).where(App.id == result.app.id)
app = session.scalar(stmt)
return app, 201 return app, 201
@ -407,10 +332,6 @@ class AppTraceApi(Resource):
api.add_resource(AppListApi, "/apps") api.add_resource(AppListApi, "/apps")
api.add_resource(AppImportDependenciesCheckApi, "/apps/import/dependencies/check")
api.add_resource(AppImportApi, "/apps/import")
api.add_resource(AppImportFromUrlApi, "/apps/import/url")
api.add_resource(AppImportFromUrlDependenciesCheckApi, "/apps/import/url/dependencies/check")
api.add_resource(AppApi, "/apps/<uuid:app_id>") api.add_resource(AppApi, "/apps/<uuid:app_id>")
api.add_resource(AppCopyApi, "/apps/<uuid:app_id>/copy") api.add_resource(AppCopyApi, "/apps/<uuid:app_id>/copy")
api.add_resource(AppExportApi, "/apps/<uuid:app_id>/export") api.add_resource(AppExportApi, "/apps/<uuid:app_id>/export")

@ -0,0 +1,90 @@
from typing import cast
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from extensions.ext_database import db
from fields.app_fields import app_import_fields
from libs.login import login_required
from models import Account
from services.app_dsl_service import AppDslService, ImportStatus
class AppImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_import_fields)
def post(self):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("mode", type=str, required=True, location="json")
parser.add_argument("yaml_content", type=str, location="json")
parser.add_argument("yaml_url", type=str, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("app_id", type=str, location="json")
args = parser.parse_args()
# Create service with session
with Session(db.engine) as session:
import_service = AppDslService(session)
# Import app
account = cast(Account, current_user)
result = import_service.import_app(
account=account,
import_mode=args["mode"],
yaml_content=args.get("yaml_content"),
yaml_url=args.get("yaml_url"),
name=args.get("name"),
description=args.get("description"),
icon_type=args.get("icon_type"),
icon=args.get("icon"),
icon_background=args.get("icon_background"),
app_id=args.get("app_id"),
)
session.commit()
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
class AppImportConfirmApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_import_fields)
def post(self, import_id):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
# Create service with session
with Session(db.engine) as session:
import_service = AppDslService(session)
# Confirm import
account = cast(Account, current_user)
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200

@ -1,4 +1,4 @@
from datetime import datetime, timezone from datetime import UTC, datetime
import pytz import pytz
from flask_login import current_user from flask_login import current_user
@ -314,7 +314,7 @@ def _get_conversation(app_model, conversation_id):
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
if not conversation.read_at: if not conversation.read_at:
conversation.read_at = datetime.now(timezone.utc).replace(tzinfo=None) conversation.read_at = datetime.now(UTC).replace(tzinfo=None)
conversation.read_account_id = current_user.id conversation.read_account_id = current_user.id
db.session.commit() db.session.commit()

@ -1,4 +1,4 @@
from datetime import datetime, timezone from datetime import UTC, datetime
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse
@ -75,7 +75,7 @@ class AppSite(Resource):
setattr(site, attr_name, value) setattr(site, attr_name, value)
site.updated_by = current_user.id site.updated_by = current_user.id
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) site.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
return site return site
@ -99,7 +99,7 @@ class AppSiteAccessTokenReset(Resource):
site.code = Site.generate_code(16) site.code = Site.generate_code(16)
site.updated_by = current_user.id site.updated_by = current_user.id
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) site.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
return site return site

@ -21,7 +21,6 @@ from libs.login import current_user, login_required
from models import App from models import App
from models.account import Account from models.account import Account
from models.model import AppMode from models.model import AppMode
from services.app_dsl_service import AppDslService
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError from services.errors.app import WorkflowHashNotEqualError
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
@ -130,34 +129,6 @@ class DraftWorkflowApi(Resource):
} }
class DraftWorkflowImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields)
def post(self, app_model: App):
"""
Import draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
workflow = AppDslService.import_and_overwrite_workflow(
app_model=app_model, data=args["data"], account=current_user
)
return workflow
class AdvancedChatDraftWorkflowRunApi(Resource): class AdvancedChatDraftWorkflowRunApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -490,7 +461,6 @@ class ConvertToWorkflowApi(Resource):
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft") api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
api.add_resource(DraftWorkflowImportApi, "/apps/<uuid:app_id>/workflows/draft/import")
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run") api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run") api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop") api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")

@ -65,7 +65,7 @@ class ActivateApi(Resource):
account.timezone = args["timezone"] account.timezone = args["timezone"]
account.interface_theme = "light" account.interface_theme = "light"
account.status = AccountStatus.ACTIVE.value account.status = AccountStatus.ACTIVE.value
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))

@ -1,5 +1,5 @@
import logging import logging
from datetime import datetime, timezone from datetime import UTC, datetime
from typing import Optional from typing import Optional
import requests import requests
@ -108,7 +108,7 @@ class OAuthCallback(Resource):
if account.status == AccountStatus.PENDING.value: if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value account.status = AccountStatus.ACTIVE.value
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
try: try:

@ -88,7 +88,7 @@ class DataSourceApi(Resource):
if action == "enable": if action == "enable":
if data_source_binding.disabled: if data_source_binding.disabled:
data_source_binding.disabled = False data_source_binding.disabled = False
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.add(data_source_binding) db.session.add(data_source_binding)
db.session.commit() db.session.commit()
else: else:
@ -97,7 +97,7 @@ class DataSourceApi(Resource):
if action == "disable": if action == "disable":
if not data_source_binding.disabled: if not data_source_binding.disabled:
data_source_binding.disabled = True data_source_binding.disabled = True
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.add(data_source_binding) db.session.add(data_source_binding)
db.session.commit() db.session.commit()
else: else:

@ -1,6 +1,6 @@
import logging import logging
from argparse import ArgumentTypeError from argparse import ArgumentTypeError
from datetime import datetime, timezone from datetime import UTC, datetime
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
@ -681,7 +681,7 @@ class DocumentProcessingApi(DocumentResource):
raise InvalidActionError("Document not in indexing state.") raise InvalidActionError("Document not in indexing state.")
document.paused_by = current_user.id document.paused_by = current_user.id
document.paused_at = datetime.now(timezone.utc).replace(tzinfo=None) document.paused_at = datetime.now(UTC).replace(tzinfo=None)
document.is_paused = True document.is_paused = True
db.session.commit() db.session.commit()
@ -761,7 +761,7 @@ class DocumentMetadataApi(DocumentResource):
document.doc_metadata[key] = value document.doc_metadata[key] = value
document.doc_type = doc_type document.doc_type = doc_type
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
return {"result": "success", "message": "Document metadata updated."}, 200 return {"result": "success", "message": "Document metadata updated."}, 200
@ -803,7 +803,7 @@ class DocumentStatusApi(DocumentResource):
document.enabled = True document.enabled = True
document.disabled_at = None document.disabled_at = None
document.disabled_by = None document.disabled_by = None
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
# Set cache to prevent indexing the same document multiple times # Set cache to prevent indexing the same document multiple times
@ -820,9 +820,9 @@ class DocumentStatusApi(DocumentResource):
raise InvalidActionError("Document already disabled.") raise InvalidActionError("Document already disabled.")
document.enabled = False document.enabled = False
document.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None) document.disabled_at = datetime.now(UTC).replace(tzinfo=None)
document.disabled_by = current_user.id document.disabled_by = current_user.id
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
# Set cache to prevent indexing the same document multiple times # Set cache to prevent indexing the same document multiple times
@ -837,9 +837,9 @@ class DocumentStatusApi(DocumentResource):
raise InvalidActionError("Document already archived.") raise InvalidActionError("Document already archived.")
document.archived = True document.archived = True
document.archived_at = datetime.now(timezone.utc).replace(tzinfo=None) document.archived_at = datetime.now(UTC).replace(tzinfo=None)
document.archived_by = current_user.id document.archived_by = current_user.id
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
if document.enabled: if document.enabled:
@ -856,7 +856,7 @@ class DocumentStatusApi(DocumentResource):
document.archived = False document.archived = False
document.archived_at = None document.archived_at = None
document.archived_by = None document.archived_by = None
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
# Set cache to prevent indexing the same document multiple times # Set cache to prevent indexing the same document multiple times

@ -1,5 +1,5 @@
import uuid import uuid
from datetime import datetime, timezone from datetime import UTC, datetime
import pandas as pd import pandas as pd
from flask import request from flask import request
@ -188,7 +188,7 @@ class DatasetDocumentSegmentApi(Resource):
raise InvalidActionError("Segment is already disabled.") raise InvalidActionError("Segment is already disabled.")
segment.enabled = False segment.enabled = False
segment.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None) segment.disabled_at = datetime.now(UTC).replace(tzinfo=None)
segment.disabled_by = current_user.id segment.disabled_by = current_user.id
db.session.commit() db.session.commit()

@ -1,5 +1,5 @@
import logging import logging
from datetime import datetime, timezone from datetime import UTC, datetime
from flask_login import current_user from flask_login import current_user
from flask_restful import reqparse from flask_restful import reqparse
@ -46,7 +46,7 @@ class CompletionApi(InstalledAppResource):
streaming = args["response_mode"] == "streaming" streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False args["auto_generate_name"] = False
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
try: try:
@ -106,7 +106,7 @@ class ChatApi(InstalledAppResource):
args["auto_generate_name"] = False args["auto_generate_name"] = False
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
try: try:

@ -1,4 +1,4 @@
from datetime import datetime, timezone from datetime import UTC, datetime
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, inputs, marshal_with, reqparse from flask_restful import Resource, inputs, marshal_with, reqparse
@ -81,7 +81,7 @@ class InstalledAppsListApi(Resource):
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
app_owner_tenant_id=app.tenant_id, app_owner_tenant_id=app.tenant_id,
is_pinned=False, is_pinned=False,
last_used_at=datetime.now(timezone.utc).replace(tzinfo=None), last_used_at=datetime.now(UTC).replace(tzinfo=None),
) )
db.session.add(new_installed_app) db.session.add(new_installed_app)
db.session.commit() db.session.commit()

@ -60,7 +60,7 @@ class AccountInitApi(Resource):
raise InvalidInvitationCodeError() raise InvalidInvitationCodeError()
invitation_code.status = "used" invitation_code.status = "used"
invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) invitation_code.used_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_tenant_id = account.current_tenant_id
invitation_code.used_by_account_id = account.id invitation_code.used_by_account_id = account.id
@ -68,7 +68,7 @@ class AccountInitApi(Resource):
account.timezone = args["timezone"] account.timezone = args["timezone"]
account.interface_theme = "light" account.interface_theme = "light"
account.status = "active" account.status = "active"
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
return {"result": "success"} return {"result": "success"}

@ -305,6 +305,20 @@ class PluginDeleteInstallTaskApi(Resource):
raise ValueError(e) raise ValueError(e)
class PluginDeleteAllInstallTaskItemsApi(Resource):
@setup_required
@login_required
@account_initialization_required
@plugin_permission_required(debug_required=True)
def post(self):
tenant_id = current_user.current_tenant_id
try:
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
except PluginDaemonBadRequestError as e:
raise ValueError(e)
class PluginDeleteInstallTaskItemApi(Resource): class PluginDeleteInstallTaskItemApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -453,6 +467,7 @@ api.add_resource(PluginFetchManifestApi, "/workspaces/current/plugin/fetch-manif
api.add_resource(PluginFetchInstallTasksApi, "/workspaces/current/plugin/tasks") api.add_resource(PluginFetchInstallTasksApi, "/workspaces/current/plugin/tasks")
api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>") api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>")
api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>/delete") api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>/delete")
api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all")
api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>") api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall") api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")

@ -1,5 +1,5 @@
from collections.abc import Callable from collections.abc import Callable
from datetime import datetime, timezone from datetime import UTC, datetime
from enum import Enum from enum import Enum
from functools import wraps from functools import wraps
from typing import Optional from typing import Optional
@ -198,7 +198,7 @@ def validate_and_get_api_token(scope=None):
if not api_token: if not api_token:
raise Unauthorized("Access token is invalid") raise Unauthorized("Access token is invalid")
api_token.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
return api_token return api_token

@ -1,3 +1,4 @@
import uuid
from typing import Optional from typing import Optional
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity

@ -12,7 +12,7 @@ from core.provider_manager import ProviderManager
class ModelConfigConverter: class ModelConfigConverter:
@classmethod @classmethod
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity: def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity:
""" """
Convert app model config dict to entity. Convert app model config dict to entity.
:param app_config: app config :param app_config: app config
@ -39,27 +39,23 @@ class ModelConfigConverter:
) )
if model_credentials is None: if model_credentials is None:
if not skip_check: raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
else: # check model
model_credentials = {} provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_config.model, model_type=ModelType.LLM
if not skip_check: )
# check model
provider_model = provider_model_bundle.configuration.get_provider_model( if provider_model is None:
model=model_config.model, model_type=ModelType.LLM model_name = model_config.model
) raise ValueError(f"Model {model_name} not exist.")
if provider_model is None: if provider_model.status == ModelStatus.NO_CONFIGURE:
model_name = model_config.model raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
raise ValueError(f"Model {model_name} not exist.") elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
if provider_model.status == ModelStatus.NO_CONFIGURE: elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config # model config
completion_params = model_config.parameters completion_params = model_config.parameters
@ -77,7 +73,7 @@ class ModelConfigConverter:
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE): if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
model_mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]).value model_mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]).value
if not skip_check and not model_schema: if not model_schema:
raise ValueError(f"Model {model_name} not exist.") raise ValueError(f"Model {model_name} not exist.")
return ModelConfigWithCredentialsEntity( return ModelConfigWithCredentialsEntity(

@ -1,4 +1,5 @@
from core.app.app_config.entities import ( from core.app.app_config.entities import (
AdvancedChatMessageEntity,
AdvancedChatPromptTemplateEntity, AdvancedChatPromptTemplateEntity,
AdvancedCompletionPromptTemplateEntity, AdvancedCompletionPromptTemplateEntity,
PromptTemplateEntity, PromptTemplateEntity,
@ -25,7 +26,9 @@ class PromptTemplateConfigManager:
chat_prompt_messages = [] chat_prompt_messages = []
for message in chat_prompt_config.get("prompt", []): for message in chat_prompt_config.get("prompt", []):
chat_prompt_messages.append( chat_prompt_messages.append(
{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])} AdvancedChatMessageEntity(
**{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
)
) )
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages) advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)

@ -1,5 +1,5 @@
from collections.abc import Sequence from collections.abc import Sequence
from enum import Enum from enum import Enum, StrEnum
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@ -88,7 +88,7 @@ class PromptTemplateEntity(BaseModel):
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
class VariableEntityType(str, Enum): class VariableEntityType(StrEnum):
TEXT_INPUT = "text-input" TEXT_INPUT = "text-input"
SELECT = "select" SELECT = "select"
PARAGRAPH = "paragraph" PARAGRAPH = "paragraph"

@ -138,7 +138,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation_id=conversation.id if conversation else None, conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs inputs=conversation.inputs
if conversation if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), else self._prepare_user_inputs(user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id),
query=query, query=query,
files=file_objs, files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

@ -139,7 +139,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
conversation_id=conversation.id if conversation else None, conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs inputs=conversation.inputs
if conversation if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), else self._prepare_user_inputs(user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id),
query=query, query=query,
files=file_objs, files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

@ -1,5 +1,5 @@
import json import json
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from core.app.app_config.entities import VariableEntityType from core.app.app_config.entities import VariableEntityType
@ -7,7 +7,7 @@ from core.file import File, FileUploadConfig
from factories import file_factory from factories import file_factory
if TYPE_CHECKING: if TYPE_CHECKING:
from core.app.app_config.entities import AppConfig, VariableEntity from core.app.app_config.entities import VariableEntity
class BaseAppGenerator: class BaseAppGenerator:
@ -15,23 +15,23 @@ class BaseAppGenerator:
self, self,
*, *,
user_inputs: Optional[Mapping[str, Any]], user_inputs: Optional[Mapping[str, Any]],
app_config: "AppConfig", variables: Sequence["VariableEntity"],
tenant_id: str,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
user_inputs = user_inputs or {} user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values # Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables
user_inputs = { user_inputs = {
var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var) var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var)
for var in variables for var in variables
} }
user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()} user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()}
# Convert files in inputs to File # Convert files in inputs to File
entity_dictionary = {item.variable: item for item in app_config.variables} entity_dictionary = {item.variable: item for item in variables}
# Convert single file to File # Convert single file to File
files_inputs = { files_inputs = {
k: file_factory.build_from_mapping( k: file_factory.build_from_mapping(
mapping=v, mapping=v,
tenant_id=app_config.tenant_id, tenant_id=tenant_id,
config=FileUploadConfig( config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
@ -45,7 +45,7 @@ class BaseAppGenerator:
file_list_inputs = { file_list_inputs = {
k: file_factory.build_from_mappings( k: file_factory.build_from_mappings(
mappings=v, mappings=v,
tenant_id=app_config.tenant_id, tenant_id=tenant_id,
config=FileUploadConfig( config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,

@ -142,7 +142,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
conversation_id=conversation.id if conversation else None, conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs inputs=conversation.inputs
if conversation if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), else self._prepare_user_inputs(user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id),
query=query, query=query,
files=file_objs, files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

@ -123,7 +123,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
app_config=app_config, app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config), model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config, file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.id
),
query=query, query=query,
files=file_objs, files=file_objs,
user_id=user.id, user_id=user.id,

@ -1,7 +1,7 @@
import json import json
import logging import logging
from collections.abc import Generator from collections.abc import Generator
from datetime import datetime, timezone from datetime import UTC, datetime
from typing import Optional, Union from typing import Optional, Union
from sqlalchemy import and_ from sqlalchemy import and_
@ -200,7 +200,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
db.session.commit() db.session.commit()
db.session.refresh(conversation) db.session.refresh(conversation)
else: else:
conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) conversation.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
message = Message( message = Message(

@ -108,7 +108,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
task_id=str(uuid.uuid4()), task_id=str(uuid.uuid4()),
app_config=app_config, app_config=app_config,
file_upload_config=file_extra_config, file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
files=system_files, files=system_files,
user_id=user.id, user_id=user.id,
stream=stream, stream=stream,

@ -43,7 +43,6 @@ from core.workflow.graph_engine.entities.event import (
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.iteration import IterationNodeData
from core.workflow.nodes.node_mapping import node_type_classes_mapping from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
@ -160,8 +159,6 @@ class WorkflowBasedAppRunner(AppRunner):
user_inputs=user_inputs, user_inputs=user_inputs,
variable_pool=variable_pool, variable_pool=variable_pool,
tenant_id=workflow.tenant_id, tenant_id=workflow.tenant_id,
node_type=node_type,
node_data=IterationNodeData(**iteration_node_config.get("data", {})),
) )
return graph, variable_pool return graph, variable_pool

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum, StrEnum
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
@ -11,7 +11,7 @@ from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.base import BaseNodeData
class QueueEvent(str, Enum): class QueueEvent(StrEnum):
""" """
QueueEvent enum QueueEvent enum
""" """

@ -1,7 +1,7 @@
import json import json
import time import time
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from datetime import datetime, timezone from datetime import UTC, datetime
from typing import Any, Optional, Union, cast from typing import Any, Optional, Union, cast
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -144,7 +144,7 @@ class WorkflowCycleManage:
workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
db.session.refresh(workflow_run) db.session.refresh(workflow_run)
@ -191,7 +191,7 @@ class WorkflowCycleManage:
workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
@ -211,15 +211,18 @@ class WorkflowCycleManage:
for workflow_node_execution in running_workflow_node_executions: for workflow_node_execution in running_workflow_node_executions:
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error workflow_node_execution.error = error
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.elapsed_time = ( workflow_node_execution.elapsed_time = (
workflow_node_execution.finished_at - workflow_node_execution.created_at workflow_node_execution.finished_at - workflow_node_execution.created_at
).total_seconds() ).total_seconds()
db.session.commit() db.session.commit()
db.session.refresh(workflow_run)
db.session.close() db.session.close()
with Session(db.engine, expire_on_commit=False) as session:
session.add(workflow_run)
session.refresh(workflow_run)
if trace_manager: if trace_manager:
trace_manager.add_trace_task( trace_manager.add_trace_task(
TraceTask( TraceTask(
@ -259,7 +262,7 @@ class WorkflowCycleManage:
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
} }
) )
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
session.add(workflow_node_execution) session.add(workflow_node_execution)
session.commit() session.commit()
@ -282,7 +285,7 @@ class WorkflowCycleManage:
execution_metadata = ( execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
) )
finished_at = datetime.now(timezone.utc).replace(tzinfo=None) finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds() elapsed_time = (finished_at - event.start_at).total_seconds()
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
@ -326,7 +329,7 @@ class WorkflowCycleManage:
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data) process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs) outputs = WorkflowEntry.handle_special_values(event.outputs)
finished_at = datetime.now(timezone.utc).replace(tzinfo=None) finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds() elapsed_time = (finished_at - event.start_at).total_seconds()
execution_metadata = ( execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
@ -654,7 +657,7 @@ class WorkflowCycleManage:
if event.error is None if event.error is None
else WorkflowNodeExecutionStatus.FAILED, else WorkflowNodeExecutionStatus.FAILED,
error=None, error=None,
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(), elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
execution_metadata=event.metadata, execution_metadata=event.metadata,
finished_at=int(time.time()), finished_at=int(time.time()),

@ -246,7 +246,7 @@ class ProviderConfiguration(BaseModel):
if provider_record: if provider_record:
provider_record.encrypted_config = json.dumps(credentials) provider_record.encrypted_config = json.dumps(credentials)
provider_record.is_valid = True provider_record.is_valid = True
provider_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
provider_record = Provider() provider_record = Provider()
@ -401,7 +401,7 @@ class ProviderConfiguration(BaseModel):
if provider_model_record: if provider_model_record:
provider_model_record.encrypted_config = json.dumps(credentials) provider_model_record.encrypted_config = json.dumps(credentials)
provider_model_record.is_valid = True provider_model_record.is_valid = True
provider_model_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
provider_model_record = ProviderModel() provider_model_record = ProviderModel()
@ -474,7 +474,7 @@ class ProviderConfiguration(BaseModel):
if model_setting: if model_setting:
model_setting.enabled = True model_setting.enabled = True
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
model_setting = ProviderModelSetting() model_setting = ProviderModelSetting()
@ -508,7 +508,7 @@ class ProviderConfiguration(BaseModel):
if model_setting: if model_setting:
model_setting.enabled = False model_setting.enabled = False
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
model_setting = ProviderModelSetting() model_setting = ProviderModelSetting()
@ -574,7 +574,7 @@ class ProviderConfiguration(BaseModel):
if model_setting: if model_setting:
model_setting.load_balancing_enabled = True model_setting.load_balancing_enabled = True
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
model_setting = ProviderModelSetting() model_setting = ProviderModelSetting()
@ -608,7 +608,7 @@ class ProviderConfiguration(BaseModel):
if model_setting: if model_setting:
model_setting.load_balancing_enabled = False model_setting.load_balancing_enabled = False
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
model_setting = ProviderModelSetting() model_setting = ProviderModelSetting()

@ -1,7 +1,7 @@
from enum import Enum from enum import StrEnum
class FileType(str, Enum): class FileType(StrEnum):
IMAGE = "image" IMAGE = "image"
DOCUMENT = "document" DOCUMENT = "document"
AUDIO = "audio" AUDIO = "audio"
@ -16,7 +16,7 @@ class FileType(str, Enum):
raise ValueError(f"No matching enum found for value '{value}'") raise ValueError(f"No matching enum found for value '{value}'")
class FileTransferMethod(str, Enum): class FileTransferMethod(StrEnum):
REMOTE_URL = "remote_url" REMOTE_URL = "remote_url"
LOCAL_FILE = "local_file" LOCAL_FILE = "local_file"
TOOL_FILE = "tool_file" TOOL_FILE = "tool_file"
@ -29,7 +29,7 @@ class FileTransferMethod(str, Enum):
raise ValueError(f"No matching enum found for value '{value}'") raise ValueError(f"No matching enum found for value '{value}'")
class FileBelongsTo(str, Enum): class FileBelongsTo(StrEnum):
USER = "user" USER = "user"
ASSISTANT = "assistant" ASSISTANT = "assistant"
@ -41,7 +41,7 @@ class FileBelongsTo(str, Enum):
raise ValueError(f"No matching enum found for value '{value}'") raise ValueError(f"No matching enum found for value '{value}'")
class FileAttribute(str, Enum): class FileAttribute(StrEnum):
TYPE = "type" TYPE = "type"
SIZE = "size" SIZE = "size"
NAME = "name" NAME = "name"
@ -51,5 +51,5 @@ class FileAttribute(str, Enum):
EXTENSION = "extension" EXTENSION = "extension"
class ArrayFileAttribute(str, Enum): class ArrayFileAttribute(StrEnum):
LENGTH = "length" LENGTH = "length"

@ -3,7 +3,12 @@ import base64
from configs import dify_config from configs import dify_config
from core.file import file_repository from core.file import file_repository
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent from core.model_runtime.entities import (
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
VideoPromptMessageContent,
)
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_storage import storage from extensions.ext_storage import storage
@ -29,35 +34,17 @@ def get_attr(*, file: File, attr: FileAttribute):
return file.remote_url return file.remote_url
case FileAttribute.EXTENSION: case FileAttribute.EXTENSION:
return file.extension return file.extension
case _:
raise ValueError(f"Invalid file attribute: {attr}")
def to_prompt_message_content( def to_prompt_message_content(
f: File, f: File,
/, /,
*, *,
image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW, image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
): ):
"""
Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object.
This function takes a File object and converts it to an appropriate PromptMessageContent
object, which can be used as a prompt for image or audio-based AI models.
Args:
f (File): The File object to convert.
detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts.
If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW.
Returns:
Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level
Raises:
ValueError: If the file type is not supported or if required data is missing.
"""
match f.type: match f.type:
case FileType.IMAGE: case FileType.IMAGE:
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url": if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
data = _to_url(f) data = _to_url(f)
else: else:
@ -65,7 +52,7 @@ def to_prompt_message_content(
return ImagePromptMessageContent(data=data, detail=image_detail_config) return ImagePromptMessageContent(data=data, detail=image_detail_config)
case FileType.AUDIO: case FileType.AUDIO:
encoded_string = _file_to_encoded_string(f) encoded_string = _get_encoded_string(f)
if f.extension is None: if f.extension is None:
raise ValueError("Missing file extension") raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
@ -74,9 +61,20 @@ def to_prompt_message_content(
data = _to_url(f) data = _to_url(f)
else: else:
data = _to_base64_data_string(f) data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.DOCUMENT:
data = _get_encoded_string(f)
if f.mime_type is None:
raise ValueError("Missing file mime_type")
return DocumentPromptMessageContent(
encode_format="base64",
mime_type=f.mime_type,
data=data,
)
case _: case _:
raise ValueError("file type f.type is not supported") raise ValueError(f"file type {f.type} is not supported")
def download(f: File, /): def download(f: File, /):
@ -118,21 +116,16 @@ def _get_encoded_string(f: File, /):
case FileTransferMethod.REMOTE_URL: case FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url, follow_redirects=True) response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status() response.raise_for_status()
content = response.content data = response.content
encoded_string = base64.b64encode(content).decode("utf-8")
return encoded_string
case FileTransferMethod.LOCAL_FILE: case FileTransferMethod.LOCAL_FILE:
upload_file = file_repository.get_upload_file(session=db.session(), file=f) upload_file = file_repository.get_upload_file(session=db.session(), file=f)
data = _download_file_content(upload_file.key) data = _download_file_content(upload_file.key)
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string
case FileTransferMethod.TOOL_FILE: case FileTransferMethod.TOOL_FILE:
tool_file = file_repository.get_tool_file(session=db.session(), file=f) tool_file = file_repository.get_tool_file(session=db.session(), file=f)
data = _download_file_content(tool_file.file_key) data = _download_file_content(tool_file.file_key)
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string encoded_string = base64.b64encode(data).decode("utf-8")
case _: return encoded_string
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
def _to_base64_data_string(f: File, /): def _to_base64_data_string(f: File, /):
@ -140,18 +133,6 @@ def _to_base64_data_string(f: File, /):
return f"data:{f.mime_type};base64,{encoded_string}" return f"data:{f.mime_type};base64,{encoded_string}"
def _file_to_encoded_string(f: File, /):
match f.type:
case FileType.IMAGE:
return _to_base64_data_string(f)
case FileType.VIDEO:
return _to_base64_data_string(f)
case FileType.AUDIO:
return _get_encoded_string(f)
case _:
raise ValueError(f"file type {f.type} is not supported")
def _to_url(f: File, /): def _to_url(f: File, /):
if f.transfer_method == FileTransferMethod.REMOTE_URL: if f.transfer_method == FileTransferMethod.REMOTE_URL:
if f.remote_url is None: if f.remote_url is None:

@ -1,6 +1,6 @@
import logging import logging
from collections.abc import Mapping from collections.abc import Mapping
from enum import Enum from enum import StrEnum
from threading import Lock from threading import Lock
from typing import Any, Optional from typing import Any, Optional
@ -31,7 +31,7 @@ class CodeExecutionResponse(BaseModel):
data: Data data: Data
class CodeLanguage(str, Enum): class CodeLanguage(StrEnum):
PYTHON3 = "python3" PYTHON3 = "python3"
JINJA2 = "jinja2" JINJA2 = "jinja2"
JAVASCRIPT = "javascript" JAVASCRIPT = "javascript"

@ -86,7 +86,7 @@ class IndexingRunner:
except ProviderTokenNotInitError as e: except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error" dataset_document.indexing_status = "error"
dataset_document.error = str(e.description) dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
except ObjectDeletedError: except ObjectDeletedError:
logging.warning("Document deleted, document id: {}".format(dataset_document.id)) logging.warning("Document deleted, document id: {}".format(dataset_document.id))
@ -94,7 +94,7 @@ class IndexingRunner:
logging.exception("consume document failed") logging.exception("consume document failed")
dataset_document.indexing_status = "error" dataset_document.indexing_status = "error"
dataset_document.error = str(e) dataset_document.error = str(e)
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
def run_in_splitting_status(self, dataset_document: DatasetDocument): def run_in_splitting_status(self, dataset_document: DatasetDocument):
@ -142,13 +142,13 @@ class IndexingRunner:
except ProviderTokenNotInitError as e: except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error" dataset_document.indexing_status = "error"
dataset_document.error = str(e.description) dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
except Exception as e: except Exception as e:
logging.exception("consume document failed") logging.exception("consume document failed")
dataset_document.indexing_status = "error" dataset_document.indexing_status = "error"
dataset_document.error = str(e) dataset_document.error = str(e)
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
def run_in_indexing_status(self, dataset_document: DatasetDocument): def run_in_indexing_status(self, dataset_document: DatasetDocument):
@ -200,13 +200,13 @@ class IndexingRunner:
except ProviderTokenNotInitError as e: except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error" dataset_document.indexing_status = "error"
dataset_document.error = str(e.description) dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
except Exception as e: except Exception as e:
logging.exception("consume document failed") logging.exception("consume document failed")
dataset_document.indexing_status = "error" dataset_document.indexing_status = "error"
dataset_document.error = str(e) dataset_document.error = str(e)
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
def indexing_estimate( def indexing_estimate(
@ -372,7 +372,7 @@ class IndexingRunner:
after_indexing_status="splitting", after_indexing_status="splitting",
extra_update_params={ extra_update_params={
DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs), DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs),
DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}, },
) )
@ -464,7 +464,7 @@ class IndexingRunner:
doc_store.add_documents(documents) doc_store.add_documents(documents)
# update document status to indexing # update document status to indexing
cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
self._update_document_index_status( self._update_document_index_status(
document_id=dataset_document.id, document_id=dataset_document.id,
after_indexing_status="indexing", after_indexing_status="indexing",
@ -479,7 +479,7 @@ class IndexingRunner:
dataset_document_id=dataset_document.id, dataset_document_id=dataset_document.id,
update_params={ update_params={
DocumentSegment.status: "indexing", DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}, },
) )
@ -680,7 +680,7 @@ class IndexingRunner:
after_indexing_status="completed", after_indexing_status="completed",
extra_update_params={ extra_update_params={
DatasetDocument.tokens: tokens, DatasetDocument.tokens: tokens,
DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DatasetDocument.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at, DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
DatasetDocument.error: None, DatasetDocument.error: None,
}, },
@ -705,7 +705,7 @@ class IndexingRunner:
{ {
DocumentSegment.status: "completed", DocumentSegment.status: "completed",
DocumentSegment.enabled: True, DocumentSegment.enabled: True,
DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
} }
) )
@ -738,7 +738,7 @@ class IndexingRunner:
{ {
DocumentSegment.status: "completed", DocumentSegment.status: "completed",
DocumentSegment.enabled: True, DocumentSegment.enabled: True,
DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
} }
) )
@ -849,7 +849,7 @@ class IndexingRunner:
doc_store.add_documents(documents) doc_store.add_documents(documents)
# update document status to indexing # update document status to indexing
cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
self._update_document_index_status( self._update_document_index_status(
document_id=dataset_document.id, document_id=dataset_document.id,
after_indexing_status="indexing", after_indexing_status="indexing",
@ -864,7 +864,7 @@ class IndexingRunner:
dataset_document_id=dataset_document.id, dataset_document_id=dataset_document.id,
update_params={ update_params={
DocumentSegment.status: "indexing", DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}, },
) )
pass pass

@ -1,8 +1,8 @@
from collections.abc import Sequence
from typing import Optional from typing import Optional
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file import file_manager from core.file import file_manager
from core.file.models import FileType
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities import ( from core.model_runtime.entities import (
AssistantPromptMessage, AssistantPromptMessage,
@ -27,7 +27,7 @@ class TokenBufferMemory:
def get_history_prompt_messages( def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: Optional[int] = None self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> list[PromptMessage]: ) -> Sequence[PromptMessage]:
""" """
Get history prompt messages. Get history prompt messages.
:param max_token_limit: max token limit :param max_token_limit: max token limit
@ -102,12 +102,11 @@ class TokenBufferMemory:
prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query)) prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file in file_objs: for file in file_objs:
if file.type in {FileType.IMAGE, FileType.AUDIO}: prompt_message = file_manager.to_prompt_message_content(
prompt_message = file_manager.to_prompt_message_content( file,
file, image_detail_config=detail,
image_detail_config=detail, )
) prompt_message_contents.append(prompt_message)
prompt_message_contents.append(prompt_message)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents))

@ -136,10 +136,10 @@ class ModelInstance:
def invoke_llm( def invoke_llm(
self, self,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: Optional[dict] = None, model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None, tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Optional from typing import Optional
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
@ -31,7 +32,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> None: ) -> None:
@ -60,7 +61,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
): ):
@ -90,7 +91,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> None: ) -> None:
@ -120,7 +121,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> None: ) -> None:

@ -2,6 +2,7 @@ from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsa
from .message_entities import ( from .message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
AudioPromptMessageContent, AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage, PromptMessage,
PromptMessageContent, PromptMessageContent,
@ -37,4 +38,5 @@ __all__ = [
"LLMResultChunk", "LLMResultChunk",
"LLMResultChunkDelta", "LLMResultChunkDelta",
"AudioPromptMessageContent", "AudioPromptMessageContent",
"DocumentPromptMessageContent",
] ]

@ -1,6 +1,7 @@
from abc import ABC from abc import ABC
from enum import Enum from collections.abc import Sequence
from typing import Optional from enum import Enum, StrEnum
from typing import Literal, Optional
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@ -48,7 +49,7 @@ class PromptMessageFunction(BaseModel):
function: PromptMessageTool function: PromptMessageTool
class PromptMessageContentType(Enum): class PromptMessageContentType(StrEnum):
""" """
Enum class for prompt message content type. Enum class for prompt message content type.
""" """
@ -57,6 +58,7 @@ class PromptMessageContentType(Enum):
IMAGE = "image" IMAGE = "image"
AUDIO = "audio" AUDIO = "audio"
VIDEO = "video" VIDEO = "video"
DOCUMENT = "document"
class PromptMessageContent(BaseModel): class PromptMessageContent(BaseModel):
@ -93,7 +95,7 @@ class ImagePromptMessageContent(PromptMessageContent):
Model class for image prompt message content. Model class for image prompt message content.
""" """
class DETAIL(str, Enum): class DETAIL(StrEnum):
LOW = "low" LOW = "low"
HIGH = "high" HIGH = "high"
@ -101,13 +103,20 @@ class ImagePromptMessageContent(PromptMessageContent):
detail: DETAIL = DETAIL.LOW detail: DETAIL = DETAIL.LOW
class DocumentPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
encode_format: Literal["base64"]
mime_type: str
data: str
class PromptMessage(ABC, BaseModel): class PromptMessage(ABC, BaseModel):
""" """
Model class for prompt message. Model class for prompt message.
""" """
role: PromptMessageRole role: PromptMessageRole
content: Optional[str | list[PromptMessageContent]] = None content: Optional[str | Sequence[PromptMessageContent]] = None
name: Optional[str] = None name: Optional[str] = None
def is_empty(self) -> bool: def is_empty(self) -> bool:

@ -1,5 +1,5 @@
from decimal import Decimal from decimal import Decimal
from enum import Enum from enum import Enum, StrEnum
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -82,9 +82,12 @@ class ModelFeature(Enum):
AGENT_THOUGHT = "agent-thought" AGENT_THOUGHT = "agent-thought"
VISION = "vision" VISION = "vision"
STREAM_TOOL_CALL = "stream-tool-call" STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"
class DefaultParameterName(str, Enum): class DefaultParameterName(StrEnum):
""" """
Enum class for parameter template variable. Enum class for parameter template variable.
""" """

@ -1,6 +1,6 @@
import logging import logging
import time import time
from collections.abc import Generator from collections.abc import Generator, Sequence
from typing import Optional, Union from typing import Optional, Union
from pydantic import ConfigDict from pydantic import ConfigDict
@ -41,7 +41,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None, model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
@ -96,7 +96,7 @@ class LargeLanguageModel(AIModel):
model_parameters=model_parameters, model_parameters=model_parameters,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
tools=tools, tools=tools,
stop=stop, stop=list(stop) if stop else None,
stream=stream, stream=stream,
) )
@ -176,7 +176,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
@ -318,7 +318,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
@ -364,7 +364,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
@ -411,7 +411,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
@ -459,7 +459,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import StrEnum
from typing import Any, Optional, Union from typing import Any, Optional, Union
from pydantic import BaseModel, ConfigDict, field_validator from pydantic import BaseModel, ConfigDict, field_validator
@ -122,7 +122,7 @@ trace_info_info_map = {
} }
class TraceTaskName(str, Enum): class TraceTaskName(StrEnum):
CONVERSATION_TRACE = "conversation" CONVERSATION_TRACE = "conversation"
WORKFLOW_TRACE = "workflow" WORKFLOW_TRACE = "workflow"
MESSAGE_TRACE = "message" MESSAGE_TRACE = "message"

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import StrEnum
from typing import Any, Optional, Union from typing import Any, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, field_validator from pydantic import BaseModel, ConfigDict, Field, field_validator
@ -39,7 +39,7 @@ def validate_input_output(v, field_name):
return v return v
class LevelEnum(str, Enum): class LevelEnum(StrEnum):
DEBUG = "DEBUG" DEBUG = "DEBUG"
WARNING = "WARNING" WARNING = "WARNING"
ERROR = "ERROR" ERROR = "ERROR"
@ -178,7 +178,7 @@ class LangfuseSpan(BaseModel):
return validate_input_output(v, field_name) return validate_input_output(v, field_name)
class UnitEnum(str, Enum): class UnitEnum(StrEnum):
CHARACTERS = "CHARACTERS" CHARACTERS = "CHARACTERS"
TOKENS = "TOKENS" TOKENS = "TOKENS"
SECONDS = "SECONDS" SECONDS = "SECONDS"

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import StrEnum
from typing import Any, Optional, Union from typing import Any, Optional, Union
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@ -8,7 +8,7 @@ from pydantic_core.core_schema import ValidationInfo
from core.ops.utils import replace_text_with_content from core.ops.utils import replace_text_with_content
class LangSmithRunType(str, Enum): class LangSmithRunType(StrEnum):
tool = "tool" tool = "tool"
chain = "chain" chain = "chain"
llm = "llm" llm = "llm"

@ -1,4 +1,4 @@
from enum import Enum from enum import StrEnum
from pydantic import BaseModel from pydantic import BaseModel
@ -6,7 +6,7 @@ from core.plugin.entities.plugin import PluginDeclaration, PluginInstallationSou
class PluginBundleDependency(BaseModel): class PluginBundleDependency(BaseModel):
class Type(str, Enum): class Type(StrEnum):
Github = PluginInstallationSource.Github.value Github = PluginInstallationSource.Github.value
Marketplace = PluginInstallationSource.Marketplace.value Marketplace = PluginInstallationSource.Marketplace.value
Package = PluginInstallationSource.Package.value Package = PluginInstallationSource.Package.value

@ -1,7 +1,7 @@
import datetime import datetime
import enum
import re import re
from collections.abc import Mapping from collections.abc import Mapping
from enum import Enum
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
@ -13,7 +13,7 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderEntity from core.tools.entities.tool_entities import ToolProviderEntity
class PluginInstallationSource(str, Enum): class PluginInstallationSource(enum.StrEnum):
Github = "github" Github = "github"
Marketplace = "marketplace" Marketplace = "marketplace"
Package = "package" Package = "package"
@ -55,7 +55,7 @@ class PluginResourceRequirements(BaseModel):
permission: Optional[Permission] permission: Optional[Permission]
class PluginCategory(str, Enum): class PluginCategory(enum.StrEnum):
Tool = "tool" Tool = "tool"
Model = "model" Model = "model"
Extension = "extension" Extension = "extension"
@ -163,7 +163,7 @@ class GenericProviderID:
class PluginDependency(BaseModel): class PluginDependency(BaseModel):
class Type(str, Enum): class Type(enum.StrEnum):
Github = PluginInstallationSource.Github.value Github = PluginInstallationSource.Github.value
Marketplace = PluginInstallationSource.Marketplace.value Marketplace = PluginInstallationSource.Marketplace.value
Package = PluginInstallationSource.Package.value Package = PluginInstallationSource.Package.value

@ -1,3 +1,4 @@
import enum
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Generic, Optional, TypeVar from typing import Generic, Optional, TypeVar
@ -119,7 +120,7 @@ class PluginDaemonInnerError(Exception):
self.message = message self.message = message
class PluginInstallTaskStatus(str, Enum): class PluginInstallTaskStatus(enum.StrEnum):
Pending = "pending" Pending = "pending"
Running = "running" Running = "running"
Success = "success" Success = "success"

@ -126,6 +126,16 @@ class PluginInstallationManager(BasePluginManager):
bool, bool,
) )
def delete_all_plugin_installation_task_items(self, tenant_id: str) -> bool:
"""
Delete all plugin installation task items.
"""
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/management/install/tasks/delete_all",
bool,
)
def delete_plugin_installation_task_item(self, tenant_id: str, task_id: str, identifier: str) -> bool: def delete_plugin_installation_task_item(self, tenant_id: str, task_id: str, identifier: str) -> bool:
""" """
Delete a plugin installation task item. Delete a plugin installation task item.

@ -23,7 +23,7 @@ if TYPE_CHECKING:
from core.file.models import File from core.file.models import File
class ModelMode(str, enum.Enum): class ModelMode(enum.StrEnum):
COMPLETION = "completion" COMPLETION = "completion"
CHAT = "chat" CHAT = "chat"

@ -1,3 +1,4 @@
from collections.abc import Sequence
from typing import cast from typing import cast
from core.model_runtime.entities import ( from core.model_runtime.entities import (
@ -14,7 +15,7 @@ from core.prompt.simple_prompt_transform import ModelMode
class PromptMessageUtil: class PromptMessageUtil:
@staticmethod @staticmethod
def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]: def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]) -> list[dict]:
""" """
Prompt messages to prompt for saving. Prompt messages to prompt for saving.
:param model_mode: model mode :param model_mode: model mode

@ -12,7 +12,7 @@ class CleanProcessor:
# Unicode U+FFFE # Unicode U+FFFE
text = re.sub("\ufffe", "", text) text = re.sub("\ufffe", "", text)
rules = process_rule["rules"] if process_rule else None rules = process_rule["rules"] if process_rule else {}
if "pre_processing_rules" in rules: if "pre_processing_rules" in rules:
pre_processing_rules = rules["pre_processing_rules"] pre_processing_rules = rules["pre_processing_rules"]
for pre_processing_rule in pre_processing_rules: for pre_processing_rule in pre_processing_rules:

@ -1,5 +1,5 @@
from enum import Enum from enum import StrEnum
class KeyWordType(str, Enum): class KeyWordType(StrEnum):
JIEBA = "jieba" JIEBA = "jieba"

@ -1,7 +1,7 @@
from enum import Enum from enum import StrEnum
class VectorType(str, Enum): class VectorType(StrEnum):
ANALYTICDB = "analyticdb" ANALYTICDB = "analyticdb"
CHROMA = "chroma" CHROMA = "chroma"
MILVUS = "milvus" MILVUS = "milvus"

@ -114,10 +114,10 @@ class WordExtractor(BaseExtractor):
mime_type=mime_type or "", mime_type=mime_type or "",
created_by=self.user_id, created_by=self.user_id,
created_by_role=CreatedByRole.ACCOUNT, created_by_role=CreatedByRole.ACCOUNT,
created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
used=True, used=True,
used_by=self.user_id, used_by=self.user_id,
used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), used_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
) )
db.session.add(upload_file) db.session.add(upload_file)

@ -27,11 +27,11 @@ class RerankModelRunner(BaseRerankRunner):
:return: :return:
""" """
docs = [] docs = []
doc_id = set() doc_ids = set()
unique_documents = [] unique_documents = []
for document in documents: for document in documents:
if document.provider == "dify" and document.metadata["doc_id"] not in doc_id: if document.provider == "dify" and document.metadata["doc_id"] not in doc_ids:
doc_id.add(document.metadata["doc_id"]) doc_ids.add(document.metadata["doc_id"])
docs.append(document.page_content) docs.append(document.page_content)
unique_documents.append(document) unique_documents.append(document)
elif document.provider == "external": elif document.provider == "external":

@ -1,6 +1,6 @@
from enum import Enum from enum import StrEnum
class RerankMode(str, Enum): class RerankMode(StrEnum):
RERANKING_MODEL = "reranking_model" RERANKING_MODEL = "reranking_model"
WEIGHTED_SCORE = "weighted_score" WEIGHTED_SCORE = "weighted_score"

@ -37,11 +37,10 @@ class WeightRerankRunner(BaseRerankRunner):
:return: :return:
""" """
unique_documents = [] unique_documents = []
doc_id = set() doc_ids = set()
for document in documents: for document in documents:
doc_id = document.metadata.get("doc_id") if document.metadata["doc_id"] not in doc_ids:
if doc_id not in doc_id: doc_ids.add(document.metadata["doc_id"])
doc_id.add(doc_id)
unique_documents.append(document) unique_documents.append(document)
documents = unique_documents documents = unique_documents

@ -1,4 +1,4 @@
from datetime import datetime, timezone from datetime import UTC, datetime
from typing import Any, Optional, Union from typing import Any, Optional, Union
from pytz import timezone as pytz_timezone from pytz import timezone as pytz_timezone
@ -23,7 +23,7 @@ class CurrentTimeTool(BuiltinTool):
tz = tool_parameters.get("timezone", "UTC") tz = tool_parameters.get("timezone", "UTC")
fm = tool_parameters.get("format") or "%Y-%m-%d %H:%M:%S %Z" fm = tool_parameters.get("format") or "%Y-%m-%d %H:%M:%S %Z"
if tz == "UTC": if tz == "UTC":
return self.create_text_message(f"{datetime.now(timezone.utc).strftime(fm)}") return self.create_text_message(f"{datetime.now(UTC).strftime(fm)}")
try: try:
tz = pytz_timezone(tz) tz = pytz_timezone(tz)

@ -1,4 +1,5 @@
import base64 import base64
import enum
from enum import Enum from enum import Enum
from typing import Any, Optional, Union from typing import Any, Optional, Union
@ -33,7 +34,7 @@ class ToolLabelEnum(Enum):
OTHER = "other" OTHER = "other"
class ToolProviderType(str, Enum): class ToolProviderType(enum.StrEnum):
""" """
Enum class for tool provider Enum class for tool provider
""" """
@ -205,7 +206,7 @@ class ToolParameterOption(BaseModel):
class ToolParameter(BaseModel): class ToolParameter(BaseModel):
class ToolParameterType(str, Enum): class ToolParameterType(enum.StrEnum):
STRING = CommonParameterType.STRING.value STRING = CommonParameterType.STRING.value
NUMBER = CommonParameterType.NUMBER.value NUMBER = CommonParameterType.NUMBER.value
BOOLEAN = CommonParameterType.BOOLEAN.value BOOLEAN = CommonParameterType.BOOLEAN.value

@ -1,7 +1,7 @@
import json import json
from collections.abc import Generator, Iterable from collections.abc import Generator, Iterable
from copy import deepcopy from copy import deepcopy
from datetime import datetime, timezone from datetime import UTC, datetime
from mimetypes import guess_type from mimetypes import guess_type
from typing import Any, Optional, Union, cast from typing import Any, Optional, Union, cast
@ -64,7 +64,12 @@ class ToolEngine:
if parameters and len(parameters) == 1: if parameters and len(parameters) == 1:
tool_parameters = {parameters[0].name: tool_parameters} tool_parameters = {parameters[0].name: tool_parameters}
else: else:
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") try:
tool_parameters = json.loads(tool_parameters)
except Exception as e:
pass
if not isinstance(tool_parameters, dict):
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
# invoke the tool # invoke the tool
try: try:
@ -195,10 +200,7 @@ class ToolEngine:
""" """
Invoke the tool with the given arguments. Invoke the tool with the given arguments.
""" """
if not tool.runtime: started_at = datetime.now(UTC)
raise ValueError("missing runtime in tool")
started_at = datetime.now(timezone.utc)
meta = ToolInvokeMeta( meta = ToolInvokeMeta(
time_cost=0.0, time_cost=0.0,
error=None, error=None,
@ -216,7 +218,7 @@ class ToolEngine:
meta.error = str(e) meta.error = str(e)
raise ToolEngineInvokeError(meta) raise ToolEngineInvokeError(meta)
finally: finally:
ended_at = datetime.now(timezone.utc) ended_at = datetime.now(UTC)
meta.time_cost = (ended_at - started_at).total_seconds() meta.time_cost = (ended_at - started_at).total_seconds()
yield meta yield meta

@ -118,11 +118,11 @@ class FileSegment(Segment):
@property @property
def log(self) -> str: def log(self) -> str:
return str(self.value) return ""
@property @property
def text(self) -> str: def text(self) -> str:
return str(self.value) return ""
class ArrayAnySegment(ArraySegment): class ArrayAnySegment(ArraySegment):
@ -155,3 +155,11 @@ class ArrayFileSegment(ArraySegment):
for item in self.value: for item in self.value:
items.append(item.markdown) items.append(item.markdown)
return "\n".join(items) return "\n".join(items)
@property
def log(self) -> str:
return ""
@property
def text(self) -> str:
return ""

@ -1,7 +1,7 @@
from enum import Enum from enum import StrEnum
class SegmentType(str, Enum): class SegmentType(StrEnum):
NONE = "none" NONE = "none"
NUMBER = "number" NUMBER = "number"
STRING = "string" STRING = "string"

@ -1,5 +1,5 @@
from collections.abc import Mapping from collections.abc import Mapping
from enum import Enum from enum import StrEnum
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel from pydantic import BaseModel
@ -8,7 +8,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
class NodeRunMetadataKey(str, Enum): class NodeRunMetadataKey(StrEnum):
""" """
Node Run Metadata Key. Node Run Metadata Key.
""" """
@ -36,7 +36,7 @@ class NodeRunResult(BaseModel):
inputs: Optional[Mapping[str, Any]] = None # node inputs inputs: Optional[Mapping[str, Any]] = None # node inputs
process_data: Optional[dict[str, Any]] = None # process data process_data: Optional[dict[str, Any]] = None # process data
outputs: Optional[dict[str, Any]] = None # node outputs outputs: Optional[Mapping[str, Any]] = None # node outputs
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
llm_usage: Optional[LLMUsage] = None # llm usage llm_usage: Optional[LLMUsage] = None # llm usage

@ -1,7 +1,7 @@
from enum import Enum from enum import StrEnum
class SystemVariableKey(str, Enum): class SystemVariableKey(StrEnum):
""" """
System Variables. System Variables.
""" """

@ -1,5 +1,5 @@
import uuid import uuid
from datetime import datetime, timezone from datetime import UTC, datetime
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
@ -63,7 +63,7 @@ class RouteNodeState(BaseModel):
raise Exception(f"Invalid route status {run_result.status}") raise Exception(f"Invalid route status {run_result.status}")
self.node_run_result = run_result self.node_run_result = run_result
self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) self.finished_at = datetime.now(UTC).replace(tzinfo=None)
class RuntimeRouteState(BaseModel): class RuntimeRouteState(BaseModel):
@ -81,7 +81,7 @@ class RuntimeRouteState(BaseModel):
:param node_id: node id :param node_id: node id
""" """
state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None)) state = RouteNodeState(node_id=node_id, start_at=datetime.now(UTC).replace(tzinfo=None))
self.node_state_mapping[state.id] = state self.node_state_mapping[state.id] = state
return state return state

@ -1,7 +1,7 @@
from enum import Enum from enum import StrEnum
class NodeType(str, Enum): class NodeType(StrEnum):
START = "start" START = "start"
END = "end" END = "end"
ANSWER = "answer" ANSWER = "answer"

@ -108,7 +108,7 @@ class Executor:
self.content = self.variable_pool.convert_template(data[0].value).text self.content = self.variable_pool.convert_template(data[0].value).text
case "json": case "json":
json_string = self.variable_pool.convert_template(data[0].value).text json_string = self.variable_pool.convert_template(data[0].value).text
json_object = json.loads(json_string) json_object = json.loads(json_string, strict=False)
self.json = json_object self.json = json_object
# self.json = self._parse_object_contains_variables(json_object) # self.json = self._parse_object_contains_variables(json_object)
case "binary": case "binary":

@ -1,4 +1,4 @@
from enum import Enum from enum import StrEnum
from typing import Any, Optional from typing import Any, Optional
from pydantic import Field from pydantic import Field
@ -6,7 +6,7 @@ from pydantic import Field
from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData
class ErrorHandleMode(str, Enum): class ErrorHandleMode(StrEnum):
TERMINATED = "terminated" TERMINATED = "terminated"
CONTINUE_ON_ERROR = "continue-on-error" CONTINUE_ON_ERROR = "continue-on-error"
REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output" REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output"

@ -2,7 +2,7 @@ import logging
import uuid import uuid
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, wait from concurrent.futures import Future, wait
from datetime import datetime, timezone from datetime import UTC, datetime
from queue import Empty, Queue from queue import Empty, Queue
from typing import TYPE_CHECKING, Any, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast
@ -135,7 +135,7 @@ class IterationNode(BaseNode[IterationNodeData]):
thread_pool_id=self.thread_pool_id, thread_pool_id=self.thread_pool_id,
) )
start_at = datetime.now(timezone.utc).replace(tzinfo=None) start_at = datetime.now(UTC).replace(tzinfo=None)
yield IterationRunStartedEvent( yield IterationRunStartedEvent(
iteration_id=self.id, iteration_id=self.id,
@ -367,7 +367,7 @@ class IterationNode(BaseNode[IterationNodeData]):
""" """
run single iteration run single iteration
""" """
iter_start_at = datetime.now(timezone.utc).replace(tzinfo=None) iter_start_at = datetime.now(UTC).replace(tzinfo=None)
try: try:
rst = graph_engine.run() rst = graph_engine.run()
@ -440,7 +440,7 @@ class IterationNode(BaseNode[IterationNodeData]):
variable_pool.add([self.node_id, "index"], next_index) variable_pool.add([self.node_id, "index"], next_index)
if next_index < len(iterator_list_value): if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds() duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent( yield IterationRunNextEvent(
iteration_id=self.id, iteration_id=self.id,
@ -461,7 +461,7 @@ class IterationNode(BaseNode[IterationNodeData]):
if next_index < len(iterator_list_value): if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds() duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent( yield IterationRunNextEvent(
iteration_id=self.id, iteration_id=self.id,
@ -503,7 +503,7 @@ class IterationNode(BaseNode[IterationNodeData]):
if next_index < len(iterator_list_value): if next_index < len(iterator_list_value):
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds() duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
iter_run_map[iteration_run_id] = duration iter_run_map[iteration_run_id] = duration
yield IterationRunNextEvent( yield IterationRunNextEvent(
iteration_id=self.id, iteration_id=self.id,

@ -39,7 +39,14 @@ class VisionConfig(BaseModel):
class PromptConfig(BaseModel): class PromptConfig(BaseModel):
jinja2_variables: Optional[list[VariableSelector]] = None jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list)
@field_validator("jinja2_variables", mode="before")
@classmethod
def convert_none_jinja2_variables(cls, v: Any):
if v is None:
return []
return v
class LLMNodeChatModelMessage(ChatModelMessage): class LLMNodeChatModelMessage(ChatModelMessage):
@ -53,7 +60,14 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
class LLMNodeData(BaseNodeData): class LLMNodeData(BaseNodeData):
model: ModelConfig model: ModelConfig
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
prompt_config: Optional[PromptConfig] = None prompt_config: PromptConfig = Field(default_factory=PromptConfig)
memory: Optional[MemoryConfig] = None memory: Optional[MemoryConfig] = None
context: ContextConfig context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig) vision: VisionConfig = Field(default_factory=VisionConfig)
@field_validator("prompt_config", mode="before")
@classmethod
def convert_none_prompt_config(cls, v: Any):
if v is None:
return PromptConfig()
return v

@ -24,3 +24,17 @@ class LLMModeRequiredError(LLMNodeError):
class NoPromptFoundError(LLMNodeError): class NoPromptFoundError(LLMNodeError):
"""Raised when no prompt is found in the LLM configuration.""" """Raised when no prompt is found in the LLM configuration."""
class TemplateTypeNotSupportError(LLMNodeError):
def __init__(self, *, type_name: str):
super().__init__(f"Prompt type {type_name} is not supported.")
class MemoryRolePrefixRequiredError(LLMNodeError):
"""Raised when memory role prefix is required for completion model."""
class FileTypeNotSupportError(LLMNodeError):
def __init__(self, *, type_name: str):
super().__init__(f"{type_name} type is not supported by this model")

@ -1,4 +1,5 @@
import json import json
import logging
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast
@ -6,21 +7,26 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.model_entities import ModelStatus from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import QuotaUnit from core.entities.provider_entities import QuotaUnit
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import ( from core.model_runtime.entities import (
AudioPromptMessageContent,
ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage, PromptMessage,
PromptMessageContentType, PromptMessageContentType,
TextPromptMessageContent, TextPromptMessageContent,
VideoPromptMessageContent,
) )
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageRole,
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables import ( from core.variables import (
@ -34,6 +40,8 @@ from core.variables import (
) )
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
@ -58,18 +66,23 @@ from .entities import (
ModelConfig, ModelConfig,
) )
from .exc import ( from .exc import (
FileTypeNotSupportError,
InvalidContextStructureError, InvalidContextStructureError,
InvalidVariableTypeError, InvalidVariableTypeError,
LLMModeRequiredError, LLMModeRequiredError,
LLMNodeError, LLMNodeError,
MemoryRolePrefixRequiredError,
ModelNotExistError, ModelNotExistError,
NoPromptFoundError, NoPromptFoundError,
TemplateTypeNotSupportError,
VariableNotFoundError, VariableNotFoundError,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from core.file.models import File from core.file.models import File
logger = logging.getLogger(__name__)
class LLMNode(BaseNode[LLMNodeData]): class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData _node_data_cls = LLMNodeData
@ -121,19 +134,19 @@ class LLMNode(BaseNode[LLMNodeData]):
# fetch memory # fetch memory
memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance) memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance)
# fetch prompt messages query = None
if self.node_data.memory: if self.node_data.memory:
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) query = self.node_data.memory.query_prompt_template
if not query: if query is None and (
raise VariableNotFoundError("Query not found") query_variable := self.graph_runtime_state.variable_pool.get(
query = query.text (SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)
else: )
query = None ):
query = query_variable.text
prompt_messages, stop = self._fetch_prompt_messages( prompt_messages, stop = self._fetch_prompt_messages(
system_query=query, user_query=query,
inputs=inputs, user_files=files,
files=files,
context=context, context=context,
memory=memory, memory=memory,
model_config=model_config, model_config=model_config,
@ -141,6 +154,8 @@ class LLMNode(BaseNode[LLMNodeData]):
memory_config=self.node_data.memory, memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled, vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail, vision_detail=self.node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
) )
process_data = { process_data = {
@ -181,6 +196,17 @@ class LLMNode(BaseNode[LLMNodeData]):
) )
) )
return return
except Exception as e:
logger.exception(f"Node {self.node_id} failed to run")
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
inputs=node_inputs,
process_data=process_data,
)
)
return
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
@ -203,8 +229,8 @@ class LLMNode(BaseNode[LLMNodeData]):
self, self,
node_data_model: ModelConfig, node_data_model: ModelConfig,
model_instance: ModelInstance, model_instance: ModelInstance,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
) -> Generator[NodeEvent, None, None]: ) -> Generator[NodeEvent, None, None]:
db.session.close() db.session.close()
@ -519,9 +545,8 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_prompt_messages( def _fetch_prompt_messages(
self, self,
*, *,
system_query: str | None = None, user_query: str | None = None,
inputs: dict[str, str] | None = None, user_files: Sequence["File"],
files: Sequence["File"],
context: str | None = None, context: str | None = None,
memory: TokenBufferMemory | None = None, memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity, model_config: ModelConfigWithCredentialsEntity,
@ -529,58 +554,144 @@ class LLMNode(BaseNode[LLMNodeData]):
memory_config: MemoryConfig | None = None, memory_config: MemoryConfig | None = None,
vision_enabled: bool = False, vision_enabled: bool = False,
vision_detail: ImagePromptMessageContent.DETAIL, vision_detail: ImagePromptMessageContent.DETAIL,
) -> tuple[list[PromptMessage], Optional[list[str]]]: variable_pool: VariablePool,
inputs = inputs or {} jinja2_variables: Sequence[VariableSelector],
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) prompt_messages = []
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template, if isinstance(prompt_template, list):
inputs=inputs, # For chat model
query=system_query or "", prompt_messages.extend(
files=files, _handle_list_messages(
context=context, messages=prompt_template,
memory_config=memory_config, context=context,
memory=memory, jinja2_variables=jinja2_variables,
model_config=model_config, variable_pool=variable_pool,
) vision_detail_config=vision_detail,
stop = model_config.stop )
)
# Get memory messages for chat mode
memory_messages = _handle_memory_chat_mode(
memory=memory,
memory_config=memory_config,
model_config=model_config,
)
# Extend prompt_messages with memory messages
prompt_messages.extend(memory_messages)
# Add current query to the prompt messages
if user_query:
message = LLMNodeChatModelMessage(
text=user_query,
role=PromptMessageRole.USER,
edition_type="basic",
)
prompt_messages.extend(
_handle_list_messages(
messages=[message],
context="",
jinja2_variables=[],
variable_pool=variable_pool,
vision_detail_config=vision_detail,
)
)
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
# For completion model
prompt_messages.extend(
_handle_completion_template(
template=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
)
# Get memory text for completion model
memory_text = _handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_config=model_config,
)
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content
# Add current query to the prompt message
if user_query:
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
prompt_messages[0].content = prompt_content
else:
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
if vision_enabled and user_files:
file_prompts = []
for file in user_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# Filter prompt messages
filtered_prompt_messages = [] filtered_prompt_messages = []
for prompt_message in prompt_messages: for prompt_message in prompt_messages:
if prompt_message.is_empty(): if isinstance(prompt_message.content, list):
continue
if not isinstance(prompt_message.content, str):
prompt_message_content = [] prompt_message_content = []
for content_item in prompt_message.content or []: for content_item in prompt_message.content:
# Skip image if vision is disabled # Skip content if features are not defined
if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE: if not model_config.model_schema.features:
if content_item.type != PromptMessageContentType.TEXT:
continue
prompt_message_content.append(content_item)
continue continue
if isinstance(content_item, ImagePromptMessageContent): # Skip content if corresponding feature is not supported
# Override vision config if LLM node has vision config, if (
# cuz vision detail is related to the configuration from FileUpload feature. (
content_item.detail = vision_detail content_item.type == PromptMessageContentType.IMAGE
prompt_message_content.append(content_item) and ModelFeature.VISION not in model_config.model_schema.features
elif isinstance( )
content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent or (
content_item.type == PromptMessageContentType.DOCUMENT
and ModelFeature.DOCUMENT not in model_config.model_schema.features
)
or (
content_item.type == PromptMessageContentType.VIDEO
and ModelFeature.VIDEO not in model_config.model_schema.features
)
or (
content_item.type == PromptMessageContentType.AUDIO
and ModelFeature.AUDIO not in model_config.model_schema.features
)
): ):
prompt_message_content.append(content_item) raise FileTypeNotSupportError(type_name=content_item.type)
prompt_message_content.append(content_item)
if len(prompt_message_content) > 1: if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
prompt_message.content = prompt_message_content
elif (
len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT
):
prompt_message.content = prompt_message_content[0].data prompt_message.content = prompt_message_content[0].data
else:
prompt_message.content = prompt_message_content
if prompt_message.is_empty():
continue
filtered_prompt_messages.append(prompt_message) filtered_prompt_messages.append(prompt_message)
if not filtered_prompt_messages: if len(filtered_prompt_messages) == 0:
raise NoPromptFoundError( raise NoPromptFoundError(
"No prompt found in the LLM configuration. " "No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding." "Please ensure a prompt is properly configured before proceeding."
) )
stop = model_config.stop
return filtered_prompt_messages, stop return filtered_prompt_messages, stop
@classmethod @classmethod
@ -715,3 +826,198 @@ class LLMNode(BaseNode[LLMNodeData]):
} }
}, },
} }
def _combine_text_message_with_role(*, text: str, role: PromptMessageRole):
match role:
case PromptMessageRole.USER:
return UserPromptMessage(content=[TextPromptMessageContent(data=text)])
case PromptMessageRole.ASSISTANT:
return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)])
case PromptMessageRole.SYSTEM:
return SystemPromptMessage(content=[TextPromptMessageContent(data=text)])
raise NotImplementedError(f"Role {role} is not supported")
def _render_jinja2_message(
*,
template: str,
jinjia2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
):
if not template:
return ""
jinjia2_inputs = {}
for jinja2_variable in jinjia2_variables:
variable = variable_pool.get(jinja2_variable.value_selector)
jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
code_execute_resp = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2,
code=template,
inputs=jinjia2_inputs,
)
result_text = code_execute_resp["result"]
return result_text
def _handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
) -> Sequence[PromptMessage]:
prompt_messages = []
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_text_message_with_role(text=result_text, role=message.role)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)
# Process segments for images
file_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
if isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
# Create message with text from all segments
plain_text = segment_group.text
if plain_text:
prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role)
prompt_messages.append(prompt_message)
if file_contents:
# Create message with image contents
prompt_message = UserPromptMessage(content=file_contents)
prompt_messages.append(prompt_message)
return prompt_messages
def _calculate_rest_token(
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> int:
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(str(parameter_rule.use_template))
or 0
)
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _handle_memory_chat_mode(
*,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> Sequence[PromptMessage]:
memory_messages = []
# Get messages from memory for chat model
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
return memory_messages
def _handle_memory_completion_mode(
*,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> str:
memory_text = ""
# Get history text from memory for completion model
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_text = memory.get_history_prompt_text(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
return memory_text
def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
) -> Sequence[PromptMessage]:
"""Handle completion template processing outside of LLMNode class.
Args:
template: The completion model prompt template
context: Optional context string
jinja2_variables: Variables for jinja2 template rendering
variable_pool: Variable pool for template conversion
Returns:
Sequence of prompt messages
"""
prompt_messages = []
if template.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=template.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
else:
if context:
template_text = template.text.replace("{#context#}", context)
else:
template_text = template.text
result_text = variable_pool.convert_template(template_text).text
prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER)
prompt_messages.append(prompt_message)
return prompt_messages

@ -89,12 +89,14 @@ class QuestionClassifierNode(LLMNode):
) )
prompt_messages, stop = self._fetch_prompt_messages( prompt_messages, stop = self._fetch_prompt_messages(
prompt_template=prompt_template, prompt_template=prompt_template,
system_query=query, user_query=query,
memory=memory, memory=memory,
model_config=model_config, model_config=model_config,
files=files, user_files=files,
vision_enabled=node_data.vision.enabled, vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail, vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=[],
) )
# handle invoke result # handle invoke result

@ -1,11 +1,11 @@
from collections.abc import Sequence from collections.abc import Sequence
from enum import Enum from enum import StrEnum
from typing import Optional from typing import Optional
from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.base import BaseNodeData
class WriteMode(str, Enum): class WriteMode(StrEnum):
OVER_WRITE = "over-write" OVER_WRITE = "over-write"
APPEND = "append" APPEND = "append"
CLEAR = "clear" CLEAR = "clear"

@ -5,10 +5,9 @@ from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast from typing import Any, Optional, cast
from configs import dify_config from configs import dify_config
from core.app.app_config.entities import FileUploadConfig
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File, FileTransferMethod, ImageConfig from core.file.models import File
from core.workflow.callbacks import WorkflowCallback from core.workflow.callbacks import WorkflowCallback
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.errors import WorkflowNodeRunFailedError
@ -18,9 +17,8 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNode, BaseNodeData from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.event import NodeEvent from core.workflow.nodes.event import NodeEvent
from core.workflow.nodes.llm import LLMNodeData
from core.workflow.nodes.node_mapping import node_type_classes_mapping from core.workflow.nodes.node_mapping import node_type_classes_mapping
from factories import file_factory from factories import file_factory
from models.enums import UserFrom from models.enums import UserFrom
@ -115,7 +113,12 @@ class WorkflowEntry:
@classmethod @classmethod
def single_step_run( def single_step_run(
cls, workflow: Workflow, node_id: str, user_id: str, user_inputs: dict cls,
*,
workflow: Workflow,
node_id: str,
user_id: str,
user_inputs: dict,
) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]: ) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
""" """
Single step run workflow node Single step run workflow node
@ -135,13 +138,9 @@ class WorkflowEntry:
raise ValueError("nodes not found in workflow graph") raise ValueError("nodes not found in workflow graph")
# fetch node config from node id # fetch node config from node id
node_config = None try:
for node in nodes: node_config = next(filter(lambda node: node["id"] == node_id, nodes))
if node.get("id") == node_id: except StopIteration:
node_config = node
break
if not node_config:
raise ValueError("node id not found in workflow graph") raise ValueError("node id not found in workflow graph")
# Get node class # Get node class
@ -153,11 +152,7 @@ class WorkflowEntry:
raise ValueError(f"Node class not found for node type {node_type}") raise ValueError(f"Node class not found for node type {node_type}")
# init variable pool # init variable pool
variable_pool = VariablePool( variable_pool = VariablePool(environment_variables=workflow.environment_variables)
system_variables={},
user_inputs={},
environment_variables=workflow.environment_variables,
)
# init graph # init graph
graph = Graph.init(graph_config=workflow.graph_dict) graph = Graph.init(graph_config=workflow.graph_dict)
@ -183,28 +178,24 @@ class WorkflowEntry:
try: try:
# variable selector to variable mapping # variable selector to variable mapping
try: variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( graph_config=workflow.graph_dict, config=node_config
graph_config=workflow.graph_dict, config=node_config
)
except NotImplementedError:
variable_mapping = {}
cls.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
node_type=node_type,
node_data=node_instance.node_data,
) )
except NotImplementedError:
variable_mapping = {}
cls.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
)
try:
# run node # run node
generator = node_instance.run() generator = node_instance.run()
return node_instance, generator
except Exception as e: except Exception as e:
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
return node_instance, generator
@classmethod @classmethod
def run_free_node( def run_free_node(
@ -332,12 +323,11 @@ class WorkflowEntry:
@classmethod @classmethod
def mapping_user_inputs_to_variable_pool( def mapping_user_inputs_to_variable_pool(
cls, cls,
*,
variable_mapping: Mapping[str, Sequence[str]], variable_mapping: Mapping[str, Sequence[str]],
user_inputs: dict, user_inputs: dict,
variable_pool: VariablePool, variable_pool: VariablePool,
tenant_id: str, tenant_id: str,
node_type: NodeType,
node_data: BaseNodeData,
) -> None: ) -> None:
for node_variable, variable_selector in variable_mapping.items(): for node_variable, variable_selector in variable_mapping.items():
# fetch node id and variable key from node_variable # fetch node id and variable key from node_variable
@ -355,40 +345,21 @@ class WorkflowEntry:
# fetch variable node id from variable selector # fetch variable node id from variable selector
variable_node_id = variable_selector[0] variable_node_id = variable_selector[0]
variable_key_list = variable_selector[1:] variable_key_list = variable_selector[1:]
variable_key_list = cast(list[str], variable_key_list) variable_key_list = list(variable_key_list)
# get input value # get input value
input_value = user_inputs.get(node_variable) input_value = user_inputs.get(node_variable)
if not input_value: if not input_value:
input_value = user_inputs.get(node_variable_key) input_value = user_inputs.get(node_variable_key)
# FIXME: temp fix for image type if isinstance(input_value, dict) and "type" in input_value and "transfer_method" in input_value:
if node_type == NodeType.LLM: input_value = file_factory.build_from_mapping(mapping=input_value, tenant_id=tenant_id)
new_value = [] if (
if isinstance(input_value, list): isinstance(input_value, list)
node_data = cast(LLMNodeData, node_data) and all(isinstance(item, dict) for item in input_value)
and all("type" in item and "transfer_method" in item for item in input_value)
detail = node_data.vision.configs.detail if node_data.vision.configs else None ):
input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id)
for item in input_value:
if isinstance(item, dict) and "type" in item and item["type"] == "image":
transfer_method = FileTransferMethod.value_of(item.get("transfer_method"))
mapping = {
"id": item.get("id"),
"transfer_method": transfer_method,
"upload_file_id": item.get("upload_file_id"),
"url": item.get("url"),
}
config = FileUploadConfig(image_config=ImageConfig(detail=detail) if detail else None)
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
config=config,
)
new_value.append(file)
if new_value:
input_value = new_value
# append variable and value to variable pool # append variable and value to variable pool
variable_pool.add([variable_node_id] + variable_key_list, input_value) variable_pool.add([variable_node_id] + variable_key_list, input_value)

@ -33,7 +33,7 @@ def handle(sender, **kwargs):
raise NotFound("Document not found") raise NotFound("Document not found")
document.indexing_status = "parsing" document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
documents.append(document) documents.append(document)
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()

@ -1,4 +1,4 @@
from datetime import datetime, timezone from datetime import UTC, datetime
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
from events.message_event import message_was_created from events.message_event import message_was_created
@ -17,5 +17,5 @@ def handle(sender, **kwargs):
db.session.query(Provider).filter( db.session.query(Provider).filter(
Provider.tenant_id == application_generate_entity.app_config.tenant_id, Provider.tenant_id == application_generate_entity.app_config.tenant_id,
Provider.provider_name == application_generate_entity.model_conf.provider, Provider.provider_name == application_generate_entity.model_conf.provider,
).update({"last_used": datetime.now(timezone.utc).replace(tzinfo=None)}) ).update({"last_used": datetime.now(UTC).replace(tzinfo=None)})
db.session.commit() db.session.commit()

@ -1,5 +1,5 @@
from collections.abc import Generator from collections.abc import Generator
from datetime import datetime, timedelta, timezone from datetime import UTC, datetime, timedelta
from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas
@ -67,7 +67,7 @@ class AzureBlobStorage(BaseStorage):
account_key=self.account_key, account_key=self.account_key,
resource_types=ResourceTypes(service=True, container=True, object=True), resource_types=ResourceTypes(service=True, container=True, object=True),
permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1), expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1),
) )
redis_client.set(cache_key, sas_token, ex=3000) redis_client.set(cache_key, sas_token, ex=3000)
return BlobServiceClient(account_url=self.account_url, credential=sas_token) return BlobServiceClient(account_url=self.account_url, credential=sas_token)

@ -1,7 +1,7 @@
from enum import Enum from enum import StrEnum
class StorageType(str, Enum): class StorageType(StrEnum):
ALIYUN_OSS = "aliyun-oss" ALIYUN_OSS = "aliyun-oss"
AZURE_BLOB = "azure-blob" AZURE_BLOB = "azure-blob"
BAIDU_OBS = "baidu-obs" BAIDU_OBS = "baidu-obs"

@ -1,10 +1,11 @@
import mimetypes import mimetypes
from collections.abc import Callable, Mapping, Sequence from collections.abc import Callable, Mapping, Sequence
from typing import Any from typing import Any, cast
import httpx import httpx
from sqlalchemy import select from sqlalchemy import select
from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from extensions.ext_database import db from extensions.ext_database import db
@ -71,7 +72,12 @@ def build_from_mapping(
transfer_method=transfer_method, transfer_method=transfer_method,
) )
if not _is_file_valid_with_config(file=file, config=config): if not _is_file_valid_with_config(
input_file_type=mapping.get("type", FileType.CUSTOM),
file_extension=file.extension,
file_transfer_method=file.transfer_method,
config=config,
):
raise ValueError(f"File validation failed for file: {file.filename}") raise ValueError(f"File validation failed for file: {file.filename}")
return file return file
@ -80,12 +86,9 @@ def build_from_mapping(
def build_from_mappings( def build_from_mappings(
*, *,
mappings: Sequence[Mapping[str, Any]], mappings: Sequence[Mapping[str, Any]],
config: FileUploadConfig | None, config: FileUploadConfig | None = None,
tenant_id: str, tenant_id: str,
) -> Sequence[File]: ) -> Sequence[File]:
if not config:
return []
files = [ files = [
build_from_mapping( build_from_mapping(
mapping=mapping, mapping=mapping,
@ -96,13 +99,14 @@ def build_from_mappings(
] ]
if ( if (
config
# If image config is set. # If image config is set.
config.image_config and config.image_config
# And the number of image files exceeds the maximum limit # And the number of image files exceeds the maximum limit
and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits
): ):
raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}") raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}")
if config.number_limits and len(files) > config.number_limits: if config and config.number_limits and len(files) > config.number_limits:
raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}") raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}")
return files return files
@ -114,17 +118,18 @@ def _build_from_local_file(
tenant_id: str, tenant_id: str,
transfer_method: FileTransferMethod, transfer_method: FileTransferMethod,
) -> File: ) -> File:
file_type = FileType.value_of(mapping.get("type"))
stmt = select(UploadFile).where( stmt = select(UploadFile).where(
UploadFile.id == mapping.get("upload_file_id"), UploadFile.id == mapping.get("upload_file_id"),
UploadFile.tenant_id == tenant_id, UploadFile.tenant_id == tenant_id,
) )
row = db.session.scalar(stmt) row = db.session.scalar(stmt)
if row is None: if row is None:
raise ValueError("Invalid upload file") raise ValueError("Invalid upload file")
file_type = FileType(mapping.get("type"))
file_type = _standardize_file_type(file_type, extension="." + row.extension, mime_type=row.mime_type)
return File( return File(
id=mapping.get("id"), id=mapping.get("id"),
filename=row.name, filename=row.name,
@ -152,11 +157,14 @@ def _build_from_remote_url(
mime_type, filename, file_size = _get_remote_file_info(url) mime_type, filename, file_size = _get_remote_file_info(url)
extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin" extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin"
file_type = FileType(mapping.get("type"))
file_type = _standardize_file_type(file_type, extension=extension, mime_type=mime_type)
return File( return File(
id=mapping.get("id"), id=mapping.get("id"),
filename=filename, filename=filename,
tenant_id=tenant_id, tenant_id=tenant_id,
type=FileType.value_of(mapping.get("type")), type=file_type,
transfer_method=transfer_method, transfer_method=transfer_method,
remote_url=url, remote_url=url,
mime_type=mime_type, mime_type=mime_type,
@ -171,6 +179,7 @@ def _get_remote_file_info(url: str):
mime_type = mimetypes.guess_type(filename)[0] or "" mime_type = mimetypes.guess_type(filename)[0] or ""
resp = ssrf_proxy.head(url, follow_redirects=True) resp = ssrf_proxy.head(url, follow_redirects=True)
resp = cast(httpx.Response, resp)
if resp.status_code == httpx.codes.OK: if resp.status_code == httpx.codes.OK:
if content_disposition := resp.headers.get("Content-Disposition"): if content_disposition := resp.headers.get("Content-Disposition"):
filename = str(content_disposition.split("filename=")[-1].strip('"')) filename = str(content_disposition.split("filename=")[-1].strip('"'))
@ -180,20 +189,6 @@ def _get_remote_file_info(url: str):
return mime_type, filename, file_size return mime_type, filename, file_size
def _get_file_type_by_mimetype(mime_type: str) -> FileType:
if "image" in mime_type:
file_type = FileType.IMAGE
elif "video" in mime_type:
file_type = FileType.VIDEO
elif "audio" in mime_type:
file_type = FileType.AUDIO
elif "text" in mime_type or "pdf" in mime_type:
file_type = FileType.DOCUMENT
else:
file_type = FileType.CUSTOM
return file_type
def _build_from_tool_file( def _build_from_tool_file(
*, *,
mapping: Mapping[str, Any], mapping: Mapping[str, Any],
@ -213,7 +208,8 @@ def _build_from_tool_file(
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
file_type = mapping.get("type", _get_file_type_by_mimetype(tool_file.mimetype)) file_type = FileType(mapping.get("type"))
file_type = _standardize_file_type(file_type, extension=extension, mime_type=tool_file.mimetype)
return File( return File(
id=mapping.get("id"), id=mapping.get("id"),
@ -229,18 +225,72 @@ def _build_from_tool_file(
) )
def _is_file_valid_with_config(*, file: File, config: FileUploadConfig) -> bool: def _is_file_valid_with_config(
if config.allowed_file_types and file.type not in config.allowed_file_types and file.type != FileType.CUSTOM: *,
input_file_type: str,
file_extension: str,
file_transfer_method: FileTransferMethod,
config: FileUploadConfig,
) -> bool:
if (
config.allowed_file_types
and input_file_type not in config.allowed_file_types
and input_file_type != FileType.CUSTOM
):
return False return False
if config.allowed_file_extensions and file.extension not in config.allowed_file_extensions: if (
input_file_type == FileType.CUSTOM
and config.allowed_file_extensions is not None
and file_extension not in config.allowed_file_extensions
):
return False return False
if config.allowed_file_upload_methods and file.transfer_method not in config.allowed_file_upload_methods: if config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods:
return False return False
if file.type == FileType.IMAGE and config.image_config: if input_file_type == FileType.IMAGE and config.image_config:
if config.image_config.transfer_methods and file.transfer_method not in config.image_config.transfer_methods: if config.image_config.transfer_methods and file_transfer_method not in config.image_config.transfer_methods:
return False return False
return True return True
def _standardize_file_type(file_type: FileType, /, *, extension: str = "", mime_type: str = "") -> FileType:
"""
If custom type, try to guess the file type by extension and mime_type.
"""
if file_type != FileType.CUSTOM:
return FileType(file_type)
guessed_type = None
if extension:
guessed_type = _get_file_type_by_extension(extension)
if guessed_type is None and mime_type:
guessed_type = _get_file_type_by_mimetype(mime_type)
return guessed_type or FileType.CUSTOM
def _get_file_type_by_extension(extension: str) -> FileType | None:
extension = extension.lstrip(".")
if extension in IMAGE_EXTENSIONS:
return FileType.IMAGE
elif extension in VIDEO_EXTENSIONS:
return FileType.VIDEO
elif extension in AUDIO_EXTENSIONS:
return FileType.AUDIO
elif extension in DOCUMENT_EXTENSIONS:
return FileType.DOCUMENT
def _get_file_type_by_mimetype(mime_type: str) -> FileType | None:
if "image" in mime_type:
file_type = FileType.IMAGE
elif "video" in mime_type:
file_type = FileType.VIDEO
elif "audio" in mime_type:
file_type = FileType.AUDIO
elif "text" in mime_type or "pdf" in mime_type:
file_type = FileType.DOCUMENT
else:
file_type = FileType.CUSTOM
return file_type

@ -190,3 +190,12 @@ app_site_fields = {
"show_workflow_steps": fields.Boolean, "show_workflow_steps": fields.Boolean,
"use_icon_as_answer_icon": fields.Boolean, "use_icon_as_answer_icon": fields.Boolean,
} }
app_import_fields = {
"id": fields.String,
"status": fields.String,
"app_id": fields.String,
"current_dsl_version": fields.String,
"imported_dsl_version": fields.String,
"error": fields.String,
}

@ -41,6 +41,7 @@ dataset_retrieval_model_fields = {
external_retrieval_model_fields = { external_retrieval_model_fields = {
"top_k": fields.Integer, "top_k": fields.Integer,
"score_threshold": fields.Float, "score_threshold": fields.Float,
"score_threshold_enabled": fields.Boolean,
} }
tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String}

@ -31,9 +31,12 @@ class AppIconUrlField(fields.Raw):
if obj is None: if obj is None:
return None return None
from models.model import IconType from models.model import App, IconType
if obj.icon_type == IconType.IMAGE.value: if isinstance(obj, dict) and "app" in obj:
obj = obj["app"]
if isinstance(obj, App) and obj.icon_type == IconType.IMAGE.value:
return file_helpers.get_signed_file_url(obj.icon) return file_helpers.get_signed_file_url(obj.icon)
return None return None

@ -70,7 +70,7 @@ class NotionOAuth(OAuthDataSource):
if data_source_binding: if data_source_binding:
data_source_binding.source_info = source_info data_source_binding.source_info = source_info
data_source_binding.disabled = False data_source_binding.disabled = False
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
new_data_source_binding = DataSourceOauthBinding( new_data_source_binding = DataSourceOauthBinding(
@ -106,7 +106,7 @@ class NotionOAuth(OAuthDataSource):
if data_source_binding: if data_source_binding:
data_source_binding.source_info = source_info data_source_binding.source_info = source_info
data_source_binding.disabled = False data_source_binding.disabled = False
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
new_data_source_binding = DataSourceOauthBinding( new_data_source_binding = DataSourceOauthBinding(
@ -141,7 +141,7 @@ class NotionOAuth(OAuthDataSource):
} }
data_source_binding.source_info = new_source_info data_source_binding.source_info = new_source_info
data_source_binding.disabled = False data_source_binding.disabled = False
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
raise ValueError("Data source binding not found") raise ValueError("Data source binding not found")

@ -10,7 +10,7 @@ from models.base import Base
from .types import StringUUID from .types import StringUUID
class AccountStatus(str, enum.Enum): class AccountStatus(enum.StrEnum):
PENDING = "pending" PENDING = "pending"
UNINITIALIZED = "uninitialized" UNINITIALIZED = "uninitialized"
ACTIVE = "active" ACTIVE = "active"
@ -102,15 +102,15 @@ class Account(UserMixin, Base):
return db.session.query(Account).filter(Account.id == account_integrate.account_id).one_or_none() return db.session.query(Account).filter(Account.id == account_integrate.account_id).one_or_none()
return None return None
def get_integrates(self) -> list[db.Model]:
ai = db.Model
return db.session.query(ai).filter(ai.account_id == self.id).all()
# check current_user.current_tenant.current_role in ['admin', 'owner'] # check current_user.current_tenant.current_role in ['admin', 'owner']
@property @property
def is_admin_or_owner(self): def is_admin_or_owner(self):
return TenantAccountRole.is_privileged_role(self._current_tenant.current_role) return TenantAccountRole.is_privileged_role(self._current_tenant.current_role)
@property
def is_admin(self):
return TenantAccountRole.is_admin_role(self._current_tenant.current_role)
@property @property
def is_editor(self): def is_editor(self):
return TenantAccountRole.is_editing_role(self._current_tenant.current_role) return TenantAccountRole.is_editing_role(self._current_tenant.current_role)
@ -124,12 +124,12 @@ class Account(UserMixin, Base):
return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR
class TenantStatus(str, enum.Enum): class TenantStatus(enum.StrEnum):
NORMAL = "normal" NORMAL = "normal"
ARCHIVE = "archive" ARCHIVE = "archive"
class TenantAccountRole(str, enum.Enum): class TenantAccountRole(enum.StrEnum):
OWNER = "owner" OWNER = "owner"
ADMIN = "admin" ADMIN = "admin"
EDITOR = "editor" EDITOR = "editor"
@ -138,7 +138,9 @@ class TenantAccountRole(str, enum.Enum):
@staticmethod @staticmethod
def is_valid_role(role: str) -> bool: def is_valid_role(role: str) -> bool:
return role and role in { if not role:
return False
return role in {
TenantAccountRole.OWNER, TenantAccountRole.OWNER,
TenantAccountRole.ADMIN, TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR, TenantAccountRole.EDITOR,
@ -148,11 +150,21 @@ class TenantAccountRole(str, enum.Enum):
@staticmethod @staticmethod
def is_privileged_role(role: str) -> bool: def is_privileged_role(role: str) -> bool:
return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} if not role:
return False
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN}
@staticmethod
def is_admin_role(role: str) -> bool:
if not role:
return False
return role == TenantAccountRole.ADMIN
@staticmethod @staticmethod
def is_non_owner_role(role: str) -> bool: def is_non_owner_role(role: str) -> bool:
return role and role in { if not role:
return False
return role in {
TenantAccountRole.ADMIN, TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR, TenantAccountRole.EDITOR,
TenantAccountRole.NORMAL, TenantAccountRole.NORMAL,
@ -161,11 +173,15 @@ class TenantAccountRole(str, enum.Enum):
@staticmethod @staticmethod
def is_editing_role(role: str) -> bool: def is_editing_role(role: str) -> bool:
return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR} if not role:
return False
return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR}
@staticmethod @staticmethod
def is_dataset_edit_role(role: str) -> bool: def is_dataset_edit_role(role: str) -> bool:
return role and role in { if not role:
return False
return role in {
TenantAccountRole.OWNER, TenantAccountRole.OWNER,
TenantAccountRole.ADMIN, TenantAccountRole.ADMIN,
TenantAccountRole.EDITOR, TenantAccountRole.EDITOR,
@ -265,12 +281,12 @@ class InvitationCode(db.Model):
class TenantPluginPermission(Base): class TenantPluginPermission(Base):
class InstallPermission(str, enum.Enum): class InstallPermission(enum.StrEnum):
EVERYONE = "everyone" EVERYONE = "everyone"
ADMINS = "admins" ADMINS = "admins"
NOBODY = "noone" NOBODY = "noone"
class DebugPermission(str, enum.Enum): class DebugPermission(enum.StrEnum):
EVERYONE = "everyone" EVERYONE = "everyone"
ADMINS = "admins" ADMINS = "admins"
NOBODY = "noone" NOBODY = "noone"

@ -23,7 +23,7 @@ from .model import App, Tag, TagBinding, UploadFile
from .types import StringUUID from .types import StringUUID
class DatasetPermissionEnum(str, enum.Enum): class DatasetPermissionEnum(enum.StrEnum):
ONLY_ME = "only_me" ONLY_ME = "only_me"
ALL_TEAM = "all_team_members" ALL_TEAM = "all_team_members"
PARTIAL_TEAM = "partial_members" PARTIAL_TEAM = "partial_members"

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save