Migrate to DeclarativeBaseModel

pull/12372/head
Yeuoly 2 years ago
parent 53e1b45d40
commit 11270a7ef2
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61

@ -3,6 +3,8 @@ from functools import wraps
from flask import request from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
from constants.languages import supported_language from constants.languages import supported_language
@ -54,7 +56,8 @@ class InsertExploreAppListApi(Resource):
parser.add_argument("position", type=int, required=True, nullable=False, location="json") parser.add_argument("position", type=int, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
app = App.query.filter(App.id == args["app_id"]).first() with Session(db.engine) as session:
app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
if not app: if not app:
raise NotFound(f'App \'{args["app_id"]}\' is not found') raise NotFound(f'App \'{args["app_id"]}\' is not found')
@ -70,7 +73,10 @@ class InsertExploreAppListApi(Resource):
privacy_policy = site.privacy_policy or args["privacy_policy"] or "" privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or "" custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() with Session(db.engine) as session:
recommended_app = session.execute(
select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"])
).scalar_one_or_none()
if not recommended_app: if not recommended_app:
recommended_app = RecommendedApp( recommended_app = RecommendedApp(
@ -110,17 +116,27 @@ class InsertExploreAppApi(Resource):
@only_edition_cloud @only_edition_cloud
@admin_required @admin_required
def delete(self, app_id): def delete(self, app_id):
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first() with Session(db.engine) as session:
recommended_app = session.execute(
select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id))
).scalar_one_or_none()
if not recommended_app: if not recommended_app:
return {"result": "success"}, 204 return {"result": "success"}, 204
app = App.query.filter(App.id == recommended_app.app_id).first() with Session(db.engine) as session:
app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none()
if app: if app:
app.is_public = False app.is_public = False
installed_apps = InstalledApp.query.filter( with Session(db.engine) as session:
InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id installed_apps = session.execute(
).all() select(InstalledApp).filter(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
)
).all()
for installed_app in installed_apps: for installed_app in installed_apps:
db.session.delete(installed_app) db.session.delete(installed_app)

@ -33,7 +33,10 @@ def _get_resource(resource_id, tenant_id, resource_model):
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id) select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none() ).scalar_one_or_none()
else: else:
resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first() with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()
if resource is None: if resource is None:
flask_restful.abort(404, message=f"{resource_model.__name__} not found.") flask_restful.abort(404, message=f"{resource_model.__name__} not found.")

@ -3,6 +3,8 @@ import secrets
from flask import request from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.languages import languages from constants.languages import languages
from controllers.console import api from controllers.console import api
@ -41,7 +43,8 @@ class ForgotPasswordSendEmailApi(Resource):
else: else:
language = "en-US" language = "en-US"
account = Account.query.filter_by(email=args["email"]).first() with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
token = None token = None
if account is None: if account is None:
if FeatureService.get_system_features().is_allow_register: if FeatureService.get_system_features().is_allow_register:
@ -108,7 +111,8 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(new_password, salt) password_hashed = hash_password(new_password, salt)
base64_password_hashed = base64.b64encode(password_hashed).decode() base64_password_hashed = base64.b64encode(password_hashed).decode()
account = Account.query.filter_by(email=reset_data.get("email")).first() with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=reset_data.get("email"))).scalar_one_or_none()
if account: if account:
account.password = base64_password_hashed account.password = base64_password_hashed
account.password_salt = base64_salt account.password_salt = base64_salt

@ -5,6 +5,8 @@ from typing import Optional
import requests import requests
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_restful import Resource from flask_restful import Resource
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import Unauthorized
from configs import dify_config from configs import dify_config
@ -135,7 +137,8 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
account = Account.get_by_openid(provider, user_info.id) account = Account.get_by_openid(provider, user_info.id)
if not account: if not account:
account = Account.query.filter_by(email=user_info.email).first() with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
return account return account

