fix: band-level windowed refine_layout + programmatic map_fields to prevent 91.5% content loss
Root cause: LLM receiving full 34k-char JRXML would regenerate from scratch
instead of modifying coordinates in-place, shrinking output to ~3k chars.
Solution (programmatic node control, not prompt engineering):
- New agent/jrxml_windower.py: decompose JRXML into header (never sent to
LLM) + individual bands. Split bands >4000 chars at element boundaries.
Reassemble with element count validation (>10% change = rollback).
- Rewrite refine_layout: per-band windowed LLM processing (~2-4k chars
each). LLM cannot "reimagine" the entire report.
- Rewrite map_fields: 100% programmatic regex $F{field_N} -> real name
replacement. Zero LLM calls, zero content loss.
- _sanitize_field_name: non-ASCII chars escaped to _uXXXX_ format for
valid JRXML identifiers.
- Tests: 48 new unit tests (windower 28 + map_fields 20). All passing.
Full suite 385 tests, zero regressions.
This commit is contained in:
@@ -0,0 +1,136 @@
|
||||
"""OCR 字段 → KB 字段匹配模块。
|
||||
|
||||
两阶段匹配:
|
||||
1. Embedding 粗筛(相似度 top-3)
|
||||
2. LLM 精确确认
|
||||
|
||||
返回映射: {"工单号": "billNo", "客户名称": "customerName", ...}
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from backend.logger import get_logger
|
||||
|
||||
load_dotenv()
|
||||
|
||||
_match_log = get_logger("field_matcher")
|
||||
|
||||
|
||||
def _embed(text: str) -> list:
|
||||
"""获取文本的向量嵌入。"""
|
||||
from backend.rag_adapter import _get_searcher
|
||||
searcher = _get_searcher()
|
||||
if searcher._model is None:
|
||||
_ = searcher.model
|
||||
emb = searcher.model.encode(text, normalize_embeddings=True, show_progress_bar=False)
|
||||
return emb.tolist()
|
||||
|
||||
|
||||
def _cosine_similarity(a: list, b: list) -> float:
|
||||
"""余弦相似度(假设向量已归一化,点积即相似度)。"""
|
||||
return sum(x * y for x, y in zip(a, b))
|
||||
|
||||
|
||||
def match_ocr_to_kb(ocr_fields: list[str], kb_fields: list[dict],
|
||||
llm=None) -> dict[str, str]:
|
||||
"""将 OCR 提取的字段名匹配到 KB 字段定义。
|
||||
|
||||
Args:
|
||||
ocr_fields: OCR 提取的中文字段名列表
|
||||
kb_fields: KB 字段定义 [{"name": "billNo", "description": "工单号", ...}]
|
||||
llm: 可选的 LLM 实例,用于精确确认
|
||||
|
||||
Returns:
|
||||
{"工单号": "billNo", "客户": "customerName", ...}
|
||||
"""
|
||||
if not ocr_fields or not kb_fields:
|
||||
return {}
|
||||
|
||||
result = {}
|
||||
|
||||
# 阶段 1: Embedding 粗筛
|
||||
try:
|
||||
ocr_embs = {f: _embed(f) for f in ocr_fields}
|
||||
kb_embs = {f["name"]: _embed(f.get("description", f["name"])) for f in kb_fields}
|
||||
except Exception as e:
|
||||
_match_log.warning("Embedding 匹配失败,回退到 LLM: %s", e)
|
||||
return _match_via_llm(ocr_fields, kb_fields, llm)
|
||||
|
||||
candidates = {}
|
||||
for ocr_name, ocr_emb in ocr_embs.items():
|
||||
scored = []
|
||||
for kb_name, kb_emb in kb_embs.items():
|
||||
sim = _cosine_similarity(ocr_emb, kb_emb)
|
||||
scored.append((kb_name, sim))
|
||||
scored.sort(key=lambda x: x[1], reverse=True)
|
||||
candidates[ocr_name] = scored[:3]
|
||||
|
||||
# 阶段 2: LLM 精确确认
|
||||
if llm:
|
||||
confirmed = _match_via_llm(ocr_fields, kb_fields, llm, candidates)
|
||||
result.update(confirmed)
|
||||
else:
|
||||
for ocr_name, cands in candidates.items():
|
||||
if cands and cands[0][1] > 0.5:
|
||||
result[ocr_name] = cands[0][0]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _match_via_llm(ocr_fields: list[str], kb_fields: list[dict],
|
||||
llm, candidates: Optional[dict] = None) -> dict[str, str]:
|
||||
"""使用 LLM 精确确认字段映射。"""
|
||||
kb_desc = "\n".join(
|
||||
f"- {f['name']}: {f.get('description', '')} ({f.get('type', 'java.lang.String')})"
|
||||
for f in kb_fields
|
||||
)
|
||||
|
||||
candidates_hint = ""
|
||||
if candidates:
|
||||
cand_lines = []
|
||||
for ocr_name, cands in candidates.items():
|
||||
cand_str = ", ".join(f"{n}({s:.2f})" for n, s in cands)
|
||||
cand_lines.append(f" {ocr_name} -> 候选: {cand_str}")
|
||||
candidates_hint = (
|
||||
"向量相似度候选(仅供参考,请根据语义确认):\n"
|
||||
+ "\n".join(cand_lines)
|
||||
)
|
||||
|
||||
prompt = (
|
||||
"请将以下 OCR 识别的字段名匹配到知识库定义的字段。\n\n"
|
||||
f"OCR 字段: {json.dumps(ocr_fields, ensure_ascii=False)}\n\n"
|
||||
f"知识库字段:\n{kb_desc}\n\n"
|
||||
f"{candidates_hint}\n\n"
|
||||
"请以 JSON 对象格式输出映射关系,键为 OCR 字段名,值为 KB 字段名:\n"
|
||||
'{"工单号": "billNo", "客户名称": "customerName"}'
|
||||
)
|
||||
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
content = response.content if hasattr(response, "content") else str(response)
|
||||
start = content.find("{")
|
||||
end = content.rfind("}") + 1
|
||||
if start >= 0 and end > start:
|
||||
return json.loads(content[start:end])
|
||||
except Exception as e:
|
||||
_match_log.warning("LLM 字段匹配失败: %s", e)
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def format_field_mapping_context(mapping: dict[str, str]) -> str:
|
||||
"""将字段映射格式化为 prompt 上下文字符串。"""
|
||||
if not mapping:
|
||||
return ""
|
||||
|
||||
lines = ["[字段映射 — OCR -> KB]",
|
||||
"请在 JRXML 中使用以下参数名:",
|
||||
"| OCR 字段 | JRXML 参数 |",
|
||||
"|---|---|"]
|
||||
for ocr_name, kb_name in mapping.items():
|
||||
lines.append(f"| {ocr_name} | $P{{{kb_name}}} |")
|
||||
return "\n".join(lines)
|
||||
@@ -0,0 +1,227 @@
|
||||
"""多租户知识库管理模块。
|
||||
|
||||
用户 + 知识库 CRUD,持久化到 kb_data/ 目录。
|
||||
每个 KB 拥有独立的 JSON 元数据文件和文件存储目录。
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import tempfile
|
||||
import shutil
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from backend.logger import get_logger
|
||||
|
||||
load_dotenv()
|
||||
|
||||
_kb_log = get_logger("kb_manager")
|
||||
|
||||
KB_DATA_DIR = Path(os.getenv("KB_DATA_DIR", "./kb_data"))
|
||||
_USERS_FILE = KB_DATA_DIR / "users.json"
|
||||
|
||||
_VALID_ID_RE = re.compile(r'^[a-fA-F0-9]{12,}$')
|
||||
|
||||
|
||||
def _validate_id(id_str: str, label: str = "id") -> None:
|
||||
if not _VALID_ID_RE.match(id_str):
|
||||
raise ValueError(f"Invalid {label}: {id_str!r}")
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _ensure_dir(path: Path) -> None:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def _read_json(fp: Path) -> dict:
|
||||
with open(fp, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _write_json_atomic(fp: Path, data: dict) -> None:
|
||||
_ensure_dir(fp.parent)
|
||||
tmp = tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False,
|
||||
dir=fp.parent, encoding="utf-8",
|
||||
)
|
||||
try:
|
||||
json.dump(data, tmp, ensure_ascii=False, indent=2)
|
||||
tmp.flush()
|
||||
os.fsync(tmp.fileno())
|
||||
tmp.close()
|
||||
os.replace(tmp.name, str(fp))
|
||||
except Exception:
|
||||
tmp.close()
|
||||
Path(tmp.name).unlink(missing_ok=True)
|
||||
raise
|
||||
|
||||
|
||||
# ── User CRUD ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _load_users() -> list[dict]:
|
||||
_ensure_dir(KB_DATA_DIR)
|
||||
if _USERS_FILE.exists():
|
||||
return _read_json(_USERS_FILE)
|
||||
return []
|
||||
|
||||
|
||||
def _save_users(users: list[dict]) -> None:
|
||||
_write_json_atomic(_USERS_FILE, users)
|
||||
|
||||
|
||||
def create_user(name: str, user_id: Optional[str] = None) -> dict:
|
||||
uid = user_id or uuid.uuid4().hex
|
||||
users = _load_users()
|
||||
if any(u["user_id"] == uid for u in users):
|
||||
raise ValueError(f"User {uid} already exists")
|
||||
user = {"user_id": uid, "name": name, "created_at": _now_iso()}
|
||||
users.append(user)
|
||||
_save_users(users)
|
||||
_ensure_dir(KB_DATA_DIR / uid)
|
||||
_write_json_atomic(KB_DATA_DIR / uid / "profile.json", user)
|
||||
_kb_log.info("创建用户", extra={"user_id": uid, "user_name": name})
|
||||
return user
|
||||
|
||||
|
||||
def list_users() -> list[dict]:
|
||||
return _load_users()
|
||||
|
||||
|
||||
def get_user(user_id: str) -> Optional[dict]:
|
||||
_validate_id(user_id, "user_id")
|
||||
for u in _load_users():
|
||||
if u["user_id"] == user_id:
|
||||
return u
|
||||
return None
|
||||
|
||||
|
||||
def delete_user(user_id: str) -> bool:
|
||||
_validate_id(user_id, "user_id")
|
||||
users = _load_users()
|
||||
filtered = [u for u in users if u["user_id"] != user_id]
|
||||
if len(filtered) == len(users):
|
||||
return False
|
||||
_save_users(filtered)
|
||||
user_dir = KB_DATA_DIR / user_id
|
||||
if user_dir.exists():
|
||||
shutil.rmtree(user_dir)
|
||||
_kb_log.info("删除用户", extra={"user_id": user_id})
|
||||
return True
|
||||
|
||||
|
||||
# ── KB CRUD ────────────────────────────────────────────────────────────────
|
||||
|
||||
def _kb_dir(kb_id: str) -> Optional[Path]:
|
||||
_validate_id(kb_id, "kb_id")
|
||||
for user_dir in KB_DATA_DIR.iterdir():
|
||||
if user_dir.is_dir() and not user_dir.name.startswith("."):
|
||||
candidate = user_dir / kb_id
|
||||
if candidate.is_dir():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
def _ensure_user_dir(user_id: str) -> Path:
|
||||
_validate_id(user_id, "user_id")
|
||||
d = KB_DATA_DIR / user_id
|
||||
_ensure_dir(d)
|
||||
return d
|
||||
|
||||
|
||||
def create_kb(user_id: str, name: str, description: str = "",
|
||||
kb_id: Optional[str] = None) -> dict:
|
||||
user_dir = _ensure_user_dir(user_id)
|
||||
kid = kb_id or uuid.uuid4().hex
|
||||
kb_dir = user_dir / kid
|
||||
_ensure_dir(kb_dir)
|
||||
_ensure_dir(kb_dir / "raw")
|
||||
|
||||
now = _now_iso()
|
||||
meta = {
|
||||
"kb_id": kid, "user_id": user_id, "name": name,
|
||||
"description": description, "created_at": now, "updated_at": now,
|
||||
"fields": [], "templates": [], "file_count": 0,
|
||||
"chunk_count": 0, "parse_status": "empty",
|
||||
}
|
||||
_write_json_atomic(kb_dir / "meta.json", meta)
|
||||
_kb_log.info("创建知识库", extra={"kb_id": kid, "user_id": user_id, "kb_name": name})
|
||||
return meta
|
||||
|
||||
|
||||
def list_kbs(user_id: str) -> list[dict]:
|
||||
user_dir = _ensure_user_dir(user_id)
|
||||
kbs = []
|
||||
for kb_dir in sorted(user_dir.iterdir(), key=os.path.getmtime, reverse=True):
|
||||
if kb_dir.is_dir() and not kb_dir.name.startswith("."):
|
||||
meta_path = kb_dir / "meta.json"
|
||||
if meta_path.exists():
|
||||
meta = _read_json(meta_path)
|
||||
kbs.append({
|
||||
"kb_id": meta.get("kb_id", kb_dir.name),
|
||||
"name": meta.get("name", kb_dir.name),
|
||||
"description": meta.get("description", ""),
|
||||
"created_at": meta.get("created_at", ""),
|
||||
"updated_at": meta.get("updated_at", ""),
|
||||
"field_count": len(meta.get("fields", [])),
|
||||
"template_count": len(meta.get("templates", [])),
|
||||
"file_count": meta.get("file_count", 0),
|
||||
"chunk_count": meta.get("chunk_count", 0),
|
||||
"parse_status": meta.get("parse_status", "empty"),
|
||||
})
|
||||
return kbs
|
||||
|
||||
|
||||
def get_kb(kb_id: str) -> Optional[dict]:
|
||||
_validate_id(kb_id, "kb_id")
|
||||
kb_dir = _kb_dir(kb_id)
|
||||
if kb_dir is None:
|
||||
return None
|
||||
meta_path = kb_dir / "meta.json"
|
||||
return _read_json(meta_path) if meta_path.exists() else None
|
||||
|
||||
|
||||
def update_kb_meta(kb_id: str, updates: dict) -> Optional[dict]:
|
||||
kb_dir = _kb_dir(kb_id)
|
||||
if kb_dir is None:
|
||||
return None
|
||||
meta_path = kb_dir / "meta.json"
|
||||
meta = _read_json(meta_path)
|
||||
meta.update(updates)
|
||||
meta["updated_at"] = _now_iso()
|
||||
_write_json_atomic(meta_path, meta)
|
||||
return meta
|
||||
|
||||
|
||||
def delete_kb(kb_id: str) -> bool:
|
||||
kb_dir = _kb_dir(kb_id)
|
||||
if kb_dir is None:
|
||||
return False
|
||||
shutil.rmtree(kb_dir)
|
||||
_kb_log.info("删除知识库", extra={"kb_id": kb_id})
|
||||
return True
|
||||
|
||||
|
||||
def get_kb_raw_dir(kb_id: str) -> Optional[Path]:
|
||||
kb_dir = _kb_dir(kb_id)
|
||||
return kb_dir / "raw" if kb_dir else None
|
||||
|
||||
|
||||
def get_kb_chunks_path(kb_id: str) -> Optional[Path]:
|
||||
kb_dir = _kb_dir(kb_id)
|
||||
return kb_dir / "chunks.json" if kb_dir else None
|
||||
|
||||
|
||||
def get_kb_chroma_path(kb_id: str) -> Optional[Path]:
|
||||
kb_dir = _kb_dir(kb_id)
|
||||
if kb_dir is None:
|
||||
return None
|
||||
chroma_dir = kb_dir / "chroma"
|
||||
_ensure_dir(chroma_dir)
|
||||
return chroma_dir
|
||||
@@ -0,0 +1,336 @@
|
||||
"""KB 解析管道 — 文件提取→字段解析→chunk 切割→向量嵌入。
|
||||
|
||||
调用者: api_server.py (upload endpoint), scripts/init_default_kb.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
import zipfile
|
||||
import tarfile
|
||||
import tempfile
|
||||
import defusedxml.ElementTree as ET
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from backend.logger import get_logger
|
||||
from backend.file_parser import parse_file
|
||||
|
||||
load_dotenv()
|
||||
|
||||
_kb_parse_log = get_logger("kb_parser")
|
||||
|
||||
|
||||
def _find_tag(elem, tag):
|
||||
for el in elem.iter():
|
||||
local = el.tag.split("}")[-1] if "}" in el.tag else el.tag
|
||||
if local == tag:
|
||||
return el
|
||||
return None
|
||||
|
||||
|
||||
def _find_all_tags(elem, tag):
|
||||
results = []
|
||||
for el in elem.iter():
|
||||
local = el.tag.split("}")[-1] if "}" in el.tag else el.tag
|
||||
if local == tag:
|
||||
results.append(el)
|
||||
return results
|
||||
|
||||
|
||||
def parse_jrxml_fields(jrxml_path: str) -> dict:
|
||||
"""解析 JRXML 文件,提取参数和字段定义。"""
|
||||
try:
|
||||
tree = ET.parse(jrxml_path)
|
||||
root = tree.getroot()
|
||||
except ET.ParseError as e:
|
||||
return {"error": f"JRXML 解析失败: {e}", "parameters": [], "fields": [],
|
||||
"report_name": ""}
|
||||
|
||||
report_name = root.attrib.get("name", "")
|
||||
page_width = root.attrib.get("pageWidth", "")
|
||||
page_height = root.attrib.get("pageHeight", "")
|
||||
|
||||
parameters = []
|
||||
for p in _find_all_tags(root, "parameter"):
|
||||
params = {"name": p.attrib.get("name", ""),
|
||||
"type": p.attrib.get("class", "java.lang.String"),
|
||||
"description": ""}
|
||||
desc = _find_tag(p, "parameterDescription")
|
||||
if desc is not None and desc.text:
|
||||
params["description"] = desc.text.strip()
|
||||
parameters.append(params)
|
||||
|
||||
fields = []
|
||||
for f in _find_all_tags(root, "field"):
|
||||
fields.append({"name": f.attrib.get("name", ""),
|
||||
"type": f.attrib.get("class", "java.lang.String"),
|
||||
"description": ""})
|
||||
|
||||
query_text = ""
|
||||
query = _find_tag(root, "queryString")
|
||||
if query is not None and query.text:
|
||||
query_text = query.text.strip()
|
||||
|
||||
return {"report_name": report_name, "page_width": page_width,
|
||||
"page_height": page_height, "parameters": parameters,
|
||||
"fields": fields, "query": query_text, "error": None}
|
||||
|
||||
|
||||
def _extract_archive(file_path: str, dest_dir: str) -> list[str]:
|
||||
extracted = []
|
||||
resolved_dest = os.path.realpath(dest_dir)
|
||||
|
||||
if zipfile.is_zipfile(file_path):
|
||||
with zipfile.ZipFile(file_path, "r") as zf:
|
||||
for member in zf.namelist():
|
||||
member_path = os.path.realpath(os.path.join(dest_dir, member))
|
||||
if not member_path.startswith(resolved_dest + os.sep):
|
||||
continue
|
||||
zf.extract(member, dest_dir)
|
||||
if not member.endswith("/"):
|
||||
extracted.append(member_path)
|
||||
elif tarfile.is_tarfile(file_path):
|
||||
with tarfile.open(file_path, "r:*") as tf:
|
||||
for member in tf.getmembers():
|
||||
member_path = os.path.realpath(os.path.join(dest_dir, member.name))
|
||||
if not member_path.startswith(resolved_dest + os.sep):
|
||||
continue
|
||||
tf.extract(member, dest_dir)
|
||||
if not member.name.endswith("/"):
|
||||
extracted.append(member_path)
|
||||
return extracted
|
||||
|
||||
|
||||
def process_file_for_kb(kb_id: str, file_path: str,
|
||||
source_name: str = "") -> dict:
|
||||
from backend.kb_manager import get_kb_raw_dir
|
||||
raw_dir = get_kb_raw_dir(kb_id)
|
||||
if raw_dir is None:
|
||||
return {"error": "KB 不存在"}
|
||||
|
||||
fname = source_name or os.path.basename(file_path)
|
||||
dest = raw_dir / fname
|
||||
shutil.copy2(file_path, dest)
|
||||
|
||||
suffix = Path(fname).suffix.lower()
|
||||
|
||||
if suffix == ".jrxml":
|
||||
jrxml_info = parse_jrxml_fields(file_path)
|
||||
text = f"[JRXML 模板: {jrxml_info['report_name']}]\n"
|
||||
text += f"页面: {jrxml_info['page_width']}x{jrxml_info['page_height']}\n"
|
||||
text += "参数:\n" + "\n".join(
|
||||
f" {p['name']} ({p['type']})" for p in jrxml_info["parameters"])
|
||||
text += "\n字段:\n" + "\n".join(
|
||||
f" {f['name']} ({f['type']})" for f in jrxml_info["fields"])
|
||||
if jrxml_info["query"]:
|
||||
text += f"\n查询:\n{jrxml_info['query']}"
|
||||
try:
|
||||
raw_xml = Path(file_path).read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
raw_xml = ""
|
||||
return {"filename": fname, "type": "jrxml", "text": text,
|
||||
"raw_xml": raw_xml, "jrxml_info": jrxml_info, "error": None}
|
||||
|
||||
if suffix in (".zip", ".tar", ".gz", ".tgz"):
|
||||
tmpdir = tempfile.mkdtemp(prefix="kb_extract_")
|
||||
try:
|
||||
extracted = _extract_archive(file_path, tmpdir)
|
||||
sub_results = []
|
||||
for ep in extracted:
|
||||
sub = process_file_for_kb(
|
||||
kb_id, ep, source_name=os.path.basename(ep))
|
||||
sub_results.append(sub)
|
||||
return {"filename": fname, "type": "archive", "text": "",
|
||||
"archive_contents": sub_results, "error": None}
|
||||
finally:
|
||||
shutil.rmtree(tmpdir, ignore_errors=True)
|
||||
|
||||
parse_result = parse_file(str(dest))
|
||||
return {"filename": fname, "type": suffix.lstrip("."),
|
||||
"text": parse_result.get("text", ""),
|
||||
"error": parse_result.get("error")}
|
||||
|
||||
|
||||
def chunk_file_results(results: list[dict], kb_name: str = "") -> list[dict]:
|
||||
chunks = []
|
||||
chunk_idx = 0
|
||||
|
||||
for r in results:
|
||||
if r.get("type") == "archive":
|
||||
for sub in r.get("archive_contents", []):
|
||||
chunks.extend(chunk_file_results([sub], kb_name))
|
||||
continue
|
||||
|
||||
fname = r.get("filename", "")
|
||||
ftype = r.get("type", "")
|
||||
text = r.get("text", "")
|
||||
if not text.strip():
|
||||
continue
|
||||
|
||||
if ftype == "jrxml" and r.get("raw_xml"):
|
||||
jinfo = r.get("jrxml_info", {})
|
||||
report_name = jinfo.get("report_name", "")
|
||||
chunks.append({
|
||||
"id": f"chunk_{chunk_idx}",
|
||||
"content": f"[JRXML 模板: {report_name}]\n{r['text']}\n\n"
|
||||
f"<xml>\n{r['raw_xml']}\n</xml>",
|
||||
"metadata": {"chunk_type": "jrxml_template",
|
||||
"source_file": fname,
|
||||
"report_name": report_name,
|
||||
"kb_name": kb_name,
|
||||
"param_count": len(jinfo.get("parameters", [])),
|
||||
"field_count": len(jinfo.get("fields", []))},
|
||||
})
|
||||
chunk_idx += 1
|
||||
continue
|
||||
|
||||
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
||||
for para in paragraphs:
|
||||
if len(para) < 10:
|
||||
continue
|
||||
chunk_type = "md_section" if ftype in ("md", "") else f"{ftype}_text"
|
||||
chunks.append({
|
||||
"id": f"chunk_{chunk_idx}",
|
||||
"content": para,
|
||||
"metadata": {"chunk_type": chunk_type,
|
||||
"source_file": fname, "kb_name": kb_name},
|
||||
})
|
||||
chunk_idx += 1
|
||||
return chunks
|
||||
|
||||
|
||||
def extract_fields_with_llm(text: str, llm=None) -> list[dict]:
|
||||
if llm is None:
|
||||
return _extract_fields_from_table(text)
|
||||
prompt = (
|
||||
"请分析以下接口文档内容,提取所有字段定义。\n"
|
||||
"对每个字段,输出: 字段名, 含义, 类型, 是否必需。\n"
|
||||
"以 JSON 数组格式输出,每个元素为 {\"name\": \"...\", "
|
||||
"\"description\": \"...\", \"type\": \"...\", \"required\": false}。\n\n"
|
||||
f"{text}"
|
||||
)
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
content = response.content if hasattr(response, "content") else str(response)
|
||||
start = content.find("[")
|
||||
end = content.rfind("]") + 1
|
||||
if start >= 0 and end > start:
|
||||
return json.loads(content[start:end])
|
||||
except Exception as e:
|
||||
_kb_parse_log.warning("LLM 字段提取失败,使用表格回退: %s", e)
|
||||
return _extract_fields_from_table(text)
|
||||
|
||||
|
||||
def _extract_fields_from_table(text: str) -> list[dict]:
|
||||
fields = []
|
||||
lines = text.split("\n")
|
||||
header_found = False
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line.startswith("|"):
|
||||
continue
|
||||
cells = [c.strip() for c in line.split("|")[1:-1]]
|
||||
if not cells:
|
||||
continue
|
||||
if not header_found:
|
||||
if any(h in str(c) for c in cells
|
||||
for h in ["字段", "名称", "含义", "说明", "类型"]):
|
||||
header_found = True
|
||||
continue
|
||||
if all(c.replace("-", "").replace(":", "").replace(" ", "") == ""
|
||||
for c in cells):
|
||||
continue
|
||||
if len(cells) >= 2:
|
||||
name = cells[0].replace("**", "").replace("L ", "").replace("\\", "").strip()
|
||||
if not name or name in ("", "---"):
|
||||
continue
|
||||
field = {"name": name, "description": "", "type": "java.lang.String",
|
||||
"required": False}
|
||||
if len(cells) >= 2 and cells[1]:
|
||||
field["description"] = cells[1].replace("<br/>", " ").strip()
|
||||
if len(cells) >= 3 and cells[2]:
|
||||
field["required"] = cells[2].strip() in ("是", "Y", "y", "yes", "Yes", "必填")
|
||||
if len(cells) >= 4 and cells[3]:
|
||||
field["type"] = cells[3].strip()
|
||||
fields.append(field)
|
||||
return fields
|
||||
|
||||
|
||||
def build_kb_from_files(kb_id: str, file_paths: list[str],
|
||||
llm=None) -> dict:
|
||||
from backend.kb_manager import update_kb_meta, get_kb_chunks_path
|
||||
from backend.kb_searcher import get_kb_searcher
|
||||
|
||||
all_results = []
|
||||
errors = []
|
||||
for fp in file_paths:
|
||||
try:
|
||||
r = process_file_for_kb(kb_id, fp)
|
||||
all_results.append(r)
|
||||
if r.get("error"):
|
||||
errors.append({"file": os.path.basename(fp), "error": r["error"]})
|
||||
except Exception as e:
|
||||
errors.append({"file": os.path.basename(fp), "error": str(e)})
|
||||
|
||||
chunks = chunk_file_results(all_results)
|
||||
|
||||
chunks_path = get_kb_chunks_path(kb_id)
|
||||
if chunks_path:
|
||||
chunks_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(chunks_path, "w", encoding="utf-8") as f:
|
||||
json.dump(chunks, f, ensure_ascii=False, indent=2)
|
||||
|
||||
searcher = get_kb_searcher(kb_id)
|
||||
if searcher and chunks:
|
||||
try:
|
||||
searcher.add_chunks(chunks)
|
||||
except Exception as e:
|
||||
errors.append({"file": "embedding", "error": str(e)})
|
||||
|
||||
all_fields = []
|
||||
template_names = []
|
||||
for r in all_results:
|
||||
_collect_from_result(r, all_fields, template_names)
|
||||
|
||||
for r in all_results:
|
||||
if r.get("type") in ("archive", "jrxml"):
|
||||
continue
|
||||
text = r.get("text", "")
|
||||
if text.strip():
|
||||
for ef in extract_fields_with_llm(text, llm):
|
||||
if not any(f["name"] == ef["name"] for f in all_fields):
|
||||
all_fields.append(ef)
|
||||
|
||||
update_kb_meta(kb_id, {
|
||||
"fields": all_fields, "templates": template_names,
|
||||
"file_count": len(file_paths), "chunk_count": len(chunks),
|
||||
"parse_status": "ready" if not errors else "partial",
|
||||
})
|
||||
|
||||
_kb_parse_log.info("KB 构建完成", extra={
|
||||
"kb_id": kb_id, "fields": len(all_fields),
|
||||
"templates": len(template_names), "chunks": len(chunks),
|
||||
})
|
||||
|
||||
return {"status": "ready" if not errors else "partial",
|
||||
"field_count": len(all_fields), "chunk_count": len(chunks),
|
||||
"template_count": len(template_names), "errors": errors}
|
||||
|
||||
|
||||
def _collect_from_result(r: dict, all_fields: list, template_names: list) -> None:
|
||||
jinfo = r.get("jrxml_info")
|
||||
if jinfo and jinfo.get("report_name"):
|
||||
template_names.append({"name": jinfo["report_name"],
|
||||
"file": r.get("filename", "")})
|
||||
for p in jinfo.get("parameters", []):
|
||||
field = {"name": p["name"], "description": p.get("description", ""),
|
||||
"type": p.get("type", "java.lang.String"), "required": False}
|
||||
if not any(f["name"] == field["name"] for f in all_fields):
|
||||
all_fields.append(field)
|
||||
for f in jinfo.get("fields", []):
|
||||
field = {"name": f["name"], "description": f.get("description", ""),
|
||||
"type": f.get("type", "java.lang.String"), "required": False}
|
||||
if not any(fi["name"] == field["name"] for fi in all_fields):
|
||||
all_fields.append(field)
|
||||
@@ -0,0 +1,170 @@
|
||||
"""KB 隔离的 ChromaDB 语义搜索适配器。
|
||||
|
||||
每个知识库拥有独立的 ChromaDB collection。
|
||||
调用者: backend/rag_adapter.py, agent/nodes.py, api_server.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
def _resolve(path: str) -> Path:
|
||||
p = Path(path)
|
||||
return p if p.is_absolute() else _PROJECT_ROOT / p
|
||||
|
||||
|
||||
class KBChromaSearcher:
|
||||
"""连接指定 KB 的 ChromaDB,提供语义搜索。"""
|
||||
|
||||
def __init__(self, chroma_path: str, collection_name: str = "kb_chunks",
|
||||
model_name: Optional[str] = None, use_gpu: Optional[bool] = None,
|
||||
use_fp16: Optional[bool] = None):
|
||||
self.chroma_path = str(_resolve(chroma_path))
|
||||
self.collection_name = collection_name
|
||||
model_path = model_name or os.getenv(
|
||||
"RAG_EMBED_MODEL", "./rag/models/paraphrase-multilingual-MiniLM-L12-v2")
|
||||
resolved = _resolve(model_path)
|
||||
self.model_name = str(resolved) if resolved.exists() else model_path
|
||||
self.use_gpu = (use_gpu if use_gpu is not None
|
||||
else os.getenv("RAG_USE_GPU", "true").lower() in ("true", "1"))
|
||||
self.use_fp16 = (use_fp16 if use_fp16 is not None
|
||||
else os.getenv("RAG_USE_FP16", "true").lower() in ("true", "1"))
|
||||
self._model = None
|
||||
self._client = None
|
||||
self._collection = None
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
if self._model is None:
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
device = "cuda" if (self.use_gpu and torch.cuda.is_available()) else "cpu"
|
||||
logger.info("加载嵌入模型: %s (device=%s)", self.model_name, device)
|
||||
model = SentenceTransformer(self.model_name, device=device)
|
||||
if device == "cuda" and self.use_fp16:
|
||||
model = model.half()
|
||||
self._model = model
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
if self._client is None:
|
||||
import chromadb
|
||||
self._client = chromadb.PersistentClient(path=self.chroma_path)
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def collection(self):
|
||||
if self._collection is None:
|
||||
try:
|
||||
self._collection = self.client.get_collection(self.collection_name)
|
||||
except Exception:
|
||||
self._collection = self.client.create_collection(
|
||||
self.collection_name, metadata={"hnsw:space": "cosine"})
|
||||
return self._collection
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
try:
|
||||
self.client.get_collection(self.collection_name)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def search(self, query: str, k: int = 5, threshold: Optional[float] = None) -> list[dict]:
|
||||
if not self.is_ready():
|
||||
return []
|
||||
query_embedding = self.model.encode(
|
||||
query, normalize_embeddings=True, show_progress_bar=False)
|
||||
results = self.collection.query(
|
||||
query_embeddings=[query_embedding.tolist()],
|
||||
n_results=k, include=["documents", "metadatas", "distances"])
|
||||
output = []
|
||||
if not results["ids"] or not results["ids"][0]:
|
||||
return output
|
||||
for i, doc_id in enumerate(results["ids"][0]):
|
||||
dist = results["distances"][0][i]
|
||||
if threshold is not None and dist > threshold:
|
||||
continue
|
||||
output.append({
|
||||
"id": doc_id,
|
||||
"content": results["documents"][0][i],
|
||||
"metadata": results["metadatas"][0][i] or {},
|
||||
"distance": dist,
|
||||
})
|
||||
return output
|
||||
|
||||
def search_templates(self, query: str, k: int = 3) -> list[dict]:
|
||||
results = self.search(query, k=k * 2)
|
||||
templates = []
|
||||
for r in results:
|
||||
meta = r.get("metadata", {})
|
||||
chunk_type = meta.get("chunk_type", "")
|
||||
if "jrxml" in chunk_type.lower() or meta.get("report_name"):
|
||||
templates.append(r)
|
||||
if len(templates) >= k:
|
||||
break
|
||||
return templates
|
||||
|
||||
def search_as_context(self, query: str, k: int = 5) -> str:
|
||||
results = self.search(query, k=k)
|
||||
if not results:
|
||||
return ""
|
||||
parts = []
|
||||
for r in results:
|
||||
meta = r.get("metadata", {})
|
||||
header = f"[类型:{meta.get('chunk_type', 'N/A')}]"
|
||||
if meta.get("report_name"):
|
||||
header += f" [报表:{meta['report_name']}]"
|
||||
parts.append(f"{header}\n{r['content']}")
|
||||
return "\n\n---\n\n".join(parts)
|
||||
|
||||
def add_chunks(self, chunks: list[dict]) -> None:
|
||||
if not chunks:
|
||||
return
|
||||
ids = [c["id"] for c in chunks]
|
||||
docs = [c["content"] for c in chunks]
|
||||
metas = [c.get("metadata", {}) for c in chunks]
|
||||
embeddings = self.model.encode(
|
||||
docs, normalize_embeddings=True, show_progress_bar=True)
|
||||
self.collection.upsert(
|
||||
ids=ids, documents=docs, metadatas=metas,
|
||||
embeddings=embeddings.tolist())
|
||||
|
||||
|
||||
_searchers: dict = {}
|
||||
|
||||
|
||||
def get_kb_searcher(kb_id: str) -> Optional[KBChromaSearcher]:
|
||||
from backend.kb_manager import get_kb_chroma_path
|
||||
if kb_id in _searchers:
|
||||
return _searchers[kb_id]
|
||||
chroma_path = get_kb_chroma_path(kb_id)
|
||||
if chroma_path is None:
|
||||
return None
|
||||
searcher = KBChromaSearcher(str(chroma_path))
|
||||
_searchers[kb_id] = searcher
|
||||
return searcher
|
||||
|
||||
|
||||
def search_kb(kb_id: str, query: str, k: int = 5) -> str:
|
||||
searcher = get_kb_searcher(kb_id)
|
||||
if searcher is None:
|
||||
return ""
|
||||
return searcher.search_as_context(query, k=k)
|
||||
|
||||
|
||||
def search_templates_in_kb(kb_id: str, query: str, k: int = 3) -> list[dict]:
|
||||
searcher = get_kb_searcher(kb_id)
|
||||
if searcher is None:
|
||||
return []
|
||||
return searcher.search_templates(query, k=k)
|
||||
@@ -150,6 +150,12 @@ def _get_searcher() -> RAGSearcher:
|
||||
return _searcher
|
||||
|
||||
|
||||
def search_chunks(query: str, k: int = 5) -> str:
|
||||
"""搜索 JRXML 知识库并返回拼接后的上下文文本(便捷函数)。"""
|
||||
def search_chunks(query: str, k: int = 5, kb_id: str = "") -> str:
|
||||
"""搜索知识库并返回拼接后的上下文文本。
|
||||
|
||||
若指定 kb_id,使用该 KB 专属 ChromaDB;否则使用全局默认库。
|
||||
"""
|
||||
if kb_id:
|
||||
from backend.kb_searcher import search_kb
|
||||
return search_kb(kb_id, query, k=k)
|
||||
return _get_searcher().search_as_context(query, k=k)
|
||||
|
||||
@@ -56,6 +56,7 @@ def create_session(name: str = "", agent_state: Optional[dict] = None,
|
||||
"session_name": name or f"新建报表 {now[:10]}",
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"kb_id": agent_state.get("kb_id", "") if agent_state else "",
|
||||
"agent_state": agent_state,
|
||||
}
|
||||
with open(_session_path(sid), "w", encoding="utf-8") as f:
|
||||
|
||||
Reference in New Issue
Block a user