280 lines
9.5 KiB
Python
280 lines
9.5 KiB
Python
import unittest
|
|
import pandas as pd
|
|
from datetime import datetime
|
|
import time
|
|
import pymysql
|
|
import platform
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from utils.mysql_agent import MySQLAgent
|
|
|
|
|
|
class TestMySQLAgent(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
"""初始化测试环境和测试表"""
|
|
# 创建唯一的测试数据库和表名(避免冲突)
|
|
cls.test_db_name = f"test_db_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
|
cls.test_table = f"test_table_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
|
|
|
# 基础配置(根据实际环境修改)
|
|
cls.base_config = {
|
|
'host': 'localhost',
|
|
'port': 3306,
|
|
'user': 'root',
|
|
'password': '123123',
|
|
'max_connections': 10
|
|
}
|
|
|
|
# 创建测试数据库
|
|
cls._create_test_database()
|
|
|
|
# 初始化数据库连接
|
|
cls.db = MySQLAgent({
|
|
**cls.base_config,
|
|
'database': cls.test_db_name
|
|
})
|
|
|
|
# 创建测试表并插入初始数据
|
|
test_data = pd.DataFrame({
|
|
'id': [1, 2, 3],
|
|
'name': ['Test1', 'Test2', 'Test3'],
|
|
'value': [10.5, 20.3, 30.8],
|
|
'created_at': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03'])
|
|
})
|
|
cls.db.create_table_from_df(cls.test_table, test_data, primary_key='id')
|
|
cls.db.insert_from_df(cls.test_table, test_data)
|
|
|
|
@classmethod
|
|
def _create_test_database(cls):
|
|
"""创建测试数据库"""
|
|
temp_conn = pymysql.connect(
|
|
host=cls.base_config['host'],
|
|
port=cls.base_config['port'],
|
|
user=cls.base_config['user'],
|
|
password=cls.base_config['password'],
|
|
charset='utf8mb4'
|
|
)
|
|
try:
|
|
with temp_conn.cursor() as cursor:
|
|
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}")
|
|
cursor.execute(f"USE {cls.test_db_name}")
|
|
cursor.execute("SET GLOBAL max_connections = 100")
|
|
temp_conn.commit()
|
|
finally:
|
|
temp_conn.close()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
"""清理测试环境"""
|
|
if hasattr(cls, 'db') and cls.db:
|
|
# 删除测试表
|
|
if cls.db.table_exists(cls.test_table):
|
|
cls.db.drop_table(cls.test_table)
|
|
|
|
# 删除测试数据库
|
|
temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4')
|
|
try:
|
|
with temp_conn.cursor() as cursor:
|
|
cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}")
|
|
temp_conn.commit()
|
|
finally:
|
|
temp_conn.close()
|
|
|
|
def test_connection(self):
|
|
"""测试数据库连接"""
|
|
version_df = self.db.query_to_df("SELECT VERSION() as version")
|
|
self.assertIsNotNone(version_df)
|
|
self.assertEqual(len(version_df), 1)
|
|
print(f"数据库版本: {version_df['version'].iloc[0]}")
|
|
|
|
def test_query_to_df(self):
|
|
"""测试查询返回DataFrame"""
|
|
df = self.db.query_to_df(
|
|
f"SELECT * FROM {self.test_table} WHERE id > %s",
|
|
params=(1,)
|
|
)
|
|
self.assertIsInstance(df, pd.DataFrame)
|
|
self.assertEqual(len(df), 2) # id>1 的数据有2条
|
|
self.assertIn('name', df.columns)
|
|
|
|
def test_insert_from_df(self):
|
|
"""测试DataFrame插入"""
|
|
new_data = pd.DataFrame({
|
|
'id': [4, 5],
|
|
'name': ['Test4', 'Test5'],
|
|
'value': [40.1, 50.2],
|
|
'created_at': pd.to_datetime(['2023-01-04', '2023-01-05'])
|
|
})
|
|
|
|
inserted_rows = self.db.insert_from_df(self.test_table, new_data)
|
|
self.assertEqual(inserted_rows, 2)
|
|
|
|
# 验证插入结果
|
|
result_df = self.db.query_to_df(
|
|
f"SELECT name FROM {self.test_table} WHERE id IN (4,5)"
|
|
)
|
|
self.assertEqual(result_df['name'].tolist(), ['Test4', 'Test5'])
|
|
|
|
def test_update_from_df(self):
|
|
"""测试DataFrame更新"""
|
|
update_data = pd.DataFrame({
|
|
'id': [1, 2],
|
|
'name': ['Updated1', 'Updated2']
|
|
})
|
|
|
|
updated_rows = self.db.update_from_df(self.test_table, update_data, 'id')
|
|
self.assertGreaterEqual(updated_rows, 2)
|
|
|
|
# 验证更新结果
|
|
result_df = self.db.query_to_df(
|
|
f"SELECT name FROM {self.test_table} WHERE id IN (1,2)"
|
|
)
|
|
self.assertIn('Updated1', result_df['name'].values)
|
|
self.assertIn('Updated2', result_df['name'].values)
|
|
|
|
def test_transaction(self):
|
|
"""测试事务处理"""
|
|
conn = self.db.begin_transaction()
|
|
try:
|
|
# 执行事务内操作
|
|
cursor = conn.cursor()
|
|
cursor.execute(f"UPDATE {self.test_table} SET value = 99.9 WHERE id = 1")
|
|
cursor.execute(f"UPDATE {self.test_table} SET value = 88.8 WHERE id = 2")
|
|
self.db.commit_transaction(conn)
|
|
except Exception:
|
|
self.db.rollback_transaction(conn)
|
|
raise
|
|
|
|
# 验证事务提交结果
|
|
result_df = self.db.query_to_df(
|
|
f"SELECT value FROM {self.test_table} WHERE id IN (1,2)"
|
|
)
|
|
self.assertIn(99.9, result_df['value'].values)
|
|
self.assertIn(88.8, result_df['value'].values)
|
|
|
|
def test_large_data_insert(self):
|
|
"""测试大数据量插入"""
|
|
# 生成1000行测试数据
|
|
large_data = pd.DataFrame({
|
|
'id': range(1000, 2000),
|
|
'name': [f"Item_{i}" for i in range(1000, 2000)],
|
|
'value': [i * 0.1 for i in range(1000, 2000)],
|
|
'created_at': pd.date_range('2023-01-01', periods=1000)
|
|
})
|
|
|
|
# 根据平台自动调整批次大小
|
|
chunk_size = 100 if platform.system() == 'Windows' else 500
|
|
start_time = time.time()
|
|
inserted_rows = self.db.insert_from_df(
|
|
self.test_table,
|
|
large_data,
|
|
chunk_size=chunk_size
|
|
)
|
|
elapsed = time.time() - start_time
|
|
|
|
self.assertEqual(inserted_rows, 1000)
|
|
print(f"插入1000行数据耗时: {elapsed:.2f}秒 (批次大小: {chunk_size})")
|
|
|
|
def test_concurrent_access(self):
|
|
"""测试并发访问"""
|
|
|
|
def query_worker(i):
|
|
"""并发查询工作函数"""
|
|
df = self.db.query_to_df(
|
|
f"SELECT * FROM {self.test_table} WHERE id = %s",
|
|
params=(i % 3 + 1,) # 查询id=1,2,3循环
|
|
)
|
|
return len(df)
|
|
|
|
# 20个线程执行100次查询
|
|
start_time = time.time()
|
|
with ThreadPoolExecutor(max_workers=20) as executor:
|
|
results = list(executor.map(query_worker, range(100)))
|
|
elapsed = time.time() - start_time
|
|
|
|
self.assertEqual(sum(results), 100) # 每次查询应返回1行
|
|
print(f"100次并发查询耗时: {elapsed:.2f}秒")
|
|
|
|
|
|
class TestPlatformSpecific(unittest.TestCase):
|
|
"""平台特定功能测试"""
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.test_db_name = f"test_platform_db_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
|
cls.base_config = {
|
|
'host': 'localhost',
|
|
'port': 3306,
|
|
'user': 'root',
|
|
'password': '123123'
|
|
}
|
|
|
|
# 创建测试数据库
|
|
temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4')
|
|
try:
|
|
with temp_conn.cursor() as cursor:
|
|
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}")
|
|
temp_conn.commit()
|
|
finally:
|
|
temp_conn.close()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
"""清理测试数据库"""
|
|
temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4')
|
|
try:
|
|
with temp_conn.cursor() as cursor:
|
|
cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}")
|
|
temp_conn.commit()
|
|
finally:
|
|
temp_conn.close()
|
|
|
|
def test_windows_timeout(self):
|
|
"""测试Windows平台超时处理"""
|
|
if platform.system() != 'Windows':
|
|
self.skipTest("仅在Windows平台运行")
|
|
|
|
config = {
|
|
**self.base_config,
|
|
'database': self.test_db_name,
|
|
'connect_timeout': 1,
|
|
'read_timeout': 1,
|
|
'write_timeout': 1
|
|
}
|
|
db = MySQLAgent(config)
|
|
|
|
# 执行会超时的查询(SLEEP(2)超过1秒超时设置)
|
|
with self.assertRaises((pymysql.OperationalError, TimeoutError)) as ctx:
|
|
try:
|
|
db.query_to_df("SELECT SLEEP(2)")
|
|
except Exception as e:
|
|
# 提取底层异常信息(可能被包装)
|
|
while hasattr(e, 'args') and len(e.args) > 0 and isinstance(e.args[0], Exception):
|
|
e = e.args[0]
|
|
raise e
|
|
|
|
error_msg = str(ctx.exception)
|
|
self.assertTrue(
|
|
"timed out" in error_msg or
|
|
"timeout" in error_msg or
|
|
"HY000" in error_msg, # MySQL超时错误码
|
|
f"未检测到超时异常,实际异常: {error_msg}"
|
|
)
|
|
|
|
def test_macos_ssl_connection(self):
|
|
"""测试macOS平台SSL连接"""
|
|
if platform.system() != 'Darwin':
|
|
self.skipTest("仅在macOS平台运行")
|
|
|
|
config = {
|
|
**self.base_config,
|
|
'database': self.test_db_name,
|
|
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
|
|
}
|
|
db = MySQLAgent(config)
|
|
version_df = db.query_to_df("SELECT VERSION() as version")
|
|
self.assertIsNotNone(version_df)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main(verbosity=2) |