import os import sys import platform import pandas as pd import pymysql import json import numpy as np from pymysql import cursors from pymysql.err import MySQLError from typing import Union, List, Dict, Any, Optional, Tuple, Literal import threading from datetime import datetime from pathlib import Path # 导入日志系统 from utils.logger import log class MySQLAgent: """ 全平台兼容的MySQL数据库操作类 支持Windows/macOS/Linux系统 配置参数从外部传入,不使用连接池和事务管理 """ _instance = None _lock = threading.Lock() def __new__(cls, *args, **kwargs): if not cls._instance: with cls._lock: if not cls._instance: cls._instance = super().__new__(cls) return cls._instance def __init__(self, config: dict): """初始化MySQL数据库连接(原有逻辑完全保留)""" if hasattr(self, 'config') and self.config: return # 基础配置校验 required_keys = ['host', 'port', 'user', 'password', 'database'] if not all(key in config for key in required_keys): log.warning(f"数据库配置缺少必要参数,当前数据库链接信息为:{config}") raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}") self.config = { 'host': config['host'], 'port': config['port'], 'user': config['user'], 'password': config['password'], 'database': config['database'], 'charset': config.get('charset', 'utf8mb4'), 'autocommit': True, 'connect_timeout': config.get('connect_timeout', 10), 'read_timeout': config.get('read_timeout', 30), 'write_timeout': config.get('write_timeout', 30), 'ssl': config.get('ssl') } # 初始化日志 current_platform = platform.system() self.log = log.bind(module=f"MySQLAgent({current_platform})") def get_connection(self) -> pymysql.connections.Connection: """获取数据库连接(原有逻辑完全保留)""" try: conn = pymysql.connect(** self.config) # 为连接添加 character_set_name 方法 if not hasattr(conn, 'character_set_name'): def _character_set_name(): return self.config.get('charset', 'utf8mb4') conn.character_set_name = _character_set_name # macOS需要特殊处理SSL if platform.system() == 'Darwin' and self.config.get('ssl'): conn.ping(reconnect=True) self.log.trace("获取数据库连接成功") return conn except Exception as e: error_msg = str(e) if platform.system() == 'Windows' and "timed out" in error_msg: self.log.warning("Windows连接超时,正在重试...") return self._retry_connection() self.log.error("连接失败", error=error_msg, error_type=type(e).__name__, host=self.config.get('host'), port=self.config.get('port'), database=self.config.get('database'), exc_info=True) raise def _retry_connection(self, max_retries: int = 3) -> Any | None: """Windows平台连接重试机制(原有逻辑完全保留)""" for attempt in range(max_retries): try: conn = pymysql.connect(**self.config) self.log.info(f"经过 {attempt + 1} 次尝试后成功建立连接") return conn except Exception: if attempt == max_retries - 1: raise import time time.sleep(1) def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None, parse_dates: Union[List[str], bool] = True,is_print = True) -> pd.DataFrame: """执行SQL查询并返回DataFrame(原有逻辑完全保留)""" try: self.log.debug("执行SQL查询", sql=sql) # 获取连接并确保字符集方法存在 conn = self.get_connection() # 创建SQLAlchemy引擎 from sqlalchemy import create_engine from sqlalchemy.pool import StaticPool engine = create_engine( "mysql+pymysql://", creator=lambda: conn, poolclass=StaticPool, connect_args={'charset': self.config.get('charset', 'utf8mb4')} ) # 执行查询 df = pd.read_sql(sql, engine, params=params, parse_dates=parse_dates) if is_print: self.log.info("查询执行成功", 行数=len(df)) return df except Exception as e: self.log.error("SQL查询失败", sql=sql, params=params, error=str(e), error_type=type(e).__name__, exc_info=True) raise finally: if 'engine' in locals(): engine.dispose() def insert_from_df(self, table_name: str, df: pd.DataFrame, chunk_size: int = 1000, replace: bool = False, ignore_duplicates: bool = None) -> int: """ 兼容旧接口的通用插入方法:保留replace参数,同时支持新的ignore_duplicates 自动处理重复数据,对所有数据源通用,插入失败的数据会通过日志记录 安全性说明: - 使用 INSERT INTO(不是 REPLACE INTO 或 INSERT ... ON DUPLICATE KEY UPDATE) - 当 ignore_duplicates=True 时,重复记录会被跳过,不会覆盖或删除现有数据 - 如果数据库连接失败,操作会抛出异常,不会部分成功 - 所有操作都是安全的,不会导致数据丢失或覆盖 """ # 【兼容性处理】如果未指定ignore_duplicates,用replace参数推导 if ignore_duplicates is None: ignore_duplicates = not replace # 旧逻辑中replace=True表示替换,即不忽略重复 if df.empty: self.log.warning("尝试插入空的DataFrame", table=table_name) return 0 conn = None cursor = None total_inserted = 0 total_duplicates = 0 total_failed = 0 failed_records = [] # 存储所有失败的记录 try: # 1. 建立数据库连接 conn = self.get_connection() cursor = conn.cursor() self.log.debug(f"已建立连接,准备插入数据到 {table_name}") # 2. 获取数据库表的实际列名 cursor.execute(f"SHOW COLUMNS FROM `{table_name}`") columns_info = cursor.fetchall() db_columns = [col[0] for col in columns_info] self.log.debug(f"表 {table_name} 包含以下列:{db_columns}") # 3. 数据预处理:统一处理空值 cleaned_df = df.replace( [None, np.nan, pd.NA, 'nan', 'NaN', 'NAN', ''], None ).copy() # 4. 字段匹配:只保留与数据库匹配的列 df_columns = cleaned_df.columns.tolist() matched_columns = [col for col in df_columns if col in db_columns] unmatched_columns = [col for col in df_columns if col not in db_columns] if unmatched_columns: self.log.warning( f"表 {table_name} 中存在不匹配的列,已自动丢弃", unmatched_columns=unmatched_columns, count=len(unmatched_columns) ) if not matched_columns: self.log.warning(f"表 {table_name} 没有匹配的列,终止插入操作") return 0 filtered_df = cleaned_df[matched_columns].copy() total_to_insert = len(filtered_df) self.log.debug( f"表 {table_name} 的过滤后DataFrame:共 {total_to_insert} 行待插入" ) # 5. 处理复杂类型(dict/list转JSON) for col in filtered_df.columns: has_complex_type = filtered_df[col].apply( lambda x: isinstance(x, (dict, list)) if x is not None else False ).any() if has_complex_type: self.log.debug(f"表 {table_name} 中的 {col} 列包含复杂类型,正在转换为JSON") filtered_df.loc[:, col] = filtered_df[col].apply( lambda x: json.dumps(x, ensure_ascii=False) if x is not None else x ) # 6. 构建通用插入SQL columns_str = ', '.join([f"`{col}`" for col in filtered_df.columns]) placeholders = ', '.join(['%s'] * len(filtered_df.columns)) insert_sql = f"INSERT INTO `{table_name}` ({columns_str}) VALUES ({placeholders})" self.log.trace(f"为表 {table_name} 生成的插入SQL:{insert_sql}") # 7. 逐条插入(确保能捕获单条重复错误) records = filtered_df.to_dict('records') indices = filtered_df.index.tolist() for i, (record, idx) in enumerate(zip(records, indices)): try: data = tuple(record[col] for col in filtered_df.columns) cursor.execute(insert_sql, data) total_inserted += 1 if (i + 1) % 100 == 0: self.log.trace( f"已向表 {table_name} 插入 {i + 1}/{total_to_insert} 行数据" ) except MySQLError as e: # 8. 捕获重复错误(MySQL错误码1062) if e.args[0] == 1062: total_duplicates += 1 short_record = { k: (str(v)[:100] + '...') if isinstance(v, (str, dict, list)) else v for k, v in record.items() } self.log.warning( f"表 {table_name} 中跳过重复记录", index=idx, error_message=e.args[1], record=short_record ) # 记录重复的记录 failed_records.append({ 'index': idx, 'type': 'duplicate', 'error_code': e.args[0], 'error_message': e.args[1], 'record': record }) if not ignore_duplicates: raise else: # 其他数据库错误 total_failed += 1 # 记录失败的记录详情 failed_records.append({ 'index': idx, 'type': 'error', 'error_code': e.args[0], 'error_message': e.args[1], 'record': record }) self.log.error( f"表 {table_name} 插入记录失败", index=idx, error_code=e.args[0], error_message=e.args[1], record=record # 完整记录写入日志 ) if not ignore_duplicates: raise # 提交事务 conn.commit() # 9. 插入结果统计,包括失败记录汇总 self.log.info( f"表 {table_name} 插入结果汇总", total_to_insert=total_to_insert, total_inserted=total_inserted, total_duplicates=total_duplicates, total_failed=total_failed, failed_records_count=len(failed_records) ) # 单独记录所有失败的数据详情 if failed_records: self.log.error( f"表 {table_name} 插入失败记录详情", failed_records_summary=[ { 'index': r['index'], 'type': r['type'], 'error_code': r['error_code'], 'error_message': r['error_message'] } for r in failed_records ], # 完整记录可以作为调试信息单独记录,避免日志过大 detailed_failed_records=failed_records ) return total_inserted except Exception as e: if conn: conn.rollback() self.log.error(f"表 {table_name} 批量插入失败", error=str(e), error_type=type(e).__name__, table_name=table_name, total_records=len(df) if not df.empty else 0, exc_info=True) # 记录事务回滚时的失败记录 if failed_records: self.log.error( f"表 {table_name} 事务回滚,已失败的记录", failed_records=failed_records, failed_count=len(failed_records) ) raise finally: if cursor: cursor.close() if conn: conn.close() def _get_primary_key(self, table_name: str, cursor) -> Optional[str]: """【新增辅助方法】获取表的主键(用于replace逻辑的去重)""" try: cursor.execute(""" SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s AND CONSTRAINT_NAME = 'PRIMARY' """, (self.config['database'], table_name)) result = cursor.fetchone() return result[0] if result else None except Exception as e: self.log.warning(f"获取表 {table_name} 的主键失败", error=str(e)) return None def _get_table_detailed_info(self, table_name: str) -> Dict[str, Dict[str, Any]]: """获取表的详细结构信息(原有逻辑完全保留,供其他方法调用)""" sql = """ SELECT column_name, data_type, character_maximum_length FROM information_schema.columns WHERE table_schema = %s \ AND table_name = %s \ """ params = (self.config['database'], table_name) try: conn = self.get_connection() try: cursor = conn.cursor() cursor.execute(sql, params) result = cursor.fetchall() # 强制转换为列表,避免游标类型导致的解析问题 result_list = list(result) if not result_list: self.log.error("未在表中找到任何列", 表=table_name) return {} schema = {} for row in result_list: # 确保正确提取字段名(兼容元组格式) col_name = str(row[0]).strip() # 强制转为字符串并去空格 data_type = str(row[1]).strip() max_length = row[2] if row[2] else None schema[col_name] = { 'type': data_type, 'max_length': max_length } self.log.debug("成功获取表结构信息", 表=table_name, 列=list(schema.keys())) return schema finally: cursor.close() conn.close() except Exception as e: self.log.error("获取表详细信息失败", 表=table_name, error=str(e)) raise def _validate_and_clean_data(self, df: pd.DataFrame, table_name: str, table_schema: Dict[str, Dict[str, Any]]) -> pd.DataFrame: """数据校验与清洗(原有逻辑完全保留,供其他方法调用)""" # 1. 字段过滤:只保留表中存在的字段 df_columns = df.columns.tolist() table_columns = list(table_schema.keys()) valid_columns = [col for col in df_columns if col in table_columns] invalid_columns = [col for col in df_columns if col not in table_columns] if invalid_columns: self.log.warning("丢弃表中不存在的无效列", 表=table_name, 无效列=invalid_columns, 数量=len(invalid_columns)) cleaned_df = df[valid_columns].copy() if cleaned_df.empty: return cleaned_df # 2. 处理每个字段的数据 for col in valid_columns: col_info = table_schema[col] data_type = col_info['type'] max_length = col_info['max_length'] # 2.1 处理空值 if cleaned_df[col].isnull().any(): # 根据字段类型设置默认值 default_value = '' if data_type in ['varchar', 'char', 'text'] else None cleaned_df[col].fillna(default_value, inplace=True) self.log.debug("替换空值", 表=table_name, 列=col, 默认值=default_value, 数量=cleaned_df[col].isnull().sum()) # 2.2 处理字符串类型的超长字段 if data_type in ['varchar', 'char'] and max_length: # 确保是字符串类型 cleaned_df[col] = cleaned_df[col].astype(str) # 截断超长内容 too_long_mask = cleaned_df[col].str.len() > max_length if too_long_mask.any(): cleaned_df.loc[too_long_mask, col] = cleaned_df.loc[too_long_mask, col].str.slice(0, max_length) self.log.warning("截断超长值", 表=table_name, 列=col, 最大长度=max_length, 数量=too_long_mask.sum()) # 2.3 处理日期时间类型 if data_type in ['datetime', 'timestamp']: try: # 尝试转换为datetime类型 cleaned_df[col] = pd.to_datetime(cleaned_df[col]) except Exception as e: self.log.warning("转换为datetime失败,使用当前时间替代", 表=table_name, 列=col, 错误=str(e)) # 转换失败的用当前时间替代 invalid_mask = pd.to_datetime(cleaned_df[col], errors='coerce').isna() cleaned_df.loc[invalid_mask, col] = datetime.now() return cleaned_df def update_from_df(self, table_name: str, df: pd.DataFrame, key_columns: Union[str, List[str]]) -> int: """使用DataFrame数据更新数据库表(原有逻辑完全保留)""" if df.empty: self.log.warning("尝试使用空的DataFrame进行更新", 表=table_name) return 0 self.log.debug("准备从DataFrame更新表数据", 表=table_name, 关键字列=key_columns, 行数=len(df)) try: if isinstance(key_columns, str): key_columns = [key_columns] 总更新数 = 0 with self.get_connection() as conn: with conn.cursor() as cursor: # 获取表结构信息 table_info = self._get_table_detailed_info(table_name) columns = [col for col in df.columns if col in table_info] # 构建UPDATE语句模板 set_clause = ', '.join([f"{col}=%s" for col in columns if col not in key_columns]) where_clause = ' AND '.join([f"{col}=%s" for col in key_columns]) if not set_clause: self.log.warning("没有可更新的列", 表=table_name) return 0 update_sql = f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}" self.log.trace("生成的更新SQL", sql=update_sql) # 准备数据 update_data = [] for _, row in df.iterrows(): set_values = [row[col] for col in columns if col not in key_columns] key_values = [row[col] for col in key_columns] update_data.append(tuple(set_values + key_values)) # 执行批量更新 cursor.executemany(update_sql, update_data) 总更新数 = cursor.rowcount conn.commit() self.log.info("数据更新成功", 表=table_name, 更新行数=总更新数) return 总更新数 except Exception as e: self.log.error("数据更新失败", 表=table_name, error=str(e), exc_info=True) raise def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]: """推断DataFrame各列的SQL类型(原有逻辑完全保留)""" type_mapping = { 'int64': 'BIGINT', 'float64': 'DOUBLE', 'datetime64[ns]': 'DATETIME', 'object': 'TEXT', 'bool': 'TINYINT(1)', 'category': 'VARCHAR(255)' } sql_types = {} for col, dtype in df.dtypes.items(): dtype_str = str(dtype) sql_types[col] = type_mapping.get(dtype_str, 'TEXT') self.log.debug("将DataFrame类型映射为SQL类型", 映射关系=sql_types) return sql_types def create_table_from_df(self, table_name: str, df: pd.DataFrame, primary_key: Union[str, List[str], None] = None) -> bool: """根据DataFrame结构创建表(原有逻辑完全保留)""" if self.table_exists(table_name): self.log.warning("表已存在", 表=table_name) return False self.log.debug("根据DataFrame结构创建新表", 表=table_name, 列=list(df.columns)) try: sql_types = self.df_to_sql_type(df) columns_sql = [] for col, sql_type in sql_types.items(): col_def = f"{col} {sql_type}" columns_sql.append(col_def) if primary_key: if isinstance(primary_key, str): primary_key = [primary_key] pk_columns = [col for col in primary_key if col in sql_types] if pk_columns: columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})") self.log.trace("设置主键", 表=table_name, 主键=pk_columns) create_sql = f"CREATE TABLE {table_name} (\n {',\n '.join(columns_sql)}\n)" self.execute_sql(create_sql) self.log.info("表创建成功", 表=table_name) return True except Exception as e: self.log.error("创建表失败", 表=table_name, error=str(e), exc_info=True) return False def create_table_if_not_exists(self, table_name: str, create_sql: str) -> bool: """ 创建表(如果不存在) 使用 CREATE TABLE IF NOT EXISTS,不会删除已存在的表和数据 参数: table_name: 表名 create_sql: 完整的 CREATE TABLE SQL 语句(必须包含 IF NOT EXISTS) 返回: bool: 是否成功(表已存在也会返回True) 注意: - 此方法使用 CREATE TABLE IF NOT EXISTS,是安全的,不会删除现有数据 - 如果连接失败,会抛出异常 """ if "IF NOT EXISTS" not in create_sql.upper(): self.log.warning(f"CREATE TABLE 语句建议使用 IF NOT EXISTS 以保证安全性") try: self.execute_sql(create_sql) self.log.info(f"成功创建/检查表(表已存在时不会删除数据): {table_name}") return True except Exception as e: self.log.error(f"创建/检查表失败(可能是数据库连接问题): {str(e)}", table=table_name, exc_info=True) raise def add_unique_index_if_not_exists(self, table_name: str, index_name: str, column_name: str, column_length: int = 500, check_duplicates: bool = True) -> bool: """ 添加唯一索引(如果不存在) 不会删除数据,只添加索引 参数: table_name: 表名 index_name: 索引名称 column_name: 要添加索引的列名 column_length: 索引长度(对于VARCHAR/TEXT类型) check_duplicates: 是否在添加索引前检查重复数据 返回: bool: 是否成功添加索引(索引已存在也会返回True) 注意: - 此方法是安全的,不会删除数据 - 如果表中存在重复数据,会跳过添加索引(不会删除数据) - 如果连接失败,会抛出异常 """ try: # 1. 检查索引是否已存在 check_index_sql = f""" SELECT COUNT(*) as cnt FROM INFORMATION_SCHEMA.STATISTICS WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s AND INDEX_NAME = %s """ result = self.query_to_df( check_index_sql, params=(self.config['database'], table_name, index_name), is_print=False ) if not result.empty and result['cnt'].iloc[0] > 0: self.log.debug(f"唯一索引 {index_name} 已存在,跳过添加") return True # 2. 如果启用重复检查,先检查是否有重复数据 if check_duplicates: check_duplicates_sql = f""" SELECT {column_name}, COUNT(*) as cnt FROM `{table_name}` WHERE {column_name} IS NOT NULL AND {column_name} != '' GROUP BY {column_name} HAVING cnt > 1 LIMIT 1 """ duplicates = self.query_to_df(check_duplicates_sql, is_print=False) if not duplicates.empty: self.log.warning( f"表 {table_name} 中存在重复的 {column_name} 数据,无法添加唯一索引。" "现有数据不会被删除。", duplicate_count=len(duplicates) ) return False # 3. 添加唯一索引 add_index_sql = f""" ALTER TABLE `{table_name}` ADD UNIQUE KEY `{index_name}` ({column_name}({column_length})) """ self.execute_sql(add_index_sql) self.log.info(f"成功添加唯一索引 {index_name}(现有数据不受影响)") return True except Exception as e: error_msg = str(e) # 如果索引已存在,不报错 if "Duplicate key name" in error_msg or "already exists" in error_msg.lower(): self.log.debug(f"唯一索引 {index_name} 已存在,跳过添加") return True else: self.log.warning(f"添加唯一索引时出现问题(不影响现有数据): {error_msg}") raise def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None, fetch: bool = False) -> Union[int, List[Dict[str, Any]]]: """执行SQL语句(原有逻辑完全保留)""" try: with self.get_connection() as conn: with conn.cursor() as cursor: # Linux/macOS需要更长的执行时间 if platform.system() != 'Windows': cursor.execute("SET SESSION max_execution_time=600000") cursor.execute(sql, params) if fetch: result = cursor.fetchall() self.log.debug("查询执行完成", 行数=len(result)) return result else: affected_rows = cursor.rowcount conn.commit() # 立即提交 self.log.debug("更新执行完成", 受影响行数=affected_rows) return affected_rows except Exception as e: self.log.error("SQL执行失败", sql=sql, params=params, error=str(e), error_type=type(e).__name__, exc_info=True) raise def table_exists(self, table_name: str) -> bool: """检查表是否存在(原有逻辑完全保留)""" sql = """ SELECT COUNT(*) as count FROM `information_schema`.`tables` WHERE `table_schema` = %s \ AND `table_name` = %s \ """ params = (self.config['database'], table_name) try: result = self.execute_sql(sql, params, fetch=True) exists = result[0][0] > 0 # 适配元组结果 self.log.debug("检查表是否存在", 表=table_name, 存在=exists) return exists except Exception: return False def drop_table(self, table_name: str) -> bool: """删除表(原有逻辑完全保留)""" if not self.table_exists(table_name): self.log.warning("表不存在", 表=table_name) return False try: self.execute_sql(f"DROP TABLE {table_name}") self.log.info("表删除成功", 表=table_name) return True except Exception as e: self.log.error("删除表失败", 表=table_name, error=str(e), exc_info=True) return False def validate_connection(self) -> bool: """验证连接是否有效(原有逻辑完全保留)""" try: with self.get_connection() as conn: with conn.cursor() as cursor: cursor.execute("SELECT 1") return cursor.fetchone()[0] == 1 except Exception: return False # 平台特定的默认配置(原有逻辑完全保留) def get_default_config(): """获取各平台默认配置""" current_platform = platform.system() base_config = { 'host': 'localhost', 'port': 3306, 'user': 'root', 'password': '123123', 'database': 'intelligence_system', } if current_platform == 'Windows': return {**base_config, 'connect_timeout': 10, 'read_timeout': 30, 'write_timeout': 30 } elif current_platform == 'Darwin': return { **base_config, 'connect_timeout': 15, 'read_timeout': 60, 'write_timeout': 60, 'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'} } else: # Linux和其他平台 return {** base_config, 'connect_timeout': 15, 'read_timeout': 60, 'write_timeout': 60 } if __name__ == "__main__": # 使用示例(原有逻辑完全保留) db = MySQLAgent(get_default_config()) # 测试连接 if db.validate_connection(): print("数据库连接成功") # 获取数据库版本 version = db.query_to_df("SELECT VERSION() as version") print(f"数据库版本: {version['version'].iloc[0]}") else: print("连接数据库失败")