Compare commits
6 Commits
1dfc5f1024
..
master
| Author | SHA1 | Date | |
|---|---|---|---|
| c894e344aa | |||
| 5d1155bd20 | |||
| fc18fa74c3 | |||
| c5f6e8288d | |||
| e1db06dd79 | |||
| fd67231866 |
Generated
+7
@@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="DataSourcePerFileMappings">
|
||||
<file url="file://$PROJECT_DIR$/tools/SQL.sql" value="36976640-4e4b-40d7-80c5-f77ff8c735e5" />
|
||||
<file url="file://$PROJECT_DIR$/tools/情报收集.sql" value="36976640-4e4b-40d7-80c5-f77ff8c735e5" />
|
||||
</component>
|
||||
</project>
|
||||
Generated
+15
@@ -0,0 +1,15 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PublishConfigData" uploadOnCheckin="ee95b41f-4bcf-4810-b328-4a7a4f66093f" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||
<serverData>
|
||||
<paths name="gitea">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
</serverData>
|
||||
<option name="myUploadOnCheckinName" value="gitea" />
|
||||
</component>
|
||||
</project>
|
||||
Generated
+12
@@ -0,0 +1,12 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="MaterialThemeProjectNewConfig">
|
||||
<option name="metadata">
|
||||
<MTProjectMetadataState>
|
||||
<option name="migrated" value="true" />
|
||||
<option name="pristineConfig" value="false" />
|
||||
<option name="userId" value="-2834c26c:198a1f98ccf:-7ffe" />
|
||||
</MTProjectMetadataState>
|
||||
</option>
|
||||
</component>
|
||||
</project>
|
||||
Generated
+1
-1
@@ -3,5 +3,5 @@
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.13 (intelligence_system)" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="intelligence" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="intelligence_system" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
Generated
+6
-1
@@ -1,7 +1,12 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="SqlDialectMappings">
|
||||
<file url="file://$PROJECT_DIR$/tools/SQL.sql" dialect="MySQL" />
|
||||
<file url="file://$PROJECT_DIR$/tools/情报收集.sql" dialect="MySQL" />
|
||||
<file url="PROJECT" dialect="MySQL" />
|
||||
</component>
|
||||
<component name="SqlResolveMappings">
|
||||
<file url="file://$PROJECT_DIR$/storage/mysql_agent.py" scope="{"node":{ "@negative":"1", "group":{ "@kind":"root", "node":{ "@negative":"1" } } }}" />
|
||||
<file url="file://$PROJECT_DIR$/utils/mysql_agent.py" scope="{"node":{ "@negative":"1", "group":{ "@kind":"root", "node":{ "@negative":"1" } } }}" />
|
||||
<file url="PROJECT" scope="{"node":{ "@negative":"1", "group":{ "@kind":"root", "node":{ "@negative":"1" } } }}" />
|
||||
</component>
|
||||
</project>
|
||||
Generated
+14
@@ -0,0 +1,14 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="WebServers">
|
||||
<option name="servers">
|
||||
<webServer id="ee95b41f-4bcf-4810-b328-4a7a4f66093f" name="gitea" url="">
|
||||
<fileTransfer accessType="WEBDAV" port="6180">
|
||||
<advancedOptions>
|
||||
<advancedOptions dataProtectionLevel="Private" passiveMode="true" shareSSLContext="true" />
|
||||
</advancedOptions>
|
||||
</fileTransfer>
|
||||
</webServer>
|
||||
</option>
|
||||
</component>
|
||||
</project>
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,326 @@
|
||||
import feedparser
|
||||
import requests
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from loguru import logger
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
# Add the parent directory to the Python path to find utils module
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parent_dir = os.path.dirname(current_dir)
|
||||
if parent_dir not in sys.path:
|
||||
sys.path.insert(0, parent_dir)
|
||||
|
||||
from utils.mysql_agent import MySQLAgent
|
||||
|
||||
# 数据库连接配置
|
||||
local_DB_Config = {
|
||||
'host': "localhost",
|
||||
'user': "root",
|
||||
'password': "123123",
|
||||
'database': "intelligence_system",
|
||||
'port': 3306,
|
||||
'charset': 'utf8mb4',
|
||||
'connect_timeout': 10,
|
||||
'read_timeout': 30,
|
||||
'write_timeout': 30,
|
||||
'autocommit': True
|
||||
}
|
||||
|
||||
# 目标数据表名
|
||||
table_name = "collector_rss_subscriptions"
|
||||
|
||||
|
||||
class NewsAPIClient:
|
||||
"""新闻API客户端,用于获取和处理RSS源数据并写入数据库"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化客户端并建立数据库连接"""
|
||||
self.logger = logger.bind(module="NewsAPIClient")
|
||||
self.db_agent = MySQLAgent(local_DB_Config)
|
||||
self.logger.info("新闻API客户端初始化完成,已连接到数据库")
|
||||
|
||||
def _format_result(self, success: bool, message: str = "", data: Optional[Any] = None) -> Dict[str, Any]:
|
||||
"""统一返回结果格式"""
|
||||
return {
|
||||
'success': bool(success),
|
||||
'message': str(message),
|
||||
'data': data
|
||||
}
|
||||
|
||||
def verify_database(self) -> bool:
|
||||
"""验证数据库表结构是否符合要求(适配元组格式的查询结果)"""
|
||||
try:
|
||||
# 1. 检查表是否存在(execute_sql返回元组列表,如 [(table_name,)])
|
||||
result = self.db_agent.execute_sql(
|
||||
f"SHOW TABLES LIKE '{table_name}'",
|
||||
fetch=True
|
||||
)
|
||||
# 元组结果需通过索引0判断(若表存在,result是[(table_name,)], 否则为空列表)
|
||||
if not result:
|
||||
self.logger.error(f"表 {table_name} 不存在,请先创建表结构")
|
||||
return False
|
||||
|
||||
# 2. 检查表字段是否完整(DESCRIBE返回的元组格式:(字段名, 类型, 是否为空, ...))
|
||||
desc_result = self.db_agent.execute_sql(
|
||||
f"DESCRIBE {table_name}",
|
||||
fetch=True
|
||||
)
|
||||
# 关键修改:用元组索引0提取字段名(而非字典键'Field')
|
||||
columns = [col[0] for col in desc_result] # col是元组,col[0]即字段名
|
||||
required_columns = ['文章标题', '文章链接', '文章摘要', '发布时间',
|
||||
'来源URL', '创建时间', '更新时间']
|
||||
missing_cols = [col for col in required_columns if col not in columns]
|
||||
|
||||
if missing_cols:
|
||||
self.logger.error(f"表 {table_name} 缺少必要字段:{missing_cols}")
|
||||
return False
|
||||
|
||||
self.logger.info(f"数据库表结构验证通过,当前字段:{columns}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"数据库验证失败: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def load_last_update_time(self) -> Optional[datetime]:
|
||||
"""加载上次更新时间缓存"""
|
||||
cache_file = os.path.join(os.getcwd(), 'output', 'last_update.pkl')
|
||||
if os.path.exists(cache_file):
|
||||
try:
|
||||
with open(cache_file, 'rb') as f:
|
||||
last_update = pickle.load(f)
|
||||
self.logger.debug(f"加载上次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
return last_update
|
||||
except Exception as e:
|
||||
self.logger.error(f"加载上次更新时间失败: {str(e)}", exc_info=True)
|
||||
self.logger.debug("未找到上次更新时间缓存,将获取全部数据")
|
||||
return None
|
||||
|
||||
def save_last_update_time(self, last_update: datetime) -> None:
|
||||
"""保存本次更新时间"""
|
||||
try:
|
||||
cache_dir = os.path.join(os.getcwd(), 'output')
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
cache_file = os.path.join(cache_dir, 'last_update.pkl')
|
||||
|
||||
with open(cache_file, 'wb') as f:
|
||||
pickle.dump(last_update, f)
|
||||
self.logger.debug(f"已保存本次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"保存更新时间失败: {str(e)}", exc_info=True)
|
||||
|
||||
def fetch_single_rss(self, url: str, timeout: int = 15) -> Optional[feedparser.FeedParserDict]:
|
||||
"""获取并解析单个RSS源"""
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
||||
}
|
||||
|
||||
for attempt in range(3):
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=timeout)
|
||||
response.raise_for_status()
|
||||
response.encoding = response.apparent_encoding
|
||||
feed = feedparser.parse(response.text)
|
||||
|
||||
if feed.bozo:
|
||||
self.logger.warning(f"解析 {url} 存在潜在问题: {feed.bozo_exception}")
|
||||
|
||||
self.logger.debug(f"成功获取 {url} 的RSS数据")
|
||||
return feed
|
||||
|
||||
except requests.RequestException as e:
|
||||
self.logger.warning(f"第 {attempt + 1} 次获取 {url} 失败: {str(e)}")
|
||||
if attempt < 2:
|
||||
time.sleep(3 * (attempt + 1)) # 指数退避重试
|
||||
continue
|
||||
|
||||
self.logger.error(f"三次尝试后仍无法获取 {url} 的RSS数据")
|
||||
return None
|
||||
|
||||
def fetch_all_rss(self, urls: List[str], timeout: int = 15) -> Dict[str, feedparser.FeedParserDict]:
|
||||
"""并发获取多个RSS源"""
|
||||
feeds = {}
|
||||
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||
future_to_url = {executor.submit(self.fetch_single_rss, url, timeout): url for url in urls}
|
||||
|
||||
for future in as_completed(future_to_url):
|
||||
url = future_to_url[future]
|
||||
try:
|
||||
feed = future.result()
|
||||
if feed:
|
||||
feeds[url] = feed
|
||||
except Exception as e:
|
||||
self.logger.error(f"处理 {url} 时发生异常: {str(e)}", exc_info=True)
|
||||
|
||||
self.logger.info(f"RSS源获取完成,成功获取 {len(feeds)}/{len(urls)} 个源")
|
||||
return feeds
|
||||
|
||||
def process_feed_entry(self, entry: Dict[str, Any], url: str) -> Dict[str, str]:
|
||||
"""处理单个RSS条目,转换为数据库兼容格式"""
|
||||
# 处理标题
|
||||
title = entry.get('title', '无标题') or '无标题'
|
||||
if len(title) > 255:
|
||||
title = title[:252] + '...'
|
||||
|
||||
# 处理链接
|
||||
link = entry.get('link', '无链接') or '无链接'
|
||||
if len(link) > 1024:
|
||||
link = link[:1021] + '...'
|
||||
|
||||
# 处理摘要
|
||||
summary = entry.get('summary', '无内容摘要')
|
||||
content_list = entry.get('content', [])
|
||||
content = content_list[0].value if (content_list and hasattr(content_list[0], 'value')) else ''
|
||||
description = summary if summary != '无内容摘要' else (content[:200] + '...' if content else '无内容摘要')
|
||||
|
||||
# 处理发布时间
|
||||
published_parsed = entry.get('published_parsed') or entry.get('updated_parsed')
|
||||
if published_parsed:
|
||||
entry_time = datetime(*published_parsed[:6])
|
||||
else:
|
||||
pub_str = entry.get('published', entry.get('updated', ''))
|
||||
try:
|
||||
entry_time = datetime.strptime(pub_str, '%a, %d %b %Y %H:%M:%S %z').replace(tzinfo=None)
|
||||
except:
|
||||
entry_time = datetime.now()
|
||||
|
||||
# 处理来源URL
|
||||
source_url = url or '未知来源'
|
||||
if len(source_url) > 1024:
|
||||
source_url = source_url[:1021] + '...'
|
||||
|
||||
# 当前时间(创建/更新时间)
|
||||
current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
return {
|
||||
'文章标题': title,
|
||||
'文章链接': link,
|
||||
'文章摘要': description,
|
||||
'发布时间': entry_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
'来源URL': source_url,
|
||||
'创建时间': current_time,
|
||||
'更新时间': current_time
|
||||
}
|
||||
|
||||
def display_feed_info(self, feed: feedparser.FeedParserDict, last_update: Optional[datetime] = None,
|
||||
url: Optional[str] = None) -> Optional[datetime]:
|
||||
"""处理RSS源信息并写入数据库"""
|
||||
if not feed:
|
||||
self.logger.warning("无法处理空的RSS源数据")
|
||||
return None
|
||||
|
||||
self.logger.info(f"开始处理 RSS 源: {url}")
|
||||
entries = feed.entries
|
||||
data_list = []
|
||||
new_last_update = last_update
|
||||
|
||||
for i, entry in enumerate(entries, 1):
|
||||
entry_data = self.process_feed_entry(entry, url)
|
||||
entry_time = datetime.strptime(entry_data['发布时间'], '%Y-%m-%d %H:%M:%S')
|
||||
|
||||
# 过滤旧数据
|
||||
if last_update and entry_time <= last_update:
|
||||
continue
|
||||
|
||||
# 更新最新时间戳
|
||||
if new_last_update is None or entry_time > new_last_update:
|
||||
new_last_update = entry_time
|
||||
|
||||
self.logger.debug(f"处理条目 {i}: {entry_data['文章标题']}")
|
||||
data_list.append(entry_data)
|
||||
|
||||
# 写入数据库
|
||||
if data_list:
|
||||
df = pd.DataFrame(data_list)
|
||||
self.write_to_database(df)
|
||||
|
||||
return new_last_update
|
||||
|
||||
# rss_subscriptions.py 中的 write_to_database 方法可以保持简洁
|
||||
def write_to_database(self, df: pd.DataFrame) -> Dict[str, Any]:
|
||||
if df.empty:
|
||||
self.logger.info("没有新数据需要写入数据库")
|
||||
return self._format_result(True, "没有新数据需要写入")
|
||||
|
||||
try:
|
||||
inserted_rows = self.db_agent.insert_from_df(
|
||||
table_name=table_name,
|
||||
df=df,
|
||||
chunk_size=500,
|
||||
ignore_duplicates=True
|
||||
)
|
||||
|
||||
self.logger.info(f"成功写入 {inserted_rows}/{len(df)} 条记录")
|
||||
return self._format_result(
|
||||
True,
|
||||
f"成功写入 {inserted_rows}/{len(df)} 条记录",
|
||||
{"success_count": inserted_rows, "total": len(df)}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
"数据库写入失败",
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
table_name=table_name,
|
||||
record_count=len(df),
|
||||
sample_records=df.head(2).to_dict('records') if not df.empty else [],
|
||||
exc_info=True
|
||||
)
|
||||
return self._format_result(False, f"数据库操作失败: {str(e)}")
|
||||
@classmethod
|
||||
def main(cls):
|
||||
"""主函数入口"""
|
||||
try:
|
||||
client = cls()
|
||||
|
||||
# 验证数据库
|
||||
if not client.verify_database():
|
||||
client.logger.error("数据库验证失败,程序终止")
|
||||
return
|
||||
|
||||
# RSS源列表
|
||||
rss_urls = [
|
||||
"https://www.chinanews.com.cn/rss/finance.xml",
|
||||
"https://www.chinanews.com.cn/rss/world.xml",
|
||||
"https://www.chinanews.com.cn/rss/china.xml",
|
||||
"https://www.chinanews.com.cn/rss/scroll-news.xml"
|
||||
]
|
||||
|
||||
# 加载上次更新时间
|
||||
last_update = client.load_last_update_time()
|
||||
if last_update:
|
||||
client.logger.info(f"上次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
# 获取RSS数据
|
||||
client.logger.info("开始获取RSS源数据...")
|
||||
start_time = time.time()
|
||||
feeds = client.fetch_all_rss(rss_urls)
|
||||
client.logger.info(f"获取完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
|
||||
# 处理并写入数据
|
||||
new_last_update = None
|
||||
for url, feed in feeds.items():
|
||||
current_last_update = client.display_feed_info(feed, last_update, url)
|
||||
if current_last_update and (new_last_update is None or current_last_update > new_last_update):
|
||||
new_last_update = current_last_update
|
||||
|
||||
# 保存最新更新时间
|
||||
if new_last_update:
|
||||
client.save_last_update_time(new_last_update)
|
||||
client.logger.info(f"本次最新更新时间: {new_last_update.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
else:
|
||||
client.logger.info("没有获取到新内容")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"程序运行出错: {str(e)}", exc_info=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
NewsAPIClient.main()
|
||||
@@ -0,0 +1,44 @@
|
||||
import os
|
||||
|
||||
|
||||
class Config:
|
||||
|
||||
MYSQL_CONFIG = {
|
||||
'host': '123.60.167.249',
|
||||
'port': 3306,
|
||||
'user': 'intelligence',
|
||||
'password': '123123',
|
||||
'database': "intelligence_system",
|
||||
'max_connections': 10
|
||||
}
|
||||
|
||||
OFFLINE_MYSQL_CONFIG = {
|
||||
'host': 'localhost',
|
||||
'port': 3306,
|
||||
'user': 'root',
|
||||
'password': '123123',
|
||||
'database': "intelligence_system",
|
||||
'max_connections': 10
|
||||
}
|
||||
|
||||
MINIO_CONFIG = {
|
||||
'endpoint': '127.0.0.1:9005',
|
||||
'access_key': 'admin',
|
||||
'secret_key': 'abc88888888',
|
||||
'secure': False # 社区版默认不启用SSL
|
||||
}
|
||||
|
||||
# 百度AI API配置(千帆平台)
|
||||
# 优先从环境变量读取,如果没有则使用默认值(需要用户自行配置)
|
||||
BAIDU_AI_CONFIG = {
|
||||
'api_key': os.getenv('BAIDU_API_KEY', 'bce-v3/ALTAK-SFA4vEP3uBYLsyqCZcERg/1f43596d40d9a2c8318b13d5888a5e8e4e7a7f30'), # 百度千帆API Key
|
||||
'model': 'ernie-x1-turbo-32k', # 使用的模型
|
||||
}
|
||||
|
||||
# AI处理器配置
|
||||
AI_PROCESSOR_CONFIG = {
|
||||
'batch_size': 10, # 批量处理的默认大小
|
||||
'delay': 1.5, # 每条记录之间的延迟(秒),避免API限流
|
||||
'source_table': 'processed_rss_data', # 源数据表
|
||||
'result_table': 'ai_processor_rss_analysis', # AI分析结果表
|
||||
}
|
||||
@@ -1,273 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
配置初始化模块
|
||||
功能:
|
||||
1. 自动生成默认配置文件
|
||||
2. 多环境配置支持(dev/test/prod)
|
||||
3. 敏感信息加密存储
|
||||
4. 配置完整性检查与修复
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
import logging
|
||||
from cryptography.fernet import Fernet
|
||||
import hashlib
|
||||
|
||||
# 初始化日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger('config_init')
|
||||
|
||||
class ConfigInitializer:
|
||||
"""配置初始化工具类"""
|
||||
|
||||
def __init__(self, app_name: str = "intelligence_system"):
|
||||
self.system = platform.system().lower()
|
||||
self.app_name = app_name
|
||||
self.config_dir = self._get_config_dir()
|
||||
self.config_file = self.config_dir / "config.json"
|
||||
self.secret_key_file = self.config_dir / ".secret.key"
|
||||
self._fernet = None
|
||||
|
||||
# 确保配置目录存在
|
||||
self.config_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 设置文件权限(非Windows)
|
||||
if self.system != 'windows':
|
||||
os.chmod(self.config_dir, 0o700)
|
||||
|
||||
def _get_config_dir(self) -> Path:
|
||||
"""获取适合当前平台的配置目录路径"""
|
||||
if self.system == 'windows':
|
||||
return Path(os.environ['APPDATA']) / self.app_name
|
||||
elif self.system == 'darwin': # macOS
|
||||
return Path.home() / "Library" / "Application Support" / self.app_name
|
||||
else: # Linux及其他Unix-like
|
||||
xdg_config = os.getenv('XDG_CONFIG_HOME', '~/.config')
|
||||
return Path(xdg_config).expanduser() / self.app_name
|
||||
|
||||
def _init_encryption(self):
|
||||
"""初始化加密模块"""
|
||||
if not self.secret_key_file.exists():
|
||||
self.secret_key_file.write_bytes(Fernet.generate_key())
|
||||
if self.system != 'windows':
|
||||
self.secret_key_file.chmod(0o600) # 仅用户可读写
|
||||
|
||||
self._fernet = Fernet(self.secret_key_file.read_bytes())
|
||||
|
||||
def encrypt_value(self, plaintext: str) -> str:
|
||||
"""加密敏感信息"""
|
||||
if not self._fernet:
|
||||
self._init_encryption()
|
||||
return self._fernet.encrypt(plaintext.encode()).decode()
|
||||
|
||||
def decrypt_value(self, ciphertext: str) -> str:
|
||||
"""解密信息"""
|
||||
if not self._fernet:
|
||||
self._init_encryption()
|
||||
return self._fernet.decrypt(ciphertext.encode()).decode()
|
||||
|
||||
def _get_default_config(self) -> Dict[str, Any]:
|
||||
"""获取默认配置模板"""
|
||||
return {
|
||||
"system": {
|
||||
"env": "dev", # dev/test/prod
|
||||
"log_level": "INFO",
|
||||
"max_threads": max(1, os.cpu_count() or 4),
|
||||
"data_dir": str(self.config_dir / "data")
|
||||
},
|
||||
"api": {
|
||||
"newsapi": {
|
||||
"endpoint": "https://newsapi.org/v2",
|
||||
"key": "" # 需加密存储
|
||||
},
|
||||
"weibo": {
|
||||
"version": "2",
|
||||
"access_token": "" # 需加密存储
|
||||
}
|
||||
},
|
||||
"database": {
|
||||
"type": "sqlite",
|
||||
"path": str(self.config_dir / "data.db")
|
||||
},
|
||||
"network": {
|
||||
"timeout": 30,
|
||||
"retries": 3,
|
||||
"proxy": "" # 示例: http://user:pass@proxy:port
|
||||
}
|
||||
}
|
||||
|
||||
def _migrate_old_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""旧配置迁移(兼容性处理)"""
|
||||
# 示例:将旧版api_key迁移到新版结构
|
||||
if 'api_key' in config:
|
||||
config.setdefault('api', {})['newsapi'] = {
|
||||
'key': config.pop('api_key')
|
||||
}
|
||||
return config
|
||||
|
||||
def _validate_config(self, config: Dict[str, Any]) -> bool:
|
||||
"""验证配置完整性"""
|
||||
required_keys = {
|
||||
"system": ["env", "log_level"],
|
||||
"api/newsapi": ["endpoint"]
|
||||
}
|
||||
|
||||
for path, keys in required_keys.items():
|
||||
current = config
|
||||
for part in path.split('/'):
|
||||
current = current.get(part, {})
|
||||
if not isinstance(current, dict):
|
||||
return False
|
||||
|
||||
for key in keys:
|
||||
if key not in current:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _repair_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""自动修复缺失的配置项"""
|
||||
default_config = self._get_default_config()
|
||||
|
||||
def _merge(current, default):
|
||||
for key, value in default.items():
|
||||
if key not in current:
|
||||
current[key] = value
|
||||
elif isinstance(value, dict):
|
||||
_merge(current[key], value)
|
||||
return current
|
||||
|
||||
return _merge(config, default_config)
|
||||
|
||||
def init_config(self, force: bool = False) -> bool:
|
||||
"""
|
||||
初始化配置文件
|
||||
参数:
|
||||
force: 是否强制重新生成配置
|
||||
返回:
|
||||
bool: 是否创建了新配置
|
||||
"""
|
||||
config = None
|
||||
|
||||
# 已有配置文件且不强制重置
|
||||
if self.config_file.exists() and not force:
|
||||
try:
|
||||
with open(self.config_file, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
# 配置迁移和修复
|
||||
config = self._migrate_old_config(config)
|
||||
if not self._validate_config(config):
|
||||
config = self._repair_config(config)
|
||||
logger.warning("自动修复不完整的配置文件")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载现有配置失败: {str(e)}")
|
||||
config = None
|
||||
|
||||
# 需要创建新配置
|
||||
if config is None:
|
||||
config = self._get_default_config()
|
||||
logger.info("创建新的配置文件")
|
||||
|
||||
# 加密敏感字段
|
||||
self._init_encryption()
|
||||
for field in [
|
||||
"api/newsapi/key",
|
||||
"api/weibo/access_token",
|
||||
"network/proxy"
|
||||
]:
|
||||
parts = field.split('/')
|
||||
current = config
|
||||
for part in parts[:-1]:
|
||||
current = current.setdefault(part, {})
|
||||
|
||||
if parts[-1] in current and current[parts[-1]]:
|
||||
current[parts[-1]] = self.encrypt_value(current[parts[-1]])
|
||||
|
||||
# 保存配置
|
||||
with open(self.config_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(config, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 设置文件权限(非Windows)
|
||||
if self.system != 'windows':
|
||||
os.chmod(self.config_file, 0o600)
|
||||
|
||||
return True
|
||||
|
||||
def get_config_hash(self) -> str:
|
||||
"""获取配置文件哈希值(用于检测变更)"""
|
||||
if not self.config_file.exists():
|
||||
return ""
|
||||
|
||||
with open(self.config_file, 'rb') as f:
|
||||
return hashlib.sha256(f.read()).hexdigest()
|
||||
|
||||
def create_env_specific_config(self, env: str = None) -> bool:
|
||||
"""
|
||||
创建环境特定配置
|
||||
参数:
|
||||
env: 环境类型(dev/test/prod)
|
||||
"""
|
||||
if not self.config_file.exists():
|
||||
self.init_config()
|
||||
|
||||
with open(self.config_file, 'r', encoding='utf-8') as f:
|
||||
base_config = json.load(f)
|
||||
|
||||
env = env or base_config['system']['env']
|
||||
env_config = {
|
||||
f"env_{env}": {
|
||||
"api": {
|
||||
"newsapi": {"endpoint": self._get_env_endpoint(env)}
|
||||
},
|
||||
"database": {
|
||||
"path": str(self.config_dir / f"data_{env}.db")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
env_file = self.config_dir / f"config.{env}.json"
|
||||
with open(env_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(env_config, f, indent=2)
|
||||
|
||||
return True
|
||||
|
||||
def _get_env_endpoint(self, env: str) -> str:
|
||||
"""获取环境特定的API端点"""
|
||||
endpoints = {
|
||||
"dev": "http://dev-api.example.com",
|
||||
"test": "https://test-api.example.com",
|
||||
"prod": "https://api.example.com"
|
||||
}
|
||||
return endpoints.get(env, endpoints['dev'])
|
||||
|
||||
# 快捷初始化函数
|
||||
def init_app_config(app_name: str = None, force: bool = False) -> bool:
|
||||
"""
|
||||
快速初始化应用配置
|
||||
参数:
|
||||
app_name: 应用名称
|
||||
force: 是否强制重新初始化
|
||||
"""
|
||||
return ConfigInitializer(app_name).init_config(force)
|
||||
|
||||
# 测试代码
|
||||
if __name__ == "__main__":
|
||||
# 初始化配置
|
||||
initializer = ConfigInitializer()
|
||||
if initializer.init_config():
|
||||
print("配置文件已生成:", initializer.config_file)
|
||||
|
||||
# 创建环境配置示例
|
||||
initializer.create_env_specific_config("prod")
|
||||
print("生产环境配置已生成")
|
||||
|
||||
# 加密演示
|
||||
encrypted = initializer.encrypt_value("my_secret_key")
|
||||
print("加密示例:", encrypted)
|
||||
print("解密测试:", initializer.decrypt_value(encrypted))
|
||||
@@ -1,409 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import platform
|
||||
import pandas as pd
|
||||
import pymysql
|
||||
from pymysql import cursors
|
||||
from pymysql.err import MySQLError
|
||||
from dbutils.pooled_db import PooledDB
|
||||
from typing import Union, List, Dict, Any, Optional, Tuple
|
||||
import threading
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
# 导入您的日志系统
|
||||
from utils.logger import log as logger
|
||||
|
||||
class MySQLAgent:
|
||||
"""
|
||||
全平台兼容的MySQL数据库操作类
|
||||
支持Windows/macOS/Linux系统
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
# 各平台特定的配置
|
||||
PLATFORM_CONFIG = {
|
||||
'Windows': {
|
||||
'socket_timeout': 30,
|
||||
'connect_timeout': 10,
|
||||
'ssl': None
|
||||
},
|
||||
'Darwin': { # macOS
|
||||
'socket_timeout': 60,
|
||||
'connect_timeout': 15,
|
||||
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
|
||||
},
|
||||
'Linux': {
|
||||
'socket_timeout': 60,
|
||||
'connect_timeout': 15,
|
||||
'ssl': None
|
||||
}
|
||||
}
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if not cls._instance:
|
||||
with cls._lock:
|
||||
if not cls._instance:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
if hasattr(self, '_pool') and self._pool:
|
||||
return
|
||||
|
||||
if not config:
|
||||
from config.settings import DATABASE_CONFIG
|
||||
config = DATABASE_CONFIG
|
||||
|
||||
# 获取当前平台配置
|
||||
current_platform = platform.system()
|
||||
platform_config = self.PLATFORM_CONFIG.get(current_platform, {})
|
||||
|
||||
# 基础配置
|
||||
self.config = {
|
||||
'host': config.get('host', 'localhost'),
|
||||
'port': config.get('port', 3306),
|
||||
'user': config.get('user', 'root'),
|
||||
'password': config.get('password', ''),
|
||||
'database': config.get('database', 'intelligence_system'),
|
||||
'charset': config.get('charset', 'utf8mb4'),
|
||||
'cursorclass': cursors.DictCursor,
|
||||
'autocommit': True,
|
||||
**platform_config # 合并平台特定配置
|
||||
}
|
||||
|
||||
# 处理各平台路径差异
|
||||
if current_platform == 'Windows':
|
||||
self.config['ssl'] = None # Windows通常不需要SSL配置
|
||||
|
||||
# macOS特殊处理
|
||||
elif current_platform == 'Darwin':
|
||||
if not os.path.exists(self.config['ssl']['ca']):
|
||||
self.config['ssl'] = None
|
||||
logger.warning("macOS SSL certificate not found, disabling SSL")
|
||||
|
||||
self.pool_size = config.get('max_connections', 5)
|
||||
self._pool = self._create_pool()
|
||||
self.logger = logger.bind(module=f"MySQLAgent({current_platform})")
|
||||
|
||||
def _create_pool(self) -> PooledDB:
|
||||
"""创建跨平台兼容的连接池"""
|
||||
try:
|
||||
# 各平台连接池参数调整
|
||||
pool_config = {
|
||||
'creator': pymysql,
|
||||
'maxconnections': self.pool_size,
|
||||
'mincached': 1,
|
||||
'maxcached': 3,
|
||||
'blocking': True,
|
||||
'ping': 1, # 定期检查连接有效性
|
||||
**self.config
|
||||
}
|
||||
|
||||
# Windows平台需要更短的超时时间
|
||||
if platform.system() == 'Windows':
|
||||
pool_config['ping'] = 0 # Windows上ping有时不稳定
|
||||
|
||||
pool = PooledDB(**pool_config)
|
||||
self.logger.info(f"Connection pool created for {platform.system()}")
|
||||
return pool
|
||||
|
||||
except Exception as e:
|
||||
self.logger.critical("Failed to create connection pool",
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
def _handle_path(self, path: str) -> str:
|
||||
"""处理跨平台路径问题"""
|
||||
if platform.system() == 'Windows':
|
||||
return path.replace('/', '\\')
|
||||
return path
|
||||
|
||||
def get_connection(self) -> pymysql.connections.Connection:
|
||||
"""
|
||||
获取数据库连接(跨平台兼容)
|
||||
|
||||
Returns:
|
||||
pymysql.connections.Connection: 数据库连接
|
||||
|
||||
Raises:
|
||||
MySQLError: 如果连接失败
|
||||
"""
|
||||
try:
|
||||
conn = self._pool.connection()
|
||||
|
||||
# macOS需要特殊处理SSL
|
||||
if platform.system() == 'Darwin' and self.config.get('ssl'):
|
||||
conn.ping(reconnect=True)
|
||||
|
||||
self.logger.trace("Connection obtained")
|
||||
return conn
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
||||
# Windows特定错误处理
|
||||
if platform.system() == 'Windows' and "timed out" in error_msg:
|
||||
self.logger.warning("Windows connection timeout, retrying...")
|
||||
return self._retry_connection()
|
||||
|
||||
self.logger.error("Connection failed",
|
||||
error=error_msg,
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
def _retry_connection(self, max_retries: int = 3) -> pymysql.connections.Connection:
|
||||
"""Windows平台连接重试机制"""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
conn = self._pool.connection()
|
||||
self.logger.info(f"Connection established after {attempt+1} attempts")
|
||||
return conn
|
||||
except Exception:
|
||||
if attempt == max_retries - 1:
|
||||
raise
|
||||
import time
|
||||
time.sleep(1)
|
||||
|
||||
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
|
||||
parse_dates: Union[List[str], bool] = True) -> pd.DataFrame:
|
||||
"""
|
||||
跨平台兼容的SQL查询
|
||||
|
||||
Args:
|
||||
sql (str): SQL语句
|
||||
params (Union[tuple, dict, None]): 参数
|
||||
parse_dates (Union[List[str], bool]): 日期解析
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: 查询结果
|
||||
"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
# Linux/macOS需要更长的查询超时
|
||||
if platform.system() != 'Windows':
|
||||
conn.cursor().execute("SET SESSION wait_timeout=600")
|
||||
|
||||
df = pd.read_sql(sql, conn, params=params, parse_dates=parse_dates)
|
||||
|
||||
# Windows平台需要手动关闭游标
|
||||
if platform.system() == 'Windows':
|
||||
conn.cursor().close()
|
||||
|
||||
self.logger.info("Query executed", rows=len(df))
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Query failed",
|
||||
sql=sql,
|
||||
params=params,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
def insert_from_df(self, table_name: str, df: pd.DataFrame,
|
||||
chunk_size: int = 1000, replace: bool = False) -> int:
|
||||
"""
|
||||
跨平台数据插入
|
||||
|
||||
Args:
|
||||
table_name (str): 表名
|
||||
df (pd.DataFrame): 数据
|
||||
chunk_size (int): 分批大小
|
||||
replace (bool): 是否替换
|
||||
|
||||
Returns:
|
||||
int: 插入行数
|
||||
"""
|
||||
if df.empty:
|
||||
self.logger.warning("Empty DataFrame", table=table_name)
|
||||
return 0
|
||||
|
||||
try:
|
||||
method = 'replace' if replace else 'append'
|
||||
total_rows = 0
|
||||
|
||||
with self.get_connection() as conn:
|
||||
# 各平台不同的分批策略
|
||||
if platform.system() == 'Windows':
|
||||
chunk_size = min(chunk_size, 500) # Windows上减小批次
|
||||
|
||||
for i in range(0, len(df), chunk_size):
|
||||
chunk = df.iloc[i:i + chunk_size]
|
||||
|
||||
# macOS需要特殊处理datetime
|
||||
if platform.system() == 'Darwin':
|
||||
for col in chunk.select_dtypes(include=['datetime64']):
|
||||
chunk[col] = chunk[col].dt.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
chunk.to_sql(
|
||||
table_name,
|
||||
conn,
|
||||
if_exists=method,
|
||||
index=False,
|
||||
method='multi'
|
||||
)
|
||||
total_rows += len(chunk)
|
||||
method = 'append'
|
||||
|
||||
self.logger.info("Data inserted", table=table_name, rows=total_rows)
|
||||
return total_rows
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Insert failed",
|
||||
table=table_name,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
|
||||
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
|
||||
"""
|
||||
跨平台SQL执行
|
||||
|
||||
Args:
|
||||
sql (str): SQL语句
|
||||
params (Union[tuple, dict, None]): 参数
|
||||
fetch (bool): 是否获取结果
|
||||
|
||||
Returns:
|
||||
Union[int, List[Dict[str, Any]]]: 结果
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = self.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Linux/macOS需要更长的执行时间
|
||||
if platform.system() != 'Windows':
|
||||
cursor.execute("SET SESSION max_execution_time=600000")
|
||||
|
||||
cursor.execute(sql, params)
|
||||
|
||||
if fetch:
|
||||
result = cursor.fetchall()
|
||||
self.logger.debug("Query executed", rows=len(result))
|
||||
return result
|
||||
else:
|
||||
affected_rows = cursor.rowcount
|
||||
self.logger.debug("Update executed", affected_rows=affected_rows)
|
||||
return affected_rows
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("SQL execution failed",
|
||||
sql=sql,
|
||||
params=params,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
def begin_transaction(self) -> pymysql.connections.Connection:
|
||||
"""开始事务(跨平台兼容)"""
|
||||
try:
|
||||
conn = self.get_connection()
|
||||
conn.autocommit(False)
|
||||
|
||||
# macOS需要特殊处理事务隔离级别
|
||||
if platform.system() == 'Darwin':
|
||||
conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED")
|
||||
|
||||
self.logger.debug("Transaction started")
|
||||
return conn
|
||||
except Exception as e:
|
||||
self.logger.error("Begin transaction failed", error=str(e))
|
||||
raise
|
||||
|
||||
def commit_transaction(self, conn: pymysql.connections.Connection) -> None:
|
||||
"""提交事务(跨平台兼容)"""
|
||||
try:
|
||||
conn.commit()
|
||||
self.logger.debug("Transaction committed")
|
||||
except Exception as e:
|
||||
self.logger.error("Commit failed", error=str(e))
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def rollback_transaction(self, conn: pymysql.connections.Connection) -> None:
|
||||
"""回滚事务(跨平台兼容)"""
|
||||
try:
|
||||
conn.rollback()
|
||||
self.logger.warning("Transaction rolled back")
|
||||
except Exception as e:
|
||||
self.logger.error("Rollback failed", error=str(e))
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def __del__(self):
|
||||
"""析构函数(跨平台资源清理)"""
|
||||
if hasattr(self, '_pool'):
|
||||
try:
|
||||
self._pool.close()
|
||||
self.logger.info("Connection pool closed")
|
||||
except Exception as e:
|
||||
self.logger.error("Failed to close pool", error=str(e))
|
||||
|
||||
|
||||
# 平台特定的默认配置
|
||||
def get_default_config():
|
||||
"""获取各平台默认配置"""
|
||||
current_platform = platform.system()
|
||||
|
||||
base_config = {
|
||||
'host': 'localhost',
|
||||
'port': 3306,
|
||||
'user': 'root',
|
||||
'password': '',
|
||||
'database': 'intelligence_system',
|
||||
'max_connections': 5
|
||||
}
|
||||
|
||||
if current_platform == 'Windows':
|
||||
return {
|
||||
**base_config,
|
||||
'connect_timeout': 10,
|
||||
'read_timeout': 30,
|
||||
'write_timeout': 30
|
||||
}
|
||||
elif current_platform == 'Darwin':
|
||||
return {
|
||||
**base_config,
|
||||
'connect_timeout': 15,
|
||||
'read_timeout': 60,
|
||||
'write_timeout': 60,
|
||||
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
|
||||
}
|
||||
else: # Linux和其他平台
|
||||
return {
|
||||
**base_config,
|
||||
'connect_timeout': 15,
|
||||
'read_timeout': 60,
|
||||
'write_timeout': 60
|
||||
}
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
# 自动获取适合当前平台的配置
|
||||
config = get_default_config()
|
||||
|
||||
# 初始化数据库连接
|
||||
db = MySQLAgent(config)
|
||||
|
||||
# 测试查询
|
||||
try:
|
||||
df = db.query_to_df("SELECT VERSION() as version")
|
||||
print(f"Database version: {df['version'].iloc[0]}")
|
||||
print(f"Running on: {platform.system()} {platform.release()}")
|
||||
except Exception as e:
|
||||
print(f"Error: {str(e)}")
|
||||
-120
@@ -1,120 +0,0 @@
|
||||
## 情报收集系统设计
|
||||
|
||||
### 参考文档
|
||||
https://alidocs.dingtalk.com/i/nodes/NZQYprEoWoexdo1ohPdxXvDbJ1waOeDk?utm_scene=team_space
|
||||
|
||||
### 程序框架
|
||||
```angular2html
|
||||
intelligence_system/
|
||||
├── config/ # 系统配置中心
|
||||
│ ├── __init__.py # 配置包初始化
|
||||
│ ├── settings.py # 主配置文件(数据库连接、API密钥等)
|
||||
│ └── scheduler_rules.yaml # 任务调度规则
|
||||
|
||||
├── data_collection/ # 数据采集层
|
||||
│ ├── spiders/ # 网络爬虫子系统
|
||||
│ │ ├── weibo_spider.py # 黑猫爬虫
|
||||
│ │
|
||||
│ ├── api_integration/ # API接口子系统
|
||||
│ │ ├── news_api.py # 新闻接口
|
||||
│ │
|
||||
│ └── internal/ # 内部数据收集
|
||||
│ ├── jian_dao_cloud.py # 简道云表单收集器
|
||||
|
||||
├── data_processing/ # 数据处理层
|
||||
│ ├── structured/ # 结构化数据处理
|
||||
│ │ ├── data_cleaner.py # 数据清洗(去重/标准化)
|
||||
│ │ └── schema_mapper.py # 数据结构转换器
|
||||
│ │
|
||||
│ ├── unstructured/ # 非结构化数据处理
|
||||
│ │ ├── text_parser.py # 文本解析(PDF/HTML等)
|
||||
│ │ ├── image_analyzer.py # 图像识别(OpenCV集成)
|
||||
│ │ └── video_processor.py # 音视频分离分析
|
||||
│ │
|
||||
│ └── ai_engine/ # AI分析核心
|
||||
│ ├── nlp_processor.py # 自然语言处理引擎
|
||||
│ ├── sentiment_analyzer.py # 情感分析模型
|
||||
│ └── topic_modeler.py # LDA主题建模工具
|
||||
|
||||
├── storage/ # 数据存储层
|
||||
│ ├── mysql_agent.py # MySQL读写管理器
|
||||
│ └── query_builder.py # SQL动态构建器
|
||||
|
||||
├── services/ # 应用服务层
|
||||
│ ├── monitoring/ # 舆情监控
|
||||
│ │ ├── opinion_monitor.py # 实时舆情追踪
|
||||
│ │ └── brand_reputation.py # 品牌口碑分析
|
||||
│ │
|
||||
│ ├── analysis/ # 竞品分析
|
||||
│ │ ├── competitor_tracker.py # 竞品动态监控
|
||||
│ │ └── swot_generator.py # SWOT分析报告
|
||||
│ │
|
||||
│ ├── reporting/ # 报告服务
|
||||
│ │ ├── daily_reporter.py # 自动化日报生成
|
||||
│ │ └── weekly_digest.py # 周报汇编系统
|
||||
│ │
|
||||
│ └── alert/ # 预警服务
|
||||
│ ├── alert_trigger.py # 动态阈值告警
|
||||
│ └── notification_center.py # 邮件/短信通知
|
||||
|
||||
├── system_management/ # 系统管理层
|
||||
│ ├── scheduler/ # 任务调度
|
||||
│ │ └── task_scheduler.py # 任务调度器
|
||||
│ │
|
||||
│ └── monitor/ # 系统监控
|
||||
│ ├── health_monitor.py # 服务健康检测
|
||||
│ └── performance_watcher.py # 资源占用监控
|
||||
|
||||
├── utils/ # 工具库
|
||||
│ ├── file_handler.py # 通用文件操作
|
||||
│ ├── logger.py # 日志系统
|
||||
│ └── datetime_parser.py # 时间格式处理
|
||||
|
||||
└── main.py # 系统入口(启动所有服务)
|
||||
```
|
||||
### 程序设计原则
|
||||
1. 所有程序尽可能在py文件中运行,尽量避免使用命令行执行
|
||||
2. 配置需要在配置类中定义
|
||||
3. 密钥等信息直接放在配置类中
|
||||
|
||||
### 主程序设计
|
||||
主程序需要一次启动,一直运行,启动时运行一次(在代码中可取消),之后每天定时生成一次报告
|
||||
|
||||
主程序包含爬虫/api调度器。该调度器通过查询mysql中任务调度情况按需执行,db文件中应包含任务名称、
|
||||
任务路径、任务执行频率(支持按天、按周,按分钟)、上次执行时间、下次执行时间等信息
|
||||
|
||||
主程序应包含数据处理调度器,根据数据类别分别处理,如文本数据处理调度器、图片数据处理调度器等,
|
||||
每天定时拉取db获取到的原始数据,分别进行处理,处理完成后将结果保存到mysql中
|
||||
|
||||
主程序应包含日报、周报等生成,根据时间定时生成报告,报告需要存储
|
||||
|
||||
### 日志设计
|
||||
日志系统应兼容多个平台,如win、mac和linux,日志需要保存为log文件,并且在日志大于20mb时自动压缩
|
||||
|
||||
### 数据库链接设计
|
||||
数据存储放在数据库中,数据库类型为mysql,数据库名称为intelligence_system
|
||||
|
||||
数据库表的命名规则与目录一致,数据采集类以collector_为开头,数据处理类以processor_为开
|
||||
头,数据存储类以storage_为开头,应用层类以application_为开头
|
||||
依次类推。
|
||||
|
||||
数据库链接为通用配置,要求数据采集或处理类等,可以直接调用封装好的数据库
|
||||
链接,不必每次都重新写,
|
||||
该链接包含表的增删改查功能,以及执行sql语句功能
|
||||
|
||||
数据库结构:
|
||||
1. collector_news_api:新闻api数据表
|
||||
2. collector_complaint_spider:投诉数据表
|
||||
3. processor_text_processor:文本处理数据表
|
||||
4. processor_image_processor:图片处理数据表
|
||||
5. main_task 任务调度表
|
||||
6. application_reporter_daily:日报数据表
|
||||
7. application_reporter_monthly:周报数据表
|
||||
|
||||
### 数据采集设计
|
||||
每一个数据采集均为独立python文件,里面执行主程序均为main,以方便调度
|
||||
每一个数据采集均会根据规则创建数据库表,数据处理类以processor_为开头,(或者统一维护到一个表中,按来源去区分)
|
||||
|
||||
### 数据处理
|
||||
从多个数据库库表中获取数据,对数据进行处理,处理完成后将结果保存到数据库中,处理结果可能存储在多个表中
|
||||
数据处理数据库表以processor_为开头
|
||||
@@ -0,0 +1,27 @@
|
||||
# 列出所有任务
|
||||
python system_management/scheduler/task_management.py list
|
||||
|
||||
# 只显示活跃任务
|
||||
python system_management/scheduler/task_management.py list --active-only
|
||||
|
||||
# 查看任务详情
|
||||
python system_management/scheduler/task_management.py show 1
|
||||
|
||||
# 更新任务Cron表达式
|
||||
python system_management/scheduler/task_management.py update <task_id> --cron "0 10 * * *"
|
||||
|
||||
# 启用任务
|
||||
python system_management/scheduler/task_management.py toggle 1 --activate
|
||||
|
||||
# 禁用任务
|
||||
python system_management/scheduler/task_management.py toggle 1 --deactivate
|
||||
|
||||
# 手动执行任务
|
||||
python system_management/scheduler/task_management.py run 1
|
||||
|
||||
# 添加新任务
|
||||
python system_management/scheduler/task_management.py add \
|
||||
--name "hourly_data_check" \
|
||||
--type "processor" \
|
||||
--module "processors.data_checker" \
|
||||
--cron "0 * * * *"
|
||||
@@ -0,0 +1,292 @@
|
||||
# 对象存储数据库操作.md
|
||||
|
||||
## 1. 类概述
|
||||
|
||||
`MinIOAgent` 是一个全平台兼容的对象存储操作类,支持 Windows/macOS/Linux 系统,提供对象存储的桶管理、对象操作、权限控制等功能。
|
||||
|
||||
### 核心特性:
|
||||
|
||||
- ✅ 连接池管理与自动重连
|
||||
- ✅ 全平台兼容的对象操作接口
|
||||
- ✅ 支持大文件分块上传/下载
|
||||
- ✅ 预签名 URL 生成(临时访问)
|
||||
- ✅ 完善的日志记录与错误处理
|
||||
- ✅ 批量操作与前缀筛选
|
||||
|
||||
---
|
||||
|
||||
## 2. 初始化配置
|
||||
|
||||
### 基本配置参数
|
||||
```python
|
||||
Config = {
|
||||
'endpoint': '127.0.0.1:9005', # 对象存储服务地址
|
||||
'access_key': 'minioadmin', # 访问密钥
|
||||
'secret_key': 'minioadmin', # 密钥
|
||||
'secure': False, # 是否启用SSL(社区版默认False)
|
||||
'region': 'us-east-1', # 区域(默认值)
|
||||
'timeout': 300, # 超时时间(秒)
|
||||
'max_pool_connections': 10 # 连接池最大连接数
|
||||
}
|
||||
```
|
||||
|
||||
### 各平台特殊配置
|
||||
| 平台 | 超时设置(秒) | 分块大小建议 | 并发数建议 |
|
||||
|---------|----------------|--------------|------------|
|
||||
| Windows | 300 | 5MB-10MB | 2-4 |
|
||||
| macOS | 300 | 10MB-20MB | 4-8 |
|
||||
| Linux | 300 | 20MB-50MB | 8-16 |
|
||||
|
||||
### 初始化示例
|
||||
```python
|
||||
from utils.minio_agent import MinIOAgent
|
||||
|
||||
# 基础初始化
|
||||
config = {
|
||||
'endpoint': '127.0.0.1:9005',
|
||||
'access_key': 'minioadmin',
|
||||
'secret_key': 'minioadmin',
|
||||
'secure': False
|
||||
}
|
||||
|
||||
# 创建客户端实例
|
||||
minio_client = MinIOAgent(config)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. 桶(Bucket)管理
|
||||
|
||||
### 桶操作
|
||||
```python
|
||||
# 创建桶
|
||||
if minio_client.create_bucket('my-bucket'):
|
||||
print("桶创建成功")
|
||||
|
||||
# 检查桶是否存在
|
||||
if minio_client.bucket_exists('my-bucket'):
|
||||
print("桶已存在")
|
||||
|
||||
# 列出所有桶
|
||||
buckets = minio_client.list_buckets()
|
||||
for bucket in buckets:
|
||||
print(f"桶名称: {bucket['name']}, 创建时间: {bucket['creation_date']}")
|
||||
|
||||
# 删除桶(需先清空桶内对象)
|
||||
if minio_client.delete_bucket('my-bucket'):
|
||||
print("桶删除成功")
|
||||
```
|
||||
|
||||
### 桶策略管理
|
||||
```python
|
||||
# 获取桶策略
|
||||
policy = minio_client.get_bucket_policy('my-bucket')
|
||||
print(policy)
|
||||
|
||||
# 设置公共读策略
|
||||
public_read_policy = {
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [{
|
||||
"Effect": "Allow",
|
||||
"Principal": "*",
|
||||
"Action": ["s3:GetObject"],
|
||||
"Resource": ["arn:aws:s3:::my-bucket/*"]
|
||||
}]
|
||||
}
|
||||
minio_client.set_bucket_policy('my-bucket', public_read_policy)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. 对象(Object)操作
|
||||
|
||||
### 上传对象
|
||||
```python
|
||||
# 从文件上传
|
||||
upload_meta = minio_client.upload_file(
|
||||
bucket_name='my-bucket',
|
||||
object_name='documents/report.pdf',
|
||||
file_path='/local/path/to/report.pdf',
|
||||
content_type='application/pdf' # MIME类型
|
||||
)
|
||||
print(f"上传成功,大小: {upload_meta['size']} bytes")
|
||||
|
||||
# 从字节流上传
|
||||
data = b"test data"
|
||||
upload_meta = minio_client.upload_bytes(
|
||||
bucket_name='my-bucket',
|
||||
object_name='test/data.bin',
|
||||
data=data
|
||||
)
|
||||
|
||||
# 大文件分块上传
|
||||
upload_meta = minio_client.upload_large_file(
|
||||
bucket_name='my-bucket',
|
||||
object_name='videos/large_file.mp4',
|
||||
file_path='/local/path/to/large.mp4',
|
||||
part_size=5*1024*1024 # 5MB分块
|
||||
)
|
||||
```
|
||||
|
||||
### 下载对象
|
||||
```python
|
||||
# 下载到文件
|
||||
download_meta = minio_client.download_file(
|
||||
bucket_name='my-bucket',
|
||||
object_name='documents/report.pdf',
|
||||
file_path='/local/save/path/report.pdf'
|
||||
)
|
||||
|
||||
# 下载为字节流
|
||||
data = minio_client.download_bytes(
|
||||
bucket_name='my-bucket',
|
||||
object_name='test/data.bin'
|
||||
)
|
||||
print(f"下载数据: {data}")
|
||||
```
|
||||
|
||||
### 查询与列举对象
|
||||
```python
|
||||
# 列举桶内所有对象
|
||||
objects = minio_client.list_objects('my-bucket')
|
||||
for obj in objects:
|
||||
print(f"对象: {obj['object_name']}, 大小: {obj['size']}")
|
||||
|
||||
# 按前缀筛选(类似文件夹)
|
||||
pdf_files = minio_client.list_objects(
|
||||
bucket_name='my-bucket',
|
||||
prefix='documents/', # 前缀(类似文件夹路径)
|
||||
recursive=False # 是否递归查询子目录
|
||||
)
|
||||
|
||||
# 获取对象元信息
|
||||
meta = minio_client.get_object_metadata(
|
||||
bucket_name='my-bucket',
|
||||
object_name='documents/report.pdf'
|
||||
)
|
||||
print(f"内容类型: {meta['content_type']}, 最后修改: {meta['last_modified']}")
|
||||
```
|
||||
|
||||
### 删除对象
|
||||
```python
|
||||
# 删除单个对象
|
||||
if minio_client.delete_object('my-bucket', 'test/data.bin'):
|
||||
print("对象删除成功")
|
||||
|
||||
# 批量删除对象
|
||||
delete_count = minio_client.delete_objects(
|
||||
bucket_name='my-bucket',
|
||||
object_names=['file1.txt', 'file2.txt', 'docs/report.pdf']
|
||||
)
|
||||
print(f"成功删除 {delete_count} 个对象")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. 高级功能
|
||||
|
||||
### 预签名 URL(临时访问)
|
||||
```python
|
||||
# 生成下载预签名URL(有效期30分钟)
|
||||
download_url = minio_client.get_presigned_url(
|
||||
bucket_name='my-bucket',
|
||||
object_name='documents/report.pdf',
|
||||
expires=1800, # 有效期(秒)
|
||||
method='GET' # 访问方法(GET下载,PUT上传)
|
||||
)
|
||||
print(f"临时下载链接: {download_url}")
|
||||
|
||||
# 生成上传预签名URL(允许客户端直接上传)
|
||||
upload_url = minio_client.get_presigned_url(
|
||||
bucket_name='my-bucket',
|
||||
object_name='user_uploads/image.jpg',
|
||||
expires=3600,
|
||||
method='PUT'
|
||||
)
|
||||
```
|
||||
|
||||
### 批量操作
|
||||
```python
|
||||
# 批量复制对象(同桶内)
|
||||
copy_results = minio_client.copy_objects(
|
||||
source_bucket='my-bucket',
|
||||
dest_bucket='my-bucket',
|
||||
object_mapping={
|
||||
'documents/report.pdf': 'archive/report_2024.pdf',
|
||||
'data/raw.csv': 'data/backup/raw_2024.csv'
|
||||
}
|
||||
)
|
||||
|
||||
# 批量移动对象(跨桶)
|
||||
move_results = minio_client.move_objects(
|
||||
source_bucket='my-bucket',
|
||||
dest_bucket='archive-bucket',
|
||||
object_prefix='2023/' # 移动所有以2023/为前缀的对象
|
||||
)
|
||||
```
|
||||
|
||||
### 生命周期管理
|
||||
```python
|
||||
# 设置对象生命周期规则(自动迁移/删除)
|
||||
rule = {
|
||||
"Rules": [{
|
||||
"ID": "archive-old-files",
|
||||
"Status": "Enabled",
|
||||
"Prefix": "logs/",
|
||||
"Expiration": {
|
||||
"Days": 90 # 90天后自动删除
|
||||
},
|
||||
"Transition": {
|
||||
"Days": 30, # 30天后迁移到低频存储
|
||||
"StorageClass": "STANDARD_IA"
|
||||
}
|
||||
}]
|
||||
}
|
||||
minio_client.set_bucket_lifecycle('my-bucket', rule)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. 异常处理
|
||||
|
||||
```python
|
||||
from minio.error import S3Error
|
||||
|
||||
try:
|
||||
# 尝试上传对象
|
||||
minio_client.upload_file(
|
||||
bucket_name='my-bucket',
|
||||
object_name='critical/data.csv',
|
||||
file_path='/local/data.csv'
|
||||
)
|
||||
except S3Error as e:
|
||||
if e.code == 'NoSuchBucket':
|
||||
print("桶不存在,创建后重试")
|
||||
minio_client.create_bucket('my-bucket')
|
||||
elif e.code == 'AccessDenied':
|
||||
print("权限不足,请检查密钥")
|
||||
else:
|
||||
print(f"上传失败: {e}")
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. 性能优化建议
|
||||
|
||||
1. **大文件处理**:
|
||||
- 超过100MB的文件建议使用分块上传(`upload_large_file`)
|
||||
- 根据网络状况调整分块大小(5-50MB)
|
||||
|
||||
2. **批量操作**:
|
||||
- 列举对象时使用前缀筛选减少返回数据量
|
||||
- 批量删除/复制时单次操作不超过1000个对象
|
||||
|
||||
3. **缓存策略**:
|
||||
- 对频繁访问的对象使用预签名URL并设置合理过期时间
|
||||
- 客户端缓存对象元数据减少请求次数
|
||||
|
||||
4. **并发控制**:
|
||||
- 多线程操作时控制并发数(参考平台建议值)
|
||||
- 避免同时对同一对象进行写操作
|
||||
@@ -0,0 +1,2 @@
|
||||
## 开发进度
|
||||
###
|
||||
+1
-1
@@ -28,7 +28,7 @@
|
||||
|
||||
### 基本配置参数
|
||||
```python
|
||||
{
|
||||
Config = {
|
||||
'host': 'localhost', # 数据库主机
|
||||
'port': 3306, # 端口
|
||||
'user': 'root', # 用户名
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<excludeFolder url="file://$MODULE_DIR$/logs" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="intelligence" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="intelligence_system" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
||||
Binary file not shown.
Binary file not shown.
+133929
-108
File diff suppressed because it is too large
Load Diff
+71015
File diff suppressed because it is too large
Load Diff
@@ -1,111 +1,134 @@
|
||||
# main.py
|
||||
import signal
|
||||
import time
|
||||
from datetime import datetime
|
||||
from system_management.scheduler.task_scheduler import TaskScheduler
|
||||
from utils.logger import CrossPlatformLog
|
||||
from config import Config
|
||||
|
||||
|
||||
# 初始化日志
|
||||
log = CrossPlatformLog.get_logger("Main")
|
||||
|
||||
|
||||
class IntelligenceSystem:
|
||||
def __init__(self, db_config=None):
|
||||
self.scheduler = TaskScheduler(db_config)
|
||||
def __init__(self, db_config=None, run_all_on_startup=False):
|
||||
"""初始化系统(仅作为容器,不包含业务逻辑)
|
||||
|
||||
Args:
|
||||
db_config: 数据库配置
|
||||
run_all_on_startup: 启动时是否立即执行所有到期任务(默认False)
|
||||
"""
|
||||
self.scheduler = TaskScheduler(Config.MYSQL_CONFIG, max_workers=5)
|
||||
self._running = False
|
||||
log.info("IntelligenceSystem initialized")
|
||||
self.run_all_on_startup = run_all_on_startup
|
||||
log.info(f"情报系统已初始化(Cron模式),启动时执行任务: {run_all_on_startup}")
|
||||
|
||||
def run(self):
|
||||
"""启动系统主循环"""
|
||||
def start(self):
|
||||
"""启动系统主入口"""
|
||||
self._running = True
|
||||
self._register_signal_handlers()
|
||||
self._setup_signal_handlers()
|
||||
log.info("系统启动 - 运行在Cron调度模式")
|
||||
|
||||
# 启动时执行所有到期任务(如果开关开启)
|
||||
if self.run_all_on_startup:
|
||||
print(f"\n{'='*60}")
|
||||
print("🚀 启动时执行所有到期任务...")
|
||||
print(f"{'='*60}\n")
|
||||
log.info("启动时执行所有到期任务")
|
||||
result = self.scheduler.check_and_run_tasks(print_empty_status=True)
|
||||
print(f"\n启动任务执行完成: 总数={result['总任务数']}, 成功={result['成功']}, 失败={result['失败']}\n")
|
||||
|
||||
log.info("Starting main loop")
|
||||
# 时间追踪变量
|
||||
last_status_print_time = time.time() # 上次打印状态的时间
|
||||
last_hourly_report_time = time.time() # 上次小时统计的时间
|
||||
status_print_interval = 60 # 每分钟打印一次状态(60秒)
|
||||
hourly_report_interval = 3600 # 每小时统计一次(3600秒)
|
||||
|
||||
try:
|
||||
# 主循环 - 仅负责定期检查任务
|
||||
while self._running:
|
||||
start_time = time.time()
|
||||
self._run_cycle()
|
||||
current_time = time.time()
|
||||
|
||||
# 判断是否需要打印状态(每分钟一次)
|
||||
should_print_status = (current_time - last_status_print_time) >= status_print_interval
|
||||
|
||||
# 检查并执行到期任务
|
||||
self.scheduler.check_and_run_tasks(print_empty_status=should_print_status)
|
||||
|
||||
# 更新最后打印时间
|
||||
if should_print_status:
|
||||
last_status_print_time = current_time
|
||||
|
||||
# 检查是否需要进行小时统计(每小时一次)
|
||||
if (current_time - last_hourly_report_time) >= hourly_report_interval:
|
||||
self._print_hourly_stats()
|
||||
last_hourly_report_time = current_time
|
||||
|
||||
# 精确控制循环间隔(扣除执行时间)
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, 60 - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
# 短间隔轮询(每10秒检查一次,保证Cron时间精度)
|
||||
time.sleep(10)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
log.info("Received keyboard interrupt")
|
||||
except Exception as e:
|
||||
log.critical(
|
||||
"System crashed",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
log.critical("系统主循环崩溃", exc_info=True)
|
||||
finally:
|
||||
self.shutdown()
|
||||
|
||||
def _run_cycle(self):
|
||||
"""单个运行周期"""
|
||||
try:
|
||||
# 1. 执行任务调度
|
||||
result = self.scheduler.run_pending_tasks()
|
||||
|
||||
# 2. 每小时记录系统状态
|
||||
if datetime.now().minute == 0:
|
||||
self._log_system_status()
|
||||
|
||||
except Exception as e:
|
||||
log.error(
|
||||
"Cycle execution failed",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def _log_system_status(self):
|
||||
"""记录系统状态"""
|
||||
try:
|
||||
status_df = self.scheduler.get_task_status()
|
||||
pending = len(status_df[status_df['next_run_time'] <= datetime.now()])
|
||||
|
||||
log.info(
|
||||
"System status",
|
||||
pending_tasks=pending,
|
||||
active_tasks=len(status_df),
|
||||
last_success=status_df['last_run_time'].max()
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
"Failed to log system status",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
def _register_signal_handlers(self):
|
||||
"""注册信号处理"""
|
||||
def _setup_signal_handlers(self):
|
||||
"""设置系统信号处理器"""
|
||||
signal.signal(signal.SIGINT, self._handle_shutdown)
|
||||
signal.signal(signal.SIGTERM, self._handle_shutdown)
|
||||
log.debug("Signal handlers registered")
|
||||
log.debug("信号处理器已注册")
|
||||
|
||||
def _handle_shutdown(self, signum, frame):
|
||||
"""处理关闭信号"""
|
||||
log.info(
|
||||
f"Processing shutdown signal {signum}",
|
||||
signal=signum
|
||||
)
|
||||
"""处理系统关闭信号"""
|
||||
log.info(f"收到关闭信号 {signum},开始关闭系统")
|
||||
self._running = False
|
||||
|
||||
def _print_hourly_stats(self):
|
||||
"""打印并重置小时统计信息"""
|
||||
stats = self.scheduler.get_and_reset_hourly_stats()
|
||||
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"📊 小时任务统计报告 - {now}")
|
||||
print(f"{'='*60}")
|
||||
print(f" 总任务数: {stats['总数']}")
|
||||
print(f" 成功: {stats['成功']}")
|
||||
print(f" 失败: {stats['失败']}")
|
||||
if stats['总数'] > 0:
|
||||
success_rate = (stats['成功'] / stats['总数']) * 100
|
||||
print(f" 成功率: {success_rate:.1f}%")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
log.info(
|
||||
"小时任务统计",
|
||||
总任务数=stats['总数'],
|
||||
成功=stats['成功'],
|
||||
失败=stats['失败']
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
"""关闭系统"""
|
||||
log.info("Performing system shutdown")
|
||||
# 此处可添加其他清理逻辑
|
||||
log.success("System shutdown completed")
|
||||
"""优雅关闭系统"""
|
||||
log.info("开始优雅关闭系统")
|
||||
|
||||
# 等待所有正在执行的任务完成
|
||||
self.scheduler.executor.shutdown(wait=True, cancel_futures=False)
|
||||
|
||||
# 记录最终状态
|
||||
pending_count = self.scheduler.get_pending_tasks_count()
|
||||
log.info(
|
||||
"系统关闭完成",
|
||||
pending_tasks=pending_count,
|
||||
shutdown_time=datetime.now()
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
system = IntelligenceSystem()
|
||||
system.run()
|
||||
# 启动系统 - 仅作为入口,不包含调度逻辑
|
||||
# run_all_on_startup=True: 启动时立即执行所有到期任务
|
||||
# run_all_on_startup=False: 启动时不执行任务,等待下次调度周期
|
||||
system = IntelligenceSystem(run_all_on_startup=False)
|
||||
system.start()
|
||||
except Exception as e:
|
||||
log.critical(
|
||||
"System startup failed",
|
||||
exc_info=True
|
||||
)
|
||||
log.critical("情报系统启动失败", exc_info=True)
|
||||
raise
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,453 @@
|
||||
# RSS数据AI处理模块
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import pandas as pd
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
from openai import OpenAI
|
||||
|
||||
# 添加项目根目录到路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parent_dir = os.path.dirname(os.path.dirname(current_dir))
|
||||
if parent_dir not in sys.path:
|
||||
sys.path.insert(0, parent_dir)
|
||||
|
||||
from utils.mysql_agent import MySQLAgent
|
||||
from utils.logger import log
|
||||
from config import Config
|
||||
|
||||
|
||||
class RSSDataAIProcessor:
|
||||
"""RSS数据AI处理主类
|
||||
|
||||
负责:
|
||||
- 从数据库加载未处理的RSS数据
|
||||
- 调用AI进行分析
|
||||
- 保存分析结果
|
||||
- 更新处理状态
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化AI处理器"""
|
||||
self.log = log.bind(module="RSSDataAIProcessor")
|
||||
self.db_agent = MySQLAgent(Config.MYSQL_CONFIG)
|
||||
|
||||
# 从Config读取配置
|
||||
self.source_table = Config.AI_PROCESSOR_CONFIG['source_table']
|
||||
self.ai_table = Config.AI_PROCESSOR_CONFIG['result_table']
|
||||
self.default_batch_size = Config.AI_PROCESSOR_CONFIG['batch_size']
|
||||
self.default_delay = Config.AI_PROCESSOR_CONFIG['delay']
|
||||
|
||||
# 初始化百度千帆API客户端
|
||||
self.api_key = Config.BAIDU_AI_CONFIG.get('api_key')
|
||||
if self.api_key:
|
||||
self.ai_client = OpenAI(
|
||||
base_url='https://qianfan.baidubce.com/v2',
|
||||
api_key=self.api_key
|
||||
)
|
||||
self.model = Config.BAIDU_AI_CONFIG.get('model', 'ernie-x1-turbo-32k')
|
||||
self.log.info("RSS数据AI处理器初始化完成")
|
||||
else:
|
||||
self.ai_client = None
|
||||
self.log.warning("百度AI未配置,AI处理功能将不可用")
|
||||
self.log.warning("请在config.py中配置 BAIDU_AI_CONFIG['api_key']")
|
||||
|
||||
def is_configured(self) -> bool:
|
||||
"""检查是否已配置API"""
|
||||
return self.ai_client is not None
|
||||
|
||||
def main(self, batch_size: Optional[int] = 200, delay: Optional[float] = None) -> Dict[str, Any]:
|
||||
"""主程序:批量处理RSS数据的完整流程
|
||||
|
||||
Args:
|
||||
batch_size: 批量处理的记录数,None则使用配置的默认值
|
||||
delay: 每条记录之间的延迟(秒),None则使用配置的默认值
|
||||
|
||||
Returns:
|
||||
dict: 处理结果统计信息
|
||||
"""
|
||||
# 使用传入参数或默认配置
|
||||
batch_size = batch_size or self.default_batch_size
|
||||
delay = delay or self.default_delay
|
||||
|
||||
try:
|
||||
# 1. 检查配置
|
||||
if not self.is_configured():
|
||||
error_msg = "百度AI未配置,请在config.py中配置 BAIDU_AI_CONFIG['api_key']"
|
||||
self.log.error(error_msg)
|
||||
return {
|
||||
'success': False,
|
||||
'message': error_msg,
|
||||
'processed_count': 0,
|
||||
'failed_count': 0
|
||||
}
|
||||
|
||||
self.log.info(f"开始批量处理数据,批次大小: {batch_size}, 延迟: {delay}秒")
|
||||
|
||||
# 2. 准备数据库表结构
|
||||
self.ensure_ai_processed_column()
|
||||
if not self.db_agent.table_exists(self.ai_table):
|
||||
self.create_ai_result_table()
|
||||
|
||||
# 3. 加载未处理的数据
|
||||
df = self.load_unprocessed_data(batch_size)
|
||||
if df.empty:
|
||||
self.log.info("没有需要处理的数据")
|
||||
return {
|
||||
'success': True,
|
||||
'message': '没有需要处理的数据',
|
||||
'processed_count': 0,
|
||||
'failed_count': 0
|
||||
}
|
||||
|
||||
# 4. 处理每条记录
|
||||
results = []
|
||||
processed_ids = []
|
||||
failed_count = 0
|
||||
|
||||
for idx, record in df.iterrows():
|
||||
try:
|
||||
self.log.debug(f"处理记录 {record['id']} ({idx + 1}/{len(df)})")
|
||||
|
||||
result = self.process_single_record(record.to_dict())
|
||||
|
||||
if result:
|
||||
results.append(result)
|
||||
processed_ids.append(record['id'])
|
||||
else:
|
||||
failed_count += 1
|
||||
|
||||
# 延迟,避免API限流
|
||||
if delay > 0 and idx < len(df) - 1:
|
||||
time.sleep(delay)
|
||||
|
||||
except Exception as e:
|
||||
self.log.error(f"处理记录 {record['id']} 异常: {str(e)}", exc_info=True)
|
||||
failed_count += 1
|
||||
|
||||
# 5. 保存结果
|
||||
saved_count = 0
|
||||
if results:
|
||||
saved_count = self.save_ai_results(results)
|
||||
|
||||
# 6. 标记为已处理
|
||||
if processed_ids:
|
||||
self.mark_as_processed(processed_ids)
|
||||
|
||||
# 7. 返回统计信息
|
||||
stats = {
|
||||
'success': True,
|
||||
'message': 'AI处理完成',
|
||||
'total_count': len(df),
|
||||
'processed_count': len(processed_ids),
|
||||
'saved_count': saved_count,
|
||||
'failed_count': failed_count,
|
||||
'relevant_count': sum(1 for r in results if r.get('是否相关')),
|
||||
'processing_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
}
|
||||
|
||||
self.log.info("批量处理完成", **stats)
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"批量处理失败: {str(e)}"
|
||||
self.log.error(error_msg, exc_info=True)
|
||||
return {
|
||||
'success': False,
|
||||
'message': error_msg,
|
||||
'processed_count': 0,
|
||||
'failed_count': 0
|
||||
}
|
||||
|
||||
def ensure_ai_processed_column(self):
|
||||
"""确保processed_rss_data表有"是否ai处理"字段"""
|
||||
try:
|
||||
# 检查字段是否存在
|
||||
check_sql = """
|
||||
SELECT COUNT(*) as count
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE TABLE_SCHEMA = %s
|
||||
AND TABLE_NAME = %s
|
||||
AND COLUMN_NAME = '是否ai处理'
|
||||
"""
|
||||
|
||||
result = self.db_agent.execute_sql(
|
||||
check_sql,
|
||||
params=(Config.MYSQL_CONFIG['database'], self.source_table),
|
||||
fetch=True
|
||||
)
|
||||
|
||||
if result[0][0] == 0:
|
||||
# 字段不存在,添加字段
|
||||
alter_sql = f"""
|
||||
ALTER TABLE {self.source_table}
|
||||
ADD COLUMN 是否ai处理 TINYINT(1) DEFAULT 0 COMMENT 'AI处理标记:0-未处理,1-已处理'
|
||||
"""
|
||||
self.db_agent.execute_sql(alter_sql)
|
||||
self.log.info(f"成功为表 {self.source_table} 添加 '是否ai处理' 字段")
|
||||
else:
|
||||
self.log.debug(f"表 {self.source_table} 已存在 '是否ai处理' 字段")
|
||||
|
||||
except Exception as e:
|
||||
self.log.error(f"检查/添加字段失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def create_ai_result_table(self):
|
||||
"""创建AI处理结果表"""
|
||||
create_sql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.ai_table} (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY COMMENT '主键ID',
|
||||
source_id INT NOT NULL COMMENT '来源数据ID(processed_rss_data.id)',
|
||||
文章标题 TEXT COMMENT '文章标题',
|
||||
文章摘要 TEXT COMMENT '文章摘要',
|
||||
发布时间 DATETIME COMMENT '发布时间',
|
||||
来源URL VARCHAR(1024) COMMENT '来源URL',
|
||||
文章链接 VARCHAR(1024) COMMENT '文章链接',
|
||||
是否相关 BOOLEAN COMMENT 'AI判断是否与汽车后市场相关',
|
||||
相关度评分 INT COMMENT '相关度评分(0-100)',
|
||||
标签 TEXT COMMENT 'AI生成的标签(JSON数组)',
|
||||
分类 VARCHAR(100) COMMENT 'AI判断的主要分类',
|
||||
分析说明 TEXT COMMENT 'AI分析说明',
|
||||
处理时间 DATETIME COMMENT 'AI处理时间',
|
||||
创建时间 TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '记录创建时间',
|
||||
更新时间 TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '记录更新时间',
|
||||
INDEX idx_source_id (source_id),
|
||||
INDEX idx_是否相关 (是否相关),
|
||||
INDEX idx_分类 (分类),
|
||||
INDEX idx_处理时间 (处理时间)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='RSS数据AI分析结果表'
|
||||
"""
|
||||
|
||||
try:
|
||||
self.db_agent.execute_sql(create_sql)
|
||||
self.log.info(f"成功创建AI结果表: {self.ai_table}")
|
||||
except Exception as e:
|
||||
self.log.error(f"创建AI结果表失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def load_unprocessed_data(self, limit: int = 100) -> pd.DataFrame:
|
||||
"""加载未经AI处理的数据
|
||||
|
||||
Args:
|
||||
limit: 每次处理的记录数量
|
||||
|
||||
Returns:
|
||||
未处理的数据DataFrame
|
||||
"""
|
||||
try:
|
||||
sql = f"""
|
||||
SELECT id, 文章标题, 文章摘要, 发布时间, 来源URL, 文章链接
|
||||
FROM {self.source_table}
|
||||
WHERE 是否ai处理 = 0 OR 是否ai处理 IS NULL
|
||||
ORDER BY 创建时间 DESC
|
||||
LIMIT %s
|
||||
"""
|
||||
|
||||
df = self.db_agent.query_to_df(sql, params=(limit,), is_print=False)
|
||||
self.log.info(f"成功加载 {len(df)} 条未处理的数据")
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
self.log.error(f"加载未处理数据失败: {str(e)}", exc_info=True)
|
||||
return pd.DataFrame()
|
||||
|
||||
def analyze_news(self, title: str, summary: str) -> Dict[str, Any]:
|
||||
"""调用AI分析新闻(保留原有提示词)"""
|
||||
# 构建提示词(保留原有格式)
|
||||
prompt = f"""分析以下新闻是否与汽车后市场相关,返回JSON格式:
|
||||
|
||||
标题:{title}
|
||||
摘要:{summary}
|
||||
|
||||
返回格式:
|
||||
{{
|
||||
"is_relevant": true/false,
|
||||
"relevance_score": 0-100,
|
||||
"tags": ["标签1", "标签2"],
|
||||
"category": "分类(配件/维修/保养/改装/美容/装饰/二手车/金融/保险/其他)",
|
||||
"analysis": "简要说明"
|
||||
}}
|
||||
|
||||
注意:只返回JSON格式的结果,不要包含其他说明文字。"""
|
||||
|
||||
try:
|
||||
# 调用百度千帆API
|
||||
response = self.ai_client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}]
|
||||
)
|
||||
|
||||
# 获取响应内容
|
||||
raw_content = response.choices[0].message.content
|
||||
|
||||
# 解析JSON(处理markdown包裹)
|
||||
if '```json' in raw_content:
|
||||
json_str = raw_content.split('```json')[1].split('```')[0].strip()
|
||||
elif '```' in raw_content:
|
||||
json_str = raw_content.split('```')[1].split('```')[0].strip()
|
||||
else:
|
||||
json_str = raw_content.strip()
|
||||
|
||||
result = json.loads(json_str)
|
||||
|
||||
# 补充缺失字段
|
||||
return {
|
||||
'is_relevant': result.get('is_relevant', False),
|
||||
'relevance_score': result.get('relevance_score', 0),
|
||||
'tags': result.get('tags', []),
|
||||
'category': result.get('category', '其他'),
|
||||
'analysis': result.get('analysis', '')
|
||||
}
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
self.log.warning(f"JSON解析失败: {str(e)}, 原始响应: {raw_content[:200]}")
|
||||
return {
|
||||
'is_relevant': False,
|
||||
'relevance_score': 0,
|
||||
'tags': [],
|
||||
'category': '其他',
|
||||
'analysis': f"解析失败: {raw_content[:100]}"
|
||||
}
|
||||
except Exception as e:
|
||||
self.log.error(f"AI调用异常: {str(e)}", exc_info=True)
|
||||
return {
|
||||
'is_relevant': False,
|
||||
'relevance_score': 0,
|
||||
'tags': [],
|
||||
'category': '其他',
|
||||
'analysis': f"处理异常: {str(e)}"
|
||||
}
|
||||
|
||||
def process_single_record(self, record: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""处理单条记录
|
||||
|
||||
Args:
|
||||
record: 记录字典
|
||||
|
||||
Returns:
|
||||
处理结果字典
|
||||
"""
|
||||
if not self.is_configured():
|
||||
self.log.error("AI客户端未配置,无法处理数据")
|
||||
return None
|
||||
|
||||
try:
|
||||
title = str(record.get('文章标题', '')).strip()
|
||||
summary = str(record.get('文章摘要', '')).strip()
|
||||
|
||||
if not title and not summary:
|
||||
self.log.warning(f"记录 {record.get('id')} 标题和摘要均为空,跳过处理")
|
||||
return None
|
||||
|
||||
# 调用AI分析
|
||||
analysis_result = self.analyze_news(title, summary)
|
||||
|
||||
# 构建结果记录
|
||||
result = {
|
||||
'source_id': record['id'],
|
||||
'文章标题': title,
|
||||
'文章摘要': summary,
|
||||
'发布时间': record.get('发布时间'),
|
||||
'来源URL': record.get('来源URL'),
|
||||
'文章链接': record.get('文章链接'),
|
||||
'是否相关': analysis_result.get('is_relevant', False),
|
||||
'相关度评分': analysis_result.get('relevance_score', 0),
|
||||
'标签': json.dumps(analysis_result.get('tags', []), ensure_ascii=False),
|
||||
'分类': analysis_result.get('category', '其他'),
|
||||
'分析说明': analysis_result.get('analysis', ''),
|
||||
'处理时间': datetime.now()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self.log.error(f"处理记录 {record.get('id')} 失败: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
def save_ai_results(self, results: List[Dict[str, Any]]) -> int:
|
||||
"""保存AI处理结果
|
||||
|
||||
Args:
|
||||
results: 处理结果列表
|
||||
|
||||
Returns:
|
||||
成功保存的记录数
|
||||
"""
|
||||
if not results:
|
||||
return 0
|
||||
|
||||
try:
|
||||
df = pd.DataFrame(results)
|
||||
inserted = self.db_agent.insert_from_df(
|
||||
table_name=self.ai_table,
|
||||
df=df,
|
||||
ignore_duplicates=True
|
||||
)
|
||||
self.log.info(f"成功保存 {inserted} 条AI处理结果")
|
||||
return inserted
|
||||
|
||||
except Exception as e:
|
||||
self.log.error(f"保存AI处理结果失败: {str(e)}", exc_info=True)
|
||||
return 0
|
||||
|
||||
def mark_as_processed(self, ids: List[int]) -> bool:
|
||||
"""标记记录为已处理
|
||||
|
||||
Args:
|
||||
ids: 记录ID列表
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not ids:
|
||||
return True
|
||||
|
||||
try:
|
||||
id_placeholders = ','.join(['%s'] * len(ids))
|
||||
sql = f"""
|
||||
UPDATE {self.source_table}
|
||||
SET 是否ai处理 = 1
|
||||
WHERE id IN ({id_placeholders})
|
||||
"""
|
||||
|
||||
self.db_agent.execute_sql(sql, params=ids)
|
||||
self.log.info(f"成功标记 {len(ids)} 条记录为已处理")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.log.error(f"标记记录为已处理失败: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""命令行直接运行"""
|
||||
# 实例化处理器并调用main方法
|
||||
processor = RSSDataAIProcessor()
|
||||
result = processor.main()
|
||||
|
||||
# 输出结果
|
||||
if result['success']:
|
||||
print("\n" + "=" * 60)
|
||||
print("✓ AI处理完成")
|
||||
print("=" * 60)
|
||||
print(f"总记录数: {result.get('total_count', 0)}")
|
||||
print(f"成功处理: {result.get('processed_count', 0)}")
|
||||
print(f"保存记录: {result.get('saved_count', 0)}")
|
||||
print(f"失败记录: {result.get('failed_count', 0)}")
|
||||
print(f"相关记录: {result.get('relevant_count', 0)}")
|
||||
print(f"处理时间: {result.get('processing_time', '')}")
|
||||
print("=" * 60 + "\n")
|
||||
else:
|
||||
print("\n" + "=" * 60)
|
||||
print("✗ 处理失败")
|
||||
print("=" * 60)
|
||||
print(f"错误信息: {result['message']}")
|
||||
print("\n提示: 请设置环境变量")
|
||||
print(" Windows: $env:BAIDU_API_KEY = 'your_key'")
|
||||
print(" Linux/Mac: export BAIDU_API_KEY='your_key'")
|
||||
print("=" * 60 + "\n")
|
||||
@@ -0,0 +1,37 @@
|
||||
汽车配件
|
||||
汽车维修
|
||||
汽车保养
|
||||
汽车改装
|
||||
汽车美容
|
||||
汽车装饰
|
||||
轮胎
|
||||
机油
|
||||
刹车片
|
||||
火花塞
|
||||
滤清器
|
||||
蓄电池
|
||||
车灯
|
||||
保险杠
|
||||
车门
|
||||
座椅
|
||||
方向盘
|
||||
仪表盘
|
||||
音响
|
||||
导航
|
||||
汽车用品
|
||||
车载设备
|
||||
汽车电子
|
||||
汽车安全
|
||||
汽车保险
|
||||
二手车
|
||||
汽车交易
|
||||
汽车金融
|
||||
汽车租赁
|
||||
汽车服务
|
||||
4S店
|
||||
汽修店
|
||||
汽车后市场
|
||||
汽车产业链
|
||||
汽车供应链
|
||||
汽车
|
||||
车
|
||||
@@ -0,0 +1,409 @@
|
||||
# RSS数据处理模块 - 汽车后市场新闻分词和过滤
|
||||
import pandas as pd
|
||||
import jieba
|
||||
import jieba.posseg as pseg
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
# 添加项目根目录到路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parent_dir = os.path.dirname(current_dir)
|
||||
if parent_dir not in sys.path:
|
||||
sys.path.insert(0, parent_dir)
|
||||
|
||||
from utils.mysql_agent import MySQLAgent
|
||||
from utils.logger import log
|
||||
from config import Config
|
||||
|
||||
class RSSDataProcessor:
|
||||
"""RSS数据处理器 - 专门处理汽车后市场相关新闻"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化处理器"""
|
||||
self.log = log.bind(module="RSSDataProcessor")
|
||||
self.db_agent = MySQLAgent(Config.MYSQL_CONFIG)
|
||||
self.table_name = "collector_rss_subscriptions"
|
||||
self.processed_table_name = "processed_rss_data"
|
||||
|
||||
# 获取项目根目录
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self.project_root = os.path.dirname(current_dir)
|
||||
|
||||
# 设置文件路径(相对于项目根目录)
|
||||
self.keywords_file = os.path.join(self.project_root, "processors", "keywords.txt")
|
||||
self.stopwords_file = os.path.join(self.project_root, "processors", "stopwords.txt")
|
||||
|
||||
# 汽车后市场相关关键词(默认值,实际从文件加载)
|
||||
self.auto_aftermarket_keywords = {
|
||||
'汽车配件', '汽车维修', '汽车保养', '汽车改装', '汽车美容', '汽车装饰',
|
||||
'轮胎', '机油', '刹车片', '火花塞', '滤清器', '蓄电池', '车灯',
|
||||
'保险杠', '车门', '座椅', '方向盘', '仪表盘', '音响', '导航',
|
||||
'汽车用品', '车载设备', '汽车电子', '汽车安全', '汽车保险',
|
||||
'二手车', '汽车交易', '汽车金融', '汽车租赁', '汽车服务',
|
||||
'4S店', '汽修店', '汽车后市场', '汽车产业链', '汽车供应链', '汽车', '车'
|
||||
}
|
||||
|
||||
# 停用词表(默认值,实际从文件加载)
|
||||
self.stopwords = {
|
||||
'的', '了', '在', '是', '我', '有', '和', '就', '不', '人', '都', '一', '一个',
|
||||
'上', '也', '很', '到', '说', '要', '去', '你', '会', '着', '没有', '看', '好',
|
||||
'自己', '这', '那', '它', '他', '她', '我们', '你们', '他们', '什么', '怎么',
|
||||
'为什么', '因为', '所以', '但是', '然后', '如果', '虽然', '而且', '或者',
|
||||
'可以', '应该', '必须', '需要', '想要', '希望', '觉得', '认为', '知道',
|
||||
'了解', '明白', '清楚', '简单', '容易', '困难', '重要', '主要', '基本',
|
||||
'一般', '特别', '非常', '十分', '相当', '比较', '更加', '最', '更',
|
||||
'已经', '正在', '将要', '可能', '也许', '大概', '大约', '左右', '上下',
|
||||
'今天', '明天', '昨天', '现在', '以前', '以后', '时候', '时间', '地方',
|
||||
'这里', '那里', '这样', '那样', '如此', '这样', '那样', '如何', '怎样'
|
||||
}
|
||||
|
||||
# 缓存关键词,避免重复加载
|
||||
self._cached_keywords = None
|
||||
|
||||
self.log.info("RSS数据处理器初始化完成")
|
||||
|
||||
def load_keywords(self, keywords_file: Optional[str] = None) -> set:
|
||||
"""从文件加载汽车后市场关键词(带缓存)"""
|
||||
# 如果已经缓存,直接返回
|
||||
if self._cached_keywords is not None:
|
||||
return self._cached_keywords
|
||||
|
||||
# 使用默认路径(项目根目录下的文件)
|
||||
if keywords_file is None:
|
||||
keywords_file = self.keywords_file
|
||||
|
||||
keywords = set()
|
||||
try:
|
||||
if os.path.exists(keywords_file):
|
||||
with open(keywords_file, 'r', encoding='utf-8') as f:
|
||||
keywords = set(line.strip() for line in f if line.strip())
|
||||
self.log.info(f"成功加载汽车后市场关键词,共 {len(keywords)} 个")
|
||||
else:
|
||||
self.log.warning(f"关键词文件不存在: {keywords_file}")
|
||||
# 使用默认关键词
|
||||
keywords = self.auto_aftermarket_keywords
|
||||
except Exception as e:
|
||||
self.log.error(f"加载关键词失败: {str(e)}")
|
||||
keywords = self.auto_aftermarket_keywords
|
||||
|
||||
# 缓存关键词
|
||||
self._cached_keywords = keywords
|
||||
return keywords
|
||||
|
||||
def load_rss_data(self, limit: int = 1000) -> pd.DataFrame:
|
||||
"""从数据库加载未处理的RSS数据"""
|
||||
try:
|
||||
sql = f"""
|
||||
SELECT id, 文章标题, 文章摘要, 发布时间, 来源URL, 文章链接
|
||||
FROM {self.table_name}
|
||||
WHERE 是否已处理 = 0
|
||||
ORDER BY 发布时间 DESC
|
||||
LIMIT %s
|
||||
"""
|
||||
|
||||
df = self.db_agent.query_to_df(sql, params=(limit,), is_print=False)
|
||||
self.log.info(f"成功加载 {len(df)} 条未处理的RSS数据")
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
self.log.error(f"加载RSS数据失败: {str(e)}", exc_info=True)
|
||||
return pd.DataFrame()
|
||||
|
||||
def mark_as_processed(self, ids: List[int]) -> bool:
|
||||
"""标记指定ID的数据为已处理"""
|
||||
if not ids:
|
||||
return True
|
||||
|
||||
try:
|
||||
# 将ID列表转换为字符串格式用于SQL IN语句
|
||||
id_placeholders = ','.join(['%s'] * len(ids))
|
||||
sql = f"""
|
||||
UPDATE {self.table_name}
|
||||
SET 是否已处理 = 1
|
||||
WHERE id IN ({id_placeholders})
|
||||
"""
|
||||
|
||||
result = self.db_agent.execute_sql(sql, params=ids)
|
||||
self.log.info(f"成功标记 {len(ids)} 条数据为已处理")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.log.error(f"标记数据为已处理失败: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def load_stopwords(self, stopwords_file: Optional[str] = None) -> set:
|
||||
"""加载停用词表"""
|
||||
# 使用默认路径(项目根目录下的文件)
|
||||
if stopwords_file is None:
|
||||
stopwords_file = self.stopwords_file
|
||||
|
||||
try:
|
||||
if os.path.exists(stopwords_file):
|
||||
with open(stopwords_file, 'r', encoding='utf-8') as f:
|
||||
stopwords = set(line.strip() for line in f if line.strip())
|
||||
self.log.info(f"成功加载停用词表,共 {len(stopwords)} 个词")
|
||||
return stopwords
|
||||
else:
|
||||
self.log.warning(f"停用词文件不存在: {stopwords_file},使用默认停用词")
|
||||
return self.stopwords
|
||||
except Exception as e:
|
||||
self.log.error(f"加载停用词表失败: {str(e)}")
|
||||
return self.stopwords
|
||||
|
||||
def add_custom_dict(self, custom_dict_file: Optional[str] = None):
|
||||
"""添加自定义词典"""
|
||||
if custom_dict_file and os.path.exists(custom_dict_file):
|
||||
try:
|
||||
jieba.load_userdict(custom_dict_file)
|
||||
self.log.info("成功加载自定义词典")
|
||||
except Exception as e:
|
||||
self.log.warning(f"加载自定义词典失败: {str(e)}")
|
||||
|
||||
# 从文件加载汽车后市场关键词并添加到jieba词典
|
||||
keywords = self.load_keywords()
|
||||
for keyword in keywords:
|
||||
jieba.add_word(keyword, freq=1000, tag='n')
|
||||
|
||||
def segment_and_pos(self, text: str, stopwords: set) -> List[str]:
|
||||
"""分词并标注词性,过滤停用词"""
|
||||
if not text or pd.isna(text):
|
||||
return []
|
||||
|
||||
words = pseg.cut(str(text))
|
||||
result = []
|
||||
# 汽车后市场相关的词性标签
|
||||
allowed_flags = {'n', 'vn', 'np', 'ns', 'nr', 'nt'} # 名词、动词、动名词、名词短语、处所词、人名、机构名
|
||||
|
||||
for word, flag in words:
|
||||
word = word.strip()
|
||||
if (len(word) >= 1 and
|
||||
word not in stopwords and
|
||||
flag in allowed_flags and
|
||||
not word.isdigit()): # 过滤纯数字
|
||||
result.append(word)
|
||||
|
||||
return result
|
||||
|
||||
def is_auto_aftermarket_related(self, text: str) -> bool:
|
||||
"""判断文本是否与汽车后市场相关"""
|
||||
if not text:
|
||||
return False
|
||||
|
||||
text_lower = str(text).lower()
|
||||
|
||||
# 从文件加载关键词
|
||||
keywords = self.load_keywords()
|
||||
|
||||
# 检查是否包含汽车后市场关键词
|
||||
for keyword in keywords:
|
||||
if keyword in text_lower:
|
||||
return True
|
||||
|
||||
# 检查分词结果中是否包含相关词汇
|
||||
words = self.segment_and_pos(text, self.stopwords)
|
||||
for word in words:
|
||||
if word in keywords:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def process_dataframe(self, df: pd.DataFrame, stopwords: set) -> pd.DataFrame:
|
||||
"""处理整个DataFrame,进行分词和过滤"""
|
||||
if df.empty:
|
||||
self.log.warning("输入的DataFrame为空")
|
||||
return df
|
||||
|
||||
# 确保所有文本都是字符串,并处理NaN值
|
||||
df['文章标题'] = df['文章标题'].fillna('').astype(str)
|
||||
df['文章摘要'] = df['文章摘要'].fillna('').astype(str)
|
||||
|
||||
# 合并标题和摘要进行分词
|
||||
df['combined_text'] = df['文章标题'] + ' ' + df['文章摘要']
|
||||
|
||||
# 分词处理
|
||||
df['segmented_words'] = df['combined_text'].apply(lambda x: self.segment_and_pos(x, stopwords))
|
||||
|
||||
# 判断是否与汽车后市场相关(只要出现关键词就入库)
|
||||
df['is_auto_related'] = df['combined_text'].apply(self.is_auto_aftermarket_related)
|
||||
df['is_filtered'] = df['is_auto_related']
|
||||
|
||||
# 添加处理时间
|
||||
df['processed_time'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
self.log.info(f"数据处理完成,共处理 {len(df)} 条记录")
|
||||
return df
|
||||
|
||||
def filter_auto_aftermarket_news(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""过滤出汽车后市场相关的新闻"""
|
||||
if df.empty:
|
||||
return df
|
||||
|
||||
# 过滤出包含关键词的文章
|
||||
filtered_df = df[df['is_filtered'] == True].copy()
|
||||
|
||||
self.log.info(f"过滤出 {len(filtered_df)} 条汽车后市场相关新闻")
|
||||
return filtered_df
|
||||
|
||||
def save_to_database(self, df: pd.DataFrame) -> bool:
|
||||
"""保存处理结果到数据库"""
|
||||
if df.empty:
|
||||
self.log.warning("没有数据需要保存")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 准备保存的数据
|
||||
save_df = df[['文章标题', '文章摘要', '发布时间', '来源URL', '文章链接',
|
||||
'segmented_words', 'is_auto_related', 'processed_time']].copy()
|
||||
|
||||
# 将分词结果转换为字符串
|
||||
save_df['分词结果'] = save_df['segmented_words'].apply(lambda x: ' '.join(x))
|
||||
|
||||
# 重命名列名为中文
|
||||
save_df = save_df.rename(columns={
|
||||
'is_auto_related': '是否汽车相关',
|
||||
'processed_time': '处理时间'
|
||||
})
|
||||
|
||||
# 删除不需要的列
|
||||
save_df = save_df.drop('segmented_words', axis=1)
|
||||
|
||||
# 检查目标表是否存在,不存在则创建
|
||||
if not self.db_agent.table_exists(self.processed_table_name):
|
||||
self.create_processed_table()
|
||||
|
||||
# 插入数据
|
||||
inserted_rows = self.db_agent.insert_from_df(
|
||||
table_name=self.processed_table_name,
|
||||
df=save_df,
|
||||
ignore_duplicates=True
|
||||
)
|
||||
|
||||
self.log.info(f"成功保存 {inserted_rows} 条处理结果到数据库")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.log.error(f"保存到数据库失败: {str(e)}", exc_info=True)
|
||||
return False
|
||||
|
||||
def create_processed_table(self):
|
||||
"""创建处理结果表"""
|
||||
create_sql = f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.processed_table_name} (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
文章标题 TEXT,
|
||||
文章摘要 TEXT,
|
||||
发布时间 DATETIME,
|
||||
来源URL VARCHAR(1024),
|
||||
文章链接 VARCHAR(1024),
|
||||
分词结果 TEXT,
|
||||
是否汽车相关 BOOLEAN,
|
||||
处理时间 DATETIME,
|
||||
创建时间 TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
更新时间 TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
|
||||
"""
|
||||
|
||||
try:
|
||||
self.db_agent.execute_sql(create_sql)
|
||||
self.log.info(f"成功创建处理结果表: {self.processed_table_name}")
|
||||
except Exception as e:
|
||||
self.log.error(f"创建表失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_processing_statistics(self, df: pd.DataFrame) -> Dict[str, Any]:
|
||||
"""获取处理统计信息"""
|
||||
if df.empty:
|
||||
return {}
|
||||
|
||||
total_count = len(df)
|
||||
filtered_count = len(df[df['is_filtered'] == True])
|
||||
|
||||
stats = {
|
||||
'total_articles': total_count,
|
||||
'filtered_articles': filtered_count,
|
||||
'filter_rate': filtered_count / total_count if total_count > 0 else 0,
|
||||
'processing_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
def process_rss_data(self, limit: int = 1000, save_to_db: bool = True) -> Dict[str, Any]:
|
||||
"""处理RSS数据的主函数"""
|
||||
try:
|
||||
self.log.info("开始处理RSS数据...")
|
||||
|
||||
# 1. 加载RSS数据
|
||||
df = self.load_rss_data(limit)
|
||||
if df.empty:
|
||||
self.log.warning("没有加载到RSS数据")
|
||||
return {'success': False, 'message': '没有数据可处理'}
|
||||
|
||||
# 2. 加载停用词表
|
||||
stopwords = self.load_stopwords()
|
||||
|
||||
# 3. 添加自定义词典
|
||||
self.add_custom_dict()
|
||||
|
||||
# 4. 处理数据
|
||||
processed_df = self.process_dataframe(df, stopwords)
|
||||
|
||||
# 5. 过滤汽车后市场相关新闻
|
||||
filtered_df = self.filter_auto_aftermarket_news(processed_df)
|
||||
|
||||
# 6. 获取统计信息
|
||||
stats = self.get_processing_statistics(processed_df)
|
||||
|
||||
# 7. 保存到数据库
|
||||
if save_to_db and not filtered_df.empty:
|
||||
save_success = self.save_to_database(filtered_df)
|
||||
stats['save_success'] = save_success
|
||||
|
||||
# 8. 标记数据为已处理
|
||||
if not df.empty and 'id' in df.columns:
|
||||
processed_ids = df['id'].tolist()
|
||||
mark_success = self.mark_as_processed(processed_ids)
|
||||
stats['mark_success'] = mark_success
|
||||
if not mark_success:
|
||||
self.log.warning("部分数据标记为已处理失败")
|
||||
|
||||
# 9. 输出结果
|
||||
self.log.info("RSS数据处理完成", **stats)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': 'RSS数据处理完成',
|
||||
'statistics': stats,
|
||||
'filtered_data': filtered_df
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.log.error(f"RSS数据处理失败: {str(e)}", exc_info=True)
|
||||
return {'success': False, 'message': f'处理失败: {str(e)}'}
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数入口"""
|
||||
try:
|
||||
# 创建处理器实例
|
||||
processor = RSSDataProcessor()
|
||||
|
||||
# 处理RSS数据
|
||||
result = processor.process_rss_data(
|
||||
limit=5000, # 处理最近5000条数据
|
||||
save_to_db=True # 保存到数据库
|
||||
)
|
||||
|
||||
if result['success']:
|
||||
print("RSS数据处理完成!")
|
||||
print(f"处理统计: {result['statistics']}")
|
||||
else:
|
||||
print(f"处理失败: {result['message']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"程序运行出错: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,100 @@
|
||||
的
|
||||
了
|
||||
在
|
||||
是
|
||||
我
|
||||
有
|
||||
和
|
||||
就
|
||||
不
|
||||
人
|
||||
都
|
||||
一
|
||||
一个
|
||||
上
|
||||
也
|
||||
很
|
||||
到
|
||||
说
|
||||
要
|
||||
去
|
||||
你
|
||||
会
|
||||
着
|
||||
没有
|
||||
看
|
||||
好
|
||||
自己
|
||||
这
|
||||
那
|
||||
它
|
||||
他
|
||||
她
|
||||
我们
|
||||
你们
|
||||
他们
|
||||
什么
|
||||
怎么
|
||||
为什么
|
||||
因为
|
||||
所以
|
||||
但是
|
||||
然后
|
||||
如果
|
||||
虽然
|
||||
而且
|
||||
或者
|
||||
可以
|
||||
应该
|
||||
必须
|
||||
需要
|
||||
想要
|
||||
希望
|
||||
觉得
|
||||
认为
|
||||
知道
|
||||
了解
|
||||
明白
|
||||
清楚
|
||||
简单
|
||||
容易
|
||||
困难
|
||||
重要
|
||||
主要
|
||||
基本
|
||||
一般
|
||||
特别
|
||||
非常
|
||||
十分
|
||||
相当
|
||||
比较
|
||||
更加
|
||||
最
|
||||
更
|
||||
已经
|
||||
正在
|
||||
将要
|
||||
可能
|
||||
也许
|
||||
大概
|
||||
大约
|
||||
左右
|
||||
上下
|
||||
今天
|
||||
明天
|
||||
昨天
|
||||
现在
|
||||
以前
|
||||
以后
|
||||
时候
|
||||
时间
|
||||
地方
|
||||
这里
|
||||
那里
|
||||
这样
|
||||
那样
|
||||
如此
|
||||
这样
|
||||
那样
|
||||
如何
|
||||
怎样
|
||||
@@ -0,0 +1,148 @@
|
||||
## 情报收集系统设计
|
||||
|
||||
### 参考文档
|
||||
https://alidocs.dingtalk.com/i/nodes/NZQYprEoWoexdo1ohPdxXvDbJ1waOeDk?utm_scene=team_space
|
||||
|
||||
### 程序框架(当前实现)
|
||||
```angular2html
|
||||
intelligence_system/
|
||||
├── collectors/ # 数据采集层
|
||||
│ ├── complaint_spider.py # 投诉信息爬虫(结构化入库/附件走MinIO)
|
||||
│ ├── rss_subscriptions.py # RSS 订阅抓取
|
||||
│ └── internal/ # 内部数据收集(保留)
|
||||
│ └── jian_dao_cloud.py # 简道云表单收集器(示例/占位)
|
||||
│
|
||||
├── processors/ # 数据处理层
|
||||
│ ├── processor_rss_data.py # RSS数据清洗、分词、过滤与入库
|
||||
│ ├── keywords.txt # 行业关键词(用于分词/过滤)
|
||||
│ ├── stopwords.txt # 停用词
|
||||
│ └── ai_engine/
|
||||
│ └── ai_proessor_rss_data # 预留(AI分析扩展占位)
|
||||
│
|
||||
├── services/ # 应用服务层(保留)
|
||||
│ ├── monitoring/ # 舆情监控
|
||||
│ │ ├── opinion_monitor.py # 实时舆情追踪(占位)
|
||||
│ │ └── brand_reputation.py # 品牌口碑分析(占位)
|
||||
│ ├── analysis/ # 竞品分析
|
||||
│ │ ├── competitor_tracker.py # 竞品动态监控(占位)
|
||||
│ │ └── swot_generator.py # SWOT分析报告(占位)
|
||||
│ ├── reporting/ # 报告服务
|
||||
│ │ ├── daily_reporter.py # 自动化日报生成(占位)
|
||||
│ │ └── weekly_digest.py # 周报汇编系统(占位)
|
||||
│ └── alert/ # 预警服务
|
||||
│ ├── alert_trigger.py # 动态阈值告警(占位)
|
||||
│ └── notification_center.py # 邮件/短信通知(占位)
|
||||
│
|
||||
├── applications/ # 应用层
|
||||
│ ├── alert.py # 告警触发/通知(占位/实现中)
|
||||
│ └── reporter/
|
||||
│ ├── daily.py # 日报生成
|
||||
│ └── monthly.py # 月报生成
|
||||
│
|
||||
├── system_management/ # 系统管理层
|
||||
│ ├── scheduler/
|
||||
│ │ ├── task_scheduler.py # 任务调度器(Cron表达式 + 线程池)
|
||||
│ │ └── task_management.py # 任务管理辅助
|
||||
│ └── monitor/ # 系统监控(目录占位)
|
||||
│
|
||||
├── utils/ # 工具库
|
||||
│ ├── file_handler.py # 通用文件操作
|
||||
│ ├── logger.py # 跨平台日志系统(Loguru)
|
||||
│ ├── mysql_agent.py # MySQL读写管理器
|
||||
│ └── minio_agent.py # MinIO对象存储客户端
|
||||
│
|
||||
├── config.py # 配置加载与管理(含数据库/存储配置)
|
||||
├── main.py # 系统入口(Cron轮询 + 调度执行)
|
||||
└── requirements.txt # 依赖清单
|
||||
```
|
||||
|
||||
### 程序设计原则
|
||||
1. 所有程序尽可能在py文件中运行,尽量避免使用命令行执行
|
||||
2. 配置需要在配置类中定义
|
||||
3. 密钥等信息直接放在配置类中
|
||||
4. 数据存储遵循"结构化存MySQL,非结构化存MinIO"原则,通过元数据关联
|
||||
|
||||
### 主程序与调度设计(已实现)
|
||||
主程序以长运行进程方式启动,进入轻量轮询循环(每10秒)。调度器按Cron表达式在`main_task`表中拉取到期任务,使用线程池异步执行,并在每分钟输出运行状态、每小时汇总统计。
|
||||
|
||||
- 调度器能力:
|
||||
- 基于`croniter`解析Cron表达式,支持时区(默认`Asia/Shanghai`)
|
||||
- 线程池并发执行,信号量限制最大并发(与`max_workers`一致)
|
||||
- 任务入口动态解析:支持`package.module`、`package.module.ClassName.main`、`package.module.func` 等形式
|
||||
- 成功/失败后自动计算`next_run_time`或设置15分钟后重试
|
||||
- 关键字段自动更新:`is_running`、`last_run_time`、`last_run_status`、`run_count`、`next_run_time`
|
||||
|
||||
- 主循环:
|
||||
- 每10秒检查一次待运行任务
|
||||
- 每分钟打印当前周期统计;每小时写入累计统计日志
|
||||
- 支持`SIGINT/SIGTERM`优雅关闭,等待正在运行的任务完成
|
||||
|
||||
### 日志设计(已实现)
|
||||
跨平台日志系统(Loguru)输出至`logs/`目录:
|
||||
|
||||
- application.log:主日志,`rotation = 20MB`,达到阈值后压缩为`application.log.YYYYMMDD.zip`,`retention = 30天`
|
||||
- errors.log:错误日志(ERROR及以上),`rotation = 10MB`,`retention = 90天`
|
||||
- 结构化扩展字段:日志支持`extra`键值对,自动美化并对长字段(如`sql`、`params`)截断
|
||||
|
||||
建议记录的业务事件:
|
||||
- MySQL读写操作要点(表名、影响行数、事务状态)
|
||||
- MinIO对象操作(对象路径、大小、耗时、状态)
|
||||
- 任务执行上下文(task_id、task_name、module_path、耗时、状态)
|
||||
|
||||
### 存储系统设计(MinIO+MySQL)
|
||||
#### 核心存储分工
|
||||
| 存储类型 | 适用数据 | 核心作用 |
|
||||
|----------|----------|----------|
|
||||
| MySQL | 结构化数据、元数据、关系型数据 | 存储业务逻辑数据、非结构化数据的索引信息、任务调度信息等 |
|
||||
| MinIO | 非结构化数据 | 存储图片、视频、PDF文档、原始爬取文件等二进制/大文件数据 |
|
||||
|
||||
|
||||
#### 核心存储配置
|
||||
1. **MySQL配置**
|
||||
- 数据库名称:`intelligence_system`
|
||||
- 连接管理:通过`utils/mysql_agent.py`封装线程安全的连接池,提供结构化数据的增删改查及SQL执行能力
|
||||
- 适配特性:支持多平台(Windows/macOS/Linux)的超时配置和批处理优化
|
||||
|
||||
2. **MinIO配置**
|
||||
- 存储桶命名规则:按数据类型划分,如`collector-images`(采集层图片)、`processor-videos`(处理层视频)
|
||||
- 连接管理:通过`utils/minio_agent.py`封装客户端,提供对象上传、下载、删除、查询URL等能力
|
||||
- 路径规则:`{数据层}/{来源}/{时间戳}_{唯一ID}.{后缀}`(例:`collector/weibo_spider/20240520_12345.jpg`)
|
||||
|
||||
|
||||
#### 表命名规则(扩展)
|
||||
- 数据采集类:以`collector_`为前缀(存储采集到的结构化数据及MinIO对象元数据)
|
||||
- 数据处理类:以`processor_`为前缀(存储处理结果的结构化数据及MinIO处理后对象的元数据)
|
||||
- 数据存储类:以`storage_`为前缀(存储MinIO对象的索引信息,如哈希、大小、访问权限等)
|
||||
- 应用层类:以`application_`为前缀(对应业务应用数据)
|
||||
- 系统类:如任务调度表等采用功能命名(如`main_task`)
|
||||
|
||||
|
||||
#### 核心表结构(当前落地)
|
||||
1. `main_task`:任务调度表(`task_name`、`task_type`、`module_path`、`cron_expression`、`time_zone`、`run_count`、`is_running`、`last_run_time`、`last_run_status`、`next_run_time`、`is_active` 等)
|
||||
2. `collector_rss_subscriptions`:RSS源采集数据(`文章标题`、`文章摘要`、`发布时间`、`来源URL`、`文章链接`、`是否已处理` 等)
|
||||
3. `processed_rss_data`:RSS处理结果(`分词结果`、`是否汽车相关`、`处理时间` 等)
|
||||
4. `collector_complaint_spider`:投诉信息爬虫数据(含文本与附件MinIO路径`attachment_minio_path`等)
|
||||
5. 可选:`storage_object_index`(建议用于统一索引MinIO对象元数据)
|
||||
|
||||
### 数据采集设计
|
||||
1. 结构化数据(RSS、投诉文本):写入`collector_`前缀表
|
||||
2. 非结构化数据(附件/图片等):
|
||||
- 使用`utils/minio_agent.py`上传至对应存储桶
|
||||
- 将对象路径与元数据写入业务表或`storage_object_index`
|
||||
3. 采集模块需同时处理MySQL与MinIO交互,确保关联完整
|
||||
|
||||
### 数据处理设计(RSS流程已实现)
|
||||
`processors/processor_rss_data.py`流程:
|
||||
- 从`collector_rss_subscriptions`加载未处理数据(可配置`limit`)
|
||||
- 加载停用词与行业关键词(`stopwords.txt` / `keywords.txt`),并动态注入`jieba`词典
|
||||
- 标注词性并过滤停用词,仅保留与汽车后市场相关的词汇
|
||||
- 标记与过滤:出现任一行业关键词即视为相关,进入保存
|
||||
- 将结果写入`processed_rss_data`,并回写源表`是否已处理 = 1`
|
||||
- 输出处理统计(总量、命中量、命中率、时间)
|
||||
|
||||
### 依赖与运行
|
||||
- 依赖:见`requirements.txt`(pandas、SQLAlchemy、PyMySQL、croniter、pytz、loguru、jieba、feedparser、beautifulsoup4、minio 等)
|
||||
- 配置:在`config.py`中设置`MYSQL_CONFIG`与MinIO参数
|
||||
- 运行:
|
||||
- 启动主程序:`python main.py`
|
||||
- 添加任务:向`main_task`插入记录,`module_path`可指向如`processors.processor_rss_data.main`
|
||||
@@ -0,0 +1,18 @@
|
||||
croniter==3.0.3
|
||||
dbutils==3.1.2
|
||||
loguru==0.7.3
|
||||
minio==7.2.16
|
||||
numpy==2.3.3
|
||||
pandas==2.3.2
|
||||
pymysql==1.1.2
|
||||
pytest==8.4.2
|
||||
pytz==2025.2
|
||||
Requests==2.32.5
|
||||
SQLAlchemy==2.0.43
|
||||
tenacity==9.1.2
|
||||
beautifulsoup4==4.13.5
|
||||
feedparser==6.0.11
|
||||
Markdown==3.9
|
||||
openai==1.107.3
|
||||
tqdm==4.67.1
|
||||
jieba==0.42.1
|
||||
@@ -1,683 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import platform
|
||||
import pandas as pd
|
||||
import pymysql
|
||||
from pymysql import cursors
|
||||
from pymysql.err import MySQLError
|
||||
from dbutils.pooled_db import PooledDB
|
||||
from typing import Union, List, Dict, Any, Optional, Tuple
|
||||
import threading
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
# 导入日志系统
|
||||
from utils.logger import log
|
||||
|
||||
|
||||
class MySQLAgent:
|
||||
"""
|
||||
全平台兼容的MySQL数据库操作类
|
||||
支持Windows/macOS/Linux系统
|
||||
配置参数从外部传入
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if not cls._instance:
|
||||
with cls._lock:
|
||||
if not cls._instance:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: dict):
|
||||
"""
|
||||
初始化MySQL数据库连接
|
||||
|
||||
Args:
|
||||
config (dict): 数据库配置字典,包含以下键:
|
||||
- host: 数据库主机
|
||||
- port: 端口
|
||||
- user: 用户名
|
||||
- password: 密码
|
||||
- database: 数据库名
|
||||
- [可选] charset: 字符集(默认utf8mb4)
|
||||
- [可选] max_connections: 最大连接数(默认5)
|
||||
- [可选] connect_timeout: 连接超时(秒)
|
||||
- [可选] read_timeout: 读取超时(秒)
|
||||
- [可选] write_timeout: 写入超时(秒)
|
||||
- [可选] ssl: SSL配置
|
||||
"""
|
||||
if hasattr(self, '_pool') and self._pool:
|
||||
return
|
||||
|
||||
# 基础配置
|
||||
required_keys = ['host', 'port', 'user', 'password', 'database']
|
||||
if not all(key in config for key in required_keys):
|
||||
raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}")
|
||||
|
||||
self.config = {
|
||||
'host': config['host'],
|
||||
'port': config['port'],
|
||||
'user': config['user'],
|
||||
'password': config['password'],
|
||||
'database': config['database'],
|
||||
'charset': config.get('charset', 'utf8mb4'),
|
||||
'cursorclass': cursors.DictCursor,
|
||||
'autocommit': True,
|
||||
'connect_timeout': config.get('connect_timeout', 10),
|
||||
'read_timeout': config.get('read_timeout', 30),
|
||||
'write_timeout': config.get('write_timeout', 30),
|
||||
'ssl': config.get('ssl')
|
||||
}
|
||||
|
||||
# 初始化log
|
||||
current_platform = platform.system()
|
||||
self.log = log.bind(module=f"MySQLAgent({current_platform})")
|
||||
|
||||
# 创建连接池
|
||||
self.pool_size = config.get('max_connections', 5)
|
||||
self._pool = self._create_pool()
|
||||
|
||||
def _create_pool(self) -> PooledDB:
|
||||
"""创建连接池"""
|
||||
try:
|
||||
# 使用包装函数确保线程安全
|
||||
def connect():
|
||||
conn = pymysql.connect(**self.config)
|
||||
conn.threadsafety = 1 # 显式设置线程安全级别
|
||||
return conn
|
||||
|
||||
pool = PooledDB(
|
||||
creator=connect,
|
||||
mincached=1,
|
||||
maxcached=3,
|
||||
maxconnections=self.pool_size,
|
||||
blocking=True,
|
||||
ping=1
|
||||
)
|
||||
|
||||
self.log.info("Connection pool created")
|
||||
return pool
|
||||
|
||||
except Exception as e:
|
||||
self.log.critical("Failed to create connection pool",
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
def get_connection(self) -> pymysql.connections.Connection:
|
||||
"""
|
||||
获取数据库连接
|
||||
|
||||
Returns:
|
||||
pymysql.connections.Connection: 数据库连接对象
|
||||
|
||||
Raises:
|
||||
MySQLError: 如果获取连接失败
|
||||
"""
|
||||
try:
|
||||
conn = self._pool.connection()
|
||||
|
||||
# macOS需要特殊处理SSL
|
||||
if platform.system() == 'Darwin' and self.config.get('ssl'):
|
||||
conn.ping(reconnect=True)
|
||||
|
||||
self.log.trace("Database connection obtained")
|
||||
return conn
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
||||
# Windows特定错误处理
|
||||
if platform.system() == 'Windows' and "timed out" in error_msg:
|
||||
self.log.warning("Windows connection timeout, retrying...")
|
||||
return self._retry_connection()
|
||||
|
||||
self.log.error("Connection failed",
|
||||
error=error_msg,
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
def _retry_connection(self, max_retries: int = 3) -> pymysql.connections.Connection:
|
||||
"""Windows平台连接重试机制"""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
conn = self._pool.connection()
|
||||
self.log.info(f"Connection established after {attempt + 1} attempts")
|
||||
return conn
|
||||
except Exception:
|
||||
if attempt == max_retries - 1:
|
||||
raise
|
||||
import time
|
||||
time.sleep(1)
|
||||
|
||||
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
|
||||
parse_dates: Union[List[str], bool] = True) -> pd.DataFrame:
|
||||
"""
|
||||
执行SQL查询并返回DataFrame
|
||||
|
||||
Args:
|
||||
sql (str): SQL查询语句
|
||||
params (Union[tuple, dict, None]): 查询参数
|
||||
parse_dates (Union[List[str], bool]): 自动解析日期字段
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: 查询结果
|
||||
|
||||
Raises:
|
||||
MySQLError: 如果查询失败
|
||||
"""
|
||||
try:
|
||||
self.log.debug("Executing SQL query", sql=sql)
|
||||
|
||||
with self.get_connection() as conn:
|
||||
# Linux/macOS需要更长的查询超时
|
||||
if platform.system() != 'Windows':
|
||||
conn.cursor().execute("SET SESSION wait_timeout=600")
|
||||
|
||||
df = pd.read_sql(sql, conn, params=params, parse_dates=parse_dates)
|
||||
|
||||
# Windows平台需要手动关闭游标
|
||||
if platform.system() == 'Windows':
|
||||
conn.cursor().close()
|
||||
|
||||
self.log.info("Query executed successfully", rows=len(df))
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("SQL query failed",
|
||||
sql=sql,
|
||||
params=params,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
def insert_from_df(self, table_name: str, df: pd.DataFrame,
|
||||
chunk_size: int = 1000, replace: bool = False) -> int:
|
||||
"""
|
||||
将DataFrame数据插入到数据库表(修复版)
|
||||
|
||||
Args:
|
||||
table_name (str): 目标表名
|
||||
df (pd.DataFrame): 要插入的数据
|
||||
chunk_size (int): 分批插入大小
|
||||
replace (bool): 是否替换现有数据
|
||||
|
||||
Returns:
|
||||
int: 插入的总行数
|
||||
|
||||
Raises:
|
||||
MySQLError: 如果插入失败
|
||||
"""
|
||||
if df.empty:
|
||||
self.log.warning("Attempted to insert empty DataFrame", table=table_name)
|
||||
return 0
|
||||
|
||||
self.log.debug("Preparing to insert DataFrame",
|
||||
table=table_name,
|
||||
rows=len(df),
|
||||
chunk_size=chunk_size)
|
||||
|
||||
try:
|
||||
method = 'replace' if replace else 'append'
|
||||
total_rows = 0
|
||||
|
||||
# 创建临时SQLAlchemy引擎(不创建新连接池)
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
# 获取当前连接并包装
|
||||
conn = self.get_connection()
|
||||
|
||||
# 修复连接对象缺少character_set_name的问题
|
||||
if not hasattr(conn, 'character_set_name'):
|
||||
conn.character_set_name = lambda: self.config.get('charset', 'utf8mb4')
|
||||
|
||||
engine = create_engine(
|
||||
"mysql+pymysql://",
|
||||
creator=lambda: conn,
|
||||
poolclass=StaticPool, # 使用静态池避免创建新连接
|
||||
connect_args={
|
||||
'charset': self.config.get('charset', 'utf8mb4'),
|
||||
'autocommit': True
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
for i in range(0, len(df), chunk_size):
|
||||
chunk = df.iloc[i:i + chunk_size]
|
||||
|
||||
# macOS需要特殊处理datetime
|
||||
if platform.system() == 'Darwin':
|
||||
for col in chunk.select_dtypes(include=['datetime64']):
|
||||
chunk[col] = chunk[col].dt.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
chunk.to_sql(
|
||||
table_name,
|
||||
engine,
|
||||
if_exists=method,
|
||||
index=False,
|
||||
method='multi'
|
||||
)
|
||||
total_rows += len(chunk)
|
||||
method = 'append' # 第一次之后都使用追加模式
|
||||
self.log.trace(f"Inserted chunk {i // chunk_size + 1}",
|
||||
rows=len(chunk),
|
||||
total_inserted=total_rows)
|
||||
|
||||
self.log.info("Data inserted successfully",
|
||||
table=table_name,
|
||||
total_rows=total_rows)
|
||||
return total_rows
|
||||
finally:
|
||||
# 确保连接正确关闭
|
||||
engine.dispose()
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("Data insertion failed",
|
||||
table=table_name,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
def update_from_df(self, table_name: str, df: pd.DataFrame,
|
||||
key_columns: Union[str, List[str]]) -> int:
|
||||
"""
|
||||
使用DataFrame数据更新数据库表
|
||||
|
||||
Args:
|
||||
table_name (str): 目标表名
|
||||
df (pd.DataFrame): 包含更新数据
|
||||
key_columns (Union[str, List[str]]): 用于匹配记录的关键列
|
||||
|
||||
Returns:
|
||||
int: 更新的总行数
|
||||
|
||||
Raises:
|
||||
MySQLError: 如果更新失败
|
||||
"""
|
||||
if df.empty:
|
||||
self.log.warning("Attempted to update with empty DataFrame", table=table_name)
|
||||
return 0
|
||||
|
||||
self.log.debug("Preparing to update table from DataFrame",
|
||||
table=table_name,
|
||||
key_columns=key_columns,
|
||||
rows=len(df))
|
||||
|
||||
try:
|
||||
if isinstance(key_columns, str):
|
||||
key_columns = [key_columns]
|
||||
|
||||
total_updated = 0
|
||||
conn = self.begin_transaction()
|
||||
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 获取表结构信息
|
||||
table_info = self._get_table_info(table_name)
|
||||
columns = [col for col in df.columns if col in table_info]
|
||||
|
||||
# 构建UPDATE语句模板
|
||||
set_clause = ', '.join([f"{col}=%s" for col in columns if col not in key_columns])
|
||||
where_clause = ' AND '.join([f"{col}=%s" for col in key_columns])
|
||||
|
||||
update_sql = f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}"
|
||||
self.log.trace("Generated update SQL", sql=update_sql)
|
||||
|
||||
# 准备数据
|
||||
update_data = []
|
||||
for _, row in df.iterrows():
|
||||
# SET部分的值
|
||||
set_values = [row[col] for col in columns if col not in key_columns]
|
||||
# WHERE部分的值
|
||||
key_values = [row[col] for col in key_columns]
|
||||
update_data.append(tuple(set_values + key_values))
|
||||
|
||||
# 执行批量更新
|
||||
cursor.executemany(update_sql, update_data)
|
||||
total_updated = cursor.rowcount
|
||||
|
||||
self.commit_transaction(conn)
|
||||
self.log.info("Data updated successfully",
|
||||
table=table_name,
|
||||
rows_updated=total_updated)
|
||||
return total_updated
|
||||
|
||||
except Exception as e:
|
||||
self.rollback_transaction(conn)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("Data update failed",
|
||||
table=table_name,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
def _get_table_info(self, table_name: str) -> Dict[str, str]:
|
||||
"""
|
||||
获取表结构信息
|
||||
|
||||
Args:
|
||||
table_name (str): 表名
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: 列名到类型的映射
|
||||
|
||||
Raises:
|
||||
MySQLError: 如果查询失败
|
||||
"""
|
||||
sql = f"""
|
||||
SELECT column_name, data_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s AND table_name = %s
|
||||
"""
|
||||
|
||||
params = (self.config['database'], table_name)
|
||||
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql, params)
|
||||
result = cursor.fetchall()
|
||||
return {row['column_name']: row['data_type'] for row in result}
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("Failed to get table info",
|
||||
table=table_name,
|
||||
error=str(e))
|
||||
raise
|
||||
|
||||
def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]:
|
||||
"""
|
||||
推断DataFrame各列的SQL类型
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): 输入数据框
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: 列名到SQL类型的映射
|
||||
"""
|
||||
type_mapping = {
|
||||
'int64': 'BIGINT',
|
||||
'float64': 'DOUBLE',
|
||||
'datetime64[ns]': 'DATETIME',
|
||||
'object': 'TEXT',
|
||||
'bool': 'TINYINT(1)',
|
||||
'category': 'VARCHAR(255)'
|
||||
}
|
||||
|
||||
sql_types = {}
|
||||
for col, dtype in df.dtypes.items():
|
||||
dtype_str = str(dtype)
|
||||
sql_types[col] = type_mapping.get(dtype_str, 'TEXT')
|
||||
|
||||
self.log.debug("Mapped DataFrame types to SQL types",
|
||||
mappings=sql_types)
|
||||
return sql_types
|
||||
|
||||
def create_table_from_df(self, table_name: str, df: pd.DataFrame,
|
||||
primary_key: Union[str, List[str], None] = None) -> bool:
|
||||
"""
|
||||
根据DataFrame结构创建表
|
||||
|
||||
Args:
|
||||
table_name (str): 表名
|
||||
df (pd.DataFrame): 参考数据框
|
||||
primary_key (Union[str, List[str], None]): 主键列
|
||||
|
||||
Returns:
|
||||
bool: 是否创建成功
|
||||
"""
|
||||
if self.table_exists(table_name):
|
||||
self.log.warning("Table already exists", table=table_name)
|
||||
return False
|
||||
|
||||
self.log.debug("Creating new table from DataFrame schema",
|
||||
table=table_name,
|
||||
columns=list(df.columns))
|
||||
|
||||
try:
|
||||
sql_types = self.df_to_sql_type(df)
|
||||
columns_sql = []
|
||||
|
||||
for col, sql_type in sql_types.items():
|
||||
col_def = f"{col} {sql_type}"
|
||||
columns_sql.append(col_def)
|
||||
|
||||
if primary_key:
|
||||
if isinstance(primary_key, str):
|
||||
primary_key = [primary_key]
|
||||
pk_columns = [col for col in primary_key if col in sql_types]
|
||||
if pk_columns:
|
||||
columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})")
|
||||
self.log.trace("Set primary key",
|
||||
table=table_name,
|
||||
primary_key=pk_columns)
|
||||
|
||||
create_sql = f"CREATE TABLE {table_name} (\n {',\n '.join(columns_sql)}\n)"
|
||||
|
||||
self.execute_sql(create_sql)
|
||||
self.log.info("Table created successfully", table=table_name)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("Failed to create table",
|
||||
table=table_name,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
return False
|
||||
|
||||
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
|
||||
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
|
||||
"""
|
||||
执行SQL语句
|
||||
|
||||
Args:
|
||||
sql (str): SQL语句
|
||||
params (Union[tuple, dict, None]): 参数
|
||||
fetch (bool): 是否获取结果
|
||||
|
||||
Returns:
|
||||
Union[int, List[Dict[str, Any]]]:
|
||||
- 如果是INSERT/UPDATE/DELETE,返回影响的行数
|
||||
- 如果是SELECT且fetch=True,返回结果列表
|
||||
"""
|
||||
conn = None
|
||||
cursor = None
|
||||
try:
|
||||
conn = self.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Linux/macOS需要更长的执行时间
|
||||
if platform.system() != 'Windows':
|
||||
cursor.execute("SET SESSION max_execution_time=600000")
|
||||
|
||||
cursor.execute(sql, params)
|
||||
|
||||
if fetch:
|
||||
result = cursor.fetchall()
|
||||
self.log.debug("Query executed", rows=len(result))
|
||||
return result
|
||||
else:
|
||||
affected_rows = cursor.rowcount
|
||||
self.log.debug("Update executed", affected_rows=affected_rows)
|
||||
return affected_rows
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("SQL execution failed",
|
||||
sql=sql,
|
||||
params=params,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
def begin_transaction(self) -> pymysql.connections.Connection:
|
||||
"""开始事务"""
|
||||
try:
|
||||
conn = self.get_connection()
|
||||
conn.autocommit(False)
|
||||
|
||||
# macOS需要特殊处理事务隔离级别
|
||||
if platform.system() == 'Darwin':
|
||||
conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED")
|
||||
|
||||
self.log.debug("Transaction started")
|
||||
return conn
|
||||
except Exception as e:
|
||||
self.log.error("Begin transaction_failed", error=str(e))
|
||||
raise
|
||||
|
||||
def commit_transaction(self, conn: pymysql.connections.Connection) -> None:
|
||||
"""提交事务"""
|
||||
try:
|
||||
conn.commit()
|
||||
self.log.debug("Transaction committed")
|
||||
except Exception as e:
|
||||
self.log.error("Commit failed", error=str(e))
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def rollback_transaction(self, conn: pymysql.connections.Connection) -> None:
|
||||
"""回滚事务"""
|
||||
try:
|
||||
conn.rollback()
|
||||
self.log.warning("Transaction rolled back")
|
||||
except Exception as e:
|
||||
self.log.error("Rollback failed", error=str(e))
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def table_exists(self, table_name: str) -> bool:
|
||||
"""检查表是否存在"""
|
||||
sql = """
|
||||
SELECT COUNT(*) as count
|
||||
FROM `information_schema`.`tables`
|
||||
WHERE `table_schema` = %s AND `table_name` = %s
|
||||
"""
|
||||
|
||||
params = (self.config['database'], table_name)
|
||||
|
||||
try:
|
||||
result = self.execute_sql(sql, params, fetch=True)
|
||||
exists = result[0]['count'] > 0
|
||||
self.log.debug("Checked table existence",
|
||||
table=table_name,
|
||||
exists=exists)
|
||||
return exists
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def drop_table(self, table_name: str) -> bool:
|
||||
"""删除表"""
|
||||
if not self.table_exists(table_name):
|
||||
self.log.warning("Table does not exist", table=table_name)
|
||||
return False
|
||||
|
||||
try:
|
||||
self.execute_sql(f"DROP TABLE {table_name}")
|
||||
self.log.info("Table dropped successfully", table=table_name)
|
||||
return True
|
||||
except Exception as e:
|
||||
self.log.error("Failed to drop table",
|
||||
table=table_name,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
return False
|
||||
|
||||
def get_pool_status(self) -> Dict[str, int]:
|
||||
"""获取连接池状态"""
|
||||
return {
|
||||
'max': self._pool._maxconnections,
|
||||
'active': self._pool._connections,
|
||||
'idle': len(self._pool._idle_cache),
|
||||
'shared': len(self._pool._shared_cache)
|
||||
}
|
||||
|
||||
def validate_connection(self) -> bool:
|
||||
"""验证连接是否有效"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("SELECT 1")
|
||||
return cursor.fetchone()[0] == 1
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def __del__(self):
|
||||
"""析构函数"""
|
||||
if hasattr(self, '_pool'):
|
||||
try:
|
||||
self._pool.close()
|
||||
self.log.info("Connection pool closed")
|
||||
except Exception as e:
|
||||
self.log.error("Failed to close pool", error=str(e))
|
||||
|
||||
|
||||
# 平台特定的默认配置
|
||||
def get_default_config():
|
||||
"""获取各平台默认配置"""
|
||||
current_platform = platform.system()
|
||||
|
||||
base_config = {
|
||||
'host': 'localhost',
|
||||
'port': 3306,
|
||||
'user': 'root',
|
||||
'password': '123123',
|
||||
'database': 'intelligence',
|
||||
'max_connections': 5
|
||||
}
|
||||
|
||||
if current_platform == 'Windows':
|
||||
return {
|
||||
**base_config,
|
||||
'connect_timeout': 10,
|
||||
'read_timeout': 30,
|
||||
'write_timeout': 30
|
||||
}
|
||||
elif current_platform == 'Darwin':
|
||||
return {
|
||||
**base_config,
|
||||
'connect_timeout': 15,
|
||||
'read_timeout': 60,
|
||||
'write_timeout': 60,
|
||||
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
|
||||
}
|
||||
else: # Linux和其他平台
|
||||
return {
|
||||
**base_config,
|
||||
'connect_timeout': 15,
|
||||
'read_timeout': 60,
|
||||
'write_timeout': 60
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 使用示例
|
||||
db = MySQLAgent(get_default_config())
|
||||
|
||||
# 测试连接
|
||||
if db.validate_connection():
|
||||
print("Database connection successful")
|
||||
|
||||
# 获取数据库版本
|
||||
version = db.query_to_df("SELECT VERSION() as version")
|
||||
print(f"Database version: {version['version'].iloc[0]}")
|
||||
|
||||
# 查看连接池状态
|
||||
print("Connection pool status:", db.get_pool_status())
|
||||
else:
|
||||
print("Failed to connect to database")
|
||||
@@ -0,0 +1,2 @@
|
||||
# Makes system_management a package
|
||||
|
||||
Binary file not shown.
@@ -0,0 +1,3 @@
|
||||
# Makes system_management.scheduler a package
|
||||
from .task_scheduler import TaskScheduler
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,190 @@
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
from system_management.scheduler.task_scheduler import TaskScheduler
|
||||
from system_management.scheduler.task_scheduler import TaskManager
|
||||
from config import Config
|
||||
from utils.logger import CrossPlatformLog
|
||||
|
||||
# 初始化日志
|
||||
log = CrossPlatformLog.get_logger("TaskManagement")
|
||||
|
||||
|
||||
def main():
|
||||
# 初始化配置和组件
|
||||
scheduler = TaskScheduler(Config.MYSQL_CONFIG)
|
||||
manager = TaskManager(scheduler)
|
||||
|
||||
# 解析命令行参数
|
||||
parser = argparse.ArgumentParser(description="任务管理工具")
|
||||
subparsers = parser.add_subparsers(dest="command", help="可用命令")
|
||||
|
||||
# 列出任务
|
||||
list_parser = subparsers.add_parser("list", help="列出所有任务")
|
||||
list_parser.add_argument("--active-only", action="store_true", help="只显示活跃任务")
|
||||
|
||||
# 查看任务详情
|
||||
show_parser = subparsers.add_parser("show", help="显示任务详情")
|
||||
show_parser.add_argument("task_id", type=int, help="任务ID")
|
||||
|
||||
# 更新任务
|
||||
update_parser = subparsers.add_parser("update", help="更新任务属性")
|
||||
update_parser.add_argument("task_id", type=int, help="任务ID")
|
||||
update_parser.add_argument("--name", help="任务名称")
|
||||
update_parser.add_argument("--type", help="任务类型")
|
||||
update_parser.add_argument("--module", help="模块路径")
|
||||
update_parser.add_argument("--cron", help="Cron表达式")
|
||||
update_parser.add_argument("--timezone", help="时区")
|
||||
|
||||
# 启用/禁用任务
|
||||
toggle_parser = subparsers.add_parser("toggle", help="启用/禁用任务")
|
||||
toggle_parser.add_argument("task_id", type=int, help="任务ID")
|
||||
toggle_parser.add_argument("--activate", action="store_true", help="启用任务")
|
||||
toggle_parser.add_argument("--deactivate", action="store_true", help="禁用任务")
|
||||
|
||||
# 删除任务
|
||||
delete_parser = subparsers.add_parser("delete", help="删除任务")
|
||||
delete_parser.add_argument("task_id", type=int, help="任务ID")
|
||||
|
||||
# 手动执行任务
|
||||
run_parser = subparsers.add_parser("run", help="手动执行任务")
|
||||
run_parser.add_argument("task_id", type=int, help="任务ID")
|
||||
|
||||
# 添加任务
|
||||
add_parser = subparsers.add_parser("add", help="添加新任务")
|
||||
add_parser.add_argument("--name", required=True, help="任务名称")
|
||||
add_parser.add_argument("--type", required=True, help="任务类型")
|
||||
add_parser.add_argument("--module", required=True, help="模块路径")
|
||||
add_parser.add_argument("--cron", required=True, help="Cron表达式")
|
||||
add_parser.add_argument("--timezone", default="Asia/Shanghai", help="时区")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 执行相应命令
|
||||
if args.command == "list":
|
||||
try:
|
||||
tasks = manager.get_all_tasks(args.active_only)
|
||||
manager.print_task_table(tasks)
|
||||
log.info(f"列出任务完成,共{len(tasks)}个任务")
|
||||
except Exception as e:
|
||||
log.error(f"列出任务失败: {str(e)}", exc_info=True)
|
||||
|
||||
elif args.command == "show":
|
||||
try:
|
||||
task = manager.get_task_by_id(args.task_id)
|
||||
if task:
|
||||
print("\n===== 任务详情 =====")
|
||||
for key, value in task.items():
|
||||
print(f"{key}: {value}")
|
||||
print("====================")
|
||||
log.info(f"显示任务详情成功,任务ID: {args.task_id}")
|
||||
else:
|
||||
log.warning(f"未找到任务ID: {args.task_id}")
|
||||
print(f"任务ID {args.task_id} 不存在")
|
||||
except Exception as e:
|
||||
log.error(f"显示任务详情失败,任务ID: {args.task_id}", exc_info=True)
|
||||
|
||||
elif args.command == "update":
|
||||
try:
|
||||
updates = {}
|
||||
if args.name:
|
||||
updates['task_name'] = args.name
|
||||
if args.type:
|
||||
updates['task_type'] = args.type
|
||||
if args.module:
|
||||
updates['module_path'] = args.module
|
||||
if args.cron:
|
||||
updates['cron_expression'] = args.cron
|
||||
if args.timezone:
|
||||
updates['time_zone'] = args.timezone
|
||||
|
||||
if not updates:
|
||||
log.warning("未提供任何更新字段")
|
||||
print("请至少指定一个更新字段")
|
||||
return
|
||||
|
||||
if manager.update_task(args.task_id, updates):
|
||||
log.info(f"任务ID {args.task_id} 更新成功")
|
||||
print(f"任务ID {args.task_id} 更新成功")
|
||||
else:
|
||||
log.warning(f"任务ID {args.task_id} 更新失败")
|
||||
print(f"任务ID {args.task_id} 更新失败")
|
||||
except Exception as e:
|
||||
log.error(f"更新任务失败,任务ID: {args.task_id}", exc_info=True)
|
||||
|
||||
elif args.command == "toggle":
|
||||
try:
|
||||
if args.activate and args.deactivate:
|
||||
log.warning("不能同时指定 --activate 和 --deactivate")
|
||||
print("不能同时指定 --activate 和 --deactivate")
|
||||
return
|
||||
if not args.activate and not args.deactivate:
|
||||
log.warning("请指定 --activate 或 --deactivate")
|
||||
print("请指定 --activate 或 --deactivate")
|
||||
return
|
||||
|
||||
if args.activate:
|
||||
success = manager.toggle_task_status(args.task_id, True)
|
||||
action = "启用"
|
||||
else:
|
||||
success = manager.toggle_task_status(args.task_id, False)
|
||||
action = "禁用"
|
||||
|
||||
if success:
|
||||
log.info(f"任务ID {args.task_id} {action}成功")
|
||||
print(f"任务ID {args.task_id} {action}成功")
|
||||
else:
|
||||
log.warning(f"任务ID {args.task_id} {action}失败")
|
||||
print(f"任务ID {args.task_id} {action}失败")
|
||||
except Exception as e:
|
||||
log.error(f"切换任务状态失败,任务ID: {args.task_id}", exc_info=True)
|
||||
|
||||
elif args.command == "delete":
|
||||
try:
|
||||
confirm = input(f"确定要删除任务ID {args.task_id} 吗? (y/n) ")
|
||||
if confirm.lower() == 'y':
|
||||
if manager.delete_task(args.task_id):
|
||||
log.info(f"任务ID {args.task_id} 删除成功")
|
||||
print(f"任务ID {args.task_id} 删除成功")
|
||||
else:
|
||||
log.warning(f"任务ID {args.task_id} 删除失败")
|
||||
print(f"任务ID {args.task_id} 删除失败")
|
||||
else:
|
||||
log.info(f"用户取消删除任务ID {args.task_id}")
|
||||
print("操作已取消")
|
||||
except Exception as e:
|
||||
log.error(f"删除任务失败,任务ID: {args.task_id}", exc_info=True)
|
||||
|
||||
elif args.command == "run":
|
||||
try:
|
||||
log.info(f"开始手动执行任务ID {args.task_id}")
|
||||
print(f"正在手动执行任务ID {args.task_id}...")
|
||||
if manager.run_task_manually(args.task_id):
|
||||
log.info(f"任务ID {args.task_id} 执行成功")
|
||||
print(f"任务ID {args.task_id} 执行成功")
|
||||
else:
|
||||
log.warning(f"任务ID {args.task_id} 执行失败")
|
||||
print(f"任务ID {args.task_id} 执行失败")
|
||||
except Exception as e:
|
||||
log.error(f"手动执行任务失败,任务ID: {args.task_id}", exc_info=True)
|
||||
|
||||
elif args.command == "add":
|
||||
try:
|
||||
task_id = scheduler.add_task(
|
||||
task_name=args.name,
|
||||
task_type=args.type,
|
||||
module_path=args.module,
|
||||
cron_expression=args.cron,
|
||||
time_zone=args.timezone
|
||||
)
|
||||
log.info(f"新任务添加成功,ID: {task_id}")
|
||||
print(f"新任务添加成功,ID: {task_id}")
|
||||
except Exception as e:
|
||||
log.error(f"添加任务失败: {str(e)}", exc_info=True)
|
||||
print(f"添加任务失败: {str(e)}")
|
||||
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,277 +1,484 @@
|
||||
# system_management/scheduler/task_scheduler.py
|
||||
import importlib
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional
|
||||
from storage.mysql_agent import MySQLAgent
|
||||
from pathlib import Path
|
||||
|
||||
# 使用您的日志系统
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any
|
||||
import croniter
|
||||
import pytz
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import pandas as pd
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from utils.mysql_agent import MySQLAgent
|
||||
from utils.logger import CrossPlatformLog
|
||||
|
||||
# 初始化调度器日志
|
||||
log = CrossPlatformLog.get_logger("TaskScheduler")
|
||||
|
||||
|
||||
class TaskScheduler:
|
||||
def __init__(self, db_config: Optional[Dict] = None):
|
||||
"""
|
||||
初始化任务调度器
|
||||
def __init__(self, db_config: Optional[Dict] = None, max_workers: int = 5):
|
||||
"""初始化任务调度器(基于Cron表达式)"""
|
||||
self.db = MySQLAgent(db_config or {})
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
# 并发容量控制:限制同时运行的后台任务不超过 max_workers
|
||||
self._running_semaphore = threading.Semaphore(max_workers)
|
||||
# 任务统计
|
||||
self.hourly_stats = {'成功': 0, '失败': 0, '总数': 0}
|
||||
self.hourly_stats_lock = threading.Lock()
|
||||
log.info(f"任务调度器已初始化,最大工作线程数: {max_workers}")
|
||||
|
||||
def _resolve_callable(self, module_path: str):
|
||||
"""解析模块路径,支持模块、模块内类/函数,并返回可调用对象
|
||||
|
||||
兼容以下形式:
|
||||
- package.module -> 期望模块内存在 main()
|
||||
- package.module.ClassName -> 调用 ClassName.main() 或实例化后调用 main()
|
||||
- package.module.func_name -> 直接调用该函数
|
||||
- package.module.ClassName.method_name -> 调用指定方法
|
||||
"""
|
||||
if not module_path or not isinstance(module_path, str):
|
||||
raise ImportError("无效的模块路径")
|
||||
|
||||
parts = module_path.split('.')
|
||||
last_import_error = None
|
||||
|
||||
# 从最长前缀开始尝试导入模块,逐步回退
|
||||
for i in range(len(parts), 0, -1):
|
||||
module_name = '.'.join(parts[:i])
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
attr_chain = parts[i:]
|
||||
|
||||
# 从模块开始逐级解析属性
|
||||
target = module
|
||||
for attr in attr_chain:
|
||||
if not hasattr(target, attr):
|
||||
raise AttributeError(f"在 {target} 中未找到属性: {attr}")
|
||||
target = getattr(target, attr)
|
||||
|
||||
# 若目标是类,优先尝试类方法/实例方法 main
|
||||
if isinstance(target, type):
|
||||
# 类方法 main
|
||||
if hasattr(target, 'main') and callable(getattr(target, 'main')):
|
||||
return getattr(target, 'main')
|
||||
# 实例方法 main
|
||||
try:
|
||||
instance = target()
|
||||
if hasattr(instance, 'main') and callable(getattr(instance, 'main')):
|
||||
return getattr(instance, 'main')
|
||||
except Exception:
|
||||
pass
|
||||
# 不把“类本身”当作任务入口(否则只会构造实例不执行 main)
|
||||
raise AttributeError(f"类 {target.__name__} 缺少可调用的 main() 作为任务入口")
|
||||
|
||||
# 目标非类:若本身可调用,则直接作为入口返回
|
||||
if callable(target):
|
||||
return target
|
||||
|
||||
# 否则尝试对象上的 main()
|
||||
if hasattr(target, 'main') and callable(getattr(target, 'main')):
|
||||
return getattr(target, 'main')
|
||||
|
||||
raise AttributeError(f"路径 {module_path} 未解析到可调用入口(缺少 main 或不可调用)")
|
||||
except Exception as e:
|
||||
last_import_error = e
|
||||
continue
|
||||
|
||||
# 如果所有尝试均失败,则抛出最后的错误
|
||||
raise ImportError(f"模块 {module_path} 导入/解析失败: {str(last_import_error)}")
|
||||
|
||||
def check_and_run_tasks(self, print_empty_status: bool = False) -> Dict[str, int]:
|
||||
"""检查并执行所有到期的任务,优化空任务处理和异常容错
|
||||
|
||||
Args:
|
||||
db_config (Optional[Dict]): 可选的数据库配置,默认使用MySQLAgent默认配置
|
||||
print_empty_status: 是否打印空任务状态(默认False,避免频繁输出)
|
||||
"""
|
||||
self.db = MySQLAgent(db_config or {}) # 使用您提供的MySQLAgent
|
||||
self._init_task_table()
|
||||
log.info("TaskScheduler initialized")
|
||||
|
||||
def _init_task_table(self):
|
||||
"""确保任务表存在并包含必要字段"""
|
||||
if not self.db.table_exists("main_task"):
|
||||
log.info("Creating main_task table")
|
||||
create_sql = """
|
||||
CREATE TABLE main_task (
|
||||
task_id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
task_name VARCHAR(100) NOT NULL,
|
||||
module_path VARCHAR(255) NOT NULL COMMENT '例如data_collection.spiders.weibo_spider',
|
||||
frequency_type ENUM('minute','hourly','daily','weekly','monthly') NOT NULL,
|
||||
frequency_value INT DEFAULT NULL COMMENT '间隔数值',
|
||||
last_run_time DATETIME DEFAULT NULL,
|
||||
next_run_time DATETIME DEFAULT NULL,
|
||||
last_run_status VARCHAR(20) DEFAULT NULL,
|
||||
is_active TINYINT(1) DEFAULT 1,
|
||||
is_running TINYINT(1) DEFAULT 0,
|
||||
run_count INT DEFAULT 0,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
|
||||
INDEX idx_next_run (next_run_time),
|
||||
INDEX idx_active (is_active)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
|
||||
"""
|
||||
self.db.execute_sql(create_sql)
|
||||
log.success("main_task table created")
|
||||
|
||||
def run_pending_tasks(self) -> Dict[str, int]:
|
||||
"""
|
||||
执行所有到期的活跃任务
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: 包含执行结果的字典 {
|
||||
'total': 总任务数,
|
||||
'success': 成功数,
|
||||
'failed': 失败数
|
||||
}
|
||||
"""
|
||||
result = {'total': 0, 'success': 0, 'failed': 0}
|
||||
result = {'总任务数': 0, '成功': 0, '失败': 0}
|
||||
|
||||
try:
|
||||
# 使用您提供的query_to_df方法获取任务
|
||||
tasks_df = self.db.query_to_df(
|
||||
"SELECT * FROM main_task "
|
||||
"WHERE is_active = 1 AND next_run_time <= %s "
|
||||
"ORDER BY next_run_time",
|
||||
params=(datetime.now(),)
|
||||
)
|
||||
# 获取当前时间(带时区转换为本地时间)
|
||||
tz = pytz.timezone('Asia/Shanghai')
|
||||
now = datetime.now(tz).replace(tzinfo=None) # 移除时区信息,与数据库存储一致
|
||||
log.debug(f"当前检查时间: {now.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
result['total'] = len(tasks_df)
|
||||
# 查询所有到期的活跃任务(使用参数化查询防止注入)
|
||||
tasks_df = self.db.query_to_df("""
|
||||
SELECT *
|
||||
FROM main_task
|
||||
WHERE is_active = 1
|
||||
AND next_run_time <= %s
|
||||
AND is_running = 0
|
||||
ORDER BY next_run_time
|
||||
""", params=(now,),is_print=False)
|
||||
|
||||
result['总任务数'] = len(tasks_df)
|
||||
if tasks_df.empty:
|
||||
log.debug("No pending tasks found")
|
||||
# 空任务时根据参数决定是否输出
|
||||
if print_empty_status:
|
||||
print(f"当前没有到期的任务,等待新任务加入...{now.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
return result
|
||||
|
||||
# 并发执行任务
|
||||
futures = []
|
||||
for _, task in tasks_df.iterrows():
|
||||
task_id = task['task_id']
|
||||
log.bind(task_id=task_id).info(
|
||||
f"Starting task {task['task_name']}"
|
||||
)
|
||||
|
||||
# 标记任务为执行中
|
||||
self._update_task_status(
|
||||
task_id,
|
||||
{'is_running': 1, 'last_run_time': datetime.now()}
|
||||
)
|
||||
# 传递任务字典的副本避免线程安全问题
|
||||
task_copy = task.to_dict()
|
||||
futures.append(self.executor.submit(self._process_single_task, task_copy))
|
||||
|
||||
# 收集执行结果
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
self._execute_single_task(task)
|
||||
self._update_task_status(
|
||||
task_id,
|
||||
{
|
||||
'last_run_status': 'success',
|
||||
'is_running': 0,
|
||||
'run_count': task['run_count'] + 1,
|
||||
'next_run_time': self._calculate_next_run(
|
||||
task['frequency_type'],
|
||||
task['frequency_value']
|
||||
)
|
||||
}
|
||||
)
|
||||
result['success'] += 1
|
||||
log.bind(task_id=task_id).success("Task completed")
|
||||
if future.result():
|
||||
result['成功'] += 1
|
||||
else:
|
||||
result['失败'] += 1
|
||||
except Exception as e:
|
||||
log.bind(task_id=task_id).error(
|
||||
f"Task failed: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
self._update_task_status(
|
||||
task_id,
|
||||
{
|
||||
'last_run_status': 'failed',
|
||||
'is_running': 0,
|
||||
'next_run_time': self._calculate_next_run(
|
||||
task['frequency_type'],
|
||||
task['frequency_value'],
|
||||
retry=True
|
||||
)
|
||||
}
|
||||
)
|
||||
result['failed'] += 1
|
||||
log.error(f"任务线程执行失败: {str(e)}", exc_info=True)
|
||||
result['失败'] += 1
|
||||
|
||||
# 更新小时统计
|
||||
with self.hourly_stats_lock:
|
||||
self.hourly_stats['成功'] += result['成功']
|
||||
self.hourly_stats['失败'] += result['失败']
|
||||
self.hourly_stats['总数'] += result['总任务数']
|
||||
|
||||
log.info(
|
||||
"Scheduler cycle completed",
|
||||
total_tasks=result['total'],
|
||||
success=result['success'],
|
||||
failed=result['failed']
|
||||
"任务调度周期完成",
|
||||
总任务数=result['总任务数'],
|
||||
成功=result['成功'],
|
||||
失败=result['失败']
|
||||
)
|
||||
return result
|
||||
|
||||
except SQLAlchemyError as e: # 数据库异常处理优化
|
||||
log.error(f"数据库操作失败,将在下次轮询重试: {str(e)}", exc_info=True)
|
||||
return result # 不中断,返回当前结果
|
||||
except Exception as e:
|
||||
log.critical(
|
||||
"Scheduler main loop failed",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
log.error("调度器周期执行异常,将在下次轮询重试", exc_info=True)
|
||||
return result # 不中断主循环,允许下次重试
|
||||
|
||||
def _execute_single_task(self, task: Dict) -> None:
|
||||
"""执行单个任务模块"""
|
||||
start_time = time.time()
|
||||
task_log = log.bind(
|
||||
task_id=task['task_id'],
|
||||
module=task['module_path']
|
||||
)
|
||||
def _process_single_task(self, task: Dict[str, Any]) -> bool:
|
||||
"""处理单个任务(线程安全)"""
|
||||
task_id = task['task_id']
|
||||
task_name = task['task_name']
|
||||
task_log = log.bind(task_id=task_id, task_name=task_name)
|
||||
task_log.info(f"开始执行任务: {task_name}")
|
||||
|
||||
try:
|
||||
module = importlib.import_module(task['module_path'])
|
||||
# 阻塞等待可用的执行槽位,保证同时运行的任务不超过最大工作线程数
|
||||
self._running_semaphore.acquire()
|
||||
|
||||
if not hasattr(module, 'main'):
|
||||
raise ImportError(f"Module has no main() function")
|
||||
# 标记任务为运行中(使用当前时间的时区感知对象)
|
||||
tz = pytz.timezone(task.get('time_zone', 'Asia/Shanghai'))
|
||||
current_time = datetime.now(tz).replace(tzinfo=None)
|
||||
|
||||
# 执行任务
|
||||
task_log.debug("Task execution started")
|
||||
module.main()
|
||||
self._update_task_status(task_id, {
|
||||
'is_running': 1,
|
||||
'last_run_time': current_time
|
||||
})
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
task_log.info(
|
||||
f"Task completed in {elapsed:.2f}s",
|
||||
duration=elapsed
|
||||
)
|
||||
# 将任务主体放到后台线程执行,当前线程快速返回
|
||||
self.executor.submit(self._run_task_async, task.copy())
|
||||
task_log.debug("任务已提交至后台执行队列")
|
||||
return True # 表示已成功提交
|
||||
|
||||
except Exception as e:
|
||||
task_log.error(
|
||||
"Task execution failed",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
task_log.error(f"任务执行失败: {str(e)}", exc_info=True)
|
||||
|
||||
def _update_task_status(self, task_id: int, updates: Dict) -> None:
|
||||
"""更新任务状态"""
|
||||
set_clause = ", ".join([f"{k}=%s" for k in updates.keys()])
|
||||
sql = f"UPDATE main_task SET {set_clause} WHERE task_id=%s"
|
||||
# 失败时计算下次重试时间(15分钟后)
|
||||
next_retry_time = datetime.now() + pd.Timedelta(minutes=15)
|
||||
|
||||
params = list(updates.values()) + [task_id]
|
||||
# 即使任务执行失败,也要确保状态更新
|
||||
try:
|
||||
self._update_task_status(task_id, {
|
||||
'last_run_status': 'failed',
|
||||
'is_running': 0,
|
||||
'next_run_time': next_retry_time
|
||||
})
|
||||
except Exception as update_err:
|
||||
task_log.error(f"任务失败后状态更新失败: {str(update_err)}", exc_info=True)
|
||||
|
||||
# 若已占用并发槽位,释放之
|
||||
try:
|
||||
self._running_semaphore.release()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
def _run_task_async(self, task: Dict[str, Any]) -> None:
|
||||
"""在后台线程中执行任务主体,并在结束后更新状态"""
|
||||
task_id = task['task_id']
|
||||
task_name = task['task_name']
|
||||
task_log = log.bind(task_id=task_id, task_name=task_name)
|
||||
try:
|
||||
affected = self.db.execute_sql(sql, params=params)
|
||||
if affected != 1:
|
||||
log.warning(
|
||||
"Unexpected row count in update",
|
||||
task_id=task_id,
|
||||
expected=1,
|
||||
affected=affected
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
"Failed to update task status",
|
||||
task_id=task_id,
|
||||
exc_info=True
|
||||
# 如果 module_path 指向类,先实例化以触发初始化日志,然后执行 main
|
||||
self._execute_task_logic(task)
|
||||
|
||||
# 成功后计算下次运行时间
|
||||
next_run_time = self._calculate_next_run_time(
|
||||
cron_expr=task['cron_expression'],
|
||||
time_zone=task.get('time_zone', 'Asia/Shanghai')
|
||||
)
|
||||
raise
|
||||
|
||||
def _calculate_next_run(self, freq_type: str, freq_value: Optional[int] = None,
|
||||
retry: bool = False) -> datetime:
|
||||
self._update_task_status(task_id, {
|
||||
'last_run_status': 'success',
|
||||
'is_running': 0,
|
||||
'run_count': task['run_count'] + 1,
|
||||
'next_run_time': next_run_time
|
||||
})
|
||||
task_log.info(f"任务执行成功: {task_name}")
|
||||
except Exception:
|
||||
task_log.error("任务后台执行失败", exc_info=True)
|
||||
next_retry_time = datetime.now() + pd.Timedelta(minutes=15)
|
||||
try:
|
||||
self._update_task_status(task_id, {
|
||||
'last_run_status': 'failed',
|
||||
'is_running': 0,
|
||||
'next_run_time': next_retry_time
|
||||
})
|
||||
except Exception:
|
||||
task_log.error("任务失败后状态更新失败(后台)", exc_info=True)
|
||||
finally:
|
||||
# 释放并发槽位
|
||||
try:
|
||||
self._running_semaphore.release()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _execute_task_logic(self, task):
|
||||
"""
|
||||
计算下次执行时间(带重试逻辑)
|
||||
执行任务逻辑的核心方法
|
||||
支持类方法、静态方法和实例方法的调用
|
||||
"""
|
||||
base_time = datetime.now()
|
||||
module_path = task.get('module_path')
|
||||
if not module_path:
|
||||
raise ValueError("任务缺少 module_path 配置")
|
||||
|
||||
if retry:
|
||||
# 失败后15分钟重试
|
||||
log.debug("Calculating retry time")
|
||||
return base_time + timedelta(minutes=15)
|
||||
# 解析模块路径和类名
|
||||
try:
|
||||
path_parts = module_path.split('.')
|
||||
if len(path_parts) < 2:
|
||||
raise ValueError(f"无效的模块路径: {module_path}")
|
||||
|
||||
if freq_type == 'minute':
|
||||
delta = timedelta(minutes=freq_value or 1)
|
||||
elif freq_type == 'hourly':
|
||||
delta = timedelta(hours=freq_value or 1)
|
||||
elif freq_type == 'daily':
|
||||
delta = timedelta(days=freq_value or 1)
|
||||
elif freq_type == 'weekly':
|
||||
delta = timedelta(weeks=freq_value or 1)
|
||||
elif freq_type == 'monthly':
|
||||
# 处理月末日期特殊情况
|
||||
next_month = (base_time.replace(day=1) + timedelta(days=32)).replace(day=1)
|
||||
last_day = (next_month - timedelta(days=1)).day
|
||||
day = min(base_time.day, last_day)
|
||||
return base_time.replace(day=1, month=next_month.month, day=day)
|
||||
module_name = '.'.join(path_parts[:-1])
|
||||
class_name = path_parts[-1]
|
||||
method_name = 'main' # 默认方法名
|
||||
except Exception as e:
|
||||
raise ValueError(f"解析模块路径失败: {str(e)}")
|
||||
|
||||
# 动态导入模块
|
||||
try:
|
||||
import importlib
|
||||
module = importlib.import_module(module_name)
|
||||
except ImportError as e:
|
||||
raise ImportError(f"无法导入模块 {module_name}: {str(e)}")
|
||||
|
||||
# 获取类和方法
|
||||
if not hasattr(module, class_name):
|
||||
raise AttributeError(f"模块 {module_name} 中未找到类 {class_name}")
|
||||
|
||||
cls = getattr(module, class_name)
|
||||
|
||||
# 检查是否存在指定方法
|
||||
if not hasattr(cls, method_name):
|
||||
raise AttributeError(f"类 {class_name} 中未找到方法 {method_name}")
|
||||
|
||||
method = getattr(cls, method_name)
|
||||
|
||||
# 根据方法类型决定如何调用
|
||||
import inspect
|
||||
callable_entry = None
|
||||
|
||||
# 判断是否为静态方法或类方法
|
||||
if isinstance(method, staticmethod):
|
||||
# 静态方法可以直接调用
|
||||
callable_entry = method
|
||||
elif isinstance(method, classmethod):
|
||||
# 类方法需要传入类作为第一个参数
|
||||
callable_entry = method
|
||||
else:
|
||||
raise ValueError(f"Unknown frequency type: {freq_type}")
|
||||
# 实例方法或普通函数
|
||||
try:
|
||||
# 尝试检查方法签名
|
||||
sig = inspect.signature(method)
|
||||
params = list(sig.parameters.values())
|
||||
|
||||
return base_time + delta
|
||||
# 如果第一个参数是self且没有默认值,则认为是实例方法
|
||||
if params and params[0].name == 'self' and params[0].default == inspect.Parameter.empty:
|
||||
# 创建实例并获取绑定方法
|
||||
instance = cls()
|
||||
callable_entry = getattr(instance, method_name)
|
||||
else:
|
||||
# 可能是普通函数或者是带有默认self参数的方法
|
||||
callable_entry = method
|
||||
except Exception:
|
||||
# 如果检查签名失败,默认尝试创建实例
|
||||
try:
|
||||
instance = cls()
|
||||
callable_entry = getattr(instance, method_name)
|
||||
except Exception:
|
||||
# 如果创建实例也失败,则直接调用方法(适用于不需要self的特殊情况)
|
||||
callable_entry = method
|
||||
|
||||
def add_task(self, task_name: str, module_path: str, frequency_type: str,
|
||||
frequency_value: Optional[int] = None) -> int:
|
||||
"""
|
||||
添加新任务到调度系统
|
||||
"""
|
||||
# 执行任务
|
||||
if not callable(callable_entry):
|
||||
raise TypeError(f"{module_path}.{method_name} 不是可调用对象")
|
||||
|
||||
try:
|
||||
# 执行任务逻辑
|
||||
callable_entry()
|
||||
except Exception as e:
|
||||
self.logger.error(f"任务逻辑执行失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def _calculate_next_run_time(self, cron_expr: str, time_zone: str = 'Asia/Shanghai') -> datetime:
|
||||
"""基于Cron表达式计算下次运行时间"""
|
||||
try:
|
||||
tz = pytz.timezone(time_zone)
|
||||
now = datetime.now(tz) # 使用任务指定时区的当前时间
|
||||
cron = croniter.croniter(cron_expr, now)
|
||||
next_run = cron.get_next(datetime)
|
||||
return next_run.replace(tzinfo=None) # 移除时区信息,适应数据库存储
|
||||
except Exception as e:
|
||||
log.error(f"Cron表达式解析失败: {cron_expr}, 错误: {str(e)}")
|
||||
raise ValueError(f"无效的Cron表达式: {cron_expr}")
|
||||
|
||||
def _update_task_status(self, task_id: int, updates: Dict[str, Any]) -> None:
|
||||
"""更新任务状态到数据库(适配SQLAlchemy的参数传递方式)"""
|
||||
if not updates:
|
||||
log.warning(f"任务ID {task_id} 未提供任何更新字段")
|
||||
return
|
||||
|
||||
# 构建UPDATE语句(确保字段名安全)
|
||||
valid_fields = {'is_running', 'last_run_time', 'last_run_status',
|
||||
'run_count', 'next_run_time', 'updated_at'}
|
||||
filtered_updates = {k: v for k, v in updates.items() if k in valid_fields}
|
||||
|
||||
if not filtered_updates:
|
||||
log.warning(f"任务ID {task_id} 没有有效的更新字段")
|
||||
return
|
||||
|
||||
set_clause = ", ".join([f"{k}=%s" for k in filtered_updates.keys()])
|
||||
sql = f"UPDATE main_task SET {set_clause}, updated_at=NOW() WHERE task_id=%s"
|
||||
params = list(filtered_updates.values()) + [task_id]
|
||||
|
||||
try:
|
||||
# 执行更新并获取受影响的行数
|
||||
affected_rows = self.db.execute_sql(sql, params=params)
|
||||
if affected_rows != 1:
|
||||
log.warning(
|
||||
"任务状态更新异常",
|
||||
task_id=task_id,
|
||||
预期影响行数=1,
|
||||
实际影响行数=affected_rows
|
||||
)
|
||||
except SQLAlchemyError as e:
|
||||
log.error(f"任务状态更新失败(数据库错误),task_id: {task_id}", exc_info=True)
|
||||
raise
|
||||
except Exception as e:
|
||||
log.error(f"任务状态更新失败,task_id: {task_id}", exc_info=True)
|
||||
raise
|
||||
|
||||
def add_task(self,
|
||||
task_name: str,
|
||||
task_type: str,
|
||||
module_path: str,
|
||||
cron_expression: str,
|
||||
time_zone: str = 'Asia/Shanghai') -> int:
|
||||
"""添加新的Cron任务"""
|
||||
if not cron_expression:
|
||||
raise ValueError("Cron表达式不能为空")
|
||||
|
||||
# 验证模块路径可解析(提前检查,避免添加无效任务)
|
||||
try:
|
||||
_ = self._resolve_callable(module_path)
|
||||
except Exception as e:
|
||||
raise ValueError(f"模块路径不可用: {module_path},错误: {str(e)}")
|
||||
|
||||
# 计算首次运行时间
|
||||
first_run_time = self._calculate_next_run_time(cron_expression, time_zone)
|
||||
|
||||
# 插入数据库
|
||||
sql = """
|
||||
INSERT INTO main_task
|
||||
(task_name, module_path, frequency_type, frequency_value, next_run_time)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
"""
|
||||
|
||||
next_run = self._calculate_next_run(frequency_type, frequency_value)
|
||||
params = (task_name, module_path, frequency_type, frequency_value, next_run)
|
||||
INSERT INTO main_task
|
||||
(task_name, task_type, module_path, cron_expression, time_zone,
|
||||
next_run_time, is_active, created_at, updated_at)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, 1, NOW(), NOW()) \
|
||||
"""
|
||||
params = (task_name, task_type, module_path, cron_expression, time_zone, first_run_time)
|
||||
|
||||
try:
|
||||
self.db.execute_sql(sql, params=params)
|
||||
task_id = self.db.query_to_df("SELECT LAST_INSERT_ID() AS id").iloc[0]['id']
|
||||
# 获取插入的任务ID
|
||||
result_df = self.db.query_to_df("SELECT LAST_INSERT_ID() AS id")
|
||||
if result_df.empty or 'id' not in result_df.columns:
|
||||
raise ValueError("无法获取新添加任务的ID")
|
||||
|
||||
task_id = result_df.iloc[0]['id']
|
||||
log.info(
|
||||
"New task added",
|
||||
"新任务添加成功",
|
||||
task_id=task_id,
|
||||
task_name=task_name,
|
||||
next_run=next_run
|
||||
cron表达式=cron_expression,
|
||||
首次运行时间=first_run_time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
)
|
||||
return task_id
|
||||
except SQLAlchemyError as e:
|
||||
log.error(f"添加任务失败(数据库错误): {task_name}", exc_info=True)
|
||||
raise
|
||||
except Exception as e:
|
||||
log.error(
|
||||
"Failed to add new task",
|
||||
task_name=task_name,
|
||||
exc_info=True
|
||||
)
|
||||
log.error(f"添加任务失败: {task_name}", exc_info=True)
|
||||
raise
|
||||
|
||||
def get_task_status(self, active_only: bool = True) -> pd.DataFrame:
|
||||
"""
|
||||
获取任务状态
|
||||
"""
|
||||
where = "WHERE is_active = 1" if active_only else ""
|
||||
log.debug("Fetching task status", active_only=active_only)
|
||||
return self.db.query_to_df(
|
||||
f"""
|
||||
SELECT
|
||||
task_id, task_name, module_path,
|
||||
frequency_type, frequency_value,
|
||||
last_run_time, next_run_time,
|
||||
last_run_status, run_count,
|
||||
is_active, is_running
|
||||
FROM main_task
|
||||
{where}
|
||||
ORDER BY next_run_time
|
||||
"""
|
||||
)
|
||||
def get_pending_tasks_count(self) -> int:
|
||||
"""获取待执行任务数量(用于优雅关闭)"""
|
||||
try:
|
||||
tz = pytz.timezone('Asia/Shanghai')
|
||||
now = datetime.now(tz).replace(tzinfo=None)
|
||||
sql = """
|
||||
SELECT COUNT(*) as cnt
|
||||
FROM main_task
|
||||
WHERE is_active = 1
|
||||
AND next_run_time <= %s
|
||||
AND is_running = 0
|
||||
"""
|
||||
df = self.db.query_to_df(sql, params=(now,))
|
||||
return df['cnt'].iloc[0] if not df.empty else 0
|
||||
except Exception as e:
|
||||
log.error(f"查询待执行任务数量失败: {str(e)}", exc_info=True)
|
||||
return 0 # 出错时返回0,避免影响关闭流程
|
||||
|
||||
def get_pending_tasks(self) -> List[Dict[str, Any]]:
|
||||
"""查询所有待执行任务(兼容原有逻辑)"""
|
||||
try:
|
||||
tz = pytz.timezone('Asia/Shanghai')
|
||||
now = datetime.now(tz).replace(tzinfo=None)
|
||||
sql = """
|
||||
SELECT *
|
||||
FROM main_task
|
||||
WHERE is_active = 1
|
||||
AND next_run_time <= %s
|
||||
AND is_running = 0
|
||||
ORDER BY next_run_time
|
||||
"""
|
||||
tasks_df = self.db.query_to_df(sql, params=(now,))
|
||||
|
||||
if tasks_df.empty:
|
||||
log.info("当前任务列表为空,等待新任务加入...")
|
||||
return []
|
||||
|
||||
log.info(f"查询到{len(tasks_df)}个待执行任务")
|
||||
return tasks_df.to_dict('records')
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"查询待执行任务失败,将重试: {str(e)}", exc_info=True)
|
||||
return []
|
||||
|
||||
def get_and_reset_hourly_stats(self) -> Dict[str, int]:
|
||||
"""获取并重置小时统计数据(用于每小时统计)"""
|
||||
with self.hourly_stats_lock:
|
||||
stats = self.hourly_stats.copy()
|
||||
# 重置统计
|
||||
self.hourly_stats = {'成功': 0, '失败': 0, '总数': 0}
|
||||
return stats
|
||||
@@ -0,0 +1 @@
|
||||
print("Hello, World!")
|
||||
@@ -0,0 +1,171 @@
|
||||
import unittest
|
||||
import os
|
||||
import tempfile
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from utils.minio_agent import MinIOAgent # 导入之前的MinIO操作类
|
||||
|
||||
|
||||
class TestMinIOAgent(unittest.TestCase):
|
||||
# 测试配置 - 本地MinIO社区版
|
||||
MINIO_CONFIG = {
|
||||
'endpoint': '127.0.0.1:9005',
|
||||
'access_key': 'admin', # 默认账号
|
||||
'secret_key': 'abc88888888', # 默认密码
|
||||
'secure': False # 社区版默认不启用SSL
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""初始化测试环境"""
|
||||
# 创建唯一测试桶(避免冲突)
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
cls.test_bucket = f"test-bucket-{timestamp}"
|
||||
cls.test_object = "test-data/sample.txt"
|
||||
cls.test_content = b"this is MinIO test data: 1234567890"
|
||||
|
||||
# 初始化客户端
|
||||
cls.minio_agent = MinIOAgent(cls.MINIO_CONFIG)
|
||||
|
||||
# 确保测试桶存在
|
||||
cls.minio_agent.create_bucket(cls.test_bucket)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
"""清理测试环境"""
|
||||
try:
|
||||
# 列出并删除桶内所有对象
|
||||
objects = cls.minio_agent.list_objects(cls.test_bucket)
|
||||
for obj in objects:
|
||||
cls.minio_agent.delete_object(cls.test_bucket, obj['object_name'])
|
||||
|
||||
# 删除测试桶(MinIO要求桶为空才能删除)
|
||||
cls.minio_agent._client.remove_bucket(cls.test_bucket)
|
||||
print(f"\n测试清理完成,已删除桶: {cls.test_bucket}")
|
||||
except Exception as e:
|
||||
print(f"清理测试环境失败: {str(e)}")
|
||||
|
||||
def test_01_create_bucket(self):
|
||||
"""测试创建存储桶"""
|
||||
new_bucket = f"temp-bucket-{datetime.now().microsecond}"
|
||||
result = self.minio_agent.create_bucket(new_bucket)
|
||||
self.assertTrue(result, "存储桶创建失败")
|
||||
|
||||
# 验证桶是否存在
|
||||
exists = self.minio_agent._client.bucket_exists(new_bucket)
|
||||
self.assertTrue(exists, "存储桶创建后未检测到存在")
|
||||
|
||||
# 清理临时桶
|
||||
self.minio_agent._client.remove_bucket(new_bucket)
|
||||
|
||||
def test_02_upload_download(self):
|
||||
"""测试上传与下载功能"""
|
||||
# 上传数据
|
||||
upload_meta = self.minio_agent.upload_bytes(
|
||||
bucket=self.test_bucket,
|
||||
object_name=self.test_object,
|
||||
data=self.test_content
|
||||
)
|
||||
|
||||
# 验证上传结果
|
||||
self.assertEqual(upload_meta['size'], len(self.test_content), "上传数据大小不匹配")
|
||||
self.assertEqual(upload_meta['local_hash'], hashlib.md5(self.test_content).hexdigest(), "本地哈希校验失败")
|
||||
|
||||
# 下载数据到临时文件
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
|
||||
download_meta = self.minio_agent.download_file(
|
||||
bucket=self.test_bucket,
|
||||
object_name=self.test_object,
|
||||
local_path=temp_path
|
||||
)
|
||||
|
||||
# 验证下载内容
|
||||
with open(temp_path, 'rb') as f:
|
||||
downloaded_content = f.read()
|
||||
|
||||
self.assertEqual(downloaded_content, self.test_content, "下载数据与原始数据不匹配")
|
||||
self.assertEqual(download_meta['size'], len(self.test_content), "下载文件大小不匹配")
|
||||
|
||||
# 清理临时文件
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_03_presigned_url(self):
|
||||
"""测试生成预签名URL"""
|
||||
# 先上传测试文件
|
||||
self.minio_agent.upload_bytes(
|
||||
self.test_bucket,
|
||||
self.test_object,
|
||||
self.test_content
|
||||
)
|
||||
|
||||
# 生成URL(有效期30秒)
|
||||
url_info = self.minio_agent.get_presigned_url(
|
||||
bucket=self.test_bucket,
|
||||
object_name=self.test_object,
|
||||
expires=30
|
||||
)
|
||||
|
||||
# 验证URL格式
|
||||
self.assertIn("http://127.0.0.1:9005", url_info['presigned_url'], "预签名URL格式不正确")
|
||||
self.assertEqual(url_info['expires_in'], 30, "过期时间设置不正确")
|
||||
|
||||
def test_04_list_objects(self):
|
||||
"""测试列出对象功能"""
|
||||
# 上传多个测试对象
|
||||
test_objects = [
|
||||
"test-folder/file1.txt",
|
||||
"test-folder/file2.csv",
|
||||
"another-folder/image.jpg"
|
||||
]
|
||||
|
||||
for obj in test_objects:
|
||||
self.minio_agent.upload_bytes(
|
||||
self.test_bucket,
|
||||
obj,
|
||||
b"tese_list_obj"
|
||||
)
|
||||
|
||||
# 列出所有对象
|
||||
all_objects = self.minio_agent.list_objects(self.test_bucket)
|
||||
self.assertEqual(len(all_objects), len(test_objects) + 1, "列出对象数量不匹配") # +1是之前的test_object
|
||||
|
||||
# 按前缀筛选
|
||||
filtered_objects = self.minio_agent.list_objects(
|
||||
self.test_bucket,
|
||||
prefix="test-folder/"
|
||||
)
|
||||
self.assertEqual(len(filtered_objects), 2, "按前缀筛选结果不正确")
|
||||
|
||||
def test_05_delete_object(self):
|
||||
"""测试删除对象功能"""
|
||||
# 创建测试对象
|
||||
delete_obj = "to-delete/temp.txt"
|
||||
self.minio_agent.upload_bytes(
|
||||
self.test_bucket,
|
||||
delete_obj,
|
||||
b"will be delete"
|
||||
)
|
||||
|
||||
# 执行删除
|
||||
result = self.minio_agent.delete_object(self.test_bucket, delete_obj)
|
||||
self.assertTrue(result, "删除对象失败")
|
||||
|
||||
# 验证删除
|
||||
objects = self.minio_agent.list_objects(self.test_bucket, prefix="to-delete/")
|
||||
self.assertEqual(len(objects), 0, "对象删除后仍存在")
|
||||
|
||||
def test_06_upload_empty_data(self):
|
||||
"""测试上传空数据的异常处理"""
|
||||
with self.assertRaises(ValueError, msg="未捕获空数据上传异常"):
|
||||
self.minio_agent.upload_bytes(
|
||||
self.test_bucket,
|
||||
"empty.txt",
|
||||
b""
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 执行测试并显示详细结果
|
||||
unittest.main(verbosity=2)
|
||||
+104
-115
@@ -1,21 +1,22 @@
|
||||
import unittest
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
import tempfile
|
||||
import time
|
||||
import pymysql
|
||||
from storage.mysql_agent import MySQLAgent
|
||||
import platform
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from utils.mysql_agent import MySQLAgent
|
||||
|
||||
|
||||
class TestMySQLAgent(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""初始化测试环境和测试表"""
|
||||
# 创建唯一的测试数据库名
|
||||
cls.test_db_name = "test_db_" + datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
cls.test_table = "test_table_" + datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
# 创建唯一的测试数据库和表名(避免冲突)
|
||||
cls.test_db_name = f"test_db_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
cls.test_table = f"test_table_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
|
||||
# 基础配置
|
||||
# 基础配置(根据实际环境修改)
|
||||
cls.base_config = {
|
||||
'host': 'localhost',
|
||||
'port': 3306,
|
||||
@@ -33,21 +34,19 @@ class TestMySQLAgent(unittest.TestCase):
|
||||
'database': cls.test_db_name
|
||||
})
|
||||
|
||||
# 创建测试表
|
||||
# 创建测试表并插入初始数据
|
||||
test_data = pd.DataFrame({
|
||||
'id': [1, 2, 3],
|
||||
'name': ['Test1', 'Test2', 'Test3'],
|
||||
'value': [10.5, 20.3, 30.8],
|
||||
'created_at': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03'])
|
||||
})
|
||||
|
||||
cls.db.create_table_from_df(cls.test_table, test_data, primary_key='id')
|
||||
cls.db.insert_from_df(cls.test_table, test_data)
|
||||
|
||||
@classmethod
|
||||
def _create_test_database(cls):
|
||||
"""创建测试数据库"""
|
||||
# 使用临时连接创建数据库
|
||||
temp_conn = pymysql.connect(
|
||||
host=cls.base_config['host'],
|
||||
port=cls.base_config['port'],
|
||||
@@ -55,7 +54,6 @@ class TestMySQLAgent(unittest.TestCase):
|
||||
password=cls.base_config['password'],
|
||||
charset='utf8mb4'
|
||||
)
|
||||
|
||||
try:
|
||||
with temp_conn.cursor() as cursor:
|
||||
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}")
|
||||
@@ -67,21 +65,14 @@ class TestMySQLAgent(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
"""清理测试数据库"""
|
||||
"""清理测试环境"""
|
||||
if hasattr(cls, 'db') and cls.db:
|
||||
# 删除测试表
|
||||
if cls.db.table_exists(cls.test_table):
|
||||
cls.db.drop_table(cls.test_table)
|
||||
|
||||
# 删除测试数据库
|
||||
temp_conn = pymysql.connect(
|
||||
host=cls.base_config['host'],
|
||||
port=cls.base_config['port'],
|
||||
user=cls.base_config['user'],
|
||||
password=cls.base_config['password'],
|
||||
charset='utf8mb4'
|
||||
)
|
||||
|
||||
temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4')
|
||||
try:
|
||||
with temp_conn.cursor() as cursor:
|
||||
cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}")
|
||||
@@ -89,22 +80,24 @@ class TestMySQLAgent(unittest.TestCase):
|
||||
finally:
|
||||
temp_conn.close()
|
||||
|
||||
def test_01_connection(self):
|
||||
def test_connection(self):
|
||||
"""测试数据库连接"""
|
||||
version = self.db.query_to_df("SELECT VERSION() as version")
|
||||
self.assertIsNotNone(version)
|
||||
print(f"\nDatabase version: {version['version'].iloc[0]}")
|
||||
print(f"Running on: {platform.system()} {platform.release()}")
|
||||
version_df = self.db.query_to_df("SELECT VERSION() as version")
|
||||
self.assertIsNotNone(version_df)
|
||||
self.assertEqual(len(version_df), 1)
|
||||
print(f"数据库版本: {version_df['version'].iloc[0]}")
|
||||
|
||||
def test_02_query_to_df(self):
|
||||
def test_query_to_df(self):
|
||||
"""测试查询返回DataFrame"""
|
||||
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id > %s", (1,))
|
||||
self.assertEqual(len(df), 2)
|
||||
df = self.db.query_to_df(
|
||||
f"SELECT * FROM {self.test_table} WHERE id > %s",
|
||||
params=(1,)
|
||||
)
|
||||
self.assertIsInstance(df, pd.DataFrame)
|
||||
print("\nQuery result sample:")
|
||||
print(df.head())
|
||||
self.assertEqual(len(df), 2) # id>1 的数据有2条
|
||||
self.assertIn('name', df.columns)
|
||||
|
||||
def test_03_insert_from_df(self):
|
||||
def test_insert_from_df(self):
|
||||
"""测试DataFrame插入"""
|
||||
new_data = pd.DataFrame({
|
||||
'id': [4, 5],
|
||||
@@ -113,55 +106,55 @@ class TestMySQLAgent(unittest.TestCase):
|
||||
'created_at': pd.to_datetime(['2023-01-04', '2023-01-05'])
|
||||
})
|
||||
|
||||
rows = self.db.insert_from_df(self.test_table, new_data)
|
||||
self.assertEqual(rows, 2)
|
||||
inserted_rows = self.db.insert_from_df(self.test_table, new_data)
|
||||
self.assertEqual(inserted_rows, 2)
|
||||
|
||||
# 验证数据
|
||||
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id >= 4")
|
||||
self.assertEqual(len(df), 2)
|
||||
self.assertEqual(df['name'].tolist(), ['Test4', 'Test5'])
|
||||
# 验证插入结果
|
||||
result_df = self.db.query_to_df(
|
||||
f"SELECT name FROM {self.test_table} WHERE id IN (4,5)"
|
||||
)
|
||||
self.assertEqual(result_df['name'].tolist(), ['Test4', 'Test5'])
|
||||
|
||||
def test_04_update_from_df(self):
|
||||
def test_update_from_df(self):
|
||||
"""测试DataFrame更新"""
|
||||
update_data = pd.DataFrame({
|
||||
'id': [1, 2],
|
||||
'name': ['Updated1', 'Updated2']
|
||||
})
|
||||
|
||||
rows = self.db.update_from_df(self.test_table, update_data, 'id')
|
||||
self.assertGreaterEqual(rows, 2)
|
||||
updated_rows = self.db.update_from_df(self.test_table, update_data, 'id')
|
||||
self.assertGreaterEqual(updated_rows, 2)
|
||||
|
||||
# 验证更新
|
||||
df = self.db.query_to_df(f"SELECT name FROM {self.test_table} WHERE id IN (1,2)")
|
||||
self.assertIn('Updated1', df['name'].values)
|
||||
self.assertIn('Updated2', df['name'].values)
|
||||
# 验证更新结果
|
||||
result_df = self.db.query_to_df(
|
||||
f"SELECT name FROM {self.test_table} WHERE id IN (1,2)"
|
||||
)
|
||||
self.assertIn('Updated1', result_df['name'].values)
|
||||
self.assertIn('Updated2', result_df['name'].values)
|
||||
|
||||
def test_05_transaction(self):
|
||||
def test_transaction(self):
|
||||
"""测试事务处理"""
|
||||
conn = self.db.begin_transaction()
|
||||
try:
|
||||
# 执行多个操作
|
||||
# 执行事务内操作
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f"UPDATE {self.test_table} SET value = 99.9 WHERE id = 1")
|
||||
cursor.execute(f"UPDATE {self.test_table} SET value = 88.8 WHERE id = 2")
|
||||
|
||||
# 验证事务内修改
|
||||
cursor.execute(f"SELECT value FROM {self.test_table} WHERE id = 1")
|
||||
self.assertEqual(cursor.fetchone()['value'], 99.9)
|
||||
|
||||
self.db.commit_transaction(conn)
|
||||
except Exception:
|
||||
self.db.rollback_transaction(conn)
|
||||
raise
|
||||
|
||||
# 验证提交后的修改
|
||||
df = self.db.query_to_df(f"SELECT value FROM {self.test_table} WHERE id IN (1,2)")
|
||||
self.assertIn(99.9, df['value'].values)
|
||||
self.assertIn(88.8, df['value'].values)
|
||||
# 验证事务提交结果
|
||||
result_df = self.db.query_to_df(
|
||||
f"SELECT value FROM {self.test_table} WHERE id IN (1,2)"
|
||||
)
|
||||
self.assertIn(99.9, result_df['value'].values)
|
||||
self.assertIn(88.8, result_df['value'].values)
|
||||
|
||||
def test_06_large_data(self):
|
||||
"""测试大数据量操作"""
|
||||
# 生成测试数据
|
||||
def test_large_data_insert(self):
|
||||
"""测试大数据量插入"""
|
||||
# 生成1000行测试数据
|
||||
large_data = pd.DataFrame({
|
||||
'id': range(1000, 2000),
|
||||
'name': [f"Item_{i}" for i in range(1000, 2000)],
|
||||
@@ -169,59 +162,55 @@ class TestMySQLAgent(unittest.TestCase):
|
||||
'created_at': pd.date_range('2023-01-01', periods=1000)
|
||||
})
|
||||
|
||||
# Windows平台使用更小的批次
|
||||
# 根据平台自动调整批次大小
|
||||
chunk_size = 100 if platform.system() == 'Windows' else 500
|
||||
|
||||
start_time = time.time()
|
||||
rows = self.db.insert_from_df(self.test_table, large_data, chunk_size=chunk_size)
|
||||
inserted_rows = self.db.insert_from_df(
|
||||
self.test_table,
|
||||
large_data,
|
||||
chunk_size=chunk_size
|
||||
)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
self.assertEqual(rows, 1000)
|
||||
print(f"\nInserted 1000 rows in {elapsed:.2f}s (chunk_size={chunk_size})")
|
||||
self.assertEqual(inserted_rows, 1000)
|
||||
print(f"插入1000行数据耗时: {elapsed:.2f}秒 (批次大小: {chunk_size})")
|
||||
|
||||
# 验证数据
|
||||
df = self.db.query_to_df(f"SELECT COUNT(*) as cnt FROM {self.test_table} WHERE id >= 1000")
|
||||
self.assertEqual(df['cnt'].iloc[0], 1000)
|
||||
|
||||
def test_07_concurrent_access(self):
|
||||
def test_concurrent_access(self):
|
||||
"""测试并发访问"""
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
def worker(i):
|
||||
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id = %s", (i % 5 + 1,))
|
||||
def query_worker(i):
|
||||
"""并发查询工作函数"""
|
||||
df = self.db.query_to_df(
|
||||
f"SELECT * FROM {self.test_table} WHERE id = %s",
|
||||
params=(i % 3 + 1,) # 查询id=1,2,3循环
|
||||
)
|
||||
return len(df)
|
||||
|
||||
# 20个线程执行100次查询
|
||||
start_time = time.time()
|
||||
with ThreadPoolExecutor(max_workers=20) as executor:
|
||||
results = list(executor.map(worker, range(100)))
|
||||
|
||||
results = list(executor.map(query_worker, range(100)))
|
||||
elapsed = time.time() - start_time
|
||||
self.assertEqual(sum(results), 100)
|
||||
print(f"\nCompleted 100 concurrent queries in {elapsed:.2f}s")
|
||||
|
||||
self.assertEqual(sum(results), 100) # 每次查询应返回1行
|
||||
print(f"100次并发查询耗时: {elapsed:.2f}秒")
|
||||
|
||||
|
||||
class TestPlatformSpecific(unittest.TestCase):
|
||||
"""平台特定功能测试"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""创建临时测试数据库"""
|
||||
cls.test_db_name = "test_db_platform_" + datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
cls.test_db_name = f"test_platform_db_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
cls.base_config = {
|
||||
'host': 'localhost',
|
||||
'port': 3306,
|
||||
'user': 'root',
|
||||
'password': '123123',
|
||||
'max_connections': 10
|
||||
'password': '123123'
|
||||
}
|
||||
|
||||
# 创建数据库
|
||||
temp_conn = pymysql.connect(
|
||||
host=cls.base_config['host'],
|
||||
port=cls.base_config['port'],
|
||||
user=cls.base_config['user'],
|
||||
password=cls.base_config['password'],
|
||||
charset='utf8mb4'
|
||||
)
|
||||
|
||||
# 创建测试数据库
|
||||
temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4')
|
||||
try:
|
||||
with temp_conn.cursor() as cursor:
|
||||
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}")
|
||||
@@ -231,15 +220,8 @@ class TestPlatformSpecific(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
"""删除临时测试数据库"""
|
||||
temp_conn = pymysql.connect(
|
||||
host=cls.base_config['host'],
|
||||
port=cls.base_config['port'],
|
||||
user=cls.base_config['user'],
|
||||
password=cls.base_config['password'],
|
||||
charset='utf8mb4'
|
||||
)
|
||||
|
||||
"""清理测试数据库"""
|
||||
temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4')
|
||||
try:
|
||||
with temp_conn.cursor() as cursor:
|
||||
cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}")
|
||||
@@ -250,42 +232,49 @@ class TestPlatformSpecific(unittest.TestCase):
|
||||
def test_windows_timeout(self):
|
||||
"""测试Windows平台超时处理"""
|
||||
if platform.system() != 'Windows':
|
||||
self.skipTest("Only runs on Windows")
|
||||
self.skipTest("仅在Windows平台运行")
|
||||
|
||||
config = {
|
||||
**self.base_config,
|
||||
'database': self.test_db_name,
|
||||
'connect_timeout': 1,
|
||||
'read_timeout': 1
|
||||
'read_timeout': 1,
|
||||
'write_timeout': 1
|
||||
}
|
||||
|
||||
db = MySQLAgent(config)
|
||||
|
||||
# 测试短超时查询
|
||||
start_time = time.time()
|
||||
try:
|
||||
db.query_to_df("SELECT SLEEP(2)")
|
||||
self.fail("Should have timed out")
|
||||
except Exception as e:
|
||||
self.assertIn("timed out", str(e))
|
||||
print(f"\nWindows timeout test: {str(e)}")
|
||||
# 执行会超时的查询(SLEEP(2)超过1秒超时设置)
|
||||
with self.assertRaises((pymysql.OperationalError, TimeoutError)) as ctx:
|
||||
try:
|
||||
db.query_to_df("SELECT SLEEP(2)")
|
||||
except Exception as e:
|
||||
# 提取底层异常信息(可能被包装)
|
||||
while hasattr(e, 'args') and len(e.args) > 0 and isinstance(e.args[0], Exception):
|
||||
e = e.args[0]
|
||||
raise e
|
||||
|
||||
def test_macos_ssl(self):
|
||||
"""测试macOS SSL连接"""
|
||||
error_msg = str(ctx.exception)
|
||||
self.assertTrue(
|
||||
"timed out" in error_msg or
|
||||
"timeout" in error_msg or
|
||||
"HY000" in error_msg, # MySQL超时错误码
|
||||
f"未检测到超时异常,实际异常: {error_msg}"
|
||||
)
|
||||
|
||||
def test_macos_ssl_connection(self):
|
||||
"""测试macOS平台SSL连接"""
|
||||
if platform.system() != 'Darwin':
|
||||
self.skipTest("Only runs on macOS")
|
||||
self.skipTest("仅在macOS平台运行")
|
||||
|
||||
config = {
|
||||
**self.base_config,
|
||||
'database': self.test_db_name,
|
||||
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
|
||||
}
|
||||
|
||||
db = MySQLAgent(config)
|
||||
version = db.query_to_df("SELECT VERSION() as version")
|
||||
self.assertIsNotNone(version)
|
||||
print(f"\nmacOS SSL connection successful: {version['version'].iloc[0]}")
|
||||
version_df = db.query_to_df("SELECT VERSION() as version")
|
||||
self.assertIsNotNone(version_df)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main(verbosity=2)
|
||||
@@ -0,0 +1,18 @@
|
||||
CREATE TABLE IF NOT EXISTS main_task (
|
||||
task_id INT AUTO_INCREMENT PRIMARY KEY COMMENT '任务唯一ID',
|
||||
task_name VARCHAR(100) NOT NULL COMMENT '任务名称',
|
||||
task_type VARCHAR(50) NOT NULL COMMENT '任务类型(如processor、collector等)',
|
||||
module_path VARCHAR(255) NOT NULL COMMENT '任务模块路径(如processors.data_checker)',
|
||||
cron_expression VARCHAR(100) NOT NULL COMMENT 'Cron表达式(调度频率)',
|
||||
time_zone VARCHAR(50) DEFAULT 'Asia/Shanghai' COMMENT '时区', -- 补充此字段
|
||||
next_run_time DATETIME NOT NULL COMMENT '下次运行时间',
|
||||
last_run_time DATETIME NULL COMMENT '上次运行时间',
|
||||
last_run_status ENUM('success', 'failed', 'pending') DEFAULT 'pending' COMMENT '上次运行状态',
|
||||
run_count INT DEFAULT 0 COMMENT '运行次数统计',
|
||||
is_active TINYINT(1) DEFAULT 1 COMMENT '是否活跃(1=启用,0=禁用)',
|
||||
is_running TINYINT(1) DEFAULT 0 COMMENT '是否正在运行(1=运行中,0=未运行)',
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
|
||||
INDEX idx_next_run (next_run_time) COMMENT '优化下次运行时间查询', -- 建议保留索引提升性能
|
||||
INDEX idx_active (is_active) COMMENT '优化活跃任务查询' -- 建议保留索引提升性能
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='任务调度主表';
|
||||
Binary file not shown.
@@ -0,0 +1,957 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "197b1b81f5528a50",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. 初始化(所有操作前必须运行)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"id": "initial_id",
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-10-29T02:25:08.582541Z",
|
||||
"start_time": "2025-10-29T02:25:08.473381Z"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"# 使 Notebook 可从项目根导入\n",
|
||||
"import sys\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"def add_project_root(marker_dirs=(\"utils\", \"system_management\")):\n",
|
||||
" p = Path.cwd()\n",
|
||||
" for _ in range(6):\n",
|
||||
" if all((p / d).exists() for d in marker_dirs):\n",
|
||||
" if str(p) not in sys.path:\n",
|
||||
" sys.path.insert(0, str(p))\n",
|
||||
" return p\n",
|
||||
" p = p.parent\n",
|
||||
" raise RuntimeError(\"未找到项目根目录,请从项目根启动 Notebook 或手动设置 sys.path\")\n",
|
||||
"\n",
|
||||
"PROJECT_ROOT = add_project_root()\n",
|
||||
"print(f\"PROJECT_ROOT = {PROJECT_ROOT}\")\n",
|
||||
"\n",
|
||||
"# 依赖与日志\n",
|
||||
"import pandas as pd\n",
|
||||
"from IPython.display import display, HTML, Markdown\n",
|
||||
"from utils.logger import CrossPlatformLog\n",
|
||||
"log = CrossPlatformLog.get_logger(\"TaskNotebook\")\n",
|
||||
"\n",
|
||||
"# 配置与调度器\n",
|
||||
"from config import Config # 若你使用 ConfigManager,请改为: from config.config import ConfigManager\n",
|
||||
"from system_management.scheduler.task_scheduler import TaskScheduler\n",
|
||||
"\n",
|
||||
"# 初始化调度器(根据你的项目配置选一段)\n",
|
||||
"scheduler = TaskScheduler(Config.MYSQL_CONFIG)\n",
|
||||
"# 或使用 ConfigManager:\n",
|
||||
"# config = ConfigManager()\n",
|
||||
"# scheduler = TaskScheduler(config.get(\"database\"))\n",
|
||||
"\n",
|
||||
"# 在 Notebook 中实现一个最小可用的 TaskManager\n",
|
||||
"class TaskManager:\n",
|
||||
" def __init__(self, scheduler: TaskScheduler):\n",
|
||||
" self.scheduler = scheduler\n",
|
||||
" self.db = scheduler.db # 复用调度器里的 MySQLAgent\n",
|
||||
"\n",
|
||||
" def get_all_tasks(self, active_only: bool = False):\n",
|
||||
" sql = \"\"\"\n",
|
||||
" SELECT *\n",
|
||||
" FROM main_task\n",
|
||||
" {where}\n",
|
||||
" ORDER BY created_at DESC, task_id DESC\n",
|
||||
" \"\"\"\n",
|
||||
" where = \"WHERE is_active = 1\" if active_only else \"\"\n",
|
||||
" df = self.db.query_to_df(sql.format(where=where))\n",
|
||||
" return [] if df is None or df.empty else df.to_dict(\"records\")\n",
|
||||
"\n",
|
||||
" def get_task_by_id(self, task_id: int):\n",
|
||||
" df = self.db.query_to_df(\n",
|
||||
" \"SELECT * FROM main_task WHERE task_id = %s\",\n",
|
||||
" params=(task_id,)\n",
|
||||
" )\n",
|
||||
" return None if df is None or df.empty else df.iloc[0].to_dict()\n",
|
||||
"\n",
|
||||
" def update_task(self, task_id: int, updates: dict) -> bool:\n",
|
||||
" if not updates:\n",
|
||||
" return False\n",
|
||||
" # 允许更新的字段(与调度器一致)\n",
|
||||
" allowed = {\n",
|
||||
" \"task_name\", \"task_type\", \"module_path\",\n",
|
||||
" \"cron_expression\", \"time_zone\"\n",
|
||||
" }\n",
|
||||
" filtered = {k: v for k, v in updates.items() if k in allowed}\n",
|
||||
" if not filtered:\n",
|
||||
" return False\n",
|
||||
"\n",
|
||||
" set_clause = \", \".join([f\"{k}=%s\" for k in filtered.keys()])\n",
|
||||
" params = list(filtered.values()) + [task_id]\n",
|
||||
" sql = f\"UPDATE main_task SET {set_clause}, updated_at=NOW() WHERE task_id=%s\"\n",
|
||||
" affected = self.db.execute_sql(sql, params=params)\n",
|
||||
" return affected == 1\n",
|
||||
"\n",
|
||||
" def toggle_task_status(self, task_id: int, activate: bool) -> bool:\n",
|
||||
" sql = \"UPDATE main_task SET is_active=%s, updated_at=NOW() WHERE task_id=%s\"\n",
|
||||
" affected = self.db.execute_sql(sql, params=(1 if activate else 0, task_id))\n",
|
||||
" return affected == 1\n",
|
||||
"\n",
|
||||
" def delete_task(self, task_id: int) -> bool:\n",
|
||||
" # 如果你更偏好软删除,可以改为: UPDATE main_task SET is_active=0, updated_at=NOW() WHERE task_id=%s\n",
|
||||
" affected = self.db.execute_sql(\"DELETE FROM main_task WHERE task_id=%s\", params=(task_id,))\n",
|
||||
" return affected == 1\n",
|
||||
"\n",
|
||||
" def run_task_manually(self, task_id: int) -> bool:\n",
|
||||
" # 读取任务,直接复用调度器的单任务执行逻辑\n",
|
||||
" task = self.get_task_by_id(task_id)\n",
|
||||
" if not task:\n",
|
||||
" return False\n",
|
||||
" # _process_single_task 期望 dict\n",
|
||||
" try:\n",
|
||||
" return bool(self.scheduler._process_single_task(task)) # 注意:使用了调度器的内部方法\n",
|
||||
" except Exception:\n",
|
||||
" log.exception(\"手动执行任务失败\")\n",
|
||||
" return False\n",
|
||||
"\n",
|
||||
" def run_task_synchronously(self, task_id: int) -> dict:\n",
|
||||
" \"\"\"同步执行任务并返回详细结果(用于Notebook中查看执行过程)\"\"\"\n",
|
||||
" import time\n",
|
||||
" import sys\n",
|
||||
" from io import StringIO\n",
|
||||
" \n",
|
||||
" task = self.get_task_by_id(task_id)\n",
|
||||
" if not task:\n",
|
||||
" return {\n",
|
||||
" 'success': False,\n",
|
||||
" 'error': f'未找到任务ID: {task_id}',\n",
|
||||
" 'output': ''\n",
|
||||
" }\n",
|
||||
" \n",
|
||||
" # 捕获标准输出\n",
|
||||
" old_stdout = sys.stdout\n",
|
||||
" sys.stdout = output_buffer = StringIO()\n",
|
||||
" \n",
|
||||
" start_time = time.time()\n",
|
||||
" success = False\n",
|
||||
" error_msg = None\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" # 直接同步执行任务逻辑\n",
|
||||
" self.scheduler._execute_task_logic(task)\n",
|
||||
" success = True\n",
|
||||
" \n",
|
||||
" # 更新任务状态\n",
|
||||
" next_run_time = self.scheduler._calculate_next_run_time(\n",
|
||||
" cron_expr=task['cron_expression'],\n",
|
||||
" time_zone=task.get('time_zone', 'Asia/Shanghai')\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" self.scheduler._update_task_status(task['task_id'], {\n",
|
||||
" 'last_run_status': 'success',\n",
|
||||
" 'is_running': 0,\n",
|
||||
" 'run_count': task['run_count'] + 1,\n",
|
||||
" 'next_run_time': next_run_time\n",
|
||||
" })\n",
|
||||
" \n",
|
||||
" except Exception as e:\n",
|
||||
" success = False\n",
|
||||
" error_msg = str(e)\n",
|
||||
" log.exception(f\"任务执行失败: {task['task_name']}\")\n",
|
||||
" \n",
|
||||
" # 更新失败状态\n",
|
||||
" try:\n",
|
||||
" next_retry_time = datetime.now() + pd.Timedelta(minutes=15)\n",
|
||||
" self.scheduler._update_task_status(task['task_id'], {\n",
|
||||
" 'last_run_status': 'failed',\n",
|
||||
" 'is_running': 0,\n",
|
||||
" 'next_run_time': next_retry_time\n",
|
||||
" })\n",
|
||||
" except Exception:\n",
|
||||
" pass\n",
|
||||
" \n",
|
||||
" finally:\n",
|
||||
" # 恢复标准输出\n",
|
||||
" sys.stdout = old_stdout\n",
|
||||
" output_text = output_buffer.getvalue()\n",
|
||||
" \n",
|
||||
" execution_time = time.time() - start_time\n",
|
||||
" \n",
|
||||
" return {\n",
|
||||
" 'success': success,\n",
|
||||
" 'task_name': task['task_name'],\n",
|
||||
" 'task_id': task['task_id'],\n",
|
||||
" 'execution_time': execution_time,\n",
|
||||
" 'output': output_text,\n",
|
||||
" 'error': error_msg\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"# 在这里创建 manager(供后续单元使用)\n",
|
||||
"manager = TaskManager(scheduler)\n",
|
||||
"\n",
|
||||
"# 常用辅助函数\n",
|
||||
"def format_datetime(dt):\n",
|
||||
" if dt is None:\n",
|
||||
" return \"-\"\n",
|
||||
" try:\n",
|
||||
" if isinstance(dt, pd.Timestamp) and pd.isna(dt):\n",
|
||||
" return \"-\"\n",
|
||||
" return dt.strftime(\"%Y-%m-%d %H:%M:%S\")\n",
|
||||
" except Exception:\n",
|
||||
" try:\n",
|
||||
" if pd.isna(dt):\n",
|
||||
" return \"-\"\n",
|
||||
" except Exception:\n",
|
||||
" pass\n",
|
||||
" return str(dt)"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"PROJECT_ROOT = D:\\Idea Project\\intelligence_system\n",
|
||||
"\u001B[32m2025-10-29 10:25:08\u001B[0m | \u001B[1mINFO \u001B[0m | \u001B[36mtask_scheduler\u001B[0m - \u001B[1m任务调度器已初始化,最大工作线程数: 5\u001B[0m\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"execution_count": 8
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8271189cef3b5f17",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. 列出任务(对应命令行 list)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "7b020af55972643",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-10-17T05:43:18.499582Z",
|
||||
"start_time": "2025-10-17T05:43:18.394863Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001B[32m2025-10-29 09:54:09\u001B[0m | \u001B[1mINFO \u001B[0m | \u001B[36mmysql_agent\u001B[0m - \u001B[1m查询执行成功\u001B[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/markdown": [
|
||||
"### 任务列表"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.Markdown object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th>任务ID</th>\n",
|
||||
" <th>任务名称</th>\n",
|
||||
" <th>任务类型</th>\n",
|
||||
" <th>模块路径</th>\n",
|
||||
" <th>Cron表达式</th>\n",
|
||||
" <th>时区</th>\n",
|
||||
" <th>下次运行时间</th>\n",
|
||||
" <th>最后运行时间</th>\n",
|
||||
" <th>运行状态</th>\n",
|
||||
" <th>运行次数</th>\n",
|
||||
" <th>是否活跃</th>\n",
|
||||
" <th>is_running</th>\n",
|
||||
" <th>created_at</th>\n",
|
||||
" <th>updated_at</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>RSS基于规则数据处理</td>\n",
|
||||
" <td>processor</td>\n",
|
||||
" <td>processors.processor_rss_data</td>\n",
|
||||
" <td>0 8,20 * * *</td>\n",
|
||||
" <td>Asia/Shanghai</td>\n",
|
||||
" <td>2025-10-28 20:00:00</td>\n",
|
||||
" <td>2025-10-28 13:34:49</td>\n",
|
||||
" <td>success</td>\n",
|
||||
" <td>10</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>2025-10-22 16:06:42</td>\n",
|
||||
" <td>2025-10-28 13:34:50</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>RSS新闻订阅</td>\n",
|
||||
" <td>collector</td>\n",
|
||||
" <td>collectors.rss_subscriptions.NewsAPIClient</td>\n",
|
||||
" <td>*/5 * * * *</td>\n",
|
||||
" <td>Asia/Shanghai</td>\n",
|
||||
" <td>2025-10-28 13:40:00</td>\n",
|
||||
" <td>2025-10-28 13:35:09</td>\n",
|
||||
" <td>success</td>\n",
|
||||
" <td>495</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>2025-10-16 15:47:34</td>\n",
|
||||
" <td>2025-10-28 13:35:09</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.HTML object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>任务ID</th>\n",
|
||||
" <th>任务名称</th>\n",
|
||||
" <th>任务类型</th>\n",
|
||||
" <th>模块路径</th>\n",
|
||||
" <th>Cron表达式</th>\n",
|
||||
" <th>时区</th>\n",
|
||||
" <th>下次运行时间</th>\n",
|
||||
" <th>最后运行时间</th>\n",
|
||||
" <th>运行状态</th>\n",
|
||||
" <th>运行次数</th>\n",
|
||||
" <th>是否活跃</th>\n",
|
||||
" <th>is_running</th>\n",
|
||||
" <th>created_at</th>\n",
|
||||
" <th>updated_at</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>2</td>\n",
|
||||
" <td>RSS基于规则数据处理</td>\n",
|
||||
" <td>processor</td>\n",
|
||||
" <td>processors.processor_rss_data</td>\n",
|
||||
" <td>0 8,20 * * *</td>\n",
|
||||
" <td>Asia/Shanghai</td>\n",
|
||||
" <td>2025-10-28 20:00:00</td>\n",
|
||||
" <td>2025-10-28 13:34:49</td>\n",
|
||||
" <td>success</td>\n",
|
||||
" <td>10</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>2025-10-22 16:06:42</td>\n",
|
||||
" <td>2025-10-28 13:34:50</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>RSS新闻订阅</td>\n",
|
||||
" <td>collector</td>\n",
|
||||
" <td>collectors.rss_subscriptions.NewsAPIClient</td>\n",
|
||||
" <td>*/5 * * * *</td>\n",
|
||||
" <td>Asia/Shanghai</td>\n",
|
||||
" <td>2025-10-28 13:40:00</td>\n",
|
||||
" <td>2025-10-28 13:35:09</td>\n",
|
||||
" <td>success</td>\n",
|
||||
" <td>495</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" <td>2025-10-16 15:47:34</td>\n",
|
||||
" <td>2025-10-28 13:35:09</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" 任务ID 任务名称 任务类型 模块路径 \\\n",
|
||||
"0 2 RSS基于规则数据处理 processor processors.processor_rss_data \n",
|
||||
"1 1 RSS新闻订阅 collector collectors.rss_subscriptions.NewsAPIClient \n",
|
||||
"\n",
|
||||
" Cron表达式 时区 下次运行时间 最后运行时间 \\\n",
|
||||
"0 0 8,20 * * * Asia/Shanghai 2025-10-28 20:00:00 2025-10-28 13:34:49 \n",
|
||||
"1 */5 * * * * Asia/Shanghai 2025-10-28 13:40:00 2025-10-28 13:35:09 \n",
|
||||
"\n",
|
||||
" 运行状态 运行次数 是否活跃 is_running created_at updated_at \n",
|
||||
"0 success 10 1 0 2025-10-22 16:06:42 2025-10-28 13:34:50 \n",
|
||||
"1 success 495 1 0 2025-10-16 15:47:34 2025-10-28 13:35:09 "
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 列出所有任务(包括已禁用的)\n",
|
||||
"def list_tasks(active_only=True):\n",
|
||||
" tasks = manager.get_all_tasks(active_only)\n",
|
||||
" if not tasks:\n",
|
||||
" display(Markdown(\"### 没有找到任务\"))\n",
|
||||
" return None\n",
|
||||
"\n",
|
||||
" df = pd.DataFrame(tasks)\n",
|
||||
"\n",
|
||||
" # 格式化日期列\n",
|
||||
" if 'last_run_time' in df.columns:\n",
|
||||
" df['last_run_time'] = df['last_run_time'].apply(format_datetime)\n",
|
||||
" if 'next_run_time' in df.columns:\n",
|
||||
" df['next_run_time'] = df['next_run_time'].apply(format_datetime)\n",
|
||||
"\n",
|
||||
" # 重命名列名\n",
|
||||
" df = df.rename(columns={\n",
|
||||
" 'task_id': '任务ID',\n",
|
||||
" 'task_name': '任务名称',\n",
|
||||
" 'task_type': '任务类型',\n",
|
||||
" 'module_path': '模块路径',\n",
|
||||
" 'cron_expression': 'Cron表达式',\n",
|
||||
" 'time_zone': '时区',\n",
|
||||
" 'last_run_time': '最后运行时间',\n",
|
||||
" 'next_run_time': '下次运行时间',\n",
|
||||
" 'last_run_status': '运行状态',\n",
|
||||
" 'is_active': '是否活跃',\n",
|
||||
" 'run_count': '运行次数'\n",
|
||||
" })\n",
|
||||
"\n",
|
||||
" display(Markdown(\"### 任务列表\"))\n",
|
||||
" display(HTML(df.to_html(index=False)))\n",
|
||||
" return df\n",
|
||||
"\n",
|
||||
"# 执行:列出所有任务(包括已禁用)\n",
|
||||
"list_tasks(active_only=False)\n",
|
||||
"\n",
|
||||
"# 或者:只列出活跃任务\n",
|
||||
"# list_tasks(active_only=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7780dcef67a0534c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. 查看任务详情(对应命令行 show)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"id": "eab90de72c35429e",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-10-29T02:26:12.873536Z",
|
||||
"start_time": "2025-10-29T02:26:12.648420Z"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"# 查看指定任务的详情\n",
|
||||
"def show_task_details(task_id):\n",
|
||||
" task = manager.get_task_by_id(task_id)\n",
|
||||
" if not task:\n",
|
||||
" display(Markdown(f\"### 未找到任务ID为 {task_id} 的任务\"))\n",
|
||||
" return None\n",
|
||||
"\n",
|
||||
" details = [\"### 任务详情\"]\n",
|
||||
" details.append(f\"**任务ID**: {task.get('task_id')}\")\n",
|
||||
" details.append(f\"**任务名称**: {task.get('task_name')}\")\n",
|
||||
" details.append(f\"**任务类型**: {task.get('task_type')}\")\n",
|
||||
" details.append(f\"**模块路径**: {task.get('module_path')}\")\n",
|
||||
" details.append(f\"**Cron表达式**: {task.get('cron_expression')}\")\n",
|
||||
" details.append(f\"**时区**: {task.get('time_zone', 'Asia/Shanghai')}\")\n",
|
||||
" details.append(f\"**最后运行时间**: {format_datetime(task.get('last_run_time'))}\")\n",
|
||||
" details.append(f\"**下次运行时间**: {format_datetime(task.get('next_run_time'))}\")\n",
|
||||
" details.append(f\"**运行状态**: {task.get('last_run_status', '未运行')}\")\n",
|
||||
" details.append(f\"**是否活跃**: {'是' if task.get('is_active') else '否'}\")\n",
|
||||
" details.append(f\"**运行次数**: {task.get('run_count', 0)}\")\n",
|
||||
" details.append(f\"**创建时间**: {format_datetime(task.get('created_at'))}\")\n",
|
||||
"\n",
|
||||
" display(Markdown('\\n'.join(details)))\n",
|
||||
" return task\n",
|
||||
"\n",
|
||||
"# 执行:查看任务ID为1的详情(替换为实际ID)\n",
|
||||
"show_task_details(1)"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001B[32m2025-10-29 10:26:12\u001B[0m | \u001B[1mINFO \u001B[0m | \u001B[36mmysql_agent\u001B[0m - \u001B[1m查询执行成功\u001B[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<IPython.core.display.Markdown object>"
|
||||
],
|
||||
"text/markdown": "### 任务详情\n**任务ID**: 1\n**任务名称**: RSS新闻订阅\n**任务类型**: collector\n**模块路径**: processors.processor_rss_data.RSSDataProcessor\n**Cron表达式**: */5 * * * *\n**时区**: Asia/Shanghai\n**最后运行时间**: 2025-10-28 13:35:09\n**下次运行时间**: 2025-10-29 10:25:00\n**运行状态**: success\n**是否活跃**: 是\n**运行次数**: 496\n**创建时间**: 2025-10-16 15:47:34"
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data",
|
||||
"jetTransient": {
|
||||
"display_id": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'task_id': 1,\n",
|
||||
" 'task_name': 'RSS新闻订阅',\n",
|
||||
" 'task_type': 'collector',\n",
|
||||
" 'module_path': 'processors.processor_rss_data.RSSDataProcessor',\n",
|
||||
" 'cron_expression': '*/5 * * * *',\n",
|
||||
" 'time_zone': 'Asia/Shanghai',\n",
|
||||
" 'next_run_time': Timestamp('2025-10-29 10:25:00'),\n",
|
||||
" 'last_run_time': Timestamp('2025-10-28 13:35:09'),\n",
|
||||
" 'last_run_status': 'success',\n",
|
||||
" 'run_count': 496,\n",
|
||||
" 'is_active': 1,\n",
|
||||
" 'is_running': 0,\n",
|
||||
" 'created_at': Timestamp('2025-10-16 15:47:34'),\n",
|
||||
" 'updated_at': Timestamp('2025-10-29 10:24:49')}"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"execution_count": 10
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a313f1524f5a54bc",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 4. 添加新任务(对应命令行 add)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "2b2d723bb8e2784f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001B[32m2025-10-29 09:56:52\u001B[0m | \u001B[1mINFO \u001B[0m | \u001B[36mmysql_agent\u001B[0m - \u001B[1m查询执行成功\u001B[0m\n",
|
||||
"\u001B[32m2025-10-29 09:56:52\u001B[0m | \u001B[1mINFO \u001B[0m | \u001B[36mtask_scheduler\u001B[0m - \u001B[1m新任务添加成功\u001B[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/markdown": [
|
||||
"### 任务添加成功!"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.Markdown object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/markdown": [
|
||||
"新任务ID: 0,任务名称: AI处理RSS新闻"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.Markdown object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"np.int64(0)"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 添加新任务\n",
|
||||
"def add_new_task(name, task_type, module_path, cron_expression, timezone=\"Asia/Shanghai\"):\n",
|
||||
" try:\n",
|
||||
" task_id = scheduler.add_task(\n",
|
||||
" task_name=name,\n",
|
||||
" task_type=task_type,\n",
|
||||
" module_path=module_path,\n",
|
||||
" cron_expression=cron_expression,\n",
|
||||
" time_zone=timezone\n",
|
||||
" )\n",
|
||||
" display(Markdown(f\"### 任务添加成功!\"))\n",
|
||||
" display(Markdown(f\"新任务ID: {task_id},任务名称: {name}\"))\n",
|
||||
" return task_id\n",
|
||||
" except Exception as e:\n",
|
||||
" display(Markdown(f\"### 添加任务失败: {str(e)}\"))\n",
|
||||
" return None\n",
|
||||
"\n",
|
||||
"# 执行:添加一个新闻采集任务\n",
|
||||
"add_new_task(\n",
|
||||
" name=\"AI处理RSS新闻\",\n",
|
||||
" task_type=\"processor\",\n",
|
||||
" module_path=\"processors.ai_processors.ai_processor_rss_data.RSSDataAIProcessor\",\n",
|
||||
" cron_expression=\"5 0 * * *\", # 每5分钟执行1次\n",
|
||||
" timezone=\"Asia/Shanghai\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "12373bcbb4a0b434",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 5. 更新任务属性(对应命令行 update)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"id": "c892fd8ad2f0dd9d",
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2025-10-29T02:29:56.088085Z",
|
||||
"start_time": "2025-10-29T02:29:55.754298Z"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"# 更新任务属性\n",
|
||||
"def update_task(task_id, **kwargs):\n",
|
||||
" updates = {}\n",
|
||||
" if 'name' in kwargs and kwargs['name']:\n",
|
||||
" updates['task_name'] = kwargs['name']\n",
|
||||
" if 'type' in kwargs and kwargs['type']:\n",
|
||||
" updates['task_type'] = kwargs['type']\n",
|
||||
" if 'module' in kwargs and kwargs['module']:\n",
|
||||
" updates['module_path'] = kwargs['module']\n",
|
||||
" if 'cron' in kwargs and kwargs['cron']:\n",
|
||||
" updates['cron_expression'] = kwargs['cron']\n",
|
||||
" if 'timezone' in kwargs and kwargs['timezone']:\n",
|
||||
" updates['time_zone'] = kwargs['timezone']\n",
|
||||
"\n",
|
||||
" if not updates:\n",
|
||||
" display(Markdown(\"### 没有提供任何更新内容\"))\n",
|
||||
" return False\n",
|
||||
"\n",
|
||||
" success = manager.update_task(task_id, updates)\n",
|
||||
" if success:\n",
|
||||
" display(Markdown(f\"### 任务ID {task_id} 更新成功\"))\n",
|
||||
" show_task_details(task_id) # 显示更新后的详情\n",
|
||||
" else:\n",
|
||||
" display(Markdown(f\"### 任务ID {task_id} 更新失败\"))\n",
|
||||
" return success\n",
|
||||
"\n",
|
||||
"# 执行:更新任务(示例:修改任务1的Cron表达式为每天10点)\n",
|
||||
"update_task(2, module = \"processors.processor_rss_data\")\n",
|
||||
"\n",
|
||||
"# 执行:同时更新多个属性(名称和Cron表达式)\n",
|
||||
"# update_task(1, name=\"每日早间新闻采集\", cron=\"0 8 * * *\")"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<IPython.core.display.Markdown object>"
|
||||
],
|
||||
"text/markdown": "### 任务ID 2 更新成功"
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data",
|
||||
"jetTransient": {
|
||||
"display_id": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001B[32m2025-10-29 10:29:56\u001B[0m | \u001B[1mINFO \u001B[0m | \u001B[36mmysql_agent\u001B[0m - \u001B[1m查询执行成功\u001B[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<IPython.core.display.Markdown object>"
|
||||
],
|
||||
"text/markdown": "### 任务详情\n**任务ID**: 2\n**任务名称**: RSS基于规则数据处理\n**任务类型**: processor\n**模块路径**: processors.processor_rss_data\n**Cron表达式**: 0 8,20 * * *\n**时区**: Asia/Shanghai\n**最后运行时间**: 2025-10-28 13:34:49\n**下次运行时间**: 2025-10-28 20:00:00\n**运行状态**: success\n**是否活跃**: 是\n**运行次数**: 10\n**创建时间**: 2025-10-22 16:06:42"
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data",
|
||||
"jetTransient": {
|
||||
"display_id": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"True"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"execution_count": 21
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "37564011cf5aa501",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 6. 启用 / 禁用任务(对应命令行 toggle)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "65388d10c5c8d407",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/markdown": [
|
||||
"### 任务ID 1 启用成功"
|
||||
],
|
||||
"text/plain": [
|
||||
"<IPython.core.display.Markdown object>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"True"
|
||||
]
|
||||
},
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 启用或禁用任务\n",
|
||||
"def toggle_task_status(task_id, activate=True):\n",
|
||||
" success = manager.toggle_task_status(task_id, activate)\n",
|
||||
" action = \"启用\" if activate else \"禁用\"\n",
|
||||
" if success:\n",
|
||||
" display(Markdown(f\"### 任务ID {task_id} {action}成功\"))\n",
|
||||
" else:\n",
|
||||
" display(Markdown(f\"### 任务ID {task_id} {action}失败\"))\n",
|
||||
" return success\n",
|
||||
"\n",
|
||||
"# 执行:启用任务ID为1的任务\n",
|
||||
"toggle_task_status(1, activate=True)\n",
|
||||
"\n",
|
||||
"# 执行:禁用任务ID为1的任务\n",
|
||||
"# toggle_task_status(1, activate=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c554c748169d5ac8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 7. 手动执行任务(对应命令行 run)\n",
|
||||
"\n",
|
||||
"自动识别main,即main的上一级"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"id": "94892f4134316f8e",
|
||||
"metadata": {
|
||||
"jupyter": {
|
||||
"is_executing": true
|
||||
},
|
||||
"ExecuteTime": {
|
||||
"start_time": "2025-10-29T02:30:10.298891Z"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"# 手动执行任务(异步方式,快速返回)\n",
|
||||
"def run_task_manually(task_id):\n",
|
||||
" display(Markdown(f\"### 正在手动执行任务ID {task_id}...\"))\n",
|
||||
" success = manager.run_task_manually(task_id)\n",
|
||||
" if success:\n",
|
||||
" display(Markdown(f\"### 任务ID {task_id} 执行成功\"))\n",
|
||||
" else:\n",
|
||||
" display(Markdown(f\"### 任务ID {task_id} 执行失败\"))\n",
|
||||
" return success\n",
|
||||
"\n",
|
||||
"# 手动执行任务(同步方式,显示详细执行过程)\n",
|
||||
"def run_task_with_details(task_id):\n",
|
||||
" display(Markdown(f\"### 开始执行任务ID {task_id}\"))\n",
|
||||
" display(Markdown(\"---\"))\n",
|
||||
" \n",
|
||||
" result = manager.run_task_synchronously(task_id)\n",
|
||||
" \n",
|
||||
" if not result['success'] and result.get('error') and 'task_id' not in result:\n",
|
||||
" display(Markdown(f\"### ❌ 错误: {result['error']}\"))\n",
|
||||
" return result\n",
|
||||
" \n",
|
||||
" # 显示任务基本信息\n",
|
||||
" display(Markdown(f\"**任务名称**: {result['task_name']}\"))\n",
|
||||
" display(Markdown(f\"**任务ID**: {result['task_id']}\"))\n",
|
||||
" display(Markdown(f\"**执行时长**: {result['execution_time']:.2f} 秒\"))\n",
|
||||
" display(Markdown(\"---\"))\n",
|
||||
" \n",
|
||||
" # 显示执行输出\n",
|
||||
" if result['output']:\n",
|
||||
" display(Markdown(\"### 📋 执行输出:\"))\n",
|
||||
" print(result['output'])\n",
|
||||
" display(Markdown(\"---\"))\n",
|
||||
" \n",
|
||||
" # 显示执行结果\n",
|
||||
" if result['success']:\n",
|
||||
" display(Markdown(\"### ✅ 任务执行成功\"))\n",
|
||||
" else:\n",
|
||||
" display(Markdown(f\"### ❌ 任务执行失败\"))\n",
|
||||
" if result['error']:\n",
|
||||
" display(Markdown(f\"**错误信息**: {result['error']}\"))\n",
|
||||
" \n",
|
||||
" return result\n",
|
||||
"\n",
|
||||
"# 执行:手动运行任务ID为2的任务(显示详细执行过程)\n",
|
||||
"run_task_with_details(3)"
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<IPython.core.display.Markdown object>"
|
||||
],
|
||||
"text/markdown": "### 开始执行任务ID 3"
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data",
|
||||
"jetTransient": {
|
||||
"display_id": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<IPython.core.display.Markdown object>"
|
||||
],
|
||||
"text/markdown": "---"
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data",
|
||||
"jetTransient": {
|
||||
"display_id": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\u001B[32m2025-10-29 10:30:10\u001B[0m | \u001B[1mINFO \u001B[0m | \u001B[36mmysql_agent\u001B[0m - \u001B[1m查询执行成功\u001B[0m\n",
|
||||
"\u001B[32m2025-10-29 10:30:11\u001B[0m | \u001B[1mINFO \u001B[0m | \u001B[36mai_processor_rss_data\u001B[0m - \u001B[1mRSS数据AI处理器初始化完成\u001B[0m\n",
|
||||
"\u001B[32m2025-10-29 10:30:11\u001B[0m | \u001B[1mINFO \u001B[0m | \u001B[36mai_processor_rss_data\u001B[0m - \u001B[1m开始批量处理数据,批次大小: 200, 延迟: 1.5秒\u001B[0m\n",
|
||||
"\u001B[32m2025-10-29 10:30:11\u001B[0m | \u001B[1mINFO \u001B[0m | \u001B[36mai_processor_rss_data\u001B[0m - \u001B[1m成功加载 3 条未处理的数据\u001B[0m\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"execution_count": null
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c3492a1af7dbf2b1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 8. 删除任务(对应命令行 delete)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6936dcc673933a8d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 删除任务\n",
|
||||
"def delete_task(task_id, confirm=False):\n",
|
||||
" if not confirm:\n",
|
||||
" display(Markdown(f\"### 警告:删除任务是不可逆操作!\"))\n",
|
||||
" display(Markdown(f\"请运行 `delete_task({task_id}, confirm=True)` 确认删除\"))\n",
|
||||
" return False\n",
|
||||
"\n",
|
||||
" success = manager.delete_task(task_id)\n",
|
||||
" if success:\n",
|
||||
" display(Markdown(f\"### 任务ID {task_id} 删除成功\"))\n",
|
||||
" else:\n",
|
||||
" display(Markdown(f\"### 任务ID {task_id} 删除失败\"))\n",
|
||||
" return success\n",
|
||||
"\n",
|
||||
"# 执行:第一步 - 确认删除(不会实际删除)\n",
|
||||
"delete_task(1)\n",
|
||||
"\n",
|
||||
"# 执行:第二步 - 实际删除(谨慎操作!)\n",
|
||||
"# delete_task(1, confirm=True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "intelligence_system",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
use intelligence_system;
|
||||
SELECT * FROM main_task
|
||||
WHERE is_active = 1
|
||||
AND next_run_time <= %s
|
||||
AND is_running = 0
|
||||
ORDER BY next_run_time;
|
||||
@@ -0,0 +1 @@
|
||||
from .logger import CrossPlatformLog
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+35
-17
@@ -1,10 +1,11 @@
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
import pickle
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
from pathlib import Path, PurePath
|
||||
from typing import Union, Optional, List, Dict, Any
|
||||
from typing import Union, Optional, List, Dict, Any, Callable
|
||||
from utils.logger import log
|
||||
|
||||
class FileHandler:
|
||||
@@ -71,6 +72,17 @@ class FileHandler:
|
||||
df = pd.read_excel(file_path, **kwargs)
|
||||
elif ext == 'json':
|
||||
df = pd.read_json(file_path, encoding=encoding, **kwargs)
|
||||
elif ext in ['pkl', 'pickle']:
|
||||
# 统一将pickle内容转为DataFrame返回
|
||||
obj = pd.read_pickle(file_path)
|
||||
if isinstance(obj, pd.DataFrame):
|
||||
df = obj
|
||||
elif isinstance(obj, list):
|
||||
df = pd.DataFrame(obj)
|
||||
elif isinstance(obj, dict):
|
||||
df = pd.DataFrame([obj])
|
||||
else:
|
||||
df = pd.DataFrame({'content': [obj]})
|
||||
elif ext == 'parquet':
|
||||
df = pd.read_parquet(file_path, **kwargs)
|
||||
else:
|
||||
@@ -102,25 +114,31 @@ class FileHandler:
|
||||
if not parent_dir.exists():
|
||||
self.create_dir(parent_dir)
|
||||
|
||||
# 统一数据格式
|
||||
if isinstance(data, pd.DataFrame):
|
||||
df = data
|
||||
else:
|
||||
df = pd.DataFrame(data if isinstance(data, list) else [data])
|
||||
|
||||
# 根据扩展名选择写入方式
|
||||
ext = self.get_file_extension(file_path)
|
||||
if ext in ['csv', 'txt']:
|
||||
df.to_csv(file_path, encoding=encoding, index=False, **kwargs)
|
||||
elif ext in ['xls', 'xlsx']:
|
||||
df.to_excel(file_path, index=False, **kwargs)
|
||||
elif ext == 'json':
|
||||
df.to_json(file_path, force_ascii=False, **kwargs)
|
||||
elif ext == 'parquet':
|
||||
df.to_parquet(file_path, **kwargs)
|
||||
|
||||
if ext in ['pkl', 'pickle']:
|
||||
# 直接按原始对象进行pickle序列化
|
||||
with open(file_path, 'wb') as f:
|
||||
pickle.dump(data, f)
|
||||
else:
|
||||
with open(file_path, 'w', encoding=encoding) as f:
|
||||
f.write(str(data))
|
||||
# 统一数据格式到DataFrame
|
||||
if isinstance(data, pd.DataFrame):
|
||||
df = data
|
||||
else:
|
||||
df = pd.DataFrame(data if isinstance(data, list) else [data])
|
||||
|
||||
if ext in ['csv', 'txt']:
|
||||
df.to_csv(file_path, encoding=encoding, index=False, **kwargs)
|
||||
elif ext in ['xls', 'xlsx']:
|
||||
df.to_excel(file_path, index=False, **kwargs)
|
||||
elif ext == 'json':
|
||||
df.to_json(file_path, force_ascii=False, **kwargs)
|
||||
elif ext == 'parquet':
|
||||
df.to_parquet(file_path, **kwargs)
|
||||
else:
|
||||
with open(file_path, 'w', encoding=encoding) as f:
|
||||
f.write(str(data))
|
||||
|
||||
# 返回成功结果
|
||||
return self._format_result(
|
||||
|
||||
+32
-4
@@ -35,6 +35,7 @@ class CrossPlatformLog:
|
||||
"""配置跨平台日志处理器"""
|
||||
logger.remove() # 清除默认配置
|
||||
|
||||
|
||||
# 统一控制台输出格式
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
@@ -58,20 +59,47 @@ class CrossPlatformLog:
|
||||
compression=self._compress_log,
|
||||
encoding="utf-8",
|
||||
level="DEBUG",
|
||||
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {module}:{line} - {message}",
|
||||
# 👇 增加 {extra} 输出,并美化结构
|
||||
# format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {module}:{line} - {message}{extra_output}",
|
||||
retention="30 days",
|
||||
enqueue=True # 线程安全
|
||||
enqueue=True,
|
||||
# 👇 动态处理 extra 字段为可读格式
|
||||
format=self._format_with_extra, # 使用自定义格式函数
|
||||
)
|
||||
|
||||
def _format_with_extra(self, record):
|
||||
# 构造 extra 的可读字符串
|
||||
extra_str = ""
|
||||
if record["extra"]:
|
||||
extra_items = []
|
||||
for key, value in record["extra"].items():
|
||||
if key == "extra_output": # 跳过自己,避免递归
|
||||
continue
|
||||
value_repr = repr(value)
|
||||
# 对于错误信息,增加截断长度限制,避免丢失重要信息
|
||||
if key in ["error", "error_message", "sql", "params"]:
|
||||
if len(value_repr) > 500:
|
||||
value_repr = value_repr[:497] + "..."
|
||||
elif len(value_repr) > 200:
|
||||
value_repr = value_repr[:197] + "..."
|
||||
extra_items.append(f"\n → {key}: {value_repr}")
|
||||
extra_str = "".join(extra_items)
|
||||
|
||||
# 👉 直接将 extra_str 写入 message 或附加字段
|
||||
record["extra"]["extra_output"] = extra_str
|
||||
|
||||
# ✅ 关键:返回的 format 字符串不再引用 {extra_output},而是使用 {extra[extra_output]}
|
||||
return "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {module}:{line} - {message}{extra[extra_output]}\n"
|
||||
def _add_error_log(self):
|
||||
"""错误日志专用配置"""
|
||||
error_log = self.log_dir / "errors.log"
|
||||
logger.add(
|
||||
str(error_log),
|
||||
level="ERROR",
|
||||
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | ERROR | {module}:{line} - {message}\n{exception}",
|
||||
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | ERROR | {module}:{line} - {message}{extra[extra_output]}\n{exception}",
|
||||
rotation="10 MB",
|
||||
retention="90 days"
|
||||
retention="90 days",
|
||||
enqueue=True
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -0,0 +1,383 @@
|
||||
import os
|
||||
import sys
|
||||
import platform
|
||||
import threading
|
||||
from typing import List, Dict, Optional, BinaryIO, Tuple, Any
|
||||
from datetime import datetime, timedelta
|
||||
import hashlib
|
||||
from io import BytesIO
|
||||
from minio import Minio
|
||||
from minio.error import S3Error, MinioException
|
||||
from utils.logger import log
|
||||
|
||||
|
||||
class MinIOAgent:
|
||||
"""
|
||||
全平台兼容的MinIO对象存储操作类
|
||||
支持Windows/macOS/Linux系统,提供对象存储的上传、下载、查询等功能
|
||||
专注于二进制数据处理,返回元数据用于与MySQL关联
|
||||
"""
|
||||
_instance = None # 单例模式实例
|
||||
_lock = threading.Lock() # 线程锁,保证单例线程安全
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""单例模式实现,确保全局只有一个实例"""
|
||||
if not cls._instance:
|
||||
with cls._lock:
|
||||
if not cls._instance:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: dict):
|
||||
"""
|
||||
初始化MinIO连接
|
||||
|
||||
参数:
|
||||
config (dict): MinIO配置字典,包含以下键:
|
||||
- endpoint: 服务端点(例:'localhost:9000')
|
||||
- access_key: 访问密钥
|
||||
- secret_key: 密钥
|
||||
- [可选] secure: 是否使用SSL(默认False)
|
||||
- [可选] region: 区域
|
||||
- [可选] timeout: 超时时间(秒,默认30)
|
||||
"""
|
||||
# 避免重复初始化
|
||||
if hasattr(self, '_client') and self._client:
|
||||
return
|
||||
|
||||
# 验证必要配置参数
|
||||
required_keys = ['endpoint', 'access_key', 'secret_key']
|
||||
if not all(key in config for key in required_keys):
|
||||
raise ValueError(f"MinIO配置缺少必要参数,需要: {required_keys}")
|
||||
|
||||
# 整合配置,设置默认值
|
||||
self.config = {
|
||||
'endpoint': config['endpoint'],
|
||||
'access_key': config['access_key'],
|
||||
'secret_key': config['secret_key'],
|
||||
'secure': config.get('secure', False),
|
||||
'region': config.get('region'),
|
||||
'timeout': config.get('timeout', 30)
|
||||
}
|
||||
|
||||
# 初始化日志,绑定当前平台信息
|
||||
current_platform = platform.system()
|
||||
self.log = log.bind(module=f"MinIOAgent({current_platform})")
|
||||
|
||||
# 创建客户端实例
|
||||
self._client = self._create_client()
|
||||
|
||||
# 验证连接是否有效
|
||||
self._verify_connection()
|
||||
|
||||
def _create_client(self) -> Minio:
|
||||
"""创建MinIO客户端实例"""
|
||||
try:
|
||||
client = Minio(
|
||||
endpoint=self.config['endpoint'],
|
||||
access_key=self.config['access_key'],
|
||||
secret_key=self.config['secret_key'],
|
||||
secure=self.config['secure'],
|
||||
region=self.config['region']
|
||||
)
|
||||
self.log.info("MinIO客户端创建成功")
|
||||
return client
|
||||
except Exception as e:
|
||||
self.log.critical("创建MinIO客户端失败", 错误=str(e), exc_info=True)
|
||||
raise
|
||||
|
||||
def _verify_connection(self) -> None:
|
||||
"""验证与MinIO服务的连接是否正常"""
|
||||
try:
|
||||
# 通过列出存储桶来验证连接
|
||||
self._client.list_buckets()
|
||||
self.log.info(f"成功连接到MinIO服务:{self.config['endpoint']}")
|
||||
except Exception as e:
|
||||
self.log.critical("连接验证失败", 错误=str(e), exc_info=True)
|
||||
raise
|
||||
|
||||
def create_bucket(self, bucket_name: str) -> bool:
|
||||
"""
|
||||
创建存储桶(如不存在)
|
||||
|
||||
参数:
|
||||
bucket_name: 存储桶名称
|
||||
|
||||
返回:
|
||||
是否成功创建(或已存在)
|
||||
"""
|
||||
try:
|
||||
if not self._client.bucket_exists(bucket_name):
|
||||
self._client.make_bucket(bucket_name)
|
||||
self.log.info(f"存储桶创建成功:{bucket_name}")
|
||||
return True
|
||||
self.log.debug(f"存储桶已存在:{bucket_name}")
|
||||
return True
|
||||
except MinioException as e:
|
||||
self.log.error(f"创建存储桶 {bucket_name} 失败", 错误=str(e), exc_info=True)
|
||||
return False
|
||||
|
||||
def upload_bytes(self, bucket: str, object_name: str, data: bytes) -> Dict[str, Any]:
|
||||
"""
|
||||
上传二进制数据至MinIO
|
||||
|
||||
参数:
|
||||
bucket: 存储桶名称
|
||||
object_name: 对象名称(路径)
|
||||
data: 二进制数据
|
||||
|
||||
返回:
|
||||
包含元数据的字典:
|
||||
- bucket: 存储桶名称
|
||||
- object_name: 对象路径
|
||||
- size: 数据大小(字节)
|
||||
- etag: 服务器生成的哈希值
|
||||
- content_type: 内容类型
|
||||
- upload_time: 上传时间(UTC)
|
||||
- local_hash: 本地计算的MD5哈希
|
||||
"""
|
||||
if not data:
|
||||
raise ValueError("上传数据不能为空")
|
||||
|
||||
# 确保存储桶存在
|
||||
self.create_bucket(bucket)
|
||||
|
||||
try:
|
||||
# 计算本地哈希(用于数据完整性校验)
|
||||
local_hash = hashlib.md5(data).hexdigest()
|
||||
|
||||
# 上传数据
|
||||
result = self._client.put_object(
|
||||
bucket_name=bucket,
|
||||
object_name=object_name,
|
||||
data=BytesIO(data),
|
||||
length=len(data),
|
||||
content_type=self._guess_content_type(object_name)
|
||||
)
|
||||
|
||||
# 构建元数据
|
||||
metadata = {
|
||||
'bucket': bucket,
|
||||
'object_name': object_name,
|
||||
'size': len(data),
|
||||
'etag': result.etag,
|
||||
'content_type': result.content_type,
|
||||
'upload_time': datetime.utcfromtimestamp(result.last_modified.timestamp()),
|
||||
'local_hash': local_hash
|
||||
}
|
||||
|
||||
self.log.info(
|
||||
"文件上传成功",
|
||||
存储桶=bucket,
|
||||
对象名称=object_name,
|
||||
大小=len(data)
|
||||
)
|
||||
return metadata
|
||||
|
||||
except MinioException as e:
|
||||
self.log.error(
|
||||
"文件上传失败",
|
||||
存储桶=bucket,
|
||||
对象名称=object_name,
|
||||
错误=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def download_file(self, bucket: str, object_name: str, local_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
从MinIO下载文件至本地
|
||||
|
||||
参数:
|
||||
bucket: 存储桶名称
|
||||
object_name: 对象名称(路径)
|
||||
local_path: 本地保存路径
|
||||
|
||||
返回:
|
||||
包含下载信息的字典:
|
||||
- local_path: 本地路径
|
||||
- size: 文件大小
|
||||
- download_time: 下载时间
|
||||
"""
|
||||
try:
|
||||
# 创建父目录(如果不存在)
|
||||
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||
|
||||
# 下载文件
|
||||
start_time = datetime.now()
|
||||
self._client.fget_object(bucket, object_name, local_path)
|
||||
download_time = datetime.now() - start_time
|
||||
|
||||
# 获取文件信息
|
||||
stat = os.stat(local_path)
|
||||
|
||||
result = {
|
||||
'local_path': local_path,
|
||||
'size': stat.st_size,
|
||||
'download_time': download_time.total_seconds(),
|
||||
'downloaded_at': datetime.now()
|
||||
}
|
||||
|
||||
self.log.info(
|
||||
"文件下载成功",
|
||||
存储桶=bucket,
|
||||
对象名称=object_name,
|
||||
本地路径=local_path,
|
||||
大小=stat.st_size
|
||||
)
|
||||
return result
|
||||
|
||||
except MinioException as e:
|
||||
self.log.error(
|
||||
"文件下载失败",
|
||||
存储桶=bucket,
|
||||
对象名称=object_name,
|
||||
错误=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
except IOError as e:
|
||||
self.log.error(
|
||||
"本地文件操作失败",
|
||||
本地路径=local_path,
|
||||
错误=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def get_presigned_url(self, bucket: str, object_name: str, expires: int = 3600) -> Dict[str, str]:
|
||||
"""
|
||||
生成临时访问URL
|
||||
|
||||
参数:
|
||||
bucket: 存储桶名称
|
||||
object_name: 对象名称(路径)
|
||||
expires: 过期时间(秒),默认3600秒
|
||||
|
||||
返回:
|
||||
包含URL和过期信息的字典
|
||||
"""
|
||||
try:
|
||||
url = self._client.presigned_get_object(
|
||||
bucket_name=bucket,
|
||||
object_name=object_name,
|
||||
expires=expires
|
||||
)
|
||||
|
||||
result = {
|
||||
'presigned_url': url,
|
||||
'expires_in': expires,
|
||||
'expires_at': datetime.now() + timedelta(seconds=expires),
|
||||
'bucket': bucket,
|
||||
'object_name': object_name
|
||||
}
|
||||
|
||||
self.log.debug(
|
||||
"预签名URL生成成功",
|
||||
存储桶=bucket,
|
||||
对象名称=object_name,
|
||||
过期时间=expires
|
||||
)
|
||||
return result
|
||||
|
||||
except MinioException as e:
|
||||
self.log.error(
|
||||
"生成预签名URL失败",
|
||||
存储桶=bucket,
|
||||
对象名称=object_name,
|
||||
错误=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def list_objects(self, bucket: str, prefix: str = "") -> List[Dict[str, Any]]:
|
||||
"""
|
||||
查询指定前缀的对象列表及元数据
|
||||
|
||||
参数:
|
||||
bucket: 存储桶名称
|
||||
prefix: 对象路径前缀
|
||||
|
||||
返回:
|
||||
对象信息列表,每个对象包含:
|
||||
- bucket: 存储桶
|
||||
- object_name: 对象名称
|
||||
- size: 大小
|
||||
- last_modified: 最后修改时间
|
||||
- etag: 哈希值
|
||||
- content_type: 内容类型
|
||||
"""
|
||||
try:
|
||||
objects = self._client.list_objects(
|
||||
bucket_name=bucket,
|
||||
prefix=prefix,
|
||||
recursive=True
|
||||
)
|
||||
|
||||
result = []
|
||||
for obj in objects:
|
||||
# 获取详细元数据
|
||||
stat = self._client.stat_object(bucket, obj.object_name)
|
||||
|
||||
result.append({
|
||||
'bucket': bucket,
|
||||
'object_name': obj.object_name,
|
||||
'size': obj.size,
|
||||
'last_modified': obj.last_modified,
|
||||
'etag': stat.etag,
|
||||
'content_type': stat.content_type
|
||||
})
|
||||
|
||||
self.log.info(
|
||||
"对象列表查询成功",
|
||||
存储桶=bucket,
|
||||
前缀=prefix,
|
||||
数量=len(result)
|
||||
)
|
||||
return result
|
||||
|
||||
except MinioException as e:
|
||||
self.log.error(
|
||||
"查询对象列表失败",
|
||||
存储桶=bucket,
|
||||
前缀=prefix,
|
||||
错误=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def delete_object(self, bucket: str, object_name: str) -> bool:
|
||||
"""
|
||||
删除指定对象
|
||||
|
||||
参数:
|
||||
bucket: 存储桶名称
|
||||
object_name: 对象名称(路径)
|
||||
|
||||
返回:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
self._client.remove_object(bucket, object_name)
|
||||
self.log.info(
|
||||
"对象删除成功",
|
||||
存储桶=bucket,
|
||||
对象名称=object_name
|
||||
)
|
||||
return True
|
||||
except MinioException as e:
|
||||
self.log.error(
|
||||
"删除对象失败",
|
||||
存储桶=bucket,
|
||||
对象名称=object_name,
|
||||
错误=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _guess_content_type(object_name: str) -> str:
|
||||
"""根据文件名猜测内容类型"""
|
||||
from mimetypes import guess_type
|
||||
mime_type, _ = guess_type(object_name)
|
||||
return mime_type or 'application/octet-stream' # 默认二进制流类型
|
||||
@@ -0,0 +1,722 @@
|
||||
import os
|
||||
import sys
|
||||
import platform
|
||||
import pandas as pd
|
||||
import pymysql
|
||||
import json
|
||||
import numpy as np
|
||||
from pymysql import cursors
|
||||
from pymysql.err import MySQLError
|
||||
from typing import Union, List, Dict, Any, Optional, Tuple, Literal
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
# 导入日志系统
|
||||
from utils.logger import log
|
||||
|
||||
|
||||
class MySQLAgent:
|
||||
"""
|
||||
全平台兼容的MySQL数据库操作类
|
||||
支持Windows/macOS/Linux系统
|
||||
配置参数从外部传入,不使用连接池和事务管理
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if not cls._instance:
|
||||
with cls._lock:
|
||||
if not cls._instance:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: dict):
|
||||
"""初始化MySQL数据库连接(原有逻辑完全保留)"""
|
||||
if hasattr(self, 'config') and self.config:
|
||||
return
|
||||
|
||||
# 基础配置校验
|
||||
required_keys = ['host', 'port', 'user', 'password', 'database']
|
||||
if not all(key in config for key in required_keys):
|
||||
log.warning(f"数据库配置缺少必要参数,当前数据库链接信息为:{config}")
|
||||
raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}")
|
||||
|
||||
self.config = {
|
||||
'host': config['host'],
|
||||
'port': config['port'],
|
||||
'user': config['user'],
|
||||
'password': config['password'],
|
||||
'database': config['database'],
|
||||
'charset': config.get('charset', 'utf8mb4'),
|
||||
'autocommit': True,
|
||||
'connect_timeout': config.get('connect_timeout', 10),
|
||||
'read_timeout': config.get('read_timeout', 30),
|
||||
'write_timeout': config.get('write_timeout', 30),
|
||||
'ssl': config.get('ssl')
|
||||
}
|
||||
|
||||
# 初始化日志
|
||||
current_platform = platform.system()
|
||||
self.log = log.bind(module=f"MySQLAgent({current_platform})")
|
||||
|
||||
def get_connection(self) -> pymysql.connections.Connection:
|
||||
"""获取数据库连接(原有逻辑完全保留)"""
|
||||
try:
|
||||
conn = pymysql.connect(** self.config)
|
||||
|
||||
# 为连接添加 character_set_name 方法
|
||||
if not hasattr(conn, 'character_set_name'):
|
||||
def _character_set_name():
|
||||
return self.config.get('charset', 'utf8mb4')
|
||||
|
||||
conn.character_set_name = _character_set_name
|
||||
|
||||
# macOS需要特殊处理SSL
|
||||
if platform.system() == 'Darwin' and self.config.get('ssl'):
|
||||
conn.ping(reconnect=True)
|
||||
|
||||
self.log.trace("获取数据库连接成功")
|
||||
return conn
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if platform.system() == 'Windows' and "timed out" in error_msg:
|
||||
self.log.warning("Windows连接超时,正在重试...")
|
||||
return self._retry_connection()
|
||||
|
||||
self.log.error("连接失败",
|
||||
error=error_msg,
|
||||
error_type=type(e).__name__,
|
||||
host=self.config.get('host'),
|
||||
port=self.config.get('port'),
|
||||
database=self.config.get('database'),
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
def _retry_connection(self, max_retries: int = 3) -> Any | None:
|
||||
"""Windows平台连接重试机制(原有逻辑完全保留)"""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
conn = pymysql.connect(**self.config)
|
||||
self.log.info(f"经过 {attempt + 1} 次尝试后成功建立连接")
|
||||
return conn
|
||||
except Exception:
|
||||
if attempt == max_retries - 1:
|
||||
raise
|
||||
import time
|
||||
time.sleep(1)
|
||||
|
||||
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
|
||||
parse_dates: Union[List[str], bool] = True,is_print = True) -> pd.DataFrame:
|
||||
"""执行SQL查询并返回DataFrame(原有逻辑完全保留)"""
|
||||
try:
|
||||
self.log.debug("执行SQL查询", sql=sql)
|
||||
|
||||
# 获取连接并确保字符集方法存在
|
||||
conn = self.get_connection()
|
||||
|
||||
# 创建SQLAlchemy引擎
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
engine = create_engine(
|
||||
"mysql+pymysql://",
|
||||
creator=lambda: conn,
|
||||
poolclass=StaticPool,
|
||||
connect_args={'charset': self.config.get('charset', 'utf8mb4')}
|
||||
)
|
||||
|
||||
# 执行查询
|
||||
df = pd.read_sql(sql, engine, params=params, parse_dates=parse_dates)
|
||||
if is_print:
|
||||
self.log.info("查询执行成功", 行数=len(df))
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("SQL查询失败",
|
||||
sql=sql,
|
||||
params=params,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
if 'engine' in locals():
|
||||
engine.dispose()
|
||||
|
||||
def insert_from_df(self, table_name: str, df: pd.DataFrame,
|
||||
chunk_size: int = 1000, replace: bool = False,
|
||||
ignore_duplicates: bool = None) -> int:
|
||||
"""
|
||||
兼容旧接口的通用插入方法:保留replace参数,同时支持新的ignore_duplicates
|
||||
自动处理重复数据,对所有数据源通用,插入失败的数据会通过日志记录
|
||||
"""
|
||||
# 【兼容性处理】如果未指定ignore_duplicates,用replace参数推导
|
||||
if ignore_duplicates is None:
|
||||
ignore_duplicates = not replace # 旧逻辑中replace=True表示替换,即不忽略重复
|
||||
|
||||
if df.empty:
|
||||
self.log.warning("尝试插入空的DataFrame", table=table_name)
|
||||
return 0
|
||||
|
||||
conn = None
|
||||
cursor = None
|
||||
total_inserted = 0
|
||||
total_duplicates = 0
|
||||
total_failed = 0
|
||||
failed_records = [] # 存储所有失败的记录
|
||||
|
||||
try:
|
||||
# 1. 建立数据库连接
|
||||
conn = self.get_connection()
|
||||
cursor = conn.cursor()
|
||||
self.log.debug(f"已建立连接,准备插入数据到 {table_name}")
|
||||
|
||||
# 2. 获取数据库表的实际列名
|
||||
cursor.execute(f"SHOW COLUMNS FROM `{table_name}`")
|
||||
columns_info = cursor.fetchall()
|
||||
db_columns = [col[0] for col in columns_info]
|
||||
self.log.debug(f"表 {table_name} 包含以下列:{db_columns}")
|
||||
|
||||
# 3. 数据预处理:统一处理空值
|
||||
cleaned_df = df.replace(
|
||||
[None, np.nan, pd.NA, 'nan', 'NaN', 'NAN', ''],
|
||||
None
|
||||
).copy()
|
||||
|
||||
# 4. 字段匹配:只保留与数据库匹配的列
|
||||
df_columns = cleaned_df.columns.tolist()
|
||||
matched_columns = [col for col in df_columns if col in db_columns]
|
||||
unmatched_columns = [col for col in df_columns if col not in db_columns]
|
||||
|
||||
if unmatched_columns:
|
||||
self.log.warning(
|
||||
f"表 {table_name} 中存在不匹配的列,已自动丢弃",
|
||||
unmatched_columns=unmatched_columns,
|
||||
count=len(unmatched_columns)
|
||||
)
|
||||
|
||||
if not matched_columns:
|
||||
self.log.warning(f"表 {table_name} 没有匹配的列,终止插入操作")
|
||||
return 0
|
||||
|
||||
filtered_df = cleaned_df[matched_columns].copy()
|
||||
total_to_insert = len(filtered_df)
|
||||
self.log.debug(
|
||||
f"表 {table_name} 的过滤后DataFrame:共 {total_to_insert} 行待插入"
|
||||
)
|
||||
|
||||
# 5. 处理复杂类型(dict/list转JSON)
|
||||
for col in filtered_df.columns:
|
||||
has_complex_type = filtered_df[col].apply(
|
||||
lambda x: isinstance(x, (dict, list)) if x is not None else False
|
||||
).any()
|
||||
|
||||
if has_complex_type:
|
||||
self.log.debug(f"表 {table_name} 中的 {col} 列包含复杂类型,正在转换为JSON")
|
||||
filtered_df.loc[:, col] = filtered_df[col].apply(
|
||||
lambda x: json.dumps(x, ensure_ascii=False) if x is not None else x
|
||||
)
|
||||
|
||||
# 6. 构建通用插入SQL
|
||||
columns_str = ', '.join([f"`{col}`" for col in filtered_df.columns])
|
||||
placeholders = ', '.join(['%s'] * len(filtered_df.columns))
|
||||
insert_sql = f"INSERT INTO `{table_name}` ({columns_str}) VALUES ({placeholders})"
|
||||
self.log.trace(f"为表 {table_name} 生成的插入SQL:{insert_sql}")
|
||||
|
||||
# 7. 逐条插入(确保能捕获单条重复错误)
|
||||
records = filtered_df.to_dict('records')
|
||||
indices = filtered_df.index.tolist()
|
||||
|
||||
for i, (record, idx) in enumerate(zip(records, indices)):
|
||||
try:
|
||||
data = tuple(record[col] for col in filtered_df.columns)
|
||||
cursor.execute(insert_sql, data)
|
||||
total_inserted += 1
|
||||
|
||||
if (i + 1) % 100 == 0:
|
||||
self.log.trace(
|
||||
f"已向表 {table_name} 插入 {i + 1}/{total_to_insert} 行数据"
|
||||
)
|
||||
|
||||
except MySQLError as e:
|
||||
# 8. 捕获重复错误(MySQL错误码1062)
|
||||
if e.args[0] == 1062:
|
||||
total_duplicates += 1
|
||||
short_record = {
|
||||
k: (str(v)[:100] + '...') if isinstance(v, (str, dict, list)) else v
|
||||
for k, v in record.items()
|
||||
}
|
||||
self.log.warning(
|
||||
f"表 {table_name} 中跳过重复记录",
|
||||
index=idx,
|
||||
error_message=e.args[1],
|
||||
record=short_record
|
||||
)
|
||||
# 记录重复的记录
|
||||
failed_records.append({
|
||||
'index': idx,
|
||||
'type': 'duplicate',
|
||||
'error_code': e.args[0],
|
||||
'error_message': e.args[1],
|
||||
'record': record
|
||||
})
|
||||
if not ignore_duplicates:
|
||||
raise
|
||||
else:
|
||||
# 其他数据库错误
|
||||
total_failed += 1
|
||||
# 记录失败的记录详情
|
||||
failed_records.append({
|
||||
'index': idx,
|
||||
'type': 'error',
|
||||
'error_code': e.args[0],
|
||||
'error_message': e.args[1],
|
||||
'record': record
|
||||
})
|
||||
self.log.error(
|
||||
f"表 {table_name} 插入记录失败",
|
||||
index=idx,
|
||||
error_code=e.args[0],
|
||||
error_message=e.args[1],
|
||||
record=record # 完整记录写入日志
|
||||
)
|
||||
if not ignore_duplicates:
|
||||
raise
|
||||
|
||||
# 提交事务
|
||||
conn.commit()
|
||||
|
||||
# 9. 插入结果统计,包括失败记录汇总
|
||||
self.log.info(
|
||||
f"表 {table_name} 插入结果汇总",
|
||||
total_to_insert=total_to_insert,
|
||||
total_inserted=total_inserted,
|
||||
total_duplicates=total_duplicates,
|
||||
total_failed=total_failed,
|
||||
failed_records_count=len(failed_records)
|
||||
)
|
||||
|
||||
# 单独记录所有失败的数据详情
|
||||
if failed_records:
|
||||
self.log.error(
|
||||
f"表 {table_name} 插入失败记录详情",
|
||||
failed_records_summary=[
|
||||
{
|
||||
'index': r['index'],
|
||||
'type': r['type'],
|
||||
'error_code': r['error_code'],
|
||||
'error_message': r['error_message']
|
||||
} for r in failed_records
|
||||
],
|
||||
# 完整记录可以作为调试信息单独记录,避免日志过大
|
||||
detailed_failed_records=failed_records
|
||||
)
|
||||
|
||||
return total_inserted
|
||||
|
||||
except Exception as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
self.log.error(f"表 {table_name} 批量插入失败",
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
table_name=table_name,
|
||||
total_records=len(df) if not df.empty else 0,
|
||||
exc_info=True)
|
||||
# 记录事务回滚时的失败记录
|
||||
if failed_records:
|
||||
self.log.error(
|
||||
f"表 {table_name} 事务回滚,已失败的记录",
|
||||
failed_records=failed_records,
|
||||
failed_count=len(failed_records)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
def _get_primary_key(self, table_name: str, cursor) -> Optional[str]:
|
||||
"""【新增辅助方法】获取表的主键(用于replace逻辑的去重)"""
|
||||
try:
|
||||
cursor.execute("""
|
||||
SELECT COLUMN_NAME
|
||||
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
|
||||
WHERE TABLE_SCHEMA = %s
|
||||
AND TABLE_NAME = %s
|
||||
AND CONSTRAINT_NAME = 'PRIMARY'
|
||||
""", (self.config['database'], table_name))
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
except Exception as e:
|
||||
self.log.warning(f"获取表 {table_name} 的主键失败", error=str(e))
|
||||
return None
|
||||
|
||||
def _get_table_detailed_info(self, table_name: str) -> Dict[str, Dict[str, Any]]:
|
||||
"""获取表的详细结构信息(原有逻辑完全保留,供其他方法调用)"""
|
||||
sql = """
|
||||
SELECT column_name, data_type, character_maximum_length
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = %s \
|
||||
AND table_name = %s \
|
||||
"""
|
||||
params = (self.config['database'], table_name)
|
||||
|
||||
try:
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql, params)
|
||||
result = cursor.fetchall()
|
||||
|
||||
# 强制转换为列表,避免游标类型导致的解析问题
|
||||
result_list = list(result)
|
||||
if not result_list:
|
||||
self.log.error("未在表中找到任何列", 表=table_name)
|
||||
return {}
|
||||
|
||||
schema = {}
|
||||
for row in result_list:
|
||||
# 确保正确提取字段名(兼容元组格式)
|
||||
col_name = str(row[0]).strip() # 强制转为字符串并去空格
|
||||
data_type = str(row[1]).strip()
|
||||
max_length = row[2] if row[2] else None
|
||||
|
||||
schema[col_name] = {
|
||||
'type': data_type,
|
||||
'max_length': max_length
|
||||
}
|
||||
|
||||
self.log.debug("成功获取表结构信息",
|
||||
表=table_name,
|
||||
列=list(schema.keys()))
|
||||
return schema
|
||||
finally:
|
||||
cursor.close()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
self.log.error("获取表详细信息失败",
|
||||
表=table_name,
|
||||
error=str(e))
|
||||
raise
|
||||
|
||||
def _validate_and_clean_data(self, df: pd.DataFrame, table_name: str,
|
||||
table_schema: Dict[str, Dict[str, Any]]) -> pd.DataFrame:
|
||||
"""数据校验与清洗(原有逻辑完全保留,供其他方法调用)"""
|
||||
# 1. 字段过滤:只保留表中存在的字段
|
||||
df_columns = df.columns.tolist()
|
||||
table_columns = list(table_schema.keys())
|
||||
|
||||
valid_columns = [col for col in df_columns if col in table_columns]
|
||||
invalid_columns = [col for col in df_columns if col not in table_columns]
|
||||
|
||||
if invalid_columns:
|
||||
self.log.warning("丢弃表中不存在的无效列",
|
||||
表=table_name,
|
||||
无效列=invalid_columns,
|
||||
数量=len(invalid_columns))
|
||||
|
||||
cleaned_df = df[valid_columns].copy()
|
||||
if cleaned_df.empty:
|
||||
return cleaned_df
|
||||
|
||||
# 2. 处理每个字段的数据
|
||||
for col in valid_columns:
|
||||
col_info = table_schema[col]
|
||||
data_type = col_info['type']
|
||||
max_length = col_info['max_length']
|
||||
|
||||
# 2.1 处理空值
|
||||
if cleaned_df[col].isnull().any():
|
||||
# 根据字段类型设置默认值
|
||||
default_value = '' if data_type in ['varchar', 'char', 'text'] else None
|
||||
cleaned_df[col].fillna(default_value, inplace=True)
|
||||
self.log.debug("替换空值",
|
||||
表=table_name,
|
||||
列=col,
|
||||
默认值=default_value,
|
||||
数量=cleaned_df[col].isnull().sum())
|
||||
|
||||
# 2.2 处理字符串类型的超长字段
|
||||
if data_type in ['varchar', 'char'] and max_length:
|
||||
# 确保是字符串类型
|
||||
cleaned_df[col] = cleaned_df[col].astype(str)
|
||||
# 截断超长内容
|
||||
too_long_mask = cleaned_df[col].str.len() > max_length
|
||||
if too_long_mask.any():
|
||||
cleaned_df.loc[too_long_mask, col] = cleaned_df.loc[too_long_mask, col].str.slice(0, max_length)
|
||||
self.log.warning("截断超长值",
|
||||
表=table_name,
|
||||
列=col,
|
||||
最大长度=max_length,
|
||||
数量=too_long_mask.sum())
|
||||
|
||||
# 2.3 处理日期时间类型
|
||||
if data_type in ['datetime', 'timestamp']:
|
||||
try:
|
||||
# 尝试转换为datetime类型
|
||||
cleaned_df[col] = pd.to_datetime(cleaned_df[col])
|
||||
except Exception as e:
|
||||
self.log.warning("转换为datetime失败,使用当前时间替代",
|
||||
表=table_name,
|
||||
列=col,
|
||||
错误=str(e))
|
||||
# 转换失败的用当前时间替代
|
||||
invalid_mask = pd.to_datetime(cleaned_df[col], errors='coerce').isna()
|
||||
cleaned_df.loc[invalid_mask, col] = datetime.now()
|
||||
|
||||
return cleaned_df
|
||||
|
||||
def update_from_df(self, table_name: str, df: pd.DataFrame,
|
||||
key_columns: Union[str, List[str]]) -> int:
|
||||
"""使用DataFrame数据更新数据库表(原有逻辑完全保留)"""
|
||||
if df.empty:
|
||||
self.log.warning("尝试使用空的DataFrame进行更新", 表=table_name)
|
||||
return 0
|
||||
|
||||
self.log.debug("准备从DataFrame更新表数据",
|
||||
表=table_name,
|
||||
关键字列=key_columns,
|
||||
行数=len(df))
|
||||
|
||||
try:
|
||||
if isinstance(key_columns, str):
|
||||
key_columns = [key_columns]
|
||||
|
||||
总更新数 = 0
|
||||
with self.get_connection() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
# 获取表结构信息
|
||||
table_info = self._get_table_detailed_info(table_name)
|
||||
columns = [col for col in df.columns if col in table_info]
|
||||
|
||||
# 构建UPDATE语句模板
|
||||
set_clause = ', '.join([f"{col}=%s" for col in columns if col not in key_columns])
|
||||
where_clause = ' AND '.join([f"{col}=%s" for col in key_columns])
|
||||
|
||||
if not set_clause:
|
||||
self.log.warning("没有可更新的列", 表=table_name)
|
||||
return 0
|
||||
|
||||
update_sql = f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}"
|
||||
self.log.trace("生成的更新SQL", sql=update_sql)
|
||||
|
||||
# 准备数据
|
||||
update_data = []
|
||||
for _, row in df.iterrows():
|
||||
set_values = [row[col] for col in columns if col not in key_columns]
|
||||
key_values = [row[col] for col in key_columns]
|
||||
update_data.append(tuple(set_values + key_values))
|
||||
|
||||
# 执行批量更新
|
||||
cursor.executemany(update_sql, update_data)
|
||||
总更新数 = cursor.rowcount
|
||||
conn.commit()
|
||||
|
||||
self.log.info("数据更新成功",
|
||||
表=table_name,
|
||||
更新行数=总更新数)
|
||||
return 总更新数
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("数据更新失败",
|
||||
表=table_name,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]:
|
||||
"""推断DataFrame各列的SQL类型(原有逻辑完全保留)"""
|
||||
type_mapping = {
|
||||
'int64': 'BIGINT',
|
||||
'float64': 'DOUBLE',
|
||||
'datetime64[ns]': 'DATETIME',
|
||||
'object': 'TEXT',
|
||||
'bool': 'TINYINT(1)',
|
||||
'category': 'VARCHAR(255)'
|
||||
}
|
||||
|
||||
sql_types = {}
|
||||
for col, dtype in df.dtypes.items():
|
||||
dtype_str = str(dtype)
|
||||
sql_types[col] = type_mapping.get(dtype_str, 'TEXT')
|
||||
|
||||
self.log.debug("将DataFrame类型映射为SQL类型",
|
||||
映射关系=sql_types)
|
||||
return sql_types
|
||||
|
||||
def create_table_from_df(self, table_name: str, df: pd.DataFrame,
|
||||
primary_key: Union[str, List[str], None] = None) -> bool:
|
||||
"""根据DataFrame结构创建表(原有逻辑完全保留)"""
|
||||
if self.table_exists(table_name):
|
||||
self.log.warning("表已存在", 表=table_name)
|
||||
return False
|
||||
|
||||
self.log.debug("根据DataFrame结构创建新表",
|
||||
表=table_name,
|
||||
列=list(df.columns))
|
||||
|
||||
try:
|
||||
sql_types = self.df_to_sql_type(df)
|
||||
columns_sql = []
|
||||
|
||||
for col, sql_type in sql_types.items():
|
||||
col_def = f"{col} {sql_type}"
|
||||
columns_sql.append(col_def)
|
||||
|
||||
if primary_key:
|
||||
if isinstance(primary_key, str):
|
||||
primary_key = [primary_key]
|
||||
pk_columns = [col for col in primary_key if col in sql_types]
|
||||
if pk_columns:
|
||||
columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})")
|
||||
self.log.trace("设置主键",
|
||||
表=table_name,
|
||||
主键=pk_columns)
|
||||
|
||||
create_sql = f"CREATE TABLE {table_name} (\n {',\n '.join(columns_sql)}\n)"
|
||||
|
||||
self.execute_sql(create_sql)
|
||||
self.log.info("表创建成功", 表=table_name)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("创建表失败",
|
||||
表=table_name,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
return False
|
||||
|
||||
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
|
||||
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
|
||||
"""执行SQL语句(原有逻辑完全保留)"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
# Linux/macOS需要更长的执行时间
|
||||
if platform.system() != 'Windows':
|
||||
cursor.execute("SET SESSION max_execution_time=600000")
|
||||
|
||||
cursor.execute(sql, params)
|
||||
|
||||
if fetch:
|
||||
result = cursor.fetchall()
|
||||
self.log.debug("查询执行完成", 行数=len(result))
|
||||
return result
|
||||
else:
|
||||
affected_rows = cursor.rowcount
|
||||
conn.commit() # 立即提交
|
||||
self.log.debug("更新执行完成", 受影响行数=affected_rows)
|
||||
return affected_rows
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("SQL执行失败",
|
||||
sql=sql,
|
||||
params=params,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
exc_info=True)
|
||||
raise
|
||||
|
||||
def table_exists(self, table_name: str) -> bool:
|
||||
"""检查表是否存在(原有逻辑完全保留)"""
|
||||
sql = """
|
||||
SELECT COUNT(*) as count
|
||||
FROM `information_schema`.`tables`
|
||||
WHERE `table_schema` = %s \
|
||||
AND `table_name` = %s \
|
||||
"""
|
||||
|
||||
params = (self.config['database'], table_name)
|
||||
|
||||
try:
|
||||
result = self.execute_sql(sql, params, fetch=True)
|
||||
exists = result[0][0] > 0 # 适配元组结果
|
||||
self.log.debug("检查表是否存在",
|
||||
表=table_name,
|
||||
存在=exists)
|
||||
return exists
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def drop_table(self, table_name: str) -> bool:
|
||||
"""删除表(原有逻辑完全保留)"""
|
||||
if not self.table_exists(table_name):
|
||||
self.log.warning("表不存在", 表=table_name)
|
||||
return False
|
||||
|
||||
try:
|
||||
self.execute_sql(f"DROP TABLE {table_name}")
|
||||
self.log.info("表删除成功", 表=table_name)
|
||||
return True
|
||||
except Exception as e:
|
||||
self.log.error("删除表失败",
|
||||
表=table_name,
|
||||
error=str(e),
|
||||
exc_info=True)
|
||||
return False
|
||||
|
||||
def validate_connection(self) -> bool:
|
||||
"""验证连接是否有效(原有逻辑完全保留)"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("SELECT 1")
|
||||
return cursor.fetchone()[0] == 1
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# 平台特定的默认配置(原有逻辑完全保留)
|
||||
def get_default_config():
|
||||
"""获取各平台默认配置"""
|
||||
current_platform = platform.system()
|
||||
|
||||
base_config = {
|
||||
'host': 'localhost',
|
||||
'port': 3306,
|
||||
'user': 'root',
|
||||
'password': '123123',
|
||||
'database': 'intelligence_system',
|
||||
}
|
||||
|
||||
if current_platform == 'Windows':
|
||||
return {**base_config,
|
||||
'connect_timeout': 10,
|
||||
'read_timeout': 30,
|
||||
'write_timeout': 30
|
||||
}
|
||||
elif current_platform == 'Darwin':
|
||||
return {
|
||||
**base_config,
|
||||
'connect_timeout': 15,
|
||||
'read_timeout': 60,
|
||||
'write_timeout': 60,
|
||||
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
|
||||
}
|
||||
else: # Linux和其他平台
|
||||
return {** base_config,
|
||||
'connect_timeout': 15,
|
||||
'read_timeout': 60,
|
||||
'write_timeout': 60
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 使用示例(原有逻辑完全保留)
|
||||
db = MySQLAgent(get_default_config())
|
||||
|
||||
# 测试连接
|
||||
if db.validate_connection():
|
||||
print("数据库连接成功")
|
||||
|
||||
# 获取数据库版本
|
||||
version = db.query_to_df("SELECT VERSION() as version")
|
||||
print(f"数据库版本: {version['version'].iloc[0]}")
|
||||
else:
|
||||
print("连接数据库失败")
|
||||
Reference in New Issue
Block a user