support bedrock kb: retrieve and generate (#13027)
parent
59b3e672aa
commit
b2bbc28580
@ -0,0 +1,114 @@
|
|||||||
|
"""
|
||||||
|
Configuration classes for AWS Bedrock retrieve and generate API
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TextInferenceConfig:
|
||||||
|
"""Text inference configuration"""
|
||||||
|
|
||||||
|
maxTokens: Optional[int] = None
|
||||||
|
stopSequences: Optional[list[str]] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
topP: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PerformanceConfig:
|
||||||
|
"""Performance configuration"""
|
||||||
|
|
||||||
|
latency: Literal["standard", "optimized"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PromptTemplate:
|
||||||
|
"""Prompt template configuration"""
|
||||||
|
|
||||||
|
textPromptTemplate: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GuardrailConfig:
|
||||||
|
"""Guardrail configuration"""
|
||||||
|
|
||||||
|
guardrailId: str
|
||||||
|
guardrailVersion: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GenerationConfig:
|
||||||
|
"""Generation configuration"""
|
||||||
|
|
||||||
|
additionalModelRequestFields: Optional[dict[str, Any]] = None
|
||||||
|
guardrailConfiguration: Optional[GuardrailConfig] = None
|
||||||
|
inferenceConfig: Optional[dict[str, TextInferenceConfig]] = None
|
||||||
|
performanceConfig: Optional[PerformanceConfig] = None
|
||||||
|
promptTemplate: Optional[PromptTemplate] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VectorSearchConfig:
|
||||||
|
"""Vector search configuration"""
|
||||||
|
|
||||||
|
filter: Optional[dict[str, Any]] = None
|
||||||
|
numberOfResults: Optional[int] = None
|
||||||
|
overrideSearchType: Optional[Literal["HYBRID", "SEMANTIC"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetrievalConfig:
|
||||||
|
"""Retrieval configuration"""
|
||||||
|
|
||||||
|
vectorSearchConfiguration: VectorSearchConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OrchestrationConfig:
|
||||||
|
"""Orchestration configuration"""
|
||||||
|
|
||||||
|
additionalModelRequestFields: Optional[dict[str, Any]] = None
|
||||||
|
inferenceConfig: Optional[dict[str, TextInferenceConfig]] = None
|
||||||
|
performanceConfig: Optional[PerformanceConfig] = None
|
||||||
|
promptTemplate: Optional[PromptTemplate] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KnowledgeBaseConfig:
|
||||||
|
"""Knowledge base configuration"""
|
||||||
|
|
||||||
|
generationConfiguration: GenerationConfig
|
||||||
|
knowledgeBaseId: str
|
||||||
|
modelArn: str
|
||||||
|
orchestrationConfiguration: Optional[OrchestrationConfig] = None
|
||||||
|
retrievalConfiguration: Optional[RetrievalConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SessionConfig:
|
||||||
|
"""Session configuration"""
|
||||||
|
|
||||||
|
kmsKeyArn: Optional[str] = None
|
||||||
|
sessionId: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetrieveAndGenerateConfiguration:
|
||||||
|
"""Retrieve and generate configuration
|
||||||
|
The use of knowledgeBaseConfiguration or externalSourcesConfiguration depends on the type value
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str = "KNOWLEDGE_BASE"
|
||||||
|
knowledgeBaseConfiguration: Optional[KnowledgeBaseConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetrieveAndGenerateConfig:
|
||||||
|
"""Retrieve and generate main configuration"""
|
||||||
|
|
||||||
|
input: dict[str, str]
|
||||||
|
retrieveAndGenerateConfiguration: RetrieveAndGenerateConfiguration
|
||||||
|
sessionConfiguration: Optional[SessionConfig] = None
|
||||||
|
sessionId: Optional[str] = None
|
||||||
@ -0,0 +1,324 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockRetrieveAndGenerateTool(BuiltinTool):
|
||||||
|
bedrock_client: Any = None
|
||||||
|
|
||||||
|
def _create_text_inference_config(
|
||||||
|
self,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[str] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""Create text inference configuration"""
|
||||||
|
if any([max_tokens, stop_sequences, temperature, top_p]):
|
||||||
|
config = {}
|
||||||
|
if max_tokens is not None:
|
||||||
|
config["maxTokens"] = max_tokens
|
||||||
|
if stop_sequences:
|
||||||
|
try:
|
||||||
|
config["stopSequences"] = json.loads(stop_sequences)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
config["stopSequences"] = []
|
||||||
|
if temperature is not None:
|
||||||
|
config["temperature"] = temperature
|
||||||
|
if top_p is not None:
|
||||||
|
config["topP"] = top_p
|
||||||
|
return config
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _create_guardrail_config(
|
||||||
|
self,
|
||||||
|
guardrail_id: Optional[str] = None,
|
||||||
|
guardrail_version: Optional[str] = None,
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""Create guardrail configuration"""
|
||||||
|
if guardrail_id and guardrail_version:
|
||||||
|
return {"guardrailId": guardrail_id, "guardrailVersion": guardrail_version}
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _create_generation_config(
|
||||||
|
self,
|
||||||
|
additional_model_fields: Optional[str] = None,
|
||||||
|
guardrail_config: Optional[dict] = None,
|
||||||
|
text_inference_config: Optional[dict] = None,
|
||||||
|
performance_mode: Optional[str] = None,
|
||||||
|
prompt_template: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Create generation configuration"""
|
||||||
|
config = {}
|
||||||
|
|
||||||
|
if additional_model_fields:
|
||||||
|
try:
|
||||||
|
config["additionalModelRequestFields"] = json.loads(additional_model_fields)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if guardrail_config:
|
||||||
|
config["guardrailConfiguration"] = guardrail_config
|
||||||
|
|
||||||
|
if text_inference_config:
|
||||||
|
config["inferenceConfig"] = {"textInferenceConfig": text_inference_config}
|
||||||
|
|
||||||
|
if performance_mode:
|
||||||
|
config["performanceConfig"] = {"latency": performance_mode}
|
||||||
|
|
||||||
|
if prompt_template:
|
||||||
|
config["promptTemplate"] = {"textPromptTemplate": prompt_template}
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def _create_orchestration_config(
|
||||||
|
self,
|
||||||
|
orchestration_additional_model_fields: Optional[str] = None,
|
||||||
|
orchestration_text_inference_config: Optional[dict] = None,
|
||||||
|
orchestration_performance_mode: Optional[str] = None,
|
||||||
|
orchestration_prompt_template: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Create orchestration configuration"""
|
||||||
|
config = {}
|
||||||
|
|
||||||
|
if orchestration_additional_model_fields:
|
||||||
|
try:
|
||||||
|
config["additionalModelRequestFields"] = json.loads(orchestration_additional_model_fields)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if orchestration_text_inference_config:
|
||||||
|
config["inferenceConfig"] = {"textInferenceConfig": orchestration_text_inference_config}
|
||||||
|
|
||||||
|
if orchestration_performance_mode:
|
||||||
|
config["performanceConfig"] = {"latency": orchestration_performance_mode}
|
||||||
|
|
||||||
|
if orchestration_prompt_template:
|
||||||
|
config["promptTemplate"] = {"textPromptTemplate": orchestration_prompt_template}
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def _create_vector_search_config(
|
||||||
|
self,
|
||||||
|
number_of_results: int = 5,
|
||||||
|
search_type: str = "SEMANTIC",
|
||||||
|
metadata_filter: Optional[dict] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Create vector search configuration"""
|
||||||
|
config = {
|
||||||
|
"numberOfResults": number_of_results,
|
||||||
|
"overrideSearchType": search_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Only add filter if metadata_filter is not empty
|
||||||
|
if metadata_filter:
|
||||||
|
config["filter"] = metadata_filter
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def _bedrock_retrieve_and_generate(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
knowledge_base_id: str,
|
||||||
|
model_arn: str,
|
||||||
|
# Generation Configuration
|
||||||
|
additional_model_fields: Optional[str] = None,
|
||||||
|
guardrail_id: Optional[str] = None,
|
||||||
|
guardrail_version: Optional[str] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[str] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
performance_mode: str = "standard",
|
||||||
|
prompt_template: Optional[str] = None,
|
||||||
|
# Orchestration Configuration
|
||||||
|
orchestration_additional_model_fields: Optional[str] = None,
|
||||||
|
orchestration_max_tokens: Optional[int] = None,
|
||||||
|
orchestration_stop_sequences: Optional[str] = None,
|
||||||
|
orchestration_temperature: Optional[float] = None,
|
||||||
|
orchestration_top_p: Optional[float] = None,
|
||||||
|
orchestration_performance_mode: Optional[str] = None,
|
||||||
|
orchestration_prompt_template: Optional[str] = None,
|
||||||
|
# Retrieval Configuration
|
||||||
|
number_of_results: int = 5,
|
||||||
|
search_type: str = "SEMANTIC",
|
||||||
|
metadata_filter: Optional[dict] = None,
|
||||||
|
# Additional Configuration
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
try:
|
||||||
|
# Create text inference configurations
|
||||||
|
text_inference_config = self._create_text_inference_config(max_tokens, stop_sequences, temperature, top_p)
|
||||||
|
orchestration_text_inference_config = self._create_text_inference_config(
|
||||||
|
orchestration_max_tokens, orchestration_stop_sequences, orchestration_temperature, orchestration_top_p
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create guardrail configuration
|
||||||
|
guardrail_config = self._create_guardrail_config(guardrail_id, guardrail_version)
|
||||||
|
|
||||||
|
# Create vector search configuration
|
||||||
|
vector_search_config = self._create_vector_search_config(number_of_results, search_type, metadata_filter)
|
||||||
|
|
||||||
|
# Create generation configuration
|
||||||
|
generation_config = self._create_generation_config(
|
||||||
|
additional_model_fields, guardrail_config, text_inference_config, performance_mode, prompt_template
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create orchestration configuration
|
||||||
|
orchestration_config = self._create_orchestration_config(
|
||||||
|
orchestration_additional_model_fields,
|
||||||
|
orchestration_text_inference_config,
|
||||||
|
orchestration_performance_mode,
|
||||||
|
orchestration_prompt_template,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create knowledge base configuration
|
||||||
|
knowledge_base_config = {
|
||||||
|
"knowledgeBaseId": knowledge_base_id,
|
||||||
|
"modelArn": model_arn,
|
||||||
|
"generationConfiguration": generation_config,
|
||||||
|
"orchestrationConfiguration": orchestration_config,
|
||||||
|
"retrievalConfiguration": {"vectorSearchConfiguration": vector_search_config},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create request configuration
|
||||||
|
request_config = {
|
||||||
|
"input": {"text": query},
|
||||||
|
"retrieveAndGenerateConfiguration": {
|
||||||
|
"type": "KNOWLEDGE_BASE",
|
||||||
|
"knowledgeBaseConfiguration": knowledge_base_config,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add session configuration if provided
|
||||||
|
if session_id and len(session_id) >= 2:
|
||||||
|
request_config["sessionConfiguration"] = {"sessionId": session_id}
|
||||||
|
request_config["sessionId"] = session_id
|
||||||
|
|
||||||
|
# Send request
|
||||||
|
response = self.bedrock_client.retrieve_and_generate(**request_config)
|
||||||
|
|
||||||
|
# Process response
|
||||||
|
result = {"output": response.get("output", {}).get("text", ""), "citations": []}
|
||||||
|
|
||||||
|
# Process citations
|
||||||
|
for citation in response.get("citations", []):
|
||||||
|
citation_info = {
|
||||||
|
"text": citation.get("generatedResponsePart", {}).get("textResponsePart", {}).get("text", ""),
|
||||||
|
"references": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
for ref in citation.get("retrievedReferences", []):
|
||||||
|
reference = {
|
||||||
|
"content": ref.get("content", {}).get("text", ""),
|
||||||
|
"metadata": ref.get("metadata", {}),
|
||||||
|
"location": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
location = ref.get("location", {})
|
||||||
|
if location.get("type") == "S3":
|
||||||
|
reference["location"] = location.get("s3Location", {}).get("uri")
|
||||||
|
|
||||||
|
citation_info["references"].append(reference)
|
||||||
|
|
||||||
|
result["citations"].append(citation_info)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error calling Bedrock service: {str(e)}")
|
||||||
|
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
tool_parameters: dict[str, Any],
|
||||||
|
) -> ToolInvokeMessage:
|
||||||
|
try:
|
||||||
|
# Initialize Bedrock client if not already initialized
|
||||||
|
if not self.bedrock_client:
|
||||||
|
aws_region = tool_parameters.get("aws_region")
|
||||||
|
aws_access_key_id = tool_parameters.get("aws_access_key_id")
|
||||||
|
aws_secret_access_key = tool_parameters.get("aws_secret_access_key")
|
||||||
|
|
||||||
|
client_kwargs = {
|
||||||
|
"service_name": "bedrock-agent-runtime",
|
||||||
|
}
|
||||||
|
if aws_region:
|
||||||
|
client_kwargs["region_name"] = aws_region
|
||||||
|
# Only add credentials if both access key and secret key are provided
|
||||||
|
if aws_access_key_id and aws_secret_access_key:
|
||||||
|
client_kwargs.update(
|
||||||
|
{"aws_access_key_id": aws_access_key_id, "aws_secret_access_key": aws_secret_access_key}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.bedrock_client = boto3.client(**client_kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
return self.create_text_message(f"Failed to initialize Bedrock client: {str(e)}")
|
||||||
|
|
||||||
|
# Parse metadata filter if provided
|
||||||
|
metadata_filter = None
|
||||||
|
if metadata_filter_str := tool_parameters.get("metadata_filter"):
|
||||||
|
try:
|
||||||
|
parsed_filter = json.loads(metadata_filter_str)
|
||||||
|
if parsed_filter: # Only set if not empty
|
||||||
|
metadata_filter = parsed_filter
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return self.create_text_message("metadata_filter must be a valid JSON string")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self._bedrock_retrieve_and_generate(
|
||||||
|
query=tool_parameters["query"],
|
||||||
|
knowledge_base_id=tool_parameters["knowledge_base_id"],
|
||||||
|
model_arn=tool_parameters["model_arn"],
|
||||||
|
# Generation Configuration
|
||||||
|
additional_model_fields=tool_parameters.get("additional_model_fields"),
|
||||||
|
guardrail_id=tool_parameters.get("guardrail_id"),
|
||||||
|
guardrail_version=tool_parameters.get("guardrail_version"),
|
||||||
|
max_tokens=tool_parameters.get("max_tokens"),
|
||||||
|
stop_sequences=tool_parameters.get("stop_sequences"),
|
||||||
|
temperature=tool_parameters.get("temperature"),
|
||||||
|
top_p=tool_parameters.get("top_p"),
|
||||||
|
performance_mode=tool_parameters.get("performance_mode", "standard"),
|
||||||
|
prompt_template=tool_parameters.get("prompt_template"),
|
||||||
|
# Orchestration Configuration
|
||||||
|
orchestration_additional_model_fields=tool_parameters.get("orchestration_additional_model_fields"),
|
||||||
|
orchestration_max_tokens=tool_parameters.get("orchestration_max_tokens"),
|
||||||
|
orchestration_stop_sequences=tool_parameters.get("orchestration_stop_sequences"),
|
||||||
|
orchestration_temperature=tool_parameters.get("orchestration_temperature"),
|
||||||
|
orchestration_top_p=tool_parameters.get("orchestration_top_p"),
|
||||||
|
orchestration_performance_mode=tool_parameters.get("orchestration_performance_mode"),
|
||||||
|
orchestration_prompt_template=tool_parameters.get("orchestration_prompt_template"),
|
||||||
|
# Retrieval Configuration
|
||||||
|
number_of_results=tool_parameters.get("number_of_results", 5),
|
||||||
|
search_type=tool_parameters.get("search_type", "SEMANTIC"),
|
||||||
|
metadata_filter=metadata_filter,
|
||||||
|
# Additional Configuration
|
||||||
|
session_id=tool_parameters.get("session_id"),
|
||||||
|
)
|
||||||
|
return self.create_json_message(response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return self.create_text_message(f"Tool invocation error: {str(e)}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return self.create_text_message(f"Tool execution error: {str(e)}")
|
||||||
|
|
||||||
|
def validate_parameters(self, parameters: dict[str, Any]) -> None:
|
||||||
|
"""Validate the parameters"""
|
||||||
|
required_params = ["query", "model_arn", "knowledge_base_id"]
|
||||||
|
for param in required_params:
|
||||||
|
if not parameters.get(param):
|
||||||
|
raise ValueError(f"{param} is required")
|
||||||
|
|
||||||
|
# Validate metadata filter if provided
|
||||||
|
if metadata_filter_str := parameters.get("metadata_filter"):
|
||||||
|
try:
|
||||||
|
if not isinstance(json.loads(metadata_filter_str), dict):
|
||||||
|
raise ValueError("metadata_filter must be a valid JSON object")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError("metadata_filter must be a valid JSON string")
|
||||||
Loading…
Reference in New Issue