Merge branch 'main' into feat/plugins
commit
9a242bcac9
@ -0,0 +1,213 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
if os.environ.get("DEBUG", "false").lower() != "true":
|
||||||
|
from gevent import monkey
|
||||||
|
|
||||||
|
monkey.patch_all()
|
||||||
|
|
||||||
|
import grpc.experimental.gevent
|
||||||
|
|
||||||
|
grpc.experimental.gevent.init_gevent()
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from logging.handlers import RotatingFileHandler
|
||||||
|
|
||||||
|
from flask import Flask, Response, request
|
||||||
|
from flask_cors import CORS
|
||||||
|
from werkzeug.exceptions import Unauthorized
|
||||||
|
|
||||||
|
import contexts
|
||||||
|
from commands import register_commands
|
||||||
|
from configs import dify_config
|
||||||
|
from extensions import (
|
||||||
|
ext_celery,
|
||||||
|
ext_code_based_extension,
|
||||||
|
ext_compress,
|
||||||
|
ext_database,
|
||||||
|
ext_hosting_provider,
|
||||||
|
ext_login,
|
||||||
|
ext_mail,
|
||||||
|
ext_migrate,
|
||||||
|
ext_proxy_fix,
|
||||||
|
ext_redis,
|
||||||
|
ext_sentry,
|
||||||
|
ext_storage,
|
||||||
|
)
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_login import login_manager
|
||||||
|
from libs.passport import PassportService
|
||||||
|
from services.account_service import AccountService
|
||||||
|
|
||||||
|
|
||||||
|
class DifyApp(Flask):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Application Factory Function
|
||||||
|
# ----------------------------
|
||||||
|
def create_flask_app_with_configs() -> Flask:
|
||||||
|
"""
|
||||||
|
create a raw flask app
|
||||||
|
with configs loaded from .env file
|
||||||
|
"""
|
||||||
|
dify_app = DifyApp(__name__)
|
||||||
|
dify_app.config.from_mapping(dify_config.model_dump())
|
||||||
|
|
||||||
|
# populate configs into system environment variables
|
||||||
|
for key, value in dify_app.config.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
os.environ[key] = value
|
||||||
|
elif isinstance(value, int | float | bool):
|
||||||
|
os.environ[key] = str(value)
|
||||||
|
elif value is None:
|
||||||
|
os.environ[key] = ""
|
||||||
|
|
||||||
|
return dify_app
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> Flask:
|
||||||
|
app = create_flask_app_with_configs()
|
||||||
|
|
||||||
|
app.secret_key = app.config["SECRET_KEY"]
|
||||||
|
|
||||||
|
log_handlers = None
|
||||||
|
log_file = app.config.get("LOG_FILE")
|
||||||
|
if log_file:
|
||||||
|
log_dir = os.path.dirname(log_file)
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
log_handlers = [
|
||||||
|
RotatingFileHandler(
|
||||||
|
filename=log_file,
|
||||||
|
maxBytes=1024 * 1024 * 1024,
|
||||||
|
backupCount=5,
|
||||||
|
),
|
||||||
|
logging.StreamHandler(sys.stdout),
|
||||||
|
]
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=app.config.get("LOG_LEVEL"),
|
||||||
|
format=app.config.get("LOG_FORMAT"),
|
||||||
|
datefmt=app.config.get("LOG_DATEFORMAT"),
|
||||||
|
handlers=log_handlers,
|
||||||
|
force=True,
|
||||||
|
)
|
||||||
|
log_tz = app.config.get("LOG_TZ")
|
||||||
|
if log_tz:
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pytz
|
||||||
|
|
||||||
|
timezone = pytz.timezone(log_tz)
|
||||||
|
|
||||||
|
def time_converter(seconds):
|
||||||
|
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
|
||||||
|
|
||||||
|
for handler in logging.root.handlers:
|
||||||
|
handler.formatter.converter = time_converter
|
||||||
|
initialize_extensions(app)
|
||||||
|
register_blueprints(app)
|
||||||
|
register_commands(app)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_extensions(app):
|
||||||
|
# Since the application instance is now created, pass it to each Flask
|
||||||
|
# extension instance to bind it to the Flask application instance (app)
|
||||||
|
ext_compress.init_app(app)
|
||||||
|
ext_code_based_extension.init()
|
||||||
|
ext_database.init_app(app)
|
||||||
|
ext_migrate.init(app, db)
|
||||||
|
ext_redis.init_app(app)
|
||||||
|
ext_storage.init_app(app)
|
||||||
|
ext_celery.init_app(app)
|
||||||
|
ext_login.init_app(app)
|
||||||
|
ext_mail.init_app(app)
|
||||||
|
ext_hosting_provider.init_app(app)
|
||||||
|
ext_sentry.init_app(app)
|
||||||
|
ext_proxy_fix.init_app(app)
|
||||||
|
|
||||||
|
|
||||||
|
# Flask-Login configuration
|
||||||
|
@login_manager.request_loader
|
||||||
|
def load_user_from_request(request_from_flask_login):
|
||||||
|
"""Load user based on the request."""
|
||||||
|
if request.blueprint not in {"console", "inner_api"}:
|
||||||
|
return None
|
||||||
|
# Check if the user_id contains a dot, indicating the old format
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header:
|
||||||
|
auth_token = request.args.get("_token")
|
||||||
|
if not auth_token:
|
||||||
|
raise Unauthorized("Invalid Authorization token.")
|
||||||
|
else:
|
||||||
|
if " " not in auth_header:
|
||||||
|
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||||
|
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||||
|
auth_scheme = auth_scheme.lower()
|
||||||
|
if auth_scheme != "bearer":
|
||||||
|
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||||
|
|
||||||
|
decoded = PassportService().verify(auth_token)
|
||||||
|
user_id = decoded.get("user_id")
|
||||||
|
|
||||||
|
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||||
|
if logged_in_account:
|
||||||
|
contexts.tenant_id.set(logged_in_account.current_tenant_id)
|
||||||
|
return logged_in_account
|
||||||
|
|
||||||
|
|
||||||
|
@login_manager.unauthorized_handler
|
||||||
|
def unauthorized_handler():
|
||||||
|
"""Handle unauthorized requests."""
|
||||||
|
return Response(
|
||||||
|
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
|
||||||
|
status=401,
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# register blueprint routers
|
||||||
|
def register_blueprints(app):
|
||||||
|
from controllers.console import bp as console_app_bp
|
||||||
|
from controllers.files import bp as files_bp
|
||||||
|
from controllers.inner_api import bp as inner_api_bp
|
||||||
|
from controllers.service_api import bp as service_api_bp
|
||||||
|
from controllers.web import bp as web_bp
|
||||||
|
|
||||||
|
CORS(
|
||||||
|
service_api_bp,
|
||||||
|
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||||
|
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||||
|
)
|
||||||
|
app.register_blueprint(service_api_bp)
|
||||||
|
|
||||||
|
CORS(
|
||||||
|
web_bp,
|
||||||
|
resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
|
||||||
|
supports_credentials=True,
|
||||||
|
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||||
|
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||||
|
expose_headers=["X-Version", "X-Env"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.register_blueprint(web_bp)
|
||||||
|
|
||||||
|
CORS(
|
||||||
|
console_app_bp,
|
||||||
|
resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
|
||||||
|
supports_credentials=True,
|
||||||
|
allow_headers=["Content-Type", "Authorization"],
|
||||||
|
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||||
|
expose_headers=["X-Version", "X-Env"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.register_blueprint(console_app_bp)
|
||||||
|
|
||||||
|
CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
|
||||||
|
app.register_blueprint(files_bp)
|
||||||
|
|
||||||
|
app.register_blueprint(inner_api_bp)
|
||||||
@ -1,88 +1,24 @@
|
|||||||
import logging
|
from flask_restful import Resource
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restful import Resource, marshal, reqparse
|
|
||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
|
||||||
|
|
||||||
import services
|
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.app.error import (
|
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||||
CompletionRequestError,
|
|
||||||
ProviderModelCurrentlyNotSupportError,
|
|
||||||
ProviderNotInitializeError,
|
|
||||||
ProviderQuotaExceededError,
|
|
||||||
)
|
|
||||||
from controllers.console.datasets.error import DatasetNotInitializedError
|
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from core.errors.error import (
|
|
||||||
LLMBadRequestError,
|
|
||||||
ModelCurrentlyNotSupportError,
|
|
||||||
ProviderTokenNotInitError,
|
|
||||||
QuotaExceededError,
|
|
||||||
)
|
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
|
||||||
from fields.hit_testing_fields import hit_testing_record_fields
|
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from services.dataset_service import DatasetService
|
|
||||||
from services.hit_testing_service import HitTestingService
|
|
||||||
|
|
||||||
|
|
||||||
class HitTestingApi(Resource):
|
class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, dataset_id):
|
def post(self, dataset_id):
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
|
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||||
if dataset is None:
|
args = self.parse_args()
|
||||||
raise NotFound("Dataset not found.")
|
self.hit_testing_args_check(args)
|
||||||
|
|
||||||
try:
|
|
||||||
DatasetService.check_dataset_permission(dataset, current_user)
|
|
||||||
except services.errors.account.NoPermissionError as e:
|
|
||||||
raise Forbidden(str(e))
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("query", type=str, location="json")
|
|
||||||
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
|
||||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
HitTestingService.hit_testing_args_check(args)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = HitTestingService.retrieve(
|
|
||||||
dataset=dataset,
|
|
||||||
query=args["query"],
|
|
||||||
account=current_user,
|
|
||||||
retrieval_model=args["retrieval_model"],
|
|
||||||
external_retrieval_model=args["external_retrieval_model"],
|
|
||||||
limit=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
return self.perform_hit_testing(dataset, args)
|
||||||
except services.errors.index.IndexNotInitializedError:
|
|
||||||
raise DatasetNotInitializedError()
|
|
||||||
except ProviderTokenNotInitError as ex:
|
|
||||||
raise ProviderNotInitializeError(ex.description)
|
|
||||||
except QuotaExceededError:
|
|
||||||
raise ProviderQuotaExceededError()
|
|
||||||
except ModelCurrentlyNotSupportError:
|
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
|
||||||
except LLMBadRequestError:
|
|
||||||
raise ProviderNotInitializeError(
|
|
||||||
"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
|
||||||
"in the Settings -> Model Provider."
|
|
||||||
)
|
|
||||||
except InvokeError as e:
|
|
||||||
raise CompletionRequestError(e.description)
|
|
||||||
except ValueError as e:
|
|
||||||
raise ValueError(str(e))
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception("Hit testing failed.")
|
|
||||||
raise InternalServerError(str(e))
|
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
|
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
|
||||||
|
|||||||
@ -0,0 +1,85 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from flask_login import current_user
|
||||||
|
from flask_restful import marshal, reqparse
|
||||||
|
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||||
|
|
||||||
|
import services.dataset_service
|
||||||
|
from controllers.console.app.error import (
|
||||||
|
CompletionRequestError,
|
||||||
|
ProviderModelCurrentlyNotSupportError,
|
||||||
|
ProviderNotInitializeError,
|
||||||
|
ProviderQuotaExceededError,
|
||||||
|
)
|
||||||
|
from controllers.console.datasets.error import DatasetNotInitializedError
|
||||||
|
from core.errors.error import (
|
||||||
|
LLMBadRequestError,
|
||||||
|
ModelCurrentlyNotSupportError,
|
||||||
|
ProviderTokenNotInitError,
|
||||||
|
QuotaExceededError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
|
from fields.hit_testing_fields import hit_testing_record_fields
|
||||||
|
from services.dataset_service import DatasetService
|
||||||
|
from services.hit_testing_service import HitTestingService
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetsHitTestingBase:
|
||||||
|
@staticmethod
|
||||||
|
def get_and_validate_dataset(dataset_id: str):
|
||||||
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
|
if dataset is None:
|
||||||
|
raise NotFound("Dataset not found.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
DatasetService.check_dataset_permission(dataset, current_user)
|
||||||
|
except services.errors.account.NoPermissionError as e:
|
||||||
|
raise Forbidden(str(e))
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def hit_testing_args_check(args):
|
||||||
|
HitTestingService.hit_testing_args_check(args)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_args():
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
|
||||||
|
parser.add_argument("query", type=str, location="json")
|
||||||
|
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||||
|
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def perform_hit_testing(dataset, args):
|
||||||
|
try:
|
||||||
|
response = HitTestingService.retrieve(
|
||||||
|
dataset=dataset,
|
||||||
|
query=args["query"],
|
||||||
|
account=current_user,
|
||||||
|
retrieval_model=args["retrieval_model"],
|
||||||
|
external_retrieval_model=args["external_retrieval_model"],
|
||||||
|
limit=10,
|
||||||
|
)
|
||||||
|
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||||
|
except services.errors.index.IndexNotInitializedError:
|
||||||
|
raise DatasetNotInitializedError()
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
except QuotaExceededError:
|
||||||
|
raise ProviderQuotaExceededError()
|
||||||
|
except ModelCurrentlyNotSupportError:
|
||||||
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
|
except LLMBadRequestError:
|
||||||
|
raise ProviderNotInitializeError(
|
||||||
|
"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||||
|
"in the Settings -> Model Provider."
|
||||||
|
)
|
||||||
|
except InvokeError as e:
|
||||||
|
raise CompletionRequestError(e.description)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("Hit testing failed.")
|
||||||
|
raise InternalServerError(str(e))
|
||||||
@ -0,0 +1,17 @@
|
|||||||
|
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||||
|
from controllers.service_api import api
|
||||||
|
from controllers.service_api.wraps import DatasetApiResource
|
||||||
|
|
||||||
|
|
||||||
|
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||||
|
def post(self, tenant_id, dataset_id):
|
||||||
|
dataset_id_str = str(dataset_id)
|
||||||
|
|
||||||
|
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||||
|
args = self.parse_args()
|
||||||
|
self.hit_testing_args_check(args)
|
||||||
|
|
||||||
|
return self.perform_hit_testing(dataset, args)
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
|
||||||
@ -0,0 +1,46 @@
|
|||||||
|
model: accounts/fireworks/models/qwen2p5-72b-instruct
|
||||||
|
label:
|
||||||
|
zh_Hans: Qwen2.5 72B Instruct
|
||||||
|
en_US: Qwen2.5 72B Instruct
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 32768
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
- name: context_length_exceeded_behavior
|
||||||
|
default: None
|
||||||
|
label:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
help:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
type: string
|
||||||
|
options:
|
||||||
|
- None
|
||||||
|
- truncate
|
||||||
|
- error
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '0.9'
|
||||||
|
output: '0.9'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue