From 5d7f41763f7efe624531c46233e5f604f7593cde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=88=92=E9=85=92=E7=9A=84=E6=9D=8E=E7=99=BD?= <670939375@qq.com> Date: Fri, 22 Aug 2025 19:14:32 +0800 Subject: [PATCH] JSON parsing fix. --- src/nodes/report_structure_node.py | 74 ++++++++++++++++---- src/nodes/search_node.py | 87 +++++++++++++++++++---- src/nodes/summary_node.py | 51 +++++++++++--- src/utils/text_processing.py | 108 +++++++++++++++++++++++++++++ streamlit_app.py | 2 +- 5 files changed, 283 insertions(+), 39 deletions(-) diff --git a/src/nodes/report_structure_node.py b/src/nodes/report_structure_node.py index 5751a69..87632da 100644 --- a/src/nodes/report_structure_node.py +++ b/src/nodes/report_structure_node.py @@ -13,7 +13,8 @@ from ..prompts import SYSTEM_PROMPT_REPORT_STRUCTURE from ..utils.text_processing import ( remove_reasoning_from_output, clean_json_tags, - extract_clean_response + extract_clean_response, + fix_incomplete_json ) @@ -77,48 +78,91 @@ class ReportStructureNode(StateMutationNode): cleaned_output = remove_reasoning_from_output(output) cleaned_output = clean_json_tags(cleaned_output) + # 记录清理后的输出用于调试 + self.log_info(f"清理后的输出: {cleaned_output[:200]}...") + # 解析JSON try: report_structure = json.loads(cleaned_output) - except JSONDecodeError: + self.log_info("JSON解析成功") + except JSONDecodeError as e: + self.log_info(f"JSON解析失败: {str(e)}") # 使用更强大的提取方法 report_structure = extract_clean_response(cleaned_output) if "error" in report_structure: - raise ValueError("JSON解析失败") + self.log_error("JSON解析失败,尝试修复...") + # 尝试修复JSON + fixed_json = fix_incomplete_json(cleaned_output) + if fixed_json: + try: + report_structure = json.loads(fixed_json) + self.log_info("JSON修复成功") + except JSONDecodeError: + self.log_error("JSON修复失败") + # 返回默认结构 + return self._generate_default_structure() + else: + self.log_error("无法修复JSON,使用默认结构") + return self._generate_default_structure() # 验证结构 if not isinstance(report_structure, list): - raise ValueError("报告结构应该是一个列表") + self.log_info("报告结构不是列表,尝试转换...") + if isinstance(report_structure, dict): + # 如果是单个对象,包装成列表 + report_structure = [report_structure] + else: + self.log_error("报告结构格式无效,使用默认结构") + return self._generate_default_structure() # 验证每个段落 validated_structure = [] for i, paragraph in enumerate(report_structure): if not isinstance(paragraph, dict): + self.log_warning(f"段落 {i+1} 不是字典格式,跳过") continue title = paragraph.get("title", f"段落 {i+1}") content = paragraph.get("content", "") + if not title or not content: + self.log_warning(f"段落 {i+1} 缺少标题或内容,跳过") + continue + validated_structure.append({ "title": title, "content": content }) + if not validated_structure: + self.log_warning("没有有效的段落结构,使用默认结构") + return self._generate_default_structure() + + self.log_info(f"成功验证 {len(validated_structure)} 个段落结构") return validated_structure except Exception as e: self.log_error(f"处理输出失败: {str(e)}") - # 返回默认结构 - return [ - { - "title": "概述", - "content": f"对'{self.query}'的总体概述和背景介绍" - }, - { - "title": "详细分析", - "content": f"深入分析'{self.query}'的相关内容" - } - ] + return self._generate_default_structure() + + def _generate_default_structure(self) -> List[Dict[str, str]]: + """ + 生成默认的报告结构 + + Returns: + 默认的报告结构列表 + """ + self.log_info("生成默认报告结构") + return [ + { + "title": "研究概述", + "content": "对查询主题进行总体概述和分析" + }, + { + "title": "深度分析", + "content": "深入分析查询主题的各个方面" + } + ] def mutate_state(self, input_data: Any = None, state: State = None, **kwargs) -> State: """ diff --git a/src/nodes/search_node.py b/src/nodes/search_node.py index 2bfa29c..52cc17d 100644 --- a/src/nodes/search_node.py +++ b/src/nodes/search_node.py @@ -12,7 +12,8 @@ from ..prompts import SYSTEM_PROMPT_FIRST_SEARCH, SYSTEM_PROMPT_REFLECTION from ..utils.text_processing import ( remove_reasoning_from_output, clean_json_tags, - extract_clean_response + extract_clean_response, + fix_incomplete_json ) @@ -91,21 +92,40 @@ class FirstSearchNode(BaseNode): cleaned_output = remove_reasoning_from_output(output) cleaned_output = clean_json_tags(cleaned_output) + # 记录清理后的输出用于调试 + self.log_info(f"清理后的输出: {cleaned_output[:200]}...") + # 解析JSON try: result = json.loads(cleaned_output) - except JSONDecodeError: + self.log_info("JSON解析成功") + except JSONDecodeError as e: + self.log_info(f"JSON解析失败: {str(e)}") # 使用更强大的提取方法 result = extract_clean_response(cleaned_output) if "error" in result: - raise ValueError("JSON解析失败") + self.log_error("JSON解析失败,尝试修复...") + # 尝试修复JSON + fixed_json = fix_incomplete_json(cleaned_output) + if fixed_json: + try: + result = json.loads(fixed_json) + self.log_info("JSON修复成功") + except JSONDecodeError: + self.log_error("JSON修复失败") + # 返回默认查询 + return self._get_default_search_query() + else: + self.log_error("无法修复JSON,使用默认查询") + return self._get_default_search_query() # 验证和清理结果 search_query = result.get("search_query", "") reasoning = result.get("reasoning", "") if not search_query: - raise ValueError("未找到搜索查询") + self.log_warning("未找到搜索查询,使用默认查询") + return self._get_default_search_query() return { "search_query": search_query, @@ -115,10 +135,19 @@ class FirstSearchNode(BaseNode): except Exception as e: self.log_error(f"处理输出失败: {str(e)}") # 返回默认查询 - return { - "search_query": "相关主题研究", - "reasoning": "由于解析失败,使用默认搜索查询" - } + return self._get_default_search_query() + + def _get_default_search_query(self) -> Dict[str, str]: + """ + 获取默认搜索查询 + + Returns: + 默认的搜索查询字典 + """ + return { + "search_query": "相关主题研究", + "reasoning": "由于解析失败,使用默认搜索查询" + } class ReflectionNode(BaseNode): @@ -198,21 +227,40 @@ class ReflectionNode(BaseNode): cleaned_output = remove_reasoning_from_output(output) cleaned_output = clean_json_tags(cleaned_output) + # 记录清理后的输出用于调试 + self.log_info(f"清理后的输出: {cleaned_output[:200]}...") + # 解析JSON try: result = json.loads(cleaned_output) - except JSONDecodeError: + self.log_info("JSON解析成功") + except JSONDecodeError as e: + self.log_info(f"JSON解析失败: {str(e)}") # 使用更强大的提取方法 result = extract_clean_response(cleaned_output) if "error" in result: - raise ValueError("JSON解析失败") + self.log_error("JSON解析失败,尝试修复...") + # 尝试修复JSON + fixed_json = fix_incomplete_json(cleaned_output) + if fixed_json: + try: + result = json.loads(fixed_json) + self.log_info("JSON修复成功") + except JSONDecodeError: + self.log_error("JSON修复失败") + # 返回默认查询 + return self._get_default_reflection_query() + else: + self.log_error("无法修复JSON,使用默认查询") + return self._get_default_reflection_query() # 验证和清理结果 search_query = result.get("search_query", "") reasoning = result.get("reasoning", "") if not search_query: - raise ValueError("未找到搜索查询") + self.log_warning("未找到搜索查询,使用默认查询") + return self._get_default_reflection_query() return { "search_query": search_query, @@ -222,7 +270,16 @@ class ReflectionNode(BaseNode): except Exception as e: self.log_error(f"处理输出失败: {str(e)}") # 返回默认查询 - return { - "search_query": "深度研究补充信息", - "reasoning": "由于解析失败,使用默认反思搜索查询" - } + return self._get_default_reflection_query() + + def _get_default_reflection_query(self) -> Dict[str, str]: + """ + 获取默认反思搜索查询 + + Returns: + 默认的反思搜索查询字典 + """ + return { + "search_query": "深度研究补充信息", + "reasoning": "由于解析失败,使用默认反思搜索查询" + } diff --git a/src/nodes/summary_node.py b/src/nodes/summary_node.py index 2b9d24a..e44fb1e 100644 --- a/src/nodes/summary_node.py +++ b/src/nodes/summary_node.py @@ -14,6 +14,7 @@ from ..utils.text_processing import ( remove_reasoning_from_output, clean_json_tags, extract_clean_response, + fix_incomplete_json, format_search_results_for_prompt ) @@ -82,25 +83,42 @@ class FirstSummaryNode(StateMutationNode): def process_output(self, output: str) -> str: """ - 处理LLM输出,提取段落总结 + 处理LLM输出,提取段落内容 Args: output: LLM原始输出 Returns: - 段落总结内容 + 段落内容 """ try: # 清理响应文本 cleaned_output = remove_reasoning_from_output(output) cleaned_output = clean_json_tags(cleaned_output) + # 记录清理后的输出用于调试 + self.log_info(f"清理后的输出: {cleaned_output[:200]}...") + # 解析JSON try: result = json.loads(cleaned_output) - except JSONDecodeError: - # 如果不是JSON格式,直接返回清理后的文本 - return cleaned_output + self.log_info("JSON解析成功") + except JSONDecodeError as e: + self.log_info(f"JSON解析失败: {str(e)}") + # 尝试修复JSON + fixed_json = fix_incomplete_json(cleaned_output) + if fixed_json: + try: + result = json.loads(fixed_json) + self.log_info("JSON修复成功") + except JSONDecodeError: + self.log_info("JSON修复失败,直接使用清理后的文本") + # 如果不是JSON格式,直接返回清理后的文本 + return cleaned_output + else: + self.log_info("无法修复JSON,直接使用清理后的文本") + # 如果不是JSON格式,直接返回清理后的文本 + return cleaned_output # 提取段落内容 if isinstance(result, dict): @@ -224,12 +242,29 @@ class ReflectionSummaryNode(StateMutationNode): cleaned_output = remove_reasoning_from_output(output) cleaned_output = clean_json_tags(cleaned_output) + # 记录清理后的输出用于调试 + self.log_info(f"清理后的输出: {cleaned_output[:200]}...") + # 解析JSON try: result = json.loads(cleaned_output) - except JSONDecodeError: - # 如果不是JSON格式,直接返回清理后的文本 - return cleaned_output + self.log_info("JSON解析成功") + except JSONDecodeError as e: + self.log_info(f"JSON解析失败: {str(e)}") + # 尝试修复JSON + fixed_json = fix_incomplete_json(cleaned_output) + if fixed_json: + try: + result = json.loads(fixed_json) + self.log_info("JSON修复成功") + except JSONDecodeError: + self.log_info("JSON修复失败,直接使用清理后的文本") + # 如果不是JSON格式,直接返回清理后的文本 + return cleaned_output + else: + self.log_info("无法修复JSON,直接使用清理后的文本") + # 如果不是JSON格式,直接返回清理后的文本 + return cleaned_output # 提取更新后的段落内容 if isinstance(result, dict): diff --git a/src/utils/text_processing.py b/src/utils/text_processing.py index b76baa7..471650c 100644 --- a/src/utils/text_processing.py +++ b/src/utils/text_processing.py @@ -55,6 +55,20 @@ def remove_reasoning_from_output(text: str) -> str: Returns: 清理后的文本 """ + # 查找JSON开始位置 + json_start = -1 + + # 尝试找到第一个 { 或 [ + for i, char in enumerate(text): + if char in '{[': + json_start = i + break + + if json_start != -1: + # 从JSON开始位置截取 + return text[json_start:].strip() + + # 如果没有找到JSON标记,尝试其他方法 # 移除常见的推理标识 patterns = [ r'(?:reasoning|推理|思考|分析)[::]\s*.*?(?=\{|\[)', # 移除推理部分 @@ -88,6 +102,14 @@ def extract_clean_response(text: str) -> Dict[str, Any]: except JSONDecodeError: pass + # 尝试修复不完整的JSON + fixed_text = fix_incomplete_json(cleaned_text) + if fixed_text: + try: + return json.loads(fixed_text) + except JSONDecodeError: + pass + # 尝试查找JSON对象 json_pattern = r'\{.*\}' match = re.search(json_pattern, cleaned_text, re.DOTALL) @@ -111,6 +133,92 @@ def extract_clean_response(text: str) -> Dict[str, Any]: return {"error": "JSON解析失败", "raw_text": cleaned_text} +def fix_incomplete_json(text: str) -> str: + """ + 修复不完整的JSON响应 + + Args: + text: 原始文本 + + Returns: + 修复后的JSON文本,如果无法修复则返回空字符串 + """ + # 移除多余的逗号和空白 + text = re.sub(r',\s*}', '}', text) + text = re.sub(r',\s*]', ']', text) + + # 检查是否已经是有效的JSON + try: + json.loads(text) + return text + except JSONDecodeError: + pass + + # 检查是否缺少开头的数组符号 + if text.strip().startswith('{') and not text.strip().startswith('['): + # 如果以对象开始,尝试包装成数组 + if text.count('{') > 1: + # 多个对象,包装成数组 + text = '[' + text + ']' + else: + # 单个对象,包装成数组 + text = '[' + text + ']' + + # 检查是否缺少结尾的数组符号 + if text.strip().endswith('}') and not text.strip().endswith(']'): + # 如果以对象结束,尝试包装成数组 + if text.count('}') > 1: + # 多个对象,包装成数组 + text = '[' + text + ']' + else: + # 单个对象,包装成数组 + text = '[' + text + ']' + + # 检查括号是否匹配 + open_braces = text.count('{') + close_braces = text.count('}') + open_brackets = text.count('[') + close_brackets = text.count(']') + + # 修复不匹配的括号 + if open_braces > close_braces: + text += '}' * (open_braces - close_braces) + if open_brackets > close_brackets: + text += ']' * (open_brackets - close_brackets) + + # 验证修复后的JSON是否有效 + try: + json.loads(text) + return text + except JSONDecodeError: + # 如果仍然无效,尝试更激进的修复 + return fix_aggressive_json(text) + + +def fix_aggressive_json(text: str) -> str: + """ + 更激进的JSON修复方法 + + Args: + text: 原始文本 + + Returns: + 修复后的JSON文本 + """ + # 查找所有可能的JSON对象 + objects = re.findall(r'\{[^{}]*\}', text) + + if len(objects) >= 2: + # 如果有多个对象,包装成数组 + return '[' + ','.join(objects) + ']' + elif len(objects) == 1: + # 如果只有一个对象,包装成数组 + return '[' + objects[0] + ']' + else: + # 如果没有找到对象,返回空数组 + return '[]' + + def update_state_with_search_results(search_results: List[Dict[str, Any]], paragraph_index: int, state: Any) -> Any: """ diff --git a/streamlit_app.py b/streamlit_app.py index e15466c..63ed36b 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -13,7 +13,7 @@ import json sys.path.insert(0, os.path.join(os.path.dirname(__file__), '.')) from src import DeepSearchAgent, Config -from config import DEEPSEEK_API_KEY, DEEPSEEK_API_KEY_2, TAVILY_API_KEY +from config import DEEPSEEK_API_KEY, TAVILY_API_KEY def main():