import unittest import pandas as pd from datetime import datetime import time import pymysql import platform from concurrent.futures import ThreadPoolExecutor from utils.mysql_agent import MySQLAgent class TestMySQLAgent(unittest.TestCase): @classmethod def setUpClass(cls): """初始化测试环境和测试表""" # 创建唯一的测试数据库和表名(避免冲突) cls.test_db_name = f"test_db_{datetime.now().strftime('%Y%m%d%H%M%S')}" cls.test_table = f"test_table_{datetime.now().strftime('%Y%m%d%H%M%S')}" # 基础配置(根据实际环境修改) cls.base_config = { 'host': 'localhost', 'port': 3306, 'user': 'root', 'password': '123123', 'max_connections': 10 } # 创建测试数据库 cls._create_test_database() # 初始化数据库连接 cls.db = MySQLAgent({ **cls.base_config, 'database': cls.test_db_name }) # 创建测试表并插入初始数据 test_data = pd.DataFrame({ 'id': [1, 2, 3], 'name': ['Test1', 'Test2', 'Test3'], 'value': [10.5, 20.3, 30.8], 'created_at': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03']) }) cls.db.create_table_from_df(cls.test_table, test_data, primary_key='id') cls.db.insert_from_df(cls.test_table, test_data) @classmethod def _create_test_database(cls): """创建测试数据库""" temp_conn = pymysql.connect( host=cls.base_config['host'], port=cls.base_config['port'], user=cls.base_config['user'], password=cls.base_config['password'], charset='utf8mb4' ) try: with temp_conn.cursor() as cursor: cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}") cursor.execute(f"USE {cls.test_db_name}") cursor.execute("SET GLOBAL max_connections = 100") temp_conn.commit() finally: temp_conn.close() @classmethod def tearDownClass(cls): """清理测试环境""" if hasattr(cls, 'db') and cls.db: # 删除测试表 if cls.db.table_exists(cls.test_table): cls.db.drop_table(cls.test_table) # 删除测试数据库 temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4') try: with temp_conn.cursor() as cursor: cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}") temp_conn.commit() finally: temp_conn.close() def test_connection(self): """测试数据库连接""" version_df = self.db.query_to_df("SELECT VERSION() as version") self.assertIsNotNone(version_df) self.assertEqual(len(version_df), 1) print(f"数据库版本: {version_df['version'].iloc[0]}") def test_query_to_df(self): """测试查询返回DataFrame""" df = self.db.query_to_df( f"SELECT * FROM {self.test_table} WHERE id > %s", params=(1,) ) self.assertIsInstance(df, pd.DataFrame) self.assertEqual(len(df), 2) # id>1 的数据有2条 self.assertIn('name', df.columns) def test_insert_from_df(self): """测试DataFrame插入""" new_data = pd.DataFrame({ 'id': [4, 5], 'name': ['Test4', 'Test5'], 'value': [40.1, 50.2], 'created_at': pd.to_datetime(['2023-01-04', '2023-01-05']) }) inserted_rows = self.db.insert_from_df(self.test_table, new_data) self.assertEqual(inserted_rows, 2) # 验证插入结果 result_df = self.db.query_to_df( f"SELECT name FROM {self.test_table} WHERE id IN (4,5)" ) self.assertEqual(result_df['name'].tolist(), ['Test4', 'Test5']) def test_update_from_df(self): """测试DataFrame更新""" update_data = pd.DataFrame({ 'id': [1, 2], 'name': ['Updated1', 'Updated2'] }) updated_rows = self.db.update_from_df(self.test_table, update_data, 'id') self.assertGreaterEqual(updated_rows, 2) # 验证更新结果 result_df = self.db.query_to_df( f"SELECT name FROM {self.test_table} WHERE id IN (1,2)" ) self.assertIn('Updated1', result_df['name'].values) self.assertIn('Updated2', result_df['name'].values) def test_transaction(self): """测试事务处理""" conn = self.db.begin_transaction() try: # 执行事务内操作 cursor = conn.cursor() cursor.execute(f"UPDATE {self.test_table} SET value = 99.9 WHERE id = 1") cursor.execute(f"UPDATE {self.test_table} SET value = 88.8 WHERE id = 2") self.db.commit_transaction(conn) except Exception: self.db.rollback_transaction(conn) raise # 验证事务提交结果 result_df = self.db.query_to_df( f"SELECT value FROM {self.test_table} WHERE id IN (1,2)" ) self.assertIn(99.9, result_df['value'].values) self.assertIn(88.8, result_df['value'].values) def test_large_data_insert(self): """测试大数据量插入""" # 生成1000行测试数据 large_data = pd.DataFrame({ 'id': range(1000, 2000), 'name': [f"Item_{i}" for i in range(1000, 2000)], 'value': [i * 0.1 for i in range(1000, 2000)], 'created_at': pd.date_range('2023-01-01', periods=1000) }) # 根据平台自动调整批次大小 chunk_size = 100 if platform.system() == 'Windows' else 500 start_time = time.time() inserted_rows = self.db.insert_from_df( self.test_table, large_data, chunk_size=chunk_size ) elapsed = time.time() - start_time self.assertEqual(inserted_rows, 1000) print(f"插入1000行数据耗时: {elapsed:.2f}秒 (批次大小: {chunk_size})") def test_concurrent_access(self): """测试并发访问""" def query_worker(i): """并发查询工作函数""" df = self.db.query_to_df( f"SELECT * FROM {self.test_table} WHERE id = %s", params=(i % 3 + 1,) # 查询id=1,2,3循环 ) return len(df) # 20个线程执行100次查询 start_time = time.time() with ThreadPoolExecutor(max_workers=20) as executor: results = list(executor.map(query_worker, range(100))) elapsed = time.time() - start_time self.assertEqual(sum(results), 100) # 每次查询应返回1行 print(f"100次并发查询耗时: {elapsed:.2f}秒") class TestPlatformSpecific(unittest.TestCase): """平台特定功能测试""" @classmethod def setUpClass(cls): cls.test_db_name = f"test_platform_db_{datetime.now().strftime('%Y%m%d%H%M%S')}" cls.base_config = { 'host': 'localhost', 'port': 3306, 'user': 'root', 'password': '123123' } # 创建测试数据库 temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4') try: with temp_conn.cursor() as cursor: cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}") temp_conn.commit() finally: temp_conn.close() @classmethod def tearDownClass(cls): """清理测试数据库""" temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4') try: with temp_conn.cursor() as cursor: cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}") temp_conn.commit() finally: temp_conn.close() def test_windows_timeout(self): """测试Windows平台超时处理""" if platform.system() != 'Windows': self.skipTest("仅在Windows平台运行") config = { **self.base_config, 'database': self.test_db_name, 'connect_timeout': 1, 'read_timeout': 1, 'write_timeout': 1 } db = MySQLAgent(config) # 执行会超时的查询(SLEEP(2)超过1秒超时设置) with self.assertRaises((pymysql.OperationalError, TimeoutError)) as ctx: try: db.query_to_df("SELECT SLEEP(2)") except Exception as e: # 提取底层异常信息(可能被包装) while hasattr(e, 'args') and len(e.args) > 0 and isinstance(e.args[0], Exception): e = e.args[0] raise e error_msg = str(ctx.exception) self.assertTrue( "timed out" in error_msg or "timeout" in error_msg or "HY000" in error_msg, # MySQL超时错误码 f"未检测到超时异常,实际异常: {error_msg}" ) def test_macos_ssl_connection(self): """测试macOS平台SSL连接""" if platform.system() != 'Darwin': self.skipTest("仅在macOS平台运行") config = { **self.base_config, 'database': self.test_db_name, 'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'} } db = MySQLAgent(config) version_df = db.query_to_df("SELECT VERSION() as version") self.assertIsNotNone(version_df) if __name__ == '__main__': unittest.main(verbosity=2)