fix: better gard nan value from numpy for issue #11827 (#11864)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>
pull/11879/head
yihong 1 year ago committed by GitHub
parent 95a7e50137
commit 463fbe2680
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -92,7 +92,10 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding
# calc usage
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)

@ -88,7 +88,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding
# calc usage
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)

@ -97,7 +97,10 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding
# calc usage
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)

@ -100,7 +100,10 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel):
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)

@ -116,6 +116,8 @@ class CacheEmbedding(Embeddings):
embedding_results = embedding_result.embeddings[0]
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
if np.isnan(embedding_results).any():
raise ValueError("Normalized embedding is nan please try again")
except Exception as ex:
if dify_config.DEBUG:
logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'")

Loading…
Cancel
Save