Merge branch 'main' into feat/retry-single-step-debug

feat/retry-single-step-debug
Novice Lee 1 year ago
commit 8933dd85bf

@ -56,6 +56,12 @@ jobs:
- name: Run Tool - name: Run Tool
run: poetry run -C api bash dev/pytest/pytest_tools.sh run: poetry run -C api bash dev/pytest/pytest_tools.sh
- name: Run mypy
run: |
pushd api
poetry run python -m mypy --install-types --non-interactive .
popd
- name: Set up dotenvs - name: Set up dotenvs
run: | run: |
cp docker/.env.example docker/.env cp docker/.env.example docker/.env

@ -65,7 +65,7 @@ OPENDAL_FS_ROOT=storage
# S3 Storage configuration # S3 Storage configuration
S3_USE_AWS_MANAGED_IAM=false S3_USE_AWS_MANAGED_IAM=false
S3_ENDPOINT=https://your-bucket-name.storage.s3.clooudflare.com S3_ENDPOINT=https://your-bucket-name.storage.s3.cloudflare.com
S3_BUCKET_NAME=your-bucket-name S3_BUCKET_NAME=your-bucket-name
S3_ACCESS_KEY=your-access-key S3_ACCESS_KEY=your-access-key
S3_SECRET_KEY=your-secret-key S3_SECRET_KEY=your-secret-key
@ -74,7 +74,7 @@ S3_REGION=your-region
# Azure Blob Storage configuration # Azure Blob Storage configuration
AZURE_BLOB_ACCOUNT_NAME=your-account-name AZURE_BLOB_ACCOUNT_NAME=your-account-name
AZURE_BLOB_ACCOUNT_KEY=your-account-key AZURE_BLOB_ACCOUNT_KEY=your-account-key
AZURE_BLOB_CONTAINER_NAME=yout-container-name AZURE_BLOB_CONTAINER_NAME=your-container-name
AZURE_BLOB_ACCOUNT_URL=https://<your_account_name>.blob.core.windows.net AZURE_BLOB_ACCOUNT_URL=https://<your_account_name>.blob.core.windows.net
# Aliyun oss Storage configuration # Aliyun oss Storage configuration
@ -88,7 +88,7 @@ ALIYUN_OSS_REGION=your-region
ALIYUN_OSS_PATH=your-path ALIYUN_OSS_PATH=your-path
# Google Storage configuration # Google Storage configuration
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string
# Tencent COS Storage configuration # Tencent COS Storage configuration

@ -67,7 +67,7 @@ ignore = [
"SIM105", # suppressible-exception "SIM105", # suppressible-exception
"SIM107", # return-in-try-except-finally "SIM107", # return-in-try-except-finally
"SIM108", # if-else-block-instead-of-if-exp "SIM108", # if-else-block-instead-of-if-exp
"SIM113", # eumerate-for-loop "SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements "SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false "SIM210", # if-expr-with-true-false
] ]

