merge main
commit
009cb2a650
@ -0,0 +1,7 @@
|
|||||||
|
# Ensure that .sh scripts use LF as line separator, even if they are checked out
|
||||||
|
# to Windows(NTFS) file-system, by a user of Docker for Window.
|
||||||
|
# These .sh scripts will be run from the Container after `docker compose up -d`.
|
||||||
|
# If they appear to be CRLF style, Dash from the Container will fail to execute
|
||||||
|
# them.
|
||||||
|
|
||||||
|
*.sh text eol=lf
|
||||||
@ -0,0 +1,3 @@
|
|||||||
|
from .app_config import DifyConfig
|
||||||
|
|
||||||
|
dify_config = DifyConfig()
|
||||||
@ -0,0 +1,36 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class OCIStorageConfig(BaseSettings):
|
||||||
|
"""
|
||||||
|
OCI storage configs
|
||||||
|
"""
|
||||||
|
|
||||||
|
OCI_ENDPOINT: Optional[str] = Field(
|
||||||
|
description='OCI storage endpoint',
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
OCI_REGION: Optional[str] = Field(
|
||||||
|
description='OCI storage region',
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
OCI_BUCKET_NAME: Optional[str] = Field(
|
||||||
|
description='OCI storage bucket name',
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
OCI_ACCESS_KEY: Optional[str] = Field(
|
||||||
|
description='OCI storage access key',
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
OCI_SECRET_KEY: Optional[str] = Field(
|
||||||
|
description='OCI storage secret key',
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
@ -0,0 +1,44 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class AnalyticdbConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Configuration for connecting to AnalyticDB.
|
||||||
|
Refer to the following documentation for details on obtaining credentials:
|
||||||
|
https://www.alibabacloud.com/help/en/analyticdb-for-postgresql/getting-started/create-an-instance-instances-with-vector-engine-optimization-enabled
|
||||||
|
"""
|
||||||
|
|
||||||
|
ANALYTICDB_KEY_ID : Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The Access Key ID provided by Alibaba Cloud for authentication."
|
||||||
|
)
|
||||||
|
ANALYTICDB_KEY_SECRET : Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The Secret Access Key corresponding to the Access Key ID for secure access."
|
||||||
|
)
|
||||||
|
ANALYTICDB_REGION_ID : Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The region where the AnalyticDB instance is deployed (e.g., 'cn-hangzhou')."
|
||||||
|
)
|
||||||
|
ANALYTICDB_INSTANCE_ID : Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The unique identifier of the AnalyticDB instance you want to connect to (e.g., 'gp-ab123456').."
|
||||||
|
)
|
||||||
|
ANALYTICDB_ACCOUNT : Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The account name used to log in to the AnalyticDB instance."
|
||||||
|
)
|
||||||
|
ANALYTICDB_PASSWORD : Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The password associated with the AnalyticDB account for authentication."
|
||||||
|
)
|
||||||
|
ANALYTICDB_NAMESPACE : Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The namespace within AnalyticDB for schema isolation."
|
||||||
|
)
|
||||||
|
ANALYTICDB_NAMESPACE_PASSWORD : Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The password for accessing the specified namespace within the AnalyticDB instance."
|
||||||
|
)
|
||||||
File diff suppressed because one or more lines are too long
@ -0,0 +1,4 @@
|
|||||||
|
TTS_AUTO_PLAY_TIMEOUT = 5
|
||||||
|
|
||||||
|
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||||
|
TTS_AUTO_PLAY_YIELD_CPU_TIME = 0.02
|
||||||
@ -0,0 +1,107 @@
|
|||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
from flask_restful import Resource, reqparse
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.auth.error import (
|
||||||
|
InvalidEmailError,
|
||||||
|
InvalidTokenError,
|
||||||
|
PasswordMismatchError,
|
||||||
|
PasswordResetRateLimitExceededError,
|
||||||
|
)
|
||||||
|
from controllers.console.setup import setup_required
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.helper import email as email_validate
|
||||||
|
from libs.password import hash_password, valid_password
|
||||||
|
from models.account import Account
|
||||||
|
from services.account_service import AccountService
|
||||||
|
from services.errors.account import RateLimitExceededError
|
||||||
|
|
||||||
|
|
||||||
|
class ForgotPasswordSendEmailApi(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('email', type=str, required=True, location='json')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
email = args['email']
|
||||||
|
|
||||||
|
if not email_validate(email):
|
||||||
|
raise InvalidEmailError()
|
||||||
|
|
||||||
|
account = Account.query.filter_by(email=email).first()
|
||||||
|
|
||||||
|
if account:
|
||||||
|
try:
|
||||||
|
AccountService.send_reset_password_email(account=account)
|
||||||
|
except RateLimitExceededError:
|
||||||
|
logging.warning(f"Rate limit exceeded for email: {account.email}")
|
||||||
|
raise PasswordResetRateLimitExceededError()
|
||||||
|
else:
|
||||||
|
# Return success to avoid revealing email registration status
|
||||||
|
logging.warning(f"Attempt to reset password for unregistered email: {email}")
|
||||||
|
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
class ForgotPasswordCheckApi(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
|
||||||
|
args = parser.parse_args()
|
||||||
|
token = args['token']
|
||||||
|
|
||||||
|
reset_data = AccountService.get_reset_password_data(token)
|
||||||
|
|
||||||
|
if reset_data is None:
|
||||||
|
return {'is_valid': False, 'email': None}
|
||||||
|
return {'is_valid': True, 'email': reset_data.get('email')}
|
||||||
|
|
||||||
|
|
||||||
|
class ForgotPasswordResetApi(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('token', type=str, required=True, nullable=False, location='json')
|
||||||
|
parser.add_argument('new_password', type=valid_password, required=True, nullable=False, location='json')
|
||||||
|
parser.add_argument('password_confirm', type=valid_password, required=True, nullable=False, location='json')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
new_password = args['new_password']
|
||||||
|
password_confirm = args['password_confirm']
|
||||||
|
|
||||||
|
if str(new_password).strip() != str(password_confirm).strip():
|
||||||
|
raise PasswordMismatchError()
|
||||||
|
|
||||||
|
token = args['token']
|
||||||
|
reset_data = AccountService.get_reset_password_data(token)
|
||||||
|
|
||||||
|
if reset_data is None:
|
||||||
|
raise InvalidTokenError()
|
||||||
|
|
||||||
|
AccountService.revoke_reset_password_token(token)
|
||||||
|
|
||||||
|
salt = secrets.token_bytes(16)
|
||||||
|
base64_salt = base64.b64encode(salt).decode()
|
||||||
|
|
||||||
|
password_hashed = hash_password(new_password, salt)
|
||||||
|
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||||
|
|
||||||
|
account = Account.query.filter_by(email=reset_data.get('email')).first()
|
||||||
|
account.password = base64_password_hashed
|
||||||
|
account.password_salt = base64_salt
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return {'result': 'success'}
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(ForgotPasswordSendEmailApi, '/forgot-password')
|
||||||
|
api.add_resource(ForgotPasswordCheckApi, '/forgot-password/validity')
|
||||||
|
api.add_resource(ForgotPasswordResetApi, '/forgot-password/resets')
|
||||||
@ -0,0 +1,135 @@
|
|||||||
|
import base64
|
||||||
|
import concurrent.futures
|
||||||
|
import logging
|
||||||
|
import queue
|
||||||
|
import re
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueTextChunkEvent
|
||||||
|
from core.model_manager import ModelManager
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
|
||||||
|
|
||||||
|
class AudioTrunk:
|
||||||
|
def __init__(self, status: str, audio):
|
||||||
|
self.audio = audio
|
||||||
|
self.status = status
|
||||||
|
|
||||||
|
|
||||||
|
def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||||
|
if not text_content or text_content.isspace():
|
||||||
|
return
|
||||||
|
return model_instance.invoke_tts(
|
||||||
|
content_text=text_content.strip(),
|
||||||
|
user="responding_tts",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
voice=voice
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _process_future(future_queue, audio_queue):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
future = future_queue.get()
|
||||||
|
if future is None:
|
||||||
|
break
|
||||||
|
for audio in future.result():
|
||||||
|
audio_base64 = base64.b64encode(bytes(audio))
|
||||||
|
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
|
||||||
|
except Exception as e:
|
||||||
|
logging.getLogger(__name__).warning(e)
|
||||||
|
break
|
||||||
|
audio_queue.put(AudioTrunk("finish", b''))
|
||||||
|
|
||||||
|
|
||||||
|
class AppGeneratorTTSPublisher:
|
||||||
|
|
||||||
|
def __init__(self, tenant_id: str, voice: str):
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.msg_text = ''
|
||||||
|
self._audio_queue = queue.Queue()
|
||||||
|
self._msg_queue = queue.Queue()
|
||||||
|
self.match = re.compile(r'[。.!?]')
|
||||||
|
self.model_manager = ModelManager()
|
||||||
|
self.model_instance = self.model_manager.get_default_model_instance(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
model_type=ModelType.TTS
|
||||||
|
)
|
||||||
|
self.voices = self.model_instance.get_tts_voices()
|
||||||
|
values = [voice.get('value') for voice in self.voices]
|
||||||
|
self.voice = voice
|
||||||
|
if not voice or voice not in values:
|
||||||
|
self.voice = self.voices[0].get('value')
|
||||||
|
self.MAX_SENTENCE = 2
|
||||||
|
self._last_audio_event = None
|
||||||
|
self._runtime_thread = threading.Thread(target=self._runtime).start()
|
||||||
|
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
||||||
|
|
||||||
|
def publish(self, message):
|
||||||
|
try:
|
||||||
|
self._msg_queue.put(message)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(e)
|
||||||
|
|
||||||
|
def _runtime(self):
|
||||||
|
future_queue = queue.Queue()
|
||||||
|
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
message = self._msg_queue.get()
|
||||||
|
if message is None:
|
||||||
|
if self.msg_text and len(self.msg_text.strip()) > 0:
|
||||||
|
futures_result = self.executor.submit(_invoiceTTS, self.msg_text,
|
||||||
|
self.model_instance, self.tenant_id, self.voice)
|
||||||
|
future_queue.put(futures_result)
|
||||||
|
break
|
||||||
|
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
|
||||||
|
self.msg_text += message.event.chunk.delta.message.content
|
||||||
|
elif isinstance(message.event, QueueTextChunkEvent):
|
||||||
|
self.msg_text += message.event.text
|
||||||
|
self.last_message = message
|
||||||
|
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
||||||
|
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
|
||||||
|
self.MAX_SENTENCE += 1
|
||||||
|
text_content = ''.join(sentence_arr)
|
||||||
|
futures_result = self.executor.submit(_invoiceTTS, text_content,
|
||||||
|
self.model_instance,
|
||||||
|
self.tenant_id,
|
||||||
|
self.voice)
|
||||||
|
future_queue.put(futures_result)
|
||||||
|
if text_tmp:
|
||||||
|
self.msg_text = text_tmp
|
||||||
|
else:
|
||||||
|
self.msg_text = ''
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(e)
|
||||||
|
break
|
||||||
|
future_queue.put(None)
|
||||||
|
|
||||||
|
def checkAndGetAudio(self) -> AudioTrunk | None:
|
||||||
|
try:
|
||||||
|
if self._last_audio_event and self._last_audio_event.status == "finish":
|
||||||
|
if self.executor:
|
||||||
|
self.executor.shutdown(wait=False)
|
||||||
|
return self.last_message
|
||||||
|
audio = self._audio_queue.get_nowait()
|
||||||
|
if audio and audio.status == "finish":
|
||||||
|
self.executor.shutdown(wait=False)
|
||||||
|
self._runtime_thread = None
|
||||||
|
if audio:
|
||||||
|
self._last_audio_event = audio
|
||||||
|
return audio
|
||||||
|
except queue.Empty:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _extract_sentence(self, org_text):
|
||||||
|
tx = self.match.finditer(org_text)
|
||||||
|
start = 0
|
||||||
|
result = []
|
||||||
|
for i in tx:
|
||||||
|
end = i.regs[0][1]
|
||||||
|
result.append(org_text[start:end])
|
||||||
|
start = end
|
||||||
|
return result, org_text[start:]
|
||||||
@ -1,52 +1,56 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from core.app.app_config.entities import AppConfig, VariableEntity
|
from core.app.app_config.entities import AppConfig, VariableEntity
|
||||||
|
|
||||||
|
|
||||||
class BaseAppGenerator:
|
class BaseAppGenerator:
|
||||||
def _get_cleaned_inputs(self, user_inputs: dict, app_config: AppConfig):
|
def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_config: AppConfig) -> Mapping[str, Any]:
|
||||||
if user_inputs is None:
|
user_inputs = user_inputs or {}
|
||||||
user_inputs = {}
|
|
||||||
|
|
||||||
filtered_inputs = {}
|
|
||||||
|
|
||||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||||
variables = app_config.variables
|
variables = app_config.variables
|
||||||
for variable_config in variables:
|
filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
||||||
variable = variable_config.variable
|
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
|
||||||
|
|
||||||
if (variable not in user_inputs
|
|
||||||
or user_inputs[variable] is None
|
|
||||||
or (isinstance(user_inputs[variable], str) and user_inputs[variable] == '')):
|
|
||||||
if variable_config.required:
|
|
||||||
raise ValueError(f"{variable} is required in input form")
|
|
||||||
else:
|
|
||||||
filtered_inputs[variable] = variable_config.default if variable_config.default is not None else ""
|
|
||||||
continue
|
|
||||||
|
|
||||||
value = user_inputs[variable]
|
|
||||||
|
|
||||||
if value is not None:
|
|
||||||
if variable_config.type != VariableEntity.Type.NUMBER and not isinstance(value, str):
|
|
||||||
raise ValueError(f"{variable} in input form must be a string")
|
|
||||||
elif variable_config.type == VariableEntity.Type.NUMBER and isinstance(value, str):
|
|
||||||
if '.' in value:
|
|
||||||
value = float(value)
|
|
||||||
else:
|
|
||||||
value = int(value)
|
|
||||||
|
|
||||||
if variable_config.type == VariableEntity.Type.SELECT:
|
|
||||||
options = variable_config.options if variable_config.options is not None else []
|
|
||||||
if value not in options:
|
|
||||||
raise ValueError(f"{variable} in input form must be one of the following: {options}")
|
|
||||||
elif variable_config.type in [VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH]:
|
|
||||||
if variable_config.max_length is not None:
|
|
||||||
max_length = variable_config.max_length
|
|
||||||
if len(value) > max_length:
|
|
||||||
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
|
|
||||||
|
|
||||||
if value and isinstance(value, str):
|
|
||||||
filtered_inputs[variable] = value.replace('\x00', '')
|
|
||||||
else:
|
|
||||||
filtered_inputs[variable] = value if value is not None else None
|
|
||||||
|
|
||||||
return filtered_inputs
|
return filtered_inputs
|
||||||
|
|
||||||
|
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
|
||||||
|
user_input_value = inputs.get(var.name)
|
||||||
|
if var.required and not user_input_value:
|
||||||
|
raise ValueError(f'{var.name} is required in input form')
|
||||||
|
if not var.required and not user_input_value:
|
||||||
|
# TODO: should we return None here if the default value is None?
|
||||||
|
return var.default or ''
|
||||||
|
if (
|
||||||
|
var.type
|
||||||
|
in (
|
||||||
|
VariableEntity.Type.TEXT_INPUT,
|
||||||
|
VariableEntity.Type.SELECT,
|
||||||
|
VariableEntity.Type.PARAGRAPH,
|
||||||
|
)
|
||||||
|
and user_input_value
|
||||||
|
and not isinstance(user_input_value, str)
|
||||||
|
):
|
||||||
|
raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string")
|
||||||
|
if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str):
|
||||||
|
# may raise ValueError if user_input_value is not a valid number
|
||||||
|
try:
|
||||||
|
if '.' in user_input_value:
|
||||||
|
return float(user_input_value)
|
||||||
|
else:
|
||||||
|
return int(user_input_value)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"{var.name} in input form must be a valid number")
|
||||||
|
if var.type == VariableEntity.Type.SELECT:
|
||||||
|
options = var.options or []
|
||||||
|
if user_input_value not in options:
|
||||||
|
raise ValueError(f'{var.name} in input form must be one of the following: {options}')
|
||||||
|
elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH):
|
||||||
|
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
|
||||||
|
raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters')
|
||||||
|
|
||||||
|
return user_input_value
|
||||||
|
|
||||||
|
def _sanitize_value(self, value: Any) -> Any:
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value.replace('\x00', '')
|
||||||
|
return value
|
||||||
|
|||||||
@ -0,0 +1 @@
|
|||||||
|
from .rate_limit import RateLimit
|
||||||
@ -0,0 +1,120 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Generator
|
||||||
|
from datetime import timedelta
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from core.errors.error import AppInvokeQuotaExceededError
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimit:
|
||||||
|
_MAX_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:max_active_requests"
|
||||||
|
_ACTIVE_REQUESTS_KEY = "dify:rate_limit:{}:active_requests"
|
||||||
|
_UNLIMITED_REQUEST_ID = "unlimited_request_id"
|
||||||
|
_REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes
|
||||||
|
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
|
||||||
|
_instance_dict = {}
|
||||||
|
|
||||||
|
def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int):
|
||||||
|
if client_id not in cls._instance_dict:
|
||||||
|
instance = super().__new__(cls)
|
||||||
|
cls._instance_dict[client_id] = instance
|
||||||
|
return cls._instance_dict[client_id]
|
||||||
|
|
||||||
|
def __init__(self, client_id: str, max_active_requests: int):
|
||||||
|
self.max_active_requests = max_active_requests
|
||||||
|
if hasattr(self, 'initialized'):
|
||||||
|
return
|
||||||
|
self.initialized = True
|
||||||
|
self.client_id = client_id
|
||||||
|
self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id)
|
||||||
|
self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id)
|
||||||
|
self.last_recalculate_time = float('-inf')
|
||||||
|
self.flush_cache(use_local_value=True)
|
||||||
|
|
||||||
|
def flush_cache(self, use_local_value=False):
|
||||||
|
self.last_recalculate_time = time.time()
|
||||||
|
# flush max active requests
|
||||||
|
if use_local_value or not redis_client.exists(self.max_active_requests_key):
|
||||||
|
with redis_client.pipeline() as pipe:
|
||||||
|
pipe.set(self.max_active_requests_key, self.max_active_requests)
|
||||||
|
pipe.expire(self.max_active_requests_key, timedelta(days=1))
|
||||||
|
pipe.execute()
|
||||||
|
else:
|
||||||
|
with redis_client.pipeline() as pipe:
|
||||||
|
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8'))
|
||||||
|
redis_client.expire(self.max_active_requests_key, timedelta(days=1))
|
||||||
|
|
||||||
|
# flush max active requests (in-transit request list)
|
||||||
|
if not redis_client.exists(self.active_requests_key):
|
||||||
|
return
|
||||||
|
request_details = redis_client.hgetall(self.active_requests_key)
|
||||||
|
redis_client.expire(self.active_requests_key, timedelta(days=1))
|
||||||
|
timeout_requests = [k for k, v in request_details.items() if
|
||||||
|
time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME]
|
||||||
|
if timeout_requests:
|
||||||
|
redis_client.hdel(self.active_requests_key, *timeout_requests)
|
||||||
|
|
||||||
|
def enter(self, request_id: Optional[str] = None) -> str:
|
||||||
|
if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL:
|
||||||
|
self.flush_cache()
|
||||||
|
if self.max_active_requests <= 0:
|
||||||
|
return RateLimit._UNLIMITED_REQUEST_ID
|
||||||
|
if not request_id:
|
||||||
|
request_id = RateLimit.gen_request_key()
|
||||||
|
|
||||||
|
active_requests_count = redis_client.hlen(self.active_requests_key)
|
||||||
|
if active_requests_count >= self.max_active_requests:
|
||||||
|
raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum "
|
||||||
|
"concurrent requests allowed is {}.".format(self.max_active_requests))
|
||||||
|
redis_client.hset(self.active_requests_key, request_id, str(time.time()))
|
||||||
|
return request_id
|
||||||
|
|
||||||
|
def exit(self, request_id: str):
|
||||||
|
if request_id == RateLimit._UNLIMITED_REQUEST_ID:
|
||||||
|
return
|
||||||
|
redis_client.hdel(self.active_requests_key, request_id)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def gen_request_key() -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
def generate(self, generator: Union[Generator, callable, dict], request_id: str):
|
||||||
|
if isinstance(generator, dict):
|
||||||
|
return generator
|
||||||
|
else:
|
||||||
|
return RateLimitGenerator(self, generator, request_id)
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitGenerator:
|
||||||
|
def __init__(self, rate_limit: RateLimit, generator: Union[Generator, callable], request_id: str):
|
||||||
|
self.rate_limit = rate_limit
|
||||||
|
if callable(generator):
|
||||||
|
self.generator = generator()
|
||||||
|
else:
|
||||||
|
self.generator = generator
|
||||||
|
self.request_id = request_id
|
||||||
|
self.closed = False
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.closed:
|
||||||
|
raise StopIteration
|
||||||
|
try:
|
||||||
|
return next(self.generator)
|
||||||
|
except StopIteration:
|
||||||
|
self.close()
|
||||||
|
raise
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if not self.closed:
|
||||||
|
self.closed = True
|
||||||
|
self.rate_limit.exit(self.request_id)
|
||||||
|
if self.generator is not None and hasattr(self.generator, 'close'):
|
||||||
|
self.generator.close()
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue