1. LLM接口改为字节级流式接口,防止超时错误,也避免utf-8长字节字符拼接错误
This commit is contained in:
@@ -5,7 +5,8 @@ Unified OpenAI-compatible LLM client for the Insight Engine, with retry support.
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, Iterator, Generator
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
@@ -82,6 +83,76 @@ class LLMClient:
|
|||||||
return self.validate_response(response.choices[0].message.content)
|
return self.validate_response(response.choices[0].message.content)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
@with_retry(LLM_RETRY_CONFIG)
|
||||||
|
def stream_invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> Generator[str, None, None]:
|
||||||
|
"""
|
||||||
|
流式调用LLM,逐步返回响应内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt: 系统提示词
|
||||||
|
user_prompt: 用户提示词
|
||||||
|
**kwargs: 额外参数(temperature, top_p等)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
响应文本块(str)
|
||||||
|
"""
|
||||||
|
current_time = datetime.now().strftime("%Y年%m月%d日%H时%M分")
|
||||||
|
time_prefix = f"今天的实际时间是{current_time}"
|
||||||
|
if user_prompt:
|
||||||
|
user_prompt = f"{time_prefix}\n{user_prompt}"
|
||||||
|
else:
|
||||||
|
user_prompt = time_prefix
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
allowed_keys = {"temperature", "top_p", "presence_penalty", "frequency_penalty"}
|
||||||
|
extra_params = {key: value for key, value in kwargs.items() if key in allowed_keys and value is not None}
|
||||||
|
# 强制使用流式
|
||||||
|
extra_params["stream"] = True
|
||||||
|
|
||||||
|
timeout = kwargs.pop("timeout", self.timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stream = self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=messages,
|
||||||
|
timeout=timeout,
|
||||||
|
**extra_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
if chunk.choices and len(chunk.choices) > 0:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if delta and delta.content:
|
||||||
|
yield delta.content
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"流式请求失败: {str(e)}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def stream_invoke_to_string(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
流式调用LLM并安全地拼接为完整字符串(避免UTF-8多字节字符截断)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt: 系统提示词
|
||||||
|
user_prompt: 用户提示词
|
||||||
|
**kwargs: 额外参数(temperature, top_p等)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
完整的响应字符串
|
||||||
|
"""
|
||||||
|
# 以字节形式收集所有块
|
||||||
|
byte_chunks = []
|
||||||
|
for chunk in self.stream_invoke(system_prompt, user_prompt, **kwargs):
|
||||||
|
byte_chunks.append(chunk.encode('utf-8'))
|
||||||
|
|
||||||
|
# 拼接所有字节,然后一次性解码
|
||||||
|
if byte_chunks:
|
||||||
|
return b''.join(byte_chunks).decode('utf-8', errors='replace')
|
||||||
|
return ""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def validate_response(response: Optional[str]) -> str:
|
def validate_response(response: Optional[str]) -> str:
|
||||||
if response is None:
|
if response is None:
|
||||||
|
|||||||
@@ -70,8 +70,8 @@ class ReportFormattingNode(BaseNode):
|
|||||||
|
|
||||||
logger.info("正在格式化最终报告")
|
logger.info("正在格式化最终报告")
|
||||||
|
|
||||||
# 调用LLM
|
# 调用LLM(流式,安全拼接UTF-8)
|
||||||
response = self.llm_client.invoke(
|
response = self.llm_client.stream_invoke_to_string(
|
||||||
SYSTEM_PROMPT_REPORT_FORMATTING,
|
SYSTEM_PROMPT_REPORT_FORMATTING,
|
||||||
message,
|
message,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -51,8 +51,8 @@ class ReportStructureNode(StateMutationNode):
|
|||||||
try:
|
try:
|
||||||
logger.info(f"正在为查询生成报告结构: {self.query}")
|
logger.info(f"正在为查询生成报告结构: {self.query}")
|
||||||
|
|
||||||
# 调用LLM
|
# 调用LLM(流式,安全拼接UTF-8)
|
||||||
response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
|
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
|
||||||
|
|
||||||
# 处理响应
|
# 处理响应
|
||||||
processed_response = self.process_output(response)
|
processed_response = self.process_output(response)
|
||||||
|
|||||||
@@ -65,8 +65,8 @@ class FirstSearchNode(BaseNode):
|
|||||||
|
|
||||||
logger.info("正在生成首次搜索查询")
|
logger.info("正在生成首次搜索查询")
|
||||||
|
|
||||||
# 调用LLM
|
# 调用LLM(流式,安全拼接UTF-8)
|
||||||
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message)
|
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_FIRST_SEARCH, message)
|
||||||
|
|
||||||
# 处理响应
|
# 处理响应
|
||||||
processed_response = self.process_output(response)
|
processed_response = self.process_output(response)
|
||||||
@@ -200,8 +200,8 @@ class ReflectionNode(BaseNode):
|
|||||||
|
|
||||||
logger.info("正在进行反思并生成新搜索查询")
|
logger.info("正在进行反思并生成新搜索查询")
|
||||||
|
|
||||||
# 调用LLM
|
# 调用LLM(流式,安全拼接UTF-8)
|
||||||
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message)
|
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REFLECTION, message)
|
||||||
|
|
||||||
# 处理响应
|
# 处理响应
|
||||||
processed_response = self.process_output(response)
|
processed_response = self.process_output(response)
|
||||||
|
|||||||
@@ -99,8 +99,8 @@ class FirstSummaryNode(StateMutationNode):
|
|||||||
|
|
||||||
logger.info("正在生成首次段落总结")
|
logger.info("正在生成首次段落总结")
|
||||||
|
|
||||||
# 调用LLM
|
# 调用LLM(流式,安全拼接UTF-8)
|
||||||
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SUMMARY, message)
|
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_FIRST_SUMMARY, message)
|
||||||
|
|
||||||
# 处理响应
|
# 处理响应
|
||||||
processed_response = self.process_output(response)
|
processed_response = self.process_output(response)
|
||||||
@@ -264,8 +264,8 @@ class ReflectionSummaryNode(StateMutationNode):
|
|||||||
|
|
||||||
logger.info("正在生成反思总结")
|
logger.info("正在生成反思总结")
|
||||||
|
|
||||||
# 调用LLM
|
# 调用LLM(流式,安全拼接UTF-8)
|
||||||
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION_SUMMARY, message)
|
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REFLECTION_SUMMARY, message)
|
||||||
|
|
||||||
# 处理响应
|
# 处理响应
|
||||||
processed_response = self.process_output(response)
|
processed_response = self.process_output(response)
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ Unified OpenAI-compatible LLM client for the Media Engine, with retry support.
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, Generator
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
@@ -85,6 +86,76 @@ class LLMClient:
|
|||||||
return self.validate_response(response.choices[0].message.content)
|
return self.validate_response(response.choices[0].message.content)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
@with_retry(LLM_RETRY_CONFIG)
|
||||||
|
def stream_invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> Generator[str, None, None]:
|
||||||
|
"""
|
||||||
|
流式调用LLM,逐步返回响应内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt: 系统提示词
|
||||||
|
user_prompt: 用户提示词
|
||||||
|
**kwargs: 额外参数(temperature, top_p等)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
响应文本块(str)
|
||||||
|
"""
|
||||||
|
current_time = datetime.now().strftime("%Y年%m月%d日%H时%M分")
|
||||||
|
time_prefix = f"今天的实际时间是{current_time}"
|
||||||
|
if user_prompt:
|
||||||
|
user_prompt = f"{time_prefix}\n{user_prompt}"
|
||||||
|
else:
|
||||||
|
user_prompt = time_prefix
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
allowed_keys = {"temperature", "top_p", "presence_penalty", "frequency_penalty"}
|
||||||
|
extra_params = {key: value for key, value in kwargs.items() if key in allowed_keys and value is not None}
|
||||||
|
# 强制使用流式
|
||||||
|
extra_params["stream"] = True
|
||||||
|
|
||||||
|
timeout = kwargs.pop("timeout", self.timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stream = self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=messages,
|
||||||
|
timeout=timeout,
|
||||||
|
**extra_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
if chunk.choices and len(chunk.choices) > 0:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if delta and delta.content:
|
||||||
|
yield delta.content
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"流式请求失败: {str(e)}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def stream_invoke_to_string(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
流式调用LLM并安全地拼接为完整字符串(避免UTF-8多字节字符截断)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt: 系统提示词
|
||||||
|
user_prompt: 用户提示词
|
||||||
|
**kwargs: 额外参数(temperature, top_p等)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
完整的响应字符串
|
||||||
|
"""
|
||||||
|
# 以字节形式收集所有块
|
||||||
|
byte_chunks = []
|
||||||
|
for chunk in self.stream_invoke(system_prompt, user_prompt, **kwargs):
|
||||||
|
byte_chunks.append(chunk.encode('utf-8'))
|
||||||
|
|
||||||
|
# 拼接所有字节,然后一次性解码
|
||||||
|
if byte_chunks:
|
||||||
|
return b''.join(byte_chunks).decode('utf-8', errors='replace')
|
||||||
|
return ""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def validate_response(response: Optional[str]) -> str:
|
def validate_response(response: Optional[str]) -> str:
|
||||||
if response is None:
|
if response is None:
|
||||||
|
|||||||
@@ -68,8 +68,8 @@ class ReportFormattingNode(BaseNode):
|
|||||||
|
|
||||||
logger.info("正在格式化最终报告")
|
logger.info("正在格式化最终报告")
|
||||||
|
|
||||||
# 调用LLM生成Markdown格式
|
# 调用LLM生成Markdown格式(流式,安全拼接UTF-8)
|
||||||
response = self.llm_client.invoke(
|
response = self.llm_client.stream_invoke_to_string(
|
||||||
SYSTEM_PROMPT_REPORT_FORMATTING,
|
SYSTEM_PROMPT_REPORT_FORMATTING,
|
||||||
message,
|
message,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ class ReportStructureNode(StateMutationNode):
|
|||||||
logger.info(f"正在为查询生成报告结构: {self.query}")
|
logger.info(f"正在为查询生成报告结构: {self.query}")
|
||||||
|
|
||||||
# 调用LLM
|
# 调用LLM
|
||||||
response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
|
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
|
||||||
|
|
||||||
# 处理响应
|
# 处理响应
|
||||||
processed_response = self.process_output(response)
|
processed_response = self.process_output(response)
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ class FirstSearchNode(BaseNode):
|
|||||||
logger.info("正在生成首次搜索查询")
|
logger.info("正在生成首次搜索查询")
|
||||||
|
|
||||||
# 调用LLM
|
# 调用LLM
|
||||||
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message)
|
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_FIRST_SEARCH, message)
|
||||||
|
|
||||||
# 处理响应
|
# 处理响应
|
||||||
processed_response = self.process_output(response)
|
processed_response = self.process_output(response)
|
||||||
@@ -201,7 +201,7 @@ class ReflectionNode(BaseNode):
|
|||||||
logger.info("正在进行反思并生成新搜索查询")
|
logger.info("正在进行反思并生成新搜索查询")
|
||||||
|
|
||||||
# 调用LLM
|
# 调用LLM
|
||||||
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message)
|
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REFLECTION, message)
|
||||||
|
|
||||||
# 处理响应
|
# 处理响应
|
||||||
processed_response = self.process_output(response)
|
processed_response = self.process_output(response)
|
||||||
|
|||||||
@@ -99,8 +99,8 @@ class FirstSummaryNode(StateMutationNode):
|
|||||||
|
|
||||||
logger.info("正在生成首次段落总结")
|
logger.info("正在生成首次段落总结")
|
||||||
|
|
||||||
# 调用LLM生成总结
|
# 调用LLM生成总结(流式,安全拼接UTF-8)
|
||||||
response = self.llm_client.invoke(
|
response = self.llm_client.stream_invoke_to_string(
|
||||||
SYSTEM_PROMPT_FIRST_SUMMARY,
|
SYSTEM_PROMPT_FIRST_SUMMARY,
|
||||||
message,
|
message,
|
||||||
)
|
)
|
||||||
@@ -267,8 +267,8 @@ class ReflectionSummaryNode(StateMutationNode):
|
|||||||
|
|
||||||
logger.info("正在生成反思总结")
|
logger.info("正在生成反思总结")
|
||||||
|
|
||||||
# 调用LLM生成总结
|
# 调用LLM生成总结(流式,安全拼接UTF-8)
|
||||||
response = self.llm_client.invoke(
|
response = self.llm_client.stream_invoke_to_string(
|
||||||
SYSTEM_PROMPT_REFLECTION_SUMMARY,
|
SYSTEM_PROMPT_REFLECTION_SUMMARY,
|
||||||
message,
|
message,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ Unified OpenAI-compatible LLM client for the Query Engine, with retry support.
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, Generator
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
@@ -82,6 +83,76 @@ class LLMClient:
|
|||||||
return self.validate_response(response.choices[0].message.content)
|
return self.validate_response(response.choices[0].message.content)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
@with_retry(LLM_RETRY_CONFIG)
|
||||||
|
def stream_invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> Generator[str, None, None]:
|
||||||
|
"""
|
||||||
|
流式调用LLM,逐步返回响应内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt: 系统提示词
|
||||||
|
user_prompt: 用户提示词
|
||||||
|
**kwargs: 额外参数(temperature, top_p等)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
响应文本块(str)
|
||||||
|
"""
|
||||||
|
current_time = datetime.now().strftime("%Y年%m月%d日%H时%M分")
|
||||||
|
time_prefix = f"今天的实际时间是{current_time}"
|
||||||
|
if user_prompt:
|
||||||
|
user_prompt = f"{time_prefix}\n{user_prompt}"
|
||||||
|
else:
|
||||||
|
user_prompt = time_prefix
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
allowed_keys = {"temperature", "top_p", "presence_penalty", "frequency_penalty"}
|
||||||
|
extra_params = {key: value for key, value in kwargs.items() if key in allowed_keys and value is not None}
|
||||||
|
# 强制使用流式
|
||||||
|
extra_params["stream"] = True
|
||||||
|
|
||||||
|
timeout = kwargs.pop("timeout", self.timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stream = self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=messages,
|
||||||
|
timeout=timeout,
|
||||||
|
**extra_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
if chunk.choices and len(chunk.choices) > 0:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if delta and delta.content:
|
||||||
|
yield delta.content
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"流式请求失败: {str(e)}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def stream_invoke_to_string(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
流式调用LLM并安全地拼接为完整字符串(避免UTF-8多字节字符截断)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt: 系统提示词
|
||||||
|
user_prompt: 用户提示词
|
||||||
|
**kwargs: 额外参数(temperature, top_p等)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
完整的响应字符串
|
||||||
|
"""
|
||||||
|
# 以字节形式收集所有块
|
||||||
|
byte_chunks = []
|
||||||
|
for chunk in self.stream_invoke(system_prompt, user_prompt, **kwargs):
|
||||||
|
byte_chunks.append(chunk.encode('utf-8'))
|
||||||
|
|
||||||
|
# 拼接所有字节,然后一次性解码
|
||||||
|
if byte_chunks:
|
||||||
|
return b''.join(byte_chunks).decode('utf-8', errors='replace')
|
||||||
|
return ""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def validate_response(response: Optional[str]) -> str:
|
def validate_response(response: Optional[str]) -> str:
|
||||||
if response is None:
|
if response is None:
|
||||||
|
|||||||
@@ -68,8 +68,8 @@ class ReportFormattingNode(BaseNode):
|
|||||||
|
|
||||||
logger.info("正在格式化最终报告")
|
logger.info("正在格式化最终报告")
|
||||||
|
|
||||||
# 调用LLM生成Markdown格式
|
# 调用LLM生成Markdown格式(流式,安全拼接UTF-8)
|
||||||
response = self.llm_client.invoke(
|
response = self.llm_client.stream_invoke_to_string(
|
||||||
SYSTEM_PROMPT_REPORT_FORMATTING,
|
SYSTEM_PROMPT_REPORT_FORMATTING,
|
||||||
message,
|
message,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ class ReportStructureNode(StateMutationNode):
|
|||||||
logger.info(f"正在为查询生成报告结构: {self.query}")
|
logger.info(f"正在为查询生成报告结构: {self.query}")
|
||||||
|
|
||||||
# 调用LLM
|
# 调用LLM
|
||||||
response = self.llm_client.invoke(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
|
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REPORT_STRUCTURE, self.query)
|
||||||
|
|
||||||
# 处理响应
|
# 处理响应
|
||||||
processed_response = self.process_output(response)
|
processed_response = self.process_output(response)
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ class FirstSearchNode(BaseNode):
|
|||||||
logger.info("正在生成首次搜索查询")
|
logger.info("正在生成首次搜索查询")
|
||||||
|
|
||||||
# 调用LLM
|
# 调用LLM
|
||||||
response = self.llm_client.invoke(SYSTEM_PROMPT_FIRST_SEARCH, message)
|
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_FIRST_SEARCH, message)
|
||||||
|
|
||||||
# 处理响应
|
# 处理响应
|
||||||
processed_response = self.process_output(response)
|
processed_response = self.process_output(response)
|
||||||
@@ -201,7 +201,7 @@ class ReflectionNode(BaseNode):
|
|||||||
logger.info("正在进行反思并生成新搜索查询")
|
logger.info("正在进行反思并生成新搜索查询")
|
||||||
|
|
||||||
# 调用LLM
|
# 调用LLM
|
||||||
response = self.llm_client.invoke(SYSTEM_PROMPT_REFLECTION, message)
|
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_REFLECTION, message)
|
||||||
|
|
||||||
# 处理响应
|
# 处理响应
|
||||||
processed_response = self.process_output(response)
|
processed_response = self.process_output(response)
|
||||||
|
|||||||
@@ -99,8 +99,8 @@ class FirstSummaryNode(StateMutationNode):
|
|||||||
|
|
||||||
logger.info("正在生成首次段落总结")
|
logger.info("正在生成首次段落总结")
|
||||||
|
|
||||||
# 调用LLM生成总结
|
# 调用LLM生成总结(流式,安全拼接UTF-8)
|
||||||
response = self.llm_client.invoke(
|
response = self.llm_client.stream_invoke_to_string(
|
||||||
SYSTEM_PROMPT_FIRST_SUMMARY,
|
SYSTEM_PROMPT_FIRST_SUMMARY,
|
||||||
message,
|
message,
|
||||||
)
|
)
|
||||||
@@ -267,8 +267,8 @@ class ReflectionSummaryNode(StateMutationNode):
|
|||||||
|
|
||||||
logger.info("正在生成反思总结")
|
logger.info("正在生成反思总结")
|
||||||
|
|
||||||
# 调用LLM生成总结
|
# 调用LLM生成总结(流式,安全拼接UTF-8)
|
||||||
response = self.llm_client.invoke(
|
response = self.llm_client.stream_invoke_to_string(
|
||||||
SYSTEM_PROMPT_REFLECTION_SUMMARY,
|
SYSTEM_PROMPT_REFLECTION_SUMMARY,
|
||||||
message,
|
message,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ Unified OpenAI-compatible LLM client for the Report Engine, with retry support.
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional, Generator
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
@@ -75,6 +76,70 @@ class LLMClient:
|
|||||||
return self.validate_response(response.choices[0].message.content)
|
return self.validate_response(response.choices[0].message.content)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
@with_retry(LLM_RETRY_CONFIG)
|
||||||
|
def stream_invoke(self, system_prompt: str, user_prompt: str, **kwargs) -> Generator[str, None, None]:
|
||||||
|
"""
|
||||||
|
流式调用LLM,逐步返回响应内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt: 系统提示词
|
||||||
|
user_prompt: 用户提示词
|
||||||
|
**kwargs: 额外参数(temperature, top_p等)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
响应文本块(str)
|
||||||
|
"""
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
allowed_keys = {"temperature", "top_p", "presence_penalty", "frequency_penalty"}
|
||||||
|
extra_params = {key: value for key, value in kwargs.items() if key in allowed_keys and value is not None}
|
||||||
|
# 强制使用流式
|
||||||
|
extra_params["stream"] = True
|
||||||
|
|
||||||
|
timeout = kwargs.pop("timeout", self.timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stream = self.client.chat.completions.create(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=messages,
|
||||||
|
timeout=timeout,
|
||||||
|
**extra_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in stream:
|
||||||
|
if chunk.choices and len(chunk.choices) > 0:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if delta and delta.content:
|
||||||
|
yield delta.content
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"流式请求失败: {str(e)}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def stream_invoke_to_string(self, system_prompt: str, user_prompt: str, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
流式调用LLM并安全地拼接为完整字符串(避免UTF-8多字节字符截断)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt: 系统提示词
|
||||||
|
user_prompt: 用户提示词
|
||||||
|
**kwargs: 额外参数(temperature, top_p等)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
完整的响应字符串
|
||||||
|
"""
|
||||||
|
# 以字节形式收集所有块
|
||||||
|
byte_chunks = []
|
||||||
|
for chunk in self.stream_invoke(system_prompt, user_prompt, **kwargs):
|
||||||
|
byte_chunks.append(chunk.encode('utf-8'))
|
||||||
|
|
||||||
|
# 拼接所有字节,然后一次性解码
|
||||||
|
if byte_chunks:
|
||||||
|
return b''.join(byte_chunks).decode('utf-8', errors='replace')
|
||||||
|
return ""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def validate_response(response: Optional[str]) -> str:
|
def validate_response(response: Optional[str]) -> str:
|
||||||
if response is None:
|
if response is None:
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ class HTMLGenerationNode(StateMutationNode):
|
|||||||
message = json.dumps(llm_input, ensure_ascii=False, indent=2)
|
message = json.dumps(llm_input, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
# 调用LLM生成HTML
|
# 调用LLM生成HTML
|
||||||
response = self.llm_client.invoke(SYSTEM_PROMPT_HTML_GENERATION, message)
|
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_HTML_GENERATION, message)
|
||||||
|
|
||||||
# 处理响应(简化版)
|
# 处理响应(简化版)
|
||||||
processed_response = self.process_output(response)
|
processed_response = self.process_output(response)
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ class TemplateSelectionNode(BaseNode):
|
|||||||
请根据查询内容、报告内容和论坛日志的具体情况,选择最合适的模板。"""
|
请根据查询内容、报告内容和论坛日志的具体情况,选择最合适的模板。"""
|
||||||
|
|
||||||
# 调用LLM
|
# 调用LLM
|
||||||
response = self.llm_client.invoke(SYSTEM_PROMPT_TEMPLATE_SELECTION, user_message)
|
response = self.llm_client.stream_invoke_to_string(SYSTEM_PROMPT_TEMPLATE_SELECTION, user_message)
|
||||||
|
|
||||||
# 检查响应是否为空
|
# 检查响应是否为空
|
||||||
if not response or not response.strip():
|
if not response or not response.strip():
|
||||||
|
|||||||
@@ -6,10 +6,7 @@ Forum日志读取工具
|
|||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List, Dict
|
from typing import Optional, List, Dict
|
||||||
import logging
|
from loguru import logger
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def get_latest_host_speech(log_dir: str = "logs") -> Optional[str]:
|
def get_latest_host_speech(log_dir: str = "logs") -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user