优化任务调度说明
This commit is contained in:
@@ -0,0 +1 @@
|
||||
from .logger import CrossPlatformLog
|
||||
@@ -0,0 +1,513 @@
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
import pickle
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
from pathlib import Path, PurePath
|
||||
from typing import Union, Optional, List, Dict, Any, Callable
|
||||
from utils.logger import log
|
||||
|
||||
class FileHandler:
|
||||
"""
|
||||
跨平台文件操作工具类(兼容Windows/macOS/Linux)
|
||||
功能规范:
|
||||
- 读取文件内容的方法返回DataFrame
|
||||
- 其他所有方法返回统一格式字典:
|
||||
{
|
||||
'success': bool, # 操作是否成功
|
||||
'message': str, # 操作结果描述
|
||||
'data': Any # 操作返回的数据(可选)
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, base_path: Optional[Union[str, Path]] = None):
|
||||
"""
|
||||
初始化文件处理器
|
||||
:param base_path: 基础路径(自动处理跨平台路径格式)
|
||||
"""
|
||||
self.base_path = self._normalize_path(base_path) if base_path else None
|
||||
self.log = log.bind(module=self.__class__.__name__)
|
||||
|
||||
def _normalize_path(self, path: Union[str, Path]) -> Path:
|
||||
"""统一转换为跨平台Path对象"""
|
||||
return Path(str(path).replace('\\', '/'))
|
||||
|
||||
def _resolve_path(self, path: Union[str, Path]) -> Path:
|
||||
"""解析路径(自动处理跨平台路径)"""
|
||||
path = self._normalize_path(path)
|
||||
if not path.is_absolute() and self.base_path:
|
||||
return self._normalize_path(self.base_path / path)
|
||||
return path
|
||||
|
||||
def _format_result(self,
|
||||
success: bool,
|
||||
message: str = "",
|
||||
data: Optional[Any] = None) -> Dict[str, Any]:
|
||||
"""统一返回结果格式"""
|
||||
return {
|
||||
'success': bool(success),
|
||||
'message': str(message),
|
||||
'data': data
|
||||
}
|
||||
|
||||
def read_file(self,
|
||||
file_path: Union[str, Path],
|
||||
encoding: str = 'utf-8',
|
||||
**kwargs) -> pd.DataFrame:
|
||||
"""
|
||||
读取文件内容为DataFrame(跨平台兼容)
|
||||
:param file_path: 文件路径(自动处理跨平台格式)
|
||||
:param encoding: 文件编码(默认utf-8)
|
||||
:return: 包含文件内容的DataFrame
|
||||
:raises: 文件读取失败时抛出原始异常
|
||||
"""
|
||||
file_path = self._resolve_path(file_path)
|
||||
try:
|
||||
ext = self.get_file_extension(file_path)
|
||||
|
||||
if ext in ['csv', 'txt']:
|
||||
df = pd.read_csv(file_path, encoding=encoding, **kwargs)
|
||||
elif ext in ['xls', 'xlsx']:
|
||||
df = pd.read_excel(file_path, **kwargs)
|
||||
elif ext == 'json':
|
||||
df = pd.read_json(file_path, encoding=encoding, **kwargs)
|
||||
elif ext in ['pkl', 'pickle']:
|
||||
# 统一将pickle内容转为DataFrame返回
|
||||
obj = pd.read_pickle(file_path)
|
||||
if isinstance(obj, pd.DataFrame):
|
||||
df = obj
|
||||
elif isinstance(obj, list):
|
||||
df = pd.DataFrame(obj)
|
||||
elif isinstance(obj, dict):
|
||||
df = pd.DataFrame([obj])
|
||||
else:
|
||||
df = pd.DataFrame({'content': [obj]})
|
||||
elif ext == 'parquet':
|
||||
df = pd.read_parquet(file_path, **kwargs)
|
||||
else:
|
||||
with open(file_path, 'r', encoding=encoding) as f:
|
||||
return pd.DataFrame({'content': [f.read()]})
|
||||
|
||||
self.log.debug(f"文件读取成功 | path={file_path} shape={df.shape}")
|
||||
return df
|
||||
except Exception as e:
|
||||
self.log.error(f"文件读取失败 | path={file_path} error={str(e)}")
|
||||
raise
|
||||
|
||||
def write_file(self,
|
||||
file_path: Union[str, Path],
|
||||
data: Union[pd.DataFrame, Dict, List],
|
||||
encoding: str = 'utf-8',
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
写入文件(跨平台兼容)
|
||||
:param file_path: 目标文件路径
|
||||
:param data: 要写入的数据(支持DataFrame/dict/list)
|
||||
:param encoding: 文件编码(默认utf-8)
|
||||
:return: 操作结果字典
|
||||
"""
|
||||
file_path = self._resolve_path(file_path)
|
||||
try:
|
||||
# 自动创建父目录
|
||||
parent_dir = file_path.parent
|
||||
if not parent_dir.exists():
|
||||
self.create_dir(parent_dir)
|
||||
|
||||
# 根据扩展名选择写入方式
|
||||
ext = self.get_file_extension(file_path)
|
||||
|
||||
if ext in ['pkl', 'pickle']:
|
||||
# 直接按原始对象进行pickle序列化
|
||||
with open(file_path, 'wb') as f:
|
||||
pickle.dump(data, f)
|
||||
else:
|
||||
# 统一数据格式到DataFrame
|
||||
if isinstance(data, pd.DataFrame):
|
||||
df = data
|
||||
else:
|
||||
df = pd.DataFrame(data if isinstance(data, list) else [data])
|
||||
|
||||
if ext in ['csv', 'txt']:
|
||||
df.to_csv(file_path, encoding=encoding, index=False, **kwargs)
|
||||
elif ext in ['xls', 'xlsx']:
|
||||
df.to_excel(file_path, index=False, **kwargs)
|
||||
elif ext == 'json':
|
||||
df.to_json(file_path, force_ascii=False, **kwargs)
|
||||
elif ext == 'parquet':
|
||||
df.to_parquet(file_path, **kwargs)
|
||||
else:
|
||||
with open(file_path, 'w', encoding=encoding) as f:
|
||||
f.write(str(data))
|
||||
|
||||
# 返回成功结果
|
||||
return self._format_result(
|
||||
True,
|
||||
"文件写入成功",
|
||||
{
|
||||
'file_path': str(file_path),
|
||||
'file_size': os.path.getsize(file_path)
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
return self._format_result(
|
||||
False,
|
||||
f"文件写入失败: {str(e)}",
|
||||
{'file_path': str(file_path)}
|
||||
)
|
||||
|
||||
def file_exists(self, file_path: Union[str, Path]) -> Dict[str, Any]:
|
||||
"""
|
||||
检查文件是否存在(跨平台兼容)
|
||||
:return: 包含exists字段的结果字典
|
||||
"""
|
||||
file_path = self._resolve_path(file_path)
|
||||
exists = file_path.is_file()
|
||||
msg = f"文件{'' if exists else '不'}存在: {file_path}"
|
||||
return self._format_result(True, msg, {'exists': exists})
|
||||
|
||||
def dir_exists(self, dir_path: Union[str, Path]) -> Dict[str, Any]:
|
||||
"""
|
||||
检查目录是否存在(跨平台兼容)
|
||||
:return: 包含exists字段的结果字典
|
||||
"""
|
||||
dir_path = self._resolve_path(dir_path)
|
||||
exists = dir_path.is_dir()
|
||||
msg = f"目录{'' if exists else '不'}存在: {dir_path}"
|
||||
return self._format_result(True, msg, {'exists': exists})
|
||||
|
||||
def create_dir(self, dir_path: Union[str, Path]) -> Dict[str, Any]:
|
||||
"""
|
||||
创建目录(跨平台兼容)
|
||||
:return: 包含path字段的结果字典
|
||||
"""
|
||||
dir_path = self._resolve_path(dir_path)
|
||||
try:
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Windows系统需要额外设置权限
|
||||
if os.name == 'nt':
|
||||
try:
|
||||
os.chmod(dir_path, 0o777)
|
||||
except:
|
||||
pass
|
||||
|
||||
return self._format_result(True, "目录创建成功", {'path': str(dir_path)})
|
||||
except Exception as e:
|
||||
return self._format_result(False, f"目录创建失败: {str(e)}", {'path': str(dir_path)})
|
||||
|
||||
def delete_file(self, file_path: Union[str, Path]) -> Dict[str, Any]:
|
||||
"""
|
||||
删除文件(跨平台兼容)
|
||||
:return: 包含path字段的结果字典
|
||||
"""
|
||||
file_path = self._resolve_path(file_path)
|
||||
try:
|
||||
if not file_path.exists():
|
||||
return self._format_result(False, "文件不存在", {'path': str(file_path)})
|
||||
|
||||
file_path.unlink()
|
||||
return self._format_result(True, "文件删除成功", {'path': str(file_path)})
|
||||
except Exception as e:
|
||||
return self._format_result(False, f"文件删除失败: {str(e)}", {'path': str(file_path)})
|
||||
|
||||
def delete_dir(self, dir_path: Union[str, Path]) -> Dict[str, Any]:
|
||||
"""
|
||||
删除目录及其内容(跨平台兼容)
|
||||
:return: 包含path字段的结果字典
|
||||
"""
|
||||
dir_path = self._resolve_path(dir_path)
|
||||
try:
|
||||
if not dir_path.exists():
|
||||
return self._format_result(False, "目录不存在", {'path': str(dir_path)})
|
||||
|
||||
shutil.rmtree(dir_path)
|
||||
return self._format_result(True, "目录删除成功", {'path': str(dir_path)})
|
||||
except Exception as e:
|
||||
return self._format_result(False, f"目录删除失败: {str(e)}", {'path': str(dir_path)})
|
||||
|
||||
def list_files(self,
|
||||
dir_path: Union[str, Path],
|
||||
recursive: bool = False,
|
||||
pattern: str = '*') -> Dict[str, Any]:
|
||||
"""
|
||||
列出目录中的文件(跨平台兼容)
|
||||
:param recursive: 是否递归查找
|
||||
:param pattern: 文件匹配模式(如*.txt)
|
||||
:return: 包含files字段的结果字典
|
||||
"""
|
||||
dir_path = self._resolve_path(dir_path)
|
||||
try:
|
||||
if recursive:
|
||||
files = list(dir_path.rglob(pattern))
|
||||
else:
|
||||
files = list(dir_path.glob(pattern))
|
||||
|
||||
file_info = [
|
||||
{
|
||||
'path': str(f),
|
||||
'name': f.name,
|
||||
'size': f.stat().st_size,
|
||||
'modified': datetime.fromtimestamp(f.stat().st_mtime).isoformat(),
|
||||
'is_dir': f.is_dir()
|
||||
} for f in files if f.is_file() # 只返回文件,不包括目录
|
||||
]
|
||||
|
||||
return self._format_result(
|
||||
True,
|
||||
f"找到 {len(file_info)} 个文件",
|
||||
{'files': file_info}
|
||||
)
|
||||
except Exception as e:
|
||||
return self._format_result(
|
||||
False,
|
||||
f"列出文件失败: {str(e)}",
|
||||
{'files': []}
|
||||
)
|
||||
|
||||
def get_file_extension(self, file_path: Union[str, Path]) -> str:
|
||||
"""
|
||||
获取文件扩展名(跨平台兼容)
|
||||
:return: 小写且不带点的扩展名(如 'jpg')
|
||||
"""
|
||||
file_path = self._resolve_path(file_path)
|
||||
ext = file_path.suffix.lower().lstrip('.')
|
||||
self.log.trace(f"获取文件扩展名 | path={file_path} ext={ext}")
|
||||
return ext
|
||||
|
||||
def copy_file(self,
|
||||
src_path: Union[str, Path],
|
||||
dst_path: Union[str, Path]) -> Dict[str, Any]:
|
||||
"""
|
||||
复制文件(跨平台兼容)
|
||||
:return: 包含source和destination字段的结果字典
|
||||
"""
|
||||
src_path = self._resolve_path(src_path)
|
||||
dst_path = self._resolve_path(dst_path)
|
||||
try:
|
||||
if not src_path.exists():
|
||||
return self._format_result(
|
||||
False,
|
||||
"源文件不存在",
|
||||
{
|
||||
'source': str(src_path),
|
||||
'destination': str(dst_path)
|
||||
}
|
||||
)
|
||||
|
||||
# 确保目标目录存在
|
||||
self.create_dir(dst_path.parent)
|
||||
|
||||
shutil.copy2(src_path, dst_path)
|
||||
return self._format_result(
|
||||
True,
|
||||
"文件复制成功",
|
||||
{
|
||||
'source': str(src_path),
|
||||
'destination': str(dst_path),
|
||||
'file_size': dst_path.stat().st_size
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
return self._format_result(
|
||||
False,
|
||||
f"文件复制失败: {str(e)}",
|
||||
{
|
||||
'source': str(src_path),
|
||||
'destination': str(dst_path)
|
||||
}
|
||||
)
|
||||
|
||||
def move_file(self,
|
||||
src_path: Union[str, Path],
|
||||
dst_path: Union[str, Path]) -> Dict[str, Any]:
|
||||
"""
|
||||
移动/重命名文件(跨平台兼容)
|
||||
:return: 包含source和destination字段的结果字典
|
||||
"""
|
||||
src_path = self._resolve_path(src_path)
|
||||
dst_path = self._resolve_path(dst_path)
|
||||
try:
|
||||
if not src_path.exists():
|
||||
return self._format_result(
|
||||
False,
|
||||
"源文件不存在",
|
||||
{
|
||||
'source': str(src_path),
|
||||
'destination': str(dst_path)
|
||||
}
|
||||
)
|
||||
|
||||
# 确保目标目录存在
|
||||
self.create_dir(dst_path.parent)
|
||||
|
||||
shutil.move(src_path, dst_path)
|
||||
return self._format_result(
|
||||
True,
|
||||
"文件移动成功",
|
||||
{
|
||||
'source': str(src_path),
|
||||
'destination': str(dst_path)
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
return self._format_result(
|
||||
False,
|
||||
f"文件移动失败: {str(e)}",
|
||||
{
|
||||
'source': str(src_path),
|
||||
'destination': str(dst_path)
|
||||
}
|
||||
)
|
||||
|
||||
def zip_files(self,
|
||||
file_paths: List[Union[str, Path]],
|
||||
zip_path: Union[str, Path]) -> Dict[str, Any]:
|
||||
"""
|
||||
压缩多个文件到zip(跨平台兼容)
|
||||
:param file_paths: 要压缩的文件路径列表
|
||||
:param zip_path: 目标zip文件路径
|
||||
:return: 包含zip_path和file_count字段的结果字典
|
||||
"""
|
||||
zip_path = self._resolve_path(zip_path)
|
||||
try:
|
||||
# 确保目标目录存在
|
||||
self.create_dir(zip_path.parent)
|
||||
|
||||
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||||
file_count = 0
|
||||
for file_path in file_paths:
|
||||
file_path = self._resolve_path(file_path)
|
||||
if file_path.exists():
|
||||
zipf.write(file_path, file_path.name)
|
||||
file_count += 1
|
||||
|
||||
return self._format_result(
|
||||
True,
|
||||
"文件压缩成功",
|
||||
{
|
||||
'zip_path': str(zip_path),
|
||||
'file_count': file_count,
|
||||
'zip_size': os.path.getsize(zip_path)
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
return self._format_result(
|
||||
False,
|
||||
f"文件压缩失败: {str(e)}",
|
||||
{
|
||||
'zip_path': str(zip_path)
|
||||
}
|
||||
)
|
||||
|
||||
def unzip(self,
|
||||
zip_path: Union[str, Path],
|
||||
extract_to: Optional[Union[str, Path]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
解压zip文件(跨平台兼容)
|
||||
:param extract_to: 解压目标目录(默认为zip文件所在目录)
|
||||
:return: 包含extract_to和file_count字段的结果字典
|
||||
"""
|
||||
zip_path = self._resolve_path(zip_path)
|
||||
extract_to = self._resolve_path(extract_to) if extract_to else zip_path.parent
|
||||
|
||||
try:
|
||||
if not zip_path.exists():
|
||||
return self._format_result(
|
||||
False,
|
||||
"ZIP文件不存在",
|
||||
{
|
||||
'zip_path': str(zip_path),
|
||||
'extract_to': str(extract_to)
|
||||
}
|
||||
)
|
||||
|
||||
# 确保目标目录存在
|
||||
self.create_dir(extract_to)
|
||||
|
||||
with zipfile.ZipFile(zip_path, 'r') as zipf:
|
||||
file_list = zipf.namelist()
|
||||
zipf.extractall(extract_to)
|
||||
|
||||
return self._format_result(
|
||||
True,
|
||||
"文件解压成功",
|
||||
{
|
||||
'extract_to': str(extract_to),
|
||||
'file_count': len(file_list)
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
return self._format_result(
|
||||
False,
|
||||
f"文件解压失败: {str(e)}",
|
||||
{
|
||||
'zip_path': str(zip_path),
|
||||
'extract_to': str(extract_to)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------- 测试用例 ----------------------------
|
||||
if __name__ == "__main__":
|
||||
# 初始化处理器(自动处理跨平台路径)
|
||||
project_root = next(p for p in Path(__file__).resolve().parents if
|
||||
(p / '.git').exists() or (p / 'pyproject.toml').exists() or (p / 'requirements.txt').exists())
|
||||
handler = FileHandler(project_root / "test")
|
||||
|
||||
# 测试路径标准化
|
||||
test_paths = [
|
||||
"normal/path",
|
||||
"windows\\style\\path",
|
||||
"mixed/path\\with\\both"
|
||||
]
|
||||
|
||||
print("=== 路径标准化测试 ===")
|
||||
for path in test_paths:
|
||||
resolved = handler._resolve_path(path)
|
||||
print(f"原始路径: {path} -> 标准化: {resolved} (类型: {type(resolved)})")
|
||||
|
||||
# 测试目录操作
|
||||
print("\n=== 目录操作测试 ===")
|
||||
dir_result = handler.create_dir("test_dir")
|
||||
print(dir_result)
|
||||
|
||||
# 测试文件操作
|
||||
print("\n=== 文件操作测试 ===")
|
||||
test_data = [{"name": "Alice", "age": 25}, {"name": "Bob", "age": 30}]
|
||||
write_result = handler.write_file("test_dir/data.json", test_data)
|
||||
print(write_result)
|
||||
|
||||
# 测试文件读取
|
||||
try:
|
||||
df = handler.read_file("test_dir/data.json")
|
||||
print("\n读取文件内容:")
|
||||
print(df)
|
||||
except Exception as e:
|
||||
print(f"\n文件读取失败: {str(e)}")
|
||||
|
||||
# 测试列表文件
|
||||
print("\n=== 文件列表测试 ===")
|
||||
list_result = handler.list_files("test_dir")
|
||||
print(list_result)
|
||||
|
||||
# 测试压缩解压
|
||||
print("\n=== 压缩解压测试 ===")
|
||||
zip_result = handler.zip_files(
|
||||
["test_dir/data.json"],
|
||||
"test_archive.zip"
|
||||
)
|
||||
print(zip_result)
|
||||
|
||||
unzip_result = handler.unzip(
|
||||
"test_archive.zip",
|
||||
"extracted_files"
|
||||
)
|
||||
print(unzip_result)
|
||||
|
||||
# 清理测试数据
|
||||
print("\n=== 清理测试数据 ===")
|
||||
print(handler.delete_file("test_dir/data.json"))
|
||||
print(handler.delete_dir("test_dir"))
|
||||
print(handler.delete_file("test_archive.zip"))
|
||||
print(handler.delete_dir("extracted_files"))
|
||||
+128
@@ -0,0 +1,128 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
import platform
|
||||
from datetime import datetime
|
||||
import zipfile
|
||||
|
||||
|
||||
class CrossPlatformLog:
|
||||
"""跨平台日志系统(支持Linux/Windows/Mac)"""
|
||||
|
||||
def __init__(self):
|
||||
self.log_dir = self._get_log_dir()
|
||||
self._setup_logger()
|
||||
|
||||
def _get_log_dir(self):
|
||||
"""获取跨平台日志目录(相对路径)"""
|
||||
base_dir = Path(__file__).parent.parent # 项目根目录
|
||||
log_dir = base_dir / "logs"
|
||||
|
||||
# 自动创建日志目录
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Windows特殊权限处理
|
||||
if platform.system() == "Windows":
|
||||
try:
|
||||
os.chmod(log_dir, 0o777) # 确保写入权限
|
||||
except:
|
||||
pass
|
||||
|
||||
return log_dir
|
||||
|
||||
def _setup_logger(self):
|
||||
"""配置跨平台日志处理器"""
|
||||
logger.remove() # 清除默认配置
|
||||
|
||||
|
||||
# 统一控制台输出格式
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
level="INFO",
|
||||
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{module}</cyan> - <level>{message}</level>",
|
||||
filter=lambda record: record["level"].no >= 20 # INFO及以上级别
|
||||
)
|
||||
|
||||
# 主日志文件(兼容所有平台路径)
|
||||
self._add_main_log()
|
||||
|
||||
# 错误日志单独存储
|
||||
self._add_error_log()
|
||||
|
||||
def _add_main_log(self):
|
||||
"""主日志文件配置"""
|
||||
main_log = self.log_dir / "application.log"
|
||||
logger.add(
|
||||
str(main_log),
|
||||
rotation="20 MB",
|
||||
compression=self._compress_log,
|
||||
encoding="utf-8",
|
||||
level="DEBUG",
|
||||
# 👇 增加 {extra} 输出,并美化结构
|
||||
# format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {module}:{line} - {message}{extra_output}",
|
||||
retention="30 days",
|
||||
enqueue=True,
|
||||
# 👇 动态处理 extra 字段为可读格式
|
||||
format=self._format_with_extra, # 使用自定义格式函数
|
||||
)
|
||||
|
||||
def _format_with_extra(self, record):
|
||||
# 构造 extra 的可读字符串
|
||||
extra_str = ""
|
||||
if record["extra"]:
|
||||
extra_items = []
|
||||
for key, value in record["extra"].items():
|
||||
if key == "extra_output": # 跳过自己,避免递归
|
||||
continue
|
||||
value_repr = repr(value)
|
||||
# 对于错误信息,增加截断长度限制,避免丢失重要信息
|
||||
if key in ["error", "error_message", "sql", "params"]:
|
||||
if len(value_repr) > 500:
|
||||
value_repr = value_repr[:497] + "..."
|
||||
elif len(value_repr) > 200:
|
||||
value_repr = value_repr[:197] + "..."
|
||||
extra_items.append(f"\n → {key}: {value_repr}")
|
||||
extra_str = "".join(extra_items)
|
||||
|
||||
# 👉 直接将 extra_str 写入 message 或附加字段
|
||||
record["extra"]["extra_output"] = extra_str
|
||||
|
||||
# ✅ 关键:返回的 format 字符串不再引用 {extra_output},而是使用 {extra[extra_output]}
|
||||
return "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {module}:{line} - {message}{extra[extra_output]}\n"
|
||||
def _add_error_log(self):
|
||||
"""错误日志专用配置"""
|
||||
error_log = self.log_dir / "errors.log"
|
||||
logger.add(
|
||||
str(error_log),
|
||||
level="ERROR",
|
||||
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | ERROR | {module}:{line} - {message}{extra[extra_output]}\n{exception}",
|
||||
rotation="10 MB",
|
||||
retention="90 days",
|
||||
enqueue=True
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _compress_log(log_path):
|
||||
"""通用日志压缩方法(兼容所有平台)"""
|
||||
if not os.path.exists(log_path):
|
||||
return
|
||||
|
||||
try:
|
||||
zip_path = f"{log_path}.{datetime.now().strftime('%Y%m%d')}.zip"
|
||||
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||||
zipf.write(log_path, arcname=os.path.basename(log_path))
|
||||
os.remove(log_path)
|
||||
return zip_path
|
||||
except Exception as e:
|
||||
print(f"日志压缩失败: {str(e)}")
|
||||
return log_path # 返回原文件路径继续使用
|
||||
|
||||
@classmethod
|
||||
def get_logger(cls, module_name=None):
|
||||
"""获取模块专属日志器"""
|
||||
return logger.bind(module=module_name or "__main__")
|
||||
|
||||
|
||||
# 初始化全局日志器
|
||||
log = CrossPlatformLog().get_logger()
|
||||
@@ -0,0 +1,383 @@
|
||||
import os
|
||||
import sys
|
||||
import platform
|
||||
import threading
|
||||
from typing import List, Dict, Optional, BinaryIO, Tuple, Any
|
||||
from datetime import datetime, timedelta
|
||||
import hashlib
|
||||
from io import BytesIO
|
||||
from minio import Minio
|
||||
from minio.error import S3Error, MinioException
|
||||
from utils.logger import log
|
||||
|
||||
|
||||
class MinIOAgent:
|
||||
"""
|
||||
全平台兼容的MinIO对象存储操作类
|
||||
支持Windows/macOS/Linux系统,提供对象存储的上传、下载、查询等功能
|
||||
专注于二进制数据处理,返回元数据用于与MySQL关联
|
||||
"""
|
||||
_instance = None # 单例模式实例
|
||||
_lock = threading.Lock() # 线程锁,保证单例线程安全
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""单例模式实现,确保全局只有一个实例"""
|
||||
if not cls._instance:
|
||||
with cls._lock:
|
||||
if not cls._instance:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: dict):
|
||||
"""
|
||||
初始化MinIO连接
|
||||
|
||||
参数:
|
||||
config (dict): MinIO配置字典,包含以下键:
|
||||
- endpoint: 服务端点(例:'localhost:9000')
|
||||
- access_key: 访问密钥
|
||||
- secret_key: 密钥
|
||||
- [可选] secure: 是否使用SSL(默认False)
|
||||
- [可选] region: 区域
|
||||
- [可选] timeout: 超时时间(秒,默认30)
|
||||
"""
|
||||
# 避免重复初始化
|
||||
if hasattr(self, '_client') and self._client:
|
||||
return
|
||||
|
||||
# 验证必要配置参数
|
||||
required_keys = ['endpoint', 'access_key', 'secret_key']
|
||||
if not all(key in config for key in required_keys):
|
||||
raise ValueError(f"MinIO配置缺少必要参数,需要: {required_keys}")
|
||||
|
||||
# 整合配置,设置默认值
|
||||
self.config = {
|
||||
'endpoint': config['endpoint'],
|
||||
'access_key': config['access_key'],
|
||||
'secret_key': config['secret_key'],
|
||||
'secure': config.get('secure', False),
|
||||
'region': config.get('region'),
|
||||
'timeout': config.get('timeout', 30)
|
||||
}
|
||||
|
||||
# 初始化日志,绑定当前平台信息
|
||||
current_platform = platform.system()
|
||||
self.log = log.bind(module=f"MinIOAgent({current_platform})")
|
||||
|
||||
# 创建客户端实例
|
||||
self._client = self._create_client()
|
||||
|
||||
# 验证连接是否有效
|
||||
self._verify_connection()
|
||||
|
||||
def _create_client(self) -> Minio:
|
||||
"""创建MinIO客户端实例"""
|
||||
try:
|
||||
client = Minio(
|
||||
endpoint=self.config['endpoint'],
|
||||
access_key=self.config['access_key'],
|
||||
secret_key=self.config['secret_key'],
|
||||
secure=self.config['secure'],
|
||||
region=self.config['region']
|
||||
)
|
||||
self.log.info("MinIO客户端创建成功")
|
||||
return client
|
||||
except Exception as e:
|
||||
self.log.critical("创建MinIO客户端失败", 错误=str(e), exc_info=True)
|
||||
raise
|
||||
|
||||
def _verify_connection(self) -> None:
|
||||
"""验证与MinIO服务的连接是否正常"""
|
||||
try:
|
||||
# 通过列出存储桶来验证连接
|
||||
self._client.list_buckets()
|
||||
self.log.info(f"成功连接到MinIO服务:{self.config['endpoint']}")
|
||||
except Exception as e:
|
||||
self.log.critical("连接验证失败", 错误=str(e), exc_info=True)
|
||||
raise
|
||||
|
||||
def create_bucket(self, bucket_name: str) -> bool:
|
||||
"""
|
||||
创建存储桶(如不存在)
|
||||
|
||||
参数:
|
||||
bucket_name: 存储桶名称
|
||||
|
||||
返回:
|
||||
是否成功创建(或已存在)
|
||||
"""
|
||||
try:
|
||||
if not self._client.bucket_exists(bucket_name):
|
||||
self._client.make_bucket(bucket_name)
|
||||
self.log.info(f"存储桶创建成功:{bucket_name}")
|
||||
return True
|
||||
self.log.debug(f"存储桶已存在:{bucket_name}")
|
||||
return True
|
||||
except MinioException as e:
|
||||
self.log.error(f"创建存储桶 {bucket_name} 失败", 错误=str(e), exc_info=True)
|
||||
return False
|
||||
|
||||
def upload_bytes(self, bucket: str, object_name: str, data: bytes) -> Dict[str, Any]:
|
||||
"""
|
||||
上传二进制数据至MinIO
|
||||
|
||||
参数:
|
||||
bucket: 存储桶名称
|
||||
object_name: 对象名称(路径)
|
||||
data: 二进制数据
|
||||
|
||||
返回:
|
||||
包含元数据的字典:
|
||||
- bucket: 存储桶名称
|
||||
- object_name: 对象路径
|
||||
- size: 数据大小(字节)
|
||||
- etag: 服务器生成的哈希值
|
||||
- content_type: 内容类型
|
||||
- upload_time: 上传时间(UTC)
|
||||
- local_hash: 本地计算的MD5哈希
|
||||
"""
|
||||
if not data:
|
||||
raise ValueError("上传数据不能为空")
|
||||
|
||||
# 确保存储桶存在
|
||||
self.create_bucket(bucket)
|
||||
|
||||
try:
|
||||
# 计算本地哈希(用于数据完整性校验)
|
||||
local_hash = hashlib.md5(data).hexdigest()
|
||||
|
||||
# 上传数据
|
||||
result = self._client.put_object(
|
||||
bucket_name=bucket,
|
||||
object_name=object_name,
|
||||
data=BytesIO(data),
|
||||
length=len(data),
|
||||
content_type=self._guess_content_type(object_name)
|
||||
)
|
||||
|
||||
# 构建元数据
|
||||
metadata = {
|
||||
'bucket': bucket,
|
||||
'object_name': object_name,
|
||||
'size': len(data),
|
||||
'etag': result.etag,
|
||||
'content_type': result.content_type,
|
||||
'upload_time': datetime.utcfromtimestamp(result.last_modified.timestamp()),
|
||||
'local_hash': local_hash
|
||||
}
|
||||
|
||||
self.log.info(
|
||||
"文件上传成功",
|
||||
存储桶=bucket,
|
||||
对象名称=object_name,
|
||||
大小=len(data)
|
||||
)
|
||||
return metadata
|
||||
|
||||
except MinioException as e:
|
||||
self.log.error(
|
||||
"文件上传失败",
|
||||
存储桶=bucket,
|
||||
对象名称=object_name,
|
||||
错误=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def download_file(self, bucket: str, object_name: str, local_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
从MinIO下载文件至本地
|
||||
|
||||
参数:
|
||||
bucket: 存储桶名称
|
||||
object_name: 对象名称(路径)
|
||||
local_path: 本地保存路径
|
||||
|
||||
返回:
|
||||
包含下载信息的字典:
|
||||
- local_path: 本地路径
|
||||
- size: 文件大小
|
||||
- download_time: 下载时间
|
||||
"""
|
||||
try:
|
||||
# 创建父目录(如果不存在)
|
||||
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||
|
||||
# 下载文件
|
||||
start_time = datetime.now()
|
||||
self._client.fget_object(bucket, object_name, local_path)
|
||||
download_time = datetime.now() - start_time
|
||||
|
||||
# 获取文件信息
|
||||
stat = os.stat(local_path)
|
||||
|
||||
result = {
|
||||
'local_path': local_path,
|
||||
'size': stat.st_size,
|
||||
'download_time': download_time.total_seconds(),
|
||||
'downloaded_at': datetime.now()
|
||||
}
|
||||
|
||||
self.log.info(
|
||||
"文件下载成功",
|
||||
存储桶=bucket,
|
||||
对象名称=object_name,
|
||||
本地路径=local_path,
|
||||
大小=stat.st_size
|
||||
)
|
||||
return result
|
||||
|
||||
except MinioException as e:
|
||||
self.log.error(
|
||||
"文件下载失败",
|
||||
存储桶=bucket,
|
||||
对象名称=object_name,
|
||||
错误=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
except IOError as e:
|
||||
self.log.error(
|
||||
"本地文件操作失败",
|
||||
本地路径=local_path,
|
||||
错误=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_presigned_url(self, bucket: str, object_name: str, expires: int = 3600) -> Dict[str, str]:
|
||||
"""
|
||||
生成临时访问URL
|
||||
|
||||
参数:
|
||||
bucket: 存储桶名称
|
||||
object_name: 对象名称(路径)
|
||||
expires: 过期时间(秒),默认3600秒
|
||||
|
||||
返回:
|
||||
包含URL和过期信息的字典
|
||||
"""
|
||||
try:
|
||||
url = self._client.presigned_get_object(
|
||||
bucket_name=bucket,
|
||||
object_name=object_name,
|
||||
expires=expires
|
||||
)
|
||||
|
||||
result = {
|
||||
'presigned_url': url,
|
||||
'expires_in': expires,
|
||||
'expires_at': datetime.now() + timedelta(seconds=expires),
|
||||
'bucket': bucket,
|
||||
'object_name': object_name
|
||||
}
|
||||
|
||||
self.log.debug(
|
||||
"预签名URL生成成功",
|
||||
存储桶=bucket,
|
||||
对象名称=object_name,
|
||||
过期时间=expires
|
||||
)
|
||||
return result
|
||||
|
||||
except MinioException as e:
|
||||
self.log.error(
|
||||
"生成预签名URL失败",
|
||||
存储桶=bucket,
|
||||
对象名称=object_name,
|
||||
错误=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def list_objects(self, bucket: str, prefix: str = "") -> List[Dict[str, Any]]:
|
||||
"""
|
||||
查询指定前缀的对象列表及元数据
|
||||
|
||||
参数:
|
||||
bucket: 存储桶名称
|
||||
prefix: 对象路径前缀
|
||||
|
||||
返回:
|
||||
对象信息列表,每个对象包含:
|
||||
- bucket: 存储桶
|
||||
- object_name: 对象名称
|
||||
- size: 大小
|
||||
- last_modified: 最后修改时间
|
||||
- etag: 哈希值
|
||||
- content_type: 内容类型
|
||||
"""
|
||||
try:
|
||||
objects = self._client.list_objects(
|
||||
bucket_name=bucket,
|
||||
prefix=prefix,
|
||||
recursive=True
|
||||
)
|
||||
|
||||
result = []
|
||||
for obj in objects:
|
||||
# 获取详细元数据
|
||||
stat = self._client.stat_object(bucket, obj.object_name)
|
||||
|
||||
result.append({
|
||||
'bucket': bucket,
|
||||
'object_name': obj.object_name,
|
||||
'size': obj.size,
|
||||
'last_modified': obj.last_modified,
|
||||
'etag': stat.etag,
|
||||
'content_type': stat.content_type
|
||||
})
|
||||
|
||||
self.log.info(
|
||||
"对象列表查询成功",
|
||||
存储桶=bucket,
|
||||
前缀=prefix,
|
||||
数量=len(result)
|
||||
)
|
||||
return result
|
||||
|
||||
except MinioException as e:
|
||||
self.log.error(
|
||||
"查询对象列表失败",
|
||||
存储桶=bucket,
|
||||
前缀=prefix,
|
||||
错误=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def delete_object(self, bucket: str, object_name: str) -> bool:
|
||||
"""
|
||||
删除指定对象
|
||||
|
||||
参数:
|
||||
bucket: 存储桶名称
|
||||
object_name: 对象名称(路径)
|
||||
|
||||
返回:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
self._client.remove_object(bucket, object_name)
|
||||
self.log.info(
|
||||
"对象删除成功",
|
||||
存储桶=bucket,
|
||||
对象名称=object_name
|
||||
)
|
||||
return True
|
||||
except MinioException as e:
|
||||
self.log.error(
|
||||
"删除对象失败",
|
||||
存储桶=bucket,
|
||||
对象名称=object_name,
|
||||
错误=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _guess_content_type(object_name: str) -> str:
|
||||
"""根据文件名猜测内容类型"""
|
||||
from mimetypes import guess_type
|
||||
mime_type, _ = guess_type(object_name)
|
||||
return mime_type or 'application/octet-stream' # 默认二进制流类型
|
||||
@@ -0,0 +1,722 @@
|
||||
import os
|
||||
import sys
|
||||
import platform
|
||||
import pandas as pd
|
||||
import pymysql
|
||||
import json
|
||||
import numpy as np
|
||||
from pymysql import cursors
|
||||
from pymysql.err import MySQLError
|
||||
from typing import Union, List, Dict, Any, Optional, Tuple, Literal
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# 导入日志系统
|
||||
from utils.logger import log
|
||||
|
||||
|
||||
class MySQLAgent:
|
||||
"""
|
||||
全平台兼容的MySQL数据库操作类
|
||||
支持Windows/macOS/Linux系统
|
||||
配置参数从外部传入,不使用连接池和事务管理
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if not cls._instance:
|
||||
with cls._lock:
|
||||
if not cls._instance:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: dict):
|
||||
"""初始化MySQL数据库连接(原有逻辑完全保留)"""
|
||||
if hasattr(self, 'config') and self.config:
|
||||
return
|
||||
|
||||
# 基础配置校验
|
||||
required_keys = ['host', 'port', 'user', 'password', 'database']
|
||||
if not all(key in config for key in required_keys):
|
||||
log.warning(f"数据库配置缺少必要参数,当前数据库链接信息为:{config}")
|
||||
raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}")
|
||||
|
||||
self.config = {
|
||||
'host': config['host'],
|
||||
'port': config['port'],
|
||||
'user': config['user'],
|
||||
'password': config['password'],
|
||||
'database': config['database'],
|
||||
'charset': config.get('charset', 'utf8mb4'),
|
||||
'autocommit': True,
|
||||
'connect_timeout': config.get('connect_timeout', 10),
|
||||
'read_timeout': config.get('read_timeout', 30),
|
||||
'write_timeout': config.get('write_timeout', 30),
|
||||
'ssl': config.get('ssl')
|
||||
}
|
||||
|
||||
# 初始化日志
|
||||
current_platform = platform.system()
|
||||
self.log = log.bind(module=f"MySQLAgent({current_platform})")
|
||||
|
||||
def get_connection(self) -> pymysql.connections.Connection:
|
||||
"""获取数据库连接(原有逻辑完全保留)"""
|
||||
try:
|
||||
conn = pymysql.connect(** self.config)
|
||||
|
||||
# 为连接添加 character_set_name 方法
|
||||
if not hasattr(conn, 'character_set_name'):
|
||||
def _character_set_name():
|
||||
return self.config.get('charset', 'utf8mb4')
|
||||
|
||||
conn.character_set_name = _character_set_name
|
||||
|
||||
# macOS需要特殊处理SSL
|
||||
if platform.system() == 'Darwin' and self.config.get('ssl'):
|
||||
conn.ping(reconnect=True)
|
||||
|
||||
self.log.trace("获取数据库连接成功")
|
||||
return conn
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if platform.system() == 'Windows' and "timed out" in error_msg:
|
||||
self.log.warning("Windows连接超时,正在重试...")
|
||||
return self._retry_connection()
|
||||
|
||||
self.log.error("连接失败",
|
||||
error=error_msg,
|
||||
error_type=type(e).__name__,
|
||||
host=self.config.get('host'),
|
||||
port=self.config.get('port'),
|
||||
database=self.config.get('database'),
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
def _retry_connection(self, max_retries: int = 3) -> Any | None:
|
||||
"""Windows平台连接重试机制(原有逻辑完全保留)"""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
conn = pymysql.connect(**self.config)
|
||||
self.log.info(f"经过 {attempt + 1} 次尝试后成功建立连接")
|
||||
return conn
|
||||
except Exception:
|
||||
if attempt == max_retries - 1:
|
||||
raise
|
||||
import time
|
||||
time.sleep(1)
|
||||
|
||||
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
|
||||
parse_dates: Union[List[str], bool] = True,is_print = True) -> pd.DataFrame:
|
||||
"""执行SQL查询并返回DataFrame(原有逻辑完全保留)"""
|
||||
try:
|
||||
self.log.debug("执行SQL查询", sql=sql)
|
||||
|
||||
# 获取连接并确保字符集方法存在
|
||||
conn = self.get_connection()
|
||||
|
||||
# 创建SQLAlchemy引擎
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
engine = create_engine(
|
||||
"mysql+pymysql://",
|
||||
creator=lambda: conn,
|
||||
poolclass=StaticPool,
|
||||
connect_args={'charset': self.config.get('charset', 'utf8mb4')}
|
||||
)
|
||||
|
||||
# 执行查询
|
||||
df = pd.read_sql(sql, engine, params=params, parse_dates=parse_dates)
|
||||
if is_print:
|
||||
self.log.info("查询执行成功", 行数=len(df))
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("SQL查询失败",
|
||||
sql=sql,
|
||||
params=params,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
if 'engine' in locals():
|
||||
engine.dispose()
|
||||
|
||||
def insert_from_df(self, table_name: str, df: pd.DataFrame,
|
||||
chunk_size: int = 1000, replace: bool = False,
|
||||
ignore_duplicates: bool = None) -> int:
|
||||
"""
|
||||
兼容旧接口的通用插入方法:保留replace参数,同时支持新的ignore_duplicates
|
||||
自动处理重复数据,对所有数据源通用,插入失败的数据会通过日志记录
|
||||
"""
|
||||
# 【兼容性处理】如果未指定ignore_duplicates,用replace参数推导
|
||||
if ignore_duplicates is None:
|
||||
ignore_duplicates = not replace # 旧逻辑中replace=True表示替换,即不忽略重复
|
||||
|
||||
if df.empty:
|
||||
self.log.warning("尝试插入空的DataFrame", table=table_name)
|
||||
return 0
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
total_inserted = 0
|
||||
total_duplicates = 0
|
||||
total_failed = 0
|
||||
failed_records = [] # 存储所有失败的记录
|
||||
|
||||
try:
|
||||
# 1. 建立数据库连接
|
||||
conn = self.get_connection()
|
||||
cursor = conn.cursor()
|
||||
self.log.debug(f"已建立连接,准备插入数据到 {table_name}")
|
||||
|
||||
# 2. 获取数据库表的实际列名
|
||||
cursor.execute(f"SHOW COLUMNS FROM `{table_name}`")
|
||||
columns_info = cursor.fetchall()
|
||||
db_columns = [col[0] for col in columns_info]
|
||||
self.log.debug(f"表 {table_name} 包含以下列:{db_columns}")
|
||||
|
||||
# 3. 数据预处理:统一处理空值
|
||||
cleaned_df = df.replace(
|
||||
[None, np.nan, pd.NA, 'nan', 'NaN', 'NAN', ''],
|
||||
None
|
||||
).copy()
|
||||
|
||||
# 4. 字段匹配:只保留与数据库匹配的列
|
||||
df_columns = cleaned_df.columns.tolist()
|
||||
matched_columns = [col for col in df_columns if col in db_columns]
|
||||
unmatched_columns = [col for col in df_columns if col not in db_columns]
|
||||
|
||||
if unmatched_columns:
|
||||
self.log.warning(
|
||||
f"表 {table_name} 中存在不匹配的列,已自动丢弃",
|
||||
unmatched_columns=unmatched_columns,
|
||||
count=len(unmatched_columns)
|
||||
)
|
||||
|
||||
if not matched_columns:
|
||||
self.log.warning(f"表 {table_name} 没有匹配的列,终止插入操作")
|
||||
return 0
|
||||
|
||||
filtered_df = cleaned_df[matched_columns].copy()
|
||||
total_to_insert = len(filtered_df)
|
||||
self.log.debug(
|
||||
f"表 {table_name} 的过滤后DataFrame:共 {total_to_insert} 行待插入"
|
||||
)
|
||||
|
||||
# 5. 处理复杂类型(dict/list转JSON)
|
||||
for col in filtered_df.columns:
|
||||
has_complex_type = filtered_df[col].apply(
|
||||
lambda x: isinstance(x, (dict, list)) if x is not None else False
|
||||
).any()
|
||||
|
||||
if has_complex_type:
|
||||
self.log.debug(f"表 {table_name} 中的 {col} 列包含复杂类型,正在转换为JSON")
|
||||
filtered_df.loc[:, col] = filtered_df[col].apply(
|
||||
lambda x: json.dumps(x, ensure_ascii=False) if x is not None else x
|
||||
)
|
||||
|
||||
# 6. 构建通用插入SQL
|
||||
columns_str = ', '.join([f"`{col}`" for col in filtered_df.columns])
|
||||
placeholders = ', '.join(['%s'] * len(filtered_df.columns))
|
||||
insert_sql = f"INSERT INTO `{table_name}` ({columns_str}) VALUES ({placeholders})"
|
||||
self.log.trace(f"为表 {table_name} 生成的插入SQL:{insert_sql}")
|
||||
|
||||
# 7. 逐条插入(确保能捕获单条重复错误)
|
||||
records = filtered_df.to_dict('records')
|
||||
indices = filtered_df.index.tolist()
|
||||
|
||||
for i, (record, idx) in enumerate(zip(records, indices)):
|
||||
try:
|
||||
data = tuple(record[col] for col in filtered_df.columns)
|
||||
cursor.execute(insert_sql, data)
|
||||
total_inserted += 1
|
||||
|
||||
if (i + 1) % 100 == 0:
|
||||
self.log.trace(
|
||||
f"已向表 {table_name} 插入 {i + 1}/{total_to_insert} 行数据"
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
# 8. 捕获重复错误(MySQL错误码1062)
|
||||
if e.args[0] == 1062:
|
||||
total_duplicates += 1
|
||||
short_record = {
|
||||
k: (str(v)[:100] + '...') if isinstance(v, (str, dict, list)) else v
|
||||
for k, v in record.items()
|
||||
}
|
||||
self.log.warning(
|
||||
f"表 {table_name} 中跳过重复记录",
|
||||
index=idx,
|
||||
error_message=e.args[1],
|
||||
record=short_record
|
||||
)
|
||||
# 记录重复的记录
|
||||
failed_records.append({
|
||||
'index': idx,
|
||||
'type': 'duplicate',
|
||||
'error_code': e.args[0],
|
||||
'error_message': e.args[1],
|
||||
'record': record
|
||||
})
|
||||
if not ignore_duplicates:
|
||||
raise
|
||||
else:
|
||||
# 其他数据库错误
|
||||
total_failed += 1
|
||||
# 记录失败的记录详情
|
||||
failed_records.append({
|
||||
'index': idx,
|
||||
'type': 'error',
|
||||
'error_code': e.args[0],
|
||||
'error_message': e.args[1],
|
||||
'record': record
|
||||
})
|
||||
self.log.error(
|
||||
f"表 {table_name} 插入记录失败",
|
||||
index=idx,
|
||||
error_code=e.args[0],
|
||||
error_message=e.args[1],
|
||||
record=record # 完整记录写入日志
|
||||
)
|
||||
if not ignore_duplicates:
|
||||
raise
|
||||
|
||||
# 提交事务
|
||||
conn.commit()
|
||||
|
||||
# 9. 插入结果统计,包括失败记录汇总
|
||||
self.log.info(
|
||||
f"表 {table_name} 插入结果汇总",
|
||||
total_to_insert=total_to_insert,
|
||||
total_inserted=total_inserted,
|
||||
total_duplicates=total_duplicates,
|
||||
total_failed=total_failed,
|
||||
failed_records_count=len(failed_records)
|
||||
)
|
||||
|
||||
# 单独记录所有失败的数据详情
|
||||
if failed_records:
|
||||
self.log.error(
|
||||
f"表 {table_name} 插入失败记录详情",
|
||||
failed_records_summary=[
|
||||
{
|
||||
'index': r['index'],
|
||||
'type': r['type'],
|
||||
'error_code': r['error_code'],
|
||||
'error_message': r['error_message']
|
||||
} for r in failed_records
|
||||
],
|
||||
# 完整记录可以作为调试信息单独记录,避免日志过大
|
||||
detailed_failed_records=failed_records
|
||||
)
|
||||
|
||||
return total_inserted
|
||||
|
||||
except Exception as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
self.log.error(f"表 {table_name} 批量插入失败",
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
table_name=table_name,
|
||||
total_records=len(df) if not df.empty else 0,
|
||||
exc_info=True)
|
||||
# 记录事务回滚时的失败记录
|
||||
if failed_records:
|
||||
self.log.error(
|
||||
f"表 {table_name} 事务回滚,已失败的记录",
|
||||
failed_records=failed_records,
|
||||
failed_count=len(failed_records)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
def _get_primary_key(self, table_name: str, cursor) -> Optional[str]:
|
||||
"""【新增辅助方法】获取表的主键(用于replace逻辑的去重)"""
|
||||
try:
|
||||
cursor.execute("""
|
||||
SELECT COLUMN_NAME
|
||||
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
|
||||
WHERE TABLE_SCHEMA = %s
|
||||
AND TABLE_NAME = %s
|
||||
AND CONSTRAINT_NAME = 'PRIMARY'
|
||||
""", (self.config['database'], table_name))
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
except Exception as e:
|
||||
self.log.warning(f"获取表 {table_name} 的主键失败", error=str(e))
|
||||
return None
|
||||
|
||||
def _get_table_detailed_info(self, table_name: str) -> Dict[str, Dict[str, Any]]:
|
||||
"""获取表的详细结构信息(原有逻辑完全保留,供其他方法调用)"""
|
||||
sql = """
|
||||
SELECT column_name, data_type, character_maximum_length
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s \
|
||||
AND table_name = %s \
|
||||
"""
|
||||
params = (self.config['database'], table_name)
|
||||
|
||||
try:
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql, params)
|
||||
result = cursor.fetchall()
|
||||
|
||||
# 强制转换为列表,避免游标类型导致的解析问题
|
||||
result_list = list(result)
|
||||
if not result_list:
|
||||
self.log.error("未在表中找到任何列", 表=table_name)
|
||||
return {}
|
||||
|
||||
schema = {}
|
||||
for row in result_list:
|
||||
# 确保正确提取字段名(兼容元组格式)
|
||||
col_name = str(row[0]).strip() # 强制转为字符串并去空格
|
||||
data_type = str(row[1]).strip()
|
||||
max_length = row[2] if row[2] else None
|
||||
|
||||
schema[col_name] = {
|
||||
'type': data_type,
|
||||
'max_length': max_length
|
||||
}
|
||||
|
||||
self.log.debug("成功获取表结构信息",
|
||||
表=table_name,
|
||||
列=list(schema.keys()))
|
||||
return schema
|
||||
finally:
|
||||
cursor.close()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
self.log.error("获取表详细信息失败",
|
||||
表=table_name,
|
||||
error=str(e))
|
||||
raise
|
||||
|
||||
def _validate_and_clean_data(self, df: pd.DataFrame, table_name: str,
|
||||
table_schema: Dict[str, Dict[str, Any]]) -> pd.DataFrame:
|
||||
"""数据校验与清洗(原有逻辑完全保留,供其他方法调用)"""
|
||||
# 1. 字段过滤:只保留表中存在的字段
|
||||
df_columns = df.columns.tolist()
|
||||
table_columns = list(table_schema.keys())
|
||||
|
||||
valid_columns = [col for col in df_columns if col in table_columns]
|
||||
invalid_columns = [col for col in df_columns if col not in table_columns]
|
||||
|
||||
if invalid_columns:
|
||||
self.log.warning("丢弃表中不存在的无效列",
|
||||
表=table_name,
|
||||
无效列=invalid_columns,
|
||||
数量=len(invalid_columns))
|
||||
|
||||
cleaned_df = df[valid_columns].copy()
|
||||
if cleaned_df.empty:
|
||||
return cleaned_df
|
||||
|
||||
# 2. 处理每个字段的数据
|
||||
for col in valid_columns:
|
||||
col_info = table_schema[col]
|
||||
data_type = col_info['type']
|
||||
max_length = col_info['max_length']
|
||||
|
||||
# 2.1 处理空值
|
||||
if cleaned_df[col].isnull().any():
|
||||
# 根据字段类型设置默认值
|
||||
default_value = '' if data_type in ['varchar', 'char', 'text'] else None
|
||||
cleaned_df[col].fillna(default_value, inplace=True)
|
||||
self.log.debug("替换空值",
|
||||
表=table_name,
|
||||
列=col,
|
||||
默认值=default_value,
|
||||
数量=cleaned_df[col].isnull().sum())
|
||||
|
||||
# 2.2 处理字符串类型的超长字段
|
||||
if data_type in ['varchar', 'char'] and max_length:
|
||||
# 确保是字符串类型
|
||||
cleaned_df[col] = cleaned_df[col].astype(str)
|
||||
# 截断超长内容
|
||||
too_long_mask = cleaned_df[col].str.len() > max_length
|
||||
if too_long_mask.any():
|
||||
cleaned_df.loc[too_long_mask, col] = cleaned_df.loc[too_long_mask, col].str.slice(0, max_length)
|
||||
self.log.warning("截断超长值",
|
||||
表=table_name,
|
||||
列=col,
|
||||
最大长度=max_length,
|
||||
数量=too_long_mask.sum())
|
||||
|
||||
# 2.3 处理日期时间类型
|
||||
if data_type in ['datetime', 'timestamp']:
|
||||
try:
|
||||
# 尝试转换为datetime类型
|
||||
cleaned_df[col] = pd.to_datetime(cleaned_df[col])
|
||||
except Exception as e:
|
||||
self.log.warning("转换为datetime失败,使用当前时间替代",
|
||||
表=table_name,
|
||||
列=col,
|
||||
错误=str(e))
|
||||
# 转换失败的用当前时间替代
|
||||
invalid_mask = pd.to_datetime(cleaned_df[col], errors='coerce').isna()
|
||||
cleaned_df.loc[invalid_mask, col] = datetime.now()
|
||||
|
||||
return cleaned_df
|
||||
|
||||
def update_from_df(self, table_name: str, df: pd.DataFrame,
|
||||
key_columns: Union[str, List[str]]) -> int:
|
||||
"""使用DataFrame数据更新数据库表(原有逻辑完全保留)"""
|
||||
if df.empty:
|
||||
self.log.warning("尝试使用空的DataFrame进行更新", 表=table_name)
|
||||
return 0
|
||||
|
||||
self.log.debug("准备从DataFrame更新表数据",
|
||||
表=table_name,
|
||||
关键字列=key_columns,
|
||||
行数=len(df))
|
||||
|
||||
try:
|
||||
if isinstance(key_columns, str):
|
||||
key_columns = [key_columns]
|
||||
|
||||
总更新数 = 0
|
||||
with self.get_connection() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
# 获取表结构信息
|
||||
table_info = self._get_table_detailed_info(table_name)
|
||||
columns = [col for col in df.columns if col in table_info]
|
||||
|
||||
# 构建UPDATE语句模板
|
||||
set_clause = ', '.join([f"{col}=%s" for col in columns if col not in key_columns])
|
||||
where_clause = ' AND '.join([f"{col}=%s" for col in key_columns])
|
||||
|
||||
if not set_clause:
|
||||
self.log.warning("没有可更新的列", 表=table_name)
|
||||
return 0
|
||||
|
||||
update_sql = f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}"
|
||||
self.log.trace("生成的更新SQL", sql=update_sql)
|
||||
|
||||
# 准备数据
|
||||
update_data = []
|
||||
for _, row in df.iterrows():
|
||||
set_values = [row[col] for col in columns if col not in key_columns]
|
||||
key_values = [row[col] for col in key_columns]
|
||||
update_data.append(tuple(set_values + key_values))
|
||||
|
||||
# 执行批量更新
|
||||
cursor.executemany(update_sql, update_data)
|
||||
总更新数 = cursor.rowcount
|
||||
conn.commit()
|
||||
|
||||
self.log.info("数据更新成功",
|
||||
表=table_name,
|
||||
更新行数=总更新数)
|
||||
return 总更新数
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("数据更新失败",
|
||||
表=table_name,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]:
|
||||
"""推断DataFrame各列的SQL类型(原有逻辑完全保留)"""
|
||||
type_mapping = {
|
||||
'int64': 'BIGINT',
|
||||
'float64': 'DOUBLE',
|
||||
'datetime64[ns]': 'DATETIME',
|
||||
'object': 'TEXT',
|
||||
'bool': 'TINYINT(1)',
|
||||
'category': 'VARCHAR(255)'
|
||||
}
|
||||
|
||||
sql_types = {}
|
||||
for col, dtype in df.dtypes.items():
|
||||
dtype_str = str(dtype)
|
||||
sql_types[col] = type_mapping.get(dtype_str, 'TEXT')
|
||||
|
||||
self.log.debug("将DataFrame类型映射为SQL类型",
|
||||
映射关系=sql_types)
|
||||
return sql_types
|
||||
|
||||
def create_table_from_df(self, table_name: str, df: pd.DataFrame,
|
||||
primary_key: Union[str, List[str], None] = None) -> bool:
|
||||
"""根据DataFrame结构创建表(原有逻辑完全保留)"""
|
||||
if self.table_exists(table_name):
|
||||
self.log.warning("表已存在", 表=table_name)
|
||||
return False
|
||||
|
||||
self.log.debug("根据DataFrame结构创建新表",
|
||||
表=table_name,
|
||||
列=list(df.columns))
|
||||
|
||||
try:
|
||||
sql_types = self.df_to_sql_type(df)
|
||||
columns_sql = []
|
||||
|
||||
for col, sql_type in sql_types.items():
|
||||
col_def = f"{col} {sql_type}"
|
||||
columns_sql.append(col_def)
|
||||
|
||||
if primary_key:
|
||||
if isinstance(primary_key, str):
|
||||
primary_key = [primary_key]
|
||||
pk_columns = [col for col in primary_key if col in sql_types]
|
||||
if pk_columns:
|
||||
columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})")
|
||||
self.log.trace("设置主键",
|
||||
表=table_name,
|
||||
主键=pk_columns)
|
||||
|
||||
create_sql = f"CREATE TABLE {table_name} (\n {',\n '.join(columns_sql)}\n)"
|
||||
|
||||
self.execute_sql(create_sql)
|
||||
self.log.info("表创建成功", 表=table_name)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("创建表失败",
|
||||
表=table_name,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
return False
|
||||
|
||||
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
|
||||
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
|
||||
"""执行SQL语句(原有逻辑完全保留)"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
# Linux/macOS需要更长的执行时间
|
||||
if platform.system() != 'Windows':
|
||||
cursor.execute("SET SESSION max_execution_time=600000")
|
||||
|
||||
cursor.execute(sql, params)
|
||||
|
||||
if fetch:
|
||||
result = cursor.fetchall()
|
||||
self.log.debug("查询执行完成", 行数=len(result))
|
||||
return result
|
||||
else:
|
||||
affected_rows = cursor.rowcount
|
||||
conn.commit() # 立即提交
|
||||
self.log.debug("更新执行完成", 受影响行数=affected_rows)
|
||||
return affected_rows
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("SQL执行失败",
|
||||
sql=sql,
|
||||
params=params,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
def table_exists(self, table_name: str) -> bool:
|
||||
"""检查表是否存在(原有逻辑完全保留)"""
|
||||
sql = """
|
||||
SELECT COUNT(*) as count
|
||||
FROM `information_schema`.`tables`
|
||||
WHERE `table_schema` = %s \
|
||||
AND `table_name` = %s \
|
||||
"""
|
||||
|
||||
params = (self.config['database'], table_name)
|
||||
|
||||
try:
|
||||
result = self.execute_sql(sql, params, fetch=True)
|
||||
exists = result[0][0] > 0 # 适配元组结果
|
||||
self.log.debug("检查表是否存在",
|
||||
表=table_name,
|
||||
存在=exists)
|
||||
return exists
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def drop_table(self, table_name: str) -> bool:
|
||||
"""删除表(原有逻辑完全保留)"""
|
||||
if not self.table_exists(table_name):
|
||||
self.log.warning("表不存在", 表=table_name)
|
||||
return False
|
||||
|
||||
try:
|
||||
self.execute_sql(f"DROP TABLE {table_name}")
|
||||
self.log.info("表删除成功", 表=table_name)
|
||||
return True
|
||||
except Exception as e:
|
||||
self.log.error("删除表失败",
|
||||
表=table_name,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
return False
|
||||
|
||||
def validate_connection(self) -> bool:
|
||||
"""验证连接是否有效(原有逻辑完全保留)"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("SELECT 1")
|
||||
return cursor.fetchone()[0] == 1
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# 平台特定的默认配置(原有逻辑完全保留)
|
||||
def get_default_config():
|
||||
"""获取各平台默认配置"""
|
||||
current_platform = platform.system()
|
||||
|
||||
base_config = {
|
||||
'host': 'localhost',
|
||||
'port': 3306,
|
||||
'user': 'root',
|
||||
'password': '123123',
|
||||
'database': 'intelligence_system',
|
||||
}
|
||||
|
||||
if current_platform == 'Windows':
|
||||
return {**base_config,
|
||||
'connect_timeout': 10,
|
||||
'read_timeout': 30,
|
||||
'write_timeout': 30
|
||||
}
|
||||
elif current_platform == 'Darwin':
|
||||
return {
|
||||
**base_config,
|
||||
'connect_timeout': 15,
|
||||
'read_timeout': 60,
|
||||
'write_timeout': 60,
|
||||
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
|
||||
}
|
||||
else: # Linux和其他平台
|
||||
return {** base_config,
|
||||
'connect_timeout': 15,
|
||||
'read_timeout': 60,
|
||||
'write_timeout': 60
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 使用示例(原有逻辑完全保留)
|
||||
db = MySQLAgent(get_default_config())
|
||||
|
||||
# 测试连接
|
||||
if db.validate_connection():
|
||||
print("数据库连接成功")
|
||||
|
||||
# 获取数据库版本
|
||||
version = db.query_to_df("SELECT VERSION() as version")
|
||||
print(f"数据库版本: {version['version'].iloc[0]}")
|
||||
else:
|
||||
print("连接数据库失败")
|
||||
Reference in New Issue
Block a user