Improved Rendering
This commit is contained in:
@@ -10,12 +10,12 @@ from __future__ import annotations
|
||||
import json
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Tuple, Callable, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from ..core import TemplateSection, ChapterStorage
|
||||
from ..ir import ALLOWED_BLOCK_TYPES, IRValidator
|
||||
from ..ir import ALLOWED_BLOCK_TYPES, ALLOWED_INLINE_MARKS, IRValidator
|
||||
from ..prompts import (
|
||||
SYSTEM_PROMPT_CHAPTER_JSON,
|
||||
build_chapter_user_prompt,
|
||||
@@ -28,10 +28,41 @@ except ImportError: # pragma: no cover - optional dependency
|
||||
_json_repair_fn = None
|
||||
|
||||
|
||||
class ChapterJsonParseError(ValueError):
|
||||
"""Raised when the LLM output for a chapter cannot be parsed as valid JSON."""
|
||||
|
||||
def __init__(self, message: str, raw_text: Optional[str] = None):
|
||||
super().__init__(message)
|
||||
self.raw_text = raw_text
|
||||
|
||||
|
||||
class ChapterGenerationNode(BaseNode):
|
||||
"""负责按章节调用LLM并校验JSON结构"""
|
||||
|
||||
_COLON_EQUALS_PATTERN = re.compile(r'(":\s*)=')
|
||||
_LINE_BREAK_SENTINEL = "__LINE_BREAK__"
|
||||
_INLINE_MARK_ALIASES = {
|
||||
"strong": "bold",
|
||||
"b": "bold",
|
||||
"em": "italic",
|
||||
"emphasis": "italic",
|
||||
"i": "italic",
|
||||
"u": "underline",
|
||||
"strike-through": "strike",
|
||||
"strikethrough": "strike",
|
||||
"s": "strike",
|
||||
"codeblock": "code",
|
||||
"monospace": "code",
|
||||
"hyperlink": "link",
|
||||
"url": "link",
|
||||
"colour": "color",
|
||||
"textcolor": "color",
|
||||
"bgcolor": "highlight",
|
||||
"background": "highlight",
|
||||
"highlightcolor": "highlight",
|
||||
"sub": "subscript",
|
||||
"sup": "superscript",
|
||||
}
|
||||
|
||||
def __init__(self, llm_client, validator: IRValidator, storage: ChapterStorage):
|
||||
"""
|
||||
@@ -51,6 +82,7 @@ class ChapterGenerationNode(BaseNode):
|
||||
section: TemplateSection,
|
||||
context: Dict[str, Any],
|
||||
run_dir: Path,
|
||||
stream_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""针对单个章节调用LLM,校验/落盘章节JSON并返回结构化结果"""
|
||||
@@ -64,7 +96,13 @@ class ChapterGenerationNode(BaseNode):
|
||||
llm_payload = self._build_payload(section, context)
|
||||
user_message = build_chapter_user_prompt(llm_payload)
|
||||
|
||||
raw_text = self._stream_llm(user_message, chapter_dir, **kwargs)
|
||||
raw_text = self._stream_llm(
|
||||
user_message,
|
||||
chapter_dir,
|
||||
stream_callback=stream_callback,
|
||||
section_meta=chapter_meta,
|
||||
**kwargs,
|
||||
)
|
||||
chapter_json = self._parse_chapter(raw_text)
|
||||
|
||||
# 自动补全关键字段后再校验
|
||||
@@ -150,8 +188,15 @@ class ChapterGenerationNode(BaseNode):
|
||||
payload["globalContext"]["sectionBudgets"] = chapter_plan["sections"]
|
||||
return payload
|
||||
|
||||
def _stream_llm(self, user_message: str, chapter_dir: Path, **kwargs) -> str:
|
||||
"""流式调用LLM并实时写入raw文件"""
|
||||
def _stream_llm(
|
||||
self,
|
||||
user_message: str,
|
||||
chapter_dir: Path,
|
||||
stream_callback: Optional[Callable[[str, Dict[str, Any]], None]] = None,
|
||||
section_meta: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""流式调用LLM并实时写入raw文件,同时通过回调将delta抛出。"""
|
||||
chunks: List[str] = []
|
||||
with self.storage.capture_stream(chapter_dir) as stream_fp:
|
||||
stream = self.llm_client.stream_invoke(
|
||||
@@ -163,6 +208,12 @@ class ChapterGenerationNode(BaseNode):
|
||||
for delta in stream:
|
||||
stream_fp.write(delta)
|
||||
chunks.append(delta)
|
||||
if stream_callback:
|
||||
meta = section_meta or {}
|
||||
try:
|
||||
stream_callback(delta, meta)
|
||||
except Exception as callback_error: # pragma: no cover - 仅记录,不阻断主流程
|
||||
logger.warning(f"章节流式回调失败: {callback_error}")
|
||||
return "".join(chunks)
|
||||
|
||||
def _parse_chapter(self, raw_text: str) -> Dict[str, Any]:
|
||||
@@ -192,9 +243,13 @@ class ChapterGenerationNode(BaseNode):
|
||||
try:
|
||||
data = self._parse_with_candidates(candidate_payloads[-1:])
|
||||
except json.JSONDecodeError as inner_exc:
|
||||
raise ValueError(f"章节JSON解析失败: {inner_exc}") from inner_exc
|
||||
raise ChapterJsonParseError(
|
||||
f"章节JSON解析失败: {inner_exc}", raw_text=cleaned
|
||||
) from inner_exc
|
||||
else:
|
||||
raise ValueError(f"章节JSON解析失败: {exc}") from exc
|
||||
raise ChapterJsonParseError(
|
||||
f"章节JSON解析失败: {exc}", raw_text=cleaned
|
||||
) from exc
|
||||
|
||||
if "chapter" in data and isinstance(data["chapter"], dict):
|
||||
return data["chapter"]
|
||||
@@ -400,6 +455,7 @@ class ChapterGenerationNode(BaseNode):
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
self._ensure_block_type(block)
|
||||
self._sanitize_block_content(block)
|
||||
block_type = block.get("type")
|
||||
if block_type == "list":
|
||||
items = block.get("items")
|
||||
@@ -424,6 +480,98 @@ class ChapterGenerationNode(BaseNode):
|
||||
|
||||
walk(chapter.get("blocks"))
|
||||
|
||||
def _sanitize_block_content(self, block: Dict[str, Any]):
|
||||
"""根据类型做精细化修复,例如清理paragraph内的非法inline mark"""
|
||||
block_type = block.get("type")
|
||||
if block_type == "paragraph":
|
||||
self._normalize_paragraph_block(block)
|
||||
|
||||
def _normalize_paragraph_block(self, block: Dict[str, Any]):
|
||||
"""将paragraph的inlines统一规整,剔除非法marks"""
|
||||
inlines = block.get("inlines")
|
||||
normalized_runs: List[Dict[str, Any]] = []
|
||||
if isinstance(inlines, list) and inlines:
|
||||
for run in inlines:
|
||||
normalized_runs.extend(self._coerce_inline_run(run))
|
||||
else:
|
||||
normalized_runs = [self._as_inline_run(self._extract_block_text(block))]
|
||||
if not normalized_runs:
|
||||
normalized_runs = [self._as_inline_run("")]
|
||||
block["inlines"] = normalized_runs
|
||||
|
||||
def _coerce_inline_run(self, run: Any) -> List[Dict[str, Any]]:
|
||||
"""将任意inline写法规整为合法run"""
|
||||
if isinstance(run, dict):
|
||||
normalized_run = dict(run)
|
||||
text = normalized_run.get("text")
|
||||
if not isinstance(text, str):
|
||||
text = "" if text is None else str(text)
|
||||
marks = normalized_run.get("marks")
|
||||
sanitized_marks, extra_text = self._sanitize_inline_marks(marks)
|
||||
normalized_run["marks"] = sanitized_marks
|
||||
normalized_run["text"] = (text or "") + extra_text
|
||||
return [normalized_run]
|
||||
if isinstance(run, str):
|
||||
return [self._as_inline_run(run)]
|
||||
if isinstance(run, (int, float)):
|
||||
return [self._as_inline_run(str(run))]
|
||||
if isinstance(run, list):
|
||||
normalized: List[Dict[str, Any]] = []
|
||||
for item in run:
|
||||
normalized.extend(self._coerce_inline_run(item))
|
||||
return normalized
|
||||
return [self._as_inline_run("" if run is None else str(run))]
|
||||
|
||||
def _sanitize_inline_marks(self, marks: Any) -> Tuple[List[Dict[str, Any]], str]:
|
||||
"""过滤非法marks并将break类控制符转成文本"""
|
||||
text_suffix = ""
|
||||
if marks is None:
|
||||
return [], text_suffix
|
||||
mark_list = marks if isinstance(marks, list) else [marks]
|
||||
sanitized: List[Dict[str, Any]] = []
|
||||
for mark in mark_list:
|
||||
normalized_mark, extra_text = self._normalize_inline_mark(mark)
|
||||
if normalized_mark:
|
||||
sanitized.append(normalized_mark)
|
||||
if extra_text:
|
||||
text_suffix += extra_text
|
||||
return sanitized, text_suffix
|
||||
|
||||
def _normalize_inline_mark(self, mark: Any) -> Tuple[Dict[str, Any] | None, str]:
|
||||
"""对单个mark做兼容映射,或者在必要时转换为文本"""
|
||||
if not isinstance(mark, dict):
|
||||
return None, ""
|
||||
canonical_type = self._canonical_inline_mark_type(mark.get("type"))
|
||||
if canonical_type == self._LINE_BREAK_SENTINEL:
|
||||
return None, "\n"
|
||||
if canonical_type in ALLOWED_INLINE_MARKS:
|
||||
normalized = dict(mark)
|
||||
normalized["type"] = canonical_type
|
||||
return normalized, ""
|
||||
return None, ""
|
||||
|
||||
def _canonical_inline_mark_type(self, mark_type: Any) -> str | None:
|
||||
"""将mark type映射为Schema所支持的取值"""
|
||||
if not isinstance(mark_type, str):
|
||||
return None
|
||||
normalized = mark_type.strip()
|
||||
if not normalized:
|
||||
return None
|
||||
lowered = normalized.lower()
|
||||
if lowered in {"break", "linebreak", "br"}:
|
||||
return self._LINE_BREAK_SENTINEL
|
||||
return self._INLINE_MARK_ALIASES.get(lowered, lowered)
|
||||
|
||||
def _extract_block_text(self, block: Dict[str, Any]) -> str:
|
||||
"""优先从text/content等字段提取fallback文本"""
|
||||
for key in ("text", "content", "value", "title"):
|
||||
value = block.get(key)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if value is not None:
|
||||
return str(value)
|
||||
return ""
|
||||
|
||||
def _normalize_list_items(self, items: Any) -> List[List[Dict[str, Any]]]:
|
||||
"""确保list block的items为[[block, block], ...]结构"""
|
||||
if not isinstance(items, list):
|
||||
@@ -490,16 +638,21 @@ class ChapterGenerationNode(BaseNode):
|
||||
text = str(block)
|
||||
block.clear()
|
||||
block["type"] = "paragraph"
|
||||
block["inlines"] = [{"text": text}]
|
||||
block["inlines"] = [self._as_inline_run(text)]
|
||||
|
||||
@staticmethod
|
||||
def _as_paragraph_block(text: str) -> Dict[str, Any]:
|
||||
"""将字符串快速包装成paragraph block,方便统一处理"""
|
||||
return {
|
||||
"type": "paragraph",
|
||||
"inlines": [{"text": text or ""}],
|
||||
"inlines": [ChapterGenerationNode._as_inline_run(text)],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _as_inline_run(text: str) -> Dict[str, Any]:
|
||||
"""构造基础inline run,保证marks字段存在"""
|
||||
return {"text": text or "", "marks": []}
|
||||
|
||||
@staticmethod
|
||||
def _parse_with_candidates(payloads: List[str]) -> Dict[str, Any]:
|
||||
"""按顺序尝试多个payload,直到解析成功"""
|
||||
@@ -513,4 +666,4 @@ class ChapterGenerationNode(BaseNode):
|
||||
raise last_exc
|
||||
|
||||
|
||||
__all__ = ["ChapterGenerationNode"]
|
||||
__all__ = ["ChapterGenerationNode", "ChapterJsonParseError"]
|
||||
|
||||
Reference in New Issue
Block a user