You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
gcgj-dify-1.7.0/api/extensions/ext_vanna.py

229 lines
7.4 KiB
Python

from vanna.ollama import Ollama
from dotenv import load_dotenv
load_dotenv()
from functools import wraps
from flask import Flask, jsonify, Response, request
import flask
from extensions.storage.cache import MemoryCache
from dify_app import DifyApp
from vanna.milvus import Milvus_VectorStore
from pymilvus import MilvusClient
from configs import dify_config
# SETUP
cache = MemoryCache()
milvus_uri = dify_config.MILVUS_URI
milvus_client = MilvusClient(uri=milvus_uri)
milvus_client.use_database("test")
class MyVanna(Milvus_VectorStore, Ollama):
def __init__(self, config=None):
Milvus_VectorStore.__init__(self, config=config)
Ollama.__init__(self, config=config)
# vn = MyVanna(config={
# 'model': 'qwen2:7b', # 本地ollama大模型名称
# 'ollama_host':'http://wsd.wisdomidata.com:19042', # 本地ollama大模型服务地址
# 'milvus_client': milvus_client, # 本地milvus向量数据库服务地址
# "n_results": 12,
# })
# vn.connect_to_postgres(
# host=dify_config.DB_HOST,
# dbname='vanna_demo',
# user=dify_config.DB_USERNAME,
# password=dify_config.DB_PASSWORD,
# port=dify_config.DB_PORT
# )
# vn.connect_to_mysql(
# host='122.51.104.137',
# port=33306,
# dbname='demo',
# user='sws',
# password='123456'
# )
def init_app(app: DifyApp):
def requires_cache(fields):
def decorator(f):
@wraps(f)
def decorated(*args, **kwargs):
id = request.args.get('id')
if id is None:
return jsonify({"type": "error", "error": "No id provided"})
for field in fields:
if cache.get(id=id, field=field) is None:
return jsonify({"type": "error", "error": f"No {field} found"})
field_values = {field: cache.get(id=id, field=field) for field in fields}
# Add the id to the field_values
field_values['id'] = id
return f(*args, **field_values, **kwargs)
return decorated
return decorator
@app.route('/api/v0/generate_questions', methods=['GET'])
def generate_questions():
return jsonify({
"type": "question_list",
"questions": vn.generate_questions(),
"header": "Here are some questions you can ask:"
})
@app.route('/api/v0/generate_sql', methods=['GET'])
def generate_sql():
question = flask.request.args.get('question')
if question is None:
return jsonify({"type": "error", "error": "No question provided"})
id = cache.generate_id(question=question)
sql = vn.generate_sql(question=question)
cache.set(id=id, field='question', value=question)
cache.set(id=id, field='sql', value=sql)
return jsonify(
{
"type": "sql",
"id": id,
"text": sql,
})
@app.route('/api/v0/run_sql', methods=['GET'])
@requires_cache(['sql'])
def run_sql(id: str, sql: str):
try:
df = vn.run_sql(sql=sql)
cache.set(id=id, field='df', value=df)
return jsonify(
{
"type": "df",
"id": id,
"df": df.head(10).to_json(orient='records'),
})
except Exception as e:
return jsonify({"type": "error", "error": str(e)})
@app.route('/api/v0/download_csv', methods=['GET'])
@requires_cache(['df'])
def download_csv(id: str, df):
csv = df.to_csv()
return Response(
csv,
mimetype="text/csv",
headers={"Content-disposition":
f"attachment; filename={id}.csv"})
@app.route('/api/v0/generate_plotly_figure', methods=['GET'])
@requires_cache(['df', 'question', 'sql'])
def generate_plotly_figure(id: str, df, question, sql):
try:
code = vn.generate_plotly_code(question=question, sql=sql,
df_metadata=f"Running df.dtypes gives:\n {df.dtypes}")
fig = vn.get_plotly_figure(plotly_code=code, df=df, dark_mode=False)
fig_json = fig.to_json()
cache.set(id=id, field='fig_json', value=fig_json)
return jsonify(
{
"type": "plotly_figure",
"id": id,
"fig": fig_json,
})
except Exception as e:
# Print the stack trace
import traceback
traceback.print_exc()
return jsonify({"type": "error", "error": str(e)})
@app.route('/api/v0/get_training_data', methods=['GET'])
def get_training_data():
df = vn.get_training_data()
return jsonify(
{
"type": "df",
"id": "training_data",
"df": df.head(25).to_json(orient='records'),
})
@app.route('/api/v0/remove_training_data', methods=['POST'])
def remove_training_data():
# Get id from the JSON body
id = flask.request.json.get('id')
if id is None:
return jsonify({"type": "error", "error": "No id provided"})
if vn.remove_training_data(id=id):
return jsonify({"success": True})
else:
return jsonify({"type": "error", "error": "Couldn't remove training data"})
@app.route('/api/v0/train', methods=['POST'])
def add_training_data():
question = flask.request.json.get('question')
sql = flask.request.json.get('sql')
ddl = flask.request.json.get('ddl')
documentation = flask.request.json.get('documentation')
try:
id = vn.train(question=question, sql=sql, ddl=ddl, documentation=documentation)
return jsonify({"id": id})
except Exception as e:
print("TRAINING ERROR", e)
return jsonify({"type": "error", "error": str(e)})
@app.route('/api/v0/generate_followup_questions', methods=['GET'])
@requires_cache(['df', 'question', 'sql'])
def generate_followup_questions(id: str, df, question, sql):
followup_questions = vn.generate_followup_questions(question=question, sql=sql, df=df)
cache.set(id=id, field='followup_questions', value=followup_questions)
return jsonify(
{
"type": "question_list",
"id": id,
"questions": followup_questions,
"header": "Here are some followup questions you can ask:"
})
@app.route('/api/v0/load_question', methods=['GET'])
@requires_cache(['question', 'sql', 'df', 'fig_json', 'followup_questions'])
def load_question(id: str, question, sql, df, fig_json, followup_questions):
try:
return jsonify(
{
"type": "question_cache",
"id": id,
"question": question,
"sql": sql,
"df": df.head(10).to_json(orient='records'),
"fig": fig_json,
"followup_questions": followup_questions,
})
except Exception as e:
return jsonify({"type": "error", "error": str(e)})
@app.route('/api/v0/get_question_history', methods=['GET'])
def get_question_history():
return jsonify({"type": "question_history", "questions": cache.get_all(field_list=['question'])})