Compare commits

...

29 Commits

Author SHA1 Message Date
chenyanqun 8b5c46b2df fix:修复广船嵌入导航时,意外重定向到检测。 6 months ago
chenyanqun 3535e51fcd feat:更新web的next.config.js通用嵌入 6 months ago
陈炎群 f9bb9dae05 revert 5f8d916852
revert feat:更新next.config.js文件嵌入来源
6 months ago
chenyanqun 5f8d916852 feat:更新next.config.js文件嵌入来源 6 months ago
陈炎群 f8ad71384b feat:新增广船嵌入操作 7 months ago
陈志荣 c0507ac1b7 更新 'README.md' 7 months ago
GuanMu bd43ca6275
fix: rounded (#22909) 7 months ago
Yeuoly 9237976988
fix: refine handling of constant and mixed input types in ToolManager and ToolNodeData (#22903) 7 months ago
zyssyz123 6ac06486e3
Feat/change user email freezes limit (#22900) 7 months ago
HyaCinth 061d4c8ea0
fix(plugins_select): Adjust z-index, fix issue where options cannot be displayed (#22873) (#22893) 7 months ago
NFish aca8b83669
fix: support authorization using session and user_id in URL. (#22898) 7 months ago
Wu Tianwei a8f09ad43f
refactor(i18next): streamline fallback translation handling and initi… (#22894) 7 months ago
KVOJJJin de611ab344
Feat: add notification for change email completed (#22812)
Co-authored-by: Yansong Zhang <916125788@qq.com>
7 months ago
呆萌闷油瓶 371fe7a700
fix: type error in list-operator (#22803) 7 months ago
Nite Knite c6d7328e15
feat: revamp tool list page (#22879) 7 months ago
Will a327d024e9
fix: improved conversation name (#22840) 7 months ago
HyaCinth b8504ac7d0
refactor(dayjs): Refactor internationalized time formatting feature (#22870) (#22872) 7 months ago
Asuka Minato bb33335dd4
add autofix (#22785) 7 months ago
jiangbo721 5a02e599e1
chore: code format model-selector use enum (#22787)
Co-authored-by: 刘江波 <jiangbo721@163.com>
7 months ago
croatialu d1572f47a0
feat: Add user variable processing function to chat history (#22863) 7 months ago
Asuka Minato ef51678c73
orm filter -> where (#22801)
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Claude <noreply@anthropic.com>
7 months ago
Boris Polonsky e64e7563f6
feat(k8s): Add pure migration option for `api` component (#22750) 7 months ago
Song Kylin 0731db8c22
fix: private.pem keyPath error in windows (#22814)
Co-authored-by: songkunling <songkunling@cabrtech.com>
7 months ago
Jason Young 8c3e390172
test: add comprehensive integration tests for API key authentication system (#22856) 7 months ago
wanttobeamaster 8278b39f85
fix tablestore full text search bug (#22853) 7 months ago
wanttobeamaster 1c3c40db69
fix: tablestore TypeError when vector is missing (#22843)
Co-authored-by: xiaozhiqing.xzq <xiaozhiqing.xzq@alibaba-inc.com>
7 months ago
Novice 7ec94eb83c
chore(version): bump to 1.7.0 (#22830)
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
7 months ago
-LAN- 79ab8b205f
fix: improve max active requests calculation logic (#22847)
Signed-off-by: -LAN- <laipz8200@outlook.com>
7 months ago
Wu Tianwei 882f8bdd2c
fix: update @headlessui/react to version 2.2.1 (#22839) 7 months ago

@ -0,0 +1,27 @@
name: autofix.ci
on:
workflow_call:
pull_request:
push:
branches: [ "main" ]
permissions:
contents: read
jobs:
autofix:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
# Use uv to ensure we have the same ruff version in CI and locally.
- uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f
- run: |
cd api
uv sync --dev
# Fix lint errors
uv run ruff check --fix-only .
# Format code
uv run ruff format .
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27

@ -1,7 +1,7 @@
![cover-v5-optimized](./images/GitHub_README_if.png) ![cover-v5-optimized](./images/GitHub_README_if.png)
<p align="center"> <p align="center">
📌 <a href="https://dify.ai/blog/introducing-dify-workflow-file-upload-a-demo-on-ai-podcast">Introducing Dify Workflow File Upload: Recreate Google NotebookLM Podcast</a> 📌 <a href="https://dify.ai/blog/introducing-dify-workflow-file-upload-a-demo-on-ai-podcast">Introducing Dify Workflow File Upload: Recreate Google NotebookLM Podcast111</a>
</p> </p>
<p align="center"> <p align="center">

@ -50,7 +50,7 @@ def reset_password(email, new_password, password_confirm):
click.echo(click.style("Passwords do not match.", fg="red")) click.echo(click.style("Passwords do not match.", fg="red"))
return return
account = db.session.query(Account).filter(Account.email == email).one_or_none() account = db.session.query(Account).where(Account.email == email).one_or_none()
if not account: if not account:
click.echo(click.style("Account not found for email: {}".format(email), fg="red")) click.echo(click.style("Account not found for email: {}".format(email), fg="red"))
@ -89,7 +89,7 @@ def reset_email(email, new_email, email_confirm):
click.echo(click.style("New emails do not match.", fg="red")) click.echo(click.style("New emails do not match.", fg="red"))
return return
account = db.session.query(Account).filter(Account.email == email).one_or_none() account = db.session.query(Account).where(Account.email == email).one_or_none()
if not account: if not account:
click.echo(click.style("Account not found for email: {}".format(email), fg="red")) click.echo(click.style("Account not found for email: {}".format(email), fg="red"))
@ -136,8 +136,8 @@ def reset_encrypt_key_pair():
tenant.encrypt_public_key = generate_key_pair(tenant.id) tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete() db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
db.session.commit() db.session.commit()
click.echo( click.echo(
@ -172,7 +172,7 @@ def migrate_annotation_vector_database():
per_page = 50 per_page = 50
apps = ( apps = (
db.session.query(App) db.session.query(App)
.filter(App.status == "normal") .where(App.status == "normal")
.order_by(App.created_at.desc()) .order_by(App.created_at.desc())
.limit(per_page) .limit(per_page)
.offset((page - 1) * per_page) .offset((page - 1) * per_page)
@ -192,7 +192,7 @@ def migrate_annotation_vector_database():
try: try:
click.echo("Creating app annotation index: {}".format(app.id)) click.echo("Creating app annotation index: {}".format(app.id))
app_annotation_setting = ( app_annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first() db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
) )
if not app_annotation_setting: if not app_annotation_setting:
@ -202,13 +202,13 @@ def migrate_annotation_vector_database():
# get dataset_collection_binding info # get dataset_collection_binding info
dataset_collection_binding = ( dataset_collection_binding = (
db.session.query(DatasetCollectionBinding) db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) .where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first() .first()
) )
if not dataset_collection_binding: if not dataset_collection_binding:
click.echo("App annotation collection binding not found: {}".format(app.id)) click.echo("App annotation collection binding not found: {}".format(app.id))
continue continue
annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all() annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all()
dataset = Dataset( dataset = Dataset(
id=app.id, id=app.id,
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
@ -305,7 +305,7 @@ def migrate_knowledge_vector_database():
while True: while True:
try: try:
stmt = ( stmt = (
select(Dataset).filter(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc())
) )
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
@ -332,7 +332,7 @@ def migrate_knowledge_vector_database():
if dataset.collection_binding_id: if dataset.collection_binding_id:
dataset_collection_binding = ( dataset_collection_binding = (
db.session.query(DatasetCollectionBinding) db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id) .where(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none() .one_or_none()
) )
if dataset_collection_binding: if dataset_collection_binding:
@ -367,7 +367,7 @@ def migrate_knowledge_vector_database():
dataset_documents = ( dataset_documents = (
db.session.query(DatasetDocument) db.session.query(DatasetDocument)
.filter( .where(
DatasetDocument.dataset_id == dataset.id, DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == "completed", DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
@ -381,7 +381,7 @@ def migrate_knowledge_vector_database():
for dataset_document in dataset_documents: for dataset_document in dataset_documents:
segments = ( segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.document_id == dataset_document.id, DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed", DocumentSegment.status == "completed",
DocumentSegment.enabled == True, DocumentSegment.enabled == True,
@ -468,7 +468,7 @@ def convert_to_agent_apps():
app_id = str(i.id) app_id = str(i.id)
if app_id not in proceeded_app_ids: if app_id not in proceeded_app_ids:
proceeded_app_ids.append(app_id) proceeded_app_ids.append(app_id)
app = db.session.query(App).filter(App.id == app_id).first() app = db.session.query(App).where(App.id == app_id).first()
if app is not None: if app is not None:
apps.append(app) apps.append(app)
@ -483,7 +483,7 @@ def convert_to_agent_apps():
db.session.commit() db.session.commit()
# update conversation mode to agent # update conversation mode to agent
db.session.query(Conversation).filter(Conversation.app_id == app.id).update( db.session.query(Conversation).where(Conversation.app_id == app.id).update(
{Conversation.mode: AppMode.AGENT_CHAT.value} {Conversation.mode: AppMode.AGENT_CHAT.value}
) )
@ -560,7 +560,7 @@ def old_metadata_migration():
try: try:
stmt = ( stmt = (
select(DatasetDocument) select(DatasetDocument)
.filter(DatasetDocument.doc_metadata.is_not(None)) .where(DatasetDocument.doc_metadata.is_not(None))
.order_by(DatasetDocument.created_at.desc()) .order_by(DatasetDocument.created_at.desc())
) )
documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
@ -578,7 +578,7 @@ def old_metadata_migration():
else: else:
dataset_metadata = ( dataset_metadata = (
db.session.query(DatasetMetadata) db.session.query(DatasetMetadata)
.filter(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key) .where(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key)
.first() .first()
) )
if not dataset_metadata: if not dataset_metadata:
@ -602,7 +602,7 @@ def old_metadata_migration():
else: else:
dataset_metadata_binding = ( dataset_metadata_binding = (
db.session.query(DatasetMetadataBinding) # type: ignore db.session.query(DatasetMetadataBinding) # type: ignore
.filter( .where(
DatasetMetadataBinding.dataset_id == document.dataset_id, DatasetMetadataBinding.dataset_id == document.dataset_id,
DatasetMetadataBinding.document_id == document.id, DatasetMetadataBinding.document_id == document.id,
DatasetMetadataBinding.metadata_id == dataset_metadata.id, DatasetMetadataBinding.metadata_id == dataset_metadata.id,
@ -717,7 +717,7 @@ where sites.id is null limit 1000"""
continue continue
try: try:
app = db.session.query(App).filter(App.id == app_id).first() app = db.session.query(App).where(App.id == app_id).first()
if not app: if not app:
print(f"App {app_id} not found") print(f"App {app_id} not found")
continue continue

@ -56,7 +56,7 @@ class InsertExploreAppListApi(Resource):
parser.add_argument("position", type=int, required=True, nullable=False, location="json") parser.add_argument("position", type=int, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
app = db.session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none() app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
if not app: if not app:
raise NotFound(f"App '{args['app_id']}' is not found") raise NotFound(f"App '{args['app_id']}' is not found")
@ -74,7 +74,7 @@ class InsertExploreAppListApi(Resource):
with Session(db.engine) as session: with Session(db.engine) as session:
recommended_app = session.execute( recommended_app = session.execute(
select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]) select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"])
).scalar_one_or_none() ).scalar_one_or_none()
if not recommended_app: if not recommended_app:
@ -117,21 +117,21 @@ class InsertExploreAppApi(Resource):
def delete(self, app_id): def delete(self, app_id):
with Session(db.engine) as session: with Session(db.engine) as session:
recommended_app = session.execute( recommended_app = session.execute(
select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id)) select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
).scalar_one_or_none() ).scalar_one_or_none()
if not recommended_app: if not recommended_app:
return {"result": "success"}, 204 return {"result": "success"}, 204
with Session(db.engine) as session: with Session(db.engine) as session:
app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none() app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none()
if app: if app:
app.is_public = False app.is_public = False
with Session(db.engine) as session: with Session(db.engine) as session:
installed_apps = session.execute( installed_apps = session.execute(
select(InstalledApp).filter( select(InstalledApp).where(
InstalledApp.app_id == recommended_app.app_id, InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
) )

@ -61,7 +61,7 @@ class BaseApiKeyListResource(Resource):
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model) _get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
keys = ( keys = (
db.session.query(ApiToken) db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.all() .all()
) )
return {"items": keys} return {"items": keys}
@ -76,7 +76,7 @@ class BaseApiKeyListResource(Resource):
current_key_count = ( current_key_count = (
db.session.query(ApiToken) db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.count() .count()
) )
@ -117,7 +117,7 @@ class BaseApiKeyResource(Resource):
key = ( key = (
db.session.query(ApiToken) db.session.query(ApiToken)
.filter( .where(
getattr(ApiToken, self.resource_id_field) == resource_id, getattr(ApiToken, self.resource_id_field) == resource_id,
ApiToken.type == self.resource_type, ApiToken.type == self.resource_type,
ApiToken.id == api_key_id, ApiToken.id == api_key_id,
@ -128,7 +128,7 @@ class BaseApiKeyResource(Resource):
if key is None: if key is None:
flask_restful.abort(404, message="API key not found") flask_restful.abort(404, message="API key not found")
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit() db.session.commit()
return {"result": "success"}, 204 return {"result": "success"}, 204

@ -49,7 +49,7 @@ class CompletionConversationApi(Resource):
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion") query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion")
if args["keyword"]: if args["keyword"]:
query = query.join(Message, Message.conversation_id == Conversation.id).filter( query = query.join(Message, Message.conversation_id == Conversation.id).where(
or_( or_(
Message.query.ilike("%{}%".format(args["keyword"])), Message.query.ilike("%{}%".format(args["keyword"])),
Message.answer.ilike("%{}%".format(args["keyword"])), Message.answer.ilike("%{}%".format(args["keyword"])),
@ -121,7 +121,7 @@ class CompletionConversationDetailApi(Resource):
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first() .first()
) )
@ -181,7 +181,7 @@ class ChatConversationApi(Resource):
Message.conversation_id == Conversation.id, Message.conversation_id == Conversation.id,
) )
.join(subquery, subquery.c.conversation_id == Conversation.id) .join(subquery, subquery.c.conversation_id == Conversation.id)
.filter( .where(
or_( or_(
Message.query.ilike(keyword_filter), Message.query.ilike(keyword_filter),
Message.answer.ilike(keyword_filter), Message.answer.ilike(keyword_filter),
@ -286,7 +286,7 @@ class ChatConversationDetailApi(Resource):
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first() .first()
) )
@ -308,7 +308,7 @@ api.add_resource(ChatConversationDetailApi, "/apps/<uuid:app_id>/chat-conversati
def _get_conversation(app_model, conversation_id): def _get_conversation(app_model, conversation_id):
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)
.filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first() .first()
) )

@ -26,7 +26,7 @@ class AppMCPServerController(Resource):
@get_app_model @get_app_model
@marshal_with(app_server_fields) @marshal_with(app_server_fields)
def get(self, app_model): def get(self, app_model):
server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == app_model.id).first() server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
return server return server
@setup_required @setup_required
@ -73,7 +73,7 @@ class AppMCPServerController(Resource):
parser.add_argument("parameters", type=dict, required=True, location="json") parser.add_argument("parameters", type=dict, required=True, location="json")
parser.add_argument("status", type=str, required=False, location="json") parser.add_argument("status", type=str, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first() server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
if not server: if not server:
raise NotFound() raise NotFound()
@ -104,8 +104,8 @@ class AppMCPServerRefreshController(Resource):
raise NotFound() raise NotFound()
server = ( server = (
db.session.query(AppMCPServer) db.session.query(AppMCPServer)
.filter(AppMCPServer.id == server_id) .where(AppMCPServer.id == server_id)
.filter(AppMCPServer.tenant_id == current_user.current_tenant_id) .where(AppMCPServer.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not server: if not server:

@ -56,7 +56,7 @@ class ChatMessageListApi(Resource):
conversation = ( conversation = (
db.session.query(Conversation) db.session.query(Conversation)
.filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id) .where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
.first() .first()
) )
@ -66,7 +66,7 @@ class ChatMessageListApi(Resource):
if args["first_id"]: if args["first_id"]:
first_message = ( first_message = (
db.session.query(Message) db.session.query(Message)
.filter(Message.conversation_id == conversation.id, Message.id == args["first_id"]) .where(Message.conversation_id == conversation.id, Message.id == args["first_id"])
.first() .first()
) )
@ -75,7 +75,7 @@ class ChatMessageListApi(Resource):
history_messages = ( history_messages = (
db.session.query(Message) db.session.query(Message)
.filter( .where(
Message.conversation_id == conversation.id, Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at, Message.created_at < first_message.created_at,
Message.id != first_message.id, Message.id != first_message.id,
@ -87,7 +87,7 @@ class ChatMessageListApi(Resource):
else: else:
history_messages = ( history_messages = (
db.session.query(Message) db.session.query(Message)
.filter(Message.conversation_id == conversation.id) .where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc()) .order_by(Message.created_at.desc())
.limit(args["limit"]) .limit(args["limit"])
.all() .all()
@ -98,7 +98,7 @@ class ChatMessageListApi(Resource):
current_page_first_message = history_messages[-1] current_page_first_message = history_messages[-1]
rest_count = ( rest_count = (
db.session.query(Message) db.session.query(Message)
.filter( .where(
Message.conversation_id == conversation.id, Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at, Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id, Message.id != current_page_first_message.id,
@ -167,7 +167,7 @@ class MessageAnnotationCountApi(Resource):
@account_initialization_required @account_initialization_required
@get_app_model @get_app_model
def get(self, app_model): def get(self, app_model):
count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count() count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()
return {"count": count} return {"count": count}
@ -214,7 +214,7 @@ class MessageApi(Resource):
def get(self, app_model, message_id): def get(self, app_model, message_id):
message_id = str(message_id) message_id = str(message_id)
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
if not message: if not message:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")

@ -42,7 +42,7 @@ class ModelConfigResource(Resource):
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
# get original app model config # get original app model config
original_app_model_config = ( original_app_model_config = (
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first()
) )
if original_app_model_config is None: if original_app_model_config is None:
raise ValueError("Original app model config not found") raise ValueError("Original app model config not found")

@ -49,7 +49,7 @@ class AppSite(Resource):
if not current_user.is_editor: if not current_user.is_editor:
raise Forbidden() raise Forbidden()
site = db.session.query(Site).filter(Site.app_id == app_model.id).first() site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site: if not site:
raise NotFound raise NotFound
@ -93,7 +93,7 @@ class AppSiteAccessTokenReset(Resource):
if not current_user.is_admin_or_owner: if not current_user.is_admin_or_owner:
raise Forbidden() raise Forbidden()
site = db.session.query(Site).filter(Site.app_id == app_model.id).first() site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site: if not site:
raise NotFound raise NotFound

@ -11,7 +11,7 @@ from models import App, AppMode
def _load_app_model(app_id: str) -> Optional[App]: def _load_app_model(app_id: str) -> Optional[App]:
app_model = ( app_model = (
db.session.query(App) db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first() .first()
) )
return app_model return app_model

@ -113,3 +113,9 @@ class MemberNotInTenantError(BaseHTTPException):
error_code = "member_not_in_tenant" error_code = "member_not_in_tenant"
description = "The member is not in the workspace." description = "The member is not in the workspace."
code = 400 code = 400
class AccountInFreezeError(BaseHTTPException):
error_code = "account_in_freeze"
description = "This email is temporarily unavailable."
code = 400

@ -30,7 +30,7 @@ class DataSourceApi(Resource):
# get workspace data source integrates # get workspace data source integrates
data_source_integrates = ( data_source_integrates = (
db.session.query(DataSourceOauthBinding) db.session.query(DataSourceOauthBinding)
.filter( .where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.disabled == False,
) )
@ -171,7 +171,7 @@ class DataSourceNotionApi(Resource):
page_id = str(page_id) page_id = str(page_id)
with Session(db.engine) as session: with Session(db.engine) as session:
data_source_binding = session.execute( data_source_binding = session.execute(
select(DataSourceOauthBinding).filter( select(DataSourceOauthBinding).where(
db.and_( db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",

@ -412,7 +412,7 @@ class DatasetIndexingEstimateApi(Resource):
file_ids = args["info_list"]["file_info_list"]["file_ids"] file_ids = args["info_list"]["file_info_list"]["file_ids"]
file_details = ( file_details = (
db.session.query(UploadFile) db.session.query(UploadFile)
.filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) .where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
.all() .all()
) )
@ -517,14 +517,14 @@ class DatasetIndexingStatusApi(Resource):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
documents = ( documents = (
db.session.query(Document) db.session.query(Document)
.filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id) .where(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
.all() .all()
) )
documents_status = [] documents_status = []
for document in documents: for document in documents:
completed_segments = ( completed_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id), DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment", DocumentSegment.status != "re_segment",
@ -533,7 +533,7 @@ class DatasetIndexingStatusApi(Resource):
) )
total_segments = ( total_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count() .count()
) )
# Create a dictionary with document attributes and additional fields # Create a dictionary with document attributes and additional fields
@ -568,7 +568,7 @@ class DatasetApiKeyApi(Resource):
def get(self): def get(self):
keys = ( keys = (
db.session.query(ApiToken) db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.all() .all()
) )
return {"items": keys} return {"items": keys}
@ -584,7 +584,7 @@ class DatasetApiKeyApi(Resource):
current_key_count = ( current_key_count = (
db.session.query(ApiToken) db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.count() .count()
) )
@ -620,7 +620,7 @@ class DatasetApiDeleteApi(Resource):
key = ( key = (
db.session.query(ApiToken) db.session.query(ApiToken)
.filter( .where(
ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.tenant_id == current_user.current_tenant_id,
ApiToken.type == self.resource_type, ApiToken.type == self.resource_type,
ApiToken.id == api_key_id, ApiToken.id == api_key_id,
@ -631,7 +631,7 @@ class DatasetApiDeleteApi(Resource):
if key is None: if key is None:
flask_restful.abort(404, message="API key not found") flask_restful.abort(404, message="API key not found")
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit() db.session.commit()
return {"result": "success"}, 204 return {"result": "success"}, 204

@ -124,7 +124,7 @@ class GetProcessRuleApi(Resource):
# get the latest process rule # get the latest process rule
dataset_process_rule = ( dataset_process_rule = (
db.session.query(DatasetProcessRule) db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.dataset_id == document.dataset_id) .where(DatasetProcessRule.dataset_id == document.dataset_id)
.order_by(DatasetProcessRule.created_at.desc()) .order_by(DatasetProcessRule.created_at.desc())
.limit(1) .limit(1)
.one_or_none() .one_or_none()
@ -176,7 +176,7 @@ class DatasetDocumentListApi(Resource):
if search: if search:
search = f"%{search}%" search = f"%{search}%"
query = query.filter(Document.name.like(search)) query = query.where(Document.name.like(search))
if sort.startswith("-"): if sort.startswith("-"):
sort_logic = desc sort_logic = desc
@ -212,7 +212,7 @@ class DatasetDocumentListApi(Resource):
for document in documents: for document in documents:
completed_segments = ( completed_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id), DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment", DocumentSegment.status != "re_segment",
@ -221,7 +221,7 @@ class DatasetDocumentListApi(Resource):
) )
total_segments = ( total_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count() .count()
) )
document.completed_segments = completed_segments document.completed_segments = completed_segments
@ -417,7 +417,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
file = ( file = (
db.session.query(UploadFile) db.session.query(UploadFile)
.filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
.first() .first()
) )
@ -492,7 +492,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
file_id = data_source_info["upload_file_id"] file_id = data_source_info["upload_file_id"]
file_detail = ( file_detail = (
db.session.query(UploadFile) db.session.query(UploadFile)
.filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id) .where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id)
.first() .first()
) )
@ -568,7 +568,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
for document in documents: for document in documents:
completed_segments = ( completed_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id), DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment", DocumentSegment.status != "re_segment",
@ -577,7 +577,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
) )
total_segments = ( total_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count() .count()
) )
# Create a dictionary with document attributes and additional fields # Create a dictionary with document attributes and additional fields
@ -611,7 +611,7 @@ class DocumentIndexingStatusApi(DocumentResource):
completed_segments = ( completed_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id), DocumentSegment.document_id == str(document_id),
DocumentSegment.status != "re_segment", DocumentSegment.status != "re_segment",
@ -620,7 +620,7 @@ class DocumentIndexingStatusApi(DocumentResource):
) )
total_segments = ( total_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment") .where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
.count() .count()
) )

@ -78,7 +78,7 @@ class DatasetDocumentSegmentListApi(Resource):
query = ( query = (
select(DocumentSegment) select(DocumentSegment)
.filter( .where(
DocumentSegment.document_id == str(document_id), DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id, DocumentSegment.tenant_id == current_user.current_tenant_id,
) )
@ -86,19 +86,19 @@ class DatasetDocumentSegmentListApi(Resource):
) )
if status_list: if status_list:
query = query.filter(DocumentSegment.status.in_(status_list)) query = query.where(DocumentSegment.status.in_(status_list))
if hit_count_gte is not None: if hit_count_gte is not None:
query = query.filter(DocumentSegment.hit_count >= hit_count_gte) query = query.where(DocumentSegment.hit_count >= hit_count_gte)
if keyword: if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
if args["enabled"].lower() != "all": if args["enabled"].lower() != "all":
if args["enabled"].lower() == "true": if args["enabled"].lower() == "true":
query = query.filter(DocumentSegment.enabled == True) query = query.where(DocumentSegment.enabled == True)
elif args["enabled"].lower() == "false": elif args["enabled"].lower() == "false":
query = query.filter(DocumentSegment.enabled == False) query = query.where(DocumentSegment.enabled == False)
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@ -285,7 +285,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -331,7 +331,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -436,7 +436,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -493,7 +493,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -540,7 +540,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -586,7 +586,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -595,7 +595,7 @@ class ChildChunkUpdateApi(Resource):
child_chunk_id = str(child_chunk_id) child_chunk_id = str(child_chunk_id)
child_chunk = ( child_chunk = (
db.session.query(ChildChunk) db.session.query(ChildChunk)
.filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) .where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not child_chunk: if not child_chunk:
@ -635,7 +635,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id) segment_id = str(segment_id)
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not segment: if not segment:
@ -644,7 +644,7 @@ class ChildChunkUpdateApi(Resource):
child_chunk_id = str(child_chunk_id) child_chunk_id = str(child_chunk_id)
child_chunk = ( child_chunk = (
db.session.query(ChildChunk) db.session.query(ChildChunk)
.filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) .where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.first() .first()
) )
if not child_chunk: if not child_chunk:

@ -34,11 +34,11 @@ class InstalledAppsListApi(Resource):
if app_id: if app_id:
installed_apps = ( installed_apps = (
db.session.query(InstalledApp) db.session.query(InstalledApp)
.filter(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)) .where(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
.all() .all()
) )
else: else:
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all() installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all()
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
installed_app_list: list[dict[str, Any]] = [ installed_app_list: list[dict[str, Any]] = [
@ -94,12 +94,12 @@ class InstalledAppsListApi(Resource):
parser.add_argument("app_id", type=str, required=True, help="Invalid app_id") parser.add_argument("app_id", type=str, required=True, help="Invalid app_id")
args = parser.parse_args() args = parser.parse_args()
recommended_app = db.session.query(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]).first() recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first()
if recommended_app is None: if recommended_app is None:
raise NotFound("App not found") raise NotFound("App not found")
current_tenant_id = current_user.current_tenant_id current_tenant_id = current_user.current_tenant_id
app = db.session.query(App).filter(App.id == args["app_id"]).first() app = db.session.query(App).where(App.id == args["app_id"]).first()
if app is None: if app is None:
raise NotFound("App not found") raise NotFound("App not found")
@ -109,7 +109,7 @@ class InstalledAppsListApi(Resource):
installed_app = ( installed_app = (
db.session.query(InstalledApp) db.session.query(InstalledApp)
.filter(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)) .where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id))
.first() .first()
) )

@ -28,7 +28,7 @@ def installed_app_required(view=None):
installed_app = ( installed_app = (
db.session.query(InstalledApp) db.session.query(InstalledApp)
.filter( .where(
InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id
) )
.first() .first()

@ -21,7 +21,7 @@ def plugin_permission_required(
with Session(db.engine) as session: with Session(db.engine) as session:
permission = ( permission = (
session.query(TenantPluginPermission) session.query(TenantPluginPermission)
.filter( .where(
TenantPluginPermission.tenant_id == tenant_id, TenantPluginPermission.tenant_id == tenant_id,
) )
.first() .first()

@ -9,6 +9,7 @@ from configs import dify_config
from constants.languages import supported_language from constants.languages import supported_language
from controllers.console import api from controllers.console import api
from controllers.console.auth.error import ( from controllers.console.auth.error import (
AccountInFreezeError,
EmailAlreadyInUseError, EmailAlreadyInUseError,
EmailChangeLimitError, EmailChangeLimitError,
EmailCodeError, EmailCodeError,
@ -68,7 +69,7 @@ class AccountInitApi(Resource):
# check invitation code # check invitation code
invitation_code = ( invitation_code = (
db.session.query(InvitationCode) db.session.query(InvitationCode)
.filter( .where(
InvitationCode.code == args["invitation_code"], InvitationCode.code == args["invitation_code"],
InvitationCode.status == "unused", InvitationCode.status == "unused",
) )
@ -228,7 +229,7 @@ class AccountIntegrateApi(Resource):
def get(self): def get(self):
account = current_user account = current_user
account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all() account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all()
base_url = request.url_root.rstrip("/") base_url = request.url_root.rstrip("/")
oauth_base_path = "/console/api/oauth/login" oauth_base_path = "/console/api/oauth/login"
@ -479,21 +480,28 @@ class ChangeEmailResetApi(Resource):
parser.add_argument("token", type=str, required=True, nullable=False, location="json") parser.add_argument("token", type=str, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
if AccountService.is_account_in_freeze(args["new_email"]):
raise AccountInFreezeError()
if not AccountService.check_email_unique(args["new_email"]):
raise EmailAlreadyInUseError()
reset_data = AccountService.get_change_email_data(args["token"]) reset_data = AccountService.get_change_email_data(args["token"])
if not reset_data: if not reset_data:
raise InvalidTokenError() raise InvalidTokenError()
AccountService.revoke_change_email_token(args["token"]) AccountService.revoke_change_email_token(args["token"])
if not AccountService.check_email_unique(args["new_email"]):
raise EmailAlreadyInUseError()
old_email = reset_data.get("old_email", "") old_email = reset_data.get("old_email", "")
if current_user.email != old_email: if current_user.email != old_email:
raise AccountNotFound() raise AccountNotFound()
updated_account = AccountService.update_account(current_user, email=args["new_email"]) updated_account = AccountService.update_account(current_user, email=args["new_email"])
AccountService.send_change_email_completed_notify_email(
email=args["new_email"],
)
return updated_account return updated_account

@ -108,7 +108,7 @@ class MemberCancelInviteApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def delete(self, member_id): def delete(self, member_id):
member = db.session.query(Account).filter(Account.id == str(member_id)).first() member = db.session.query(Account).where(Account.id == str(member_id)).first()
if member is None: if member is None:
abort(404) abort(404)
else: else:

@ -22,7 +22,7 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
user_id = "DEFAULT-USER" user_id = "DEFAULT-USER"
if user_id == "DEFAULT-USER": if user_id == "DEFAULT-USER":
user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first() user_model = session.query(EndUser).where(EndUser.session_id == "DEFAULT-USER").first()
if not user_model: if not user_model:
user_model = EndUser( user_model = EndUser(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -36,7 +36,7 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
else: else:
user_model = AccountService.load_user(user_id) user_model = AccountService.load_user(user_id)
if not user_model: if not user_model:
user_model = session.query(EndUser).filter(EndUser.id == user_id).first() user_model = session.query(EndUser).where(EndUser.id == user_id).first()
if not user_model: if not user_model:
raise ValueError("user not found") raise ValueError("user not found")
except Exception: except Exception:
@ -71,7 +71,7 @@ def get_user_tenant(view: Optional[Callable] = None):
try: try:
tenant_model = ( tenant_model = (
db.session.query(Tenant) db.session.query(Tenant)
.filter( .where(
Tenant.id == tenant_id, Tenant.id == tenant_id,
) )
.first() .first()

@ -55,7 +55,7 @@ def enterprise_inner_api_user_auth(view):
if signature_base64 != token: if signature_base64 != token:
return view(*args, **kwargs) return view(*args, **kwargs)
kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first() kwargs["user"] = db.session.query(EndUser).where(EndUser.id == user_id).first()
return view(*args, **kwargs) return view(*args, **kwargs)

@ -30,7 +30,7 @@ class MCPAppApi(Resource):
request_id = args.get("id") request_id = args.get("id")
server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first() server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
if not server: if not server:
return helper.compact_generate_response( return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found") create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found")
@ -41,7 +41,7 @@ class MCPAppApi(Resource):
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active") create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active")
) )
app = db.session.query(App).filter(App.id == server.app_id).first() app = db.session.query(App).where(App.id == server.app_id).first()
if not app: if not app:
return helper.compact_generate_response( return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found") create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found")

@ -16,7 +16,7 @@ class AppSiteApi(Resource):
@marshal_with(fields.site_fields) @marshal_with(fields.site_fields)
def get(self, app_model: App): def get(self, app_model: App):
"""Retrieve app site info.""" """Retrieve app site info."""
site = db.session.query(Site).filter(Site.app_id == app_model.id).first() site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site: if not site:
raise Forbidden() raise Forbidden()

@ -63,7 +63,7 @@ class DocumentAddByTextApi(DatasetApiResource):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset does not exist.") raise ValueError("Dataset does not exist.")
@ -136,7 +136,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
args = parser.parse_args() args = parser.parse_args()
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset does not exist.") raise ValueError("Dataset does not exist.")
@ -206,7 +206,7 @@ class DocumentAddByFileApi(DatasetApiResource):
# get dataset info # get dataset info
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset does not exist.") raise ValueError("Dataset does not exist.")
@ -299,7 +299,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
# get dataset info # get dataset info
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset does not exist.") raise ValueError("Dataset does not exist.")
@ -367,7 +367,7 @@ class DocumentDeleteApi(DatasetApiResource):
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
# get dataset info # get dataset info
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset does not exist.") raise ValueError("Dataset does not exist.")
@ -398,7 +398,7 @@ class DocumentListApi(DatasetApiResource):
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str) search = request.args.get("keyword", default=None, type=str)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -406,7 +406,7 @@ class DocumentListApi(DatasetApiResource):
if search: if search:
search = f"%{search}%" search = f"%{search}%"
query = query.filter(Document.name.like(search)) query = query.where(Document.name.like(search))
query = query.order_by(desc(Document.created_at), desc(Document.position)) query = query.order_by(desc(Document.created_at), desc(Document.position))
@ -430,7 +430,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
batch = str(batch) batch = str(batch)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
# get dataset # get dataset
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
# get documents # get documents
@ -441,7 +441,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
for document in documents: for document in documents:
completed_segments = ( completed_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id), DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment", DocumentSegment.status != "re_segment",
@ -450,7 +450,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
) )
total_segments = ( total_segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count() .count()
) )
# Create a dictionary with document attributes and additional fields # Create a dictionary with document attributes and additional fields

@ -42,7 +42,7 @@ class SegmentApi(DatasetApiResource):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
# check document # check document
@ -89,7 +89,7 @@ class SegmentApi(DatasetApiResource):
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
page = request.args.get("page", default=1, type=int) page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int) limit = request.args.get("limit", default=20, type=int)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
# check document # check document
@ -146,7 +146,7 @@ class DatasetSegmentApi(DatasetApiResource):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
# check user's model setting # check user's model setting
@ -170,7 +170,7 @@ class DatasetSegmentApi(DatasetApiResource):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
# check user's model setting # check user's model setting
@ -216,7 +216,7 @@ class DatasetSegmentApi(DatasetApiResource):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
# check user's model setting # check user's model setting
@ -246,7 +246,7 @@ class ChildChunkApi(DatasetApiResource):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -296,7 +296,7 @@ class ChildChunkApi(DatasetApiResource):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -343,7 +343,7 @@ class DatasetChildChunkApi(DatasetApiResource):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -382,7 +382,7 @@ class DatasetChildChunkApi(DatasetApiResource):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")

@ -17,7 +17,7 @@ class UploadFileApi(DatasetApiResource):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
tenant_id = str(tenant_id) tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
# check document # check document
@ -31,7 +31,7 @@ class UploadFileApi(DatasetApiResource):
data_source_info = document.data_source_info_dict data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info: if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"] file_id = data_source_info["upload_file_id"]
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file: if not upload_file:
raise NotFound("UploadFile not found.") raise NotFound("UploadFile not found.")
else: else:

@ -44,7 +44,7 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
def decorated_view(*args, **kwargs): def decorated_view(*args, **kwargs):
api_token = validate_and_get_api_token("app") api_token = validate_and_get_api_token("app")
app_model = db.session.query(App).filter(App.id == api_token.app_id).first() app_model = db.session.query(App).where(App.id == api_token.app_id).first()
if not app_model: if not app_model:
raise Forbidden("The app no longer exists.") raise Forbidden("The app no longer exists.")
@ -54,7 +54,7 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
if not app_model.enable_api: if not app_model.enable_api:
raise Forbidden("The app's API service has been disabled.") raise Forbidden("The app's API service has been disabled.")
tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first() tenant = db.session.query(Tenant).where(Tenant.id == app_model.tenant_id).first()
if tenant is None: if tenant is None:
raise ValueError("Tenant does not exist.") raise ValueError("Tenant does not exist.")
if tenant.status == TenantStatus.ARCHIVE: if tenant.status == TenantStatus.ARCHIVE:
@ -62,15 +62,15 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
tenant_account_join = ( tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin) db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == api_token.tenant_id) .where(Tenant.id == api_token.tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id) .where(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.role.in_(["owner"])) .where(TenantAccountJoin.role.in_(["owner"]))
.filter(Tenant.status == TenantStatus.NORMAL) .where(Tenant.status == TenantStatus.NORMAL)
.one_or_none() .one_or_none()
) # TODO: only owner information is required, so only one is returned. ) # TODO: only owner information is required, so only one is returned.
if tenant_account_join: if tenant_account_join:
tenant, ta = tenant_account_join tenant, ta = tenant_account_join
account = db.session.query(Account).filter(Account.id == ta.account_id).first() account = db.session.query(Account).where(Account.id == ta.account_id).first()
# Login admin # Login admin
if account: if account:
account.current_tenant = tenant account.current_tenant = tenant
@ -213,15 +213,15 @@ def validate_dataset_token(view=None):
api_token = validate_and_get_api_token("dataset") api_token = validate_and_get_api_token("dataset")
tenant_account_join = ( tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin) db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == api_token.tenant_id) .where(Tenant.id == api_token.tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id) .where(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.role.in_(["owner"])) .where(TenantAccountJoin.role.in_(["owner"]))
.filter(Tenant.status == TenantStatus.NORMAL) .where(Tenant.status == TenantStatus.NORMAL)
.one_or_none() .one_or_none()
) # TODO: only owner information is required, so only one is returned. ) # TODO: only owner information is required, so only one is returned.
if tenant_account_join: if tenant_account_join:
tenant, ta = tenant_account_join tenant, ta = tenant_account_join
account = db.session.query(Account).filter(Account.id == ta.account_id).first() account = db.session.query(Account).where(Account.id == ta.account_id).first()
# Login admin # Login admin
if account: if account:
account.current_tenant = tenant account.current_tenant = tenant
@ -293,7 +293,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
end_user = ( end_user = (
db.session.query(EndUser) db.session.query(EndUser)
.filter( .where(
EndUser.tenant_id == app_model.tenant_id, EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id, EndUser.app_id == app_model.id,
EndUser.session_id == user_id, EndUser.session_id == user_id,
@ -320,7 +320,7 @@ class DatasetApiResource(Resource):
method_decorators = [validate_dataset_token] method_decorators = [validate_dataset_token]
def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset: def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first() dataset = db.session.query(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first()
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")

@ -3,6 +3,7 @@ from datetime import UTC, datetime, timedelta
from flask import request from flask import request
from flask_restful import Resource from flask_restful import Resource
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config from configs import dify_config
@ -42,17 +43,17 @@ class PassportResource(Resource):
raise WebAppAuthRequiredError() raise WebAppAuthRequiredError()
# get site from db and check if it is normal # get site from db and check if it is normal
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal"))
if not site: if not site:
raise NotFound() raise NotFound()
# get app from db and check if it is normal and enable_site # get app from db and check if it is normal and enable_site
app_model = db.session.query(App).filter(App.id == site.app_id).first() app_model = db.session.scalar(select(App).where(App.id == site.app_id))
if not app_model or app_model.status != "normal" or not app_model.enable_site: if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound() raise NotFound()
if user_id: if user_id:
end_user = ( end_user = db.session.scalar(
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first() select(EndUser).where(EndUser.app_id == app_model.id, EndUser.session_id == user_id)
) )
if end_user: if end_user:
@ -121,11 +122,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
if not user_auth_type: if not user_auth_type:
raise Unauthorized("Missing auth_type in the token.") raise Unauthorized("Missing auth_type in the token.")
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal"))
if not site: if not site:
raise NotFound() raise NotFound()
app_model = db.session.query(App).filter(App.id == site.app_id).first() app_model = db.session.scalar(select(App).where(App.id == site.app_id))
if not app_model or app_model.status != "normal" or not app_model.enable_site: if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound() raise NotFound()
@ -140,16 +141,14 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
end_user = None end_user = None
if end_user_id: if end_user_id:
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
if session_id: if session_id:
end_user = ( end_user = db.session.scalar(
db.session.query(EndUser) select(EndUser).where(
.filter(
EndUser.session_id == session_id, EndUser.session_id == session_id,
EndUser.tenant_id == app_model.tenant_id, EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id, EndUser.app_id == app_model.id,
) )
.first()
) )
if not end_user: if not end_user:
if not session_id: if not session_id:
@ -187,8 +186,8 @@ def _exchange_for_public_app_token(app_model, site, token_decoded):
user_id = token_decoded.get("user_id") user_id = token_decoded.get("user_id")
end_user = None end_user = None
if user_id: if user_id:
end_user = ( end_user = db.session.scalar(
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first() select(EndUser).where(EndUser.app_id == app_model.id, EndUser.session_id == user_id)
) )
if not end_user: if not end_user:
@ -224,6 +223,8 @@ def generate_session_id():
""" """
while True: while True:
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count() existing_count = db.session.scalar(
select(func.count()).select_from(EndUser).where(EndUser.session_id == session_id)
)
if existing_count == 0: if existing_count == 0:
return session_id return session_id

@ -57,7 +57,7 @@ class AppSiteApi(WebApiResource):
def get(self, app_model, end_user): def get(self, app_model, end_user):
"""Retrieve app site info.""" """Retrieve app site info."""
# get site # get site
site = db.session.query(Site).filter(Site.app_id == app_model.id).first() site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site: if not site:
raise Forbidden() raise Forbidden()

@ -3,6 +3,7 @@ from functools import wraps
from flask import request from flask import request
from flask_restful import Resource from flask_restful import Resource
from sqlalchemy import select
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
@ -48,8 +49,8 @@ def decode_jwt_token():
decoded = PassportService().verify(tk) decoded = PassportService().verify(tk)
app_code = decoded.get("app_code") app_code = decoded.get("app_code")
app_id = decoded.get("app_id") app_id = decoded.get("app_id")
app_model = db.session.query(App).filter(App.id == app_id).first() app_model = db.session.scalar(select(App).where(App.id == app_id))
site = db.session.query(Site).filter(Site.code == app_code).first() site = db.session.scalar(select(Site).where(Site.code == app_code))
if not app_model: if not app_model:
raise NotFound() raise NotFound()
if not app_code or not site: if not app_code or not site:
@ -57,7 +58,7 @@ def decode_jwt_token():
if app_model.enable_site is False: if app_model.enable_site is False:
raise BadRequest("Site is disabled.") raise BadRequest("Site is disabled.")
end_user_id = decoded.get("end_user_id") end_user_id = decoded.get("end_user_id")
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
if not end_user: if not end_user:
raise NotFound() raise NotFound()

@ -99,7 +99,7 @@ class BaseAgentRunner(AppRunner):
# get how many agent thoughts have been created # get how many agent thoughts have been created
self.agent_thought_count = ( self.agent_thought_count = (
db.session.query(MessageAgentThought) db.session.query(MessageAgentThought)
.filter( .where(
MessageAgentThought.message_id == self.message.id, MessageAgentThought.message_id == self.message.id,
) )
.count() .count()
@ -336,7 +336,7 @@ class BaseAgentRunner(AppRunner):
Save agent thought Save agent thought
""" """
updated_agent_thought = ( updated_agent_thought = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first() db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought.id).first()
) )
if not updated_agent_thought: if not updated_agent_thought:
raise ValueError("agent thought not found") raise ValueError("agent thought not found")
@ -496,7 +496,7 @@ class BaseAgentRunner(AppRunner):
return result return result
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
if not files: if not files:
return UserPromptMessage(content=message.query) return UserPromptMessage(content=message.query)
if message.app_model_config: if message.app_model_config:

@ -72,7 +72,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config) app_config = cast(AdvancedChatAppConfig, app_config)
app_record = db.session.query(App).filter(App.id == app_config.app_id).first() app_record = db.session.query(App).where(App.id == app_config.app_id).first()
if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")

@ -45,7 +45,7 @@ class AgentChatAppRunner(AppRunner):
app_config = application_generate_entity.app_config app_config = application_generate_entity.app_config
app_config = cast(AgentChatAppConfig, app_config) app_config = cast(AgentChatAppConfig, app_config)
app_record = db.session.query(App).filter(App.id == app_config.app_id).first() app_record = db.session.query(App).where(App.id == app_config.app_id).first()
if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")
@ -183,10 +183,10 @@ class AgentChatAppRunner(AppRunner):
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []): if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
conversation_result = db.session.query(Conversation).filter(Conversation.id == conversation.id).first() conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first()
if conversation_result is None: if conversation_result is None:
raise ValueError("Conversation not found") raise ValueError("Conversation not found")
message_result = db.session.query(Message).filter(Message.id == message.id).first() message_result = db.session.query(Message).where(Message.id == message.id).first()
if message_result is None: if message_result is None:
raise ValueError("Message not found") raise ValueError("Message not found")
db.session.close() db.session.close()

@ -43,7 +43,7 @@ class ChatAppRunner(AppRunner):
app_config = application_generate_entity.app_config app_config = application_generate_entity.app_config
app_config = cast(ChatAppConfig, app_config) app_config = cast(ChatAppConfig, app_config)
app_record = db.session.query(App).filter(App.id == app_config.app_id).first() app_record = db.session.query(App).where(App.id == app_config.app_id).first()
if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")

@ -248,7 +248,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
""" """
message = ( message = (
db.session.query(Message) db.session.query(Message)
.filter( .where(
Message.id == message_id, Message.id == message_id,
Message.app_id == app_model.id, Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"), Message.from_source == ("api" if isinstance(user, EndUser) else "console"),

@ -36,7 +36,7 @@ class CompletionAppRunner(AppRunner):
app_config = application_generate_entity.app_config app_config = application_generate_entity.app_config
app_config = cast(CompletionAppConfig, app_config) app_config = cast(CompletionAppConfig, app_config)
app_record = db.session.query(App).filter(App.id == app_config.app_id).first() app_record = db.session.query(App).where(App.id == app_config.app_id).first()
if not app_record: if not app_record:
raise ValueError("App not found") raise ValueError("App not found")

@ -85,7 +85,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
if conversation: if conversation:
app_model_config = ( app_model_config = (
db.session.query(AppModelConfig) db.session.query(AppModelConfig)
.filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id) .where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
.first() .first()
) )
@ -151,13 +151,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
introduction = self._get_conversation_introduction(application_generate_entity) introduction = self._get_conversation_introduction(application_generate_entity)
# get conversation name # get conversation name
if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity):
query = application_generate_entity.query or "New conversation" query = application_generate_entity.query or "New conversation"
else:
query = next(iter(application_generate_entity.inputs.values()), "New conversation")
if isinstance(query, int):
query = str(query)
query = query or "New conversation"
conversation_name = (query[:20] + "") if len(query) > 20 else query conversation_name = (query[:20] + "") if len(query) > 20 else query
if not conversation: if not conversation:
@ -259,7 +253,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param conversation_id: conversation id :param conversation_id: conversation id
:return: conversation :return: conversation
""" """
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
if not conversation: if not conversation:
raise ConversationNotExistsError("Conversation not exists") raise ConversationNotExistsError("Conversation not exists")
@ -272,7 +266,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param message_id: message id :param message_id: message id
:return: message :return: message
""" """
message = db.session.query(Message).filter(Message.id == message_id).first() message = db.session.query(Message).where(Message.id == message_id).first()
if message is None: if message is None:
raise MessageNotExistsError("Message not exists") raise MessageNotExistsError("Message not exists")

@ -26,7 +26,7 @@ class AnnotationReplyFeature:
:return: :return:
""" """
annotation_setting = ( annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_record.id).first() db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first()
) )
if not annotation_setting: if not annotation_setting:

@ -471,7 +471,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:return: :return:
""" """
agent_thought: Optional[MessageAgentThought] = ( agent_thought: Optional[MessageAgentThought] = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first() db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
) )
if agent_thought: if agent_thought:

@ -81,7 +81,7 @@ class MessageCycleManager:
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str): def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context(): with flask_app.app_context():
# get conversation and message # get conversation and message
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
if not conversation: if not conversation:
return return
@ -140,7 +140,7 @@ class MessageCycleManager:
:param event: event :param event: event
:return: :return:
""" """
message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first() message_file = db.session.query(MessageFile).where(MessageFile.id == event.message_file_id).first()
if message_file and message_file.url is not None: if message_file and message_file.url is not None:
# get tool file id # get tool file id

@ -49,7 +49,7 @@ class DatasetIndexToolCallbackHandler:
for document in documents: for document in documents:
if document.metadata is not None: if document.metadata is not None:
document_id = document.metadata["document_id"] document_id = document.metadata["document_id"]
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
if not dataset_document: if not dataset_document:
_logger.warning( _logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s", "Expected DatasetDocument record to exist, but none was found, document_id=%s",
@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = ( child_chunk = (
db.session.query(ChildChunk) db.session.query(ChildChunk)
.filter( .where(
ChildChunk.index_node_id == document.metadata["doc_id"], ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id, ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id, ChildChunk.document_id == dataset_document.id,
@ -69,18 +69,18 @@ class DatasetIndexToolCallbackHandler:
if child_chunk: if child_chunk:
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.id == child_chunk.segment_id) .where(DocumentSegment.id == child_chunk.segment_id)
.update( .update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
) )
) )
else: else:
query = db.session.query(DocumentSegment).filter( query = db.session.query(DocumentSegment).where(
DocumentSegment.index_node_id == document.metadata["doc_id"] DocumentSegment.index_node_id == document.metadata["doc_id"]
) )
if "dataset_id" in document.metadata: if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment # add hit count to document segment
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)

@ -191,7 +191,7 @@ class ProviderConfiguration(BaseModel):
provider_record = ( provider_record = (
db.session.query(Provider) db.session.query(Provider)
.filter( .where(
Provider.tenant_id == self.tenant_id, Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value, Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_name.in_(provider_names), Provider.provider_name.in_(provider_names),
@ -351,7 +351,7 @@ class ProviderConfiguration(BaseModel):
provider_model_record = ( provider_model_record = (
db.session.query(ProviderModel) db.session.query(ProviderModel)
.filter( .where(
ProviderModel.tenant_id == self.tenant_id, ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name.in_(provider_names), ProviderModel.provider_name.in_(provider_names),
ProviderModel.model_name == model, ProviderModel.model_name == model,
@ -481,7 +481,7 @@ class ProviderConfiguration(BaseModel):
return ( return (
db.session.query(ProviderModelSetting) db.session.query(ProviderModelSetting)
.filter( .where(
ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(provider_names), ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_type == model_type.to_origin_model_type(),
@ -560,7 +560,7 @@ class ProviderConfiguration(BaseModel):
return ( return (
db.session.query(LoadBalancingModelConfig) db.session.query(LoadBalancingModelConfig)
.filter( .where(
LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(provider_names), LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
@ -583,7 +583,7 @@ class ProviderConfiguration(BaseModel):
load_balancing_config_count = ( load_balancing_config_count = (
db.session.query(LoadBalancingModelConfig) db.session.query(LoadBalancingModelConfig)
.filter( .where(
LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(provider_names), LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
@ -627,7 +627,7 @@ class ProviderConfiguration(BaseModel):
model_setting = ( model_setting = (
db.session.query(ProviderModelSetting) db.session.query(ProviderModelSetting)
.filter( .where(
ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(provider_names), ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_type == model_type.to_origin_model_type(),
@ -693,7 +693,7 @@ class ProviderConfiguration(BaseModel):
preferred_model_provider = ( preferred_model_provider = (
db.session.query(TenantPreferredModelProvider) db.session.query(TenantPreferredModelProvider)
.filter( .where(
TenantPreferredModelProvider.tenant_id == self.tenant_id, TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name.in_(provider_names), TenantPreferredModelProvider.provider_name.in_(provider_names),
) )

@ -32,7 +32,7 @@ class ApiExternalDataTool(ExternalDataTool):
# get api_based_extension # get api_based_extension
api_based_extension = ( api_based_extension = (
db.session.query(APIBasedExtension) db.session.query(APIBasedExtension)
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first() .first()
) )
@ -56,7 +56,7 @@ class ApiExternalDataTool(ExternalDataTool):
# get api_based_extension # get api_based_extension
api_based_extension = ( api_based_extension = (
db.session.query(APIBasedExtension) db.session.query(APIBasedExtension)
.filter(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id) .where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
.first() .first()
) )

@ -15,7 +15,7 @@ def encrypt_token(tenant_id: str, token: str):
from models.account import Tenant from models.account import Tenant
from models.engine import db from models.engine import db
if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()): if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()):
raise ValueError(f"Tenant with id {tenant_id} not found") raise ValueError(f"Tenant with id {tenant_id} not found")
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode() return base64.b64encode(encrypted_token).decode()

@ -59,7 +59,7 @@ class IndexingRunner:
# get the process rule # get the process rule
processing_rule = ( processing_rule = (
db.session.query(DatasetProcessRule) db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first() .first()
) )
if not processing_rule: if not processing_rule:
@ -119,12 +119,12 @@ class IndexingRunner:
db.session.delete(document_segment) db.session.delete(document_segment)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# delete child chunks # delete child chunks
db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete() db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit() db.session.commit()
# get the process rule # get the process rule
processing_rule = ( processing_rule = (
db.session.query(DatasetProcessRule) db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first() .first()
) )
if not processing_rule: if not processing_rule:
@ -212,7 +212,7 @@ class IndexingRunner:
# get the process rule # get the process rule
processing_rule = ( processing_rule = (
db.session.query(DatasetProcessRule) db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first() .first()
) )
@ -316,7 +316,7 @@ class IndexingRunner:
# delete image files and related db records # delete image files and related db records
image_upload_file_ids = get_image_upload_file_ids(document.page_content) image_upload_file_ids = get_image_upload_file_ids(document.page_content)
for upload_file_id in image_upload_file_ids: for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if image_file is None: if image_file is None:
continue continue
try: try:
@ -346,7 +346,7 @@ class IndexingRunner:
raise ValueError("no upload file found") raise ValueError("no upload file found")
file_detail = ( file_detail = (
db.session.query(UploadFile).filter(UploadFile.id == data_source_info["upload_file_id"]).one_or_none() db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
) )
if file_detail: if file_detail:
@ -599,7 +599,7 @@ class IndexingRunner:
keyword.create(documents) keyword.create(documents)
if dataset.indexing_technique != "high_quality": if dataset.indexing_technique != "high_quality":
document_ids = [document.metadata["doc_id"] for document in documents] document_ids = [document.metadata["doc_id"] for document in documents]
db.session.query(DocumentSegment).filter( db.session.query(DocumentSegment).where(
DocumentSegment.document_id == document_id, DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id, DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id.in_(document_ids), DocumentSegment.index_node_id.in_(document_ids),
@ -630,7 +630,7 @@ class IndexingRunner:
index_processor.load(dataset, chunk_documents, with_keywords=False) index_processor.load(dataset, chunk_documents, with_keywords=False)
document_ids = [document.metadata["doc_id"] for document in chunk_documents] document_ids = [document.metadata["doc_id"] for document in chunk_documents]
db.session.query(DocumentSegment).filter( db.session.query(DocumentSegment).where(
DocumentSegment.document_id == dataset_document.id, DocumentSegment.document_id == dataset_document.id,
DocumentSegment.dataset_id == dataset.id, DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(document_ids), DocumentSegment.index_node_id.in_(document_ids),

@ -28,7 +28,7 @@ class MCPServerStreamableHTTPRequestHandler:
): ):
self.app = app self.app = app
self.request = request self.request = request
mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.app.id).first() mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == self.app.id).first()
if not mcp_server: if not mcp_server:
raise ValueError("MCP server not found") raise ValueError("MCP server not found")
self.mcp_server: AppMCPServer = mcp_server self.mcp_server: AppMCPServer = mcp_server
@ -192,7 +192,7 @@ class MCPServerStreamableHTTPRequestHandler:
def retrieve_end_user(self): def retrieve_end_user(self):
return ( return (
db.session.query(EndUser) db.session.query(EndUser)
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp") .where(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.first() .first()
) )

@ -67,7 +67,7 @@ class TokenBufferMemory:
prompt_messages: list[PromptMessage] = [] prompt_messages: list[PromptMessage] = []
for message in messages: for message in messages:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
if files: if files:
file_extra_config = None file_extra_config = None
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}: if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:

@ -89,7 +89,7 @@ class ApiModeration(Moderation):
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]: def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
extension = ( extension = (
db.session.query(APIBasedExtension) db.session.query(APIBasedExtension)
.filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first() .first()
) )

@ -120,7 +120,7 @@ class AliyunDataTrace(BaseTraceInstance):
user_id = message_data.from_account_id user_id = message_data.from_account_id
if message_data.from_end_user_id: if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = ( end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
) )
if end_user_data is not None: if end_user_data is not None:
user_id = end_user_data.session_id user_id = end_user_data.session_id
@ -244,14 +244,14 @@ class AliyunDataTrace(BaseTraceInstance):
if not app_id: if not app_id:
raise ValueError("No app_id found in trace_info metadata") raise ValueError("No app_id found in trace_info metadata")
app = session.query(App).filter(App.id == app_id).first() app = session.query(App).where(App.id == app_id).first()
if not app: if not app:
raise ValueError(f"App with id {app_id} not found") raise ValueError(f"App with id {app_id} not found")
if not app.created_by: if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)") raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).filter(Account.id == app.created_by).first() service_account = session.query(Account).where(Account.id == app.created_by).first()
if not service_account: if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = ( current_tenant = (

@ -297,7 +297,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
# Add end user data if available # Add end user data if available
if trace_info.message_data.from_end_user_id: if trace_info.message_data.from_end_user_id:
end_user_data: Optional[EndUser] = ( end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == trace_info.message_data.from_end_user_id).first() db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first()
) )
if end_user_data is not None: if end_user_data is not None:
message_metadata["end_user_id"] = end_user_data.session_id message_metadata["end_user_id"] = end_user_data.session_id
@ -703,7 +703,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
WorkflowNodeExecutionModel.process_data, WorkflowNodeExecutionModel.process_data,
WorkflowNodeExecutionModel.execution_metadata, WorkflowNodeExecutionModel.execution_metadata,
) )
.filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) .where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.all() .all()
) )
return workflow_nodes return workflow_nodes

@ -44,14 +44,14 @@ class BaseTraceInstance(ABC):
""" """
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator # Get the app to find its creator
app = session.query(App).filter(App.id == app_id).first() app = session.query(App).where(App.id == app_id).first()
if not app: if not app:
raise ValueError(f"App with id {app_id} not found") raise ValueError(f"App with id {app_id} not found")
if not app.created_by: if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)") raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).filter(Account.id == app.created_by).first() service_account = session.query(Account).where(Account.id == app.created_by).first()
if not service_account: if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")

@ -244,7 +244,7 @@ class LangFuseDataTrace(BaseTraceInstance):
user_id = message_data.from_account_id user_id = message_data.from_account_id
if message_data.from_end_user_id: if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = ( end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
) )
if end_user_data is not None: if end_user_data is not None:
user_id = end_user_data.session_id user_id = end_user_data.session_id

@ -262,7 +262,7 @@ class LangSmithDataTrace(BaseTraceInstance):
if message_data.from_end_user_id: if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = ( end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
) )
if end_user_data is not None: if end_user_data is not None:
end_user_id = end_user_data.session_id end_user_id = end_user_data.session_id

@ -284,7 +284,7 @@ class OpikDataTrace(BaseTraceInstance):
if message_data.from_end_user_id: if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = ( end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
) )
if end_user_data is not None: if end_user_data is not None:
end_user_id = end_user_data.session_id end_user_id = end_user_data.session_id

@ -218,7 +218,7 @@ class OpsTraceManager:
""" """
trace_config_data: Optional[TraceAppConfig] = ( trace_config_data: Optional[TraceAppConfig] = (
db.session.query(TraceAppConfig) db.session.query(TraceAppConfig)
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first() .first()
) )
@ -226,7 +226,7 @@ class OpsTraceManager:
return None return None
# decrypt_token # decrypt_token
app = db.session.query(App).filter(App.id == app_id).first() app = db.session.query(App).where(App.id == app_id).first()
if not app: if not app:
raise ValueError("App not found") raise ValueError("App not found")
@ -253,7 +253,7 @@ class OpsTraceManager:
if app_id is None: if app_id is None:
return None return None
app: Optional[App] = db.session.query(App).filter(App.id == app_id).first() app: Optional[App] = db.session.query(App).where(App.id == app_id).first()
if app is None: if app is None:
return None return None
@ -293,18 +293,18 @@ class OpsTraceManager:
@classmethod @classmethod
def get_app_config_through_message_id(cls, message_id: str): def get_app_config_through_message_id(cls, message_id: str):
app_model_config = None app_model_config = None
message_data = db.session.query(Message).filter(Message.id == message_id).first() message_data = db.session.query(Message).where(Message.id == message_id).first()
if not message_data: if not message_data:
return None return None
conversation_id = message_data.conversation_id conversation_id = message_data.conversation_id
conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
if not conversation_data: if not conversation_data:
return None return None
if conversation_data.app_model_config_id: if conversation_data.app_model_config_id:
app_model_config = ( app_model_config = (
db.session.query(AppModelConfig) db.session.query(AppModelConfig)
.filter(AppModelConfig.id == conversation_data.app_model_config_id) .where(AppModelConfig.id == conversation_data.app_model_config_id)
.first() .first()
) )
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs: elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
@ -331,7 +331,7 @@ class OpsTraceManager:
if tracing_provider is not None: if tracing_provider is not None:
raise ValueError(f"Invalid tracing provider: {tracing_provider}") raise ValueError(f"Invalid tracing provider: {tracing_provider}")
app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first() app_config: Optional[App] = db.session.query(App).where(App.id == app_id).first()
if not app_config: if not app_config:
raise ValueError("App not found") raise ValueError("App not found")
app_config.tracing = json.dumps( app_config.tracing = json.dumps(
@ -349,7 +349,7 @@ class OpsTraceManager:
:param app_id: app id :param app_id: app id
:return: :return:
""" """
app: Optional[App] = db.session.query(App).filter(App.id == app_id).first() app: Optional[App] = db.session.query(App).where(App.id == app_id).first()
if not app: if not app:
raise ValueError("App not found") raise ValueError("App not found")
if not app.tracing: if not app.tracing:

@ -3,6 +3,8 @@ from datetime import datetime
from typing import Optional, Union from typing import Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
from sqlalchemy import select
from extensions.ext_database import db from extensions.ext_database import db
from models.model import Message from models.model import Message
@ -20,7 +22,7 @@ def filter_none_values(data: dict):
def get_message_data(message_id: str): def get_message_data(message_id: str):
return db.session.query(Message).filter(Message.id == message_id).first() return db.session.scalar(select(Message).where(Message.id == message_id))
@contextmanager @contextmanager

@ -235,7 +235,7 @@ class WeaveDataTrace(BaseTraceInstance):
if message_data.from_end_user_id: if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = ( end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first() db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
) )
if end_user_data is not None: if end_user_data is not None:
end_user_id = end_user_data.session_id end_user_id = end_user_data.session_id

@ -193,9 +193,9 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
get the user by user id get the user by user id
""" """
user = db.session.query(EndUser).filter(EndUser.id == user_id).first() user = db.session.query(EndUser).where(EndUser.id == user_id).first()
if not user: if not user:
user = db.session.query(Account).filter(Account.id == user_id).first() user = db.session.query(Account).where(Account.id == user_id).first()
if not user: if not user:
raise ValueError("user not found") raise ValueError("user not found")
@ -208,7 +208,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
get app get app
""" """
try: try:
app = db.session.query(App).filter(App.id == app_id).filter(App.tenant_id == tenant_id).first() app = db.session.query(App).where(App.id == app_id).where(App.tenant_id == tenant_id).first()
except Exception: except Exception:
raise ValueError("app not found") raise ValueError("app not found")

@ -275,7 +275,7 @@ class ProviderManager:
# Get the corresponding TenantDefaultModel record # Get the corresponding TenantDefaultModel record
default_model = ( default_model = (
db.session.query(TenantDefaultModel) db.session.query(TenantDefaultModel)
.filter( .where(
TenantDefaultModel.tenant_id == tenant_id, TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(), TenantDefaultModel.model_type == model_type.to_origin_model_type(),
) )
@ -367,7 +367,7 @@ class ProviderManager:
# Get the list of available models from get_configurations and check if it is LLM # Get the list of available models from get_configurations and check if it is LLM
default_model = ( default_model = (
db.session.query(TenantDefaultModel) db.session.query(TenantDefaultModel)
.filter( .where(
TenantDefaultModel.tenant_id == tenant_id, TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(), TenantDefaultModel.model_type == model_type.to_origin_model_type(),
) )
@ -541,7 +541,7 @@ class ProviderManager:
db.session.rollback() db.session.rollback()
existed_provider_record = ( existed_provider_record = (
db.session.query(Provider) db.session.query(Provider)
.filter( .where(
Provider.tenant_id == tenant_id, Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name, Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value, Provider.provider_type == ProviderType.SYSTEM.value,

@ -93,11 +93,11 @@ class Jieba(BaseKeyword):
documents = [] documents = []
for chunk_index in sorted_chunk_indices: for chunk_index in sorted_chunk_indices:
segment_query = db.session.query(DocumentSegment).filter( segment_query = db.session.query(DocumentSegment).where(
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
) )
if document_ids_filter: if document_ids_filter:
segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter)) segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter))
segment = segment_query.first() segment = segment_query.first()
if segment: if segment:
@ -214,7 +214,7 @@ class Jieba(BaseKeyword):
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]): def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
document_segment = ( document_segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id) .where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
.first() .first()
) )
if document_segment: if document_segment:

@ -127,7 +127,7 @@ class RetrievalService:
external_retrieval_model: Optional[dict] = None, external_retrieval_model: Optional[dict] = None,
metadata_filtering_conditions: Optional[dict] = None, metadata_filtering_conditions: Optional[dict] = None,
): ):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset: if not dataset:
return [] return []
metadata_condition = ( metadata_condition = (
@ -145,7 +145,7 @@ class RetrievalService:
@classmethod @classmethod
def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]: def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]:
with Session(db.engine) as session: with Session(db.engine) as session:
return session.query(Dataset).filter(Dataset.id == dataset_id).first() return session.query(Dataset).where(Dataset.id == dataset_id).first()
@classmethod @classmethod
def keyword_search( def keyword_search(
@ -294,7 +294,7 @@ class RetrievalService:
dataset_documents = { dataset_documents = {
doc.id: doc doc.id: doc
for doc in db.session.query(DatasetDocument) for doc in db.session.query(DatasetDocument)
.filter(DatasetDocument.id.in_(document_ids)) .where(DatasetDocument.id.in_(document_ids))
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id)) .options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
.all() .all()
} }
@ -318,7 +318,7 @@ class RetrievalService:
child_index_node_id = document.metadata.get("doc_id") child_index_node_id = document.metadata.get("doc_id")
child_chunk = ( child_chunk = (
db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first() db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first()
) )
if not child_chunk: if not child_chunk:
@ -326,7 +326,7 @@ class RetrievalService:
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.dataset_id == dataset_document.dataset_id, DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True, DocumentSegment.enabled == True,
DocumentSegment.status == "completed", DocumentSegment.status == "completed",
@ -381,7 +381,7 @@ class RetrievalService:
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.dataset_id == dataset_document.dataset_id, DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True, DocumentSegment.enabled == True,
DocumentSegment.status == "completed", DocumentSegment.status == "completed",

@ -443,7 +443,7 @@ class QdrantVectorFactory(AbstractVectorFactory):
if dataset.collection_binding_id: if dataset.collection_binding_id:
dataset_collection_binding = ( dataset_collection_binding = (
db.session.query(DatasetCollectionBinding) db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id) .where(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none() .one_or_none()
) )
if dataset_collection_binding: if dataset_collection_binding:

@ -296,12 +296,22 @@ class TableStoreVector(BaseVector):
documents = [] documents = []
for search_hit in search_response.search_hits: for search_hit in search_response.search_hits:
if search_hit.score > score_threshold: if search_hit.score > score_threshold:
metadata = json.loads(search_hit.row[1][0][1]) ots_column_map = {}
for col in search_hit.row[1]:
ots_column_map[col[0]] = col[1]
vector_str = ots_column_map.get(Field.VECTOR.value)
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
vector = json.loads(vector_str) if vector_str else None
metadata = json.loads(metadata_str) if metadata_str else {}
metadata["score"] = search_hit.score metadata["score"] = search_hit.score
documents.append( documents.append(
Document( Document(
page_content=search_hit.row[1][2][1], page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
vector=json.loads(search_hit.row[1][3][1]), vector=vector,
metadata=metadata, metadata=metadata,
) )
) )
@ -309,7 +319,7 @@ class TableStoreVector(BaseVector):
return documents return documents
def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]: def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]:
bool_query = tablestore.BoolQuery() bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[])
bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value)) bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value))
if document_ids_filter: if document_ids_filter:
@ -329,11 +339,20 @@ class TableStoreVector(BaseVector):
documents = [] documents = []
for search_hit in search_response.search_hits: for search_hit in search_response.search_hits:
ots_column_map = {}
for col in search_hit.row[1]:
ots_column_map[col[0]] = col[1]
vector_str = ots_column_map.get(Field.VECTOR.value)
metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
vector = json.loads(vector_str) if vector_str else None
metadata = json.loads(metadata_str) if metadata_str else {}
documents.append( documents.append(
Document( Document(
page_content=search_hit.row[1][2][1], page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
vector=json.loads(search_hit.row[1][3][1]), vector=vector,
metadata=json.loads(search_hit.row[1][0][1]), metadata=metadata,
) )
) )
return documents return documents

@ -418,13 +418,13 @@ class TidbOnQdrantVector(BaseVector):
class TidbOnQdrantVectorFactory(AbstractVectorFactory): class TidbOnQdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector: def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
tidb_auth_binding = ( tidb_auth_binding = (
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none() db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
) )
if not tidb_auth_binding: if not tidb_auth_binding:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900): with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = ( tidb_auth_binding = (
db.session.query(TidbAuthBinding) db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id) .where(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none() .one_or_none()
) )
if tidb_auth_binding: if tidb_auth_binding:
@ -433,7 +433,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
else: else:
idle_tidb_auth_binding = ( idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding) db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE") .where(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1) .limit(1)
.one_or_none() .one_or_none()
) )

@ -47,7 +47,7 @@ class Vector:
if dify_config.VECTOR_STORE_WHITELIST_ENABLE: if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
whitelist = ( whitelist = (
db.session.query(Whitelist) db.session.query(Whitelist)
.filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db") .where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
.one_or_none() .one_or_none()
) )
if whitelist: if whitelist:

@ -42,7 +42,7 @@ class DatasetDocumentStore:
@property @property
def docs(self) -> dict[str, Document]: def docs(self) -> dict[str, Document]:
document_segments = ( document_segments = (
db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == self._dataset.id).all() db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id).all()
) )
output = {} output = {}
@ -63,7 +63,7 @@ class DatasetDocumentStore:
def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None: def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None:
max_position = ( max_position = (
db.session.query(func.max(DocumentSegment.position)) db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == self._document_id) .where(DocumentSegment.document_id == self._document_id)
.scalar() .scalar()
) )
@ -147,7 +147,7 @@ class DatasetDocumentStore:
segment_document.tokens = tokens segment_document.tokens = tokens
if save_child and doc.children: if save_child and doc.children:
# delete the existing child chunks # delete the existing child chunks
db.session.query(ChildChunk).filter( db.session.query(ChildChunk).where(
ChildChunk.tenant_id == self._dataset.tenant_id, ChildChunk.tenant_id == self._dataset.tenant_id,
ChildChunk.dataset_id == self._dataset.id, ChildChunk.dataset_id == self._dataset.id,
ChildChunk.document_id == self._document_id, ChildChunk.document_id == self._document_id,
@ -230,7 +230,7 @@ class DatasetDocumentStore:
def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]: def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]:
document_segment = ( document_segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id) .where(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
.first() .first()
) )

@ -366,7 +366,7 @@ class NotionExtractor(BaseExtractor):
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
data_source_binding = ( data_source_binding = (
db.session.query(DataSourceOauthBinding) db.session.query(DataSourceOauthBinding)
.filter( .where(
db.and_( db.and_(
DataSourceOauthBinding.tenant_id == tenant_id, DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",

@ -118,7 +118,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_node_ids = ( child_node_ids = (
db.session.query(ChildChunk.index_node_id) db.session.query(ChildChunk.index_node_id)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter( .where(
DocumentSegment.dataset_id == dataset.id, DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids), DocumentSegment.index_node_id.in_(node_ids),
ChildChunk.dataset_id == dataset.id, ChildChunk.dataset_id == dataset.id,
@ -128,7 +128,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_node_ids = [child_node_id[0] for child_node_id in child_node_ids] child_node_ids = [child_node_id[0] for child_node_id in child_node_ids]
vector.delete_by_ids(child_node_ids) vector.delete_by_ids(child_node_ids)
if delete_child_chunks: if delete_child_chunks:
db.session.query(ChildChunk).filter( db.session.query(ChildChunk).where(
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids) ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
).delete() ).delete()
db.session.commit() db.session.commit()
@ -136,7 +136,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
vector.delete() vector.delete()
if delete_child_chunks: if delete_child_chunks:
db.session.query(ChildChunk).filter(ChildChunk.dataset_id == dataset.id).delete() db.session.query(ChildChunk).where(ChildChunk.dataset_id == dataset.id).delete()
db.session.commit() db.session.commit()
def retrieve( def retrieve(

@ -135,7 +135,7 @@ class DatasetRetrieval:
available_datasets = [] available_datasets = []
for dataset_id in dataset_ids: for dataset_id in dataset_ids:
# get dataset from dataset id # get dataset from dataset id
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
# pass if dataset is not available # pass if dataset is not available
if not dataset: if not dataset:
@ -242,7 +242,7 @@ class DatasetRetrieval:
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = ( document = (
db.session.query(DatasetDocument) db.session.query(DatasetDocument)
.filter( .where(
DatasetDocument.id == segment.document_id, DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
DatasetDocument.archived == False, DatasetDocument.archived == False,
@ -327,7 +327,7 @@ class DatasetRetrieval:
if dataset_id: if dataset_id:
# get retrieval model config # get retrieval model config
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset: if dataset:
results = [] results = []
if dataset.provider == "external": if dataset.provider == "external":
@ -516,14 +516,14 @@ class DatasetRetrieval:
if document.metadata is not None: if document.metadata is not None:
dataset_document = ( dataset_document = (
db.session.query(DatasetDocument) db.session.query(DatasetDocument)
.filter(DatasetDocument.id == document.metadata["document_id"]) .where(DatasetDocument.id == document.metadata["document_id"])
.first() .first()
) )
if dataset_document: if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = ( child_chunk = (
db.session.query(ChildChunk) db.session.query(ChildChunk)
.filter( .where(
ChildChunk.index_node_id == document.metadata["doc_id"], ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id, ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id, ChildChunk.document_id == dataset_document.id,
@ -533,7 +533,7 @@ class DatasetRetrieval:
if child_chunk: if child_chunk:
segment = ( segment = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter(DocumentSegment.id == child_chunk.segment_id) .where(DocumentSegment.id == child_chunk.segment_id)
.update( .update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, {DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False, synchronize_session=False,
@ -541,13 +541,13 @@ class DatasetRetrieval:
) )
db.session.commit() db.session.commit()
else: else:
query = db.session.query(DocumentSegment).filter( query = db.session.query(DocumentSegment).where(
DocumentSegment.index_node_id == document.metadata["doc_id"] DocumentSegment.index_node_id == document.metadata["doc_id"]
) )
# if 'dataset_id' in document.metadata: # if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata: if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"]) query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment # add hit count to document segment
query.update( query.update(
@ -600,7 +600,7 @@ class DatasetRetrieval:
): ):
with flask_app.app_context(): with flask_app.app_context():
with Session(db.engine) as session: with Session(db.engine) as session:
dataset = session.query(Dataset).filter(Dataset.id == dataset_id).first() dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset: if not dataset:
return [] return []
@ -685,7 +685,7 @@ class DatasetRetrieval:
available_datasets = [] available_datasets = []
for dataset_id in dataset_ids: for dataset_id in dataset_ids:
# get dataset from dataset id # get dataset from dataset id
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
# pass if dataset is not available # pass if dataset is not available
if not dataset: if not dataset:
@ -862,7 +862,7 @@ class DatasetRetrieval:
metadata_filtering_conditions: Optional[MetadataFilteringCondition], metadata_filtering_conditions: Optional[MetadataFilteringCondition],
inputs: dict, inputs: dict,
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
document_query = db.session.query(DatasetDocument).filter( document_query = db.session.query(DatasetDocument).where(
DatasetDocument.dataset_id.in_(dataset_ids), DatasetDocument.dataset_id.in_(dataset_ids),
DatasetDocument.indexing_status == "completed", DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
@ -930,9 +930,9 @@ class DatasetRetrieval:
raise ValueError("Invalid metadata filtering mode") raise ValueError("Invalid metadata filtering mode")
if filters: if filters:
if metadata_filtering_conditions and metadata_filtering_conditions.logical_operator == "and": # type: ignore if metadata_filtering_conditions and metadata_filtering_conditions.logical_operator == "and": # type: ignore
document_query = document_query.filter(and_(*filters)) document_query = document_query.where(and_(*filters))
else: else:
document_query = document_query.filter(or_(*filters)) document_query = document_query.where(or_(*filters))
documents = document_query.all() documents = document_query.all()
# group by dataset_id # group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
@ -958,7 +958,7 @@ class DatasetRetrieval:
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
) -> Optional[list[dict[str, Any]]]: ) -> Optional[list[dict[str, Any]]]:
# get all metadata field # get all metadata field
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all() metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
# get metadata model config # get metadata model config
if metadata_model_config is None: if metadata_model_config is None:

@ -178,7 +178,7 @@ class ApiToolProviderController(ToolProviderController):
# get tenant api providers # get tenant api providers
db_providers: list[ApiToolProvider] = ( db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name) .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
.all() .all()
) )

@ -160,7 +160,7 @@ class ToolFileManager:
with Session(self._engine, expire_on_commit=False) as session: with Session(self._engine, expire_on_commit=False) as session:
tool_file: ToolFile | None = ( tool_file: ToolFile | None = (
session.query(ToolFile) session.query(ToolFile)
.filter( .where(
ToolFile.id == id, ToolFile.id == id,
) )
.first() .first()
@ -184,7 +184,7 @@ class ToolFileManager:
with Session(self._engine, expire_on_commit=False) as session: with Session(self._engine, expire_on_commit=False) as session:
message_file: MessageFile | None = ( message_file: MessageFile | None = (
session.query(MessageFile) session.query(MessageFile)
.filter( .where(
MessageFile.id == id, MessageFile.id == id,
) )
.first() .first()
@ -204,7 +204,7 @@ class ToolFileManager:
tool_file: ToolFile | None = ( tool_file: ToolFile | None = (
session.query(ToolFile) session.query(ToolFile)
.filter( .where(
ToolFile.id == tool_file_id, ToolFile.id == tool_file_id,
) )
.first() .first()
@ -228,7 +228,7 @@ class ToolFileManager:
with Session(self._engine, expire_on_commit=False) as session: with Session(self._engine, expire_on_commit=False) as session:
tool_file: ToolFile | None = ( tool_file: ToolFile | None = (
session.query(ToolFile) session.query(ToolFile)
.filter( .where(
ToolFile.id == tool_file_id, ToolFile.id == tool_file_id,
) )
.first() .first()

@ -29,7 +29,7 @@ class ToolLabelManager:
raise ValueError("Unsupported tool type") raise ValueError("Unsupported tool type")
# delete old labels # delete old labels
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete() db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete()
# insert new labels # insert new labels
for label in labels: for label in labels:
@ -57,7 +57,7 @@ class ToolLabelManager:
labels = ( labels = (
db.session.query(ToolLabelBinding.label_name) db.session.query(ToolLabelBinding.label_name)
.filter( .where(
ToolLabelBinding.tool_id == provider_id, ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type.value, ToolLabelBinding.tool_type == controller.provider_type.value,
) )
@ -90,7 +90,7 @@ class ToolLabelManager:
provider_ids.append(controller.provider_id) provider_ids.append(controller.provider_id)
labels: list[ToolLabelBinding] = ( labels: list[ToolLabelBinding] = (
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all() db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids)).all()
) )
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}

@ -198,7 +198,7 @@ class ToolManager:
try: try:
builtin_provider = ( builtin_provider = (
db.session.query(BuiltinToolProvider) db.session.query(BuiltinToolProvider)
.filter( .where(
BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id, BuiltinToolProvider.id == credential_id,
) )
@ -216,7 +216,7 @@ class ToolManager:
# use the default provider # use the default provider
builtin_provider = ( builtin_provider = (
db.session.query(BuiltinToolProvider) db.session.query(BuiltinToolProvider)
.filter( .where(
BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == str(provider_id_entity)) (BuiltinToolProvider.provider == str(provider_id_entity))
| (BuiltinToolProvider.provider == provider_id_entity.provider_name), | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
@ -229,7 +229,7 @@ class ToolManager:
else: else:
builtin_provider = ( builtin_provider = (
db.session.query(BuiltinToolProvider) db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) .where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first() .first()
) )
@ -316,7 +316,7 @@ class ToolManager:
elif provider_type == ToolProviderType.WORKFLOW: elif provider_type == ToolProviderType.WORKFLOW:
workflow_provider = ( workflow_provider = (
db.session.query(WorkflowToolProvider) db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first() .first()
) )
@ -616,7 +616,7 @@ class ToolManager:
ORDER BY tenant_id, provider, is_default DESC, created_at DESC ORDER BY tenant_id, provider, is_default DESC, created_at DESC
""" """
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()] ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all() return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
@classmethod @classmethod
def list_providers_from_api( def list_providers_from_api(
@ -664,7 +664,7 @@ class ToolManager:
# get db api providers # get db api providers
if "api" in filters: if "api" in filters:
db_api_providers: list[ApiToolProvider] = ( db_api_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all()
) )
api_provider_controllers: list[dict[str, Any]] = [ api_provider_controllers: list[dict[str, Any]] = [
@ -687,7 +687,7 @@ class ToolManager:
if "workflow" in filters: if "workflow" in filters:
# get workflow providers # get workflow providers
workflow_providers: list[WorkflowToolProvider] = ( workflow_providers: list[WorkflowToolProvider] = (
db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
) )
workflow_provider_controllers: list[WorkflowToolProviderController] = [] workflow_provider_controllers: list[WorkflowToolProviderController] = []
@ -731,7 +731,7 @@ class ToolManager:
""" """
provider: ApiToolProvider | None = ( provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter( .where(
ApiToolProvider.id == provider_id, ApiToolProvider.id == provider_id,
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
) )
@ -768,7 +768,7 @@ class ToolManager:
""" """
provider: MCPToolProvider | None = ( provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider) db.session.query(MCPToolProvider)
.filter( .where(
MCPToolProvider.server_identifier == provider_id, MCPToolProvider.server_identifier == provider_id,
MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.tenant_id == tenant_id,
) )
@ -793,7 +793,7 @@ class ToolManager:
provider_name = provider provider_name = provider
provider_obj: ApiToolProvider | None = ( provider_obj: ApiToolProvider | None = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter( .where(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider, ApiToolProvider.name == provider,
) )
@ -885,7 +885,7 @@ class ToolManager:
try: try:
workflow_provider: WorkflowToolProvider | None = ( workflow_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider) db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first() .first()
) )
@ -902,7 +902,7 @@ class ToolManager:
try: try:
api_provider: ApiToolProvider | None = ( api_provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider) db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
.first() .first()
) )
@ -919,7 +919,7 @@ class ToolManager:
try: try:
mcp_provider: MCPToolProvider | None = ( mcp_provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider) db.session.query(MCPToolProvider)
.filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id) .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
.first() .first()
) )
@ -1011,7 +1011,9 @@ class ToolManager:
if variable is None: if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist") raise ToolParameterError(f"Variable {tool_input.value} does not exist")
parameter_value = variable.value parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}: elif tool_input.type == "constant":
parameter_value = tool_input.value
elif tool_input.type == "mixed":
segment_group = variable_pool.convert_template(str(tool_input.value)) segment_group = variable_pool.convert_template(str(tool_input.value))
parameter_value = segment_group.text parameter_value = segment_group.text
else: else:

@ -87,7 +87,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
segments = ( segments = (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.filter( .where(
DocumentSegment.dataset_id.in_(self.dataset_ids), DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed", DocumentSegment.status == "completed",
@ -114,7 +114,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = ( document = (
db.session.query(Document) db.session.query(Document)
.filter( .where(
Document.id == segment.document_id, Document.id == segment.document_id,
Document.enabled == True, Document.enabled == True,
Document.archived == False, Document.archived == False,
@ -163,7 +163,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
): ):
with flask_app.app_context(): with flask_app.app_context():
dataset = ( dataset = (
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first() db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
) )
if not dataset: if not dataset:

@ -57,7 +57,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
def _run(self, query: str) -> str: def _run(self, query: str) -> str:
dataset = ( dataset = (
db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first() db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
) )
if not dataset: if not dataset:
@ -190,7 +190,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = ( document = (
db.session.query(DatasetDocument) # type: ignore db.session.query(DatasetDocument) # type: ignore
.filter( .where(
DatasetDocument.id == segment.document_id, DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
DatasetDocument.archived == False, DatasetDocument.archived == False,

@ -84,7 +84,7 @@ class WorkflowToolProviderController(ToolProviderController):
""" """
workflow: Workflow | None = ( workflow: Workflow | None = (
db.session.query(Workflow) db.session.query(Workflow)
.filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version) .where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
.first() .first()
) )
@ -190,7 +190,7 @@ class WorkflowToolProviderController(ToolProviderController):
db_providers: WorkflowToolProvider | None = ( db_providers: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider) db.session.query(WorkflowToolProvider)
.filter( .where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id, WorkflowToolProvider.app_id == self.provider_id,
) )

@ -142,12 +142,12 @@ class WorkflowTool(Tool):
if not version: if not version:
workflow = ( workflow = (
db.session.query(Workflow) db.session.query(Workflow)
.filter(Workflow.app_id == app_id, Workflow.version != "draft") .where(Workflow.app_id == app_id, Workflow.version != "draft")
.order_by(Workflow.created_at.desc()) .order_by(Workflow.created_at.desc())
.first() .first()
) )
else: else:
workflow = db.session.query(Workflow).filter(Workflow.app_id == app_id, Workflow.version == version).first() workflow = db.session.query(Workflow).where(Workflow.app_id == app_id, Workflow.version == version).first()
if not workflow: if not workflow:
raise ValueError("workflow not found or not published") raise ValueError("workflow not found or not published")
@ -158,7 +158,7 @@ class WorkflowTool(Tool):
""" """
get the app by app id get the app by app id
""" """
app = db.session.query(App).filter(App.id == app_id).first() app = db.session.query(App).where(App.id == app_id).first()
if not app: if not app:
raise ValueError("app not found") raise ValueError("app not found")

@ -309,7 +309,7 @@ class AgentNode(BaseNode):
} }
) )
value = tool_value value = tool_value
if parameter.type == "model-selector": if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
value = cast(dict[str, Any], value) value = cast(dict[str, Any], value)
model_instance, model_schema = self._fetch_model(value) model_instance, model_schema = self._fetch_model(value)
# memory config # memory config

@ -228,7 +228,7 @@ class KnowledgeRetrievalNode(BaseNode):
# Subquery: Count the number of available documents for each dataset # Subquery: Count the number of available documents for each dataset
subquery = ( subquery = (
db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count")) db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count"))
.filter( .where(
Document.indexing_status == "completed", Document.indexing_status == "completed",
Document.enabled == True, Document.enabled == True,
Document.archived == False, Document.archived == False,
@ -242,8 +242,8 @@ class KnowledgeRetrievalNode(BaseNode):
results = ( results = (
db.session.query(Dataset) db.session.query(Dataset)
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id) .outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
.filter(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids)) .where(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
.filter((subquery.c.available_document_count > 0) | (Dataset.provider == "external")) .where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
.all() .all()
) )
@ -370,7 +370,7 @@ class KnowledgeRetrievalNode(BaseNode):
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
document = ( document = (
db.session.query(Document) db.session.query(Document)
.filter( .where(
Document.id == segment.document_id, Document.id == segment.document_id,
Document.enabled == True, Document.enabled == True,
Document.archived == False, Document.archived == False,
@ -415,7 +415,7 @@ class KnowledgeRetrievalNode(BaseNode):
def _get_metadata_filter_condition( def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]: ) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
document_query = db.session.query(Document).filter( document_query = db.session.query(Document).where(
Document.dataset_id.in_(dataset_ids), Document.dataset_id.in_(dataset_ids),
Document.indexing_status == "completed", Document.indexing_status == "completed",
Document.enabled == True, Document.enabled == True,
@ -493,9 +493,9 @@ class KnowledgeRetrievalNode(BaseNode):
node_data.metadata_filtering_conditions node_data.metadata_filtering_conditions
and node_data.metadata_filtering_conditions.logical_operator == "and" and node_data.metadata_filtering_conditions.logical_operator == "and"
): # type: ignore ): # type: ignore
document_query = document_query.filter(and_(*filters)) document_query = document_query.where(and_(*filters))
else: else:
document_query = document_query.filter(or_(*filters)) document_query = document_query.where(or_(*filters))
documents = document_query.all() documents = document_query.all()
# group by dataset_id # group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
@ -507,7 +507,7 @@ class KnowledgeRetrievalNode(BaseNode):
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
# get all metadata field # get all metadata field
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all() metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
if node_data.metadata_model_config is None: if node_data.metadata_model_config is None:
raise ValueError("metadata_model_config is required") raise ValueError("metadata_model_config is required")

@ -184,11 +184,10 @@ class ListOperatorNode(BaseNode):
value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text) value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
if value < 1: if value < 1:
raise ValueError(f"Invalid serial index: must be >= 1, got {value}") raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
if value > len(variable.value):
raise InvalidKeyError(f"Invalid serial index: must be <= {len(variable.value)}, got {value}")
value -= 1 value -= 1
if len(variable.value) > int(value):
result = variable.value[value] result = variable.value[value]
else:
result = ""
return variable.model_copy(update={"value": [result]}) return variable.model_copy(update={"value": [result]})

@ -54,7 +54,7 @@ class ToolNodeData(BaseNodeData, ToolEntity):
for val in value: for val in value:
if not isinstance(val, str): if not isinstance(val, str):
raise ValueError("value must be a list of strings") raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, str | int | float | bool): elif typ == "constant" and not isinstance(value, str | int | float | bool | dict):
raise ValueError("value must be a string, int, float, or bool") raise ValueError("value must be a string, int, float, or bool")
return typ return typ

@ -5,6 +5,11 @@ set -e
if [[ "${MIGRATION_ENABLED}" == "true" ]]; then if [[ "${MIGRATION_ENABLED}" == "true" ]]; then
echo "Running migrations" echo "Running migrations"
flask upgrade-db flask upgrade-db
# Pure migration mode
if [[ "${MODE}" == "migration" ]]; then
echo "Migration completed, exiting normally"
exit 0
fi
fi fi
if [[ "${MODE}" == "worker" ]]; then if [[ "${MODE}" == "worker" ]]; then

@ -22,7 +22,7 @@ def handle(sender, **kwargs):
document = ( document = (
db.session.query(Document) db.session.query(Document)
.filter( .where(
Document.id == document_id, Document.id == document_id,
Document.dataset_id == dataset_id, Document.dataset_id == dataset_id,
) )

@ -13,7 +13,7 @@ def handle(sender, **kwargs):
dataset_ids = get_dataset_ids_from_model_config(app_model_config) dataset_ids = get_dataset_ids_from_model_config(app_model_config)
app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all()
removed_dataset_ids: set[str] = set() removed_dataset_ids: set[str] = set()
if not app_dataset_joins: if not app_dataset_joins:
@ -27,7 +27,7 @@ def handle(sender, **kwargs):
if removed_dataset_ids: if removed_dataset_ids:
for dataset_id in removed_dataset_ids: for dataset_id in removed_dataset_ids:
db.session.query(AppDatasetJoin).filter( db.session.query(AppDatasetJoin).where(
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
).delete() ).delete()

@ -15,7 +15,7 @@ def handle(sender, **kwargs):
published_workflow = cast(Workflow, published_workflow) published_workflow = cast(Workflow, published_workflow)
dataset_ids = get_dataset_ids_from_workflow(published_workflow) dataset_ids = get_dataset_ids_from_workflow(published_workflow)
app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all() app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all()
removed_dataset_ids: set[str] = set() removed_dataset_ids: set[str] = set()
if not app_dataset_joins: if not app_dataset_joins:
@ -29,7 +29,7 @@ def handle(sender, **kwargs):
if removed_dataset_ids: if removed_dataset_ids:
for dataset_id in removed_dataset_ids: for dataset_id in removed_dataset_ids:
db.session.query(AppDatasetJoin).filter( db.session.query(AppDatasetJoin).where(
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
).delete() ).delete()

@ -40,9 +40,9 @@ def load_user_from_request(request_from_flask_login):
if workspace_id: if workspace_id:
tenant_account_join = ( tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin) db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == workspace_id) .where(Tenant.id == workspace_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id) .where(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.role == "owner") .where(TenantAccountJoin.role == "owner")
.one_or_none() .one_or_none()
) )
if tenant_account_join: if tenant_account_join:
@ -70,7 +70,7 @@ def load_user_from_request(request_from_flask_login):
end_user_id = decoded.get("end_user_id") end_user_id = decoded.get("end_user_id")
if not end_user_id: if not end_user_id:
raise Unauthorized("Invalid Authorization token.") raise Unauthorized("Invalid Authorization token.")
end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first() end_user = db.session.query(EndUser).where(EndUser.id == decoded["end_user_id"]).first()
if not end_user: if not end_user:
raise NotFound("End user not found.") raise NotFound("End user not found.")
return end_user return end_user
@ -78,12 +78,12 @@ def load_user_from_request(request_from_flask_login):
server_code = request.view_args.get("server_code") if request.view_args else None server_code = request.view_args.get("server_code") if request.view_args else None
if not server_code: if not server_code:
raise Unauthorized("Invalid Authorization token.") raise Unauthorized("Invalid Authorization token.")
app_mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first() app_mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
if not app_mcp_server: if not app_mcp_server:
raise NotFound("App MCP server not found.") raise NotFound("App MCP server not found.")
end_user = ( end_user = (
db.session.query(EndUser) db.session.query(EndUser)
.filter(EndUser.external_user_id == app_mcp_server.id, EndUser.type == "mcp") .where(EndUser.external_user_id == app_mcp_server.id, EndUser.type == "mcp")
.first() .first()
) )
if not end_user: if not end_user:

@ -261,13 +261,11 @@ def _build_from_tool_file(
transfer_method: FileTransferMethod, transfer_method: FileTransferMethod,
strict_type_validation: bool = False, strict_type_validation: bool = False,
) -> File: ) -> File:
tool_file = ( tool_file = db.session.scalar(
db.session.query(ToolFile) select(ToolFile).where(
.filter(
ToolFile.id == mapping.get("tool_file_id"), ToolFile.id == mapping.get("tool_file_id"),
ToolFile.tenant_id == tenant_id, ToolFile.tenant_id == tenant_id,
) )
.first()
) )
if tool_file is None: if tool_file is None:
@ -275,7 +273,7 @@ def _build_from_tool_file(
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
detected_file_type = _standardize_file_type(extension="." + extension, mime_type=tool_file.mimetype) detected_file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
specified_type = mapping.get("type") specified_type = mapping.get("type")

@ -25,6 +25,7 @@ class EmailType(Enum):
EMAIL_CODE_LOGIN = "email_code_login" EMAIL_CODE_LOGIN = "email_code_login"
CHANGE_EMAIL_OLD = "change_email_old" CHANGE_EMAIL_OLD = "change_email_old"
CHANGE_EMAIL_NEW = "change_email_new" CHANGE_EMAIL_NEW = "change_email_new"
CHANGE_EMAIL_COMPLETED = "change_email_completed"
OWNER_TRANSFER_CONFIRM = "owner_transfer_confirm" OWNER_TRANSFER_CONFIRM = "owner_transfer_confirm"
OWNER_TRANSFER_OLD_NOTIFY = "owner_transfer_old_notify" OWNER_TRANSFER_OLD_NOTIFY = "owner_transfer_old_notify"
OWNER_TRANSFER_NEW_NOTIFY = "owner_transfer_new_notify" OWNER_TRANSFER_NEW_NOTIFY = "owner_transfer_new_notify"
@ -344,6 +345,18 @@ def create_default_email_config() -> EmailI18nConfig:
branded_template_path="without-brand/change_mail_confirm_new_template_zh-CN.html", branded_template_path="without-brand/change_mail_confirm_new_template_zh-CN.html",
), ),
}, },
EmailType.CHANGE_EMAIL_COMPLETED: {
EmailLanguage.EN_US: EmailTemplate(
subject="Your login email has been changed",
template_path="change_mail_completed_template_en-US.html",
branded_template_path="without-brand/change_mail_completed_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="您的登录邮箱已更改",
template_path="change_mail_completed_template_zh-CN.html",
branded_template_path="without-brand/change_mail_completed_template_zh-CN.html",
),
},
EmailType.OWNER_TRANSFER_CONFIRM: { EmailType.OWNER_TRANSFER_CONFIRM: {
EmailLanguage.EN_US: EmailTemplate( EmailLanguage.EN_US: EmailTemplate(
subject="Verify Your Request to Transfer Workspace Ownership", subject="Verify Your Request to Transfer Workspace Ownership",

@ -3,6 +3,7 @@ from typing import Any
import requests import requests
from flask_login import current_user from flask_login import current_user
from sqlalchemy import select
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
@ -61,17 +62,13 @@ class NotionOAuth(OAuthDataSource):
"total": len(pages), "total": len(pages),
} }
# save data source binding # save data source binding
data_source_binding = ( data_source_binding = db.session.scalar(
db.session.query(DataSourceOauthBinding) select(DataSourceOauthBinding).where(
.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token, DataSourceOauthBinding.access_token == access_token,
) )
) )
.first()
)
if data_source_binding: if data_source_binding:
data_source_binding.source_info = source_info data_source_binding.source_info = source_info
data_source_binding.disabled = False data_source_binding.disabled = False
@ -101,17 +98,13 @@ class NotionOAuth(OAuthDataSource):
"total": len(pages), "total": len(pages),
} }
# save data source binding # save data source binding
data_source_binding = ( data_source_binding = db.session.scalar(
db.session.query(DataSourceOauthBinding) select(DataSourceOauthBinding).where(
.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token, DataSourceOauthBinding.access_token == access_token,
) )
) )
.first()
)
if data_source_binding: if data_source_binding:
data_source_binding.source_info = source_info data_source_binding.source_info = source_info
data_source_binding.disabled = False data_source_binding.disabled = False
@ -129,18 +122,15 @@ class NotionOAuth(OAuthDataSource):
def sync_data_source(self, binding_id: str): def sync_data_source(self, binding_id: str):
# save data source binding # save data source binding
data_source_binding = ( data_source_binding = db.session.scalar(
db.session.query(DataSourceOauthBinding) select(DataSourceOauthBinding).where(
.filter(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.id == binding_id, DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.disabled == False,
) )
) )
.first()
)
if data_source_binding: if data_source_binding:
# get all authorized pages # get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token) pages = self.get_authorized_pages(data_source_binding.access_token)

@ -1,4 +1,5 @@
import hashlib import hashlib
import os
from typing import Union from typing import Union
from Crypto.Cipher import AES from Crypto.Cipher import AES
@ -17,7 +18,7 @@ def generate_key_pair(tenant_id: str) -> str:
pem_private = private_key.export_key() pem_private = private_key.export_key()
pem_public = public_key.export_key() pem_public = public_key.export_key()
filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem" filepath = os.path.join("privkeys", tenant_id, "private.pem")
storage.save(filepath, pem_private) storage.save(filepath, pem_private)
@ -47,7 +48,7 @@ def encrypt(text: str, public_key: Union[str, bytes]) -> bytes:
def get_decrypt_decoding(tenant_id: str) -> tuple[RSA.RsaKey, object]: def get_decrypt_decoding(tenant_id: str) -> tuple[RSA.RsaKey, object]:
filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem" filepath = os.path.join("privkeys", tenant_id, "private.pem")
cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest()) cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest())
private_key = redis_client.get(cache_key) private_key = redis_client.get(cache_key)

@ -4,7 +4,7 @@ from datetime import datetime
from typing import Optional, cast from typing import Optional, cast
from flask_login import UserMixin # type: ignore from flask_login import UserMixin # type: ignore
from sqlalchemy import func from sqlalchemy import func, select
from sqlalchemy.orm import Mapped, mapped_column, reconstructor from sqlalchemy.orm import Mapped, mapped_column, reconstructor
from models.base import Base from models.base import Base
@ -119,7 +119,7 @@ class Account(UserMixin, Base):
@current_tenant.setter @current_tenant.setter
def current_tenant(self, tenant: "Tenant"): def current_tenant(self, tenant: "Tenant"):
ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first() ta = db.session.scalar(select(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).limit(1))
if ta: if ta:
self.role = TenantAccountRole(ta.role) self.role = TenantAccountRole(ta.role)
self._current_tenant = tenant self._current_tenant = tenant
@ -135,9 +135,9 @@ class Account(UserMixin, Base):
tuple[Tenant, TenantAccountJoin], tuple[Tenant, TenantAccountJoin],
( (
db.session.query(Tenant, TenantAccountJoin) db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == tenant_id) .where(Tenant.id == tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id) .where(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.account_id == self.id) .where(TenantAccountJoin.account_id == self.id)
.one_or_none() .one_or_none()
), ),
) )
@ -161,11 +161,11 @@ class Account(UserMixin, Base):
def get_by_openid(cls, provider: str, open_id: str): def get_by_openid(cls, provider: str, open_id: str):
account_integrate = ( account_integrate = (
db.session.query(AccountIntegrate) db.session.query(AccountIntegrate)
.filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) .where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
.one_or_none() .one_or_none()
) )
if account_integrate: if account_integrate:
return db.session.query(Account).filter(Account.id == account_integrate.account_id).one_or_none() return db.session.query(Account).where(Account.id == account_integrate.account_id).one_or_none()
return None return None
# check current_user.current_tenant.current_role in ['admin', 'owner'] # check current_user.current_tenant.current_role in ['admin', 'owner']
@ -211,7 +211,7 @@ class Tenant(Base):
def get_accounts(self) -> list[Account]: def get_accounts(self) -> list[Account]:
return ( return (
db.session.query(Account) db.session.query(Account)
.filter(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id) .where(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id)
.all() .all()
) )

@ -12,7 +12,7 @@ from datetime import datetime
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any, Optional, cast from typing import Any, Optional, cast
from sqlalchemy import func from sqlalchemy import func, select
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
@ -68,7 +68,7 @@ class Dataset(Base):
@property @property
def dataset_keyword_table(self): def dataset_keyword_table(self):
dataset_keyword_table = ( dataset_keyword_table = (
db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == self.id).first() db.session.query(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id).first()
) )
if dataset_keyword_table: if dataset_keyword_table:
return dataset_keyword_table return dataset_keyword_table
@ -95,7 +95,7 @@ class Dataset(Base):
def latest_process_rule(self): def latest_process_rule(self):
return ( return (
db.session.query(DatasetProcessRule) db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.dataset_id == self.id) .where(DatasetProcessRule.dataset_id == self.id)
.order_by(DatasetProcessRule.created_at.desc()) .order_by(DatasetProcessRule.created_at.desc())
.first() .first()
) )
@ -104,19 +104,19 @@ class Dataset(Base):
def app_count(self): def app_count(self):
return ( return (
db.session.query(func.count(AppDatasetJoin.id)) db.session.query(func.count(AppDatasetJoin.id))
.filter(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id) .where(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)
.scalar() .scalar()
) )
@property @property
def document_count(self): def document_count(self):
return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar() return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
@property @property
def available_document_count(self): def available_document_count(self):
return ( return (
db.session.query(func.count(Document.id)) db.session.query(func.count(Document.id))
.filter( .where(
Document.dataset_id == self.id, Document.dataset_id == self.id,
Document.indexing_status == "completed", Document.indexing_status == "completed",
Document.enabled == True, Document.enabled == True,
@ -129,7 +129,7 @@ class Dataset(Base):
def available_segment_count(self): def available_segment_count(self):
return ( return (
db.session.query(func.count(DocumentSegment.id)) db.session.query(func.count(DocumentSegment.id))
.filter( .where(
DocumentSegment.dataset_id == self.id, DocumentSegment.dataset_id == self.id,
DocumentSegment.status == "completed", DocumentSegment.status == "completed",
DocumentSegment.enabled == True, DocumentSegment.enabled == True,
@ -142,13 +142,13 @@ class Dataset(Base):
return ( return (
db.session.query(Document) db.session.query(Document)
.with_entities(func.coalesce(func.sum(Document.word_count), 0)) .with_entities(func.coalesce(func.sum(Document.word_count), 0))
.filter(Document.dataset_id == self.id) .where(Document.dataset_id == self.id)
.scalar() .scalar()
) )
@property @property
def doc_form(self): def doc_form(self):
document = db.session.query(Document).filter(Document.dataset_id == self.id).first() document = db.session.query(Document).where(Document.dataset_id == self.id).first()
if document: if document:
return document.doc_form return document.doc_form
return None return None
@ -169,7 +169,7 @@ class Dataset(Base):
tags = ( tags = (
db.session.query(Tag) db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id) .join(TagBinding, Tag.id == TagBinding.tag_id)
.filter( .where(
TagBinding.target_id == self.id, TagBinding.target_id == self.id,
TagBinding.tenant_id == self.tenant_id, TagBinding.tenant_id == self.tenant_id,
Tag.tenant_id == self.tenant_id, Tag.tenant_id == self.tenant_id,
@ -185,14 +185,14 @@ class Dataset(Base):
if self.provider != "external": if self.provider != "external":
return None return None
external_knowledge_binding = ( external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).filter(ExternalKnowledgeBindings.dataset_id == self.id).first() db.session.query(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id).first()
) )
if not external_knowledge_binding: if not external_knowledge_binding:
return None return None
external_knowledge_api = ( external_knowledge_api = db.session.scalar(
db.session.query(ExternalKnowledgeApis) select(ExternalKnowledgeApis).where(
.filter(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id) ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id
.first() )
) )
if not external_knowledge_api: if not external_knowledge_api:
return None return None
@ -205,7 +205,7 @@ class Dataset(Base):
@property @property
def doc_metadata(self): def doc_metadata(self):
dataset_metadatas = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id == self.id).all() dataset_metadatas = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id).all()
doc_metadata = [ doc_metadata = [
{ {
@ -408,7 +408,7 @@ class Document(Base):
data_source_info_dict = json.loads(self.data_source_info) data_source_info_dict = json.loads(self.data_source_info)
file_detail = ( file_detail = (
db.session.query(UploadFile) db.session.query(UploadFile)
.filter(UploadFile.id == data_source_info_dict["upload_file_id"]) .where(UploadFile.id == data_source_info_dict["upload_file_id"])
.one_or_none() .one_or_none()
) )
if file_detail: if file_detail:
@ -441,24 +441,24 @@ class Document(Base):
@property @property
def dataset(self): def dataset(self):
return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none() return db.session.query(Dataset).where(Dataset.id == self.dataset_id).one_or_none()
@property @property
def segment_count(self): def segment_count(self):
return db.session.query(DocumentSegment).filter(DocumentSegment.document_id == self.id).count() return db.session.query(DocumentSegment).where(DocumentSegment.document_id == self.id).count()
@property @property
def hit_count(self): def hit_count(self):
return ( return (
db.session.query(DocumentSegment) db.session.query(DocumentSegment)
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0)) .with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0))
.filter(DocumentSegment.document_id == self.id) .where(DocumentSegment.document_id == self.id)
.scalar() .scalar()
) )
@property @property
def uploader(self): def uploader(self):
user = db.session.query(Account).filter(Account.id == self.created_by).first() user = db.session.query(Account).where(Account.id == self.created_by).first()
return user.name if user else None return user.name if user else None
@property @property
@ -475,7 +475,7 @@ class Document(Base):
document_metadatas = ( document_metadatas = (
db.session.query(DatasetMetadata) db.session.query(DatasetMetadata)
.join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id) .join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
.filter( .where(
DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
) )
.all() .all()
@ -687,26 +687,26 @@ class DocumentSegment(Base):
@property @property
def dataset(self): def dataset(self):
return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first() return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
@property @property
def document(self): def document(self):
return db.session.query(Document).filter(Document.id == self.document_id).first() return db.session.scalar(select(Document).where(Document.id == self.document_id))
@property @property
def previous_segment(self): def previous_segment(self):
return ( return db.session.scalar(
db.session.query(DocumentSegment) select(DocumentSegment).where(
.filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1) DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1
.first() )
) )
@property @property
def next_segment(self): def next_segment(self):
return ( return db.session.scalar(
db.session.query(DocumentSegment) select(DocumentSegment).where(
.filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1) DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1
.first() )
) )
@property @property
@ -717,7 +717,7 @@ class DocumentSegment(Base):
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = ( child_chunks = (
db.session.query(ChildChunk) db.session.query(ChildChunk)
.filter(ChildChunk.segment_id == self.id) .where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc()) .order_by(ChildChunk.position.asc())
.all() .all()
) )
@ -734,7 +734,7 @@ class DocumentSegment(Base):
if rules.parent_mode: if rules.parent_mode:
child_chunks = ( child_chunks = (
db.session.query(ChildChunk) db.session.query(ChildChunk)
.filter(ChildChunk.segment_id == self.id) .where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc()) .order_by(ChildChunk.position.asc())
.all() .all()
) )
@ -825,15 +825,15 @@ class ChildChunk(Base):
@property @property
def dataset(self): def dataset(self):
return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first() return db.session.query(Dataset).where(Dataset.id == self.dataset_id).first()
@property @property
def document(self): def document(self):
return db.session.query(Document).filter(Document.id == self.document_id).first() return db.session.query(Document).where(Document.id == self.document_id).first()
@property @property
def segment(self): def segment(self):
return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first() return db.session.query(DocumentSegment).where(DocumentSegment.id == self.segment_id).first()
class AppDatasetJoin(Base): class AppDatasetJoin(Base):
@ -1044,11 +1044,11 @@ class ExternalKnowledgeApis(Base):
def dataset_bindings(self): def dataset_bindings(self):
external_knowledge_bindings = ( external_knowledge_bindings = (
db.session.query(ExternalKnowledgeBindings) db.session.query(ExternalKnowledgeBindings)
.filter(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) .where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
.all() .all()
) )
dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
datasets = db.session.query(Dataset).filter(Dataset.id.in_(dataset_ids)).all() datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
dataset_bindings = [] dataset_bindings = []
for dataset in datasets: for dataset in datasets:
dataset_bindings.append({"id": dataset.id, "name": dataset.name}) dataset_bindings.append({"id": dataset.id, "name": dataset.name})

@ -113,13 +113,13 @@ class App(Base):
@property @property
def site(self): def site(self):
site = db.session.query(Site).filter(Site.app_id == self.id).first() site = db.session.query(Site).where(Site.app_id == self.id).first()
return site return site
@property @property
def app_model_config(self): def app_model_config(self):
if self.app_model_config_id: if self.app_model_config_id:
return db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
return None return None
@ -128,7 +128,7 @@ class App(Base):
if self.workflow_id: if self.workflow_id:
from .workflow import Workflow from .workflow import Workflow
return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
return None return None
@ -138,7 +138,7 @@ class App(Base):
@property @property
def tenant(self): def tenant(self):
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant return tenant
@property @property
@ -282,7 +282,7 @@ class App(Base):
tags = ( tags = (
db.session.query(Tag) db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id) .join(TagBinding, Tag.id == TagBinding.tag_id)
.filter( .where(
TagBinding.target_id == self.id, TagBinding.target_id == self.id,
TagBinding.tenant_id == self.tenant_id, TagBinding.tenant_id == self.tenant_id,
Tag.tenant_id == self.tenant_id, Tag.tenant_id == self.tenant_id,
@ -296,7 +296,7 @@ class App(Base):
@property @property
def author_name(self): def author_name(self):
if self.created_by: if self.created_by:
account = db.session.query(Account).filter(Account.id == self.created_by).first() account = db.session.query(Account).where(Account.id == self.created_by).first()
if account: if account:
return account.name return account.name
@ -338,7 +338,7 @@ class AppModelConfig(Base):
@property @property
def app(self): def app(self):
app = db.session.query(App).filter(App.id == self.app_id).first() app = db.session.query(App).where(App.id == self.app_id).first()
return app return app
@property @property
@ -372,7 +372,7 @@ class AppModelConfig(Base):
@property @property
def annotation_reply_dict(self) -> dict: def annotation_reply_dict(self) -> dict:
annotation_setting = ( annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == self.app_id).first() db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
) )
if annotation_setting: if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail collection_binding_detail = annotation_setting.collection_binding_detail
@ -577,7 +577,7 @@ class RecommendedApp(Base):
@property @property
def app(self): def app(self):
app = db.session.query(App).filter(App.id == self.app_id).first() app = db.session.query(App).where(App.id == self.app_id).first()
return app return app
@ -601,12 +601,12 @@ class InstalledApp(Base):
@property @property
def app(self): def app(self):
app = db.session.query(App).filter(App.id == self.app_id).first() app = db.session.query(App).where(App.id == self.app_id).first()
return app return app
@property @property
def tenant(self): def tenant(self):
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant return tenant
@ -714,7 +714,7 @@ class Conversation(Base):
model_config["configs"] = override_model_configs model_config["configs"] = override_model_configs
else: else:
app_model_config = ( app_model_config = (
db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
) )
if app_model_config: if app_model_config:
model_config = app_model_config.to_dict() model_config = app_model_config.to_dict()
@ -737,21 +737,21 @@ class Conversation(Base):
@property @property
def annotated(self): def annotated(self):
return db.session.query(MessageAnnotation).filter(MessageAnnotation.conversation_id == self.id).count() > 0 return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).count() > 0
@property @property
def annotation(self): def annotation(self):
return db.session.query(MessageAnnotation).filter(MessageAnnotation.conversation_id == self.id).first() return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).first()
@property @property
def message_count(self): def message_count(self):
return db.session.query(Message).filter(Message.conversation_id == self.id).count() return db.session.query(Message).where(Message.conversation_id == self.id).count()
@property @property
def user_feedback_stats(self): def user_feedback_stats(self):
like = ( like = (
db.session.query(MessageFeedback) db.session.query(MessageFeedback)
.filter( .where(
MessageFeedback.conversation_id == self.id, MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user", MessageFeedback.from_source == "user",
MessageFeedback.rating == "like", MessageFeedback.rating == "like",
@ -761,7 +761,7 @@ class Conversation(Base):
dislike = ( dislike = (
db.session.query(MessageFeedback) db.session.query(MessageFeedback)
.filter( .where(
MessageFeedback.conversation_id == self.id, MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user", MessageFeedback.from_source == "user",
MessageFeedback.rating == "dislike", MessageFeedback.rating == "dislike",
@ -775,7 +775,7 @@ class Conversation(Base):
def admin_feedback_stats(self): def admin_feedback_stats(self):
like = ( like = (
db.session.query(MessageFeedback) db.session.query(MessageFeedback)
.filter( .where(
MessageFeedback.conversation_id == self.id, MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin", MessageFeedback.from_source == "admin",
MessageFeedback.rating == "like", MessageFeedback.rating == "like",
@ -785,7 +785,7 @@ class Conversation(Base):
dislike = ( dislike = (
db.session.query(MessageFeedback) db.session.query(MessageFeedback)
.filter( .where(
MessageFeedback.conversation_id == self.id, MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin", MessageFeedback.from_source == "admin",
MessageFeedback.rating == "dislike", MessageFeedback.rating == "dislike",
@ -797,7 +797,7 @@ class Conversation(Base):
@property @property
def status_count(self): def status_count(self):
messages = db.session.query(Message).filter(Message.conversation_id == self.id).all() messages = db.session.query(Message).where(Message.conversation_id == self.id).all()
status_counts = { status_counts = {
WorkflowExecutionStatus.RUNNING: 0, WorkflowExecutionStatus.RUNNING: 0,
WorkflowExecutionStatus.SUCCEEDED: 0, WorkflowExecutionStatus.SUCCEEDED: 0,
@ -824,19 +824,19 @@ class Conversation(Base):
def first_message(self): def first_message(self):
return ( return (
db.session.query(Message) db.session.query(Message)
.filter(Message.conversation_id == self.id) .where(Message.conversation_id == self.id)
.order_by(Message.created_at.asc()) .order_by(Message.created_at.asc())
.first() .first()
) )
@property @property
def app(self): def app(self):
return db.session.query(App).filter(App.id == self.app_id).first() return db.session.query(App).where(App.id == self.app_id).first()
@property @property
def from_end_user_session_id(self): def from_end_user_session_id(self):
if self.from_end_user_id: if self.from_end_user_id:
end_user = db.session.query(EndUser).filter(EndUser.id == self.from_end_user_id).first() end_user = db.session.query(EndUser).where(EndUser.id == self.from_end_user_id).first()
if end_user: if end_user:
return end_user.session_id return end_user.session_id
@ -845,7 +845,7 @@ class Conversation(Base):
@property @property
def from_account_name(self): def from_account_name(self):
if self.from_account_id: if self.from_account_id:
account = db.session.query(Account).filter(Account.id == self.from_account_id).first() account = db.session.query(Account).where(Account.id == self.from_account_id).first()
if account: if account:
return account.name return account.name
@ -1040,7 +1040,7 @@ class Message(Base):
def user_feedback(self): def user_feedback(self):
feedback = ( feedback = (
db.session.query(MessageFeedback) db.session.query(MessageFeedback)
.filter(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user") .where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
.first() .first()
) )
return feedback return feedback
@ -1049,30 +1049,30 @@ class Message(Base):
def admin_feedback(self): def admin_feedback(self):
feedback = ( feedback = (
db.session.query(MessageFeedback) db.session.query(MessageFeedback)
.filter(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin") .where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
.first() .first()
) )
return feedback return feedback
@property @property
def feedbacks(self): def feedbacks(self):
feedbacks = db.session.query(MessageFeedback).filter(MessageFeedback.message_id == self.id).all() feedbacks = db.session.query(MessageFeedback).where(MessageFeedback.message_id == self.id).all()
return feedbacks return feedbacks
@property @property
def annotation(self): def annotation(self):
annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == self.id).first() annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == self.id).first()
return annotation return annotation
@property @property
def annotation_hit_history(self): def annotation_hit_history(self):
annotation_history = ( annotation_history = (
db.session.query(AppAnnotationHitHistory).filter(AppAnnotationHitHistory.message_id == self.id).first() db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id).first()
) )
if annotation_history: if annotation_history:
annotation = ( annotation = (
db.session.query(MessageAnnotation) db.session.query(MessageAnnotation)
.filter(MessageAnnotation.id == annotation_history.annotation_id) .where(MessageAnnotation.id == annotation_history.annotation_id)
.first() .first()
) )
return annotation return annotation
@ -1080,11 +1080,9 @@ class Message(Base):
@property @property
def app_model_config(self): def app_model_config(self):
conversation = db.session.query(Conversation).filter(Conversation.id == self.conversation_id).first() conversation = db.session.query(Conversation).where(Conversation.id == self.conversation_id).first()
if conversation: if conversation:
return ( return db.session.query(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id).first()
db.session.query(AppModelConfig).filter(AppModelConfig.id == conversation.app_model_config_id).first()
)
return None return None
@ -1100,7 +1098,7 @@ class Message(Base):
def agent_thoughts(self): def agent_thoughts(self):
return ( return (
db.session.query(MessageAgentThought) db.session.query(MessageAgentThought)
.filter(MessageAgentThought.message_id == self.id) .where(MessageAgentThought.message_id == self.id)
.order_by(MessageAgentThought.position.asc()) .order_by(MessageAgentThought.position.asc())
.all() .all()
) )
@ -1113,8 +1111,8 @@ class Message(Base):
def message_files(self): def message_files(self):
from factories import file_factory from factories import file_factory
message_files = db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all() message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all()
current_app = db.session.query(App).filter(App.id == self.app_id).first() current_app = db.session.query(App).where(App.id == self.app_id).first()
if not current_app: if not current_app:
raise ValueError(f"App {self.app_id} not found") raise ValueError(f"App {self.app_id} not found")
@ -1178,7 +1176,7 @@ class Message(Base):
if self.workflow_run_id: if self.workflow_run_id:
from .workflow import WorkflowRun from .workflow import WorkflowRun
return db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first() return db.session.query(WorkflowRun).where(WorkflowRun.id == self.workflow_run_id).first()
return None return None
@ -1253,7 +1251,7 @@ class MessageFeedback(Base):
@property @property
def from_account(self): def from_account(self):
account = db.session.query(Account).filter(Account.id == self.from_account_id).first() account = db.session.query(Account).where(Account.id == self.from_account_id).first()
return account return account
def to_dict(self): def to_dict(self):
@ -1335,12 +1333,12 @@ class MessageAnnotation(Base):
@property @property
def account(self): def account(self):
account = db.session.query(Account).filter(Account.id == self.account_id).first() account = db.session.query(Account).where(Account.id == self.account_id).first()
return account return account
@property @property
def annotation_create_account(self): def annotation_create_account(self):
account = db.session.query(Account).filter(Account.id == self.account_id).first() account = db.session.query(Account).where(Account.id == self.account_id).first()
return account return account
@ -1371,14 +1369,14 @@ class AppAnnotationHitHistory(Base):
account = ( account = (
db.session.query(Account) db.session.query(Account)
.join(MessageAnnotation, MessageAnnotation.account_id == Account.id) .join(MessageAnnotation, MessageAnnotation.account_id == Account.id)
.filter(MessageAnnotation.id == self.annotation_id) .where(MessageAnnotation.id == self.annotation_id)
.first() .first()
) )
return account return account
@property @property
def annotation_create_account(self): def annotation_create_account(self):
account = db.session.query(Account).filter(Account.id == self.account_id).first() account = db.session.query(Account).where(Account.id == self.account_id).first()
return account return account
@ -1404,7 +1402,7 @@ class AppAnnotationSetting(Base):
collection_binding_detail = ( collection_binding_detail = (
db.session.query(DatasetCollectionBinding) db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == self.collection_binding_id) .where(DatasetCollectionBinding.id == self.collection_binding_id)
.first() .first()
) )
return collection_binding_detail return collection_binding_detail
@ -1470,7 +1468,7 @@ class AppMCPServer(Base):
def generate_server_code(n): def generate_server_code(n):
while True: while True:
result = generate_string(n) result = generate_string(n)
while db.session.query(AppMCPServer).filter(AppMCPServer.server_code == result).count() > 0: while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0:
result = generate_string(n) result = generate_string(n)
return result return result
@ -1527,7 +1525,7 @@ class Site(Base):
def generate_code(n): def generate_code(n):
while True: while True:
result = generate_string(n) result = generate_string(n)
while db.session.query(Site).filter(Site.code == result).count() > 0: while db.session.query(Site).where(Site.code == result).count() > 0:
result = generate_string(n) result = generate_string(n)
return result return result
@ -1558,7 +1556,7 @@ class ApiToken(Base):
def generate_api_key(prefix, n): def generate_api_key(prefix, n):
while True: while True:
result = prefix + generate_string(n) result = prefix + generate_string(n)
if db.session.query(ApiToken).filter(ApiToken.token == result).count() > 0: if db.session.query(ApiToken).where(ApiToken.token == result).count() > 0:
continue continue
return result return result

@ -153,11 +153,11 @@ class ApiToolProvider(Base):
def user(self) -> Account | None: def user(self) -> Account | None:
if not self.user_id: if not self.user_id:
return None return None
return db.session.query(Account).filter(Account.id == self.user_id).first() return db.session.query(Account).where(Account.id == self.user_id).first()
@property @property
def tenant(self) -> Tenant | None: def tenant(self) -> Tenant | None:
return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
class ToolLabelBinding(Base): class ToolLabelBinding(Base):
@ -223,11 +223,11 @@ class WorkflowToolProvider(Base):
@property @property
def user(self) -> Account | None: def user(self) -> Account | None:
return db.session.query(Account).filter(Account.id == self.user_id).first() return db.session.query(Account).where(Account.id == self.user_id).first()
@property @property
def tenant(self) -> Tenant | None: def tenant(self) -> Tenant | None:
return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
@property @property
def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]: def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
@ -235,7 +235,7 @@ class WorkflowToolProvider(Base):
@property @property
def app(self) -> App | None: def app(self) -> App | None:
return db.session.query(App).filter(App.id == self.app_id).first() return db.session.query(App).where(App.id == self.app_id).first()
class MCPToolProvider(Base): class MCPToolProvider(Base):
@ -280,11 +280,11 @@ class MCPToolProvider(Base):
) )
def load_user(self) -> Account | None: def load_user(self) -> Account | None:
return db.session.query(Account).filter(Account.id == self.user_id).first() return db.session.query(Account).where(Account.id == self.user_id).first()
@property @property
def tenant(self) -> Tenant | None: def tenant(self) -> Tenant | None:
return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
@property @property
def credentials(self) -> dict: def credentials(self) -> dict:

@ -26,7 +26,7 @@ class SavedMessage(Base):
@property @property
def message(self): def message(self):
return db.session.query(Message).filter(Message.id == self.message_id).first() return db.session.query(Message).where(Message.id == self.message_id).first()
class PinnedConversation(Base): class PinnedConversation(Base):

@ -343,7 +343,7 @@ class Workflow(Base):
return ( return (
db.session.query(WorkflowToolProvider) db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id) .where(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id)
.count() .count()
> 0 > 0
) )
@ -549,12 +549,12 @@ class WorkflowRun(Base):
from models.model import Message from models.model import Message
return ( return (
db.session.query(Message).filter(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first() db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first()
) )
@property @property
def workflow(self): def workflow(self):
return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first() return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
def to_dict(self): def to_dict(self):
return { return {

@ -1,6 +1,6 @@
[project] [project]
name = "dify-api" name = "dify-api"
version = "1.6.0" version = "1.7.0"
requires-python = ">=3.11,<3.13" requires-python = ">=3.11,<3.13"
dependencies = [ dependencies = [

@ -21,7 +21,7 @@ def clean_embedding_cache_task():
try: try:
embedding_ids = ( embedding_ids = (
db.session.query(Embedding.id) db.session.query(Embedding.id)
.filter(Embedding.created_at < thirty_days_ago) .where(Embedding.created_at < thirty_days_ago)
.order_by(Embedding.created_at.desc()) .order_by(Embedding.created_at.desc())
.limit(100) .limit(100)
.all() .all()

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save