feat: Add Vanna.AI as a builtin tool (#4878)
Co-authored-by: Yeuoly <admin@srmxy.cn>pull/4914/head^2
parent
7133a16511
commit
2d9f55b632
Binary file not shown.
|
After Width: | Height: | Size: 4.5 KiB |
@ -0,0 +1,119 @@
|
|||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
from vanna.remote import VannaDefault
|
||||||
|
|
||||||
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
|
from core.tools.errors import ToolProviderCredentialValidationError
|
||||||
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
|
|
||||||
|
|
||||||
|
class VannaTool(BuiltinTool):
|
||||||
|
def _invoke(
|
||||||
|
self, user_id: str, tool_parameters: dict[str, Any]
|
||||||
|
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||||
|
"""
|
||||||
|
invoke tools
|
||||||
|
"""
|
||||||
|
api_key = self.runtime.credentials.get("api_key", None)
|
||||||
|
if not api_key:
|
||||||
|
raise ToolProviderCredentialValidationError("Please input api key")
|
||||||
|
|
||||||
|
model = tool_parameters.get("model", "")
|
||||||
|
if not model:
|
||||||
|
return self.create_text_message("Please input RAG model")
|
||||||
|
|
||||||
|
prompt = tool_parameters.get("prompt", "")
|
||||||
|
if not prompt:
|
||||||
|
return self.create_text_message("Please input prompt")
|
||||||
|
|
||||||
|
url = tool_parameters.get("url", "")
|
||||||
|
if not url:
|
||||||
|
return self.create_text_message("Please input URL/Host/DSN")
|
||||||
|
|
||||||
|
db_name = tool_parameters.get("db_name", "")
|
||||||
|
username = tool_parameters.get("username", "")
|
||||||
|
password = tool_parameters.get("password", "")
|
||||||
|
port = tool_parameters.get("port", 0)
|
||||||
|
|
||||||
|
vn = VannaDefault(model=model, api_key=api_key)
|
||||||
|
|
||||||
|
db_type = tool_parameters.get("db_type", "")
|
||||||
|
if db_type in ["Postgres", "MySQL", "Hive", "ClickHouse"]:
|
||||||
|
if not db_name:
|
||||||
|
return self.create_text_message("Please input database name")
|
||||||
|
if not username:
|
||||||
|
return self.create_text_message("Please input username")
|
||||||
|
if port < 1:
|
||||||
|
return self.create_text_message("Please input port")
|
||||||
|
|
||||||
|
schema_sql = "SELECT * FROM INFORMATION_SCHEMA.COLUMNS"
|
||||||
|
match db_type:
|
||||||
|
case "SQLite":
|
||||||
|
schema_sql = "SELECT type, sql FROM sqlite_master WHERE sql is not null"
|
||||||
|
vn.connect_to_sqlite(url)
|
||||||
|
case "Postgres":
|
||||||
|
vn.connect_to_postgres(host=url, dbname=db_name, user=username, password=password, port=port)
|
||||||
|
case "DuckDB":
|
||||||
|
vn.connect_to_duckdb(url=url)
|
||||||
|
case "SQLServer":
|
||||||
|
vn.connect_to_mssql(url)
|
||||||
|
case "MySQL":
|
||||||
|
vn.connect_to_mysql(host=url, dbname=db_name, user=username, password=password, port=port)
|
||||||
|
case "Oracle":
|
||||||
|
vn.connect_to_oracle(user=username, password=password, dsn=url)
|
||||||
|
case "Hive":
|
||||||
|
vn.connect_to_hive(host=url, dbname=db_name, user=username, password=password, port=port)
|
||||||
|
case "ClickHouse":
|
||||||
|
vn.connect_to_clickhouse(host=url, dbname=db_name, user=username, password=password, port=port)
|
||||||
|
|
||||||
|
enable_training = tool_parameters.get("enable_training", False)
|
||||||
|
reset_training_data = tool_parameters.get("reset_training_data", False)
|
||||||
|
if enable_training:
|
||||||
|
if reset_training_data:
|
||||||
|
existing_training_data = vn.get_training_data()
|
||||||
|
if len(existing_training_data) > 0:
|
||||||
|
for _, training_data in existing_training_data.iterrows():
|
||||||
|
vn.remove_training_data(training_data["id"])
|
||||||
|
|
||||||
|
ddl = tool_parameters.get("ddl", "")
|
||||||
|
question = tool_parameters.get("question", "")
|
||||||
|
sql = tool_parameters.get("sql", "")
|
||||||
|
memos = tool_parameters.get("memos", "")
|
||||||
|
training_metadata = tool_parameters.get("training_metadata", False)
|
||||||
|
|
||||||
|
if training_metadata:
|
||||||
|
if db_type == "SQLite":
|
||||||
|
df_ddl = vn.run_sql(schema_sql)
|
||||||
|
for ddl in df_ddl["sql"].to_list():
|
||||||
|
vn.train(ddl=ddl)
|
||||||
|
else:
|
||||||
|
df_information_schema = vn.run_sql(schema_sql)
|
||||||
|
plan = vn.get_training_plan_generic(df_information_schema)
|
||||||
|
vn.train(plan=plan)
|
||||||
|
|
||||||
|
if ddl:
|
||||||
|
vn.train(ddl=ddl)
|
||||||
|
|
||||||
|
if sql:
|
||||||
|
if question:
|
||||||
|
vn.train(question=question, sql=sql)
|
||||||
|
else:
|
||||||
|
vn.train(sql=sql)
|
||||||
|
if memos:
|
||||||
|
vn.train(documentation=memos)
|
||||||
|
|
||||||
|
generate_chart = tool_parameters.get("generate_chart", True)
|
||||||
|
res = vn.ask(prompt, False, True, generate_chart)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
|
||||||
|
if res is not None:
|
||||||
|
result.append(self.create_text_message(res[0]))
|
||||||
|
if len(res) > 1 and res[1] is not None:
|
||||||
|
result.append(self.create_text_message(res[1].to_markdown()))
|
||||||
|
if len(res) > 2 and res[2] is not None:
|
||||||
|
result.append(
|
||||||
|
self.create_blob_message(blob=res[2].to_image(format="svg"), meta={"mime_type": "image/svg+xml"})
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
@ -0,0 +1,25 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.tools.errors import ToolProviderCredentialValidationError
|
||||||
|
from core.tools.provider.builtin.vanna.tools.vanna import VannaTool
|
||||||
|
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||||
|
|
||||||
|
|
||||||
|
class VannaProvider(BuiltinToolProviderController):
|
||||||
|
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||||
|
try:
|
||||||
|
VannaTool().fork_tool_runtime(
|
||||||
|
runtime={
|
||||||
|
"credentials": credentials,
|
||||||
|
}
|
||||||
|
).invoke(
|
||||||
|
user_id='',
|
||||||
|
tool_parameters={
|
||||||
|
"model": "chinook",
|
||||||
|
"db_type": "SQLite",
|
||||||
|
"url": "https://vanna.ai/Chinook.sqlite",
|
||||||
|
"query": "What are the top 10 customers by sales?"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ToolProviderCredentialValidationError(str(e))
|
||||||
@ -0,0 +1,25 @@
|
|||||||
|
identity:
|
||||||
|
author: QCTC
|
||||||
|
name: vanna
|
||||||
|
label:
|
||||||
|
en_US: Vanna.AI
|
||||||
|
zh_Hans: Vanna.AI
|
||||||
|
description:
|
||||||
|
en_US: The fastest way to get actionable insights from your database just by asking questions.
|
||||||
|
zh_Hans: 一个基于大模型和RAG的Text2SQL工具。
|
||||||
|
icon: icon.png
|
||||||
|
credentials_for_provider:
|
||||||
|
api_key:
|
||||||
|
type: secret-input
|
||||||
|
required: true
|
||||||
|
label:
|
||||||
|
en_US: API key
|
||||||
|
zh_Hans: API key
|
||||||
|
placeholder:
|
||||||
|
en_US: Please input your API key
|
||||||
|
zh_Hans: 请输入你的 API key
|
||||||
|
pt_BR: Please input your API key
|
||||||
|
help:
|
||||||
|
en_US: Get your API key from Vanna.AI
|
||||||
|
zh_Hans: 从 Vanna.AI 获取你的 API key
|
||||||
|
url: https://vanna.ai/account/profile
|
||||||
Loading…
Reference in New Issue