Files
2025-10-17 17:59:28 +08:00

723 lines
28 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import sys
import platform
import pandas as pd
import pymysql
import json
import numpy as np
from pymysql import cursors
from pymysql.err import MySQLError
from typing import Union, List, Dict, Any, Optional, Tuple, Literal
import threading
from datetime import datetime
from pathlib import Path
# 导入日志系统
from utils.logger import log
class MySQLAgent:
"""
全平台兼容的MySQL数据库操作类
支持Windows/macOS/Linux系统
配置参数从外部传入,不使用连接池和事务管理
"""
_instance = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, config: dict):
"""初始化MySQL数据库连接(原有逻辑完全保留)"""
if hasattr(self, 'config') and self.config:
return
# 基础配置校验
required_keys = ['host', 'port', 'user', 'password', 'database']
if not all(key in config for key in required_keys):
log.warning(f"数据库配置缺少必要参数,当前数据库链接信息为:{config}")
raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}")
self.config = {
'host': config['host'],
'port': config['port'],
'user': config['user'],
'password': config['password'],
'database': config['database'],
'charset': config.get('charset', 'utf8mb4'),
'autocommit': True,
'connect_timeout': config.get('connect_timeout', 10),
'read_timeout': config.get('read_timeout', 30),
'write_timeout': config.get('write_timeout', 30),
'ssl': config.get('ssl')
}
# 初始化日志
current_platform = platform.system()
self.log = log.bind(module=f"MySQLAgent({current_platform})")
def get_connection(self) -> pymysql.connections.Connection:
"""获取数据库连接(原有逻辑完全保留)"""
try:
conn = pymysql.connect(** self.config)
# 为连接添加 character_set_name 方法
if not hasattr(conn, 'character_set_name'):
def _character_set_name():
return self.config.get('charset', 'utf8mb4')
conn.character_set_name = _character_set_name
# macOS需要特殊处理SSL
if platform.system() == 'Darwin' and self.config.get('ssl'):
conn.ping(reconnect=True)
self.log.trace("获取数据库连接成功")
return conn
except Exception as e:
error_msg = str(e)
if platform.system() == 'Windows' and "timed out" in error_msg:
self.log.warning("Windows连接超时,正在重试...")
return self._retry_connection()
self.log.error("连接失败",
error=error_msg,
error_type=type(e).__name__,
host=self.config.get('host'),
port=self.config.get('port'),
database=self.config.get('database'),
exc_info=True)
raise
def _retry_connection(self, max_retries: int = 3) -> Any | None:
"""Windows平台连接重试机制(原有逻辑完全保留)"""
for attempt in range(max_retries):
try:
conn = pymysql.connect(**self.config)
self.log.info(f"经过 {attempt + 1} 次尝试后成功建立连接")
return conn
except Exception:
if attempt == max_retries - 1:
raise
import time
time.sleep(1)
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
parse_dates: Union[List[str], bool] = True,is_print = True) -> pd.DataFrame:
"""执行SQL查询并返回DataFrame(原有逻辑完全保留)"""
try:
self.log.debug("执行SQL查询", sql=sql)
# 获取连接并确保字符集方法存在
conn = self.get_connection()
# 创建SQLAlchemy引擎
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
engine = create_engine(
"mysql+pymysql://",
creator=lambda: conn,
poolclass=StaticPool,
connect_args={'charset': self.config.get('charset', 'utf8mb4')}
)
# 执行查询
df = pd.read_sql(sql, engine, params=params, parse_dates=parse_dates)
if is_print:
self.log.info("查询执行成功", 行数=len(df))
return df
except Exception as e:
self.log.error("SQL查询失败",
sql=sql,
params=params,
error=str(e),
error_type=type(e).__name__,
exc_info=True)
raise
finally:
if 'engine' in locals():
engine.dispose()
def insert_from_df(self, table_name: str, df: pd.DataFrame,
chunk_size: int = 1000, replace: bool = False,
ignore_duplicates: bool = None) -> int:
"""
兼容旧接口的通用插入方法:保留replace参数,同时支持新的ignore_duplicates
自动处理重复数据,对所有数据源通用,插入失败的数据会通过日志记录
"""
# 【兼容性处理】如果未指定ignore_duplicates,用replace参数推导
if ignore_duplicates is None:
ignore_duplicates = not replace # 旧逻辑中replace=True表示替换,即不忽略重复
if df.empty:
self.log.warning("尝试插入空的DataFrame", table=table_name)
return 0
conn = None
cursor = None
total_inserted = 0
total_duplicates = 0
total_failed = 0
failed_records = [] # 存储所有失败的记录
try:
# 1. 建立数据库连接
conn = self.get_connection()
cursor = conn.cursor()
self.log.debug(f"已建立连接,准备插入数据到 {table_name}")
# 2. 获取数据库表的实际列名
cursor.execute(f"SHOW COLUMNS FROM `{table_name}`")
columns_info = cursor.fetchall()
db_columns = [col[0] for col in columns_info]
self.log.debug(f"{table_name} 包含以下列:{db_columns}")
# 3. 数据预处理:统一处理空值
cleaned_df = df.replace(
[None, np.nan, pd.NA, 'nan', 'NaN', 'NAN', ''],
None
).copy()
# 4. 字段匹配:只保留与数据库匹配的列
df_columns = cleaned_df.columns.tolist()
matched_columns = [col for col in df_columns if col in db_columns]
unmatched_columns = [col for col in df_columns if col not in db_columns]
if unmatched_columns:
self.log.warning(
f"{table_name} 中存在不匹配的列,已自动丢弃",
unmatched_columns=unmatched_columns,
count=len(unmatched_columns)
)
if not matched_columns:
self.log.warning(f"{table_name} 没有匹配的列,终止插入操作")
return 0
filtered_df = cleaned_df[matched_columns].copy()
total_to_insert = len(filtered_df)
self.log.debug(
f"{table_name} 的过滤后DataFrame:共 {total_to_insert} 行待插入"
)
# 5. 处理复杂类型(dict/list转JSON
for col in filtered_df.columns:
has_complex_type = filtered_df[col].apply(
lambda x: isinstance(x, (dict, list)) if x is not None else False
).any()
if has_complex_type:
self.log.debug(f"{table_name} 中的 {col} 列包含复杂类型,正在转换为JSON")
filtered_df.loc[:, col] = filtered_df[col].apply(
lambda x: json.dumps(x, ensure_ascii=False) if x is not None else x
)
# 6. 构建通用插入SQL
columns_str = ', '.join([f"`{col}`" for col in filtered_df.columns])
placeholders = ', '.join(['%s'] * len(filtered_df.columns))
insert_sql = f"INSERT INTO `{table_name}` ({columns_str}) VALUES ({placeholders})"
self.log.trace(f"为表 {table_name} 生成的插入SQL{insert_sql}")
# 7. 逐条插入(确保能捕获单条重复错误)
records = filtered_df.to_dict('records')
indices = filtered_df.index.tolist()
for i, (record, idx) in enumerate(zip(records, indices)):
try:
data = tuple(record[col] for col in filtered_df.columns)
cursor.execute(insert_sql, data)
total_inserted += 1
if (i + 1) % 100 == 0:
self.log.trace(
f"已向表 {table_name} 插入 {i + 1}/{total_to_insert} 行数据"
)
except MySQLError as e:
# 8. 捕获重复错误(MySQL错误码1062)
if e.args[0] == 1062:
total_duplicates += 1
short_record = {
k: (str(v)[:100] + '...') if isinstance(v, (str, dict, list)) else v
for k, v in record.items()
}
self.log.warning(
f"{table_name} 中跳过重复记录",
index=idx,
error_message=e.args[1],
record=short_record
)
# 记录重复的记录
failed_records.append({
'index': idx,
'type': 'duplicate',
'error_code': e.args[0],
'error_message': e.args[1],
'record': record
})
if not ignore_duplicates:
raise
else:
# 其他数据库错误
total_failed += 1
# 记录失败的记录详情
failed_records.append({
'index': idx,
'type': 'error',
'error_code': e.args[0],
'error_message': e.args[1],
'record': record
})
self.log.error(
f"{table_name} 插入记录失败",
index=idx,
error_code=e.args[0],
error_message=e.args[1],
record=record # 完整记录写入日志
)
if not ignore_duplicates:
raise
# 提交事务
conn.commit()
# 9. 插入结果统计,包括失败记录汇总
self.log.info(
f"{table_name} 插入结果汇总",
total_to_insert=total_to_insert,
total_inserted=total_inserted,
total_duplicates=total_duplicates,
total_failed=total_failed,
failed_records_count=len(failed_records)
)
# 单独记录所有失败的数据详情
if failed_records:
self.log.error(
f"{table_name} 插入失败记录详情",
failed_records_summary=[
{
'index': r['index'],
'type': r['type'],
'error_code': r['error_code'],
'error_message': r['error_message']
} for r in failed_records
],
# 完整记录可以作为调试信息单独记录,避免日志过大
detailed_failed_records=failed_records
)
return total_inserted
except Exception as e:
if conn:
conn.rollback()
self.log.error(f"{table_name} 批量插入失败",
error=str(e),
error_type=type(e).__name__,
table_name=table_name,
total_records=len(df) if not df.empty else 0,
exc_info=True)
# 记录事务回滚时的失败记录
if failed_records:
self.log.error(
f"{table_name} 事务回滚,已失败的记录",
failed_records=failed_records,
failed_count=len(failed_records)
)
raise
finally:
if cursor:
cursor.close()
if conn:
conn.close()
def _get_primary_key(self, table_name: str, cursor) -> Optional[str]:
"""【新增辅助方法】获取表的主键(用于replace逻辑的去重)"""
try:
cursor.execute("""
SELECT COLUMN_NAME
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
WHERE TABLE_SCHEMA = %s
AND TABLE_NAME = %s
AND CONSTRAINT_NAME = 'PRIMARY'
""", (self.config['database'], table_name))
result = cursor.fetchone()
return result[0] if result else None
except Exception as e:
self.log.warning(f"获取表 {table_name} 的主键失败", error=str(e))
return None
def _get_table_detailed_info(self, table_name: str) -> Dict[str, Dict[str, Any]]:
"""获取表的详细结构信息(原有逻辑完全保留,供其他方法调用)"""
sql = """
SELECT column_name, data_type, character_maximum_length
FROM information_schema.columns
WHERE table_schema = %s \
AND table_name = %s \
"""
params = (self.config['database'], table_name)
try:
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(sql, params)
result = cursor.fetchall()
# 强制转换为列表,避免游标类型导致的解析问题
result_list = list(result)
if not result_list:
self.log.error("未在表中找到任何列", =table_name)
return {}
schema = {}
for row in result_list:
# 确保正确提取字段名(兼容元组格式)
col_name = str(row[0]).strip() # 强制转为字符串并去空格
data_type = str(row[1]).strip()
max_length = row[2] if row[2] else None
schema[col_name] = {
'type': data_type,
'max_length': max_length
}
self.log.debug("成功获取表结构信息",
=table_name,
=list(schema.keys()))
return schema
finally:
cursor.close()
conn.close()
except Exception as e:
self.log.error("获取表详细信息失败",
=table_name,
error=str(e))
raise
def _validate_and_clean_data(self, df: pd.DataFrame, table_name: str,
table_schema: Dict[str, Dict[str, Any]]) -> pd.DataFrame:
"""数据校验与清洗(原有逻辑完全保留,供其他方法调用)"""
# 1. 字段过滤:只保留表中存在的字段
df_columns = df.columns.tolist()
table_columns = list(table_schema.keys())
valid_columns = [col for col in df_columns if col in table_columns]
invalid_columns = [col for col in df_columns if col not in table_columns]
if invalid_columns:
self.log.warning("丢弃表中不存在的无效列",
=table_name,
无效列=invalid_columns,
数量=len(invalid_columns))
cleaned_df = df[valid_columns].copy()
if cleaned_df.empty:
return cleaned_df
# 2. 处理每个字段的数据
for col in valid_columns:
col_info = table_schema[col]
data_type = col_info['type']
max_length = col_info['max_length']
# 2.1 处理空值
if cleaned_df[col].isnull().any():
# 根据字段类型设置默认值
default_value = '' if data_type in ['varchar', 'char', 'text'] else None
cleaned_df[col].fillna(default_value, inplace=True)
self.log.debug("替换空值",
=table_name,
=col,
默认值=default_value,
数量=cleaned_df[col].isnull().sum())
# 2.2 处理字符串类型的超长字段
if data_type in ['varchar', 'char'] and max_length:
# 确保是字符串类型
cleaned_df[col] = cleaned_df[col].astype(str)
# 截断超长内容
too_long_mask = cleaned_df[col].str.len() > max_length
if too_long_mask.any():
cleaned_df.loc[too_long_mask, col] = cleaned_df.loc[too_long_mask, col].str.slice(0, max_length)
self.log.warning("截断超长值",
=table_name,
=col,
最大长度=max_length,
数量=too_long_mask.sum())
# 2.3 处理日期时间类型
if data_type in ['datetime', 'timestamp']:
try:
# 尝试转换为datetime类型
cleaned_df[col] = pd.to_datetime(cleaned_df[col])
except Exception as e:
self.log.warning("转换为datetime失败,使用当前时间替代",
=table_name,
=col,
错误=str(e))
# 转换失败的用当前时间替代
invalid_mask = pd.to_datetime(cleaned_df[col], errors='coerce').isna()
cleaned_df.loc[invalid_mask, col] = datetime.now()
return cleaned_df
def update_from_df(self, table_name: str, df: pd.DataFrame,
key_columns: Union[str, List[str]]) -> int:
"""使用DataFrame数据更新数据库表(原有逻辑完全保留)"""
if df.empty:
self.log.warning("尝试使用空的DataFrame进行更新", =table_name)
return 0
self.log.debug("准备从DataFrame更新表数据",
=table_name,
关键字列=key_columns,
行数=len(df))
try:
if isinstance(key_columns, str):
key_columns = [key_columns]
总更新数 = 0
with self.get_connection() as conn:
with conn.cursor() as cursor:
# 获取表结构信息
table_info = self._get_table_detailed_info(table_name)
columns = [col for col in df.columns if col in table_info]
# 构建UPDATE语句模板
set_clause = ', '.join([f"{col}=%s" for col in columns if col not in key_columns])
where_clause = ' AND '.join([f"{col}=%s" for col in key_columns])
if not set_clause:
self.log.warning("没有可更新的列", =table_name)
return 0
update_sql = f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}"
self.log.trace("生成的更新SQL", sql=update_sql)
# 准备数据
update_data = []
for _, row in df.iterrows():
set_values = [row[col] for col in columns if col not in key_columns]
key_values = [row[col] for col in key_columns]
update_data.append(tuple(set_values + key_values))
# 执行批量更新
cursor.executemany(update_sql, update_data)
总更新数 = cursor.rowcount
conn.commit()
self.log.info("数据更新成功",
=table_name,
更新行数=总更新数)
return 总更新数
except Exception as e:
self.log.error("数据更新失败",
=table_name,
error=str(e),
exc_info=True)
raise
def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]:
"""推断DataFrame各列的SQL类型(原有逻辑完全保留)"""
type_mapping = {
'int64': 'BIGINT',
'float64': 'DOUBLE',
'datetime64[ns]': 'DATETIME',
'object': 'TEXT',
'bool': 'TINYINT(1)',
'category': 'VARCHAR(255)'
}
sql_types = {}
for col, dtype in df.dtypes.items():
dtype_str = str(dtype)
sql_types[col] = type_mapping.get(dtype_str, 'TEXT')
self.log.debug("将DataFrame类型映射为SQL类型",
映射关系=sql_types)
return sql_types
def create_table_from_df(self, table_name: str, df: pd.DataFrame,
primary_key: Union[str, List[str], None] = None) -> bool:
"""根据DataFrame结构创建表(原有逻辑完全保留)"""
if self.table_exists(table_name):
self.log.warning("表已存在", =table_name)
return False
self.log.debug("根据DataFrame结构创建新表",
=table_name,
=list(df.columns))
try:
sql_types = self.df_to_sql_type(df)
columns_sql = []
for col, sql_type in sql_types.items():
col_def = f"{col} {sql_type}"
columns_sql.append(col_def)
if primary_key:
if isinstance(primary_key, str):
primary_key = [primary_key]
pk_columns = [col for col in primary_key if col in sql_types]
if pk_columns:
columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})")
self.log.trace("设置主键",
=table_name,
主键=pk_columns)
create_sql = f"CREATE TABLE {table_name} (\n {',\n '.join(columns_sql)}\n)"
self.execute_sql(create_sql)
self.log.info("表创建成功", =table_name)
return True
except Exception as e:
self.log.error("创建表失败",
=table_name,
error=str(e),
exc_info=True)
return False
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
"""执行SQL语句(原有逻辑完全保留)"""
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
# Linux/macOS需要更长的执行时间
if platform.system() != 'Windows':
cursor.execute("SET SESSION max_execution_time=600000")
cursor.execute(sql, params)
if fetch:
result = cursor.fetchall()
self.log.debug("查询执行完成", 行数=len(result))
return result
else:
affected_rows = cursor.rowcount
conn.commit() # 立即提交
self.log.debug("更新执行完成", 受影响行数=affected_rows)
return affected_rows
except Exception as e:
self.log.error("SQL执行失败",
sql=sql,
params=params,
error=str(e),
error_type=type(e).__name__,
exc_info=True)
raise
def table_exists(self, table_name: str) -> bool:
"""检查表是否存在(原有逻辑完全保留)"""
sql = """
SELECT COUNT(*) as count
FROM `information_schema`.`tables`
WHERE `table_schema` = %s \
AND `table_name` = %s \
"""
params = (self.config['database'], table_name)
try:
result = self.execute_sql(sql, params, fetch=True)
exists = result[0][0] > 0 # 适配元组结果
self.log.debug("检查表是否存在",
=table_name,
存在=exists)
return exists
except Exception:
return False
def drop_table(self, table_name: str) -> bool:
"""删除表(原有逻辑完全保留)"""
if not self.table_exists(table_name):
self.log.warning("表不存在", =table_name)
return False
try:
self.execute_sql(f"DROP TABLE {table_name}")
self.log.info("表删除成功", =table_name)
return True
except Exception as e:
self.log.error("删除表失败",
=table_name,
error=str(e),
exc_info=True)
return False
def validate_connection(self) -> bool:
"""验证连接是否有效(原有逻辑完全保留)"""
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute("SELECT 1")
return cursor.fetchone()[0] == 1
except Exception:
return False
# 平台特定的默认配置(原有逻辑完全保留)
def get_default_config():
"""获取各平台默认配置"""
current_platform = platform.system()
base_config = {
'host': 'localhost',
'port': 3306,
'user': 'root',
'password': '123123',
'database': 'intelligence_system',
}
if current_platform == 'Windows':
return {**base_config,
'connect_timeout': 10,
'read_timeout': 30,
'write_timeout': 30
}
elif current_platform == 'Darwin':
return {
**base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
}
else: # Linux和其他平台
return {** base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60
}
if __name__ == "__main__":
# 使用示例(原有逻辑完全保留)
db = MySQLAgent(get_default_config())
# 测试连接
if db.validate_connection():
print("数据库连接成功")
# 获取数据库版本
version = db.query_to_df("SELECT VERSION() as version")
print(f"数据库版本: {version['version'].iloc[0]}")
else:
print("连接数据库失败")