You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
gcgj-dify-1.7.0/api/services/ext/account_ext_service.py

349 lines
15 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import logging
import yaml
from typing import Optional
import flask_login
from pathlib import Path
from constants.languages import languages
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.plugin.manager.exc import PluginDaemonClientSideError
from extensions.ext_database import db
from models.account import (
Account,
Tenant,
)
from services.account_service import AccountService, TenantService
from services.dataset_service import DatasetService
from services.errors.account import (
AccountRegisterError,
TenantNotFoundError,
)
from services.errors.workspace import WorkSpaceNotAllowedCreateError
from services.ext.dataset_ext_service import DatasetExtService
from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService
from services.plugin.plugin_service import PluginService
from configs import dify_config
from configs.ext_config import get_ext_config
import os
class AccountInfo:
def __init__(self, email, name, user_id, tenant_id):
self.email = email
self.name = name
self.user_id=user_id
self.tenant_id=tenant_id
def to_dict(self):
return {
"tenant_id": self.email,
"tenant_name": self.name,
"api_key": self.tenant_id,
}
class TenantAccountInfo:
def __init__(self, tenant_id:str,
tenant_name:str,
admin_account:str,
admin_account_password:str,
):
self.tenant_id = tenant_id
self.tenant_name = tenant_name
self.admin_account=admin_account
self.admin_account_password=admin_account_password
def to_dict(self):
return {
"tenant_id": self.tenant_id,
"tenant_name": self.tenant_name,
"admin_account": self.admin_account,
"admin_account_password": self.admin_account_password,
}
class TenantData:
def __init__(self,
api_key:str,
dataset_ids:list[str]
):
self.api_key=api_key
self.dataset_ids=dataset_ids
def to_dict(self):
return {
"api_key": self.api_key,
"dataset_ids": self.dataset_ids,
}
class AccountExtService:
@staticmethod
def create_account_and_tenant(
email: str,
name: str,
tenant_name: str,
target_tenant_id: str,
interface_language: Optional[str] = None,
password: Optional[str] = None
) -> Account:
"""create account"""
account = AccountService.create_account(
email=email, name=name, interface_language=interface_language, password=password, is_setup=True
)
account.target_tenant_id = target_tenant_id
TenantService.create_owner_tenant_if_not_exist(account=account,name=tenant_name,is_setup=True)
account.current_tenant.target_tenant_id = target_tenant_id
db.session.commit()
return account
@staticmethod
def get_admin_account() -> Account:
admin = db.session.query(Account).filter(Account.target_tenant_id=="100").first()
return admin
@staticmethod
def update_account_list(
accounts: list[AccountInfo],
target_tenant_id: str,
interface_language: Optional[str] = None,
):
db.session.begin_nested()
"""Register account"""
try:
# 获取对应的企业
tenant = TenantExtService.get_tenant_by_target_tenant_id(target_tenant_id=target_tenant_id)
if tenant is None:
raise TenantNotFoundError("企业未初始,请联系管理员!")
# 获取所有的用户列表
exists = db.session.query(Account).filter(Account.target_tenant_id == target_tenant_id).all()
#
existDict = { account.email: account for account in exists }
for account in accounts:
email = account["email"]
if email in existDict:
existAccount = existDict[email]
existAccount.name = account["name"]
existAccount.email = account["email"]
existAccount.user_id = account["user_id"]
else:
newAccount = AccountService.create_account(email=account["email"],
name=account["name"],
interface_language=interface_language or languages[0],
password="wisdom@123",
is_setup=True)
newAccount.user_id = account["user_id"]
newAccount.target_tenant_id = target_tenant_id
# 创建企业关系
TenantService.create_tenant_member(tenant, newAccount)
db.session.commit()
except WorkSpaceNotAllowedCreateError:
db.session.rollback()
except AccountRegisterError as are:
db.session.rollback()
logging.exception("Register failed")
raise are
except Exception as e:
db.session.rollback()
logging.exception("Register failed")
raise AccountRegisterError(f"Registration failed: {e}") from e
class TenantExtService:
@staticmethod
def get_tenant() -> Tenant:
# 获取第一个企业,为默认企业
tenant = db.session.query(Tenant).first()
return tenant
@staticmethod
def get_tenant_by_target_tenant_id(target_tenant_id:str) -> Tenant:
# 获取第一个企业,为默认企业
tenant = db.session.query(Tenant).filter(Tenant.target_tenant_id == target_tenant_id).first()
return tenant
@staticmethod
def setModeConfig(tenant_id:str, args:dict[str, object], provider:str) -> None:
model_load_balancing_service = ModelLoadBalancingService()
if (
"load_balancing" in args
and args["load_balancing"]
and "enabled" in args["load_balancing"]
and args["load_balancing"]["enabled"]
):
if "configs" not in args["load_balancing"]:
raise ValueError("invalid load balancing configs")
# save load balancing configs
model_load_balancing_service.update_load_balancing_configs(
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
configs=args["load_balancing"]["configs"],
)
# enable load balancing
model_load_balancing_service.enable_model_load_balancing(
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
)
else:
# disable load balancing
model_load_balancing_service.disable_model_load_balancing(
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
)
if args.get("config_from", "") != "predefined-model":
model_provider_service = ModelProviderService()
try:
model_provider_service.save_model_credentials(
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
credentials=args["credentials"],
)
except CredentialsValidateFailedError as ex:
logging.exception(
f"Failed to save model credentials, tenant_id: {tenant_id},"
f" model: {args.get('model')}, model_type: {args.get('model_type')}"
)
raise ValueError(str(ex))
@staticmethod
def install_plugin(tenant_id:str):
TenantExtService.install_langgenius(tenant_id=tenant_id)
TenantExtService.install_model(tenant_id=tenant_id)
@staticmethod
def install_model(tenant_id:str):
params = {
"INIT_MODEL_LLM_NAME" : dify_config.INIT_MODEL_LLM_NAME,
"INIT_MODEL_LLM_CONTEXT_SIZE" : dify_config.INIT_MODEL_LLM_CONTEXT_SIZE,
"INIT_MODEL_LLM_MAX_TOKENS" : dify_config.INIT_MODEL_LLM_MAX_TOKENS,
"INIT_MODEL_LLM_BASE_URL" : dify_config.INIT_MODEL_LLM_BASE_URL
}
llm_config = get_ext_config(file_name="plugin_llm_config.yml",params = params)
TenantExtService.setModeConfig(
tenant_id=tenant_id,args=llm_config,provider=dify_config.INIT_MODEL_LLM_PROVIDER
)
params = {
"INIT_MODEL_TEXT_EMBEDDING_NAME" : dify_config.INIT_MODEL_TEXT_EMBEDDING_NAME,
"INIT_MODEL_TEXT_EMBEDDING_CONTEXT_SIZE" : dify_config.INIT_MODEL_TEXT_EMBEDDING_CONTEXT_SIZE,
"INIT_MODEL_TEXT_EMBEDDING_MAX_TOKENS" : dify_config.INIT_MODEL_TEXT_EMBEDDING_MAX_TOKENS,
"INIT_MODEL_TEXT_EMBEDDING_BASE_URL" : dify_config.INIT_MODEL_TEXT_EMBEDDING_BASE_URL
}
text_embedding_config = get_ext_config(file_name="plugin_embedding_config.yml", params=params)
TenantExtService.setModeConfig(
tenant_id=tenant_id,args=text_embedding_config,provider=dify_config.INIT_MODEL_TEXT_EMBEDDING_PROVIDER
)
params = {
"INIT_MODEL_TEXT_EMBEDDING_RERANK_NAME": dify_config.INIT_MODEL_TEXT_EMBEDDING_RERANK_NAME,
"INIT_MODEL_TEXT_EMBEDDING_RERANK_BASE_URL": dify_config.INIT_MODEL_TEXT_EMBEDDING_RERANK_BASE_URL,
}
text_embedding_rerank_config = get_ext_config(file_name="plugin_embedding_rerank_config.yml", params=params)
TenantExtService.setModeConfig(
tenant_id=tenant_id,args=text_embedding_rerank_config,provider=dify_config.INIT_MODEL_TEXT_EMBEDDING_RERANK_PROVIDER
)
@staticmethod
def install_langgenius(tenant_id: str):
upload_unique_identifiers = TenantExtService.upload_langgenius(tenant_id=tenant_id)
# plugin_unique_identifiers = dify_config.PLUGIN_UNIQUE_IDENTIFIERS.split(",") if dify_config.PLUGIN_UNIQUE_IDENTIFIERS else []
# 查询已经安装的
tasks = PluginService.list(tenant_id)
# 已经安装的插件
exists_plugin_unique_identifiers = [item.plugin_unique_identifier for item in tasks]
# 去除已经安装的插件ID只保留未安装的插件ID
new_unique_identifiers = [uui for uui in upload_unique_identifiers if uui not in exists_plugin_unique_identifiers]
# 安装插件
PluginService.install_from_marketplace_pkg(tenant_id, new_unique_identifiers)
@staticmethod
def upload_langgenius(tenant_id: str) -> list[str]:
directory = Path(__file__).parent.parent.parent / 'plugins' / 'langgenius'
unique_identifiers = []
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
if os.path.isfile(file_path):
print(f"读取文件:{file_path}")
with open(file_path, 'rb') as f:
content = f.read()
try:
response = PluginService.upload_pkg(tenant_id=tenant_id, pkg=content)
unique_identifier = response.unique_identifier
unique_identifiers.append(unique_identifier)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return unique_identifiers
@staticmethod
def enable_tenant(
target_tenant_id: str,
target_tenant_name: str,
) -> TenantAccountInfo:
db.session.begin_nested()
password = "wisdom@123"
try:
email = f"admin@{target_tenant_id}.com"
admin_name = f"{target_tenant_name}-管理员"
# 判断企业是否已经创建
tenant = TenantExtService.get_tenant_by_target_tenant_id(target_tenant_id)
if tenant is not None:
account = AccountService.get_user_through_email(email)
if account is None:
account = AccountService.create_account(email=email, name=admin_name, password=password, is_setup=True,interface_language="zh-Hans")
TenantService.create_tenant_member(tenant, account, role="owner")
account.target_tenant_id = target_tenant_id
else:
account = AccountExtService.create_account_and_tenant(email=email,
name=admin_name,
tenant_name=target_tenant_name,
target_tenant_id=target_tenant_id,
interface_language="zh-Hans",
password=password)
# 获取第一个企业,为默认企业
tenant = account.current_tenant
account_info = TenantAccountInfo(tenant_name=tenant.name,
tenant_id=tenant.id,
admin_account=admin_name,
admin_account_password=password)
return account_info
except Exception as e:
db.session.rollback()
logging.exception("Register failed")
raise AccountRegisterError(f"Registration failed: {e}") from e
@staticmethod
def init_tenant(
target_tenant_id: str,
target_tenant_name: str,
) -> TenantData:
db.session.begin_nested()
try:
account = flask_login.current_user
tenant = account.current_tenant
# 初始化大模型插槽
TenantExtService.install_plugin(tenant_id=tenant.id)
# 初始化知识库
datasets = DatasetExtService.init_dataset(
tenant=tenant, target_tenant_id=target_tenant_id,target_tenant_name=target_tenant_name,account=account
)
# 获取Api token
api_token = DatasetExtService().get_or_add_datasets_api_token(tenant_id=tenant.id)
db.session.commit()
dataset_ids = [dataset.id for dataset in datasets]
tenant_data = TenantData(api_key=api_token.token,
dataset_ids=dataset_ids)
return tenant_data
except Exception as e:
db.session.rollback()
logging.exception("Register failed")
raise AccountRegisterError(f"Registration failed: {e}") from e