@ -159,8 +159,7 @@ def migrate_annotation_vector_database():
try: try:
# get apps info # get apps info
apps = ( apps = (
db.session.query(App) App.query.filter(App.status == "normal")
.filter(App.status == "normal")
.order_by(App.created_at.desc()) .order_by(App.created_at.desc())
.paginate(page=page, per_page=50) .paginate(page=page, per_page=50)
) )
@ -285,8 +284,7 @@ def migrate_knowledge_vector_database():
while True: while True:
try: try:
datasets = ( datasets = (
db.session.query(Dataset) Dataset.query.filter(Dataset.indexing_technique == "high_quality")
.filter(Dataset.indexing_technique == "high_quality")
.order_by(Dataset.created_at.desc()) .order_by(Dataset.created_at.desc())
.paginate(page=page, per_page=50) .paginate(page=page, per_page=50)
) )
@ -450,7 +448,8 @@ def convert_to_agent_apps():
if app_id not in proceeded_app_ids: if app_id not in proceeded_app_ids:
proceeded_app_ids.append(app_id) proceeded_app_ids.append(app_id)
app = db.session.query(App).filter(App.id == app_id).first() app = db.session.query(App).filter(App.id == app_id).first()
apps.append(app) if app is not None:
apps.append(app)
if len(apps) == 0: if len(apps) == 0:
break break
@ -562,8 +561,13 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
new_password = secrets.token_urlsafe(16) new_password = secrets.token_urlsafe(16)
# register account # register account
account = RegisterService.register(email=email, name=account_name, password=new_password, language=language) account = RegisterService.register(
email=email,
name=account_name,
password=new_password,
language=language,
create_workspace_required=False,
)
TenantService.create_owner_tenant_if_not_exist(account, name) TenantService.create_owner_tenant_if_not_exist(account, name)
click.echo( click.echo(
@ -583,7 +587,7 @@ def upgrade_db():
click.echo(click.style("Starting database migration.", fg="green")) click.echo(click.style("Starting database migration.", fg="green"))
# run db migration # run db migration
import flask_migrate import flask_migrate # type: ignore
flask_migrate.upgrade() flask_migrate.upgrade()
@ -621,6 +625,10 @@ where sites.id is null limit 1000"""
try: try:
app = db.session.query(App).filter(App.id == app_id).first() app = db.session.query(App).filter(App.id == app_id).first()
if not app:
print(f"App {app_id} not found")
continue
tenant = app.tenant tenant = app.tenant
if tenant: if tenant:
accounts = tenant.get_accounts() accounts = tenant.get_accounts()

@ -239,7 +239,6 @@ class HttpConfig(BaseSettings):
) )
@computed_field @computed_field
@property
def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]: def CONSOLE_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(",") return self.inner_CONSOLE_CORS_ALLOW_ORIGINS.split(",")
@ -250,7 +249,6 @@ class HttpConfig(BaseSettings):
) )
@computed_field @computed_field
@property
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]: def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",") return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
@ -715,27 +713,27 @@ class PositionConfig(BaseSettings):
default="", default="",
) )
@computed_field @property
def POSITION_PROVIDER_PINS_LIST(self) -> list[str]: def POSITION_PROVIDER_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(",") if item.strip() != ""] return [item.strip() for item in self.POSITION_PROVIDER_PINS.split(",") if item.strip() != ""]
@computed_field @property
def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]: def POSITION_PROVIDER_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(",") if item.strip() != ""} return {item.strip() for item in self.POSITION_PROVIDER_INCLUDES.split(",") if item.strip() != ""}
@computed_field @property
def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]: def POSITION_PROVIDER_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(",") if item.strip() != ""} return {item.strip() for item in self.POSITION_PROVIDER_EXCLUDES.split(",") if item.strip() != ""}
@computed_field @property
def POSITION_TOOL_PINS_LIST(self) -> list[str]: def POSITION_TOOL_PINS_LIST(self) -> list[str]:
return [item.strip() for item in self.POSITION_TOOL_PINS.split(",") if item.strip() != ""] return [item.strip() for item in self.POSITION_TOOL_PINS.split(",") if item.strip() != ""]
@computed_field @property
def POSITION_TOOL_INCLUDES_SET(self) -> set[str]: def POSITION_TOOL_INCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(",") if item.strip() != ""} return {item.strip() for item in self.POSITION_TOOL_INCLUDES.split(",") if item.strip() != ""}
@computed_field @property
def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]: def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]:
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""} return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}

@ -130,7 +130,6 @@ class DatabaseConfig(BaseSettings):
) )
@computed_field @computed_field
@property
def SQLALCHEMY_DATABASE_URI(self) -> str: def SQLALCHEMY_DATABASE_URI(self) -> str:
db_extras = ( db_extras = (
f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS f"{self.DB_EXTRAS}&client_encoding={self.DB_CHARSET}" if self.DB_CHARSET else self.DB_EXTRAS
@ -168,7 +167,6 @@ class DatabaseConfig(BaseSettings):
) )
@computed_field @computed_field
@property
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
return { return {
"pool_size": self.SQLALCHEMY_POOL_SIZE, "pool_size": self.SQLALCHEMY_POOL_SIZE,
@ -206,7 +204,6 @@ class CeleryConfig(DatabaseConfig):
) )
@computed_field @computed_field
@property
def CELERY_RESULT_BACKEND(self) -> str | None: def CELERY_RESULT_BACKEND(self) -> str | None:
return ( return (
"db+{}".format(self.SQLALCHEMY_DATABASE_URI) "db+{}".format(self.SQLALCHEMY_DATABASE_URI)
@ -214,7 +211,6 @@ class CeleryConfig(DatabaseConfig):
else self.CELERY_BROKER_URL else self.CELERY_BROKER_URL
) )
@computed_field
@property @property
def BROKER_USE_SSL(self) -> bool: def BROKER_USE_SSL(self) -> bool:
return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False

@ -4,6 +4,7 @@ import logging
import os import os
import threading import threading
import time import time
from collections.abc import Mapping
from pathlib import Path from pathlib import Path
from .python_3x import http_request, makedirs_wrapper from .python_3x import http_request, makedirs_wrapper
@ -255,8 +256,8 @@ class ApolloClient:
logger.info("stopped, long_poll") logger.info("stopped, long_poll")
# add the need for endorsement to the header # add the need for endorsement to the header
def _sign_headers(self, url): def _sign_headers(self, url: str) -> Mapping[str, str]:
headers = {} headers: dict[str, str] = {}
if self.secret == "": if self.secret == "":
return headers return headers
uri = url[len(self.config_url) : len(url)] uri = url[len(self.config_url) : len(url)]

@ -1,8 +1,9 @@
import json import json
from collections.abc import Mapping
from models.model import AppMode from models.model import AppMode
default_app_templates = { default_app_templates: Mapping[AppMode, Mapping] = {
# workflow default mode # workflow default mode
AppMode.WORKFLOW: { AppMode.WORKFLOW: {
"app": { "app": {

@ -1,4 +1,4 @@
from flask_restful import fields from flask_restful import fields # type: ignore
parameters__system_parameters = { parameters__system_parameters = {
"image_file_size_limit": fields.Integer, "image_file_size_limit": fields.Integer,

@ -3,6 +3,25 @@ from flask import Blueprint
from libs.external_api import ExternalApi from libs.external_api import ExternalApi
from .app.app_import import AppImportApi, AppImportConfirmApi from .app.app_import import AppImportApi, AppImportConfirmApi
from .explore.audio import ChatAudioApi, ChatTextApi
from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
from .explore.conversation import (
ConversationApi,
ConversationListApi,
ConversationPinApi,
ConversationRenameApi,
ConversationUnPinApi,
)
from .explore.message import (
MessageFeedbackApi,
MessageListApi,
MessageMoreLikeThisApi,
MessageSuggestedQuestionApi,
)
from .explore.workflow import (
InstalledAppWorkflowRunApi,
InstalledAppWorkflowTaskStopApi,
)
from .files import FileApi, FilePreviewApi, FileSupportTypeApi from .files import FileApi, FilePreviewApi, FileSupportTypeApi
from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
@ -66,15 +85,81 @@ from .datasets import (
# Import explore controllers # Import explore controllers
from .explore import ( from .explore import (
audio,
completion,
conversation,
installed_app, installed_app,
message,
parameter, parameter,
recommended_app, recommended_app,
saved_message, saved_message,
workflow, )
# Explore Audio
api.add_resource(ChatAudioApi, "/installed-apps/<uuid:installed_app_id>/audio-to-text", endpoint="installed_app_audio")
api.add_resource(ChatTextApi, "/installed-apps/<uuid:installed_app_id>/text-to-audio", endpoint="installed_app_text")
# Explore Completion
api.add_resource(
CompletionApi, "/installed-apps/<uuid:installed_app_id>/completion-messages", endpoint="installed_app_completion"
)
api.add_resource(
CompletionStopApi,
"/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop",
endpoint="installed_app_stop_completion",
)
api.add_resource(
ChatApi, "/installed-apps/<uuid:installed_app_id>/chat-messages", endpoint="installed_app_chat_completion"
)
api.add_resource(
ChatStopApi,
"/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop",
endpoint="installed_app_stop_chat_completion",
)
# Explore Conversation
api.add_resource(
ConversationRenameApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name",
endpoint="installed_app_conversation_rename",
)
api.add_resource(
ConversationListApi, "/installed-apps/<uuid:installed_app_id>/conversations", endpoint="installed_app_conversations"
)
api.add_resource(
ConversationApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
endpoint="installed_app_conversation",
)
api.add_resource(
ConversationPinApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
endpoint="installed_app_conversation_pin",
)
api.add_resource(
ConversationUnPinApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin",
endpoint="installed_app_conversation_unpin",
)
# Explore Message
api.add_resource(MessageListApi, "/installed-apps/<uuid:installed_app_id>/messages", endpoint="installed_app_messages")
api.add_resource(
MessageFeedbackApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks",
endpoint="installed_app_message_feedback",
)
api.add_resource(
MessageMoreLikeThisApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this",
endpoint="installed_app_more_like_this",
)
api.add_resource(
MessageSuggestedQuestionApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions",
endpoint="installed_app_suggested_question",
)
# Explore Workflow
api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps/<uuid:installed_app_id>/workflows/run")
api.add_resource(
InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
) )
# Import tag controllers # Import tag controllers

@ -1,7 +1,7 @@
from functools import wraps from functools import wraps
from flask import request from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config from configs import dify_config

@ -1,5 +1,7 @@
import flask_restful from typing import Any
from flask_login import current_user
import flask_restful # type: ignore
from flask_login import current_user # type: ignore
from flask_restful import Resource, fields, marshal_with from flask_restful import Resource, fields, marshal_with
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -35,14 +37,15 @@ def _get_resource(resource_id, tenant_id, resource_model):
class BaseApiKeyListResource(Resource): class BaseApiKeyListResource(Resource):
method_decorators = [account_initialization_required, login_required, setup_required] method_decorators = [account_initialization_required, login_required, setup_required]
resource_type = None resource_type: str | None = None
resource_model = None resource_model: Any = None
resource_id_field = None resource_id_field: str | None = None
token_prefix = None token_prefix: str | None = None
max_keys = 10 max_keys = 10
@marshal_with(api_key_list) @marshal_with(api_key_list)
def get(self, resource_id): def get(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id) resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
keys = ( keys = (
@ -54,6 +57,7 @@ class BaseApiKeyListResource(Resource):
@marshal_with(api_key_fields) @marshal_with(api_key_fields)
def post(self, resource_id): def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id) resource_id = str(resource_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
if not current_user.is_editor: if not current_user.is_editor:
@ -86,11 +90,12 @@ class BaseApiKeyListResource(Resource):
class BaseApiKeyResource(Resource): class BaseApiKeyResource(Resource):
method_decorators = [account_initialization_required, login_required, setup_required] method_decorators = [account_initialization_required, login_required, setup_required]
resource_type = None resource_type: str | None = None
resource_model = None resource_model: Any = None
resource_id_field = None resource_id_field: str | None = None
def delete(self, resource_id, api_key_id): def delete(self, resource_id, api_key_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id) resource_id = str(resource_id)
api_key_id = str(api_key_id) api_key_id = str(api_key_id)
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model)

@ -1,4 +1,4 @@
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from controllers.console import api from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required

@ -1,4 +1,4 @@
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from controllers.console import api from controllers.console import api
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model

@ -1,6 +1,6 @@
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal, marshal_with, reqparse from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import api
@ -110,7 +110,7 @@ class AnnotationListApi(Resource):
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
keyword = request.args.get("keyword", default=None, type=str) keyword = request.args.get("keyword", default="", type=str)
app_id = str(app_id) app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)

@ -1,8 +1,8 @@
import uuid import uuid
from typing import cast from typing import cast
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, inputs, marshal, marshal_with, reqparse from flask_restful import Resource, inputs, marshal, marshal_with, reqparse # type: ignore
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, abort from werkzeug.exceptions import BadRequest, Forbidden, abort

@ -1,7 +1,7 @@
from typing import cast from typing import cast
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden

@ -1,7 +1,7 @@
import logging import logging
from flask import request from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services

@ -1,7 +1,7 @@
import logging import logging
import flask_login import flask_login # type: ignore
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services

@ -1,9 +1,9 @@
from datetime import UTC, datetime from datetime import UTC, datetime
import pytz import pytz # pip install pytz
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range from flask_restful.inputs import int_range # type: ignore
from sqlalchemy import func, or_ from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@ -77,8 +77,9 @@ class CompletionConversationApi(Resource):
query = query.where(Conversation.created_at < end_datetime_utc) query = query.where(Conversation.created_at < end_datetime_utc)
# FIXME, the type ignore in this file
if args["annotation_status"] == "annotated": if args["annotation_status"] == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
) )
elif args["annotation_status"] == "not_annotated": elif args["annotation_status"] == "not_annotated":
@ -222,7 +223,7 @@ class ChatConversationApi(Resource):
query = query.where(Conversation.created_at <= end_datetime_utc) query = query.where(Conversation.created_at <= end_datetime_utc)
if args["annotation_status"] == "annotated": if args["annotation_status"] == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
) )
elif args["annotation_status"] == "not_annotated": elif args["annotation_status"] == "not_annotated":
@ -234,7 +235,7 @@ class ChatConversationApi(Resource):
if args["message_count_gte"] and args["message_count_gte"] >= 1: if args["message_count_gte"] and args["message_count_gte"] >= 1:
query = ( query = (
query.options(joinedload(Conversation.messages)) query.options(joinedload(Conversation.messages)) # type: ignore
.join(Message, Message.conversation_id == Conversation.id) .join(Message, Message.conversation_id == Conversation.id)
.group_by(Conversation.id) .group_by(Conversation.id)
.having(func.count(Message.id) >= args["message_count_gte"]) .having(func.count(Message.id) >= args["message_count_gte"])

@ -1,4 +1,4 @@
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session

@ -1,7 +1,7 @@
import os import os
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from controllers.console import api from controllers.console import api
from controllers.console.app.error import ( from controllers.console.app.error import (

@ -1,8 +1,8 @@
import logging import logging
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, fields, marshal_with, reqparse from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range from flask_restful.inputs import int_range # type: ignore
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from controllers.console import api from controllers.console import api

@ -1,8 +1,9 @@
import json import json
from typing import cast
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource from flask_restful import Resource # type: ignore
from controllers.console import api from controllers.console import api
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
@ -26,7 +27,9 @@ class ModelConfigResource(Resource):
"""Modify app model config""" """Modify app model config"""
# validate config # validate config
model_configuration = AppModelConfigService.validate_configuration( model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id, config=request.json, app_mode=AppMode.value_of(app_model.mode) tenant_id=current_user.current_tenant_id,
config=cast(dict, request.json),
app_mode=AppMode.value_of(app_model.mode),
) )
new_app_model_config = AppModelConfig( new_app_model_config = AppModelConfig(
@ -38,9 +41,11 @@ class ModelConfigResource(Resource):
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
# get original app model config # get original app model config
original_app_model_config: AppModelConfig = ( original_app_model_config = (
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
) )
if original_app_model_config is None:
raise ValueError("Original app model config not found")
agent_mode = original_app_model_config.agent_mode_dict agent_mode = original_app_model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input # decrypt agent tool parameters if it's secret-input
parameter_map = {} parameter_map = {}

@ -1,4 +1,4 @@
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import BadRequest from werkzeug.exceptions import BadRequest
from controllers.console import api from controllers.console import api

@ -1,7 +1,7 @@
from datetime import UTC, datetime from datetime import UTC, datetime
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse # type: ignore
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
from constants.languages import supported_language from constants.languages import supported_language
@ -50,7 +50,7 @@ class AppSite(Resource):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
site = db.session.query(Site).filter(Site.app_id == app_model.id).one_or_404() site = Site.query.filter(Site.app_id == app_model.id).one_or_404()
for attr_name in [ for attr_name in [
"title", "title",

@ -3,8 +3,8 @@ from decimal import Decimal
import pytz import pytz
from flask import jsonify from flask import jsonify
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from controllers.console import api from controllers.console import api
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model

@ -2,7 +2,7 @@ import json
import logging import logging
from flask import abort, request from flask import abort, request
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse # type: ignore
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services

@ -1,5 +1,5 @@
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range from flask_restful.inputs import int_range # type: ignore
from controllers.console import api from controllers.console import api
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model

@ -1,5 +1,5 @@
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range from flask_restful.inputs import int_range # type: ignore
from controllers.console import api from controllers.console import api
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model

@ -3,8 +3,8 @@ from decimal import Decimal
import pytz import pytz
from flask import jsonify from flask import jsonify
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from controllers.console import api from controllers.console import api
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model

@ -8,7 +8,7 @@ from libs.login import current_user
from models import App, AppMode from models import App, AppMode
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func): def decorator(view_func):
@wraps(view_func) @wraps(view_func)
def decorated_view(*args, **kwargs): def decorated_view(*args, **kwargs):

@ -1,14 +1,14 @@
import datetime import datetime
from flask import request from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import api from controllers.console import api
from controllers.console.error import AlreadyActivateError from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import StrLen, email, extract_remote_ip, timezone from libs.helper import StrLen, email, extract_remote_ip, timezone
from models.account import AccountStatus, Tenant from models.account import AccountStatus
from services.account_service import AccountService, RegisterService from services.account_service import AccountService, RegisterService
@ -27,7 +27,7 @@ class ActivateCheckApi(Resource):
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
if invitation: if invitation:
data = invitation.get("data", {}) data = invitation.get("data", {})
tenant: Tenant = invitation.get("tenant", None) tenant = invitation.get("tenant", None)
workspace_name = tenant.name if tenant else None workspace_name = tenant.name if tenant else None
workspace_id = tenant.id if tenant else None workspace_id = tenant.id if tenant else None
invitee_email = data.get("email") if data else None invitee_email = data.get("email") if data else None

@ -1,5 +1,5 @@
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import api

@ -2,8 +2,8 @@ import logging
import requests import requests
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource from flask_restful import Resource # type: ignore
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from configs import dify_config from configs import dify_config
@ -17,8 +17,8 @@ from ..wraps import account_initialization_required, setup_required
def get_oauth_providers(): def get_oauth_providers():
with current_app.app_context(): with current_app.app_context():
notion_oauth = NotionOAuth( notion_oauth = NotionOAuth(
client_id=dify_config.NOTION_CLIENT_ID, client_id=dify_config.NOTION_CLIENT_ID or "",
client_secret=dify_config.NOTION_CLIENT_SECRET, client_secret=dify_config.NOTION_CLIENT_SECRET or "",
redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion", redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion",
) )

@ -2,7 +2,7 @@ import base64
import secrets import secrets
from flask import request from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from constants.languages import languages from constants.languages import languages
from controllers.console import api from controllers.console import api
@ -122,8 +122,8 @@ class ForgotPasswordResetApi(Resource):
else: else:
try: try:
account = AccountService.create_account_and_tenant( account = AccountService.create_account_and_tenant(
email=reset_data.get("email"), email=reset_data.get("email", ""),
name=reset_data.get("email"), name=reset_data.get("email", ""),
password=password_confirm, password=password_confirm,
interface_language=languages[0], interface_language=languages[0],
) )

@ -1,8 +1,8 @@
from typing import cast from typing import cast
import flask_login import flask_login # type: ignore
from flask import request from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
import services import services
from constants.languages import languages from constants.languages import languages

@ -4,7 +4,7 @@ from typing import Optional
import requests import requests
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_restful import Resource from flask_restful import Resource # type: ignore
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import Unauthorized
from configs import dify_config from configs import dify_config
@ -77,7 +77,8 @@ class OAuthCallback(Resource):
token = oauth_provider.get_access_token(code) token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token) user_info = oauth_provider.get_user_info(token)
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") error_text = e.response.text if e.response else str(e)
logging.exception(f"An error occurred during the OAuth process with {provider}: {error_text}")
return {"error": "OAuth process failed"}, 400 return {"error": "OAuth process failed"}, 400
if invite_token and RegisterService.is_valid_invite_token(invite_token): if invite_token and RegisterService.is_valid_invite_token(invite_token):
@ -129,7 +130,7 @@ class OAuthCallback(Resource):
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
account = Account.get_by_openid(provider, user_info.id) account: Optional[Account] = Account.get_by_openid(provider, user_info.id)
if not account: if not account:
account = Account.query.filter_by(email=user_info.email).first() account = Account.query.filter_by(email=user_info.email).first()

@ -1,5 +1,5 @@
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from controllers.console import api from controllers.console import api
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required

@ -2,8 +2,8 @@ import datetime
import json import json
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse # type: ignore
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
@ -218,7 +218,7 @@ class DataSourceNotionApi(Resource):
args["doc_form"], args["doc_form"],
args["doc_language"], args["doc_language"],
) )
return response, 200 return response.model_dump(), 200
class DataSourceNotionDatasetSyncApi(Resource): class DataSourceNotionDatasetSyncApi(Resource):

@ -1,7 +1,7 @@
import flask_restful import flask_restful # type: ignore
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user # type: ignore # type: ignore
from flask_restful import Resource, marshal, marshal_with, reqparse from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
@ -464,7 +464,7 @@ class DatasetIndexingEstimateApi(Resource):
except Exception as e: except Exception as e:
raise IndexingEstimateError(str(e)) raise IndexingEstimateError(str(e))
return response, 200 return response.model_dump(), 200
class DatasetRelatedAppListApi(Resource): class DatasetRelatedAppListApi(Resource):
@ -733,6 +733,18 @@ class DatasetPermissionUserListApi(Resource):
}, 200 }, 200
class DatasetAutoDisableLogApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
api.add_resource(DatasetListApi, "/datasets") api.add_resource(DatasetListApi, "/datasets")
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>") api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check") api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check")
@ -747,3 +759,4 @@ api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info")
api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting") api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting")
api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>") api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>")
api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users") api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users")
api.add_resource(DatasetAutoDisableLogApi, "/datasets/<uuid:dataset_id>/auto-disable-logs")

@ -1,12 +1,13 @@
import logging import logging
from argparse import ArgumentTypeError from argparse import ArgumentTypeError
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import cast
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, fields, marshal, marshal_with, reqparse from flask_restful import Resource, fields, marshal, marshal_with, reqparse # type: ignore
from sqlalchemy import asc, desc from sqlalchemy import asc, desc
from transformers.hf_argparser import string_to_bool from transformers.hf_argparser import string_to_bool # type: ignore
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
@ -51,6 +52,7 @@ from fields.document_fields import (
from libs.login import login_required from libs.login import login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from tasks.add_document_to_index_task import add_document_to_index_task from tasks.add_document_to_index_task import add_document_to_index_task
from tasks.remove_document_from_index_task import remove_document_from_index_task from tasks.remove_document_from_index_task import remove_document_from_index_task
@ -254,20 +256,22 @@ class DatasetDocumentListApi(Resource):
parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json") parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
parser.add_argument("original_document_id", type=str, required=False, location="json") parser.add_argument("original_document_id", type=str, required=False, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument( parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json" "doc_language", type=str, default="English", required=False, nullable=False, location="json"
) )
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
knowledge_config = KnowledgeConfig(**args)
if not dataset.indexing_technique and not args["indexing_technique"]: if not dataset.indexing_technique and not knowledge_config.indexing_technique:
raise ValueError("indexing_technique is required.") raise ValueError("indexing_technique is required.")
# validate args # validate args
DocumentService.document_create_args_validate(args) DocumentService.document_create_args_validate(knowledge_config)
try: try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user)
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
@ -277,6 +281,25 @@ class DatasetDocumentListApi(Resource):
return {"documents": documents, "batch": batch} return {"documents": documents, "batch": batch}
@setup_required
@login_required
@account_initialization_required
def delete(self, dataset_id):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
try:
document_ids = request.args.getlist("document_id")
DocumentService.delete_documents(dataset, document_ids)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return {"result": "success"}, 204
class DatasetInitApi(Resource): class DatasetInitApi(Resource):
@setup_required @setup_required
@ -312,9 +335,9 @@ class DatasetInitApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
knowledge_config = KnowledgeConfig(**args)
if args["indexing_technique"] == "high_quality": if knowledge_config.indexing_technique == "high_quality":
if args["embedding_model"] is None or args["embedding_model_provider"] is None: if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.") raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
try: try:
model_manager = ModelManager() model_manager = ModelManager()
@ -333,11 +356,11 @@ class DatasetInitApi(Resource):
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
# validate args # validate args
DocumentService.document_create_args_validate(args) DocumentService.document_create_args_validate(knowledge_config)
try: try:
dataset, documents, batch = DocumentService.save_document_without_dataset_id( dataset, documents, batch = DocumentService.save_document_without_dataset_id(
tenant_id=current_user.current_tenant_id, document_data=args, account=current_user tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -390,7 +413,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
try: try:
response = indexing_runner.indexing_estimate( estimate_response = indexing_runner.indexing_estimate(
current_user.current_tenant_id, current_user.current_tenant_id,
[extract_setting], [extract_setting],
data_process_rule_dict, data_process_rule_dict,
@ -398,6 +421,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
"English", "English",
dataset_id, dataset_id,
) )
return estimate_response.model_dump(), 200
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider "
@ -408,7 +432,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
except Exception as e: except Exception as e:
raise IndexingEstimateError(str(e)) raise IndexingEstimateError(str(e))
return response return response, 200
class DocumentBatchIndexingEstimateApi(DocumentResource): class DocumentBatchIndexingEstimateApi(DocumentResource):
@ -419,9 +443,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
batch = str(batch) batch = str(batch)
documents = self.get_batch_documents(dataset_id, batch) documents = self.get_batch_documents(dataset_id, batch)
response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}
if not documents: if not documents:
return response return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
data_process_rule = documents[0].dataset_process_rule data_process_rule = documents[0].dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict() data_process_rule_dict = data_process_rule.to_dict()
info_list = [] info_list = []
@ -499,6 +522,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
"English", "English",
dataset_id, dataset_id,
) )
return response.model_dump(), 200
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider "
@ -508,7 +532,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
except Exception as e: except Exception as e:
raise IndexingEstimateError(str(e)) raise IndexingEstimateError(str(e))
return response
class DocumentBatchIndexingStatusApi(DocumentResource): class DocumentBatchIndexingStatusApi(DocumentResource):
@ -581,7 +604,8 @@ class DocumentDetailApi(DocumentResource):
if metadata == "only": if metadata == "only":
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata} response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata}
elif metadata == "without": elif metadata == "without":
process_rules = DatasetService.get_process_rules(dataset_id) dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict()
data_source_info = document.data_source_detail_dict data_source_info = document.data_source_detail_dict
response = { response = {
"id": document.id, "id": document.id,
@ -589,7 +613,8 @@ class DocumentDetailApi(DocumentResource):
"data_source_type": document.data_source_type, "data_source_type": document.data_source_type,
"data_source_info": data_source_info, "data_source_info": data_source_info,
"dataset_process_rule_id": document.dataset_process_rule_id, "dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": process_rules, "dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,
"name": document.name, "name": document.name,
"created_from": document.created_from, "created_from": document.created_from,
"created_by": document.created_by, "created_by": document.created_by,
@ -612,7 +637,8 @@ class DocumentDetailApi(DocumentResource):
"doc_language": document.doc_language, "doc_language": document.doc_language,
} }
else: else:
process_rules = DatasetService.get_process_rules(dataset_id) dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict()
data_source_info = document.data_source_detail_dict data_source_info = document.data_source_detail_dict
response = { response = {
"id": document.id, "id": document.id,
@ -620,7 +646,8 @@ class DocumentDetailApi(DocumentResource):
"data_source_type": document.data_source_type, "data_source_type": document.data_source_type,
"data_source_info": data_source_info, "data_source_info": data_source_info,
"dataset_process_rule_id": document.dataset_process_rule_id, "dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": process_rules, "dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,
"name": document.name, "name": document.name,
"created_from": document.created_from, "created_from": document.created_from,
"created_by": document.created_by, "created_by": document.created_by,
@ -733,8 +760,7 @@ class DocumentMetadataApi(DocumentResource):
if not isinstance(doc_metadata, dict): if not isinstance(doc_metadata, dict):
raise ValueError("doc_metadata must be a dictionary.") raise ValueError("doc_metadata must be a dictionary.")
metadata_schema: dict = cast(dict, DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type])
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
document.doc_metadata = {} document.doc_metadata = {}
if doc_type == "others": if doc_type == "others":
@ -757,9 +783,8 @@ class DocumentStatusApi(DocumentResource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, document_id, action): def patch(self, dataset_id, action):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -774,84 +799,79 @@ class DocumentStatusApi(DocumentResource):
# check user's permission # check user's permission
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
document = self.get_document(dataset_id, document_id) document_ids = request.args.getlist("document_id")
for document_id in document_ids:
document = self.get_document(dataset_id, document_id)
indexing_cache_key = "document_{}_indexing".format(document.id) indexing_cache_key = "document_{}_indexing".format(document.id)
cache_result = redis_client.get(indexing_cache_key) cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None: if cache_result is not None:
raise InvalidActionError("Document is being indexed, please try again later") raise InvalidActionError(f"Document:{document.name} is being indexed, please try again later")
if action == "enable": if action == "enable":
if document.enabled: if document.enabled:
raise InvalidActionError("Document already enabled.") continue
document.enabled = True
document.disabled_at = None
document.disabled_by = None
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
document.enabled = True # Set cache to prevent indexing the same document multiple times
document.disabled_at = None redis_client.setex(indexing_cache_key, 600, 1)
document.disabled_by = None
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
# Set cache to prevent indexing the same document multiple times add_document_to_index_task.delay(document_id)
redis_client.setex(indexing_cache_key, 600, 1)
add_document_to_index_task.delay(document_id) elif action == "disable":
if not document.completed_at or document.indexing_status != "completed":
raise InvalidActionError(f"Document: {document.name} is not completed.")
if not document.enabled:
continue
return {"result": "success"}, 200 document.enabled = False
document.disabled_at = datetime.now(UTC).replace(tzinfo=None)
document.disabled_by = current_user.id
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
elif action == "disable": # Set cache to prevent indexing the same document multiple times
if not document.completed_at or document.indexing_status != "completed": redis_client.setex(indexing_cache_key, 600, 1)
raise InvalidActionError("Document is not completed.")
if not document.enabled:
raise InvalidActionError("Document already disabled.")
document.enabled = False remove_document_from_index_task.delay(document_id)
document.disabled_at = datetime.now(UTC).replace(tzinfo=None)
document.disabled_by = current_user.id
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
# Set cache to prevent indexing the same document multiple times elif action == "archive":
redis_client.setex(indexing_cache_key, 600, 1) if document.archived:
continue
remove_document_from_index_task.delay(document_id) document.archived = True
document.archived_at = datetime.now(UTC).replace(tzinfo=None)
document.archived_by = current_user.id
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
return {"result": "success"}, 200 if document.enabled:
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
elif action == "archive": remove_document_from_index_task.delay(document_id)
if document.archived:
raise InvalidActionError("Document already archived.")
document.archived = True elif action == "un_archive":
document.archived_at = datetime.now(UTC).replace(tzinfo=None) if not document.archived:
document.archived_by = current_user.id continue
document.updated_at = datetime.now(UTC).replace(tzinfo=None) document.archived = False
db.session.commit() document.archived_at = None
document.archived_by = None
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
if document.enabled:
# Set cache to prevent indexing the same document multiple times # Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1) redis_client.setex(indexing_cache_key, 600, 1)
remove_document_from_index_task.delay(document_id) add_document_to_index_task.delay(document_id)
return {"result": "success"}, 200
elif action == "un_archive":
if not document.archived:
raise InvalidActionError("Document is not archived.")
document.archived = False
document.archived_at = None
document.archived_by = None
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
add_document_to_index_task.delay(document_id) else:
raise InvalidActionError()
return {"result": "success"}, 200 return {"result": "success"}, 200
else:
raise InvalidActionError()
class DocumentPauseApi(DocumentResource): class DocumentPauseApi(DocumentResource):
@ -1022,7 +1042,7 @@ api.add_resource(
) )
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>") api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata") api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>") api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause") api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")
api.add_resource(DocumentRecoverApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume") api.add_resource(DocumentRecoverApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry") api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")

@ -1,16 +1,21 @@
import uuid import uuid
from datetime import UTC, datetime
import pandas as pd import pandas as pd
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal, reqparse from flask_restful import Resource, marshal, reqparse # type: ignore
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
from controllers.console import api from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError from controllers.console.datasets.error import (
ChildChunkDeleteIndexError,
ChildChunkIndexingError,
InvalidActionError,
NoFileUploadedError,
TooManyFilesError,
)
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_knowledge_limit_check, cloud_edition_billing_knowledge_limit_check,
@ -20,15 +25,15 @@ from controllers.console.wraps import (
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.segment_fields import segment_fields from fields.segment_fields import child_chunk_fields, segment_fields
from libs.login import login_required from libs.login import login_required
from models import DocumentSegment from models.dataset import ChildChunk, DocumentSegment
from services.dataset_service import DatasetService, DocumentService, SegmentService from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
from tasks.enable_segment_to_index_task import enable_segment_to_index_task
class DatasetDocumentSegmentListApi(Resource): class DatasetDocumentSegmentListApi(Resource):
@ -53,15 +58,16 @@ class DatasetDocumentSegmentListApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("last_id", type=str, default=None, location="args")
parser.add_argument("limit", type=int, default=20, location="args") parser.add_argument("limit", type=int, default=20, location="args")
parser.add_argument("status", type=str, action="append", default=[], location="args") parser.add_argument("status", type=str, action="append", default=[], location="args")
parser.add_argument("hit_count_gte", type=int, default=None, location="args") parser.add_argument("hit_count_gte", type=int, default=None, location="args")
parser.add_argument("enabled", type=str, default="all", location="args") parser.add_argument("enabled", type=str, default="all", location="args")
parser.add_argument("keyword", type=str, default=None, location="args") parser.add_argument("keyword", type=str, default=None, location="args")
parser.add_argument("page", type=int, default=1, location="args")
args = parser.parse_args() args = parser.parse_args()
last_id = args["last_id"] page = args["page"]
limit = min(args["limit"], 100) limit = min(args["limit"], 100)
status_list = args["status"] status_list = args["status"]
hit_count_gte = args["hit_count_gte"] hit_count_gte = args["hit_count_gte"]
@ -69,14 +75,7 @@ class DatasetDocumentSegmentListApi(Resource):
query = DocumentSegment.query.filter( query = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id
) ).order_by(DocumentSegment.position.asc())
if last_id is not None:
last_segment = db.session.get(DocumentSegment, str(last_id))
if last_segment:
query = query.filter(DocumentSegment.position > last_segment.position)
else:
return {"data": [], "has_more": False, "limit": limit}, 200
if status_list: if status_list:
query = query.filter(DocumentSegment.status.in_(status_list)) query = query.filter(DocumentSegment.status.in_(status_list))
@ -93,21 +92,44 @@ class DatasetDocumentSegmentListApi(Resource):
elif args["enabled"].lower() == "false": elif args["enabled"].lower() == "false":
query = query.filter(DocumentSegment.enabled == False) query = query.filter(DocumentSegment.enabled == False)
total = query.count() segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
segments = query.order_by(DocumentSegment.position).limit(limit + 1).all()
has_more = False
if len(segments) > limit:
has_more = True
segments = segments[:-1]
return { response = {
"data": marshal(segments, segment_fields), "data": marshal(segments.items, segment_fields),
"doc_form": document.doc_form,
"has_more": has_more,
"limit": limit, "limit": limit,
"total": total, "total": segments.total,
}, 200 "total_pages": segments.pages,
"page": page,
}
return response, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
segment_ids = request.args.getlist("segment_id")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
SegmentService.delete_segments(segment_ids, document, dataset)
return {"result": "success"}, 200
class DatasetDocumentSegmentApi(Resource): class DatasetDocumentSegmentApi(Resource):
@ -115,11 +137,15 @@ class DatasetDocumentSegmentApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, segment_id, action): def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check user's model setting # check user's model setting
DatasetService.check_dataset_model_setting(dataset) DatasetService.check_dataset_model_setting(dataset)
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
@ -147,59 +173,17 @@ class DatasetDocumentSegmentApi(Resource):
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
segment_ids = request.args.getlist("segment_id")
segment = DocumentSegment.query.filter( document_indexing_cache_key = "document_{}_indexing".format(document.id)
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
if segment.status != "completed":
raise NotFound("Segment is not completed, enable or disable function is not allowed")
document_indexing_cache_key = "document_{}_indexing".format(segment.document_id)
cache_result = redis_client.get(document_indexing_cache_key) cache_result = redis_client.get(document_indexing_cache_key)
if cache_result is not None: if cache_result is not None:
raise InvalidActionError("Document is being indexed, please try again later") raise InvalidActionError("Document is being indexed, please try again later")
try:
indexing_cache_key = "segment_{}_indexing".format(segment.id) SegmentService.update_segments_status(segment_ids, action, dataset, document)
cache_result = redis_client.get(indexing_cache_key) except Exception as e:
if cache_result is not None: raise InvalidActionError(str(e))
raise InvalidActionError("Segment is being indexed, please try again later") return {"result": "success"}, 200
if action == "enable":
if segment.enabled:
raise InvalidActionError("Segment is already enabled.")
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
db.session.commit()
# Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1)
enable_segment_to_index_task.delay(segment.id)
return {"result": "success"}, 200
elif action == "disable":
if not segment.enabled:
raise InvalidActionError("Segment is already disabled.")
segment.enabled = False
segment.disabled_at = datetime.now(UTC).replace(tzinfo=None)
segment.disabled_by = current_user.id
db.session.commit()
# Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1)
disable_segment_from_index_task.delay(segment.id)
return {"result": "success"}, 200
else:
raise InvalidActionError()
class DatasetDocumentSegmentAddApi(Resource): class DatasetDocumentSegmentAddApi(Resource):
@ -307,9 +291,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
parser.add_argument("content", type=str, required=True, nullable=False, location="json") parser.add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("answer", type=str, required=False, nullable=True, location="json") parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
parser.add_argument(
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document) SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(args, segment, document, dataset) segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required @setup_required
@ -412,8 +399,248 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
return {"job_id": job_id, "job_status": cache_result.decode()}, 200 return {"job_id": job_id, "job_status": cache_result.decode()}, 200
class ChildChunkAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
def post(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
if not current_user.is_editor:
raise Forbidden()
# check embedding model setting
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
parser = reqparse.RequestParser()
parser.add_argument("limit", type=int, default=20, location="args")
parser.add_argument("keyword", type=str, default=None, location="args")
parser.add_argument("page", type=int, default=1, location="args")
args = parser.parse_args()
page = args["page"]
limit = min(args["limit"], 100)
keyword = args["keyword"]
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
return {
"data": marshal(child_chunks.items, child_chunk_fields),
"total": child_chunks.total,
"total_pages": child_chunks.pages,
"page": page,
"limit": limit,
}, 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")]
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
class ChildChunkUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = ChildChunk.query.filter(
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
).first()
if not child_chunk:
raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
try:
SegmentService.delete_child_chunk(child_chunk, dataset)
except ChildChunkDeleteIndexServiceError as e:
raise ChildChunkDeleteIndexError(str(e))
return {"result": "success"}, 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = ChildChunk.query.filter(
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
).first()
if not child_chunk:
raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
child_chunk = SegmentService.update_child_chunk(
args.get("content"), child_chunk, segment, document, dataset
)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments") api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
api.add_resource(DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>") api.add_resource(
DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>"
)
api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment") api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
api.add_resource( api.add_resource(
DatasetDocumentSegmentUpdateApi, DatasetDocumentSegmentUpdateApi,
@ -424,3 +651,11 @@ api.add_resource(
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import", "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
"/datasets/batch_import_status/<uuid:job_id>", "/datasets/batch_import_status/<uuid:job_id>",
) )
api.add_resource(
ChildChunkAddApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks",
)
api.add_resource(
ChildChunkUpdateApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>",
)

@ -89,3 +89,15 @@ class IndexingEstimateError(BaseHTTPException):
error_code = "indexing_estimate_error" error_code = "indexing_estimate_error"
description = "Knowledge indexing estimate failed: {message}" description = "Knowledge indexing estimate failed: {message}"
code = 500 code = 500
class ChildChunkIndexingError(BaseHTTPException):
error_code = "child_chunk_indexing_error"
description = "Create child chunk index failed: {message}"
code = 500
class ChildChunkDeleteIndexError(BaseHTTPException):
error_code = "child_chunk_delete_index_error"
description = "Delete child chunk index failed: {message}"
code = 500

@ -1,6 +1,6 @@
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal, reqparse from flask_restful import Resource, marshal, reqparse # type: ignore
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services import services

@ -1,4 +1,4 @@
from flask_restful import Resource from flask_restful import Resource # type: ignore
from controllers.console import api from controllers.console import api
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase

@ -1,7 +1,7 @@
import logging import logging
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import marshal, reqparse from flask_restful import marshal, reqparse # type: ignore
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services.dataset_service import services.dataset_service

@ -1,4 +1,4 @@
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from controllers.console import api from controllers.console import api
from controllers.console.datasets.error import WebsiteCrawlError from controllers.console.datasets.error import WebsiteCrawlError

@ -4,7 +4,6 @@ from flask import request
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services
from controllers.console import api
from controllers.console.app.error import ( from controllers.console.app.error import (
AppUnavailableError, AppUnavailableError,
AudioTooLargeError, AudioTooLargeError,
@ -67,7 +66,7 @@ class ChatAudioApi(InstalledAppResource):
class ChatTextApi(InstalledAppResource): class ChatTextApi(InstalledAppResource):
def post(self, installed_app): def post(self, installed_app):
from flask_restful import reqparse from flask_restful import reqparse # type: ignore
app_model = installed_app.app app_model = installed_app.app
try: try:
@ -118,9 +117,3 @@ class ChatTextApi(InstalledAppResource):
except Exception as e: except Exception as e:
logging.exception("internal server error.") logging.exception("internal server error.")
raise InternalServerError() raise InternalServerError()
api.add_resource(ChatAudioApi, "/installed-apps/<uuid:installed_app_id>/audio-to-text", endpoint="installed_app_audio")
api.add_resource(ChatTextApi, "/installed-apps/<uuid:installed_app_id>/text-to-audio", endpoint="installed_app_text")
# api.add_resource(ChatTextApiWithMessageId, '/installed-apps/<uuid:installed_app_id>/text-to-audio/message-id',
# endpoint='installed_app_text_with_message_id')

@ -1,12 +1,11 @@
import logging import logging
from datetime import UTC, datetime from datetime import UTC, datetime
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import reqparse from flask_restful import reqparse # type: ignore
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services
from controllers.console import api
from controllers.console.app.error import ( from controllers.console.app.error import (
AppUnavailableError, AppUnavailableError,
CompletionRequestError, CompletionRequestError,
@ -147,21 +146,3 @@ class ChatStopApi(InstalledAppResource):
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {"result": "success"}, 200 return {"result": "success"}, 200
api.add_resource(
CompletionApi, "/installed-apps/<uuid:installed_app_id>/completion-messages", endpoint="installed_app_completion"
)
api.add_resource(
CompletionStopApi,
"/installed-apps/<uuid:installed_app_id>/completion-messages/<string:task_id>/stop",
endpoint="installed_app_stop_completion",
)
api.add_resource(
ChatApi, "/installed-apps/<uuid:installed_app_id>/chat-messages", endpoint="installed_app_chat_completion"
)
api.add_resource(
ChatStopApi,
"/installed-apps/<uuid:installed_app_id>/chat-messages/<string:task_id>/stop",
endpoint="installed_app_stop_chat_completion",
)

@ -1,10 +1,9 @@
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import marshal_with, reqparse from flask_restful import marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.explore.error import NotChatAppError from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
@ -118,28 +117,3 @@ class ConversationUnPinApi(InstalledAppResource):
WebConversationService.unpin(app_model, conversation_id, current_user) WebConversationService.unpin(app_model, conversation_id, current_user)
return {"result": "success"} return {"result": "success"}
api.add_resource(
ConversationRenameApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/name",
endpoint="installed_app_conversation_rename",
)
api.add_resource(
ConversationListApi, "/installed-apps/<uuid:installed_app_id>/conversations", endpoint="installed_app_conversations"
)
api.add_resource(
ConversationApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
endpoint="installed_app_conversation",
)
api.add_resource(
ConversationPinApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
endpoint="installed_app_conversation_pin",
)
api.add_resource(
ConversationUnPinApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/unpin",
endpoint="installed_app_conversation_unpin",
)

@ -1,8 +1,9 @@
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, inputs, marshal_with, reqparse from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore
from sqlalchemy import and_ from sqlalchemy import and_
from werkzeug.exceptions import BadRequest, Forbidden, NotFound from werkzeug.exceptions import BadRequest, Forbidden, NotFound
@ -34,7 +35,7 @@ class InstalledAppsListApi(Resource):
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all() installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
installed_apps = [ installed_app_list: list[dict[str, Any]] = [
{ {
"id": installed_app.id, "id": installed_app.id,
"app": installed_app.app, "app": installed_app.app,
@ -47,7 +48,7 @@ class InstalledAppsListApi(Resource):
for installed_app in installed_apps for installed_app in installed_apps
if installed_app.app is not None if installed_app.app is not None
] ]
installed_apps.sort( installed_app_list.sort(
key=lambda app: ( key=lambda app: (
-app["is_pinned"], -app["is_pinned"],
app["last_used_at"] is None, app["last_used_at"] is None,
@ -55,7 +56,7 @@ class InstalledAppsListApi(Resource):
) )
) )
return {"installed_apps": installed_apps} return {"installed_apps": installed_app_list}
@login_required @login_required
@account_initialization_required @account_initialization_required

@ -1,12 +1,11 @@
import logging import logging
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import marshal_with, reqparse from flask_restful import marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range from flask_restful.inputs import int_range # type: ignore
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services
from controllers.console import api
from controllers.console.app.error import ( from controllers.console.app.error import (
AppMoreLikeThisDisabledError, AppMoreLikeThisDisabledError,
CompletionRequestError, CompletionRequestError,
@ -70,7 +69,7 @@ class MessageFeedbackApi(InstalledAppResource):
args = parser.parse_args() args = parser.parse_args()
try: try:
MessageService.create_feedback(app_model, message_id, current_user, args["rating"], args["content"]) MessageService.create_feedback(app_model, message_id, current_user, args.get("rating"), args.get("content"))
except services.errors.message.MessageNotExistsError: except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
@ -153,21 +152,3 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
raise InternalServerError() raise InternalServerError()
return {"data": questions} return {"data": questions}
api.add_resource(MessageListApi, "/installed-apps/<uuid:installed_app_id>/messages", endpoint="installed_app_messages")
api.add_resource(
MessageFeedbackApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/feedbacks",
endpoint="installed_app_message_feedback",
)
api.add_resource(
MessageMoreLikeThisApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/more-like-this",
endpoint="installed_app_more_like_this",
)
api.add_resource(
MessageSuggestedQuestionApi,
"/installed-apps/<uuid:installed_app_id>/messages/<uuid:message_id>/suggested-questions",
endpoint="installed_app_suggested_question",
)

@ -1,4 +1,4 @@
from flask_restful import marshal_with from flask_restful import marshal_with # type: ignore
from controllers.common import fields from controllers.common import fields
from controllers.common import helpers as controller_helpers from controllers.common import helpers as controller_helpers

@ -1,5 +1,5 @@
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, fields, marshal_with, reqparse from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
from constants.languages import languages from constants.languages import languages
from controllers.console import api from controllers.console import api

@ -1,6 +1,6 @@
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import fields, marshal_with, reqparse from flask_restful import fields, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range from flask_restful.inputs import int_range # type: ignore
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api

@ -1,9 +1,8 @@
import logging import logging
from flask_restful import reqparse from flask_restful import reqparse # type: ignore
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
from controllers.console import api
from controllers.console.app.error import ( from controllers.console.app.error import (
CompletionRequestError, CompletionRequestError,
ProviderModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError,
@ -73,9 +72,3 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource):
AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id)
return {"result": "success"} return {"result": "success"}
api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps/<uuid:installed_app_id>/workflows/run")
api.add_resource(
InstalledAppWorkflowTaskStopApi, "/installed-apps/<uuid:installed_app_id>/workflows/tasks/<string:task_id>/stop"
)

@ -1,7 +1,7 @@
from functools import wraps from functools import wraps
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource from flask_restful import Resource # type: ignore
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required

@ -1,5 +1,5 @@
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse # type: ignore
from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
from controllers.console import api from controllers.console import api

@ -1,5 +1,5 @@
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource from flask_restful import Resource # type: ignore
from libs.login import login_required from libs.login import login_required
from services.feature_service import FeatureService from services.feature_service import FeatureService

@ -1,6 +1,8 @@
from typing import Literal
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with from flask_restful import Resource, marshal_with # type: ignore
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
import services import services
@ -48,7 +50,8 @@ class FileApi(Resource):
@cloud_edition_billing_resource_check("documents") @cloud_edition_billing_resource_check("documents")
def post(self): def post(self):
file = request.files["file"] file = request.files["file"]
source = request.form.get("source") source_str = request.form.get("source")
source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
if "file" not in request.files: if "file" not in request.files:
raise NoFileUploadedError() raise NoFileUploadedError()

@ -1,7 +1,7 @@
import os import os
from flask import session from flask import session
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from configs import dify_config from configs import dify_config
from libs.helper import StrLen from libs.helper import StrLen

@ -1,4 +1,4 @@
from flask_restful import Resource from flask_restful import Resource # type: ignore
from controllers.console import api from controllers.console import api

@ -2,8 +2,8 @@ import urllib.parse
from typing import cast from typing import cast
import httpx import httpx
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse # type: ignore
import services import services
from controllers.common import helpers from controllers.common import helpers

@ -1,5 +1,5 @@
from flask import request from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from configs import dify_config from configs import dify_config
from libs.helper import StrLen, email, extract_remote_ip from libs.helper import StrLen, email, extract_remote_ip

@ -1,6 +1,6 @@
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse # type: ignore
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import api
@ -23,7 +23,7 @@ class TagListApi(Resource):
@account_initialization_required @account_initialization_required
@marshal_with(tag_fields) @marshal_with(tag_fields)
def get(self): def get(self):
tag_type = request.args.get("type", type=str) tag_type = request.args.get("type", type=str, default="")
keyword = request.args.get("keyword", default=None, type=str) keyword = request.args.get("keyword", default=None, type=str)
tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword) tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword)

@ -2,7 +2,7 @@ import json
import logging import logging
import requests import requests
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from packaging import version from packaging import version
from configs import dify_config from configs import dify_config

@ -2,8 +2,8 @@ import datetime
import pytz import pytz
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, fields, marshal_with, reqparse from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
from configs import dify_config from configs import dify_config
from constants.languages import supported_language from constants.languages import supported_language

@ -1,4 +1,4 @@
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import api
@ -37,7 +37,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
model_load_balancing_service = ModelLoadBalancingService() model_load_balancing_service = ModelLoadBalancingService()
result = True result = True
error = None error = ""
try: try:
model_load_balancing_service.validate_load_balancing_credentials( model_load_balancing_service.validate_load_balancing_credentials(
@ -86,7 +86,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
model_load_balancing_service = ModelLoadBalancingService() model_load_balancing_service = ModelLoadBalancingService()
result = True result = True
error = None error = ""
try: try:
model_load_balancing_service.validate_load_balancing_credentials( model_load_balancing_service.validate_load_balancing_credentials(

@ -1,7 +1,7 @@
from urllib import parse from urllib import parse
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, abort, marshal_with, reqparse from flask_restful import Resource, abort, marshal_with, reqparse # type: ignore
import services import services
from configs import dify_config from configs import dify_config
@ -89,19 +89,19 @@ class MemberCancelInviteApi(Resource):
@account_initialization_required @account_initialization_required
def delete(self, member_id): def delete(self, member_id):
member = db.session.query(Account).filter(Account.id == str(member_id)).first() member = db.session.query(Account).filter(Account.id == str(member_id)).first()
if not member: if member is None:
abort(404) abort(404)
else:
try: try:
TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user) TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user)
except services.errors.account.CannotOperateSelfError as e: except services.errors.account.CannotOperateSelfError as e:
return {"code": "cannot-operate-self", "message": str(e)}, 400 return {"code": "cannot-operate-self", "message": str(e)}, 400
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
return {"code": "forbidden", "message": str(e)}, 403 return {"code": "forbidden", "message": str(e)}, 403
except services.errors.account.MemberNotInTenantError as e: except services.errors.account.MemberNotInTenantError as e:
return {"code": "member-not-found", "message": str(e)}, 404 return {"code": "member-not-found", "message": str(e)}, 404
except Exception as e: except Exception as e:
raise ValueError(str(e)) raise ValueError(str(e))
return {"result": "success"}, 204 return {"result": "success"}, 204
@ -122,10 +122,11 @@ class MemberUpdateRoleApi(Resource):
return {"code": "invalid-role", "message": "Invalid role"}, 400 return {"code": "invalid-role", "message": "Invalid role"}, 400
member = db.session.get(Account, str(member_id)) member = db.session.get(Account, str(member_id))
if not member: if member:
abort(404) abort(404)
try: try:
assert member is not None, "Member not found"
TenantService.update_member_role(current_user.current_tenant, member, new_role, current_user) TenantService.update_member_role(current_user.current_tenant, member, new_role, current_user)
except Exception as e: except Exception as e:
raise ValueError(str(e)) raise ValueError(str(e))

@ -1,8 +1,8 @@
import io import io
from flask import send_file from flask import send_file
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import api
@ -66,7 +66,7 @@ class ModelProviderValidateApi(Resource):
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
result = True result = True
error = None error = ""
try: try:
model_provider_service.provider_credentials_validate( model_provider_service.provider_credentials_validate(
@ -132,7 +132,8 @@ class ModelProviderIconApi(Resource):
icon_type=icon_type, icon_type=icon_type,
lang=lang, lang=lang,
) )
if icon is None:
raise ValueError(f"icon not found for provider {provider}, icon_type {icon_type}, lang {lang}")
return send_file(io.BytesIO(icon), mimetype=mimetype) return send_file(io.BytesIO(icon), mimetype=mimetype)

@ -1,7 +1,7 @@
import logging import logging
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
from controllers.console import api from controllers.console import api
@ -308,7 +308,7 @@ class ModelProviderModelValidateApi(Resource):
model_provider_service = ModelProviderService() model_provider_service = ModelProviderService()
result = True result = True
error = None error = ""
try: try:
model_provider_service.model_credentials_validate( model_provider_service.model_credentials_validate(

@ -1,8 +1,8 @@
import io import io
from flask import send_file from flask import send_file
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden

@ -1,8 +1,8 @@
import logging import logging
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse # type: ignore
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import Unauthorized
import services import services
@ -82,11 +82,7 @@ class WorkspaceListApi(Resource):
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args() args = parser.parse_args()
tenants = ( tenants = Tenant.query.order_by(Tenant.created_at.desc()).paginate(page=args["page"], per_page=args["limit"])
db.session.query(Tenant)
.order_by(Tenant.created_at.desc())
.paginate(page=args["page"], per_page=args["limit"])
)
has_more = False has_more = False
if len(tenants.items) == args["limit"]: if len(tenants.items) == args["limit"]:
@ -151,6 +147,8 @@ class SwitchWorkspaceApi(Resource):
raise AccountNotLinkTenantError("Account not link tenant") raise AccountNotLinkTenantError("Account not link tenant")
new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant
if new_tenant is None:
raise ValueError("Tenant not found")
return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)}
@ -166,7 +164,7 @@ class CustomConfigWorkspaceApi(Resource):
parser.add_argument("replace_webapp_logo", type=str, location="json") parser.add_argument("replace_webapp_logo", type=str, location="json")
args = parser.parse_args() args = parser.parse_args()
tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404() tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404()
custom_config_dict = { custom_config_dict = {
"remove_webapp_brand": args["remove_webapp_brand"], "remove_webapp_brand": args["remove_webapp_brand"],

@ -3,7 +3,7 @@ import os
from functools import wraps from functools import wraps
from flask import abort, request from flask import abort, request
from flask_login import current_user from flask_login import current_user # type: ignore
from configs import dify_config from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError from controllers.console.workspace.error import AccountNotInitializedError
@ -121,8 +121,8 @@ def cloud_utm_record(view):
utm_info = request.cookies.get("utm_info") utm_info = request.cookies.get("utm_info")
if utm_info: if utm_info:
utm_info = json.loads(utm_info) utm_info_dict: dict = json.loads(utm_info)
OperationService.record_utm(current_user.current_tenant_id, utm_info) OperationService.record_utm(current_user.current_tenant_id, utm_info_dict)
except Exception as e: except Exception as e:
pass pass
return view(*args, **kwargs) return view(*args, **kwargs)

@ -1,5 +1,5 @@
from flask import Response, request from flask import Response, request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
import services import services

@ -1,5 +1,5 @@
from flask import Response from flask import Response
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
from controllers.files import api from controllers.files import api

@ -1,4 +1,4 @@
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from controllers.console.wraps import setup_required from controllers.console.wraps import setup_required
from controllers.inner_api import api from controllers.inner_api import api

@ -45,14 +45,14 @@ def inner_api_user_auth(view):
if " " in user_id: if " " in user_id:
user_id = user_id.split(" ")[1] user_id = user_id.split(" ")[1]
inner_api_key = request.headers.get("X-Inner-Api-Key") inner_api_key = request.headers.get("X-Inner-Api-Key", "")
data_to_sign = f"DIFY {user_id}" data_to_sign = f"DIFY {user_id}"
signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1) signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1)
signature = b64encode(signature.digest()).decode("utf-8") signature_base64 = b64encode(signature.digest()).decode("utf-8")
if signature != token: if signature_base64 != token:
return view(*args, **kwargs) return view(*args, **kwargs)
kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first() kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first()

@ -1,4 +1,4 @@
from flask_restful import Resource, marshal_with from flask_restful import Resource, marshal_with # type: ignore
from controllers.common import fields from controllers.common import fields
from controllers.common import helpers as controller_helpers from controllers.common import helpers as controller_helpers

@ -1,7 +1,7 @@
import logging import logging
from flask import request from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
import services import services
@ -83,7 +83,7 @@ class TextApi(Resource):
and app_model.workflow and app_model.workflow
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech") text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {})
voice = args.get("voice") or text_to_speech.get("voice") voice = args.get("voice") or text_to_speech.get("voice")
else: else:
try: try:

@ -1,6 +1,6 @@
import logging import logging
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services

@ -1,5 +1,5 @@
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound

@ -1,5 +1,5 @@
from flask import request from flask import request
from flask_restful import Resource, marshal_with from flask_restful import Resource, marshal_with # type: ignore
import services import services
from controllers.common.errors import FilenameNotExistsError from controllers.common.errors import FilenameNotExistsError

@ -1,7 +1,7 @@
import logging import logging
from flask_restful import Resource, fields, marshal_with, reqparse from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range from flask_restful.inputs import int_range # type: ignore
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services import services
@ -108,7 +108,7 @@ class MessageFeedbackApi(Resource):
args = parser.parse_args() args = parser.parse_args()
try: try:
MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"]) MessageService.create_feedback(app_model, message_id, end_user, args.get("rating"), args.get("content"))
except services.errors.message.MessageNotExistsError: except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")

@ -1,7 +1,7 @@
import logging import logging
from flask_restful import Resource, fields, marshal_with, reqparse from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range from flask_restful.inputs import int_range # type: ignore
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
from controllers.service_api import api from controllers.service_api import api

@ -1,5 +1,5 @@
from flask import request from flask import request
from flask_restful import marshal, reqparse from flask_restful import marshal, reqparse # type: ignore
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
import services.dataset_service import services.dataset_service

@ -1,7 +1,7 @@
import json import json
from flask import request from flask import request
from flask_restful import marshal, reqparse from flask_restful import marshal, reqparse # type: ignore
from sqlalchemy import desc from sqlalchemy import desc
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
@ -22,6 +22,7 @@ from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
from services.dataset_service import DocumentService from services.dataset_service import DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from services.file_service import FileService from services.file_service import FileService
@ -67,13 +68,14 @@ class DocumentAddByTextApi(DatasetApiResource):
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
} }
args["data_source"] = data_source args["data_source"] = data_source
knowledge_config = KnowledgeConfig(**args)
# validate args # validate args
DocumentService.document_create_args_validate(args) DocumentService.document_create_args_validate(knowledge_config)
try: try:
documents, batch = DocumentService.save_document_with_dataset_id( documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset, dataset=dataset,
document_data=args, knowledge_config=knowledge_config,
account=current_user, account=current_user,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api", created_from="api",
@ -122,12 +124,13 @@ class DocumentUpdateByTextApi(DatasetApiResource):
args["data_source"] = data_source args["data_source"] = data_source
# validate args # validate args
args["original_document_id"] = str(document_id) args["original_document_id"] = str(document_id)
DocumentService.document_create_args_validate(args) knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)
try: try:
documents, batch = DocumentService.save_document_with_dataset_id( documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset, dataset=dataset,
document_data=args, knowledge_config=knowledge_config,
account=current_user, account=current_user,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api", created_from="api",
@ -186,12 +189,13 @@ class DocumentAddByFileApi(DatasetApiResource):
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
args["data_source"] = data_source args["data_source"] = data_source
# validate args # validate args
DocumentService.document_create_args_validate(args) knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)
try: try:
documents, batch = DocumentService.save_document_with_dataset_id( documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset, dataset=dataset,
document_data=args, knowledge_config=knowledge_config,
account=dataset.created_by_account, account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api", created_from="api",
@ -245,12 +249,14 @@ class DocumentUpdateByFileApi(DatasetApiResource):
args["data_source"] = data_source args["data_source"] = data_source
# validate args # validate args
args["original_document_id"] = str(document_id) args["original_document_id"] = str(document_id)
DocumentService.document_create_args_validate(args)
knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)
try: try:
documents, batch = DocumentService.save_document_with_dataset_id( documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset, dataset=dataset,
document_data=args, knowledge_config=knowledge_config,
account=dataset.created_by_account, account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api", created_from="api",

@ -1,5 +1,5 @@
from flask_login import current_user from flask_login import current_user # type: ignore
from flask_restful import marshal, reqparse from flask_restful import marshal, reqparse # type: ignore
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.service_api import api from controllers.service_api import api
@ -16,6 +16,7 @@ from extensions.ext_database import db
from fields.segment_fields import segment_fields from fields.segment_fields import segment_fields
from models.dataset import Dataset, DocumentSegment from models.dataset import Dataset, DocumentSegment
from services.dataset_service import DatasetService, DocumentService, SegmentService from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
class SegmentApi(DatasetApiResource): class SegmentApi(DatasetApiResource):
@ -193,7 +194,7 @@ class DatasetSegmentApi(DatasetApiResource):
args = parser.parse_args() args = parser.parse_args()
SegmentService.segment_create_args_validate(args["segment"], document) SegmentService.segment_create_args_validate(args["segment"], document)
segment = SegmentService.update_segment(args["segment"], segment, document, dataset) segment = SegmentService.update_segment(SegmentUpdateArgs(**args["segment"]), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200

@ -1,4 +1,4 @@
from flask_restful import Resource from flask_restful import Resource # type: ignore
from configs import dify_config from configs import dify_config
from controllers.service_api import api from controllers.service_api import api

@ -5,8 +5,8 @@ from functools import wraps
from typing import Optional from typing import Optional
from flask import current_app, request from flask import current_app, request
from flask_login import user_logged_in from flask_login import user_logged_in # type: ignore
from flask_restful import Resource from flask_restful import Resource # type: ignore
from pydantic import BaseModel from pydantic import BaseModel
from werkzeug.exceptions import Forbidden, Unauthorized from werkzeug.exceptions import Forbidden, Unauthorized
@ -49,6 +49,8 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
raise Forbidden("The app's API service has been disabled.") raise Forbidden("The app's API service has been disabled.")
tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first() tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first()
if tenant is None:
raise ValueError("Tenant does not exist.")
if tenant.status == TenantStatus.ARCHIVE: if tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("The workspace's status is archived.") raise Forbidden("The workspace's status is archived.")
@ -154,8 +156,8 @@ def validate_dataset_token(view=None):
# Login admin # Login admin
if account: if account:
account.current_tenant = tenant account.current_tenant = tenant
current_app.login_manager._update_request_context_with_user(account) current_app.login_manager._update_request_context_with_user(account) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=_get_user()) user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
else: else:
raise Unauthorized("Tenant owner account does not exist.") raise Unauthorized("Tenant owner account does not exist.")
else: else:

@ -1,4 +1,4 @@
from flask_restful import marshal_with from flask_restful import marshal_with # type: ignore
from controllers.common import fields from controllers.common import fields
from controllers.common import helpers as controller_helpers from controllers.common import helpers as controller_helpers

@ -65,7 +65,7 @@ class AudioApi(WebApiResource):
class TextApi(WebApiResource): class TextApi(WebApiResource):
def post(self, app_model: App, end_user): def post(self, app_model: App, end_user):
from flask_restful import reqparse from flask_restful import reqparse # type: ignore
try: try:
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -82,7 +82,7 @@ class TextApi(WebApiResource):
and app_model.workflow and app_model.workflow
and app_model.workflow.features_dict and app_model.workflow.features_dict
): ):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech") text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {})
voice = args.get("voice") or text_to_speech.get("voice") voice = args.get("voice") or text_to_speech.get("voice")
else: else:
try: try:

@ -1,6 +1,6 @@
import logging import logging
from flask_restful import reqparse from flask_restful import reqparse # type: ignore
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services

@ -1,5 +1,5 @@
from flask_restful import marshal_with, reqparse from flask_restful import marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range from flask_restful.inputs import int_range # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound

@ -1,4 +1,4 @@
from flask_restful import Resource from flask_restful import Resource # type: ignore
from controllers.web import api from controllers.web import api
from services.feature_service import FeatureService from services.feature_service import FeatureService

@ -1,5 +1,5 @@
from flask import request from flask import request
from flask_restful import marshal_with from flask_restful import marshal_with # type: ignore
import services import services
from controllers.common.errors import FilenameNotExistsError from controllers.common.errors import FilenameNotExistsError
@ -33,7 +33,7 @@ class FileApi(WebApiResource):
content=file.read(), content=file.read(),
mimetype=file.mimetype, mimetype=file.mimetype,
user=end_user, user=end_user,
source=source, source="datasets" if source == "datasets" else None,
) )
except services.errors.file.FileTooLargeError as file_too_large_error: except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description) raise FileTooLargeError(file_too_large_error.description)

@ -1,7 +1,7 @@
import logging import logging
from flask_restful import fields, marshal_with, reqparse from flask_restful import fields, marshal_with, reqparse # type: ignore
from flask_restful.inputs import int_range from flask_restful.inputs import int_range # type: ignore
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services
@ -105,10 +105,17 @@ class MessageFeedbackApi(WebApiResource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
parser.add_argument("content", type=str, location="json", default=None)
args = parser.parse_args() args = parser.parse_args()
try: try:
MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"]) MessageService.create_feedback(
app_model=app_model,
message_id=message_id,
user=end_user,
rating=args.get("rating"),
content=args.get("content"),
)
except services.errors.message.MessageNotExistsError: except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")

@ -1,7 +1,7 @@
import uuid import uuid
from flask import request from flask import request
from flask_restful import Resource from flask_restful import Resource # type: ignore
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
from controllers.web import api from controllers.web import api

@ -1,7 +1,7 @@
import urllib.parse import urllib.parse
import httpx import httpx
from flask_restful import marshal_with, reqparse from flask_restful import marshal_with, reqparse # type: ignore
import services import services
from controllers.common import helpers from controllers.common import helpers

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save