"""KB 解析管道 — 文件提取→字段解析→chunk 切割→向量嵌入。
调用者: api_server.py (upload endpoint), scripts/init_default_kb.py
"""
import os
import json
import shutil
import zipfile
import tarfile
import tempfile
import defusedxml.ElementTree as ET
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
from backend.logger import get_logger
from backend.file_parser import parse_file
load_dotenv()
_kb_parse_log = get_logger("kb_parser")
def _find_tag(elem, tag):
for el in elem.iter():
local = el.tag.split("}")[-1] if "}" in el.tag else el.tag
if local == tag:
return el
return None
def _find_all_tags(elem, tag):
results = []
for el in elem.iter():
local = el.tag.split("}")[-1] if "}" in el.tag else el.tag
if local == tag:
results.append(el)
return results
def parse_jrxml_fields(jrxml_path: str) -> dict:
"""解析 JRXML 文件,提取参数和字段定义。"""
try:
tree = ET.parse(jrxml_path)
root = tree.getroot()
except ET.ParseError as e:
return {"error": f"JRXML 解析失败: {e}", "parameters": [], "fields": [],
"report_name": ""}
report_name = root.attrib.get("name", "")
page_width = root.attrib.get("pageWidth", "")
page_height = root.attrib.get("pageHeight", "")
parameters = []
for p in _find_all_tags(root, "parameter"):
params = {"name": p.attrib.get("name", ""),
"type": p.attrib.get("class", "java.lang.String"),
"description": ""}
desc = _find_tag(p, "parameterDescription")
if desc is not None and desc.text:
params["description"] = desc.text.strip()
parameters.append(params)
fields = []
for f in _find_all_tags(root, "field"):
fields.append({"name": f.attrib.get("name", ""),
"type": f.attrib.get("class", "java.lang.String"),
"description": ""})
query_text = ""
query = _find_tag(root, "queryString")
if query is not None and query.text:
query_text = query.text.strip()
return {"report_name": report_name, "page_width": page_width,
"page_height": page_height, "parameters": parameters,
"fields": fields, "query": query_text, "error": None}
def _extract_archive(file_path: str, dest_dir: str) -> list[str]:
extracted = []
resolved_dest = os.path.realpath(dest_dir)
if zipfile.is_zipfile(file_path):
with zipfile.ZipFile(file_path, "r") as zf:
for member in zf.namelist():
member_path = os.path.realpath(os.path.join(dest_dir, member))
if not member_path.startswith(resolved_dest + os.sep):
continue
zf.extract(member, dest_dir)
if not member.endswith("/"):
extracted.append(member_path)
elif tarfile.is_tarfile(file_path):
with tarfile.open(file_path, "r:*") as tf:
for member in tf.getmembers():
member_path = os.path.realpath(os.path.join(dest_dir, member.name))
if not member_path.startswith(resolved_dest + os.sep):
continue
tf.extract(member, dest_dir)
if not member.name.endswith("/"):
extracted.append(member_path)
return extracted
def process_file_for_kb(kb_id: str, file_path: str,
source_name: str = "") -> dict:
from backend.kb_manager import get_kb_raw_dir
raw_dir = get_kb_raw_dir(kb_id)
if raw_dir is None:
return {"error": "KB 不存在"}
fname = source_name or os.path.basename(file_path)
dest = raw_dir / fname
shutil.copy2(file_path, dest)
suffix = Path(fname).suffix.lower()
if suffix == ".jrxml":
jrxml_info = parse_jrxml_fields(file_path)
text = f"[JRXML 模板: {jrxml_info['report_name']}]\n"
text += f"页面: {jrxml_info['page_width']}x{jrxml_info['page_height']}\n"
text += "参数:\n" + "\n".join(
f" {p['name']} ({p['type']})" for p in jrxml_info["parameters"])
text += "\n字段:\n" + "\n".join(
f" {f['name']} ({f['type']})" for f in jrxml_info["fields"])
if jrxml_info["query"]:
text += f"\n查询:\n{jrxml_info['query']}"
try:
raw_xml = Path(file_path).read_text(encoding="utf-8")
except Exception:
raw_xml = ""
return {"filename": fname, "type": "jrxml", "text": text,
"raw_xml": raw_xml, "jrxml_info": jrxml_info, "error": None}
if suffix in (".zip", ".tar", ".gz", ".tgz"):
tmpdir = tempfile.mkdtemp(prefix="kb_extract_")
try:
extracted = _extract_archive(file_path, tmpdir)
sub_results = []
for ep in extracted:
sub = process_file_for_kb(
kb_id, ep, source_name=os.path.basename(ep))
sub_results.append(sub)
return {"filename": fname, "type": "archive", "text": "",
"archive_contents": sub_results, "error": None}
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
parse_result = parse_file(str(dest))
return {"filename": fname, "type": suffix.lstrip("."),
"text": parse_result.get("text", ""),
"error": parse_result.get("error")}
def chunk_file_results(results: list[dict], kb_name: str = "") -> list[dict]:
chunks = []
chunk_idx = 0
for r in results:
if r.get("type") == "archive":
for sub in r.get("archive_contents", []):
chunks.extend(chunk_file_results([sub], kb_name))
continue
fname = r.get("filename", "")
ftype = r.get("type", "")
text = r.get("text", "")
if not text.strip():
continue
if ftype == "jrxml" and r.get("raw_xml"):
jinfo = r.get("jrxml_info", {})
report_name = jinfo.get("report_name", "")
chunks.append({
"id": f"chunk_{chunk_idx}",
"content": f"[JRXML 模板: {report_name}]\n{r['text']}\n\n"
f"\n{r['raw_xml']}\n",
"metadata": {"chunk_type": "jrxml_template",
"source_file": fname,
"report_name": report_name,
"kb_name": kb_name,
"param_count": len(jinfo.get("parameters", [])),
"field_count": len(jinfo.get("fields", []))},
})
chunk_idx += 1
continue
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
for para in paragraphs:
if len(para) < 10:
continue
chunk_type = "md_section" if ftype in ("md", "") else f"{ftype}_text"
chunks.append({
"id": f"chunk_{chunk_idx}",
"content": para,
"metadata": {"chunk_type": chunk_type,
"source_file": fname, "kb_name": kb_name},
})
chunk_idx += 1
return chunks
def extract_fields_with_llm(text: str, llm=None) -> list[dict]:
if llm is None:
return _extract_fields_from_table(text)
prompt = (
"请分析以下接口文档内容,提取所有字段定义。\n"
"对每个字段,输出: 字段名, 含义, 类型, 是否必需。\n"
"以 JSON 数组格式输出,每个元素为 {\"name\": \"...\", "
"\"description\": \"...\", \"type\": \"...\", \"required\": false}。\n\n"
f"{text}"
)
try:
response = llm.invoke(prompt)
content = response.content if hasattr(response, "content") else str(response)
start = content.find("[")
end = content.rfind("]") + 1
if start >= 0 and end > start:
return json.loads(content[start:end])
except Exception as e:
_kb_parse_log.warning("LLM 字段提取失败,使用表格回退: %s", e)
return _extract_fields_from_table(text)
def _extract_fields_from_table(text: str) -> list[dict]:
fields = []
lines = text.split("\n")
header_found = False
for line in lines:
line = line.strip()
if not line.startswith("|"):
continue
cells = [c.strip() for c in line.split("|")[1:-1]]
if not cells:
continue
if not header_found:
if any(h in str(c) for c in cells
for h in ["字段", "名称", "含义", "说明", "类型"]):
header_found = True
continue
if all(c.replace("-", "").replace(":", "").replace(" ", "") == ""
for c in cells):
continue
if len(cells) >= 2:
name = cells[0].replace("**", "").replace("L ", "").replace("\\", "").strip()
if not name or name in ("", "---"):
continue
field = {"name": name, "description": "", "type": "java.lang.String",
"required": False}
if len(cells) >= 2 and cells[1]:
field["description"] = cells[1].replace("
", " ").strip()
if len(cells) >= 3 and cells[2]:
field["required"] = cells[2].strip() in ("是", "Y", "y", "yes", "Yes", "必填")
if len(cells) >= 4 and cells[3]:
field["type"] = cells[3].strip()
fields.append(field)
return fields
def build_kb_from_files(kb_id: str, file_paths: list[str],
llm=None) -> dict:
from backend.kb_manager import update_kb_meta, get_kb_chunks_path
from backend.kb_searcher import get_kb_searcher
all_results = []
errors = []
for fp in file_paths:
try:
r = process_file_for_kb(kb_id, fp)
all_results.append(r)
if r.get("error"):
errors.append({"file": os.path.basename(fp), "error": r["error"]})
except Exception as e:
errors.append({"file": os.path.basename(fp), "error": str(e)})
chunks = chunk_file_results(all_results)
chunks_path = get_kb_chunks_path(kb_id)
if chunks_path:
chunks_path.parent.mkdir(parents=True, exist_ok=True)
with open(chunks_path, "w", encoding="utf-8") as f:
json.dump(chunks, f, ensure_ascii=False, indent=2)
searcher = get_kb_searcher(kb_id)
if searcher and chunks:
try:
searcher.add_chunks(chunks)
except Exception as e:
errors.append({"file": "embedding", "error": str(e)})
all_fields = []
template_names = []
for r in all_results:
_collect_from_result(r, all_fields, template_names)
for r in all_results:
if r.get("type") in ("archive", "jrxml"):
continue
text = r.get("text", "")
if text.strip():
for ef in extract_fields_with_llm(text, llm):
if not any(f["name"] == ef["name"] for f in all_fields):
all_fields.append(ef)
update_kb_meta(kb_id, {
"fields": all_fields, "templates": template_names,
"file_count": len(file_paths), "chunk_count": len(chunks),
"parse_status": "ready" if not errors else "partial",
})
_kb_parse_log.info("KB 构建完成", extra={
"kb_id": kb_id, "fields": len(all_fields),
"templates": len(template_names), "chunks": len(chunks),
})
return {"status": "ready" if not errors else "partial",
"field_count": len(all_fields), "chunk_count": len(chunks),
"template_count": len(template_names), "errors": errors}
def _collect_from_result(r: dict, all_fields: list, template_names: list) -> None:
jinfo = r.get("jrxml_info")
if jinfo and jinfo.get("report_name"):
template_names.append({"name": jinfo["report_name"],
"file": r.get("filename", "")})
for p in jinfo.get("parameters", []):
field = {"name": p["name"], "description": p.get("description", ""),
"type": p.get("type", "java.lang.String"), "required": False}
if not any(f["name"] == field["name"] for f in all_fields):
all_fields.append(field)
for f in jinfo.get("fields", []):
field = {"name": f["name"], "description": f.get("description", ""),
"type": f.get("type", "java.lang.String"), "required": False}
if not any(fi["name"] == field["name"] for fi in all_fields):
all_fields.append(field)