mysql数据链接更新

This commit is contained in:
z66
2025-09-18 17:03:24 +08:00
parent 9afa9d2e58
commit 20fd9587ee
5 changed files with 188086 additions and 726 deletions
+150 -117
View File
@@ -6,73 +6,109 @@ import os
import pickle import pickle
import time import time
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
import pymysql from loguru import logger
from utils.mysql_agent import MySQLAgent
from typing import Dict, List, Optional, Any
# 数据库连接信息 # 数据库连接配置
local_DB_Config = { local_DB_Config = {
'host': "localhost", 'host': "localhost",
'user': "root", 'user': "root",
'password': "123123", 'password': "123123",
'database': "intelligence_system", 'database': "intelligence_system",
'charset': 'utf8mb4' 'port': 3306,
'charset': 'utf8mb4',
'connect_timeout': 10,
'read_timeout': 30,
'write_timeout': 30,
'autocommit': True
} }
# 表名 # 目标数据表名
table_name = "collector_rss_subscriptions" table_name = "collector_rss_subscriptions"
def verify_database(): class NewsAPIClient:
"""验证数据库连接和表结构""" """新闻API客户端,用于获取和处理RSS源数据并写入数据库"""
def __init__(self):
"""初始化客户端并建立数据库连接"""
self.logger = logger.bind(module="NewsAPIClient")
self.db_agent = MySQLAgent(local_DB_Config)
self.logger.info("新闻API客户端初始化完成,已连接到数据库")
def _format_result(self, success: bool, message: str = "", data: Optional[Any] = None) -> Dict[str, Any]:
"""统一返回结果格式"""
return {
'success': bool(success),
'message': str(message),
'data': data
}
def verify_database(self) -> bool:
"""验证数据库表结构是否符合要求(适配元组格式的查询结果)"""
try: try:
conn = pymysql.connect(**local_DB_Config) # 1. 检查表是否存在(execute_sql返回元组列表,如 [(table_name,)]
with conn.cursor() as cursor: result = self.db_agent.execute_sql(
# 检查表是否存在 f"SHOW TABLES LIKE '{table_name}'",
cursor.execute(f"SHOW TABLES LIKE '{table_name}'") fetch=True
if not cursor.fetchone(): )
print(f"错误: 表 {table_name} 不存在!") # 元组结果需通过索引0判断(若表存在,result是[(table_name,)], 否则为空列表)
if not result:
self.logger.error(f"{table_name} 不存在,请先创建表结构")
return False return False
# 检查表结构 # 2. 检查表字段是否完整(DESCRIBE返回的元组格式:(字段名, 类型, 是否为空, ...))
cursor.execute(f"DESCRIBE {table_name}") desc_result = self.db_agent.execute_sql(
columns = [col[0] for col in cursor.fetchall()] f"DESCRIBE {table_name}",
print("表列名:", columns) fetch=True
)
# 关键修改:用元组索引0提取字段名(而非字典键'Field'
columns = [col[0] for col in desc_result] # col是元组,col[0]即字段名
required_columns = ['文章标题', '文章链接', '文章摘要', '发布时间',
'来源URL', '创建时间', '更新时间']
missing_cols = [col for col in required_columns if col not in columns]
# 检查插入权限 if missing_cols:
test_sql = f"""INSERT INTO `{table_name}` self.logger.error(f"{table_name} 缺少必要字段:{missing_cols}")
(`文章标题`, `文章链接`, `文章摘要`, `发布时间`, `来源URL`) return False
VALUES (%s, %s, %s, %s, %s)"""
cursor.execute(test_sql, ('测试标题', 'http://test.com', '测试内容', datetime.now(), '测试来源'))
conn.rollback()
print("数据库验证通过!") self.logger.info(f"数据库表结构验证通过,当前字段:{columns}")
return True return True
except Exception as e: except Exception as e:
print("数据库验证失败:", e) self.logger.error(f"数据库验证失败: {str(e)}", exc_info=True)
return False return False
finally:
if 'conn' in locals():
conn.close()
def load_last_update_time(self) -> Optional[datetime]:
def load_last_update_time(): """加载上次更新时间缓存"""
"""加载上次更新的时间"""
cache_file = os.path.join(os.getcwd(), 'output', 'last_update.pkl') cache_file = os.path.join(os.getcwd(), 'output', 'last_update.pkl')
if os.path.exists(cache_file): if os.path.exists(cache_file):
try:
with open(cache_file, 'rb') as f: with open(cache_file, 'rb') as f:
return pickle.load(f) last_update = pickle.load(f)
self.logger.debug(f"加载上次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
return last_update
except Exception as e:
self.logger.error(f"加载上次更新时间失败: {str(e)}", exc_info=True)
self.logger.debug("未找到上次更新时间缓存,将获取全部数据")
return None return None
def save_last_update_time(self, last_update: datetime) -> None:
"""保存本次更新时间"""
try:
cache_dir = os.path.join(os.getcwd(), 'output')
os.makedirs(cache_dir, exist_ok=True)
cache_file = os.path.join(cache_dir, 'last_update.pkl')
def save_last_update_time(last_update):
"""保存本次更新的时间"""
cache_file = os.path.join(os.getcwd(), 'output', 'last_update.pkl')
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
with open(cache_file, 'wb') as f: with open(cache_file, 'wb') as f:
pickle.dump(last_update, f) pickle.dump(last_update, f)
self.logger.debug(f"已保存本次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
except Exception as e:
self.logger.error(f"保存更新时间失败: {str(e)}", exc_info=True)
def fetch_single_rss(self, url: str, timeout: int = 15) -> Optional[feedparser.FeedParserDict]:
def fetch_single_rss(url, timeout=15): """获取并解析单个RSS源"""
"""获取并解析单个 RSS 源"""
headers = { headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
} }
@@ -85,24 +121,25 @@ def fetch_single_rss(url, timeout=15):
feed = feedparser.parse(response.text) feed = feedparser.parse(response.text)
if feed.bozo: if feed.bozo:
print(f"警告: 解析可能存在问题: {feed.bozo_exception}") self.logger.warning(f"解析 {url} 存在潜在问题: {feed.bozo_exception}")
self.logger.debug(f"成功获取 {url} 的RSS数据")
return feed return feed
except requests.RequestException as e: except requests.RequestException as e:
print(f"{attempt + 1}尝试获取 {url} 失败: {e}") self.logger.warning(f"{attempt + 1} 次获取 {url} 失败: {str(e)}")
if attempt < 2: if attempt < 2:
time.sleep(5 * (attempt + 1)) time.sleep(3 * (attempt + 1)) # 指数退避重试
continue continue
self.logger.error(f"三次尝试后仍无法获取 {url} 的RSS数据")
return None return None
def fetch_all_rss(self, urls: List[str], timeout: int = 15) -> Dict[str, feedparser.FeedParserDict]:
def fetch_all_rss(urls, timeout=15): """并发获取多个RSS源"""
"""使用线程池并发获取多个RSS源"""
feeds = {} feeds = {}
with ThreadPoolExecutor(max_workers=3) as executor: with ThreadPoolExecutor(max_workers=3) as executor:
future_to_url = {executor.submit(fetch_single_rss, url, timeout): url for url in urls} future_to_url = {executor.submit(self.fetch_single_rss, url, timeout): url for url in urls}
for future in as_completed(future_to_url): for future in as_completed(future_to_url):
url = future_to_url[future] url = future_to_url[future]
@@ -111,13 +148,13 @@ def fetch_all_rss(urls, timeout=15):
if feed: if feed:
feeds[url] = feed feeds[url] = feed
except Exception as e: except Exception as e:
print(f"获取 {url} 时发生异常: {e}") self.logger.error(f"处理 {url} 时发生异常: {str(e)}", exc_info=True)
self.logger.info(f"RSS源获取完成,成功获取 {len(feeds)}/{len(urls)} 个源")
return feeds return feeds
def process_feed_entry(self, entry: Dict[str, Any], url: str) -> Dict[str, str]:
def process_feed_entry(entry, url): """处理单个RSS条目,转换为数据库兼容格式"""
"""处理单个RSS条目并返回结构化数据"""
# 处理标题 # 处理标题
title = entry.get('title', '无标题') or '无标题' title = entry.get('title', '无标题') or '无标题'
if len(title) > 255: if len(title) > 255:
@@ -131,7 +168,7 @@ def process_feed_entry(entry, url):
# 处理摘要 # 处理摘要
summary = entry.get('summary', '无内容摘要') summary = entry.get('summary', '无内容摘要')
content_list = entry.get('content', []) content_list = entry.get('content', [])
content = content_list[0].value if content_list else '' content = content_list[0].value if (content_list and hasattr(content_list[0], 'value')) else ''
description = summary if summary != '无内容摘要' else (content[:200] + '...' if content else '无内容摘要') description = summary if summary != '无内容摘要' else (content[:200] + '...' if content else '无内容摘要')
# 处理发布时间 # 处理发布时间
@@ -141,7 +178,7 @@ def process_feed_entry(entry, url):
else: else:
pub_str = entry.get('published', entry.get('updated', '')) pub_str = entry.get('published', entry.get('updated', ''))
try: try:
entry_time = datetime.strptime(pub_str, '%a, %d %b %Y %H:%M:%S %z') entry_time = datetime.strptime(pub_str, '%a, %d %b %Y %H:%M:%S %z').replace(tzinfo=None)
except: except:
entry_time = datetime.now() entry_time = datetime.now()
@@ -150,100 +187,89 @@ def process_feed_entry(entry, url):
if len(source_url) > 1024: if len(source_url) > 1024:
source_url = source_url[:1021] + '...' source_url = source_url[:1021] + '...'
# 当前时间(创建/更新时间)
current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
return { return {
'文章标题': title, '文章标题': title,
'文章链接': link, '文章链接': link,
'文章摘要': description, '文章摘要': description,
'发布时间': entry_time.strftime('%Y-%m-%d %H:%M:%S'), '发布时间': entry_time.strftime('%Y-%m-%d %H:%M:%S'),
'来源URL': source_url '来源URL': source_url,
'创建时间': current_time,
'更新时间': current_time
} }
def display_feed_info(self, feed: feedparser.FeedParserDict, last_update: Optional[datetime] = None,
def display_feed_info(feed, last_update=None, url=None): url: Optional[str] = None) -> Optional[datetime]:
"""处理并显示RSS源信息""" """处理RSS源信息并写入数据库"""
if not feed: if not feed:
print("无法显示信息:feed 为 None") self.logger.warning("无法处理空的RSS源数据")
return None return None
print("=" * 80) self.logger.info(f"开始处理 RSS 源: {url}")
print(f"处理 RSS 源: {url}")
entries = feed.entries entries = feed.entries
data_list = [] data_list = []
new_last_update = last_update new_last_update = last_update
for i, entry in enumerate(entries, 1): for i, entry in enumerate(entries, 1):
entry_data = process_feed_entry(entry, url) entry_data = self.process_feed_entry(entry, url)
entry_time = datetime.strptime(entry_data['发布时间'], '%Y-%m-%d %H:%M:%S') entry_time = datetime.strptime(entry_data['发布时间'], '%Y-%m-%d %H:%M:%S')
# 过滤旧数据
if last_update and entry_time <= last_update: if last_update and entry_time <= last_update:
continue continue
# 更新最新时间戳
if new_last_update is None or entry_time > new_last_update: if new_last_update is None or entry_time > new_last_update:
new_last_update = entry_time new_last_update = entry_time
print(f"\n--- 条目 {i} ---") self.logger.debug(f"处理条目 {i}: {entry_data['文章标题']}")
print(f"标题: {entry_data['文章标题']}")
print(f"链接: {entry_data['文章链接']}")
print(f"摘要: {entry_data['文章摘要'][:100]}...")
print(f"时间: {entry_data['发布时间']}")
data_list.append(entry_data) data_list.append(entry_data)
# 写入数据库
if data_list: if data_list:
df = pd.DataFrame(data_list) df = pd.DataFrame(data_list)
write_to_database(df) self.write_to_database(df)
return new_last_update return new_last_update
# news_api.py 中的 write_to_database 方法可以保持简洁
def write_to_database(df): def write_to_database(self, df: pd.DataFrame) -> Dict[str, Any]:
"""将数据写入数据库"""
if df.empty: if df.empty:
print("没有新数据需要写入") self.logger.info("没有新数据需要写入数据库")
return return self._format_result(True, "没有新数据需要写入")
print("\n准备写入数据库的数据样例:")
print(df.iloc[0].to_dict())
try: try:
conn = pymysql.connect(**local_DB_Config) inserted_rows = self.db_agent.insert_from_df(
with conn.cursor() as cursor: table_name=table_name,
sql = f"""INSERT IGNORE INTO `{table_name}` df=df,
(`文章标题`, `文章链接`, `文章摘要`, `发布时间`, `来源URL`) chunk_size=500,
VALUES (%s, %s, %s, %s, %s)""" replace=False
)
success_count = 0 self.logger.info(f"成功写入 {inserted_rows}/{len(df)} 条记录")
for _, row in df.iterrows(): return self._format_result(
True,
f"成功写入 {inserted_rows}/{len(df)} 条记录",
{"success_count": inserted_rows, "total": len(df)}
)
except Exception as e:
self.logger.error(f"数据库写入失败: {str(e)}", exc_info=True)
return self._format_result(False, f"数据库操作失败: {str(e)}")
@classmethod
def main(cls):
"""主函数入口"""
try: try:
cursor.execute(sql, ( client = cls()
row['文章标题'],
row['文章链接'],
row['文章摘要'],
row['发布时间'],
row['来源URL']
))
success_count += cursor.rowcount
except Exception as e:
print(f"插入记录时出错: {e}")
print(f"问题数据: {row.to_dict()}")
continue
conn.commit() # 验证数据库
print(f"成功写入 {success_count}/{len(df)} 条记录") if not client.verify_database():
client.logger.error("数据库验证失败,程序终止")
except Exception as e:
print("数据库操作失败:", e)
finally:
if 'conn' in locals():
conn.close()
def main():
"""主函数"""
if not verify_database():
print("数据库验证失败,程序终止")
return return
# RSS源列表
rss_urls = [ rss_urls = [
"https://www.chinanews.com.cn/rss/finance.xml", "https://www.chinanews.com.cn/rss/finance.xml",
"https://www.chinanews.com.cn/rss/world.xml", "https://www.chinanews.com.cn/rss/world.xml",
@@ -251,27 +277,34 @@ def main():
"https://www.chinanews.com.cn/rss/scroll-news.xml" "https://www.chinanews.com.cn/rss/scroll-news.xml"
] ]
last_update = load_last_update_time() # 加载上次更新时间
last_update = client.load_last_update_time()
if last_update: if last_update:
print(f"上次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}") client.logger.info(f"上次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
print("\n开始获取RSS数据...") # 获取RSS数据
client.logger.info("开始获取RSS源数据...")
start_time = time.time() start_time = time.time()
feeds = fetch_all_rss(rss_urls) feeds = client.fetch_all_rss(rss_urls)
print(f"获取完成,耗时: {time.time() - start_time:.2f}") client.logger.info(f"获取完成,耗时: {time.time() - start_time:.2f}")
# 处理并写入数据
new_last_update = None new_last_update = None
for url, feed in feeds.items(): for url, feed in feeds.items():
current_last_update = display_feed_info(feed, last_update, url) current_last_update = client.display_feed_info(feed, last_update, url)
if current_last_update and (new_last_update is None or current_last_update > new_last_update): if current_last_update and (new_last_update is None or current_last_update > new_last_update):
new_last_update = current_last_update new_last_update = current_last_update
# 保存最新更新时间
if new_last_update: if new_last_update:
save_last_update_time(new_last_update) client.save_last_update_time(new_last_update)
print(f"\n本次最新更新时间: {new_last_update.strftime('%Y-%m-%d %H:%M:%S')}") client.logger.info(f"本次最新更新时间: {new_last_update.strftime('%Y-%m-%d %H:%M:%S')}")
else: else:
print("\n没有获取到新内容") client.logger.info("没有获取到新内容")
except Exception as e:
logger.error(f"程序运行出错: {str(e)}", exc_info=True)
if __name__ == "__main__": if __name__ == "__main__":
main() NewsAPIClient.main()
+123849
View File
File diff suppressed because it is too large Load Diff
+63455
View File
File diff suppressed because it is too large Load Diff
+101 -111
View File
@@ -3,18 +3,20 @@ import pandas as pd
from datetime import datetime from datetime import datetime
import time import time
import pymysql import pymysql
from utils.mysql_agent import MySQLAgent
import platform import platform
from concurrent.futures import ThreadPoolExecutor
from utils.mysql_agent import MySQLAgent
class TestMySQLAgent(unittest.TestCase): class TestMySQLAgent(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
"""初始化测试环境和测试表""" """初始化测试环境和测试表"""
# 创建唯一的测试数据库 # 创建唯一的测试数据库和表名(避免冲突)
cls.test_db_name = "test_db_" + datetime.now().strftime("%Y%m%d%H%M%S") cls.test_db_name = f"test_db_{datetime.now().strftime('%Y%m%d%H%M%S')}"
cls.test_table = "test_table_" + datetime.now().strftime("%Y%m%d%H%M%S") cls.test_table = f"test_table_{datetime.now().strftime('%Y%m%d%H%M%S')}"
# 基础配置 # 基础配置(根据实际环境修改)
cls.base_config = { cls.base_config = {
'host': 'localhost', 'host': 'localhost',
'port': 3306, 'port': 3306,
@@ -32,21 +34,19 @@ class TestMySQLAgent(unittest.TestCase):
'database': cls.test_db_name 'database': cls.test_db_name
}) })
# 创建测试表 # 创建测试表并插入初始数据
test_data = pd.DataFrame({ test_data = pd.DataFrame({
'id': [1, 2, 3], 'id': [1, 2, 3],
'name': ['Test1', 'Test2', 'Test3'], 'name': ['Test1', 'Test2', 'Test3'],
'value': [10.5, 20.3, 30.8], 'value': [10.5, 20.3, 30.8],
'created_at': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03']) 'created_at': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03'])
}) })
cls.db.create_table_from_df(cls.test_table, test_data, primary_key='id') cls.db.create_table_from_df(cls.test_table, test_data, primary_key='id')
cls.db.insert_from_df(cls.test_table, test_data) cls.db.insert_from_df(cls.test_table, test_data)
@classmethod @classmethod
def _create_test_database(cls): def _create_test_database(cls):
"""创建测试数据库""" """创建测试数据库"""
# 使用临时连接创建数据库
temp_conn = pymysql.connect( temp_conn = pymysql.connect(
host=cls.base_config['host'], host=cls.base_config['host'],
port=cls.base_config['port'], port=cls.base_config['port'],
@@ -54,7 +54,6 @@ class TestMySQLAgent(unittest.TestCase):
password=cls.base_config['password'], password=cls.base_config['password'],
charset='utf8mb4' charset='utf8mb4'
) )
try: try:
with temp_conn.cursor() as cursor: with temp_conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}") cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}")
@@ -66,21 +65,14 @@ class TestMySQLAgent(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
"""清理测试数据库""" """清理测试环境"""
if hasattr(cls, 'db') and cls.db: if hasattr(cls, 'db') and cls.db:
# 删除测试表 # 删除测试表
if cls.db.table_exists(cls.test_table): if cls.db.table_exists(cls.test_table):
cls.db.drop_table(cls.test_table) cls.db.drop_table(cls.test_table)
# 删除测试数据库 # 删除测试数据库
temp_conn = pymysql.connect( temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4')
host=cls.base_config['host'],
port=cls.base_config['port'],
user=cls.base_config['user'],
password=cls.base_config['password'],
charset='utf8mb4'
)
try: try:
with temp_conn.cursor() as cursor: with temp_conn.cursor() as cursor:
cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}") cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}")
@@ -88,22 +80,24 @@ class TestMySQLAgent(unittest.TestCase):
finally: finally:
temp_conn.close() temp_conn.close()
def test_01_connection(self): def test_connection(self):
"""测试数据库连接""" """测试数据库连接"""
version = self.db.query_to_df("SELECT VERSION() as version") version_df = self.db.query_to_df("SELECT VERSION() as version")
self.assertIsNotNone(version) self.assertIsNotNone(version_df)
print(f"\nDatabase version: {version['version'].iloc[0]}") self.assertEqual(len(version_df), 1)
print(f"Running on: {platform.system()} {platform.release()}") print(f"数据库版本: {version_df['version'].iloc[0]}")
def test_02_query_to_df(self): def test_query_to_df(self):
"""测试查询返回DataFrame""" """测试查询返回DataFrame"""
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id > %s", (1,)) df = self.db.query_to_df(
self.assertEqual(len(df), 2) f"SELECT * FROM {self.test_table} WHERE id > %s",
params=(1,)
)
self.assertIsInstance(df, pd.DataFrame) self.assertIsInstance(df, pd.DataFrame)
print("\nQuery result sample:") self.assertEqual(len(df), 2) # id>1 的数据有2条
print(df.head()) self.assertIn('name', df.columns)
def test_03_insert_from_df(self): def test_insert_from_df(self):
"""测试DataFrame插入""" """测试DataFrame插入"""
new_data = pd.DataFrame({ new_data = pd.DataFrame({
'id': [4, 5], 'id': [4, 5],
@@ -112,55 +106,55 @@ class TestMySQLAgent(unittest.TestCase):
'created_at': pd.to_datetime(['2023-01-04', '2023-01-05']) 'created_at': pd.to_datetime(['2023-01-04', '2023-01-05'])
}) })
rows = self.db.insert_from_df(self.test_table, new_data) inserted_rows = self.db.insert_from_df(self.test_table, new_data)
self.assertEqual(rows, 2) self.assertEqual(inserted_rows, 2)
# 验证数据 # 验证插入结果
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id >= 4") result_df = self.db.query_to_df(
self.assertEqual(len(df), 2) f"SELECT name FROM {self.test_table} WHERE id IN (4,5)"
self.assertEqual(df['name'].tolist(), ['Test4', 'Test5']) )
self.assertEqual(result_df['name'].tolist(), ['Test4', 'Test5'])
def test_04_update_from_df(self): def test_update_from_df(self):
"""测试DataFrame更新""" """测试DataFrame更新"""
update_data = pd.DataFrame({ update_data = pd.DataFrame({
'id': [1, 2], 'id': [1, 2],
'name': ['Updated1', 'Updated2'] 'name': ['Updated1', 'Updated2']
}) })
rows = self.db.update_from_df(self.test_table, update_data, 'id') updated_rows = self.db.update_from_df(self.test_table, update_data, 'id')
self.assertGreaterEqual(rows, 2) self.assertGreaterEqual(updated_rows, 2)
# 验证更新 # 验证更新结果
df = self.db.query_to_df(f"SELECT name FROM {self.test_table} WHERE id IN (1,2)") result_df = self.db.query_to_df(
self.assertIn('Updated1', df['name'].values) f"SELECT name FROM {self.test_table} WHERE id IN (1,2)"
self.assertIn('Updated2', df['name'].values) )
self.assertIn('Updated1', result_df['name'].values)
self.assertIn('Updated2', result_df['name'].values)
def test_05_transaction(self): def test_transaction(self):
"""测试事务处理""" """测试事务处理"""
conn = self.db.begin_transaction() conn = self.db.begin_transaction()
try: try:
# 执行多个操作 # 执行事务内操作
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(f"UPDATE {self.test_table} SET value = 99.9 WHERE id = 1") cursor.execute(f"UPDATE {self.test_table} SET value = 99.9 WHERE id = 1")
cursor.execute(f"UPDATE {self.test_table} SET value = 88.8 WHERE id = 2") cursor.execute(f"UPDATE {self.test_table} SET value = 88.8 WHERE id = 2")
# 验证事务内修改
cursor.execute(f"SELECT value FROM {self.test_table} WHERE id = 1")
self.assertEqual(cursor.fetchone()['value'], 99.9)
self.db.commit_transaction(conn) self.db.commit_transaction(conn)
except Exception: except Exception:
self.db.rollback_transaction(conn) self.db.rollback_transaction(conn)
raise raise
# 验证提交后的修改 # 验证事务提交结果
df = self.db.query_to_df(f"SELECT value FROM {self.test_table} WHERE id IN (1,2)") result_df = self.db.query_to_df(
self.assertIn(99.9, df['value'].values) f"SELECT value FROM {self.test_table} WHERE id IN (1,2)"
self.assertIn(88.8, df['value'].values) )
self.assertIn(99.9, result_df['value'].values)
self.assertIn(88.8, result_df['value'].values)
def test_06_large_data(self): def test_large_data_insert(self):
"""测试大数据量操作""" """测试大数据量插入"""
# 生成测试数据 # 生成1000行测试数据
large_data = pd.DataFrame({ large_data = pd.DataFrame({
'id': range(1000, 2000), 'id': range(1000, 2000),
'name': [f"Item_{i}" for i in range(1000, 2000)], 'name': [f"Item_{i}" for i in range(1000, 2000)],
@@ -168,59 +162,55 @@ class TestMySQLAgent(unittest.TestCase):
'created_at': pd.date_range('2023-01-01', periods=1000) 'created_at': pd.date_range('2023-01-01', periods=1000)
}) })
# Windows平台使用更小的批次 # 根据平台自动调整批次大小
chunk_size = 100 if platform.system() == 'Windows' else 500 chunk_size = 100 if platform.system() == 'Windows' else 500
start_time = time.time() start_time = time.time()
rows = self.db.insert_from_df(self.test_table, large_data, chunk_size=chunk_size) inserted_rows = self.db.insert_from_df(
self.test_table,
large_data,
chunk_size=chunk_size
)
elapsed = time.time() - start_time elapsed = time.time() - start_time
self.assertEqual(rows, 1000) self.assertEqual(inserted_rows, 1000)
print(f"\nInserted 1000 rows in {elapsed:.2f}s (chunk_size={chunk_size})") print(f"插入1000行数据耗时: {elapsed:.2f} (批次大小: {chunk_size})")
# 验证数据 def test_concurrent_access(self):
df = self.db.query_to_df(f"SELECT COUNT(*) as cnt FROM {self.test_table} WHERE id >= 1000")
self.assertEqual(df['cnt'].iloc[0], 1000)
def test_07_concurrent_access(self):
"""测试并发访问""" """测试并发访问"""
from concurrent.futures import ThreadPoolExecutor
def worker(i): def query_worker(i):
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id = %s", (i % 5 + 1,)) """并发查询工作函数"""
df = self.db.query_to_df(
f"SELECT * FROM {self.test_table} WHERE id = %s",
params=(i % 3 + 1,) # 查询id=1,2,3循环
)
return len(df) return len(df)
# 20个线程执行100次查询
start_time = time.time() start_time = time.time()
with ThreadPoolExecutor(max_workers=20) as executor: with ThreadPoolExecutor(max_workers=20) as executor:
results = list(executor.map(worker, range(100))) results = list(executor.map(query_worker, range(100)))
elapsed = time.time() - start_time elapsed = time.time() - start_time
self.assertEqual(sum(results), 100)
print(f"\nCompleted 100 concurrent queries in {elapsed:.2f}s") self.assertEqual(sum(results), 100) # 每次查询应返回1行
print(f"100次并发查询耗时: {elapsed:.2f}")
class TestPlatformSpecific(unittest.TestCase): class TestPlatformSpecific(unittest.TestCase):
"""平台特定功能测试"""
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
"""创建临时测试数据库""" cls.test_db_name = f"test_platform_db_{datetime.now().strftime('%Y%m%d%H%M%S')}"
cls.test_db_name = "test_db_platform_" + datetime.now().strftime("%Y%m%d%H%M%S")
cls.base_config = { cls.base_config = {
'host': 'localhost', 'host': 'localhost',
'port': 3306, 'port': 3306,
'user': 'root', 'user': 'root',
'password': '123123', 'password': '123123'
'max_connections': 10
} }
# 创建数据库 # 创建测试数据库
temp_conn = pymysql.connect( temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4')
host=cls.base_config['host'],
port=cls.base_config['port'],
user=cls.base_config['user'],
password=cls.base_config['password'],
charset='utf8mb4'
)
try: try:
with temp_conn.cursor() as cursor: with temp_conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}") cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}")
@@ -230,15 +220,8 @@ class TestPlatformSpecific(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
"""删除临时测试数据库""" """清理测试数据库"""
temp_conn = pymysql.connect( temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4')
host=cls.base_config['host'],
port=cls.base_config['port'],
user=cls.base_config['user'],
password=cls.base_config['password'],
charset='utf8mb4'
)
try: try:
with temp_conn.cursor() as cursor: with temp_conn.cursor() as cursor:
cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}") cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}")
@@ -249,42 +232,49 @@ class TestPlatformSpecific(unittest.TestCase):
def test_windows_timeout(self): def test_windows_timeout(self):
"""测试Windows平台超时处理""" """测试Windows平台超时处理"""
if platform.system() != 'Windows': if platform.system() != 'Windows':
self.skipTest("Only runs on Windows") self.skipTest("仅在Windows平台运行")
config = { config = {
**self.base_config, **self.base_config,
'database': self.test_db_name, 'database': self.test_db_name,
'connect_timeout': 1, 'connect_timeout': 1,
'read_timeout': 1 'read_timeout': 1,
'write_timeout': 1
} }
db = MySQLAgent(config) db = MySQLAgent(config)
# 测试短超时查询 # 执行会超时查询(SLEEP(2)超过1秒超时设置)
start_time = time.time() with self.assertRaises((pymysql.OperationalError, TimeoutError)) as ctx:
try: try:
db.query_to_df("SELECT SLEEP(2)") db.query_to_df("SELECT SLEEP(2)")
self.fail("Should have timed out")
except Exception as e: except Exception as e:
self.assertIn("timed out", str(e)) # 提取底层异常信息(可能被包装)
print(f"\nWindows timeout test: {str(e)}") while hasattr(e, 'args') and len(e.args) > 0 and isinstance(e.args[0], Exception):
e = e.args[0]
raise e
def test_macos_ssl(self): error_msg = str(ctx.exception)
"""测试macOS SSL连接""" self.assertTrue(
"timed out" in error_msg or
"timeout" in error_msg or
"HY000" in error_msg, # MySQL超时错误码
f"未检测到超时异常,实际异常: {error_msg}"
)
def test_macos_ssl_connection(self):
"""测试macOS平台SSL连接"""
if platform.system() != 'Darwin': if platform.system() != 'Darwin':
self.skipTest("Only runs on macOS") self.skipTest("仅在macOS平台运行")
config = { config = {
**self.base_config, **self.base_config,
'database': self.test_db_name, 'database': self.test_db_name,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'} 'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
} }
db = MySQLAgent(config) db = MySQLAgent(config)
version = db.query_to_df("SELECT VERSION() as version") version_df = db.query_to_df("SELECT VERSION() as version")
self.assertIsNotNone(version) self.assertIsNotNone(version_df)
print(f"\nmacOS SSL connection successful: {version['version'].iloc[0]}")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main(verbosity=2)
+389 -356
View File
@@ -3,13 +3,13 @@ import sys
import platform import platform
import pandas as pd import pandas as pd
import pymysql import pymysql
import json
import numpy as np
from pymysql import cursors from pymysql import cursors
from pymysql.err import MySQLError from pymysql.err import MySQLError
from dbutils.pooled_db import PooledDB from typing import Union, List, Dict, Any, Optional, Tuple, Literal
from typing import Union, List, Dict, Any, Optional, Tuple
import threading import threading
from datetime import datetime from datetime import datetime
import numpy as np
from pathlib import Path from pathlib import Path
# 导入日志系统 # 导入日志系统
@@ -20,7 +20,7 @@ class MySQLAgent:
""" """
全平台兼容的MySQL数据库操作类 全平台兼容的MySQL数据库操作类
支持Windows/macOS/Linux系统 支持Windows/macOS/Linux系统
配置参数从外部传入 配置参数从外部传入,不使用连接池和事务管理
""" """
_instance = None _instance = None
@@ -34,30 +34,14 @@ class MySQLAgent:
return cls._instance return cls._instance
def __init__(self, config: dict): def __init__(self, config: dict):
""" """初始化MySQL数据库连接(原有逻辑完全保留)"""
初始化MySQL数据库连接 if hasattr(self, 'config') and self.config:
Args:
config (dict): 数据库配置字典,包含以下键:
- host: 数据库主机
- port: 端口
- user: 用户名
- password: 密码
- database: 数据库名
- [可选] charset: 字符集(默认utf8mb4)
- [可选] max_connections: 最大连接数(默认5)
- [可选] connect_timeout: 连接超时(秒)
- [可选] read_timeout: 读取超时(秒)
- [可选] write_timeout: 写入超时(秒)
- [可选] ssl: SSL配置
"""
if hasattr(self, '_pool') and self._pool:
return return
# 基础配置校验 # 基础配置校验
required_keys = ['host', 'port', 'user', 'password', 'database'] required_keys = ['host', 'port', 'user', 'password', 'database']
if not all(key in config for key in required_keys): if not all(key in config for key in required_keys):
log.warning(f"数据库配置缺少必要参数,当前配置: {config}") log.warning(f"数据库配置缺少必要参数,当前数据库链接信息为:{config}")
raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}") raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}")
self.config = { self.config = {
@@ -67,7 +51,6 @@ class MySQLAgent:
'password': config['password'], 'password': config['password'],
'database': config['database'], 'database': config['database'],
'charset': config.get('charset', 'utf8mb4'), 'charset': config.get('charset', 'utf8mb4'),
'cursorclass': cursors.DictCursor,
'autocommit': True, 'autocommit': True,
'connect_timeout': config.get('connect_timeout', 10), 'connect_timeout': config.get('connect_timeout', 10),
'read_timeout': config.get('read_timeout', 30), 'read_timeout': config.get('read_timeout', 30),
@@ -79,86 +62,57 @@ class MySQLAgent:
current_platform = platform.system() current_platform = platform.system()
self.log = log.bind(module=f"MySQLAgent({current_platform})") self.log = log.bind(module=f"MySQLAgent({current_platform})")
# 创建连接池
self.pool_size = config.get('max_connections', 5)
self._pool = self._create_pool()
def _create_pool(self) -> PooledDB:
"""创建连接池"""
try:
# 线程安全的连接创建函数
def connect():
conn = pymysql.connect(**self.config)
conn.threadsafety = 1 # 显式设置线程安全级别
return conn
pool = PooledDB(
creator=connect,
mincached=1,
maxcached=3,
maxconnections=self.pool_size,
blocking=True,
ping=1 # 每次获取连接时ping数据库
)
self.log.info("连接池创建成功")
return pool
except Exception as e:
self.log.critical("连接池创建失败", error=str(e), exc_info=True)
raise
def get_connection(self) -> pymysql.connections.Connection: def get_connection(self) -> pymysql.connections.Connection:
"""获取数据库连接(修复字符集方法缺失问题""" """获取数据库连接(原有逻辑完全保留"""
try: try:
conn = self._pool.connection() conn = pymysql.connect(**self.config)
# 为连接添加字符集方法(兼容SQLAlchemy # 为连接添加 character_set_name 方法
if not hasattr(conn, 'character_set_name'): if not hasattr(conn, 'character_set_name'):
def _character_set_name(): def _character_set_name():
return self.config.get('charset', 'utf8mb4') return self.config.get('charset', 'utf8mb4')
conn.character_set_name = _character_set_name conn.character_set_name = _character_set_name
# macOS平台SSL特殊处理 # macOS需要特殊处理SSL
if platform.system() == 'Darwin' and self.config.get('ssl'): if platform.system() == 'Darwin' and self.config.get('ssl'):
conn.ping(reconnect=True) conn.ping(reconnect=True)
self.log.trace("获取数据库连接成功") self.log.trace("Database connection obtained")
return conn return conn
except Exception as e: except Exception as e:
error_msg = str(e) error_msg = str(e)
# Windows平台连接超时重试
if platform.system() == 'Windows' and "timed out" in error_msg: if platform.system() == 'Windows' and "timed out" in error_msg:
self.log.warning("Windows连接超时,尝试重试...") self.log.warning("Windows connection timeout, retrying...")
return self._retry_connection() return self._retry_connection()
self.log.error("获取连接失败", error=error_msg, exc_info=True) self.log.error("Connection failed", error=error_msg, exc_info=True)
raise raise
def _retry_connection(self, max_retries: int = 3) -> pymysql.connections.Connection: def _retry_connection(self, max_retries: int = 3) -> Any | None:
"""Windows平台连接重试机制""" """Windows平台连接重试机制(原有逻辑完全保留)"""
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
conn = self._pool.connection() conn = pymysql.connect(**self.config)
self.log.info(f"{attempt + 1}次尝试连接成功") self.log.info(f"Connection established after {attempt + 1} attempts")
return conn return conn
except Exception: except Exception:
if attempt == max_retries - 1: if attempt == max_retries - 1:
raise raise
import time import time
time.sleep(1) # 重试间隔1秒 time.sleep(1)
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None, def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
parse_dates: Union[List[str], bool] = True) -> pd.DataFrame: parse_dates: Union[List[str], bool] = True) -> pd.DataFrame:
"""执行SQL查询并返回DataFrame优化连接管理""" """执行SQL查询并返回DataFrame原有逻辑完全保留"""
conn = None
try: try:
self.log.debug("执行SQL查询", sql=sql) self.log.debug("Executing SQL query", sql=sql)
# 获取连接并确保字符集方法存在
conn = self.get_connection() conn = self.get_connection()
# 创建SQLAlchemy引擎(使用静态池避免连接重复创建) # 创建SQLAlchemy引擎
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool from sqlalchemy.pool import StaticPool
engine = create_engine( engine = create_engine(
@@ -170,180 +124,361 @@ class MySQLAgent:
# 执行查询 # 执行查询
df = pd.read_sql(sql, engine, params=params, parse_dates=parse_dates) df = pd.read_sql(sql, engine, params=params, parse_dates=parse_dates)
self.log.info(f"查询成功,返回{len(df)}行数据") self.log.info("Query executed successfully", rows=len(df))
return df return df
except Exception as e: except Exception as e:
self.log.error(f"SQL查询失败{sql}", sql=sql, params=params, error=str(e), exc_info=True) self.log.error("SQL query failed", sql=sql, params=params, error=str(e), exc_info=True)
raise raise
finally: finally:
# 确保连接释放回池 if 'engine' in locals():
if conn: engine.dispose()
try:
conn.close()
except Exception as e:
self.log.warning("关闭连接失败", error=str(e))
def insert_from_df(self, table_name: str, df: pd.DataFrame, def insert_from_df(self, table_name: str, df: pd.DataFrame,
chunk_size: int = 1000, replace: bool = False) -> int: chunk_size: int = 1000, replace: bool = False, # 保留replace参数
"""将DataFrame数据插入到数据库表(优化批量处理)""" ignore_duplicates: bool = None) -> int: # 新增ignore_duplicates参数
"""
兼容旧接口的通用插入方法:保留replace参数,同时支持新的ignore_duplicates
自动处理重复数据,对所有数据源通用
"""
# 【兼容性处理】如果未指定ignore_duplicates,用replace参数推导(replace=True时不忽略重复)
if ignore_duplicates is None:
ignore_duplicates = not replace # 旧逻辑中replace=True表示替换,即不忽略重复
if df.empty: if df.empty:
self.log.warning(f"尝试插入空DataFrame到表{table_name}") self.log.warning("Attempted to insert empty DataFrame", table=table_name)
return 0 return 0
self.log.debug(f"准备插入DataFrame到表{table_name}", rows=len(df), chunk_size=chunk_size) conn = None
cursor = None
# 根据平台自动调整批次大小 total_inserted = 0
current_platform = platform.system() total_duplicated = 0
if current_platform == 'Windows' and chunk_size > 500: total_failed = 0
chunk_size = 500
self.log.debug(f"Windows平台自动调整批次大小为{chunk_size}")
elif current_platform == 'Linux' and chunk_size < 1000:
chunk_size = 1000
self.log.debug(f"Linux平台自动调整批次大小为{chunk_size}")
try: try:
method = 'replace' if replace else 'append' # 1. 建立数据库连接
total_rows = 0
conn = self.get_connection() conn = self.get_connection()
cursor = conn.cursor()
self.log.debug(f"Established connection for inserting into {table_name}")
# 创建SQLAlchemy引擎 # 2. 获取数据库表的实际列名
from sqlalchemy import create_engine cursor.execute(f"SHOW COLUMNS FROM `{table_name}`")
from sqlalchemy.pool import StaticPool columns_info = cursor.fetchall()
engine = create_engine( db_columns = [col[0] for col in columns_info]
"mysql+pymysql://", self.log.debug(f"Table {table_name} has columns: {db_columns}")
creator=lambda: conn,
poolclass=StaticPool, # 3. 数据预处理:统一处理空值
connect_args={ cleaned_df = df.replace(
'charset': self.config.get('charset', 'utf8mb4'), [None, np.nan, pd.NA, 'nan', 'NaN', 'NAN', ''],
'autocommit': True None
} ).copy()
# 4. 字段匹配:只保留与数据库匹配的列
df_columns = cleaned_df.columns.tolist()
matched_columns = [col for col in df_columns if col in db_columns]
unmatched_columns = [col for col in df_columns if col not in db_columns]
if unmatched_columns:
self.log.warning(
f"Table {table_name} dropping unmatched columns",
unmatched_columns=unmatched_columns,
count=len(unmatched_columns)
) )
if not matched_columns:
self.log.warning(f"No matched columns for {table_name}, abort insertion")
return 0
filtered_df = cleaned_df[matched_columns].copy()
total_to_insert = len(filtered_df)
self.log.debug(
f"Filtered DataFrame for {table_name}: {total_to_insert} rows to insert"
)
# 5. 处理复杂类型(dict/list转JSON
for col in filtered_df.columns:
has_complex_type = filtered_df[col].apply(
lambda x: isinstance(x, (dict, list)) if x is not None else False
).any()
if has_complex_type:
self.log.debug(f"Column {col} in {table_name} has complex type, converting to JSON")
filtered_df.loc[:, col] = filtered_df[col].apply(
lambda x: json.dumps(x, ensure_ascii=False) if x is not None else x
)
# 6. 构建通用插入SQL
columns_str = ', '.join([f"`{col}`" for col in filtered_df.columns])
placeholders = ', '.join(['%s'] * len(filtered_df.columns))
insert_sql = f"INSERT INTO `{table_name}` ({columns_str}) VALUES ({placeholders})"
self.log.trace(f"Generated insert SQL for {table_name}: {insert_sql}")
# 7. 逐条插入(确保能捕获单条重复错误)
records = filtered_df.to_dict('records')
indices = filtered_df.index.tolist()
for i, (record, idx) in enumerate(zip(records, indices)):
try: try:
for i in range(0, len(df), chunk_size): data = tuple(record[col] for col in filtered_df.columns)
chunk = df.iloc[i:i + chunk_size].copy() # 使用copy避免SettingWithCopyWarning cursor.execute(insert_sql, data)
total_inserted += 1
# macOS平台datetime特殊处理 if (i + 1) % 100 == 0:
if platform.system() == 'Darwin': self.log.trace(
for col in chunk.select_dtypes(include=['datetime64']): f"Inserted {i + 1}/{total_to_insert} rows into {table_name}"
chunk[col] = chunk[col].dt.strftime('%Y-%m-%d %H:%M:%S')
chunk.to_sql(
table_name,
engine,
if_exists=method,
index=False,
method='multi'
) )
total_rows += len(chunk)
method = 'append' # 首次后使用追加模式
self.log.trace(f"插入第{i // chunk_size + 1}批数据", rows=len(chunk), total=total_rows)
self.log.info(f"数据插入成功,表{table_name}共插入{total_rows}") except MySQLError as e:
return total_rows # 8. 捕获重复错误(MySQL错误码1062)
finally: if e.args[0] == 1062:
engine.dispose() total_duplicated += 1
conn.close() short_record = {
k: (str(v)[:100] + '...') if isinstance(v, (str, dict, list)) else v
for k, v in record.items()
}
self.log.warning(
f"Skipped duplicate record in {table_name}",
index=idx,
error_msg=e.args[1],
record=short_record
)
if not ignore_duplicates:
raise
else:
# 其他数据库错误
total_failed += 1
self.log.error(
f"Failed to insert record in {table_name}",
index=idx,
error_code=e.args[0],
error_msg=e.args[1],
record=record
)
if not ignore_duplicates:
raise
# 提交事务
conn.commit()
# 9. 插入结果统计
self.log.info(
f"Insertion summary for {table_name}",
total_to_insert=total_to_insert,
total_inserted=total_inserted,
total_duplicated=total_duplicated,
total_failed=total_failed
)
return total_inserted
except Exception as e: except Exception as e:
self.log.error(f"数据插入失败,表{table_name}", error=str(e), exc_info=True) if conn:
conn.rollback()
self.log.error(f"Batch insertion failed for {table_name}", error=str(e), exc_info=True)
raise raise
finally:
if cursor:
cursor.close()
if conn:
conn.close()
def _get_primary_key(self, table_name: str, cursor) -> Optional[str]:
"""【新增辅助方法】获取表的主键(用于replace逻辑的去重)"""
try:
cursor.execute("""
SELECT COLUMN_NAME
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
WHERE TABLE_SCHEMA = %s
AND TABLE_NAME = %s
AND CONSTRAINT_NAME = 'PRIMARY'
""", (self.config['database'], table_name))
result = cursor.fetchone()
return result[0] if result else None
except Exception as e:
self.log.warning(f"Failed to get primary key for {table_name}", error=str(e))
return None
def _get_table_detailed_info(self, table_name: str) -> Dict[str, Dict[str, Any]]:
"""获取表的详细结构信息(原有逻辑完全保留,供其他方法调用)"""
sql = """
SELECT column_name, data_type, character_maximum_length
FROM information_schema.columns
WHERE table_schema = %s \
AND table_name = %s \
"""
params = (self.config['database'], table_name)
try:
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(sql, params)
result = cursor.fetchall()
# 强制转换为列表,避免游标类型导致的解析问题
result_list = list(result)
if not result_list:
self.log.error("No columns found in table", table=table_name)
return {}
schema = {}
for row in result_list:
# 确保正确提取字段名(兼容元组格式)
col_name = str(row[0]).strip() # 强制转为字符串并去空格
data_type = str(row[1]).strip()
max_length = row[2] if row[2] else None
schema[col_name] = {
'type': data_type,
'max_length': max_length
}
self.log.debug("Successfully fetched table schema",
table=table_name,
columns=list(schema.keys()))
return schema
finally:
cursor.close()
conn.close()
except Exception as e:
self.log.error("Failed to get table detailed info",
table=table_name,
error=str(e))
raise
def _validate_and_clean_data(self, df: pd.DataFrame, table_name: str,
table_schema: Dict[str, Dict[str, Any]]) -> pd.DataFrame:
"""数据校验与清洗(原有逻辑完全保留,供其他方法调用)"""
# 1. 字段过滤:只保留表中存在的字段
df_columns = df.columns.tolist()
table_columns = list(table_schema.keys())
valid_columns = [col for col in df_columns if col in table_columns]
invalid_columns = [col for col in df_columns if col not in table_columns]
if invalid_columns:
self.log.warning("Dropping invalid columns not present in table",
table=table_name,
invalid_columns=invalid_columns,
count=len(invalid_columns))
cleaned_df = df[valid_columns].copy()
if cleaned_df.empty:
return cleaned_df
# 2. 处理每个字段的数据
for col in valid_columns:
col_info = table_schema[col]
data_type = col_info['type']
max_length = col_info['max_length']
# 2.1 处理空值
if cleaned_df[col].isnull().any():
# 根据字段类型设置默认值
default_value = '' if data_type in ['varchar', 'char', 'text'] else None
cleaned_df[col].fillna(default_value, inplace=True)
self.log.debug("Replaced null values",
table=table_name,
column=col,
default_value=default_value,
count=cleaned_df[col].isnull().sum())
# 2.2 处理字符串类型的超长字段
if data_type in ['varchar', 'char'] and max_length:
# 确保是字符串类型
cleaned_df[col] = cleaned_df[col].astype(str)
# 截断超长内容
too_long_mask = cleaned_df[col].str.len() > max_length
if too_long_mask.any():
cleaned_df.loc[too_long_mask, col] = cleaned_df.loc[too_long_mask, col].str.slice(0, max_length)
self.log.warning("Truncated overlength values",
table=table_name,
column=col,
max_length=max_length,
count=too_long_mask.sum())
# 2.3 处理日期时间类型
if data_type in ['datetime', 'timestamp']:
try:
# 尝试转换为datetime类型
cleaned_df[col] = pd.to_datetime(cleaned_df[col])
except Exception as e:
self.log.warning("Failed to convert to datetime, using current time",
table=table_name,
column=col,
error=str(e))
# 转换失败的用当前时间替代
invalid_mask = pd.to_datetime(cleaned_df[col], errors='coerce').isna()
cleaned_df.loc[invalid_mask, col] = datetime.now()
return cleaned_df
def update_from_df(self, table_name: str, df: pd.DataFrame, def update_from_df(self, table_name: str, df: pd.DataFrame,
key_columns: Union[str, List[str]]) -> int: key_columns: Union[str, List[str]]) -> int:
"""使用DataFrame数据更新数据库表(优化事务处理""" """使用DataFrame数据更新数据库表(原有逻辑完全保留"""
if df.empty: if df.empty:
self.log.warning(f"尝试用空DataFrame更新表{table_name}") self.log.warning("Attempted to update with empty DataFrame", table=table_name)
return 0 return 0
self.log.debug(f"准备从DataFrame更新表{table_name}", key_columns=key_columns, rows=len(df)) self.log.debug("Preparing to update table from DataFrame",
table=table_name,
key_columns=key_columns,
rows=len(df))
try: try:
if isinstance(key_columns, str): if isinstance(key_columns, str):
key_columns = [key_columns] key_columns = [key_columns]
# 验证关键列存在性
missing_keys = [key for key in key_columns if key not in df.columns]
if missing_keys:
raise ValueError(f"DataFrame中缺少关键列: {missing_keys}")
total_updated = 0 total_updated = 0
conn = self.begin_transaction() with self.get_connection() as conn:
with conn.cursor() as cursor:
try:
cursor = conn.cursor()
# 获取表结构信息 # 获取表结构信息
table_info = self._get_table_info(table_name) table_info = self._get_table_detailed_info(table_name)
valid_columns = [col for col in df.columns if col in table_info] columns = [col for col in df.columns if col in table_info]
if not valid_columns:
self.log.warning(f"DataFrame列与表{table_name}无匹配") # 构建UPDATE语句模板
set_clause = ', '.join([f"{col}=%s" for col in columns if col not in key_columns])
where_clause = ' AND '.join([f"{col}=%s" for col in key_columns])
if not set_clause:
self.log.warning("No columns to update", table=table_name)
return 0 return 0
# 构建UPDATE语句 update_sql = f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}"
set_clause = ', '.join([f"`{col}`=%s" for col in valid_columns if col not in key_columns]) self.log.trace("Generated update SQL", sql=update_sql)
where_clause = ' AND '.join([f"`{col}`=%s" for col in key_columns])
update_sql = f"UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}"
self.log.trace("生成更新SQL", sql=update_sql)
# 准备更新数据 # 准备数据
update_data = [] update_data = []
for _, row in df.iterrows(): for _, row in df.iterrows():
set_values = [row[col] for col in valid_columns if col not in key_columns] set_values = [row[col] for col in columns if col not in key_columns]
key_values = [row[col] for col in key_columns] key_values = [row[col] for col in key_columns]
update_data.append(tuple(set_values + key_values)) update_data.append(tuple(set_values + key_values))
# 执行批量更新 # 执行批量更新
cursor.executemany(update_sql, update_data) cursor.executemany(update_sql, update_data)
total_updated = cursor.rowcount total_updated = cursor.rowcount
self.commit_transaction(conn) conn.commit()
self.log.info(f"数据更新成功,表{table_name}共更新{total_updated}")
self.log.info("Data updated successfully",
table=table_name,
rows_updated=total_updated)
return total_updated return total_updated
except Exception as e: except Exception as e:
self.rollback_transaction(conn) self.log.error("Data update failed",
raise table=table_name,
error=str(e),
except Exception as e: exc_info=True)
self.log.error(f"数据更新失败,表{table_name}", error=str(e), exc_info=True)
raise
def _get_table_info(self, table_name: str) -> Dict[str, str]:
"""获取表结构信息(优化SQL安全性)"""
sql = """
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = %s \
AND table_name = %s \
"""
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(sql, (self.config['database'], table_name))
result = cursor.fetchall()
return {row['column_name']: row['data_type'] for row in result}
except Exception as e:
self.log.error(f"获取表{table_name}结构失败", error=str(e))
raise raise
def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]: def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]:
"""推断DataFrame各列的SQL类型(扩展类型映射""" """推断DataFrame各列的SQL类型(原有逻辑完全保留"""
type_mapping = { type_mapping = {
'int64': 'BIGINT', 'int64': 'BIGINT',
'int32': 'INT',
'int16': 'SMALLINT',
'int8': 'TINYINT',
'uint64': 'BIGINT UNSIGNED',
'float64': 'DOUBLE', 'float64': 'DOUBLE',
'float32': 'FLOAT',
'datetime64[ns]': 'DATETIME', 'datetime64[ns]': 'DATETIME',
'datetime64[ns, UTC]': 'DATETIME',
'timedelta64[ns]': 'TIME',
'object': 'TEXT', 'object': 'TEXT',
'string': 'VARCHAR(255)',
'bool': 'TINYINT(1)', 'bool': 'TINYINT(1)',
'category': 'VARCHAR(255)' 'category': 'VARCHAR(255)'
} }
@@ -353,136 +488,84 @@ class MySQLAgent:
dtype_str = str(dtype) dtype_str = str(dtype)
sql_types[col] = type_mapping.get(dtype_str, 'TEXT') sql_types[col] = type_mapping.get(dtype_str, 'TEXT')
self.log.debug("DataFrame类型映射为SQL类型", mappings=sql_types) self.log.debug("Mapped DataFrame types to SQL types",
mappings=sql_types)
return sql_types return sql_types
def create_table_from_df(self, table_name: str, df: pd.DataFrame, def create_table_from_df(self, table_name: str, df: pd.DataFrame,
primary_key: Union[str, List[str], None] = None) -> bool: primary_key: Union[str, List[str], None] = None) -> bool:
"""根据DataFrame结构创建表(增强表结构定义""" """根据DataFrame结构创建表(原有逻辑完全保留"""
if self.table_exists(table_name): if self.table_exists(table_name):
self.log.warning(f"{table_name}已存在") self.log.warning("Table already exists", table=table_name)
return False return False
self.log.debug(f"根据DataFrame结构创建表{table_name}", columns=list(df.columns)) self.log.debug("Creating new table from DataFrame schema",
table=table_name,
columns=list(df.columns))
try: try:
sql_types = self.df_to_sql_type(df) sql_types = self.df_to_sql_type(df)
columns_sql = [] columns_sql = []
for col, sql_type in sql_types.items(): for col, sql_type in sql_types.items():
# 特殊字段处理 col_def = f"{col} {sql_type}"
if col.lower() in ['create_time', 'created_at'] and sql_type != 'DATETIME':
col_def = f"`{col}` DATETIME DEFAULT CURRENT_TIMESTAMP"
elif col.lower() in ['update_time', 'updated_at'] and sql_type != 'DATETIME':
col_def = f"`{col}` DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"
else:
col_def = f"`{col}` {sql_type}"
columns_sql.append(col_def) columns_sql.append(col_def)
# 处理主键
if primary_key: if primary_key:
if isinstance(primary_key, str): if isinstance(primary_key, str):
primary_key = [primary_key] primary_key = [primary_key]
pk_columns = [f"`{col}`" for col in primary_key if col in sql_types] pk_columns = [col for col in primary_key if col in sql_types]
if pk_columns: if pk_columns:
columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})") columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})")
self.log.trace(f"{table_name}设置主键", primary_key=pk_columns) self.log.trace("Set primary key",
table=table_name,
primary_key=pk_columns)
create_sql = f"CREATE TABLE {table_name} (\n {',\n '.join(columns_sql)}\n)"
create_sql = f"CREATE TABLE `{table_name}` (\n {',\n '.join(columns_sql)}\n)"
self.execute_sql(create_sql) self.execute_sql(create_sql)
self.log.info(f"{table_name}创建成功") self.log.info("Table created successfully", table=table_name)
return True return True
except Exception as e: except Exception as e:
self.log.error(f"{table_name}创建失败", error=str(e), exc_info=True) self.log.error("Failed to create table",
table=table_name,
error=str(e),
exc_info=True)
return False return False
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None, def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]: fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
"""执行SQL语句(增强资源管理""" """执行SQL语句(原有逻辑完全保留"""
conn = None
cursor = None
try: try:
conn = self.get_connection() with self.get_connection() as conn:
cursor = conn.cursor() with conn.cursor() as cursor:
# Linux/macOS需要更长的执行时间
# 非Windows平台延长执行超时
if platform.system() != 'Windows': if platform.system() != 'Windows':
cursor.execute("SET SESSION max_execution_time=600000") # 10分钟 cursor.execute("SET SESSION max_execution_time=600000")
cursor.execute(sql, params) cursor.execute(sql, params)
if fetch: if fetch:
result = cursor.fetchall() result = cursor.fetchall()
self.log.debug(f"查询执行完成,返回{len(result)}") self.log.debug("Query executed", rows=len(result))
return result return result
else: else:
affected_rows = cursor.rowcount affected_rows = cursor.rowcount
self.log.debug(f"更新执行完成,影响{affected_rows}") conn.commit() # 立即提交
self.log.debug("Update executed", affected_rows=affected_rows)
return affected_rows return affected_rows
except Exception as e: except Exception as e:
self.log.error("SQL执行失败", sql=sql, params=params, error=str(e), exc_info=True) self.log.error("SQL execution failed",
sql=sql,
params=params,
error=str(e),
exc_info=True)
raise raise
finally:
if cursor:
try:
cursor.close()
except Exception as e:
self.log.warning("关闭游标失败", error=str(e))
if conn:
try:
conn.close()
except Exception as e:
self.log.warning("关闭连接失败", error=str(e))
def begin_transaction(self) -> pymysql.connections.Connection:
"""开始事务(增强隔离级别处理)"""
try:
conn = self.get_connection()
conn.autocommit(False)
# 平台特定事务配置
if platform.system() == 'Darwin':
conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED")
elif platform.system() == 'Linux':
conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL REPEATABLE READ")
self.log.debug("事务开始")
return conn
except Exception as e:
self.log.error("事务开始失败", error=str(e))
raise
def commit_transaction(self, conn: pymysql.connections.Connection) -> None:
"""提交事务"""
try:
conn.commit()
self.log.debug("事务提交成功")
except Exception as e:
self.log.error("事务提交失败", error=str(e))
raise
finally:
try:
conn.close()
except Exception as e:
self.log.warning("事务提交后关闭连接失败", error=str(e))
def rollback_transaction(self, conn: pymysql.connections.Connection) -> None:
"""回滚事务"""
try:
conn.rollback()
self.log.warning("事务已回滚")
except Exception as e:
self.log.error("事务回滚失败", error=str(e))
finally:
try:
conn.close()
except Exception as e:
self.log.warning("事务回滚后关闭连接失败", error=str(e))
def table_exists(self, table_name: str) -> bool: def table_exists(self, table_name: str) -> bool:
"""检查表是否存在(优化SQL安全性""" """检查表是否存在(原有逻辑完全保留"""
sql = """ sql = """
SELECT COUNT(*) as count SELECT COUNT(*) as count
FROM `information_schema`.`tables` FROM `information_schema`.`tables`
@@ -490,64 +573,49 @@ class MySQLAgent:
AND `table_name` = %s \ AND `table_name` = %s \
""" """
params = (self.config['database'], table_name)
try: try:
result = self.execute_sql(sql, (self.config['database'], table_name), fetch=True) result = self.execute_sql(sql, params, fetch=True)
exists = result[0]['count'] > 0 exists = result[0][0] > 0 # 适配元组结果
self.log.debug(f"{table_name}存在性检查", exists=exists) self.log.debug("Checked table existence",
table=table_name,
exists=exists)
return exists return exists
except Exception as e: except Exception:
self.log.warning(f"{table_name}存在性检查失败", error=str(e))
return False return False
def drop_table(self, table_name: str) -> bool: def drop_table(self, table_name: str) -> bool:
"""删除表(增加二次确认日志""" """删除表(原有逻辑完全保留"""
if not self.table_exists(table_name): if not self.table_exists(table_name):
self.log.warning(f"{table_name}不存在,无法删除") self.log.warning("Table does not exist", table=table_name)
return False return False
try: try:
self.execute_sql(f"DROP TABLE `{table_name}`") self.execute_sql(f"DROP TABLE {table_name}")
self.log.info(f"{table_name}删除成功") self.log.info("Table dropped successfully", table=table_name)
return True return True
except Exception as e: except Exception as e:
self.log.error(f"{table_name}删除失败", error=str(e), exc_info=True) self.log.error("Failed to drop table",
table=table_name,
error=str(e),
exc_info=True)
return False return False
def get_pool_status(self) -> Dict[str, int]:
"""获取连接池状态"""
status = {
'max_connections': self._pool._maxconnections,
'active_connections': len(self._pool._connections),
'idle_connections': len(self._pool._idle_cache),
'shared_connections': len(self._pool._shared_cache)
}
self.log.debug("连接池状态", **status)
return status
def validate_connection(self) -> bool: def validate_connection(self) -> bool:
"""验证连接是否有效(增强健康检查""" """验证连接是否有效(原有逻辑完全保留"""
try: try:
with self.get_connection() as conn: with self.get_connection() as conn:
with conn.cursor() as cursor: with conn.cursor() as cursor:
cursor.execute("SELECT 1 AS health_check") cursor.execute("SELECT 1")
result = cursor.fetchone() return cursor.fetchone()[0] == 1
return result['health_check'] == 1 except Exception:
except Exception as e:
self.log.warning("连接健康检查失败", error=str(e))
return False return False
def __del__(self):
"""析构函数(确保连接池关闭)"""
if hasattr(self, '_pool') and self._pool:
try:
self._pool.close()
self.log.info("连接池已关闭")
except Exception as e:
self.log.error("连接池关闭失败", error=str(e))
# 平台特定的默认配置(原有逻辑完全保留)
def get_default_config(): def get_default_config():
"""获取各平台默认配置(优化默认参数)""" """获取各平台默认配置"""
current_platform = platform.system() current_platform = platform.system()
base_config = { base_config = {
@@ -555,76 +623,41 @@ def get_default_config():
'port': 3306, 'port': 3306,
'user': 'root', 'user': 'root',
'password': '123123', 'password': '123123',
'database': 'intelligence', 'database': 'intelligence_system',
'max_connections': 10, # 增加默认连接数
'charset': 'utf8mb4'
} }
if current_platform == 'Windows': if current_platform == 'Windows':
return { return {**base_config,
**base_config,
'connect_timeout': 10, 'connect_timeout': 10,
'read_timeout': 30, 'read_timeout': 30,
'write_timeout': 30, 'write_timeout': 30
'ssl': None # Windows默认禁用SSL
} }
elif current_platform == 'Darwin': # macOS elif current_platform == 'Darwin':
return { return {
**base_config, **base_config,
'connect_timeout': 15, 'connect_timeout': 15,
'read_timeout': 60, 'read_timeout': 60,
'write_timeout': 60, 'write_timeout': 60,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'} # macOS默认SSL配置 'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
} }
else: # Linux其他平台 else: # Linux其他平台
return { return {**base_config,
**base_config,
'connect_timeout': 15, 'connect_timeout': 15,
'read_timeout': 60, 'read_timeout': 60,
'write_timeout': 60, 'write_timeout': 60
'ssl': None # Linux默认禁用SSL
} }
if __name__ == "__main__": if __name__ == "__main__":
# 使用示例 # 使用示例(原有逻辑完全保留)
try:
db = MySQLAgent(get_default_config()) db = MySQLAgent(get_default_config())
# 测试连接 # 测试连接
if db.validate_connection(): if db.validate_connection():
print("数据库连接成功") print("Database connection successful")
# 获取数据库版本 # 获取数据库版本
version_df = db.query_to_df("SELECT VERSION() as version") version = db.query_to_df("SELECT VERSION() as version")
print(f"数据库版本: {version_df['version'].iloc[0]}") print(f"Database version: {version['version'].iloc[0]}")
# 查看连接池状态
print("连接池状态:", db.get_pool_status())
# 创建测试表
test_df = pd.DataFrame({
'id': [1, 2, 3],
'name': ['测试1', '测试2', '测试3'],
'value': [10.5, 20.3, 30.8],
'created_at': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03'])
})
db.create_table_from_df('test_table', test_df, primary_key='id')
print("测试表创建成功")
# 插入数据
rows_inserted = db.insert_from_df('test_table', test_df)
print(f"插入了{rows_inserted}行数据")
# 查询数据
result_df = db.query_to_df("SELECT * FROM test_table")
print("查询结果:")
print(result_df)
# 清理测试表
db.drop_table('test_table')
print("测试表已删除")
else: else:
print("数据库连接失败") print("Failed to connect to database")
except Exception as e:
print(f"示例执行失败: {str(e)}")