Dev to Main: Refactor LLM Interface to Byte Stream for Improved Stability

Dev to Main: Refactor LLM Interface to Byte Stream for Improved Stability
This commit is contained in:
BaiFu
2025-11-07 15:52:28 +08:00
committed by GitHub
19 changed files with 315 additions and 40 deletions
+72 -1
View File
@@ -5,7 +5,8 @@ Unified OpenAI-compatible LLM client for the Insight Engine, with retry support.
import os
import sys
from datetime import datetime
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Iterator, Generator
from loguru import logger
from openai import OpenAI
@@ -82,6 +83,76 @@ class LLMClient:
return self.validate_response(response.choices[0].message.content)
return ""
@with_retry(LLM_RETRY_CONFIG)
def stream_invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> Generator[str, None, None]:
"""
流式调用LLM,逐步返回响应内容
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
**kwargs: 额外参数(temperature, top_p等)
Yields:
响应文本块(str)
"""
current_time = datetime.now().strftime("%Y年%m月%d%H时%M分")
time_prefix = f"今天的实际时间是{current_time}"
if user_prompt:
user_prompt = f"{time_prefix}\n{user_prompt}"
else:
user_prompt = time_prefix
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
allowed_keys = {"temperature", "top_p", "presence_penalty", "frequency_penalty"}
extra_params = {key: value for key, value in kwargs.items() if key in allowed_keys and value is not None}
# 强制使用流式
extra_params["stream"] = True
timeout = kwargs.pop("timeout", self.timeout)
try:
stream = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
timeout=timeout,
**extra_params,
)
for chunk in stream:
if chunk.choices and len(chunk.choices) > 0:
delta = chunk.choices[0].delta
if delta and delta.content:
yield delta.content
except Exception as e:
logger.error(f"流式请求失败: {str(e)}")
raise e
def stream_invoke_to_string(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""
流式调用LLM并安全地拼接为完整字符串(避免UTF-8多字节字符截断)
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
**kwargs: 额外参数(temperature, top_p等)
Returns:
完整的响应字符串
"""
# 以字节形式收集所有块
byte_chunks = []
for chunk in self.stream_invoke(system_prompt, user_prompt, **kwargs):
byte_chunks.append(chunk.encode('utf-8'))
# 拼接所有字节,然后一次性解码
if byte_chunks:
return b''.join(byte_chunks).decode('utf-8', errors='replace')
return ""
@staticmethod
def validate_response(response: Optional[str]) -> str:
if response is None:
+2 -2
View File
@@ -70,8 +70,8 @@ class ReportFormattingNode(BaseNode):
logger.info("正在格式化最终报告")
# 调用LLM
response = self.llm_client.invoke(
# 调用LLM(流式,安全拼接UTF-8
response = self.llm_client.stream_invoke_to_string(
SYSTEM_PROMPT_REPORT_FORMATTING,
message,
)
+2 -2
View File
@@ -51,8 +51,8 @@ class ReportStructureNode(StateMutationNode):
try:
logger.info(f"正在为查询生成报告结构: {self.query}")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
# 调用LLM(流式,安全拼接UTF-8
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
# 处理响应
processed_response = self.process_output(response)
+4 -4
View File
@@ -65,8 +65,8 @@ class FirstSearchNode(BaseNode):
logger.info("正在生成首次搜索查询")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message)
# 调用LLM(流式,安全拼接UTF-8
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_FIRST_SEARCH, message)
# 处理响应
processed_response = self.process_output(response)
@@ -200,8 +200,8 @@ class ReflectionNode(BaseNode):
logger.info("正在进行反思并生成新搜索查询")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message)
# 调用LLM(流式,安全拼接UTF-8
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REFLECTION, message)
# 处理响应
processed_response = self.process_output(response)
+4 -4
View File
@@ -99,8 +99,8 @@ class FirstSummaryNode(StateMutationNode):
logger.info("正在生成首次段落总结")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SUMMARY, message)
# 调用LLM(流式,安全拼接UTF-8
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_FIRST_SUMMARY, message)
# 处理响应
processed_response = self.process_output(response)
@@ -264,8 +264,8 @@ class ReflectionSummaryNode(StateMutationNode):
logger.info("正在生成反思总结")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION_SUMMARY, message)
# 调用LLM(流式,安全拼接UTF-8
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REFLECTION_SUMMARY, message)
# 处理响应
processed_response = self.process_output(response)
+72 -1
View File
@@ -5,7 +5,8 @@ Unified OpenAI-compatible LLM client for the Media Engine, with retry support.
import os
import sys
from datetime import datetime
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Generator
from loguru import logger
from openai import OpenAI
@@ -85,6 +86,76 @@ class LLMClient:
return self.validate_response(response.choices[0].message.content)
return ""
@with_retry(LLM_RETRY_CONFIG)
def stream_invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> Generator[str, None, None]:
"""
流式调用LLM,逐步返回响应内容
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
**kwargs: 额外参数(temperature, top_p等)
Yields:
响应文本块(str)
"""
current_time = datetime.now().strftime("%Y年%m月%d%H时%M分")
time_prefix = f"今天的实际时间是{current_time}"
if user_prompt:
user_prompt = f"{time_prefix}\n{user_prompt}"
else:
user_prompt = time_prefix
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
allowed_keys = {"temperature", "top_p", "presence_penalty", "frequency_penalty"}
extra_params = {key: value for key, value in kwargs.items() if key in allowed_keys and value is not None}
# 强制使用流式
extra_params["stream"] = True
timeout = kwargs.pop("timeout", self.timeout)
try:
stream = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
timeout=timeout,
**extra_params,
)
for chunk in stream:
if chunk.choices and len(chunk.choices) > 0:
delta = chunk.choices[0].delta
if delta and delta.content:
yield delta.content
except Exception as e:
logger.error(f"流式请求失败: {str(e)}")
raise e
def stream_invoke_to_string(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""
流式调用LLM并安全地拼接为完整字符串(避免UTF-8多字节字符截断)
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
**kwargs: 额外参数(temperature, top_p等)
Returns:
完整的响应字符串
"""
# 以字节形式收集所有块
byte_chunks = []
for chunk in self.stream_invoke(system_prompt, user_prompt, **kwargs):
byte_chunks.append(chunk.encode('utf-8'))
# 拼接所有字节,然后一次性解码
if byte_chunks:
return b''.join(byte_chunks).decode('utf-8', errors='replace')
return ""
@staticmethod
def validate_response(response: Optional[str]) -> str:
if response is None:
+2 -2
View File
@@ -68,8 +68,8 @@ class ReportFormattingNode(BaseNode):
logger.info("正在格式化最终报告")
# 调用LLM生成Markdown格式
response = self.llm_client.invoke(
# 调用LLM生成Markdown格式(流式,安全拼接UTF-8
response = self.llm_client.stream_invoke_to_string(
SYSTEM_PROMPT_REPORT_FORMATTING,
message,
)
+1 -1
View File
@@ -52,7 +52,7 @@ class ReportStructureNode(StateMutationNode):
logger.info(f"正在为查询生成报告结构: {self.query}")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
# 处理响应
processed_response = self.process_output(response)
+2 -2
View File
@@ -66,7 +66,7 @@ class FirstSearchNode(BaseNode):
logger.info("正在生成首次搜索查询")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message)
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_FIRST_SEARCH, message)
# 处理响应
processed_response = self.process_output(response)
@@ -201,7 +201,7 @@ class ReflectionNode(BaseNode):
logger.info("正在进行反思并生成新搜索查询")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message)
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REFLECTION, message)
# 处理响应
processed_response = self.process_output(response)
+4 -4
View File
@@ -99,8 +99,8 @@ class FirstSummaryNode(StateMutationNode):
logger.info("正在生成首次段落总结")
# 调用LLM生成总结
response = self.llm_client.invoke(
# 调用LLM生成总结(流式,安全拼接UTF-8
response = self.llm_client.stream_invoke_to_string(
SYSTEM_PROMPT_FIRST_SUMMARY,
message,
)
@@ -267,8 +267,8 @@ class ReflectionSummaryNode(StateMutationNode):
logger.info("正在生成反思总结")
# 调用LLM生成总结
response = self.llm_client.invoke(
# 调用LLM生成总结(流式,安全拼接UTF-8
response = self.llm_client.stream_invoke_to_string(
SYSTEM_PROMPT_REFLECTION_SUMMARY,
message,
)
+72 -1
View File
@@ -5,7 +5,8 @@ Unified OpenAI-compatible LLM client for the Query Engine, with retry support.
import os
import sys
from datetime import datetime
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Generator
from loguru import logger
from openai import OpenAI
@@ -82,6 +83,76 @@ class LLMClient:
return self.validate_response(response.choices[0].message.content)
return ""
@with_retry(LLM_RETRY_CONFIG)
def stream_invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> Generator[str, None, None]:
"""
流式调用LLM逐步返回响应内容
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
**kwargs: 额外参数temperature, top_p等
Yields:
响应文本块str
"""
current_time = datetime.now().strftime("%Y年%m月%d%H时%M分")
time_prefix = f"今天的实际时间是{current_time}"
if user_prompt:
user_prompt = f"{time_prefix}\n{user_prompt}"
else:
user_prompt = time_prefix
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
allowed_keys = {"temperature", "top_p", "presence_penalty", "frequency_penalty"}
extra_params = {key: value for key, value in kwargs.items() if key in allowed_keys and value is not None}
# 强制使用流式
extra_params["stream"] = True
timeout = kwargs.pop("timeout", self.timeout)
try:
stream = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
timeout=timeout,
**extra_params,
)
for chunk in stream:
if chunk.choices and len(chunk.choices) > 0:
delta = chunk.choices[0].delta
if delta and delta.content:
yield delta.content
except Exception as e:
logger.error(f"流式请求失败: {str(e)}")
raise e
def stream_invoke_to_string(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""
流式调用LLM并安全地拼接为完整字符串避免UTF-8多字节字符截断
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
**kwargs: 额外参数temperature, top_p等
Returns:
完整的响应字符串
"""
# 以字节形式收集所有块
byte_chunks = []
for chunk in self.stream_invoke(system_prompt, user_prompt, **kwargs):
byte_chunks.append(chunk.encode('utf-8'))
# 拼接所有字节,然后一次性解码
if byte_chunks:
return b''.join(byte_chunks).decode('utf-8', errors='replace')
return ""
@staticmethod
def validate_response(response: Optional[str]) -> str:
if response is None:
+2 -2
View File
@@ -68,8 +68,8 @@ class ReportFormattingNode(BaseNode):
logger.info("正在格式化最终报告")
# 调用LLM生成Markdown格式
response = self.llm_client.invoke(
# 调用LLM生成Markdown格式(流式,安全拼接UTF-8
response = self.llm_client.stream_invoke_to_string(
SYSTEM_PROMPT_REPORT_FORMATTING,
message,
)
+1 -1
View File
@@ -52,7 +52,7 @@ class ReportStructureNode(StateMutationNode):
logger.info(f"正在为查询生成报告结构: {self.query}")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
# 处理响应
processed_response = self.process_output(response)
+2 -2
View File
@@ -66,7 +66,7 @@ class FirstSearchNode(BaseNode):
logger.info("正在生成首次搜索查询")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message)
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_FIRST_SEARCH, message)
# 处理响应
processed_response = self.process_output(response)
@@ -201,7 +201,7 @@ class ReflectionNode(BaseNode):
logger.info("正在进行反思并生成新搜索查询")
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message)
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REFLECTION, message)
# 处理响应
processed_response = self.process_output(response)
+4 -4
View File
@@ -99,8 +99,8 @@ class FirstSummaryNode(StateMutationNode):
logger.info("正在生成首次段落总结")
# 调用LLM生成总结
response = self.llm_client.invoke(
# 调用LLM生成总结(流式,安全拼接UTF-8
response = self.llm_client.stream_invoke_to_string(
SYSTEM_PROMPT_FIRST_SUMMARY,
message,
)
@@ -267,8 +267,8 @@ class ReflectionSummaryNode(StateMutationNode):
logger.info("正在生成反思总结")
# 调用LLM生成总结
response = self.llm_client.invoke(
# 调用LLM生成总结(流式,安全拼接UTF-8
response = self.llm_client.stream_invoke_to_string(
SYSTEM_PROMPT_REFLECTION_SUMMARY,
message,
)
+66 -1
View File
@@ -4,7 +4,8 @@ Unified OpenAI-compatible LLM client for the Report Engine, with retry support.
import os
import sys
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Generator
from loguru import logger
from openai import OpenAI
@@ -75,6 +76,70 @@ class LLMClient:
return self.validate_response(response.choices[0].message.content)
return ""
@with_retry(LLM_RETRY_CONFIG)
def stream_invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> Generator[str, None, None]:
"""
流式调用LLM逐步返回响应内容
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
**kwargs: 额外参数temperature, top_p等
Yields:
响应文本块str
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
allowed_keys = {"temperature", "top_p", "presence_penalty", "frequency_penalty"}
extra_params = {key: value for key, value in kwargs.items() if key in allowed_keys and value is not None}
# 强制使用流式
extra_params["stream"] = True
timeout = kwargs.pop("timeout", self.timeout)
try:
stream = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
timeout=timeout,
**extra_params,
)
for chunk in stream:
if chunk.choices and len(chunk.choices) > 0:
delta = chunk.choices[0].delta
if delta and delta.content:
yield delta.content
except Exception as e:
logger.error(f"流式请求失败: {str(e)}")
raise e
def stream_invoke_to_string(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
"""
流式调用LLM并安全地拼接为完整字符串避免UTF-8多字节字符截断
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
**kwargs: 额外参数temperature, top_p等
Returns:
完整的响应字符串
"""
# 以字节形式收集所有块
byte_chunks = []
for chunk in self.stream_invoke(system_prompt, user_prompt, **kwargs):
byte_chunks.append(chunk.encode('utf-8'))
# 拼接所有字节,然后一次性解码
if byte_chunks:
return b''.join(byte_chunks).decode('utf-8', errors='replace')
return ""
@staticmethod
def validate_response(response: Optional[str]) -> str:
if response is None:
+1 -1
View File
@@ -60,7 +60,7 @@ class HTMLGenerationNode(StateMutationNode):
message = json.dumps(llm_input, ensure_ascii=False, indent=2)
# 调用LLM生成HTML
response = self.llm_client.invoke(SYSTEM_PROMPT_HTML_GENERATION, message)
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_HTML_GENERATION, message)
# 处理响应(简化版)
processed_response = self.process_output(response)
@@ -115,7 +115,7 @@ class TemplateSelectionNode(BaseNode):
请根据查询内容报告内容和论坛日志的具体情况选择最合适的模板"""
# 调用LLM
response = self.llm_client.invoke(SYSTEM_PROMPT_TEMPLATE_SELECTION, user_message)
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_TEMPLATE_SELECTION, user_message)
# 检查响应是否为空
if not response or not response.strip():
+1 -4
View File
@@ -6,10 +6,7 @@ Forum日志读取工具
import re
from pathlib import Path
from typing import Optional, List, Dict
import logging
logger = logging.getLogger(__name__)
from loguru import logger
def get_latest_host_speech(log_dir: str = "logs") -> Optional[str]:
"""