|
|
import logging
|
|
|
import yaml
|
|
|
from typing import Optional
|
|
|
import flask_login
|
|
|
from pathlib import Path
|
|
|
from constants.languages import languages
|
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
|
from core.plugin.manager.exc import PluginDaemonClientSideError
|
|
|
from extensions.ext_database import db
|
|
|
from models.account import (
|
|
|
Account,
|
|
|
Tenant,
|
|
|
)
|
|
|
from services.account_service import AccountService, TenantService
|
|
|
from services.dataset_service import DatasetService
|
|
|
from services.errors.account import (
|
|
|
AccountRegisterError,
|
|
|
TenantNotFoundError,
|
|
|
)
|
|
|
from services.errors.workspace import WorkSpaceNotAllowedCreateError
|
|
|
from services.ext.dataset_ext_service import DatasetExtService
|
|
|
from services.model_load_balancing_service import ModelLoadBalancingService
|
|
|
from services.model_provider_service import ModelProviderService
|
|
|
from services.plugin.plugin_service import PluginService
|
|
|
from configs import dify_config
|
|
|
from configs.ext_config import get_ext_config
|
|
|
import os
|
|
|
|
|
|
class AccountInfo:
|
|
|
def __init__(self, email, name, user_id, tenant_id):
|
|
|
self.email = email
|
|
|
self.name = name
|
|
|
self.user_id=user_id
|
|
|
self.tenant_id=tenant_id
|
|
|
def to_dict(self):
|
|
|
return {
|
|
|
"tenant_id": self.email,
|
|
|
"tenant_name": self.name,
|
|
|
"api_key": self.tenant_id,
|
|
|
}
|
|
|
|
|
|
class TenantAccountInfo:
|
|
|
def __init__(self, tenant_id:str,
|
|
|
tenant_name:str,
|
|
|
admin_account:str,
|
|
|
admin_account_password:str,
|
|
|
):
|
|
|
self.tenant_id = tenant_id
|
|
|
self.tenant_name = tenant_name
|
|
|
self.admin_account=admin_account
|
|
|
self.admin_account_password=admin_account_password
|
|
|
|
|
|
def to_dict(self):
|
|
|
return {
|
|
|
"tenant_id": self.tenant_id,
|
|
|
"tenant_name": self.tenant_name,
|
|
|
"admin_account": self.admin_account,
|
|
|
"admin_account_password": self.admin_account_password,
|
|
|
}
|
|
|
|
|
|
class TenantData:
|
|
|
def __init__(self,
|
|
|
api_key:str,
|
|
|
dataset_ids:list[str]
|
|
|
):
|
|
|
self.api_key=api_key
|
|
|
self.dataset_ids=dataset_ids
|
|
|
|
|
|
def to_dict(self):
|
|
|
return {
|
|
|
"api_key": self.api_key,
|
|
|
"dataset_ids": self.dataset_ids,
|
|
|
}
|
|
|
|
|
|
class AccountExtService:
|
|
|
|
|
|
@staticmethod
|
|
|
def create_account_and_tenant(
|
|
|
email: str,
|
|
|
name: str,
|
|
|
tenant_name: str,
|
|
|
target_tenant_id: str,
|
|
|
interface_language: Optional[str] = None,
|
|
|
password: Optional[str] = None
|
|
|
) -> Account:
|
|
|
"""create account"""
|
|
|
account = AccountService.create_account(
|
|
|
email=email, name=name, interface_language=interface_language, password=password, is_setup=True
|
|
|
)
|
|
|
account.target_tenant_id = target_tenant_id
|
|
|
TenantService.create_owner_tenant_if_not_exist(account=account,name=tenant_name,is_setup=True)
|
|
|
account.current_tenant.target_tenant_id = target_tenant_id
|
|
|
db.session.commit()
|
|
|
return account
|
|
|
|
|
|
@staticmethod
|
|
|
def get_admin_account() -> Account:
|
|
|
admin = db.session.query(Account).filter(Account.target_tenant_id=="100").first()
|
|
|
return admin
|
|
|
|
|
|
@staticmethod
|
|
|
def update_account_list(
|
|
|
accounts: list[AccountInfo],
|
|
|
target_tenant_id: str,
|
|
|
interface_language: Optional[str] = None,
|
|
|
):
|
|
|
db.session.begin_nested()
|
|
|
"""Register account"""
|
|
|
try:
|
|
|
# 获取对应的企业
|
|
|
tenant = TenantExtService.get_tenant_by_target_tenant_id(target_tenant_id=target_tenant_id)
|
|
|
if tenant is None:
|
|
|
raise TenantNotFoundError("企业未初始,请联系管理员!")
|
|
|
# 获取所有的用户列表
|
|
|
exists = db.session.query(Account).filter(Account.target_tenant_id == target_tenant_id).all()
|
|
|
#
|
|
|
existDict = { account.email: account for account in exists }
|
|
|
|
|
|
for account in accounts:
|
|
|
email = account["email"]
|
|
|
if email in existDict:
|
|
|
existAccount = existDict[email]
|
|
|
existAccount.name = account["name"]
|
|
|
existAccount.email = account["email"]
|
|
|
existAccount.user_id = account["user_id"]
|
|
|
else:
|
|
|
newAccount = AccountService.create_account(email=account["email"],
|
|
|
name=account["name"],
|
|
|
interface_language=interface_language or languages[0],
|
|
|
password="wisdom@123",
|
|
|
is_setup=True)
|
|
|
newAccount.user_id = account["user_id"]
|
|
|
newAccount.target_tenant_id = target_tenant_id
|
|
|
# 创建企业关系
|
|
|
TenantService.create_tenant_member(tenant, newAccount)
|
|
|
db.session.commit()
|
|
|
except WorkSpaceNotAllowedCreateError:
|
|
|
db.session.rollback()
|
|
|
except AccountRegisterError as are:
|
|
|
db.session.rollback()
|
|
|
logging.exception("Register failed")
|
|
|
raise are
|
|
|
except Exception as e:
|
|
|
db.session.rollback()
|
|
|
logging.exception("Register failed")
|
|
|
raise AccountRegisterError(f"Registration failed: {e}") from e
|
|
|
|
|
|
|
|
|
class TenantExtService:
|
|
|
|
|
|
@staticmethod
|
|
|
def get_tenant() -> Tenant:
|
|
|
# 获取第一个企业,为默认企业
|
|
|
tenant = db.session.query(Tenant).first()
|
|
|
return tenant
|
|
|
|
|
|
@staticmethod
|
|
|
def get_tenant_by_target_tenant_id(target_tenant_id:str) -> Tenant:
|
|
|
# 获取第一个企业,为默认企业
|
|
|
tenant = db.session.query(Tenant).filter(Tenant.target_tenant_id == target_tenant_id).first()
|
|
|
return tenant
|
|
|
|
|
|
@staticmethod
|
|
|
def setModeConfig(tenant_id:str, args:dict[str, object], provider:str) -> None:
|
|
|
|
|
|
model_load_balancing_service = ModelLoadBalancingService()
|
|
|
if (
|
|
|
"load_balancing" in args
|
|
|
and args["load_balancing"]
|
|
|
and "enabled" in args["load_balancing"]
|
|
|
and args["load_balancing"]["enabled"]
|
|
|
):
|
|
|
if "configs" not in args["load_balancing"]:
|
|
|
raise ValueError("invalid load balancing configs")
|
|
|
|
|
|
# save load balancing configs
|
|
|
model_load_balancing_service.update_load_balancing_configs(
|
|
|
tenant_id=tenant_id,
|
|
|
provider=provider,
|
|
|
model=args["model"],
|
|
|
model_type=args["model_type"],
|
|
|
configs=args["load_balancing"]["configs"],
|
|
|
)
|
|
|
|
|
|
# enable load balancing
|
|
|
model_load_balancing_service.enable_model_load_balancing(
|
|
|
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
|
|
)
|
|
|
else:
|
|
|
# disable load balancing
|
|
|
model_load_balancing_service.disable_model_load_balancing(
|
|
|
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
|
|
|
)
|
|
|
|
|
|
if args.get("config_from", "") != "predefined-model":
|
|
|
model_provider_service = ModelProviderService()
|
|
|
|
|
|
try:
|
|
|
model_provider_service.save_model_credentials(
|
|
|
tenant_id=tenant_id,
|
|
|
provider=provider,
|
|
|
model=args["model"],
|
|
|
model_type=args["model_type"],
|
|
|
credentials=args["credentials"],
|
|
|
)
|
|
|
except CredentialsValidateFailedError as ex:
|
|
|
logging.exception(
|
|
|
f"Failed to save model credentials, tenant_id: {tenant_id},"
|
|
|
f" model: {args.get('model')}, model_type: {args.get('model_type')}"
|
|
|
)
|
|
|
raise ValueError(str(ex))
|
|
|
|
|
|
@staticmethod
|
|
|
def install_plugin(tenant_id:str):
|
|
|
TenantExtService.install_langgenius(tenant_id=tenant_id)
|
|
|
TenantExtService.install_model(tenant_id=tenant_id)
|
|
|
|
|
|
@staticmethod
|
|
|
def install_model(tenant_id:str):
|
|
|
|
|
|
params = {
|
|
|
"INIT_MODEL_LLM_NAME" : dify_config.INIT_MODEL_LLM_NAME,
|
|
|
"INIT_MODEL_LLM_CONTEXT_SIZE" : dify_config.INIT_MODEL_LLM_CONTEXT_SIZE,
|
|
|
"INIT_MODEL_LLM_MAX_TOKENS" : dify_config.INIT_MODEL_LLM_MAX_TOKENS,
|
|
|
"INIT_MODEL_LLM_BASE_URL" : dify_config.INIT_MODEL_LLM_BASE_URL
|
|
|
}
|
|
|
llm_config = get_ext_config(file_name="plugin_llm_config.yml",params = params)
|
|
|
|
|
|
TenantExtService.setModeConfig(
|
|
|
tenant_id=tenant_id,args=llm_config,provider=dify_config.INIT_MODEL_LLM_PROVIDER
|
|
|
)
|
|
|
|
|
|
params = {
|
|
|
"INIT_MODEL_TEXT_EMBEDDING_NAME" : dify_config.INIT_MODEL_TEXT_EMBEDDING_NAME,
|
|
|
"INIT_MODEL_TEXT_EMBEDDING_CONTEXT_SIZE" : dify_config.INIT_MODEL_TEXT_EMBEDDING_CONTEXT_SIZE,
|
|
|
"INIT_MODEL_TEXT_EMBEDDING_MAX_TOKENS" : dify_config.INIT_MODEL_TEXT_EMBEDDING_MAX_TOKENS,
|
|
|
"INIT_MODEL_TEXT_EMBEDDING_BASE_URL" : dify_config.INIT_MODEL_TEXT_EMBEDDING_BASE_URL
|
|
|
}
|
|
|
text_embedding_config = get_ext_config(file_name="plugin_embedding_config.yml", params=params)
|
|
|
TenantExtService.setModeConfig(
|
|
|
tenant_id=tenant_id,args=text_embedding_config,provider=dify_config.INIT_MODEL_TEXT_EMBEDDING_PROVIDER
|
|
|
)
|
|
|
params = {
|
|
|
"INIT_MODEL_TEXT_EMBEDDING_RERANK_NAME": dify_config.INIT_MODEL_TEXT_EMBEDDING_RERANK_NAME,
|
|
|
"INIT_MODEL_TEXT_EMBEDDING_RERANK_BASE_URL": dify_config.INIT_MODEL_TEXT_EMBEDDING_RERANK_BASE_URL,
|
|
|
}
|
|
|
text_embedding_rerank_config = get_ext_config(file_name="plugin_embedding_rerank_config.yml", params=params)
|
|
|
TenantExtService.setModeConfig(
|
|
|
tenant_id=tenant_id,args=text_embedding_rerank_config,provider=dify_config.INIT_MODEL_TEXT_EMBEDDING_RERANK_PROVIDER
|
|
|
)
|
|
|
|
|
|
@staticmethod
|
|
|
def install_langgenius(tenant_id: str):
|
|
|
upload_unique_identifiers = TenantExtService.upload_langgenius(tenant_id=tenant_id)
|
|
|
# plugin_unique_identifiers = dify_config.PLUGIN_UNIQUE_IDENTIFIERS.split(",") if dify_config.PLUGIN_UNIQUE_IDENTIFIERS else []
|
|
|
# 查询已经安装的
|
|
|
tasks = PluginService.list(tenant_id)
|
|
|
# 已经安装的插件
|
|
|
exists_plugin_unique_identifiers = [item.plugin_unique_identifier for item in tasks]
|
|
|
# 去除已经安装的插件ID,只保留未安装的插件ID
|
|
|
new_unique_identifiers = [uui for uui in upload_unique_identifiers if uui not in exists_plugin_unique_identifiers]
|
|
|
# 安装插件
|
|
|
PluginService.install_from_marketplace_pkg(tenant_id, new_unique_identifiers)
|
|
|
|
|
|
@staticmethod
|
|
|
def upload_langgenius(tenant_id: str) -> list[str]:
|
|
|
directory = Path(__file__).parent.parent.parent / 'plugins' / 'langgenius'
|
|
|
unique_identifiers = []
|
|
|
for filename in os.listdir(directory):
|
|
|
file_path = os.path.join(directory, filename)
|
|
|
if os.path.isfile(file_path):
|
|
|
print(f"读取文件:{file_path}")
|
|
|
with open(file_path, 'rb') as f:
|
|
|
content = f.read()
|
|
|
try:
|
|
|
response = PluginService.upload_pkg(tenant_id=tenant_id, pkg=content)
|
|
|
unique_identifier = response.unique_identifier
|
|
|
unique_identifiers.append(unique_identifier)
|
|
|
except PluginDaemonClientSideError as e:
|
|
|
raise ValueError(e)
|
|
|
return unique_identifiers
|
|
|
|
|
|
@staticmethod
|
|
|
def enable_tenant(
|
|
|
target_tenant_id: str,
|
|
|
target_tenant_name: str,
|
|
|
) -> TenantAccountInfo:
|
|
|
db.session.begin_nested()
|
|
|
password = "wisdom@123"
|
|
|
try:
|
|
|
email = f"admin@{target_tenant_id}.com"
|
|
|
admin_name = f"{target_tenant_name}-管理员"
|
|
|
# 判断企业是否已经创建
|
|
|
tenant = TenantExtService.get_tenant_by_target_tenant_id(target_tenant_id)
|
|
|
if tenant is not None:
|
|
|
account = AccountService.get_user_through_email(email)
|
|
|
if account is None:
|
|
|
account = AccountService.create_account(email=email, name=admin_name, password=password, is_setup=True,interface_language="zh-Hans")
|
|
|
TenantService.create_tenant_member(tenant, account, role="owner")
|
|
|
account.target_tenant_id = target_tenant_id
|
|
|
else:
|
|
|
account = AccountExtService.create_account_and_tenant(email=email,
|
|
|
name=admin_name,
|
|
|
tenant_name=target_tenant_name,
|
|
|
target_tenant_id=target_tenant_id,
|
|
|
interface_language="zh-Hans",
|
|
|
password=password)
|
|
|
# 获取第一个企业,为默认企业
|
|
|
tenant = account.current_tenant
|
|
|
|
|
|
account_info = TenantAccountInfo(tenant_name=tenant.name,
|
|
|
tenant_id=tenant.id,
|
|
|
admin_account=admin_name,
|
|
|
admin_account_password=password)
|
|
|
return account_info
|
|
|
except Exception as e:
|
|
|
db.session.rollback()
|
|
|
logging.exception("Register failed")
|
|
|
raise AccountRegisterError(f"Registration failed: {e}") from e
|
|
|
|
|
|
@staticmethod
|
|
|
def init_tenant(
|
|
|
target_tenant_id: str,
|
|
|
target_tenant_name: str,
|
|
|
) -> TenantData:
|
|
|
db.session.begin_nested()
|
|
|
try:
|
|
|
account = flask_login.current_user
|
|
|
tenant = account.current_tenant
|
|
|
# 初始化大模型插槽
|
|
|
TenantExtService.install_plugin(tenant_id=tenant.id)
|
|
|
# 初始化知识库
|
|
|
datasets = DatasetExtService.init_dataset(
|
|
|
tenant=tenant, target_tenant_id=target_tenant_id,target_tenant_name=target_tenant_name,account=account
|
|
|
)
|
|
|
# 获取Api token
|
|
|
api_token = DatasetExtService().get_or_add_datasets_api_token(tenant_id=tenant.id)
|
|
|
db.session.commit()
|
|
|
dataset_ids = [dataset.id for dataset in datasets]
|
|
|
tenant_data = TenantData(api_key=api_token.token,
|
|
|
dataset_ids=dataset_ids)
|
|
|
return tenant_data
|
|
|
except Exception as e:
|
|
|
db.session.rollback()
|
|
|
logging.exception("Register failed")
|
|
|
raise AccountRegisterError(f"Registration failed: {e}") from e
|
|
|
|
|
|
|