数据库操作

This commit is contained in:
2025-08-06 16:24:17 +08:00
parent aa0b71a90b
commit c8d268647f
11 changed files with 1344 additions and 706 deletions
-1
View File
@@ -1 +0,0 @@
{"a":{"0":1},"b":{"0":2}}
+291
View File
@@ -0,0 +1,291 @@
import unittest
import pandas as pd
from datetime import datetime
import tempfile
import time
import pymysql
from storage.mysql_agent import MySQLAgent
import platform
class TestMySQLAgent(unittest.TestCase):
@classmethod
def setUpClass(cls):
"""初始化测试环境和测试表"""
# 创建唯一的测试数据库名
cls.test_db_name = "test_db_" + datetime.now().strftime("%Y%m%d%H%M%S")
cls.test_table = "test_table_" + datetime.now().strftime("%Y%m%d%H%M%S")
# 基础配置
cls.base_config = {
'host': 'localhost',
'port': 3306,
'user': 'root',
'password': '123123',
'max_connections': 10
}
# 创建测试数据库
cls._create_test_database()
# 初始化数据库连接
cls.db = MySQLAgent({
**cls.base_config,
'database': cls.test_db_name
})
# 创建测试表
test_data = pd.DataFrame({
'id': [1, 2, 3],
'name': ['Test1', 'Test2', 'Test3'],
'value': [10.5, 20.3, 30.8],
'created_at': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03'])
})
cls.db.create_table_from_df(cls.test_table, test_data, primary_key='id')
cls.db.insert_from_df(cls.test_table, test_data)
@classmethod
def _create_test_database(cls):
"""创建测试数据库"""
# 使用临时连接创建数据库
temp_conn = pymysql.connect(
host=cls.base_config['host'],
port=cls.base_config['port'],
user=cls.base_config['user'],
password=cls.base_config['password'],
charset='utf8mb4'
)
try:
with temp_conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}")
cursor.execute(f"USE {cls.test_db_name}")
cursor.execute("SET GLOBAL max_connections = 100")
temp_conn.commit()
finally:
temp_conn.close()
@classmethod
def tearDownClass(cls):
"""清理测试数据库"""
if hasattr(cls, 'db') and cls.db:
# 删除测试表
if cls.db.table_exists(cls.test_table):
cls.db.drop_table(cls.test_table)
# 删除测试数据库
temp_conn = pymysql.connect(
host=cls.base_config['host'],
port=cls.base_config['port'],
user=cls.base_config['user'],
password=cls.base_config['password'],
charset='utf8mb4'
)
try:
with temp_conn.cursor() as cursor:
cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}")
temp_conn.commit()
finally:
temp_conn.close()
def test_01_connection(self):
"""测试数据库连接"""
version = self.db.query_to_df("SELECT VERSION() as version")
self.assertIsNotNone(version)
print(f"\nDatabase version: {version['version'].iloc[0]}")
print(f"Running on: {platform.system()} {platform.release()}")
def test_02_query_to_df(self):
"""测试查询返回DataFrame"""
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id > %s", (1,))
self.assertEqual(len(df), 2)
self.assertIsInstance(df, pd.DataFrame)
print("\nQuery result sample:")
print(df.head())
def test_03_insert_from_df(self):
"""测试DataFrame插入"""
new_data = pd.DataFrame({
'id': [4, 5],
'name': ['Test4', 'Test5'],
'value': [40.1, 50.2],
'created_at': pd.to_datetime(['2023-01-04', '2023-01-05'])
})
rows = self.db.insert_from_df(self.test_table, new_data)
self.assertEqual(rows, 2)
# 验证数据
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id >= 4")
self.assertEqual(len(df), 2)
self.assertEqual(df['name'].tolist(), ['Test4', 'Test5'])
def test_04_update_from_df(self):
"""测试DataFrame更新"""
update_data = pd.DataFrame({
'id': [1, 2],
'name': ['Updated1', 'Updated2']
})
rows = self.db.update_from_df(self.test_table, update_data, 'id')
self.assertGreaterEqual(rows, 2)
# 验证更新
df = self.db.query_to_df(f"SELECT name FROM {self.test_table} WHERE id IN (1,2)")
self.assertIn('Updated1', df['name'].values)
self.assertIn('Updated2', df['name'].values)
def test_05_transaction(self):
"""测试事务处理"""
conn = self.db.begin_transaction()
try:
# 执行多个操作
cursor = conn.cursor()
cursor.execute(f"UPDATE {self.test_table} SET value = 99.9 WHERE id = 1")
cursor.execute(f"UPDATE {self.test_table} SET value = 88.8 WHERE id = 2")
# 验证事务内修改
cursor.execute(f"SELECT value FROM {self.test_table} WHERE id = 1")
self.assertEqual(cursor.fetchone()['value'], 99.9)
self.db.commit_transaction(conn)
except Exception:
self.db.rollback_transaction(conn)
raise
# 验证提交后的修改
df = self.db.query_to_df(f"SELECT value FROM {self.test_table} WHERE id IN (1,2)")
self.assertIn(99.9, df['value'].values)
self.assertIn(88.8, df['value'].values)
def test_06_large_data(self):
"""测试大数据量操作"""
# 生成测试数据
large_data = pd.DataFrame({
'id': range(1000, 2000),
'name': [f"Item_{i}" for i in range(1000, 2000)],
'value': [i * 0.1 for i in range(1000, 2000)],
'created_at': pd.date_range('2023-01-01', periods=1000)
})
# Windows平台使用更小的批次
chunk_size = 100 if platform.system() == 'Windows' else 500
start_time = time.time()
rows = self.db.insert_from_df(self.test_table, large_data, chunk_size=chunk_size)
elapsed = time.time() - start_time
self.assertEqual(rows, 1000)
print(f"\nInserted 1000 rows in {elapsed:.2f}s (chunk_size={chunk_size})")
# 验证数据
df = self.db.query_to_df(f"SELECT COUNT(*) as cnt FROM {self.test_table} WHERE id >= 1000")
self.assertEqual(df['cnt'].iloc[0], 1000)
def test_07_concurrent_access(self):
"""测试并发访问"""
from concurrent.futures import ThreadPoolExecutor
def worker(i):
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id = %s", (i % 5 + 1,))
return len(df)
start_time = time.time()
with ThreadPoolExecutor(max_workers=20) as executor:
results = list(executor.map(worker, range(100)))
elapsed = time.time() - start_time
self.assertEqual(sum(results), 100)
print(f"\nCompleted 100 concurrent queries in {elapsed:.2f}s")
class TestPlatformSpecific(unittest.TestCase):
@classmethod
def setUpClass(cls):
"""创建临时测试数据库"""
cls.test_db_name = "test_db_platform_" + datetime.now().strftime("%Y%m%d%H%M%S")
cls.base_config = {
'host': 'localhost',
'port': 3306,
'user': 'root',
'password': '123123',
'max_connections': 10
}
# 创建数据库
temp_conn = pymysql.connect(
host=cls.base_config['host'],
port=cls.base_config['port'],
user=cls.base_config['user'],
password=cls.base_config['password'],
charset='utf8mb4'
)
try:
with temp_conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}")
temp_conn.commit()
finally:
temp_conn.close()
@classmethod
def tearDownClass(cls):
"""删除临时测试数据库"""
temp_conn = pymysql.connect(
host=cls.base_config['host'],
port=cls.base_config['port'],
user=cls.base_config['user'],
password=cls.base_config['password'],
charset='utf8mb4'
)
try:
with temp_conn.cursor() as cursor:
cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}")
temp_conn.commit()
finally:
temp_conn.close()
def test_windows_timeout(self):
"""测试Windows平台超时处理"""
if platform.system() != 'Windows':
self.skipTest("Only runs on Windows")
config = {
**self.base_config,
'database': self.test_db_name,
'connect_timeout': 1,
'read_timeout': 1
}
db = MySQLAgent(config)
# 测试短超时查询
start_time = time.time()
try:
db.query_to_df("SELECT SLEEP(2)")
self.fail("Should have timed out")
except Exception as e:
self.assertIn("timed out", str(e))
print(f"\nWindows timeout test: {str(e)}")
def test_macos_ssl(self):
"""测试macOS SSL连接"""
if platform.system() != 'Darwin':
self.skipTest("Only runs on macOS")
config = {
**self.base_config,
'database': self.test_db_name,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
}
db = MySQLAgent(config)
version = db.query_to_df("SELECT VERSION() as version")
self.assertIsNotNone(version)
print(f"\nmacOS SSL connection successful: {version['version'].iloc[0]}")
if __name__ == '__main__':
unittest.main()
+13 -3
View File
@@ -3,6 +3,8 @@ import pandas as pd
import os
from pathlib import Path
from utils.file_handler import FileHandler
from datetime import datetime
@pytest.fixture
def temp_dir(tmp_path):
@@ -11,11 +13,13 @@ def temp_dir(tmp_path):
test_dir.mkdir()
return test_dir
@pytest.fixture
def file_handler(temp_dir):
"""创建FileHandler实例"""
return FileHandler(temp_dir)
@pytest.fixture
def sample_dataframe():
"""创建测试用DataFrame"""
@@ -25,6 +29,7 @@ def sample_dataframe():
'value': [10.5, 20.3, 30.1]
})
@pytest.fixture
def sample_text_file(temp_dir):
"""创建测试文本文件"""
@@ -55,30 +60,33 @@ def test_read_write_csv(file_handler, temp_dir, sample_dataframe):
assert df.shape == (3, 3)
assert list(df.columns) == ['id', 'name', 'value']
def test_read_write_json(file_handler, temp_dir, sample_dataframe):
"""测试JSON文件读写"""
test_file = temp_dir / "test.json"
# 测试写入
write_result = file_handler.write_file(test_file, sample_dataframe)
assert write_result.iloc[0]['success'] == True
assert write_result.iloc[0]['success'] == True
# 测试读取
df = file_handler.read_file(test_file)
assert df.shape == (3, 3)
def test_read_write_excel(file_handler, temp_dir, sample_dataframe):
"""测试Excel文件读写"""
test_file = temp_dir / "test.xlsx"
# 测试写入
write_result = file_handler.write_file(test_file, sample_dataframe)
assert write_result.iloc[0]['success'] == True
assert write_result.iloc[0]['success'] == True
# 测试读取
df = file_handler.read_file(test_file)
assert df.shape == (3, 3)
def test_read_write_csv(file_handler, temp_dir, sample_dataframe):
"""测试CSV文件读写"""
test_file = temp_dir / "test.csv"
@@ -119,6 +127,7 @@ def test_file_operations(file_handler, sample_text_file):
assert delete_df.iloc[0]['deleted'] == True
assert not os.path.exists(sample_text_file)
def test_directory_operations(file_handler, temp_dir):
"""测试目录操作"""
test_dir = temp_dir / "subdir"
@@ -160,6 +169,7 @@ def test_zip_operations(file_handler, temp_dir, sample_dataframe):
assert os.path.exists(extract_dir / "file1.txt")
assert os.path.exists(extract_dir / "file2.csv")
def test_zip_directory(file_handler, temp_dir):
"""测试目录压缩"""
# 创建测试目录结构
@@ -174,4 +184,4 @@ def test_zip_directory(file_handler, temp_dir):
zip_path = temp_dir / "dir.zip"
zip_result = file_handler.zip_dir(test_dir, zip_path)
assert zip_result.iloc[0]['zipped'] == True
assert zip_result.iloc[0]['file_count'] == 2
assert zip_result.iloc[0]['file_count'] == 2