|
|
|
|
@ -16,9 +16,13 @@ import websocket
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SparkLLMClient:
|
|
|
|
|
def __init__(self, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
|
|
|
|
|
def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
|
|
|
|
|
|
|
|
|
|
self.api_base = "wss://spark-api.xf-yun.com/v1.1/chat" if not api_domain else ('wss://' + api_domain + '/v1.1/chat')
|
|
|
|
|
domain = 'spark-api.xf-yun.com' if not api_domain else api_domain
|
|
|
|
|
api_version = 'v2.1' if model_name == 'spark-v2' else 'v1.1'
|
|
|
|
|
|
|
|
|
|
self.chat_domain = 'generalv2' if model_name == 'spark-v2' else 'general'
|
|
|
|
|
self.api_base = f"wss://{domain}/{api_version}/chat"
|
|
|
|
|
self.app_id = app_id
|
|
|
|
|
self.ws_url = self.create_url(
|
|
|
|
|
urlparse(self.api_base).netloc,
|
|
|
|
|
@ -76,7 +80,10 @@ class SparkLLMClient:
|
|
|
|
|
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
|
|
|
|
|
|
|
|
|
def on_error(self, ws, error):
|
|
|
|
|
self.queue.put({'error': error})
|
|
|
|
|
self.queue.put({
|
|
|
|
|
'status_code': error.status_code,
|
|
|
|
|
'error': error.resp_body.decode('utf-8')
|
|
|
|
|
})
|
|
|
|
|
ws.close()
|
|
|
|
|
|
|
|
|
|
def on_close(self, ws, close_status_code, close_reason):
|
|
|
|
|
@ -120,7 +127,7 @@ class SparkLLMClient:
|
|
|
|
|
},
|
|
|
|
|
"parameter": {
|
|
|
|
|
"chat": {
|
|
|
|
|
"domain": "general"
|
|
|
|
|
"domain": self.chat_domain
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
"payload": {
|
|
|
|
|
@ -139,7 +146,14 @@ class SparkLLMClient:
|
|
|
|
|
while True:
|
|
|
|
|
content = self.queue.get()
|
|
|
|
|
if 'error' in content:
|
|
|
|
|
raise SparkError(content['error'])
|
|
|
|
|
if content['status_code'] == 401:
|
|
|
|
|
raise SparkError('[Spark] The credentials you provided are incorrect. '
|
|
|
|
|
'Please double-check and fill them in again.')
|
|
|
|
|
elif content['status_code'] == 403:
|
|
|
|
|
raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
|
|
|
|
|
"Please try again after obtaining the necessary permissions.")
|
|
|
|
|
else:
|
|
|
|
|
raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
|
|
|
|
|
|
|
|
|
|
if 'data' not in content:
|
|
|
|
|
break
|
|
|
|
|
|