"""KB 解析管道 — 文件提取→字段解析→chunk 切割→向量嵌入。 调用者: api_server.py (upload endpoint), scripts/init_default_kb.py """ import os import json import shutil import zipfile import tarfile import tempfile import defusedxml.ElementTree as ET from pathlib import Path from typing import Optional from dotenv import load_dotenv from backend.logger import get_logger from backend.file_parser import parse_file load_dotenv() _kb_parse_log = get_logger("kb_parser") def _find_tag(elem, tag): for el in elem.iter(): local = el.tag.split("}")[-1] if "}" in el.tag else el.tag if local == tag: return el return None def _find_all_tags(elem, tag): results = [] for el in elem.iter(): local = el.tag.split("}")[-1] if "}" in el.tag else el.tag if local == tag: results.append(el) return results def parse_jrxml_fields(jrxml_path: str) -> dict: """解析 JRXML 文件,提取参数和字段定义。""" try: tree = ET.parse(jrxml_path) root = tree.getroot() except ET.ParseError as e: return {"error": f"JRXML 解析失败: {e}", "parameters": [], "fields": [], "report_name": ""} report_name = root.attrib.get("name", "") page_width = root.attrib.get("pageWidth", "") page_height = root.attrib.get("pageHeight", "") parameters = [] for p in _find_all_tags(root, "parameter"): params = {"name": p.attrib.get("name", ""), "type": p.attrib.get("class", "java.lang.String"), "description": ""} desc = _find_tag(p, "parameterDescription") if desc is not None and desc.text: params["description"] = desc.text.strip() parameters.append(params) fields = [] for f in _find_all_tags(root, "field"): fields.append({"name": f.attrib.get("name", ""), "type": f.attrib.get("class", "java.lang.String"), "description": ""}) query_text = "" query = _find_tag(root, "queryString") if query is not None and query.text: query_text = query.text.strip() return {"report_name": report_name, "page_width": page_width, "page_height": page_height, "parameters": parameters, "fields": fields, "query": query_text, "error": None} def _extract_archive(file_path: str, dest_dir: str) -> list[str]: extracted = [] resolved_dest = os.path.realpath(dest_dir) if zipfile.is_zipfile(file_path): with zipfile.ZipFile(file_path, "r") as zf: for member in zf.namelist(): member_path = os.path.realpath(os.path.join(dest_dir, member)) if not member_path.startswith(resolved_dest + os.sep): continue zf.extract(member, dest_dir) if not member.endswith("/"): extracted.append(member_path) elif tarfile.is_tarfile(file_path): with tarfile.open(file_path, "r:*") as tf: for member in tf.getmembers(): member_path = os.path.realpath(os.path.join(dest_dir, member.name)) if not member_path.startswith(resolved_dest + os.sep): continue tf.extract(member, dest_dir) if not member.name.endswith("/"): extracted.append(member_path) return extracted def process_file_for_kb(kb_id: str, file_path: str, source_name: str = "") -> dict: from backend.kb_manager import get_kb_raw_dir raw_dir = get_kb_raw_dir(kb_id) if raw_dir is None: return {"error": "KB 不存在"} fname = source_name or os.path.basename(file_path) dest = raw_dir / fname shutil.copy2(file_path, dest) suffix = Path(fname).suffix.lower() if suffix == ".jrxml": jrxml_info = parse_jrxml_fields(file_path) text = f"[JRXML 模板: {jrxml_info['report_name']}]\n" text += f"页面: {jrxml_info['page_width']}x{jrxml_info['page_height']}\n" text += "参数:\n" + "\n".join( f" {p['name']} ({p['type']})" for p in jrxml_info["parameters"]) text += "\n字段:\n" + "\n".join( f" {f['name']} ({f['type']})" for f in jrxml_info["fields"]) if jrxml_info["query"]: text += f"\n查询:\n{jrxml_info['query']}" try: raw_xml = Path(file_path).read_text(encoding="utf-8") except Exception: raw_xml = "" return {"filename": fname, "type": "jrxml", "text": text, "raw_xml": raw_xml, "jrxml_info": jrxml_info, "error": None} if suffix in (".zip", ".tar", ".gz", ".tgz"): tmpdir = tempfile.mkdtemp(prefix="kb_extract_") try: extracted = _extract_archive(file_path, tmpdir) sub_results = [] for ep in extracted: sub = process_file_for_kb( kb_id, ep, source_name=os.path.basename(ep)) sub_results.append(sub) return {"filename": fname, "type": "archive", "text": "", "archive_contents": sub_results, "error": None} finally: shutil.rmtree(tmpdir, ignore_errors=True) parse_result = parse_file(str(dest)) return {"filename": fname, "type": suffix.lstrip("."), "text": parse_result.get("text", ""), "error": parse_result.get("error")} def chunk_file_results(results: list[dict], kb_name: str = "") -> list[dict]: chunks = [] chunk_idx = 0 for r in results: if r.get("type") == "archive": for sub in r.get("archive_contents", []): chunks.extend(chunk_file_results([sub], kb_name)) continue fname = r.get("filename", "") ftype = r.get("type", "") text = r.get("text", "") if not text.strip(): continue if ftype == "jrxml" and r.get("raw_xml"): jinfo = r.get("jrxml_info", {}) report_name = jinfo.get("report_name", "") chunks.append({ "id": f"chunk_{chunk_idx}", "content": f"[JRXML 模板: {report_name}]\n{r['text']}\n\n" f"\n{r['raw_xml']}\n", "metadata": {"chunk_type": "jrxml_template", "source_file": fname, "report_name": report_name, "kb_name": kb_name, "param_count": len(jinfo.get("parameters", [])), "field_count": len(jinfo.get("fields", []))}, }) chunk_idx += 1 continue paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] for para in paragraphs: if len(para) < 10: continue chunk_type = "md_section" if ftype in ("md", "") else f"{ftype}_text" chunks.append({ "id": f"chunk_{chunk_idx}", "content": para, "metadata": {"chunk_type": chunk_type, "source_file": fname, "kb_name": kb_name}, }) chunk_idx += 1 return chunks def extract_fields_with_llm(text: str, llm=None) -> list[dict]: if llm is None: return _extract_fields_from_table(text) prompt = ( "请分析以下接口文档内容,提取所有字段定义。\n" "对每个字段,输出: 字段名, 含义, 类型, 是否必需。\n" "以 JSON 数组格式输出,每个元素为 {\"name\": \"...\", " "\"description\": \"...\", \"type\": \"...\", \"required\": false}。\n\n" f"{text}" ) try: response = llm.invoke(prompt) content = response.content if hasattr(response, "content") else str(response) start = content.find("[") end = content.rfind("]") + 1 if start >= 0 and end > start: return json.loads(content[start:end]) except Exception as e: _kb_parse_log.warning("LLM 字段提取失败,使用表格回退: %s", e) return _extract_fields_from_table(text) def _extract_fields_from_table(text: str) -> list[dict]: fields = [] lines = text.split("\n") header_found = False for line in lines: line = line.strip() if not line.startswith("|"): continue cells = [c.strip() for c in line.split("|")[1:-1]] if not cells: continue if not header_found: if any(h in str(c) for c in cells for h in ["字段", "名称", "含义", "说明", "类型"]): header_found = True continue if all(c.replace("-", "").replace(":", "").replace(" ", "") == "" for c in cells): continue if len(cells) >= 2: name = cells[0].replace("**", "").replace("L ", "").replace("\\", "").strip() if not name or name in ("", "---"): continue field = {"name": name, "description": "", "type": "java.lang.String", "required": False} if len(cells) >= 2 and cells[1]: field["description"] = cells[1].replace("
", " ").strip() if len(cells) >= 3 and cells[2]: field["required"] = cells[2].strip() in ("是", "Y", "y", "yes", "Yes", "必填") if len(cells) >= 4 and cells[3]: field["type"] = cells[3].strip() fields.append(field) return fields def build_kb_from_files(kb_id: str, file_paths: list[str], llm=None) -> dict: from backend.kb_manager import update_kb_meta, get_kb_chunks_path from backend.kb_searcher import get_kb_searcher all_results = [] errors = [] for fp in file_paths: try: r = process_file_for_kb(kb_id, fp) all_results.append(r) if r.get("error"): errors.append({"file": os.path.basename(fp), "error": r["error"]}) except Exception as e: errors.append({"file": os.path.basename(fp), "error": str(e)}) chunks = chunk_file_results(all_results) chunks_path = get_kb_chunks_path(kb_id) if chunks_path: chunks_path.parent.mkdir(parents=True, exist_ok=True) with open(chunks_path, "w", encoding="utf-8") as f: json.dump(chunks, f, ensure_ascii=False, indent=2) searcher = get_kb_searcher(kb_id) if searcher and chunks: try: searcher.add_chunks(chunks) except Exception as e: errors.append({"file": "embedding", "error": str(e)}) all_fields = [] template_names = [] for r in all_results: _collect_from_result(r, all_fields, template_names) for r in all_results: if r.get("type") in ("archive", "jrxml"): continue text = r.get("text", "") if text.strip(): for ef in extract_fields_with_llm(text, llm): if not any(f["name"] == ef["name"] for f in all_fields): all_fields.append(ef) update_kb_meta(kb_id, { "fields": all_fields, "templates": template_names, "file_count": len(file_paths), "chunk_count": len(chunks), "parse_status": "ready" if not errors else "partial", }) _kb_parse_log.info("KB 构建完成", extra={ "kb_id": kb_id, "fields": len(all_fields), "templates": len(template_names), "chunks": len(chunks), }) return {"status": "ready" if not errors else "partial", "field_count": len(all_fields), "chunk_count": len(chunks), "template_count": len(template_names), "errors": errors} def _collect_from_result(r: dict, all_fields: list, template_names: list) -> None: jinfo = r.get("jrxml_info") if jinfo and jinfo.get("report_name"): template_names.append({"name": jinfo["report_name"], "file": r.get("filename", "")}) for p in jinfo.get("parameters", []): field = {"name": p["name"], "description": p.get("description", ""), "type": p.get("type", "java.lang.String"), "required": False} if not any(f["name"] == field["name"] for f in all_fields): all_fields.append(field) for f in jinfo.get("fields", []): field = {"name": f["name"], "description": f.get("description", ""), "type": f.get("type", "java.lang.String"), "required": False} if not any(fi["name"] == field["name"] for fi in all_fields): all_fields.append(field)