|
|
|
|
@ -1,20 +1,38 @@
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
from collections.abc import Generator
|
|
|
|
|
from collections.abc import Generator, Iterator
|
|
|
|
|
from typing import Optional, Union, cast
|
|
|
|
|
|
|
|
|
|
import cohere
|
|
|
|
|
from cohere.responses import Chat, Generations
|
|
|
|
|
from cohere.responses.chat import StreamEnd, StreamingChat, StreamTextGeneration
|
|
|
|
|
from cohere.responses.generation import StreamingGenerations, StreamingText
|
|
|
|
|
from cohere import (
|
|
|
|
|
ChatMessage,
|
|
|
|
|
ChatStreamRequestToolResultsItem,
|
|
|
|
|
GenerateStreamedResponse,
|
|
|
|
|
GenerateStreamedResponse_StreamEnd,
|
|
|
|
|
GenerateStreamedResponse_StreamError,
|
|
|
|
|
GenerateStreamedResponse_TextGeneration,
|
|
|
|
|
Generation,
|
|
|
|
|
NonStreamedChatResponse,
|
|
|
|
|
StreamedChatResponse,
|
|
|
|
|
StreamedChatResponse_StreamEnd,
|
|
|
|
|
StreamedChatResponse_TextGeneration,
|
|
|
|
|
StreamedChatResponse_ToolCallsGeneration,
|
|
|
|
|
Tool,
|
|
|
|
|
ToolCall,
|
|
|
|
|
ToolParameterDefinitionsValue,
|
|
|
|
|
)
|
|
|
|
|
from cohere.core import RequestOptions
|
|
|
|
|
|
|
|
|
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
|
|
|
from core.model_runtime.entities.message_entities import (
|
|
|
|
|
AssistantPromptMessage,
|
|
|
|
|
PromptMessage,
|
|
|
|
|
PromptMessageContentType,
|
|
|
|
|
PromptMessageRole,
|
|
|
|
|
PromptMessageTool,
|
|
|
|
|
SystemPromptMessage,
|
|
|
|
|
TextPromptMessageContent,
|
|
|
|
|
ToolPromptMessage,
|
|
|
|
|
UserPromptMessage,
|
|
|
|
|
)
|
|
|
|
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
|
|
|
|
|
@ -64,6 +82,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
credentials=credentials,
|
|
|
|
|
prompt_messages=prompt_messages,
|
|
|
|
|
model_parameters=model_parameters,
|
|
|
|
|
tools=tools,
|
|
|
|
|
stop=stop,
|
|
|
|
|
stream=stream,
|
|
|
|
|
user=user
|
|
|
|
|
@ -159,19 +178,26 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
if stop:
|
|
|
|
|
model_parameters['end_sequences'] = stop
|
|
|
|
|
|
|
|
|
|
response = client.generate(
|
|
|
|
|
prompt=prompt_messages[0].content,
|
|
|
|
|
model=model,
|
|
|
|
|
stream=stream,
|
|
|
|
|
**model_parameters,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if stream:
|
|
|
|
|
response = client.generate_stream(
|
|
|
|
|
prompt=prompt_messages[0].content,
|
|
|
|
|
model=model,
|
|
|
|
|
**model_parameters,
|
|
|
|
|
request_options=RequestOptions(max_retries=0)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
|
|
|
|
else:
|
|
|
|
|
response = client.generate(
|
|
|
|
|
prompt=prompt_messages[0].content,
|
|
|
|
|
model=model,
|
|
|
|
|
**model_parameters,
|
|
|
|
|
request_options=RequestOptions(max_retries=0)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
|
|
|
|
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
|
|
|
|
|
|
|
|
|
def _handle_generate_response(self, model: str, credentials: dict, response: Generations,
|
|
|
|
|
def _handle_generate_response(self, model: str, credentials: dict, response: Generation,
|
|
|
|
|
prompt_messages: list[PromptMessage]) \
|
|
|
|
|
-> LLMResult:
|
|
|
|
|
"""
|
|
|
|
|
@ -191,8 +217,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# calculate num tokens
|
|
|
|
|
prompt_tokens = response.meta['billed_units']['input_tokens']
|
|
|
|
|
completion_tokens = response.meta['billed_units']['output_tokens']
|
|
|
|
|
prompt_tokens = int(response.meta.billed_units.input_tokens)
|
|
|
|
|
completion_tokens = int(response.meta.billed_units.output_tokens)
|
|
|
|
|
|
|
|
|
|
# transform usage
|
|
|
|
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
|
|
|
|
@ -207,7 +233,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
def _handle_generate_stream_response(self, model: str, credentials: dict, response: StreamingGenerations,
|
|
|
|
|
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Iterator[GenerateStreamedResponse],
|
|
|
|
|
prompt_messages: list[PromptMessage]) -> Generator:
|
|
|
|
|
"""
|
|
|
|
|
Handle llm stream response
|
|
|
|
|
@ -220,8 +246,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
index = 1
|
|
|
|
|
full_assistant_content = ''
|
|
|
|
|
for chunk in response:
|
|
|
|
|
if isinstance(chunk, StreamingText):
|
|
|
|
|
chunk = cast(StreamingText, chunk)
|
|
|
|
|
if isinstance(chunk, GenerateStreamedResponse_TextGeneration):
|
|
|
|
|
chunk = cast(GenerateStreamedResponse_TextGeneration, chunk)
|
|
|
|
|
text = chunk.text
|
|
|
|
|
|
|
|
|
|
if text is None:
|
|
|
|
|
@ -244,10 +270,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
index += 1
|
|
|
|
|
elif chunk is None:
|
|
|
|
|
elif isinstance(chunk, GenerateStreamedResponse_StreamEnd):
|
|
|
|
|
chunk = cast(GenerateStreamedResponse_StreamEnd, chunk)
|
|
|
|
|
|
|
|
|
|
# calculate num tokens
|
|
|
|
|
prompt_tokens = response.meta['billed_units']['input_tokens']
|
|
|
|
|
completion_tokens = response.meta['billed_units']['output_tokens']
|
|
|
|
|
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
|
|
|
|
|
completion_tokens = self._num_tokens_from_messages(
|
|
|
|
|
model,
|
|
|
|
|
credentials,
|
|
|
|
|
[AssistantPromptMessage(content=full_assistant_content)]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# transform usage
|
|
|
|
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
|
|
|
|
@ -258,14 +290,18 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
delta=LLMResultChunkDelta(
|
|
|
|
|
index=index,
|
|
|
|
|
message=AssistantPromptMessage(content=''),
|
|
|
|
|
finish_reason=response.finish_reason,
|
|
|
|
|
finish_reason=chunk.finish_reason,
|
|
|
|
|
usage=usage
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
break
|
|
|
|
|
elif isinstance(chunk, GenerateStreamedResponse_StreamError):
|
|
|
|
|
chunk = cast(GenerateStreamedResponse_StreamError, chunk)
|
|
|
|
|
raise InvokeBadRequestError(chunk.err)
|
|
|
|
|
|
|
|
|
|
def _chat_generate(self, model: str, credentials: dict,
|
|
|
|
|
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
|
|
|
|
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
|
|
|
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
|
|
|
|
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
|
|
|
|
"""
|
|
|
|
|
Invoke llm chat model
|
|
|
|
|
@ -274,6 +310,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
:param credentials: credentials
|
|
|
|
|
:param prompt_messages: prompt messages
|
|
|
|
|
:param model_parameters: model parameters
|
|
|
|
|
:param tools: tools for tool calling
|
|
|
|
|
:param stop: stop words
|
|
|
|
|
:param stream: is stream response
|
|
|
|
|
:param user: unique user id
|
|
|
|
|
@ -282,31 +319,46 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
# initialize client
|
|
|
|
|
client = cohere.Client(credentials.get('api_key'))
|
|
|
|
|
|
|
|
|
|
if user:
|
|
|
|
|
model_parameters['user_name'] = user
|
|
|
|
|
if stop:
|
|
|
|
|
model_parameters['stop_sequences'] = stop
|
|
|
|
|
|
|
|
|
|
if tools:
|
|
|
|
|
model_parameters['tools'] = self._convert_tools(tools)
|
|
|
|
|
|
|
|
|
|
message, chat_histories = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
|
|
|
|
|
message, chat_histories, tool_results \
|
|
|
|
|
= self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
|
|
|
|
|
|
|
|
|
|
if tool_results:
|
|
|
|
|
model_parameters['tool_results'] = tool_results
|
|
|
|
|
|
|
|
|
|
# chat model
|
|
|
|
|
real_model = model
|
|
|
|
|
if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL:
|
|
|
|
|
real_model = model.removesuffix('-chat')
|
|
|
|
|
|
|
|
|
|
response = client.chat(
|
|
|
|
|
message=message,
|
|
|
|
|
chat_history=chat_histories,
|
|
|
|
|
model=real_model,
|
|
|
|
|
stream=stream,
|
|
|
|
|
**model_parameters,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if stream:
|
|
|
|
|
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, stop)
|
|
|
|
|
response = client.chat_stream(
|
|
|
|
|
message=message,
|
|
|
|
|
chat_history=chat_histories,
|
|
|
|
|
model=real_model,
|
|
|
|
|
**model_parameters,
|
|
|
|
|
request_options=RequestOptions(max_retries=0)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, stop)
|
|
|
|
|
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
|
|
|
|
|
else:
|
|
|
|
|
response = client.chat(
|
|
|
|
|
message=message,
|
|
|
|
|
chat_history=chat_histories,
|
|
|
|
|
model=real_model,
|
|
|
|
|
**model_parameters,
|
|
|
|
|
request_options=RequestOptions(max_retries=0)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Chat,
|
|
|
|
|
prompt_messages: list[PromptMessage], stop: Optional[list[str]] = None) \
|
|
|
|
|
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
|
|
|
|
|
|
|
|
|
|
def _handle_chat_generate_response(self, model: str, credentials: dict, response: NonStreamedChatResponse,
|
|
|
|
|
prompt_messages: list[PromptMessage]) \
|
|
|
|
|
-> LLMResult:
|
|
|
|
|
"""
|
|
|
|
|
Handle llm chat response
|
|
|
|
|
@ -315,14 +367,27 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
:param credentials: credentials
|
|
|
|
|
:param response: response
|
|
|
|
|
:param prompt_messages: prompt messages
|
|
|
|
|
:param stop: stop words
|
|
|
|
|
:return: llm response
|
|
|
|
|
"""
|
|
|
|
|
assistant_text = response.text
|
|
|
|
|
|
|
|
|
|
tool_calls = []
|
|
|
|
|
if response.tool_calls:
|
|
|
|
|
for cohere_tool_call in response.tool_calls:
|
|
|
|
|
tool_call = AssistantPromptMessage.ToolCall(
|
|
|
|
|
id=cohere_tool_call.name,
|
|
|
|
|
type='function',
|
|
|
|
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
|
|
name=cohere_tool_call.name,
|
|
|
|
|
arguments=json.dumps(cohere_tool_call.parameters)
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
tool_calls.append(tool_call)
|
|
|
|
|
|
|
|
|
|
# transform assistant message to prompt message
|
|
|
|
|
assistant_prompt_message = AssistantPromptMessage(
|
|
|
|
|
content=assistant_text
|
|
|
|
|
content=assistant_text,
|
|
|
|
|
tool_calls=tool_calls
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# calculate num tokens
|
|
|
|
|
@ -332,44 +397,38 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
# transform usage
|
|
|
|
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
|
|
|
|
|
|
|
|
|
if stop:
|
|
|
|
|
# enforce stop tokens
|
|
|
|
|
assistant_text = self.enforce_stop_tokens(assistant_text, stop)
|
|
|
|
|
assistant_prompt_message = AssistantPromptMessage(
|
|
|
|
|
content=assistant_text
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# transform response
|
|
|
|
|
response = LLMResult(
|
|
|
|
|
model=model,
|
|
|
|
|
prompt_messages=prompt_messages,
|
|
|
|
|
message=assistant_prompt_message,
|
|
|
|
|
usage=usage,
|
|
|
|
|
system_fingerprint=response.preamble
|
|
|
|
|
usage=usage
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: StreamingChat,
|
|
|
|
|
prompt_messages: list[PromptMessage],
|
|
|
|
|
stop: Optional[list[str]] = None) -> Generator:
|
|
|
|
|
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
|
|
|
|
|
response: Iterator[StreamedChatResponse],
|
|
|
|
|
prompt_messages: list[PromptMessage]) -> Generator:
|
|
|
|
|
"""
|
|
|
|
|
Handle llm chat stream response
|
|
|
|
|
|
|
|
|
|
:param model: model name
|
|
|
|
|
:param response: response
|
|
|
|
|
:param prompt_messages: prompt messages
|
|
|
|
|
:param stop: stop words
|
|
|
|
|
:return: llm response chunk generator
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def final_response(full_text: str, index: int, finish_reason: Optional[str] = None,
|
|
|
|
|
preamble: Optional[str] = None) -> LLMResultChunk:
|
|
|
|
|
def final_response(full_text: str,
|
|
|
|
|
tool_calls: list[AssistantPromptMessage.ToolCall],
|
|
|
|
|
index: int,
|
|
|
|
|
finish_reason: Optional[str] = None) -> LLMResultChunk:
|
|
|
|
|
# calculate num tokens
|
|
|
|
|
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
|
|
|
|
|
|
|
|
|
|
full_assistant_prompt_message = AssistantPromptMessage(
|
|
|
|
|
content=full_text
|
|
|
|
|
content=full_text,
|
|
|
|
|
tool_calls=tool_calls
|
|
|
|
|
)
|
|
|
|
|
completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message])
|
|
|
|
|
|
|
|
|
|
@ -379,10 +438,9 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
return LLMResultChunk(
|
|
|
|
|
model=model,
|
|
|
|
|
prompt_messages=prompt_messages,
|
|
|
|
|
system_fingerprint=preamble,
|
|
|
|
|
delta=LLMResultChunkDelta(
|
|
|
|
|
index=index,
|
|
|
|
|
message=AssistantPromptMessage(content=''),
|
|
|
|
|
message=AssistantPromptMessage(content='', tool_calls=tool_calls),
|
|
|
|
|
finish_reason=finish_reason,
|
|
|
|
|
usage=usage
|
|
|
|
|
)
|
|
|
|
|
@ -390,9 +448,10 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
|
|
|
|
|
index = 1
|
|
|
|
|
full_assistant_content = ''
|
|
|
|
|
tool_calls = []
|
|
|
|
|
for chunk in response:
|
|
|
|
|
if isinstance(chunk, StreamTextGeneration):
|
|
|
|
|
chunk = cast(StreamTextGeneration, chunk)
|
|
|
|
|
if isinstance(chunk, StreamedChatResponse_TextGeneration):
|
|
|
|
|
chunk = cast(StreamedChatResponse_TextGeneration, chunk)
|
|
|
|
|
text = chunk.text
|
|
|
|
|
|
|
|
|
|
if text is None:
|
|
|
|
|
@ -403,12 +462,6 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
content=text
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# stop
|
|
|
|
|
# notice: This logic can only cover few stop scenarios
|
|
|
|
|
if stop and text in stop:
|
|
|
|
|
yield final_response(full_assistant_content, index, 'stop')
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
full_assistant_content += text
|
|
|
|
|
|
|
|
|
|
yield LLMResultChunk(
|
|
|
|
|
@ -421,39 +474,98 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
index += 1
|
|
|
|
|
elif isinstance(chunk, StreamEnd):
|
|
|
|
|
chunk = cast(StreamEnd, chunk)
|
|
|
|
|
yield final_response(full_assistant_content, index, chunk.finish_reason, response.preamble)
|
|
|
|
|
elif isinstance(chunk, StreamedChatResponse_ToolCallsGeneration):
|
|
|
|
|
chunk = cast(StreamedChatResponse_ToolCallsGeneration, chunk)
|
|
|
|
|
|
|
|
|
|
tool_calls = []
|
|
|
|
|
if chunk.tool_calls:
|
|
|
|
|
for cohere_tool_call in chunk.tool_calls:
|
|
|
|
|
tool_call = AssistantPromptMessage.ToolCall(
|
|
|
|
|
id=cohere_tool_call.name,
|
|
|
|
|
type='function',
|
|
|
|
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
|
|
|
|
name=cohere_tool_call.name,
|
|
|
|
|
arguments=json.dumps(cohere_tool_call.parameters)
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
tool_calls.append(tool_call)
|
|
|
|
|
elif isinstance(chunk, StreamedChatResponse_StreamEnd):
|
|
|
|
|
chunk = cast(StreamedChatResponse_StreamEnd, chunk)
|
|
|
|
|
yield final_response(full_assistant_content, tool_calls, index, chunk.finish_reason)
|
|
|
|
|
index += 1
|
|
|
|
|
|
|
|
|
|
def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \
|
|
|
|
|
-> tuple[str, list[dict]]:
|
|
|
|
|
-> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]:
|
|
|
|
|
"""
|
|
|
|
|
Convert prompt messages to message and chat histories
|
|
|
|
|
:param prompt_messages: prompt messages
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
chat_histories = []
|
|
|
|
|
latest_tool_call_n_outputs = []
|
|
|
|
|
for prompt_message in prompt_messages:
|
|
|
|
|
chat_histories.append(self._convert_prompt_message_to_dict(prompt_message))
|
|
|
|
|
if prompt_message.role == PromptMessageRole.ASSISTANT:
|
|
|
|
|
prompt_message = cast(AssistantPromptMessage, prompt_message)
|
|
|
|
|
if prompt_message.tool_calls:
|
|
|
|
|
for tool_call in prompt_message.tool_calls:
|
|
|
|
|
latest_tool_call_n_outputs.append(ChatStreamRequestToolResultsItem(
|
|
|
|
|
call=ToolCall(
|
|
|
|
|
name=tool_call.function.name,
|
|
|
|
|
parameters=json.loads(tool_call.function.arguments)
|
|
|
|
|
),
|
|
|
|
|
outputs=[]
|
|
|
|
|
))
|
|
|
|
|
else:
|
|
|
|
|
cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message)
|
|
|
|
|
if cohere_prompt_message:
|
|
|
|
|
chat_histories.append(cohere_prompt_message)
|
|
|
|
|
elif prompt_message.role == PromptMessageRole.TOOL:
|
|
|
|
|
prompt_message = cast(ToolPromptMessage, prompt_message)
|
|
|
|
|
if latest_tool_call_n_outputs:
|
|
|
|
|
i = 0
|
|
|
|
|
for tool_call_n_outputs in latest_tool_call_n_outputs:
|
|
|
|
|
if tool_call_n_outputs.call.name == prompt_message.tool_call_id:
|
|
|
|
|
latest_tool_call_n_outputs[i] = ChatStreamRequestToolResultsItem(
|
|
|
|
|
call=ToolCall(
|
|
|
|
|
name=tool_call_n_outputs.call.name,
|
|
|
|
|
parameters=tool_call_n_outputs.call.parameters
|
|
|
|
|
),
|
|
|
|
|
outputs=[{
|
|
|
|
|
"result": prompt_message.content
|
|
|
|
|
}]
|
|
|
|
|
)
|
|
|
|
|
break
|
|
|
|
|
i += 1
|
|
|
|
|
else:
|
|
|
|
|
cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message)
|
|
|
|
|
if cohere_prompt_message:
|
|
|
|
|
chat_histories.append(cohere_prompt_message)
|
|
|
|
|
|
|
|
|
|
if latest_tool_call_n_outputs:
|
|
|
|
|
new_latest_tool_call_n_outputs = []
|
|
|
|
|
for tool_call_n_outputs in latest_tool_call_n_outputs:
|
|
|
|
|
if tool_call_n_outputs.outputs:
|
|
|
|
|
new_latest_tool_call_n_outputs.append(tool_call_n_outputs)
|
|
|
|
|
|
|
|
|
|
latest_tool_call_n_outputs = new_latest_tool_call_n_outputs
|
|
|
|
|
|
|
|
|
|
# get latest message from chat histories and pop it
|
|
|
|
|
if len(chat_histories) > 0:
|
|
|
|
|
latest_message = chat_histories.pop()
|
|
|
|
|
message = latest_message['message']
|
|
|
|
|
message = latest_message.message
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError('Prompt messages is empty')
|
|
|
|
|
|
|
|
|
|
return message, chat_histories
|
|
|
|
|
return message, chat_histories, latest_tool_call_n_outputs
|
|
|
|
|
|
|
|
|
|
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
|
|
|
|
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> Optional[ChatMessage]:
|
|
|
|
|
"""
|
|
|
|
|
Convert PromptMessage to dict for Cohere model
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(message, UserPromptMessage):
|
|
|
|
|
message = cast(UserPromptMessage, message)
|
|
|
|
|
if isinstance(message.content, str):
|
|
|
|
|
message_dict = {"role": "USER", "message": message.content}
|
|
|
|
|
chat_message = ChatMessage(role="USER", message=message.content)
|
|
|
|
|
else:
|
|
|
|
|
sub_message_text = ''
|
|
|
|
|
for message_content in message.content:
|
|
|
|
|
@ -461,20 +573,57 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
message_content = cast(TextPromptMessageContent, message_content)
|
|
|
|
|
sub_message_text += message_content.data
|
|
|
|
|
|
|
|
|
|
message_dict = {"role": "USER", "message": sub_message_text}
|
|
|
|
|
chat_message = ChatMessage(role="USER", message=sub_message_text)
|
|
|
|
|
elif isinstance(message, AssistantPromptMessage):
|
|
|
|
|
message = cast(AssistantPromptMessage, message)
|
|
|
|
|
message_dict = {"role": "CHATBOT", "message": message.content}
|
|
|
|
|
if not message.content:
|
|
|
|
|
return None
|
|
|
|
|
chat_message = ChatMessage(role="CHATBOT", message=message.content)
|
|
|
|
|
elif isinstance(message, SystemPromptMessage):
|
|
|
|
|
message = cast(SystemPromptMessage, message)
|
|
|
|
|
message_dict = {"role": "USER", "message": message.content}
|
|
|
|
|
chat_message = ChatMessage(role="USER", message=message.content)
|
|
|
|
|
elif isinstance(message, ToolPromptMessage):
|
|
|
|
|
return None
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Got unknown type {message}")
|
|
|
|
|
|
|
|
|
|
if message.name:
|
|
|
|
|
message_dict["user_name"] = message.name
|
|
|
|
|
return chat_message
|
|
|
|
|
|
|
|
|
|
def _convert_tools(self, tools: list[PromptMessageTool]) -> list[Tool]:
|
|
|
|
|
"""
|
|
|
|
|
Convert tools to Cohere model
|
|
|
|
|
"""
|
|
|
|
|
cohere_tools = []
|
|
|
|
|
for tool in tools:
|
|
|
|
|
properties = tool.parameters['properties']
|
|
|
|
|
required_properties = tool.parameters['required']
|
|
|
|
|
|
|
|
|
|
parameter_definitions = {}
|
|
|
|
|
for p_key, p_val in properties.items():
|
|
|
|
|
required = False
|
|
|
|
|
if property in required_properties:
|
|
|
|
|
required = True
|
|
|
|
|
|
|
|
|
|
desc = p_val['description']
|
|
|
|
|
if 'enum' in p_val:
|
|
|
|
|
desc += (f"; Only accepts one of the following predefined options: "
|
|
|
|
|
f"[{', '.join(p_val['enum'])}]")
|
|
|
|
|
|
|
|
|
|
parameter_definitions[p_key] = ToolParameterDefinitionsValue(
|
|
|
|
|
description=desc,
|
|
|
|
|
type=p_val['type'],
|
|
|
|
|
required=required
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return message_dict
|
|
|
|
|
cohere_tool = Tool(
|
|
|
|
|
name=tool.name,
|
|
|
|
|
description=tool.description,
|
|
|
|
|
parameter_definitions=parameter_definitions
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
cohere_tools.append(cohere_tool)
|
|
|
|
|
|
|
|
|
|
return cohere_tools
|
|
|
|
|
|
|
|
|
|
def _num_tokens_from_string(self, model: str, credentials: dict, text: str) -> int:
|
|
|
|
|
"""
|
|
|
|
|
@ -493,12 +642,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
model=model
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return response.length
|
|
|
|
|
return len(response.tokens)
|
|
|
|
|
|
|
|
|
|
def _num_tokens_from_messages(self, model: str, credentials: dict, messages: list[PromptMessage]) -> int:
|
|
|
|
|
"""Calculate num tokens Cohere model."""
|
|
|
|
|
messages = [self._convert_prompt_message_to_dict(m) for m in messages]
|
|
|
|
|
message_strs = [f"{message['role']}: {message['message']}" for message in messages]
|
|
|
|
|
calc_messages = []
|
|
|
|
|
for message in messages:
|
|
|
|
|
cohere_message = self._convert_prompt_message_to_dict(message)
|
|
|
|
|
if cohere_message:
|
|
|
|
|
calc_messages.append(cohere_message)
|
|
|
|
|
message_strs = [f"{message.role}: {message.message}" for message in calc_messages]
|
|
|
|
|
message_str = "\n".join(message_strs)
|
|
|
|
|
|
|
|
|
|
real_model = model
|
|
|
|
|
@ -564,13 +717,21 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|
|
|
|
"""
|
|
|
|
|
return {
|
|
|
|
|
InvokeConnectionError: [
|
|
|
|
|
cohere.CohereConnectionError
|
|
|
|
|
cohere.errors.service_unavailable_error.ServiceUnavailableError
|
|
|
|
|
],
|
|
|
|
|
InvokeServerUnavailableError: [
|
|
|
|
|
cohere.errors.internal_server_error.InternalServerError
|
|
|
|
|
],
|
|
|
|
|
InvokeRateLimitError: [
|
|
|
|
|
cohere.errors.too_many_requests_error.TooManyRequestsError
|
|
|
|
|
],
|
|
|
|
|
InvokeAuthorizationError: [
|
|
|
|
|
cohere.errors.unauthorized_error.UnauthorizedError,
|
|
|
|
|
cohere.errors.forbidden_error.ForbiddenError
|
|
|
|
|
],
|
|
|
|
|
InvokeServerUnavailableError: [],
|
|
|
|
|
InvokeRateLimitError: [],
|
|
|
|
|
InvokeAuthorizationError: [],
|
|
|
|
|
InvokeBadRequestError: [
|
|
|
|
|
cohere.CohereAPIError,
|
|
|
|
|
cohere.CohereError,
|
|
|
|
|
cohere.core.api_error.ApiError,
|
|
|
|
|
cohere.errors.bad_request_error.BadRequestError,
|
|
|
|
|
cohere.errors.not_found_error.NotFoundError,
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
|