"""kb_parser.py 测试 — JRXML 解析, 文件处理, 分块, 字段提取。""" import json import os import sys import tempfile from pathlib import Path from unittest.mock import patch, MagicMock import pytest sys.path.insert(0, str(Path(__file__).parent.parent)) from backend.kb_parser import ( parse_jrxml_fields, _extract_archive, process_file_for_kb, chunk_file_results, extract_fields_with_llm, _extract_fields_from_table, _find_tag, _find_all_tags, _collect_from_result, build_kb_from_files, ) SAMPLE_JRXML = """ 工单号 """ SAMPLE_JRXML_NO_NS = """ """ INVALID_XML = """ """ @pytest.fixture def jrxml_file(): with tempfile.NamedTemporaryFile(mode="w", suffix=".jrxml", delete=False, encoding="utf-8") as f: f.write(SAMPLE_JRXML) path = f.name yield path os.unlink(path) @pytest.fixture def temp_kb(monkeypatch): with tempfile.TemporaryDirectory(prefix="test_kb_parser_") as tmpdir: kb_data = Path(tmpdir) user_dir = kb_data / "default" kb_dir = user_dir / "abcdef1234567890abcd" raw_dir = kb_dir / "raw" raw_dir.mkdir(parents=True) monkeypatch.setattr( "backend.kb_manager.get_kb_raw_dir", lambda kb_id: raw_dir if kb_id == "abcdef1234567890abcd" else None) monkeypatch.setattr( "backend.kb_manager.get_kb_chunks_path", lambda kb_id: kb_dir / "chunks.json" if kb_id == "abcdef1234567890abcd" else None) monkeypatch.setattr( "backend.kb_manager.update_kb_meta", lambda kb_id, updates: None) yield {"kb_id": "abcdef1234567890abcd", "kb_dir": kb_dir, "raw_dir": raw_dir, "data_dir": kb_data} # ── JRXML 解析 ────────────────────────────────────────────────── class TestParseJrxmlFields: def test_parses_parameters(self, jrxml_file): result = parse_jrxml_fields(jrxml_file) assert result["error"] is None assert len(result["parameters"]) == 2 assert result["parameters"][0]["name"] == "billNo" assert result["parameters"][0]["type"] == "java.lang.String" assert result["parameters"][0]["description"] == "工单号" def test_parses_fields(self, jrxml_file): result = parse_jrxml_fields(jrxml_file) assert len(result["fields"]) == 2 field_names = [f["name"] for f in result["fields"]] assert "amount" in field_names assert "createDate" in field_names def test_parses_query(self, jrxml_file): result = parse_jrxml_fields(jrxml_file) assert "SELECT * FROM orders" in result["query"] def test_parses_report_metadata(self, jrxml_file): result = parse_jrxml_fields(jrxml_file) assert result["report_name"] == "TestReport" assert result["page_width"] == "595" assert result["page_height"] == "842" def test_parses_jrxml_without_namespace(self, tmp_path): fp = tmp_path / "simple.jrxml" fp.write_text(SAMPLE_JRXML_NO_NS, encoding="utf-8") result = parse_jrxml_fields(str(fp)) assert result["report_name"] == "SimpleReport" assert len(result["parameters"]) == 1 def test_invalid_xml_returns_error(self, tmp_path): fp = tmp_path / "bad.jrxml" fp.write_text(INVALID_XML, encoding="utf-8") result = parse_jrxml_fields(str(fp)) assert result["error"] is not None assert "解析失败" in result["error"] def test_empty_jrxml_has_no_fields(self, tmp_path): fp = tmp_path / "empty.jrxml" fp.write_text( '' '', encoding="utf-8") result = parse_jrxml_fields(str(fp)) assert result["parameters"] == [] assert result["fields"] == [] def test_nonexistent_file_raises(self): with pytest.raises(FileNotFoundError): parse_jrxml_fields("/nonexistent/path.jrxml") # ── 表格字段提取 ──────────────────────────────────────────────── class TestExtractFieldsFromTable: def test_extracts_from_markdown_table(self): text = """| 字段名 | 含义 | 必填 | 类型 | |--------|------|------|------| | billNo | 工单号 | 是 | String | | amount | 金额 | 否 | BigDecimal |""" fields = _extract_fields_from_table(text) assert len(fields) == 2 assert fields[0]["name"] == "billNo" assert fields[0]["description"] == "工单号" assert fields[0]["required"] is True assert fields[1]["name"] == "amount" def test_skips_separator_rows(self): text = """| 字段 | 说明 | |------|------| |------|------| | name | 名称 |""" fields = _extract_fields_from_table(text) assert len(fields) == 1 assert fields[0]["name"] == "name" def test_returns_empty_for_plain_text(self): fields = _extract_fields_from_table("这是一段普通文本,没有表格。") assert fields == [] def test_cells_with_bold_markers_stripped(self): text = """| 名称 | 含义 | |------|------| | **billNo** | 单号 |""" fields = _extract_fields_from_table(text) assert fields[0]["name"] == "billNo" # ── LLM 字段提取 ──────────────────────────────────────────────── class TestExtractFieldsWithLlm: def test_falls_back_to_table_when_no_llm(self): text = "| 字段 | 说明 |\n|------|------|\n| code | 编码 |" fields = extract_fields_with_llm(text, llm=None) assert len(fields) >= 1 assert any(f["name"] == "code" for f in fields) def test_uses_llm_when_provided(self): mock_llm = MagicMock() mock_response = MagicMock() mock_response.content = '[{"name": "id", "description": "ID", "type": "Long", "required": true}]' mock_llm.invoke.return_value = mock_response fields = extract_fields_with_llm("some text", llm=mock_llm) assert len(fields) == 1 assert fields[0]["name"] == "id" def test_llm_failure_falls_back_to_table(self): mock_llm = MagicMock() mock_llm.invoke.side_effect = RuntimeError("LLM down") text = "| 字段 | 说明 |\n|------|------|\n| code | 编码 |" fields = extract_fields_with_llm(text, llm=mock_llm) assert any(f["name"] == "code" for f in fields) # ── 文件处理 ──────────────────────────────────────────────────── class TestProcessFileForKb: def test_process_jrxml_copies_and_parses(self, jrxml_file, temp_kb): result = process_file_for_kb(temp_kb["kb_id"], jrxml_file) assert result["type"] == "jrxml" assert result["jrxml_info"]["report_name"] == "TestReport" assert result["error"] is None copied = list(temp_kb["raw_dir"].glob("*.jrxml")) assert len(copied) == 1 def test_process_nonexistent_kb_returns_error(self, jrxml_file): result = process_file_for_kb("deadbeef1234567890abcd", jrxml_file) assert result["error"] is not None def test_process_text_file(self, tmp_path, temp_kb): fp = tmp_path / "readme.md" fp.write_text("# 标题\n\n这是一段内容。\n\n另一段内容。", encoding="utf-8") with patch("backend.kb_parser.parse_file") as mock_parse: mock_parse.return_value = {"text": "parsed content", "error": None} result = process_file_for_kb(temp_kb["kb_id"], str(fp)) assert result["filename"] is not None assert result["error"] is None # ── 分块 ──────────────────────────────────────────────────────── class TestChunkFileResults: def test_jrxml_result_produces_template_chunk(self, jrxml_file): info = parse_jrxml_fields(jrxml_file) raw = Path(jrxml_file).read_text(encoding="utf-8") results = [{ "filename": "test.jrxml", "type": "jrxml", "text": "text content", "raw_xml": raw, "jrxml_info": info, "error": None, }] chunks = chunk_file_results(results, kb_name="测试库") assert len(chunks) >= 1 tmpl = [c for c in chunks if c["metadata"]["chunk_type"] == "jrxml_template"] assert len(tmpl) == 1 assert tmpl[0]["metadata"]["report_name"] == "TestReport" assert "TestReport" in tmpl[0]["content"] def test_archive_result_recurses(self): results = [{ "filename": "bundle.zip", "type": "archive", "text": "", "archive_contents": [ {"filename": "inner.jrxml", "type": "jrxml", "text": "inner text", "raw_xml": "", "jrxml_info": {"report_name": "Inner", "parameters": [], "fields": []}, "error": None}, ], "error": None, }] chunks = chunk_file_results(results) assert any(c["metadata"]["report_name"] == "Inner" for c in chunks) def test_empty_text_skipped(self): results = [{"filename": "empty.md", "type": "md", "text": "", "error": None}] assert chunk_file_results(results) == [] def test_short_paragraphs_skipped(self): results = [{"filename": "short.txt", "type": "txt", "text": "hi", "error": None}] assert chunk_file_results(results) == [] def test_text_split_into_paragraphs(self): long_para = "A" * 50 results = [ {"filename": "doc.txt", "type": "txt", "text": f"{long_para}\n\n{long_para}\n\n{long_para}", "error": None}, ] chunks = chunk_file_results(results) assert len(chunks) == 3 # ── _collect_from_result ──────────────────────────────────────── class TestCollectFromResult: def test_collects_jrxml_parameters_as_fields(self): fields = [] templates = [] _collect_from_result({ "jrxml_info": { "report_name": "R1", "parameters": [{"name": "p1", "type": "String", "description": "参数1"}], "fields": [], }, "filename": "r1.jrxml", }, fields, templates) assert len(templates) == 1 assert any(f["name"] == "p1" for f in fields) def test_collects_jrxml_fields(self): fields = [] templates = [] _collect_from_result({ "jrxml_info": { "report_name": "R2", "parameters": [], "fields": [{"name": "f1", "type": "Double", "description": ""}], }, "filename": "r2.jrxml", }, fields, templates) assert any(f["name"] == "f1" for f in fields) def test_skips_non_jrxml(self): fields = [] templates = [] _collect_from_result({"type": "csv", "filename": "data.csv"}, fields, templates) assert templates == [] assert fields == [] def test_deduplicates_fields(self): fields = [] templates = [] info = {"report_name": "R", "parameters": [{"name": "dup", "type": "String", "description": ""}], "fields": []} _collect_from_result({"jrxml_info": info, "filename": "a.jrxml"}, fields, templates) _collect_from_result({"jrxml_info": info, "filename": "b.jrxml"}, fields, templates) assert sum(1 for f in fields if f["name"] == "dup") == 1