Merge branch 'main' into feat/parent-child-retrieval

feat/parent-child-retrieval-api
AkaraChen 1 year ago
commit c4aa98e609

@ -1,5 +1,5 @@
FROM mcr.microsoft.com/devcontainers/python:3.10 FROM mcr.microsoft.com/devcontainers/python:3.12
# [Optional] Uncomment this section to install additional OS packages. # [Optional] Uncomment this section to install additional OS packages.
# RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ # RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
# && apt-get -y install --no-install-recommends <your-package-list-here> # && apt-get -y install --no-install-recommends <your-package-list-here>

@ -1,7 +1,7 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the // For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/anaconda // README at: https://github.com/devcontainers/templates/tree/main/src/anaconda
{ {
"name": "Python 3.10", "name": "Python 3.12",
"build": { "build": {
"context": "..", "context": "..",
"dockerfile": "Dockerfile" "dockerfile": "Dockerfile"

@ -4,7 +4,7 @@ inputs:
python-version: python-version:
description: Python version to use and the Poetry installed with description: Python version to use and the Poetry installed with
required: true required: true
default: '3.10' default: '3.11'
poetry-version: poetry-version:
description: Poetry version to set up description: Poetry version to set up
required: true required: true

@ -20,7 +20,6 @@ jobs:
strategy: strategy:
matrix: matrix:
python-version: python-version:
- "3.10"
- "3.11" - "3.11"
- "3.12" - "3.12"

@ -8,6 +8,8 @@ on:
- api/core/rag/datasource/** - api/core/rag/datasource/**
- docker/** - docker/**
- .github/workflows/vdb-tests.yml - .github/workflows/vdb-tests.yml
- api/poetry.lock
- api/pyproject.toml
concurrency: concurrency:
group: vdb-tests-${{ github.head_ref || github.run_id }} group: vdb-tests-${{ github.head_ref || github.run_id }}
@ -20,7 +22,6 @@ jobs:
strategy: strategy:
matrix: matrix:
python-version: python-version:
- "3.10"
- "3.11" - "3.11"
- "3.12" - "3.12"

@ -1,6 +1,8 @@
# CONTRIBUTING
So you're looking to contribute to Dify - that's awesome, we can't wait to see what you do. As a startup with limited headcount and funding, we have grand ambitions to design the most intuitive workflow for building and managing LLM applications. Any help from the community counts, truly. So you're looking to contribute to Dify - that's awesome, we can't wait to see what you do. As a startup with limited headcount and funding, we have grand ambitions to design the most intuitive workflow for building and managing LLM applications. Any help from the community counts, truly.
We need to be nimble and ship fast given where we are, but we also want to make sure that contributors like you get as smooth an experience at contributing as possible. We've assembled this contribution guide for that purpose, aiming at getting you familiarized with the codebase & how we work with contributors, so you could quickly jump to the fun part. We need to be nimble and ship fast given where we are, but we also want to make sure that contributors like you get as smooth an experience at contributing as possible. We've assembled this contribution guide for that purpose, aiming at getting you familiarized with the codebase & how we work with contributors, so you could quickly jump to the fun part.
This guide, like Dify itself, is a constant work in progress. We highly appreciate your understanding if at times it lags behind the actual project, and welcome any feedback for us to improve. This guide, like Dify itself, is a constant work in progress. We highly appreciate your understanding if at times it lags behind the actual project, and welcome any feedback for us to improve.
@ -10,14 +12,12 @@ In terms of licensing, please take a minute to read our short [License and Contr
[Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types: [Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types:
### Feature requests: ### Feature requests
* If you're opening a new feature request, we'd like you to explain what the proposed feature achieves, and include as much context as possible. [@perzeusss](https://github.com/perzeuss) has made a solid [Feature Request Copilot](https://udify.app/chat/MK2kVSnw1gakVwMX) that helps you draft out your needs. Feel free to give it a try. * If you're opening a new feature request, we'd like you to explain what the proposed feature achieves, and include as much context as possible. [@perzeusss](https://github.com/perzeuss) has made a solid [Feature Request Copilot](https://udify.app/chat/MK2kVSnw1gakVwMX) that helps you draft out your needs. Feel free to give it a try.
* If you want to pick one up from the existing issues, simply drop a comment below it saying so. * If you want to pick one up from the existing issues, simply drop a comment below it saying so.
A team member working in the related direction will be looped in. If all looks good, they will give the go-ahead for you to start coding. We ask that you hold off working on the feature until then, so none of your work goes to waste should we propose changes. A team member working in the related direction will be looped in. If all looks good, they will give the go-ahead for you to start coding. We ask that you hold off working on the feature until then, so none of your work goes to waste should we propose changes.
Depending on whichever area the proposed feature falls under, you might talk to different team members. Here's rundown of the areas each our team members are working on at the moment: Depending on whichever area the proposed feature falls under, you might talk to different team members. Here's rundown of the areas each our team members are working on at the moment:
@ -40,7 +40,7 @@ In terms of licensing, please take a minute to read our short [License and Contr
| Non-core features and minor enhancements | Low Priority | | Non-core features and minor enhancements | Low Priority |
| Valuable but not immediate | Future-Feature | | Valuable but not immediate | Future-Feature |
### Anything else (e.g. bug report, performance optimization, typo correction): ### Anything else (e.g. bug report, performance optimization, typo correction)
* Start coding right away. * Start coding right away.
@ -52,7 +52,6 @@ In terms of licensing, please take a minute to read our short [License and Contr
| Non-critical bugs, performance boosts | Medium Priority | | Non-critical bugs, performance boosts | Medium Priority |
| Minor fixes (typos, confusing but working UI) | Low Priority | | Minor fixes (typos, confusing but working UI) | Low Priority |
## Installing ## Installing
Here are the steps to set up Dify for development: Here are the steps to set up Dify for development:
@ -63,7 +62,7 @@ Here are the steps to set up Dify for development:
Clone the forked repository from your terminal: Clone the forked repository from your terminal:
``` ```shell
git clone git@github.com:<github_username>/dify.git git clone git@github.com:<github_username>/dify.git
``` ```
@ -71,11 +70,11 @@ git clone git@github.com:<github_username>/dify.git
Dify requires the following dependencies to build, make sure they're installed on your system: Dify requires the following dependencies to build, make sure they're installed on your system:
- [Docker](https://www.docker.com/) * [Docker](https://www.docker.com/)
- [Docker Compose](https://docs.docker.com/compose/install/) * [Docker Compose](https://docs.docker.com/compose/install/)
- [Node.js v18.x (LTS)](http://nodejs.org) * [Node.js v18.x (LTS)](http://nodejs.org)
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/) * [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
- [Python](https://www.python.org/) version 3.10.x * [Python](https://www.python.org/) version 3.11.x or 3.12.x
### 4. Installations ### 4. Installations
@ -85,7 +84,7 @@ Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/install-faq) fo
### 5. Visit dify in your browser ### 5. Visit dify in your browser
To validate your set up, head over to [http://localhost:3000](http://localhost:3000) (the default, or your self-configured URL and port) in your browser. You should now see Dify up and running. To validate your set up, head over to [http://localhost:3000](http://localhost:3000) (the default, or your self-configured URL and port) in your browser. You should now see Dify up and running.
## Developing ## Developing
@ -97,9 +96,9 @@ To help you quickly navigate where your contribution fits, a brief, annotated ou
### Backend ### Backend
Difys backend is written in Python using [Flask](https://flask.palletsprojects.com/en/3.0.x/). It uses [SQLAlchemy](https://www.sqlalchemy.org/) for ORM and [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) for task queueing. Authorization logic goes via Flask-login. Difys backend is written in Python using [Flask](https://flask.palletsprojects.com/en/3.0.x/). It uses [SQLAlchemy](https://www.sqlalchemy.org/) for ORM and [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) for task queueing. Authorization logic goes via Flask-login.
``` ```text
[api/] [api/]
├── constants // Constant settings used throughout code base. ├── constants // Constant settings used throughout code base.
├── controllers // API route definitions and request handling logic. ├── controllers // API route definitions and request handling logic.
@ -121,7 +120,7 @@ Difys backend is written in Python using [Flask](https://flask.palletsproject
The website is bootstrapped on [Next.js](https://nextjs.org/) boilerplate in Typescript and uses [Tailwind CSS](https://tailwindcss.com/) for styling. [React-i18next](https://react.i18next.com/) is used for internationalization. The website is bootstrapped on [Next.js](https://nextjs.org/) boilerplate in Typescript and uses [Tailwind CSS](https://tailwindcss.com/) for styling. [React-i18next](https://react.i18next.com/) is used for internationalization.
``` ```text
[web/] [web/]
├── app // layouts, pages, and components ├── app // layouts, pages, and components
│ ├── (commonLayout) // common layout used throughout the app │ ├── (commonLayout) // common layout used throughout the app
@ -149,10 +148,10 @@ The website is bootstrapped on [Next.js](https://nextjs.org/) boilerplate in Typ
## Submitting your PR ## Submitting your PR
At last, time to open a pull request (PR) to our repo. For major features, we first merge them into the `deploy/dev` branch for testing, before they go into the `main` branch. If you run into issues like merge conflicts or don't know how to open a pull request, check out [GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests). At last, time to open a pull request (PR) to our repo. For major features, we first merge them into the `deploy/dev` branch for testing, before they go into the `main` branch. If you run into issues like merge conflicts or don't know how to open a pull request, check out [GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests).
And that's it! Once your PR is merged, you will be featured as a contributor in our [README](https://github.com/langgenius/dify/blob/main/README.md). And that's it! Once your PR is merged, you will be featured as a contributor in our [README](https://github.com/langgenius/dify/blob/main/README.md).
## Getting Help ## Getting Help
If you ever get stuck or got a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat. If you ever get stuck or got a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat.

@ -71,7 +71,7 @@ Dify 依赖以下工具和库:
- [Docker Compose](https://docs.docker.com/compose/install/) - [Docker Compose](https://docs.docker.com/compose/install/)
- [Node.js v18.x (LTS)](http://nodejs.org) - [Node.js v18.x (LTS)](http://nodejs.org)
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/) - [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
- [Python](https://www.python.org/) version 3.10.x - [Python](https://www.python.org/) version 3.11.x or 3.12.x
### 4. 安装 ### 4. 安装

@ -74,7 +74,7 @@ Dify を構築するには次の依存関係が必要です。それらがシス
- [Docker Compose](https://docs.docker.com/compose/install/) - [Docker Compose](https://docs.docker.com/compose/install/)
- [Node.js v18.x (LTS)](http://nodejs.org) - [Node.js v18.x (LTS)](http://nodejs.org)
- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/) - [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/)
- [Python](https://www.python.org/) version 3.10.x - [Python](https://www.python.org/) version 3.11.x or 3.12.x
### 4. インストール ### 4. インストール

@ -73,7 +73,7 @@ Dify yêu cầu các phụ thuộc sau để build, hãy đảm bảo chúng đ
- [Docker Compose](https://docs.docker.com/compose/install/) - [Docker Compose](https://docs.docker.com/compose/install/)
- [Node.js v18.x (LTS)](http://nodejs.org) - [Node.js v18.x (LTS)](http://nodejs.org)
- [npm](https://www.npmjs.com/) phiên bản 8.x.x hoặc [Yarn](https://yarnpkg.com/) - [npm](https://www.npmjs.com/) phiên bản 8.x.x hoặc [Yarn](https://yarnpkg.com/)
- [Python](https://www.python.org/) phiên bản 3.10.x - [Python](https://www.python.org/) phiên bản 3.11.x hoặc 3.12.x
### 4. Cài đặt ### 4. Cài đặt
@ -153,4 +153,4 @@ Và thế là xong! Khi PR của bạn được merge, bạn sẽ được giớ
## Nhận trợ giúp ## Nhận trợ giúp
Nếu bạn gặp khó khăn hoặc có câu hỏi cấp bách trong quá trình đóng góp, hãy đặt câu hỏi của bạn trong vấn đề GitHub liên quan, hoặc tham gia [Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi để trò chuyện nhanh chóng. Nếu bạn gặp khó khăn hoặc có câu hỏi cấp bách trong quá trình đóng góp, hãy đặt câu hỏi của bạn trong vấn đề GitHub liên quan, hoặc tham gia [Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi để trò chuyện nhanh chóng.

@ -42,6 +42,11 @@ REDIS_SENTINEL_USERNAME=
REDIS_SENTINEL_PASSWORD= REDIS_SENTINEL_PASSWORD=
REDIS_SENTINEL_SOCKET_TIMEOUT=0.1 REDIS_SENTINEL_SOCKET_TIMEOUT=0.1
# redis Cluster configuration.
REDIS_USE_CLUSTERS=false
REDIS_CLUSTERS=
REDIS_CLUSTERS_PASSWORD=
# PostgreSQL database configuration # PostgreSQL database configuration
DB_USERNAME=postgres DB_USERNAME=postgres
DB_PASSWORD=difyai123456 DB_PASSWORD=difyai123456

@ -1,5 +1,5 @@
# base image # base image
FROM python:3.10-slim-bookworm AS base FROM python:3.12-slim-bookworm AS base
WORKDIR /app/api WORKDIR /app/api

@ -18,12 +18,17 @@
``` ```
2. Copy `.env.example` to `.env` 2. Copy `.env.example` to `.env`
```cli
cp .env.example .env
```
3. Generate a `SECRET_KEY` in the `.env` file. 3. Generate a `SECRET_KEY` in the `.env` file.
bash for Linux
```bash for Linux ```bash for Linux
sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env
``` ```
bash for Mac
```bash for Mac ```bash for Mac
secret_key=$(openssl rand -base64 42) secret_key=$(openssl rand -base64 42)
sed -i '' "/^SECRET_KEY=/c\\ sed -i '' "/^SECRET_KEY=/c\\
@ -37,18 +42,10 @@
5. Install dependencies 5. Install dependencies
```bash ```bash
poetry env use 3.10 poetry env use 3.12
poetry install poetry install
``` ```
In case of contributors missing to update dependencies for `pyproject.toml`, you can perform the following shell instead.
```bash
poetry shell # activate current environment
poetry add $(cat requirements.txt) # install dependencies of production and update pyproject.toml
poetry add $(cat requirements-dev.txt) --group dev # install dependencies of development and update pyproject.toml
```
6. Run migrate 6. Run migrate
Before the first launch, migrate the database to the latest version. Before the first launch, migrate the database to the latest version.
@ -84,5 +81,3 @@
```bash ```bash
poetry run -C api bash dev/pytest/pytest_all_tests.sh poetry run -C api bash dev/pytest/pytest_all_tests.sh
``` ```

@ -1,6 +1,11 @@
import os import os
import sys import sys
python_version = sys.version_info
if not ((3, 11) <= python_version < (3, 13)):
print(f"Python 3.11 or 3.12 is required, current version is {python_version.major}.{python_version.minor}")
raise SystemExit(1)
from configs import dify_config from configs import dify_config
if not dify_config.DEBUG: if not dify_config.DEBUG:
@ -30,9 +35,6 @@ from models import account, dataset, model, source, task, tool, tools, web # no
# DO NOT REMOVE ABOVE # DO NOT REMOVE ABOVE
if sys.version_info[:2] == (3, 10):
print("Warning: Python 3.10 will not be supported in the next version.")
warnings.simplefilter("ignore", ResourceWarning) warnings.simplefilter("ignore", ResourceWarning)

@ -27,7 +27,6 @@ class DifyConfig(
# read from dotenv format config file # read from dotenv format config file
env_file=".env", env_file=".env",
env_file_encoding="utf-8", env_file_encoding="utf-8",
frozen=True,
# ignore extra attributes # ignore extra attributes
extra="ignore", extra="ignore",
) )

@ -68,3 +68,18 @@ class RedisConfig(BaseSettings):
description="Socket timeout in seconds for Redis Sentinel connections", description="Socket timeout in seconds for Redis Sentinel connections",
default=0.1, default=0.1,
) )
REDIS_USE_CLUSTERS: bool = Field(
description="Enable Redis Clusters mode for high availability",
default=False,
)
REDIS_CLUSTERS: Optional[str] = Field(
description="Comma-separated list of Redis Clusters nodes (host:port)",
default=None,
)
REDIS_CLUSTERS_PASSWORD: Optional[str] = Field(
description="Password for Redis Clusters authentication (if required)",
default=None,
)

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field( CURRENT_VERSION: str = Field(
description="Dify version", description="Dify version",
default="0.11.2", default="0.12.0",
) )
COMMIT_SHA: str = Field( COMMIT_SHA: str = Field(

@ -2,6 +2,7 @@ from flask import Blueprint
from libs.external_api import ExternalApi from libs.external_api import ExternalApi
from .app.app_import import AppImportApi, AppImportConfirmApi
from .files import FileApi, FilePreviewApi, FileSupportTypeApi from .files import FileApi, FilePreviewApi, FileSupportTypeApi
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
@ -17,6 +18,10 @@ api.add_resource(FileSupportTypeApi, "/files/support-type")
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>") api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
api.add_resource(RemoteFileUploadApi, "/remote-files/upload") api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
# Import App
api.add_resource(AppImportApi, "/apps/imports")
api.add_resource(AppImportConfirmApi, "/apps/imports/<string:import_id>/confirm")
# Import other controllers # Import other controllers
from . import admin, apikey, extension, feature, ping, setup, version from . import admin, apikey, extension, feature, ping, setup, version

@ -1,7 +1,10 @@
import uuid import uuid
from typing import cast
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, inputs, marshal, marshal_with, reqparse from flask_restful import Resource, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, abort from werkzeug.exceptions import BadRequest, Forbidden, abort
from controllers.console import api from controllers.console import api
@ -13,13 +16,15 @@ from controllers.console.wraps import (
setup_required, setup_required,
) )
from core.ops.ops_trace_manager import OpsTraceManager from core.ops.ops_trace_manager import OpsTraceManager
from extensions.ext_database import db
from fields.app_fields import ( from fields.app_fields import (
app_detail_fields, app_detail_fields,
app_detail_fields_with_site, app_detail_fields_with_site,
app_pagination_fields, app_pagination_fields,
) )
from libs.login import login_required from libs.login import login_required
from services.app_dsl_service import AppDslService from models import Account, App
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService from services.app_service import AppService
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
@ -92,61 +97,6 @@ class AppListApi(Resource):
return app, 201 return app, 201
class AppImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_detail_fields_with_site)
@cloud_edition_billing_resource_check("apps")
def post(self):
"""Import app"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args()
app = AppDslService.import_and_create_new_app(
tenant_id=current_user.current_tenant_id, data=args["data"], args=args, account=current_user
)
return app, 201
class AppImportFromUrlApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_detail_fields_with_site)
@cloud_edition_billing_resource_check("apps")
def post(self):
"""Import app from url"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("url", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args()
app = AppDslService.import_and_create_new_app_from_url(
tenant_id=current_user.current_tenant_id, url=args["url"], args=args, account=current_user
)
return app, 201
class AppApi(Resource): class AppApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -224,10 +174,24 @@ class AppCopyApi(Resource):
parser.add_argument("icon_background", type=str, location="json") parser.add_argument("icon_background", type=str, location="json")
args = parser.parse_args() args = parser.parse_args()
data = AppDslService.export_dsl(app_model=app_model, include_secret=True) with Session(db.engine) as session:
app = AppDslService.import_and_create_new_app( import_service = AppDslService(session)
tenant_id=current_user.current_tenant_id, data=data, args=args, account=current_user yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
) account = cast(Account, current_user)
result = import_service.import_app(
account=account,
import_mode=ImportMode.YAML_CONTENT.value,
yaml_content=yaml_content,
name=args.get("name"),
description=args.get("description"),
icon_type=args.get("icon_type"),
icon=args.get("icon"),
icon_background=args.get("icon_background"),
)
session.commit()
stmt = select(App).where(App.id == result.app_id)
app = session.scalar(stmt)
return app, 201 return app, 201
@ -368,8 +332,6 @@ class AppTraceApi(Resource):
api.add_resource(AppListApi, "/apps") api.add_resource(AppListApi, "/apps")
api.add_resource(AppImportApi, "/apps/import")
api.add_resource(AppImportFromUrlApi, "/apps/import/url")
api.add_resource(AppApi, "/apps/<uuid:app_id>") api.add_resource(AppApi, "/apps/<uuid:app_id>")
api.add_resource(AppCopyApi, "/apps/<uuid:app_id>/copy") api.add_resource(AppCopyApi, "/apps/<uuid:app_id>/copy")
api.add_resource(AppExportApi, "/apps/<uuid:app_id>/export") api.add_resource(AppExportApi, "/apps/<uuid:app_id>/export")

@ -0,0 +1,90 @@
from typing import cast
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from extensions.ext_database import db
from fields.app_fields import app_import_fields
from libs.login import login_required
from models import Account
from services.app_dsl_service import AppDslService, ImportStatus
class AppImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_import_fields)
def post(self):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("mode", type=str, required=True, location="json")
parser.add_argument("yaml_content", type=str, location="json")
parser.add_argument("yaml_url", type=str, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("app_id", type=str, location="json")
args = parser.parse_args()
# Create service with session
with Session(db.engine) as session:
import_service = AppDslService(session)
# Import app
account = cast(Account, current_user)
result = import_service.import_app(
account=account,
import_mode=args["mode"],
yaml_content=args.get("yaml_content"),
yaml_url=args.get("yaml_url"),
name=args.get("name"),
description=args.get("description"),
icon_type=args.get("icon_type"),
icon=args.get("icon"),
icon_background=args.get("icon_background"),
app_id=args.get("app_id"),
)
session.commit()
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
class AppImportConfirmApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_import_fields)
def post(self, import_id):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
# Create service with session
with Session(db.engine) as session:
import_service = AppDslService(session)
# Confirm import
account = cast(Account, current_user)
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200

@ -1,4 +1,4 @@
from datetime import datetime, timezone from datetime import UTC, datetime
import pytz import pytz
from flask_login import current_user from flask_login import current_user
@ -314,7 +314,7 @@ def _get_conversation(app_model, conversation_id):
raise NotFound("Conversation Not Exists.") raise NotFound("Conversation Not Exists.")
if not conversation.read_at: if not conversation.read_at:
conversation.read_at = datetime.now(timezone.utc).replace(tzinfo=None) conversation.read_at = datetime.now(UTC).replace(tzinfo=None)
conversation.read_account_id = current_user.id conversation.read_account_id = current_user.id
db.session.commit() db.session.commit()

@ -1,4 +1,4 @@
from datetime import datetime, timezone from datetime import UTC, datetime
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse
@ -75,7 +75,7 @@ class AppSite(Resource):
setattr(site, attr_name, value) setattr(site, attr_name, value)
site.updated_by = current_user.id site.updated_by = current_user.id
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) site.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
return site return site
@ -99,7 +99,7 @@ class AppSiteAccessTokenReset(Resource):
site.code = Site.generate_code(16) site.code = Site.generate_code(16)
site.updated_by = current_user.id site.updated_by = current_user.id
site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) site.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
return site return site

@ -20,7 +20,6 @@ from libs.helper import TimestampField, uuid_value
from libs.login import current_user, login_required from libs.login import current_user, login_required
from models import App from models import App
from models.model import AppMode from models.model import AppMode
from services.app_dsl_service import AppDslService
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError from services.errors.app import WorkflowHashNotEqualError
from services.workflow_service import WorkflowService from services.workflow_service import WorkflowService
@ -126,31 +125,6 @@ class DraftWorkflowApi(Resource):
} }
class DraftWorkflowImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_fields)
def post(self, app_model: App):
"""
Import draft workflow
"""
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("data", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
workflow = AppDslService.import_and_overwrite_workflow(
app_model=app_model, data=args["data"], account=current_user
)
return workflow
class AdvancedChatDraftWorkflowRunApi(Resource): class AdvancedChatDraftWorkflowRunApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -453,7 +427,6 @@ class ConvertToWorkflowApi(Resource):
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft") api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
api.add_resource(DraftWorkflowImportApi, "/apps/<uuid:app_id>/workflows/draft/import")
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run") api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run") api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop") api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")

@ -65,7 +65,7 @@ class ActivateApi(Resource):
account.timezone = args["timezone"] account.timezone = args["timezone"]
account.interface_theme = "light" account.interface_theme = "light"
account.status = AccountStatus.ACTIVE.value account.status = AccountStatus.ACTIVE.value
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))

@ -1,5 +1,5 @@
import logging import logging
from datetime import datetime, timezone from datetime import UTC, datetime
from typing import Optional from typing import Optional
import requests import requests
@ -106,7 +106,7 @@ class OAuthCallback(Resource):
if account.status == AccountStatus.PENDING.value: if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value account.status = AccountStatus.ACTIVE.value
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
try: try:

@ -83,7 +83,7 @@ class DataSourceApi(Resource):
if action == "enable": if action == "enable":
if data_source_binding.disabled: if data_source_binding.disabled:
data_source_binding.disabled = False data_source_binding.disabled = False
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.add(data_source_binding) db.session.add(data_source_binding)
db.session.commit() db.session.commit()
else: else:
@ -92,7 +92,7 @@ class DataSourceApi(Resource):
if action == "disable": if action == "disable":
if not data_source_binding.disabled: if not data_source_binding.disabled:
data_source_binding.disabled = True data_source_binding.disabled = True
data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.add(data_source_binding) db.session.add(data_source_binding)
db.session.commit() db.session.commit()
else: else:

@ -1,6 +1,6 @@
import logging import logging
from argparse import ArgumentTypeError from argparse import ArgumentTypeError
from datetime import datetime, timezone from datetime import UTC, datetime
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
@ -665,7 +665,7 @@ class DocumentProcessingApi(DocumentResource):
raise InvalidActionError("Document not in indexing state.") raise InvalidActionError("Document not in indexing state.")
document.paused_by = current_user.id document.paused_by = current_user.id
document.paused_at = datetime.now(timezone.utc).replace(tzinfo=None) document.paused_at = datetime.now(UTC).replace(tzinfo=None)
document.is_paused = True document.is_paused = True
db.session.commit() db.session.commit()
@ -745,7 +745,7 @@ class DocumentMetadataApi(DocumentResource):
document.doc_metadata[key] = value document.doc_metadata[key] = value
document.doc_type = doc_type document.doc_type = doc_type
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
return {"result": "success", "message": "Document metadata updated."}, 200 return {"result": "success", "message": "Document metadata updated."}, 200
@ -787,7 +787,7 @@ class DocumentStatusApi(DocumentResource):
document.enabled = True document.enabled = True
document.disabled_at = None document.disabled_at = None
document.disabled_by = None document.disabled_by = None
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
# Set cache to prevent indexing the same document multiple times # Set cache to prevent indexing the same document multiple times
@ -804,9 +804,9 @@ class DocumentStatusApi(DocumentResource):
raise InvalidActionError("Document already disabled.") raise InvalidActionError("Document already disabled.")
document.enabled = False document.enabled = False
document.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None) document.disabled_at = datetime.now(UTC).replace(tzinfo=None)
document.disabled_by = current_user.id document.disabled_by = current_user.id
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
# Set cache to prevent indexing the same document multiple times # Set cache to prevent indexing the same document multiple times
@ -821,9 +821,9 @@ class DocumentStatusApi(DocumentResource):
raise InvalidActionError("Document already archived.") raise InvalidActionError("Document already archived.")
document.archived = True document.archived = True
document.archived_at = datetime.now(timezone.utc).replace(tzinfo=None) document.archived_at = datetime.now(UTC).replace(tzinfo=None)
document.archived_by = current_user.id document.archived_by = current_user.id
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
if document.enabled: if document.enabled:
@ -840,7 +840,7 @@ class DocumentStatusApi(DocumentResource):
document.archived = False document.archived = False
document.archived_at = None document.archived_at = None
document.archived_by = None document.archived_by = None
document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
# Set cache to prevent indexing the same document multiple times # Set cache to prevent indexing the same document multiple times

@ -1,5 +1,5 @@
import uuid import uuid
from datetime import datetime, timezone from datetime import UTC, datetime
import pandas as pd import pandas as pd
from flask import request from flask import request
@ -188,7 +188,7 @@ class DatasetDocumentSegmentApi(Resource):
raise InvalidActionError("Segment is already disabled.") raise InvalidActionError("Segment is already disabled.")
segment.enabled = False segment.enabled = False
segment.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None) segment.disabled_at = datetime.now(UTC).replace(tzinfo=None)
segment.disabled_by = current_user.id segment.disabled_by = current_user.id
db.session.commit() db.session.commit()

@ -1,5 +1,5 @@
import logging import logging
from datetime import datetime, timezone from datetime import UTC, datetime
from flask_login import current_user from flask_login import current_user
from flask_restful import reqparse from flask_restful import reqparse
@ -46,7 +46,7 @@ class CompletionApi(InstalledAppResource):
streaming = args["response_mode"] == "streaming" streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False args["auto_generate_name"] = False
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
try: try:
@ -106,7 +106,7 @@ class ChatApi(InstalledAppResource):
args["auto_generate_name"] = False args["auto_generate_name"] = False
installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
try: try:

@ -1,4 +1,4 @@
from datetime import datetime, timezone from datetime import UTC, datetime
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, inputs, marshal_with, reqparse from flask_restful import Resource, inputs, marshal_with, reqparse
@ -81,7 +81,7 @@ class InstalledAppsListApi(Resource):
tenant_id=current_tenant_id, tenant_id=current_tenant_id,
app_owner_tenant_id=app.tenant_id, app_owner_tenant_id=app.tenant_id,
is_pinned=False, is_pinned=False,
last_used_at=datetime.now(timezone.utc).replace(tzinfo=None), last_used_at=datetime.now(UTC).replace(tzinfo=None),
) )
db.session.add(new_installed_app) db.session.add(new_installed_app)
db.session.commit() db.session.commit()

@ -60,7 +60,7 @@ class AccountInitApi(Resource):
raise InvalidInvitationCodeError() raise InvalidInvitationCodeError()
invitation_code.status = "used" invitation_code.status = "used"
invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) invitation_code.used_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_tenant_id = account.current_tenant_id
invitation_code.used_by_account_id = account.id invitation_code.used_by_account_id = account.id
@ -68,7 +68,7 @@ class AccountInitApi(Resource):
account.timezone = args["timezone"] account.timezone = args["timezone"]
account.interface_theme = "light" account.interface_theme = "light"
account.status = "active" account.status = "active"
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
return {"result": "success"} return {"result": "success"}

@ -1,5 +1,5 @@
from collections.abc import Callable from collections.abc import Callable
from datetime import datetime, timezone from datetime import UTC, datetime
from enum import Enum from enum import Enum
from functools import wraps from functools import wraps
from typing import Optional from typing import Optional
@ -198,7 +198,7 @@ def validate_and_get_api_token(scope=None):
if not api_token: if not api_token:
raise Unauthorized("Access token is invalid") raise Unauthorized("Access token is invalid")
api_token.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
return api_token return api_token

@ -2,7 +2,7 @@ import json
import logging import logging
import uuid import uuid
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from datetime import datetime, timezone from datetime import UTC, datetime
from typing import Optional, Union, cast from typing import Optional, Union, cast
from core.agent.entities import AgentEntity, AgentToolEntity from core.agent.entities import AgentEntity, AgentToolEntity
@ -114,16 +114,9 @@ class BaseAgentRunner(AppRunner):
# check if model supports stream tool call # check if model supports stream tool call
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []): features = model_schema.features if model_schema and model_schema.features else []
self.stream_tool_call = True self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
else: self.files = application_generate_entity.files if ModelFeature.VISION in features else []
self.stream_tool_call = False
# check if model supports vision
if model_schema and ModelFeature.VISION in (model_schema.features or []):
self.files = application_generate_entity.files
else:
self.files = []
self.query = None self.query = None
self._current_thoughts: list[PromptMessage] = [] self._current_thoughts: list[PromptMessage] = []
@ -250,7 +243,7 @@ class BaseAgentRunner(AppRunner):
update prompt message tool update prompt message tool
""" """
# try to get tool runtime parameters # try to get tool runtime parameters
tool_runtime_parameters = tool.get_runtime_parameters() or [] tool_runtime_parameters = tool.get_runtime_parameters()
for parameter in tool_runtime_parameters: for parameter in tool_runtime_parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM: if parameter.form != ToolParameter.ToolParameterForm.LLM:
@ -419,7 +412,7 @@ class BaseAgentRunner(AppRunner):
.first() .first()
) )
db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None)
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
db.session.commit() db.session.commit()
db.session.close() db.session.close()

@ -1,3 +1,4 @@
import uuid
from typing import Optional from typing import Optional
from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity

@ -11,7 +11,7 @@ from core.provider_manager import ProviderManager
class ModelConfigConverter: class ModelConfigConverter:
@classmethod @classmethod
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity: def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity:
""" """
Convert app model config dict to entity. Convert app model config dict to entity.
:param app_config: app config :param app_config: app config
@ -38,27 +38,23 @@ class ModelConfigConverter:
) )
if model_credentials is None: if model_credentials is None:
if not skip_check: raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
else: # check model
model_credentials = {} provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_config.model, model_type=ModelType.LLM
if not skip_check: )
# check model
provider_model = provider_model_bundle.configuration.get_provider_model( if provider_model is None:
model=model_config.model, model_type=ModelType.LLM model_name = model_config.model
) raise ValueError(f"Model {model_name} not exist.")
if provider_model is None: if provider_model.status == ModelStatus.NO_CONFIGURE:
model_name = model_config.model raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
raise ValueError(f"Model {model_name} not exist.") elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
if provider_model.status == ModelStatus.NO_CONFIGURE: elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config # model config
completion_params = model_config.parameters completion_params = model_config.parameters
@ -76,7 +72,7 @@ class ModelConfigConverter:
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials) model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
if not skip_check and not model_schema: if not model_schema:
raise ValueError(f"Model {model_name} not exist.") raise ValueError(f"Model {model_name} not exist.")
return ModelConfigWithCredentialsEntity( return ModelConfigWithCredentialsEntity(

@ -1,4 +1,5 @@
from core.app.app_config.entities import ( from core.app.app_config.entities import (
AdvancedChatMessageEntity,
AdvancedChatPromptTemplateEntity, AdvancedChatPromptTemplateEntity,
AdvancedCompletionPromptTemplateEntity, AdvancedCompletionPromptTemplateEntity,
PromptTemplateEntity, PromptTemplateEntity,
@ -25,7 +26,9 @@ class PromptTemplateConfigManager:
chat_prompt_messages = [] chat_prompt_messages = []
for message in chat_prompt_config.get("prompt", []): for message in chat_prompt_config.get("prompt", []):
chat_prompt_messages.append( chat_prompt_messages.append(
{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])} AdvancedChatMessageEntity(
**{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
)
) )
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages) advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)

@ -1,5 +1,5 @@
from collections.abc import Sequence from collections.abc import Sequence
from enum import Enum from enum import Enum, StrEnum
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@ -88,7 +88,7 @@ class PromptTemplateEntity(BaseModel):
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
class VariableEntityType(str, Enum): class VariableEntityType(StrEnum):
TEXT_INPUT = "text-input" TEXT_INPUT = "text-input"
SELECT = "select" SELECT = "select"
PARAGRAPH = "paragraph" PARAGRAPH = "paragraph"

@ -127,7 +127,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation_id=conversation.id if conversation else None, conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs inputs=conversation.inputs
if conversation if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), else self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query, query=query,
files=file_objs, files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

@ -134,7 +134,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
conversation_id=conversation.id if conversation else None, conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs inputs=conversation.inputs
if conversation if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), else self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query, query=query,
files=file_objs, files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

@ -1,4 +1,4 @@
from collections.abc import Mapping from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from core.app.app_config.entities import VariableEntityType from core.app.app_config.entities import VariableEntityType
@ -6,7 +6,7 @@ from core.file import File, FileUploadConfig
from factories import file_factory from factories import file_factory
if TYPE_CHECKING: if TYPE_CHECKING:
from core.app.app_config.entities import AppConfig, VariableEntity from core.app.app_config.entities import VariableEntity
class BaseAppGenerator: class BaseAppGenerator:
@ -14,23 +14,23 @@ class BaseAppGenerator:
self, self,
*, *,
user_inputs: Optional[Mapping[str, Any]], user_inputs: Optional[Mapping[str, Any]],
app_config: "AppConfig", variables: Sequence["VariableEntity"],
tenant_id: str,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
user_inputs = user_inputs or {} user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values # Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables
user_inputs = { user_inputs = {
var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var) var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var)
for var in variables for var in variables
} }
user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()} user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()}
# Convert files in inputs to File # Convert files in inputs to File
entity_dictionary = {item.variable: item for item in app_config.variables} entity_dictionary = {item.variable: item for item in variables}
# Convert single file to File # Convert single file to File
files_inputs = { files_inputs = {
k: file_factory.build_from_mapping( k: file_factory.build_from_mapping(
mapping=v, mapping=v,
tenant_id=app_config.tenant_id, tenant_id=tenant_id,
config=FileUploadConfig( config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
@ -44,7 +44,7 @@ class BaseAppGenerator:
file_list_inputs = { file_list_inputs = {
k: file_factory.build_from_mappings( k: file_factory.build_from_mappings(
mappings=v, mappings=v,
tenant_id=app_config.tenant_id, tenant_id=tenant_id,
config=FileUploadConfig( config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,

@ -132,7 +132,9 @@ class ChatAppGenerator(MessageBasedAppGenerator):
conversation_id=conversation.id if conversation else None, conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs inputs=conversation.inputs
if conversation if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), else self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query, query=query,
files=file_objs, files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

@ -113,7 +113,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
app_config=app_config, app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config), model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config, file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
query=query, query=query,
files=file_objs, files=file_objs,
user_id=user.id, user_id=user.id,

@ -1,7 +1,7 @@
import json import json
import logging import logging
from collections.abc import Generator from collections.abc import Generator
from datetime import datetime, timezone from datetime import UTC, datetime
from typing import Optional, Union from typing import Optional, Union
from sqlalchemy import and_ from sqlalchemy import and_
@ -200,7 +200,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
db.session.commit() db.session.commit()
db.session.refresh(conversation) db.session.refresh(conversation)
else: else:
conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) conversation.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
message = Message( message = Message(

@ -96,7 +96,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
task_id=str(uuid.uuid4()), task_id=str(uuid.uuid4()),
app_config=app_config, app_config=app_config,
file_upload_config=file_extra_config, file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), inputs=self._prepare_user_inputs(
user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id
),
files=system_files, files=system_files,
user_id=user.id, user_id=user.id,
stream=stream, stream=stream,

@ -43,7 +43,6 @@ from core.workflow.graph_engine.entities.event import (
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.iteration import IterationNodeData
from core.workflow.nodes.node_mapping import node_type_classes_mapping from core.workflow.nodes.node_mapping import node_type_classes_mapping
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
@ -160,8 +159,6 @@ class WorkflowBasedAppRunner(AppRunner):
user_inputs=user_inputs, user_inputs=user_inputs,
variable_pool=variable_pool, variable_pool=variable_pool,
tenant_id=workflow.tenant_id, tenant_id=workflow.tenant_id,
node_type=node_type,
node_data=IterationNodeData(**iteration_node_config.get("data", {})),
) )
return graph, variable_pool return graph, variable_pool

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum, StrEnum
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
@ -11,7 +11,7 @@ from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNodeData from core.workflow.nodes.base import BaseNodeData
class QueueEvent(str, Enum): class QueueEvent(StrEnum):
""" """
QueueEvent enum QueueEvent enum
""" """

@ -1,8 +1,9 @@
import json import json
import time import time
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from datetime import datetime, timezone from datetime import UTC, datetime
from typing import Any, Optional, Union, cast from typing import Any, Optional, Union, cast
from uuid import uuid4
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -80,38 +81,38 @@ class WorkflowCycleManage:
inputs[f"sys.{key.value}"] = value inputs[f"sys.{key.value}"] = value
inputs = WorkflowEntry.handle_special_values(inputs)
triggered_from = ( triggered_from = (
WorkflowRunTriggeredFrom.DEBUGGING WorkflowRunTriggeredFrom.DEBUGGING
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
else WorkflowRunTriggeredFrom.APP_RUN else WorkflowRunTriggeredFrom.APP_RUN
) )
# handle special values
inputs = WorkflowEntry.handle_special_values(inputs)
# init workflow run # init workflow run
workflow_run = WorkflowRun() with Session(db.engine, expire_on_commit=False) as session:
workflow_run_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID] workflow_run = WorkflowRun()
if workflow_run_id: system_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID]
workflow_run.id = workflow_run_id workflow_run.id = system_id or str(uuid4())
workflow_run.tenant_id = self._workflow.tenant_id workflow_run.tenant_id = self._workflow.tenant_id
workflow_run.app_id = self._workflow.app_id workflow_run.app_id = self._workflow.app_id
workflow_run.sequence_number = new_sequence_number workflow_run.sequence_number = new_sequence_number
workflow_run.workflow_id = self._workflow.id workflow_run.workflow_id = self._workflow.id
workflow_run.type = self._workflow.type workflow_run.type = self._workflow.type
workflow_run.triggered_from = triggered_from.value workflow_run.triggered_from = triggered_from.value
workflow_run.version = self._workflow.version workflow_run.version = self._workflow.version
workflow_run.graph = self._workflow.graph workflow_run.graph = self._workflow.graph
workflow_run.inputs = json.dumps(inputs) workflow_run.inputs = json.dumps(inputs)
workflow_run.status = WorkflowRunStatus.RUNNING.value workflow_run.status = WorkflowRunStatus.RUNNING
workflow_run.created_by_role = ( workflow_run.created_by_role = (
CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value CreatedByRole.ACCOUNT if isinstance(self._user, Account) else CreatedByRole.END_USER
) )
workflow_run.created_by = self._user.id workflow_run.created_by = self._user.id
workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
db.session.add(workflow_run) session.add(workflow_run)
db.session.commit() session.commit()
db.session.refresh(workflow_run)
db.session.close()
return workflow_run return workflow_run
@ -144,7 +145,7 @@ class WorkflowCycleManage:
workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
db.session.refresh(workflow_run) db.session.refresh(workflow_run)
@ -191,7 +192,7 @@ class WorkflowCycleManage:
workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
@ -211,15 +212,18 @@ class WorkflowCycleManage:
for workflow_node_execution in running_workflow_node_executions: for workflow_node_execution in running_workflow_node_executions:
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error workflow_node_execution.error = error
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.elapsed_time = ( workflow_node_execution.elapsed_time = (
workflow_node_execution.finished_at - workflow_node_execution.created_at workflow_node_execution.finished_at - workflow_node_execution.created_at
).total_seconds() ).total_seconds()
db.session.commit() db.session.commit()
db.session.refresh(workflow_run)
db.session.close() db.session.close()
with Session(db.engine, expire_on_commit=False) as session:
session.add(workflow_run)
session.refresh(workflow_run)
if trace_manager: if trace_manager:
trace_manager.add_trace_task( trace_manager.add_trace_task(
TraceTask( TraceTask(
@ -259,7 +263,7 @@ class WorkflowCycleManage:
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
} }
) )
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
session.add(workflow_node_execution) session.add(workflow_node_execution)
session.commit() session.commit()
@ -282,7 +286,7 @@ class WorkflowCycleManage:
execution_metadata = ( execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
) )
finished_at = datetime.now(timezone.utc).replace(tzinfo=None) finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds() elapsed_time = (finished_at - event.start_at).total_seconds()
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
@ -326,7 +330,7 @@ class WorkflowCycleManage:
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data) process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs) outputs = WorkflowEntry.handle_special_values(event.outputs)
finished_at = datetime.now(timezone.utc).replace(tzinfo=None) finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds() elapsed_time = (finished_at - event.start_at).total_seconds()
execution_metadata = ( execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
@ -381,7 +385,7 @@ class WorkflowCycleManage:
id=workflow_run.id, id=workflow_run.id,
workflow_id=workflow_run.workflow_id, workflow_id=workflow_run.workflow_id,
sequence_number=workflow_run.sequence_number, sequence_number=workflow_run.sequence_number,
inputs=workflow_run.inputs_dict or {}, inputs=workflow_run.inputs_dict,
created_at=int(workflow_run.created_at.timestamp()), created_at=int(workflow_run.created_at.timestamp()),
), ),
) )
@ -428,7 +432,7 @@ class WorkflowCycleManage:
created_by=created_by, created_by=created_by,
created_at=int(workflow_run.created_at.timestamp()), created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp()), finished_at=int(workflow_run.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}), files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
), ),
) )
@ -654,7 +658,7 @@ class WorkflowCycleManage:
if event.error is None if event.error is None
else WorkflowNodeExecutionStatus.FAILED, else WorkflowNodeExecutionStatus.FAILED,
error=None, error=None,
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(), elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
execution_metadata=event.metadata, execution_metadata=event.metadata,
finished_at=int(time.time()), finished_at=int(time.time()),

@ -240,7 +240,7 @@ class ProviderConfiguration(BaseModel):
if provider_record: if provider_record:
provider_record.encrypted_config = json.dumps(credentials) provider_record.encrypted_config = json.dumps(credentials)
provider_record.is_valid = True provider_record.is_valid = True
provider_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
provider_record = Provider( provider_record = Provider(
@ -394,7 +394,7 @@ class ProviderConfiguration(BaseModel):
if provider_model_record: if provider_model_record:
provider_model_record.encrypted_config = json.dumps(credentials) provider_model_record.encrypted_config = json.dumps(credentials)
provider_model_record.is_valid = True provider_model_record.is_valid = True
provider_model_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
provider_model_record = ProviderModel( provider_model_record = ProviderModel(
@ -468,7 +468,7 @@ class ProviderConfiguration(BaseModel):
if model_setting: if model_setting:
model_setting.enabled = True model_setting.enabled = True
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
model_setting = ProviderModelSetting( model_setting = ProviderModelSetting(
@ -503,7 +503,7 @@ class ProviderConfiguration(BaseModel):
if model_setting: if model_setting:
model_setting.enabled = False model_setting.enabled = False
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
model_setting = ProviderModelSetting( model_setting = ProviderModelSetting(
@ -570,7 +570,7 @@ class ProviderConfiguration(BaseModel):
if model_setting: if model_setting:
model_setting.load_balancing_enabled = True model_setting.load_balancing_enabled = True
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
model_setting = ProviderModelSetting( model_setting = ProviderModelSetting(
@ -605,7 +605,7 @@ class ProviderConfiguration(BaseModel):
if model_setting: if model_setting:
model_setting.load_balancing_enabled = False model_setting.load_balancing_enabled = False
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
else: else:
model_setting = ProviderModelSetting( model_setting = ProviderModelSetting(

@ -1,7 +1,7 @@
from enum import Enum from enum import StrEnum
class FileType(str, Enum): class FileType(StrEnum):
IMAGE = "image" IMAGE = "image"
DOCUMENT = "document" DOCUMENT = "document"
AUDIO = "audio" AUDIO = "audio"
@ -16,7 +16,7 @@ class FileType(str, Enum):
raise ValueError(f"No matching enum found for value '{value}'") raise ValueError(f"No matching enum found for value '{value}'")
class FileTransferMethod(str, Enum): class FileTransferMethod(StrEnum):
REMOTE_URL = "remote_url" REMOTE_URL = "remote_url"
LOCAL_FILE = "local_file" LOCAL_FILE = "local_file"
TOOL_FILE = "tool_file" TOOL_FILE = "tool_file"
@ -29,7 +29,7 @@ class FileTransferMethod(str, Enum):
raise ValueError(f"No matching enum found for value '{value}'") raise ValueError(f"No matching enum found for value '{value}'")
class FileBelongsTo(str, Enum): class FileBelongsTo(StrEnum):
USER = "user" USER = "user"
ASSISTANT = "assistant" ASSISTANT = "assistant"
@ -41,7 +41,7 @@ class FileBelongsTo(str, Enum):
raise ValueError(f"No matching enum found for value '{value}'") raise ValueError(f"No matching enum found for value '{value}'")
class FileAttribute(str, Enum): class FileAttribute(StrEnum):
TYPE = "type" TYPE = "type"
SIZE = "size" SIZE = "size"
NAME = "name" NAME = "name"
@ -51,5 +51,5 @@ class FileAttribute(str, Enum):
EXTENSION = "extension" EXTENSION = "extension"
class ArrayFileAttribute(str, Enum): class ArrayFileAttribute(StrEnum):
LENGTH = "length" LENGTH = "length"

@ -3,7 +3,12 @@ import base64
from configs import dify_config from configs import dify_config
from core.file import file_repository from core.file import file_repository
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent from core.model_runtime.entities import (
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
VideoPromptMessageContent,
)
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_storage import storage from extensions.ext_storage import storage
@ -29,35 +34,17 @@ def get_attr(*, file: File, attr: FileAttribute):
return file.remote_url return file.remote_url
case FileAttribute.EXTENSION: case FileAttribute.EXTENSION:
return file.extension return file.extension
case _:
raise ValueError(f"Invalid file attribute: {attr}")
def to_prompt_message_content( def to_prompt_message_content(
f: File, f: File,
/, /,
*, *,
image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW, image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
): ):
"""
Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object.
This function takes a File object and converts it to an appropriate PromptMessageContent
object, which can be used as a prompt for image or audio-based AI models.
Args:
f (File): The File object to convert.
detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts.
If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW.
Returns:
Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level
Raises:
ValueError: If the file type is not supported or if required data is missing.
"""
match f.type: match f.type:
case FileType.IMAGE: case FileType.IMAGE:
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url": if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
data = _to_url(f) data = _to_url(f)
else: else:
@ -65,7 +52,7 @@ def to_prompt_message_content(
return ImagePromptMessageContent(data=data, detail=image_detail_config) return ImagePromptMessageContent(data=data, detail=image_detail_config)
case FileType.AUDIO: case FileType.AUDIO:
encoded_string = _file_to_encoded_string(f) encoded_string = _get_encoded_string(f)
if f.extension is None: if f.extension is None:
raise ValueError("Missing file extension") raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
@ -74,9 +61,20 @@ def to_prompt_message_content(
data = _to_url(f) data = _to_url(f)
else: else:
data = _to_base64_data_string(f) data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.DOCUMENT:
data = _get_encoded_string(f)
if f.mime_type is None:
raise ValueError("Missing file mime_type")
return DocumentPromptMessageContent(
encode_format="base64",
mime_type=f.mime_type,
data=data,
)
case _: case _:
raise ValueError("file type f.type is not supported") raise ValueError(f"file type {f.type} is not supported")
def download(f: File, /): def download(f: File, /):
@ -118,21 +116,16 @@ def _get_encoded_string(f: File, /):
case FileTransferMethod.REMOTE_URL: case FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url, follow_redirects=True) response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status() response.raise_for_status()
content = response.content data = response.content
encoded_string = base64.b64encode(content).decode("utf-8")
return encoded_string
case FileTransferMethod.LOCAL_FILE: case FileTransferMethod.LOCAL_FILE:
upload_file = file_repository.get_upload_file(session=db.session(), file=f) upload_file = file_repository.get_upload_file(session=db.session(), file=f)
data = _download_file_content(upload_file.key) data = _download_file_content(upload_file.key)
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string
case FileTransferMethod.TOOL_FILE: case FileTransferMethod.TOOL_FILE:
tool_file = file_repository.get_tool_file(session=db.session(), file=f) tool_file = file_repository.get_tool_file(session=db.session(), file=f)
data = _download_file_content(tool_file.file_key) data = _download_file_content(tool_file.file_key)
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string encoded_string = base64.b64encode(data).decode("utf-8")
case _: return encoded_string
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
def _to_base64_data_string(f: File, /): def _to_base64_data_string(f: File, /):
@ -140,18 +133,6 @@ def _to_base64_data_string(f: File, /):
return f"data:{f.mime_type};base64,{encoded_string}" return f"data:{f.mime_type};base64,{encoded_string}"
def _file_to_encoded_string(f: File, /):
match f.type:
case FileType.IMAGE:
return _to_base64_data_string(f)
case FileType.VIDEO:
return _to_base64_data_string(f)
case FileType.AUDIO:
return _get_encoded_string(f)
case _:
raise ValueError(f"file type {f.type} is not supported")
def _to_url(f: File, /): def _to_url(f: File, /):
if f.transfer_method == FileTransferMethod.REMOTE_URL: if f.transfer_method == FileTransferMethod.REMOTE_URL:
if f.remote_url is None: if f.remote_url is None:

@ -1,6 +1,6 @@
import logging import logging
from collections.abc import Mapping from collections.abc import Mapping
from enum import Enum from enum import StrEnum
from threading import Lock from threading import Lock
from typing import Any, Optional from typing import Any, Optional
@ -31,7 +31,7 @@ class CodeExecutionResponse(BaseModel):
data: Data data: Data
class CodeLanguage(str, Enum): class CodeLanguage(StrEnum):
PYTHON3 = "python3" PYTHON3 = "python3"
JINJA2 = "jinja2" JINJA2 = "jinja2"
JAVASCRIPT = "javascript" JAVASCRIPT = "javascript"

@ -30,6 +30,7 @@ from core.rag.splitter.fixed_text_splitter import (
) )
from core.rag.splitter.text_splitter import TextSplitter from core.rag.splitter.text_splitter import TextSplitter
from core.tools.utils.text_processing_utils import remove_leading_symbols from core.tools.utils.text_processing_utils import remove_leading_symbols
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from extensions.ext_storage import storage from extensions.ext_storage import storage
@ -85,7 +86,7 @@ class IndexingRunner:
except ProviderTokenNotInitError as e: except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error" dataset_document.indexing_status = "error"
dataset_document.error = str(e.description) dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
except ObjectDeletedError: except ObjectDeletedError:
logging.warning("Document deleted, document id: {}".format(dataset_document.id)) logging.warning("Document deleted, document id: {}".format(dataset_document.id))
@ -93,7 +94,7 @@ class IndexingRunner:
logging.exception("consume document failed") logging.exception("consume document failed")
dataset_document.indexing_status = "error" dataset_document.indexing_status = "error"
dataset_document.error = str(e) dataset_document.error = str(e)
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
def run_in_splitting_status(self, dataset_document: DatasetDocument): def run_in_splitting_status(self, dataset_document: DatasetDocument):
@ -141,13 +142,13 @@ class IndexingRunner:
except ProviderTokenNotInitError as e: except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error" dataset_document.indexing_status = "error"
dataset_document.error = str(e.description) dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
except Exception as e: except Exception as e:
logging.exception("consume document failed") logging.exception("consume document failed")
dataset_document.indexing_status = "error" dataset_document.indexing_status = "error"
dataset_document.error = str(e) dataset_document.error = str(e)
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
def run_in_indexing_status(self, dataset_document: DatasetDocument): def run_in_indexing_status(self, dataset_document: DatasetDocument):
@ -199,13 +200,13 @@ class IndexingRunner:
except ProviderTokenNotInitError as e: except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error" dataset_document.indexing_status = "error"
dataset_document.error = str(e.description) dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
except Exception as e: except Exception as e:
logging.exception("consume document failed") logging.exception("consume document failed")
dataset_document.indexing_status = "error" dataset_document.indexing_status = "error"
dataset_document.error = str(e) dataset_document.error = str(e)
dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
def indexing_estimate( def indexing_estimate(
@ -279,6 +280,19 @@ class IndexingRunner:
if len(preview_texts) < 5: if len(preview_texts) < 5:
preview_texts.append(document.page_content) preview_texts.append(document.page_content)
# delete image files and related db records
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
try:
storage.delete(image_file.key)
except Exception:
logging.exception(
"Delete image_files failed while indexing_estimate, \
image_upload_file_is: {}".format(upload_file_id)
)
db.session.delete(image_file)
if doc_form and doc_form == "qa_model": if doc_form and doc_form == "qa_model":
if len(preview_texts) > 0: if len(preview_texts) > 0:
# qa model document # qa model document
@ -358,7 +372,7 @@ class IndexingRunner:
after_indexing_status="splitting", after_indexing_status="splitting",
extra_update_params={ extra_update_params={
DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs), DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs),
DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}, },
) )
@ -450,7 +464,7 @@ class IndexingRunner:
doc_store.add_documents(documents) doc_store.add_documents(documents)
# update document status to indexing # update document status to indexing
cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
self._update_document_index_status( self._update_document_index_status(
document_id=dataset_document.id, document_id=dataset_document.id,
after_indexing_status="indexing", after_indexing_status="indexing",
@ -465,7 +479,7 @@ class IndexingRunner:
dataset_document_id=dataset_document.id, dataset_document_id=dataset_document.id,
update_params={ update_params={
DocumentSegment.status: "indexing", DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}, },
) )
@ -666,7 +680,7 @@ class IndexingRunner:
after_indexing_status="completed", after_indexing_status="completed",
extra_update_params={ extra_update_params={
DatasetDocument.tokens: tokens, DatasetDocument.tokens: tokens,
DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DatasetDocument.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at, DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
DatasetDocument.error: None, DatasetDocument.error: None,
}, },
@ -691,7 +705,7 @@ class IndexingRunner:
{ {
DocumentSegment.status: "completed", DocumentSegment.status: "completed",
DocumentSegment.enabled: True, DocumentSegment.enabled: True,
DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
} }
) )
@ -724,7 +738,7 @@ class IndexingRunner:
{ {
DocumentSegment.status: "completed", DocumentSegment.status: "completed",
DocumentSegment.enabled: True, DocumentSegment.enabled: True,
DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
} }
) )
@ -835,7 +849,7 @@ class IndexingRunner:
doc_store.add_documents(documents) doc_store.add_documents(documents)
# update document status to indexing # update document status to indexing
cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
self._update_document_index_status( self._update_document_index_status(
document_id=dataset_document.id, document_id=dataset_document.id,
after_indexing_status="indexing", after_indexing_status="indexing",
@ -850,7 +864,7 @@ class IndexingRunner:
dataset_document_id=dataset_document.id, dataset_document_id=dataset_document.id,
update_params={ update_params={
DocumentSegment.status: "indexing", DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
}, },
) )
pass pass

@ -1,8 +1,8 @@
from collections.abc import Sequence
from typing import Optional from typing import Optional
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file import file_manager from core.file import file_manager
from core.file.models import FileType
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities import ( from core.model_runtime.entities import (
AssistantPromptMessage, AssistantPromptMessage,
@ -27,7 +27,7 @@ class TokenBufferMemory:
def get_history_prompt_messages( def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: Optional[int] = None self, max_token_limit: int = 2000, message_limit: Optional[int] = None
) -> list[PromptMessage]: ) -> Sequence[PromptMessage]:
""" """
Get history prompt messages. Get history prompt messages.
:param max_token_limit: max token limit :param max_token_limit: max token limit
@ -102,12 +102,11 @@ class TokenBufferMemory:
prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query)) prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file in file_objs: for file in file_objs:
if file.type in {FileType.IMAGE, FileType.AUDIO}: prompt_message = file_manager.to_prompt_message_content(
prompt_message = file_manager.to_prompt_message_content( file,
file, image_detail_config=detail,
image_detail_config=detail, )
) prompt_message_contents.append(prompt_message)
prompt_message_contents.append(prompt_message)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents))

@ -100,10 +100,10 @@ class ModelInstance:
def invoke_llm( def invoke_llm(
self, self,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: Optional[dict] = None, model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None, tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Optional from typing import Optional
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
@ -31,7 +32,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> None: ) -> None:
@ -60,7 +61,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
): ):
@ -90,7 +91,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> None: ) -> None:
@ -120,7 +121,7 @@ class Callback(ABC):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> None: ) -> None:

@ -2,6 +2,7 @@ from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsa
from .message_entities import ( from .message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
AudioPromptMessageContent, AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage, PromptMessage,
PromptMessageContent, PromptMessageContent,
@ -37,4 +38,5 @@ __all__ = [
"LLMResultChunk", "LLMResultChunk",
"LLMResultChunkDelta", "LLMResultChunkDelta",
"AudioPromptMessageContent", "AudioPromptMessageContent",
"DocumentPromptMessageContent",
] ]

@ -1,6 +1,7 @@
from abc import ABC from abc import ABC
from enum import Enum from collections.abc import Sequence
from typing import Optional from enum import Enum, StrEnum
from typing import Literal, Optional
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@ -48,7 +49,7 @@ class PromptMessageFunction(BaseModel):
function: PromptMessageTool function: PromptMessageTool
class PromptMessageContentType(Enum): class PromptMessageContentType(StrEnum):
""" """
Enum class for prompt message content type. Enum class for prompt message content type.
""" """
@ -57,6 +58,7 @@ class PromptMessageContentType(Enum):
IMAGE = "image" IMAGE = "image"
AUDIO = "audio" AUDIO = "audio"
VIDEO = "video" VIDEO = "video"
DOCUMENT = "document"
class PromptMessageContent(BaseModel): class PromptMessageContent(BaseModel):
@ -93,7 +95,7 @@ class ImagePromptMessageContent(PromptMessageContent):
Model class for image prompt message content. Model class for image prompt message content.
""" """
class DETAIL(str, Enum): class DETAIL(StrEnum):
LOW = "low" LOW = "low"
HIGH = "high" HIGH = "high"
@ -101,13 +103,20 @@ class ImagePromptMessageContent(PromptMessageContent):
detail: DETAIL = DETAIL.LOW detail: DETAIL = DETAIL.LOW
class DocumentPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
encode_format: Literal["base64"]
mime_type: str
data: str
class PromptMessage(ABC, BaseModel): class PromptMessage(ABC, BaseModel):
""" """
Model class for prompt message. Model class for prompt message.
""" """
role: PromptMessageRole role: PromptMessageRole
content: Optional[str | list[PromptMessageContent]] = None content: Optional[str | Sequence[PromptMessageContent]] = None
name: Optional[str] = None name: Optional[str] = None
def is_empty(self) -> bool: def is_empty(self) -> bool:

@ -1,5 +1,5 @@
from decimal import Decimal from decimal import Decimal
from enum import Enum from enum import Enum, StrEnum
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
@ -87,9 +87,12 @@ class ModelFeature(Enum):
AGENT_THOUGHT = "agent-thought" AGENT_THOUGHT = "agent-thought"
VISION = "vision" VISION = "vision"
STREAM_TOOL_CALL = "stream-tool-call" STREAM_TOOL_CALL = "stream-tool-call"
DOCUMENT = "document"
VIDEO = "video"
AUDIO = "audio"
class DefaultParameterName(str, Enum): class DefaultParameterName(StrEnum):
""" """
Enum class for parameter template variable. Enum class for parameter template variable.
""" """

@ -2,7 +2,7 @@ import logging
import re import re
import time import time
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping, Sequence
from typing import Optional, Union from typing import Optional, Union
from pydantic import ConfigDict from pydantic import ConfigDict
@ -48,7 +48,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None, model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
@ -169,7 +169,7 @@ class LargeLanguageModel(AIModel):
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
@ -212,7 +212,7 @@ if you are not sure about the structure.
) )
model_parameters.pop("response_format") model_parameters.pop("response_format")
stop = stop or [] stop = list(stop) if stop is not None else []
stop.extend(["\n```", "```\n"]) stop.extend(["\n```", "```\n"])
block_prompts = block_prompts.replace("{{block}}", code_block) block_prompts = block_prompts.replace("{{block}}", code_block)
@ -408,7 +408,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
@ -479,7 +479,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> Union[LLMResult, Generator]: ) -> Union[LLMResult, Generator]:
@ -601,7 +601,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
@ -647,7 +647,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
@ -694,7 +694,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,
@ -742,7 +742,7 @@ if you are not sure about the structure.
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None, callbacks: Optional[list[Callback]] = None,

@ -7,6 +7,7 @@ features:
- vision - vision
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document
model_properties: model_properties:
mode: chat mode: chat
context_size: 200000 context_size: 200000

@ -7,6 +7,7 @@ features:
- vision - vision
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document
model_properties: model_properties:
mode: chat mode: chat
context_size: 200000 context_size: 200000

@ -1,7 +1,7 @@
import base64 import base64
import io import io
import json import json
from collections.abc import Generator from collections.abc import Generator, Sequence
from typing import Optional, Union, cast from typing import Optional, Union, cast
import anthropic import anthropic
@ -21,9 +21,9 @@ from httpx import Timeout
from PIL import Image from PIL import Image
from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities import (
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
DocumentPromptMessageContent,
ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage, PromptMessage,
PromptMessageContentType, PromptMessageContentType,
@ -33,6 +33,7 @@ from core.model_runtime.entities.message_entities import (
ToolPromptMessage, ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.invoke import (
InvokeAuthorizationError, InvokeAuthorizationError,
InvokeBadRequestError, InvokeBadRequestError,
@ -86,10 +87,10 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
self, self,
model: str, model: str,
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: Sequence[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[Sequence[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
) -> Union[LLMResult, Generator]: ) -> Union[LLMResult, Generator]:
@ -130,9 +131,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
# Add the new header for claude-3-5-sonnet-20240620 model # Add the new header for claude-3-5-sonnet-20240620 model
extra_headers = {} extra_headers = {}
if model == "claude-3-5-sonnet-20240620": if model == "claude-3-5-sonnet-20240620":
if model_parameters.get("max_tokens") > 4096: if model_parameters.get("max_tokens", 0) > 4096:
extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15" extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15"
if any(
isinstance(content, DocumentPromptMessageContent)
for prompt_message in prompt_messages
if isinstance(prompt_message.content, list)
for content in prompt_message.content
):
extra_headers["anthropic-beta"] = "pdfs-2024-09-25"
if tools: if tools:
extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools] extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools]
response = client.beta.tools.messages.create( response = client.beta.tools.messages.create(
@ -444,7 +453,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
return credentials_kwargs return credentials_kwargs
def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]: def _convert_prompt_messages(self, prompt_messages: Sequence[PromptMessage]) -> tuple[str, list[dict]]:
""" """
Convert prompt messages to dict list and system Convert prompt messages to dict list and system
""" """
@ -452,7 +461,15 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
first_loop = True first_loop = True
for message in prompt_messages: for message in prompt_messages:
if isinstance(message, SystemPromptMessage): if isinstance(message, SystemPromptMessage):
message.content = message.content.strip() if isinstance(message.content, str):
message.content = message.content.strip()
elif isinstance(message.content, list):
# System prompt only support text
message.content = "".join(
c.data.strip() for c in message.content if isinstance(c, TextPromptMessageContent)
)
else:
raise ValueError(f"Unknown system prompt message content type {type(message.content)}")
if first_loop: if first_loop:
system = message.content system = message.content
first_loop = False first_loop = False
@ -504,6 +521,21 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
"source": {"type": "base64", "media_type": mime_type, "data": base64_data}, "source": {"type": "base64", "media_type": mime_type, "data": base64_data},
} }
sub_messages.append(sub_message_dict) sub_messages.append(sub_message_dict)
elif isinstance(message_content, DocumentPromptMessageContent):
if message_content.mime_type != "application/pdf":
raise ValueError(
f"Unsupported document type {message_content.mime_type}, "
"only support application/pdf"
)
sub_message_dict = {
"type": "document",
"source": {
"type": message_content.encode_format,
"media_type": message_content.mime_type,
"data": message_content.data,
},
}
sub_messages.append(sub_message_dict)
prompt_message_dicts.append({"role": "user", "content": sub_messages}) prompt_message_dicts.append({"role": "user", "content": sub_messages})
elif isinstance(message, AssistantPromptMessage): elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message) message = cast(AssistantPromptMessage, message)

@ -779,7 +779,7 @@ LLM_BASE_MODELS = [
name="frequency_penalty", name="frequency_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
), ),
_get_max_tokens(default=512, min_val=1, max_val=4096), _get_max_tokens(default=512, min_val=1, max_val=16384),
ParameterRule( ParameterRule(
name="seed", name="seed",
label=I18nObject(zh_Hans="种子", en_US="Seed"), label=I18nObject(zh_Hans="种子", en_US="Seed"),

@ -2,13 +2,11 @@
import base64 import base64
import json import json
import logging import logging
import mimetypes
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union, cast from typing import Optional, Union, cast
# 3rd import # 3rd import
import boto3 import boto3
import requests
from botocore.config import Config from botocore.config import Config
from botocore.exceptions import ( from botocore.exceptions import (
ClientError, ClientError,
@ -439,22 +437,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
sub_messages.append(sub_message_dict) sub_messages.append(sub_message_dict)
elif message_content.type == PromptMessageContentType.IMAGE: elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content) message_content = cast(ImagePromptMessageContent, message_content)
if not message_content.data.startswith("data:"): data_split = message_content.data.split(";base64,")
# fetch image data from url mime_type = data_split[0].replace("data:", "")
try: base64_data = data_split[1]
url = message_content.data image_content = base64.b64decode(base64_data)
image_content = requests.get(url).content
if "?" in url:
url = url.split("?")[0]
mime_type, _ = mimetypes.guess_type(url)
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
else:
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
image_content = base64.b64decode(base64_data)
if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}: if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}:
raise ValueError( raise ValueError(

@ -15,9 +15,9 @@ parameter_rules:
use_template: max_tokens use_template: max_tokens
required: true required: true
type: int type: int
default: 4096 default: 8192
min: 1 min: 1
max: 4096 max: 8192
help: help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.

@ -16,9 +16,9 @@ parameter_rules:
use_template: max_tokens use_template: max_tokens
required: true required: true
type: int type: int
default: 4096 default: 8192
min: 1 min: 1
max: 4096 max: 8192
help: help:
zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 zh_Hans: 停止前生成的最大令牌数。请注意Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.

@ -691,8 +691,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
base_model_schema = cast(AIModelEntity, base_model_schema) base_model_schema = cast(AIModelEntity, base_model_schema)
base_model_schema_features = base_model_schema.features or [] base_model_schema_features = base_model_schema.features or []
base_model_schema_model_properties = base_model_schema.model_properties or {} base_model_schema_model_properties = base_model_schema.model_properties
base_model_schema_parameters_rules = base_model_schema.parameter_rules or [] base_model_schema_parameters_rules = base_model_schema.parameter_rules
entity = AIModelEntity( entity = AIModelEntity(
model=model, model=model,

@ -5,6 +5,7 @@ label:
model_type: llm model_type: llm
features: features:
- agent-thought - agent-thought
- tool-call
- multi-tool-call - multi-tool-call
- stream-tool-call - stream-tool-call
model_properties: model_properties:
@ -72,7 +73,7 @@ parameter_rules:
- text - text
- json_object - json_object
pricing: pricing:
input: '1' input: "1"
output: '2' output: "2"
unit: '0.000001' unit: "0.000001"
currency: RMB currency: RMB

@ -5,6 +5,7 @@ label:
model_type: llm model_type: llm
features: features:
- agent-thought - agent-thought
- tool-call
- multi-tool-call - multi-tool-call
- stream-tool-call - stream-tool-call
model_properties: model_properties:

@ -1,18 +1,17 @@
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Union from typing import Optional, Union
from urllib.parse import urlparse
import tiktoken from yarl import URL
from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
PromptMessage, PromptMessage,
PromptMessageTool, PromptMessageTool,
) )
from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): class DeepseekLargeLanguageModel(OAIAPICompatLargeLanguageModel):
def _invoke( def _invoke(
self, self,
model: str, model: str,
@ -25,92 +24,15 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel):
user: Optional[str] = None, user: Optional[str] = None,
) -> Union[LLMResult, Generator]: ) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials) self._add_custom_parameters(credentials)
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
def validate_credentials(self, model: str, credentials: dict) -> None: def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials) self._add_custom_parameters(credentials)
super().validate_credentials(model, credentials) super().validate_credentials(model, credentials)
# refactored from openai model runtime, use cl100k_base for calculate token number
def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Calculate num tokens for text completion model with tiktoken package.
:param model: model name
:param text: prompt text
:param tools: tools for tool calling
:return: number of tokens
"""
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = len(encoding.encode(text))
if tools:
num_tokens += self._num_tokens_for_tools(encoding, tools)
return num_tokens
# refactored from openai model runtime, use cl100k_base for calculate token number
def _num_tokens_from_messages(
self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
encoding = tiktoken.get_encoding("cl100k_base")
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
# Cast str(value) in case the message value is not a string
# This occurs with function messages
# TODO: The current token calculation method for the image type is not implemented,
# which need to download the image and then get the resolution for calculation,
# and will increase the request delay
if isinstance(value, list):
text = ""
for item in value:
if isinstance(item, dict) and item["type"] == "text":
text += item["text"]
value = text
if key == "tool_calls":
for tool_call in value:
for t_key, t_value in tool_call.items():
num_tokens += len(encoding.encode(t_key))
if t_key == "function":
for f_key, f_value in t_value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(t_key))
num_tokens += len(encoding.encode(t_value))
else:
num_tokens += len(encoding.encode(str(value)))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
if tools:
num_tokens += self._num_tokens_for_tools(encoding, tools)
return num_tokens
@staticmethod @staticmethod
def _add_custom_parameters(credentials: dict) -> None: def _add_custom_parameters(credentials) -> None:
credentials["mode"] = "chat" credentials["endpoint_url"] = str(URL(credentials.get("endpoint_url", "https://api.deepseek.com")))
credentials["openai_api_key"] = credentials["api_key"] credentials["mode"] = LLMMode.CHAT.value
if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": credentials["function_calling_type"] = "tool_call"
credentials["openai_api_base"] = "https://api.deepseek.com" credentials["stream_function_calling"] = "support"
else:
parsed_url = urlparse(credentials["endpoint_url"])
credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}"

@ -18,7 +18,8 @@ class FishAudioProvider(ModelProvider):
""" """
try: try:
model_instance = self.get_model_instance(ModelType.TTS) model_instance = self.get_model_instance(ModelType.TTS)
model_instance.validate_credentials(credentials=credentials) # FIXME fish tts do not have model for now, so set it to empty string instead
model_instance.validate_credentials(model="", credentials=credentials)
except CredentialsValidateFailedError as ex: except CredentialsValidateFailedError as ex:
raise ex raise ex
except Exception as ex: except Exception as ex:

@ -66,7 +66,7 @@ class FishAudioText2SpeechModel(TTSModel):
voice=voice, voice=voice,
) )
def validate_credentials(self, credentials: dict, user: Optional[str] = None) -> None: def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None:
""" """
Validate credentials for text2speech model Validate credentials for text2speech model
@ -76,7 +76,7 @@ class FishAudioText2SpeechModel(TTSModel):
try: try:
self.get_tts_model_voices( self.get_tts_model_voices(
None, "",
credentials={ credentials={
"api_key": credentials["api_key"], "api_key": credentials["api_key"],
"api_base": credentials["api_base"], "api_base": credentials["api_base"],

@ -122,7 +122,7 @@ class GiteeAIRerankModel(RerankModel):
label=I18nObject(en_US=model), label=I18nObject(en_US=model),
model_type=ModelType.RERANK, model_type=ModelType.RERANK,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512))},
) )
return entity return entity

@ -7,6 +7,7 @@ features:
- vision - vision
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document
model_properties: model_properties:
mode: chat mode: chat
context_size: 1048576 context_size: 1048576

@ -7,6 +7,7 @@ features:
- vision - vision
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document
model_properties: model_properties:
mode: chat mode: chat
context_size: 1048576 context_size: 1048576

@ -7,6 +7,7 @@ features:
- vision - vision
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document
model_properties: model_properties:
mode: chat mode: chat
context_size: 1048576 context_size: 1048576

@ -7,6 +7,7 @@ features:
- vision - vision
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document
model_properties: model_properties:
mode: chat mode: chat
context_size: 1048576 context_size: 1048576

@ -7,6 +7,7 @@ features:
- vision - vision
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document
model_properties: model_properties:
mode: chat mode: chat
context_size: 1048576 context_size: 1048576

@ -7,6 +7,7 @@ features:
- vision - vision
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document
model_properties: model_properties:
mode: chat mode: chat
context_size: 1048576 context_size: 1048576

@ -7,6 +7,7 @@ features:
- vision - vision
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document
model_properties: model_properties:
mode: chat mode: chat
context_size: 1048576 context_size: 1048576

@ -7,6 +7,7 @@ features:
- vision - vision
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document
model_properties: model_properties:
mode: chat mode: chat
context_size: 2097152 context_size: 2097152

@ -7,6 +7,7 @@ features:
- vision - vision
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document
model_properties: model_properties:
mode: chat mode: chat
context_size: 2097152 context_size: 2097152

@ -7,6 +7,7 @@ features:
- vision - vision
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document
model_properties: model_properties:
mode: chat mode: chat
context_size: 2097152 context_size: 2097152

@ -7,6 +7,7 @@ features:
- vision - vision
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document
model_properties: model_properties:
mode: chat mode: chat
context_size: 2097152 context_size: 2097152

@ -7,6 +7,7 @@ features:
- vision - vision
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document
model_properties: model_properties:
mode: chat mode: chat
context_size: 2097152 context_size: 2097152

@ -7,6 +7,7 @@ features:
- vision - vision
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document
model_properties: model_properties:
mode: chat mode: chat
context_size: 2097152 context_size: 2097152

@ -7,9 +7,10 @@ features:
- vision - vision
- tool-call - tool-call
- stream-tool-call - stream-tool-call
- document
model_properties: model_properties:
mode: chat mode: chat
context_size: 2097152 context_size: 32767
parameter_rules: parameter_rules:
- name: temperature - name: temperature
use_template: temperature use_template: temperature

@ -0,0 +1,38 @@
model: gemini-exp-1121
label:
en_US: Gemini exp 1121
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 32767
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

@ -0,0 +1,38 @@
model: learnlm-1.5-pro-experimental
label:
en_US: LearnLM 1.5 Pro Experimental
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 32767
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

@ -16,6 +16,7 @@ from PIL import Image
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
DocumentPromptMessageContent,
ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage, PromptMessage,
PromptMessageContentType, PromptMessageContentType,
@ -35,6 +36,21 @@ from core.model_runtime.errors.invoke import (
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
GOOGLE_AVAILABLE_MIMETYPE = [
"application/pdf",
"application/x-javascript",
"text/javascript",
"application/x-python",
"text/x-python",
"text/plain",
"text/html",
"text/css",
"text/md",
"text/csv",
"text/xml",
"text/rtf",
]
class GoogleLargeLanguageModel(LargeLanguageModel): class GoogleLargeLanguageModel(LargeLanguageModel):
def _invoke( def _invoke(
@ -370,6 +386,12 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}} blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
glm_content["parts"].append(blob) glm_content["parts"].append(blob)
elif c.type == PromptMessageContentType.DOCUMENT:
message_content = cast(DocumentPromptMessageContent, c)
if message_content.mime_type not in GOOGLE_AVAILABLE_MIMETYPE:
raise ValueError(f"Unsupported mime type {message_content.mime_type}")
blob = {"inline_data": {"mime_type": message_content.mime_type, "data": message_content.data}}
glm_content["parts"].append(blob)
return glm_content return glm_content
elif isinstance(message, AssistantPromptMessage): elif isinstance(message, AssistantPromptMessage):

@ -140,7 +140,7 @@ class GPUStackRerankModel(RerankModel):
label=I18nObject(en_US=model), label=I18nObject(en_US=model),
model_type=ModelType.RERANK, model_type=ModelType.RERANK,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512))},
) )
return entity return entity

@ -34,3 +34,11 @@ model_credential_schema:
placeholder: placeholder:
zh_Hans: 在此输入Text Embedding Inference的服务器地址如 http://192.168.1.100:8080 zh_Hans: 在此输入Text Embedding Inference的服务器地址如 http://192.168.1.100:8080
en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080 en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080
- variable: api_key
label:
en_US: API Key
type: secret-input
required: false
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key

@ -51,8 +51,13 @@ class HuggingfaceTeiRerankModel(RerankModel):
server_url = server_url.removesuffix("/") server_url = server_url.removesuffix("/")
headers = {"Content-Type": "application/json"}
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
try: try:
results = TeiHelper.invoke_rerank(server_url, query, docs) results = TeiHelper.invoke_rerank(server_url, query, docs, headers)
rerank_documents = [] rerank_documents = []
for result in results: for result in results:
@ -80,7 +85,11 @@ class HuggingfaceTeiRerankModel(RerankModel):
""" """
try: try:
server_url = credentials["server_url"] server_url = credentials["server_url"]
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) headers = {"Content-Type": "application/json"}
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers)
if extra_args.model_type != "reranker": if extra_args.model_type != "reranker":
raise CredentialsValidateFailedError("Current model is not a rerank model") raise CredentialsValidateFailedError("Current model is not a rerank model")

@ -26,13 +26,15 @@ cache_lock = Lock()
class TeiHelper: class TeiHelper:
@staticmethod @staticmethod
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter: def get_tei_extra_parameter(
server_url: str, model_name: str, headers: Optional[dict] = None
) -> TeiModelExtraParameter:
TeiHelper._clean_cache() TeiHelper._clean_cache()
with cache_lock: with cache_lock:
if model_name not in cache: if model_name not in cache:
cache[model_name] = { cache[model_name] = {
"expires": time() + 300, "expires": time() + 300,
"value": TeiHelper._get_tei_extra_parameter(server_url), "value": TeiHelper._get_tei_extra_parameter(server_url, headers),
} }
return cache[model_name]["value"] return cache[model_name]["value"]
@ -47,7 +49,7 @@ class TeiHelper:
pass pass
@staticmethod @staticmethod
def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter: def _get_tei_extra_parameter(server_url: str, headers: Optional[dict] = None) -> TeiModelExtraParameter:
""" """
get tei model extra parameter like model_type, max_input_length, max_batch_requests get tei model extra parameter like model_type, max_input_length, max_batch_requests
""" """
@ -61,7 +63,7 @@ class TeiHelper:
session.mount("https://", HTTPAdapter(max_retries=3)) session.mount("https://", HTTPAdapter(max_retries=3))
try: try:
response = session.get(url, timeout=10) response = session.get(url, headers=headers, timeout=10)
except (MissingSchema, ConnectionError, Timeout) as e: except (MissingSchema, ConnectionError, Timeout) as e:
raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}") raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}")
if response.status_code != 200: if response.status_code != 200:
@ -86,7 +88,7 @@ class TeiHelper:
) )
@staticmethod @staticmethod
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: def invoke_tokenize(server_url: str, texts: list[str], headers: Optional[dict] = None) -> list[list[dict]]:
""" """
Invoke tokenize endpoint Invoke tokenize endpoint
@ -114,15 +116,15 @@ class TeiHelper:
:param server_url: server url :param server_url: server url
:param texts: texts to tokenize :param texts: texts to tokenize
""" """
resp = httpx.post( url = f"{server_url}/tokenize"
f"{server_url}/tokenize", json_data = {"inputs": texts}
json={"inputs": texts}, resp = httpx.post(url, json=json_data, headers=headers)
)
resp.raise_for_status() resp.raise_for_status()
return resp.json() return resp.json()
@staticmethod @staticmethod
def invoke_embeddings(server_url: str, texts: list[str]) -> dict: def invoke_embeddings(server_url: str, texts: list[str], headers: Optional[dict] = None) -> dict:
""" """
Invoke embeddings endpoint Invoke embeddings endpoint
@ -147,15 +149,14 @@ class TeiHelper:
:param texts: texts to embed :param texts: texts to embed
""" """
# Use OpenAI compatible API here, which has usage tracking # Use OpenAI compatible API here, which has usage tracking
resp = httpx.post( url = f"{server_url}/v1/embeddings"
f"{server_url}/v1/embeddings", json_data = {"input": texts}
json={"input": texts}, resp = httpx.post(url, json=json_data, headers=headers)
)
resp.raise_for_status() resp.raise_for_status()
return resp.json() return resp.json()
@staticmethod @staticmethod
def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]: def invoke_rerank(server_url: str, query: str, docs: list[str], headers: Optional[dict] = None) -> list[dict]:
""" """
Invoke rerank endpoint Invoke rerank endpoint
@ -173,10 +174,7 @@ class TeiHelper:
:param candidates: candidates to rerank :param candidates: candidates to rerank
""" """
params = {"query": query, "texts": docs, "return_text": True} params = {"query": query, "texts": docs, "return_text": True}
url = f"{server_url}/rerank"
response = httpx.post( response = httpx.post(url, json=params, headers=headers)
server_url + "/rerank",
json=params,
)
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()

@ -51,6 +51,10 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
server_url = server_url.removesuffix("/") server_url = server_url.removesuffix("/")
headers = {"Content-Type": "application/json"}
api_key = credentials["api_key"]
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
# get model properties # get model properties
context_size = self._get_context_size(model, credentials) context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials) max_chunks = self._get_max_chunks(model, credentials)
@ -60,7 +64,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
used_tokens = 0 used_tokens = 0
# get tokenized results from TEI # get tokenized results from TEI
batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts) batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts, headers)
for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)): for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)):
# Check if the number of tokens is larger than the context size # Check if the number of tokens is larger than the context size
@ -97,7 +101,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
used_tokens = 0 used_tokens = 0
for i in _iter: for i in _iter:
iter_texts = inputs[i : i + max_chunks] iter_texts = inputs[i : i + max_chunks]
results = TeiHelper.invoke_embeddings(server_url, iter_texts) results = TeiHelper.invoke_embeddings(server_url, iter_texts, headers)
embeddings = results["data"] embeddings = results["data"]
embeddings = [embedding["embedding"] for embedding in embeddings] embeddings = [embedding["embedding"] for embedding in embeddings]
batched_embeddings.extend(embeddings) batched_embeddings.extend(embeddings)
@ -127,7 +131,11 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
server_url = server_url.removesuffix("/") server_url = server_url.removesuffix("/")
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts) headers = {
"Authorization": f"Bearer {credentials.get('api_key')}",
}
batch_tokens = TeiHelper.invoke_tokenize(server_url, texts, headers)
num_tokens = sum(len(tokens) for tokens in batch_tokens) num_tokens = sum(len(tokens) for tokens in batch_tokens)
return num_tokens return num_tokens
@ -141,7 +149,14 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel):
""" """
try: try:
server_url = credentials["server_url"] server_url = credentials["server_url"]
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) headers = {"Content-Type": "application/json"}
api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers)
print(extra_args) print(extra_args)
if extra_args.model_type != "embedding": if extra_args.model_type != "embedding":
raise CredentialsValidateFailedError("Current model is not a embedding model") raise CredentialsValidateFailedError("Current model is not a embedding model")

@ -128,7 +128,7 @@ class JinaRerankModel(RerankModel):
label=I18nObject(en_US=model), label=I18nObject(en_US=model),
model_type=ModelType.RERANK, model_type=ModelType.RERANK,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000))},
) )
return entity return entity

@ -193,7 +193,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel):
label=I18nObject(en_US=model), label=I18nObject(en_US=model),
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000))},
) )
return entity return entity

@ -22,6 +22,7 @@ from core.model_runtime.entities.message_entities import (
PromptMessageTool, PromptMessageTool,
SystemPromptMessage, SystemPromptMessage,
TextPromptMessageContent, TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.model_entities import ( from core.model_runtime.entities.model_entities import (
@ -86,6 +87,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
credentials=credentials, credentials=credentials,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
model_parameters=model_parameters, model_parameters=model_parameters,
tools=tools,
stop=stop, stop=stop,
stream=stream, stream=stream,
user=user, user=user,
@ -153,6 +155,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
credentials: dict, credentials: dict,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
model_parameters: dict, model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None, stop: Optional[list[str]] = None,
stream: bool = True, stream: bool = True,
user: Optional[str] = None, user: Optional[str] = None,
@ -196,6 +199,8 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
if completion_type is LLMMode.CHAT: if completion_type is LLMMode.CHAT:
endpoint_url = urljoin(endpoint_url, "api/chat") endpoint_url = urljoin(endpoint_url, "api/chat")
data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
if tools:
data["tools"] = [self._convert_prompt_message_tool_to_dict(tool) for tool in tools]
else: else:
endpoint_url = urljoin(endpoint_url, "api/generate") endpoint_url = urljoin(endpoint_url, "api/generate")
first_prompt_message = prompt_messages[0] first_prompt_message = prompt_messages[0]
@ -232,7 +237,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
if stream: if stream:
return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages) return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages) return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages, tools)
def _handle_generate_response( def _handle_generate_response(
self, self,
@ -241,6 +246,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
completion_type: LLMMode, completion_type: LLMMode,
response: requests.Response, response: requests.Response,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]],
) -> LLMResult: ) -> LLMResult:
""" """
Handle llm completion response Handle llm completion response
@ -253,14 +259,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
:return: llm result :return: llm result
""" """
response_json = response.json() response_json = response.json()
tool_calls = []
if completion_type is LLMMode.CHAT: if completion_type is LLMMode.CHAT:
message = response_json.get("message", {}) message = response_json.get("message", {})
response_content = message.get("content", "") response_content = message.get("content", "")
response_tool_calls = message.get("tool_calls", [])
tool_calls = [self._extract_response_tool_call(tool_call) for tool_call in response_tool_calls]
else: else:
response_content = response_json["response"] response_content = response_json["response"]
assistant_message = AssistantPromptMessage(content=response_content) assistant_message = AssistantPromptMessage(content=response_content, tool_calls=tool_calls)
if "prompt_eval_count" in response_json and "eval_count" in response_json: if "prompt_eval_count" in response_json and "eval_count" in response_json:
# transform usage # transform usage
@ -405,9 +413,28 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
chunk_index += 1 chunk_index += 1
def _convert_prompt_message_tool_to_dict(self, tool: PromptMessageTool) -> dict:
"""
Convert PromptMessageTool to dict for Ollama API
:param tool: tool
:return: tool dict
"""
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
},
}
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
""" """
Convert PromptMessage to dict for Ollama API Convert PromptMessage to dict for Ollama API
:param message: prompt message
:return: message dict
""" """
if isinstance(message, UserPromptMessage): if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message) message = cast(UserPromptMessage, message)
@ -432,6 +459,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
elif isinstance(message, SystemPromptMessage): elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message) message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content} message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {"role": "tool", "content": message.content}
else: else:
raise ValueError(f"Got unknown type {message}") raise ValueError(f"Got unknown type {message}")
@ -452,6 +482,29 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
return num_tokens return num_tokens
def _extract_response_tool_call(self, response_tool_call: dict) -> AssistantPromptMessage.ToolCall:
"""
Extract response tool call
"""
tool_call = None
if response_tool_call and "function" in response_tool_call:
# Convert arguments to JSON string if it's a dict
arguments = response_tool_call.get("function").get("arguments")
if isinstance(arguments, dict):
arguments = json.dumps(arguments)
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_tool_call.get("function").get("name"),
arguments=arguments,
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.get("function").get("name"),
type="function",
function=function,
)
return tool_call
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
""" """
Get customizable model schema. Get customizable model schema.
@ -461,10 +514,15 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
:return: model schema :return: model schema
""" """
extras = {} extras = {
"features": [],
}
if "vision_support" in credentials and credentials["vision_support"] == "true": if "vision_support" in credentials and credentials["vision_support"] == "true":
extras["features"] = [ModelFeature.VISION] extras["features"].append(ModelFeature.VISION)
if "function_call_support" in credentials and credentials["function_call_support"] == "true":
extras["features"].append(ModelFeature.TOOL_CALL)
extras["features"].append(ModelFeature.MULTI_TOOL_CALL)
entity = AIModelEntity( entity = AIModelEntity(
model=model, model=model,

@ -96,3 +96,22 @@ model_credential_schema:
label: label:
en_US: 'No' en_US: 'No'
zh_Hans: zh_Hans:
- variable: function_call_support
label:
zh_Hans: 是否支持函数调用
en_US: Function call support
show_on:
- variable: __model_type
value: llm
default: 'false'
type: radio
required: false
options:
- value: 'true'
label:
en_US: 'Yes'
zh_Hans:
- value: 'false'
label:
en_US: 'No'
zh_Hans:

@ -139,7 +139,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel):
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)),
ModelPropertyKey.MAX_CHUNKS: 1, ModelPropertyKey.MAX_CHUNKS: 1,
}, },
parameter_rules=[], parameter_rules=[],

@ -3,6 +3,7 @@
- gpt-4o - gpt-4o
- gpt-4o-2024-05-13 - gpt-4o-2024-05-13
- gpt-4o-2024-08-06 - gpt-4o-2024-08-06
- gpt-4o-2024-11-20
- chatgpt-4o-latest - chatgpt-4o-latest
- gpt-4o-mini - gpt-4o-mini
- gpt-4o-mini-2024-07-18 - gpt-4o-mini-2024-07-18

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save