|
|
|
|
@ -1,5 +1,5 @@
|
|
|
|
|
from collections.abc import Callable
|
|
|
|
|
from datetime import UTC, datetime
|
|
|
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
|
from enum import Enum
|
|
|
|
|
from functools import wraps
|
|
|
|
|
from typing import Optional
|
|
|
|
|
@ -8,6 +8,8 @@ from flask import current_app, request
|
|
|
|
|
from flask_login import user_logged_in # type: ignore
|
|
|
|
|
from flask_restful import Resource # type: ignore
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
from sqlalchemy import select, update
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
from werkzeug.exceptions import Forbidden, Unauthorized
|
|
|
|
|
|
|
|
|
|
from extensions.ext_database import db
|
|
|
|
|
@ -174,7 +176,7 @@ def validate_dataset_token(view=None):
|
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_and_get_api_token(scope=None):
|
|
|
|
|
def validate_and_get_api_token(scope: str | None = None):
|
|
|
|
|
"""
|
|
|
|
|
Validate and get API token.
|
|
|
|
|
"""
|
|
|
|
|
@ -188,20 +190,25 @@ def validate_and_get_api_token(scope=None):
|
|
|
|
|
if auth_scheme != "bearer":
|
|
|
|
|
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
|
|
|
|
|
|
|
|
|
api_token = (
|
|
|
|
|
db.session.query(ApiToken)
|
|
|
|
|
.filter(
|
|
|
|
|
ApiToken.token == auth_token,
|
|
|
|
|
ApiToken.type == scope,
|
|
|
|
|
current_time = datetime.now(UTC).replace(tzinfo=None)
|
|
|
|
|
cutoff_time = current_time - timedelta(minutes=1)
|
|
|
|
|
with Session(db.engine, expire_on_commit=False) as session:
|
|
|
|
|
update_stmt = (
|
|
|
|
|
update(ApiToken)
|
|
|
|
|
.where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope)
|
|
|
|
|
.values(last_used_at=current_time)
|
|
|
|
|
.returning(ApiToken)
|
|
|
|
|
)
|
|
|
|
|
.first()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not api_token:
|
|
|
|
|
raise Unauthorized("Access token is invalid")
|
|
|
|
|
|
|
|
|
|
api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None)
|
|
|
|
|
db.session.commit()
|
|
|
|
|
result = session.execute(update_stmt)
|
|
|
|
|
api_token = result.scalar_one_or_none()
|
|
|
|
|
|
|
|
|
|
if not api_token:
|
|
|
|
|
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
|
|
|
|
api_token = session.scalar(stmt)
|
|
|
|
|
if not api_token:
|
|
|
|
|
raise Unauthorized("Access token is invalid")
|
|
|
|
|
else:
|
|
|
|
|
session.commit()
|
|
|
|
|
|
|
|
|
|
return api_token
|
|
|
|
|
|
|
|
|
|
|