""" batch_chunker.py 统一批量分块入口,支持 JRXML 和 Markdown 文件混合处理 """ import os import sys import json import time from pathlib import Path from datetime import datetime from collections import defaultdict from jrxml_chunker import JRXMLSemanticChunker from md_chunker import MarkdownSemanticChunker, save_chunks_to_json SUPPORTED_EXTENSIONS = ('.jrxml', '.JRXML', '.md', '.markdown') def batch_chunk_with_report(input_dir: str = None, output_dir: str = None, max_chunk_size: int = 2000, incremental: bool = False): """ 批量分块,支持 JRXML 和 Markdown 混合处理 Args: input_dir: 输入目录 output_dir: 输出目录 max_chunk_size: 单个 chunk 最大字符数 incremental: 增量模式,只处理新增文件,合并到已有结果 """ if input_dir is None: print("错误:请指定输入目录") return None input_path = Path(input_dir).resolve() if not input_path.exists(): print(f"❌ 目录不存在: {input_path}") return None if not input_path.is_dir(): print(f"❌ 不是目录: {input_path}") return None if output_dir is None: output_dir = input_path.parent / f"{input_path.stem}_chunks" output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) print(f"\n{'='*60}") print(f"统一批量分块 v1.0" + (" [增量模式]" if incremental else "")) print(f"{'='*60}") print(f"输入目录: {input_path}") print(f"输出目录: {output_path}") print(f"{'='*60}\n") # 增量模式:加载已有数据,跳过已处理的文件 existing_chunks = [] processed_files = set() if incremental: existing_chunks_path = output_path / "all_chunks.json" existing_stats_path = output_path / "processing_stats.json" if existing_chunks_path.exists() and existing_stats_path.exists(): with open(existing_chunks_path, 'r', encoding='utf-8') as f: existing_chunks = json.load(f) with open(existing_stats_path, 'r', encoding='utf-8') as f: existing_stats = json.load(f) processed_files = set(existing_stats.get("chunks_per_file", {}).keys()) print(f"增量模式: 已有 {len(existing_chunks)} 个 chunks, {len(processed_files)} 个已处理文件") else: print(f"增量模式: 未找到已有数据,切换为全量处理") incremental = False # 初始化分块器 jrxml_chunker = JRXMLSemanticChunker(max_chunk_size=max_chunk_size) md_chunker = MarkdownSemanticChunker(max_chunk_size=max_chunk_size) # 收集所有支持的文件 files_by_ext = defaultdict(list) for ext in SUPPORTED_EXTENSIONS: files_by_ext[ext] = list(input_path.rglob(f"*{ext}")) # 增量模式:过滤已处理文件 total_found = sum(len(f) for f in files_by_ext.values()) if incremental and processed_files: skipped = 0 for ext in SUPPORTED_EXTENSIONS: new_list = [] for f in files_by_ext[ext]: if str(f.relative_to(input_path)) in processed_files: skipped += 1 else: new_list.append(f) files_by_ext[ext] = new_list print(f"扫描到 {total_found} 个文件, 跳过 {skipped} 个已处理") else: print(f"扫描到 {total_found} 个文件") total_files = sum(len(f) for f in files_by_ext.values()) for ext, files in files_by_ext.items(): if files: print(f" {ext}: {len(files)} 个") if total_files == 0: print("✅ 没有新文件需要处理") result_stats = existing_stats.copy() if (incremental and processed_files) else {} return { "chunks": existing_chunks, "stats": result_stats, "output_path": str(output_path) } # 统计变量 all_chunks = [] stats = { "total_files": total_found, "success": 0, "failed": 0, "total_chunks": 0, "failed_files": [], "chunks_per_file": defaultdict(int), "chunk_types": defaultdict(int), "files_by_type": {"jrxml": 0, "markdown": 0}, "started_at": datetime.now().isoformat() } start_time = time.time() # 处理 JRXML 文件 jrxml_files = files_by_ext.get('.jrxml', []) + files_by_ext.get('.JRXML', []) if jrxml_files: print(f"\n📄 处理 JRXML 文件 ({len(jrxml_files)} 个)...") for i, jrxml_file in enumerate(jrxml_files, 1): relative_path = jrxml_file.relative_to(input_path) try: file_start = time.time() chunks = jrxml_chunker.chunk_file(str(jrxml_file)) file_duration = time.time() - file_start all_chunks.extend(chunks) stats["success"] += 1 stats["files_by_type"]["jrxml"] += 1 stats["total_chunks"] += len(chunks) stats["chunks_per_file"][str(relative_path)] = len(chunks) for chunk in chunks: stats["chunk_types"][f"jrxml_{chunk['chunk_type']}"] += 1 print(f"[{i}/{len(jrxml_files)}] ✅ JRXML: {relative_path} → {len(chunks)} chunks ({file_duration:.2f}s)") except Exception as e: stats["failed"] += 1 error_info = {"file": str(relative_path), "type": "jrxml", "error": str(e)} stats["failed_files"].append(error_info) print(f"[{i}/{len(jrxml_files)}] ❌ JRXML: {relative_path} → {e}") # 处理 Markdown 文件 md_files = files_by_ext.get('.md', []) + files_by_ext.get('.markdown', []) if md_files: print(f"\n📝 处理 Markdown 文件 ({len(md_files)} 个)...") for i, md_file in enumerate(md_files, 1): relative_path = md_file.relative_to(input_path) try: file_start = time.time() chunks = md_chunker.chunk_file(str(md_file)) file_duration = time.time() - file_start all_chunks.extend(chunks) stats["success"] += 1 stats["files_by_type"]["markdown"] += 1 stats["total_chunks"] += len(chunks) stats["chunks_per_file"][str(relative_path)] = len(chunks) for chunk in chunks: stats["chunk_types"][f"md_{chunk['chunk_type']}"] += 1 print(f"[{i}/{len(md_files)}] ✅ MD: {relative_path} → {len(chunks)} chunks ({file_duration:.2f}s)") except Exception as e: stats["failed"] += 1 error_info = {"file": str(relative_path), "type": "markdown", "error": str(e)} stats["failed_files"].append(error_info) print(f"[{i}/{len(md_files)}] ❌ MD: {relative_path} → {e}") total_duration = time.time() - start_time stats["processing_time"] = round(total_duration, 2) stats["finished_at"] = datetime.now().isoformat() # 增量模式:合并新旧数据 if incremental and existing_chunks: merged_chunks = existing_chunks + all_chunks print(f"\n合并: 已有 {len(existing_chunks)} + 新增 {len(all_chunks)} = {len(merged_chunks)} 个 chunks") all_chunks = merged_chunks # 合并统计 merged_stats = existing_stats.copy() merged_stats["success"] = existing_stats.get("success", 0) + stats["success"] merged_stats["failed"] = existing_stats.get("failed", 0) + stats["failed"] merged_stats["total_chunks"] = existing_stats.get("total_chunks", 0) + stats["total_chunks"] merged_stats["processing_time"] = round(existing_stats.get("processing_time", 0) + total_duration, 2) merged_stats["finished_at"] = stats["finished_at"] for fp, count in stats["chunks_per_file"].items(): merged_stats["chunks_per_file"][fp] = count for ct, count in stats["chunk_types"].items(): merged_stats["chunk_types"][ct] = merged_stats.get("chunk_types", {}).get(ct, 0) + count merged_stats["files_by_type"]["jrxml"] = existing_stats.get("files_by_type", {}).get("jrxml", 0) + stats["files_by_type"]["jrxml"] merged_stats["files_by_type"]["markdown"] = existing_stats.get("files_by_type", {}).get("markdown", 0) + stats["files_by_type"]["markdown"] if stats["failed_files"]: merged_stats.setdefault("failed_files", []).extend(stats["failed_files"]) stats_serializable = {k: (dict(v) if isinstance(v, defaultdict) else v) for k, v in merged_stats.items()} else: stats_serializable = {k: (dict(v) if isinstance(v, defaultdict) else v) for k, v in stats.items()} # 保存所有 chunks all_chunks_path = output_path / "all_chunks.json" save_chunks_to_json(all_chunks, str(all_chunks_path)) # 保存统计报告 stats_path = output_path / "processing_stats.json" with open(stats_path, "w", encoding="utf-8") as f: json.dump(stats_serializable, f, ensure_ascii=False, indent=2) # 打印总结 total_success = stats_serializable.get("success", stats["success"]) total_failed = stats_serializable.get("failed", stats["failed"]) total_chunks_count = stats_serializable.get("total_chunks", stats["total_chunks"]) jrxml_count = stats_serializable.get("files_by_type", {}).get("jrxml", stats["files_by_type"]["jrxml"]) md_count = stats_serializable.get("files_by_type", {}).get("markdown", stats["files_by_type"]["markdown"]) print(f"\n{'='*60}") print(f"处理完成!") print(f"{'='*60}") print(f"✅ 成功: {total_success} 文件 (JRXML: {jrxml_count}, MD: {md_count})") print(f"❌ 失败: {total_failed} 文件") print(f"📦 总 Chunks: {total_chunks_count}") print(f"⏱️ 总耗时: {total_duration:.2f}s") print(f"📂 输出目录: {output_path}") print(f"\n主要文件:") print(f" - {all_chunks_path}") print(f" - {stats_path}") display_types = stats_serializable.get("chunk_types", stats.get("chunk_types", {})) if display_types: print(f"\nChunk 类型分布 (前 10):") sorted_types = sorted(display_types.items(), key=lambda x: -x[1])[:10] for ct, count in sorted_types: print(f" {ct}: {count}") if stats["failed_files"]: print(f"\n⚠️ 失败文件详情:") for fail in stats["failed_files"][:10]: print(f" - {fail['file']} ({fail['type']}): {fail['error']}") return { "chunks": all_chunks, "stats": stats_serializable, "output_path": str(output_path) } if __name__ == "__main__": if len(sys.argv) < 2: print("=" * 60) print("统一批量分块 v1.0") print("支持 JRXML 和 Markdown 文件") print("=" * 60) print("\n用法:") print(" python batch_chunker.py <目录路径>") print(" python batch_chunker.py <目录路径> --output <输出目录>") print(" python batch_chunker.py <目录路径> --incremental") print("\n示例:") print(" python batch_chunker.py ./jrxml_source") print(" python batch_chunker.py ./docs") print(" python batch_chunker.py ./ --output ./chunks") print(" python batch_chunker.py ./jrxml_source --incremental # 增量分块") sys.exit(0) input_path = sys.argv[1] output_dir = None if "--output" in sys.argv: idx = sys.argv.index("--output") if idx + 1 < len(sys.argv): output_dir = sys.argv[idx + 1] incremental = "--incremental" in sys.argv if os.path.isdir(input_path): batch_chunk_with_report(input_path, output_dir, incremental=incremental) else: print(f"❌ 路径无效或不是目录: {input_path}")