Files
intelligence_system/utils/mysql_agent.py
T
2025-09-18 17:03:24 +08:00

663 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import sys
import platform
import pandas as pd
import pymysql
import json
import numpy as np
from pymysql import cursors
from pymysql.err import MySQLError
from typing import Union, List, Dict, Any, Optional, Tuple, Literal
import threading
from datetime import datetime
from pathlib import Path
# 导入日志系统
from utils.logger import log
class MySQLAgent:
"""
全平台兼容的MySQL数据库操作类
支持Windows/macOS/Linux系统
配置参数从外部传入,不使用连接池和事务管理
"""
_instance = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, config: dict):
"""初始化MySQL数据库连接(原有逻辑完全保留)"""
if hasattr(self, 'config') and self.config:
return
# 基础配置校验
required_keys = ['host', 'port', 'user', 'password', 'database']
if not all(key in config for key in required_keys):
log.warning(f"数据库配置缺少必要参数,当前数据库链接信息为:{config}")
raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}")
self.config = {
'host': config['host'],
'port': config['port'],
'user': config['user'],
'password': config['password'],
'database': config['database'],
'charset': config.get('charset', 'utf8mb4'),
'autocommit': True,
'connect_timeout': config.get('connect_timeout', 10),
'read_timeout': config.get('read_timeout', 30),
'write_timeout': config.get('write_timeout', 30),
'ssl': config.get('ssl')
}
# 初始化日志
current_platform = platform.system()
self.log = log.bind(module=f"MySQLAgent({current_platform})")
def get_connection(self) -> pymysql.connections.Connection:
"""获取数据库连接(原有逻辑完全保留)"""
try:
conn = pymysql.connect(**self.config)
# 为连接添加 character_set_name 方法
if not hasattr(conn, 'character_set_name'):
def _character_set_name():
return self.config.get('charset', 'utf8mb4')
conn.character_set_name = _character_set_name
# macOS需要特殊处理SSL
if platform.system() == 'Darwin' and self.config.get('ssl'):
conn.ping(reconnect=True)
self.log.trace("Database connection obtained")
return conn
except Exception as e:
error_msg = str(e)
if platform.system() == 'Windows' and "timed out" in error_msg:
self.log.warning("Windows connection timeout, retrying...")
return self._retry_connection()
self.log.error("Connection failed", error=error_msg, exc_info=True)
raise
def _retry_connection(self, max_retries: int = 3) -> Any | None:
"""Windows平台连接重试机制(原有逻辑完全保留)"""
for attempt in range(max_retries):
try:
conn = pymysql.connect(**self.config)
self.log.info(f"Connection established after {attempt + 1} attempts")
return conn
except Exception:
if attempt == max_retries - 1:
raise
import time
time.sleep(1)
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
parse_dates: Union[List[str], bool] = True) -> pd.DataFrame:
"""执行SQL查询并返回DataFrame(原有逻辑完全保留)"""
try:
self.log.debug("Executing SQL query", sql=sql)
# 获取连接并确保字符集方法存在
conn = self.get_connection()
# 创建SQLAlchemy引擎
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
engine = create_engine(
"mysql+pymysql://",
creator=lambda: conn,
poolclass=StaticPool,
connect_args={'charset': self.config.get('charset', 'utf8mb4')}
)
# 执行查询
df = pd.read_sql(sql, engine, params=params, parse_dates=parse_dates)
self.log.info("Query executed successfully", rows=len(df))
return df
except Exception as e:
self.log.error("SQL query failed", sql=sql, params=params, error=str(e), exc_info=True)
raise
finally:
if 'engine' in locals():
engine.dispose()
def insert_from_df(self, table_name: str, df: pd.DataFrame,
chunk_size: int = 1000, replace: bool = False, # 保留replace参数
ignore_duplicates: bool = None) -> int: # 新增ignore_duplicates参数
"""
兼容旧接口的通用插入方法:保留replace参数,同时支持新的ignore_duplicates
自动处理重复数据,对所有数据源通用
"""
# 【兼容性处理】如果未指定ignore_duplicates,用replace参数推导(replace=True时不忽略重复)
if ignore_duplicates is None:
ignore_duplicates = not replace # 旧逻辑中replace=True表示替换,即不忽略重复
if df.empty:
self.log.warning("Attempted to insert empty DataFrame", table=table_name)
return 0
conn = None
cursor = None
total_inserted = 0
total_duplicated = 0
total_failed = 0
try:
# 1. 建立数据库连接
conn = self.get_connection()
cursor = conn.cursor()
self.log.debug(f"Established connection for inserting into {table_name}")
# 2. 获取数据库表的实际列名
cursor.execute(f"SHOW COLUMNS FROM `{table_name}`")
columns_info = cursor.fetchall()
db_columns = [col[0] for col in columns_info]
self.log.debug(f"Table {table_name} has columns: {db_columns}")
# 3. 数据预处理:统一处理空值
cleaned_df = df.replace(
[None, np.nan, pd.NA, 'nan', 'NaN', 'NAN', ''],
None
).copy()
# 4. 字段匹配:只保留与数据库匹配的列
df_columns = cleaned_df.columns.tolist()
matched_columns = [col for col in df_columns if col in db_columns]
unmatched_columns = [col for col in df_columns if col not in db_columns]
if unmatched_columns:
self.log.warning(
f"Table {table_name} dropping unmatched columns",
unmatched_columns=unmatched_columns,
count=len(unmatched_columns)
)
if not matched_columns:
self.log.warning(f"No matched columns for {table_name}, abort insertion")
return 0
filtered_df = cleaned_df[matched_columns].copy()
total_to_insert = len(filtered_df)
self.log.debug(
f"Filtered DataFrame for {table_name}: {total_to_insert} rows to insert"
)
# 5. 处理复杂类型(dict/list转JSON
for col in filtered_df.columns:
has_complex_type = filtered_df[col].apply(
lambda x: isinstance(x, (dict, list)) if x is not None else False
).any()
if has_complex_type:
self.log.debug(f"Column {col} in {table_name} has complex type, converting to JSON")
filtered_df.loc[:, col] = filtered_df[col].apply(
lambda x: json.dumps(x, ensure_ascii=False) if x is not None else x
)
# 6. 构建通用插入SQL
columns_str = ', '.join([f"`{col}`" for col in filtered_df.columns])
placeholders = ', '.join(['%s'] * len(filtered_df.columns))
insert_sql = f"INSERT INTO `{table_name}` ({columns_str}) VALUES ({placeholders})"
self.log.trace(f"Generated insert SQL for {table_name}: {insert_sql}")
# 7. 逐条插入(确保能捕获单条重复错误)
records = filtered_df.to_dict('records')
indices = filtered_df.index.tolist()
for i, (record, idx) in enumerate(zip(records, indices)):
try:
data = tuple(record[col] for col in filtered_df.columns)
cursor.execute(insert_sql, data)
total_inserted += 1
if (i + 1) % 100 == 0:
self.log.trace(
f"Inserted {i + 1}/{total_to_insert} rows into {table_name}"
)
except MySQLError as e:
# 8. 捕获重复错误(MySQL错误码1062)
if e.args[0] == 1062:
total_duplicated += 1
short_record = {
k: (str(v)[:100] + '...') if isinstance(v, (str, dict, list)) else v
for k, v in record.items()
}
self.log.warning(
f"Skipped duplicate record in {table_name}",
index=idx,
error_msg=e.args[1],
record=short_record
)
if not ignore_duplicates:
raise
else:
# 其他数据库错误
total_failed += 1
self.log.error(
f"Failed to insert record in {table_name}",
index=idx,
error_code=e.args[0],
error_msg=e.args[1],
record=record
)
if not ignore_duplicates:
raise
# 提交事务
conn.commit()
# 9. 插入结果统计
self.log.info(
f"Insertion summary for {table_name}",
total_to_insert=total_to_insert,
total_inserted=total_inserted,
total_duplicated=total_duplicated,
total_failed=total_failed
)
return total_inserted
except Exception as e:
if conn:
conn.rollback()
self.log.error(f"Batch insertion failed for {table_name}", error=str(e), exc_info=True)
raise
finally:
if cursor:
cursor.close()
if conn:
conn.close()
def _get_primary_key(self, table_name: str, cursor) -> Optional[str]:
"""【新增辅助方法】获取表的主键(用于replace逻辑的去重)"""
try:
cursor.execute("""
SELECT COLUMN_NAME
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
WHERE TABLE_SCHEMA = %s
AND TABLE_NAME = %s
AND CONSTRAINT_NAME = 'PRIMARY'
""", (self.config['database'], table_name))
result = cursor.fetchone()
return result[0] if result else None
except Exception as e:
self.log.warning(f"Failed to get primary key for {table_name}", error=str(e))
return None
def _get_table_detailed_info(self, table_name: str) -> Dict[str, Dict[str, Any]]:
"""获取表的详细结构信息(原有逻辑完全保留,供其他方法调用)"""
sql = """
SELECT column_name, data_type, character_maximum_length
FROM information_schema.columns
WHERE table_schema = %s \
AND table_name = %s \
"""
params = (self.config['database'], table_name)
try:
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(sql, params)
result = cursor.fetchall()
# 强制转换为列表,避免游标类型导致的解析问题
result_list = list(result)
if not result_list:
self.log.error("No columns found in table", table=table_name)
return {}
schema = {}
for row in result_list:
# 确保正确提取字段名(兼容元组格式)
col_name = str(row[0]).strip() # 强制转为字符串并去空格
data_type = str(row[1]).strip()
max_length = row[2] if row[2] else None
schema[col_name] = {
'type': data_type,
'max_length': max_length
}
self.log.debug("Successfully fetched table schema",
table=table_name,
columns=list(schema.keys()))
return schema
finally:
cursor.close()
conn.close()
except Exception as e:
self.log.error("Failed to get table detailed info",
table=table_name,
error=str(e))
raise
def _validate_and_clean_data(self, df: pd.DataFrame, table_name: str,
table_schema: Dict[str, Dict[str, Any]]) -> pd.DataFrame:
"""数据校验与清洗(原有逻辑完全保留,供其他方法调用)"""
# 1. 字段过滤:只保留表中存在的字段
df_columns = df.columns.tolist()
table_columns = list(table_schema.keys())
valid_columns = [col for col in df_columns if col in table_columns]
invalid_columns = [col for col in df_columns if col not in table_columns]
if invalid_columns:
self.log.warning("Dropping invalid columns not present in table",
table=table_name,
invalid_columns=invalid_columns,
count=len(invalid_columns))
cleaned_df = df[valid_columns].copy()
if cleaned_df.empty:
return cleaned_df
# 2. 处理每个字段的数据
for col in valid_columns:
col_info = table_schema[col]
data_type = col_info['type']
max_length = col_info['max_length']
# 2.1 处理空值
if cleaned_df[col].isnull().any():
# 根据字段类型设置默认值
default_value = '' if data_type in ['varchar', 'char', 'text'] else None
cleaned_df[col].fillna(default_value, inplace=True)
self.log.debug("Replaced null values",
table=table_name,
column=col,
default_value=default_value,
count=cleaned_df[col].isnull().sum())
# 2.2 处理字符串类型的超长字段
if data_type in ['varchar', 'char'] and max_length:
# 确保是字符串类型
cleaned_df[col] = cleaned_df[col].astype(str)
# 截断超长内容
too_long_mask = cleaned_df[col].str.len() > max_length
if too_long_mask.any():
cleaned_df.loc[too_long_mask, col] = cleaned_df.loc[too_long_mask, col].str.slice(0, max_length)
self.log.warning("Truncated overlength values",
table=table_name,
column=col,
max_length=max_length,
count=too_long_mask.sum())
# 2.3 处理日期时间类型
if data_type in ['datetime', 'timestamp']:
try:
# 尝试转换为datetime类型
cleaned_df[col] = pd.to_datetime(cleaned_df[col])
except Exception as e:
self.log.warning("Failed to convert to datetime, using current time",
table=table_name,
column=col,
error=str(e))
# 转换失败的用当前时间替代
invalid_mask = pd.to_datetime(cleaned_df[col], errors='coerce').isna()
cleaned_df.loc[invalid_mask, col] = datetime.now()
return cleaned_df
def update_from_df(self, table_name: str, df: pd.DataFrame,
key_columns: Union[str, List[str]]) -> int:
"""使用DataFrame数据更新数据库表(原有逻辑完全保留)"""
if df.empty:
self.log.warning("Attempted to update with empty DataFrame", table=table_name)
return 0
self.log.debug("Preparing to update table from DataFrame",
table=table_name,
key_columns=key_columns,
rows=len(df))
try:
if isinstance(key_columns, str):
key_columns = [key_columns]
total_updated = 0
with self.get_connection() as conn:
with conn.cursor() as cursor:
# 获取表结构信息
table_info = self._get_table_detailed_info(table_name)
columns = [col for col in df.columns if col in table_info]
# 构建UPDATE语句模板
set_clause = ', '.join([f"{col}=%s" for col in columns if col not in key_columns])
where_clause = ' AND '.join([f"{col}=%s" for col in key_columns])
if not set_clause:
self.log.warning("No columns to update", table=table_name)
return 0
update_sql = f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}"
self.log.trace("Generated update SQL", sql=update_sql)
# 准备数据
update_data = []
for _, row in df.iterrows():
set_values = [row[col] for col in columns if col not in key_columns]
key_values = [row[col] for col in key_columns]
update_data.append(tuple(set_values + key_values))
# 执行批量更新
cursor.executemany(update_sql, update_data)
total_updated = cursor.rowcount
conn.commit()
self.log.info("Data updated successfully",
table=table_name,
rows_updated=total_updated)
return total_updated
except Exception as e:
self.log.error("Data update failed",
table=table_name,
error=str(e),
exc_info=True)
raise
def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]:
"""推断DataFrame各列的SQL类型(原有逻辑完全保留)"""
type_mapping = {
'int64': 'BIGINT',
'float64': 'DOUBLE',
'datetime64[ns]': 'DATETIME',
'object': 'TEXT',
'bool': 'TINYINT(1)',
'category': 'VARCHAR(255)'
}
sql_types = {}
for col, dtype in df.dtypes.items():
dtype_str = str(dtype)
sql_types[col] = type_mapping.get(dtype_str, 'TEXT')
self.log.debug("Mapped DataFrame types to SQL types",
mappings=sql_types)
return sql_types
def create_table_from_df(self, table_name: str, df: pd.DataFrame,
primary_key: Union[str, List[str], None] = None) -> bool:
"""根据DataFrame结构创建表(原有逻辑完全保留)"""
if self.table_exists(table_name):
self.log.warning("Table already exists", table=table_name)
return False
self.log.debug("Creating new table from DataFrame schema",
table=table_name,
columns=list(df.columns))
try:
sql_types = self.df_to_sql_type(df)
columns_sql = []
for col, sql_type in sql_types.items():
col_def = f"{col} {sql_type}"
columns_sql.append(col_def)
if primary_key:
if isinstance(primary_key, str):
primary_key = [primary_key]
pk_columns = [col for col in primary_key if col in sql_types]
if pk_columns:
columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})")
self.log.trace("Set primary key",
table=table_name,
primary_key=pk_columns)
create_sql = f"CREATE TABLE {table_name} (\n {',\n '.join(columns_sql)}\n)"
self.execute_sql(create_sql)
self.log.info("Table created successfully", table=table_name)
return True
except Exception as e:
self.log.error("Failed to create table",
table=table_name,
error=str(e),
exc_info=True)
return False
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
"""执行SQL语句(原有逻辑完全保留)"""
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
# Linux/macOS需要更长的执行时间
if platform.system() != 'Windows':
cursor.execute("SET SESSION max_execution_time=600000")
cursor.execute(sql, params)
if fetch:
result = cursor.fetchall()
self.log.debug("Query executed", rows=len(result))
return result
else:
affected_rows = cursor.rowcount
conn.commit() # 立即提交
self.log.debug("Update executed", affected_rows=affected_rows)
return affected_rows
except Exception as e:
self.log.error("SQL execution failed",
sql=sql,
params=params,
error=str(e),
exc_info=True)
raise
def table_exists(self, table_name: str) -> bool:
"""检查表是否存在(原有逻辑完全保留)"""
sql = """
SELECT COUNT(*) as count
FROM `information_schema`.`tables`
WHERE `table_schema` = %s \
AND `table_name` = %s \
"""
params = (self.config['database'], table_name)
try:
result = self.execute_sql(sql, params, fetch=True)
exists = result[0][0] > 0 # 适配元组结果
self.log.debug("Checked table existence",
table=table_name,
exists=exists)
return exists
except Exception:
return False
def drop_table(self, table_name: str) -> bool:
"""删除表(原有逻辑完全保留)"""
if not self.table_exists(table_name):
self.log.warning("Table does not exist", table=table_name)
return False
try:
self.execute_sql(f"DROP TABLE {table_name}")
self.log.info("Table dropped successfully", table=table_name)
return True
except Exception as e:
self.log.error("Failed to drop table",
table=table_name,
error=str(e),
exc_info=True)
return False
def validate_connection(self) -> bool:
"""验证连接是否有效(原有逻辑完全保留)"""
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute("SELECT 1")
return cursor.fetchone()[0] == 1
except Exception:
return False
# 平台特定的默认配置(原有逻辑完全保留)
def get_default_config():
"""获取各平台默认配置"""
current_platform = platform.system()
base_config = {
'host': 'localhost',
'port': 3306,
'user': 'root',
'password': '123123',
'database': 'intelligence_system',
}
if current_platform == 'Windows':
return {**base_config,
'connect_timeout': 10,
'read_timeout': 30,
'write_timeout': 30
}
elif current_platform == 'Darwin':
return {
**base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
}
else: # Linux和其他平台
return {**base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60
}
if __name__ == "__main__":
# 使用示例(原有逻辑完全保留)
db = MySQLAgent(get_default_config())
# 测试连接
if db.validate_connection():
print("Database connection successful")
# 获取数据库版本
version = db.query_to_df("SELECT VERSION() as version")
print(f"Database version: {version['version'].iloc[0]}")
else:
print("Failed to connect to database")