优化任务调度说明

This commit is contained in:
z66
2025-10-17 17:59:28 +08:00
commit fd67231866
49 changed files with 300973 additions and 0 deletions
+1
View File
@@ -0,0 +1 @@
from .logger import CrossPlatformLog
+513
View File
@@ -0,0 +1,513 @@
import os
import shutil
import zipfile
import pickle
import pandas as pd
from datetime import datetime
from pathlib import Path, PurePath
from typing import Union, Optional, List, Dict, Any, Callable
from utils.logger import log
class FileHandler:
"""
跨平台文件操作工具类(兼容Windows/macOS/Linux
功能规范:
- 读取文件内容的方法返回DataFrame
- 其他所有方法返回统一格式字典:
{
'success': bool, # 操作是否成功
'message': str, # 操作结果描述
'data': Any # 操作返回的数据(可选)
}
"""
def __init__(self, base_path: Optional[Union[str, Path]] = None):
"""
初始化文件处理器
:param base_path: 基础路径(自动处理跨平台路径格式)
"""
self.base_path = self._normalize_path(base_path) if base_path else None
self.log = log.bind(module=self.__class__.__name__)
def _normalize_path(self, path: Union[str, Path]) -> Path:
"""统一转换为跨平台Path对象"""
return Path(str(path).replace('\\', '/'))
def _resolve_path(self, path: Union[str, Path]) -> Path:
"""解析路径(自动处理跨平台路径)"""
path = self._normalize_path(path)
if not path.is_absolute() and self.base_path:
return self._normalize_path(self.base_path / path)
return path
def _format_result(self,
success: bool,
message: str = "",
data: Optional[Any] = None) -> Dict[str, Any]:
"""统一返回结果格式"""
return {
'success': bool(success),
'message': str(message),
'data': data
}
def read_file(self,
file_path: Union[str, Path],
encoding: str = 'utf-8',
**kwargs) -> pd.DataFrame:
"""
读取文件内容为DataFrame(跨平台兼容)
:param file_path: 文件路径(自动处理跨平台格式)
:param encoding: 文件编码(默认utf-8
:return: 包含文件内容的DataFrame
:raises: 文件读取失败时抛出原始异常
"""
file_path = self._resolve_path(file_path)
try:
ext = self.get_file_extension(file_path)
if ext in ['csv', 'txt']:
df = pd.read_csv(file_path, encoding=encoding, **kwargs)
elif ext in ['xls', 'xlsx']:
df = pd.read_excel(file_path, **kwargs)
elif ext == 'json':
df = pd.read_json(file_path, encoding=encoding, **kwargs)
elif ext in ['pkl', 'pickle']:
# 统一将pickle内容转为DataFrame返回
obj = pd.read_pickle(file_path)
if isinstance(obj, pd.DataFrame):
df = obj
elif isinstance(obj, list):
df = pd.DataFrame(obj)
elif isinstance(obj, dict):
df = pd.DataFrame([obj])
else:
df = pd.DataFrame({'content': [obj]})
elif ext == 'parquet':
df = pd.read_parquet(file_path, **kwargs)
else:
with open(file_path, 'r', encoding=encoding) as f:
return pd.DataFrame({'content': [f.read()]})
self.log.debug(f"文件读取成功 | path={file_path} shape={df.shape}")
return df
except Exception as e:
self.log.error(f"文件读取失败 | path={file_path} error={str(e)}")
raise
def write_file(self,
file_path: Union[str, Path],
data: Union[pd.DataFrame, Dict, List],
encoding: str = 'utf-8',
**kwargs) -> Dict[str, Any]:
"""
写入文件(跨平台兼容)
:param file_path: 目标文件路径
:param data: 要写入的数据(支持DataFrame/dict/list
:param encoding: 文件编码(默认utf-8
:return: 操作结果字典
"""
file_path = self._resolve_path(file_path)
try:
# 自动创建父目录
parent_dir = file_path.parent
if not parent_dir.exists():
self.create_dir(parent_dir)
# 根据扩展名选择写入方式
ext = self.get_file_extension(file_path)
if ext in ['pkl', 'pickle']:
# 直接按原始对象进行pickle序列化
with open(file_path, 'wb') as f:
pickle.dump(data, f)
else:
# 统一数据格式到DataFrame
if isinstance(data, pd.DataFrame):
df = data
else:
df = pd.DataFrame(data if isinstance(data, list) else [data])
if ext in ['csv', 'txt']:
df.to_csv(file_path, encoding=encoding, index=False, **kwargs)
elif ext in ['xls', 'xlsx']:
df.to_excel(file_path, index=False, **kwargs)
elif ext == 'json':
df.to_json(file_path, force_ascii=False, **kwargs)
elif ext == 'parquet':
df.to_parquet(file_path, **kwargs)
else:
with open(file_path, 'w', encoding=encoding) as f:
f.write(str(data))
# 返回成功结果
return self._format_result(
True,
"文件写入成功",
{
'file_path': str(file_path),
'file_size': os.path.getsize(file_path)
}
)
except Exception as e:
return self._format_result(
False,
f"文件写入失败: {str(e)}",
{'file_path': str(file_path)}
)
def file_exists(self, file_path: Union[str, Path]) -> Dict[str, Any]:
"""
检查文件是否存在(跨平台兼容)
:return: 包含exists字段的结果字典
"""
file_path = self._resolve_path(file_path)
exists = file_path.is_file()
msg = f"文件{'' if exists else ''}存在: {file_path}"
return self._format_result(True, msg, {'exists': exists})
def dir_exists(self, dir_path: Union[str, Path]) -> Dict[str, Any]:
"""
检查目录是否存在(跨平台兼容)
:return: 包含exists字段的结果字典
"""
dir_path = self._resolve_path(dir_path)
exists = dir_path.is_dir()
msg = f"目录{'' if exists else ''}存在: {dir_path}"
return self._format_result(True, msg, {'exists': exists})
def create_dir(self, dir_path: Union[str, Path]) -> Dict[str, Any]:
"""
创建目录(跨平台兼容)
:return: 包含path字段的结果字典
"""
dir_path = self._resolve_path(dir_path)
try:
dir_path.mkdir(parents=True, exist_ok=True)
# Windows系统需要额外设置权限
if os.name == 'nt':
try:
os.chmod(dir_path, 0o777)
except:
pass
return self._format_result(True, "目录创建成功", {'path': str(dir_path)})
except Exception as e:
return self._format_result(False, f"目录创建失败: {str(e)}", {'path': str(dir_path)})
def delete_file(self, file_path: Union[str, Path]) -> Dict[str, Any]:
"""
删除文件(跨平台兼容)
:return: 包含path字段的结果字典
"""
file_path = self._resolve_path(file_path)
try:
if not file_path.exists():
return self._format_result(False, "文件不存在", {'path': str(file_path)})
file_path.unlink()
return self._format_result(True, "文件删除成功", {'path': str(file_path)})
except Exception as e:
return self._format_result(False, f"文件删除失败: {str(e)}", {'path': str(file_path)})
def delete_dir(self, dir_path: Union[str, Path]) -> Dict[str, Any]:
"""
删除目录及其内容(跨平台兼容)
:return: 包含path字段的结果字典
"""
dir_path = self._resolve_path(dir_path)
try:
if not dir_path.exists():
return self._format_result(False, "目录不存在", {'path': str(dir_path)})
shutil.rmtree(dir_path)
return self._format_result(True, "目录删除成功", {'path': str(dir_path)})
except Exception as e:
return self._format_result(False, f"目录删除失败: {str(e)}", {'path': str(dir_path)})
def list_files(self,
dir_path: Union[str, Path],
recursive: bool = False,
pattern: str = '*') -> Dict[str, Any]:
"""
列出目录中的文件(跨平台兼容)
:param recursive: 是否递归查找
:param pattern: 文件匹配模式(如*.txt
:return: 包含files字段的结果字典
"""
dir_path = self._resolve_path(dir_path)
try:
if recursive:
files = list(dir_path.rglob(pattern))
else:
files = list(dir_path.glob(pattern))
file_info = [
{
'path': str(f),
'name': f.name,
'size': f.stat().st_size,
'modified': datetime.fromtimestamp(f.stat().st_mtime).isoformat(),
'is_dir': f.is_dir()
} for f in files if f.is_file() # 只返回文件,不包括目录
]
return self._format_result(
True,
f"找到 {len(file_info)} 个文件",
{'files': file_info}
)
except Exception as e:
return self._format_result(
False,
f"列出文件失败: {str(e)}",
{'files': []}
)
def get_file_extension(self, file_path: Union[str, Path]) -> str:
"""
获取文件扩展名(跨平台兼容)
:return: 小写且不带点的扩展名(如 'jpg'
"""
file_path = self._resolve_path(file_path)
ext = file_path.suffix.lower().lstrip('.')
self.log.trace(f"获取文件扩展名 | path={file_path} ext={ext}")
return ext
def copy_file(self,
src_path: Union[str, Path],
dst_path: Union[str, Path]) -> Dict[str, Any]:
"""
复制文件(跨平台兼容)
:return: 包含source和destination字段的结果字典
"""
src_path = self._resolve_path(src_path)
dst_path = self._resolve_path(dst_path)
try:
if not src_path.exists():
return self._format_result(
False,
"源文件不存在",
{
'source': str(src_path),
'destination': str(dst_path)
}
)
# 确保目标目录存在
self.create_dir(dst_path.parent)
shutil.copy2(src_path, dst_path)
return self._format_result(
True,
"文件复制成功",
{
'source': str(src_path),
'destination': str(dst_path),
'file_size': dst_path.stat().st_size
}
)
except Exception as e:
return self._format_result(
False,
f"文件复制失败: {str(e)}",
{
'source': str(src_path),
'destination': str(dst_path)
}
)
def move_file(self,
src_path: Union[str, Path],
dst_path: Union[str, Path]) -> Dict[str, Any]:
"""
移动/重命名文件(跨平台兼容)
:return: 包含source和destination字段的结果字典
"""
src_path = self._resolve_path(src_path)
dst_path = self._resolve_path(dst_path)
try:
if not src_path.exists():
return self._format_result(
False,
"源文件不存在",
{
'source': str(src_path),
'destination': str(dst_path)
}
)
# 确保目标目录存在
self.create_dir(dst_path.parent)
shutil.move(src_path, dst_path)
return self._format_result(
True,
"文件移动成功",
{
'source': str(src_path),
'destination': str(dst_path)
}
)
except Exception as e:
return self._format_result(
False,
f"文件移动失败: {str(e)}",
{
'source': str(src_path),
'destination': str(dst_path)
}
)
def zip_files(self,
file_paths: List[Union[str, Path]],
zip_path: Union[str, Path]) -> Dict[str, Any]:
"""
压缩多个文件到zip(跨平台兼容)
:param file_paths: 要压缩的文件路径列表
:param zip_path: 目标zip文件路径
:return: 包含zip_path和file_count字段的结果字典
"""
zip_path = self._resolve_path(zip_path)
try:
# 确保目标目录存在
self.create_dir(zip_path.parent)
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
file_count = 0
for file_path in file_paths:
file_path = self._resolve_path(file_path)
if file_path.exists():
zipf.write(file_path, file_path.name)
file_count += 1
return self._format_result(
True,
"文件压缩成功",
{
'zip_path': str(zip_path),
'file_count': file_count,
'zip_size': os.path.getsize(zip_path)
}
)
except Exception as e:
return self._format_result(
False,
f"文件压缩失败: {str(e)}",
{
'zip_path': str(zip_path)
}
)
def unzip(self,
zip_path: Union[str, Path],
extract_to: Optional[Union[str, Path]] = None) -> Dict[str, Any]:
"""
解压zip文件(跨平台兼容)
:param extract_to: 解压目标目录(默认为zip文件所在目录)
:return: 包含extract_to和file_count字段的结果字典
"""
zip_path = self._resolve_path(zip_path)
extract_to = self._resolve_path(extract_to) if extract_to else zip_path.parent
try:
if not zip_path.exists():
return self._format_result(
False,
"ZIP文件不存在",
{
'zip_path': str(zip_path),
'extract_to': str(extract_to)
}
)
# 确保目标目录存在
self.create_dir(extract_to)
with zipfile.ZipFile(zip_path, 'r') as zipf:
file_list = zipf.namelist()
zipf.extractall(extract_to)
return self._format_result(
True,
"文件解压成功",
{
'extract_to': str(extract_to),
'file_count': len(file_list)
}
)
except Exception as e:
return self._format_result(
False,
f"文件解压失败: {str(e)}",
{
'zip_path': str(zip_path),
'extract_to': str(extract_to)
}
)
# ---------------------------- 测试用例 ----------------------------
if __name__ == "__main__":
# 初始化处理器(自动处理跨平台路径)
project_root = next(p for p in Path(__file__).resolve().parents if
(p / '.git').exists() or (p / 'pyproject.toml').exists() or (p / 'requirements.txt').exists())
handler = FileHandler(project_root / "test")
# 测试路径标准化
test_paths = [
"normal/path",
"windows\\style\\path",
"mixed/path\\with\\both"
]
print("=== 路径标准化测试 ===")
for path in test_paths:
resolved = handler._resolve_path(path)
print(f"原始路径: {path} -> 标准化: {resolved} (类型: {type(resolved)})")
# 测试目录操作
print("\n=== 目录操作测试 ===")
dir_result = handler.create_dir("test_dir")
print(dir_result)
# 测试文件操作
print("\n=== 文件操作测试 ===")
test_data = [{"name": "Alice", "age": 25}, {"name": "Bob", "age": 30}]
write_result = handler.write_file("test_dir/data.json", test_data)
print(write_result)
# 测试文件读取
try:
df = handler.read_file("test_dir/data.json")
print("\n读取文件内容:")
print(df)
except Exception as e:
print(f"\n文件读取失败: {str(e)}")
# 测试列表文件
print("\n=== 文件列表测试 ===")
list_result = handler.list_files("test_dir")
print(list_result)
# 测试压缩解压
print("\n=== 压缩解压测试 ===")
zip_result = handler.zip_files(
["test_dir/data.json"],
"test_archive.zip"
)
print(zip_result)
unzip_result = handler.unzip(
"test_archive.zip",
"extracted_files"
)
print(unzip_result)
# 清理测试数据
print("\n=== 清理测试数据 ===")
print(handler.delete_file("test_dir/data.json"))
print(handler.delete_dir("test_dir"))
print(handler.delete_file("test_archive.zip"))
print(handler.delete_dir("extracted_files"))
+128
View File
@@ -0,0 +1,128 @@
import os
import sys
from pathlib import Path
from loguru import logger
import platform
from datetime import datetime
import zipfile
class CrossPlatformLog:
"""跨平台日志系统(支持Linux/Windows/Mac"""
def __init__(self):
self.log_dir = self._get_log_dir()
self._setup_logger()
def _get_log_dir(self):
"""获取跨平台日志目录(相对路径)"""
base_dir = Path(__file__).parent.parent # 项目根目录
log_dir = base_dir / "logs"
# 自动创建日志目录
log_dir.mkdir(exist_ok=True)
# Windows特殊权限处理
if platform.system() == "Windows":
try:
os.chmod(log_dir, 0o777) # 确保写入权限
except:
pass
return log_dir
def _setup_logger(self):
"""配置跨平台日志处理器"""
logger.remove() # 清除默认配置
# 统一控制台输出格式
logger.add(
sys.stdout,
level="INFO",
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{module}</cyan> - <level>{message}</level>",
filter=lambda record: record["level"].no >= 20 # INFO及以上级别
)
# 主日志文件(兼容所有平台路径)
self._add_main_log()
# 错误日志单独存储
self._add_error_log()
def _add_main_log(self):
"""主日志文件配置"""
main_log = self.log_dir / "application.log"
logger.add(
str(main_log),
rotation="20 MB",
compression=self._compress_log,
encoding="utf-8",
level="DEBUG",
# 👇 增加 {extra} 输出,并美化结构
# format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {module}:{line} - {message}{extra_output}",
retention="30 days",
enqueue=True,
# 👇 动态处理 extra 字段为可读格式
format=self._format_with_extra, # 使用自定义格式函数
)
def _format_with_extra(self, record):
# 构造 extra 的可读字符串
extra_str = ""
if record["extra"]:
extra_items = []
for key, value in record["extra"].items():
if key == "extra_output": # 跳过自己,避免递归
continue
value_repr = repr(value)
# 对于错误信息,增加截断长度限制,避免丢失重要信息
if key in ["error", "error_message", "sql", "params"]:
if len(value_repr) > 500:
value_repr = value_repr[:497] + "..."
elif len(value_repr) > 200:
value_repr = value_repr[:197] + "..."
extra_items.append(f"\n{key}: {value_repr}")
extra_str = "".join(extra_items)
# 👉 直接将 extra_str 写入 message 或附加字段
record["extra"]["extra_output"] = extra_str
# ✅ 关键:返回的 format 字符串不再引用 {extra_output},而是使用 {extra[extra_output]}
return "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {module}:{line} - {message}{extra[extra_output]}\n"
def _add_error_log(self):
"""错误日志专用配置"""
error_log = self.log_dir / "errors.log"
logger.add(
str(error_log),
level="ERROR",
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | ERROR | {module}:{line} - {message}{extra[extra_output]}\n{exception}",
rotation="10 MB",
retention="90 days",
enqueue=True
)
@staticmethod
def _compress_log(log_path):
"""通用日志压缩方法(兼容所有平台)"""
if not os.path.exists(log_path):
return
try:
zip_path = f"{log_path}.{datetime.now().strftime('%Y%m%d')}.zip"
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
zipf.write(log_path, arcname=os.path.basename(log_path))
os.remove(log_path)
return zip_path
except Exception as e:
print(f"日志压缩失败: {str(e)}")
return log_path # 返回原文件路径继续使用
@classmethod
def get_logger(cls, module_name=None):
"""获取模块专属日志器"""
return logger.bind(module=module_name or "__main__")
# 初始化全局日志器
log = CrossPlatformLog().get_logger()
+383
View File
@@ -0,0 +1,383 @@
import os
import sys
import platform
import threading
from typing import List, Dict, Optional, BinaryIO, Tuple, Any
from datetime import datetime, timedelta
import hashlib
from io import BytesIO
from minio import Minio
from minio.error import S3Error, MinioException
from utils.logger import log
class MinIOAgent:
"""
全平台兼容的MinIO对象存储操作类
支持Windows/macOS/Linux系统,提供对象存储的上传、下载、查询等功能
专注于二进制数据处理,返回元数据用于与MySQL关联
"""
_instance = None # 单例模式实例
_lock = threading.Lock() # 线程锁,保证单例线程安全
def __new__(cls, *args, **kwargs):
"""单例模式实现,确保全局只有一个实例"""
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, config: dict):
"""
初始化MinIO连接
参数:
config (dict): MinIO配置字典,包含以下键:
- endpoint: 服务端点(例:'localhost:9000'
- access_key: 访问密钥
- secret_key: 密钥
- [可选] secure: 是否使用SSL(默认False
- [可选] region: 区域
- [可选] timeout: 超时时间(秒,默认30)
"""
# 避免重复初始化
if hasattr(self, '_client') and self._client:
return
# 验证必要配置参数
required_keys = ['endpoint', 'access_key', 'secret_key']
if not all(key in config for key in required_keys):
raise ValueError(f"MinIO配置缺少必要参数,需要: {required_keys}")
# 整合配置,设置默认值
self.config = {
'endpoint': config['endpoint'],
'access_key': config['access_key'],
'secret_key': config['secret_key'],
'secure': config.get('secure', False),
'region': config.get('region'),
'timeout': config.get('timeout', 30)
}
# 初始化日志,绑定当前平台信息
current_platform = platform.system()
self.log = log.bind(module=f"MinIOAgent({current_platform})")
# 创建客户端实例
self._client = self._create_client()
# 验证连接是否有效
self._verify_connection()
def _create_client(self) -> Minio:
"""创建MinIO客户端实例"""
try:
client = Minio(
endpoint=self.config['endpoint'],
access_key=self.config['access_key'],
secret_key=self.config['secret_key'],
secure=self.config['secure'],
region=self.config['region']
)
self.log.info("MinIO客户端创建成功")
return client
except Exception as e:
self.log.critical("创建MinIO客户端失败", 错误=str(e), exc_info=True)
raise
def _verify_connection(self) -> None:
"""验证与MinIO服务的连接是否正常"""
try:
# 通过列出存储桶来验证连接
self._client.list_buckets()
self.log.info(f"成功连接到MinIO服务:{self.config['endpoint']}")
except Exception as e:
self.log.critical("连接验证失败", 错误=str(e), exc_info=True)
raise
def create_bucket(self, bucket_name: str) -> bool:
"""
创建存储桶(如不存在)
参数:
bucket_name: 存储桶名称
返回:
是否成功创建(或已存在)
"""
try:
if not self._client.bucket_exists(bucket_name):
self._client.make_bucket(bucket_name)
self.log.info(f"存储桶创建成功:{bucket_name}")
return True
self.log.debug(f"存储桶已存在:{bucket_name}")
return True
except MinioException as e:
self.log.error(f"创建存储桶 {bucket_name} 失败", 错误=str(e), exc_info=True)
return False
def upload_bytes(self, bucket: str, object_name: str, data: bytes) -> Dict[str, Any]:
"""
上传二进制数据至MinIO
参数:
bucket: 存储桶名称
object_name: 对象名称(路径)
data: 二进制数据
返回:
包含元数据的字典:
- bucket: 存储桶名称
- object_name: 对象路径
- size: 数据大小(字节)
- etag: 服务器生成的哈希值
- content_type: 内容类型
- upload_time: 上传时间(UTC)
- local_hash: 本地计算的MD5哈希
"""
if not data:
raise ValueError("上传数据不能为空")
# 确保存储桶存在
self.create_bucket(bucket)
try:
# 计算本地哈希(用于数据完整性校验)
local_hash = hashlib.md5(data).hexdigest()
# 上传数据
result = self._client.put_object(
bucket_name=bucket,
object_name=object_name,
data=BytesIO(data),
length=len(data),
content_type=self._guess_content_type(object_name)
)
# 构建元数据
metadata = {
'bucket': bucket,
'object_name': object_name,
'size': len(data),
'etag': result.etag,
'content_type': result.content_type,
'upload_time': datetime.utcfromtimestamp(result.last_modified.timestamp()),
'local_hash': local_hash
}
self.log.info(
"文件上传成功",
存储桶=bucket,
对象名称=object_name,
大小=len(data)
)
return metadata
except MinioException as e:
self.log.error(
"文件上传失败",
存储桶=bucket,
对象名称=object_name,
错误=str(e),
exc_info=True
)
raise
def download_file(self, bucket: str, object_name: str, local_path: str) -> Dict[str, Any]:
"""
从MinIO下载文件至本地
参数:
bucket: 存储桶名称
object_name: 对象名称(路径)
local_path: 本地保存路径
返回:
包含下载信息的字典:
- local_path: 本地路径
- size: 文件大小
- download_time: 下载时间
"""
try:
# 创建父目录(如果不存在)
os.makedirs(os.path.dirname(local_path), exist_ok=True)
# 下载文件
start_time = datetime.now()
self._client.fget_object(bucket, object_name, local_path)
download_time = datetime.now() - start_time
# 获取文件信息
stat = os.stat(local_path)
result = {
'local_path': local_path,
'size': stat.st_size,
'download_time': download_time.total_seconds(),
'downloaded_at': datetime.now()
}
self.log.info(
"文件下载成功",
存储桶=bucket,
对象名称=object_name,
本地路径=local_path,
大小=stat.st_size
)
return result
except MinioException as e:
self.log.error(
"文件下载失败",
存储桶=bucket,
对象名称=object_name,
错误=str(e),
exc_info=True
)
raise
except IOError as e:
self.log.error(
"本地文件操作失败",
本地路径=local_path,
错误=str(e),
exc_info=True
)
raise
def get_presigned_url(self, bucket: str, object_name: str, expires: int = 3600) -> Dict[str, str]:
"""
生成临时访问URL
参数:
bucket: 存储桶名称
object_name: 对象名称(路径)
expires: 过期时间(秒),默认3600秒
返回:
包含URL和过期信息的字典
"""
try:
url = self._client.presigned_get_object(
bucket_name=bucket,
object_name=object_name,
expires=expires
)
result = {
'presigned_url': url,
'expires_in': expires,
'expires_at': datetime.now() + timedelta(seconds=expires),
'bucket': bucket,
'object_name': object_name
}
self.log.debug(
"预签名URL生成成功",
存储桶=bucket,
对象名称=object_name,
过期时间=expires
)
return result
except MinioException as e:
self.log.error(
"生成预签名URL失败",
存储桶=bucket,
对象名称=object_name,
错误=str(e),
exc_info=True
)
raise
def list_objects(self, bucket: str, prefix: str = "") -> List[Dict[str, Any]]:
"""
查询指定前缀的对象列表及元数据
参数:
bucket: 存储桶名称
prefix: 对象路径前缀
返回:
对象信息列表,每个对象包含:
- bucket: 存储桶
- object_name: 对象名称
- size: 大小
- last_modified: 最后修改时间
- etag: 哈希值
- content_type: 内容类型
"""
try:
objects = self._client.list_objects(
bucket_name=bucket,
prefix=prefix,
recursive=True
)
result = []
for obj in objects:
# 获取详细元数据
stat = self._client.stat_object(bucket, obj.object_name)
result.append({
'bucket': bucket,
'object_name': obj.object_name,
'size': obj.size,
'last_modified': obj.last_modified,
'etag': stat.etag,
'content_type': stat.content_type
})
self.log.info(
"对象列表查询成功",
存储桶=bucket,
前缀=prefix,
数量=len(result)
)
return result
except MinioException as e:
self.log.error(
"查询对象列表失败",
存储桶=bucket,
前缀=prefix,
错误=str(e),
exc_info=True
)
raise
def delete_object(self, bucket: str, object_name: str) -> bool:
"""
删除指定对象
参数:
bucket: 存储桶名称
object_name: 对象名称(路径)
返回:
是否删除成功
"""
try:
self._client.remove_object(bucket, object_name)
self.log.info(
"对象删除成功",
存储桶=bucket,
对象名称=object_name
)
return True
except MinioException as e:
self.log.error(
"删除对象失败",
存储桶=bucket,
对象名称=object_name,
错误=str(e),
exc_info=True
)
return False
@staticmethod
def _guess_content_type(object_name: str) -> str:
"""根据文件名猜测内容类型"""
from mimetypes import guess_type
mime_type, _ = guess_type(object_name)
return mime_type or 'application/octet-stream' # 默认二进制流类型
+722
View File
@@ -0,0 +1,722 @@
import os
import sys
import platform
import pandas as pd
import pymysql
import json
import numpy as np
from pymysql import cursors
from pymysql.err import MySQLError
from typing import Union, List, Dict, Any, Optional, Tuple, Literal
import threading
from datetime import datetime
from pathlib import Path
# 导入日志系统
from utils.logger import log
class MySQLAgent:
"""
全平台兼容的MySQL数据库操作类
支持Windows/macOS/Linux系统
配置参数从外部传入,不使用连接池和事务管理
"""
_instance = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, config: dict):
"""初始化MySQL数据库连接(原有逻辑完全保留)"""
if hasattr(self, 'config') and self.config:
return
# 基础配置校验
required_keys = ['host', 'port', 'user', 'password', 'database']
if not all(key in config for key in required_keys):
log.warning(f"数据库配置缺少必要参数,当前数据库链接信息为:{config}")
raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}")
self.config = {
'host': config['host'],
'port': config['port'],
'user': config['user'],
'password': config['password'],
'database': config['database'],
'charset': config.get('charset', 'utf8mb4'),
'autocommit': True,
'connect_timeout': config.get('connect_timeout', 10),
'read_timeout': config.get('read_timeout', 30),
'write_timeout': config.get('write_timeout', 30),
'ssl': config.get('ssl')
}
# 初始化日志
current_platform = platform.system()
self.log = log.bind(module=f"MySQLAgent({current_platform})")
def get_connection(self) -> pymysql.connections.Connection:
"""获取数据库连接(原有逻辑完全保留)"""
try:
conn = pymysql.connect(** self.config)
# 为连接添加 character_set_name 方法
if not hasattr(conn, 'character_set_name'):
def _character_set_name():
return self.config.get('charset', 'utf8mb4')
conn.character_set_name = _character_set_name
# macOS需要特殊处理SSL
if platform.system() == 'Darwin' and self.config.get('ssl'):
conn.ping(reconnect=True)
self.log.trace("获取数据库连接成功")
return conn
except Exception as e:
error_msg = str(e)
if platform.system() == 'Windows' and "timed out" in error_msg:
self.log.warning("Windows连接超时,正在重试...")
return self._retry_connection()
self.log.error("连接失败",
error=error_msg,
error_type=type(e).__name__,
host=self.config.get('host'),
port=self.config.get('port'),
database=self.config.get('database'),
exc_info=True)
raise
def _retry_connection(self, max_retries: int = 3) -> Any | None:
"""Windows平台连接重试机制(原有逻辑完全保留)"""
for attempt in range(max_retries):
try:
conn = pymysql.connect(**self.config)
self.log.info(f"经过 {attempt + 1} 次尝试后成功建立连接")
return conn
except Exception:
if attempt == max_retries - 1:
raise
import time
time.sleep(1)
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
parse_dates: Union[List[str], bool] = True,is_print = True) -> pd.DataFrame:
"""执行SQL查询并返回DataFrame(原有逻辑完全保留)"""
try:
self.log.debug("执行SQL查询", sql=sql)
# 获取连接并确保字符集方法存在
conn = self.get_connection()
# 创建SQLAlchemy引擎
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
engine = create_engine(
"mysql+pymysql://",
creator=lambda: conn,
poolclass=StaticPool,
connect_args={'charset': self.config.get('charset', 'utf8mb4')}
)
# 执行查询
df = pd.read_sql(sql, engine, params=params, parse_dates=parse_dates)
if is_print:
self.log.info("查询执行成功", 行数=len(df))
return df
except Exception as e:
self.log.error("SQL查询失败",
sql=sql,
params=params,
error=str(e),
error_type=type(e).__name__,
exc_info=True)
raise
finally:
if 'engine' in locals():
engine.dispose()
def insert_from_df(self, table_name: str, df: pd.DataFrame,
chunk_size: int = 1000, replace: bool = False,
ignore_duplicates: bool = None) -> int:
"""
兼容旧接口的通用插入方法:保留replace参数,同时支持新的ignore_duplicates
自动处理重复数据,对所有数据源通用,插入失败的数据会通过日志记录
"""
# 【兼容性处理】如果未指定ignore_duplicates,用replace参数推导
if ignore_duplicates is None:
ignore_duplicates = not replace # 旧逻辑中replace=True表示替换,即不忽略重复
if df.empty:
self.log.warning("尝试插入空的DataFrame", table=table_name)
return 0
conn = None
cursor = None
total_inserted = 0
total_duplicates = 0
total_failed = 0
failed_records = [] # 存储所有失败的记录
try:
# 1. 建立数据库连接
conn = self.get_connection()
cursor = conn.cursor()
self.log.debug(f"已建立连接,准备插入数据到 {table_name}")
# 2. 获取数据库表的实际列名
cursor.execute(f"SHOW COLUMNS FROM `{table_name}`")
columns_info = cursor.fetchall()
db_columns = [col[0] for col in columns_info]
self.log.debug(f"{table_name} 包含以下列:{db_columns}")
# 3. 数据预处理:统一处理空值
cleaned_df = df.replace(
[None, np.nan, pd.NA, 'nan', 'NaN', 'NAN', ''],
None
).copy()
# 4. 字段匹配:只保留与数据库匹配的列
df_columns = cleaned_df.columns.tolist()
matched_columns = [col for col in df_columns if col in db_columns]
unmatched_columns = [col for col in df_columns if col not in db_columns]
if unmatched_columns:
self.log.warning(
f"{table_name} 中存在不匹配的列,已自动丢弃",
unmatched_columns=unmatched_columns,
count=len(unmatched_columns)
)
if not matched_columns:
self.log.warning(f"{table_name} 没有匹配的列,终止插入操作")
return 0
filtered_df = cleaned_df[matched_columns].copy()
total_to_insert = len(filtered_df)
self.log.debug(
f"{table_name} 的过滤后DataFrame:共 {total_to_insert} 行待插入"
)
# 5. 处理复杂类型(dict/list转JSON
for col in filtered_df.columns:
has_complex_type = filtered_df[col].apply(
lambda x: isinstance(x, (dict, list)) if x is not None else False
).any()
if has_complex_type:
self.log.debug(f"{table_name} 中的 {col} 列包含复杂类型,正在转换为JSON")
filtered_df.loc[:, col] = filtered_df[col].apply(
lambda x: json.dumps(x, ensure_ascii=False) if x is not None else x
)
# 6. 构建通用插入SQL
columns_str = ', '.join([f"`{col}`" for col in filtered_df.columns])
placeholders = ', '.join(['%s'] * len(filtered_df.columns))
insert_sql = f"INSERT INTO `{table_name}` ({columns_str}) VALUES ({placeholders})"
self.log.trace(f"为表 {table_name} 生成的插入SQL{insert_sql}")
# 7. 逐条插入(确保能捕获单条重复错误)
records = filtered_df.to_dict('records')
indices = filtered_df.index.tolist()
for i, (record, idx) in enumerate(zip(records, indices)):
try:
data = tuple(record[col] for col in filtered_df.columns)
cursor.execute(insert_sql, data)
total_inserted += 1
if (i + 1) % 100 == 0:
self.log.trace(
f"已向表 {table_name} 插入 {i + 1}/{total_to_insert} 行数据"
)
except MySQLError as e:
# 8. 捕获重复错误(MySQL错误码1062)
if e.args[0] == 1062:
total_duplicates += 1
short_record = {
k: (str(v)[:100] + '...') if isinstance(v, (str, dict, list)) else v
for k, v in record.items()
}
self.log.warning(
f"{table_name} 中跳过重复记录",
index=idx,
error_message=e.args[1],
record=short_record
)
# 记录重复的记录
failed_records.append({
'index': idx,
'type': 'duplicate',
'error_code': e.args[0],
'error_message': e.args[1],
'record': record
})
if not ignore_duplicates:
raise
else:
# 其他数据库错误
total_failed += 1
# 记录失败的记录详情
failed_records.append({
'index': idx,
'type': 'error',
'error_code': e.args[0],
'error_message': e.args[1],
'record': record
})
self.log.error(
f"{table_name} 插入记录失败",
index=idx,
error_code=e.args[0],
error_message=e.args[1],
record=record # 完整记录写入日志
)
if not ignore_duplicates:
raise
# 提交事务
conn.commit()
# 9. 插入结果统计,包括失败记录汇总
self.log.info(
f"{table_name} 插入结果汇总",
total_to_insert=total_to_insert,
total_inserted=total_inserted,
total_duplicates=total_duplicates,
total_failed=total_failed,
failed_records_count=len(failed_records)
)
# 单独记录所有失败的数据详情
if failed_records:
self.log.error(
f"{table_name} 插入失败记录详情",
failed_records_summary=[
{
'index': r['index'],
'type': r['type'],
'error_code': r['error_code'],
'error_message': r['error_message']
} for r in failed_records
],
# 完整记录可以作为调试信息单独记录,避免日志过大
detailed_failed_records=failed_records
)
return total_inserted
except Exception as e:
if conn:
conn.rollback()
self.log.error(f"{table_name} 批量插入失败",
error=str(e),
error_type=type(e).__name__,
table_name=table_name,
total_records=len(df) if not df.empty else 0,
exc_info=True)
# 记录事务回滚时的失败记录
if failed_records:
self.log.error(
f"{table_name} 事务回滚,已失败的记录",
failed_records=failed_records,
failed_count=len(failed_records)
)
raise
finally:
if cursor:
cursor.close()
if conn:
conn.close()
def _get_primary_key(self, table_name: str, cursor) -> Optional[str]:
"""【新增辅助方法】获取表的主键(用于replace逻辑的去重)"""
try:
cursor.execute("""
SELECT COLUMN_NAME
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
WHERE TABLE_SCHEMA = %s
AND TABLE_NAME = %s
AND CONSTRAINT_NAME = 'PRIMARY'
""", (self.config['database'], table_name))
result = cursor.fetchone()
return result[0] if result else None
except Exception as e:
self.log.warning(f"获取表 {table_name} 的主键失败", error=str(e))
return None
def _get_table_detailed_info(self, table_name: str) -> Dict[str, Dict[str, Any]]:
"""获取表的详细结构信息(原有逻辑完全保留,供其他方法调用)"""
sql = """
SELECT column_name, data_type, character_maximum_length
FROM information_schema.columns
WHERE table_schema = %s \
AND table_name = %s \
"""
params = (self.config['database'], table_name)
try:
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(sql, params)
result = cursor.fetchall()
# 强制转换为列表,避免游标类型导致的解析问题
result_list = list(result)
if not result_list:
self.log.error("未在表中找到任何列", =table_name)
return {}
schema = {}
for row in result_list:
# 确保正确提取字段名(兼容元组格式)
col_name = str(row[0]).strip() # 强制转为字符串并去空格
data_type = str(row[1]).strip()
max_length = row[2] if row[2] else None
schema[col_name] = {
'type': data_type,
'max_length': max_length
}
self.log.debug("成功获取表结构信息",
=table_name,
=list(schema.keys()))
return schema
finally:
cursor.close()
conn.close()
except Exception as e:
self.log.error("获取表详细信息失败",
=table_name,
error=str(e))
raise
def _validate_and_clean_data(self, df: pd.DataFrame, table_name: str,
table_schema: Dict[str, Dict[str, Any]]) -> pd.DataFrame:
"""数据校验与清洗(原有逻辑完全保留,供其他方法调用)"""
# 1. 字段过滤:只保留表中存在的字段
df_columns = df.columns.tolist()
table_columns = list(table_schema.keys())
valid_columns = [col for col in df_columns if col in table_columns]
invalid_columns = [col for col in df_columns if col not in table_columns]
if invalid_columns:
self.log.warning("丢弃表中不存在的无效列",
=table_name,
无效列=invalid_columns,
数量=len(invalid_columns))
cleaned_df = df[valid_columns].copy()
if cleaned_df.empty:
return cleaned_df
# 2. 处理每个字段的数据
for col in valid_columns:
col_info = table_schema[col]
data_type = col_info['type']
max_length = col_info['max_length']
# 2.1 处理空值
if cleaned_df[col].isnull().any():
# 根据字段类型设置默认值
default_value = '' if data_type in ['varchar', 'char', 'text'] else None
cleaned_df[col].fillna(default_value, inplace=True)
self.log.debug("替换空值",
=table_name,
=col,
默认值=default_value,
数量=cleaned_df[col].isnull().sum())
# 2.2 处理字符串类型的超长字段
if data_type in ['varchar', 'char'] and max_length:
# 确保是字符串类型
cleaned_df[col] = cleaned_df[col].astype(str)
# 截断超长内容
too_long_mask = cleaned_df[col].str.len() > max_length
if too_long_mask.any():
cleaned_df.loc[too_long_mask, col] = cleaned_df.loc[too_long_mask, col].str.slice(0, max_length)
self.log.warning("截断超长值",
=table_name,
=col,
最大长度=max_length,
数量=too_long_mask.sum())
# 2.3 处理日期时间类型
if data_type in ['datetime', 'timestamp']:
try:
# 尝试转换为datetime类型
cleaned_df[col] = pd.to_datetime(cleaned_df[col])
except Exception as e:
self.log.warning("转换为datetime失败,使用当前时间替代",
=table_name,
=col,
错误=str(e))
# 转换失败的用当前时间替代
invalid_mask = pd.to_datetime(cleaned_df[col], errors='coerce').isna()
cleaned_df.loc[invalid_mask, col] = datetime.now()
return cleaned_df
def update_from_df(self, table_name: str, df: pd.DataFrame,
key_columns: Union[str, List[str]]) -> int:
"""使用DataFrame数据更新数据库表(原有逻辑完全保留)"""
if df.empty:
self.log.warning("尝试使用空的DataFrame进行更新", =table_name)
return 0
self.log.debug("准备从DataFrame更新表数据",
=table_name,
关键字列=key_columns,
行数=len(df))
try:
if isinstance(key_columns, str):
key_columns = [key_columns]
总更新数 = 0
with self.get_connection() as conn:
with conn.cursor() as cursor:
# 获取表结构信息
table_info = self._get_table_detailed_info(table_name)
columns = [col for col in df.columns if col in table_info]
# 构建UPDATE语句模板
set_clause = ', '.join([f"{col}=%s" for col in columns if col not in key_columns])
where_clause = ' AND '.join([f"{col}=%s" for col in key_columns])
if not set_clause:
self.log.warning("没有可更新的列", =table_name)
return 0
update_sql = f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}"
self.log.trace("生成的更新SQL", sql=update_sql)
# 准备数据
update_data = []
for _, row in df.iterrows():
set_values = [row[col] for col in columns if col not in key_columns]
key_values = [row[col] for col in key_columns]
update_data.append(tuple(set_values + key_values))
# 执行批量更新
cursor.executemany(update_sql, update_data)
总更新数 = cursor.rowcount
conn.commit()
self.log.info("数据更新成功",
=table_name,
更新行数=总更新数)
return 总更新数
except Exception as e:
self.log.error("数据更新失败",
=table_name,
error=str(e),
exc_info=True)
raise
def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]:
"""推断DataFrame各列的SQL类型(原有逻辑完全保留)"""
type_mapping = {
'int64': 'BIGINT',
'float64': 'DOUBLE',
'datetime64[ns]': 'DATETIME',
'object': 'TEXT',
'bool': 'TINYINT(1)',
'category': 'VARCHAR(255)'
}
sql_types = {}
for col, dtype in df.dtypes.items():
dtype_str = str(dtype)
sql_types[col] = type_mapping.get(dtype_str, 'TEXT')
self.log.debug("将DataFrame类型映射为SQL类型",
映射关系=sql_types)
return sql_types
def create_table_from_df(self, table_name: str, df: pd.DataFrame,
primary_key: Union[str, List[str], None] = None) -> bool:
"""根据DataFrame结构创建表(原有逻辑完全保留)"""
if self.table_exists(table_name):
self.log.warning("表已存在", =table_name)
return False
self.log.debug("根据DataFrame结构创建新表",
=table_name,
=list(df.columns))
try:
sql_types = self.df_to_sql_type(df)
columns_sql = []
for col, sql_type in sql_types.items():
col_def = f"{col} {sql_type}"
columns_sql.append(col_def)
if primary_key:
if isinstance(primary_key, str):
primary_key = [primary_key]
pk_columns = [col for col in primary_key if col in sql_types]
if pk_columns:
columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})")
self.log.trace("设置主键",
=table_name,
主键=pk_columns)
create_sql = f"CREATE TABLE {table_name} (\n {',\n '.join(columns_sql)}\n)"
self.execute_sql(create_sql)
self.log.info("表创建成功", =table_name)
return True
except Exception as e:
self.log.error("创建表失败",
=table_name,
error=str(e),
exc_info=True)
return False
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
"""执行SQL语句(原有逻辑完全保留)"""
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
# Linux/macOS需要更长的执行时间
if platform.system() != 'Windows':
cursor.execute("SET SESSION max_execution_time=600000")
cursor.execute(sql, params)
if fetch:
result = cursor.fetchall()
self.log.debug("查询执行完成", 行数=len(result))
return result
else:
affected_rows = cursor.rowcount
conn.commit() # 立即提交
self.log.debug("更新执行完成", 受影响行数=affected_rows)
return affected_rows
except Exception as e:
self.log.error("SQL执行失败",
sql=sql,
params=params,
error=str(e),
error_type=type(e).__name__,
exc_info=True)
raise
def table_exists(self, table_name: str) -> bool:
"""检查表是否存在(原有逻辑完全保留)"""
sql = """
SELECT COUNT(*) as count
FROM `information_schema`.`tables`
WHERE `table_schema` = %s \
AND `table_name` = %s \
"""
params = (self.config['database'], table_name)
try:
result = self.execute_sql(sql, params, fetch=True)
exists = result[0][0] > 0 # 适配元组结果
self.log.debug("检查表是否存在",
=table_name,
存在=exists)
return exists
except Exception:
return False
def drop_table(self, table_name: str) -> bool:
"""删除表(原有逻辑完全保留)"""
if not self.table_exists(table_name):
self.log.warning("表不存在", =table_name)
return False
try:
self.execute_sql(f"DROP TABLE {table_name}")
self.log.info("表删除成功", =table_name)
return True
except Exception as e:
self.log.error("删除表失败",
=table_name,
error=str(e),
exc_info=True)
return False
def validate_connection(self) -> bool:
"""验证连接是否有效(原有逻辑完全保留)"""
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute("SELECT 1")
return cursor.fetchone()[0] == 1
except Exception:
return False
# 平台特定的默认配置(原有逻辑完全保留)
def get_default_config():
"""获取各平台默认配置"""
current_platform = platform.system()
base_config = {
'host': 'localhost',
'port': 3306,
'user': 'root',
'password': '123123',
'database': 'intelligence_system',
}
if current_platform == 'Windows':
return {**base_config,
'connect_timeout': 10,
'read_timeout': 30,
'write_timeout': 30
}
elif current_platform == 'Darwin':
return {
**base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
}
else: # Linux和其他平台
return {** base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60
}
if __name__ == "__main__":
# 使用示例(原有逻辑完全保留)
db = MySQLAgent(get_default_config())
# 测试连接
if db.validate_connection():
print("数据库连接成功")
# 获取数据库版本
version = db.query_to_df("SELECT VERSION() as version")
print(f"数据库版本: {version['version'].iloc[0]}")
else:
print("连接数据库失败")