Merge branch 'main' into feat/r2
commit
82be119fec
@ -1,5 +1,4 @@
|
|||||||
FROM mcr.microsoft.com/devcontainers/python:3.12
|
FROM mcr.microsoft.com/devcontainers/python:3.12
|
||||||
|
|
||||||
# [Optional] Uncomment this section to install additional OS packages.
|
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||||
# RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
&& apt-get -y install libgmp-dev libmpfr-dev libmpc-dev
|
||||||
# && apt-get -y install --no-install-recommends <your-package-list-here>
|
|
||||||
|
|||||||
@ -0,0 +1,27 @@
|
|||||||
|
from flask_restful import (
|
||||||
|
Resource, # type: ignore
|
||||||
|
reqparse,
|
||||||
|
)
|
||||||
|
|
||||||
|
from controllers.console.wraps import setup_required
|
||||||
|
from controllers.inner_api import api
|
||||||
|
from controllers.inner_api.wraps import enterprise_inner_api_only
|
||||||
|
from services.enterprise.mail_service import DifyMail, EnterpriseMailService
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseMail(Resource):
|
||||||
|
@setup_required
|
||||||
|
@enterprise_inner_api_only
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("to", type=str, action="append", required=True)
|
||||||
|
parser.add_argument("subject", type=str, required=True)
|
||||||
|
parser.add_argument("body", type=str, required=True)
|
||||||
|
parser.add_argument("substitutions", type=dict, required=False)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
EnterpriseMailService.send_mail(DifyMail(**args))
|
||||||
|
return {"message": "success"}, 200
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(EnterpriseMail, "/enterprise/mail")
|
||||||
@ -0,0 +1,120 @@
|
|||||||
|
from flask import request
|
||||||
|
from flask_restful import Resource, reqparse
|
||||||
|
from jwt import InvalidTokenError # type: ignore
|
||||||
|
from werkzeug.exceptions import BadRequest
|
||||||
|
|
||||||
|
import services
|
||||||
|
from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError
|
||||||
|
from controllers.console.error import AccountBannedError, AccountNotFound
|
||||||
|
from controllers.console.wraps import setup_required
|
||||||
|
from libs.helper import email
|
||||||
|
from libs.password import valid_password
|
||||||
|
from services.account_service import AccountService
|
||||||
|
from services.webapp_auth_service import WebAppAuthService
|
||||||
|
|
||||||
|
|
||||||
|
class LoginApi(Resource):
|
||||||
|
"""Resource for web app email/password login."""
|
||||||
|
|
||||||
|
def post(self):
|
||||||
|
"""Authenticate user and login."""
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("email", type=email, required=True, location="json")
|
||||||
|
parser.add_argument("password", type=valid_password, required=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
app_code = request.headers.get("X-App-Code")
|
||||||
|
if app_code is None:
|
||||||
|
raise BadRequest("X-App-Code header is missing.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
account = WebAppAuthService.authenticate(args["email"], args["password"])
|
||||||
|
except services.errors.account.AccountLoginError:
|
||||||
|
raise AccountBannedError()
|
||||||
|
except services.errors.account.AccountPasswordError:
|
||||||
|
raise EmailOrPasswordMismatchError()
|
||||||
|
except services.errors.account.AccountNotFoundError:
|
||||||
|
raise AccountNotFound()
|
||||||
|
|
||||||
|
WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code)
|
||||||
|
|
||||||
|
end_user = WebAppAuthService.create_end_user(email=args["email"], app_code=app_code)
|
||||||
|
|
||||||
|
token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id)
|
||||||
|
return {"result": "success", "token": token}
|
||||||
|
|
||||||
|
|
||||||
|
# class LogoutApi(Resource):
|
||||||
|
# @setup_required
|
||||||
|
# def get(self):
|
||||||
|
# account = cast(Account, flask_login.current_user)
|
||||||
|
# if isinstance(account, flask_login.AnonymousUserMixin):
|
||||||
|
# return {"result": "success"}
|
||||||
|
# flask_login.logout_user()
|
||||||
|
# return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
class EmailCodeLoginSendEmailApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("email", type=email, required=True, location="json")
|
||||||
|
parser.add_argument("language", type=str, required=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||||
|
language = "zh-Hans"
|
||||||
|
else:
|
||||||
|
language = "en-US"
|
||||||
|
|
||||||
|
account = WebAppAuthService.get_user_through_email(args["email"])
|
||||||
|
if account is None:
|
||||||
|
raise AccountNotFound()
|
||||||
|
else:
|
||||||
|
token = WebAppAuthService.send_email_code_login_email(account=account, language=language)
|
||||||
|
|
||||||
|
return {"result": "success", "data": token}
|
||||||
|
|
||||||
|
|
||||||
|
class EmailCodeLoginApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("email", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("code", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("token", type=str, required=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
user_email = args["email"]
|
||||||
|
app_code = request.headers.get("X-App-Code")
|
||||||
|
if app_code is None:
|
||||||
|
raise BadRequest("X-App-Code header is missing.")
|
||||||
|
|
||||||
|
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
|
||||||
|
if token_data is None:
|
||||||
|
raise InvalidTokenError()
|
||||||
|
|
||||||
|
if token_data["email"] != args["email"]:
|
||||||
|
raise InvalidEmailError()
|
||||||
|
|
||||||
|
if token_data["code"] != args["code"]:
|
||||||
|
raise EmailCodeError()
|
||||||
|
|
||||||
|
WebAppAuthService.revoke_email_code_login_token(args["token"])
|
||||||
|
account = WebAppAuthService.get_user_through_email(user_email)
|
||||||
|
if not account:
|
||||||
|
raise AccountNotFound()
|
||||||
|
|
||||||
|
WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code)
|
||||||
|
|
||||||
|
end_user = WebAppAuthService.create_end_user(email=user_email, app_code=app_code)
|
||||||
|
|
||||||
|
token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id)
|
||||||
|
AccountService.reset_login_error_rate_limit(args["email"])
|
||||||
|
return {"result": "success", "token": token}
|
||||||
|
|
||||||
|
|
||||||
|
# api.add_resource(LoginApi, "/login")
|
||||||
|
# api.add_resource(LogoutApi, "/logout")
|
||||||
|
# api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
|
||||||
|
# api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
|
||||||
@ -0,0 +1,32 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class WaterCrawlError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WaterCrawlBadRequestError(WaterCrawlError):
|
||||||
|
def __init__(self, response):
|
||||||
|
self.status_code = response.status_code
|
||||||
|
self.response = response
|
||||||
|
data = response.json()
|
||||||
|
self.message = data.get("message", "Unknown error occurred")
|
||||||
|
self.errors = data.get("errors", {})
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def flat_errors(self):
|
||||||
|
return json.dumps(self.errors)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"WaterCrawlBadRequestError: {self.message} \n {self.flat_errors}"
|
||||||
|
|
||||||
|
|
||||||
|
class WaterCrawlPermissionError(WaterCrawlBadRequestError):
|
||||||
|
def __str__(self):
|
||||||
|
return f"You are exceeding your WaterCrawl API limits. {self.message}"
|
||||||
|
|
||||||
|
|
||||||
|
class WaterCrawlAuthenticationError(WaterCrawlBadRequestError):
|
||||||
|
def __str__(self):
|
||||||
|
return "WaterCrawl API key is invalid or expired. Please check your API key and try again."
|
||||||
@ -0,0 +1,7 @@
|
|||||||
|
# The minimal selector length for valid variables.
|
||||||
|
#
|
||||||
|
# The first element of the selector is the node id, and the second element is the variable name.
|
||||||
|
#
|
||||||
|
# If the selector length is more than 2, the remaining parts are the keys / indexes paths used
|
||||||
|
# to extract part of the variable value.
|
||||||
|
MIN_SELECTORS_LENGTH = 2
|
||||||
@ -0,0 +1,8 @@
|
|||||||
|
from collections.abc import Iterable, Sequence
|
||||||
|
|
||||||
|
|
||||||
|
def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]:
|
||||||
|
selectors = [node_id, name]
|
||||||
|
if paths:
|
||||||
|
selectors.extend(paths)
|
||||||
|
return selectors
|
||||||
@ -0,0 +1,98 @@
|
|||||||
|
"""
|
||||||
|
Domain entities for workflow node execution.
|
||||||
|
|
||||||
|
This module contains the domain model for workflow node execution, which is used
|
||||||
|
by the core workflow module. These models are independent of the storage mechanism
|
||||||
|
and don't contain implementation details like tenant_id, app_id, etc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||||
|
from core.workflow.nodes.enums import NodeType
|
||||||
|
|
||||||
|
|
||||||
|
class NodeExecutionStatus(StrEnum):
|
||||||
|
"""
|
||||||
|
Node Execution Status Enum.
|
||||||
|
"""
|
||||||
|
|
||||||
|
RUNNING = "running"
|
||||||
|
SUCCEEDED = "succeeded"
|
||||||
|
FAILED = "failed"
|
||||||
|
EXCEPTION = "exception"
|
||||||
|
RETRY = "retry"
|
||||||
|
|
||||||
|
|
||||||
|
class NodeExecution(BaseModel):
|
||||||
|
"""
|
||||||
|
Domain model for workflow node execution.
|
||||||
|
|
||||||
|
This model represents the core business entity of a node execution,
|
||||||
|
without implementation details like tenant_id, app_id, etc.
|
||||||
|
|
||||||
|
Note: User/context-specific fields (triggered_from, created_by, created_by_role)
|
||||||
|
have been moved to the repository implementation to keep the domain model clean.
|
||||||
|
These fields are still accepted in the constructor for backward compatibility,
|
||||||
|
but they are not stored in the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Core identification fields
|
||||||
|
id: str # Unique identifier for this execution record
|
||||||
|
node_execution_id: Optional[str] = None # Optional secondary ID for cross-referencing
|
||||||
|
workflow_id: str # ID of the workflow this node belongs to
|
||||||
|
workflow_run_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)
|
||||||
|
|
||||||
|
# Execution positioning and flow
|
||||||
|
index: int # Sequence number for ordering in trace visualization
|
||||||
|
predecessor_node_id: Optional[str] = None # ID of the node that executed before this one
|
||||||
|
node_id: str # ID of the node being executed
|
||||||
|
node_type: NodeType # Type of node (e.g., start, llm, knowledge)
|
||||||
|
title: str # Display title of the node
|
||||||
|
|
||||||
|
# Execution data
|
||||||
|
inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node
|
||||||
|
process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data
|
||||||
|
outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node
|
||||||
|
|
||||||
|
# Execution state
|
||||||
|
status: NodeExecutionStatus = NodeExecutionStatus.RUNNING # Current execution status
|
||||||
|
error: Optional[str] = None # Error message if execution failed
|
||||||
|
elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds
|
||||||
|
|
||||||
|
# Additional metadata
|
||||||
|
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.)
|
||||||
|
|
||||||
|
# Timing information
|
||||||
|
created_at: datetime # When execution started
|
||||||
|
finished_at: Optional[datetime] = None # When execution completed
|
||||||
|
|
||||||
|
def update_from_mapping(
|
||||||
|
self,
|
||||||
|
inputs: Optional[Mapping[str, Any]] = None,
|
||||||
|
process_data: Optional[Mapping[str, Any]] = None,
|
||||||
|
outputs: Optional[Mapping[str, Any]] = None,
|
||||||
|
metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update the model from mappings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: The inputs to update
|
||||||
|
process_data: The process data to update
|
||||||
|
outputs: The outputs to update
|
||||||
|
metadata: The metadata to update
|
||||||
|
"""
|
||||||
|
if inputs is not None:
|
||||||
|
self.inputs = dict(inputs)
|
||||||
|
if process_data is not None:
|
||||||
|
self.process_data = dict(process_data)
|
||||||
|
if outputs is not None:
|
||||||
|
self.outputs = dict(outputs)
|
||||||
|
if metadata is not None:
|
||||||
|
self.metadata = dict(metadata)
|
||||||
@ -0,0 +1,73 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import flask
|
||||||
|
import werkzeug.http
|
||||||
|
from flask import Flask
|
||||||
|
from flask.signals import request_finished, request_started
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_content_type_json(content_type: str) -> bool:
|
||||||
|
if not content_type:
|
||||||
|
return False
|
||||||
|
content_type_no_option, _ = werkzeug.http.parse_options_header(content_type)
|
||||||
|
return content_type_no_option.lower() == "application/json"
|
||||||
|
|
||||||
|
|
||||||
|
def _log_request_started(_sender, **_extra):
|
||||||
|
"""Log the start of a request."""
|
||||||
|
if not _logger.isEnabledFor(logging.DEBUG):
|
||||||
|
return
|
||||||
|
|
||||||
|
request = flask.request
|
||||||
|
if not (_is_content_type_json(request.content_type) and request.data):
|
||||||
|
_logger.debug("Received Request %s -> %s", request.method, request.path)
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
json_data = json.loads(request.data)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
_logger.exception("Failed to parse JSON request")
|
||||||
|
return
|
||||||
|
formatted_json = json.dumps(json_data, ensure_ascii=False, indent=2)
|
||||||
|
_logger.debug(
|
||||||
|
"Received Request %s -> %s, Request Body:\n%s",
|
||||||
|
request.method,
|
||||||
|
request.path,
|
||||||
|
formatted_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_request_finished(_sender, response, **_extra):
|
||||||
|
"""Log the end of a request."""
|
||||||
|
if not _logger.isEnabledFor(logging.DEBUG) or response is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not _is_content_type_json(response.content_type):
|
||||||
|
_logger.debug("Response %s %s", response.status, response.content_type)
|
||||||
|
return
|
||||||
|
|
||||||
|
response_data = response.get_data(as_text=True)
|
||||||
|
try:
|
||||||
|
json_data = json.loads(response_data)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
_logger.exception("Failed to parse JSON response")
|
||||||
|
return
|
||||||
|
formatted_json = json.dumps(json_data, ensure_ascii=False, indent=2)
|
||||||
|
_logger.debug(
|
||||||
|
"Response %s %s, Response Body:\n%s",
|
||||||
|
response.status,
|
||||||
|
response.content_type,
|
||||||
|
formatted_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def init_app(app: Flask):
|
||||||
|
"""Initialize the request logging extension."""
|
||||||
|
if not dify_config.ENABLE_REQUEST_LOGGING:
|
||||||
|
return
|
||||||
|
request_started.connect(_log_request_started, app)
|
||||||
|
request_finished.connect(_log_request_finished, app)
|
||||||
@ -0,0 +1,51 @@
|
|||||||
|
"""add WorkflowDraftVariable model
|
||||||
|
|
||||||
|
Revision ID: 2adcbe1f5dfb
|
||||||
|
Revises: d28f2004b072
|
||||||
|
Create Date: 2025-05-15 15:31:03.128680
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
import models as models
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "2adcbe1f5dfb"
|
||||||
|
down_revision = "d28f2004b072"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"workflow_draft_variables",
|
||||||
|
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||||
|
sa.Column("app_id", models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column("last_edited_at", sa.DateTime(), nullable=True),
|
||||||
|
sa.Column("node_id", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("name", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("description", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("selector", sa.String(length=255), nullable=False),
|
||||||
|
sa.Column("value_type", sa.String(length=20), nullable=False),
|
||||||
|
sa.Column("value", sa.Text(), nullable=False),
|
||||||
|
sa.Column("visible", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("editable", sa.Boolean(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id", name=op.f("workflow_draft_variables_pkey")),
|
||||||
|
sa.UniqueConstraint("app_id", "node_id", "name", name=op.f("workflow_draft_variables_app_id_key")),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
|
||||||
|
# Dropping `workflow_draft_variables` also drops any index associated with it.
|
||||||
|
op.drop_table("workflow_draft_variables")
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -1,11 +1,90 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from services.enterprise.base import EnterpriseRequest
|
from services.enterprise.base import EnterpriseRequest
|
||||||
|
|
||||||
|
|
||||||
|
class WebAppSettings(BaseModel):
|
||||||
|
access_mode: str = Field(
|
||||||
|
description="Access mode for the web app. Can be 'public' or 'private'",
|
||||||
|
default="private",
|
||||||
|
alias="accessMode",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EnterpriseService:
|
class EnterpriseService:
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_info(cls):
|
def get_info(cls):
|
||||||
return EnterpriseRequest.send_request("GET", "/info")
|
return EnterpriseRequest.send_request("GET", "/info")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_app_web_sso_enabled(cls, app_code):
|
def get_workspace_info(cls, tenant_id: str):
|
||||||
return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}")
|
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
|
||||||
|
|
||||||
|
class WebAppAuth:
|
||||||
|
@classmethod
|
||||||
|
def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str):
|
||||||
|
params = {"userId": user_id, "appCode": app_code}
|
||||||
|
data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params)
|
||||||
|
|
||||||
|
return data.get("result", False)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings:
|
||||||
|
if not app_id:
|
||||||
|
raise ValueError("app_id must be provided.")
|
||||||
|
params = {"appId": app_id}
|
||||||
|
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params)
|
||||||
|
if not data:
|
||||||
|
raise ValueError("No data found.")
|
||||||
|
return WebAppSettings(**data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def batch_get_app_access_mode_by_id(cls, app_ids: list[str]) -> dict[str, WebAppSettings]:
|
||||||
|
if not app_ids:
|
||||||
|
return {}
|
||||||
|
body = {"appIds": app_ids}
|
||||||
|
data: dict[str, str] = EnterpriseRequest.send_request("POST", "/webapp/access-mode/batch/id", json=body)
|
||||||
|
if not data:
|
||||||
|
raise ValueError("No data found.")
|
||||||
|
|
||||||
|
if not isinstance(data["accessModes"], dict):
|
||||||
|
raise ValueError("Invalid data format.")
|
||||||
|
|
||||||
|
ret = {}
|
||||||
|
for key, value in data["accessModes"].items():
|
||||||
|
curr = WebAppSettings()
|
||||||
|
curr.access_mode = value
|
||||||
|
ret[key] = curr
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_app_access_mode_by_code(cls, app_code: str) -> WebAppSettings:
|
||||||
|
if not app_code:
|
||||||
|
raise ValueError("app_code must be provided.")
|
||||||
|
params = {"appCode": app_code}
|
||||||
|
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params)
|
||||||
|
if not data:
|
||||||
|
raise ValueError("No data found.")
|
||||||
|
return WebAppSettings(**data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def update_app_access_mode(cls, app_id: str, access_mode: str):
|
||||||
|
if not app_id:
|
||||||
|
raise ValueError("app_id must be provided.")
|
||||||
|
if access_mode not in ["public", "private", "private_all"]:
|
||||||
|
raise ValueError("access_mode must be either 'public', 'private', or 'private_all'")
|
||||||
|
|
||||||
|
data = {"appId": app_id, "accessMode": access_mode}
|
||||||
|
|
||||||
|
response = EnterpriseRequest.send_request("POST", "/webapp/access-mode", json=data)
|
||||||
|
|
||||||
|
return response.get("result", False)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def cleanup_webapp(cls, app_id: str):
|
||||||
|
if not app_id:
|
||||||
|
raise ValueError("app_id must be provided.")
|
||||||
|
|
||||||
|
body = {"appId": app_id}
|
||||||
|
EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body)
|
||||||
|
|||||||
@ -0,0 +1,18 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from tasks.mail_enterprise_task import send_enterprise_email_task
|
||||||
|
|
||||||
|
|
||||||
|
class DifyMail(BaseModel):
|
||||||
|
to: list[str]
|
||||||
|
subject: str
|
||||||
|
body: str
|
||||||
|
substitutions: dict[str, str] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class EnterpriseMailService:
|
||||||
|
@classmethod
|
||||||
|
def send_mail(cls, mail: DifyMail):
|
||||||
|
send_enterprise_email_task.delay(
|
||||||
|
to=mail.to, subject=mail.subject, body=mail.body, substitutions=mail.substitutions
|
||||||
|
)
|
||||||
@ -0,0 +1,141 @@
|
|||||||
|
import random
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
|
from werkzeug.exceptions import NotFound, Unauthorized
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from controllers.web.error import WebAppAuthAccessDeniedError
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.helper import TokenManager
|
||||||
|
from libs.passport import PassportService
|
||||||
|
from libs.password import compare_password
|
||||||
|
from models.account import Account, AccountStatus
|
||||||
|
from models.model import App, EndUser, Site
|
||||||
|
from services.enterprise.enterprise_service import EnterpriseService
|
||||||
|
from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError
|
||||||
|
from services.feature_service import FeatureService
|
||||||
|
from tasks.mail_email_code_login import send_email_code_login_mail_task
|
||||||
|
|
||||||
|
|
||||||
|
class WebAppAuthService:
|
||||||
|
"""Service for web app authentication."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def authenticate(email: str, password: str) -> Account:
|
||||||
|
"""authenticate account with email and password"""
|
||||||
|
|
||||||
|
account = Account.query.filter_by(email=email).first()
|
||||||
|
if not account:
|
||||||
|
raise AccountNotFoundError()
|
||||||
|
|
||||||
|
if account.status == AccountStatus.BANNED.value:
|
||||||
|
raise AccountLoginError("Account is banned.")
|
||||||
|
|
||||||
|
if account.password is None or not compare_password(password, account.password, account.password_salt):
|
||||||
|
raise AccountPasswordError("Invalid email or password.")
|
||||||
|
|
||||||
|
return cast(Account, account)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def login(cls, account: Account, app_code: str, end_user_id: str) -> str:
|
||||||
|
site = db.session.query(Site).filter(Site.code == app_code).first()
|
||||||
|
if not site:
|
||||||
|
raise NotFound("Site not found.")
|
||||||
|
|
||||||
|
access_token = cls._get_account_jwt_token(account=account, site=site, end_user_id=end_user_id)
|
||||||
|
|
||||||
|
return access_token
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_user_through_email(cls, email: str):
|
||||||
|
account = db.session.query(Account).filter(Account.email == email).first()
|
||||||
|
if not account:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if account.status == AccountStatus.BANNED.value:
|
||||||
|
raise Unauthorized("Account is banned.")
|
||||||
|
|
||||||
|
return account
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def send_email_code_login_email(
|
||||||
|
cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
|
||||||
|
):
|
||||||
|
email = account.email if account else email
|
||||||
|
if email is None:
|
||||||
|
raise ValueError("Email must be provided.")
|
||||||
|
|
||||||
|
code = "".join([str(random.randint(0, 9)) for _ in range(6)])
|
||||||
|
token = TokenManager.generate_token(
|
||||||
|
account=account, email=email, token_type="webapp_email_code_login", additional_data={"code": code}
|
||||||
|
)
|
||||||
|
send_email_code_login_mail_task.delay(
|
||||||
|
language=language,
|
||||||
|
to=account.email if account else email,
|
||||||
|
code=code,
|
||||||
|
)
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]:
|
||||||
|
return TokenManager.get_token_data(token, "webapp_email_code_login")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def revoke_email_code_login_token(cls, token: str):
|
||||||
|
TokenManager.revoke_token(token, "webapp_email_code_login")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_end_user(cls, app_code, email) -> EndUser:
|
||||||
|
site = db.session.query(Site).filter(Site.code == app_code).first()
|
||||||
|
if not site:
|
||||||
|
raise NotFound("Site not found.")
|
||||||
|
app_model = db.session.query(App).filter(App.id == site.app_id).first()
|
||||||
|
if not app_model:
|
||||||
|
raise NotFound("App not found.")
|
||||||
|
end_user = EndUser(
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
app_id=app_model.id,
|
||||||
|
type="browser",
|
||||||
|
is_anonymous=False,
|
||||||
|
session_id=email,
|
||||||
|
name="enterpriseuser",
|
||||||
|
external_user_id="enterpriseuser",
|
||||||
|
)
|
||||||
|
db.session.add(end_user)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return end_user
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_user_accessibility(cls, account: Account, app_code: str):
|
||||||
|
"""Check if the user is allowed to access the app."""
|
||||||
|
system_features = FeatureService.get_system_features()
|
||||||
|
if system_features.webapp_auth.enabled:
|
||||||
|
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
|
||||||
|
|
||||||
|
if (
|
||||||
|
app_settings.access_mode != "public"
|
||||||
|
and not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(account.id, app_code=app_code)
|
||||||
|
):
|
||||||
|
raise WebAppAuthAccessDeniedError()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_account_jwt_token(cls, account: Account, site: Site, end_user_id: str) -> str:
|
||||||
|
exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24)
|
||||||
|
exp = int(exp_dt.timestamp())
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"iss": site.id,
|
||||||
|
"sub": "Web API Passport",
|
||||||
|
"app_id": site.app_id,
|
||||||
|
"app_code": site.code,
|
||||||
|
"user_id": account.id,
|
||||||
|
"end_user_id": end_user_id,
|
||||||
|
"token_source": "webapp",
|
||||||
|
"exp": exp,
|
||||||
|
}
|
||||||
|
|
||||||
|
token: str = PassportService().issue(payload)
|
||||||
|
return token
|
||||||
@ -0,0 +1,33 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
import click
|
||||||
|
from celery import shared_task # type: ignore
|
||||||
|
from flask import render_template_string
|
||||||
|
|
||||||
|
from extensions.ext_mail import mail
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task(queue="mail")
|
||||||
|
def send_enterprise_email_task(to, subject, body, substitutions):
|
||||||
|
if not mail.is_inited():
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info(click.style("Start enterprise mail to {} with subject {}".format(to, subject), fg="green"))
|
||||||
|
start_at = time.perf_counter()
|
||||||
|
|
||||||
|
try:
|
||||||
|
html_content = render_template_string(body, **substitutions)
|
||||||
|
|
||||||
|
if isinstance(to, list):
|
||||||
|
for t in to:
|
||||||
|
mail.send(to=t, subject=subject, html=html_content)
|
||||||
|
else:
|
||||||
|
mail.send(to=to, subject=subject, html=html_content)
|
||||||
|
|
||||||
|
end_at = time.perf_counter()
|
||||||
|
logging.info(
|
||||||
|
click.style("Send enterprise mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green")
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Send enterprise mail to {} failed".format(to))
|
||||||
@ -0,0 +1,70 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: 'Arial', sans-serif;
|
||||||
|
line-height: 16pt;
|
||||||
|
color: #101828;
|
||||||
|
background-color: #e9ebf0;
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
.container {
|
||||||
|
width: 600px;
|
||||||
|
height: 360px;
|
||||||
|
margin: 40px auto;
|
||||||
|
padding: 36px 48px;
|
||||||
|
background-color: #fcfcfd;
|
||||||
|
border-radius: 16px;
|
||||||
|
border: 1px solid #ffffff;
|
||||||
|
box-shadow: 0 2px 4px -2px rgba(9, 9, 11, 0.08);
|
||||||
|
}
|
||||||
|
.header {
|
||||||
|
margin-bottom: 24px;
|
||||||
|
}
|
||||||
|
.header img {
|
||||||
|
max-width: 100px;
|
||||||
|
height: auto;
|
||||||
|
}
|
||||||
|
.title {
|
||||||
|
font-weight: 600;
|
||||||
|
font-size: 24px;
|
||||||
|
line-height: 28.8px;
|
||||||
|
}
|
||||||
|
.description {
|
||||||
|
font-size: 13px;
|
||||||
|
line-height: 16px;
|
||||||
|
color: #676f83;
|
||||||
|
margin-top: 12px;
|
||||||
|
}
|
||||||
|
.code-content {
|
||||||
|
padding: 16px 32px;
|
||||||
|
text-align: center;
|
||||||
|
border-radius: 16px;
|
||||||
|
background-color: #f2f4f7;
|
||||||
|
margin: 16px auto;
|
||||||
|
}
|
||||||
|
.code {
|
||||||
|
line-height: 36px;
|
||||||
|
font-weight: 700;
|
||||||
|
font-size: 30px;
|
||||||
|
}
|
||||||
|
.tips {
|
||||||
|
line-height: 16px;
|
||||||
|
color: #676f83;
|
||||||
|
font-size: 13px;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<p class="title">Your login code for {{application_title}}</p>
|
||||||
|
<p class="description">Copy and paste this code, this code will only be valid for the next 5 minutes.</p>
|
||||||
|
<div class="code-content">
|
||||||
|
<span class="code">{{code}}</span>
|
||||||
|
</div>
|
||||||
|
<p class="tips">If you didn't request a login, don't worry. You can safely ignore this email.</p>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
@ -0,0 +1,70 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: 'Arial', sans-serif;
|
||||||
|
line-height: 16pt;
|
||||||
|
color: #101828;
|
||||||
|
background-color: #e9ebf0;
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
.container {
|
||||||
|
width: 600px;
|
||||||
|
height: 360px;
|
||||||
|
margin: 40px auto;
|
||||||
|
padding: 36px 48px;
|
||||||
|
background-color: #fcfcfd;
|
||||||
|
border-radius: 16px;
|
||||||
|
border: 1px solid #ffffff;
|
||||||
|
box-shadow: 0 2px 4px -2px rgba(9, 9, 11, 0.08);
|
||||||
|
}
|
||||||
|
.header {
|
||||||
|
margin-bottom: 24px;
|
||||||
|
}
|
||||||
|
.header img {
|
||||||
|
max-width: 100px;
|
||||||
|
height: auto;
|
||||||
|
}
|
||||||
|
.title {
|
||||||
|
font-weight: 600;
|
||||||
|
font-size: 24px;
|
||||||
|
line-height: 28.8px;
|
||||||
|
}
|
||||||
|
.description {
|
||||||
|
font-size: 13px;
|
||||||
|
line-height: 16px;
|
||||||
|
color: #676f83;
|
||||||
|
margin-top: 12px;
|
||||||
|
}
|
||||||
|
.code-content {
|
||||||
|
padding: 16px 32px;
|
||||||
|
text-align: center;
|
||||||
|
border-radius: 16px;
|
||||||
|
background-color: #f2f4f7;
|
||||||
|
margin: 16px auto;
|
||||||
|
}
|
||||||
|
.code {
|
||||||
|
line-height: 36px;
|
||||||
|
font-weight: 700;
|
||||||
|
font-size: 30px;
|
||||||
|
}
|
||||||
|
.tips {
|
||||||
|
line-height: 16px;
|
||||||
|
color: #676f83;
|
||||||
|
font-size: 13px;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<p class="title">{{application_title}} 的登录验证码</p>
|
||||||
|
<p class="description">复制并粘贴此验证码,注意验证码仅在接下来的 5 分钟内有效。</p>
|
||||||
|
<div class="code-content">
|
||||||
|
<span class="code">{{code}}</span>
|
||||||
|
</div>
|
||||||
|
<p class="tips">如果您没有请求登录,请不要担心。您可以安全地忽略此电子邮件。</p>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
@ -0,0 +1,69 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: 'Arial', sans-serif;
|
||||||
|
line-height: 16pt;
|
||||||
|
color: #374151;
|
||||||
|
background-color: #E5E7EB;
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
.container {
|
||||||
|
width: 100%;
|
||||||
|
max-width: 560px;
|
||||||
|
margin: 40px auto;
|
||||||
|
padding: 20px;
|
||||||
|
background-color: #F3F4F6;
|
||||||
|
border-radius: 8px;
|
||||||
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
|
||||||
|
}
|
||||||
|
.header {
|
||||||
|
text-align: center;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
}
|
||||||
|
.header img {
|
||||||
|
max-width: 100px;
|
||||||
|
height: auto;
|
||||||
|
}
|
||||||
|
.button {
|
||||||
|
display: inline-block;
|
||||||
|
padding: 12px 24px;
|
||||||
|
background-color: #2970FF;
|
||||||
|
color: white;
|
||||||
|
text-decoration: none;
|
||||||
|
border-radius: 4px;
|
||||||
|
text-align: center;
|
||||||
|
transition: background-color 0.3s ease;
|
||||||
|
}
|
||||||
|
.button:hover {
|
||||||
|
background-color: #265DD4;
|
||||||
|
}
|
||||||
|
.footer {
|
||||||
|
font-size: 0.9em;
|
||||||
|
color: #777777;
|
||||||
|
margin-top: 30px;
|
||||||
|
}
|
||||||
|
.content {
|
||||||
|
margin-top: 20px;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<div class="content">
|
||||||
|
<p>Dear {{ to }},</p>
|
||||||
|
<p>{{ inviter_name }} is pleased to invite you to join our workspace on {{application_title}}, a platform specifically designed for LLM application development. On {{application_title}}, you can explore, create, and collaborate to build and operate AI applications.</p>
|
||||||
|
<p>Click the button below to log in to {{application_title}} and join the workspace.</p>
|
||||||
|
<p style="text-align: center;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">Login Here</a></p>
|
||||||
|
</div>
|
||||||
|
<div class="footer">
|
||||||
|
<p>Best regards,</p>
|
||||||
|
<p>{{application_title}} Team</p>
|
||||||
|
<p>Please do not reply directly to this email; it is automatically sent by the system.</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</html>
|
||||||
@ -0,0 +1,70 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: 'Arial', sans-serif;
|
||||||
|
line-height: 16pt;
|
||||||
|
color: #101828;
|
||||||
|
background-color: #e9ebf0;
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
.container {
|
||||||
|
width: 600px;
|
||||||
|
height: 360px;
|
||||||
|
margin: 40px auto;
|
||||||
|
padding: 36px 48px;
|
||||||
|
background-color: #fcfcfd;
|
||||||
|
border-radius: 16px;
|
||||||
|
border: 1px solid #ffffff;
|
||||||
|
box-shadow: 0 2px 4px -2px rgba(9, 9, 11, 0.08);
|
||||||
|
}
|
||||||
|
.header {
|
||||||
|
margin-bottom: 24px;
|
||||||
|
}
|
||||||
|
.header img {
|
||||||
|
max-width: 100px;
|
||||||
|
height: auto;
|
||||||
|
}
|
||||||
|
.title {
|
||||||
|
font-weight: 600;
|
||||||
|
font-size: 24px;
|
||||||
|
line-height: 28.8px;
|
||||||
|
}
|
||||||
|
.description {
|
||||||
|
font-size: 13px;
|
||||||
|
line-height: 16px;
|
||||||
|
color: #676f83;
|
||||||
|
margin-top: 12px;
|
||||||
|
}
|
||||||
|
.code-content {
|
||||||
|
padding: 16px 32px;
|
||||||
|
text-align: center;
|
||||||
|
border-radius: 16px;
|
||||||
|
background-color: #f2f4f7;
|
||||||
|
margin: 16px auto;
|
||||||
|
}
|
||||||
|
.code {
|
||||||
|
line-height: 36px;
|
||||||
|
font-weight: 700;
|
||||||
|
font-size: 30px;
|
||||||
|
}
|
||||||
|
.tips {
|
||||||
|
line-height: 16px;
|
||||||
|
color: #676f83;
|
||||||
|
font-size: 13px;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<p class="title">Set your {{application_title}} password</p>
|
||||||
|
<p class="description">Copy and paste this code, this code will only be valid for the next 5 minutes.</p>
|
||||||
|
<div class="code-content">
|
||||||
|
<span class="code">{{code}}</span>
|
||||||
|
</div>
|
||||||
|
<p class="tips">If you didn't request, don't worry. You can safely ignore this email.</p>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
@ -0,0 +1,70 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: 'Arial', sans-serif;
|
||||||
|
line-height: 16pt;
|
||||||
|
color: #101828;
|
||||||
|
background-color: #e9ebf0;
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
.container {
|
||||||
|
width: 600px;
|
||||||
|
height: 360px;
|
||||||
|
margin: 40px auto;
|
||||||
|
padding: 36px 48px;
|
||||||
|
background-color: #fcfcfd;
|
||||||
|
border-radius: 16px;
|
||||||
|
border: 1px solid #ffffff;
|
||||||
|
box-shadow: 0 2px 4px -2px rgba(9, 9, 11, 0.08);
|
||||||
|
}
|
||||||
|
.header {
|
||||||
|
margin-bottom: 24px;
|
||||||
|
}
|
||||||
|
.header img {
|
||||||
|
max-width: 100px;
|
||||||
|
height: auto;
|
||||||
|
}
|
||||||
|
.title {
|
||||||
|
font-weight: 600;
|
||||||
|
font-size: 24px;
|
||||||
|
line-height: 28.8px;
|
||||||
|
}
|
||||||
|
.description {
|
||||||
|
font-size: 13px;
|
||||||
|
line-height: 16px;
|
||||||
|
color: #676f83;
|
||||||
|
margin-top: 12px;
|
||||||
|
}
|
||||||
|
.code-content {
|
||||||
|
padding: 16px 32px;
|
||||||
|
text-align: center;
|
||||||
|
border-radius: 16px;
|
||||||
|
background-color: #f2f4f7;
|
||||||
|
margin: 16px auto;
|
||||||
|
}
|
||||||
|
.code {
|
||||||
|
line-height: 36px;
|
||||||
|
font-weight: 700;
|
||||||
|
font-size: 30px;
|
||||||
|
}
|
||||||
|
.tips {
|
||||||
|
line-height: 16px;
|
||||||
|
color: #676f83;
|
||||||
|
font-size: 13px;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<p class="title">设置您的 {{application_title}} 账户密码</p>
|
||||||
|
<p class="description">复制并粘贴此验证码,注意验证码仅在接下来的 5 分钟内有效。</p>
|
||||||
|
<div class="code-content">
|
||||||
|
<span class="code">{{code}}</span>
|
||||||
|
</div>
|
||||||
|
<p class="tips">如果您没有请求,请不要担心。您可以安全地忽略此电子邮件。</p>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
@ -0,0 +1,265 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from flask import Flask, Response
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from extensions import ext_request_logging
|
||||||
|
from extensions.ext_request_logging import _is_content_type_json, _log_request_finished, init_app
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_content_type_json():
|
||||||
|
"""
|
||||||
|
Test the _is_content_type_json function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert _is_content_type_json("application/json") is True
|
||||||
|
# content type header with charset option.
|
||||||
|
assert _is_content_type_json("application/json; charset=utf-8") is True
|
||||||
|
# content type header with charset option, in uppercase.
|
||||||
|
assert _is_content_type_json("APPLICATION/JSON; CHARSET=UTF-8") is True
|
||||||
|
assert _is_content_type_json("text/html") is False
|
||||||
|
assert _is_content_type_json("") is False
|
||||||
|
|
||||||
|
|
||||||
|
_KEY_NEEDLE = "needle"
|
||||||
|
_VALUE_NEEDLE = _KEY_NEEDLE[::-1]
|
||||||
|
_RESPONSE_NEEDLE = "response"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_test_app():
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
@app.route("/", methods=["GET", "POST"])
|
||||||
|
def handler():
|
||||||
|
return _RESPONSE_NEEDLE
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE(QuantumGhost): Due to the design of Flask, we need to use monkey patch to write tests.
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_request_receiver(monkeypatch) -> mock.Mock:
|
||||||
|
mock_log_request_started = mock.Mock()
|
||||||
|
monkeypatch.setattr(ext_request_logging, "_log_request_started", mock_log_request_started)
|
||||||
|
return mock_log_request_started
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_response_receiver(monkeypatch) -> mock.Mock:
|
||||||
|
mock_log_request_finished = mock.Mock()
|
||||||
|
monkeypatch.setattr(ext_request_logging, "_log_request_finished", mock_log_request_finished)
|
||||||
|
return mock_log_request_finished
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_logger(monkeypatch) -> logging.Logger:
|
||||||
|
_logger = mock.MagicMock(spec=logging.Logger)
|
||||||
|
monkeypatch.setattr(ext_request_logging, "_logger", _logger)
|
||||||
|
return _logger
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def enable_request_logging(monkeypatch):
|
||||||
|
monkeypatch.setattr(dify_config, "ENABLE_REQUEST_LOGGING", True)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequestLoggingExtension:
|
||||||
|
def test_receiver_should_not_be_invoked_if_configuration_is_disabled(
|
||||||
|
self,
|
||||||
|
monkeypatch,
|
||||||
|
mock_request_receiver,
|
||||||
|
mock_response_receiver,
|
||||||
|
):
|
||||||
|
monkeypatch.setattr(dify_config, "ENABLE_REQUEST_LOGGING", False)
|
||||||
|
|
||||||
|
app = _get_test_app()
|
||||||
|
init_app(app)
|
||||||
|
|
||||||
|
with app.test_client() as client:
|
||||||
|
client.get("/")
|
||||||
|
|
||||||
|
mock_request_receiver.assert_not_called()
|
||||||
|
mock_response_receiver.assert_not_called()
|
||||||
|
|
||||||
|
def test_receiver_should_be_called_if_enabled(
|
||||||
|
self,
|
||||||
|
enable_request_logging,
|
||||||
|
mock_request_receiver,
|
||||||
|
mock_response_receiver,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test the request logging extension with JSON data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
app = _get_test_app()
|
||||||
|
init_app(app)
|
||||||
|
|
||||||
|
with app.test_client() as client:
|
||||||
|
client.post("/", json={_KEY_NEEDLE: _VALUE_NEEDLE})
|
||||||
|
|
||||||
|
mock_request_receiver.assert_called_once()
|
||||||
|
mock_response_receiver.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoggingLevel:
|
||||||
|
@pytest.mark.usefixtures("enable_request_logging")
|
||||||
|
def test_logging_should_be_skipped_if_level_is_above_debug(self, enable_request_logging, mock_logger):
|
||||||
|
mock_logger.isEnabledFor.return_value = False
|
||||||
|
app = _get_test_app()
|
||||||
|
init_app(app)
|
||||||
|
|
||||||
|
with app.test_client() as client:
|
||||||
|
client.post("/", json={_KEY_NEEDLE: _VALUE_NEEDLE})
|
||||||
|
mock_logger.debug.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequestReceiverLogging:
|
||||||
|
@pytest.mark.usefixtures("enable_request_logging")
|
||||||
|
def test_non_json_request(self, enable_request_logging, mock_logger, mock_response_receiver):
|
||||||
|
mock_logger.isEnabledFor.return_value = True
|
||||||
|
app = _get_test_app()
|
||||||
|
init_app(app)
|
||||||
|
|
||||||
|
with app.test_client() as client:
|
||||||
|
client.post("/", data="plain text")
|
||||||
|
assert mock_logger.debug.call_count == 1
|
||||||
|
call_args = mock_logger.debug.call_args[0]
|
||||||
|
assert "Received Request" in call_args[0]
|
||||||
|
assert call_args[1] == "POST"
|
||||||
|
assert call_args[2] == "/"
|
||||||
|
assert "Request Body" not in call_args[0]
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("enable_request_logging")
|
||||||
|
def test_json_request(self, enable_request_logging, mock_logger, mock_response_receiver):
|
||||||
|
mock_logger.isEnabledFor.return_value = True
|
||||||
|
app = _get_test_app()
|
||||||
|
init_app(app)
|
||||||
|
|
||||||
|
with app.test_client() as client:
|
||||||
|
client.post("/", json={_KEY_NEEDLE: _VALUE_NEEDLE})
|
||||||
|
assert mock_logger.debug.call_count == 1
|
||||||
|
call_args = mock_logger.debug.call_args[0]
|
||||||
|
assert "Received Request" in call_args[0]
|
||||||
|
assert "Request Body" in call_args[0]
|
||||||
|
assert call_args[1] == "POST"
|
||||||
|
assert call_args[2] == "/"
|
||||||
|
assert _KEY_NEEDLE in call_args[3]
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("enable_request_logging")
|
||||||
|
def test_json_request_with_empty_body(self, enable_request_logging, mock_logger, mock_response_receiver):
|
||||||
|
mock_logger.isEnabledFor.return_value = True
|
||||||
|
app = _get_test_app()
|
||||||
|
init_app(app)
|
||||||
|
|
||||||
|
with app.test_client() as client:
|
||||||
|
client.post("/", headers={"Content-Type": "application/json"})
|
||||||
|
|
||||||
|
assert mock_logger.debug.call_count == 1
|
||||||
|
call_args = mock_logger.debug.call_args[0]
|
||||||
|
assert "Received Request" in call_args[0]
|
||||||
|
assert "Request Body" not in call_args[0]
|
||||||
|
assert call_args[1] == "POST"
|
||||||
|
assert call_args[2] == "/"
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("enable_request_logging")
|
||||||
|
def test_json_request_with_invalid_json_as_body(self, enable_request_logging, mock_logger, mock_response_receiver):
|
||||||
|
mock_logger.isEnabledFor.return_value = True
|
||||||
|
app = _get_test_app()
|
||||||
|
init_app(app)
|
||||||
|
|
||||||
|
with app.test_client() as client:
|
||||||
|
client.post(
|
||||||
|
"/",
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
data="{",
|
||||||
|
)
|
||||||
|
assert mock_logger.debug.call_count == 0
|
||||||
|
assert mock_logger.exception.call_count == 1
|
||||||
|
|
||||||
|
exception_call_args = mock_logger.exception.call_args[0]
|
||||||
|
assert exception_call_args[0] == "Failed to parse JSON request"
|
||||||
|
|
||||||
|
|
||||||
|
class TestResponseReceiverLogging:
|
||||||
|
@pytest.mark.usefixtures("enable_request_logging")
|
||||||
|
def test_non_json_response(self, enable_request_logging, mock_logger):
|
||||||
|
mock_logger.isEnabledFor.return_value = True
|
||||||
|
app = _get_test_app()
|
||||||
|
response = Response(
|
||||||
|
"OK",
|
||||||
|
headers={"Content-Type": "text/plain"},
|
||||||
|
)
|
||||||
|
_log_request_finished(app, response)
|
||||||
|
assert mock_logger.debug.call_count == 1
|
||||||
|
call_args = mock_logger.debug.call_args[0]
|
||||||
|
assert "Response" in call_args[0]
|
||||||
|
assert "200" in call_args[1]
|
||||||
|
assert call_args[2] == "text/plain"
|
||||||
|
assert "Response Body" not in call_args[0]
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("enable_request_logging")
|
||||||
|
def test_json_response(self, enable_request_logging, mock_logger, mock_response_receiver):
|
||||||
|
mock_logger.isEnabledFor.return_value = True
|
||||||
|
app = _get_test_app()
|
||||||
|
response = Response(
|
||||||
|
json.dumps({_KEY_NEEDLE: _VALUE_NEEDLE}),
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
_log_request_finished(app, response)
|
||||||
|
assert mock_logger.debug.call_count == 1
|
||||||
|
call_args = mock_logger.debug.call_args[0]
|
||||||
|
assert "Response" in call_args[0]
|
||||||
|
assert "Response Body" in call_args[0]
|
||||||
|
assert "200" in call_args[1]
|
||||||
|
assert call_args[2] == "application/json"
|
||||||
|
assert _KEY_NEEDLE in call_args[3]
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("enable_request_logging")
|
||||||
|
def test_json_request_with_invalid_json_as_body(self, enable_request_logging, mock_logger, mock_response_receiver):
|
||||||
|
mock_logger.isEnabledFor.return_value = True
|
||||||
|
app = _get_test_app()
|
||||||
|
|
||||||
|
response = Response(
|
||||||
|
"{",
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
_log_request_finished(app, response)
|
||||||
|
assert mock_logger.debug.call_count == 0
|
||||||
|
assert mock_logger.exception.call_count == 1
|
||||||
|
|
||||||
|
exception_call_args = mock_logger.exception.call_args[0]
|
||||||
|
assert exception_call_args[0] == "Failed to parse JSON response"
|
||||||
|
|
||||||
|
|
||||||
|
class TestResponseUnmodified:
|
||||||
|
def test_when_request_logging_disabled(self):
|
||||||
|
app = _get_test_app()
|
||||||
|
init_app(app)
|
||||||
|
|
||||||
|
with app.test_client() as client:
|
||||||
|
response = client.post(
|
||||||
|
"/",
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
data="{",
|
||||||
|
)
|
||||||
|
assert response.text == _RESPONSE_NEEDLE
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("enable_request_logging")
|
||||||
|
def test_when_request_logging_enabled(self, enable_request_logging):
|
||||||
|
app = _get_test_app()
|
||||||
|
init_app(app)
|
||||||
|
|
||||||
|
with app.test_client() as client:
|
||||||
|
response = client.post(
|
||||||
|
"/",
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
data="{",
|
||||||
|
)
|
||||||
|
assert response.text == _RESPONSE_NEEDLE
|
||||||
|
assert response.status_code == 200
|
||||||
@ -0,0 +1,187 @@
|
|||||||
|
from collections.abc import Callable, Iterable
|
||||||
|
from enum import StrEnum
|
||||||
|
from typing import Any, NamedTuple, TypeVar
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import exc as sa_exc
|
||||||
|
from sqlalchemy import insert
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, Session
|
||||||
|
from sqlalchemy.sql.sqltypes import VARCHAR
|
||||||
|
|
||||||
|
from models.types import EnumText
|
||||||
|
|
||||||
|
_user_type_admin = "admin"
|
||||||
|
_user_type_normal = "normal"
|
||||||
|
|
||||||
|
|
||||||
|
class _Base(DeclarativeBase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _UserType(StrEnum):
|
||||||
|
admin = _user_type_admin
|
||||||
|
normal = _user_type_normal
|
||||||
|
|
||||||
|
|
||||||
|
class _EnumWithLongValue(StrEnum):
|
||||||
|
unknown = "unknown"
|
||||||
|
a_really_long_enum_values = "a_really_long_enum_values"
|
||||||
|
|
||||||
|
|
||||||
|
class _User(_Base):
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
id: Mapped[int] = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name: Mapped[str] = sa.Column(sa.String(length=255), nullable=False)
|
||||||
|
user_type: Mapped[_UserType] = sa.Column(EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal)
|
||||||
|
user_type_nullable: Mapped[_UserType | None] = sa.Column(EnumText(enum_class=_UserType), nullable=True)
|
||||||
|
|
||||||
|
|
||||||
|
class _ColumnTest(_Base):
|
||||||
|
__tablename__ = "column_test"
|
||||||
|
|
||||||
|
id: Mapped[int] = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
|
user_type: Mapped[_UserType] = sa.Column(EnumText(enum_class=_UserType), nullable=False, default=_UserType.normal)
|
||||||
|
explicit_length: Mapped[_UserType | None] = sa.Column(
|
||||||
|
EnumText(_UserType, length=50), nullable=True, default=_UserType.normal
|
||||||
|
)
|
||||||
|
long_value: Mapped[_EnumWithLongValue] = sa.Column(EnumText(enum_class=_EnumWithLongValue), nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
|
def _first(it: Iterable[_T]) -> _T:
|
||||||
|
ls = list(it)
|
||||||
|
if not ls:
|
||||||
|
raise ValueError("List is empty")
|
||||||
|
return ls[0]
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnumText:
|
||||||
|
def test_column_impl(self):
|
||||||
|
engine = sa.create_engine("sqlite://", echo=False)
|
||||||
|
_Base.metadata.create_all(engine)
|
||||||
|
|
||||||
|
inspector = sa.inspect(engine)
|
||||||
|
columns = inspector.get_columns(_ColumnTest.__tablename__)
|
||||||
|
|
||||||
|
user_type_column = _first(c for c in columns if c["name"] == "user_type")
|
||||||
|
sql_type = user_type_column["type"]
|
||||||
|
assert isinstance(user_type_column["type"], VARCHAR)
|
||||||
|
assert sql_type.length == 20
|
||||||
|
assert user_type_column["nullable"] is False
|
||||||
|
|
||||||
|
explicit_length_column = _first(c for c in columns if c["name"] == "explicit_length")
|
||||||
|
sql_type = explicit_length_column["type"]
|
||||||
|
assert isinstance(sql_type, VARCHAR)
|
||||||
|
assert sql_type.length == 50
|
||||||
|
assert explicit_length_column["nullable"] is True
|
||||||
|
|
||||||
|
long_value_column = _first(c for c in columns if c["name"] == "long_value")
|
||||||
|
sql_type = long_value_column["type"]
|
||||||
|
assert isinstance(sql_type, VARCHAR)
|
||||||
|
assert sql_type.length == len(_EnumWithLongValue.a_really_long_enum_values)
|
||||||
|
|
||||||
|
def test_insert_and_select(self):
|
||||||
|
engine = sa.create_engine("sqlite://", echo=False)
|
||||||
|
_Base.metadata.create_all(engine)
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
admin_user = _User(
|
||||||
|
name="admin",
|
||||||
|
user_type=_UserType.admin,
|
||||||
|
user_type_nullable=None,
|
||||||
|
)
|
||||||
|
session.add(admin_user)
|
||||||
|
session.flush()
|
||||||
|
admin_user_id = admin_user.id
|
||||||
|
|
||||||
|
normal_user = _User(
|
||||||
|
name="normal",
|
||||||
|
user_type=_UserType.normal.value,
|
||||||
|
user_type_nullable=_UserType.normal.value,
|
||||||
|
)
|
||||||
|
session.add(normal_user)
|
||||||
|
session.flush()
|
||||||
|
normal_user_id = normal_user.id
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
user = session.query(_User).filter(_User.id == admin_user_id).first()
|
||||||
|
assert user.user_type == _UserType.admin
|
||||||
|
assert user.user_type_nullable is None
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
user = session.query(_User).filter(_User.id == normal_user_id).first()
|
||||||
|
assert user.user_type == _UserType.normal
|
||||||
|
assert user.user_type_nullable == _UserType.normal
|
||||||
|
|
||||||
|
def test_insert_invalid_values(self):
|
||||||
|
def _session_insert_with_value(sess: Session, user_type: Any):
|
||||||
|
user = _User(name="test_user", user_type=user_type)
|
||||||
|
sess.add(user)
|
||||||
|
sess.flush()
|
||||||
|
|
||||||
|
def _insert_with_user(sess: Session, user_type: Any):
|
||||||
|
stmt = insert(_User).values(
|
||||||
|
{
|
||||||
|
"name": "test_user",
|
||||||
|
"user_type": user_type,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
sess.execute(stmt)
|
||||||
|
|
||||||
|
class TestCase(NamedTuple):
|
||||||
|
name: str
|
||||||
|
action: Callable[[Session], None]
|
||||||
|
exc_type: type[Exception]
|
||||||
|
|
||||||
|
engine = sa.create_engine("sqlite://", echo=False)
|
||||||
|
_Base.metadata.create_all(engine)
|
||||||
|
cases = [
|
||||||
|
TestCase(
|
||||||
|
name="session insert with invalid value",
|
||||||
|
action=lambda s: _session_insert_with_value(s, "invalid"),
|
||||||
|
exc_type=ValueError,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="session insert with invalid type",
|
||||||
|
action=lambda s: _session_insert_with_value(s, 1),
|
||||||
|
exc_type=TypeError,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="insert with invalid value",
|
||||||
|
action=lambda s: _insert_with_user(s, "invalid"),
|
||||||
|
exc_type=ValueError,
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="insert with invalid type",
|
||||||
|
action=lambda s: _insert_with_user(s, 1),
|
||||||
|
exc_type=TypeError,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for idx, c in enumerate(cases, 1):
|
||||||
|
with pytest.raises(sa_exc.StatementError) as exc:
|
||||||
|
with Session(engine) as session:
|
||||||
|
c.action(session)
|
||||||
|
|
||||||
|
assert isinstance(exc.value.orig, c.exc_type), f"test case {idx} failed, name={c.name}"
|
||||||
|
|
||||||
|
def test_select_invalid_values(self):
|
||||||
|
engine = sa.create_engine("sqlite://", echo=False)
|
||||||
|
_Base.metadata.create_all(engine)
|
||||||
|
|
||||||
|
insertion_sql = """
|
||||||
|
INSERT INTO users (id, name, user_type) VALUES
|
||||||
|
(1, 'invalid_value', 'invalid');
|
||||||
|
"""
|
||||||
|
with Session(engine) as session:
|
||||||
|
session.execute(sa.text(insertion_sql))
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc:
|
||||||
|
with Session(engine) as session:
|
||||||
|
_user = session.query(_User).filter(_User.id == 1).first()
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue