Merge branch 'fix/chore-fix' into dev/plugin-deploy

pull/12372/head
Yeuoly 1 year ago
commit f0178bd603

@ -191,7 +191,7 @@ class ModelInstance:
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast( return cast(
int, list[int],
self._round_robin_invoke( self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens, function=self.model_type_instance.get_num_tokens,
model=self.model, model=self.model,
@ -240,7 +240,7 @@ class ModelInstance:
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast( return cast(
int, list[int],
self._round_robin_invoke( self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens, function=self.model_type_instance.get_num_tokens,
model=self.model, model=self.model,

@ -1,7 +1,7 @@
import json import json
from collections import defaultdict from collections import defaultdict
from json import JSONDecodeError from json import JSONDecodeError
from typing import Optional, cast from typing import Any, Optional, cast
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
@ -350,7 +350,7 @@ class ProviderManager:
:param tenant_id: workspace id :param tenant_id: workspace id
:return: :return:
""" """
providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid is True).all() providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all()
provider_name_to_provider_records_dict = defaultdict(list) provider_name_to_provider_records_dict = defaultdict(list)
for provider in providers: for provider in providers:
@ -369,7 +369,7 @@ class ProviderManager:
# Get all provider model records of the workspace # Get all provider model records of the workspace
provider_models = ( provider_models = (
db.session.query(ProviderModel) db.session.query(ProviderModel)
.filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid is True) .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True)
.all() .all()
) )
@ -735,13 +735,14 @@ class ProviderManager:
) )
# Get cached provider credentials # Get cached provider credentials
# error occurs
cached_provider_credentials = provider_credentials_cache.get() cached_provider_credentials = provider_credentials_cache.get()
if not cached_provider_credentials: if not cached_provider_credentials:
try: try:
provider_credentials = json.loads(provider_record.encrypted_config) provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config)
except JSONDecodeError: except JSONDecodeError:
provider_credentials = {} provider_credentials: dict[str, Any] = {}
# Get provider credential secret variables # Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables( provider_credential_secret_variables = self._extract_secret_variables(
@ -758,7 +759,9 @@ class ProviderManager:
if variable in provider_credentials: if variable in provider_credentials:
try: try:
provider_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa provider_credentials.get(variable, ""),
self.decoding_rsa_key,
self.decoding_cipher_rsa,
) )
except ValueError: except ValueError:
pass pass

@ -88,7 +88,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
DocumentSegment.dataset_id.in_(self.dataset_ids), DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed", DocumentSegment.status == "completed",
DocumentSegment.enabled is True, DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids), DocumentSegment.index_node_id.in_(index_node_ids),
).all() ).all()
@ -109,8 +109,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
dataset = Dataset.query.filter_by(id=segment.dataset_id).first() dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = Document.query.filter( document = Document.query.filter(
Document.id == segment.document_id, Document.id == segment.document_id,
Document.enabled is True, Document.enabled == True,
Document.archived is False, Document.archived == False,
).first() ).first()
if dataset and document: if dataset and document:
source = { source = {

@ -7,7 +7,6 @@ from configs import dify_config
from core.entities.model_entities import ( from core.entities.model_entities import (
ModelWithProviderEntity, ModelWithProviderEntity,
ProviderModelWithStatusEntity, ProviderModelWithStatusEntity,
SimpleModelProviderEntity,
) )
from core.entities.provider_entities import ProviderQuotaType, QuotaConfiguration from core.entities.provider_entities import ProviderQuotaType, QuotaConfiguration
from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.common_entities import I18nObject
@ -162,7 +161,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity):
Model with provider entity. Model with provider entity.
""" """
provider: SimpleModelProviderEntity provider: SimpleProviderEntityResponse
def __init__(self, tenant_id: str, model: ModelWithProviderEntity) -> None: def __init__(self, tenant_id: str, model: ModelWithProviderEntity) -> None:
dump_model = model.model_dump() dump_model = model.model_dump()

Loading…
Cancel
Save