You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
229 lines
7.4 KiB
Python
229 lines
7.4 KiB
Python
from vanna.ollama import Ollama
|
|
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
|
|
from functools import wraps
|
|
from flask import Flask, jsonify, Response, request
|
|
import flask
|
|
from extensions.storage.cache import MemoryCache
|
|
from dify_app import DifyApp
|
|
from vanna.milvus import Milvus_VectorStore
|
|
from pymilvus import MilvusClient
|
|
from configs import dify_config
|
|
# SETUP
|
|
cache = MemoryCache()
|
|
milvus_uri = dify_config.MILVUS_URI
|
|
|
|
milvus_client = MilvusClient(uri=milvus_uri)
|
|
milvus_client.use_database("test")
|
|
class MyVanna(Milvus_VectorStore, Ollama):
|
|
def __init__(self, config=None):
|
|
Milvus_VectorStore.__init__(self, config=config)
|
|
Ollama.__init__(self, config=config)
|
|
|
|
# vn = MyVanna(config={
|
|
# 'model': 'qwen2:7b', # 本地ollama大模型名称
|
|
# 'ollama_host':'http://wsd.wisdomidata.com:19042', # 本地ollama大模型服务地址
|
|
# 'milvus_client': milvus_client, # 本地milvus向量数据库服务地址
|
|
# "n_results": 12,
|
|
# })
|
|
# vn.connect_to_postgres(
|
|
# host=dify_config.DB_HOST,
|
|
# dbname='vanna_demo',
|
|
# user=dify_config.DB_USERNAME,
|
|
# password=dify_config.DB_PASSWORD,
|
|
# port=dify_config.DB_PORT
|
|
# )
|
|
# vn.connect_to_mysql(
|
|
# host='122.51.104.137',
|
|
# port=33306,
|
|
# dbname='demo',
|
|
# user='sws',
|
|
# password='123456'
|
|
# )
|
|
|
|
|
|
|
|
def init_app(app: DifyApp):
|
|
|
|
def requires_cache(fields):
|
|
def decorator(f):
|
|
@wraps(f)
|
|
def decorated(*args, **kwargs):
|
|
id = request.args.get('id')
|
|
|
|
if id is None:
|
|
return jsonify({"type": "error", "error": "No id provided"})
|
|
|
|
for field in fields:
|
|
if cache.get(id=id, field=field) is None:
|
|
return jsonify({"type": "error", "error": f"No {field} found"})
|
|
|
|
field_values = {field: cache.get(id=id, field=field) for field in fields}
|
|
|
|
# Add the id to the field_values
|
|
field_values['id'] = id
|
|
|
|
return f(*args, **field_values, **kwargs)
|
|
|
|
return decorated
|
|
|
|
return decorator
|
|
|
|
@app.route('/api/v0/generate_questions', methods=['GET'])
|
|
def generate_questions():
|
|
return jsonify({
|
|
"type": "question_list",
|
|
"questions": vn.generate_questions(),
|
|
"header": "Here are some questions you can ask:"
|
|
})
|
|
|
|
@app.route('/api/v0/generate_sql', methods=['GET'])
|
|
def generate_sql():
|
|
question = flask.request.args.get('question')
|
|
|
|
if question is None:
|
|
return jsonify({"type": "error", "error": "No question provided"})
|
|
|
|
id = cache.generate_id(question=question)
|
|
sql = vn.generate_sql(question=question)
|
|
|
|
cache.set(id=id, field='question', value=question)
|
|
cache.set(id=id, field='sql', value=sql)
|
|
|
|
return jsonify(
|
|
{
|
|
"type": "sql",
|
|
"id": id,
|
|
"text": sql,
|
|
})
|
|
|
|
@app.route('/api/v0/run_sql', methods=['GET'])
|
|
@requires_cache(['sql'])
|
|
def run_sql(id: str, sql: str):
|
|
try:
|
|
df = vn.run_sql(sql=sql)
|
|
|
|
cache.set(id=id, field='df', value=df)
|
|
|
|
return jsonify(
|
|
{
|
|
"type": "df",
|
|
"id": id,
|
|
"df": df.head(10).to_json(orient='records'),
|
|
})
|
|
|
|
except Exception as e:
|
|
return jsonify({"type": "error", "error": str(e)})
|
|
|
|
@app.route('/api/v0/download_csv', methods=['GET'])
|
|
@requires_cache(['df'])
|
|
def download_csv(id: str, df):
|
|
csv = df.to_csv()
|
|
|
|
return Response(
|
|
csv,
|
|
mimetype="text/csv",
|
|
headers={"Content-disposition":
|
|
f"attachment; filename={id}.csv"})
|
|
|
|
@app.route('/api/v0/generate_plotly_figure', methods=['GET'])
|
|
@requires_cache(['df', 'question', 'sql'])
|
|
def generate_plotly_figure(id: str, df, question, sql):
|
|
try:
|
|
code = vn.generate_plotly_code(question=question, sql=sql,
|
|
df_metadata=f"Running df.dtypes gives:\n {df.dtypes}")
|
|
fig = vn.get_plotly_figure(plotly_code=code, df=df, dark_mode=False)
|
|
fig_json = fig.to_json()
|
|
|
|
cache.set(id=id, field='fig_json', value=fig_json)
|
|
|
|
return jsonify(
|
|
{
|
|
"type": "plotly_figure",
|
|
"id": id,
|
|
"fig": fig_json,
|
|
})
|
|
except Exception as e:
|
|
# Print the stack trace
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
return jsonify({"type": "error", "error": str(e)})
|
|
|
|
@app.route('/api/v0/get_training_data', methods=['GET'])
|
|
def get_training_data():
|
|
df = vn.get_training_data()
|
|
|
|
return jsonify(
|
|
{
|
|
"type": "df",
|
|
"id": "training_data",
|
|
"df": df.head(25).to_json(orient='records'),
|
|
})
|
|
|
|
@app.route('/api/v0/remove_training_data', methods=['POST'])
|
|
def remove_training_data():
|
|
# Get id from the JSON body
|
|
id = flask.request.json.get('id')
|
|
|
|
if id is None:
|
|
return jsonify({"type": "error", "error": "No id provided"})
|
|
|
|
if vn.remove_training_data(id=id):
|
|
return jsonify({"success": True})
|
|
else:
|
|
return jsonify({"type": "error", "error": "Couldn't remove training data"})
|
|
|
|
@app.route('/api/v0/train', methods=['POST'])
|
|
def add_training_data():
|
|
question = flask.request.json.get('question')
|
|
sql = flask.request.json.get('sql')
|
|
ddl = flask.request.json.get('ddl')
|
|
documentation = flask.request.json.get('documentation')
|
|
|
|
try:
|
|
id = vn.train(question=question, sql=sql, ddl=ddl, documentation=documentation)
|
|
|
|
return jsonify({"id": id})
|
|
except Exception as e:
|
|
print("TRAINING ERROR", e)
|
|
return jsonify({"type": "error", "error": str(e)})
|
|
|
|
@app.route('/api/v0/generate_followup_questions', methods=['GET'])
|
|
@requires_cache(['df', 'question', 'sql'])
|
|
def generate_followup_questions(id: str, df, question, sql):
|
|
followup_questions = vn.generate_followup_questions(question=question, sql=sql, df=df)
|
|
|
|
cache.set(id=id, field='followup_questions', value=followup_questions)
|
|
|
|
return jsonify(
|
|
{
|
|
"type": "question_list",
|
|
"id": id,
|
|
"questions": followup_questions,
|
|
"header": "Here are some followup questions you can ask:"
|
|
})
|
|
|
|
@app.route('/api/v0/load_question', methods=['GET'])
|
|
@requires_cache(['question', 'sql', 'df', 'fig_json', 'followup_questions'])
|
|
def load_question(id: str, question, sql, df, fig_json, followup_questions):
|
|
try:
|
|
return jsonify(
|
|
{
|
|
"type": "question_cache",
|
|
"id": id,
|
|
"question": question,
|
|
"sql": sql,
|
|
"df": df.head(10).to_json(orient='records'),
|
|
"fig": fig_json,
|
|
"followup_questions": followup_questions,
|
|
})
|
|
|
|
except Exception as e:
|
|
return jsonify({"type": "error", "error": str(e)})
|
|
|
|
@app.route('/api/v0/get_question_history', methods=['GET'])
|
|
def get_question_history():
|
|
return jsonify({"type": "question_history", "questions": cache.get_all(field_list=['question'])})
|