mysql数据链接更新

This commit is contained in:
z66
2025-09-18 17:03:24 +08:00
parent 9afa9d2e58
commit 20fd9587ee
5 changed files with 188086 additions and 726 deletions
+258 -225
View File
@@ -6,272 +6,305 @@ import os
import pickle import pickle
import time import time
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
import pymysql from loguru import logger
from utils.mysql_agent import MySQLAgent
from typing import Dict, List, Optional, Any
# 数据库连接信息 # 数据库连接配置
local_DB_Config = { local_DB_Config = {
'host': "localhost", 'host': "localhost",
'user': "root", 'user': "root",
'password': "123123", 'password': "123123",
'database': "intelligence_system", 'database': "intelligence_system",
'charset': 'utf8mb4' 'port': 3306,
'charset': 'utf8mb4',
'connect_timeout': 10,
'read_timeout': 30,
'write_timeout': 30,
'autocommit': True
} }
# 表名 # 目标数据表名
table_name = "collector_rss_subscriptions" table_name = "collector_rss_subscriptions"
def verify_database(): class NewsAPIClient:
"""验证数据库连接和表结构""" """新闻API客户端,用于获取和处理RSS源数据并写入数据库"""
try:
conn = pymysql.connect(**local_DB_Config) def __init__(self):
with conn.cursor() as cursor: """初始化客户端并建立数据库连接"""
# 检查表是否存在 self.logger = logger.bind(module="NewsAPIClient")
cursor.execute(f"SHOW TABLES LIKE '{table_name}'") self.db_agent = MySQLAgent(local_DB_Config)
if not cursor.fetchone(): self.logger.info("新闻API客户端初始化完成,已连接到数据库")
print(f"错误: 表 {table_name} 不存在!")
def _format_result(self, success: bool, message: str = "", data: Optional[Any] = None) -> Dict[str, Any]:
"""统一返回结果格式"""
return {
'success': bool(success),
'message': str(message),
'data': data
}
def verify_database(self) -> bool:
"""验证数据库表结构是否符合要求(适配元组格式的查询结果)"""
try:
# 1. 检查表是否存在(execute_sql返回元组列表,如 [(table_name,)]
result = self.db_agent.execute_sql(
f"SHOW TABLES LIKE '{table_name}'",
fetch=True
)
# 元组结果需通过索引0判断(若表存在,result是[(table_name,)], 否则为空列表)
if not result:
self.logger.error(f"{table_name} 不存在,请先创建表结构")
return False return False
# 检查表结构 # 2. 检查表字段是否完整(DESCRIBE返回的元组格式:(字段名, 类型, 是否为空, ...))
cursor.execute(f"DESCRIBE {table_name}") desc_result = self.db_agent.execute_sql(
columns = [col[0] for col in cursor.fetchall()] f"DESCRIBE {table_name}",
print("表列名:", columns) fetch=True
)
# 关键修改:用元组索引0提取字段名(而非字典键'Field'
columns = [col[0] for col in desc_result] # col是元组,col[0]即字段名
required_columns = ['文章标题', '文章链接', '文章摘要', '发布时间',
'来源URL', '创建时间', '更新时间']
missing_cols = [col for col in required_columns if col not in columns]
# 检查插入权限 if missing_cols:
test_sql = f"""INSERT INTO `{table_name}` self.logger.error(f"{table_name} 缺少必要字段:{missing_cols}")
(`文章标题`, `文章链接`, `文章摘要`, `发布时间`, `来源URL`) return False
VALUES (%s, %s, %s, %s, %s)"""
cursor.execute(test_sql, ('测试标题', 'http://test.com', '测试内容', datetime.now(), '测试来源'))
conn.rollback()
print("数据库验证通过!") self.logger.info(f"数据库表结构验证通过,当前字段:{columns}")
return True return True
except Exception as e:
print("数据库验证失败:", e)
return False
finally:
if 'conn' in locals():
conn.close()
except Exception as e:
self.logger.error(f"数据库验证失败: {str(e)}", exc_info=True)
return False
def load_last_update_time(): def load_last_update_time(self) -> Optional[datetime]:
"""加载上次更新时间""" """加载上次更新时间缓存"""
cache_file = os.path.join(os.getcwd(), 'output', 'last_update.pkl') cache_file = os.path.join(os.getcwd(), 'output', 'last_update.pkl')
if os.path.exists(cache_file): if os.path.exists(cache_file):
with open(cache_file, 'rb') as f:
return pickle.load(f)
return None
def save_last_update_time(last_update):
"""保存本次更新的时间"""
cache_file = os.path.join(os.getcwd(), 'output', 'last_update.pkl')
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
with open(cache_file, 'wb') as f:
pickle.dump(last_update, f)
def fetch_single_rss(url, timeout=15):
"""获取并解析单个 RSS 源"""
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
for attempt in range(3):
try:
response = requests.get(url, headers=headers, timeout=timeout)
response.raise_for_status()
response.encoding = response.apparent_encoding
feed = feedparser.parse(response.text)
if feed.bozo:
print(f"警告: 解析可能存在问题: {feed.bozo_exception}")
return feed
except requests.RequestException as e:
print(f"{attempt + 1} 次尝试获取 {url} 失败: {e}")
if attempt < 2:
time.sleep(5 * (attempt + 1))
continue
return None
def fetch_all_rss(urls, timeout=15):
"""使用线程池并发获取多个RSS源"""
feeds = {}
with ThreadPoolExecutor(max_workers=3) as executor:
future_to_url = {executor.submit(fetch_single_rss, url, timeout): url for url in urls}
for future in as_completed(future_to_url):
url = future_to_url[future]
try: try:
feed = future.result() with open(cache_file, 'rb') as f:
if feed: last_update = pickle.load(f)
feeds[url] = feed self.logger.debug(f"加载上次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
return last_update
except Exception as e: except Exception as e:
print(f"获取 {url} 时发生异常: {e}") self.logger.error(f"加载上次更新时间失败: {str(e)}", exc_info=True)
self.logger.debug("未找到上次更新时间缓存,将获取全部数据")
return feeds
def process_feed_entry(entry, url):
"""处理单个RSS条目并返回结构化数据"""
# 处理标题
title = entry.get('title', '无标题') or '无标题'
if len(title) > 255:
title = title[:252] + '...'
# 处理链接
link = entry.get('link', '无链接') or '无链接'
if len(link) > 1024:
link = link[:1021] + '...'
# 处理摘要
summary = entry.get('summary', '无内容摘要')
content_list = entry.get('content', [])
content = content_list[0].value if content_list else ''
description = summary if summary != '无内容摘要' else (content[:200] + '...' if content else '无内容摘要')
# 处理发布时间
published_parsed = entry.get('published_parsed') or entry.get('updated_parsed')
if published_parsed:
entry_time = datetime(*published_parsed[:6])
else:
pub_str = entry.get('published', entry.get('updated', ''))
try:
entry_time = datetime.strptime(pub_str, '%a, %d %b %Y %H:%M:%S %z')
except:
entry_time = datetime.now()
# 处理来源URL
source_url = url or '未知来源'
if len(source_url) > 1024:
source_url = source_url[:1021] + '...'
return {
'文章标题': title,
'文章链接': link,
'文章摘要': description,
'发布时间': entry_time.strftime('%Y-%m-%d %H:%M:%S'),
'来源URL': source_url
}
def display_feed_info(feed, last_update=None, url=None):
"""处理并显示RSS源信息"""
if not feed:
print("无法显示信息:feed 为 None")
return None return None
print("=" * 80) def save_last_update_time(self, last_update: datetime) -> None:
print(f"处理 RSS 源: {url}") """保存本次更新时间"""
entries = feed.entries try:
data_list = [] cache_dir = os.path.join(os.getcwd(), 'output')
new_last_update = last_update os.makedirs(cache_dir, exist_ok=True)
cache_file = os.path.join(cache_dir, 'last_update.pkl')
for i, entry in enumerate(entries, 1): with open(cache_file, 'wb') as f:
entry_data = process_feed_entry(entry, url) pickle.dump(last_update, f)
entry_time = datetime.strptime(entry_data['发布时间'], '%Y-%m-%d %H:%M:%S') self.logger.debug(f"已保存本次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
except Exception as e:
self.logger.error(f"保存更新时间失败: {str(e)}", exc_info=True)
if last_update and entry_time <= last_update: def fetch_single_rss(self, url: str, timeout: int = 15) -> Optional[feedparser.FeedParserDict]:
continue """获取并解析单个RSS源"""
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
if new_last_update is None or entry_time > new_last_update: for attempt in range(3):
new_last_update = entry_time try:
response = requests.get(url, headers=headers, timeout=timeout)
response.raise_for_status()
response.encoding = response.apparent_encoding
feed = feedparser.parse(response.text)
print(f"\n--- 条目 {i} ---") if feed.bozo:
print(f"标题: {entry_data['文章标题']}") self.logger.warning(f"解析 {url} 存在潜在问题: {feed.bozo_exception}")
print(f"链接: {entry_data['文章链接']}")
print(f"摘要: {entry_data['文章摘要'][:100]}...")
print(f"时间: {entry_data['发布时间']}")
data_list.append(entry_data) self.logger.debug(f"成功获取 {url} 的RSS数据")
return feed
if data_list: except requests.RequestException as e:
df = pd.DataFrame(data_list) self.logger.warning(f"{attempt + 1} 次获取 {url} 失败: {str(e)}")
write_to_database(df) if attempt < 2:
time.sleep(3 * (attempt + 1)) # 指数退避重试
continue
return new_last_update self.logger.error(f"三次尝试后仍无法获取 {url} 的RSS数据")
return None
def fetch_all_rss(self, urls: List[str], timeout: int = 15) -> Dict[str, feedparser.FeedParserDict]:
"""并发获取多个RSS源"""
feeds = {}
with ThreadPoolExecutor(max_workers=3) as executor:
future_to_url = {executor.submit(self.fetch_single_rss, url, timeout): url for url in urls}
def write_to_database(df): for future in as_completed(future_to_url):
"""将数据写入数据库""" url = future_to_url[future]
if df.empty:
print("没有新数据需要写入")
return
print("\n准备写入数据库的数据样例:")
print(df.iloc[0].to_dict())
try:
conn = pymysql.connect(**local_DB_Config)
with conn.cursor() as cursor:
sql = f"""INSERT IGNORE INTO `{table_name}`
(`文章标题`, `文章链接`, `文章摘要`, `发布时间`, `来源URL`)
VALUES (%s, %s, %s, %s, %s)"""
success_count = 0
for _, row in df.iterrows():
try: try:
cursor.execute(sql, ( feed = future.result()
row['文章标题'], if feed:
row['文章链接'], feeds[url] = feed
row['文章摘要'],
row['发布时间'],
row['来源URL']
))
success_count += cursor.rowcount
except Exception as e: except Exception as e:
print(f"插入记录时出错: {e}") self.logger.error(f"处理 {url} 时发生异常: {str(e)}", exc_info=True)
print(f"问题数据: {row.to_dict()}")
continue
conn.commit() self.logger.info(f"RSS源获取完成,成功获取 {len(feeds)}/{len(urls)} 个源")
print(f"成功写入 {success_count}/{len(df)} 条记录") return feeds
except Exception as e: def process_feed_entry(self, entry: Dict[str, Any], url: str) -> Dict[str, str]:
print("数据库操作失败:", e) """处理单个RSS条目,转换为数据库兼容格式"""
finally: # 处理标题
if 'conn' in locals(): title = entry.get('title', '无标题') or '无标题'
conn.close() if len(title) > 255:
title = title[:252] + '...'
# 处理链接
link = entry.get('link', '无链接') or '无链接'
if len(link) > 1024:
link = link[:1021] + '...'
def main(): # 处理摘要
"""主函数""" summary = entry.get('summary', '无内容摘要')
if not verify_database(): content_list = entry.get('content', [])
print("数据库验证失败,程序终止") content = content_list[0].value if (content_list and hasattr(content_list[0], 'value')) else ''
return description = summary if summary != '无内容摘要' else (content[:200] + '...' if content else '无内容摘要')
rss_urls = [ # 处理发布时间
"https://www.chinanews.com.cn/rss/finance.xml", published_parsed = entry.get('published_parsed') or entry.get('updated_parsed')
"https://www.chinanews.com.cn/rss/world.xml", if published_parsed:
"https://www.chinanews.com.cn/rss/china.xml", entry_time = datetime(*published_parsed[:6])
"https://www.chinanews.com.cn/rss/scroll-news.xml" else:
] pub_str = entry.get('published', entry.get('updated', ''))
try:
entry_time = datetime.strptime(pub_str, '%a, %d %b %Y %H:%M:%S %z').replace(tzinfo=None)
except:
entry_time = datetime.now()
last_update = load_last_update_time() # 处理来源URL
if last_update: source_url = url or '未知来源'
print(f"上次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}") if len(source_url) > 1024:
source_url = source_url[:1021] + '...'
print("\n开始获取RSS源数据...") # 当前时间(创建/更新时间)
start_time = time.time() current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
feeds = fetch_all_rss(rss_urls)
print(f"获取完成,耗时: {time.time() - start_time:.2f}")
new_last_update = None return {
for url, feed in feeds.items(): '文章标题': title,
current_last_update = display_feed_info(feed, last_update, url) '文章链接': link,
if current_last_update and (new_last_update is None or current_last_update > new_last_update): '文章摘要': description,
new_last_update = current_last_update '发布时间': entry_time.strftime('%Y-%m-%d %H:%M:%S'),
'来源URL': source_url,
'创建时间': current_time,
'更新时间': current_time
}
if new_last_update: def display_feed_info(self, feed: feedparser.FeedParserDict, last_update: Optional[datetime] = None,
save_last_update_time(new_last_update) url: Optional[str] = None) -> Optional[datetime]:
print(f"\n本次最新更新时间: {new_last_update.strftime('%Y-%m-%d %H:%M:%S')}") """处理RSS源信息并写入数据库"""
else: if not feed:
print("\n没有获取到新的内容") self.logger.warning("无法处理空的RSS源数据")
return None
self.logger.info(f"开始处理 RSS 源: {url}")
entries = feed.entries
data_list = []
new_last_update = last_update
for i, entry in enumerate(entries, 1):
entry_data = self.process_feed_entry(entry, url)
entry_time = datetime.strptime(entry_data['发布时间'], '%Y-%m-%d %H:%M:%S')
# 过滤旧数据
if last_update and entry_time <= last_update:
continue
# 更新最新时间戳
if new_last_update is None or entry_time > new_last_update:
new_last_update = entry_time
self.logger.debug(f"处理条目 {i}: {entry_data['文章标题']}")
data_list.append(entry_data)
# 写入数据库
if data_list:
df = pd.DataFrame(data_list)
self.write_to_database(df)
return new_last_update
# news_api.py 中的 write_to_database 方法可以保持简洁
def write_to_database(self, df: pd.DataFrame) -> Dict[str, Any]:
if df.empty:
self.logger.info("没有新数据需要写入数据库")
return self._format_result(True, "没有新数据需要写入")
try:
inserted_rows = self.db_agent.insert_from_df(
table_name=table_name,
df=df,
chunk_size=500,
replace=False
)
self.logger.info(f"成功写入 {inserted_rows}/{len(df)} 条记录")
return self._format_result(
True,
f"成功写入 {inserted_rows}/{len(df)} 条记录",
{"success_count": inserted_rows, "total": len(df)}
)
except Exception as e:
self.logger.error(f"数据库写入失败: {str(e)}", exc_info=True)
return self._format_result(False, f"数据库操作失败: {str(e)}")
@classmethod
def main(cls):
"""主函数入口"""
try:
client = cls()
# 验证数据库
if not client.verify_database():
client.logger.error("数据库验证失败,程序终止")
return
# RSS源列表
rss_urls = [
"https://www.chinanews.com.cn/rss/finance.xml",
"https://www.chinanews.com.cn/rss/world.xml",
"https://www.chinanews.com.cn/rss/china.xml",
"https://www.chinanews.com.cn/rss/scroll-news.xml"
]
# 加载上次更新时间
last_update = client.load_last_update_time()
if last_update:
client.logger.info(f"上次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
# 获取RSS数据
client.logger.info("开始获取RSS源数据...")
start_time = time.time()
feeds = client.fetch_all_rss(rss_urls)
client.logger.info(f"获取完成,耗时: {time.time() - start_time:.2f}")
# 处理并写入数据
new_last_update = None
for url, feed in feeds.items():
current_last_update = client.display_feed_info(feed, last_update, url)
if current_last_update and (new_last_update is None or current_last_update > new_last_update):
new_last_update = current_last_update
# 保存最新更新时间
if new_last_update:
client.save_last_update_time(new_last_update)
client.logger.info(f"本次最新更新时间: {new_last_update.strftime('%Y-%m-%d %H:%M:%S')}")
else:
client.logger.info("没有获取到新内容")
except Exception as e:
logger.error(f"程序运行出错: {str(e)}", exc_info=True)
if __name__ == "__main__": if __name__ == "__main__":
main() NewsAPIClient.main()
+123849
View File
File diff suppressed because it is too large Load Diff
+63455
View File
File diff suppressed because it is too large Load Diff
+104 -114
View File
@@ -3,18 +3,20 @@ import pandas as pd
from datetime import datetime from datetime import datetime
import time import time
import pymysql import pymysql
from utils.mysql_agent import MySQLAgent
import platform import platform
from concurrent.futures import ThreadPoolExecutor
from utils.mysql_agent import MySQLAgent
class TestMySQLAgent(unittest.TestCase): class TestMySQLAgent(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
"""初始化测试环境和测试表""" """初始化测试环境和测试表"""
# 创建唯一的测试数据库 # 创建唯一的测试数据库和表名(避免冲突)
cls.test_db_name = "test_db_" + datetime.now().strftime("%Y%m%d%H%M%S") cls.test_db_name = f"test_db_{datetime.now().strftime('%Y%m%d%H%M%S')}"
cls.test_table = "test_table_" + datetime.now().strftime("%Y%m%d%H%M%S") cls.test_table = f"test_table_{datetime.now().strftime('%Y%m%d%H%M%S')}"
# 基础配置 # 基础配置(根据实际环境修改)
cls.base_config = { cls.base_config = {
'host': 'localhost', 'host': 'localhost',
'port': 3306, 'port': 3306,
@@ -32,21 +34,19 @@ class TestMySQLAgent(unittest.TestCase):
'database': cls.test_db_name 'database': cls.test_db_name
}) })
# 创建测试表 # 创建测试表并插入初始数据
test_data = pd.DataFrame({ test_data = pd.DataFrame({
'id': [1, 2, 3], 'id': [1, 2, 3],
'name': ['Test1', 'Test2', 'Test3'], 'name': ['Test1', 'Test2', 'Test3'],
'value': [10.5, 20.3, 30.8], 'value': [10.5, 20.3, 30.8],
'created_at': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03']) 'created_at': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03'])
}) })
cls.db.create_table_from_df(cls.test_table, test_data, primary_key='id') cls.db.create_table_from_df(cls.test_table, test_data, primary_key='id')
cls.db.insert_from_df(cls.test_table, test_data) cls.db.insert_from_df(cls.test_table, test_data)
@classmethod @classmethod
def _create_test_database(cls): def _create_test_database(cls):
"""创建测试数据库""" """创建测试数据库"""
# 使用临时连接创建数据库
temp_conn = pymysql.connect( temp_conn = pymysql.connect(
host=cls.base_config['host'], host=cls.base_config['host'],
port=cls.base_config['port'], port=cls.base_config['port'],
@@ -54,7 +54,6 @@ class TestMySQLAgent(unittest.TestCase):
password=cls.base_config['password'], password=cls.base_config['password'],
charset='utf8mb4' charset='utf8mb4'
) )
try: try:
with temp_conn.cursor() as cursor: with temp_conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}") cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}")
@@ -66,21 +65,14 @@ class TestMySQLAgent(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
"""清理测试数据库""" """清理测试环境"""
if hasattr(cls, 'db') and cls.db: if hasattr(cls, 'db') and cls.db:
# 删除测试表 # 删除测试表
if cls.db.table_exists(cls.test_table): if cls.db.table_exists(cls.test_table):
cls.db.drop_table(cls.test_table) cls.db.drop_table(cls.test_table)
# 删除测试数据库 # 删除测试数据库
temp_conn = pymysql.connect( temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4')
host=cls.base_config['host'],
port=cls.base_config['port'],
user=cls.base_config['user'],
password=cls.base_config['password'],
charset='utf8mb4'
)
try: try:
with temp_conn.cursor() as cursor: with temp_conn.cursor() as cursor:
cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}") cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}")
@@ -88,22 +80,24 @@ class TestMySQLAgent(unittest.TestCase):
finally: finally:
temp_conn.close() temp_conn.close()
def test_01_connection(self): def test_connection(self):
"""测试数据库连接""" """测试数据库连接"""
version = self.db.query_to_df("SELECT VERSION() as version") version_df = self.db.query_to_df("SELECT VERSION() as version")
self.assertIsNotNone(version) self.assertIsNotNone(version_df)
print(f"\nDatabase version: {version['version'].iloc[0]}") self.assertEqual(len(version_df), 1)
print(f"Running on: {platform.system()} {platform.release()}") print(f"数据库版本: {version_df['version'].iloc[0]}")
def test_02_query_to_df(self): def test_query_to_df(self):
"""测试查询返回DataFrame""" """测试查询返回DataFrame"""
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id > %s", (1,)) df = self.db.query_to_df(
self.assertEqual(len(df), 2) f"SELECT * FROM {self.test_table} WHERE id > %s",
params=(1,)
)
self.assertIsInstance(df, pd.DataFrame) self.assertIsInstance(df, pd.DataFrame)
print("\nQuery result sample:") self.assertEqual(len(df), 2) # id>1 的数据有2条
print(df.head()) self.assertIn('name', df.columns)
def test_03_insert_from_df(self): def test_insert_from_df(self):
"""测试DataFrame插入""" """测试DataFrame插入"""
new_data = pd.DataFrame({ new_data = pd.DataFrame({
'id': [4, 5], 'id': [4, 5],
@@ -112,55 +106,55 @@ class TestMySQLAgent(unittest.TestCase):
'created_at': pd.to_datetime(['2023-01-04', '2023-01-05']) 'created_at': pd.to_datetime(['2023-01-04', '2023-01-05'])
}) })
rows = self.db.insert_from_df(self.test_table, new_data) inserted_rows = self.db.insert_from_df(self.test_table, new_data)
self.assertEqual(rows, 2) self.assertEqual(inserted_rows, 2)
# 验证数据 # 验证插入结果
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id >= 4") result_df = self.db.query_to_df(
self.assertEqual(len(df), 2) f"SELECT name FROM {self.test_table} WHERE id IN (4,5)"
self.assertEqual(df['name'].tolist(), ['Test4', 'Test5']) )
self.assertEqual(result_df['name'].tolist(), ['Test4', 'Test5'])
def test_04_update_from_df(self): def test_update_from_df(self):
"""测试DataFrame更新""" """测试DataFrame更新"""
update_data = pd.DataFrame({ update_data = pd.DataFrame({
'id': [1, 2], 'id': [1, 2],
'name': ['Updated1', 'Updated2'] 'name': ['Updated1', 'Updated2']
}) })
rows = self.db.update_from_df(self.test_table, update_data, 'id') updated_rows = self.db.update_from_df(self.test_table, update_data, 'id')
self.assertGreaterEqual(rows, 2) self.assertGreaterEqual(updated_rows, 2)
# 验证更新 # 验证更新结果
df = self.db.query_to_df(f"SELECT name FROM {self.test_table} WHERE id IN (1,2)") result_df = self.db.query_to_df(
self.assertIn('Updated1', df['name'].values) f"SELECT name FROM {self.test_table} WHERE id IN (1,2)"
self.assertIn('Updated2', df['name'].values) )
self.assertIn('Updated1', result_df['name'].values)
self.assertIn('Updated2', result_df['name'].values)
def test_05_transaction(self): def test_transaction(self):
"""测试事务处理""" """测试事务处理"""
conn = self.db.begin_transaction() conn = self.db.begin_transaction()
try: try:
# 执行多个操作 # 执行事务内操作
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(f"UPDATE {self.test_table} SET value = 99.9 WHERE id = 1") cursor.execute(f"UPDATE {self.test_table} SET value = 99.9 WHERE id = 1")
cursor.execute(f"UPDATE {self.test_table} SET value = 88.8 WHERE id = 2") cursor.execute(f"UPDATE {self.test_table} SET value = 88.8 WHERE id = 2")
# 验证事务内修改
cursor.execute(f"SELECT value FROM {self.test_table} WHERE id = 1")
self.assertEqual(cursor.fetchone()['value'], 99.9)
self.db.commit_transaction(conn) self.db.commit_transaction(conn)
except Exception: except Exception:
self.db.rollback_transaction(conn) self.db.rollback_transaction(conn)
raise raise
# 验证提交后的修改 # 验证事务提交结果
df = self.db.query_to_df(f"SELECT value FROM {self.test_table} WHERE id IN (1,2)") result_df = self.db.query_to_df(
self.assertIn(99.9, df['value'].values) f"SELECT value FROM {self.test_table} WHERE id IN (1,2)"
self.assertIn(88.8, df['value'].values) )
self.assertIn(99.9, result_df['value'].values)
self.assertIn(88.8, result_df['value'].values)
def test_06_large_data(self): def test_large_data_insert(self):
"""测试大数据量操作""" """测试大数据量插入"""
# 生成测试数据 # 生成1000行测试数据
large_data = pd.DataFrame({ large_data = pd.DataFrame({
'id': range(1000, 2000), 'id': range(1000, 2000),
'name': [f"Item_{i}" for i in range(1000, 2000)], 'name': [f"Item_{i}" for i in range(1000, 2000)],
@@ -168,59 +162,55 @@ class TestMySQLAgent(unittest.TestCase):
'created_at': pd.date_range('2023-01-01', periods=1000) 'created_at': pd.date_range('2023-01-01', periods=1000)
}) })
# Windows平台使用更小的批次 # 根据平台自动调整批次大小
chunk_size = 100 if platform.system() == 'Windows' else 500 chunk_size = 100 if platform.system() == 'Windows' else 500
start_time = time.time() start_time = time.time()
rows = self.db.insert_from_df(self.test_table, large_data, chunk_size=chunk_size) inserted_rows = self.db.insert_from_df(
self.test_table,
large_data,
chunk_size=chunk_size
)
elapsed = time.time() - start_time elapsed = time.time() - start_time
self.assertEqual(rows, 1000) self.assertEqual(inserted_rows, 1000)
print(f"\nInserted 1000 rows in {elapsed:.2f}s (chunk_size={chunk_size})") print(f"插入1000行数据耗时: {elapsed:.2f} (批次大小: {chunk_size})")
# 验证数据 def test_concurrent_access(self):
df = self.db.query_to_df(f"SELECT COUNT(*) as cnt FROM {self.test_table} WHERE id >= 1000")
self.assertEqual(df['cnt'].iloc[0], 1000)
def test_07_concurrent_access(self):
"""测试并发访问""" """测试并发访问"""
from concurrent.futures import ThreadPoolExecutor
def worker(i): def query_worker(i):
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id = %s", (i % 5 + 1,)) """并发查询工作函数"""
df = self.db.query_to_df(
f"SELECT * FROM {self.test_table} WHERE id = %s",
params=(i % 3 + 1,) # 查询id=1,2,3循环
)
return len(df) return len(df)
# 20个线程执行100次查询
start_time = time.time() start_time = time.time()
with ThreadPoolExecutor(max_workers=20) as executor: with ThreadPoolExecutor(max_workers=20) as executor:
results = list(executor.map(worker, range(100))) results = list(executor.map(query_worker, range(100)))
elapsed = time.time() - start_time elapsed = time.time() - start_time
self.assertEqual(sum(results), 100)
print(f"\nCompleted 100 concurrent queries in {elapsed:.2f}s") self.assertEqual(sum(results), 100) # 每次查询应返回1行
print(f"100次并发查询耗时: {elapsed:.2f}")
class TestPlatformSpecific(unittest.TestCase): class TestPlatformSpecific(unittest.TestCase):
"""平台特定功能测试"""
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
"""创建临时测试数据库""" cls.test_db_name = f"test_platform_db_{datetime.now().strftime('%Y%m%d%H%M%S')}"
cls.test_db_name = "test_db_platform_" + datetime.now().strftime("%Y%m%d%H%M%S")
cls.base_config = { cls.base_config = {
'host': 'localhost', 'host': 'localhost',
'port': 3306, 'port': 3306,
'user': 'root', 'user': 'root',
'password': '123123', 'password': '123123'
'max_connections': 10
} }
# 创建数据库 # 创建测试数据库
temp_conn = pymysql.connect( temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4')
host=cls.base_config['host'],
port=cls.base_config['port'],
user=cls.base_config['user'],
password=cls.base_config['password'],
charset='utf8mb4'
)
try: try:
with temp_conn.cursor() as cursor: with temp_conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}") cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}")
@@ -230,15 +220,8 @@ class TestPlatformSpecific(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
"""删除临时测试数据库""" """清理测试数据库"""
temp_conn = pymysql.connect( temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4')
host=cls.base_config['host'],
port=cls.base_config['port'],
user=cls.base_config['user'],
password=cls.base_config['password'],
charset='utf8mb4'
)
try: try:
with temp_conn.cursor() as cursor: with temp_conn.cursor() as cursor:
cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}") cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}")
@@ -249,42 +232,49 @@ class TestPlatformSpecific(unittest.TestCase):
def test_windows_timeout(self): def test_windows_timeout(self):
"""测试Windows平台超时处理""" """测试Windows平台超时处理"""
if platform.system() != 'Windows': if platform.system() != 'Windows':
self.skipTest("Only runs on Windows") self.skipTest("仅在Windows平台运行")
config = { config = {
**self.base_config, **self.base_config,
'database': self.test_db_name, 'database': self.test_db_name,
'connect_timeout': 1, 'connect_timeout': 1,
'read_timeout': 1 'read_timeout': 1,
'write_timeout': 1
} }
db = MySQLAgent(config) db = MySQLAgent(config)
# 测试短超时查询 # 执行会超时查询(SLEEP(2)超过1秒超时设置)
start_time = time.time() with self.assertRaises((pymysql.OperationalError, TimeoutError)) as ctx:
try: try:
db.query_to_df("SELECT SLEEP(2)") db.query_to_df("SELECT SLEEP(2)")
self.fail("Should have timed out") except Exception as e:
except Exception as e: # 提取底层异常信息(可能被包装)
self.assertIn("timed out", str(e)) while hasattr(e, 'args') and len(e.args) > 0 and isinstance(e.args[0], Exception):
print(f"\nWindows timeout test: {str(e)}") e = e.args[0]
raise e
def test_macos_ssl(self): error_msg = str(ctx.exception)
"""测试macOS SSL连接""" self.assertTrue(
"timed out" in error_msg or
"timeout" in error_msg or
"HY000" in error_msg, # MySQL超时错误码
f"未检测到超时异常,实际异常: {error_msg}"
)
def test_macos_ssl_connection(self):
"""测试macOS平台SSL连接"""
if platform.system() != 'Darwin': if platform.system() != 'Darwin':
self.skipTest("Only runs on macOS") self.skipTest("仅在macOS平台运行")
config = { config = {
**self.base_config, **self.base_config,
'database': self.test_db_name, 'database': self.test_db_name,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'} 'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
} }
db = MySQLAgent(config) db = MySQLAgent(config)
version = db.query_to_df("SELECT VERSION() as version") version_df = db.query_to_df("SELECT VERSION() as version")
self.assertIsNotNone(version) self.assertIsNotNone(version_df)
print(f"\nmacOS SSL connection successful: {version['version'].iloc[0]}")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main(verbosity=2)
+420 -387
View File
File diff suppressed because it is too large Load Diff