minio对象存储数据库链接

This commit is contained in:
z66
2025-09-16 17:35:53 +08:00
parent 8e92acf5d5
commit 9afa9d2e58
10 changed files with 7291 additions and 347 deletions
+1
View File
@@ -2,6 +2,7 @@
<project version="4"> <project version="4">
<component name="SqlDialectMappings"> <component name="SqlDialectMappings">
<file url="file://$PROJECT_DIR$/tools/SQL.sql" dialect="MySQL" /> <file url="file://$PROJECT_DIR$/tools/SQL.sql" dialect="MySQL" />
<file url="file://$PROJECT_DIR$/tools/情报收集.sql" dialect="MySQL" />
<file url="PROJECT" dialect="MySQL" /> <file url="PROJECT" dialect="MySQL" />
</component> </component>
<component name="SqlResolveMappings"> <component name="SqlResolveMappings">
+277
View File
@@ -0,0 +1,277 @@
import feedparser
import requests
from datetime import datetime
import pandas as pd
import os
import pickle
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
import pymysql
# 数据库连接信息
local_DB_Config = {
'host': "localhost",
'user': "root",
'password': "123123",
'database': "intelligence_system",
'charset': 'utf8mb4'
}
# 表名
table_name = "collector_rss_subscriptions"
def verify_database():
"""验证数据库连接和表结构"""
try:
conn = pymysql.connect(**local_DB_Config)
with conn.cursor() as cursor:
# 检查表是否存在
cursor.execute(f"SHOW TABLES LIKE '{table_name}'")
if not cursor.fetchone():
print(f"错误: 表 {table_name} 不存在!")
return False
# 检查表结构
cursor.execute(f"DESCRIBE {table_name}")
columns = [col[0] for col in cursor.fetchall()]
print("表列名:", columns)
# 检查插入权限
test_sql = f"""INSERT INTO `{table_name}`
(`文章标题`, `文章链接`, `文章摘要`, `发布时间`, `来源URL`)
VALUES (%s, %s, %s, %s, %s)"""
cursor.execute(test_sql, ('测试标题', 'http://test.com', '测试内容', datetime.now(), '测试来源'))
conn.rollback()
print("数据库验证通过!")
return True
except Exception as e:
print("数据库验证失败:", e)
return False
finally:
if 'conn' in locals():
conn.close()
def load_last_update_time():
"""加载上次更新的时间"""
cache_file = os.path.join(os.getcwd(), 'output', 'last_update.pkl')
if os.path.exists(cache_file):
with open(cache_file, 'rb') as f:
return pickle.load(f)
return None
def save_last_update_time(last_update):
"""保存本次更新的时间"""
cache_file = os.path.join(os.getcwd(), 'output', 'last_update.pkl')
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
with open(cache_file, 'wb') as f:
pickle.dump(last_update, f)
def fetch_single_rss(url, timeout=15):
"""获取并解析单个 RSS 源"""
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
for attempt in range(3):
try:
response = requests.get(url, headers=headers, timeout=timeout)
response.raise_for_status()
response.encoding = response.apparent_encoding
feed = feedparser.parse(response.text)
if feed.bozo:
print(f"警告: 解析可能存在问题: {feed.bozo_exception}")
return feed
except requests.RequestException as e:
print(f"{attempt + 1} 次尝试获取 {url} 失败: {e}")
if attempt < 2:
time.sleep(5 * (attempt + 1))
continue
return None
def fetch_all_rss(urls, timeout=15):
"""使用线程池并发获取多个RSS源"""
feeds = {}
with ThreadPoolExecutor(max_workers=3) as executor:
future_to_url = {executor.submit(fetch_single_rss, url, timeout): url for url in urls}
for future in as_completed(future_to_url):
url = future_to_url[future]
try:
feed = future.result()
if feed:
feeds[url] = feed
except Exception as e:
print(f"获取 {url} 时发生异常: {e}")
return feeds
def process_feed_entry(entry, url):
"""处理单个RSS条目并返回结构化数据"""
# 处理标题
title = entry.get('title', '无标题') or '无标题'
if len(title) > 255:
title = title[:252] + '...'
# 处理链接
link = entry.get('link', '无链接') or '无链接'
if len(link) > 1024:
link = link[:1021] + '...'
# 处理摘要
summary = entry.get('summary', '无内容摘要')
content_list = entry.get('content', [])
content = content_list[0].value if content_list else ''
description = summary if summary != '无内容摘要' else (content[:200] + '...' if content else '无内容摘要')
# 处理发布时间
published_parsed = entry.get('published_parsed') or entry.get('updated_parsed')
if published_parsed:
entry_time = datetime(*published_parsed[:6])
else:
pub_str = entry.get('published', entry.get('updated', ''))
try:
entry_time = datetime.strptime(pub_str, '%a, %d %b %Y %H:%M:%S %z')
except:
entry_time = datetime.now()
# 处理来源URL
source_url = url or '未知来源'
if len(source_url) > 1024:
source_url = source_url[:1021] + '...'
return {
'文章标题': title,
'文章链接': link,
'文章摘要': description,
'发布时间': entry_time.strftime('%Y-%m-%d %H:%M:%S'),
'来源URL': source_url
}
def display_feed_info(feed, last_update=None, url=None):
"""处理并显示RSS源信息"""
if not feed:
print("无法显示信息:feed 为 None")
return None
print("=" * 80)
print(f"处理 RSS 源: {url}")
entries = feed.entries
data_list = []
new_last_update = last_update
for i, entry in enumerate(entries, 1):
entry_data = process_feed_entry(entry, url)
entry_time = datetime.strptime(entry_data['发布时间'], '%Y-%m-%d %H:%M:%S')
if last_update and entry_time <= last_update:
continue
if new_last_update is None or entry_time > new_last_update:
new_last_update = entry_time
print(f"\n--- 条目 {i} ---")
print(f"标题: {entry_data['文章标题']}")
print(f"链接: {entry_data['文章链接']}")
print(f"摘要: {entry_data['文章摘要'][:100]}...")
print(f"时间: {entry_data['发布时间']}")
data_list.append(entry_data)
if data_list:
df = pd.DataFrame(data_list)
write_to_database(df)
return new_last_update
def write_to_database(df):
"""将数据写入数据库"""
if df.empty:
print("没有新数据需要写入")
return
print("\n准备写入数据库的数据样例:")
print(df.iloc[0].to_dict())
try:
conn = pymysql.connect(**local_DB_Config)
with conn.cursor() as cursor:
sql = f"""INSERT IGNORE INTO `{table_name}`
(`文章标题`, `文章链接`, `文章摘要`, `发布时间`, `来源URL`)
VALUES (%s, %s, %s, %s, %s)"""
success_count = 0
for _, row in df.iterrows():
try:
cursor.execute(sql, (
row['文章标题'],
row['文章链接'],
row['文章摘要'],
row['发布时间'],
row['来源URL']
))
success_count += cursor.rowcount
except Exception as e:
print(f"插入记录时出错: {e}")
print(f"问题数据: {row.to_dict()}")
continue
conn.commit()
print(f"成功写入 {success_count}/{len(df)} 条记录")
except Exception as e:
print("数据库操作失败:", e)
finally:
if 'conn' in locals():
conn.close()
def main():
"""主函数"""
if not verify_database():
print("数据库验证失败,程序终止")
return
rss_urls = [
"https://www.chinanews.com.cn/rss/finance.xml",
"https://www.chinanews.com.cn/rss/world.xml",
"https://www.chinanews.com.cn/rss/china.xml",
"https://www.chinanews.com.cn/rss/scroll-news.xml"
]
last_update = load_last_update_time()
if last_update:
print(f"上次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
print("\n开始获取RSS源数据...")
start_time = time.time()
feeds = fetch_all_rss(rss_urls)
print(f"获取完成,耗时: {time.time() - start_time:.2f}")
new_last_update = None
for url, feed in feeds.items():
current_last_update = display_feed_info(feed, last_update, url)
if current_last_update and (new_last_update is None or current_last_update > new_last_update):
new_last_update = current_last_update
if new_last_update:
save_last_update_time(new_last_update)
print(f"\n本次最新更新时间: {new_last_update.strftime('%Y-%m-%d %H:%M:%S')}")
else:
print("\n没有获取到新的内容")
if __name__ == "__main__":
main()
+1
View File
@@ -5,6 +5,7 @@ class Config:
'port': 3306, 'port': 3306,
'user': 'root', 'user': 'root',
'password': '123123', 'password': '123123',
'database':"intelligence_system",
'max_connections': 10 'max_connections': 10
} }
+4846
View File
File diff suppressed because it is too large Load Diff
+1747
View File
File diff suppressed because it is too large Load Diff
+3 -1
View File
@@ -3,6 +3,8 @@ import time
from datetime import datetime from datetime import datetime
from system_management.scheduler.task_scheduler import TaskScheduler from system_management.scheduler.task_scheduler import TaskScheduler
from utils.logger import CrossPlatformLog from utils.logger import CrossPlatformLog
from config import Config
# 初始化日志 # 初始化日志
log = CrossPlatformLog.get_logger("Main") log = CrossPlatformLog.get_logger("Main")
@@ -11,7 +13,7 @@ log = CrossPlatformLog.get_logger("Main")
class IntelligenceSystem: class IntelligenceSystem:
def __init__(self, db_config=None): def __init__(self, db_config=None):
"""初始化系统(仅作为容器,不包含业务逻辑)""" """初始化系统(仅作为容器,不包含业务逻辑)"""
self.scheduler = TaskScheduler(db_config, max_workers=5) self.scheduler = TaskScheduler(Config.MYSQL_CONFIG, max_workers=5)
self._running = False self._running = False
log.info("情报系统已初始化(Cron模式)") log.info("情报系统已初始化(Cron模式)")
+147 -53
View File
@@ -6,12 +6,14 @@ import croniter
import pytz import pytz
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
import pandas as pd import pandas as pd
from sqlalchemy.exc import SQLAlchemyError
from utils.mysql_agent import MySQLAgent from utils.mysql_agent import MySQLAgent
from utils.logger import CrossPlatformLog from utils.logger import CrossPlatformLog
# 初始化调度器日志 # 初始化调度器日志
log = CrossPlatformLog.get_logger("TaskScheduler") log = CrossPlatformLog.get_logger("TaskScheduler")
class TaskScheduler: class TaskScheduler:
def __init__(self, db_config: Optional[Dict] = None, max_workers: int = 5): def __init__(self, db_config: Optional[Dict] = None, max_workers: int = 5):
"""初始化任务调度器(基于Cron表达式)""" """初始化任务调度器(基于Cron表达式)"""
@@ -20,31 +22,37 @@ class TaskScheduler:
log.info(f"任务调度器已初始化,最大工作线程数: {max_workers}") log.info(f"任务调度器已初始化,最大工作线程数: {max_workers}")
def check_and_run_tasks(self) -> Dict[str, int]: def check_and_run_tasks(self) -> Dict[str, int]:
"""检查并执行所有到期的任务""" """检查并执行所有到期的任务,优化空任务处理和异常容错"""
result = {'总任务数': 0, '成功': 0, '失败': 0} result = {'总任务数': 0, '成功': 0, '失败': 0}
try: try:
# 获取当前时间(带时区) # 获取当前时间(带时区转换为本地时间
now = datetime.now(pytz.timezone('Asia/Shanghai')).replace(tzinfo=None) tz = pytz.timezone('Asia/Shanghai')
now = datetime.now(tz).replace(tzinfo=None) # 移除时区信息,与数据库存储一致
log.debug(f"当前检查时间: {now.strftime('%Y-%m-%d %H:%M:%S')}")
# 查询所有到期的活跃任务 # 查询所有到期的活跃任务(使用参数化查询防止注入)
tasks_df = self.db.query_to_df(""" tasks_df = self.db.query_to_df("""
SELECT * FROM main_task SELECT *
WHERE is_active = 1 FROM main_task
AND next_run_time <= %s WHERE is_active = 1
AND is_running = 0 AND next_run_time <= %s
ORDER BY next_run_time AND is_running = 0
""", params=(now,)) ORDER BY next_run_time
""", params=(now,))
result['总任务数'] = len(tasks_df) result['总任务数'] = len(tasks_df)
if tasks_df.empty: if tasks_df.empty:
log.debug("没有到期的任务需要执行") # 空任务时输出INFO级日志,明确提示状态
log.info("当前没有到期的任务,等待新任务加入...")
return result return result
# 并发执行任务 # 并发执行任务
futures = [] futures = []
for _, task in tasks_df.iterrows(): for _, task in tasks_df.iterrows():
futures.append(self.executor.submit(self._process_single_task, task)) # 传递任务字典的副本避免线程安全问题
task_copy = task.to_dict()
futures.append(self.executor.submit(self._process_single_task, task_copy))
# 收集执行结果 # 收集执行结果
for future in as_completed(futures): for future in as_completed(futures):
@@ -54,7 +62,7 @@ class TaskScheduler:
else: else:
result['失败'] += 1 result['失败'] += 1
except Exception as e: except Exception as e:
log.error(f"任务线程执行失败: {str(e)}") log.error(f"任务线程执行失败: {str(e)}", exc_info=True)
result['失败'] += 1 result['失败'] += 1
log.info( log.info(
@@ -65,21 +73,28 @@ class TaskScheduler:
) )
return result return result
except SQLAlchemyError as e: # 数据库异常处理优化
log.error(f"数据库操作失败,将在下次轮询重试: {str(e)}", exc_info=True)
return result # 不中断,返回当前结果
except Exception as e: except Exception as e:
log.critical("调度器主循环执行失败", exc_info=True) log.error("调度器周期执行异常,将在下次轮询重试", exc_info=True)
raise return result # 不中断主循环,允许下次重试
def _process_single_task(self, task: Dict[str, Any]) -> bool: def _process_single_task(self, task: Dict[str, Any]) -> bool:
"""处理单个任务(线程安全)""" """处理单个任务(线程安全)"""
task_id = task['task_id'] task_id = task['task_id']
task_log = log.bind(task_id=task_id, task_name=task['task_name']) task_name = task['task_name']
task_log.info(f"开始执行任务: {task['task_name']}") task_log = log.bind(task_id=task_id, task_name=task_name)
task_log.info(f"开始执行任务: {task_name}")
try: try:
# 标记任务为运行中 # 标记任务为运行中(使用当前时间的时区感知对象)
tz = pytz.timezone(task.get('time_zone', 'Asia/Shanghai'))
current_time = datetime.now(tz).replace(tzinfo=None)
self._update_task_status(task_id, { self._update_task_status(task_id, {
'is_running': 1, 'is_running': 1,
'last_run_time': datetime.now() 'last_run_time': current_time
}) })
# 执行任务逻辑 # 执行任务逻辑
@@ -98,7 +113,7 @@ class TaskScheduler:
'run_count': task['run_count'] + 1, 'run_count': task['run_count'] + 1,
'next_run_time': next_run_time 'next_run_time': next_run_time
}) })
task_log.info(f"任务执行成功: {task['task_name']}") task_log.info(f"任务执行成功: {task_name}")
return True return True
except Exception as e: except Exception as e:
@@ -107,23 +122,35 @@ class TaskScheduler:
# 失败时计算下次重试时间(15分钟后) # 失败时计算下次重试时间(15分钟后)
next_retry_time = datetime.now() + pd.Timedelta(minutes=15) next_retry_time = datetime.now() + pd.Timedelta(minutes=15)
self._update_task_status(task_id, { # 即使任务执行失败,也要确保状态更新
'last_run_status': 'failed', try:
'is_running': 0, self._update_task_status(task_id, {
'next_run_time': next_retry_time 'last_run_status': 'failed',
}) 'is_running': 0,
'next_run_time': next_retry_time
})
except Exception as update_err:
task_log.error(f"任务失败后状态更新失败: {str(update_err)}", exc_info=True)
return False return False
def _execute_task_logic(self, task: Dict[str, Any]) -> None: def _execute_task_logic(self, task: Dict[str, Any]) -> None:
"""执行任务的具体逻辑(动态导入模块)""" """执行任务的具体逻辑(动态导入模块)"""
start_time = time.time() start_time = time.time()
task_log = log.bind(task_id=task['task_id'], module=task['module_path']) task_id = task['task_id']
module_path = task['module_path']
task_log = log.bind(task_id=task_id, module=module_path)
try: try:
# 动态导入任务模块 # 动态导入任务模块(增加模块存在性检查)
module = importlib.import_module(task['module_path']) try:
if not hasattr(module, 'main'): module = importlib.import_module(module_path)
raise ImportError(f"模块 {task['module_path']} 中未找到 main() 函数") except ImportError as e:
raise ImportError(f"模块 {module_path} 导入失败: {str(e)}")
# 检查main函数是否存在
if not hasattr(module, 'main') or not callable(module.main):
raise AttributeError(f"模块 {module_path} 中未找到可调用的 main() 函数")
task_log.debug("开始执行模块中的 main() 函数") task_log.debug("开始执行模块中的 main() 函数")
module.main() # 调用任务主函数 module.main() # 调用任务主函数
@@ -137,7 +164,7 @@ class TaskScheduler:
"""基于Cron表达式计算下次运行时间""" """基于Cron表达式计算下次运行时间"""
try: try:
tz = pytz.timezone(time_zone) tz = pytz.timezone(time_zone)
now = datetime.now(tz) now = datetime.now(tz) # 使用任务指定时区的当前时间
cron = croniter.croniter(cron_expr, now) cron = croniter.croniter(cron_expr, now)
next_run = cron.get_next(datetime) next_run = cron.get_next(datetime)
return next_run.replace(tzinfo=None) # 移除时区信息,适应数据库存储 return next_run.replace(tzinfo=None) # 移除时区信息,适应数据库存储
@@ -146,12 +173,26 @@ class TaskScheduler:
raise ValueError(f"无效的Cron表达式: {cron_expr}") raise ValueError(f"无效的Cron表达式: {cron_expr}")
def _update_task_status(self, task_id: int, updates: Dict[str, Any]) -> None: def _update_task_status(self, task_id: int, updates: Dict[str, Any]) -> None:
"""更新任务状态到数据库""" """更新任务状态到数据库(适配SQLAlchemy的参数传递方式)"""
set_clause = ", ".join([f"{k}=%s" for k in updates.keys()]) if not updates:
log.warning(f"任务ID {task_id} 未提供任何更新字段")
return
# 构建UPDATE语句(确保字段名安全)
valid_fields = {'is_running', 'last_run_time', 'last_run_status',
'run_count', 'next_run_time', 'updated_at'}
filtered_updates = {k: v for k, v in updates.items() if k in valid_fields}
if not filtered_updates:
log.warning(f"任务ID {task_id} 没有有效的更新字段")
return
set_clause = ", ".join([f"{k}=%s" for k in filtered_updates.keys()])
sql = f"UPDATE main_task SET {set_clause}, updated_at=NOW() WHERE task_id=%s" sql = f"UPDATE main_task SET {set_clause}, updated_at=NOW() WHERE task_id=%s"
params = list(updates.values()) + [task_id] params = list(filtered_updates.values()) + [task_id]
try: try:
# 执行更新并获取受影响的行数
affected_rows = self.db.execute_sql(sql, params=params) affected_rows = self.db.execute_sql(sql, params=params)
if affected_rows != 1: if affected_rows != 1:
log.warning( log.warning(
@@ -160,51 +201,104 @@ class TaskScheduler:
预期影响行数=1, 预期影响行数=1,
实际影响行数=affected_rows 实际影响行数=affected_rows
) )
except SQLAlchemyError as e:
log.error(f"任务状态更新失败(数据库错误),task_id: {task_id}", exc_info=True)
raise
except Exception as e: except Exception as e:
log.error(f"任务状态更新失败,task_id: {task_id}", exc_info=True) log.error(f"任务状态更新失败,task_id: {task_id}", exc_info=True)
raise raise
def add_task(self, def add_task(self,
task_name: str, task_name: str,
task_type: str, task_type: str,
module_path: str, module_path: str,
cron_expression: str, cron_expression: str,
time_zone: str = 'Asia/Shanghai') -> int: time_zone: str = 'Asia/Shanghai') -> int:
"""添加新的Cron任务""" """添加新的Cron任务"""
if not cron_expression: if not cron_expression:
raise ValueError("Cron表达式不能为空") raise ValueError("Cron表达式不能为空")
# 验证模块是否存在(提前检查,避免添加无效任务)
try:
importlib.import_module(module_path)
except ImportError as e:
raise ValueError(f"模块 {module_path} 不存在: {str(e)}")
# 计算首次运行时间 # 计算首次运行时间
first_run_time = self._calculate_next_run_time(cron_expression, time_zone) first_run_time = self._calculate_next_run_time(cron_expression, time_zone)
# 插入数据库 # 插入数据库
sql = """ sql = """
INSERT INTO main_task INSERT INTO main_task
(task_name, task_type, module_path, cron_expression, time_zone, (task_name, task_type, module_path, cron_expression, time_zone,
next_run_time, is_active) next_run_time, is_active, created_at, updated_at)
VALUES (%s, %s, %s, %s, %s, %s, 1) VALUES (%s, %s, %s, %s, %s, %s, 1, NOW(), NOW()) \
""" """
params = (task_name, task_type, module_path, cron_expression, time_zone, first_run_time) params = (task_name, task_type, module_path, cron_expression, time_zone, first_run_time)
try: try:
self.db.execute_sql(sql, params=params) self.db.execute_sql(sql, params=params)
task_id = self.db.query_to_df("SELECT LAST_INSERT_ID() AS id").iloc[0]['id'] # 获取插入的任务ID
result_df = self.db.query_to_df("SELECT LAST_INSERT_ID() AS id")
if result_df.empty or 'id' not in result_df.columns:
raise ValueError("无法获取新添加任务的ID")
task_id = result_df.iloc[0]['id']
log.info( log.info(
f"新任务添加成功", "新任务添加成功",
task_id=task_id, task_id=task_id,
task_name=task_name, task_name=task_name,
cron表达式=cron_expression, cron表达式=cron_expression,
首次运行时间=first_run_time 首次运行时间=first_run_time.strftime('%Y-%m-%d %H:%M:%S')
) )
return task_id return task_id
except SQLAlchemyError as e:
log.error(f"添加任务失败(数据库错误): {task_name}", exc_info=True)
raise
except Exception as e: except Exception as e:
log.error(f"添加任务失败: {task_name}", exc_info=True) log.error(f"添加任务失败: {task_name}", exc_info=True)
raise raise
def get_pending_tasks_count(self) -> int: def get_pending_tasks_count(self) -> int:
"""获取当前等待执行任务数量""" """获取待执行任务数量(用于优雅关闭)"""
result = self.db.query_to_df(""" try:
SELECT COUNT(*) AS count FROM main_task tz = pytz.timezone('Asia/Shanghai')
WHERE is_active = 1 AND next_run_time <= %s now = datetime.now(tz).replace(tzinfo=None)
""", params=(datetime.now(),)) sql = """
return result.iloc[0]['count'] if not result.empty else 0 SELECT COUNT(*) as cnt
FROM main_task
WHERE is_active = 1
AND next_run_time <= %s
AND is_running = 0
"""
df = self.db.query_to_df(sql, params=(now,))
return df['cnt'].iloc[0] if not df.empty else 0
except Exception as e:
log.error(f"查询待执行任务数量失败: {str(e)}", exc_info=True)
return 0 # 出错时返回0,避免影响关闭流程
def get_pending_tasks(self) -> List[Dict[str, Any]]:
"""查询所有待执行任务(兼容原有逻辑)"""
try:
tz = pytz.timezone('Asia/Shanghai')
now = datetime.now(tz).replace(tzinfo=None)
sql = """
SELECT *
FROM main_task
WHERE is_active = 1
AND next_run_time <= %s
AND is_running = 0
ORDER BY next_run_time
"""
tasks_df = self.db.query_to_df(sql, params=(now,))
if tasks_df.empty:
log.info("当前任务列表为空,等待新任务加入...")
return []
log.info(f"查询到{len(tasks_df)}个待执行任务")
return tasks_df.to_dict('records')
except Exception as e:
log.error(f"查询待执行任务失败,将重试: {str(e)}", exc_info=True)
return []
+6
View File
@@ -0,0 +1,6 @@
use intelligence_system;
SELECT * FROM main_task
WHERE is_active = 1
AND next_run_time <= %s
AND is_running = 0
ORDER BY next_run_time;
+25 -2
View File
@@ -35,6 +35,7 @@ class CrossPlatformLog:
"""配置跨平台日志处理器""" """配置跨平台日志处理器"""
logger.remove() # 清除默认配置 logger.remove() # 清除默认配置
# 统一控制台输出格式 # 统一控制台输出格式
logger.add( logger.add(
sys.stdout, sys.stdout,
@@ -58,11 +59,33 @@ class CrossPlatformLog:
compression=self._compress_log, compression=self._compress_log,
encoding="utf-8", encoding="utf-8",
level="DEBUG", level="DEBUG",
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {module}:{line} - {message}", # 👇 增加 {extra} 输出,并美化结构
# format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {module}:{line} - {message}{extra_output}",
retention="30 days", retention="30 days",
enqueue=True # 线程安全 enqueue=True,
# 👇 动态处理 extra 字段为可读格式
format=self._format_with_extra, # 使用自定义格式函数
) )
def _format_with_extra(self, record):
# 构造 extra 的可读字符串
extra_str = ""
if record["extra"]:
extra_items = []
for key, value in record["extra"].items():
if key == "extra_output": # 跳过自己,避免递归
continue
value_repr = repr(value)
if len(value_repr) > 200:
value_repr = value_repr[:197] + "..."
extra_items.append(f"\n{key}: {value_repr}")
extra_str = "".join(extra_items)
# 👉 直接将 extra_str 写入 message 或附加字段
record["extra"]["extra_output"] = extra_str
# ✅ 关键:返回的 format 字符串不再引用 {extra_output},而是使用 {extra[extra_output]}
return "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {module}:{line} - {message}{extra[extra_output]}\n"
def _add_error_log(self): def _add_error_log(self):
"""错误日志专用配置""" """错误日志专用配置"""
error_log = self.log_dir / "errors.log" error_log = self.log_dir / "errors.log"
+238 -291
View File
@@ -54,9 +54,10 @@ class MySQLAgent:
if hasattr(self, '_pool') and self._pool: if hasattr(self, '_pool') and self._pool:
return return
# 基础配置 # 基础配置校验
required_keys = ['host', 'port', 'user', 'password', 'database'] required_keys = ['host', 'port', 'user', 'password', 'database']
if not all(key in config for key in required_keys): if not all(key in config for key in required_keys):
log.warning(f"数据库配置缺少必要参数,当前配置: {config}")
raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}") raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}")
self.config = { self.config = {
@@ -74,7 +75,7 @@ class MySQLAgent:
'ssl': config.get('ssl') 'ssl': config.get('ssl')
} }
# 初始化log # 初始化日志
current_platform = platform.system() current_platform = platform.system()
self.log = log.bind(module=f"MySQLAgent({current_platform})") self.log = log.bind(module=f"MySQLAgent({current_platform})")
@@ -85,7 +86,7 @@ class MySQLAgent:
def _create_pool(self) -> PooledDB: def _create_pool(self) -> PooledDB:
"""创建连接池""" """创建连接池"""
try: try:
# 使用包装函数确保线程安全 # 线程安全的连接创建函数
def connect(): def connect():
conn = pymysql.connect(**self.config) conn = pymysql.connect(**self.config)
conn.threadsafety = 1 # 显式设置线程安全级别 conn.threadsafety = 1 # 显式设置线程安全级别
@@ -97,49 +98,43 @@ class MySQLAgent:
maxcached=3, maxcached=3,
maxconnections=self.pool_size, maxconnections=self.pool_size,
blocking=True, blocking=True,
ping=1 ping=1 # 每次获取连接时ping数据库
) )
self.log.info("Connection pool created") self.log.info("连接池创建成功")
return pool return pool
except Exception as e: except Exception as e:
self.log.critical("Failed to create connection pool", self.log.critical("连接池创建失败", error=str(e), exc_info=True)
error=str(e),
exc_info=True)
raise raise
def get_connection(self) -> pymysql.connections.Connection: def get_connection(self) -> pymysql.connections.Connection:
""" """获取数据库连接(修复字符集方法缺失问题)"""
获取数据库连接
Returns:
pymysql.connections.Connection: 数据库连接对象
Raises:
MySQLError: 如果获取连接失败
"""
try: try:
conn = self._pool.connection() conn = self._pool.connection()
# macOS需要特殊处理SSL # 为连接添加字符集方法(兼容SQLAlchemy)
if not hasattr(conn, 'character_set_name'):
def _character_set_name():
return self.config.get('charset', 'utf8mb4')
conn.character_set_name = _character_set_name
# macOS平台SSL特殊处理
if platform.system() == 'Darwin' and self.config.get('ssl'): if platform.system() == 'Darwin' and self.config.get('ssl'):
conn.ping(reconnect=True) conn.ping(reconnect=True)
self.log.trace("Database connection obtained") self.log.trace("获取数据库连接成功")
return conn return conn
except Exception as e: except Exception as e:
error_msg = str(e) error_msg = str(e)
# Windows平台连接超时重试
# Windows特定错误处理
if platform.system() == 'Windows' and "timed out" in error_msg: if platform.system() == 'Windows' and "timed out" in error_msg:
self.log.warning("Windows connection timeout, retrying...") self.log.warning("Windows连接超时,尝试重试...")
return self._retry_connection() return self._retry_connection()
self.log.error("Connection failed", self.log.error("获取连接失败", error=error_msg, exc_info=True)
error=error_msg,
exc_info=True)
raise raise
def _retry_connection(self, max_retries: int = 3) -> pymysql.connections.Connection: def _retry_connection(self, max_retries: int = 3) -> pymysql.connections.Connection:
@@ -147,100 +142,78 @@ class MySQLAgent:
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
conn = self._pool.connection() conn = self._pool.connection()
self.log.info(f"Connection established after {attempt + 1} attempts") self.log.info(f"{attempt + 1}次尝试连接成功")
return conn return conn
except Exception: except Exception:
if attempt == max_retries - 1: if attempt == max_retries - 1:
raise raise
import time import time
time.sleep(1) time.sleep(1) # 重试间隔1秒
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None, def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
parse_dates: Union[List[str], bool] = True) -> pd.DataFrame: parse_dates: Union[List[str], bool] = True) -> pd.DataFrame:
""" """执行SQL查询并返回DataFrame(优化连接管理)"""
执行SQL查询并返回DataFrame conn = None
Args:
sql (str): SQL查询语句
params (Union[tuple, dict, None]): 查询参数
parse_dates (Union[List[str], bool]): 自动解析日期字段
Returns:
pd.DataFrame: 查询结果
Raises:
MySQLError: 如果查询失败
"""
try: try:
self.log.debug("Executing SQL query", sql=sql) self.log.debug("执行SQL查询", sql=sql)
conn = self.get_connection()
with self.get_connection() as conn: # 创建SQLAlchemy引擎(使用静态池避免连接重复创建)
# Linux/macOS需要更长的查询超时 from sqlalchemy import create_engine
if platform.system() != 'Windows': from sqlalchemy.pool import StaticPool
conn.cursor().execute("SET SESSION wait_timeout=600") engine = create_engine(
"mysql+pymysql://",
creator=lambda: conn,
poolclass=StaticPool,
connect_args={'charset': self.config.get('charset', 'utf8mb4')}
)
df = pd.read_sql(sql, conn, params=params, parse_dates=parse_dates) # 执行查询
df = pd.read_sql(sql, engine, params=params, parse_dates=parse_dates)
# Windows平台需要手动关闭游标 self.log.info(f"查询成功,返回{len(df)}行数据")
if platform.system() == 'Windows':
conn.cursor().close()
self.log.info("Query executed successfully", rows=len(df))
return df return df
except Exception as e: except Exception as e:
self.log.error("SQL query failed", self.log.error(f"SQL查询失败{sql}", sql=sql, params=params, error=str(e), exc_info=True)
sql=sql,
params=params,
error=str(e),
exc_info=True)
raise raise
finally:
# 确保连接释放回池
if conn:
try:
conn.close()
except Exception as e:
self.log.warning("关闭连接失败", error=str(e))
def insert_from_df(self, table_name: str, df: pd.DataFrame, def insert_from_df(self, table_name: str, df: pd.DataFrame,
chunk_size: int = 1000, replace: bool = False) -> int: chunk_size: int = 1000, replace: bool = False) -> int:
""" """将DataFrame数据插入到数据库表(优化批量处理)"""
将DataFrame数据插入到数据库表(修复版)
Args:
table_name (str): 目标表名
df (pd.DataFrame): 要插入的数据
chunk_size (int): 分批插入大小
replace (bool): 是否替换现有数据
Returns:
int: 插入的总行数
Raises:
MySQLError: 如果插入失败
"""
if df.empty: if df.empty:
self.log.warning("Attempted to insert empty DataFrame", table=table_name) self.log.warning(f"尝试插入空DataFrame到表{table_name}")
return 0 return 0
self.log.debug("Preparing to insert DataFrame", self.log.debug(f"准备插入DataFrame到表{table_name}", rows=len(df), chunk_size=chunk_size)
table=table_name,
rows=len(df), # 根据平台自动调整批次大小
chunk_size=chunk_size) current_platform = platform.system()
if current_platform == 'Windows' and chunk_size > 500:
chunk_size = 500
self.log.debug(f"Windows平台自动调整批次大小为{chunk_size}")
elif current_platform == 'Linux' and chunk_size < 1000:
chunk_size = 1000
self.log.debug(f"Linux平台自动调整批次大小为{chunk_size}")
try: try:
method = 'replace' if replace else 'append' method = 'replace' if replace else 'append'
total_rows = 0 total_rows = 0
# 创建临时SQLAlchemy引擎(不创建新连接池)
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
# 获取当前连接并包装
conn = self.get_connection() conn = self.get_connection()
# 修复连接对象缺少character_set_name的问题 # 创建SQLAlchemy引擎
if not hasattr(conn, 'character_set_name'): from sqlalchemy import create_engine
conn.character_set_name = lambda: self.config.get('charset', 'utf8mb4') from sqlalchemy.pool import StaticPool
engine = create_engine( engine = create_engine(
"mysql+pymysql://", "mysql+pymysql://",
creator=lambda: conn, creator=lambda: conn,
poolclass=StaticPool, # 使用静态池避免创建新连接 poolclass=StaticPool,
connect_args={ connect_args={
'charset': self.config.get('charset', 'utf8mb4'), 'charset': self.config.get('charset', 'utf8mb4'),
'autocommit': True 'autocommit': True
@@ -249,9 +222,9 @@ class MySQLAgent:
try: try:
for i in range(0, len(df), chunk_size): for i in range(0, len(df), chunk_size):
chunk = df.iloc[i:i + chunk_size] chunk = df.iloc[i:i + chunk_size].copy() # 使用copy避免SettingWithCopyWarning
# macOS需要特殊处理datetime # macOS平台datetime特殊处理
if platform.system() == 'Darwin': if platform.system() == 'Darwin':
for col in chunk.select_dtypes(include=['datetime64']): for col in chunk.select_dtypes(include=['datetime64']):
chunk[col] = chunk[col].dt.strftime('%Y-%m-%d %H:%M:%S') chunk[col] = chunk[col].dt.strftime('%Y-%m-%d %H:%M:%S')
@@ -264,56 +237,37 @@ class MySQLAgent:
method='multi' method='multi'
) )
total_rows += len(chunk) total_rows += len(chunk)
method = 'append' # 第一次之后都使用追加模式 method = 'append' # 首次后使用追加模式
self.log.trace(f"Inserted chunk {i // chunk_size + 1}", self.log.trace(f"插入第{i // chunk_size + 1}批数据", rows=len(chunk), total=total_rows)
rows=len(chunk),
total_inserted=total_rows)
self.log.info("Data inserted successfully", self.log.info(f"数据插入成功,表{table_name}共插入{total_rows}")
table=table_name,
total_rows=total_rows)
return total_rows return total_rows
finally: finally:
# 确保连接正确关闭
engine.dispose() engine.dispose()
conn.close() conn.close()
except Exception as e: except Exception as e:
self.log.error("Data insertion failed", self.log.error(f"数据插入失败,表{table_name}", error=str(e), exc_info=True)
table=table_name,
error=str(e),
exc_info=True)
raise raise
def update_from_df(self, table_name: str, df: pd.DataFrame, def update_from_df(self, table_name: str, df: pd.DataFrame,
key_columns: Union[str, List[str]]) -> int: key_columns: Union[str, List[str]]) -> int:
""" """使用DataFrame数据更新数据库表(优化事务处理)"""
使用DataFrame数据更新数据库表
Args:
table_name (str): 目标表名
df (pd.DataFrame): 包含更新数据
key_columns (Union[str, List[str]]): 用于匹配记录的关键列
Returns:
int: 更新的总行数
Raises:
MySQLError: 如果更新失败
"""
if df.empty: if df.empty:
self.log.warning("Attempted to update with empty DataFrame", table=table_name) self.log.warning(f"尝试用空DataFrame更新表{table_name}")
return 0 return 0
self.log.debug("Preparing to update table from DataFrame", self.log.debug(f"准备从DataFrame更新表{table_name}", key_columns=key_columns, rows=len(df))
table=table_name,
key_columns=key_columns,
rows=len(df))
try: try:
if isinstance(key_columns, str): if isinstance(key_columns, str):
key_columns = [key_columns] key_columns = [key_columns]
# 验证关键列存在性
missing_keys = [key for key in key_columns if key not in df.columns]
if missing_keys:
raise ValueError(f"DataFrame中缺少关键列: {missing_keys}")
total_updated = 0 total_updated = 0
conn = self.begin_transaction() conn = self.begin_transaction()
@@ -322,32 +276,29 @@ class MySQLAgent:
# 获取表结构信息 # 获取表结构信息
table_info = self._get_table_info(table_name) table_info = self._get_table_info(table_name)
columns = [col for col in df.columns if col in table_info] valid_columns = [col for col in df.columns if col in table_info]
if not valid_columns:
self.log.warning(f"DataFrame列与表{table_name}无匹配")
return 0
# 构建UPDATE语句模板 # 构建UPDATE语句
set_clause = ', '.join([f"{col}=%s" for col in columns if col not in key_columns]) set_clause = ', '.join([f"`{col}`=%s" for col in valid_columns if col not in key_columns])
where_clause = ' AND '.join([f"{col}=%s" for col in key_columns]) where_clause = ' AND '.join([f"`{col}`=%s" for col in key_columns])
update_sql = f"UPDATE `{table_name}` SET {set_clause} WHERE {where_clause}"
self.log.trace("生成更新SQL", sql=update_sql)
update_sql = f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}" # 准备更新数据
self.log.trace("Generated update SQL", sql=update_sql)
# 准备数据
update_data = [] update_data = []
for _, row in df.iterrows(): for _, row in df.iterrows():
# SET部分的值 set_values = [row[col] for col in valid_columns if col not in key_columns]
set_values = [row[col] for col in columns if col not in key_columns]
# WHERE部分的值
key_values = [row[col] for col in key_columns] key_values = [row[col] for col in key_columns]
update_data.append(tuple(set_values + key_values)) update_data.append(tuple(set_values + key_values))
# 执行批量更新 # 执行批量更新
cursor.executemany(update_sql, update_data) cursor.executemany(update_sql, update_data)
total_updated = cursor.rowcount total_updated = cursor.rowcount
self.commit_transaction(conn) self.commit_transaction(conn)
self.log.info("Data updated successfully", self.log.info(f"数据更新成功,表{table_name}共更新{total_updated}")
table=table_name,
rows_updated=total_updated)
return total_updated return total_updated
except Exception as e: except Exception as e:
@@ -355,61 +306,44 @@ class MySQLAgent:
raise raise
except Exception as e: except Exception as e:
self.log.error("Data update failed", self.log.error(f"数据更新失败,表{table_name}", error=str(e), exc_info=True)
table=table_name,
error=str(e),
exc_info=True)
raise raise
def _get_table_info(self, table_name: str) -> Dict[str, str]: def _get_table_info(self, table_name: str) -> Dict[str, str]:
""" """获取表结构信息(优化SQL安全性)"""
获取表结构信息 sql = """
SELECT column_name, data_type
Args: FROM information_schema.columns
table_name (str): 表名 WHERE table_schema = %s \
AND table_name = %s \
Returns: """
Dict[str, str]: 列名到类型的映射
Raises:
MySQLError: 如果查询失败
"""
sql = f"""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
"""
params = (self.config['database'], table_name)
try: try:
with self.get_connection() as conn: with self.get_connection() as conn:
cursor = conn.cursor() with conn.cursor() as cursor:
cursor.execute(sql, params) cursor.execute(sql, (self.config['database'], table_name))
result = cursor.fetchall() result = cursor.fetchall()
return {row['column_name']: row['data_type'] for row in result} return {row['column_name']: row['data_type'] for row in result}
except Exception as e: except Exception as e:
self.log.error("Failed to get table info", self.log.error(f"获取表{table_name}结构失败", error=str(e))
table=table_name,
error=str(e))
raise raise
def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]: def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]:
""" """推断DataFrame各列的SQL类型(扩展类型映射)"""
推断DataFrame各列的SQL类型
Args:
df (pd.DataFrame): 输入数据框
Returns:
Dict[str, str]: 列名到SQL类型的映射
"""
type_mapping = { type_mapping = {
'int64': 'BIGINT', 'int64': 'BIGINT',
'int32': 'INT',
'int16': 'SMALLINT',
'int8': 'TINYINT',
'uint64': 'BIGINT UNSIGNED',
'float64': 'DOUBLE', 'float64': 'DOUBLE',
'float32': 'FLOAT',
'datetime64[ns]': 'DATETIME', 'datetime64[ns]': 'DATETIME',
'datetime64[ns, UTC]': 'DATETIME',
'timedelta64[ns]': 'TIME',
'object': 'TEXT', 'object': 'TEXT',
'string': 'VARCHAR(255)',
'bool': 'TINYINT(1)', 'bool': 'TINYINT(1)',
'category': 'VARCHAR(255)' 'category': 'VARCHAR(255)'
} }
@@ -419,217 +353,201 @@ class MySQLAgent:
dtype_str = str(dtype) dtype_str = str(dtype)
sql_types[col] = type_mapping.get(dtype_str, 'TEXT') sql_types[col] = type_mapping.get(dtype_str, 'TEXT')
self.log.debug("Mapped DataFrame types to SQL types", self.log.debug("DataFrame类型映射为SQL类型", mappings=sql_types)
mappings=sql_types)
return sql_types return sql_types
def create_table_from_df(self, table_name: str, df: pd.DataFrame, def create_table_from_df(self, table_name: str, df: pd.DataFrame,
primary_key: Union[str, List[str], None] = None) -> bool: primary_key: Union[str, List[str], None] = None) -> bool:
""" """根据DataFrame结构创建表(增强表结构定义)"""
根据DataFrame结构创建表
Args:
table_name (str): 表名
df (pd.DataFrame): 参考数据框
primary_key (Union[str, List[str], None]): 主键列
Returns:
bool: 是否创建成功
"""
if self.table_exists(table_name): if self.table_exists(table_name):
self.log.warning("Table already exists", table=table_name) self.log.warning(f"{table_name}已存在")
return False return False
self.log.debug("Creating new table from DataFrame schema", self.log.debug(f"根据DataFrame结构创建表{table_name}", columns=list(df.columns))
table=table_name,
columns=list(df.columns))
try: try:
sql_types = self.df_to_sql_type(df) sql_types = self.df_to_sql_type(df)
columns_sql = [] columns_sql = []
for col, sql_type in sql_types.items(): for col, sql_type in sql_types.items():
col_def = f"{col} {sql_type}" # 特殊字段处理
if col.lower() in ['create_time', 'created_at'] and sql_type != 'DATETIME':
col_def = f"`{col}` DATETIME DEFAULT CURRENT_TIMESTAMP"
elif col.lower() in ['update_time', 'updated_at'] and sql_type != 'DATETIME':
col_def = f"`{col}` DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"
else:
col_def = f"`{col}` {sql_type}"
columns_sql.append(col_def) columns_sql.append(col_def)
# 处理主键
if primary_key: if primary_key:
if isinstance(primary_key, str): if isinstance(primary_key, str):
primary_key = [primary_key] primary_key = [primary_key]
pk_columns = [col for col in primary_key if col in sql_types] pk_columns = [f"`{col}`" for col in primary_key if col in sql_types]
if pk_columns: if pk_columns:
columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})") columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})")
self.log.trace("Set primary key", self.log.trace(f"{table_name}设置主键", primary_key=pk_columns)
table=table_name,
primary_key=pk_columns)
create_sql = f"CREATE TABLE {table_name} (\n {',\n '.join(columns_sql)}\n)"
create_sql = f"CREATE TABLE `{table_name}` (\n {',\n '.join(columns_sql)}\n)"
self.execute_sql(create_sql) self.execute_sql(create_sql)
self.log.info("Table created successfully", table=table_name) self.log.info(f"{table_name}创建成功")
return True return True
except Exception as e: except Exception as e:
self.log.error("Failed to create table", self.log.error(f"{table_name}创建失败", error=str(e), exc_info=True)
table=table_name,
error=str(e),
exc_info=True)
return False return False
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None, def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]: fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
""" """执行SQL语句(增强资源管理)"""
执行SQL语句
Args:
sql (str): SQL语句
params (Union[tuple, dict, None]): 参数
fetch (bool): 是否获取结果
Returns:
Union[int, List[Dict[str, Any]]]:
- 如果是INSERT/UPDATE/DELETE,返回影响的行数
- 如果是SELECT且fetch=True,返回结果列表
"""
conn = None conn = None
cursor = None cursor = None
try: try:
conn = self.get_connection() conn = self.get_connection()
cursor = conn.cursor() cursor = conn.cursor()
# Linux/macOS需要更长的执行时间 # 非Windows平台延长执行超时
if platform.system() != 'Windows': if platform.system() != 'Windows':
cursor.execute("SET SESSION max_execution_time=600000") cursor.execute("SET SESSION max_execution_time=600000") # 10分钟
cursor.execute(sql, params) cursor.execute(sql, params)
if fetch: if fetch:
result = cursor.fetchall() result = cursor.fetchall()
self.log.debug("Query executed", rows=len(result)) self.log.debug(f"查询执行完成,返回{len(result)}")
return result return result
else: else:
affected_rows = cursor.rowcount affected_rows = cursor.rowcount
self.log.debug("Update executed", affected_rows=affected_rows) self.log.debug(f"更新执行完成,影响{affected_rows}")
return affected_rows return affected_rows
except Exception as e: except Exception as e:
self.log.error("SQL execution failed", self.log.error("SQL执行失败", sql=sql, params=params, error=str(e), exc_info=True)
sql=sql,
params=params,
error=str(e),
exc_info=True)
raise raise
finally: finally:
if cursor: if cursor:
cursor.close() try:
cursor.close()
except Exception as e:
self.log.warning("关闭游标失败", error=str(e))
if conn: if conn:
conn.close() try:
conn.close()
except Exception as e:
self.log.warning("关闭连接失败", error=str(e))
def begin_transaction(self) -> pymysql.connections.Connection: def begin_transaction(self) -> pymysql.connections.Connection:
"""开始事务""" """开始事务(增强隔离级别处理)"""
try: try:
conn = self.get_connection() conn = self.get_connection()
conn.autocommit(False) conn.autocommit(False)
# macOS需要特殊处理事务隔离级别 # 平台特定事务配置
if platform.system() == 'Darwin': if platform.system() == 'Darwin':
conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED") conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED")
elif platform.system() == 'Linux':
conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL REPEATABLE READ")
self.log.debug("Transaction started") self.log.debug("事务开始")
return conn return conn
except Exception as e: except Exception as e:
self.log.error("Begin transaction_failed", error=str(e)) self.log.error("事务开始失败", error=str(e))
raise raise
def commit_transaction(self, conn: pymysql.connections.Connection) -> None: def commit_transaction(self, conn: pymysql.connections.Connection) -> None:
"""提交事务""" """提交事务"""
try: try:
conn.commit() conn.commit()
self.log.debug("Transaction committed") self.log.debug("事务提交成功")
except Exception as e: except Exception as e:
self.log.error("Commit failed", error=str(e)) self.log.error("事务提交失败", error=str(e))
raise raise
finally: finally:
conn.close() try:
conn.close()
except Exception as e:
self.log.warning("事务提交后关闭连接失败", error=str(e))
def rollback_transaction(self, conn: pymysql.connections.Connection) -> None: def rollback_transaction(self, conn: pymysql.connections.Connection) -> None:
"""回滚事务""" """回滚事务"""
try: try:
conn.rollback() conn.rollback()
self.log.warning("Transaction rolled back") self.log.warning("事务已回滚")
except Exception as e: except Exception as e:
self.log.error("Rollback failed", error=str(e)) self.log.error("事务回滚失败", error=str(e))
finally: finally:
conn.close() try:
conn.close()
except Exception as e:
self.log.warning("事务回滚后关闭连接失败", error=str(e))
def table_exists(self, table_name: str) -> bool: def table_exists(self, table_name: str) -> bool:
"""检查表是否存在""" """检查表是否存在(优化SQL安全性)"""
sql = """ sql = """
SELECT COUNT(*) as count SELECT COUNT(*) as count
FROM `information_schema`.`tables` FROM `information_schema`.`tables`
WHERE `table_schema` = %s AND `table_name` = %s WHERE `table_schema` = %s \
""" AND `table_name` = %s \
"""
params = (self.config['database'], table_name)
try: try:
result = self.execute_sql(sql, params, fetch=True) result = self.execute_sql(sql, (self.config['database'], table_name), fetch=True)
exists = result[0]['count'] > 0 exists = result[0]['count'] > 0
self.log.debug("Checked table existence", self.log.debug(f"{table_name}存在性检查", exists=exists)
table=table_name,
exists=exists)
return exists return exists
except Exception: except Exception as e:
self.log.warning(f"{table_name}存在性检查失败", error=str(e))
return False return False
def drop_table(self, table_name: str) -> bool: def drop_table(self, table_name: str) -> bool:
"""删除表""" """删除表(增加二次确认日志)"""
if not self.table_exists(table_name): if not self.table_exists(table_name):
self.log.warning("Table does not exist", table=table_name) self.log.warning(f"{table_name}不存在,无法删除")
return False return False
try: try:
self.execute_sql(f"DROP TABLE {table_name}") self.execute_sql(f"DROP TABLE `{table_name}`")
self.log.info("Table dropped successfully", table=table_name) self.log.info(f"{table_name}删除成功")
return True return True
except Exception as e: except Exception as e:
self.log.error("Failed to drop table", self.log.error(f"{table_name}删除失败", error=str(e), exc_info=True)
table=table_name,
error=str(e),
exc_info=True)
return False return False
def get_pool_status(self) -> Dict[str, int]: def get_pool_status(self) -> Dict[str, int]:
"""获取连接池状态""" """获取连接池状态"""
return { status = {
'max': self._pool._maxconnections, 'max_connections': self._pool._maxconnections,
'active': self._pool._connections, 'active_connections': len(self._pool._connections),
'idle': len(self._pool._idle_cache), 'idle_connections': len(self._pool._idle_cache),
'shared': len(self._pool._shared_cache) 'shared_connections': len(self._pool._shared_cache)
} }
self.log.debug("连接池状态", **status)
return status
def validate_connection(self) -> bool: def validate_connection(self) -> bool:
"""验证连接是否有效""" """验证连接是否有效(增强健康检查)"""
try: try:
with self.get_connection() as conn: with self.get_connection() as conn:
with conn.cursor() as cursor: with conn.cursor() as cursor:
cursor.execute("SELECT 1") cursor.execute("SELECT 1 AS health_check")
return cursor.fetchone()[0] == 1 result = cursor.fetchone()
except Exception: return result['health_check'] == 1
except Exception as e:
self.log.warning("连接健康检查失败", error=str(e))
return False return False
def __del__(self): def __del__(self):
"""析构函数""" """析构函数(确保连接池关闭)"""
if hasattr(self, '_pool'): if hasattr(self, '_pool') and self._pool:
try: try:
self._pool.close() self._pool.close()
self.log.info("Connection pool closed") self.log.info("连接池已关闭")
except Exception as e: except Exception as e:
self.log.error("Failed to close pool", error=str(e)) self.log.error("连接池关闭失败", error=str(e))
# 平台特定的默认配置
def get_default_config(): def get_default_config():
"""获取各平台默认配置""" """获取各平台默认配置(优化默认参数)"""
current_platform = platform.system() current_platform = platform.system()
base_config = { base_config = {
@@ -638,7 +556,8 @@ def get_default_config():
'user': 'root', 'user': 'root',
'password': '123123', 'password': '123123',
'database': 'intelligence', 'database': 'intelligence',
'max_connections': 5 'max_connections': 10, # 增加默认连接数
'charset': 'utf8mb4'
} }
if current_platform == 'Windows': if current_platform == 'Windows':
@@ -646,38 +565,66 @@ def get_default_config():
**base_config, **base_config,
'connect_timeout': 10, 'connect_timeout': 10,
'read_timeout': 30, 'read_timeout': 30,
'write_timeout': 30 'write_timeout': 30,
'ssl': None # Windows默认禁用SSL
} }
elif current_platform == 'Darwin': elif current_platform == 'Darwin': # macOS
return { return {
**base_config, **base_config,
'connect_timeout': 15, 'connect_timeout': 15,
'read_timeout': 60, 'read_timeout': 60,
'write_timeout': 60, 'write_timeout': 60,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'} 'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'} # macOS默认SSL配置
} }
else: # Linux其他平台 else: # Linux其他平台
return { return {
**base_config, **base_config,
'connect_timeout': 15, 'connect_timeout': 15,
'read_timeout': 60, 'read_timeout': 60,
'write_timeout': 60 'write_timeout': 60,
'ssl': None # Linux默认禁用SSL
} }
if __name__ == "__main__": if __name__ == "__main__":
# 使用示例 # 使用示例
db = MySQLAgent(get_default_config()) try:
db = MySQLAgent(get_default_config())
# 测试连接 # 测试连接
if db.validate_connection(): if db.validate_connection():
print("Database connection successful") print("数据库连接成功")
# 获取数据库版本 # 获取数据库版本
version = db.query_to_df("SELECT VERSION() as version") version_df = db.query_to_df("SELECT VERSION() as version")
print(f"Database version: {version['version'].iloc[0]}") print(f"数据库版本: {version_df['version'].iloc[0]}")
# 查看连接池状态 # 查看连接池状态
print("Connection pool status:", db.get_pool_status()) print("连接池状态:", db.get_pool_status())
else:
print("Failed to connect to database") # 创建测试表
test_df = pd.DataFrame({
'id': [1, 2, 3],
'name': ['测试1', '测试2', '测试3'],
'value': [10.5, 20.3, 30.8],
'created_at': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03'])
})
db.create_table_from_df('test_table', test_df, primary_key='id')
print("测试表创建成功")
# 插入数据
rows_inserted = db.insert_from_df('test_table', test_df)
print(f"插入了{rows_inserted}行数据")
# 查询数据
result_df = db.query_to_df("SELECT * FROM test_table")
print("查询结果:")
print(result_df)
# 清理测试表
db.drop_table('test_table')
print("测试表已删除")
else:
print("数据库连接失败")
except Exception as e:
print(f"示例执行失败: {str(e)}")