|
|
|
|
@ -8,7 +8,8 @@ from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, Cho
|
|
|
|
|
from openai.types.chat.chat_completion_message import FunctionCall
|
|
|
|
|
|
|
|
|
|
from core.model_runtime.callbacks.base_callback import Callback
|
|
|
|
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
|
|
|
from core.model_runtime.entities.common_entities import I18nObject
|
|
|
|
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
|
|
|
|
from core.model_runtime.entities.message_entities import (
|
|
|
|
|
AssistantPromptMessage,
|
|
|
|
|
ImagePromptMessageContent,
|
|
|
|
|
@ -20,6 +21,15 @@ from core.model_runtime.entities.message_entities import (
|
|
|
|
|
ToolPromptMessage,
|
|
|
|
|
UserPromptMessage,
|
|
|
|
|
)
|
|
|
|
|
from core.model_runtime.entities.model_entities import (
|
|
|
|
|
AIModelEntity,
|
|
|
|
|
FetchFrom,
|
|
|
|
|
ModelFeature,
|
|
|
|
|
ModelPropertyKey,
|
|
|
|
|
ModelType,
|
|
|
|
|
ParameterRule,
|
|
|
|
|
ParameterType,
|
|
|
|
|
)
|
|
|
|
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
|
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
|
|
|
from core.model_runtime.model_providers.fireworks._common import _CommonFireworks
|
|
|
|
|
@ -608,3 +618,50 @@ class FireworksLargeLanguageModel(_CommonFireworks, LargeLanguageModel):
|
|
|
|
|
num_tokens += self._get_num_tokens_by_gpt2(required_field)
|
|
|
|
|
|
|
|
|
|
return num_tokens
|
|
|
|
|
|
|
|
|
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
|
|
|
|
return AIModelEntity(
|
|
|
|
|
model=model,
|
|
|
|
|
label=I18nObject(
|
|
|
|
|
en_US=credentials.get("model_label_en_US", model),
|
|
|
|
|
zh_Hans=credentials.get("model_label_zh_Hanns", model),
|
|
|
|
|
),
|
|
|
|
|
model_type=ModelType.LLM,
|
|
|
|
|
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
|
|
|
|
|
if credentials.get("function_calling_type") == "function_call"
|
|
|
|
|
else [],
|
|
|
|
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
|
|
|
|
model_properties={
|
|
|
|
|
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)),
|
|
|
|
|
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
|
|
|
|
},
|
|
|
|
|
parameter_rules=[
|
|
|
|
|
ParameterRule(
|
|
|
|
|
name="temperature",
|
|
|
|
|
use_template="temperature",
|
|
|
|
|
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
|
|
|
|
type=ParameterType.FLOAT,
|
|
|
|
|
),
|
|
|
|
|
ParameterRule(
|
|
|
|
|
name="max_tokens",
|
|
|
|
|
use_template="max_tokens",
|
|
|
|
|
default=512,
|
|
|
|
|
min=1,
|
|
|
|
|
max=int(credentials.get("max_tokens", 4096)),
|
|
|
|
|
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
|
|
|
|
|
type=ParameterType.INT,
|
|
|
|
|
),
|
|
|
|
|
ParameterRule(
|
|
|
|
|
name="top_p",
|
|
|
|
|
use_template="top_p",
|
|
|
|
|
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
|
|
|
|
|
type=ParameterType.FLOAT,
|
|
|
|
|
),
|
|
|
|
|
ParameterRule(
|
|
|
|
|
name="top_k",
|
|
|
|
|
use_template="top_k",
|
|
|
|
|
label=I18nObject(en_US="Top K", zh_Hans="Top K"),
|
|
|
|
|
type=ParameterType.FLOAT,
|
|
|
|
|
),
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
|