fix: band-level windowed refine_layout + programmatic map_fields to prevent 91.5% content loss

Root cause: LLM receiving full 34k-char JRXML would regenerate from scratch
instead of modifying coordinates in-place, shrinking output to ~3k chars.

Solution (programmatic node control, not prompt engineering):

- New agent/jrxml_windower.py: decompose JRXML into header (never sent to
  LLM) + individual bands. Split bands >4000 chars at element boundaries.
  Reassemble with element count validation (>10% change = rollback).

- Rewrite refine_layout: per-band windowed LLM processing (~2-4k chars
  each). LLM cannot "reimagine" the entire report.

- Rewrite map_fields: 100% programmatic regex $F{field_N} -> real name
  replacement. Zero LLM calls, zero content loss.

- _sanitize_field_name: non-ASCII chars escaped to _uXXXX_ format for
  valid JRXML identifiers.

- Tests: 48 new unit tests (windower 28 + map_fields 20). All passing.
  Full suite 385 tests, zero regressions.
This commit is contained in:
2026-05-24 08:55:38 +08:00
parent bb6cc6e241
commit bd5bfbac2d
80 changed files with 39463 additions and 108 deletions
+136
View File
@@ -0,0 +1,136 @@
"""OCR 字段 → KB 字段匹配模块。
两阶段匹配:
1. Embedding 粗筛(相似度 top-3
2. LLM 精确确认
返回映射: {"工单号": "billNo", "客户名称": "customerName", ...}
"""
import json
import os
from typing import Optional
from dotenv import load_dotenv
from backend.logger import get_logger
load_dotenv()
_match_log = get_logger("field_matcher")
def _embed(text: str) -> list:
"""获取文本的向量嵌入。"""
from backend.rag_adapter import _get_searcher
searcher = _get_searcher()
if searcher._model is None:
_ = searcher.model
emb = searcher.model.encode(text, normalize_embeddings=True, show_progress_bar=False)
return emb.tolist()
def _cosine_similarity(a: list, b: list) -> float:
"""余弦相似度(假设向量已归一化,点积即相似度)。"""
return sum(x * y for x, y in zip(a, b))
def match_ocr_to_kb(ocr_fields: list[str], kb_fields: list[dict],
llm=None) -> dict[str, str]:
"""将 OCR 提取的字段名匹配到 KB 字段定义。
Args:
ocr_fields: OCR 提取的中文字段名列表
kb_fields: KB 字段定义 [{"name": "billNo", "description": "工单号", ...}]
llm: 可选的 LLM 实例,用于精确确认
Returns:
{"工单号": "billNo", "客户": "customerName", ...}
"""
if not ocr_fields or not kb_fields:
return {}
result = {}
# 阶段 1: Embedding 粗筛
try:
ocr_embs = {f: _embed(f) for f in ocr_fields}
kb_embs = {f["name"]: _embed(f.get("description", f["name"])) for f in kb_fields}
except Exception as e:
_match_log.warning("Embedding 匹配失败,回退到 LLM: %s", e)
return _match_via_llm(ocr_fields, kb_fields, llm)
candidates = {}
for ocr_name, ocr_emb in ocr_embs.items():
scored = []
for kb_name, kb_emb in kb_embs.items():
sim = _cosine_similarity(ocr_emb, kb_emb)
scored.append((kb_name, sim))
scored.sort(key=lambda x: x[1], reverse=True)
candidates[ocr_name] = scored[:3]
# 阶段 2: LLM 精确确认
if llm:
confirmed = _match_via_llm(ocr_fields, kb_fields, llm, candidates)
result.update(confirmed)
else:
for ocr_name, cands in candidates.items():
if cands and cands[0][1] > 0.5:
result[ocr_name] = cands[0][0]
return result
def _match_via_llm(ocr_fields: list[str], kb_fields: list[dict],
llm, candidates: Optional[dict] = None) -> dict[str, str]:
"""使用 LLM 精确确认字段映射。"""
kb_desc = "\n".join(
f"- {f['name']}: {f.get('description', '')} ({f.get('type', 'java.lang.String')})"
for f in kb_fields
)
candidates_hint = ""
if candidates:
cand_lines = []
for ocr_name, cands in candidates.items():
cand_str = ", ".join(f"{n}({s:.2f})" for n, s in cands)
cand_lines.append(f" {ocr_name} -> 候选: {cand_str}")
candidates_hint = (
"向量相似度候选(仅供参考,请根据语义确认):\n"
+ "\n".join(cand_lines)
)
prompt = (
"请将以下 OCR 识别的字段名匹配到知识库定义的字段。\n\n"
f"OCR 字段: {json.dumps(ocr_fields, ensure_ascii=False)}\n\n"
f"知识库字段:\n{kb_desc}\n\n"
f"{candidates_hint}\n\n"
"请以 JSON 对象格式输出映射关系,键为 OCR 字段名,值为 KB 字段名:\n"
'{"工单号": "billNo", "客户名称": "customerName"}'
)
try:
response = llm.invoke(prompt)
content = response.content if hasattr(response, "content") else str(response)
start = content.find("{")
end = content.rfind("}") + 1
if start >= 0 and end > start:
return json.loads(content[start:end])
except Exception as e:
_match_log.warning("LLM 字段匹配失败: %s", e)
return {}
def format_field_mapping_context(mapping: dict[str, str]) -> str:
"""将字段映射格式化为 prompt 上下文字符串。"""
if not mapping:
return ""
lines = ["[字段映射 — OCR -> KB]",
"请在 JRXML 中使用以下参数名:",
"| OCR 字段 | JRXML 参数 |",
"|---|---|"]
for ocr_name, kb_name in mapping.items():
lines.append(f"| {ocr_name} | $P{{{kb_name}}} |")
return "\n".join(lines)
+227
View File
@@ -0,0 +1,227 @@
"""多租户知识库管理模块。
用户 + 知识库 CRUD,持久化到 kb_data/ 目录。
每个 KB 拥有独立的 JSON 元数据文件和文件存储目录。
"""
import json
import os
import re
import uuid
import tempfile
import shutil
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
from backend.logger import get_logger
load_dotenv()
_kb_log = get_logger("kb_manager")
KB_DATA_DIR = Path(os.getenv("KB_DATA_DIR", "./kb_data"))
_USERS_FILE = KB_DATA_DIR / "users.json"
_VALID_ID_RE = re.compile(r'^[a-fA-F0-9]{12,}$')
def _validate_id(id_str: str, label: str = "id") -> None:
if not _VALID_ID_RE.match(id_str):
raise ValueError(f"Invalid {label}: {id_str!r}")
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def _ensure_dir(path: Path) -> None:
path.mkdir(parents=True, exist_ok=True)
def _read_json(fp: Path) -> dict:
with open(fp, "r", encoding="utf-8") as f:
return json.load(f)
def _write_json_atomic(fp: Path, data: dict) -> None:
_ensure_dir(fp.parent)
tmp = tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False,
dir=fp.parent, encoding="utf-8",
)
try:
json.dump(data, tmp, ensure_ascii=False, indent=2)
tmp.flush()
os.fsync(tmp.fileno())
tmp.close()
os.replace(tmp.name, str(fp))
except Exception:
tmp.close()
Path(tmp.name).unlink(missing_ok=True)
raise
# ── User CRUD ──────────────────────────────────────────────────────────────
def _load_users() -> list[dict]:
_ensure_dir(KB_DATA_DIR)
if _USERS_FILE.exists():
return _read_json(_USERS_FILE)
return []
def _save_users(users: list[dict]) -> None:
_write_json_atomic(_USERS_FILE, users)
def create_user(name: str, user_id: Optional[str] = None) -> dict:
uid = user_id or uuid.uuid4().hex
users = _load_users()
if any(u["user_id"] == uid for u in users):
raise ValueError(f"User {uid} already exists")
user = {"user_id": uid, "name": name, "created_at": _now_iso()}
users.append(user)
_save_users(users)
_ensure_dir(KB_DATA_DIR / uid)
_write_json_atomic(KB_DATA_DIR / uid / "profile.json", user)
_kb_log.info("创建用户", extra={"user_id": uid, "user_name": name})
return user
def list_users() -> list[dict]:
return _load_users()
def get_user(user_id: str) -> Optional[dict]:
_validate_id(user_id, "user_id")
for u in _load_users():
if u["user_id"] == user_id:
return u
return None
def delete_user(user_id: str) -> bool:
_validate_id(user_id, "user_id")
users = _load_users()
filtered = [u for u in users if u["user_id"] != user_id]
if len(filtered) == len(users):
return False
_save_users(filtered)
user_dir = KB_DATA_DIR / user_id
if user_dir.exists():
shutil.rmtree(user_dir)
_kb_log.info("删除用户", extra={"user_id": user_id})
return True
# ── KB CRUD ────────────────────────────────────────────────────────────────
def _kb_dir(kb_id: str) -> Optional[Path]:
_validate_id(kb_id, "kb_id")
for user_dir in KB_DATA_DIR.iterdir():
if user_dir.is_dir() and not user_dir.name.startswith("."):
candidate = user_dir / kb_id
if candidate.is_dir():
return candidate
return None
def _ensure_user_dir(user_id: str) -> Path:
_validate_id(user_id, "user_id")
d = KB_DATA_DIR / user_id
_ensure_dir(d)
return d
def create_kb(user_id: str, name: str, description: str = "",
kb_id: Optional[str] = None) -> dict:
user_dir = _ensure_user_dir(user_id)
kid = kb_id or uuid.uuid4().hex
kb_dir = user_dir / kid
_ensure_dir(kb_dir)
_ensure_dir(kb_dir / "raw")
now = _now_iso()
meta = {
"kb_id": kid, "user_id": user_id, "name": name,
"description": description, "created_at": now, "updated_at": now,
"fields": [], "templates": [], "file_count": 0,
"chunk_count": 0, "parse_status": "empty",
}
_write_json_atomic(kb_dir / "meta.json", meta)
_kb_log.info("创建知识库", extra={"kb_id": kid, "user_id": user_id, "kb_name": name})
return meta
def list_kbs(user_id: str) -> list[dict]:
user_dir = _ensure_user_dir(user_id)
kbs = []
for kb_dir in sorted(user_dir.iterdir(), key=os.path.getmtime, reverse=True):
if kb_dir.is_dir() and not kb_dir.name.startswith("."):
meta_path = kb_dir / "meta.json"
if meta_path.exists():
meta = _read_json(meta_path)
kbs.append({
"kb_id": meta.get("kb_id", kb_dir.name),
"name": meta.get("name", kb_dir.name),
"description": meta.get("description", ""),
"created_at": meta.get("created_at", ""),
"updated_at": meta.get("updated_at", ""),
"field_count": len(meta.get("fields", [])),
"template_count": len(meta.get("templates", [])),
"file_count": meta.get("file_count", 0),
"chunk_count": meta.get("chunk_count", 0),
"parse_status": meta.get("parse_status", "empty"),
})
return kbs
def get_kb(kb_id: str) -> Optional[dict]:
_validate_id(kb_id, "kb_id")
kb_dir = _kb_dir(kb_id)
if kb_dir is None:
return None
meta_path = kb_dir / "meta.json"
return _read_json(meta_path) if meta_path.exists() else None
def update_kb_meta(kb_id: str, updates: dict) -> Optional[dict]:
kb_dir = _kb_dir(kb_id)
if kb_dir is None:
return None
meta_path = kb_dir / "meta.json"
meta = _read_json(meta_path)
meta.update(updates)
meta["updated_at"] = _now_iso()
_write_json_atomic(meta_path, meta)
return meta
def delete_kb(kb_id: str) -> bool:
kb_dir = _kb_dir(kb_id)
if kb_dir is None:
return False
shutil.rmtree(kb_dir)
_kb_log.info("删除知识库", extra={"kb_id": kb_id})
return True
def get_kb_raw_dir(kb_id: str) -> Optional[Path]:
kb_dir = _kb_dir(kb_id)
return kb_dir / "raw" if kb_dir else None
def get_kb_chunks_path(kb_id: str) -> Optional[Path]:
kb_dir = _kb_dir(kb_id)
return kb_dir / "chunks.json" if kb_dir else None
def get_kb_chroma_path(kb_id: str) -> Optional[Path]:
kb_dir = _kb_dir(kb_id)
if kb_dir is None:
return None
chroma_dir = kb_dir / "chroma"
_ensure_dir(chroma_dir)
return chroma_dir
+336
View File
@@ -0,0 +1,336 @@
"""KB 解析管道 — 文件提取→字段解析→chunk 切割→向量嵌入。
调用者: api_server.py (upload endpoint), scripts/init_default_kb.py
"""
import os
import json
import shutil
import zipfile
import tarfile
import tempfile
import defusedxml.ElementTree as ET
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
from backend.logger import get_logger
from backend.file_parser import parse_file
load_dotenv()
_kb_parse_log = get_logger("kb_parser")
def _find_tag(elem, tag):
for el in elem.iter():
local = el.tag.split("}")[-1] if "}" in el.tag else el.tag
if local == tag:
return el
return None
def _find_all_tags(elem, tag):
results = []
for el in elem.iter():
local = el.tag.split("}")[-1] if "}" in el.tag else el.tag
if local == tag:
results.append(el)
return results
def parse_jrxml_fields(jrxml_path: str) -> dict:
"""解析 JRXML 文件,提取参数和字段定义。"""
try:
tree = ET.parse(jrxml_path)
root = tree.getroot()
except ET.ParseError as e:
return {"error": f"JRXML 解析失败: {e}", "parameters": [], "fields": [],
"report_name": ""}
report_name = root.attrib.get("name", "")
page_width = root.attrib.get("pageWidth", "")
page_height = root.attrib.get("pageHeight", "")
parameters = []
for p in _find_all_tags(root, "parameter"):
params = {"name": p.attrib.get("name", ""),
"type": p.attrib.get("class", "java.lang.String"),
"description": ""}
desc = _find_tag(p, "parameterDescription")
if desc is not None and desc.text:
params["description"] = desc.text.strip()
parameters.append(params)
fields = []
for f in _find_all_tags(root, "field"):
fields.append({"name": f.attrib.get("name", ""),
"type": f.attrib.get("class", "java.lang.String"),
"description": ""})
query_text = ""
query = _find_tag(root, "queryString")
if query is not None and query.text:
query_text = query.text.strip()
return {"report_name": report_name, "page_width": page_width,
"page_height": page_height, "parameters": parameters,
"fields": fields, "query": query_text, "error": None}
def _extract_archive(file_path: str, dest_dir: str) -> list[str]:
extracted = []
resolved_dest = os.path.realpath(dest_dir)
if zipfile.is_zipfile(file_path):
with zipfile.ZipFile(file_path, "r") as zf:
for member in zf.namelist():
member_path = os.path.realpath(os.path.join(dest_dir, member))
if not member_path.startswith(resolved_dest + os.sep):
continue
zf.extract(member, dest_dir)
if not member.endswith("/"):
extracted.append(member_path)
elif tarfile.is_tarfile(file_path):
with tarfile.open(file_path, "r:*") as tf:
for member in tf.getmembers():
member_path = os.path.realpath(os.path.join(dest_dir, member.name))
if not member_path.startswith(resolved_dest + os.sep):
continue
tf.extract(member, dest_dir)
if not member.name.endswith("/"):
extracted.append(member_path)
return extracted
def process_file_for_kb(kb_id: str, file_path: str,
source_name: str = "") -> dict:
from backend.kb_manager import get_kb_raw_dir
raw_dir = get_kb_raw_dir(kb_id)
if raw_dir is None:
return {"error": "KB 不存在"}
fname = source_name or os.path.basename(file_path)
dest = raw_dir / fname
shutil.copy2(file_path, dest)
suffix = Path(fname).suffix.lower()
if suffix == ".jrxml":
jrxml_info = parse_jrxml_fields(file_path)
text = f"[JRXML 模板: {jrxml_info['report_name']}]\n"
text += f"页面: {jrxml_info['page_width']}x{jrxml_info['page_height']}\n"
text += "参数:\n" + "\n".join(
f" {p['name']} ({p['type']})" for p in jrxml_info["parameters"])
text += "\n字段:\n" + "\n".join(
f" {f['name']} ({f['type']})" for f in jrxml_info["fields"])
if jrxml_info["query"]:
text += f"\n查询:\n{jrxml_info['query']}"
try:
raw_xml = Path(file_path).read_text(encoding="utf-8")
except Exception:
raw_xml = ""
return {"filename": fname, "type": "jrxml", "text": text,
"raw_xml": raw_xml, "jrxml_info": jrxml_info, "error": None}
if suffix in (".zip", ".tar", ".gz", ".tgz"):
tmpdir = tempfile.mkdtemp(prefix="kb_extract_")
try:
extracted = _extract_archive(file_path, tmpdir)
sub_results = []
for ep in extracted:
sub = process_file_for_kb(
kb_id, ep, source_name=os.path.basename(ep))
sub_results.append(sub)
return {"filename": fname, "type": "archive", "text": "",
"archive_contents": sub_results, "error": None}
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
parse_result = parse_file(str(dest))
return {"filename": fname, "type": suffix.lstrip("."),
"text": parse_result.get("text", ""),
"error": parse_result.get("error")}
def chunk_file_results(results: list[dict], kb_name: str = "") -> list[dict]:
chunks = []
chunk_idx = 0
for r in results:
if r.get("type") == "archive":
for sub in r.get("archive_contents", []):
chunks.extend(chunk_file_results([sub], kb_name))
continue
fname = r.get("filename", "")
ftype = r.get("type", "")
text = r.get("text", "")
if not text.strip():
continue
if ftype == "jrxml" and r.get("raw_xml"):
jinfo = r.get("jrxml_info", {})
report_name = jinfo.get("report_name", "")
chunks.append({
"id": f"chunk_{chunk_idx}",
"content": f"[JRXML 模板: {report_name}]\n{r['text']}\n\n"
f"<xml>\n{r['raw_xml']}\n</xml>",
"metadata": {"chunk_type": "jrxml_template",
"source_file": fname,
"report_name": report_name,
"kb_name": kb_name,
"param_count": len(jinfo.get("parameters", [])),
"field_count": len(jinfo.get("fields", []))},
})
chunk_idx += 1
continue
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
for para in paragraphs:
if len(para) < 10:
continue
chunk_type = "md_section" if ftype in ("md", "") else f"{ftype}_text"
chunks.append({
"id": f"chunk_{chunk_idx}",
"content": para,
"metadata": {"chunk_type": chunk_type,
"source_file": fname, "kb_name": kb_name},
})
chunk_idx += 1
return chunks
def extract_fields_with_llm(text: str, llm=None) -> list[dict]:
if llm is None:
return _extract_fields_from_table(text)
prompt = (
"请分析以下接口文档内容,提取所有字段定义。\n"
"对每个字段,输出: 字段名, 含义, 类型, 是否必需。\n"
"以 JSON 数组格式输出,每个元素为 {\"name\": \"...\", "
"\"description\": \"...\", \"type\": \"...\", \"required\": false}。\n\n"
f"{text}"
)
try:
response = llm.invoke(prompt)
content = response.content if hasattr(response, "content") else str(response)
start = content.find("[")
end = content.rfind("]") + 1
if start >= 0 and end > start:
return json.loads(content[start:end])
except Exception as e:
_kb_parse_log.warning("LLM 字段提取失败,使用表格回退: %s", e)
return _extract_fields_from_table(text)
def _extract_fields_from_table(text: str) -> list[dict]:
fields = []
lines = text.split("\n")
header_found = False
for line in lines:
line = line.strip()
if not line.startswith("|"):
continue
cells = [c.strip() for c in line.split("|")[1:-1]]
if not cells:
continue
if not header_found:
if any(h in str(c) for c in cells
for h in ["字段", "名称", "含义", "说明", "类型"]):
header_found = True
continue
if all(c.replace("-", "").replace(":", "").replace(" ", "") == ""
for c in cells):
continue
if len(cells) >= 2:
name = cells[0].replace("**", "").replace("L ", "").replace("\\", "").strip()
if not name or name in ("", "---"):
continue
field = {"name": name, "description": "", "type": "java.lang.String",
"required": False}
if len(cells) >= 2 and cells[1]:
field["description"] = cells[1].replace("<br/>", " ").strip()
if len(cells) >= 3 and cells[2]:
field["required"] = cells[2].strip() in ("", "Y", "y", "yes", "Yes", "必填")
if len(cells) >= 4 and cells[3]:
field["type"] = cells[3].strip()
fields.append(field)
return fields
def build_kb_from_files(kb_id: str, file_paths: list[str],
llm=None) -> dict:
from backend.kb_manager import update_kb_meta, get_kb_chunks_path
from backend.kb_searcher import get_kb_searcher
all_results = []
errors = []
for fp in file_paths:
try:
r = process_file_for_kb(kb_id, fp)
all_results.append(r)
if r.get("error"):
errors.append({"file": os.path.basename(fp), "error": r["error"]})
except Exception as e:
errors.append({"file": os.path.basename(fp), "error": str(e)})
chunks = chunk_file_results(all_results)
chunks_path = get_kb_chunks_path(kb_id)
if chunks_path:
chunks_path.parent.mkdir(parents=True, exist_ok=True)
with open(chunks_path, "w", encoding="utf-8") as f:
json.dump(chunks, f, ensure_ascii=False, indent=2)
searcher = get_kb_searcher(kb_id)
if searcher and chunks:
try:
searcher.add_chunks(chunks)
except Exception as e:
errors.append({"file": "embedding", "error": str(e)})
all_fields = []
template_names = []
for r in all_results:
_collect_from_result(r, all_fields, template_names)
for r in all_results:
if r.get("type") in ("archive", "jrxml"):
continue
text = r.get("text", "")
if text.strip():
for ef in extract_fields_with_llm(text, llm):
if not any(f["name"] == ef["name"] for f in all_fields):
all_fields.append(ef)
update_kb_meta(kb_id, {
"fields": all_fields, "templates": template_names,
"file_count": len(file_paths), "chunk_count": len(chunks),
"parse_status": "ready" if not errors else "partial",
})
_kb_parse_log.info("KB 构建完成", extra={
"kb_id": kb_id, "fields": len(all_fields),
"templates": len(template_names), "chunks": len(chunks),
})
return {"status": "ready" if not errors else "partial",
"field_count": len(all_fields), "chunk_count": len(chunks),
"template_count": len(template_names), "errors": errors}
def _collect_from_result(r: dict, all_fields: list, template_names: list) -> None:
jinfo = r.get("jrxml_info")
if jinfo and jinfo.get("report_name"):
template_names.append({"name": jinfo["report_name"],
"file": r.get("filename", "")})
for p in jinfo.get("parameters", []):
field = {"name": p["name"], "description": p.get("description", ""),
"type": p.get("type", "java.lang.String"), "required": False}
if not any(f["name"] == field["name"] for f in all_fields):
all_fields.append(field)
for f in jinfo.get("fields", []):
field = {"name": f["name"], "description": f.get("description", ""),
"type": f.get("type", "java.lang.String"), "required": False}
if not any(fi["name"] == field["name"] for fi in all_fields):
all_fields.append(field)
+170
View File
@@ -0,0 +1,170 @@
"""KB 隔离的 ChromaDB 语义搜索适配器。
每个知识库拥有独立的 ChromaDB collection。
调用者: backend/rag_adapter.py, agent/nodes.py, api_server.py
"""
import os
import logging
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
load_dotenv()
logger = logging.getLogger(__name__)
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
def _resolve(path: str) -> Path:
p = Path(path)
return p if p.is_absolute() else _PROJECT_ROOT / p
class KBChromaSearcher:
"""连接指定 KB 的 ChromaDB,提供语义搜索。"""
def __init__(self, chroma_path: str, collection_name: str = "kb_chunks",
model_name: Optional[str] = None, use_gpu: Optional[bool] = None,
use_fp16: Optional[bool] = None):
self.chroma_path = str(_resolve(chroma_path))
self.collection_name = collection_name
model_path = model_name or os.getenv(
"RAG_EMBED_MODEL", "./rag/models/paraphrase-multilingual-MiniLM-L12-v2")
resolved = _resolve(model_path)
self.model_name = str(resolved) if resolved.exists() else model_path
self.use_gpu = (use_gpu if use_gpu is not None
else os.getenv("RAG_USE_GPU", "true").lower() in ("true", "1"))
self.use_fp16 = (use_fp16 if use_fp16 is not None
else os.getenv("RAG_USE_FP16", "true").lower() in ("true", "1"))
self._model = None
self._client = None
self._collection = None
@property
def model(self):
if self._model is None:
import torch
from sentence_transformers import SentenceTransformer
device = "cuda" if (self.use_gpu and torch.cuda.is_available()) else "cpu"
logger.info("加载嵌入模型: %s (device=%s)", self.model_name, device)
model = SentenceTransformer(self.model_name, device=device)
if device == "cuda" and self.use_fp16:
model = model.half()
self._model = model
return self._model
@property
def client(self):
if self._client is None:
import chromadb
self._client = chromadb.PersistentClient(path=self.chroma_path)
return self._client
@property
def collection(self):
if self._collection is None:
try:
self._collection = self.client.get_collection(self.collection_name)
except Exception:
self._collection = self.client.create_collection(
self.collection_name, metadata={"hnsw:space": "cosine"})
return self._collection
def is_ready(self) -> bool:
try:
self.client.get_collection(self.collection_name)
return True
except Exception:
return False
def search(self, query: str, k: int = 5, threshold: Optional[float] = None) -> list[dict]:
if not self.is_ready():
return []
query_embedding = self.model.encode(
query, normalize_embeddings=True, show_progress_bar=False)
results = self.collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=k, include=["documents", "metadatas", "distances"])
output = []
if not results["ids"] or not results["ids"][0]:
return output
for i, doc_id in enumerate(results["ids"][0]):
dist = results["distances"][0][i]
if threshold is not None and dist > threshold:
continue
output.append({
"id": doc_id,
"content": results["documents"][0][i],
"metadata": results["metadatas"][0][i] or {},
"distance": dist,
})
return output
def search_templates(self, query: str, k: int = 3) -> list[dict]:
results = self.search(query, k=k * 2)
templates = []
for r in results:
meta = r.get("metadata", {})
chunk_type = meta.get("chunk_type", "")
if "jrxml" in chunk_type.lower() or meta.get("report_name"):
templates.append(r)
if len(templates) >= k:
break
return templates
def search_as_context(self, query: str, k: int = 5) -> str:
results = self.search(query, k=k)
if not results:
return ""
parts = []
for r in results:
meta = r.get("metadata", {})
header = f"[类型:{meta.get('chunk_type', 'N/A')}]"
if meta.get("report_name"):
header += f" [报表:{meta['report_name']}]"
parts.append(f"{header}\n{r['content']}")
return "\n\n---\n\n".join(parts)
def add_chunks(self, chunks: list[dict]) -> None:
if not chunks:
return
ids = [c["id"] for c in chunks]
docs = [c["content"] for c in chunks]
metas = [c.get("metadata", {}) for c in chunks]
embeddings = self.model.encode(
docs, normalize_embeddings=True, show_progress_bar=True)
self.collection.upsert(
ids=ids, documents=docs, metadatas=metas,
embeddings=embeddings.tolist())
_searchers: dict = {}
def get_kb_searcher(kb_id: str) -> Optional[KBChromaSearcher]:
from backend.kb_manager import get_kb_chroma_path
if kb_id in _searchers:
return _searchers[kb_id]
chroma_path = get_kb_chroma_path(kb_id)
if chroma_path is None:
return None
searcher = KBChromaSearcher(str(chroma_path))
_searchers[kb_id] = searcher
return searcher
def search_kb(kb_id: str, query: str, k: int = 5) -> str:
searcher = get_kb_searcher(kb_id)
if searcher is None:
return ""
return searcher.search_as_context(query, k=k)
def search_templates_in_kb(kb_id: str, query: str, k: int = 3) -> list[dict]:
searcher = get_kb_searcher(kb_id)
if searcher is None:
return []
return searcher.search_templates(query, k=k)
+8 -2
View File
@@ -150,6 +150,12 @@ def _get_searcher() -> RAGSearcher:
return _searcher
def search_chunks(query: str, k: int = 5) -> str:
"""搜索 JRXML 知识库并返回拼接后的上下文文本(便捷函数)。"""
def search_chunks(query: str, k: int = 5, kb_id: str = "") -> str:
"""搜索知识库并返回拼接后的上下文文本
若指定 kb_id,使用该 KB 专属 ChromaDB;否则使用全局默认库。
"""
if kb_id:
from backend.kb_searcher import search_kb
return search_kb(kb_id, query, k=k)
return _get_searcher().search_as_context(query, k=k)
+1
View File
@@ -56,6 +56,7 @@ def create_session(name: str = "", agent_state: Optional[dict] = None,
"session_name": name or f"新建报表 {now[:10]}",
"created_at": now,
"updated_at": now,
"kb_id": agent_state.get("kb_id", "") if agent_state else "",
"agent_state": agent_state,
}
with open(_session_path(sid), "w", encoding="utf-8") as f: