|
|
|
|
@ -84,8 +84,9 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
Model class for Cohere large language model.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
sagemaker_client: Any = None
|
|
|
|
|
sagemaker_session: Any = None
|
|
|
|
|
predictor: Any = None
|
|
|
|
|
sagemaker_endpoint: str = None
|
|
|
|
|
|
|
|
|
|
def _handle_chat_generate_response(
|
|
|
|
|
self,
|
|
|
|
|
@ -211,7 +212,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
:param user: unique user id
|
|
|
|
|
:return: full response or stream response chunk generator result
|
|
|
|
|
"""
|
|
|
|
|
if not self.sagemaker_client:
|
|
|
|
|
if not self.sagemaker_session:
|
|
|
|
|
access_key = credentials.get("aws_access_key_id")
|
|
|
|
|
secret_key = credentials.get("aws_secret_access_key")
|
|
|
|
|
aws_region = credentials.get("aws_region")
|
|
|
|
|
@ -226,11 +227,14 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
else:
|
|
|
|
|
boto_session = boto3.Session()
|
|
|
|
|
|
|
|
|
|
self.sagemaker_client = boto_session.client("sagemaker")
|
|
|
|
|
sagemaker_session = Session(boto_session=boto_session, sagemaker_client=self.sagemaker_client)
|
|
|
|
|
sagemaker_client = boto_session.client("sagemaker")
|
|
|
|
|
self.sagemaker_session = Session(boto_session=boto_session, sagemaker_client=sagemaker_client)
|
|
|
|
|
|
|
|
|
|
if self.sagemaker_endpoint != credentials.get("sagemaker_endpoint"):
|
|
|
|
|
self.sagemaker_endpoint = credentials.get("sagemaker_endpoint")
|
|
|
|
|
self.predictor = Predictor(
|
|
|
|
|
endpoint_name=credentials.get("sagemaker_endpoint"),
|
|
|
|
|
sagemaker_session=sagemaker_session,
|
|
|
|
|
endpoint_name=self.sagemaker_endpoint,
|
|
|
|
|
sagemaker_session=self.sagemaker_session,
|
|
|
|
|
serializer=serializers.JSONSerializer(),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|