mysql数据链接更新
This commit is contained in:
+149
-116
@@ -6,72 +6,108 @@ import os
|
|||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
import pymysql
|
from loguru import logger
|
||||||
|
from utils.mysql_agent import MySQLAgent
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
|
||||||
# 数据库连接信息
|
# 数据库连接配置
|
||||||
local_DB_Config = {
|
local_DB_Config = {
|
||||||
'host': "localhost",
|
'host': "localhost",
|
||||||
'user': "root",
|
'user': "root",
|
||||||
'password': "123123",
|
'password': "123123",
|
||||||
'database': "intelligence_system",
|
'database': "intelligence_system",
|
||||||
'charset': 'utf8mb4'
|
'port': 3306,
|
||||||
|
'charset': 'utf8mb4',
|
||||||
|
'connect_timeout': 10,
|
||||||
|
'read_timeout': 30,
|
||||||
|
'write_timeout': 30,
|
||||||
|
'autocommit': True
|
||||||
}
|
}
|
||||||
|
|
||||||
# 表名
|
# 目标数据表名
|
||||||
table_name = "collector_rss_subscriptions"
|
table_name = "collector_rss_subscriptions"
|
||||||
|
|
||||||
|
|
||||||
def verify_database():
|
class NewsAPIClient:
|
||||||
"""验证数据库连接和表结构"""
|
"""新闻API客户端,用于获取和处理RSS源数据并写入数据库"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""初始化客户端并建立数据库连接"""
|
||||||
|
self.logger = logger.bind(module="NewsAPIClient")
|
||||||
|
self.db_agent = MySQLAgent(local_DB_Config)
|
||||||
|
self.logger.info("新闻API客户端初始化完成,已连接到数据库")
|
||||||
|
|
||||||
|
def _format_result(self, success: bool, message: str = "", data: Optional[Any] = None) -> Dict[str, Any]:
|
||||||
|
"""统一返回结果格式"""
|
||||||
|
return {
|
||||||
|
'success': bool(success),
|
||||||
|
'message': str(message),
|
||||||
|
'data': data
|
||||||
|
}
|
||||||
|
|
||||||
|
def verify_database(self) -> bool:
|
||||||
|
"""验证数据库表结构是否符合要求(适配元组格式的查询结果)"""
|
||||||
try:
|
try:
|
||||||
conn = pymysql.connect(**local_DB_Config)
|
# 1. 检查表是否存在(execute_sql返回元组列表,如 [(table_name,)])
|
||||||
with conn.cursor() as cursor:
|
result = self.db_agent.execute_sql(
|
||||||
# 检查表是否存在
|
f"SHOW TABLES LIKE '{table_name}'",
|
||||||
cursor.execute(f"SHOW TABLES LIKE '{table_name}'")
|
fetch=True
|
||||||
if not cursor.fetchone():
|
)
|
||||||
print(f"错误: 表 {table_name} 不存在!")
|
# 元组结果需通过索引0判断(若表存在,result是[(table_name,)], 否则为空列表)
|
||||||
|
if not result:
|
||||||
|
self.logger.error(f"表 {table_name} 不存在,请先创建表结构")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 检查表结构
|
# 2. 检查表字段是否完整(DESCRIBE返回的元组格式:(字段名, 类型, 是否为空, ...))
|
||||||
cursor.execute(f"DESCRIBE {table_name}")
|
desc_result = self.db_agent.execute_sql(
|
||||||
columns = [col[0] for col in cursor.fetchall()]
|
f"DESCRIBE {table_name}",
|
||||||
print("表列名:", columns)
|
fetch=True
|
||||||
|
)
|
||||||
|
# 关键修改:用元组索引0提取字段名(而非字典键'Field')
|
||||||
|
columns = [col[0] for col in desc_result] # col是元组,col[0]即字段名
|
||||||
|
required_columns = ['文章标题', '文章链接', '文章摘要', '发布时间',
|
||||||
|
'来源URL', '创建时间', '更新时间']
|
||||||
|
missing_cols = [col for col in required_columns if col not in columns]
|
||||||
|
|
||||||
# 检查插入权限
|
if missing_cols:
|
||||||
test_sql = f"""INSERT INTO `{table_name}`
|
self.logger.error(f"表 {table_name} 缺少必要字段:{missing_cols}")
|
||||||
(`文章标题`, `文章链接`, `文章摘要`, `发布时间`, `来源URL`)
|
return False
|
||||||
VALUES (%s, %s, %s, %s, %s)"""
|
|
||||||
cursor.execute(test_sql, ('测试标题', 'http://test.com', '测试内容', datetime.now(), '测试来源'))
|
|
||||||
conn.rollback()
|
|
||||||
|
|
||||||
print("数据库验证通过!")
|
self.logger.info(f"数据库表结构验证通过,当前字段:{columns}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("数据库验证失败:", e)
|
self.logger.error(f"数据库验证失败: {str(e)}", exc_info=True)
|
||||||
return False
|
return False
|
||||||
finally:
|
|
||||||
if 'conn' in locals():
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
|
def load_last_update_time(self) -> Optional[datetime]:
|
||||||
def load_last_update_time():
|
"""加载上次更新时间缓存"""
|
||||||
"""加载上次更新的时间"""
|
|
||||||
cache_file = os.path.join(os.getcwd(), 'output', 'last_update.pkl')
|
cache_file = os.path.join(os.getcwd(), 'output', 'last_update.pkl')
|
||||||
if os.path.exists(cache_file):
|
if os.path.exists(cache_file):
|
||||||
|
try:
|
||||||
with open(cache_file, 'rb') as f:
|
with open(cache_file, 'rb') as f:
|
||||||
return pickle.load(f)
|
last_update = pickle.load(f)
|
||||||
|
self.logger.debug(f"加载上次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
return last_update
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"加载上次更新时间失败: {str(e)}", exc_info=True)
|
||||||
|
self.logger.debug("未找到上次更新时间缓存,将获取全部数据")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def save_last_update_time(self, last_update: datetime) -> None:
|
||||||
|
"""保存本次更新时间"""
|
||||||
|
try:
|
||||||
|
cache_dir = os.path.join(os.getcwd(), 'output')
|
||||||
|
os.makedirs(cache_dir, exist_ok=True)
|
||||||
|
cache_file = os.path.join(cache_dir, 'last_update.pkl')
|
||||||
|
|
||||||
def save_last_update_time(last_update):
|
|
||||||
"""保存本次更新的时间"""
|
|
||||||
cache_file = os.path.join(os.getcwd(), 'output', 'last_update.pkl')
|
|
||||||
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
|
|
||||||
with open(cache_file, 'wb') as f:
|
with open(cache_file, 'wb') as f:
|
||||||
pickle.dump(last_update, f)
|
pickle.dump(last_update, f)
|
||||||
|
self.logger.debug(f"已保存本次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"保存更新时间失败: {str(e)}", exc_info=True)
|
||||||
|
|
||||||
|
def fetch_single_rss(self, url: str, timeout: int = 15) -> Optional[feedparser.FeedParserDict]:
|
||||||
def fetch_single_rss(url, timeout=15):
|
|
||||||
"""获取并解析单个RSS源"""
|
"""获取并解析单个RSS源"""
|
||||||
headers = {
|
headers = {
|
||||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
||||||
@@ -85,24 +121,25 @@ def fetch_single_rss(url, timeout=15):
|
|||||||
feed = feedparser.parse(response.text)
|
feed = feedparser.parse(response.text)
|
||||||
|
|
||||||
if feed.bozo:
|
if feed.bozo:
|
||||||
print(f"警告: 解析可能存在问题: {feed.bozo_exception}")
|
self.logger.warning(f"解析 {url} 存在潜在问题: {feed.bozo_exception}")
|
||||||
|
|
||||||
|
self.logger.debug(f"成功获取 {url} 的RSS数据")
|
||||||
return feed
|
return feed
|
||||||
|
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
print(f"第 {attempt + 1} 次尝试获取 {url} 失败: {e}")
|
self.logger.warning(f"第 {attempt + 1} 次获取 {url} 失败: {str(e)}")
|
||||||
if attempt < 2:
|
if attempt < 2:
|
||||||
time.sleep(5 * (attempt + 1))
|
time.sleep(3 * (attempt + 1)) # 指数退避重试
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
self.logger.error(f"三次尝试后仍无法获取 {url} 的RSS数据")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def fetch_all_rss(self, urls: List[str], timeout: int = 15) -> Dict[str, feedparser.FeedParserDict]:
|
||||||
def fetch_all_rss(urls, timeout=15):
|
"""并发获取多个RSS源"""
|
||||||
"""使用线程池并发获取多个RSS源"""
|
|
||||||
feeds = {}
|
feeds = {}
|
||||||
with ThreadPoolExecutor(max_workers=3) as executor:
|
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||||
future_to_url = {executor.submit(fetch_single_rss, url, timeout): url for url in urls}
|
future_to_url = {executor.submit(self.fetch_single_rss, url, timeout): url for url in urls}
|
||||||
|
|
||||||
for future in as_completed(future_to_url):
|
for future in as_completed(future_to_url):
|
||||||
url = future_to_url[future]
|
url = future_to_url[future]
|
||||||
@@ -111,13 +148,13 @@ def fetch_all_rss(urls, timeout=15):
|
|||||||
if feed:
|
if feed:
|
||||||
feeds[url] = feed
|
feeds[url] = feed
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"获取 {url} 时发生异常: {e}")
|
self.logger.error(f"处理 {url} 时发生异常: {str(e)}", exc_info=True)
|
||||||
|
|
||||||
|
self.logger.info(f"RSS源获取完成,成功获取 {len(feeds)}/{len(urls)} 个源")
|
||||||
return feeds
|
return feeds
|
||||||
|
|
||||||
|
def process_feed_entry(self, entry: Dict[str, Any], url: str) -> Dict[str, str]:
|
||||||
def process_feed_entry(entry, url):
|
"""处理单个RSS条目,转换为数据库兼容格式"""
|
||||||
"""处理单个RSS条目并返回结构化数据"""
|
|
||||||
# 处理标题
|
# 处理标题
|
||||||
title = entry.get('title', '无标题') or '无标题'
|
title = entry.get('title', '无标题') or '无标题'
|
||||||
if len(title) > 255:
|
if len(title) > 255:
|
||||||
@@ -131,7 +168,7 @@ def process_feed_entry(entry, url):
|
|||||||
# 处理摘要
|
# 处理摘要
|
||||||
summary = entry.get('summary', '无内容摘要')
|
summary = entry.get('summary', '无内容摘要')
|
||||||
content_list = entry.get('content', [])
|
content_list = entry.get('content', [])
|
||||||
content = content_list[0].value if content_list else ''
|
content = content_list[0].value if (content_list and hasattr(content_list[0], 'value')) else ''
|
||||||
description = summary if summary != '无内容摘要' else (content[:200] + '...' if content else '无内容摘要')
|
description = summary if summary != '无内容摘要' else (content[:200] + '...' if content else '无内容摘要')
|
||||||
|
|
||||||
# 处理发布时间
|
# 处理发布时间
|
||||||
@@ -141,7 +178,7 @@ def process_feed_entry(entry, url):
|
|||||||
else:
|
else:
|
||||||
pub_str = entry.get('published', entry.get('updated', ''))
|
pub_str = entry.get('published', entry.get('updated', ''))
|
||||||
try:
|
try:
|
||||||
entry_time = datetime.strptime(pub_str, '%a, %d %b %Y %H:%M:%S %z')
|
entry_time = datetime.strptime(pub_str, '%a, %d %b %Y %H:%M:%S %z').replace(tzinfo=None)
|
||||||
except:
|
except:
|
||||||
entry_time = datetime.now()
|
entry_time = datetime.now()
|
||||||
|
|
||||||
@@ -150,100 +187,89 @@ def process_feed_entry(entry, url):
|
|||||||
if len(source_url) > 1024:
|
if len(source_url) > 1024:
|
||||||
source_url = source_url[:1021] + '...'
|
source_url = source_url[:1021] + '...'
|
||||||
|
|
||||||
|
# 当前时间(创建/更新时间)
|
||||||
|
current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'文章标题': title,
|
'文章标题': title,
|
||||||
'文章链接': link,
|
'文章链接': link,
|
||||||
'文章摘要': description,
|
'文章摘要': description,
|
||||||
'发布时间': entry_time.strftime('%Y-%m-%d %H:%M:%S'),
|
'发布时间': entry_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||||
'来源URL': source_url
|
'来源URL': source_url,
|
||||||
|
'创建时间': current_time,
|
||||||
|
'更新时间': current_time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def display_feed_info(self, feed: feedparser.FeedParserDict, last_update: Optional[datetime] = None,
|
||||||
def display_feed_info(feed, last_update=None, url=None):
|
url: Optional[str] = None) -> Optional[datetime]:
|
||||||
"""处理并显示RSS源信息"""
|
"""处理RSS源信息并写入数据库"""
|
||||||
if not feed:
|
if not feed:
|
||||||
print("无法显示信息:feed 为 None")
|
self.logger.warning("无法处理空的RSS源数据")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
print("=" * 80)
|
self.logger.info(f"开始处理 RSS 源: {url}")
|
||||||
print(f"处理 RSS 源: {url}")
|
|
||||||
entries = feed.entries
|
entries = feed.entries
|
||||||
data_list = []
|
data_list = []
|
||||||
new_last_update = last_update
|
new_last_update = last_update
|
||||||
|
|
||||||
for i, entry in enumerate(entries, 1):
|
for i, entry in enumerate(entries, 1):
|
||||||
entry_data = process_feed_entry(entry, url)
|
entry_data = self.process_feed_entry(entry, url)
|
||||||
entry_time = datetime.strptime(entry_data['发布时间'], '%Y-%m-%d %H:%M:%S')
|
entry_time = datetime.strptime(entry_data['发布时间'], '%Y-%m-%d %H:%M:%S')
|
||||||
|
|
||||||
|
# 过滤旧数据
|
||||||
if last_update and entry_time <= last_update:
|
if last_update and entry_time <= last_update:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# 更新最新时间戳
|
||||||
if new_last_update is None or entry_time > new_last_update:
|
if new_last_update is None or entry_time > new_last_update:
|
||||||
new_last_update = entry_time
|
new_last_update = entry_time
|
||||||
|
|
||||||
print(f"\n--- 条目 {i} ---")
|
self.logger.debug(f"处理条目 {i}: {entry_data['文章标题']}")
|
||||||
print(f"标题: {entry_data['文章标题']}")
|
|
||||||
print(f"链接: {entry_data['文章链接']}")
|
|
||||||
print(f"摘要: {entry_data['文章摘要'][:100]}...")
|
|
||||||
print(f"时间: {entry_data['发布时间']}")
|
|
||||||
|
|
||||||
data_list.append(entry_data)
|
data_list.append(entry_data)
|
||||||
|
|
||||||
|
# 写入数据库
|
||||||
if data_list:
|
if data_list:
|
||||||
df = pd.DataFrame(data_list)
|
df = pd.DataFrame(data_list)
|
||||||
write_to_database(df)
|
self.write_to_database(df)
|
||||||
|
|
||||||
return new_last_update
|
return new_last_update
|
||||||
|
|
||||||
|
# news_api.py 中的 write_to_database 方法可以保持简洁
|
||||||
def write_to_database(df):
|
def write_to_database(self, df: pd.DataFrame) -> Dict[str, Any]:
|
||||||
"""将数据写入数据库"""
|
|
||||||
if df.empty:
|
if df.empty:
|
||||||
print("没有新数据需要写入")
|
self.logger.info("没有新数据需要写入数据库")
|
||||||
return
|
return self._format_result(True, "没有新数据需要写入")
|
||||||
|
|
||||||
print("\n准备写入数据库的数据样例:")
|
|
||||||
print(df.iloc[0].to_dict())
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conn = pymysql.connect(**local_DB_Config)
|
inserted_rows = self.db_agent.insert_from_df(
|
||||||
with conn.cursor() as cursor:
|
table_name=table_name,
|
||||||
sql = f"""INSERT IGNORE INTO `{table_name}`
|
df=df,
|
||||||
(`文章标题`, `文章链接`, `文章摘要`, `发布时间`, `来源URL`)
|
chunk_size=500,
|
||||||
VALUES (%s, %s, %s, %s, %s)"""
|
replace=False
|
||||||
|
)
|
||||||
|
|
||||||
success_count = 0
|
self.logger.info(f"成功写入 {inserted_rows}/{len(df)} 条记录")
|
||||||
for _, row in df.iterrows():
|
return self._format_result(
|
||||||
|
True,
|
||||||
|
f"成功写入 {inserted_rows}/{len(df)} 条记录",
|
||||||
|
{"success_count": inserted_rows, "total": len(df)}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"数据库写入失败: {str(e)}", exc_info=True)
|
||||||
|
return self._format_result(False, f"数据库操作失败: {str(e)}")
|
||||||
|
@classmethod
|
||||||
|
def main(cls):
|
||||||
|
"""主函数入口"""
|
||||||
try:
|
try:
|
||||||
cursor.execute(sql, (
|
client = cls()
|
||||||
row['文章标题'],
|
|
||||||
row['文章链接'],
|
|
||||||
row['文章摘要'],
|
|
||||||
row['发布时间'],
|
|
||||||
row['来源URL']
|
|
||||||
))
|
|
||||||
success_count += cursor.rowcount
|
|
||||||
except Exception as e:
|
|
||||||
print(f"插入记录时出错: {e}")
|
|
||||||
print(f"问题数据: {row.to_dict()}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
conn.commit()
|
# 验证数据库
|
||||||
print(f"成功写入 {success_count}/{len(df)} 条记录")
|
if not client.verify_database():
|
||||||
|
client.logger.error("数据库验证失败,程序终止")
|
||||||
except Exception as e:
|
|
||||||
print("数据库操作失败:", e)
|
|
||||||
finally:
|
|
||||||
if 'conn' in locals():
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""主函数"""
|
|
||||||
if not verify_database():
|
|
||||||
print("数据库验证失败,程序终止")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# RSS源列表
|
||||||
rss_urls = [
|
rss_urls = [
|
||||||
"https://www.chinanews.com.cn/rss/finance.xml",
|
"https://www.chinanews.com.cn/rss/finance.xml",
|
||||||
"https://www.chinanews.com.cn/rss/world.xml",
|
"https://www.chinanews.com.cn/rss/world.xml",
|
||||||
@@ -251,27 +277,34 @@ def main():
|
|||||||
"https://www.chinanews.com.cn/rss/scroll-news.xml"
|
"https://www.chinanews.com.cn/rss/scroll-news.xml"
|
||||||
]
|
]
|
||||||
|
|
||||||
last_update = load_last_update_time()
|
# 加载上次更新时间
|
||||||
|
last_update = client.load_last_update_time()
|
||||||
if last_update:
|
if last_update:
|
||||||
print(f"上次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
|
client.logger.info(f"上次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
|
||||||
print("\n开始获取RSS源数据...")
|
# 获取RSS数据
|
||||||
|
client.logger.info("开始获取RSS源数据...")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
feeds = fetch_all_rss(rss_urls)
|
feeds = client.fetch_all_rss(rss_urls)
|
||||||
print(f"获取完成,耗时: {time.time() - start_time:.2f}秒")
|
client.logger.info(f"获取完成,耗时: {time.time() - start_time:.2f}秒")
|
||||||
|
|
||||||
|
# 处理并写入数据
|
||||||
new_last_update = None
|
new_last_update = None
|
||||||
for url, feed in feeds.items():
|
for url, feed in feeds.items():
|
||||||
current_last_update = display_feed_info(feed, last_update, url)
|
current_last_update = client.display_feed_info(feed, last_update, url)
|
||||||
if current_last_update and (new_last_update is None or current_last_update > new_last_update):
|
if current_last_update and (new_last_update is None or current_last_update > new_last_update):
|
||||||
new_last_update = current_last_update
|
new_last_update = current_last_update
|
||||||
|
|
||||||
|
# 保存最新更新时间
|
||||||
if new_last_update:
|
if new_last_update:
|
||||||
save_last_update_time(new_last_update)
|
client.save_last_update_time(new_last_update)
|
||||||
print(f"\n本次最新更新时间: {new_last_update.strftime('%Y-%m-%d %H:%M:%S')}")
|
client.logger.info(f"本次最新更新时间: {new_last_update.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
else:
|
else:
|
||||||
print("\n没有获取到新的内容")
|
client.logger.info("没有获取到新内容")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"程序运行出错: {str(e)}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
NewsAPIClient.main()
|
||||||
|
|||||||
+123849
File diff suppressed because it is too large
Load Diff
+63455
File diff suppressed because it is too large
Load Diff
+101
-111
@@ -3,18 +3,20 @@ import pandas as pd
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import time
|
import time
|
||||||
import pymysql
|
import pymysql
|
||||||
from utils.mysql_agent import MySQLAgent
|
|
||||||
import platform
|
import platform
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from utils.mysql_agent import MySQLAgent
|
||||||
|
|
||||||
|
|
||||||
class TestMySQLAgent(unittest.TestCase):
|
class TestMySQLAgent(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
"""初始化测试环境和测试表"""
|
"""初始化测试环境和测试表"""
|
||||||
# 创建唯一的测试数据库名
|
# 创建唯一的测试数据库和表名(避免冲突)
|
||||||
cls.test_db_name = "test_db_" + datetime.now().strftime("%Y%m%d%H%M%S")
|
cls.test_db_name = f"test_db_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||||
cls.test_table = "test_table_" + datetime.now().strftime("%Y%m%d%H%M%S")
|
cls.test_table = f"test_table_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||||
|
|
||||||
# 基础配置
|
# 基础配置(根据实际环境修改)
|
||||||
cls.base_config = {
|
cls.base_config = {
|
||||||
'host': 'localhost',
|
'host': 'localhost',
|
||||||
'port': 3306,
|
'port': 3306,
|
||||||
@@ -32,21 +34,19 @@ class TestMySQLAgent(unittest.TestCase):
|
|||||||
'database': cls.test_db_name
|
'database': cls.test_db_name
|
||||||
})
|
})
|
||||||
|
|
||||||
# 创建测试表
|
# 创建测试表并插入初始数据
|
||||||
test_data = pd.DataFrame({
|
test_data = pd.DataFrame({
|
||||||
'id': [1, 2, 3],
|
'id': [1, 2, 3],
|
||||||
'name': ['Test1', 'Test2', 'Test3'],
|
'name': ['Test1', 'Test2', 'Test3'],
|
||||||
'value': [10.5, 20.3, 30.8],
|
'value': [10.5, 20.3, 30.8],
|
||||||
'created_at': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03'])
|
'created_at': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03'])
|
||||||
})
|
})
|
||||||
|
|
||||||
cls.db.create_table_from_df(cls.test_table, test_data, primary_key='id')
|
cls.db.create_table_from_df(cls.test_table, test_data, primary_key='id')
|
||||||
cls.db.insert_from_df(cls.test_table, test_data)
|
cls.db.insert_from_df(cls.test_table, test_data)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _create_test_database(cls):
|
def _create_test_database(cls):
|
||||||
"""创建测试数据库"""
|
"""创建测试数据库"""
|
||||||
# 使用临时连接创建数据库
|
|
||||||
temp_conn = pymysql.connect(
|
temp_conn = pymysql.connect(
|
||||||
host=cls.base_config['host'],
|
host=cls.base_config['host'],
|
||||||
port=cls.base_config['port'],
|
port=cls.base_config['port'],
|
||||||
@@ -54,7 +54,6 @@ class TestMySQLAgent(unittest.TestCase):
|
|||||||
password=cls.base_config['password'],
|
password=cls.base_config['password'],
|
||||||
charset='utf8mb4'
|
charset='utf8mb4'
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with temp_conn.cursor() as cursor:
|
with temp_conn.cursor() as cursor:
|
||||||
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}")
|
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}")
|
||||||
@@ -66,21 +65,14 @@ class TestMySQLAgent(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
"""清理测试数据库"""
|
"""清理测试环境"""
|
||||||
if hasattr(cls, 'db') and cls.db:
|
if hasattr(cls, 'db') and cls.db:
|
||||||
# 删除测试表
|
# 删除测试表
|
||||||
if cls.db.table_exists(cls.test_table):
|
if cls.db.table_exists(cls.test_table):
|
||||||
cls.db.drop_table(cls.test_table)
|
cls.db.drop_table(cls.test_table)
|
||||||
|
|
||||||
# 删除测试数据库
|
# 删除测试数据库
|
||||||
temp_conn = pymysql.connect(
|
temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4')
|
||||||
host=cls.base_config['host'],
|
|
||||||
port=cls.base_config['port'],
|
|
||||||
user=cls.base_config['user'],
|
|
||||||
password=cls.base_config['password'],
|
|
||||||
charset='utf8mb4'
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with temp_conn.cursor() as cursor:
|
with temp_conn.cursor() as cursor:
|
||||||
cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}")
|
cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}")
|
||||||
@@ -88,22 +80,24 @@ class TestMySQLAgent(unittest.TestCase):
|
|||||||
finally:
|
finally:
|
||||||
temp_conn.close()
|
temp_conn.close()
|
||||||
|
|
||||||
def test_01_connection(self):
|
def test_connection(self):
|
||||||
"""测试数据库连接"""
|
"""测试数据库连接"""
|
||||||
version = self.db.query_to_df("SELECT VERSION() as version")
|
version_df = self.db.query_to_df("SELECT VERSION() as version")
|
||||||
self.assertIsNotNone(version)
|
self.assertIsNotNone(version_df)
|
||||||
print(f"\nDatabase version: {version['version'].iloc[0]}")
|
self.assertEqual(len(version_df), 1)
|
||||||
print(f"Running on: {platform.system()} {platform.release()}")
|
print(f"数据库版本: {version_df['version'].iloc[0]}")
|
||||||
|
|
||||||
def test_02_query_to_df(self):
|
def test_query_to_df(self):
|
||||||
"""测试查询返回DataFrame"""
|
"""测试查询返回DataFrame"""
|
||||||
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id > %s", (1,))
|
df = self.db.query_to_df(
|
||||||
self.assertEqual(len(df), 2)
|
f"SELECT * FROM {self.test_table} WHERE id > %s",
|
||||||
|
params=(1,)
|
||||||
|
)
|
||||||
self.assertIsInstance(df, pd.DataFrame)
|
self.assertIsInstance(df, pd.DataFrame)
|
||||||
print("\nQuery result sample:")
|
self.assertEqual(len(df), 2) # id>1 的数据有2条
|
||||||
print(df.head())
|
self.assertIn('name', df.columns)
|
||||||
|
|
||||||
def test_03_insert_from_df(self):
|
def test_insert_from_df(self):
|
||||||
"""测试DataFrame插入"""
|
"""测试DataFrame插入"""
|
||||||
new_data = pd.DataFrame({
|
new_data = pd.DataFrame({
|
||||||
'id': [4, 5],
|
'id': [4, 5],
|
||||||
@@ -112,55 +106,55 @@ class TestMySQLAgent(unittest.TestCase):
|
|||||||
'created_at': pd.to_datetime(['2023-01-04', '2023-01-05'])
|
'created_at': pd.to_datetime(['2023-01-04', '2023-01-05'])
|
||||||
})
|
})
|
||||||
|
|
||||||
rows = self.db.insert_from_df(self.test_table, new_data)
|
inserted_rows = self.db.insert_from_df(self.test_table, new_data)
|
||||||
self.assertEqual(rows, 2)
|
self.assertEqual(inserted_rows, 2)
|
||||||
|
|
||||||
# 验证数据
|
# 验证插入结果
|
||||||
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id >= 4")
|
result_df = self.db.query_to_df(
|
||||||
self.assertEqual(len(df), 2)
|
f"SELECT name FROM {self.test_table} WHERE id IN (4,5)"
|
||||||
self.assertEqual(df['name'].tolist(), ['Test4', 'Test5'])
|
)
|
||||||
|
self.assertEqual(result_df['name'].tolist(), ['Test4', 'Test5'])
|
||||||
|
|
||||||
def test_04_update_from_df(self):
|
def test_update_from_df(self):
|
||||||
"""测试DataFrame更新"""
|
"""测试DataFrame更新"""
|
||||||
update_data = pd.DataFrame({
|
update_data = pd.DataFrame({
|
||||||
'id': [1, 2],
|
'id': [1, 2],
|
||||||
'name': ['Updated1', 'Updated2']
|
'name': ['Updated1', 'Updated2']
|
||||||
})
|
})
|
||||||
|
|
||||||
rows = self.db.update_from_df(self.test_table, update_data, 'id')
|
updated_rows = self.db.update_from_df(self.test_table, update_data, 'id')
|
||||||
self.assertGreaterEqual(rows, 2)
|
self.assertGreaterEqual(updated_rows, 2)
|
||||||
|
|
||||||
# 验证更新
|
# 验证更新结果
|
||||||
df = self.db.query_to_df(f"SELECT name FROM {self.test_table} WHERE id IN (1,2)")
|
result_df = self.db.query_to_df(
|
||||||
self.assertIn('Updated1', df['name'].values)
|
f"SELECT name FROM {self.test_table} WHERE id IN (1,2)"
|
||||||
self.assertIn('Updated2', df['name'].values)
|
)
|
||||||
|
self.assertIn('Updated1', result_df['name'].values)
|
||||||
|
self.assertIn('Updated2', result_df['name'].values)
|
||||||
|
|
||||||
def test_05_transaction(self):
|
def test_transaction(self):
|
||||||
"""测试事务处理"""
|
"""测试事务处理"""
|
||||||
conn = self.db.begin_transaction()
|
conn = self.db.begin_transaction()
|
||||||
try:
|
try:
|
||||||
# 执行多个操作
|
# 执行事务内操作
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(f"UPDATE {self.test_table} SET value = 99.9 WHERE id = 1")
|
cursor.execute(f"UPDATE {self.test_table} SET value = 99.9 WHERE id = 1")
|
||||||
cursor.execute(f"UPDATE {self.test_table} SET value = 88.8 WHERE id = 2")
|
cursor.execute(f"UPDATE {self.test_table} SET value = 88.8 WHERE id = 2")
|
||||||
|
|
||||||
# 验证事务内修改
|
|
||||||
cursor.execute(f"SELECT value FROM {self.test_table} WHERE id = 1")
|
|
||||||
self.assertEqual(cursor.fetchone()['value'], 99.9)
|
|
||||||
|
|
||||||
self.db.commit_transaction(conn)
|
self.db.commit_transaction(conn)
|
||||||
except Exception:
|
except Exception:
|
||||||
self.db.rollback_transaction(conn)
|
self.db.rollback_transaction(conn)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# 验证提交后的修改
|
# 验证事务提交结果
|
||||||
df = self.db.query_to_df(f"SELECT value FROM {self.test_table} WHERE id IN (1,2)")
|
result_df = self.db.query_to_df(
|
||||||
self.assertIn(99.9, df['value'].values)
|
f"SELECT value FROM {self.test_table} WHERE id IN (1,2)"
|
||||||
self.assertIn(88.8, df['value'].values)
|
)
|
||||||
|
self.assertIn(99.9, result_df['value'].values)
|
||||||
|
self.assertIn(88.8, result_df['value'].values)
|
||||||
|
|
||||||
def test_06_large_data(self):
|
def test_large_data_insert(self):
|
||||||
"""测试大数据量操作"""
|
"""测试大数据量插入"""
|
||||||
# 生成测试数据
|
# 生成1000行测试数据
|
||||||
large_data = pd.DataFrame({
|
large_data = pd.DataFrame({
|
||||||
'id': range(1000, 2000),
|
'id': range(1000, 2000),
|
||||||
'name': [f"Item_{i}" for i in range(1000, 2000)],
|
'name': [f"Item_{i}" for i in range(1000, 2000)],
|
||||||
@@ -168,59 +162,55 @@ class TestMySQLAgent(unittest.TestCase):
|
|||||||
'created_at': pd.date_range('2023-01-01', periods=1000)
|
'created_at': pd.date_range('2023-01-01', periods=1000)
|
||||||
})
|
})
|
||||||
|
|
||||||
# Windows平台使用更小的批次
|
# 根据平台自动调整批次大小
|
||||||
chunk_size = 100 if platform.system() == 'Windows' else 500
|
chunk_size = 100 if platform.system() == 'Windows' else 500
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
rows = self.db.insert_from_df(self.test_table, large_data, chunk_size=chunk_size)
|
inserted_rows = self.db.insert_from_df(
|
||||||
|
self.test_table,
|
||||||
|
large_data,
|
||||||
|
chunk_size=chunk_size
|
||||||
|
)
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
self.assertEqual(rows, 1000)
|
self.assertEqual(inserted_rows, 1000)
|
||||||
print(f"\nInserted 1000 rows in {elapsed:.2f}s (chunk_size={chunk_size})")
|
print(f"插入1000行数据耗时: {elapsed:.2f}秒 (批次大小: {chunk_size})")
|
||||||
|
|
||||||
# 验证数据
|
def test_concurrent_access(self):
|
||||||
df = self.db.query_to_df(f"SELECT COUNT(*) as cnt FROM {self.test_table} WHERE id >= 1000")
|
|
||||||
self.assertEqual(df['cnt'].iloc[0], 1000)
|
|
||||||
|
|
||||||
def test_07_concurrent_access(self):
|
|
||||||
"""测试并发访问"""
|
"""测试并发访问"""
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
|
|
||||||
def worker(i):
|
def query_worker(i):
|
||||||
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id = %s", (i % 5 + 1,))
|
"""并发查询工作函数"""
|
||||||
|
df = self.db.query_to_df(
|
||||||
|
f"SELECT * FROM {self.test_table} WHERE id = %s",
|
||||||
|
params=(i % 3 + 1,) # 查询id=1,2,3循环
|
||||||
|
)
|
||||||
return len(df)
|
return len(df)
|
||||||
|
|
||||||
|
# 20个线程执行100次查询
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
with ThreadPoolExecutor(max_workers=20) as executor:
|
with ThreadPoolExecutor(max_workers=20) as executor:
|
||||||
results = list(executor.map(worker, range(100)))
|
results = list(executor.map(query_worker, range(100)))
|
||||||
|
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
self.assertEqual(sum(results), 100)
|
|
||||||
print(f"\nCompleted 100 concurrent queries in {elapsed:.2f}s")
|
self.assertEqual(sum(results), 100) # 每次查询应返回1行
|
||||||
|
print(f"100次并发查询耗时: {elapsed:.2f}秒")
|
||||||
|
|
||||||
|
|
||||||
class TestPlatformSpecific(unittest.TestCase):
|
class TestPlatformSpecific(unittest.TestCase):
|
||||||
|
"""平台特定功能测试"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
"""创建临时测试数据库"""
|
cls.test_db_name = f"test_platform_db_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||||
cls.test_db_name = "test_db_platform_" + datetime.now().strftime("%Y%m%d%H%M%S")
|
|
||||||
cls.base_config = {
|
cls.base_config = {
|
||||||
'host': 'localhost',
|
'host': 'localhost',
|
||||||
'port': 3306,
|
'port': 3306,
|
||||||
'user': 'root',
|
'user': 'root',
|
||||||
'password': '123123',
|
'password': '123123'
|
||||||
'max_connections': 10
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# 创建数据库
|
# 创建测试数据库
|
||||||
temp_conn = pymysql.connect(
|
temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4')
|
||||||
host=cls.base_config['host'],
|
|
||||||
port=cls.base_config['port'],
|
|
||||||
user=cls.base_config['user'],
|
|
||||||
password=cls.base_config['password'],
|
|
||||||
charset='utf8mb4'
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with temp_conn.cursor() as cursor:
|
with temp_conn.cursor() as cursor:
|
||||||
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}")
|
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}")
|
||||||
@@ -230,15 +220,8 @@ class TestPlatformSpecific(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
"""删除临时测试数据库"""
|
"""清理测试数据库"""
|
||||||
temp_conn = pymysql.connect(
|
temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4')
|
||||||
host=cls.base_config['host'],
|
|
||||||
port=cls.base_config['port'],
|
|
||||||
user=cls.base_config['user'],
|
|
||||||
password=cls.base_config['password'],
|
|
||||||
charset='utf8mb4'
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with temp_conn.cursor() as cursor:
|
with temp_conn.cursor() as cursor:
|
||||||
cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}")
|
cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}")
|
||||||
@@ -249,42 +232,49 @@ class TestPlatformSpecific(unittest.TestCase):
|
|||||||
def test_windows_timeout(self):
|
def test_windows_timeout(self):
|
||||||
"""测试Windows平台超时处理"""
|
"""测试Windows平台超时处理"""
|
||||||
if platform.system() != 'Windows':
|
if platform.system() != 'Windows':
|
||||||
self.skipTest("Only runs on Windows")
|
self.skipTest("仅在Windows平台运行")
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
**self.base_config,
|
**self.base_config,
|
||||||
'database': self.test_db_name,
|
'database': self.test_db_name,
|
||||||
'connect_timeout': 1,
|
'connect_timeout': 1,
|
||||||
'read_timeout': 1
|
'read_timeout': 1,
|
||||||
|
'write_timeout': 1
|
||||||
}
|
}
|
||||||
|
|
||||||
db = MySQLAgent(config)
|
db = MySQLAgent(config)
|
||||||
|
|
||||||
# 测试短超时查询
|
# 执行会超时的查询(SLEEP(2)超过1秒超时设置)
|
||||||
start_time = time.time()
|
with self.assertRaises((pymysql.OperationalError, TimeoutError)) as ctx:
|
||||||
try:
|
try:
|
||||||
db.query_to_df("SELECT SLEEP(2)")
|
db.query_to_df("SELECT SLEEP(2)")
|
||||||
self.fail("Should have timed out")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.assertIn("timed out", str(e))
|
# 提取底层异常信息(可能被包装)
|
||||||
print(f"\nWindows timeout test: {str(e)}")
|
while hasattr(e, 'args') and len(e.args) > 0 and isinstance(e.args[0], Exception):
|
||||||
|
e = e.args[0]
|
||||||
|
raise e
|
||||||
|
|
||||||
def test_macos_ssl(self):
|
error_msg = str(ctx.exception)
|
||||||
"""测试macOS SSL连接"""
|
self.assertTrue(
|
||||||
|
"timed out" in error_msg or
|
||||||
|
"timeout" in error_msg or
|
||||||
|
"HY000" in error_msg, # MySQL超时错误码
|
||||||
|
f"未检测到超时异常,实际异常: {error_msg}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_macos_ssl_connection(self):
|
||||||
|
"""测试macOS平台SSL连接"""
|
||||||
if platform.system() != 'Darwin':
|
if platform.system() != 'Darwin':
|
||||||
self.skipTest("Only runs on macOS")
|
self.skipTest("仅在macOS平台运行")
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
**self.base_config,
|
**self.base_config,
|
||||||
'database': self.test_db_name,
|
'database': self.test_db_name,
|
||||||
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
|
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
|
||||||
}
|
}
|
||||||
|
|
||||||
db = MySQLAgent(config)
|
db = MySQLAgent(config)
|
||||||
version = db.query_to_df("SELECT VERSION() as version")
|
version_df = db.query_to_df("SELECT VERSION() as version")
|
||||||
self.assertIsNotNone(version)
|
self.assertIsNotNone(version_df)
|
||||||
print(f"\nmacOS SSL connection successful: {version['version'].iloc[0]}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main(verbosity=2)
|
||||||
+389
-356
@@ -3,13 +3,13 @@ import sys
|
|||||||
import platform
|
import platform
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pymysql
|
import pymysql
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
from pymysql import cursors
|
from pymysql import cursors
|
||||||
from pymysql.err import MySQLError
|
from pymysql.err import MySQLError
|
||||||
from dbutils.pooled_db import PooledDB
|
from typing import Union, List, Dict, Any, Optional, Tuple, Literal
|
||||||
from typing import Union, List, Dict, Any, Optional, Tuple
|
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import numpy as np
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# 导入日志系统
|
# 导入日志系统
|
||||||
@@ -20,7 +20,7 @@ class MySQLAgent:
|
|||||||
"""
|
"""
|
||||||
全平台兼容的MySQL数据库操作类
|
全平台兼容的MySQL数据库操作类
|
||||||
支持Windows/macOS/Linux系统
|
支持Windows/macOS/Linux系统
|
||||||
配置参数从外部传入
|
配置参数从外部传入,不使用连接池和事务管理
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
@@ -34,30 +34,14 @@ class MySQLAgent:
|
|||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self, config: dict):
|
def __init__(self, config: dict):
|
||||||
"""
|
"""初始化MySQL数据库连接(原有逻辑完全保留)"""
|
||||||
初始化MySQL数据库连接
|
if hasattr(self, 'config') and self.config:
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict): 数据库配置字典,包含以下键:
|
|
||||||
- host: 数据库主机
|
|
||||||
- port: 端口
|
|
||||||
- user: 用户名
|
|
||||||
- password: 密码
|
|
||||||
- database: 数据库名
|
|
||||||
- [可选] charset: 字符集(默认utf8mb4)
|
|
||||||
- [可选] max_connections: 最大连接数(默认5)
|
|
||||||
- [可选] connect_timeout: 连接超时(秒)
|
|
||||||
- [可选] read_timeout: 读取超时(秒)
|
|
||||||
- [可选] write_timeout: 写入超时(秒)
|
|
||||||
- [可选] ssl: SSL配置
|
|
||||||
"""
|
|
||||||
if hasattr(self, '_pool') and self._pool:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# 基础配置校验
|
# 基础配置校验
|
||||||
required_keys = ['host', 'port', 'user', 'password', 'database']
|
required_keys = ['host', 'port', 'user', 'password', 'database']
|
||||||
if not all(key in config for key in required_keys):
|
if not all(key in config for key in required_keys):
|
||||||
log.warning(f"数据库配置缺少必要参数,当前配置: {config}")
|
log.warning(f"数据库配置缺少必要参数,当前数据库链接信息为:{config}")
|
||||||
raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}")
|
raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}")
|
||||||
|
|
||||||
self.config = {
|
self.config = {
|
||||||
@@ -67,7 +51,6 @@ class MySQLAgent:
|
|||||||
'password': config['password'],
|
'password': config['password'],
|
||||||
'database': config['database'],
|
'database': config['database'],
|
||||||
'charset': config.get('charset', 'utf8mb4'),
|
'charset': config.get('charset', 'utf8mb4'),
|
||||||
'cursorclass': cursors.DictCursor,
|
|
||||||
'autocommit': True,
|
'autocommit': True,
|
||||||
'connect_timeout': config.get('connect_timeout', 10),
|
'connect_timeout': config.get('connect_timeout', 10),
|
||||||
'read_timeout': config.get('read_timeout', 30),
|
'read_timeout': config.get('read_timeout', 30),
|
||||||
@@ -79,86 +62,57 @@ class MySQLAgent:
|
|||||||
current_platform = platform.system()
|
current_platform = platform.system()
|
||||||
self.log = log.bind(module=f"MySQLAgent({current_platform})")
|
self.log = log.bind(module=f"MySQLAgent({current_platform})")
|
||||||
|
|
||||||
# 创建连接池
|
|
||||||
self.pool_size = config.get('max_connections', 5)
|
|
||||||
self._pool = self._create_pool()
|
|
||||||
|
|
||||||
def _create_pool(self) -> PooledDB:
|
|
||||||
"""创建连接池"""
|
|
||||||
try:
|
|
||||||
# 线程安全的连接创建函数
|
|
||||||
def connect():
|
|
||||||
conn = pymysql.connect(**self.config)
|
|
||||||
conn.threadsafety = 1 # 显式设置线程安全级别
|
|
||||||
return conn
|
|
||||||
|
|
||||||
pool = PooledDB(
|
|
||||||
creator=connect,
|
|
||||||
mincached=1,
|
|
||||||
maxcached=3,
|
|
||||||
maxconnections=self.pool_size,
|
|
||||||
blocking=True,
|
|
||||||
ping=1 # 每次获取连接时ping数据库
|
|
||||||
)
|
|
||||||
|
|
||||||
self.log.info("连接池创建成功")
|
|
||||||
return pool
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.log.critical("连接池创建失败", error=str(e), exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_connection(self) -> pymysql.connections.Connection:
|
def get_connection(self) -> pymysql.connections.Connection:
|
||||||
"""获取数据库连接(修复字符集方法缺失问题)"""
|
"""获取数据库连接(原有逻辑完全保留)"""
|
||||||
try:
|
try:
|
||||||
conn = self._pool.connection()
|
conn = pymysql.connect(**self.config)
|
||||||
|
|
||||||
# 为连接添加字符集方法(兼容SQLAlchemy)
|
# 为连接添加 character_set_name 方法
|
||||||
if not hasattr(conn, 'character_set_name'):
|
if not hasattr(conn, 'character_set_name'):
|
||||||
def _character_set_name():
|
def _character_set_name():
|
||||||
return self.config.get('charset', 'utf8mb4')
|
return self.config.get('charset', 'utf8mb4')
|
||||||
|
|
||||||
conn.character_set_name = _character_set_name
|
conn.character_set_name = _character_set_name
|
||||||
|
|
||||||
# macOS平台SSL特殊处理
|
# macOS需要特殊处理SSL
|
||||||
if platform.system() == 'Darwin' and self.config.get('ssl'):
|
if platform.system() == 'Darwin' and self.config.get('ssl'):
|
||||||
conn.ping(reconnect=True)
|
conn.ping(reconnect=True)
|
||||||
|
|
||||||
self.log.trace("获取数据库连接成功")
|
self.log.trace("Database connection obtained")
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = str(e)
|
error_msg = str(e)
|
||||||
# Windows平台连接超时重试
|
|
||||||
if platform.system() == 'Windows' and "timed out" in error_msg:
|
if platform.system() == 'Windows' and "timed out" in error_msg:
|
||||||
self.log.warning("Windows连接超时,尝试重试...")
|
self.log.warning("Windows connection timeout, retrying...")
|
||||||
return self._retry_connection()
|
return self._retry_connection()
|
||||||
|
|
||||||
self.log.error("获取连接失败", error=error_msg, exc_info=True)
|
self.log.error("Connection failed", error=error_msg, exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _retry_connection(self, max_retries: int = 3) -> pymysql.connections.Connection:
|
def _retry_connection(self, max_retries: int = 3) -> Any | None:
|
||||||
"""Windows平台连接重试机制"""
|
"""Windows平台连接重试机制(原有逻辑完全保留)"""
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
try:
|
try:
|
||||||
conn = self._pool.connection()
|
conn = pymysql.connect(**self.config)
|
||||||
self.log.info(f"第{attempt + 1}次尝试连接成功")
|
self.log.info(f"Connection established after {attempt + 1} attempts")
|
||||||
return conn
|
return conn
|
||||||
except Exception:
|
except Exception:
|
||||||
if attempt == max_retries - 1:
|
if attempt == max_retries - 1:
|
||||||
raise
|
raise
|
||||||
import time
|
import time
|
||||||
time.sleep(1) # 重试间隔1秒
|
time.sleep(1)
|
||||||
|
|
||||||
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
|
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
|
||||||
parse_dates: Union[List[str], bool] = True) -> pd.DataFrame:
|
parse_dates: Union[List[str], bool] = True) -> pd.DataFrame:
|
||||||
"""执行SQL查询并返回DataFrame(优化连接管理)"""
|
"""执行SQL查询并返回DataFrame(原有逻辑完全保留)"""
|
||||||
conn = None
|
|
||||||
try:
|
try:
|
||||||
self.log.debug("执行SQL查询", sql=sql)
|
self.log.debug("Executing SQL query", sql=sql)
|
||||||
|
|
||||||
|
# 获取连接并确保字符集方法存在
|
||||||
conn = self.get_connection()
|
conn = self.get_connection()
|
||||||
|
|
||||||
# 创建SQLAlchemy引擎(使用静态池避免连接重复创建)
|
# 创建SQLAlchemy引擎
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.pool import StaticPool
|
from sqlalchemy.pool import StaticPool
|
||||||
engine = create_engine(
|
engine = create_engine(
|
||||||
@@ -170,180 +124,361 @@ class MySQLAgent:
|
|||||||
|
|
||||||
# 执行查询
|
# 执行查询
|
||||||
df = pd.read_sql(sql, engine, params=params, parse_dates=parse_dates)
|
df = pd.read_sql(sql, engine, params=params, parse_dates=parse_dates)
|
||||||
self.log.info(f"查询成功,返回{len(df)}行数据")
|
self.log.info("Query executed successfully", rows=len(df))
|
||||||
|
|
||||||
return df
|
return df
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log.error(f"SQL查询失败{sql}", sql=sql, params=params, error=str(e), exc_info=True)
|
self.log.error("SQL query failed", sql=sql, params=params, error=str(e), exc_info=True)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
# 确保连接释放回池
|
if 'engine' in locals():
|
||||||
if conn:
|
engine.dispose()
|
||||||
try:
|
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
|
||||||
self.log.warning("关闭连接失败", error=str(e))
|
|
||||||
|
|
||||||
def insert_from_df(self, table_name: str, df: pd.DataFrame,
|
def insert_from_df(self, table_name: str, df: pd.DataFrame,
|
||||||
chunk_size: int = 1000, replace: bool = False) -> int:
|
chunk_size: int = 1000, replace: bool = False, # 保留replace参数
|
||||||
"""将DataFrame数据插入到数据库表(优化批量处理)"""
|
ignore_duplicates: bool = None) -> int: # 新增ignore_duplicates参数
|
||||||
|
"""
|
||||||
|
兼容旧接口的通用插入方法:保留replace参数,同时支持新的ignore_duplicates
|
||||||
|
自动处理重复数据,对所有数据源通用
|
||||||
|
"""
|
||||||
|
# 【兼容性处理】如果未指定ignore_duplicates,用replace参数推导(replace=True时不忽略重复)
|
||||||
|
if ignore_duplicates is None:
|
||||||
|
ignore_duplicates = not replace # 旧逻辑中replace=True表示替换,即不忽略重复
|
||||||
|
|
||||||
if df.empty:
|
if df.empty:
|
||||||
self.log.warning(f"尝试插入空DataFrame到表{table_name}")
|
self.log.warning("Attempted to insert empty DataFrame", table=table_name)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
self.log.debug(f"准备插入DataFrame到表{table_name}", rows=len(df), chunk_size=chunk_size)
|
conn = None
|
||||||
|
cursor = None
|
||||||
# 根据平台自动调整批次大小
|
total_inserted = 0
|
||||||
current_platform = platform.system()
|
total_duplicated = 0
|
||||||
if current_platform == 'Windows' and chunk_size > 500:
|
total_failed = 0
|
||||||
chunk_size = 500
|
|
||||||
self.log.debug(f"Windows平台自动调整批次大小为{chunk_size}")
|
|
||||||
elif current_platform == 'Linux' and chunk_size < 1000:
|
|
||||||
chunk_size = 1000
|
|
||||||
self.log.debug(f"Linux平台自动调整批次大小为{chunk_size}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
method = 'replace' if replace else 'append'
|
# 1. 建立数据库连接
|
||||||
total_rows = 0
|
|
||||||
conn = self.get_connection()
|
conn = self.get_connection()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
self.log.debug(f"Established connection for inserting into {table_name}")
|
||||||
|
|
||||||
# 创建SQLAlchemy引擎
|
# 2. 获取数据库表的实际列名
|
||||||
from sqlalchemy import create_engine
|
cursor.execute(f"SHOW COLUMNS FROM `{table_name}`")
|
||||||
from sqlalchemy.pool import StaticPool
|
columns_info = cursor.fetchall()
|
||||||
engine = create_engine(
|
db_columns = [col[0] for col in columns_info]
|
||||||
"mysql+pymysql://",
|
self.log.debug(f"Table {table_name} has columns: {db_columns}")
|
||||||
creator=lambda: conn,
|
|
||||||
poolclass=StaticPool,
|
# 3. 数据预处理:统一处理空值
|
||||||
connect_args={
|
cleaned_df = df.replace(
|
||||||
'charset': self.config.get('charset', 'utf8mb4'),
|
[None, np.nan, pd.NA, 'nan', 'NaN', 'NAN', ''],
|
||||||
'autocommit': True
|
None
|
||||||
}
|
).copy()
|
||||||
|
|
||||||
|
# 4. 字段匹配:只保留与数据库匹配的列
|
||||||
|
df_columns = cleaned_df.columns.tolist()
|
||||||
|
matched_columns = [col for col in df_columns if col in db_columns]
|
||||||
|
unmatched_columns = [col for col in df_columns if col not in db_columns]
|
||||||
|
|
||||||
|
if unmatched_columns:
|
||||||
|
self.log.warning(
|
||||||
|
f"Table {table_name} dropping unmatched columns",
|
||||||
|
unmatched_columns=unmatched_columns,
|
||||||
|
count=len(unmatched_columns)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not matched_columns:
|
||||||
|
self.log.warning(f"No matched columns for {table_name}, abort insertion")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
filtered_df = cleaned_df[matched_columns].copy()
|
||||||
|
total_to_insert = len(filtered_df)
|
||||||
|
self.log.debug(
|
||||||
|
f"Filtered DataFrame for {table_name}: {total_to_insert} rows to insert"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. 处理复杂类型(dict/list转JSON)
|
||||||
|
for col in filtered_df.columns:
|
||||||
|
has_complex_type = filtered_df[col].apply(
|
||||||
|
lambda x: isinstance(x, (dict, list)) if x is not None else False
|
||||||
|
).any()
|
||||||
|
|
||||||
|
if has_complex_type:
|
||||||
|
self.log.debug(f"Column {col} in {table_name} has complex type, converting to JSON")
|
||||||
|
filtered_df.loc[:, col] = filtered_df[col].apply(
|
||||||
|
lambda x: json.dumps(x, ensure_ascii=False) if x is not None else x
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. 构建通用插入SQL
|
||||||
|
columns_str = ', '.join([f"`{col}`" for col in filtered_df.columns])
|
||||||
|
placeholders = ', '.join(['%s'] * len(filtered_df.columns))
|
||||||
|
insert_sql = f"INSERT INTO `{table_name}` ({columns_str}) VALUES ({placeholders})"
|
||||||
|
self.log.trace(f"Generated insert SQL for {table_name}: {insert_sql}")
|
||||||
|
|
||||||
|
# 7. 逐条插入(确保能捕获单条重复错误)
|
||||||
|
records = filtered_df.to_dict('records')
|
||||||
|
indices = filtered_df.index.tolist()
|
||||||
|
|
||||||
|
for i, (record, idx) in enumerate(zip(records, indices)):
|
||||||
try:
|
try:
|
||||||
for i in range(0, len(df), chunk_size):
|
data = tuple(record[col] for col in filtered_df.columns)
|
||||||
chunk = df.iloc[i:i + chunk_size].copy() # 使用copy避免SettingWithCopyWarning
|
cursor.execute(insert_sql, data)
|
||||||
|
total_inserted += 1
|
||||||
|
|
||||||
# macOS平台datetime特殊处理
|
if (i + 1) % 100 == 0:
|
||||||
if platform.system() == 'Darwin':
|
self.log.trace(
|
||||||
for col in chunk.select_dtypes(include=['datetime64']):
|
f"Inserted {i + 1}/{total_to_insert} rows into {table_name}"
|
||||||
chunk[col] = chunk[col].dt.strftime('%Y-%m-%d %H:%M:%S')
|
|
||||||
|
|
||||||
chunk.to_sql(
|
|
||||||
table_name,
|
|
||||||
engine,
|
|
||||||
if_exists=method,
|
|
||||||
index=False,
|
|
||||||
method='multi'
|
|
||||||
)
|
)
|
||||||
total_rows += len(chunk)
|
|
||||||
method = 'append' # 首次后使用追加模式
|
|
||||||
self.log.trace(f"插入第{i // chunk_size + 1}批数据", rows=len(chunk), total=total_rows)
|
|
||||||
|
|
||||||
self.log.info(f"数据插入成功,表{table_name}共插入{total_rows}行")
|
except MySQLError as e:
|
||||||
return total_rows
|
# 8. 捕获重复错误(MySQL错误码1062)
|
||||||
finally:
|
if e.args[0] == 1062:
|
||||||
engine.dispose()
|
total_duplicated += 1
|
||||||
conn.close()
|
short_record = {
|
||||||
|
k: (str(v)[:100] + '...') if isinstance(v, (str, dict, list)) else v
|
||||||
|
for k, v in record.items()
|
||||||
|
}
|
||||||
|
self.log.warning(
|
||||||
|
f"Skipped duplicate record in {table_name}",
|
||||||
|
index=idx,
|
||||||
|
error_msg=e.args[1],
|
||||||
|
record=short_record
|
||||||
|
)
|
||||||
|
if not ignore_duplicates:
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
# 其他数据库错误
|
||||||
|
total_failed += 1
|
||||||
|
self.log.error(
|
||||||
|
f"Failed to insert record in {table_name}",
|
||||||
|
index=idx,
|
||||||
|
error_code=e.args[0],
|
||||||
|
error_msg=e.args[1],
|
||||||
|
record=record
|
||||||
|
)
|
||||||
|
if not ignore_duplicates:
|
||||||
|
raise
|
||||||
|
|
||||||
|
# 提交事务
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
# 9. 插入结果统计
|
||||||
|
self.log.info(
|
||||||
|
f"Insertion summary for {table_name}",
|
||||||
|
total_to_insert=total_to_insert,
|
||||||
|
total_inserted=total_inserted,
|
||||||
|
total_duplicated=total_duplicated,
|
||||||
|
total_failed=total_failed
|
||||||
|
)
|
||||||
|
|
||||||
|
return total_inserted
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log.error(f"数据插入失败,表{table_name}", error=str(e), exc_info=True)
|
if conn:
|
||||||
|
conn.rollback()
|
||||||
|
self.log.error(f"Batch insertion failed for {table_name}", error=str(e), exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
finally:
|
||||||
|
if cursor:
|
||||||
|
cursor.close()
|
||||||
|
if conn:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def _get_primary_key(self, table_name: str, cursor) -> Optional[str]:
|
||||||
|
"""【新增辅助方法】获取表的主键(用于replace逻辑的去重)"""
|
||||||
|
try:
|
||||||
|
cursor.execute("""
|
||||||
|
SELECT COLUMN_NAME
|
||||||
|
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
|
||||||
|
WHERE TABLE_SCHEMA = %s
|
||||||
|
AND TABLE_NAME = %s
|
||||||
|
AND CONSTRAINT_NAME = 'PRIMARY'
|
||||||
|
""", (self.config['database'], table_name))
|
||||||
|
result = cursor.fetchone()
|
||||||
|
return result[0] if result else None
|
||||||
|
except Exception as e:
|
||||||
|
self.log.warning(f"Failed to get primary key for {table_name}", error=str(e))
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_table_detailed_info(self, table_name: str) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""获取表的详细结构信息(原有逻辑完全保留,供其他方法调用)"""
|
||||||
|
sql = """
|
||||||
|
SELECT column_name, data_type, character_maximum_length
|
||||||
|
FROM information_schema.columns
|
||||||
|
WHERE table_schema = %s \
|
||||||
|
AND table_name = %s \
|
||||||
|
"""
|
||||||
|
params = (self.config['database'], table_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(sql, params)
|
||||||
|
result = cursor.fetchall()
|
||||||
|
|
||||||
|
# 强制转换为列表,避免游标类型导致的解析问题
|
||||||
|
result_list = list(result)
|
||||||
|
if not result_list:
|
||||||
|
self.log.error("No columns found in table", table=table_name)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
schema = {}
|
||||||
|
for row in result_list:
|
||||||
|
# 确保正确提取字段名(兼容元组格式)
|
||||||
|
col_name = str(row[0]).strip() # 强制转为字符串并去空格
|
||||||
|
data_type = str(row[1]).strip()
|
||||||
|
max_length = row[2] if row[2] else None
|
||||||
|
|
||||||
|
schema[col_name] = {
|
||||||
|
'type': data_type,
|
||||||
|
'max_length': max_length
|
||||||
|
}
|
||||||
|
|
||||||
|
self.log.debug("Successfully fetched table schema",
|
||||||
|
table=table_name,
|
||||||
|
columns=list(schema.keys()))
|
||||||
|
return schema
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
conn.close()
|
||||||
|
except Exception as e:
|
||||||
|
self.log.error("Failed to get table detailed info",
|
||||||
|
table=table_name,
|
||||||
|
error=str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _validate_and_clean_data(self, df: pd.DataFrame, table_name: str,
|
||||||
|
table_schema: Dict[str, Dict[str, Any]]) -> pd.DataFrame:
|
||||||
|
"""数据校验与清洗(原有逻辑完全保留,供其他方法调用)"""
|
||||||
|
# 1. 字段过滤:只保留表中存在的字段
|
||||||
|
df_columns = df.columns.tolist()
|
||||||
|
table_columns = list(table_schema.keys())
|
||||||
|
|
||||||
|
valid_columns = [col for col in df_columns if col in table_columns]
|
||||||
|
invalid_columns = [col for col in df_columns if col not in table_columns]
|
||||||
|
|
||||||
|
if invalid_columns:
|
||||||
|
self.log.warning("Dropping invalid columns not present in table",
|
||||||
|
table=table_name,
|
||||||
|
invalid_columns=invalid_columns,
|
||||||
|
count=len(invalid_columns))
|
||||||
|
|
||||||
|
cleaned_df = df[valid_columns].copy()
|
||||||
|
if cleaned_df.empty:
|
||||||
|
return cleaned_df
|
||||||
|
|
||||||
|
# 2. 处理每个字段的数据
|
||||||
|
for col in valid_columns:
|
||||||
|
col_info = table_schema[col]
|
||||||
|
data_type = col_info['type']
|
||||||
|
max_length = col_info['max_length']
|
||||||
|
|
||||||
|
# 2.1 处理空值
|
||||||
|
if cleaned_df[col].isnull().any():
|
||||||
|
# 根据字段类型设置默认值
|
||||||
|
default_value = '' if data_type in ['varchar', 'char', 'text'] else None
|
||||||
|
cleaned_df[col].fillna(default_value, inplace=True)
|
||||||
|
self.log.debug("Replaced null values",
|
||||||
|
table=table_name,
|
||||||
|
column=col,
|
||||||
|
default_value=default_value,
|
||||||
|
count=cleaned_df[col].isnull().sum())
|
||||||
|
|
||||||
|
# 2.2 处理字符串类型的超长字段
|
||||||
|
if data_type in ['varchar', 'char'] and max_length:
|
||||||
|
# 确保是字符串类型
|
||||||
|
cleaned_df[col] = cleaned_df[col].astype(str)
|
||||||
|
# 截断超长内容
|
||||||
|
too_long_mask = cleaned_df[col].str.len() > max_length
|
||||||
|
if too_long_mask.any():
|
||||||
|
cleaned_df.loc[too_long_mask, col] = cleaned_df.loc[too_long_mask, col].str.slice(0, max_length)
|
||||||
|
self.log.warning("Truncated overlength values",
|
||||||
|
table=table_name,
|
||||||
|
column=col,
|
||||||
|
max_length=max_length,
|
||||||
|
count=too_long_mask.sum())
|
||||||
|
|
||||||
|
# 2.3 处理日期时间类型
|
||||||
|
if data_type in ['datetime', 'timestamp']:
|
||||||
|
try:
|
||||||
|
# 尝试转换为datetime类型
|
||||||
|
cleaned_df[col] = pd.to_datetime(cleaned_df[col])
|
||||||
|
except Exception as e:
|
||||||
|
self.log.warning("Failed to convert to datetime, using current time",
|
||||||
|
table=table_name,
|
||||||
|
column=col,
|
||||||
|
error=str(e))
|
||||||
|
# 转换失败的用当前时间替代
|
||||||
|
invalid_mask = pd.to_datetime(cleaned_df[col], errors='coerce').isna()
|
||||||
|
cleaned_df.loc[invalid_mask, col] = datetime.now()
|
||||||
|
|
||||||
|
return cleaned_df
|
||||||
|
|
||||||
def update_from_df(self, table_name: str, df: pd.DataFrame,
|
def update_from_df(self, table_name: str, df: pd.DataFrame,
|
||||||
key_columns: Union[str, List[str]]) -> int:
|
key_columns: Union[str, List[str]]) -> int:
|
||||||
"""使用DataFrame数据更新数据库表(优化事务处理)"""
|
"""使用DataFrame数据更新数据库表(原有逻辑完全保留)"""
|
||||||
if df.empty:
|
if df.empty:
|
||||||
self.log.warning(f"尝试用空DataFrame更新表{table_name}")
|
self.log.warning("Attempted to update with empty DataFrame", table=table_name)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
self.log.debug(f"准备从DataFrame更新表{table_name}", key_columns=key_columns, rows=len(df))
|
self.log.debug("Preparing to update table from DataFrame",
|
||||||
|
table=table_name,
|
||||||
|
key_columns=key_columns,
|
||||||
|
rows=len(df))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if isinstance(key_columns, str):
|
if isinstance(key_columns, str):
|
||||||
key_columns = [key_columns]
|
key_columns = [key_columns]
|
||||||
|
|
||||||
# 验证关键列存在性
|
|
||||||
missing_keys = [key for key in key_columns if key not in df.columns]
|
|
||||||
if missing_keys:
|
|
||||||
raise ValueError(f"DataFrame中缺少关键列: {missing_keys}")
|
|
||||||
|
|
||||||
total_updated = 0
|
total_updated = 0
|
||||||
conn = self.begin_transaction()
|
with self.get_connection() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
try:
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
# 获取表结构信息
|
# 获取表结构信息
|
||||||
table_info = self._get_table_info(table_name)
|
table_info = self._get_table_detailed_info(table_name)
|
||||||
valid_columns = [col for col in df.columns if col in table_info]
|
columns = [col for col in df.columns if col in table_info]
|
||||||
if not valid_columns:
|
|
||||||
self.log.warning(f"DataFrame列与表{table_name}无匹配")
|
# 构建UPDATE语句模板
|
||||||
|
set_clause = ', '.join([f"{col}=%s" for col in columns if col not in key_columns])
|
||||||
|
where_clause = ' AND '.join([f"{col}=%s" for col in key_columns])
|
||||||
|
|
||||||
|
if not set_clause:
|
||||||
|
self.log.warning("No columns to update", table=table_name)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# 构建UPDATE语句
|
update_sql = f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}"
|
||||||
set_clause = ', '.join([f"`{col}`=%s" for col in valid_columns if col not in key_columns])
|
self.log.trace("Generated update SQL", sql=update_sql)
|
||||||
where_clause = ' AND '.join([f"`{col}`=%s" for col in key_columns])
|
|
||||||
update_sql = f"UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}"
|
|
||||||
self.log.trace("生成更新SQL", sql=update_sql)
|
|
||||||
|
|
||||||
# 准备更新数据
|
# 准备数据
|
||||||
update_data = []
|
update_data = []
|
||||||
for _, row in df.iterrows():
|
for _, row in df.iterrows():
|
||||||
set_values = [row[col] for col in valid_columns if col not in key_columns]
|
set_values = [row[col] for col in columns if col not in key_columns]
|
||||||
key_values = [row[col] for col in key_columns]
|
key_values = [row[col] for col in key_columns]
|
||||||
update_data.append(tuple(set_values + key_values))
|
update_data.append(tuple(set_values + key_values))
|
||||||
|
|
||||||
# 执行批量更新
|
# 执行批量更新
|
||||||
cursor.executemany(update_sql, update_data)
|
cursor.executemany(update_sql, update_data)
|
||||||
total_updated = cursor.rowcount
|
total_updated = cursor.rowcount
|
||||||
self.commit_transaction(conn)
|
conn.commit()
|
||||||
self.log.info(f"数据更新成功,表{table_name}共更新{total_updated}行")
|
|
||||||
|
self.log.info("Data updated successfully",
|
||||||
|
table=table_name,
|
||||||
|
rows_updated=total_updated)
|
||||||
return total_updated
|
return total_updated
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.rollback_transaction(conn)
|
self.log.error("Data update failed",
|
||||||
raise
|
table=table_name,
|
||||||
|
error=str(e),
|
||||||
except Exception as e:
|
exc_info=True)
|
||||||
self.log.error(f"数据更新失败,表{table_name}", error=str(e), exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _get_table_info(self, table_name: str) -> Dict[str, str]:
|
|
||||||
"""获取表结构信息(优化SQL安全性)"""
|
|
||||||
sql = """
|
|
||||||
SELECT column_name, data_type
|
|
||||||
FROM information_schema.columns
|
|
||||||
WHERE table_schema = %s \
|
|
||||||
AND table_name = %s \
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
with self.get_connection() as conn:
|
|
||||||
with conn.cursor() as cursor:
|
|
||||||
cursor.execute(sql, (self.config['database'], table_name))
|
|
||||||
result = cursor.fetchall()
|
|
||||||
return {row['column_name']: row['data_type'] for row in result}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.log.error(f"获取表{table_name}结构失败", error=str(e))
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]:
|
def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]:
|
||||||
"""推断DataFrame各列的SQL类型(扩展类型映射)"""
|
"""推断DataFrame各列的SQL类型(原有逻辑完全保留)"""
|
||||||
type_mapping = {
|
type_mapping = {
|
||||||
'int64': 'BIGINT',
|
'int64': 'BIGINT',
|
||||||
'int32': 'INT',
|
|
||||||
'int16': 'SMALLINT',
|
|
||||||
'int8': 'TINYINT',
|
|
||||||
'uint64': 'BIGINT UNSIGNED',
|
|
||||||
'float64': 'DOUBLE',
|
'float64': 'DOUBLE',
|
||||||
'float32': 'FLOAT',
|
|
||||||
'datetime64[ns]': 'DATETIME',
|
'datetime64[ns]': 'DATETIME',
|
||||||
'datetime64[ns, UTC]': 'DATETIME',
|
|
||||||
'timedelta64[ns]': 'TIME',
|
|
||||||
'object': 'TEXT',
|
'object': 'TEXT',
|
||||||
'string': 'VARCHAR(255)',
|
|
||||||
'bool': 'TINYINT(1)',
|
'bool': 'TINYINT(1)',
|
||||||
'category': 'VARCHAR(255)'
|
'category': 'VARCHAR(255)'
|
||||||
}
|
}
|
||||||
@@ -353,136 +488,84 @@ class MySQLAgent:
|
|||||||
dtype_str = str(dtype)
|
dtype_str = str(dtype)
|
||||||
sql_types[col] = type_mapping.get(dtype_str, 'TEXT')
|
sql_types[col] = type_mapping.get(dtype_str, 'TEXT')
|
||||||
|
|
||||||
self.log.debug("DataFrame类型映射为SQL类型", mappings=sql_types)
|
self.log.debug("Mapped DataFrame types to SQL types",
|
||||||
|
mappings=sql_types)
|
||||||
return sql_types
|
return sql_types
|
||||||
|
|
||||||
def create_table_from_df(self, table_name: str, df: pd.DataFrame,
|
def create_table_from_df(self, table_name: str, df: pd.DataFrame,
|
||||||
primary_key: Union[str, List[str], None] = None) -> bool:
|
primary_key: Union[str, List[str], None] = None) -> bool:
|
||||||
"""根据DataFrame结构创建表(增强表结构定义)"""
|
"""根据DataFrame结构创建表(原有逻辑完全保留)"""
|
||||||
if self.table_exists(table_name):
|
if self.table_exists(table_name):
|
||||||
self.log.warning(f"表{table_name}已存在")
|
self.log.warning("Table already exists", table=table_name)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self.log.debug(f"根据DataFrame结构创建表{table_name}", columns=list(df.columns))
|
self.log.debug("Creating new table from DataFrame schema",
|
||||||
|
table=table_name,
|
||||||
|
columns=list(df.columns))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sql_types = self.df_to_sql_type(df)
|
sql_types = self.df_to_sql_type(df)
|
||||||
columns_sql = []
|
columns_sql = []
|
||||||
|
|
||||||
for col, sql_type in sql_types.items():
|
for col, sql_type in sql_types.items():
|
||||||
# 特殊字段处理
|
col_def = f"{col} {sql_type}"
|
||||||
if col.lower() in ['create_time', 'created_at'] and sql_type != 'DATETIME':
|
|
||||||
col_def = f"`{col}` DATETIME DEFAULT CURRENT_TIMESTAMP"
|
|
||||||
elif col.lower() in ['update_time', 'updated_at'] and sql_type != 'DATETIME':
|
|
||||||
col_def = f"`{col}` DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"
|
|
||||||
else:
|
|
||||||
col_def = f"`{col}` {sql_type}"
|
|
||||||
columns_sql.append(col_def)
|
columns_sql.append(col_def)
|
||||||
|
|
||||||
# 处理主键
|
|
||||||
if primary_key:
|
if primary_key:
|
||||||
if isinstance(primary_key, str):
|
if isinstance(primary_key, str):
|
||||||
primary_key = [primary_key]
|
primary_key = [primary_key]
|
||||||
pk_columns = [f"`{col}`" for col in primary_key if col in sql_types]
|
pk_columns = [col for col in primary_key if col in sql_types]
|
||||||
if pk_columns:
|
if pk_columns:
|
||||||
columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})")
|
columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})")
|
||||||
self.log.trace(f"表{table_name}设置主键", primary_key=pk_columns)
|
self.log.trace("Set primary key",
|
||||||
|
table=table_name,
|
||||||
|
primary_key=pk_columns)
|
||||||
|
|
||||||
|
create_sql = f"CREATE TABLE {table_name} (\n {',\n '.join(columns_sql)}\n)"
|
||||||
|
|
||||||
create_sql = f"CREATE TABLE `{table_name}` (\n {',\n '.join(columns_sql)}\n)"
|
|
||||||
self.execute_sql(create_sql)
|
self.execute_sql(create_sql)
|
||||||
self.log.info(f"表{table_name}创建成功")
|
self.log.info("Table created successfully", table=table_name)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log.error(f"表{table_name}创建失败", error=str(e), exc_info=True)
|
self.log.error("Failed to create table",
|
||||||
|
table=table_name,
|
||||||
|
error=str(e),
|
||||||
|
exc_info=True)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
|
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
|
||||||
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
|
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
|
||||||
"""执行SQL语句(增强资源管理)"""
|
"""执行SQL语句(原有逻辑完全保留)"""
|
||||||
conn = None
|
|
||||||
cursor = None
|
|
||||||
try:
|
try:
|
||||||
conn = self.get_connection()
|
with self.get_connection() as conn:
|
||||||
cursor = conn.cursor()
|
with conn.cursor() as cursor:
|
||||||
|
# Linux/macOS需要更长的执行时间
|
||||||
# 非Windows平台延长执行超时
|
|
||||||
if platform.system() != 'Windows':
|
if platform.system() != 'Windows':
|
||||||
cursor.execute("SET SESSION max_execution_time=600000") # 10分钟
|
cursor.execute("SET SESSION max_execution_time=600000")
|
||||||
|
|
||||||
cursor.execute(sql, params)
|
cursor.execute(sql, params)
|
||||||
|
|
||||||
if fetch:
|
if fetch:
|
||||||
result = cursor.fetchall()
|
result = cursor.fetchall()
|
||||||
self.log.debug(f"查询执行完成,返回{len(result)}行")
|
self.log.debug("Query executed", rows=len(result))
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
affected_rows = cursor.rowcount
|
affected_rows = cursor.rowcount
|
||||||
self.log.debug(f"更新执行完成,影响{affected_rows}行")
|
conn.commit() # 立即提交
|
||||||
|
self.log.debug("Update executed", affected_rows=affected_rows)
|
||||||
return affected_rows
|
return affected_rows
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log.error("SQL执行失败", sql=sql, params=params, error=str(e), exc_info=True)
|
self.log.error("SQL execution failed",
|
||||||
|
sql=sql,
|
||||||
|
params=params,
|
||||||
|
error=str(e),
|
||||||
|
exc_info=True)
|
||||||
raise
|
raise
|
||||||
finally:
|
|
||||||
if cursor:
|
|
||||||
try:
|
|
||||||
cursor.close()
|
|
||||||
except Exception as e:
|
|
||||||
self.log.warning("关闭游标失败", error=str(e))
|
|
||||||
if conn:
|
|
||||||
try:
|
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
|
||||||
self.log.warning("关闭连接失败", error=str(e))
|
|
||||||
|
|
||||||
def begin_transaction(self) -> pymysql.connections.Connection:
|
|
||||||
"""开始事务(增强隔离级别处理)"""
|
|
||||||
try:
|
|
||||||
conn = self.get_connection()
|
|
||||||
conn.autocommit(False)
|
|
||||||
|
|
||||||
# 平台特定事务配置
|
|
||||||
if platform.system() == 'Darwin':
|
|
||||||
conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED")
|
|
||||||
elif platform.system() == 'Linux':
|
|
||||||
conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL REPEATABLE READ")
|
|
||||||
|
|
||||||
self.log.debug("事务开始")
|
|
||||||
return conn
|
|
||||||
except Exception as e:
|
|
||||||
self.log.error("事务开始失败", error=str(e))
|
|
||||||
raise
|
|
||||||
|
|
||||||
def commit_transaction(self, conn: pymysql.connections.Connection) -> None:
|
|
||||||
"""提交事务"""
|
|
||||||
try:
|
|
||||||
conn.commit()
|
|
||||||
self.log.debug("事务提交成功")
|
|
||||||
except Exception as e:
|
|
||||||
self.log.error("事务提交失败", error=str(e))
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
|
||||||
self.log.warning("事务提交后关闭连接失败", error=str(e))
|
|
||||||
|
|
||||||
def rollback_transaction(self, conn: pymysql.connections.Connection) -> None:
|
|
||||||
"""回滚事务"""
|
|
||||||
try:
|
|
||||||
conn.rollback()
|
|
||||||
self.log.warning("事务已回滚")
|
|
||||||
except Exception as e:
|
|
||||||
self.log.error("事务回滚失败", error=str(e))
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
|
||||||
self.log.warning("事务回滚后关闭连接失败", error=str(e))
|
|
||||||
|
|
||||||
def table_exists(self, table_name: str) -> bool:
|
def table_exists(self, table_name: str) -> bool:
|
||||||
"""检查表是否存在(优化SQL安全性)"""
|
"""检查表是否存在(原有逻辑完全保留)"""
|
||||||
sql = """
|
sql = """
|
||||||
SELECT COUNT(*) as count
|
SELECT COUNT(*) as count
|
||||||
FROM `information_schema`.`tables`
|
FROM `information_schema`.`tables`
|
||||||
@@ -490,64 +573,49 @@ class MySQLAgent:
|
|||||||
AND `table_name` = %s \
|
AND `table_name` = %s \
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
params = (self.config['database'], table_name)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.execute_sql(sql, (self.config['database'], table_name), fetch=True)
|
result = self.execute_sql(sql, params, fetch=True)
|
||||||
exists = result[0]['count'] > 0
|
exists = result[0][0] > 0 # 适配元组结果
|
||||||
self.log.debug(f"表{table_name}存在性检查", exists=exists)
|
self.log.debug("Checked table existence",
|
||||||
|
table=table_name,
|
||||||
|
exists=exists)
|
||||||
return exists
|
return exists
|
||||||
except Exception as e:
|
except Exception:
|
||||||
self.log.warning(f"表{table_name}存在性检查失败", error=str(e))
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def drop_table(self, table_name: str) -> bool:
|
def drop_table(self, table_name: str) -> bool:
|
||||||
"""删除表(增加二次确认日志)"""
|
"""删除表(原有逻辑完全保留)"""
|
||||||
if not self.table_exists(table_name):
|
if not self.table_exists(table_name):
|
||||||
self.log.warning(f"表{table_name}不存在,无法删除")
|
self.log.warning("Table does not exist", table=table_name)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.execute_sql(f"DROP TABLE `{table_name}`")
|
self.execute_sql(f"DROP TABLE {table_name}")
|
||||||
self.log.info(f"表{table_name}删除成功")
|
self.log.info("Table dropped successfully", table=table_name)
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.log.error(f"表{table_name}删除失败", error=str(e), exc_info=True)
|
self.log.error("Failed to drop table",
|
||||||
|
table=table_name,
|
||||||
|
error=str(e),
|
||||||
|
exc_info=True)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_pool_status(self) -> Dict[str, int]:
|
|
||||||
"""获取连接池状态"""
|
|
||||||
status = {
|
|
||||||
'max_connections': self._pool._maxconnections,
|
|
||||||
'active_connections': len(self._pool._connections),
|
|
||||||
'idle_connections': len(self._pool._idle_cache),
|
|
||||||
'shared_connections': len(self._pool._shared_cache)
|
|
||||||
}
|
|
||||||
self.log.debug("连接池状态", **status)
|
|
||||||
return status
|
|
||||||
|
|
||||||
def validate_connection(self) -> bool:
|
def validate_connection(self) -> bool:
|
||||||
"""验证连接是否有效(增强健康检查)"""
|
"""验证连接是否有效(原有逻辑完全保留)"""
|
||||||
try:
|
try:
|
||||||
with self.get_connection() as conn:
|
with self.get_connection() as conn:
|
||||||
with conn.cursor() as cursor:
|
with conn.cursor() as cursor:
|
||||||
cursor.execute("SELECT 1 AS health_check")
|
cursor.execute("SELECT 1")
|
||||||
result = cursor.fetchone()
|
return cursor.fetchone()[0] == 1
|
||||||
return result['health_check'] == 1
|
except Exception:
|
||||||
except Exception as e:
|
|
||||||
self.log.warning("连接健康检查失败", error=str(e))
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
"""析构函数(确保连接池关闭)"""
|
|
||||||
if hasattr(self, '_pool') and self._pool:
|
|
||||||
try:
|
|
||||||
self._pool.close()
|
|
||||||
self.log.info("连接池已关闭")
|
|
||||||
except Exception as e:
|
|
||||||
self.log.error("连接池关闭失败", error=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
|
# 平台特定的默认配置(原有逻辑完全保留)
|
||||||
def get_default_config():
|
def get_default_config():
|
||||||
"""获取各平台默认配置(优化默认参数)"""
|
"""获取各平台默认配置"""
|
||||||
current_platform = platform.system()
|
current_platform = platform.system()
|
||||||
|
|
||||||
base_config = {
|
base_config = {
|
||||||
@@ -555,76 +623,41 @@ def get_default_config():
|
|||||||
'port': 3306,
|
'port': 3306,
|
||||||
'user': 'root',
|
'user': 'root',
|
||||||
'password': '123123',
|
'password': '123123',
|
||||||
'database': 'intelligence',
|
'database': 'intelligence_system',
|
||||||
'max_connections': 10, # 增加默认连接数
|
|
||||||
'charset': 'utf8mb4'
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if current_platform == 'Windows':
|
if current_platform == 'Windows':
|
||||||
return {
|
return {**base_config,
|
||||||
**base_config,
|
|
||||||
'connect_timeout': 10,
|
'connect_timeout': 10,
|
||||||
'read_timeout': 30,
|
'read_timeout': 30,
|
||||||
'write_timeout': 30,
|
'write_timeout': 30
|
||||||
'ssl': None # Windows默认禁用SSL
|
|
||||||
}
|
}
|
||||||
elif current_platform == 'Darwin': # macOS
|
elif current_platform == 'Darwin':
|
||||||
return {
|
return {
|
||||||
**base_config,
|
**base_config,
|
||||||
'connect_timeout': 15,
|
'connect_timeout': 15,
|
||||||
'read_timeout': 60,
|
'read_timeout': 60,
|
||||||
'write_timeout': 60,
|
'write_timeout': 60,
|
||||||
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'} # macOS默认SSL配置
|
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
|
||||||
}
|
}
|
||||||
else: # Linux及其他平台
|
else: # Linux和其他平台
|
||||||
return {
|
return {**base_config,
|
||||||
**base_config,
|
|
||||||
'connect_timeout': 15,
|
'connect_timeout': 15,
|
||||||
'read_timeout': 60,
|
'read_timeout': 60,
|
||||||
'write_timeout': 60,
|
'write_timeout': 60
|
||||||
'ssl': None # Linux默认禁用SSL
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 使用示例
|
# 使用示例(原有逻辑完全保留)
|
||||||
try:
|
|
||||||
db = MySQLAgent(get_default_config())
|
db = MySQLAgent(get_default_config())
|
||||||
|
|
||||||
# 测试连接
|
# 测试连接
|
||||||
if db.validate_connection():
|
if db.validate_connection():
|
||||||
print("数据库连接成功")
|
print("Database connection successful")
|
||||||
|
|
||||||
# 获取数据库版本
|
# 获取数据库版本
|
||||||
version_df = db.query_to_df("SELECT VERSION() as version")
|
version = db.query_to_df("SELECT VERSION() as version")
|
||||||
print(f"数据库版本: {version_df['version'].iloc[0]}")
|
print(f"Database version: {version['version'].iloc[0]}")
|
||||||
|
|
||||||
# 查看连接池状态
|
|
||||||
print("连接池状态:", db.get_pool_status())
|
|
||||||
|
|
||||||
# 创建测试表
|
|
||||||
test_df = pd.DataFrame({
|
|
||||||
'id': [1, 2, 3],
|
|
||||||
'name': ['测试1', '测试2', '测试3'],
|
|
||||||
'value': [10.5, 20.3, 30.8],
|
|
||||||
'created_at': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03'])
|
|
||||||
})
|
|
||||||
db.create_table_from_df('test_table', test_df, primary_key='id')
|
|
||||||
print("测试表创建成功")
|
|
||||||
|
|
||||||
# 插入数据
|
|
||||||
rows_inserted = db.insert_from_df('test_table', test_df)
|
|
||||||
print(f"插入了{rows_inserted}行数据")
|
|
||||||
|
|
||||||
# 查询数据
|
|
||||||
result_df = db.query_to_df("SELECT * FROM test_table")
|
|
||||||
print("查询结果:")
|
|
||||||
print(result_df)
|
|
||||||
|
|
||||||
# 清理测试表
|
|
||||||
db.drop_table('test_table')
|
|
||||||
print("测试表已删除")
|
|
||||||
else:
|
else:
|
||||||
print("数据库连接失败")
|
print("Failed to connect to database")
|
||||||
except Exception as e:
|
|
||||||
print(f"示例执行失败: {str(e)}")
|
|
||||||
Reference in New Issue
Block a user