837 lines
33 KiB
Python
837 lines
33 KiB
Python
import os
|
||
import sys
|
||
import platform
|
||
import pandas as pd
|
||
import pymysql
|
||
import json
|
||
import numpy as np
|
||
from pymysql import cursors
|
||
from pymysql.err import MySQLError
|
||
from typing import Union, List, Dict, Any, Optional, Tuple, Literal
|
||
import threading
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
|
||
# 导入日志系统
|
||
from utils.logger import log
|
||
|
||
|
||
class MySQLAgent:
|
||
"""
|
||
全平台兼容的MySQL数据库操作类
|
||
支持Windows/macOS/Linux系统
|
||
配置参数从外部传入,不使用连接池和事务管理
|
||
"""
|
||
|
||
_instance = None
|
||
_lock = threading.Lock()
|
||
|
||
def __new__(cls, *args, **kwargs):
|
||
if not cls._instance:
|
||
with cls._lock:
|
||
if not cls._instance:
|
||
cls._instance = super().__new__(cls)
|
||
return cls._instance
|
||
|
||
def __init__(self, config: dict):
|
||
"""初始化MySQL数据库连接(原有逻辑完全保留)"""
|
||
if hasattr(self, 'config') and self.config:
|
||
return
|
||
|
||
# 基础配置校验
|
||
required_keys = ['host', 'port', 'user', 'password', 'database']
|
||
if not all(key in config for key in required_keys):
|
||
log.warning(f"数据库配置缺少必要参数,当前数据库链接信息为:{config}")
|
||
raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}")
|
||
|
||
self.config = {
|
||
'host': config['host'],
|
||
'port': config['port'],
|
||
'user': config['user'],
|
||
'password': config['password'],
|
||
'database': config['database'],
|
||
'charset': config.get('charset', 'utf8mb4'),
|
||
'autocommit': True,
|
||
'connect_timeout': config.get('connect_timeout', 10),
|
||
'read_timeout': config.get('read_timeout', 30),
|
||
'write_timeout': config.get('write_timeout', 30),
|
||
'ssl': config.get('ssl')
|
||
}
|
||
|
||
# 初始化日志
|
||
current_platform = platform.system()
|
||
self.log = log.bind(module=f"MySQLAgent({current_platform})")
|
||
|
||
def get_connection(self) -> pymysql.connections.Connection:
|
||
"""获取数据库连接(原有逻辑完全保留)"""
|
||
try:
|
||
conn = pymysql.connect(** self.config)
|
||
|
||
# 为连接添加 character_set_name 方法
|
||
if not hasattr(conn, 'character_set_name'):
|
||
def _character_set_name():
|
||
return self.config.get('charset', 'utf8mb4')
|
||
|
||
conn.character_set_name = _character_set_name
|
||
|
||
# macOS需要特殊处理SSL
|
||
if platform.system() == 'Darwin' and self.config.get('ssl'):
|
||
conn.ping(reconnect=True)
|
||
|
||
self.log.trace("获取数据库连接成功")
|
||
return conn
|
||
|
||
except Exception as e:
|
||
error_msg = str(e)
|
||
if platform.system() == 'Windows' and "timed out" in error_msg:
|
||
self.log.warning("Windows连接超时,正在重试...")
|
||
return self._retry_connection()
|
||
|
||
self.log.error("连接失败",
|
||
error=error_msg,
|
||
error_type=type(e).__name__,
|
||
host=self.config.get('host'),
|
||
port=self.config.get('port'),
|
||
database=self.config.get('database'),
|
||
exc_info=True)
|
||
raise
|
||
|
||
def _retry_connection(self, max_retries: int = 3) -> Any | None:
|
||
"""Windows平台连接重试机制(原有逻辑完全保留)"""
|
||
for attempt in range(max_retries):
|
||
try:
|
||
conn = pymysql.connect(**self.config)
|
||
self.log.info(f"经过 {attempt + 1} 次尝试后成功建立连接")
|
||
return conn
|
||
except Exception:
|
||
if attempt == max_retries - 1:
|
||
raise
|
||
import time
|
||
time.sleep(1)
|
||
|
||
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
|
||
parse_dates: Union[List[str], bool] = True,is_print = True) -> pd.DataFrame:
|
||
"""执行SQL查询并返回DataFrame(原有逻辑完全保留)"""
|
||
try:
|
||
self.log.debug("执行SQL查询", sql=sql)
|
||
|
||
# 获取连接并确保字符集方法存在
|
||
conn = self.get_connection()
|
||
|
||
# 创建SQLAlchemy引擎
|
||
from sqlalchemy import create_engine
|
||
from sqlalchemy.pool import StaticPool
|
||
engine = create_engine(
|
||
"mysql+pymysql://",
|
||
creator=lambda: conn,
|
||
poolclass=StaticPool,
|
||
connect_args={'charset': self.config.get('charset', 'utf8mb4')}
|
||
)
|
||
|
||
# 执行查询
|
||
df = pd.read_sql(sql, engine, params=params, parse_dates=parse_dates)
|
||
if is_print:
|
||
self.log.info("查询执行成功", 行数=len(df))
|
||
|
||
return df
|
||
|
||
except Exception as e:
|
||
self.log.error("SQL查询失败",
|
||
sql=sql,
|
||
params=params,
|
||
error=str(e),
|
||
error_type=type(e).__name__,
|
||
exc_info=True)
|
||
raise
|
||
finally:
|
||
if 'engine' in locals():
|
||
engine.dispose()
|
||
|
||
def insert_from_df(self, table_name: str, df: pd.DataFrame,
|
||
chunk_size: int = 1000, replace: bool = False,
|
||
ignore_duplicates: bool = None) -> int:
|
||
"""
|
||
兼容旧接口的通用插入方法:保留replace参数,同时支持新的ignore_duplicates
|
||
自动处理重复数据,对所有数据源通用,插入失败的数据会通过日志记录
|
||
|
||
安全性说明:
|
||
- 使用 INSERT INTO(不是 REPLACE INTO 或 INSERT ... ON DUPLICATE KEY UPDATE)
|
||
- 当 ignore_duplicates=True 时,重复记录会被跳过,不会覆盖或删除现有数据
|
||
- 如果数据库连接失败,操作会抛出异常,不会部分成功
|
||
- 所有操作都是安全的,不会导致数据丢失或覆盖
|
||
"""
|
||
# 【兼容性处理】如果未指定ignore_duplicates,用replace参数推导
|
||
if ignore_duplicates is None:
|
||
ignore_duplicates = not replace # 旧逻辑中replace=True表示替换,即不忽略重复
|
||
|
||
if df.empty:
|
||
self.log.warning("尝试插入空的DataFrame", table=table_name)
|
||
return 0
|
||
|
||
conn = None
|
||
cursor = None
|
||
total_inserted = 0
|
||
total_duplicates = 0
|
||
total_failed = 0
|
||
failed_records = [] # 存储所有失败的记录
|
||
|
||
try:
|
||
# 1. 建立数据库连接
|
||
conn = self.get_connection()
|
||
cursor = conn.cursor()
|
||
self.log.debug(f"已建立连接,准备插入数据到 {table_name}")
|
||
|
||
# 2. 获取数据库表的实际列名
|
||
cursor.execute(f"SHOW COLUMNS FROM `{table_name}`")
|
||
columns_info = cursor.fetchall()
|
||
db_columns = [col[0] for col in columns_info]
|
||
self.log.debug(f"表 {table_name} 包含以下列:{db_columns}")
|
||
|
||
# 3. 数据预处理:统一处理空值
|
||
cleaned_df = df.replace(
|
||
[None, np.nan, pd.NA, 'nan', 'NaN', 'NAN', ''],
|
||
None
|
||
).copy()
|
||
|
||
# 4. 字段匹配:只保留与数据库匹配的列
|
||
df_columns = cleaned_df.columns.tolist()
|
||
matched_columns = [col for col in df_columns if col in db_columns]
|
||
unmatched_columns = [col for col in df_columns if col not in db_columns]
|
||
|
||
if unmatched_columns:
|
||
self.log.warning(
|
||
f"表 {table_name} 中存在不匹配的列,已自动丢弃",
|
||
unmatched_columns=unmatched_columns,
|
||
count=len(unmatched_columns)
|
||
)
|
||
|
||
if not matched_columns:
|
||
self.log.warning(f"表 {table_name} 没有匹配的列,终止插入操作")
|
||
return 0
|
||
|
||
filtered_df = cleaned_df[matched_columns].copy()
|
||
total_to_insert = len(filtered_df)
|
||
self.log.debug(
|
||
f"表 {table_name} 的过滤后DataFrame:共 {total_to_insert} 行待插入"
|
||
)
|
||
|
||
# 5. 处理复杂类型(dict/list转JSON)
|
||
for col in filtered_df.columns:
|
||
has_complex_type = filtered_df[col].apply(
|
||
lambda x: isinstance(x, (dict, list)) if x is not None else False
|
||
).any()
|
||
|
||
if has_complex_type:
|
||
self.log.debug(f"表 {table_name} 中的 {col} 列包含复杂类型,正在转换为JSON")
|
||
filtered_df.loc[:, col] = filtered_df[col].apply(
|
||
lambda x: json.dumps(x, ensure_ascii=False) if x is not None else x
|
||
)
|
||
|
||
# 6. 构建通用插入SQL
|
||
columns_str = ', '.join([f"`{col}`" for col in filtered_df.columns])
|
||
placeholders = ', '.join(['%s'] * len(filtered_df.columns))
|
||
insert_sql = f"INSERT INTO `{table_name}` ({columns_str}) VALUES ({placeholders})"
|
||
self.log.trace(f"为表 {table_name} 生成的插入SQL:{insert_sql}")
|
||
|
||
# 7. 逐条插入(确保能捕获单条重复错误)
|
||
records = filtered_df.to_dict('records')
|
||
indices = filtered_df.index.tolist()
|
||
|
||
for i, (record, idx) in enumerate(zip(records, indices)):
|
||
try:
|
||
data = tuple(record[col] for col in filtered_df.columns)
|
||
cursor.execute(insert_sql, data)
|
||
total_inserted += 1
|
||
|
||
if (i + 1) % 100 == 0:
|
||
self.log.trace(
|
||
f"已向表 {table_name} 插入 {i + 1}/{total_to_insert} 行数据"
|
||
)
|
||
|
||
except MySQLError as e:
|
||
# 8. 捕获重复错误(MySQL错误码1062)
|
||
if e.args[0] == 1062:
|
||
total_duplicates += 1
|
||
short_record = {
|
||
k: (str(v)[:100] + '...') if isinstance(v, (str, dict, list)) else v
|
||
for k, v in record.items()
|
||
}
|
||
self.log.warning(
|
||
f"表 {table_name} 中跳过重复记录",
|
||
index=idx,
|
||
error_message=e.args[1],
|
||
record=short_record
|
||
)
|
||
# 记录重复的记录
|
||
failed_records.append({
|
||
'index': idx,
|
||
'type': 'duplicate',
|
||
'error_code': e.args[0],
|
||
'error_message': e.args[1],
|
||
'record': record
|
||
})
|
||
if not ignore_duplicates:
|
||
raise
|
||
else:
|
||
# 其他数据库错误
|
||
total_failed += 1
|
||
# 记录失败的记录详情
|
||
failed_records.append({
|
||
'index': idx,
|
||
'type': 'error',
|
||
'error_code': e.args[0],
|
||
'error_message': e.args[1],
|
||
'record': record
|
||
})
|
||
self.log.error(
|
||
f"表 {table_name} 插入记录失败",
|
||
index=idx,
|
||
error_code=e.args[0],
|
||
error_message=e.args[1],
|
||
record=record # 完整记录写入日志
|
||
)
|
||
if not ignore_duplicates:
|
||
raise
|
||
|
||
# 提交事务
|
||
conn.commit()
|
||
|
||
# 9. 插入结果统计,包括失败记录汇总
|
||
self.log.info(
|
||
f"表 {table_name} 插入结果汇总",
|
||
total_to_insert=total_to_insert,
|
||
total_inserted=total_inserted,
|
||
total_duplicates=total_duplicates,
|
||
total_failed=total_failed,
|
||
failed_records_count=len(failed_records)
|
||
)
|
||
|
||
# 单独记录所有失败的数据详情
|
||
if failed_records:
|
||
self.log.error(
|
||
f"表 {table_name} 插入失败记录详情",
|
||
failed_records_summary=[
|
||
{
|
||
'index': r['index'],
|
||
'type': r['type'],
|
||
'error_code': r['error_code'],
|
||
'error_message': r['error_message']
|
||
} for r in failed_records
|
||
],
|
||
# 完整记录可以作为调试信息单独记录,避免日志过大
|
||
detailed_failed_records=failed_records
|
||
)
|
||
|
||
return total_inserted
|
||
|
||
except Exception as e:
|
||
if conn:
|
||
conn.rollback()
|
||
self.log.error(f"表 {table_name} 批量插入失败",
|
||
error=str(e),
|
||
error_type=type(e).__name__,
|
||
table_name=table_name,
|
||
total_records=len(df) if not df.empty else 0,
|
||
exc_info=True)
|
||
# 记录事务回滚时的失败记录
|
||
if failed_records:
|
||
self.log.error(
|
||
f"表 {table_name} 事务回滚,已失败的记录",
|
||
failed_records=failed_records,
|
||
failed_count=len(failed_records)
|
||
)
|
||
raise
|
||
finally:
|
||
if cursor:
|
||
cursor.close()
|
||
if conn:
|
||
conn.close()
|
||
|
||
def _get_primary_key(self, table_name: str, cursor) -> Optional[str]:
|
||
"""【新增辅助方法】获取表的主键(用于replace逻辑的去重)"""
|
||
try:
|
||
cursor.execute("""
|
||
SELECT COLUMN_NAME
|
||
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
|
||
WHERE TABLE_SCHEMA = %s
|
||
AND TABLE_NAME = %s
|
||
AND CONSTRAINT_NAME = 'PRIMARY'
|
||
""", (self.config['database'], table_name))
|
||
result = cursor.fetchone()
|
||
return result[0] if result else None
|
||
except Exception as e:
|
||
self.log.warning(f"获取表 {table_name} 的主键失败", error=str(e))
|
||
return None
|
||
|
||
def _get_table_detailed_info(self, table_name: str) -> Dict[str, Dict[str, Any]]:
|
||
"""获取表的详细结构信息(原有逻辑完全保留,供其他方法调用)"""
|
||
sql = """
|
||
SELECT column_name, data_type, character_maximum_length
|
||
FROM information_schema.columns
|
||
WHERE table_schema = %s \
|
||
AND table_name = %s \
|
||
"""
|
||
params = (self.config['database'], table_name)
|
||
|
||
try:
|
||
conn = self.get_connection()
|
||
try:
|
||
cursor = conn.cursor()
|
||
cursor.execute(sql, params)
|
||
result = cursor.fetchall()
|
||
|
||
# 强制转换为列表,避免游标类型导致的解析问题
|
||
result_list = list(result)
|
||
if not result_list:
|
||
self.log.error("未在表中找到任何列", 表=table_name)
|
||
return {}
|
||
|
||
schema = {}
|
||
for row in result_list:
|
||
# 确保正确提取字段名(兼容元组格式)
|
||
col_name = str(row[0]).strip() # 强制转为字符串并去空格
|
||
data_type = str(row[1]).strip()
|
||
max_length = row[2] if row[2] else None
|
||
|
||
schema[col_name] = {
|
||
'type': data_type,
|
||
'max_length': max_length
|
||
}
|
||
|
||
self.log.debug("成功获取表结构信息",
|
||
表=table_name,
|
||
列=list(schema.keys()))
|
||
return schema
|
||
finally:
|
||
cursor.close()
|
||
conn.close()
|
||
except Exception as e:
|
||
self.log.error("获取表详细信息失败",
|
||
表=table_name,
|
||
error=str(e))
|
||
raise
|
||
|
||
def _validate_and_clean_data(self, df: pd.DataFrame, table_name: str,
|
||
table_schema: Dict[str, Dict[str, Any]]) -> pd.DataFrame:
|
||
"""数据校验与清洗(原有逻辑完全保留,供其他方法调用)"""
|
||
# 1. 字段过滤:只保留表中存在的字段
|
||
df_columns = df.columns.tolist()
|
||
table_columns = list(table_schema.keys())
|
||
|
||
valid_columns = [col for col in df_columns if col in table_columns]
|
||
invalid_columns = [col for col in df_columns if col not in table_columns]
|
||
|
||
if invalid_columns:
|
||
self.log.warning("丢弃表中不存在的无效列",
|
||
表=table_name,
|
||
无效列=invalid_columns,
|
||
数量=len(invalid_columns))
|
||
|
||
cleaned_df = df[valid_columns].copy()
|
||
if cleaned_df.empty:
|
||
return cleaned_df
|
||
|
||
# 2. 处理每个字段的数据
|
||
for col in valid_columns:
|
||
col_info = table_schema[col]
|
||
data_type = col_info['type']
|
||
max_length = col_info['max_length']
|
||
|
||
# 2.1 处理空值
|
||
if cleaned_df[col].isnull().any():
|
||
# 根据字段类型设置默认值
|
||
default_value = '' if data_type in ['varchar', 'char', 'text'] else None
|
||
cleaned_df[col].fillna(default_value, inplace=True)
|
||
self.log.debug("替换空值",
|
||
表=table_name,
|
||
列=col,
|
||
默认值=default_value,
|
||
数量=cleaned_df[col].isnull().sum())
|
||
|
||
# 2.2 处理字符串类型的超长字段
|
||
if data_type in ['varchar', 'char'] and max_length:
|
||
# 确保是字符串类型
|
||
cleaned_df[col] = cleaned_df[col].astype(str)
|
||
# 截断超长内容
|
||
too_long_mask = cleaned_df[col].str.len() > max_length
|
||
if too_long_mask.any():
|
||
cleaned_df.loc[too_long_mask, col] = cleaned_df.loc[too_long_mask, col].str.slice(0, max_length)
|
||
self.log.warning("截断超长值",
|
||
表=table_name,
|
||
列=col,
|
||
最大长度=max_length,
|
||
数量=too_long_mask.sum())
|
||
|
||
# 2.3 处理日期时间类型
|
||
if data_type in ['datetime', 'timestamp']:
|
||
try:
|
||
# 尝试转换为datetime类型
|
||
cleaned_df[col] = pd.to_datetime(cleaned_df[col])
|
||
except Exception as e:
|
||
self.log.warning("转换为datetime失败,使用当前时间替代",
|
||
表=table_name,
|
||
列=col,
|
||
错误=str(e))
|
||
# 转换失败的用当前时间替代
|
||
invalid_mask = pd.to_datetime(cleaned_df[col], errors='coerce').isna()
|
||
cleaned_df.loc[invalid_mask, col] = datetime.now()
|
||
|
||
return cleaned_df
|
||
|
||
def update_from_df(self, table_name: str, df: pd.DataFrame,
|
||
key_columns: Union[str, List[str]]) -> int:
|
||
"""使用DataFrame数据更新数据库表(原有逻辑完全保留)"""
|
||
if df.empty:
|
||
self.log.warning("尝试使用空的DataFrame进行更新", 表=table_name)
|
||
return 0
|
||
|
||
self.log.debug("准备从DataFrame更新表数据",
|
||
表=table_name,
|
||
关键字列=key_columns,
|
||
行数=len(df))
|
||
|
||
try:
|
||
if isinstance(key_columns, str):
|
||
key_columns = [key_columns]
|
||
|
||
总更新数 = 0
|
||
with self.get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
# 获取表结构信息
|
||
table_info = self._get_table_detailed_info(table_name)
|
||
columns = [col for col in df.columns if col in table_info]
|
||
|
||
# 构建UPDATE语句模板
|
||
set_clause = ', '.join([f"{col}=%s" for col in columns if col not in key_columns])
|
||
where_clause = ' AND '.join([f"{col}=%s" for col in key_columns])
|
||
|
||
if not set_clause:
|
||
self.log.warning("没有可更新的列", 表=table_name)
|
||
return 0
|
||
|
||
update_sql = f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}"
|
||
self.log.trace("生成的更新SQL", sql=update_sql)
|
||
|
||
# 准备数据
|
||
update_data = []
|
||
for _, row in df.iterrows():
|
||
set_values = [row[col] for col in columns if col not in key_columns]
|
||
key_values = [row[col] for col in key_columns]
|
||
update_data.append(tuple(set_values + key_values))
|
||
|
||
# 执行批量更新
|
||
cursor.executemany(update_sql, update_data)
|
||
总更新数 = cursor.rowcount
|
||
conn.commit()
|
||
|
||
self.log.info("数据更新成功",
|
||
表=table_name,
|
||
更新行数=总更新数)
|
||
return 总更新数
|
||
|
||
except Exception as e:
|
||
self.log.error("数据更新失败",
|
||
表=table_name,
|
||
error=str(e),
|
||
exc_info=True)
|
||
raise
|
||
|
||
def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]:
|
||
"""推断DataFrame各列的SQL类型(原有逻辑完全保留)"""
|
||
type_mapping = {
|
||
'int64': 'BIGINT',
|
||
'float64': 'DOUBLE',
|
||
'datetime64[ns]': 'DATETIME',
|
||
'object': 'TEXT',
|
||
'bool': 'TINYINT(1)',
|
||
'category': 'VARCHAR(255)'
|
||
}
|
||
|
||
sql_types = {}
|
||
for col, dtype in df.dtypes.items():
|
||
dtype_str = str(dtype)
|
||
sql_types[col] = type_mapping.get(dtype_str, 'TEXT')
|
||
|
||
self.log.debug("将DataFrame类型映射为SQL类型",
|
||
映射关系=sql_types)
|
||
return sql_types
|
||
|
||
def create_table_from_df(self, table_name: str, df: pd.DataFrame,
|
||
primary_key: Union[str, List[str], None] = None) -> bool:
|
||
"""根据DataFrame结构创建表(原有逻辑完全保留)"""
|
||
if self.table_exists(table_name):
|
||
self.log.warning("表已存在", 表=table_name)
|
||
return False
|
||
|
||
self.log.debug("根据DataFrame结构创建新表",
|
||
表=table_name,
|
||
列=list(df.columns))
|
||
|
||
try:
|
||
sql_types = self.df_to_sql_type(df)
|
||
columns_sql = []
|
||
|
||
for col, sql_type in sql_types.items():
|
||
col_def = f"{col} {sql_type}"
|
||
columns_sql.append(col_def)
|
||
|
||
if primary_key:
|
||
if isinstance(primary_key, str):
|
||
primary_key = [primary_key]
|
||
pk_columns = [col for col in primary_key if col in sql_types]
|
||
if pk_columns:
|
||
columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})")
|
||
self.log.trace("设置主键",
|
||
表=table_name,
|
||
主键=pk_columns)
|
||
|
||
create_sql = f"CREATE TABLE {table_name} (\n {',\n '.join(columns_sql)}\n)"
|
||
|
||
self.execute_sql(create_sql)
|
||
self.log.info("表创建成功", 表=table_name)
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.log.error("创建表失败",
|
||
表=table_name,
|
||
error=str(e),
|
||
exc_info=True)
|
||
return False
|
||
|
||
def create_table_if_not_exists(self, table_name: str, create_sql: str) -> bool:
|
||
"""
|
||
创建表(如果不存在)
|
||
使用 CREATE TABLE IF NOT EXISTS,不会删除已存在的表和数据
|
||
|
||
参数:
|
||
table_name: 表名
|
||
create_sql: 完整的 CREATE TABLE SQL 语句(必须包含 IF NOT EXISTS)
|
||
|
||
返回:
|
||
bool: 是否成功(表已存在也会返回True)
|
||
|
||
注意:
|
||
- 此方法使用 CREATE TABLE IF NOT EXISTS,是安全的,不会删除现有数据
|
||
- 如果连接失败,会抛出异常
|
||
"""
|
||
if "IF NOT EXISTS" not in create_sql.upper():
|
||
self.log.warning(f"CREATE TABLE 语句建议使用 IF NOT EXISTS 以保证安全性")
|
||
|
||
try:
|
||
self.execute_sql(create_sql)
|
||
self.log.info(f"成功创建/检查表(表已存在时不会删除数据): {table_name}")
|
||
return True
|
||
except Exception as e:
|
||
self.log.error(f"创建/检查表失败(可能是数据库连接问题): {str(e)}",
|
||
table=table_name, exc_info=True)
|
||
raise
|
||
|
||
def add_unique_index_if_not_exists(self, table_name: str, index_name: str,
|
||
column_name: str, column_length: int = 500,
|
||
check_duplicates: bool = True) -> bool:
|
||
"""
|
||
添加唯一索引(如果不存在)
|
||
不会删除数据,只添加索引
|
||
|
||
参数:
|
||
table_name: 表名
|
||
index_name: 索引名称
|
||
column_name: 要添加索引的列名
|
||
column_length: 索引长度(对于VARCHAR/TEXT类型)
|
||
check_duplicates: 是否在添加索引前检查重复数据
|
||
|
||
返回:
|
||
bool: 是否成功添加索引(索引已存在也会返回True)
|
||
|
||
注意:
|
||
- 此方法是安全的,不会删除数据
|
||
- 如果表中存在重复数据,会跳过添加索引(不会删除数据)
|
||
- 如果连接失败,会抛出异常
|
||
"""
|
||
try:
|
||
# 1. 检查索引是否已存在
|
||
check_index_sql = f"""
|
||
SELECT COUNT(*) as cnt
|
||
FROM INFORMATION_SCHEMA.STATISTICS
|
||
WHERE TABLE_SCHEMA = %s
|
||
AND TABLE_NAME = %s
|
||
AND INDEX_NAME = %s
|
||
"""
|
||
result = self.query_to_df(
|
||
check_index_sql,
|
||
params=(self.config['database'], table_name, index_name),
|
||
is_print=False
|
||
)
|
||
|
||
if not result.empty and result['cnt'].iloc[0] > 0:
|
||
self.log.debug(f"唯一索引 {index_name} 已存在,跳过添加")
|
||
return True
|
||
|
||
# 2. 如果启用重复检查,先检查是否有重复数据
|
||
if check_duplicates:
|
||
check_duplicates_sql = f"""
|
||
SELECT {column_name}, COUNT(*) as cnt
|
||
FROM `{table_name}`
|
||
WHERE {column_name} IS NOT NULL AND {column_name} != ''
|
||
GROUP BY {column_name}
|
||
HAVING cnt > 1
|
||
LIMIT 1
|
||
"""
|
||
duplicates = self.query_to_df(check_duplicates_sql, is_print=False)
|
||
|
||
if not duplicates.empty:
|
||
self.log.warning(
|
||
f"表 {table_name} 中存在重复的 {column_name} 数据,无法添加唯一索引。"
|
||
"现有数据不会被删除。",
|
||
duplicate_count=len(duplicates)
|
||
)
|
||
return False
|
||
|
||
# 3. 添加唯一索引
|
||
add_index_sql = f"""
|
||
ALTER TABLE `{table_name}`
|
||
ADD UNIQUE KEY `{index_name}` ({column_name}({column_length}))
|
||
"""
|
||
self.execute_sql(add_index_sql)
|
||
self.log.info(f"成功添加唯一索引 {index_name}(现有数据不受影响)")
|
||
return True
|
||
|
||
except Exception as e:
|
||
error_msg = str(e)
|
||
# 如果索引已存在,不报错
|
||
if "Duplicate key name" in error_msg or "already exists" in error_msg.lower():
|
||
self.log.debug(f"唯一索引 {index_name} 已存在,跳过添加")
|
||
return True
|
||
else:
|
||
self.log.warning(f"添加唯一索引时出现问题(不影响现有数据): {error_msg}")
|
||
raise
|
||
|
||
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
|
||
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
|
||
"""执行SQL语句(原有逻辑完全保留)"""
|
||
try:
|
||
with self.get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
# Linux/macOS需要更长的执行时间
|
||
if platform.system() != 'Windows':
|
||
cursor.execute("SET SESSION max_execution_time=600000")
|
||
|
||
cursor.execute(sql, params)
|
||
|
||
if fetch:
|
||
result = cursor.fetchall()
|
||
self.log.debug("查询执行完成", 行数=len(result))
|
||
return result
|
||
else:
|
||
affected_rows = cursor.rowcount
|
||
conn.commit() # 立即提交
|
||
self.log.debug("更新执行完成", 受影响行数=affected_rows)
|
||
return affected_rows
|
||
|
||
except Exception as e:
|
||
self.log.error("SQL执行失败",
|
||
sql=sql,
|
||
params=params,
|
||
error=str(e),
|
||
error_type=type(e).__name__,
|
||
exc_info=True)
|
||
raise
|
||
|
||
def table_exists(self, table_name: str) -> bool:
|
||
"""检查表是否存在(原有逻辑完全保留)"""
|
||
sql = """
|
||
SELECT COUNT(*) as count
|
||
FROM `information_schema`.`tables`
|
||
WHERE `table_schema` = %s \
|
||
AND `table_name` = %s \
|
||
"""
|
||
|
||
params = (self.config['database'], table_name)
|
||
|
||
try:
|
||
result = self.execute_sql(sql, params, fetch=True)
|
||
exists = result[0][0] > 0 # 适配元组结果
|
||
self.log.debug("检查表是否存在",
|
||
表=table_name,
|
||
存在=exists)
|
||
return exists
|
||
except Exception:
|
||
return False
|
||
|
||
def drop_table(self, table_name: str) -> bool:
|
||
"""删除表(原有逻辑完全保留)"""
|
||
if not self.table_exists(table_name):
|
||
self.log.warning("表不存在", 表=table_name)
|
||
return False
|
||
|
||
try:
|
||
self.execute_sql(f"DROP TABLE {table_name}")
|
||
self.log.info("表删除成功", 表=table_name)
|
||
return True
|
||
except Exception as e:
|
||
self.log.error("删除表失败",
|
||
表=table_name,
|
||
error=str(e),
|
||
exc_info=True)
|
||
return False
|
||
|
||
def validate_connection(self) -> bool:
|
||
"""验证连接是否有效(原有逻辑完全保留)"""
|
||
try:
|
||
with self.get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
cursor.execute("SELECT 1")
|
||
return cursor.fetchone()[0] == 1
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
# 平台特定的默认配置(原有逻辑完全保留)
|
||
def get_default_config():
|
||
"""获取各平台默认配置"""
|
||
current_platform = platform.system()
|
||
|
||
base_config = {
|
||
'host': 'localhost',
|
||
'port': 3306,
|
||
'user': 'root',
|
||
'password': '123123',
|
||
'database': 'intelligence_system',
|
||
}
|
||
|
||
if current_platform == 'Windows':
|
||
return {**base_config,
|
||
'connect_timeout': 10,
|
||
'read_timeout': 30,
|
||
'write_timeout': 30
|
||
}
|
||
elif current_platform == 'Darwin':
|
||
return {
|
||
**base_config,
|
||
'connect_timeout': 15,
|
||
'read_timeout': 60,
|
||
'write_timeout': 60,
|
||
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
|
||
}
|
||
else: # Linux和其他平台
|
||
return {** base_config,
|
||
'connect_timeout': 15,
|
||
'read_timeout': 60,
|
||
'write_timeout': 60
|
||
}
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 使用示例(原有逻辑完全保留)
|
||
db = MySQLAgent(get_default_config())
|
||
|
||
# 测试连接
|
||
if db.validate_connection():
|
||
print("数据库连接成功")
|
||
|
||
# 获取数据库版本
|
||
version = db.query_to_df("SELECT VERSION() as version")
|
||
print(f"数据库版本: {version['version'].iloc[0]}")
|
||
else:
|
||
print("连接数据库失败")
|