merge main
commit
99ffe43e91
@ -0,0 +1,38 @@
|
||||
model: gemini-exp-1114
|
||||
label:
|
||||
en_US: Gemini exp 1114
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1,309 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
_import_err_msg = (
|
||||
"`alibabacloud_gpdb20160503` and `alibabacloud_tea_openapi` packages not found, "
|
||||
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
|
||||
)
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class AnalyticdbVectorOpenAPIConfig(BaseModel):
|
||||
access_key_id: str
|
||||
access_key_secret: str
|
||||
region_id: str
|
||||
instance_id: str
|
||||
account: str
|
||||
account_password: str
|
||||
namespace: str = "dify"
|
||||
namespace_password: str = (None,)
|
||||
metrics: str = "cosine"
|
||||
read_timeout: int = 60000
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["access_key_id"]:
|
||||
raise ValueError("config ANALYTICDB_KEY_ID is required")
|
||||
if not values["access_key_secret"]:
|
||||
raise ValueError("config ANALYTICDB_KEY_SECRET is required")
|
||||
if not values["region_id"]:
|
||||
raise ValueError("config ANALYTICDB_REGION_ID is required")
|
||||
if not values["instance_id"]:
|
||||
raise ValueError("config ANALYTICDB_INSTANCE_ID is required")
|
||||
if not values["account"]:
|
||||
raise ValueError("config ANALYTICDB_ACCOUNT is required")
|
||||
if not values["account_password"]:
|
||||
raise ValueError("config ANALYTICDB_PASSWORD is required")
|
||||
if not values["namespace_password"]:
|
||||
raise ValueError("config ANALYTICDB_NAMESPACE_PASSWORD is required")
|
||||
return values
|
||||
|
||||
def to_analyticdb_client_params(self):
|
||||
return {
|
||||
"access_key_id": self.access_key_id,
|
||||
"access_key_secret": self.access_key_secret,
|
||||
"region_id": self.region_id,
|
||||
"read_timeout": self.read_timeout,
|
||||
}
|
||||
|
||||
|
||||
class AnalyticdbVectorOpenAPI:
|
||||
def __init__(self, collection_name: str, config: AnalyticdbVectorOpenAPIConfig):
|
||||
try:
|
||||
from alibabacloud_gpdb20160503.client import Client
|
||||
from alibabacloud_tea_openapi import models as open_api_models
|
||||
except:
|
||||
raise ImportError(_import_err_msg)
|
||||
self._collection_name = collection_name.lower()
|
||||
self.config = config
|
||||
self._client_config = open_api_models.Config(user_agent="dify", **config.to_analyticdb_client_params())
|
||||
self._client = Client(self._client_config)
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
cache_key = f"vector_initialize_{self.config.instance_id}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
database_exist_cache_key = f"vector_initialize_{self.config.instance_id}"
|
||||
if redis_client.get(database_exist_cache_key):
|
||||
return
|
||||
self._initialize_vector_database()
|
||||
self._create_namespace_if_not_exists()
|
||||
redis_client.set(database_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _initialize_vector_database(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.InitVectorDatabaseRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.init_vector_database(request)
|
||||
|
||||
def _create_namespace_if_not_exists(self) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
)
|
||||
self._client.describe_namespace(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
request = gpdb_20160503_models.CreateNamespaceRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
)
|
||||
self._client.create_namespace(request)
|
||||
else:
|
||||
raise ValueError(f"failed to create namespace {self.config.namespace}: {e}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
from Tea.exceptions import TeaException
|
||||
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
try:
|
||||
request = gpdb_20160503_models.DescribeCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
)
|
||||
self._client.describe_collection(request)
|
||||
except TeaException as e:
|
||||
if e.statusCode == 404:
|
||||
metadata = '{"ref_doc_id":"text","page_content":"text","metadata_":"jsonb"}'
|
||||
full_text_retrieval_fields = "page_content"
|
||||
request = gpdb_20160503_models.CreateCollectionRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
manager_account=self.config.account,
|
||||
manager_account_password=self.config.account_password,
|
||||
namespace=self.config.namespace,
|
||||
collection=self._collection_name,
|
||||
dimension=embedding_dimension,
|
||||
metrics=self.config.metrics,
|
||||
metadata=metadata,
|
||||
full_text_retrieval_fields=full_text_retrieval_fields,
|
||||
)
|
||||
self._client.create_collection(request)
|
||||
else:
|
||||
raise ValueError(f"failed to create collection {self._collection_name}: {e}")
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
rows: list[gpdb_20160503_models.UpsertCollectionDataRequestRows] = []
|
||||
for doc, embedding in zip(documents, embeddings, strict=True):
|
||||
metadata = {
|
||||
"ref_doc_id": doc.metadata["doc_id"],
|
||||
"page_content": doc.page_content,
|
||||
"metadata_": json.dumps(doc.metadata),
|
||||
}
|
||||
rows.append(
|
||||
gpdb_20160503_models.UpsertCollectionDataRequestRows(
|
||||
vector=embedding,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
request = gpdb_20160503_models.UpsertCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
rows=rows,
|
||||
)
|
||||
self._client.upsert_collection_data(request)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
metrics=self.config.metrics,
|
||||
include_values=True,
|
||||
vector=None,
|
||||
content=None,
|
||||
top_k=1,
|
||||
filter=f"ref_doc_id='{id}'",
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
return len(response.body.matches.match) > 0
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
ids_str = ",".join(f"'{id}'" for id in ids)
|
||||
ids_str = f"({ids_str})"
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"ref_doc_id IN {ids_str}",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
collection_data=None,
|
||||
collection_data_filter=f"metadata_ ->> '{key}' = '{value}'",
|
||||
)
|
||||
self._client.delete_collection_data(request)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=query_vector,
|
||||
content=None,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
vector=match.values.value,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return documents
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
request = gpdb_20160503_models.QueryCollectionDataRequest(
|
||||
dbinstance_id=self.config.instance_id,
|
||||
region_id=self.config.region_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
collection=self._collection_name,
|
||||
include_values=kwargs.pop("include_values", True),
|
||||
metrics=self.config.metrics,
|
||||
vector=None,
|
||||
content=query,
|
||||
top_k=kwargs.get("top_k", 4),
|
||||
filter=None,
|
||||
)
|
||||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
page_content=match.metadata.get("page_content"),
|
||||
vector=match.values.value,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
try:
|
||||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
|
||||
|
||||
request = gpdb_20160503_models.DeleteCollectionRequest(
|
||||
collection=self._collection_name,
|
||||
dbinstance_id=self.config.instance_id,
|
||||
namespace=self.config.namespace,
|
||||
namespace_password=self.config.namespace_password,
|
||||
region_id=self.config.region_id,
|
||||
)
|
||||
self._client.delete_collection(request)
|
||||
except Exception as e:
|
||||
raise e
|
||||
@ -0,0 +1,245 @@
|
||||
import json
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2.extras
|
||||
import psycopg2.pool
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class AnalyticdbVectorBySqlConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
account: str
|
||||
account_password: str
|
||||
min_connection: int
|
||||
max_connection: int
|
||||
namespace: str = "dify"
|
||||
metrics: str = "cosine"
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config ANALYTICDB_HOST is required")
|
||||
if not values["port"]:
|
||||
raise ValueError("config ANALYTICDB_PORT is required")
|
||||
if not values["account"]:
|
||||
raise ValueError("config ANALYTICDB_ACCOUNT is required")
|
||||
if not values["account_password"]:
|
||||
raise ValueError("config ANALYTICDB_PASSWORD is required")
|
||||
if not values["min_connection"]:
|
||||
raise ValueError("config ANALYTICDB_MIN_CONNECTION is required")
|
||||
if not values["max_connection"]:
|
||||
raise ValueError("config ANALYTICDB_MAX_CONNECTION is required")
|
||||
if values["min_connection"] > values["max_connection"]:
|
||||
raise ValueError("config ANALYTICDB_MIN_CONNECTION should less than ANALYTICDB_MAX_CONNECTION")
|
||||
return values
|
||||
|
||||
|
||||
class AnalyticdbVectorBySql:
|
||||
def __init__(self, collection_name: str, config: AnalyticdbVectorBySqlConfig):
|
||||
self._collection_name = collection_name.lower()
|
||||
self.databaseName = "knowledgebase"
|
||||
self.config = config
|
||||
self.table_name = f"{self.config.namespace}.{self._collection_name}"
|
||||
self.pool = None
|
||||
self._initialize()
|
||||
if not self.pool:
|
||||
self.pool = self._create_connection_pool()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
cache_key = f"vector_initialize_{self.config.host}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
database_exist_cache_key = f"vector_initialize_{self.config.host}"
|
||||
if redis_client.get(database_exist_cache_key):
|
||||
return
|
||||
self._initialize_vector_database()
|
||||
redis_client.set(database_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def _create_connection_pool(self):
|
||||
return psycopg2.pool.SimpleConnectionPool(
|
||||
self.config.min_connection,
|
||||
self.config.max_connection,
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
user=self.config.account,
|
||||
password=self.config.account_password,
|
||||
database=self.databaseName,
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _get_cursor(self):
|
||||
conn = self.pool.getconn()
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
yield cur
|
||||
finally:
|
||||
cur.close()
|
||||
conn.commit()
|
||||
self.pool.putconn(conn)
|
||||
|
||||
def _initialize_vector_database(self) -> None:
|
||||
conn = psycopg2.connect(
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
user=self.config.account,
|
||||
password=self.config.account_password,
|
||||
database="postgres",
|
||||
)
|
||||
conn.autocommit = True
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
cur.execute(f"CREATE DATABASE {self.databaseName}")
|
||||
except Exception as e:
|
||||
if "already exists" in str(e):
|
||||
return
|
||||
raise e
|
||||
finally:
|
||||
cur.close()
|
||||
conn.close()
|
||||
self.pool = self._create_connection_pool()
|
||||
with self._get_cursor() as cur:
|
||||
try:
|
||||
cur.execute("CREATE TEXT SEARCH CONFIGURATION zh_cn (PARSER = zhparser)")
|
||||
cur.execute("ALTER TEXT SEARCH CONFIGURATION zh_cn ADD MAPPING FOR n,v,a,i,e,l,x WITH simple")
|
||||
except Exception as e:
|
||||
if "already exists" not in str(e):
|
||||
raise e
|
||||
cur.execute(
|
||||
"CREATE OR REPLACE FUNCTION "
|
||||
"public.to_tsquery_from_text(txt text, lang regconfig DEFAULT 'english'::regconfig) "
|
||||
"RETURNS tsquery LANGUAGE sql IMMUTABLE STRICT AS $function$ "
|
||||
"SELECT to_tsquery(lang, COALESCE(string_agg(split_part(word, ':', 1), ' | '), '')) "
|
||||
"FROM (SELECT unnest(string_to_array(to_tsvector(lang, txt)::text, ' ')) AS word) "
|
||||
"AS words_only;$function$"
|
||||
)
|
||||
cur.execute(f"CREATE SCHEMA IF NOT EXISTS {self.config.namespace}")
|
||||
|
||||
def _create_collection_if_not_exists(self, embedding_dimension: int):
|
||||
cache_key = f"vector_indexing_{self._collection_name}"
|
||||
lock_name = f"{cache_key}_lock"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
|
||||
if redis_client.get(collection_exist_cache_key):
|
||||
return
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"CREATE TABLE IF NOT EXISTS {self.table_name}("
|
||||
f"id text PRIMARY KEY,"
|
||||
f"vector real[], ref_doc_id text, page_content text, metadata_ jsonb, "
|
||||
f"to_tsvector TSVECTOR"
|
||||
f") WITH (fillfactor=70) DISTRIBUTED BY (id);"
|
||||
)
|
||||
if embedding_dimension is not None:
|
||||
index_name = f"{self._collection_name}_embedding_idx"
|
||||
cur.execute(f"ALTER TABLE {self.table_name} ALTER COLUMN vector SET STORAGE PLAIN")
|
||||
cur.execute(
|
||||
f"CREATE INDEX {index_name} ON {self.table_name} USING ann(vector) "
|
||||
f"WITH(dim='{embedding_dimension}', distancemeasure='{self.config.metrics}', "
|
||||
f"pq_enable=0, external_storage=0)"
|
||||
)
|
||||
cur.execute(f"CREATE INDEX ON {self.table_name} USING gin(to_tsvector)")
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
values = []
|
||||
id_prefix = str(uuid.uuid4()) + "_"
|
||||
sql = f"""
|
||||
INSERT INTO {self.table_name}
|
||||
(id, ref_doc_id, vector, page_content, metadata_, to_tsvector)
|
||||
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
|
||||
"""
|
||||
for i, doc in enumerate(documents):
|
||||
values.append(
|
||||
(
|
||||
id_prefix + str(i),
|
||||
doc.metadata.get("doc_id", str(uuid.uuid4())),
|
||||
embeddings[i],
|
||||
doc.page_content,
|
||||
json.dumps(doc.metadata),
|
||||
doc.page_content,
|
||||
)
|
||||
)
|
||||
with self._get_cursor() as cur:
|
||||
psycopg2.extras.execute_batch(cur, sql, values)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"SELECT id FROM {self.table_name} WHERE ref_doc_id = %s", (id,))
|
||||
return cur.fetchone() is not None
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
try:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE ref_doc_id IN %s", (tuple(ids),))
|
||||
except Exception as e:
|
||||
if "does not exist" not in str(e):
|
||||
raise e
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
try:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE metadata_->>%s = %s", (key, value))
|
||||
except Exception as e:
|
||||
if "does not exist" not in str(e):
|
||||
raise e
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
with self._get_cursor() as cur:
|
||||
query_vector_str = json.dumps(query_vector)
|
||||
query_vector_str = "{" + query_vector_str[1:-1] + "}"
|
||||
cur.execute(
|
||||
f"SELECT t.id AS id, t.vector AS vector, (1.0 - t.score) AS score, "
|
||||
f"t.page_content as page_content, t.metadata_ AS metadata_ "
|
||||
f"FROM (SELECT id, vector, page_content, metadata_, vector <=> %s AS score "
|
||||
f"FROM {self.table_name} ORDER BY score LIMIT {top_k} ) t",
|
||||
(query_vector_str,),
|
||||
)
|
||||
documents = []
|
||||
for record in cur:
|
||||
id, vector, score, page_content, metadata = record
|
||||
if score > score_threshold:
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=page_content,
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(
|
||||
f"""SELECT id, vector, page_content, metadata_,
|
||||
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn')
|
||||
ORDER BY score DESC
|
||||
LIMIT {top_k}""",
|
||||
(f"'{query}'", f"'{query}'"),
|
||||
)
|
||||
documents = []
|
||||
for record in cur:
|
||||
id, vector, page_content, metadata, score = record
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=page_content,
|
||||
vector=vector,
|
||||
metadata=metadata,
|
||||
)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def delete(self) -> None:
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
@ -0,0 +1,87 @@
|
||||
from typing import Any
|
||||
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
SUMMARY_PROMPT = """
|
||||
User's query:
|
||||
{query}
|
||||
|
||||
Here are the news results:
|
||||
{content}
|
||||
|
||||
Please summarize the news in a few sentences.
|
||||
"""
|
||||
|
||||
|
||||
class DuckDuckGoNewsSearchTool(BuiltinTool):
|
||||
"""
|
||||
Tool for performing a news search using DuckDuckGo search engine.
|
||||
"""
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
query_dict = {
|
||||
"keywords": tool_parameters.get("query"),
|
||||
"timelimit": tool_parameters.get("timelimit"),
|
||||
"max_results": tool_parameters.get("max_results"),
|
||||
"safesearch": "moderate",
|
||||
"region": "wt-wt",
|
||||
}
|
||||
try:
|
||||
response = list(DDGS().news(**query_dict))
|
||||
if not response:
|
||||
return [self.create_text_message("No news found matching your criteria.")]
|
||||
except Exception as e:
|
||||
return [self.create_text_message(f"Error searching news: {str(e)}")]
|
||||
|
||||
require_summary = tool_parameters.get("require_summary", False)
|
||||
|
||||
if require_summary:
|
||||
results = "\n".join([f"{res.get('title')}: {res.get('body')}" for res in response])
|
||||
results = self.summary_results(user_id=user_id, content=results, query=query_dict["keywords"])
|
||||
return self.create_text_message(text=results)
|
||||
|
||||
# Create rich markdown content for each news item
|
||||
markdown_result = "\n\n"
|
||||
json_result = []
|
||||
|
||||
for res in response:
|
||||
markdown_result += f"### {res.get('title', 'Untitled')}\n\n"
|
||||
if res.get("date"):
|
||||
markdown_result += f"**Date:** {res.get('date')}\n\n"
|
||||
if res.get("body"):
|
||||
markdown_result += f"{res.get('body')}\n\n"
|
||||
if res.get("source"):
|
||||
markdown_result += f"*Source: {res.get('source')}*\n\n"
|
||||
if res.get("image"):
|
||||
markdown_result += f"})\n\n"
|
||||
markdown_result += f"[Read more]({res.get('url', '')})\n\n---\n\n"
|
||||
|
||||
json_result.append(
|
||||
self.create_json_message(
|
||||
{
|
||||
"title": res.get("title", ""),
|
||||
"date": res.get("date", ""),
|
||||
"body": res.get("body", ""),
|
||||
"url": res.get("url", ""),
|
||||
"image": res.get("image", ""),
|
||||
"source": res.get("source", ""),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
return [self.create_text_message(markdown_result)] + json_result
|
||||
|
||||
def summary_results(self, user_id: str, content: str, query: str) -> str:
|
||||
prompt = SUMMARY_PROMPT.format(query=query, content=content)
|
||||
summary = self.invoke_model(
|
||||
user_id=user_id,
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(content=prompt),
|
||||
],
|
||||
stop=[],
|
||||
)
|
||||
return summary.message.content
|
||||
@ -0,0 +1,71 @@
|
||||
identity:
|
||||
name: ddgo_news
|
||||
author: Assistant
|
||||
label:
|
||||
en_US: DuckDuckGo News Search
|
||||
zh_Hans: DuckDuckGo 新闻搜索
|
||||
description:
|
||||
human:
|
||||
en_US: Perform news searches on DuckDuckGo and get results.
|
||||
zh_Hans: 在 DuckDuckGo 上进行新闻搜索并获取结果。
|
||||
llm: Perform news searches on DuckDuckGo and get results.
|
||||
parameters:
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query String
|
||||
zh_Hans: 查询语句
|
||||
human_description:
|
||||
en_US: Search Query.
|
||||
zh_Hans: 搜索查询语句。
|
||||
llm_description: Key words for searching
|
||||
form: llm
|
||||
- name: max_results
|
||||
type: number
|
||||
required: true
|
||||
default: 5
|
||||
label:
|
||||
en_US: Max Results
|
||||
zh_Hans: 最大结果数量
|
||||
human_description:
|
||||
en_US: The Max Results
|
||||
zh_Hans: 最大结果数量
|
||||
form: form
|
||||
- name: timelimit
|
||||
type: select
|
||||
required: false
|
||||
options:
|
||||
- value: Day
|
||||
label:
|
||||
en_US: Current Day
|
||||
zh_Hans: 当天
|
||||
- value: Week
|
||||
label:
|
||||
en_US: Current Week
|
||||
zh_Hans: 本周
|
||||
- value: Month
|
||||
label:
|
||||
en_US: Current Month
|
||||
zh_Hans: 当月
|
||||
- value: Year
|
||||
label:
|
||||
en_US: Current Year
|
||||
zh_Hans: 今年
|
||||
label:
|
||||
en_US: Result Time Limit
|
||||
zh_Hans: 结果时间限制
|
||||
human_description:
|
||||
en_US: Use when querying results within a specific time range only.
|
||||
zh_Hans: 只查询一定时间范围内的结果时使用
|
||||
form: form
|
||||
- name: require_summary
|
||||
type: boolean
|
||||
default: false
|
||||
label:
|
||||
en_US: Require Summary
|
||||
zh_Hans: 是否总结
|
||||
human_description:
|
||||
en_US: Whether to pass the news results to llm for summarization.
|
||||
zh_Hans: 是否需要将新闻结果传给大模型总结
|
||||
form: form
|
||||
@ -0,0 +1,75 @@
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class DuckDuckGoVideoSearchTool(BuiltinTool):
|
||||
"""
|
||||
Tool for performing a video search using DuckDuckGo search engine.
|
||||
"""
|
||||
|
||||
IFRAME_TEMPLATE: ClassVar[str] = """
|
||||
<div style="position: relative; padding-bottom: 56.25%; height: 0; overflow: hidden; \
|
||||
max-width: 100%; border-radius: 8px;">
|
||||
<iframe
|
||||
style="position: absolute; top: 0; left: 0; width: 100%; height: 100%;"
|
||||
src="{src}"
|
||||
frameborder="0"
|
||||
allowfullscreen>
|
||||
</iframe>
|
||||
</div>"""
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
|
||||
query_dict = {
|
||||
"keywords": tool_parameters.get("query"),
|
||||
"region": tool_parameters.get("region", "wt-wt"),
|
||||
"safesearch": tool_parameters.get("safesearch", "moderate"),
|
||||
"timelimit": tool_parameters.get("timelimit"),
|
||||
"resolution": tool_parameters.get("resolution"),
|
||||
"duration": tool_parameters.get("duration"),
|
||||
"license_videos": tool_parameters.get("license_videos"),
|
||||
"max_results": tool_parameters.get("max_results"),
|
||||
}
|
||||
|
||||
# Remove None values to use API defaults
|
||||
query_dict = {k: v for k, v in query_dict.items() if v is not None}
|
||||
|
||||
# Get proxy URL from parameters
|
||||
proxy_url = tool_parameters.get("proxy_url", "").strip()
|
||||
|
||||
response = DDGS().videos(**query_dict)
|
||||
|
||||
# Create HTML result with embedded iframes
|
||||
markdown_result = "\n\n"
|
||||
json_result = []
|
||||
|
||||
for res in response:
|
||||
title = res.get("title", "")
|
||||
embed_html = res.get("embed_html", "")
|
||||
description = res.get("description", "")
|
||||
content_url = res.get("content", "")
|
||||
|
||||
# Handle TED.com videos
|
||||
if not embed_html and "ted.com/talks" in content_url:
|
||||
embed_url = content_url.replace("www.ted.com", "embed.ted.com")
|
||||
if proxy_url:
|
||||
embed_url = f"{proxy_url}{embed_url}"
|
||||
embed_html = self.IFRAME_TEMPLATE.format(src=embed_url)
|
||||
|
||||
# Original YouTube/other platform handling
|
||||
elif embed_html:
|
||||
embed_url = res.get("embed_url", "")
|
||||
if proxy_url and embed_url:
|
||||
embed_url = f"{proxy_url}{embed_url}"
|
||||
embed_html = self.IFRAME_TEMPLATE.format(src=embed_url)
|
||||
|
||||
markdown_result += f"{title}\n\n"
|
||||
markdown_result += f"{embed_html}\n\n"
|
||||
markdown_result += "---\n\n"
|
||||
|
||||
json_result.append(self.create_json_message(res))
|
||||
|
||||
return [self.create_text_message(markdown_result)] + json_result
|
||||
@ -0,0 +1,25 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class GiteeAIToolEmbedding(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {self.runtime.credentials['api_key']}",
|
||||
}
|
||||
|
||||
payload = {"inputs": tool_parameters.get("inputs")}
|
||||
model = tool_parameters.get("model", "bge-m3")
|
||||
url = f"https://ai.gitee.com/api/serverless/{model}/embeddings"
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
return self.create_text_message(f"Got Error Response:{response.text}")
|
||||
|
||||
return [self.create_text_message(response.content.decode("utf-8"))]
|
||||
@ -0,0 +1,37 @@
|
||||
identity:
|
||||
name: embedding
|
||||
author: gitee_ai
|
||||
label:
|
||||
en_US: embedding
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: Generate word embeddings using Serverless-supported models (compatible with OpenAI)
|
||||
llm: This tool is used to generate word embeddings from text input.
|
||||
parameters:
|
||||
- name: model
|
||||
type: string
|
||||
required: true
|
||||
in: path
|
||||
description:
|
||||
en_US: Supported Embedding (compatible with OpenAI) interface models
|
||||
enum:
|
||||
- bge-m3
|
||||
- bge-large-zh-v1.5
|
||||
- bge-small-zh-v1.5
|
||||
label:
|
||||
en_US: Service Model
|
||||
zh_Hans: 服务模型
|
||||
default: bge-m3
|
||||
form: form
|
||||
- name: inputs
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Input Text
|
||||
zh_Hans: 输入文本
|
||||
human_description:
|
||||
en_US: The text input used to generate embeddings.
|
||||
zh_Hans: 用于生成词向量的输入文本。
|
||||
llm_description: This text input will be used to generate embeddings.
|
||||
form: llm
|
||||
@ -0,0 +1,145 @@
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
TAVILY_API_URL = "https://api.tavily.com"
|
||||
|
||||
|
||||
class TavilyExtract:
|
||||
"""
|
||||
A class for extracting content from web pages using the Tavily Extract API.
|
||||
|
||||
Args:
|
||||
api_key (str): The API key for accessing the Tavily Extract API.
|
||||
|
||||
Methods:
|
||||
extract_content: Retrieves extracted content from the Tavily Extract API.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
|
||||
def extract_content(self, params: dict[str, Any]) -> dict:
|
||||
"""
|
||||
Retrieves extracted content from the Tavily Extract API.
|
||||
|
||||
Args:
|
||||
params (Dict[str, Any]): The extraction parameters.
|
||||
|
||||
Returns:
|
||||
dict: The extracted content.
|
||||
|
||||
"""
|
||||
# Ensure required parameters are set
|
||||
if "api_key" not in params:
|
||||
params["api_key"] = self.api_key
|
||||
|
||||
# Process parameters
|
||||
processed_params = self._process_params(params)
|
||||
|
||||
response = requests.post(f"{TAVILY_API_URL}/extract", json=processed_params)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _process_params(self, params: dict[str, Any]) -> dict:
|
||||
"""
|
||||
Processes and validates the extraction parameters.
|
||||
|
||||
Args:
|
||||
params (Dict[str, Any]): The extraction parameters.
|
||||
|
||||
Returns:
|
||||
dict: The processed parameters.
|
||||
"""
|
||||
processed_params = {}
|
||||
|
||||
# Process 'urls'
|
||||
if "urls" in params:
|
||||
urls = params["urls"]
|
||||
if isinstance(urls, str):
|
||||
processed_params["urls"] = [url.strip() for url in urls.replace(",", " ").split()]
|
||||
elif isinstance(urls, list):
|
||||
processed_params["urls"] = urls
|
||||
else:
|
||||
raise ValueError("The 'urls' parameter is required.")
|
||||
|
||||
# Only include 'api_key'
|
||||
processed_params["api_key"] = params.get("api_key", self.api_key)
|
||||
|
||||
return processed_params
|
||||
|
||||
|
||||
class TavilyExtractTool(BuiltinTool):
|
||||
"""
|
||||
A tool for extracting content from web pages using Tavily Extract.
|
||||
"""
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
Invokes the Tavily Extract tool with the given user ID and tool parameters.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user invoking the tool.
|
||||
tool_parameters (Dict[str, Any]): The parameters for the Tavily Extract tool.
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage | list[ToolInvokeMessage]: The result of the Tavily Extract tool invocation.
|
||||
"""
|
||||
urls = tool_parameters.get("urls", "")
|
||||
api_key = self.runtime.credentials.get("tavily_api_key")
|
||||
if not api_key:
|
||||
return self.create_text_message(
|
||||
"Tavily API key is missing. Please set the 'tavily_api_key' in credentials."
|
||||
)
|
||||
if not urls:
|
||||
return self.create_text_message("Please input at least one URL to extract.")
|
||||
|
||||
tavily_extract = TavilyExtract(api_key)
|
||||
try:
|
||||
raw_results = tavily_extract.extract_content(tool_parameters)
|
||||
except requests.HTTPError as e:
|
||||
return self.create_text_message(f"Error occurred while extracting content: {str(e)}")
|
||||
|
||||
if not raw_results.get("results"):
|
||||
return self.create_text_message("No content could be extracted from the provided URLs.")
|
||||
else:
|
||||
# Always return JSON message with all data
|
||||
json_message = self.create_json_message(raw_results)
|
||||
|
||||
# Create text message based on user-selected parameters
|
||||
text_message_content = self._format_results_as_text(raw_results)
|
||||
text_message = self.create_text_message(text=text_message_content)
|
||||
|
||||
return [json_message, text_message]
|
||||
|
||||
def _format_results_as_text(self, raw_results: dict) -> str:
|
||||
"""
|
||||
Formats the raw extraction results into a markdown text based on user-selected parameters.
|
||||
|
||||
Args:
|
||||
raw_results (dict): The raw extraction results.
|
||||
|
||||
Returns:
|
||||
str: The formatted markdown text.
|
||||
"""
|
||||
output_lines = []
|
||||
|
||||
for idx, result in enumerate(raw_results.get("results", []), 1):
|
||||
url = result.get("url", "")
|
||||
raw_content = result.get("raw_content", "")
|
||||
|
||||
output_lines.append(f"## Extracted Content {idx}: {url}\n")
|
||||
output_lines.append(f"**Raw Content:**\n{raw_content}\n")
|
||||
output_lines.append("---\n")
|
||||
|
||||
if raw_results.get("failed_results"):
|
||||
output_lines.append("## Failed URLs:\n")
|
||||
for failed in raw_results["failed_results"]:
|
||||
url = failed.get("url", "")
|
||||
error = failed.get("error", "Unknown error")
|
||||
output_lines.append(f"- {url}: {error}\n")
|
||||
|
||||
return "\n".join(output_lines)
|
||||
@ -0,0 +1,23 @@
|
||||
identity:
|
||||
name: tavily_extract
|
||||
author: Kalo Chin
|
||||
label:
|
||||
en_US: Tavily Extract
|
||||
zh_Hans: Tavily Extract
|
||||
description:
|
||||
human:
|
||||
en_US: A web extraction tool built specifically for AI agents (LLMs), delivering raw content from web pages.
|
||||
zh_Hans: 专为人工智能代理 (LLM) 构建的网页提取工具,提供网页的原始内容。
|
||||
llm: A tool for extracting raw content from web pages, designed for AI agents (LLMs).
|
||||
parameters:
|
||||
- name: urls
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: URLs
|
||||
zh_Hans: URLs
|
||||
human_description:
|
||||
en_US: A comma-separated list of URLs to extract content from.
|
||||
zh_Hans: 要从中提取内容的 URL 的逗号分隔列表。
|
||||
llm_description: A comma-separated list of URLs to extract content from.
|
||||
form: llm
|
||||
@ -0,0 +1,11 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg width="800px" height="800px" viewBox="0 -38 256 256" version="1.1" xmlns="http://www.w3.org/2000/svg"
|
||||
xmlns:xlink="http://www.w3.org/1999/xlink" preserveAspectRatio="xMidYMid">
|
||||
<g>
|
||||
<path d="M250.346231,28.0746923 C247.358133,17.0320558 238.732098,8.40602109 227.689461,5.41792308 C207.823743,0 127.868333,0 127.868333,0 C127.868333,0 47.9129229,0.164179487 28.0472049,5.58210256 C17.0045684,8.57020058 8.37853373,17.1962353 5.39043571,28.2388718 C-0.618533519,63.5374615 -2.94988224,117.322662 5.5546152,151.209308 C8.54271322,162.251944 17.1687479,170.877979 28.2113844,173.866077 C48.0771024,179.284 128.032513,179.284 128.032513,179.284 C128.032513,179.284 207.987923,179.284 227.853641,173.866077 C238.896277,170.877979 247.522312,162.251944 250.51041,151.209308 C256.847738,115.861464 258.801474,62.1091 250.346231,28.0746923 Z"
|
||||
fill="#FF0000">
|
||||
</path>
|
||||
<polygon fill="#FFFFFF" points="102.420513 128.06 168.749025 89.642 102.420513 51.224">
|
||||
</polygon>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.0 KiB |
@ -0,0 +1,81 @@
|
||||
from typing import Any, Union
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from youtube_transcript_api import YouTubeTranscriptApi
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class YouTubeTranscriptTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke the YouTube transcript tool
|
||||
"""
|
||||
try:
|
||||
# Extract parameters with defaults
|
||||
video_input = tool_parameters["video_id"]
|
||||
language = tool_parameters.get("language")
|
||||
output_format = tool_parameters.get("format", "text")
|
||||
preserve_formatting = tool_parameters.get("preserve_formatting", False)
|
||||
proxy = tool_parameters.get("proxy")
|
||||
cookies = tool_parameters.get("cookies")
|
||||
|
||||
# Extract video ID from URL if needed
|
||||
video_id = self._extract_video_id(video_input)
|
||||
|
||||
# Common kwargs for API calls
|
||||
kwargs = {"proxies": {"https": proxy} if proxy else None, "cookies": cookies}
|
||||
|
||||
try:
|
||||
if language:
|
||||
transcript_list = YouTubeTranscriptApi.list_transcripts(video_id, **kwargs)
|
||||
try:
|
||||
transcript = transcript_list.find_transcript([language])
|
||||
except:
|
||||
# If requested language not found, try translating from English
|
||||
transcript = transcript_list.find_transcript(["en"]).translate(language)
|
||||
transcript_data = transcript.fetch()
|
||||
else:
|
||||
transcript_data = YouTubeTranscriptApi.get_transcript(
|
||||
video_id, preserve_formatting=preserve_formatting, **kwargs
|
||||
)
|
||||
|
||||
# Format output
|
||||
formatter_class = {
|
||||
"json": "JSONFormatter",
|
||||
"pretty": "PrettyPrintFormatter",
|
||||
"srt": "SRTFormatter",
|
||||
"vtt": "WebVTTFormatter",
|
||||
}.get(output_format)
|
||||
|
||||
if formatter_class:
|
||||
from youtube_transcript_api import formatters
|
||||
|
||||
formatter = getattr(formatters, formatter_class)()
|
||||
formatted_transcript = formatter.format_transcript(transcript_data)
|
||||
else:
|
||||
formatted_transcript = " ".join(entry["text"] for entry in transcript_data)
|
||||
|
||||
return self.create_text_message(text=formatted_transcript)
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(text=f"Error getting transcript: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(text=f"Error processing request: {str(e)}")
|
||||
|
||||
def _extract_video_id(self, video_input: str) -> str:
|
||||
"""
|
||||
Extract video ID from URL or return as-is if already an ID
|
||||
"""
|
||||
if "youtube.com" in video_input or "youtu.be" in video_input:
|
||||
# Parse URL
|
||||
parsed_url = urlparse(video_input)
|
||||
if "youtube.com" in parsed_url.netloc:
|
||||
return parse_qs(parsed_url.query)["v"][0]
|
||||
else: # youtu.be
|
||||
return parsed_url.path[1:]
|
||||
return video_input # Assume it's already a video ID
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue