parent
7a00798027
commit
de8800f41a
@ -0,0 +1,115 @@
|
||||
import json
|
||||
import operator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class BedrockRetrieveTool(BuiltinTool):
|
||||
bedrock_client: Any = None
|
||||
knowledge_base_id: str = None
|
||||
topk: int = None
|
||||
|
||||
def _bedrock_retrieve(
|
||||
self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None
|
||||
):
|
||||
try:
|
||||
retrieval_query = {"text": query_input}
|
||||
|
||||
retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}
|
||||
|
||||
# 如果有元数据过滤条件,则添加到检索配置中
|
||||
if metadata_filter:
|
||||
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter
|
||||
|
||||
response = self.bedrock_client.retrieve(
|
||||
knowledgeBaseId=knowledge_base_id,
|
||||
retrievalQuery=retrieval_query,
|
||||
retrievalConfiguration=retrieval_configuration,
|
||||
)
|
||||
|
||||
results = []
|
||||
for result in response.get("retrievalResults", []):
|
||||
results.append(
|
||||
{
|
||||
"content": result.get("content", {}).get("text", ""),
|
||||
"score": result.get("score", 0.0),
|
||||
"metadata": result.get("metadata", {}),
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
raise Exception(f"Error retrieving from knowledge base: {str(e)}")
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
line = 0
|
||||
try:
|
||||
if not self.bedrock_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.bedrock_client = boto3.client("bedrock-agent-runtime", region_name=aws_region)
|
||||
else:
|
||||
self.bedrock_client = boto3.client("bedrock-agent-runtime")
|
||||
|
||||
line = 1
|
||||
if not self.knowledge_base_id:
|
||||
self.knowledge_base_id = tool_parameters.get("knowledge_base_id")
|
||||
if not self.knowledge_base_id:
|
||||
return self.create_text_message("Please provide knowledge_base_id")
|
||||
|
||||
line = 2
|
||||
if not self.topk:
|
||||
self.topk = tool_parameters.get("topk", 5)
|
||||
|
||||
line = 3
|
||||
query = tool_parameters.get("query", "")
|
||||
if not query:
|
||||
return self.create_text_message("Please input query")
|
||||
|
||||
# 获取元数据过滤条件(如果存在)
|
||||
metadata_filter_str = tool_parameters.get("metadata_filter")
|
||||
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None
|
||||
|
||||
line = 4
|
||||
retrieved_docs = self._bedrock_retrieve(
|
||||
query_input=query,
|
||||
knowledge_base_id=self.knowledge_base_id,
|
||||
num_results=self.topk,
|
||||
metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法
|
||||
)
|
||||
|
||||
line = 5
|
||||
# Sort results by score in descending order
|
||||
sorted_docs = sorted(retrieved_docs, key=operator.itemgetter("score"), reverse=True)
|
||||
|
||||
line = 6
|
||||
return [self.create_json_message(res) for res in sorted_docs]
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception {str(e)}, line : {line}")
|
||||
|
||||
def validate_parameters(self, parameters: dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the parameters
|
||||
"""
|
||||
if not parameters.get("knowledge_base_id"):
|
||||
raise ValueError("knowledge_base_id is required")
|
||||
|
||||
if not parameters.get("query"):
|
||||
raise ValueError("query is required")
|
||||
|
||||
# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
|
||||
metadata_filter_str = parameters.get("metadata_filter")
|
||||
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
|
||||
raise ValueError("metadata_filter must be a valid JSON object")
|
||||
@ -0,0 +1,80 @@
|
||||
from typing import Any, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class S3Operator(BuiltinTool):
|
||||
s3_client: Any = None
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
# Initialize S3 client if not already done
|
||||
if not self.s3_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.s3_client = boto3.client("s3", region_name=aws_region)
|
||||
else:
|
||||
self.s3_client = boto3.client("s3")
|
||||
|
||||
# Parse S3 URI
|
||||
s3_uri = tool_parameters.get("s3_uri")
|
||||
if not s3_uri:
|
||||
return self.create_text_message("s3_uri parameter is required")
|
||||
|
||||
parsed_uri = urlparse(s3_uri)
|
||||
if parsed_uri.scheme != "s3":
|
||||
return self.create_text_message("Invalid S3 URI format. Must start with 's3://'")
|
||||
|
||||
bucket = parsed_uri.netloc
|
||||
# Remove leading slash from key
|
||||
key = parsed_uri.path.lstrip("/")
|
||||
|
||||
operation_type = tool_parameters.get("operation_type", "read")
|
||||
generate_presign_url = tool_parameters.get("generate_presign_url", False)
|
||||
presign_expiry = int(tool_parameters.get("presign_expiry", 3600)) # default 1 hour
|
||||
|
||||
if operation_type == "write":
|
||||
text_content = tool_parameters.get("text_content")
|
||||
if not text_content:
|
||||
return self.create_text_message("text_content parameter is required for write operation")
|
||||
|
||||
# Write content to S3
|
||||
self.s3_client.put_object(Bucket=bucket, Key=key, Body=text_content.encode("utf-8"))
|
||||
result = f"s3://{bucket}/{key}"
|
||||
|
||||
# Generate presigned URL for the written object if requested
|
||||
if generate_presign_url:
|
||||
result = self.s3_client.generate_presigned_url(
|
||||
"get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry
|
||||
)
|
||||
|
||||
else: # read operation
|
||||
# Get object from S3
|
||||
response = self.s3_client.get_object(Bucket=bucket, Key=key)
|
||||
result = response["Body"].read().decode("utf-8")
|
||||
|
||||
# Generate presigned URL if requested
|
||||
if generate_presign_url:
|
||||
result = self.s3_client.generate_presigned_url(
|
||||
"get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry
|
||||
)
|
||||
|
||||
return self.create_text_message(text=result)
|
||||
|
||||
except self.s3_client.exceptions.NoSuchBucket:
|
||||
return self.create_text_message(f"Bucket '{bucket}' does not exist")
|
||||
except self.s3_client.exceptions.NoSuchKey:
|
||||
return self.create_text_message(f"Object '{key}' does not exist in bucket '{bucket}'")
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception: {str(e)}")
|
||||
Loading…
Reference in New Issue