mysql数据链接更新

This commit is contained in:
z66
2025-09-19 13:57:55 +08:00
parent 20fd9587ee
commit 8bf6bfe2fb
5 changed files with 560 additions and 112 deletions
+139 -99
View File
@@ -65,7 +65,7 @@ class MySQLAgent:
def get_connection(self) -> pymysql.connections.Connection:
"""获取数据库连接(原有逻辑完全保留)"""
try:
conn = pymysql.connect(**self.config)
conn = pymysql.connect(** self.config)
# 为连接添加 character_set_name 方法
if not hasattr(conn, 'character_set_name'):
@@ -78,16 +78,16 @@ class MySQLAgent:
if platform.system() == 'Darwin' and self.config.get('ssl'):
conn.ping(reconnect=True)
self.log.trace("Database connection obtained")
self.log.trace("获取数据库连接成功")
return conn
except Exception as e:
error_msg = str(e)
if platform.system() == 'Windows' and "timed out" in error_msg:
self.log.warning("Windows connection timeout, retrying...")
self.log.warning("Windows连接超时,正在重试...")
return self._retry_connection()
self.log.error("Connection failed", error=error_msg, exc_info=True)
self.log.error("连接失败", error=error_msg, exc_info=True)
raise
def _retry_connection(self, max_retries: int = 3) -> Any | None:
@@ -95,7 +95,7 @@ class MySQLAgent:
for attempt in range(max_retries):
try:
conn = pymysql.connect(**self.config)
self.log.info(f"Connection established after {attempt + 1} attempts")
self.log.info(f"经过 {attempt + 1} 次尝试后成功建立连接")
return conn
except Exception:
if attempt == max_retries - 1:
@@ -107,7 +107,7 @@ class MySQLAgent:
parse_dates: Union[List[str], bool] = True) -> pd.DataFrame:
"""执行SQL查询并返回DataFrame(原有逻辑完全保留)"""
try:
self.log.debug("Executing SQL query", sql=sql)
self.log.debug("执行SQL查询", sql=sql)
# 获取连接并确保字符集方法存在
conn = self.get_connection()
@@ -124,49 +124,50 @@ class MySQLAgent:
# 执行查询
df = pd.read_sql(sql, engine, params=params, parse_dates=parse_dates)
self.log.info("Query executed successfully", rows=len(df))
self.log.info("查询执行成功", 行数=len(df))
return df
except Exception as e:
self.log.error("SQL query failed", sql=sql, params=params, error=str(e), exc_info=True)
self.log.error("SQL查询失败", sql=sql, params=params, error=str(e), exc_info=True)
raise
finally:
if 'engine' in locals():
engine.dispose()
def insert_from_df(self, table_name: str, df: pd.DataFrame,
chunk_size: int = 1000, replace: bool = False, # 保留replace参数
ignore_duplicates: bool = None) -> int: # 新增ignore_duplicates参数
chunk_size: int = 1000, replace: bool = False,
ignore_duplicates: bool = None) -> int:
"""
兼容旧接口的通用插入方法:保留replace参数,同时支持新的ignore_duplicates
自动处理重复数据,对所有数据源通用
自动处理重复数据,对所有数据源通用,插入失败的数据会通过日志记录
"""
# 【兼容性处理】如果未指定ignore_duplicates,用replace参数推导replace=True时不忽略重复)
# 【兼容性处理】如果未指定ignore_duplicates,用replace参数推导
if ignore_duplicates is None:
ignore_duplicates = not replace # 旧逻辑中replace=True表示替换,即不忽略重复
if df.empty:
self.log.warning("Attempted to insert empty DataFrame", table=table_name)
self.log.warning("尝试插入空的DataFrame", table=table_name)
return 0
conn = None
cursor = None
total_inserted = 0
total_duplicated = 0
total_duplicates = 0
total_failed = 0
failed_records = [] # 存储所有失败的记录
try:
# 1. 建立数据库连接
conn = self.get_connection()
cursor = conn.cursor()
self.log.debug(f"Established connection for inserting into {table_name}")
self.log.debug(f"已建立连接,准备插入数据到 {table_name}")
# 2. 获取数据库表的实际列名
cursor.execute(f"SHOW COLUMNS FROM `{table_name}`")
columns_info = cursor.fetchall()
db_columns = [col[0] for col in columns_info]
self.log.debug(f"Table {table_name} has columns: {db_columns}")
self.log.debug(f" {table_name} 包含以下列:{db_columns}")
# 3. 数据预处理:统一处理空值
cleaned_df = df.replace(
@@ -181,19 +182,19 @@ class MySQLAgent:
if unmatched_columns:
self.log.warning(
f"Table {table_name} dropping unmatched columns",
f" {table_name} 中存在不匹配的列,已自动丢弃",
unmatched_columns=unmatched_columns,
count=len(unmatched_columns)
)
if not matched_columns:
self.log.warning(f"No matched columns for {table_name}, abort insertion")
self.log.warning(f"{table_name} 没有匹配的列,终止插入操作")
return 0
filtered_df = cleaned_df[matched_columns].copy()
total_to_insert = len(filtered_df)
self.log.debug(
f"Filtered DataFrame for {table_name}: {total_to_insert} rows to insert"
f"{table_name} 的过滤后DataFrame:共 {total_to_insert} 行待插入"
)
# 5. 处理复杂类型(dict/list转JSON
@@ -203,7 +204,7 @@ class MySQLAgent:
).any()
if has_complex_type:
self.log.debug(f"Column {col} in {table_name} has complex type, converting to JSON")
self.log.debug(f" {table_name} 中的 {col} 列包含复杂类型,正在转换为JSON")
filtered_df.loc[:, col] = filtered_df[col].apply(
lambda x: json.dumps(x, ensure_ascii=False) if x is not None else x
)
@@ -212,7 +213,7 @@ class MySQLAgent:
columns_str = ', '.join([f"`{col}`" for col in filtered_df.columns])
placeholders = ', '.join(['%s'] * len(filtered_df.columns))
insert_sql = f"INSERT INTO `{table_name}` ({columns_str}) VALUES ({placeholders})"
self.log.trace(f"Generated insert SQL for {table_name}: {insert_sql}")
self.log.trace(f"为表 {table_name} 生成的插入SQL{insert_sql}")
# 7. 逐条插入(确保能捕获单条重复错误)
records = filtered_df.to_dict('records')
@@ -226,34 +227,50 @@ class MySQLAgent:
if (i + 1) % 100 == 0:
self.log.trace(
f"Inserted {i + 1}/{total_to_insert} rows into {table_name}"
f"已向表 {table_name} 插入 {i + 1}/{total_to_insert} 行数据"
)
except MySQLError as e:
# 8. 捕获重复错误(MySQL错误码1062)
if e.args[0] == 1062:
total_duplicated += 1
total_duplicates += 1
short_record = {
k: (str(v)[:100] + '...') if isinstance(v, (str, dict, list)) else v
for k, v in record.items()
}
self.log.warning(
f"Skipped duplicate record in {table_name}",
f"{table_name} 中跳过重复记录",
index=idx,
error_msg=e.args[1],
error_message=e.args[1],
record=short_record
)
# 记录重复的记录
failed_records.append({
'index': idx,
'type': 'duplicate',
'error_code': e.args[0],
'error_message': e.args[1],
'record': record
})
if not ignore_duplicates:
raise
else:
# 其他数据库错误
total_failed += 1
# 记录失败的记录详情
failed_records.append({
'index': idx,
'type': 'error',
'error_code': e.args[0],
'error_message': e.args[1],
'record': record
})
self.log.error(
f"Failed to insert record in {table_name}",
f"{table_name} 插入记录失败",
index=idx,
error_code=e.args[0],
error_msg=e.args[1],
record=record
error_message=e.args[1],
record=record # 完整记录写入日志
)
if not ignore_duplicates:
raise
@@ -261,21 +278,44 @@ class MySQLAgent:
# 提交事务
conn.commit()
# 9. 插入结果统计
# 9. 插入结果统计,包括失败记录汇总
self.log.info(
f"Insertion summary for {table_name}",
f"{table_name} 插入结果汇总",
total_to_insert=total_to_insert,
total_inserted=total_inserted,
total_duplicated=total_duplicated,
total_failed=total_failed
total_duplicates=total_duplicates,
total_failed=total_failed,
failed_records_count=len(failed_records)
)
# 单独记录所有失败的数据详情
if failed_records:
self.log.error(
f"{table_name} 插入失败记录详情",
failed_records_summary=[
{
'index': r['index'],
'type': r['type'],
'error_code': r['error_code'],
'error_message': r['error_message']
} for r in failed_records
],
# 完整记录可以作为调试信息单独记录,避免日志过大
detailed_failed_records=failed_records
)
return total_inserted
except Exception as e:
if conn:
conn.rollback()
self.log.error(f"Batch insertion failed for {table_name}", error=str(e), exc_info=True)
self.log.error(f"{table_name} 批量插入失败", error=str(e), exc_info=True)
# 记录事务回滚时的失败记录
if failed_records:
self.log.error(
f"{table_name} 事务回滚,已失败的记录",
failed_records=failed_records
)
raise
finally:
if cursor:
@@ -296,7 +336,7 @@ class MySQLAgent:
result = cursor.fetchone()
return result[0] if result else None
except Exception as e:
self.log.warning(f"Failed to get primary key for {table_name}", error=str(e))
self.log.warning(f"获取表 {table_name} 的主键失败", error=str(e))
return None
def _get_table_detailed_info(self, table_name: str) -> Dict[str, Dict[str, Any]]:
@@ -319,7 +359,7 @@ class MySQLAgent:
# 强制转换为列表,避免游标类型导致的解析问题
result_list = list(result)
if not result_list:
self.log.error("No columns found in table", table=table_name)
self.log.error("未在表中找到任何列", =table_name)
return {}
schema = {}
@@ -334,16 +374,16 @@ class MySQLAgent:
'max_length': max_length
}
self.log.debug("Successfully fetched table schema",
table=table_name,
columns=list(schema.keys()))
self.log.debug("成功获取表结构信息",
=table_name,
=list(schema.keys()))
return schema
finally:
cursor.close()
conn.close()
except Exception as e:
self.log.error("Failed to get table detailed info",
table=table_name,
self.log.error("获取表详细信息失败",
=table_name,
error=str(e))
raise
@@ -358,10 +398,10 @@ class MySQLAgent:
invalid_columns = [col for col in df_columns if col not in table_columns]
if invalid_columns:
self.log.warning("Dropping invalid columns not present in table",
table=table_name,
invalid_columns=invalid_columns,
count=len(invalid_columns))
self.log.warning("丢弃表中不存在的无效列",
=table_name,
无效列=invalid_columns,
数量=len(invalid_columns))
cleaned_df = df[valid_columns].copy()
if cleaned_df.empty:
@@ -378,11 +418,11 @@ class MySQLAgent:
# 根据字段类型设置默认值
default_value = '' if data_type in ['varchar', 'char', 'text'] else None
cleaned_df[col].fillna(default_value, inplace=True)
self.log.debug("Replaced null values",
table=table_name,
column=col,
default_value=default_value,
count=cleaned_df[col].isnull().sum())
self.log.debug("替换空值",
=table_name,
=col,
默认值=default_value,
数量=cleaned_df[col].isnull().sum())
# 2.2 处理字符串类型的超长字段
if data_type in ['varchar', 'char'] and max_length:
@@ -392,11 +432,11 @@ class MySQLAgent:
too_long_mask = cleaned_df[col].str.len() > max_length
if too_long_mask.any():
cleaned_df.loc[too_long_mask, col] = cleaned_df.loc[too_long_mask, col].str.slice(0, max_length)
self.log.warning("Truncated overlength values",
table=table_name,
column=col,
max_length=max_length,
count=too_long_mask.sum())
self.log.warning("截断超长值",
=table_name,
=col,
最大长度=max_length,
数量=too_long_mask.sum())
# 2.3 处理日期时间类型
if data_type in ['datetime', 'timestamp']:
@@ -404,10 +444,10 @@ class MySQLAgent:
# 尝试转换为datetime类型
cleaned_df[col] = pd.to_datetime(cleaned_df[col])
except Exception as e:
self.log.warning("Failed to convert to datetime, using current time",
table=table_name,
column=col,
error=str(e))
self.log.warning("转换为datetime失败,使用当前时间替代",
=table_name,
=col,
错误=str(e))
# 转换失败的用当前时间替代
invalid_mask = pd.to_datetime(cleaned_df[col], errors='coerce').isna()
cleaned_df.loc[invalid_mask, col] = datetime.now()
@@ -418,19 +458,19 @@ class MySQLAgent:
key_columns: Union[str, List[str]]) -> int:
"""使用DataFrame数据更新数据库表(原有逻辑完全保留)"""
if df.empty:
self.log.warning("Attempted to update with empty DataFrame", table=table_name)
self.log.warning("尝试使用空的DataFrame进行更新", =table_name)
return 0
self.log.debug("Preparing to update table from DataFrame",
table=table_name,
key_columns=key_columns,
rows=len(df))
self.log.debug("准备从DataFrame更新表数据",
=table_name,
关键字列=key_columns,
行数=len(df))
try:
if isinstance(key_columns, str):
key_columns = [key_columns]
total_updated = 0
总更新数 = 0
with self.get_connection() as conn:
with conn.cursor() as cursor:
# 获取表结构信息
@@ -442,11 +482,11 @@ class MySQLAgent:
where_clause = ' AND '.join([f"{col}=%s" for col in key_columns])
if not set_clause:
self.log.warning("No columns to update", table=table_name)
self.log.warning("没有可更新的列", =table_name)
return 0
update_sql = f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}"
self.log.trace("Generated update SQL", sql=update_sql)
self.log.trace("生成的更新SQL", sql=update_sql)
# 准备数据
update_data = []
@@ -457,17 +497,17 @@ class MySQLAgent:
# 执行批量更新
cursor.executemany(update_sql, update_data)
total_updated = cursor.rowcount
总更新数 = cursor.rowcount
conn.commit()
self.log.info("Data updated successfully",
table=table_name,
rows_updated=total_updated)
return total_updated
self.log.info("数据更新成功",
=table_name,
更新行数=总更新数)
return 总更新数
except Exception as e:
self.log.error("Data update failed",
table=table_name,
self.log.error("数据更新失败",
=table_name,
error=str(e),
exc_info=True)
raise
@@ -488,20 +528,20 @@ class MySQLAgent:
dtype_str = str(dtype)
sql_types[col] = type_mapping.get(dtype_str, 'TEXT')
self.log.debug("Mapped DataFrame types to SQL types",
mappings=sql_types)
self.log.debug("DataFrame类型映射为SQL类型",
映射关系=sql_types)
return sql_types
def create_table_from_df(self, table_name: str, df: pd.DataFrame,
primary_key: Union[str, List[str], None] = None) -> bool:
"""根据DataFrame结构创建表(原有逻辑完全保留)"""
if self.table_exists(table_name):
self.log.warning("Table already exists", table=table_name)
self.log.warning("表已存在", =table_name)
return False
self.log.debug("Creating new table from DataFrame schema",
table=table_name,
columns=list(df.columns))
self.log.debug("根据DataFrame结构创建新表",
=table_name,
=list(df.columns))
try:
sql_types = self.df_to_sql_type(df)
@@ -517,19 +557,19 @@ class MySQLAgent:
pk_columns = [col for col in primary_key if col in sql_types]
if pk_columns:
columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})")
self.log.trace("Set primary key",
table=table_name,
primary_key=pk_columns)
self.log.trace("设置主键",
=table_name,
主键=pk_columns)
create_sql = f"CREATE TABLE {table_name} (\n {',\n '.join(columns_sql)}\n)"
self.execute_sql(create_sql)
self.log.info("Table created successfully", table=table_name)
self.log.info("表创建成功", =table_name)
return True
except Exception as e:
self.log.error("Failed to create table",
table=table_name,
self.log.error("创建表失败",
=table_name,
error=str(e),
exc_info=True)
return False
@@ -548,16 +588,16 @@ class MySQLAgent:
if fetch:
result = cursor.fetchall()
self.log.debug("Query executed", rows=len(result))
self.log.debug("查询执行完成", 行数=len(result))
return result
else:
affected_rows = cursor.rowcount
conn.commit() # 立即提交
self.log.debug("Update executed", affected_rows=affected_rows)
self.log.debug("更新执行完成", 受影响行数=affected_rows)
return affected_rows
except Exception as e:
self.log.error("SQL execution failed",
self.log.error("SQL执行失败",
sql=sql,
params=params,
error=str(e),
@@ -578,9 +618,9 @@ class MySQLAgent:
try:
result = self.execute_sql(sql, params, fetch=True)
exists = result[0][0] > 0 # 适配元组结果
self.log.debug("Checked table existence",
table=table_name,
exists=exists)
self.log.debug("检查表是否存在",
=table_name,
存在=exists)
return exists
except Exception:
return False
@@ -588,16 +628,16 @@ class MySQLAgent:
def drop_table(self, table_name: str) -> bool:
"""删除表(原有逻辑完全保留)"""
if not self.table_exists(table_name):
self.log.warning("Table does not exist", table=table_name)
self.log.warning("表不存在", =table_name)
return False
try:
self.execute_sql(f"DROP TABLE {table_name}")
self.log.info("Table dropped successfully", table=table_name)
self.log.info("表删除成功", =table_name)
return True
except Exception as e:
self.log.error("Failed to drop table",
table=table_name,
self.log.error("删除表失败",
=table_name,
error=str(e),
exc_info=True)
return False
@@ -641,7 +681,7 @@ def get_default_config():
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
}
else: # Linux和其他平台
return {**base_config,
return {** base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60
@@ -654,10 +694,10 @@ if __name__ == "__main__":
# 测试连接
if db.validate_connection():
print("Database connection successful")
print("数据库连接成功")
# 获取数据库版本
version = db.query_to_df("SELECT VERSION() as version")
print(f"Database version: {version['version'].iloc[0]}")
print(f"数据库版本: {version['version'].iloc[0]}")
else:
print("Failed to connect to database")
print("连接数据库失败")