@ -4,6 +4,8 @@ import json
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse from flask_restful import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.console import api from controllers.console import api
@ -77,7 +79,10 @@ class DataSourceApi(Resource):
def patch(self, binding_id, action): def patch(self, binding_id, action):
binding_id = str(binding_id) binding_id = str(binding_id)
action = str(action) action = str(action)
data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first() with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter_by(id=binding_id)
).scalar_one_or_none()
if data_source_binding is None: if data_source_binding is None:
raise NotFound("Data source binding not found.") raise NotFound("Data source binding not found.")
# enable binding # enable binding
@ -109,47 +114,53 @@ class DataSourceNotionListApi(Resource):
def get(self): def get(self):
dataset_id = request.args.get("dataset_id", default=None, type=str) dataset_id = request.args.get("dataset_id", default=None, type=str)
exist_page_ids = [] exist_page_ids = []
# import notion in the exist dataset with Session(db.engine) as session:
if dataset_id: # import notion in the exist dataset
dataset = DatasetService.get_dataset(dataset_id) if dataset_id:
if not dataset: dataset = DatasetService.get_dataset(dataset_id)
raise NotFound("Dataset not found.") if not dataset:
if dataset.data_source_type != "notion_import": raise NotFound("Dataset not found.")
raise ValueError("Dataset is not notion type.") if dataset.data_source_type != "notion_import":
documents = Document.query.filter_by( raise ValueError("Dataset is not notion type.")
dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id, documents = session.execute(
data_source_type="notion_import", select(Document).filter_by(
enabled=True, dataset_id=dataset_id,
tenant_id=current_user.current_tenant_id,
data_source_type="notion_import",
enabled=True,
)
).all()
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
# get all authorized pages
data_source_bindings = session.execute(
select(DataSourceOauthBinding).filter_by(
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
)
).all() ).all()
if documents: if not data_source_bindings:
for document in documents: return {"notion_info": []}, 200
data_source_info = json.loads(document.data_source_info) pre_import_info_list = []
exist_page_ids.append(data_source_info["notion_page_id"]) for data_source_binding in data_source_bindings:
# get all authorized pages source_info = data_source_binding.source_info
data_source_bindings = DataSourceOauthBinding.query.filter_by( pages = source_info["pages"]
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False # Filter out already bound pages
).all() for page in pages:
if not data_source_bindings: if page["page_id"] in exist_page_ids:
return {"notion_info": []}, 200 page["is_bound"] = True
pre_import_info_list = [] else:
for data_source_binding in data_source_bindings: page["is_bound"] = False
source_info = data_source_binding.source_info pre_import_info = {
pages = source_info["pages"] "workspace_name": source_info["workspace_name"],
# Filter out already bound pages "workspace_icon": source_info["workspace_icon"],
for page in pages: "workspace_id": source_info["workspace_id"],
if page["page_id"] in exist_page_ids: "pages": pages,
page["is_bound"] = True }
else: pre_import_info_list.append(pre_import_info)
page["is_bound"] = False return {"notion_info": pre_import_info_list}, 200
pre_import_info = {
"workspace_name": source_info["workspace_name"],
"workspace_icon": source_info["workspace_icon"],
"workspace_id": source_info["workspace_id"],
"pages": pages,
}
pre_import_info_list.append(pre_import_info)
return {"notion_info": pre_import_info_list}, 200
class DataSourceNotionApi(Resource): class DataSourceNotionApi(Resource):
@ -159,14 +170,17 @@ class DataSourceNotionApi(Resource):
def get(self, workspace_id, page_id, page_type): def get(self, workspace_id, page_id, page_type):
workspace_id = str(workspace_id) workspace_id = str(workspace_id)
page_id = str(page_id) page_id = str(page_id)
data_source_binding = DataSourceOauthBinding.query.filter( with Session(db.engine) as session:
db.and_( data_source_binding = session.execute(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, select(DataSourceOauthBinding).filter(
DataSourceOauthBinding.provider == "notion", db.and_(
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', DataSourceOauthBinding.provider == "notion",
) DataSourceOauthBinding.disabled == False,
).first() DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
).scalar_one_or_none()
if not data_source_binding: if not data_source_binding:
raise NotFound("Data source binding not found.") raise NotFound("Data source binding not found.")

@ -5,7 +5,8 @@ from datetime import datetime, timezone
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, fields, marshal, marshal_with, reqparse from flask_restful import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import asc, desc from sqlalchemy import asc, desc, select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
@ -104,7 +105,8 @@ class GetProcessRuleApi(Resource):
rules = DocumentService.DEFAULT_RULES["rules"] rules = DocumentService.DEFAULT_RULES["rules"]
if document_id: if document_id:
# get the latest process rule # get the latest process rule
document = Document.query.get_or_404(document_id) with Session(db.engine) as session:
document = session.execute(select(Document).get_or_404(document_id)).scalar_one_or_none()
dataset = DatasetService.get_dataset(document.dataset_id) dataset = DatasetService.get_dataset(document.dataset_id)
@ -167,7 +169,10 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) with Session(db.engine) as session:
query = session.execute(
select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
).all()
if search: if search:
search = f"%{search}%" search = f"%{search}%"
@ -204,18 +209,25 @@ class DatasetDocumentListApi(Resource):
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items documents = paginated_documents.items
if fetch: if fetch:
for document in documents: with Session(db.engine) as session:
completed_segments = DocumentSegment.query.filter( for document in documents:
DocumentSegment.completed_at.isnot(None), completed_segments = (
DocumentSegment.document_id == str(document.id), session.query(DocumentSegment)
DocumentSegment.status != "re_segment", .filter(
).count() DocumentSegment.completed_at.isnot(None),
total_segments = DocumentSegment.query.filter( DocumentSegment.document_id == str(document.id),
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" DocumentSegment.status != "re_segment",
).count() )
document.completed_segments = completed_segments .count()
document.total_segments = total_segments )
data = marshal(documents, document_with_segments_fields) total_segments = (
session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments
document.total_segments = total_segments
data = marshal(documents, document_with_segments_fields)
else: else:
data = marshal(documents, document_fields) data = marshal(documents, document_fields)
response = { response = {

@ -2,8 +2,11 @@ import os
from flask import session from flask import session
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from extensions.ext_database import db
from libs.helper import StrLen from libs.helper import StrLen
from models.model import DifySetup from models.model import DifySetup
from services.account_service import TenantService from services.account_service import TenantService
@ -42,7 +45,11 @@ class InitValidateAPI(Resource):
def get_init_validate_status(): def get_init_validate_status():
if dify_config.EDITION == "SELF_HOSTED": if dify_config.EDITION == "SELF_HOSTED":
if os.environ.get("INIT_PASSWORD"): if os.environ.get("INIT_PASSWORD"):
return session.get("is_init_validated") or DifySetup.query.first() if session.get("is_init_validated"):
return True
with Session(db.engine) as db_session:
return db_session.execute(select(DifySetup)).scalar_one_or_none()
return True return True

@ -4,6 +4,7 @@ import json
from flask_login import UserMixin from flask_login import UserMixin
from extensions.ext_database import db from extensions.ext_database import db
from models.base import Base
from .types import StringUUID from .types import StringUUID
@ -16,7 +17,7 @@ class AccountStatus(str, enum.Enum):
CLOSED = "closed" CLOSED = "closed"
class Account(UserMixin, db.Model): class Account(UserMixin, Base):
__tablename__ = "accounts" __tablename__ = "accounts"
__table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email"))

@ -38,7 +38,7 @@ class FileUploadConfig(BaseModel):
number_limits: int = Field(default=0, gt=0, le=10) number_limits: int = Field(default=0, gt=0, le=10)
class DifySetup(db.Model): class DifySetup(BaseModel):
__tablename__ = "dify_setups" __tablename__ = "dify_setups"
__table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)

Loading…
Cancel
Save