feat: comprehensive v2 upgrade — streaming, error KB, file upload, layout analysis
Major changes: - Streaming: LLM统一 _BaseLLM 接口 (invoke + stream), generate/modify/correct 节点使用 get_stream_writer() 实现逐字输出, UI 节点平铺展开自动折叠 - Prompt外部化: 7个prompt拆分到 prompts/*.md, loader.py 支持热重载 - 错误自增长: backend/error_kb.py — 指纹去重 + ChromaDB持久化, correct_jrxml→validate 通过时自动入库, retrieve同时搜索错误KB - 文件上传: backend/file_parser.py — PDF/DOCX/图片/文本解析, 侧边栏多文件上传, 文本自动注入下一条消息 - A4模板识别: backend/layout_analyzer.py — 三种模式(完整A4/行片段修改/行片段新建), PaddleOCR元素提取 + 行分组 + JRXML section匹配 - 会话历史下载: jrxml_versions版本追踪 + 侧边栏历史版本下载按钮 - 预览修复: route_after_save跳过预览/导出意图的验证循环 - Ctrl+C修复: JS注入拦截Streamlit裸c键清缓存 Docs: CLAUDE.md (完整项目文档), ROADMAP.md (改进路线图) Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
调用方式:
|
||||
get_embeddings() → LangChain 兼容的 embeddings 对象
|
||||
get_st_embeddings() → 原始 SentenceTransformer 实例
|
||||
get_st_model() → 原始 SentenceTransformer 实例
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
@@ -0,0 +1,225 @@
|
||||
"""错误自增长知识库 — 记录修正成功的错误案例,用于未来参考。
|
||||
|
||||
原则:
|
||||
- 仅记录"新错误"(指纹去重)
|
||||
- 必须包含完整的修正方案(prompt、工具链、前后 JRXML)
|
||||
- 存储于 ChromaDB,可被检索注入到生成 prompt 中
|
||||
|
||||
用法:
|
||||
from backend.error_kb import ErrorKB
|
||||
kb = ErrorKB()
|
||||
kb.record(error_msg, bad_jrxml, good_jrxml, correction_prompt)
|
||||
cases = kb.search("字段未声明", k=3)
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
CHROMA_DIR = Path(os.getenv("CHROMA_PERSIST_DIR", "./db/chroma"))
|
||||
COLLECTION_NAME = "jrxml_error_cases"
|
||||
|
||||
|
||||
def _make_fingerprint(error_msg: str) -> str:
|
||||
"""生成错误指纹 — 标准化后取 hash,用于去重。
|
||||
|
||||
标准化规则:
|
||||
- 去除字段名、变量名等具体标识符(替换为占位符)
|
||||
- 小写化
|
||||
- 只保留错误的结构性特征
|
||||
"""
|
||||
text = error_msg.lower()
|
||||
# 替换变量名 / 字段名($F{xxx}, "name", 'value' 等)
|
||||
text = re.sub(r'\$f\{[^}]+\}', '$f{<FIELD>}', text)
|
||||
text = re.sub(r"'[^']*'", "'<VALUE>'", text)
|
||||
text = re.sub(r'"[^"]*"', '"<VALUE>"', text)
|
||||
# 替换数字
|
||||
text = re.sub(r'\b\d+\b', '<NUM>', text)
|
||||
# 压缩空白
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
return hashlib.md5(text.encode()).hexdigest()[:16]
|
||||
|
||||
|
||||
class ErrorKB:
|
||||
"""错误案例知识库 — 包装 ChromaDB 持久化。"""
|
||||
|
||||
def __init__(self):
|
||||
self._client = None
|
||||
self._collection = None
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
if self._client is None:
|
||||
import chromadb
|
||||
self._client = chromadb.PersistentClient(path=str(CHROMA_DIR))
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def collection(self):
|
||||
if self._collection is None:
|
||||
try:
|
||||
self._collection = self.client.get_collection(COLLECTION_NAME)
|
||||
except Exception:
|
||||
self._collection = self.client.create_collection(COLLECTION_NAME)
|
||||
return self._collection
|
||||
|
||||
def exists(self, error_msg: str) -> bool:
|
||||
"""检查错误是否已存在于知识库中(按指纹去重)。"""
|
||||
fp = _make_fingerprint(error_msg)
|
||||
try:
|
||||
results = self.collection.get(ids=[fp])
|
||||
return bool(results and results["ids"])
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def record(
|
||||
self,
|
||||
error_msg: str,
|
||||
bad_jrxml: str,
|
||||
good_jrxml: str,
|
||||
correction_prompt: str,
|
||||
model: str = "",
|
||||
retry_count: int = 0,
|
||||
) -> bool:
|
||||
"""记录一个成功修正的错误案例。
|
||||
|
||||
仅当指纹不重复时写入。返回 True 表示已记录,False 表示重复。
|
||||
"""
|
||||
if self.exists(error_msg):
|
||||
return False
|
||||
|
||||
fp = _make_fingerprint(error_msg)
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# 内容:结构化记录
|
||||
doc = json.dumps({
|
||||
"error": error_msg,
|
||||
"bad_jrxml_snippet": bad_jrxml[:2000],
|
||||
"good_jrxml_snippet": good_jrxml[:2000],
|
||||
"correction_prompt": correction_prompt[:1500],
|
||||
"model": model,
|
||||
"retry_count": retry_count,
|
||||
"recorded_at": now,
|
||||
"tools": ["validation_service", "llm_correction"],
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 元数据:用于检索过滤
|
||||
error_keywords = _extract_keywords(error_msg)
|
||||
metadata = {
|
||||
"fingerprint": fp,
|
||||
"error_keywords": ", ".join(error_keywords[:5]),
|
||||
"recorded_at": now,
|
||||
"retry_success": retry_count + 1, # 第几次修正成功的
|
||||
}
|
||||
|
||||
self.collection.add(
|
||||
ids=[fp],
|
||||
documents=[doc],
|
||||
metadatas=[metadata],
|
||||
)
|
||||
return True
|
||||
|
||||
def search(self, error_msg: str, k: int = 3) -> list[dict]:
|
||||
"""根据错误消息搜索相似的修正案例(ChromaDB 语义搜索)。
|
||||
|
||||
返回 [{error, fix_snippet, prompt, ...}, ...]
|
||||
"""
|
||||
keywords = _extract_keywords(error_msg)
|
||||
if not keywords:
|
||||
return []
|
||||
|
||||
query_text = " ".join(keywords)
|
||||
try:
|
||||
results = self.collection.query(
|
||||
query_texts=[query_text],
|
||||
n_results=k,
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
output = []
|
||||
if not results["ids"] or not results["ids"][0]:
|
||||
return output
|
||||
|
||||
for i, doc_id in enumerate(results["ids"][0]):
|
||||
dist = results["distances"][0][i]
|
||||
try:
|
||||
data = json.loads(results["documents"][0][i])
|
||||
output.append({
|
||||
"id": doc_id,
|
||||
"error": data.get("error", ""),
|
||||
"fix_snippet": data.get("good_jrxml_snippet", ""),
|
||||
"prompt": data.get("correction_prompt", ""),
|
||||
"recorded_at": data.get("recorded_at", ""),
|
||||
"distance": dist,
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
return output
|
||||
|
||||
def search_as_context(self, error_msg: str, k: int = 3) -> str:
|
||||
"""搜索并返回拼接好的错误案例上下文,可直接注入 LLM prompt。"""
|
||||
results = self.search(error_msg, k=k)
|
||||
if not results:
|
||||
return ""
|
||||
|
||||
parts = []
|
||||
for r in results:
|
||||
parts.append(
|
||||
f"[历史错误案例]\n"
|
||||
f"错误: {r['error'][:200]}\n"
|
||||
f"修正后 JRXML 片段:\n{r['fix_snippet'][:800]}\n"
|
||||
)
|
||||
return "\n---\n".join(parts)
|
||||
|
||||
def stats(self) -> dict:
|
||||
"""返回知识库统计信息。"""
|
||||
try:
|
||||
count = self.collection.count()
|
||||
return {"total_cases": count, "collection": COLLECTION_NAME}
|
||||
except Exception:
|
||||
return {"total_cases": 0, "collection": COLLECTION_NAME}
|
||||
|
||||
|
||||
def _extract_keywords(error_msg: str) -> list[str]:
|
||||
"""从错误消息中提取关键词(中文 + 英文 token)。"""
|
||||
# 中文字符作为独立关键词
|
||||
chinese = re.findall(r'[一-鿿]{2,}', error_msg)
|
||||
# 英文 camelCase / snake_case token
|
||||
english = re.findall(r'[a-zA-Z_][a-zA-Z0-9_]{2,}', error_msg)
|
||||
# JRXML 特有模式
|
||||
jrxml_patterns = re.findall(r'\$F\{[^}]*\}', error_msg)
|
||||
return chinese + english + jrxml_patterns
|
||||
|
||||
|
||||
# 全局单例
|
||||
_kb: Optional[ErrorKB] = None
|
||||
|
||||
|
||||
def get_error_kb() -> ErrorKB:
|
||||
global _kb
|
||||
if _kb is None:
|
||||
_kb = ErrorKB()
|
||||
return _kb
|
||||
|
||||
|
||||
def record_error(error_msg: str, bad_jrxml: str, good_jrxml: str,
|
||||
correction_prompt: str, model: str = "", retry_count: int = 0) -> bool:
|
||||
"""便捷函数:记录成功修正的错误案例。"""
|
||||
return get_error_kb().record(error_msg, bad_jrxml, good_jrxml,
|
||||
correction_prompt, model, retry_count)
|
||||
|
||||
|
||||
def search_error_cases(error_msg: str, k: int = 3) -> str:
|
||||
"""便捷函数:搜索历史错误案例并返回上下文字符串。"""
|
||||
return get_error_kb().search_as_context(error_msg, k=k)
|
||||
@@ -0,0 +1,193 @@
|
||||
"""文件解析器:将上传文件转为文本,供 LLM 处理。
|
||||
|
||||
支持:
|
||||
- 图片 (.png/.jpg/.jpeg/.bmp) → OCR 提取文本
|
||||
- PDF (.pdf) → 文本提取
|
||||
- Word (.docx) → 文本提取
|
||||
- 纯文本 (.txt/.csv/.json/.xml) → 直接读取
|
||||
|
||||
策略选择:
|
||||
- 原生多模态: 模型支持图片时直接传文件(当前 MiniMax 不支持,自动退回文本转换)
|
||||
- 文本转换: 所有文件转为 UTF-8 文本后注入 prompt
|
||||
"""
|
||||
|
||||
import os
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import PIL.Image
|
||||
|
||||
MODELS_WITH_VISION = {
|
||||
"gpt-4o", "gpt-4-turbo", "gpt-4-vision-preview",
|
||||
"claude-3", "claude-3.5", "claude-4",
|
||||
"gemini-1.5", "gemini-2",
|
||||
}
|
||||
|
||||
|
||||
def can_use_vision(model: str = "") -> bool:
|
||||
"""检查当前模型是否支持原生多模态(图片直接上传)。"""
|
||||
if not model:
|
||||
model = os.getenv("LLM_MODEL", "")
|
||||
return any(v in model.lower() for v in MODELS_WITH_VISION)
|
||||
|
||||
|
||||
def parse_file(file_path: str, file_type: str = "") -> dict:
|
||||
"""解析任意文件为文本。
|
||||
|
||||
返回: {"text": str, "file_type": str, "method": str, "error": Optional[str]}
|
||||
"""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
return {"text": "", "file_type": file_type, "method": "none", "error": "文件不存在"}
|
||||
|
||||
suffix = file_type or path.suffix.lower()
|
||||
|
||||
parsers = {
|
||||
".png": _parse_image,
|
||||
".jpg": _parse_image,
|
||||
".jpeg": _parse_image,
|
||||
".bmp": _parse_image,
|
||||
".webp": _parse_image,
|
||||
".pdf": _parse_pdf,
|
||||
".docx": _parse_docx,
|
||||
}
|
||||
|
||||
parser = parsers.get(suffix)
|
||||
if parser:
|
||||
return parser(path)
|
||||
else:
|
||||
return _parse_text(path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 各类型解析器
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _parse_image(path: Path) -> dict:
|
||||
"""OCR 提取图片中的文字。"""
|
||||
try:
|
||||
img = PIL.Image.open(path)
|
||||
info = f"[图片: {img.size[0]}x{img.size[1]}, {img.mode}]"
|
||||
except Exception:
|
||||
info = "[图片: 无法读取元数据]"
|
||||
|
||||
# 尝试 PaddleOCR
|
||||
try:
|
||||
from paddleocr import PaddleOCR
|
||||
ocr = PaddleOCR(lang="ch", use_angle_cls=False, show_log=False)
|
||||
result = ocr.ocr(str(path))
|
||||
lines = []
|
||||
if result and result[0]:
|
||||
for line in result[0]:
|
||||
text = line[1][0] if len(line) > 1 else ""
|
||||
if text.strip():
|
||||
lines.append(text.strip())
|
||||
if lines:
|
||||
return {
|
||||
"text": f"{info}\n识别文本:\n" + "\n".join(lines),
|
||||
"file_type": "image",
|
||||
"method": "paddleocr",
|
||||
"error": None,
|
||||
}
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# OCR 不可用 → 返回图片元信息 + 安装提示
|
||||
return {
|
||||
"text": f"{info}\n(如需 OCR 文字识别,请安装: pip install paddleocr)",
|
||||
"file_type": "image",
|
||||
"method": "metadata_only",
|
||||
"error": "OCR 引擎未安装,已返回图片元信息",
|
||||
}
|
||||
|
||||
|
||||
def _parse_pdf(path: Path) -> dict:
|
||||
"""提取 PDF 中的文本。"""
|
||||
try:
|
||||
import pdfplumber
|
||||
with pdfplumber.open(path) as pdf:
|
||||
pages = []
|
||||
for page in pdf.pages:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
pages.append(text)
|
||||
full = "\n\n".join(pages)
|
||||
return {
|
||||
"text": full,
|
||||
"file_type": "pdf",
|
||||
"method": "pdfplumber",
|
||||
"error": None,
|
||||
}
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# Fallback: 尝试 PyMuPDF
|
||||
try:
|
||||
import fitz
|
||||
doc = fitz.open(path)
|
||||
pages = []
|
||||
for page in doc:
|
||||
pages.append(page.get_text())
|
||||
doc.close()
|
||||
return {
|
||||
"text": "\n\n".join(pages),
|
||||
"file_type": "pdf",
|
||||
"method": "pymupdf",
|
||||
"error": None,
|
||||
}
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {"text": "", "file_type": "pdf", "method": "none",
|
||||
"error": "PDF 解析需要安装 pdfplumber 或 PyMuPDF"}
|
||||
|
||||
|
||||
def _parse_docx(path: Path) -> dict:
|
||||
"""提取 Word 文档中的文本。"""
|
||||
try:
|
||||
from docx import Document
|
||||
doc = Document(path)
|
||||
paragraphs = [p.text for p in doc.paragraphs if p.text.strip()]
|
||||
# 同时提取表格内容
|
||||
for table in doc.tables:
|
||||
for row in table.rows:
|
||||
cells = [cell.text for cell in row.cells if cell.text.strip()]
|
||||
if cells:
|
||||
paragraphs.append(" | ".join(cells))
|
||||
return {
|
||||
"text": "\n\n".join(paragraphs),
|
||||
"file_type": "docx",
|
||||
"method": "python-docx",
|
||||
"error": None,
|
||||
}
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
return {"text": "", "file_type": "docx", "method": "none",
|
||||
"error": "DOCX 解析需要安装 python-docx"}
|
||||
|
||||
|
||||
def _parse_text(path: Path) -> dict:
|
||||
"""读取纯文本文件。"""
|
||||
try:
|
||||
text = path.read_text(encoding="utf-8")
|
||||
return {"text": text, "file_type": path.suffix, "method": "direct", "error": None}
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
text = path.read_text(encoding="gbk")
|
||||
return {"text": text, "file_type": path.suffix, "method": "direct_gbk", "error": None}
|
||||
except Exception:
|
||||
return {"text": "", "file_type": path.suffix, "method": "none",
|
||||
"error": "无法解码文件"}
|
||||
except Exception:
|
||||
return {"text": "", "file_type": path.suffix, "method": "none",
|
||||
"error": "读取失败"}
|
||||
@@ -0,0 +1,494 @@
|
||||
"""A4 图片模板布局分析器。
|
||||
|
||||
检测上传图片并逐行识别每个元素的:
|
||||
- 位置 (x, y, w, h)
|
||||
- 字体大小(基于 OCR 边界框高度估算)
|
||||
- 文本内容
|
||||
|
||||
支持三种模式:
|
||||
- 完整 A4 模板:比例匹配 + OCR 元素 ≥2 → 全量布局描述
|
||||
- 行片段(非 A4 但有元素):视为 A4 中的某几行 → 部分布局描述
|
||||
- 修改匹配:将图片中的行与现有 JRXML 做匹配,定位修改位置
|
||||
|
||||
用法:
|
||||
from backend.layout_analyzer import analyze_layout, match_rows_to_jrxml
|
||||
result = analyze_layout("row_snippet.png")
|
||||
# result["template_type"] = "partial_rows"
|
||||
match = match_rows_to_jrxml(result, current_jrxml)
|
||||
# match["matched_rows"] = [{"row_index": 0, "jrxml_section": "detail_band", ...}]
|
||||
"""
|
||||
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import PIL.Image
|
||||
|
||||
# A4 标准尺寸 (mm): 210 × 297, 比例 ≈ 0.707
|
||||
A4_RATIO = 210 / 297
|
||||
A4_RATIO_EXACT_MIN, A4_RATIO_EXACT_MAX = 0.686, 0.728
|
||||
A4_RATIO_CLOSE_MIN, A4_RATIO_CLOSE_MAX = 0.650, 0.764
|
||||
|
||||
|
||||
def analyze_layout(
|
||||
file_path: str,
|
||||
row_tolerance_ratio: float = 0.02,
|
||||
) -> dict:
|
||||
"""分析图片/PDF 的报表模板布局。
|
||||
|
||||
返回:
|
||||
{
|
||||
"is_a4_template": bool, # 完整 A4 模板
|
||||
"is_partial": bool, # 行片段(非 A4 但有文字元素)
|
||||
"template_type": str, # "full_a4" | "partial_rows" | "unknown"
|
||||
"image_size": (w, h),
|
||||
"aspect_ratio": float,
|
||||
"a4_confidence": str,
|
||||
"rows": [{y_center, elements: [{x, y, w, h, font_size, text}, ...]}, ...],
|
||||
"description": str,
|
||||
"total_rows": int,
|
||||
"total_elements": int,
|
||||
}
|
||||
"""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
return _empty_result("文件不存在")
|
||||
|
||||
img = _load_image(path)
|
||||
if img is None:
|
||||
return _empty_result("无法加载图片")
|
||||
|
||||
w, h = img.size
|
||||
ratio = min(w, h) / max(w, h)
|
||||
|
||||
# A4 比例判定
|
||||
if A4_RATIO_EXACT_MIN <= ratio <= A4_RATIO_EXACT_MAX:
|
||||
a4_confidence = "exact"
|
||||
elif A4_RATIO_CLOSE_MIN <= ratio <= A4_RATIO_CLOSE_MAX:
|
||||
a4_confidence = "close"
|
||||
else:
|
||||
a4_confidence = "not_a4"
|
||||
|
||||
# OCR 提取
|
||||
elements = _ocr_elements(img, file_path)
|
||||
|
||||
if not elements:
|
||||
return {
|
||||
"is_a4_template": False,
|
||||
"is_partial": False,
|
||||
"template_type": "unknown",
|
||||
"image_size": (w, h),
|
||||
"aspect_ratio": round(ratio, 3),
|
||||
"a4_confidence": a4_confidence,
|
||||
"rows": [],
|
||||
"description": _build_description([], w, h, a4_confidence, "unknown"),
|
||||
"total_rows": 0,
|
||||
"total_elements": 0,
|
||||
}
|
||||
|
||||
# 行分组
|
||||
rows = _group_into_rows(elements, h, row_tolerance_ratio)
|
||||
|
||||
total = sum(len(r["elements"]) for r in rows)
|
||||
|
||||
# 模板类型判定
|
||||
is_full_a4 = a4_confidence != "not_a4" and total >= 2
|
||||
is_partial = not is_full_a4 and total >= 1 # 非 A4 但有文字 → 行片段
|
||||
|
||||
if is_full_a4:
|
||||
template_type = "full_a4"
|
||||
elif is_partial:
|
||||
template_type = "partial_rows"
|
||||
else:
|
||||
template_type = "unknown"
|
||||
|
||||
description = _build_description(rows, w, h, a4_confidence, template_type)
|
||||
|
||||
return {
|
||||
"is_a4_template": is_full_a4,
|
||||
"is_partial": is_partial,
|
||||
"template_type": template_type,
|
||||
"image_size": (w, h),
|
||||
"aspect_ratio": round(ratio, 3),
|
||||
"a4_confidence": a4_confidence,
|
||||
"rows": rows,
|
||||
"description": description,
|
||||
"total_rows": len(rows),
|
||||
"total_elements": total,
|
||||
}
|
||||
|
||||
|
||||
def match_rows_to_jrxml(
|
||||
layout_result: dict,
|
||||
current_jrxml: str,
|
||||
) -> dict:
|
||||
"""将图片中的行与现有 JRXML 中的 section/band 做匹配。
|
||||
|
||||
匹配策略:
|
||||
1. 从图片 OCR 文本中提取关键词
|
||||
2. 在 JRXML 中搜索这些关键词出现在哪个 band
|
||||
3. 返回匹配结果,可用于定位修改位置
|
||||
|
||||
返回:
|
||||
{
|
||||
"matched": bool,
|
||||
"matched_rows": [{row_index, row_y_center, jrxml_section, confidence}],
|
||||
"unmatched_rows": [...],
|
||||
"description": str, # 人类可读的匹配结果
|
||||
}
|
||||
"""
|
||||
rows = layout_result.get("rows", [])
|
||||
if not rows or not current_jrxml.strip():
|
||||
return {"matched": False, "matched_rows": [], "unmatched_rows": rows,
|
||||
"description": "无行数据或 JRXML 为空"}
|
||||
|
||||
# 解析 JRXML 结构
|
||||
jrxml_sections = _parse_jrxml_sections(current_jrxml)
|
||||
|
||||
matched_rows = []
|
||||
unmatched_rows = []
|
||||
|
||||
for ri, row in enumerate(rows):
|
||||
ocr_texts = [e["text"] for e in row["elements"]]
|
||||
best_section = None
|
||||
best_score = 0
|
||||
|
||||
for section in jrxml_sections:
|
||||
score = _text_similarity(ocr_texts, section["text_content"])
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_section = section
|
||||
|
||||
if best_score > 0.3 and best_section: # 最低匹配阈值
|
||||
matched_rows.append({
|
||||
"row_index": ri,
|
||||
"row_y_center": row["y_center"],
|
||||
"jrxml_section": best_section["name"],
|
||||
"jrxml_section_type": best_section["type"],
|
||||
"confidence": round(best_score, 2),
|
||||
"matched_text": best_section["text_content"][:200],
|
||||
})
|
||||
else:
|
||||
unmatched_rows.append({
|
||||
"row_index": ri,
|
||||
"row_y_center": row["y_center"],
|
||||
"ocr_texts": ocr_texts,
|
||||
})
|
||||
|
||||
# 生成描述
|
||||
desc_parts = []
|
||||
if matched_rows:
|
||||
desc_parts.append(f"图片中 {len(matched_rows)} 行匹配到当前 JRXML:")
|
||||
for m in matched_rows:
|
||||
desc_parts.append(
|
||||
f" - 图片第 {m['row_index']+1} 行 → JRXML「{m['jrxml_section']}」"
|
||||
f"({m['jrxml_section_type']},置信度 {m['confidence']})"
|
||||
)
|
||||
if unmatched_rows:
|
||||
desc_parts.append(f"图片中 {len(unmatched_rows)} 行未匹配到 JRXML 现有区域:")
|
||||
for u in unmatched_rows:
|
||||
texts = ", ".join(u["ocr_texts"][:3])
|
||||
desc_parts.append(f" - 图片第 {u['row_index']+1} 行:{texts}")
|
||||
|
||||
return {
|
||||
"matched": len(matched_rows) > 0,
|
||||
"matched_rows": matched_rows,
|
||||
"unmatched_rows": unmatched_rows,
|
||||
"description": "\n".join(desc_parts),
|
||||
}
|
||||
|
||||
|
||||
def analyze_and_inject(file_path: str, base_prompt: str,
|
||||
current_jrxml: str = "") -> str:
|
||||
"""分析布局并增强 prompt。
|
||||
|
||||
- 完整 A4 模板 → 全量布局描述
|
||||
- 行片段 + 有 JRXML → 行匹配 + 修改指引
|
||||
- 行片段 + 无 JRXML → 行片段描述(视为 A4 模板的一部分)
|
||||
"""
|
||||
result = analyze_layout(file_path)
|
||||
tt = result.get("template_type", "unknown")
|
||||
|
||||
if tt == "unknown":
|
||||
return base_prompt
|
||||
|
||||
if tt == "full_a4":
|
||||
return f"[图片模板分析 — 完整 A4 报表]\n{result['description']}\n\n---\n原始需求:\n{base_prompt}"
|
||||
|
||||
if tt == "partial_rows":
|
||||
if current_jrxml.strip():
|
||||
match = match_rows_to_jrxml(result, current_jrxml)
|
||||
if match["matched"]:
|
||||
return (
|
||||
f"[图片模板分析 — 行片段修改]\n"
|
||||
f"图片包含 {result['total_rows']} 行,视为 A4 模板的一部分。\n"
|
||||
f"{match['description']}\n\n"
|
||||
f"{result['description']}\n\n"
|
||||
f"---\n请根据以上匹配结果,修改 JRXML 中对应区域的布局:\n{base_prompt}"
|
||||
)
|
||||
else:
|
||||
return (
|
||||
f"[图片模板分析 — 行片段(未匹配到现有区域)]\n"
|
||||
f"图片包含 {result['total_rows']} 行。\n"
|
||||
f"{result['description']}\n\n"
|
||||
f"---\n请根据以上行结构,在 JRXML 中找到合适位置进行修改:\n{base_prompt}"
|
||||
)
|
||||
else:
|
||||
return (
|
||||
f"[图片模板分析 — 行片段(无现有报表,按 A4 模板处理)]\n"
|
||||
f"图片包含 {result['total_rows']} 行,请按 A4 报表模板的需求输出整张报表。\n"
|
||||
f"{result['description']}\n\n"
|
||||
f"---\n原始需求:\n{base_prompt}"
|
||||
)
|
||||
|
||||
return base_prompt
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JRXML 结构解析
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _parse_jrxml_sections(jrxml: str) -> list[dict]:
|
||||
"""解析 JRXML 中的 section/band 结构。
|
||||
|
||||
直接搜索所有 band 元素,通过上下文字符串推断其所属 section。
|
||||
"""
|
||||
sections = []
|
||||
try:
|
||||
root = ET.fromstring(jrxml)
|
||||
section_tags = {
|
||||
"title", "pageHeader", "columnHeader", "detail",
|
||||
"columnFooter", "pageFooter", "summary", "background",
|
||||
"noData", "groupHeader", "groupFooter",
|
||||
}
|
||||
|
||||
for section_elem in root.iter():
|
||||
stag = _tag(section_elem)
|
||||
if stag not in section_tags:
|
||||
continue
|
||||
|
||||
for child in section_elem:
|
||||
if _tag(child) == "band":
|
||||
name = child.get("name", "")
|
||||
section_name = f"{stag}[{name}]" if name else stag
|
||||
text_content = ET.tostring(child, encoding="unicode")
|
||||
sections.append({
|
||||
"name": section_name,
|
||||
"type": stag,
|
||||
"text_content": text_content,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: 如果 structured parsing 失败,直接把整个 JRXML 按 band 分割
|
||||
if not sections:
|
||||
sections = _parse_jrxml_regex(jrxml)
|
||||
|
||||
return sections
|
||||
|
||||
|
||||
def _tag(elem) -> str:
|
||||
"""去除命名空间前缀的标签名。"""
|
||||
return elem.tag.split("}")[-1] if "}" in elem.tag else elem.tag
|
||||
|
||||
|
||||
def _parse_jrxml_regex(jrxml: str) -> list[dict]:
|
||||
"""正则回退方案:直接在文本中搜索 band 块。"""
|
||||
sections = []
|
||||
band_pattern = re.compile(
|
||||
r'<(title|pageHeader|columnHeader|detail|columnFooter|pageFooter|summary|background|noData|groupHeader|groupFooter)>\s*'
|
||||
r'(<band[^>]*>.*?</band>)\s*'
|
||||
r'</\1>',
|
||||
re.DOTALL,
|
||||
)
|
||||
for m in band_pattern.finditer(jrxml):
|
||||
stag = m.group(1)
|
||||
band_xml = m.group(0)
|
||||
sections.append({
|
||||
"name": stag,
|
||||
"type": stag,
|
||||
"text_content": band_xml,
|
||||
})
|
||||
return sections
|
||||
|
||||
|
||||
def _text_similarity(ocr_texts: list[str], jrxml_text: str) -> float:
|
||||
"""计算 OCR 文本与 JRXML 文本的相似度(简单的词匹配)。"""
|
||||
if not ocr_texts or not jrxml_text:
|
||||
return 0.0
|
||||
|
||||
jrxml_lower = jrxml_text.lower()
|
||||
score = 0.0
|
||||
for text in ocr_texts:
|
||||
# 精确匹配
|
||||
if text.lower() in jrxml_lower:
|
||||
score += 1.0
|
||||
else:
|
||||
# 部分词匹配
|
||||
words = re.findall(r"\w+", text)
|
||||
matched = sum(1 for w in words if w.lower() in jrxml_lower)
|
||||
if words:
|
||||
score += matched / len(words) * 0.5
|
||||
|
||||
return min(score / len(ocr_texts), 1.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 内部实现(不变)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _load_image(path: Path) -> Optional[PIL.Image.Image]:
|
||||
suffix = path.suffix.lower()
|
||||
|
||||
if suffix in (".png", ".jpg", ".jpeg", ".bmp", ".webp"):
|
||||
try:
|
||||
return PIL.Image.open(path).convert("RGB")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if suffix == ".pdf":
|
||||
try:
|
||||
import pdfplumber
|
||||
with pdfplumber.open(path) as pdf:
|
||||
if pdf.pages:
|
||||
pil_img = pdf.pages[0].to_image(resolution=150)
|
||||
return pil_img.original.convert("RGB")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
import fitz
|
||||
doc = fitz.open(path)
|
||||
pix = doc[0].get_pixmap(dpi=150)
|
||||
img = PIL.Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
||||
doc.close()
|
||||
return img
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _ocr_elements(img: PIL.Image.Image, file_path: str) -> list[dict]:
|
||||
try:
|
||||
from paddleocr import PaddleOCR
|
||||
import numpy as np
|
||||
|
||||
ocr = PaddleOCR(lang="ch", use_angle_cls=True, show_log=False)
|
||||
result = ocr.ocr(np.array(img))
|
||||
|
||||
elements = []
|
||||
if result and result[0]:
|
||||
for line in result[0]:
|
||||
if len(line) < 2:
|
||||
continue
|
||||
box = line[0]
|
||||
text_info = line[1]
|
||||
text = text_info[0] if isinstance(text_info, (list, tuple)) else str(text_info)
|
||||
if not text.strip():
|
||||
continue
|
||||
|
||||
xs = [p[0] for p in box]
|
||||
ys = [p[1] for p in box]
|
||||
x_min, x_max = min(xs), max(xs)
|
||||
y_min, y_max = min(ys), max(ys)
|
||||
|
||||
elements.append({
|
||||
"x": round(x_min, 1),
|
||||
"y": round(y_min, 1),
|
||||
"w": round(x_max - x_min, 1),
|
||||
"h": round(y_max - y_min, 1),
|
||||
"font_size": round(y_max - y_min, 1),
|
||||
"text": text.strip(),
|
||||
})
|
||||
|
||||
elements.sort(key=lambda e: (e["y"], e["x"]))
|
||||
return elements
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def _group_into_rows(elements: list[dict], img_height: int,
|
||||
tolerance_ratio: float = 0.02) -> list[dict]:
|
||||
if not elements:
|
||||
return []
|
||||
|
||||
tolerance = img_height * tolerance_ratio
|
||||
rows = []
|
||||
current_row = [elements[0]]
|
||||
|
||||
for elem in elements[1:]:
|
||||
prev_cy = current_row[0]["y"] + current_row[0]["h"] / 2
|
||||
curr_cy = elem["y"] + elem["h"] / 2
|
||||
|
||||
if abs(curr_cy - prev_cy) < tolerance:
|
||||
current_row.append(elem)
|
||||
else:
|
||||
rows.append(_build_row(current_row))
|
||||
current_row = [elem]
|
||||
|
||||
if current_row:
|
||||
rows.append(_build_row(current_row))
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
def _build_row(elements: list[dict]) -> dict:
|
||||
elements.sort(key=lambda e: e["x"])
|
||||
ys = [e["y"] for e in elements]
|
||||
return {"y_center": round(sum(ys) / len(ys), 1), "elements": elements}
|
||||
|
||||
|
||||
def _build_description(rows: list[dict], img_w: int, img_h: int,
|
||||
a4_confidence: str, template_type: str) -> str:
|
||||
if not rows:
|
||||
if template_type == "partial_rows":
|
||||
return f"图片 {img_w}x{img_h}(非 A4 比例),未检测到文字元素。"
|
||||
return f"图片共 {img_w}x{img_h} 像素,未检测到文字元素。"
|
||||
|
||||
lines = []
|
||||
if template_type == "full_a4":
|
||||
lines.append(f"图片为完整 A4 报表模板,共 {len(rows)} 行,像素区域 {img_w}x{img_h}:")
|
||||
elif template_type == "partial_rows":
|
||||
lines.append(f"图片为报表模板行片段(非完整 A4),包含 {len(rows)} 行,"
|
||||
f"像素区域 {img_w}x{img_h},请按 A4 模板处理:")
|
||||
else:
|
||||
lines.append(f"图片共 {img_w}x{img_h} 像素,包含 {len(rows)} 行文字:")
|
||||
|
||||
for i, row in enumerate(rows):
|
||||
elems = row["elements"]
|
||||
lines.append(f"\n第 {i+1} 行有 {len(elems)} 个元素:")
|
||||
for j, e in enumerate(elems):
|
||||
letter = chr(ord("a") + j)
|
||||
lines.append(
|
||||
f" 元素 {letter}:位置(x={e['x']}, y={e['y']}),"
|
||||
f"长 {e['w']}px,高 {e['h']}px,"
|
||||
f"字体 {e['font_size']}px,"
|
||||
f"内容「{e['text']}」"
|
||||
)
|
||||
|
||||
if template_type == "full_a4":
|
||||
lines.append(f"\n请根据以上布局生成对应的 JRXML 报表模板。")
|
||||
elif template_type == "partial_rows":
|
||||
lines.append(f"\n请将以上 {len(rows)} 行作为 A4 模板的一部分,"
|
||||
f"生成或修改对应的 JRXML 报表区域。")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _empty_result(error: str = "") -> dict:
|
||||
return {
|
||||
"is_a4_template": False,
|
||||
"is_partial": False,
|
||||
"template_type": "unknown",
|
||||
"image_size": (0, 0),
|
||||
"aspect_ratio": 0,
|
||||
"a4_confidence": "not_a4",
|
||||
"rows": [],
|
||||
"description": error,
|
||||
"total_rows": 0,
|
||||
"total_elements": 0,
|
||||
}
|
||||
+48
-4
@@ -8,13 +8,33 @@ from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class _BaseLLM:
|
||||
"""LLM 统一接口基类 — 所有后端都提供 invoke() 和 stream()。"""
|
||||
|
||||
def invoke(self, prompt: str) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def stream(self, prompt: str):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_llm():
|
||||
backend = os.getenv("LLM_BACKEND", "cloud")
|
||||
if backend == "local":
|
||||
from langchain_ollama import ChatOllama
|
||||
|
||||
model = os.getenv("LOCAL_LLM_MODEL", "qwen2.5-coder:7b")
|
||||
return ChatOllama(model=model, temperature=0.1)
|
||||
raw = ChatOllama(model=model, temperature=0.1)
|
||||
|
||||
class OllamaWrapper(_BaseLLM):
|
||||
def invoke(self, prompt):
|
||||
return raw.invoke(prompt)
|
||||
|
||||
def stream(self, prompt):
|
||||
for chunk in raw.stream(prompt):
|
||||
yield chunk.content
|
||||
|
||||
return OllamaWrapper()
|
||||
|
||||
provider = os.getenv("LLM_PROVIDER", "openai")
|
||||
if provider == "anthropic":
|
||||
@@ -30,7 +50,7 @@ def get_llm():
|
||||
|
||||
client = Anthropic(api_key=api_key, base_url=base_url, timeout=120)
|
||||
|
||||
class MiniMaxLLM:
|
||||
class MiniMaxLLM(_BaseLLM):
|
||||
def invoke(self, prompt: str) -> Any:
|
||||
resp = client.messages.create(
|
||||
model=model,
|
||||
@@ -43,20 +63,44 @@ def get_llm():
|
||||
return type("Response", (), {"content": block.text})()
|
||||
return type("Response", (), {"content": ""})()
|
||||
|
||||
def stream(self, prompt: str):
|
||||
with client.messages.stream(
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
messages=[{"role": "user", "content": [{"type": "text", "text": prompt}]}],
|
||||
) as s:
|
||||
for text in s.text_stream:
|
||||
yield text
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
return client.count_tokens(text)
|
||||
resp = client.messages.count_tokens(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": [{"type": "text", "text": text}]}],
|
||||
)
|
||||
return resp.input_tokens
|
||||
|
||||
return MiniMaxLLM()
|
||||
else:
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
return ChatOpenAI(
|
||||
raw = ChatOpenAI(
|
||||
model=os.getenv("LLM_MODEL", "gpt-4o"),
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
base_url=os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1"),
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
class OpenAIWrapper(_BaseLLM):
|
||||
def invoke(self, prompt):
|
||||
return raw.invoke(prompt)
|
||||
|
||||
def stream(self, prompt):
|
||||
for chunk in raw.stream(prompt):
|
||||
yield chunk.content
|
||||
|
||||
return OpenAIWrapper()
|
||||
|
||||
|
||||
def get_llm_for_correction():
|
||||
return get_llm()
|
||||
Reference in New Issue
Block a user