1. LLM接口改为字节级流式接口,防止超时错误,也避免utf-8长字节字符拼接错误
This commit is contained in:
@@ -5,7 +5,8 @@ Unified OpenAI-compatible LLM client for the Query Engine, with retry support.
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Generator
|
||||
from loguru import logger
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
@@ -82,6 +83,76 @@ class LLMClient:
|
||||
return self.validate_response(response.choices[0].message.content)
|
||||
return ""
|
||||
|
||||
@with_retry(LLM_RETRY_CONFIG)
|
||||
def stream_invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> Generator[str, None, None]:
|
||||
"""
|
||||
流式调用LLM,逐步返回响应内容
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户提示词
|
||||
**kwargs: 额外参数(temperature, top_p等)
|
||||
|
||||
Yields:
|
||||
响应文本块(str)
|
||||
"""
|
||||
current_time = datetime.now().strftime("%Y年%m月%d日%H时%M分")
|
||||
time_prefix = f"今天的实际时间是{current_time}"
|
||||
if user_prompt:
|
||||
user_prompt = f"{time_prefix}\n{user_prompt}"
|
||||
else:
|
||||
user_prompt = time_prefix
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
allowed_keys = {"temperature", "top_p", "presence_penalty", "frequency_penalty"}
|
||||
extra_params = {key: value for key, value in kwargs.items() if key in allowed_keys and value is not None}
|
||||
# 强制使用流式
|
||||
extra_params["stream"] = True
|
||||
|
||||
timeout = kwargs.pop("timeout", self.timeout)
|
||||
|
||||
try:
|
||||
stream = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
timeout=timeout,
|
||||
**extra_params,
|
||||
)
|
||||
|
||||
for chunk in stream:
|
||||
if chunk.choices and len(chunk.choices) > 0:
|
||||
delta = chunk.choices[0].delta
|
||||
if delta and delta.content:
|
||||
yield delta.content
|
||||
except Exception as e:
|
||||
logger.error(f"流式请求失败: {str(e)}")
|
||||
raise e
|
||||
|
||||
def stream_invoke_to_string(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
|
||||
"""
|
||||
流式调用LLM并安全地拼接为完整字符串(避免UTF-8多字节字符截断)
|
||||
|
||||
Args:
|
||||
system_prompt: 系统提示词
|
||||
user_prompt: 用户提示词
|
||||
**kwargs: 额外参数(temperature, top_p等)
|
||||
|
||||
Returns:
|
||||
完整的响应字符串
|
||||
"""
|
||||
# 以字节形式收集所有块
|
||||
byte_chunks = []
|
||||
for chunk in self.stream_invoke(system_prompt, user_prompt, **kwargs):
|
||||
byte_chunks.append(chunk.encode('utf-8'))
|
||||
|
||||
# 拼接所有字节,然后一次性解码
|
||||
if byte_chunks:
|
||||
return b''.join(byte_chunks).decode('utf-8', errors='replace')
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def validate_response(response: Optional[str]) -> str:
|
||||
if response is None:
|
||||
|
||||
Reference in New Issue
Block a user