mysql数据链接更新

This commit is contained in:
z66
2025-09-18 17:03:24 +08:00
parent 9afa9d2e58
commit 20fd9587ee
5 changed files with 188086 additions and 726 deletions
+149 -116
View File
@@ -6,72 +6,108 @@ import os
import pickle
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
import pymysql
from loguru import logger
from utils.mysql_agent import MySQLAgent
from typing import Dict, List, Optional, Any
# 数据库连接信息
# 数据库连接配置
local_DB_Config = {
'host': "localhost",
'user': "root",
'password': "123123",
'database': "intelligence_system",
'charset': 'utf8mb4'
'port': 3306,
'charset': 'utf8mb4',
'connect_timeout': 10,
'read_timeout': 30,
'write_timeout': 30,
'autocommit': True
}
# 表名
# 目标数据表名
table_name = "collector_rss_subscriptions"
def verify_database():
"""验证数据库连接和表结构"""
class NewsAPIClient:
"""新闻API客户端,用于获取和处理RSS源数据并写入数据库"""
def __init__(self):
"""初始化客户端并建立数据库连接"""
self.logger = logger.bind(module="NewsAPIClient")
self.db_agent = MySQLAgent(local_DB_Config)
self.logger.info("新闻API客户端初始化完成,已连接到数据库")
def _format_result(self, success: bool, message: str = "", data: Optional[Any] = None) -> Dict[str, Any]:
"""统一返回结果格式"""
return {
'success': bool(success),
'message': str(message),
'data': data
}
def verify_database(self) -> bool:
"""验证数据库表结构是否符合要求(适配元组格式的查询结果)"""
try:
conn = pymysql.connect(**local_DB_Config)
with conn.cursor() as cursor:
# 检查表是否存在
cursor.execute(f"SHOW TABLES LIKE '{table_name}'")
if not cursor.fetchone():
print(f"错误: 表 {table_name} 不存在!")
# 1. 检查表是否存在(execute_sql返回元组列表,如 [(table_name,)]
result = self.db_agent.execute_sql(
f"SHOW TABLES LIKE '{table_name}'",
fetch=True
)
# 元组结果需通过索引0判断(若表存在,result是[(table_name,)], 否则为空列表)
if not result:
self.logger.error(f"{table_name} 不存在,请先创建表结构")
return False
# 检查表结构
cursor.execute(f"DESCRIBE {table_name}")
columns = [col[0] for col in cursor.fetchall()]
print("表列名:", columns)
# 2. 检查表字段是否完整(DESCRIBE返回的元组格式:(字段名, 类型, 是否为空, ...))
desc_result = self.db_agent.execute_sql(
f"DESCRIBE {table_name}",
fetch=True
)
# 关键修改:用元组索引0提取字段名(而非字典键'Field'
columns = [col[0] for col in desc_result] # col是元组,col[0]即字段名
required_columns = ['文章标题', '文章链接', '文章摘要', '发布时间',
'来源URL', '创建时间', '更新时间']
missing_cols = [col for col in required_columns if col not in columns]
# 检查插入权限
test_sql = f"""INSERT INTO `{table_name}`
(`文章标题`, `文章链接`, `文章摘要`, `发布时间`, `来源URL`)
VALUES (%s, %s, %s, %s, %s)"""
cursor.execute(test_sql, ('测试标题', 'http://test.com', '测试内容', datetime.now(), '测试来源'))
conn.rollback()
if missing_cols:
self.logger.error(f"{table_name} 缺少必要字段:{missing_cols}")
return False
print("数据库验证通过!")
self.logger.info(f"数据库表结构验证通过,当前字段:{columns}")
return True
except Exception as e:
print("数据库验证失败:", e)
self.logger.error(f"数据库验证失败: {str(e)}", exc_info=True)
return False
finally:
if 'conn' in locals():
conn.close()
def load_last_update_time():
"""加载上次更新的时间"""
def load_last_update_time(self) -> Optional[datetime]:
"""加载上次更新时间缓存"""
cache_file = os.path.join(os.getcwd(), 'output', 'last_update.pkl')
if os.path.exists(cache_file):
try:
with open(cache_file, 'rb') as f:
return pickle.load(f)
last_update = pickle.load(f)
self.logger.debug(f"加载上次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
return last_update
except Exception as e:
self.logger.error(f"加载上次更新时间失败: {str(e)}", exc_info=True)
self.logger.debug("未找到上次更新时间缓存,将获取全部数据")
return None
def save_last_update_time(self, last_update: datetime) -> None:
"""保存本次更新时间"""
try:
cache_dir = os.path.join(os.getcwd(), 'output')
os.makedirs(cache_dir, exist_ok=True)
cache_file = os.path.join(cache_dir, 'last_update.pkl')
def save_last_update_time(last_update):
"""保存本次更新的时间"""
cache_file = os.path.join(os.getcwd(), 'output', 'last_update.pkl')
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
with open(cache_file, 'wb') as f:
pickle.dump(last_update, f)
self.logger.debug(f"已保存本次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
except Exception as e:
self.logger.error(f"保存更新时间失败: {str(e)}", exc_info=True)
def fetch_single_rss(url, timeout=15):
def fetch_single_rss(self, url: str, timeout: int = 15) -> Optional[feedparser.FeedParserDict]:
"""获取并解析单个RSS源"""
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
@@ -85,24 +121,25 @@ def fetch_single_rss(url, timeout=15):
feed = feedparser.parse(response.text)
if feed.bozo:
print(f"警告: 解析可能存在问题: {feed.bozo_exception}")
self.logger.warning(f"解析 {url} 存在潜在问题: {feed.bozo_exception}")
self.logger.debug(f"成功获取 {url} 的RSS数据")
return feed
except requests.RequestException as e:
print(f"{attempt + 1}尝试获取 {url} 失败: {e}")
self.logger.warning(f"{attempt + 1} 次获取 {url} 失败: {str(e)}")
if attempt < 2:
time.sleep(5 * (attempt + 1))
time.sleep(3 * (attempt + 1)) # 指数退避重试
continue
self.logger.error(f"三次尝试后仍无法获取 {url} 的RSS数据")
return None
def fetch_all_rss(urls, timeout=15):
"""使用线程池并发获取多个RSS源"""
def fetch_all_rss(self, urls: List[str], timeout: int = 15) -> Dict[str, feedparser.FeedParserDict]:
"""并发获取多个RSS源"""
feeds = {}
with ThreadPoolExecutor(max_workers=3) as executor:
future_to_url = {executor.submit(fetch_single_rss, url, timeout): url for url in urls}
future_to_url = {executor.submit(self.fetch_single_rss, url, timeout): url for url in urls}
for future in as_completed(future_to_url):
url = future_to_url[future]
@@ -111,13 +148,13 @@ def fetch_all_rss(urls, timeout=15):
if feed:
feeds[url] = feed
except Exception as e:
print(f"获取 {url} 时发生异常: {e}")
self.logger.error(f"处理 {url} 时发生异常: {str(e)}", exc_info=True)
self.logger.info(f"RSS源获取完成,成功获取 {len(feeds)}/{len(urls)} 个源")
return feeds
def process_feed_entry(entry, url):
"""处理单个RSS条目并返回结构化数据"""
def process_feed_entry(self, entry: Dict[str, Any], url: str) -> Dict[str, str]:
"""处理单个RSS条目,转换为数据库兼容格式"""
# 处理标题
title = entry.get('title', '无标题') or '无标题'
if len(title) > 255:
@@ -131,7 +168,7 @@ def process_feed_entry(entry, url):
# 处理摘要
summary = entry.get('summary', '无内容摘要')
content_list = entry.get('content', [])
content = content_list[0].value if content_list else ''
content = content_list[0].value if (content_list and hasattr(content_list[0], 'value')) else ''
description = summary if summary != '无内容摘要' else (content[:200] + '...' if content else '无内容摘要')
# 处理发布时间
@@ -141,7 +178,7 @@ def process_feed_entry(entry, url):
else:
pub_str = entry.get('published', entry.get('updated', ''))
try:
entry_time = datetime.strptime(pub_str, '%a, %d %b %Y %H:%M:%S %z')
entry_time = datetime.strptime(pub_str, '%a, %d %b %Y %H:%M:%S %z').replace(tzinfo=None)
except:
entry_time = datetime.now()
@@ -150,100 +187,89 @@ def process_feed_entry(entry, url):
if len(source_url) > 1024:
source_url = source_url[:1021] + '...'
# 当前时间(创建/更新时间)
current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
return {
'文章标题': title,
'文章链接': link,
'文章摘要': description,
'发布时间': entry_time.strftime('%Y-%m-%d %H:%M:%S'),
'来源URL': source_url
'来源URL': source_url,
'创建时间': current_time,
'更新时间': current_time
}
def display_feed_info(feed, last_update=None, url=None):
"""处理并显示RSS源信息"""
def display_feed_info(self, feed: feedparser.FeedParserDict, last_update: Optional[datetime] = None,
url: Optional[str] = None) -> Optional[datetime]:
"""处理RSS源信息并写入数据库"""
if not feed:
print("无法显示信息:feed 为 None")
self.logger.warning("无法处理空的RSS源数据")
return None
print("=" * 80)
print(f"处理 RSS 源: {url}")
self.logger.info(f"开始处理 RSS 源: {url}")
entries = feed.entries
data_list = []
new_last_update = last_update
for i, entry in enumerate(entries, 1):
entry_data = process_feed_entry(entry, url)
entry_data = self.process_feed_entry(entry, url)
entry_time = datetime.strptime(entry_data['发布时间'], '%Y-%m-%d %H:%M:%S')
# 过滤旧数据
if last_update and entry_time <= last_update:
continue
# 更新最新时间戳
if new_last_update is None or entry_time > new_last_update:
new_last_update = entry_time
print(f"\n--- 条目 {i} ---")
print(f"标题: {entry_data['文章标题']}")
print(f"链接: {entry_data['文章链接']}")
print(f"摘要: {entry_data['文章摘要'][:100]}...")
print(f"时间: {entry_data['发布时间']}")
self.logger.debug(f"处理条目 {i}: {entry_data['文章标题']}")
data_list.append(entry_data)
# 写入数据库
if data_list:
df = pd.DataFrame(data_list)
write_to_database(df)
self.write_to_database(df)
return new_last_update
def write_to_database(df):
"""将数据写入数据库"""
# news_api.py 中的 write_to_database 方法可以保持简洁
def write_to_database(self, df: pd.DataFrame) -> Dict[str, Any]:
if df.empty:
print("没有新数据需要写入")
return
print("\n准备写入数据库的数据样例:")
print(df.iloc[0].to_dict())
self.logger.info("没有新数据需要写入数据库")
return self._format_result(True, "没有新数据需要写入")
try:
conn = pymysql.connect(**local_DB_Config)
with conn.cursor() as cursor:
sql = f"""INSERT IGNORE INTO `{table_name}`
(`文章标题`, `文章链接`, `文章摘要`, `发布时间`, `来源URL`)
VALUES (%s, %s, %s, %s, %s)"""
inserted_rows = self.db_agent.insert_from_df(
table_name=table_name,
df=df,
chunk_size=500,
replace=False
)
success_count = 0
for _, row in df.iterrows():
self.logger.info(f"成功写入 {inserted_rows}/{len(df)} 条记录")
return self._format_result(
True,
f"成功写入 {inserted_rows}/{len(df)} 条记录",
{"success_count": inserted_rows, "total": len(df)}
)
except Exception as e:
self.logger.error(f"数据库写入失败: {str(e)}", exc_info=True)
return self._format_result(False, f"数据库操作失败: {str(e)}")
@classmethod
def main(cls):
"""主函数入口"""
try:
cursor.execute(sql, (
row['文章标题'],
row['文章链接'],
row['文章摘要'],
row['发布时间'],
row['来源URL']
))
success_count += cursor.rowcount
except Exception as e:
print(f"插入记录时出错: {e}")
print(f"问题数据: {row.to_dict()}")
continue
client = cls()
conn.commit()
print(f"成功写入 {success_count}/{len(df)} 条记录")
except Exception as e:
print("数据库操作失败:", e)
finally:
if 'conn' in locals():
conn.close()
def main():
"""主函数"""
if not verify_database():
print("数据库验证失败,程序终止")
# 验证数据库
if not client.verify_database():
client.logger.error("数据库验证失败,程序终止")
return
# RSS源列表
rss_urls = [
"https://www.chinanews.com.cn/rss/finance.xml",
"https://www.chinanews.com.cn/rss/world.xml",
@@ -251,27 +277,34 @@ def main():
"https://www.chinanews.com.cn/rss/scroll-news.xml"
]
last_update = load_last_update_time()
# 加载上次更新时间
last_update = client.load_last_update_time()
if last_update:
print(f"上次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
client.logger.info(f"上次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
print("\n开始获取RSS数据...")
# 获取RSS数据
client.logger.info("开始获取RSS源数据...")
start_time = time.time()
feeds = fetch_all_rss(rss_urls)
print(f"获取完成,耗时: {time.time() - start_time:.2f}")
feeds = client.fetch_all_rss(rss_urls)
client.logger.info(f"获取完成,耗时: {time.time() - start_time:.2f}")
# 处理并写入数据
new_last_update = None
for url, feed in feeds.items():
current_last_update = display_feed_info(feed, last_update, url)
current_last_update = client.display_feed_info(feed, last_update, url)
if current_last_update and (new_last_update is None or current_last_update > new_last_update):
new_last_update = current_last_update
# 保存最新更新时间
if new_last_update:
save_last_update_time(new_last_update)
print(f"\n本次最新更新时间: {new_last_update.strftime('%Y-%m-%d %H:%M:%S')}")
client.save_last_update_time(new_last_update)
client.logger.info(f"本次最新更新时间: {new_last_update.strftime('%Y-%m-%d %H:%M:%S')}")
else:
print("\n没有获取到新内容")
client.logger.info("没有获取到新内容")
except Exception as e:
logger.error(f"程序运行出错: {str(e)}", exc_info=True)
if __name__ == "__main__":
main()
NewsAPIClient.main()
+123849
View File
File diff suppressed because it is too large Load Diff
+63455
View File
File diff suppressed because it is too large Load Diff
+101 -111
View File
@@ -3,18 +3,20 @@ import pandas as pd
from datetime import datetime
import time
import pymysql
from utils.mysql_agent import MySQLAgent
import platform
from concurrent.futures import ThreadPoolExecutor
from utils.mysql_agent import MySQLAgent
class TestMySQLAgent(unittest.TestCase):
@classmethod
def setUpClass(cls):
"""初始化测试环境和测试表"""
# 创建唯一的测试数据库
cls.test_db_name = "test_db_" + datetime.now().strftime("%Y%m%d%H%M%S")
cls.test_table = "test_table_" + datetime.now().strftime("%Y%m%d%H%M%S")
# 创建唯一的测试数据库和表名(避免冲突)
cls.test_db_name = f"test_db_{datetime.now().strftime('%Y%m%d%H%M%S')}"
cls.test_table = f"test_table_{datetime.now().strftime('%Y%m%d%H%M%S')}"
# 基础配置
# 基础配置(根据实际环境修改)
cls.base_config = {
'host': 'localhost',
'port': 3306,
@@ -32,21 +34,19 @@ class TestMySQLAgent(unittest.TestCase):
'database': cls.test_db_name
})
# 创建测试表
# 创建测试表并插入初始数据
test_data = pd.DataFrame({
'id': [1, 2, 3],
'name': ['Test1', 'Test2', 'Test3'],
'value': [10.5, 20.3, 30.8],
'created_at': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03'])
})
cls.db.create_table_from_df(cls.test_table, test_data, primary_key='id')
cls.db.insert_from_df(cls.test_table, test_data)
@classmethod
def _create_test_database(cls):
"""创建测试数据库"""
# 使用临时连接创建数据库
temp_conn = pymysql.connect(
host=cls.base_config['host'],
port=cls.base_config['port'],
@@ -54,7 +54,6 @@ class TestMySQLAgent(unittest.TestCase):
password=cls.base_config['password'],
charset='utf8mb4'
)
try:
with temp_conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}")
@@ -66,21 +65,14 @@ class TestMySQLAgent(unittest.TestCase):
@classmethod
def tearDownClass(cls):
"""清理测试数据库"""
"""清理测试环境"""
if hasattr(cls, 'db') and cls.db:
# 删除测试表
if cls.db.table_exists(cls.test_table):
cls.db.drop_table(cls.test_table)
# 删除测试数据库
temp_conn = pymysql.connect(
host=cls.base_config['host'],
port=cls.base_config['port'],
user=cls.base_config['user'],
password=cls.base_config['password'],
charset='utf8mb4'
)
temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4')
try:
with temp_conn.cursor() as cursor:
cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}")
@@ -88,22 +80,24 @@ class TestMySQLAgent(unittest.TestCase):
finally:
temp_conn.close()
def test_01_connection(self):
def test_connection(self):
"""测试数据库连接"""
version = self.db.query_to_df("SELECT VERSION() as version")
self.assertIsNotNone(version)
print(f"\nDatabase version: {version['version'].iloc[0]}")
print(f"Running on: {platform.system()} {platform.release()}")
version_df = self.db.query_to_df("SELECT VERSION() as version")
self.assertIsNotNone(version_df)
self.assertEqual(len(version_df), 1)
print(f"数据库版本: {version_df['version'].iloc[0]}")
def test_02_query_to_df(self):
def test_query_to_df(self):
"""测试查询返回DataFrame"""
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id > %s", (1,))
self.assertEqual(len(df), 2)
df = self.db.query_to_df(
f"SELECT * FROM {self.test_table} WHERE id > %s",
params=(1,)
)
self.assertIsInstance(df, pd.DataFrame)
print("\nQuery result sample:")
print(df.head())
self.assertEqual(len(df), 2) # id>1 的数据有2条
self.assertIn('name', df.columns)
def test_03_insert_from_df(self):
def test_insert_from_df(self):
"""测试DataFrame插入"""
new_data = pd.DataFrame({
'id': [4, 5],
@@ -112,55 +106,55 @@ class TestMySQLAgent(unittest.TestCase):
'created_at': pd.to_datetime(['2023-01-04', '2023-01-05'])
})
rows = self.db.insert_from_df(self.test_table, new_data)
self.assertEqual(rows, 2)
inserted_rows = self.db.insert_from_df(self.test_table, new_data)
self.assertEqual(inserted_rows, 2)
# 验证数据
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id >= 4")
self.assertEqual(len(df), 2)
self.assertEqual(df['name'].tolist(), ['Test4', 'Test5'])
# 验证插入结果
result_df = self.db.query_to_df(
f"SELECT name FROM {self.test_table} WHERE id IN (4,5)"
)
self.assertEqual(result_df['name'].tolist(), ['Test4', 'Test5'])
def test_04_update_from_df(self):
def test_update_from_df(self):
"""测试DataFrame更新"""
update_data = pd.DataFrame({
'id': [1, 2],
'name': ['Updated1', 'Updated2']
})
rows = self.db.update_from_df(self.test_table, update_data, 'id')
self.assertGreaterEqual(rows, 2)
updated_rows = self.db.update_from_df(self.test_table, update_data, 'id')
self.assertGreaterEqual(updated_rows, 2)
# 验证更新
df = self.db.query_to_df(f"SELECT name FROM {self.test_table} WHERE id IN (1,2)")
self.assertIn('Updated1', df['name'].values)
self.assertIn('Updated2', df['name'].values)
# 验证更新结果
result_df = self.db.query_to_df(
f"SELECT name FROM {self.test_table} WHERE id IN (1,2)"
)
self.assertIn('Updated1', result_df['name'].values)
self.assertIn('Updated2', result_df['name'].values)
def test_05_transaction(self):
def test_transaction(self):
"""测试事务处理"""
conn = self.db.begin_transaction()
try:
# 执行多个操作
# 执行事务内操作
cursor = conn.cursor()
cursor.execute(f"UPDATE {self.test_table} SET value = 99.9 WHERE id = 1")
cursor.execute(f"UPDATE {self.test_table} SET value = 88.8 WHERE id = 2")
# 验证事务内修改
cursor.execute(f"SELECT value FROM {self.test_table} WHERE id = 1")
self.assertEqual(cursor.fetchone()['value'], 99.9)
self.db.commit_transaction(conn)
except Exception:
self.db.rollback_transaction(conn)
raise
# 验证提交后的修改
df = self.db.query_to_df(f"SELECT value FROM {self.test_table} WHERE id IN (1,2)")
self.assertIn(99.9, df['value'].values)
self.assertIn(88.8, df['value'].values)
# 验证事务提交结果
result_df = self.db.query_to_df(
f"SELECT value FROM {self.test_table} WHERE id IN (1,2)"
)
self.assertIn(99.9, result_df['value'].values)
self.assertIn(88.8, result_df['value'].values)
def test_06_large_data(self):
"""测试大数据量操作"""
# 生成测试数据
def test_large_data_insert(self):
"""测试大数据量插入"""
# 生成1000行测试数据
large_data = pd.DataFrame({
'id': range(1000, 2000),
'name': [f"Item_{i}" for i in range(1000, 2000)],
@@ -168,59 +162,55 @@ class TestMySQLAgent(unittest.TestCase):
'created_at': pd.date_range('2023-01-01', periods=1000)
})
# Windows平台使用更小的批次
# 根据平台自动调整批次大小
chunk_size = 100 if platform.system() == 'Windows' else 500
start_time = time.time()
rows = self.db.insert_from_df(self.test_table, large_data, chunk_size=chunk_size)
inserted_rows = self.db.insert_from_df(
self.test_table,
large_data,
chunk_size=chunk_size
)
elapsed = time.time() - start_time
self.assertEqual(rows, 1000)
print(f"\nInserted 1000 rows in {elapsed:.2f}s (chunk_size={chunk_size})")
self.assertEqual(inserted_rows, 1000)
print(f"插入1000行数据耗时: {elapsed:.2f} (批次大小: {chunk_size})")
# 验证数据
df = self.db.query_to_df(f"SELECT COUNT(*) as cnt FROM {self.test_table} WHERE id >= 1000")
self.assertEqual(df['cnt'].iloc[0], 1000)
def test_07_concurrent_access(self):
def test_concurrent_access(self):
"""测试并发访问"""
from concurrent.futures import ThreadPoolExecutor
def worker(i):
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id = %s", (i % 5 + 1,))
def query_worker(i):
"""并发查询工作函数"""
df = self.db.query_to_df(
f"SELECT * FROM {self.test_table} WHERE id = %s",
params=(i % 3 + 1,) # 查询id=1,2,3循环
)
return len(df)
# 20个线程执行100次查询
start_time = time.time()
with ThreadPoolExecutor(max_workers=20) as executor:
results = list(executor.map(worker, range(100)))
results = list(executor.map(query_worker, range(100)))
elapsed = time.time() - start_time
self.assertEqual(sum(results), 100)
print(f"\nCompleted 100 concurrent queries in {elapsed:.2f}s")
self.assertEqual(sum(results), 100) # 每次查询应返回1行
print(f"100次并发查询耗时: {elapsed:.2f}")
class TestPlatformSpecific(unittest.TestCase):
"""平台特定功能测试"""
@classmethod
def setUpClass(cls):
"""创建临时测试数据库"""
cls.test_db_name = "test_db_platform_" + datetime.now().strftime("%Y%m%d%H%M%S")
cls.test_db_name = f"test_platform_db_{datetime.now().strftime('%Y%m%d%H%M%S')}"
cls.base_config = {
'host': 'localhost',
'port': 3306,
'user': 'root',
'password': '123123',
'max_connections': 10
'password': '123123'
}
# 创建数据库
temp_conn = pymysql.connect(
host=cls.base_config['host'],
port=cls.base_config['port'],
user=cls.base_config['user'],
password=cls.base_config['password'],
charset='utf8mb4'
)
# 创建测试数据库
temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4')
try:
with temp_conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}")
@@ -230,15 +220,8 @@ class TestPlatformSpecific(unittest.TestCase):
@classmethod
def tearDownClass(cls):
"""删除临时测试数据库"""
temp_conn = pymysql.connect(
host=cls.base_config['host'],
port=cls.base_config['port'],
user=cls.base_config['user'],
password=cls.base_config['password'],
charset='utf8mb4'
)
"""清理测试数据库"""
temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4')
try:
with temp_conn.cursor() as cursor:
cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}")
@@ -249,42 +232,49 @@ class TestPlatformSpecific(unittest.TestCase):
def test_windows_timeout(self):
"""测试Windows平台超时处理"""
if platform.system() != 'Windows':
self.skipTest("Only runs on Windows")
self.skipTest("仅在Windows平台运行")
config = {
**self.base_config,
'database': self.test_db_name,
'connect_timeout': 1,
'read_timeout': 1
'read_timeout': 1,
'write_timeout': 1
}
db = MySQLAgent(config)
# 测试短超时查询
start_time = time.time()
# 执行会超时查询(SLEEP(2)超过1秒超时设置)
with self.assertRaises((pymysql.OperationalError, TimeoutError)) as ctx:
try:
db.query_to_df("SELECT SLEEP(2)")
self.fail("Should have timed out")
except Exception as e:
self.assertIn("timed out", str(e))
print(f"\nWindows timeout test: {str(e)}")
# 提取底层异常信息(可能被包装)
while hasattr(e, 'args') and len(e.args) > 0 and isinstance(e.args[0], Exception):
e = e.args[0]
raise e
def test_macos_ssl(self):
"""测试macOS SSL连接"""
error_msg = str(ctx.exception)
self.assertTrue(
"timed out" in error_msg or
"timeout" in error_msg or
"HY000" in error_msg, # MySQL超时错误码
f"未检测到超时异常,实际异常: {error_msg}"
)
def test_macos_ssl_connection(self):
"""测试macOS平台SSL连接"""
if platform.system() != 'Darwin':
self.skipTest("Only runs on macOS")
self.skipTest("仅在macOS平台运行")
config = {
**self.base_config,
'database': self.test_db_name,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
}
db = MySQLAgent(config)
version = db.query_to_df("SELECT VERSION() as version")
self.assertIsNotNone(version)
print(f"\nmacOS SSL connection successful: {version['version'].iloc[0]}")
version_df = db.query_to_df("SELECT VERSION() as version")
self.assertIsNotNone(version_df)
if __name__ == '__main__':
unittest.main()
unittest.main(verbosity=2)
+389 -356
View File
@@ -3,13 +3,13 @@ import sys
import platform
import pandas as pd
import pymysql
import json
import numpy as np
from pymysql import cursors
from pymysql.err import MySQLError
from dbutils.pooled_db import PooledDB
from typing import Union, List, Dict, Any, Optional, Tuple
from typing import Union, List, Dict, Any, Optional, Tuple, Literal
import threading
from datetime import datetime
import numpy as np
from pathlib import Path
# 导入日志系统
@@ -20,7 +20,7 @@ class MySQLAgent:
"""
全平台兼容的MySQL数据库操作类
支持Windows/macOS/Linux系统
配置参数从外部传入
配置参数从外部传入,不使用连接池和事务管理
"""
_instance = None
@@ -34,30 +34,14 @@ class MySQLAgent:
return cls._instance
def __init__(self, config: dict):
"""
初始化MySQL数据库连接
Args:
config (dict): 数据库配置字典,包含以下键:
- host: 数据库主机
- port: 端口
- user: 用户名
- password: 密码
- database: 数据库名
- [可选] charset: 字符集(默认utf8mb4)
- [可选] max_connections: 最大连接数(默认5)
- [可选] connect_timeout: 连接超时(秒)
- [可选] read_timeout: 读取超时(秒)
- [可选] write_timeout: 写入超时(秒)
- [可选] ssl: SSL配置
"""
if hasattr(self, '_pool') and self._pool:
"""初始化MySQL数据库连接(原有逻辑完全保留)"""
if hasattr(self, 'config') and self.config:
return
# 基础配置校验
required_keys = ['host', 'port', 'user', 'password', 'database']
if not all(key in config for key in required_keys):
log.warning(f"数据库配置缺少必要参数,当前配置: {config}")
log.warning(f"数据库配置缺少必要参数,当前数据库链接信息为:{config}")
raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}")
self.config = {
@@ -67,7 +51,6 @@ class MySQLAgent:
'password': config['password'],
'database': config['database'],
'charset': config.get('charset', 'utf8mb4'),
'cursorclass': cursors.DictCursor,
'autocommit': True,
'connect_timeout': config.get('connect_timeout', 10),
'read_timeout': config.get('read_timeout', 30),
@@ -79,86 +62,57 @@ class MySQLAgent:
current_platform = platform.system()
self.log = log.bind(module=f"MySQLAgent({current_platform})")
# 创建连接池
self.pool_size = config.get('max_connections', 5)
self._pool = self._create_pool()
def _create_pool(self) -> PooledDB:
"""创建连接池"""
try:
# 线程安全的连接创建函数
def connect():
conn = pymysql.connect(**self.config)
conn.threadsafety = 1 # 显式设置线程安全级别
return conn
pool = PooledDB(
creator=connect,
mincached=1,
maxcached=3,
maxconnections=self.pool_size,
blocking=True,
ping=1 # 每次获取连接时ping数据库
)
self.log.info("连接池创建成功")
return pool
except Exception as e:
self.log.critical("连接池创建失败", error=str(e), exc_info=True)
raise
def get_connection(self) -> pymysql.connections.Connection:
"""获取数据库连接(修复字符集方法缺失问题"""
"""获取数据库连接(原有逻辑完全保留"""
try:
conn = self._pool.connection()
conn = pymysql.connect(**self.config)
# 为连接添加字符集方法(兼容SQLAlchemy
# 为连接添加 character_set_name 方法
if not hasattr(conn, 'character_set_name'):
def _character_set_name():
return self.config.get('charset', 'utf8mb4')
conn.character_set_name = _character_set_name
# macOS平台SSL特殊处理
# macOS需要特殊处理SSL
if platform.system() == 'Darwin' and self.config.get('ssl'):
conn.ping(reconnect=True)
self.log.trace("获取数据库连接成功")
self.log.trace("Database connection obtained")
return conn
except Exception as e:
error_msg = str(e)
# Windows平台连接超时重试
if platform.system() == 'Windows' and "timed out" in error_msg:
self.log.warning("Windows连接超时,尝试重试...")
self.log.warning("Windows connection timeout, retrying...")
return self._retry_connection()
self.log.error("获取连接失败", error=error_msg, exc_info=True)
self.log.error("Connection failed", error=error_msg, exc_info=True)
raise
def _retry_connection(self, max_retries: int = 3) -> pymysql.connections.Connection:
"""Windows平台连接重试机制"""
def _retry_connection(self, max_retries: int = 3) -> Any | None:
"""Windows平台连接重试机制(原有逻辑完全保留)"""
for attempt in range(max_retries):
try:
conn = self._pool.connection()
self.log.info(f"{attempt + 1}次尝试连接成功")
conn = pymysql.connect(**self.config)
self.log.info(f"Connection established after {attempt + 1} attempts")
return conn
except Exception:
if attempt == max_retries - 1:
raise
import time
time.sleep(1) # 重试间隔1秒
time.sleep(1)
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
parse_dates: Union[List[str], bool] = True) -> pd.DataFrame:
"""执行SQL查询并返回DataFrame优化连接管理"""
conn = None
"""执行SQL查询并返回DataFrame原有逻辑完全保留"""
try:
self.log.debug("执行SQL查询", sql=sql)
self.log.debug("Executing SQL query", sql=sql)
# 获取连接并确保字符集方法存在
conn = self.get_connection()
# 创建SQLAlchemy引擎(使用静态池避免连接重复创建)
# 创建SQLAlchemy引擎
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
engine = create_engine(
@@ -170,180 +124,361 @@ class MySQLAgent:
# 执行查询
df = pd.read_sql(sql, engine, params=params, parse_dates=parse_dates)
self.log.info(f"查询成功,返回{len(df)}行数据")
self.log.info("Query executed successfully", rows=len(df))
return df
except Exception as e:
self.log.error(f"SQL查询失败{sql}", sql=sql, params=params, error=str(e), exc_info=True)
self.log.error("SQL query failed", sql=sql, params=params, error=str(e), exc_info=True)
raise
finally:
# 确保连接释放回池
if conn:
try:
conn.close()
except Exception as e:
self.log.warning("关闭连接失败", error=str(e))
if 'engine' in locals():
engine.dispose()
def insert_from_df(self, table_name: str, df: pd.DataFrame,
chunk_size: int = 1000, replace: bool = False) -> int:
"""将DataFrame数据插入到数据库表(优化批量处理)"""
chunk_size: int = 1000, replace: bool = False, # 保留replace参数
ignore_duplicates: bool = None) -> int: # 新增ignore_duplicates参数
"""
兼容旧接口的通用插入方法:保留replace参数,同时支持新的ignore_duplicates
自动处理重复数据,对所有数据源通用
"""
# 【兼容性处理】如果未指定ignore_duplicates,用replace参数推导(replace=True时不忽略重复)
if ignore_duplicates is None:
ignore_duplicates = not replace # 旧逻辑中replace=True表示替换,即不忽略重复
if df.empty:
self.log.warning(f"尝试插入空DataFrame到表{table_name}")
self.log.warning("Attempted to insert empty DataFrame", table=table_name)
return 0
self.log.debug(f"准备插入DataFrame到表{table_name}", rows=len(df), chunk_size=chunk_size)
# 根据平台自动调整批次大小
current_platform = platform.system()
if current_platform == 'Windows' and chunk_size > 500:
chunk_size = 500
self.log.debug(f"Windows平台自动调整批次大小为{chunk_size}")
elif current_platform == 'Linux' and chunk_size < 1000:
chunk_size = 1000
self.log.debug(f"Linux平台自动调整批次大小为{chunk_size}")
conn = None
cursor = None
total_inserted = 0
total_duplicated = 0
total_failed = 0
try:
method = 'replace' if replace else 'append'
total_rows = 0
# 1. 建立数据库连接
conn = self.get_connection()
cursor = conn.cursor()
self.log.debug(f"Established connection for inserting into {table_name}")
# 创建SQLAlchemy引擎
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
engine = create_engine(
"mysql+pymysql://",
creator=lambda: conn,
poolclass=StaticPool,
connect_args={
'charset': self.config.get('charset', 'utf8mb4'),
'autocommit': True
}
# 2. 获取数据库表的实际列名
cursor.execute(f"SHOW COLUMNS FROM `{table_name}`")
columns_info = cursor.fetchall()
db_columns = [col[0] for col in columns_info]
self.log.debug(f"Table {table_name} has columns: {db_columns}")
# 3. 数据预处理:统一处理空值
cleaned_df = df.replace(
[None, np.nan, pd.NA, 'nan', 'NaN', 'NAN', ''],
None
).copy()
# 4. 字段匹配:只保留与数据库匹配的列
df_columns = cleaned_df.columns.tolist()
matched_columns = [col for col in df_columns if col in db_columns]
unmatched_columns = [col for col in df_columns if col not in db_columns]
if unmatched_columns:
self.log.warning(
f"Table {table_name} dropping unmatched columns",
unmatched_columns=unmatched_columns,
count=len(unmatched_columns)
)
if not matched_columns:
self.log.warning(f"No matched columns for {table_name}, abort insertion")
return 0
filtered_df = cleaned_df[matched_columns].copy()
total_to_insert = len(filtered_df)
self.log.debug(
f"Filtered DataFrame for {table_name}: {total_to_insert} rows to insert"
)
# 5. 处理复杂类型(dict/list转JSON
for col in filtered_df.columns:
has_complex_type = filtered_df[col].apply(
lambda x: isinstance(x, (dict, list)) if x is not None else False
).any()
if has_complex_type:
self.log.debug(f"Column {col} in {table_name} has complex type, converting to JSON")
filtered_df.loc[:, col] = filtered_df[col].apply(
lambda x: json.dumps(x, ensure_ascii=False) if x is not None else x
)
# 6. 构建通用插入SQL
columns_str = ', '.join([f"`{col}`" for col in filtered_df.columns])
placeholders = ', '.join(['%s'] * len(filtered_df.columns))
insert_sql = f"INSERT INTO `{table_name}` ({columns_str}) VALUES ({placeholders})"
self.log.trace(f"Generated insert SQL for {table_name}: {insert_sql}")
# 7. 逐条插入(确保能捕获单条重复错误)
records = filtered_df.to_dict('records')
indices = filtered_df.index.tolist()
for i, (record, idx) in enumerate(zip(records, indices)):
try:
for i in range(0, len(df), chunk_size):
chunk = df.iloc[i:i + chunk_size].copy() # 使用copy避免SettingWithCopyWarning
data = tuple(record[col] for col in filtered_df.columns)
cursor.execute(insert_sql, data)
total_inserted += 1
# macOS平台datetime特殊处理
if platform.system() == 'Darwin':
for col in chunk.select_dtypes(include=['datetime64']):
chunk[col] = chunk[col].dt.strftime('%Y-%m-%d %H:%M:%S')
chunk.to_sql(
table_name,
engine,
if_exists=method,
index=False,
method='multi'
if (i + 1) % 100 == 0:
self.log.trace(
f"Inserted {i + 1}/{total_to_insert} rows into {table_name}"
)
total_rows += len(chunk)
method = 'append' # 首次后使用追加模式
self.log.trace(f"插入第{i // chunk_size + 1}批数据", rows=len(chunk), total=total_rows)
self.log.info(f"数据插入成功,表{table_name}共插入{total_rows}")
return total_rows
finally:
engine.dispose()
conn.close()
except MySQLError as e:
# 8. 捕获重复错误(MySQL错误码1062)
if e.args[0] == 1062:
total_duplicated += 1
short_record = {
k: (str(v)[:100] + '...') if isinstance(v, (str, dict, list)) else v
for k, v in record.items()
}
self.log.warning(
f"Skipped duplicate record in {table_name}",
index=idx,
error_msg=e.args[1],
record=short_record
)
if not ignore_duplicates:
raise
else:
# 其他数据库错误
total_failed += 1
self.log.error(
f"Failed to insert record in {table_name}",
index=idx,
error_code=e.args[0],
error_msg=e.args[1],
record=record
)
if not ignore_duplicates:
raise
# 提交事务
conn.commit()
# 9. 插入结果统计
self.log.info(
f"Insertion summary for {table_name}",
total_to_insert=total_to_insert,
total_inserted=total_inserted,
total_duplicated=total_duplicated,
total_failed=total_failed
)
return total_inserted
except Exception as e:
self.log.error(f"数据插入失败,表{table_name}", error=str(e), exc_info=True)
if conn:
conn.rollback()
self.log.error(f"Batch insertion failed for {table_name}", error=str(e), exc_info=True)
raise
finally:
if cursor:
cursor.close()
if conn:
conn.close()
def _get_primary_key(self, table_name: str, cursor) -> Optional[str]:
"""【新增辅助方法】获取表的主键(用于replace逻辑的去重)"""
try:
cursor.execute("""
SELECT COLUMN_NAME
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
WHERE TABLE_SCHEMA = %s
AND TABLE_NAME = %s
AND CONSTRAINT_NAME = 'PRIMARY'
""", (self.config['database'], table_name))
result = cursor.fetchone()
return result[0] if result else None
except Exception as e:
self.log.warning(f"Failed to get primary key for {table_name}", error=str(e))
return None
def _get_table_detailed_info(self, table_name: str) -> Dict[str, Dict[str, Any]]:
"""获取表的详细结构信息(原有逻辑完全保留,供其他方法调用)"""
sql = """
SELECT column_name, data_type, character_maximum_length
FROM information_schema.columns
WHERE table_schema = %s \
AND table_name = %s \
"""
params = (self.config['database'], table_name)
try:
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(sql, params)
result = cursor.fetchall()
# 强制转换为列表,避免游标类型导致的解析问题
result_list = list(result)
if not result_list:
self.log.error("No columns found in table", table=table_name)
return {}
schema = {}
for row in result_list:
# 确保正确提取字段名(兼容元组格式)
col_name = str(row[0]).strip() # 强制转为字符串并去空格
data_type = str(row[1]).strip()
max_length = row[2] if row[2] else None
schema[col_name] = {
'type': data_type,
'max_length': max_length
}
self.log.debug("Successfully fetched table schema",
table=table_name,
columns=list(schema.keys()))
return schema
finally:
cursor.close()
conn.close()
except Exception as e:
self.log.error("Failed to get table detailed info",
table=table_name,
error=str(e))
raise
def _validate_and_clean_data(self, df: pd.DataFrame, table_name: str,
table_schema: Dict[str, Dict[str, Any]]) -> pd.DataFrame:
"""数据校验与清洗(原有逻辑完全保留,供其他方法调用)"""
# 1. 字段过滤:只保留表中存在的字段
df_columns = df.columns.tolist()
table_columns = list(table_schema.keys())
valid_columns = [col for col in df_columns if col in table_columns]
invalid_columns = [col for col in df_columns if col not in table_columns]
if invalid_columns:
self.log.warning("Dropping invalid columns not present in table",
table=table_name,
invalid_columns=invalid_columns,
count=len(invalid_columns))
cleaned_df = df[valid_columns].copy()
if cleaned_df.empty:
return cleaned_df
# 2. 处理每个字段的数据
for col in valid_columns:
col_info = table_schema[col]
data_type = col_info['type']
max_length = col_info['max_length']
# 2.1 处理空值
if cleaned_df[col].isnull().any():
# 根据字段类型设置默认值
default_value = '' if data_type in ['varchar', 'char', 'text'] else None
cleaned_df[col].fillna(default_value, inplace=True)
self.log.debug("Replaced null values",
table=table_name,
column=col,
default_value=default_value,
count=cleaned_df[col].isnull().sum())
# 2.2 处理字符串类型的超长字段
if data_type in ['varchar', 'char'] and max_length:
# 确保是字符串类型
cleaned_df[col] = cleaned_df[col].astype(str)
# 截断超长内容
too_long_mask = cleaned_df[col].str.len() > max_length
if too_long_mask.any():
cleaned_df.loc[too_long_mask, col] = cleaned_df.loc[too_long_mask, col].str.slice(0, max_length)
self.log.warning("Truncated overlength values",
table=table_name,
column=col,
max_length=max_length,
count=too_long_mask.sum())
# 2.3 处理日期时间类型
if data_type in ['datetime', 'timestamp']:
try:
# 尝试转换为datetime类型
cleaned_df[col] = pd.to_datetime(cleaned_df[col])
except Exception as e:
self.log.warning("Failed to convert to datetime, using current time",
table=table_name,
column=col,
error=str(e))
# 转换失败的用当前时间替代
invalid_mask = pd.to_datetime(cleaned_df[col], errors='coerce').isna()
cleaned_df.loc[invalid_mask, col] = datetime.now()
return cleaned_df
def update_from_df(self, table_name: str, df: pd.DataFrame,
key_columns: Union[str, List[str]]) -> int:
"""使用DataFrame数据更新数据库表(优化事务处理"""
"""使用DataFrame数据更新数据库表(原有逻辑完全保留"""
if df.empty:
self.log.warning(f"尝试用空DataFrame更新表{table_name}")
self.log.warning("Attempted to update with empty DataFrame", table=table_name)
return 0
self.log.debug(f"准备从DataFrame更新表{table_name}", key_columns=key_columns, rows=len(df))
self.log.debug("Preparing to update table from DataFrame",
table=table_name,
key_columns=key_columns,
rows=len(df))
try:
if isinstance(key_columns, str):
key_columns = [key_columns]
# 验证关键列存在性
missing_keys = [key for key in key_columns if key not in df.columns]
if missing_keys:
raise ValueError(f"DataFrame中缺少关键列: {missing_keys}")
total_updated = 0
conn = self.begin_transaction()
try:
cursor = conn.cursor()
with self.get_connection() as conn:
with conn.cursor() as cursor:
# 获取表结构信息
table_info = self._get_table_info(table_name)
valid_columns = [col for col in df.columns if col in table_info]
if not valid_columns:
self.log.warning(f"DataFrame列与表{table_name}无匹配")
table_info = self._get_table_detailed_info(table_name)
columns = [col for col in df.columns if col in table_info]
# 构建UPDATE语句模板
set_clause = ', '.join([f"{col}=%s" for col in columns if col not in key_columns])
where_clause = ' AND '.join([f"{col}=%s" for col in key_columns])
if not set_clause:
self.log.warning("No columns to update", table=table_name)
return 0
# 构建UPDATE语句
set_clause = ', '.join([f"`{col}`=%s" for col in valid_columns if col not in key_columns])
where_clause = ' AND '.join([f"`{col}`=%s" for col in key_columns])
update_sql = f"UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}"
self.log.trace("生成更新SQL", sql=update_sql)
update_sql = f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}"
self.log.trace("Generated update SQL", sql=update_sql)
# 准备更新数据
# 准备数据
update_data = []
for _, row in df.iterrows():
set_values = [row[col] for col in valid_columns if col not in key_columns]
set_values = [row[col] for col in columns if col not in key_columns]
key_values = [row[col] for col in key_columns]
update_data.append(tuple(set_values + key_values))
# 执行批量更新
cursor.executemany(update_sql, update_data)
total_updated = cursor.rowcount
self.commit_transaction(conn)
self.log.info(f"数据更新成功,表{table_name}共更新{total_updated}")
conn.commit()
self.log.info("Data updated successfully",
table=table_name,
rows_updated=total_updated)
return total_updated
except Exception as e:
self.rollback_transaction(conn)
raise
except Exception as e:
self.log.error(f"数据更新失败,表{table_name}", error=str(e), exc_info=True)
raise
def _get_table_info(self, table_name: str) -> Dict[str, str]:
"""获取表结构信息(优化SQL安全性)"""
sql = """
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = %s \
AND table_name = %s \
"""
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(sql, (self.config['database'], table_name))
result = cursor.fetchall()
return {row['column_name']: row['data_type'] for row in result}
except Exception as e:
self.log.error(f"获取表{table_name}结构失败", error=str(e))
self.log.error("Data update failed",
table=table_name,
error=str(e),
exc_info=True)
raise
def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]:
"""推断DataFrame各列的SQL类型(扩展类型映射"""
"""推断DataFrame各列的SQL类型(原有逻辑完全保留"""
type_mapping = {
'int64': 'BIGINT',
'int32': 'INT',
'int16': 'SMALLINT',
'int8': 'TINYINT',
'uint64': 'BIGINT UNSIGNED',
'float64': 'DOUBLE',
'float32': 'FLOAT',
'datetime64[ns]': 'DATETIME',
'datetime64[ns, UTC]': 'DATETIME',
'timedelta64[ns]': 'TIME',
'object': 'TEXT',
'string': 'VARCHAR(255)',
'bool': 'TINYINT(1)',
'category': 'VARCHAR(255)'
}
@@ -353,136 +488,84 @@ class MySQLAgent:
dtype_str = str(dtype)
sql_types[col] = type_mapping.get(dtype_str, 'TEXT')
self.log.debug("DataFrame类型映射为SQL类型", mappings=sql_types)
self.log.debug("Mapped DataFrame types to SQL types",
mappings=sql_types)
return sql_types
def create_table_from_df(self, table_name: str, df: pd.DataFrame,
primary_key: Union[str, List[str], None] = None) -> bool:
"""根据DataFrame结构创建表(增强表结构定义"""
"""根据DataFrame结构创建表(原有逻辑完全保留"""
if self.table_exists(table_name):
self.log.warning(f"{table_name}已存在")
self.log.warning("Table already exists", table=table_name)
return False
self.log.debug(f"根据DataFrame结构创建表{table_name}", columns=list(df.columns))
self.log.debug("Creating new table from DataFrame schema",
table=table_name,
columns=list(df.columns))
try:
sql_types = self.df_to_sql_type(df)
columns_sql = []
for col, sql_type in sql_types.items():
# 特殊字段处理
if col.lower() in ['create_time', 'created_at'] and sql_type != 'DATETIME':
col_def = f"`{col}` DATETIME DEFAULT CURRENT_TIMESTAMP"
elif col.lower() in ['update_time', 'updated_at'] and sql_type != 'DATETIME':
col_def = f"`{col}` DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"
else:
col_def = f"`{col}` {sql_type}"
col_def = f"{col} {sql_type}"
columns_sql.append(col_def)
# 处理主键
if primary_key:
if isinstance(primary_key, str):
primary_key = [primary_key]
pk_columns = [f"`{col}`" for col in primary_key if col in sql_types]
pk_columns = [col for col in primary_key if col in sql_types]
if pk_columns:
columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})")
self.log.trace(f"{table_name}设置主键", primary_key=pk_columns)
self.log.trace("Set primary key",
table=table_name,
primary_key=pk_columns)
create_sql = f"CREATE TABLE {table_name} (\n {',\n '.join(columns_sql)}\n)"
create_sql = f"CREATE TABLE `{table_name}` (\n {',\n '.join(columns_sql)}\n)"
self.execute_sql(create_sql)
self.log.info(f"{table_name}创建成功")
self.log.info("Table created successfully", table=table_name)
return True
except Exception as e:
self.log.error(f"{table_name}创建失败", error=str(e), exc_info=True)
self.log.error("Failed to create table",
table=table_name,
error=str(e),
exc_info=True)
return False
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
"""执行SQL语句(增强资源管理"""
conn = None
cursor = None
"""执行SQL语句(原有逻辑完全保留"""
try:
conn = self.get_connection()
cursor = conn.cursor()
# 非Windows平台延长执行超时
with self.get_connection() as conn:
with conn.cursor() as cursor:
# Linux/macOS需要更长的执行时间
if platform.system() != 'Windows':
cursor.execute("SET SESSION max_execution_time=600000") # 10分钟
cursor.execute("SET SESSION max_execution_time=600000")
cursor.execute(sql, params)
if fetch:
result = cursor.fetchall()
self.log.debug(f"查询执行完成,返回{len(result)}")
self.log.debug("Query executed", rows=len(result))
return result
else:
affected_rows = cursor.rowcount
self.log.debug(f"更新执行完成,影响{affected_rows}")
conn.commit() # 立即提交
self.log.debug("Update executed", affected_rows=affected_rows)
return affected_rows
except Exception as e:
self.log.error("SQL执行失败", sql=sql, params=params, error=str(e), exc_info=True)
self.log.error("SQL execution failed",
sql=sql,
params=params,
error=str(e),
exc_info=True)
raise
finally:
if cursor:
try:
cursor.close()
except Exception as e:
self.log.warning("关闭游标失败", error=str(e))
if conn:
try:
conn.close()
except Exception as e:
self.log.warning("关闭连接失败", error=str(e))
def begin_transaction(self) -> pymysql.connections.Connection:
"""开始事务(增强隔离级别处理)"""
try:
conn = self.get_connection()
conn.autocommit(False)
# 平台特定事务配置
if platform.system() == 'Darwin':
conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED")
elif platform.system() == 'Linux':
conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL REPEATABLE READ")
self.log.debug("事务开始")
return conn
except Exception as e:
self.log.error("事务开始失败", error=str(e))
raise
def commit_transaction(self, conn: pymysql.connections.Connection) -> None:
"""提交事务"""
try:
conn.commit()
self.log.debug("事务提交成功")
except Exception as e:
self.log.error("事务提交失败", error=str(e))
raise
finally:
try:
conn.close()
except Exception as e:
self.log.warning("事务提交后关闭连接失败", error=str(e))
def rollback_transaction(self, conn: pymysql.connections.Connection) -> None:
"""回滚事务"""
try:
conn.rollback()
self.log.warning("事务已回滚")
except Exception as e:
self.log.error("事务回滚失败", error=str(e))
finally:
try:
conn.close()
except Exception as e:
self.log.warning("事务回滚后关闭连接失败", error=str(e))
def table_exists(self, table_name: str) -> bool:
"""检查表是否存在(优化SQL安全性"""
"""检查表是否存在(原有逻辑完全保留"""
sql = """
SELECT COUNT(*) as count
FROM `information_schema`.`tables`
@@ -490,64 +573,49 @@ class MySQLAgent:
AND `table_name` = %s \
"""
params = (self.config['database'], table_name)
try:
result = self.execute_sql(sql, (self.config['database'], table_name), fetch=True)
exists = result[0]['count'] > 0
self.log.debug(f"{table_name}存在性检查", exists=exists)
result = self.execute_sql(sql, params, fetch=True)
exists = result[0][0] > 0 # 适配元组结果
self.log.debug("Checked table existence",
table=table_name,
exists=exists)
return exists
except Exception as e:
self.log.warning(f"{table_name}存在性检查失败", error=str(e))
except Exception:
return False
def drop_table(self, table_name: str) -> bool:
"""删除表(增加二次确认日志"""
"""删除表(原有逻辑完全保留"""
if not self.table_exists(table_name):
self.log.warning(f"{table_name}不存在,无法删除")
self.log.warning("Table does not exist", table=table_name)
return False
try:
self.execute_sql(f"DROP TABLE `{table_name}`")
self.log.info(f"{table_name}删除成功")
self.execute_sql(f"DROP TABLE {table_name}")
self.log.info("Table dropped successfully", table=table_name)
return True
except Exception as e:
self.log.error(f"{table_name}删除失败", error=str(e), exc_info=True)
self.log.error("Failed to drop table",
table=table_name,
error=str(e),
exc_info=True)
return False
def get_pool_status(self) -> Dict[str, int]:
"""获取连接池状态"""
status = {
'max_connections': self._pool._maxconnections,
'active_connections': len(self._pool._connections),
'idle_connections': len(self._pool._idle_cache),
'shared_connections': len(self._pool._shared_cache)
}
self.log.debug("连接池状态", **status)
return status
def validate_connection(self) -> bool:
"""验证连接是否有效(增强健康检查"""
"""验证连接是否有效(原有逻辑完全保留"""
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute("SELECT 1 AS health_check")
result = cursor.fetchone()
return result['health_check'] == 1
except Exception as e:
self.log.warning("连接健康检查失败", error=str(e))
cursor.execute("SELECT 1")
return cursor.fetchone()[0] == 1
except Exception:
return False
def __del__(self):
"""析构函数(确保连接池关闭)"""
if hasattr(self, '_pool') and self._pool:
try:
self._pool.close()
self.log.info("连接池已关闭")
except Exception as e:
self.log.error("连接池关闭失败", error=str(e))
# 平台特定的默认配置(原有逻辑完全保留)
def get_default_config():
"""获取各平台默认配置(优化默认参数)"""
"""获取各平台默认配置"""
current_platform = platform.system()
base_config = {
@@ -555,76 +623,41 @@ def get_default_config():
'port': 3306,
'user': 'root',
'password': '123123',
'database': 'intelligence',
'max_connections': 10, # 增加默认连接数
'charset': 'utf8mb4'
'database': 'intelligence_system',
}
if current_platform == 'Windows':
return {
**base_config,
return {**base_config,
'connect_timeout': 10,
'read_timeout': 30,
'write_timeout': 30,
'ssl': None # Windows默认禁用SSL
'write_timeout': 30
}
elif current_platform == 'Darwin': # macOS
elif current_platform == 'Darwin':
return {
**base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'} # macOS默认SSL配置
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
}
else: # Linux其他平台
return {
**base_config,
else: # Linux其他平台
return {**base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60,
'ssl': None # Linux默认禁用SSL
'write_timeout': 60
}
if __name__ == "__main__":
# 使用示例
try:
# 使用示例(原有逻辑完全保留)
db = MySQLAgent(get_default_config())
# 测试连接
if db.validate_connection():
print("数据库连接成功")
print("Database connection successful")
# 获取数据库版本
version_df = db.query_to_df("SELECT VERSION() as version")
print(f"数据库版本: {version_df['version'].iloc[0]}")
# 查看连接池状态
print("连接池状态:", db.get_pool_status())
# 创建测试表
test_df = pd.DataFrame({
'id': [1, 2, 3],
'name': ['测试1', '测试2', '测试3'],
'value': [10.5, 20.3, 30.8],
'created_at': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03'])
})
db.create_table_from_df('test_table', test_df, primary_key='id')
print("测试表创建成功")
# 插入数据
rows_inserted = db.insert_from_df('test_table', test_df)
print(f"插入了{rows_inserted}行数据")
# 查询数据
result_df = db.query_to_df("SELECT * FROM test_table")
print("查询结果:")
print(result_df)
# 清理测试表
db.drop_table('test_table')
print("测试表已删除")
version = db.query_to_df("SELECT VERSION() as version")
print(f"Database version: {version['version'].iloc[0]}")
else:
print("数据库连接失败")
except Exception as e:
print(f"示例执行失败: {str(e)}")
print("Failed to connect to database")