Compare commits

..

16 Commits

Author SHA1 Message Date
panda 1dfc5f1024 前端展示 2025-08-14 10:56:45 +08:00
panda 498ddca73c 前端展示 2025-08-14 10:40:51 +08:00
panda a0ea91c97b 其他 2025-08-12 13:39:36 +08:00
panda e9513ea6a4 log 2025-08-12 13:36:07 +08:00
panda d2b57bb21a 数据库操作说明 2025-08-07 17:58:41 +08:00
panda b33d61c774 数据库操作说明 2025-08-06 18:01:26 +08:00
panda c8456ce3f6 数据库操作 2025-08-06 17:29:46 +08:00
panda c8d268647f 数据库操作 2025-08-06 16:24:17 +08:00
panda aa0b71a90b 通用文件读取更新 2025-08-06 14:50:27 +08:00
panda 196df754bc 通用文件读取 2025-08-06 12:33:56 +08:00
panda 40f011c66c log测试更新 2025-08-06 11:06:16 +08:00
panda 69deb0cd39 log更新 2025-08-06 10:54:08 +08:00
panda c2a941d4f5 log更新 2025-08-06 09:27:44 +08:00
panda fad2b2d1c8 md文档更新 2025-08-05 17:13:19 +08:00
panda e5da1203c0 ai初期模板 2025-08-05 15:01:04 +08:00
panda 71e9c7c5bc ai初期模板 2025-08-05 15:00:46 +08:00
66 changed files with 2022 additions and 210001 deletions
-7
View File
@@ -1,7 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourcePerFileMappings">
<file url="file://$PROJECT_DIR$/tools/SQL.sql" value="36976640-4e4b-40d7-80c5-f77ff8c735e5" />
<file url="file://$PROJECT_DIR$/tools/情报收集.sql" value="36976640-4e4b-40d7-80c5-f77ff8c735e5" />
</component>
</project>
-15
View File
@@ -1,15 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" uploadOnCheckin="ee95b41f-4bcf-4810-b328-4a7a4f66093f" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="gitea">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
</serverData>
<option name="myUploadOnCheckinName" value="gitea" />
</component>
</project>
-12
View File
@@ -1,12 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="MaterialThemeProjectNewConfig">
<option name="metadata">
<MTProjectMetadataState>
<option name="migrated" value="true" />
<option name="pristineConfig" value="false" />
<option name="userId" value="-2834c26c:198a1f98ccf:-7ffe" />
</MTProjectMetadataState>
</option>
</component>
</project>
+1 -1
View File
@@ -3,5 +3,5 @@
<component name="Black">
<option name="sdkName" value="Python 3.13 (intelligence_system)" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="intelligence_system" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="intelligence" project-jdk-type="Python SDK" />
</project>
+1 -6
View File
@@ -1,12 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="SqlDialectMappings">
<file url="file://$PROJECT_DIR$/tools/SQL.sql" dialect="MySQL" />
<file url="file://$PROJECT_DIR$/tools/情报收集.sql" dialect="MySQL" />
<file url="PROJECT" dialect="MySQL" />
</component>
<component name="SqlResolveMappings">
<file url="file://$PROJECT_DIR$/utils/mysql_agent.py" scope="{&quot;node&quot;:{ &quot;@negative&quot;:&quot;1&quot;, &quot;group&quot;:{ &quot;@kind&quot;:&quot;root&quot;, &quot;node&quot;:{ &quot;@negative&quot;:&quot;1&quot; } } }}" />
<file url="file://$PROJECT_DIR$/storage/mysql_agent.py" scope="{&quot;node&quot;:{ &quot;@negative&quot;:&quot;1&quot;, &quot;group&quot;:{ &quot;@kind&quot;:&quot;root&quot;, &quot;node&quot;:{ &quot;@negative&quot;:&quot;1&quot; } } }}" />
<file url="PROJECT" scope="{&quot;node&quot;:{ &quot;@negative&quot;:&quot;1&quot;, &quot;group&quot;:{ &quot;@kind&quot;:&quot;root&quot;, &quot;node&quot;:{ &quot;@negative&quot;:&quot;1&quot; } } }}" />
</component>
</project>
-14
View File
@@ -1,14 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="WebServers">
<option name="servers">
<webServer id="ee95b41f-4bcf-4810-b328-4a7a4f66093f" name="gitea" url="">
<fileTransfer accessType="WEBDAV" port="6180">
<advancedOptions>
<advancedOptions dataProtectionLevel="Private" passiveMode="true" shareSSLContext="true" />
</advancedOptions>
</fileTransfer>
</webServer>
</option>
</component>
</project>
Binary file not shown.
Binary file not shown.
Binary file not shown.
View File
Binary file not shown.
-326
View File
@@ -1,326 +0,0 @@
import feedparser
import requests
from datetime import datetime
import pandas as pd
import os
import pickle
import time
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
from loguru import logger
from typing import Dict, List, Optional, Any
# Add the parent directory to the Python path to find utils module
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
if parent_dir not in sys.path:
sys.path.insert(0, parent_dir)
from utils.mysql_agent import MySQLAgent
# 数据库连接配置
local_DB_Config = {
'host': "localhost",
'user': "root",
'password': "123123",
'database': "intelligence_system",
'port': 3306,
'charset': 'utf8mb4',
'connect_timeout': 10,
'read_timeout': 30,
'write_timeout': 30,
'autocommit': True
}
# 目标数据表名
table_name = "collector_rss_subscriptions"
class NewsAPIClient:
"""新闻API客户端,用于获取和处理RSS源数据并写入数据库"""
def __init__(self):
"""初始化客户端并建立数据库连接"""
self.logger = logger.bind(module="NewsAPIClient")
self.db_agent = MySQLAgent(local_DB_Config)
self.logger.info("新闻API客户端初始化完成,已连接到数据库")
def _format_result(self, success: bool, message: str = "", data: Optional[Any] = None) -> Dict[str, Any]:
"""统一返回结果格式"""
return {
'success': bool(success),
'message': str(message),
'data': data
}
def verify_database(self) -> bool:
"""验证数据库表结构是否符合要求(适配元组格式的查询结果)"""
try:
# 1. 检查表是否存在(execute_sql返回元组列表,如 [(table_name,)]
result = self.db_agent.execute_sql(
f"SHOW TABLES LIKE '{table_name}'",
fetch=True
)
# 元组结果需通过索引0判断(若表存在,result是[(table_name,)], 否则为空列表)
if not result:
self.logger.error(f"{table_name} 不存在,请先创建表结构")
return False
# 2. 检查表字段是否完整(DESCRIBE返回的元组格式:(字段名, 类型, 是否为空, ...))
desc_result = self.db_agent.execute_sql(
f"DESCRIBE {table_name}",
fetch=True
)
# 关键修改:用元组索引0提取字段名(而非字典键'Field'
columns = [col[0] for col in desc_result] # col是元组,col[0]即字段名
required_columns = ['文章标题', '文章链接', '文章摘要', '发布时间',
'来源URL', '创建时间', '更新时间']
missing_cols = [col for col in required_columns if col not in columns]
if missing_cols:
self.logger.error(f"{table_name} 缺少必要字段:{missing_cols}")
return False
self.logger.info(f"数据库表结构验证通过,当前字段:{columns}")
return True
except Exception as e:
self.logger.error(f"数据库验证失败: {str(e)}", exc_info=True)
return False
def load_last_update_time(self) -> Optional[datetime]:
"""加载上次更新时间缓存"""
cache_file = os.path.join(os.getcwd(), 'output', 'last_update.pkl')
if os.path.exists(cache_file):
try:
with open(cache_file, 'rb') as f:
last_update = pickle.load(f)
self.logger.debug(f"加载上次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
return last_update
except Exception as e:
self.logger.error(f"加载上次更新时间失败: {str(e)}", exc_info=True)
self.logger.debug("未找到上次更新时间缓存,将获取全部数据")
return None
def save_last_update_time(self, last_update: datetime) -> None:
"""保存本次更新时间"""
try:
cache_dir = os.path.join(os.getcwd(), 'output')
os.makedirs(cache_dir, exist_ok=True)
cache_file = os.path.join(cache_dir, 'last_update.pkl')
with open(cache_file, 'wb') as f:
pickle.dump(last_update, f)
self.logger.debug(f"已保存本次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
except Exception as e:
self.logger.error(f"保存更新时间失败: {str(e)}", exc_info=True)
def fetch_single_rss(self, url: str, timeout: int = 15) -> Optional[feedparser.FeedParserDict]:
"""获取并解析单个RSS源"""
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
for attempt in range(3):
try:
response = requests.get(url, headers=headers, timeout=timeout)
response.raise_for_status()
response.encoding = response.apparent_encoding
feed = feedparser.parse(response.text)
if feed.bozo:
self.logger.warning(f"解析 {url} 存在潜在问题: {feed.bozo_exception}")
self.logger.debug(f"成功获取 {url} 的RSS数据")
return feed
except requests.RequestException as e:
self.logger.warning(f"{attempt + 1} 次获取 {url} 失败: {str(e)}")
if attempt < 2:
time.sleep(3 * (attempt + 1)) # 指数退避重试
continue
self.logger.error(f"三次尝试后仍无法获取 {url} 的RSS数据")
return None
def fetch_all_rss(self, urls: List[str], timeout: int = 15) -> Dict[str, feedparser.FeedParserDict]:
"""并发获取多个RSS源"""
feeds = {}
with ThreadPoolExecutor(max_workers=3) as executor:
future_to_url = {executor.submit(self.fetch_single_rss, url, timeout): url for url in urls}
for future in as_completed(future_to_url):
url = future_to_url[future]
try:
feed = future.result()
if feed:
feeds[url] = feed
except Exception as e:
self.logger.error(f"处理 {url} 时发生异常: {str(e)}", exc_info=True)
self.logger.info(f"RSS源获取完成,成功获取 {len(feeds)}/{len(urls)} 个源")
return feeds
def process_feed_entry(self, entry: Dict[str, Any], url: str) -> Dict[str, str]:
"""处理单个RSS条目,转换为数据库兼容格式"""
# 处理标题
title = entry.get('title', '无标题') or '无标题'
if len(title) > 255:
title = title[:252] + '...'
# 处理链接
link = entry.get('link', '无链接') or '无链接'
if len(link) > 1024:
link = link[:1021] + '...'
# 处理摘要
summary = entry.get('summary', '无内容摘要')
content_list = entry.get('content', [])
content = content_list[0].value if (content_list and hasattr(content_list[0], 'value')) else ''
description = summary if summary != '无内容摘要' else (content[:200] + '...' if content else '无内容摘要')
# 处理发布时间
published_parsed = entry.get('published_parsed') or entry.get('updated_parsed')
if published_parsed:
entry_time = datetime(*published_parsed[:6])
else:
pub_str = entry.get('published', entry.get('updated', ''))
try:
entry_time = datetime.strptime(pub_str, '%a, %d %b %Y %H:%M:%S %z').replace(tzinfo=None)
except:
entry_time = datetime.now()
# 处理来源URL
source_url = url or '未知来源'
if len(source_url) > 1024:
source_url = source_url[:1021] + '...'
# 当前时间(创建/更新时间)
current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
return {
'文章标题': title,
'文章链接': link,
'文章摘要': description,
'发布时间': entry_time.strftime('%Y-%m-%d %H:%M:%S'),
'来源URL': source_url,
'创建时间': current_time,
'更新时间': current_time
}
def display_feed_info(self, feed: feedparser.FeedParserDict, last_update: Optional[datetime] = None,
url: Optional[str] = None) -> Optional[datetime]:
"""处理RSS源信息并写入数据库"""
if not feed:
self.logger.warning("无法处理空的RSS源数据")
return None
self.logger.info(f"开始处理 RSS 源: {url}")
entries = feed.entries
data_list = []
new_last_update = last_update
for i, entry in enumerate(entries, 1):
entry_data = self.process_feed_entry(entry, url)
entry_time = datetime.strptime(entry_data['发布时间'], '%Y-%m-%d %H:%M:%S')
# 过滤旧数据
if last_update and entry_time <= last_update:
continue
# 更新最新时间戳
if new_last_update is None or entry_time > new_last_update:
new_last_update = entry_time
self.logger.debug(f"处理条目 {i}: {entry_data['文章标题']}")
data_list.append(entry_data)
# 写入数据库
if data_list:
df = pd.DataFrame(data_list)
self.write_to_database(df)
return new_last_update
# rss_subscriptions.py 中的 write_to_database 方法可以保持简洁
def write_to_database(self, df: pd.DataFrame) -> Dict[str, Any]:
if df.empty:
self.logger.info("没有新数据需要写入数据库")
return self._format_result(True, "没有新数据需要写入")
try:
inserted_rows = self.db_agent.insert_from_df(
table_name=table_name,
df=df,
chunk_size=500,
ignore_duplicates=True
)
self.logger.info(f"成功写入 {inserted_rows}/{len(df)} 条记录")
return self._format_result(
True,
f"成功写入 {inserted_rows}/{len(df)} 条记录",
{"success_count": inserted_rows, "total": len(df)}
)
except Exception as e:
self.logger.error(
"数据库写入失败",
error=str(e),
error_type=type(e).__name__,
table_name=table_name,
record_count=len(df),
sample_records=df.head(2).to_dict('records') if not df.empty else [],
exc_info=True
)
return self._format_result(False, f"数据库操作失败: {str(e)}")
@classmethod
def main(cls):
"""主函数入口"""
try:
client = cls()
# 验证数据库
if not client.verify_database():
client.logger.error("数据库验证失败,程序终止")
return
# RSS源列表
rss_urls = [
"https://www.chinanews.com.cn/rss/finance.xml",
"https://www.chinanews.com.cn/rss/world.xml",
"https://www.chinanews.com.cn/rss/china.xml",
"https://www.chinanews.com.cn/rss/scroll-news.xml"
]
# 加载上次更新时间
last_update = client.load_last_update_time()
if last_update:
client.logger.info(f"上次更新时间: {last_update.strftime('%Y-%m-%d %H:%M:%S')}")
# 获取RSS数据
client.logger.info("开始获取RSS源数据...")
start_time = time.time()
feeds = client.fetch_all_rss(rss_urls)
client.logger.info(f"获取完成,耗时: {time.time() - start_time:.2f}")
# 处理并写入数据
new_last_update = None
for url, feed in feeds.items():
current_last_update = client.display_feed_info(feed, last_update, url)
if current_last_update and (new_last_update is None or current_last_update > new_last_update):
new_last_update = current_last_update
# 保存最新更新时间
if new_last_update:
client.save_last_update_time(new_last_update)
client.logger.info(f"本次最新更新时间: {new_last_update.strftime('%Y-%m-%d %H:%M:%S')}")
else:
client.logger.info("没有获取到新内容")
except Exception as e:
logger.error(f"程序运行出错: {str(e)}", exc_info=True)
if __name__ == "__main__":
NewsAPIClient.main()
-44
View File
@@ -1,44 +0,0 @@
import os
class Config:
MYSQL_CONFIG = {
'host': '123.60.167.249',
'port': 3306,
'user': 'intelligence',
'password': '123123',
'database': "intelligence_system",
'max_connections': 10
}
OFFLINE_MYSQL_CONFIG = {
'host': 'localhost',
'port': 3306,
'user': 'root',
'password': '123123',
'database': "intelligence_system",
'max_connections': 10
}
MINIO_CONFIG = {
'endpoint': '127.0.0.1:9005',
'access_key': 'admin',
'secret_key': 'abc88888888',
'secure': False # 社区版默认不启用SSL
}
# 百度AI API配置(千帆平台)
# 优先从环境变量读取,如果没有则使用默认值(需要用户自行配置)
BAIDU_AI_CONFIG = {
'api_key': os.getenv('BAIDU_API_KEY', 'bce-v3/ALTAK-SFA4vEP3uBYLsyqCZcERg/1f43596d40d9a2c8318b13d5888a5e8e4e7a7f30'), # 百度千帆API Key
'model': 'ernie-x1-turbo-32k', # 使用的模型
}
# AI处理器配置
AI_PROCESSOR_CONFIG = {
'batch_size': 10, # 批量处理的默认大小
'delay': 1.5, # 每条记录之间的延迟(秒),避免API限流
'source_table': 'processed_rss_data', # 源数据表
'result_table': 'ai_processor_rss_analysis', # AI分析结果表
}
+273
View File
@@ -0,0 +1,273 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
配置初始化模块
功能:
1. 自动生成默认配置文件
2. 多环境配置支持(dev/test/prod
3. 敏感信息加密存储
4. 配置完整性检查与修复
"""
import os
import json
import platform
from pathlib import Path
from typing import Dict, Any, Optional
import logging
from cryptography.fernet import Fernet
import hashlib
# 初始化日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('config_init')
class ConfigInitializer:
"""配置初始化工具类"""
def __init__(self, app_name: str = "intelligence_system"):
self.system = platform.system().lower()
self.app_name = app_name
self.config_dir = self._get_config_dir()
self.config_file = self.config_dir / "config.json"
self.secret_key_file = self.config_dir / ".secret.key"
self._fernet = None
# 确保配置目录存在
self.config_dir.mkdir(parents=True, exist_ok=True)
# 设置文件权限(非Windows
if self.system != 'windows':
os.chmod(self.config_dir, 0o700)
def _get_config_dir(self) -> Path:
"""获取适合当前平台的配置目录路径"""
if self.system == 'windows':
return Path(os.environ['APPDATA']) / self.app_name
elif self.system == 'darwin': # macOS
return Path.home() / "Library" / "Application Support" / self.app_name
else: # Linux及其他Unix-like
xdg_config = os.getenv('XDG_CONFIG_HOME', '~/.config')
return Path(xdg_config).expanduser() / self.app_name
def _init_encryption(self):
"""初始化加密模块"""
if not self.secret_key_file.exists():
self.secret_key_file.write_bytes(Fernet.generate_key())
if self.system != 'windows':
self.secret_key_file.chmod(0o600) # 仅用户可读写
self._fernet = Fernet(self.secret_key_file.read_bytes())
def encrypt_value(self, plaintext: str) -> str:
"""加密敏感信息"""
if not self._fernet:
self._init_encryption()
return self._fernet.encrypt(plaintext.encode()).decode()
def decrypt_value(self, ciphertext: str) -> str:
"""解密信息"""
if not self._fernet:
self._init_encryption()
return self._fernet.decrypt(ciphertext.encode()).decode()
def _get_default_config(self) -> Dict[str, Any]:
"""获取默认配置模板"""
return {
"system": {
"env": "dev", # dev/test/prod
"log_level": "INFO",
"max_threads": max(1, os.cpu_count() or 4),
"data_dir": str(self.config_dir / "data")
},
"api": {
"newsapi": {
"endpoint": "https://newsapi.org/v2",
"key": "" # 需加密存储
},
"weibo": {
"version": "2",
"access_token": "" # 需加密存储
}
},
"database": {
"type": "sqlite",
"path": str(self.config_dir / "data.db")
},
"network": {
"timeout": 30,
"retries": 3,
"proxy": "" # 示例: http://user:pass@proxy:port
}
}
def _migrate_old_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
"""旧配置迁移(兼容性处理)"""
# 示例:将旧版api_key迁移到新版结构
if 'api_key' in config:
config.setdefault('api', {})['newsapi'] = {
'key': config.pop('api_key')
}
return config
def _validate_config(self, config: Dict[str, Any]) -> bool:
"""验证配置完整性"""
required_keys = {
"system": ["env", "log_level"],
"api/newsapi": ["endpoint"]
}
for path, keys in required_keys.items():
current = config
for part in path.split('/'):
current = current.get(part, {})
if not isinstance(current, dict):
return False
for key in keys:
if key not in current:
return False
return True
def _repair_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
"""自动修复缺失的配置项"""
default_config = self._get_default_config()
def _merge(current, default):
for key, value in default.items():
if key not in current:
current[key] = value
elif isinstance(value, dict):
_merge(current[key], value)
return current
return _merge(config, default_config)
def init_config(self, force: bool = False) -> bool:
"""
初始化配置文件
参数:
force: 是否强制重新生成配置
返回:
bool: 是否创建了新配置
"""
config = None
# 已有配置文件且不强制重置
if self.config_file.exists() and not force:
try:
with open(self.config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
# 配置迁移和修复
config = self._migrate_old_config(config)
if not self._validate_config(config):
config = self._repair_config(config)
logger.warning("自动修复不完整的配置文件")
except Exception as e:
logger.error(f"加载现有配置失败: {str(e)}")
config = None
# 需要创建新配置
if config is None:
config = self._get_default_config()
logger.info("创建新的配置文件")
# 加密敏感字段
self._init_encryption()
for field in [
"api/newsapi/key",
"api/weibo/access_token",
"network/proxy"
]:
parts = field.split('/')
current = config
for part in parts[:-1]:
current = current.setdefault(part, {})
if parts[-1] in current and current[parts[-1]]:
current[parts[-1]] = self.encrypt_value(current[parts[-1]])
# 保存配置
with open(self.config_file, 'w', encoding='utf-8') as f:
json.dump(config, f, indent=2, ensure_ascii=False)
# 设置文件权限(非Windows
if self.system != 'windows':
os.chmod(self.config_file, 0o600)
return True
def get_config_hash(self) -> str:
"""获取配置文件哈希值(用于检测变更)"""
if not self.config_file.exists():
return ""
with open(self.config_file, 'rb') as f:
return hashlib.sha256(f.read()).hexdigest()
def create_env_specific_config(self, env: str = None) -> bool:
"""
创建环境特定配置
参数:
env: 环境类型(dev/test/prod
"""
if not self.config_file.exists():
self.init_config()
with open(self.config_file, 'r', encoding='utf-8') as f:
base_config = json.load(f)
env = env or base_config['system']['env']
env_config = {
f"env_{env}": {
"api": {
"newsapi": {"endpoint": self._get_env_endpoint(env)}
},
"database": {
"path": str(self.config_dir / f"data_{env}.db")
}
}
}
env_file = self.config_dir / f"config.{env}.json"
with open(env_file, 'w', encoding='utf-8') as f:
json.dump(env_config, f, indent=2)
return True
def _get_env_endpoint(self, env: str) -> str:
"""获取环境特定的API端点"""
endpoints = {
"dev": "http://dev-api.example.com",
"test": "https://test-api.example.com",
"prod": "https://api.example.com"
}
return endpoints.get(env, endpoints['dev'])
# 快捷初始化函数
def init_app_config(app_name: str = None, force: bool = False) -> bool:
"""
快速初始化应用配置
参数:
app_name: 应用名称
force: 是否强制重新初始化
"""
return ConfigInitializer(app_name).init_config(force)
# 测试代码
if __name__ == "__main__":
# 初始化配置
initializer = ConfigInitializer()
if initializer.init_config():
print("配置文件已生成:", initializer.config_file)
# 创建环境配置示例
initializer.create_env_specific_config("prod")
print("生产环境配置已生成")
# 加密演示
encrypted = initializer.encrypt_value("my_secret_key")
print("加密示例:", encrypted)
print("解密测试:", initializer.decrypt_value(encrypted))
+409
View File
@@ -0,0 +1,409 @@
import os
import sys
import platform
import pandas as pd
import pymysql
from pymysql import cursors
from pymysql.err import MySQLError
from dbutils.pooled_db import PooledDB
from typing import Union, List, Dict, Any, Optional, Tuple
import threading
from datetime import datetime
import numpy as np
from pathlib import Path
# 导入您的日志系统
from utils.logger import log as logger
class MySQLAgent:
"""
全平台兼容的MySQL数据库操作类
支持Windows/macOS/Linux系统
"""
_instance = None
_lock = threading.Lock()
# 各平台特定的配置
PLATFORM_CONFIG = {
'Windows': {
'socket_timeout': 30,
'connect_timeout': 10,
'ssl': None
},
'Darwin': { # macOS
'socket_timeout': 60,
'connect_timeout': 15,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
},
'Linux': {
'socket_timeout': 60,
'connect_timeout': 15,
'ssl': None
}
}
def __new__(cls, *args, **kwargs):
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, config: dict = None):
if hasattr(self, '_pool') and self._pool:
return
if not config:
from config.settings import DATABASE_CONFIG
config = DATABASE_CONFIG
# 获取当前平台配置
current_platform = platform.system()
platform_config = self.PLATFORM_CONFIG.get(current_platform, {})
# 基础配置
self.config = {
'host': config.get('host', 'localhost'),
'port': config.get('port', 3306),
'user': config.get('user', 'root'),
'password': config.get('password', ''),
'database': config.get('database', 'intelligence_system'),
'charset': config.get('charset', 'utf8mb4'),
'cursorclass': cursors.DictCursor,
'autocommit': True,
**platform_config # 合并平台特定配置
}
# 处理各平台路径差异
if current_platform == 'Windows':
self.config['ssl'] = None # Windows通常不需要SSL配置
# macOS特殊处理
elif current_platform == 'Darwin':
if not os.path.exists(self.config['ssl']['ca']):
self.config['ssl'] = None
logger.warning("macOS SSL certificate not found, disabling SSL")
self.pool_size = config.get('max_connections', 5)
self._pool = self._create_pool()
self.logger = logger.bind(module=f"MySQLAgent({current_platform})")
def _create_pool(self) -> PooledDB:
"""创建跨平台兼容的连接池"""
try:
# 各平台连接池参数调整
pool_config = {
'creator': pymysql,
'maxconnections': self.pool_size,
'mincached': 1,
'maxcached': 3,
'blocking': True,
'ping': 1, # 定期检查连接有效性
**self.config
}
# Windows平台需要更短的超时时间
if platform.system() == 'Windows':
pool_config['ping'] = 0 # Windows上ping有时不稳定
pool = PooledDB(**pool_config)
self.logger.info(f"Connection pool created for {platform.system()}")
return pool
except Exception as e:
self.logger.critical("Failed to create connection pool",
error=str(e),
exc_info=True)
raise
def _handle_path(self, path: str) -> str:
"""处理跨平台路径问题"""
if platform.system() == 'Windows':
return path.replace('/', '\\')
return path
def get_connection(self) -> pymysql.connections.Connection:
"""
获取数据库连接(跨平台兼容)
Returns:
pymysql.connections.Connection: 数据库连接
Raises:
MySQLError: 如果连接失败
"""
try:
conn = self._pool.connection()
# macOS需要特殊处理SSL
if platform.system() == 'Darwin' and self.config.get('ssl'):
conn.ping(reconnect=True)
self.logger.trace("Connection obtained")
return conn
except Exception as e:
error_msg = str(e)
# Windows特定错误处理
if platform.system() == 'Windows' and "timed out" in error_msg:
self.logger.warning("Windows connection timeout, retrying...")
return self._retry_connection()
self.logger.error("Connection failed",
error=error_msg,
exc_info=True)
raise
def _retry_connection(self, max_retries: int = 3) -> pymysql.connections.Connection:
"""Windows平台连接重试机制"""
for attempt in range(max_retries):
try:
conn = self._pool.connection()
self.logger.info(f"Connection established after {attempt+1} attempts")
return conn
except Exception:
if attempt == max_retries - 1:
raise
import time
time.sleep(1)
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
parse_dates: Union[List[str], bool] = True) -> pd.DataFrame:
"""
跨平台兼容的SQL查询
Args:
sql (str): SQL语句
params (Union[tuple, dict, None]): 参数
parse_dates (Union[List[str], bool]): 日期解析
Returns:
pd.DataFrame: 查询结果
"""
try:
with self.get_connection() as conn:
# Linux/macOS需要更长的查询超时
if platform.system() != 'Windows':
conn.cursor().execute("SET SESSION wait_timeout=600")
df = pd.read_sql(sql, conn, params=params, parse_dates=parse_dates)
# Windows平台需要手动关闭游标
if platform.system() == 'Windows':
conn.cursor().close()
self.logger.info("Query executed", rows=len(df))
return df
except Exception as e:
self.logger.error("Query failed",
sql=sql,
params=params,
error=str(e),
exc_info=True)
raise
def insert_from_df(self, table_name: str, df: pd.DataFrame,
chunk_size: int = 1000, replace: bool = False) -> int:
"""
跨平台数据插入
Args:
table_name (str): 表名
df (pd.DataFrame): 数据
chunk_size (int): 分批大小
replace (bool): 是否替换
Returns:
int: 插入行数
"""
if df.empty:
self.logger.warning("Empty DataFrame", table=table_name)
return 0
try:
method = 'replace' if replace else 'append'
total_rows = 0
with self.get_connection() as conn:
# 各平台不同的分批策略
if platform.system() == 'Windows':
chunk_size = min(chunk_size, 500) # Windows上减小批次
for i in range(0, len(df), chunk_size):
chunk = df.iloc[i:i + chunk_size]
# macOS需要特殊处理datetime
if platform.system() == 'Darwin':
for col in chunk.select_dtypes(include=['datetime64']):
chunk[col] = chunk[col].dt.strftime('%Y-%m-%d %H:%M:%S')
chunk.to_sql(
table_name,
conn,
if_exists=method,
index=False,
method='multi'
)
total_rows += len(chunk)
method = 'append'
self.logger.info("Data inserted", table=table_name, rows=total_rows)
return total_rows
except Exception as e:
self.logger.error("Insert failed",
table=table_name,
error=str(e),
exc_info=True)
raise
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
"""
跨平台SQL执行
Args:
sql (str): SQL语句
params (Union[tuple, dict, None]): 参数
fetch (bool): 是否获取结果
Returns:
Union[int, List[Dict[str, Any]]]: 结果
"""
conn = None
cursor = None
try:
conn = self.get_connection()
cursor = conn.cursor()
# Linux/macOS需要更长的执行时间
if platform.system() != 'Windows':
cursor.execute("SET SESSION max_execution_time=600000")
cursor.execute(sql, params)
if fetch:
result = cursor.fetchall()
self.logger.debug("Query executed", rows=len(result))
return result
else:
affected_rows = cursor.rowcount
self.logger.debug("Update executed", affected_rows=affected_rows)
return affected_rows
except Exception as e:
self.logger.error("SQL execution failed",
sql=sql,
params=params,
error=str(e),
exc_info=True)
raise
finally:
if cursor:
cursor.close()
if conn:
conn.close()
def begin_transaction(self) -> pymysql.connections.Connection:
"""开始事务(跨平台兼容)"""
try:
conn = self.get_connection()
conn.autocommit(False)
# macOS需要特殊处理事务隔离级别
if platform.system() == 'Darwin':
conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED")
self.logger.debug("Transaction started")
return conn
except Exception as e:
self.logger.error("Begin transaction failed", error=str(e))
raise
def commit_transaction(self, conn: pymysql.connections.Connection) -> None:
"""提交事务(跨平台兼容)"""
try:
conn.commit()
self.logger.debug("Transaction committed")
except Exception as e:
self.logger.error("Commit failed", error=str(e))
raise
finally:
conn.close()
def rollback_transaction(self, conn: pymysql.connections.Connection) -> None:
"""回滚事务(跨平台兼容)"""
try:
conn.rollback()
self.logger.warning("Transaction rolled back")
except Exception as e:
self.logger.error("Rollback failed", error=str(e))
finally:
conn.close()
def __del__(self):
"""析构函数(跨平台资源清理)"""
if hasattr(self, '_pool'):
try:
self._pool.close()
self.logger.info("Connection pool closed")
except Exception as e:
self.logger.error("Failed to close pool", error=str(e))
# 平台特定的默认配置
def get_default_config():
"""获取各平台默认配置"""
current_platform = platform.system()
base_config = {
'host': 'localhost',
'port': 3306,
'user': 'root',
'password': '',
'database': 'intelligence_system',
'max_connections': 5
}
if current_platform == 'Windows':
return {
**base_config,
'connect_timeout': 10,
'read_timeout': 30,
'write_timeout': 30
}
elif current_platform == 'Darwin':
return {
**base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
}
else: # Linux和其他平台
return {
**base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60
}
# 使用示例
if __name__ == "__main__":
# 自动获取适合当前平台的配置
config = get_default_config()
# 初始化数据库连接
db = MySQLAgent(config)
# 测试查询
try:
df = db.query_to_df("SELECT VERSION() as version")
print(f"Database version: {df['version'].iloc[0]}")
print(f"Running on: {platform.system()} {platform.release()}")
except Exception as e:
print(f"Error: {str(e)}")
+120
View File
@@ -0,0 +1,120 @@
## 情报收集系统设计
### 参考文档
https://alidocs.dingtalk.com/i/nodes/NZQYprEoWoexdo1ohPdxXvDbJ1waOeDk?utm_scene=team_space
### 程序框架
```angular2html
intelligence_system/
├── config/ # 系统配置中心
│ ├── __init__.py # 配置包初始化
│ ├── settings.py # 主配置文件(数据库连接、API密钥等)
│ └── scheduler_rules.yaml # 任务调度规则
├── data_collection/ # 数据采集层
│ ├── spiders/ # 网络爬虫子系统
│ │ ├── weibo_spider.py # 黑猫爬虫
│ │
│ ├── api_integration/ # API接口子系统
│ │ ├── news_api.py # 新闻接口
│ │
│ └── internal/ # 内部数据收集
│ ├── jian_dao_cloud.py # 简道云表单收集器
├── data_processing/ # 数据处理层
│ ├── structured/ # 结构化数据处理
│ │ ├── data_cleaner.py # 数据清洗(去重/标准化)
│ │ └── schema_mapper.py # 数据结构转换器
│ │
│ ├── unstructured/ # 非结构化数据处理
│ │ ├── text_parser.py # 文本解析(PDF/HTML等)
│ │ ├── image_analyzer.py # 图像识别(OpenCV集成)
│ │ └── video_processor.py # 音视频分离分析
│ │
│ └── ai_engine/ # AI分析核心
│ ├── nlp_processor.py # 自然语言处理引擎
│ ├── sentiment_analyzer.py # 情感分析模型
│ └── topic_modeler.py # LDA主题建模工具
├── storage/ # 数据存储层
│ ├── mysql_agent.py # MySQL读写管理器
│ └── query_builder.py # SQL动态构建器
├── services/ # 应用服务层
│ ├── monitoring/ # 舆情监控
│ │ ├── opinion_monitor.py # 实时舆情追踪
│ │ └── brand_reputation.py # 品牌口碑分析
│ │
│ ├── analysis/ # 竞品分析
│ │ ├── competitor_tracker.py # 竞品动态监控
│ │ └── swot_generator.py # SWOT分析报告
│ │
│ ├── reporting/ # 报告服务
│ │ ├── daily_reporter.py # 自动化日报生成
│ │ └── weekly_digest.py # 周报汇编系统
│ │
│ └── alert/ # 预警服务
│ ├── alert_trigger.py # 动态阈值告警
│ └── notification_center.py # 邮件/短信通知
├── system_management/ # 系统管理层
│ ├── scheduler/ # 任务调度
│ │ └── task_scheduler.py # 任务调度器
│ │
│ └── monitor/ # 系统监控
│ ├── health_monitor.py # 服务健康检测
│ └── performance_watcher.py # 资源占用监控
├── utils/ # 工具库
│ ├── file_handler.py # 通用文件操作
│ ├── logger.py # 日志系统
│ └── datetime_parser.py # 时间格式处理
└── main.py # 系统入口(启动所有服务)
```
### 程序设计原则
1. 所有程序尽可能在py文件中运行,尽量避免使用命令行执行
2. 配置需要在配置类中定义
3. 密钥等信息直接放在配置类中
### 主程序设计
主程序需要一次启动,一直运行,启动时运行一次(在代码中可取消),之后每天定时生成一次报告
主程序包含爬虫/api调度器。该调度器通过查询mysql中任务调度情况按需执行,db文件中应包含任务名称、
任务路径、任务执行频率(支持按天、按周,按分钟)、上次执行时间、下次执行时间等信息
主程序应包含数据处理调度器,根据数据类别分别处理,如文本数据处理调度器、图片数据处理调度器等,
每天定时拉取db获取到的原始数据,分别进行处理,处理完成后将结果保存到mysql中
主程序应包含日报、周报等生成,根据时间定时生成报告,报告需要存储
### 日志设计
日志系统应兼容多个平台,如win、mac和linux,日志需要保存为log文件,并且在日志大于20mb时自动压缩
### 数据库链接设计
数据存储放在数据库中,数据库类型为mysql,数据库名称为intelligence_system
数据库表的命名规则与目录一致,数据采集类以collector_为开头,数据处理类以processor_为开
头,数据存储类以storage_为开头,应用层类以application_为开头
依次类推。
数据库链接为通用配置,要求数据采集或处理类等,可以直接调用封装好的数据库
链接,不必每次都重新写,
该链接包含表的增删改查功能,以及执行sql语句功能
数据库结构:
1. collector_news_api:新闻api数据表
2. collector_complaint_spider:投诉数据表
3. processor_text_processor:文本处理数据表
4. processor_image_processor:图片处理数据表
5. main_task 任务调度表
6. application_reporter_daily:日报数据表
7. application_reporter_monthly:周报数据表
### 数据采集设计
每一个数据采集均为独立python文件,里面执行主程序均为main,以方便调度
每一个数据采集均会根据规则创建数据库表,数据处理类以processor_为开头,(或者统一维护到一个表中,按来源去区分)
### 数据处理
从多个数据库库表中获取数据,对数据进行处理,处理完成后将结果保存到数据库中,处理结果可能存储在多个表中
数据处理数据库表以processor_为开头
-27
View File
@@ -1,27 +0,0 @@
# 列出所有任务
python system_management/scheduler/task_management.py list
# 只显示活跃任务
python system_management/scheduler/task_management.py list --active-only
# 查看任务详情
python system_management/scheduler/task_management.py show 1
# 更新任务Cron表达式
python system_management/scheduler/task_management.py update <task_id> --cron "0 10 * * *"
# 启用任务
python system_management/scheduler/task_management.py toggle 1 --activate
# 禁用任务
python system_management/scheduler/task_management.py toggle 1 --deactivate
# 手动执行任务
python system_management/scheduler/task_management.py run 1
# 添加新任务
python system_management/scheduler/task_management.py add \
--name "hourly_data_check" \
--type "processor" \
--module "processors.data_checker" \
--cron "0 * * * *"
-292
View File
@@ -1,292 +0,0 @@
# 对象存储数据库操作.md
## 1. 类概述
`MinIOAgent` 是一个全平台兼容的对象存储操作类,支持 Windows/macOS/Linux 系统,提供对象存储的桶管理、对象操作、权限控制等功能。
### 核心特性:
- ✅ 连接池管理与自动重连
- ✅ 全平台兼容的对象操作接口
- ✅ 支持大文件分块上传/下载
- ✅ 预签名 URL 生成(临时访问)
- ✅ 完善的日志记录与错误处理
- ✅ 批量操作与前缀筛选
---
## 2. 初始化配置
### 基本配置参数
```python
Config = {
'endpoint': '127.0.0.1:9005', # 对象存储服务地址
'access_key': 'minioadmin', # 访问密钥
'secret_key': 'minioadmin', # 密钥
'secure': False, # 是否启用SSL(社区版默认False)
'region': 'us-east-1', # 区域(默认值)
'timeout': 300, # 超时时间(秒)
'max_pool_connections': 10 # 连接池最大连接数
}
```
### 各平台特殊配置
| 平台 | 超时设置(秒) | 分块大小建议 | 并发数建议 |
|---------|----------------|--------------|------------|
| Windows | 300 | 5MB-10MB | 2-4 |
| macOS | 300 | 10MB-20MB | 4-8 |
| Linux | 300 | 20MB-50MB | 8-16 |
### 初始化示例
```python
from utils.minio_agent import MinIOAgent
# 基础初始化
config = {
'endpoint': '127.0.0.1:9005',
'access_key': 'minioadmin',
'secret_key': 'minioadmin',
'secure': False
}
# 创建客户端实例
minio_client = MinIOAgent(config)
```
---
## 3. 桶(Bucket)管理
### 桶操作
```python
# 创建桶
if minio_client.create_bucket('my-bucket'):
print("桶创建成功")
# 检查桶是否存在
if minio_client.bucket_exists('my-bucket'):
print("桶已存在")
# 列出所有桶
buckets = minio_client.list_buckets()
for bucket in buckets:
print(f"桶名称: {bucket['name']}, 创建时间: {bucket['creation_date']}")
# 删除桶(需先清空桶内对象)
if minio_client.delete_bucket('my-bucket'):
print("桶删除成功")
```
### 桶策略管理
```python
# 获取桶策略
policy = minio_client.get_bucket_policy('my-bucket')
print(policy)
# 设置公共读策略
public_read_policy = {
"Version": "2012-10-17",
"Statement": [{
"Effect": "Allow",
"Principal": "*",
"Action": ["s3:GetObject"],
"Resource": ["arn:aws:s3:::my-bucket/*"]
}]
}
minio_client.set_bucket_policy('my-bucket', public_read_policy)
```
---
## 4. 对象(Object)操作
### 上传对象
```python
# 从文件上传
upload_meta = minio_client.upload_file(
bucket_name='my-bucket',
object_name='documents/report.pdf',
file_path='/local/path/to/report.pdf',
content_type='application/pdf' # MIME类型
)
print(f"上传成功,大小: {upload_meta['size']} bytes")
# 从字节流上传
data = b"test data"
upload_meta = minio_client.upload_bytes(
bucket_name='my-bucket',
object_name='test/data.bin',
data=data
)
# 大文件分块上传
upload_meta = minio_client.upload_large_file(
bucket_name='my-bucket',
object_name='videos/large_file.mp4',
file_path='/local/path/to/large.mp4',
part_size=5*1024*1024 # 5MB分块
)
```
### 下载对象
```python
# 下载到文件
download_meta = minio_client.download_file(
bucket_name='my-bucket',
object_name='documents/report.pdf',
file_path='/local/save/path/report.pdf'
)
# 下载为字节流
data = minio_client.download_bytes(
bucket_name='my-bucket',
object_name='test/data.bin'
)
print(f"下载数据: {data}")
```
### 查询与列举对象
```python
# 列举桶内所有对象
objects = minio_client.list_objects('my-bucket')
for obj in objects:
print(f"对象: {obj['object_name']}, 大小: {obj['size']}")
# 按前缀筛选(类似文件夹)
pdf_files = minio_client.list_objects(
bucket_name='my-bucket',
prefix='documents/', # 前缀(类似文件夹路径)
recursive=False # 是否递归查询子目录
)
# 获取对象元信息
meta = minio_client.get_object_metadata(
bucket_name='my-bucket',
object_name='documents/report.pdf'
)
print(f"内容类型: {meta['content_type']}, 最后修改: {meta['last_modified']}")
```
### 删除对象
```python
# 删除单个对象
if minio_client.delete_object('my-bucket', 'test/data.bin'):
print("对象删除成功")
# 批量删除对象
delete_count = minio_client.delete_objects(
bucket_name='my-bucket',
object_names=['file1.txt', 'file2.txt', 'docs/report.pdf']
)
print(f"成功删除 {delete_count} 个对象")
```
---
## 5. 高级功能
### 预签名 URL(临时访问)
```python
# 生成下载预签名URL(有效期30分钟)
download_url = minio_client.get_presigned_url(
bucket_name='my-bucket',
object_name='documents/report.pdf',
expires=1800, # 有效期(秒)
method='GET' # 访问方法(GET下载,PUT上传)
)
print(f"临时下载链接: {download_url}")
# 生成上传预签名URL(允许客户端直接上传)
upload_url = minio_client.get_presigned_url(
bucket_name='my-bucket',
object_name='user_uploads/image.jpg',
expires=3600,
method='PUT'
)
```
### 批量操作
```python
# 批量复制对象(同桶内)
copy_results = minio_client.copy_objects(
source_bucket='my-bucket',
dest_bucket='my-bucket',
object_mapping={
'documents/report.pdf': 'archive/report_2024.pdf',
'data/raw.csv': 'data/backup/raw_2024.csv'
}
)
# 批量移动对象(跨桶)
move_results = minio_client.move_objects(
source_bucket='my-bucket',
dest_bucket='archive-bucket',
object_prefix='2023/' # 移动所有以2023/为前缀的对象
)
```
### 生命周期管理
```python
# 设置对象生命周期规则(自动迁移/删除)
rule = {
"Rules": [{
"ID": "archive-old-files",
"Status": "Enabled",
"Prefix": "logs/",
"Expiration": {
"Days": 90 # 90天后自动删除
},
"Transition": {
"Days": 30, # 30天后迁移到低频存储
"StorageClass": "STANDARD_IA"
}
}]
}
minio_client.set_bucket_lifecycle('my-bucket', rule)
```
---
## 6. 异常处理
```python
from minio.error import S3Error
try:
# 尝试上传对象
minio_client.upload_file(
bucket_name='my-bucket',
object_name='critical/data.csv',
file_path='/local/data.csv'
)
except S3Error as e:
if e.code == 'NoSuchBucket':
print("桶不存在,创建后重试")
minio_client.create_bucket('my-bucket')
elif e.code == 'AccessDenied':
print("权限不足,请检查密钥")
else:
print(f"上传失败: {e}")
except Exception as e:
print(f"发生错误: {str(e)}")
```
---
## 7. 性能优化建议
1. **大文件处理**
- 超过100MB的文件建议使用分块上传(`upload_large_file`
- 根据网络状况调整分块大小(5-50MB)
2. **批量操作**
- 列举对象时使用前缀筛选减少返回数据量
- 批量删除/复制时单次操作不超过1000个对象
3. **缓存策略**
- 对频繁访问的对象使用预签名URL并设置合理过期时间
- 客户端缓存对象元数据减少请求次数
4. **并发控制**
- 多线程操作时控制并发数(参考平台建议值)
- 避免同时对同一对象进行写操作
-2
View File
@@ -1,2 +0,0 @@
## 开发进度
###
+1 -1
View File
@@ -28,7 +28,7 @@
### 基本配置参数
```python
Config = {
{
'host': 'localhost', # 数据库主机
'port': 3306, # 端口
'user': 'root', # 用户名
+1 -1
View File
@@ -5,7 +5,7 @@
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/logs" />
</content>
<orderEntry type="jdk" jdkName="intelligence_system" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="intelligence" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>
+108 -133929
View File
File diff suppressed because it is too large Load Diff
-71015
View File
File diff suppressed because it is too large Load Diff
+74 -97
View File
@@ -1,134 +1,111 @@
# main.py
import signal
import time
from datetime import datetime
from system_management.scheduler.task_scheduler import TaskScheduler
from utils.logger import CrossPlatformLog
from config import Config
# 初始化日志
log = CrossPlatformLog.get_logger("Main")
class IntelligenceSystem:
def __init__(self, db_config=None, run_all_on_startup=False):
"""初始化系统(仅作为容器,不包含业务逻辑)
Args:
db_config: 数据库配置
run_all_on_startup: 启动时是否立即执行所有到期任务(默认False)
"""
self.scheduler = TaskScheduler(Config.MYSQL_CONFIG, max_workers=5)
def __init__(self, db_config=None):
self.scheduler = TaskScheduler(db_config)
self._running = False
self.run_all_on_startup = run_all_on_startup
log.info(f"情报系统已初始化(Cron模式),启动时执行任务: {run_all_on_startup}")
log.info("IntelligenceSystem initialized")
def start(self):
"""启动系统主入口"""
def run(self):
"""启动系统主循环"""
self._running = True
self._setup_signal_handlers()
log.info("系统启动 - 运行在Cron调度模式")
# 启动时执行所有到期任务(如果开关开启)
if self.run_all_on_startup:
print(f"\n{'='*60}")
print("🚀 启动时执行所有到期任务...")
print(f"{'='*60}\n")
log.info("启动时执行所有到期任务")
result = self.scheduler.check_and_run_tasks(print_empty_status=True)
print(f"\n启动任务执行完成: 总数={result['总任务数']}, 成功={result['成功']}, 失败={result['失败']}\n")
self._register_signal_handlers()
# 时间追踪变量
last_status_print_time = time.time() # 上次打印状态的时间
last_hourly_report_time = time.time() # 上次小时统计的时间
status_print_interval = 60 # 每分钟打印一次状态(60秒)
hourly_report_interval = 3600 # 每小时统计一次(3600秒)
log.info("Starting main loop")
try:
# 主循环 - 仅负责定期检查任务
while self._running:
current_time = time.time()
# 判断是否需要打印状态(每分钟一次)
should_print_status = (current_time - last_status_print_time) >= status_print_interval
# 检查并执行到期任务
self.scheduler.check_and_run_tasks(print_empty_status=should_print_status)
# 更新最后打印时间
if should_print_status:
last_status_print_time = current_time
# 检查是否需要进行小时统计(每小时一次)
if (current_time - last_hourly_report_time) >= hourly_report_interval:
self._print_hourly_stats()
last_hourly_report_time = current_time
start_time = time.time()
self._run_cycle()
# 短间隔轮询(每10秒检查一次,保证Cron时间精度
time.sleep(10)
# 精确控制循环间隔(扣除执行时间
elapsed = time.time() - start_time
sleep_time = max(0, 60 - elapsed)
time.sleep(sleep_time)
except KeyboardInterrupt:
log.info("Received keyboard interrupt")
except Exception as e:
log.critical("系统主循环崩溃", exc_info=True)
log.critical(
"System crashed",
exc_info=True
)
raise
finally:
self.shutdown()
def _setup_signal_handlers(self):
"""设置系统信号处理器"""
def _run_cycle(self):
"""单个运行周期"""
try:
# 1. 执行任务调度
result = self.scheduler.run_pending_tasks()
# 2. 每小时记录系统状态
if datetime.now().minute == 0:
self._log_system_status()
except Exception as e:
log.error(
"Cycle execution failed",
exc_info=True
)
raise
def _log_system_status(self):
"""记录系统状态"""
try:
status_df = self.scheduler.get_task_status()
pending = len(status_df[status_df['next_run_time'] <= datetime.now()])
log.info(
"System status",
pending_tasks=pending,
active_tasks=len(status_df),
last_success=status_df['last_run_time'].max()
)
except Exception as e:
log.error(
"Failed to log system status",
exc_info=True
)
def _register_signal_handlers(self):
"""注册信号处理"""
signal.signal(signal.SIGINT, self._handle_shutdown)
signal.signal(signal.SIGTERM, self._handle_shutdown)
log.debug("信号处理器已注册")
log.debug("Signal handlers registered")
def _handle_shutdown(self, signum, frame):
"""处理系统关闭信号"""
log.info(f"收到关闭信号 {signum},开始关闭系统")
"""处理关闭信号"""
log.info(
f"Processing shutdown signal {signum}",
signal=signum
)
self._running = False
def _print_hourly_stats(self):
"""打印并重置小时统计信息"""
stats = self.scheduler.get_and_reset_hourly_stats()
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print(f"\n{'='*60}")
print(f"📊 小时任务统计报告 - {now}")
print(f"{'='*60}")
print(f" 总任务数: {stats['总数']}")
print(f" 成功: {stats['成功']}")
print(f" 失败: {stats['失败']}")
if stats['总数'] > 0:
success_rate = (stats['成功'] / stats['总数']) * 100
print(f" 成功率: {success_rate:.1f}%")
print(f"{'='*60}\n")
log.info(
"小时任务统计",
总任务数=stats['总数'],
成功=stats['成功'],
失败=stats['失败']
)
def shutdown(self):
"""优雅关闭系统"""
log.info("开始优雅关闭系统")
# 等待所有正在执行的任务完成
self.scheduler.executor.shutdown(wait=True, cancel_futures=False)
# 记录最终状态
pending_count = self.scheduler.get_pending_tasks_count()
log.info(
"系统关闭完成",
pending_tasks=pending_count,
shutdown_time=datetime.now()
)
"""关闭系统"""
log.info("Performing system shutdown")
# 此处可添加其他清理逻辑
log.success("System shutdown completed")
if __name__ == "__main__":
try:
# 启动系统 - 仅作为入口,不包含调度逻辑
# run_all_on_startup=True: 启动时立即执行所有到期任务
# run_all_on_startup=False: 启动时不执行任务,等待下次调度周期
system = IntelligenceSystem(run_all_on_startup=False)
system.start()
system = IntelligenceSystem()
system.run()
except Exception as e:
log.critical("情报系统启动失败", exc_info=True)
log.critical(
"System startup failed",
exc_info=True
)
raise
Binary file not shown.
@@ -1,453 +0,0 @@
# RSS数据AI处理模块
import os
import sys
import json
import time
import pandas as pd
from typing import List, Dict, Any, Optional
from datetime import datetime
from openai import OpenAI
# 添加项目根目录到路径
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(os.path.dirname(current_dir))
if parent_dir not in sys.path:
sys.path.insert(0, parent_dir)
from utils.mysql_agent import MySQLAgent
from utils.logger import log
from config import Config
class RSSDataAIProcessor:
"""RSS数据AI处理主类
负责:
- 从数据库加载未处理的RSS数据
- 调用AI进行分析
- 保存分析结果
- 更新处理状态
"""
def __init__(self):
"""初始化AI处理器"""
self.log = log.bind(module="RSSDataAIProcessor")
self.db_agent = MySQLAgent(Config.MYSQL_CONFIG)
# 从Config读取配置
self.source_table = Config.AI_PROCESSOR_CONFIG['source_table']
self.ai_table = Config.AI_PROCESSOR_CONFIG['result_table']
self.default_batch_size = Config.AI_PROCESSOR_CONFIG['batch_size']
self.default_delay = Config.AI_PROCESSOR_CONFIG['delay']
# 初始化百度千帆API客户端
self.api_key = Config.BAIDU_AI_CONFIG.get('api_key')
if self.api_key:
self.ai_client = OpenAI(
base_url='https://qianfan.baidubce.com/v2',
api_key=self.api_key
)
self.model = Config.BAIDU_AI_CONFIG.get('model', 'ernie-x1-turbo-32k')
self.log.info("RSS数据AI处理器初始化完成")
else:
self.ai_client = None
self.log.warning("百度AI未配置,AI处理功能将不可用")
self.log.warning("请在config.py中配置 BAIDU_AI_CONFIG['api_key']")
def is_configured(self) -> bool:
"""检查是否已配置API"""
return self.ai_client is not None
def main(self, batch_size: Optional[int] = 200, delay: Optional[float] = None) -> Dict[str, Any]:
"""主程序:批量处理RSS数据的完整流程
Args:
batch_size: 批量处理的记录数,None则使用配置的默认值
delay: 每条记录之间的延迟(秒),None则使用配置的默认值
Returns:
dict: 处理结果统计信息
"""
# 使用传入参数或默认配置
batch_size = batch_size or self.default_batch_size
delay = delay or self.default_delay
try:
# 1. 检查配置
if not self.is_configured():
error_msg = "百度AI未配置,请在config.py中配置 BAIDU_AI_CONFIG['api_key']"
self.log.error(error_msg)
return {
'success': False,
'message': error_msg,
'processed_count': 0,
'failed_count': 0
}
self.log.info(f"开始批量处理数据,批次大小: {batch_size}, 延迟: {delay}")
# 2. 准备数据库表结构
self.ensure_ai_processed_column()
if not self.db_agent.table_exists(self.ai_table):
self.create_ai_result_table()
# 3. 加载未处理的数据
df = self.load_unprocessed_data(batch_size)
if df.empty:
self.log.info("没有需要处理的数据")
return {
'success': True,
'message': '没有需要处理的数据',
'processed_count': 0,
'failed_count': 0
}
# 4. 处理每条记录
results = []
processed_ids = []
failed_count = 0
for idx, record in df.iterrows():
try:
self.log.debug(f"处理记录 {record['id']} ({idx + 1}/{len(df)})")
result = self.process_single_record(record.to_dict())
if result:
results.append(result)
processed_ids.append(record['id'])
else:
failed_count += 1
# 延迟,避免API限流
if delay > 0 and idx < len(df) - 1:
time.sleep(delay)
except Exception as e:
self.log.error(f"处理记录 {record['id']} 异常: {str(e)}", exc_info=True)
failed_count += 1
# 5. 保存结果
saved_count = 0
if results:
saved_count = self.save_ai_results(results)
# 6. 标记为已处理
if processed_ids:
self.mark_as_processed(processed_ids)
# 7. 返回统计信息
stats = {
'success': True,
'message': 'AI处理完成',
'total_count': len(df),
'processed_count': len(processed_ids),
'saved_count': saved_count,
'failed_count': failed_count,
'relevant_count': sum(1 for r in results if r.get('是否相关')),
'processing_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
self.log.info("批量处理完成", **stats)
return stats
except Exception as e:
error_msg = f"批量处理失败: {str(e)}"
self.log.error(error_msg, exc_info=True)
return {
'success': False,
'message': error_msg,
'processed_count': 0,
'failed_count': 0
}
def ensure_ai_processed_column(self):
"""确保processed_rss_data表有"是否ai处理"字段"""
try:
# 检查字段是否存在
check_sql = """
SELECT COUNT(*) as count
FROM information_schema.COLUMNS
WHERE TABLE_SCHEMA = %s
AND TABLE_NAME = %s
AND COLUMN_NAME = '是否ai处理'
"""
result = self.db_agent.execute_sql(
check_sql,
params=(Config.MYSQL_CONFIG['database'], self.source_table),
fetch=True
)
if result[0][0] == 0:
# 字段不存在,添加字段
alter_sql = f"""
ALTER TABLE {self.source_table}
ADD COLUMN 是否ai处理 TINYINT(1) DEFAULT 0 COMMENT 'AI处理标记:0-未处理,1-已处理'
"""
self.db_agent.execute_sql(alter_sql)
self.log.info(f"成功为表 {self.source_table} 添加 '是否ai处理' 字段")
else:
self.log.debug(f"{self.source_table} 已存在 '是否ai处理' 字段")
except Exception as e:
self.log.error(f"检查/添加字段失败: {str(e)}", exc_info=True)
raise
def create_ai_result_table(self):
"""创建AI处理结果表"""
create_sql = f"""
CREATE TABLE IF NOT EXISTS {self.ai_table} (
id INT AUTO_INCREMENT PRIMARY KEY COMMENT '主键ID',
source_id INT NOT NULL COMMENT '来源数据IDprocessed_rss_data.id',
文章标题 TEXT COMMENT '文章标题',
文章摘要 TEXT COMMENT '文章摘要',
发布时间 DATETIME COMMENT '发布时间',
来源URL VARCHAR(1024) COMMENT '来源URL',
文章链接 VARCHAR(1024) COMMENT '文章链接',
是否相关 BOOLEAN COMMENT 'AI判断是否与汽车后市场相关',
相关度评分 INT COMMENT '相关度评分(0-100',
标签 TEXT COMMENT 'AI生成的标签(JSON数组)',
分类 VARCHAR(100) COMMENT 'AI判断的主要分类',
分析说明 TEXT COMMENT 'AI分析说明',
处理时间 DATETIME COMMENT 'AI处理时间',
创建时间 TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '记录创建时间',
更新时间 TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '记录更新时间',
INDEX idx_source_id (source_id),
INDEX idx_是否相关 (是否相关),
INDEX idx_分类 (分类),
INDEX idx_处理时间 (处理时间)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='RSS数据AI分析结果表'
"""
try:
self.db_agent.execute_sql(create_sql)
self.log.info(f"成功创建AI结果表: {self.ai_table}")
except Exception as e:
self.log.error(f"创建AI结果表失败: {str(e)}", exc_info=True)
raise
def load_unprocessed_data(self, limit: int = 100) -> pd.DataFrame:
"""加载未经AI处理的数据
Args:
limit: 每次处理的记录数量
Returns:
未处理的数据DataFrame
"""
try:
sql = f"""
SELECT id, 文章标题, 文章摘要, 发布时间, 来源URL, 文章链接
FROM {self.source_table}
WHERE 是否ai处理 = 0 OR 是否ai处理 IS NULL
ORDER BY 创建时间 DESC
LIMIT %s
"""
df = self.db_agent.query_to_df(sql, params=(limit,), is_print=False)
self.log.info(f"成功加载 {len(df)} 条未处理的数据")
return df
except Exception as e:
self.log.error(f"加载未处理数据失败: {str(e)}", exc_info=True)
return pd.DataFrame()
def analyze_news(self, title: str, summary: str) -> Dict[str, Any]:
"""调用AI分析新闻(保留原有提示词)"""
# 构建提示词(保留原有格式)
prompt = f"""分析以下新闻是否与汽车后市场相关,返回JSON格式:
标题:{title}
摘要:{summary}
返回格式:
{{
"is_relevant": true/false,
"relevance_score": 0-100,
"tags": ["标签1", "标签2"],
"category": "分类(配件/维修/保养/改装/美容/装饰/二手车/金融/保险/其他)",
"analysis": "简要说明"
}}
注意:只返回JSON格式的结果,不要包含其他说明文字。"""
try:
# 调用百度千帆API
response = self.ai_client.chat.completions.create(
model=self.model,
messages=[{
"role": "user",
"content": prompt
}]
)
# 获取响应内容
raw_content = response.choices[0].message.content
# 解析JSON(处理markdown包裹)
if '```json' in raw_content:
json_str = raw_content.split('```json')[1].split('```')[0].strip()
elif '```' in raw_content:
json_str = raw_content.split('```')[1].split('```')[0].strip()
else:
json_str = raw_content.strip()
result = json.loads(json_str)
# 补充缺失字段
return {
'is_relevant': result.get('is_relevant', False),
'relevance_score': result.get('relevance_score', 0),
'tags': result.get('tags', []),
'category': result.get('category', '其他'),
'analysis': result.get('analysis', '')
}
except json.JSONDecodeError as e:
self.log.warning(f"JSON解析失败: {str(e)}, 原始响应: {raw_content[:200]}")
return {
'is_relevant': False,
'relevance_score': 0,
'tags': [],
'category': '其他',
'analysis': f"解析失败: {raw_content[:100]}"
}
except Exception as e:
self.log.error(f"AI调用异常: {str(e)}", exc_info=True)
return {
'is_relevant': False,
'relevance_score': 0,
'tags': [],
'category': '其他',
'analysis': f"处理异常: {str(e)}"
}
def process_single_record(self, record: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""处理单条记录
Args:
record: 记录字典
Returns:
处理结果字典
"""
if not self.is_configured():
self.log.error("AI客户端未配置,无法处理数据")
return None
try:
title = str(record.get('文章标题', '')).strip()
summary = str(record.get('文章摘要', '')).strip()
if not title and not summary:
self.log.warning(f"记录 {record.get('id')} 标题和摘要均为空,跳过处理")
return None
# 调用AI分析
analysis_result = self.analyze_news(title, summary)
# 构建结果记录
result = {
'source_id': record['id'],
'文章标题': title,
'文章摘要': summary,
'发布时间': record.get('发布时间'),
'来源URL': record.get('来源URL'),
'文章链接': record.get('文章链接'),
'是否相关': analysis_result.get('is_relevant', False),
'相关度评分': analysis_result.get('relevance_score', 0),
'标签': json.dumps(analysis_result.get('tags', []), ensure_ascii=False),
'分类': analysis_result.get('category', '其他'),
'分析说明': analysis_result.get('analysis', ''),
'处理时间': datetime.now()
}
return result
except Exception as e:
self.log.error(f"处理记录 {record.get('id')} 失败: {str(e)}", exc_info=True)
return None
def save_ai_results(self, results: List[Dict[str, Any]]) -> int:
"""保存AI处理结果
Args:
results: 处理结果列表
Returns:
成功保存的记录数
"""
if not results:
return 0
try:
df = pd.DataFrame(results)
inserted = self.db_agent.insert_from_df(
table_name=self.ai_table,
df=df,
ignore_duplicates=True
)
self.log.info(f"成功保存 {inserted} 条AI处理结果")
return inserted
except Exception as e:
self.log.error(f"保存AI处理结果失败: {str(e)}", exc_info=True)
return 0
def mark_as_processed(self, ids: List[int]) -> bool:
"""标记记录为已处理
Args:
ids: 记录ID列表
Returns:
是否成功
"""
if not ids:
return True
try:
id_placeholders = ','.join(['%s'] * len(ids))
sql = f"""
UPDATE {self.source_table}
SET 是否ai处理 = 1
WHERE id IN ({id_placeholders})
"""
self.db_agent.execute_sql(sql, params=ids)
self.log.info(f"成功标记 {len(ids)} 条记录为已处理")
return True
except Exception as e:
self.log.error(f"标记记录为已处理失败: {str(e)}", exc_info=True)
return False
if __name__ == "__main__":
"""命令行直接运行"""
# 实例化处理器并调用main方法
processor = RSSDataAIProcessor()
result = processor.main()
# 输出结果
if result['success']:
print("\n" + "=" * 60)
print("✓ AI处理完成")
print("=" * 60)
print(f"总记录数: {result.get('total_count', 0)}")
print(f"成功处理: {result.get('processed_count', 0)}")
print(f"保存记录: {result.get('saved_count', 0)}")
print(f"失败记录: {result.get('failed_count', 0)}")
print(f"相关记录: {result.get('relevant_count', 0)}")
print(f"处理时间: {result.get('processing_time', '')}")
print("=" * 60 + "\n")
else:
print("\n" + "=" * 60)
print("✗ 处理失败")
print("=" * 60)
print(f"错误信息: {result['message']}")
print("\n提示: 请设置环境变量")
print(" Windows: $env:BAIDU_API_KEY = 'your_key'")
print(" Linux/Mac: export BAIDU_API_KEY='your_key'")
print("=" * 60 + "\n")
View File
-37
View File
@@ -1,37 +0,0 @@
汽车配件
汽车维修
汽车保养
汽车改装
汽车美容
汽车装饰
轮胎
机油
刹车片
火花塞
滤清器
蓄电池
车灯
保险杠
车门
座椅
方向盘
仪表盘
音响
导航
汽车用品
车载设备
汽车电子
汽车安全
汽车保险
二手车
汽车交易
汽车金融
汽车租赁
汽车服务
4S店
汽修店
汽车后市场
汽车产业链
汽车供应链
汽车
-409
View File
@@ -1,409 +0,0 @@
# RSS数据处理模块 - 汽车后市场新闻分词和过滤
import pandas as pd
import jieba
import jieba.posseg as pseg
import os
import sys
from typing import List, Dict, Any, Optional
from datetime import datetime
# 添加项目根目录到路径
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
if parent_dir not in sys.path:
sys.path.insert(0, parent_dir)
from utils.mysql_agent import MySQLAgent
from utils.logger import log
from config import Config
class RSSDataProcessor:
"""RSS数据处理器 - 专门处理汽车后市场相关新闻"""
def __init__(self):
"""初始化处理器"""
self.log = log.bind(module="RSSDataProcessor")
self.db_agent = MySQLAgent(Config.MYSQL_CONFIG)
self.table_name = "collector_rss_subscriptions"
self.processed_table_name = "processed_rss_data"
# 获取项目根目录
current_dir = os.path.dirname(os.path.abspath(__file__))
self.project_root = os.path.dirname(current_dir)
# 设置文件路径(相对于项目根目录)
self.keywords_file = os.path.join(self.project_root, "processors", "keywords.txt")
self.stopwords_file = os.path.join(self.project_root, "processors", "stopwords.txt")
# 汽车后市场相关关键词(默认值,实际从文件加载)
self.auto_aftermarket_keywords = {
'汽车配件', '汽车维修', '汽车保养', '汽车改装', '汽车美容', '汽车装饰',
'轮胎', '机油', '刹车片', '火花塞', '滤清器', '蓄电池', '车灯',
'保险杠', '车门', '座椅', '方向盘', '仪表盘', '音响', '导航',
'汽车用品', '车载设备', '汽车电子', '汽车安全', '汽车保险',
'二手车', '汽车交易', '汽车金融', '汽车租赁', '汽车服务',
'4S店', '汽修店', '汽车后市场', '汽车产业链', '汽车供应链', '汽车', ''
}
# 停用词表(默认值,实际从文件加载)
self.stopwords = {
'', '', '', '', '', '', '', '', '', '', '', '', '一个',
'', '', '', '', '', '', '', '', '', '', '没有', '', '',
'自己', '', '', '', '', '', '我们', '你们', '他们', '什么', '怎么',
'为什么', '因为', '所以', '但是', '然后', '如果', '虽然', '而且', '或者',
'可以', '应该', '必须', '需要', '想要', '希望', '觉得', '认为', '知道',
'了解', '明白', '清楚', '简单', '容易', '困难', '重要', '主要', '基本',
'一般', '特别', '非常', '十分', '相当', '比较', '更加', '', '',
'已经', '正在', '将要', '可能', '也许', '大概', '大约', '左右', '上下',
'今天', '明天', '昨天', '现在', '以前', '以后', '时候', '时间', '地方',
'这里', '那里', '这样', '那样', '如此', '这样', '那样', '如何', '怎样'
}
# 缓存关键词,避免重复加载
self._cached_keywords = None
self.log.info("RSS数据处理器初始化完成")
def load_keywords(self, keywords_file: Optional[str] = None) -> set:
"""从文件加载汽车后市场关键词(带缓存)"""
# 如果已经缓存,直接返回
if self._cached_keywords is not None:
return self._cached_keywords
# 使用默认路径(项目根目录下的文件)
if keywords_file is None:
keywords_file = self.keywords_file
keywords = set()
try:
if os.path.exists(keywords_file):
with open(keywords_file, 'r', encoding='utf-8') as f:
keywords = set(line.strip() for line in f if line.strip())
self.log.info(f"成功加载汽车后市场关键词,共 {len(keywords)}")
else:
self.log.warning(f"关键词文件不存在: {keywords_file}")
# 使用默认关键词
keywords = self.auto_aftermarket_keywords
except Exception as e:
self.log.error(f"加载关键词失败: {str(e)}")
keywords = self.auto_aftermarket_keywords
# 缓存关键词
self._cached_keywords = keywords
return keywords
def load_rss_data(self, limit: int = 1000) -> pd.DataFrame:
"""从数据库加载未处理的RSS数据"""
try:
sql = f"""
SELECT id, 文章标题, 文章摘要, 发布时间, 来源URL, 文章链接
FROM {self.table_name}
WHERE 是否已处理 = 0
ORDER BY 发布时间 DESC
LIMIT %s
"""
df = self.db_agent.query_to_df(sql, params=(limit,), is_print=False)
self.log.info(f"成功加载 {len(df)} 条未处理的RSS数据")
return df
except Exception as e:
self.log.error(f"加载RSS数据失败: {str(e)}", exc_info=True)
return pd.DataFrame()
def mark_as_processed(self, ids: List[int]) -> bool:
"""标记指定ID的数据为已处理"""
if not ids:
return True
try:
# 将ID列表转换为字符串格式用于SQL IN语句
id_placeholders = ','.join(['%s'] * len(ids))
sql = f"""
UPDATE {self.table_name}
SET 是否已处理 = 1
WHERE id IN ({id_placeholders})
"""
result = self.db_agent.execute_sql(sql, params=ids)
self.log.info(f"成功标记 {len(ids)} 条数据为已处理")
return True
except Exception as e:
self.log.error(f"标记数据为已处理失败: {str(e)}", exc_info=True)
return False
def load_stopwords(self, stopwords_file: Optional[str] = None) -> set:
"""加载停用词表"""
# 使用默认路径(项目根目录下的文件)
if stopwords_file is None:
stopwords_file = self.stopwords_file
try:
if os.path.exists(stopwords_file):
with open(stopwords_file, 'r', encoding='utf-8') as f:
stopwords = set(line.strip() for line in f if line.strip())
self.log.info(f"成功加载停用词表,共 {len(stopwords)} 个词")
return stopwords
else:
self.log.warning(f"停用词文件不存在: {stopwords_file},使用默认停用词")
return self.stopwords
except Exception as e:
self.log.error(f"加载停用词表失败: {str(e)}")
return self.stopwords
def add_custom_dict(self, custom_dict_file: Optional[str] = None):
"""添加自定义词典"""
if custom_dict_file and os.path.exists(custom_dict_file):
try:
jieba.load_userdict(custom_dict_file)
self.log.info("成功加载自定义词典")
except Exception as e:
self.log.warning(f"加载自定义词典失败: {str(e)}")
# 从文件加载汽车后市场关键词并添加到jieba词典
keywords = self.load_keywords()
for keyword in keywords:
jieba.add_word(keyword, freq=1000, tag='n')
def segment_and_pos(self, text: str, stopwords: set) -> List[str]:
"""分词并标注词性,过滤停用词"""
if not text or pd.isna(text):
return []
words = pseg.cut(str(text))
result = []
# 汽车后市场相关的词性标签
allowed_flags = {'n', 'vn', 'np', 'ns', 'nr', 'nt'} # 名词、动词、动名词、名词短语、处所词、人名、机构名
for word, flag in words:
word = word.strip()
if (len(word) >= 1 and
word not in stopwords and
flag in allowed_flags and
not word.isdigit()): # 过滤纯数字
result.append(word)
return result
def is_auto_aftermarket_related(self, text: str) -> bool:
"""判断文本是否与汽车后市场相关"""
if not text:
return False
text_lower = str(text).lower()
# 从文件加载关键词
keywords = self.load_keywords()
# 检查是否包含汽车后市场关键词
for keyword in keywords:
if keyword in text_lower:
return True
# 检查分词结果中是否包含相关词汇
words = self.segment_and_pos(text, self.stopwords)
for word in words:
if word in keywords:
return True
return False
def process_dataframe(self, df: pd.DataFrame, stopwords: set) -> pd.DataFrame:
"""处理整个DataFrame,进行分词和过滤"""
if df.empty:
self.log.warning("输入的DataFrame为空")
return df
# 确保所有文本都是字符串,并处理NaN值
df['文章标题'] = df['文章标题'].fillna('').astype(str)
df['文章摘要'] = df['文章摘要'].fillna('').astype(str)
# 合并标题和摘要进行分词
df['combined_text'] = df['文章标题'] + ' ' + df['文章摘要']
# 分词处理
df['segmented_words'] = df['combined_text'].apply(lambda x: self.segment_and_pos(x, stopwords))
# 判断是否与汽车后市场相关(只要出现关键词就入库)
df['is_auto_related'] = df['combined_text'].apply(self.is_auto_aftermarket_related)
df['is_filtered'] = df['is_auto_related']
# 添加处理时间
df['processed_time'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
self.log.info(f"数据处理完成,共处理 {len(df)} 条记录")
return df
def filter_auto_aftermarket_news(self, df: pd.DataFrame) -> pd.DataFrame:
"""过滤出汽车后市场相关的新闻"""
if df.empty:
return df
# 过滤出包含关键词的文章
filtered_df = df[df['is_filtered'] == True].copy()
self.log.info(f"过滤出 {len(filtered_df)} 条汽车后市场相关新闻")
return filtered_df
def save_to_database(self, df: pd.DataFrame) -> bool:
"""保存处理结果到数据库"""
if df.empty:
self.log.warning("没有数据需要保存")
return False
try:
# 准备保存的数据
save_df = df[['文章标题', '文章摘要', '发布时间', '来源URL', '文章链接',
'segmented_words', 'is_auto_related', 'processed_time']].copy()
# 将分词结果转换为字符串
save_df['分词结果'] = save_df['segmented_words'].apply(lambda x: ' '.join(x))
# 重命名列名为中文
save_df = save_df.rename(columns={
'is_auto_related': '是否汽车相关',
'processed_time': '处理时间'
})
# 删除不需要的列
save_df = save_df.drop('segmented_words', axis=1)
# 检查目标表是否存在,不存在则创建
if not self.db_agent.table_exists(self.processed_table_name):
self.create_processed_table()
# 插入数据
inserted_rows = self.db_agent.insert_from_df(
table_name=self.processed_table_name,
df=save_df,
ignore_duplicates=True
)
self.log.info(f"成功保存 {inserted_rows} 条处理结果到数据库")
return True
except Exception as e:
self.log.error(f"保存到数据库失败: {str(e)}", exc_info=True)
return False
def create_processed_table(self):
"""创建处理结果表"""
create_sql = f"""
CREATE TABLE IF NOT EXISTS {self.processed_table_name} (
id INT AUTO_INCREMENT PRIMARY KEY,
文章标题 TEXT,
文章摘要 TEXT,
发布时间 DATETIME,
来源URL VARCHAR(1024),
文章链接 VARCHAR(1024),
分词结果 TEXT,
是否汽车相关 BOOLEAN,
处理时间 DATETIME,
创建时间 TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
更新时间 TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
"""
try:
self.db_agent.execute_sql(create_sql)
self.log.info(f"成功创建处理结果表: {self.processed_table_name}")
except Exception as e:
self.log.error(f"创建表失败: {str(e)}", exc_info=True)
raise
def get_processing_statistics(self, df: pd.DataFrame) -> Dict[str, Any]:
"""获取处理统计信息"""
if df.empty:
return {}
total_count = len(df)
filtered_count = len(df[df['is_filtered'] == True])
stats = {
'total_articles': total_count,
'filtered_articles': filtered_count,
'filter_rate': filtered_count / total_count if total_count > 0 else 0,
'processing_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
return stats
def process_rss_data(self, limit: int = 1000, save_to_db: bool = True) -> Dict[str, Any]:
"""处理RSS数据的主函数"""
try:
self.log.info("开始处理RSS数据...")
# 1. 加载RSS数据
df = self.load_rss_data(limit)
if df.empty:
self.log.warning("没有加载到RSS数据")
return {'success': False, 'message': '没有数据可处理'}
# 2. 加载停用词表
stopwords = self.load_stopwords()
# 3. 添加自定义词典
self.add_custom_dict()
# 4. 处理数据
processed_df = self.process_dataframe(df, stopwords)
# 5. 过滤汽车后市场相关新闻
filtered_df = self.filter_auto_aftermarket_news(processed_df)
# 6. 获取统计信息
stats = self.get_processing_statistics(processed_df)
# 7. 保存到数据库
if save_to_db and not filtered_df.empty:
save_success = self.save_to_database(filtered_df)
stats['save_success'] = save_success
# 8. 标记数据为已处理
if not df.empty and 'id' in df.columns:
processed_ids = df['id'].tolist()
mark_success = self.mark_as_processed(processed_ids)
stats['mark_success'] = mark_success
if not mark_success:
self.log.warning("部分数据标记为已处理失败")
# 9. 输出结果
self.log.info("RSS数据处理完成", **stats)
return {
'success': True,
'message': 'RSS数据处理完成',
'statistics': stats,
'filtered_data': filtered_df
}
except Exception as e:
self.log.error(f"RSS数据处理失败: {str(e)}", exc_info=True)
return {'success': False, 'message': f'处理失败: {str(e)}'}
def main():
"""主函数入口"""
try:
# 创建处理器实例
processor = RSSDataProcessor()
# 处理RSS数据
result = processor.process_rss_data(
limit=5000, # 处理最近5000条数据
save_to_db=True # 保存到数据库
)
if result['success']:
print("RSS数据处理完成!")
print(f"处理统计: {result['statistics']}")
else:
print(f"处理失败: {result['message']}")
except Exception as e:
print(f"程序运行出错: {str(e)}")
if __name__ == "__main__":
main()
-100
View File
@@ -1,100 +0,0 @@
一个
没有
自己
我们
你们
他们
什么
怎么
为什么
因为
所以
但是
然后
如果
虽然
而且
或者
可以
应该
必须
需要
想要
希望
觉得
认为
知道
了解
明白
清楚
简单
容易
困难
重要
主要
基本
一般
特别
非常
十分
相当
比较
更加
已经
正在
将要
可能
也许
大概
大约
左右
上下
今天
明天
昨天
现在
以前
以后
时候
时间
地方
这里
那里
这样
那样
如此
这样
那样
如何
怎样
View File
-148
View File
@@ -1,148 +0,0 @@
## 情报收集系统设计
### 参考文档
https://alidocs.dingtalk.com/i/nodes/NZQYprEoWoexdo1ohPdxXvDbJ1waOeDk?utm_scene=team_space
### 程序框架(当前实现)
```angular2html
intelligence_system/
├── collectors/ # 数据采集层
│ ├── complaint_spider.py # 投诉信息爬虫(结构化入库/附件走MinIO)
│ ├── rss_subscriptions.py # RSS 订阅抓取
│ └── internal/ # 内部数据收集(保留)
│ └── jian_dao_cloud.py # 简道云表单收集器(示例/占位)
├── processors/ # 数据处理层
│ ├── processor_rss_data.py # RSS数据清洗、分词、过滤与入库
│ ├── keywords.txt # 行业关键词(用于分词/过滤)
│ ├── stopwords.txt # 停用词
│ └── ai_engine/
│ └── ai_proessor_rss_data # 预留(AI分析扩展占位)
├── services/ # 应用服务层(保留)
│ ├── monitoring/ # 舆情监控
│ │ ├── opinion_monitor.py # 实时舆情追踪(占位)
│ │ └── brand_reputation.py # 品牌口碑分析(占位)
│ ├── analysis/ # 竞品分析
│ │ ├── competitor_tracker.py # 竞品动态监控(占位)
│ │ └── swot_generator.py # SWOT分析报告(占位)
│ ├── reporting/ # 报告服务
│ │ ├── daily_reporter.py # 自动化日报生成(占位)
│ │ └── weekly_digest.py # 周报汇编系统(占位)
│ └── alert/ # 预警服务
│ ├── alert_trigger.py # 动态阈值告警(占位)
│ └── notification_center.py # 邮件/短信通知(占位)
├── applications/ # 应用层
│ ├── alert.py # 告警触发/通知(占位/实现中)
│ └── reporter/
│ ├── daily.py # 日报生成
│ └── monthly.py # 月报生成
├── system_management/ # 系统管理层
│ ├── scheduler/
│ │ ├── task_scheduler.py # 任务调度器(Cron表达式 + 线程池)
│ │ └── task_management.py # 任务管理辅助
│ └── monitor/ # 系统监控(目录占位)
├── utils/ # 工具库
│ ├── file_handler.py # 通用文件操作
│ ├── logger.py # 跨平台日志系统(Loguru)
│ ├── mysql_agent.py # MySQL读写管理器
│ └── minio_agent.py # MinIO对象存储客户端
├── config.py # 配置加载与管理(含数据库/存储配置)
├── main.py # 系统入口(Cron轮询 + 调度执行)
└── requirements.txt # 依赖清单
```
### 程序设计原则
1. 所有程序尽可能在py文件中运行,尽量避免使用命令行执行
2. 配置需要在配置类中定义
3. 密钥等信息直接放在配置类中
4. 数据存储遵循"结构化存MySQL,非结构化存MinIO"原则,通过元数据关联
### 主程序与调度设计(已实现)
主程序以长运行进程方式启动,进入轻量轮询循环(每10秒)。调度器按Cron表达式在`main_task`表中拉取到期任务,使用线程池异步执行,并在每分钟输出运行状态、每小时汇总统计。
- 调度器能力:
- 基于`croniter`解析Cron表达式,支持时区(默认`Asia/Shanghai`
- 线程池并发执行,信号量限制最大并发(与`max_workers`一致)
- 任务入口动态解析:支持`package.module``package.module.ClassName.main``package.module.func` 等形式
- 成功/失败后自动计算`next_run_time`或设置15分钟后重试
- 关键字段自动更新:`is_running``last_run_time``last_run_status``run_count``next_run_time`
- 主循环:
- 每10秒检查一次待运行任务
- 每分钟打印当前周期统计;每小时写入累计统计日志
- 支持`SIGINT/SIGTERM`优雅关闭,等待正在运行的任务完成
### 日志设计(已实现)
跨平台日志系统(Loguru)输出至`logs/`目录:
- application.log:主日志,`rotation = 20MB`,达到阈值后压缩为`application.log.YYYYMMDD.zip``retention = 30天`
- errors.log:错误日志(ERROR及以上),`rotation = 10MB``retention = 90天`
- 结构化扩展字段:日志支持`extra`键值对,自动美化并对长字段(如`sql``params`)截断
建议记录的业务事件:
- MySQL读写操作要点(表名、影响行数、事务状态)
- MinIO对象操作(对象路径、大小、耗时、状态)
- 任务执行上下文(task_id、task_name、module_path、耗时、状态)
### 存储系统设计(MinIO+MySQL
#### 核心存储分工
| 存储类型 | 适用数据 | 核心作用 |
|----------|----------|----------|
| MySQL | 结构化数据、元数据、关系型数据 | 存储业务逻辑数据、非结构化数据的索引信息、任务调度信息等 |
| MinIO | 非结构化数据 | 存储图片、视频、PDF文档、原始爬取文件等二进制/大文件数据 |
#### 核心存储配置
1. **MySQL配置**
- 数据库名称:`intelligence_system`
- 连接管理:通过`utils/mysql_agent.py`封装线程安全的连接池,提供结构化数据的增删改查及SQL执行能力
- 适配特性:支持多平台(Windows/macOS/Linux)的超时配置和批处理优化
2. **MinIO配置**
- 存储桶命名规则:按数据类型划分,如`collector-images`(采集层图片)、`processor-videos`(处理层视频)
- 连接管理:通过`utils/minio_agent.py`封装客户端,提供对象上传、下载、删除、查询URL等能力
- 路径规则:`{数据层}/{来源}/{时间戳}_{唯一ID}.{后缀}`(例:`collector/weibo_spider/20240520_12345.jpg`
#### 表命名规则(扩展)
- 数据采集类:以`collector_`为前缀(存储采集到的结构化数据及MinIO对象元数据)
- 数据处理类:以`processor_`为前缀(存储处理结果的结构化数据及MinIO处理后对象的元数据)
- 数据存储类:以`storage_`为前缀(存储MinIO对象的索引信息,如哈希、大小、访问权限等)
- 应用层类:以`application_`为前缀(对应业务应用数据)
- 系统类:如任务调度表等采用功能命名(如`main_task`
#### 核心表结构(当前落地)
1. `main_task`:任务调度表(`task_name``task_type``module_path``cron_expression``time_zone``run_count``is_running``last_run_time``last_run_status``next_run_time``is_active` 等)
2. `collector_rss_subscriptions`RSS源采集数据(`文章标题``文章摘要``发布时间``来源URL``文章链接``是否已处理` 等)
3. `processed_rss_data`RSS处理结果(`分词结果``是否汽车相关``处理时间` 等)
4. `collector_complaint_spider`:投诉信息爬虫数据(含文本与附件MinIO路径`attachment_minio_path`等)
5. 可选:`storage_object_index`(建议用于统一索引MinIO对象元数据)
### 数据采集设计
1. 结构化数据(RSS、投诉文本):写入`collector_`前缀表
2. 非结构化数据(附件/图片等):
- 使用`utils/minio_agent.py`上传至对应存储桶
- 将对象路径与元数据写入业务表或`storage_object_index`
3. 采集模块需同时处理MySQL与MinIO交互,确保关联完整
### 数据处理设计(RSS流程已实现)
`processors/processor_rss_data.py`流程:
-`collector_rss_subscriptions`加载未处理数据(可配置`limit`
- 加载停用词与行业关键词(`stopwords.txt` / `keywords.txt`),并动态注入`jieba`词典
- 标注词性并过滤停用词,仅保留与汽车后市场相关的词汇
- 标记与过滤:出现任一行业关键词即视为相关,进入保存
- 将结果写入`processed_rss_data`,并回写源表`是否已处理 = 1`
- 输出处理统计(总量、命中量、命中率、时间)
### 依赖与运行
- 依赖:见`requirements.txt`pandas、SQLAlchemy、PyMySQL、croniter、pytz、loguru、jieba、feedparser、beautifulsoup4、minio 等)
- 配置:在`config.py`中设置`MYSQL_CONFIG`与MinIO参数
- 运行:
- 启动主程序:`python main.py`
- 添加任务:向`main_task`插入记录,`module_path`可指向如`processors.processor_rss_data.main`
-18
View File
@@ -1,18 +0,0 @@
croniter==3.0.3
dbutils==3.1.2
loguru==0.7.3
minio==7.2.16
numpy==2.3.3
pandas==2.3.2
pymysql==1.1.2
pytest==8.4.2
pytz==2025.2
Requests==2.32.5
SQLAlchemy==2.0.43
tenacity==9.1.2
beautifulsoup4==4.13.5
feedparser==6.0.11
Markdown==3.9
openai==1.107.3
tqdm==4.67.1
jieba==0.42.1
+683
View File
@@ -0,0 +1,683 @@
import os
import sys
import platform
import pandas as pd
import pymysql
from pymysql import cursors
from pymysql.err import MySQLError
from dbutils.pooled_db import PooledDB
from typing import Union, List, Dict, Any, Optional, Tuple
import threading
from datetime import datetime
import numpy as np
from pathlib import Path
# 导入日志系统
from utils.logger import log
class MySQLAgent:
"""
全平台兼容的MySQL数据库操作类
支持Windows/macOS/Linux系统
配置参数从外部传入
"""
_instance = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, config: dict):
"""
初始化MySQL数据库连接
Args:
config (dict): 数据库配置字典,包含以下键:
- host: 数据库主机
- port: 端口
- user: 用户名
- password: 密码
- database: 数据库名
- [可选] charset: 字符集(默认utf8mb4)
- [可选] max_connections: 最大连接数(默认5)
- [可选] connect_timeout: 连接超时(秒)
- [可选] read_timeout: 读取超时(秒)
- [可选] write_timeout: 写入超时(秒)
- [可选] ssl: SSL配置
"""
if hasattr(self, '_pool') and self._pool:
return
# 基础配置
required_keys = ['host', 'port', 'user', 'password', 'database']
if not all(key in config for key in required_keys):
raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}")
self.config = {
'host': config['host'],
'port': config['port'],
'user': config['user'],
'password': config['password'],
'database': config['database'],
'charset': config.get('charset', 'utf8mb4'),
'cursorclass': cursors.DictCursor,
'autocommit': True,
'connect_timeout': config.get('connect_timeout', 10),
'read_timeout': config.get('read_timeout', 30),
'write_timeout': config.get('write_timeout', 30),
'ssl': config.get('ssl')
}
# 初始化log
current_platform = platform.system()
self.log = log.bind(module=f"MySQLAgent({current_platform})")
# 创建连接池
self.pool_size = config.get('max_connections', 5)
self._pool = self._create_pool()
def _create_pool(self) -> PooledDB:
"""创建连接池"""
try:
# 使用包装函数确保线程安全
def connect():
conn = pymysql.connect(**self.config)
conn.threadsafety = 1 # 显式设置线程安全级别
return conn
pool = PooledDB(
creator=connect,
mincached=1,
maxcached=3,
maxconnections=self.pool_size,
blocking=True,
ping=1
)
self.log.info("Connection pool created")
return pool
except Exception as e:
self.log.critical("Failed to create connection pool",
error=str(e),
exc_info=True)
raise
def get_connection(self) -> pymysql.connections.Connection:
"""
获取数据库连接
Returns:
pymysql.connections.Connection: 数据库连接对象
Raises:
MySQLError: 如果获取连接失败
"""
try:
conn = self._pool.connection()
# macOS需要特殊处理SSL
if platform.system() == 'Darwin' and self.config.get('ssl'):
conn.ping(reconnect=True)
self.log.trace("Database connection obtained")
return conn
except Exception as e:
error_msg = str(e)
# Windows特定错误处理
if platform.system() == 'Windows' and "timed out" in error_msg:
self.log.warning("Windows connection timeout, retrying...")
return self._retry_connection()
self.log.error("Connection failed",
error=error_msg,
exc_info=True)
raise
def _retry_connection(self, max_retries: int = 3) -> pymysql.connections.Connection:
"""Windows平台连接重试机制"""
for attempt in range(max_retries):
try:
conn = self._pool.connection()
self.log.info(f"Connection established after {attempt + 1} attempts")
return conn
except Exception:
if attempt == max_retries - 1:
raise
import time
time.sleep(1)
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
parse_dates: Union[List[str], bool] = True) -> pd.DataFrame:
"""
执行SQL查询并返回DataFrame
Args:
sql (str): SQL查询语句
params (Union[tuple, dict, None]): 查询参数
parse_dates (Union[List[str], bool]): 自动解析日期字段
Returns:
pd.DataFrame: 查询结果
Raises:
MySQLError: 如果查询失败
"""
try:
self.log.debug("Executing SQL query", sql=sql)
with self.get_connection() as conn:
# Linux/macOS需要更长的查询超时
if platform.system() != 'Windows':
conn.cursor().execute("SET SESSION wait_timeout=600")
df = pd.read_sql(sql, conn, params=params, parse_dates=parse_dates)
# Windows平台需要手动关闭游标
if platform.system() == 'Windows':
conn.cursor().close()
self.log.info("Query executed successfully", rows=len(df))
return df
except Exception as e:
self.log.error("SQL query failed",
sql=sql,
params=params,
error=str(e),
exc_info=True)
raise
def insert_from_df(self, table_name: str, df: pd.DataFrame,
chunk_size: int = 1000, replace: bool = False) -> int:
"""
将DataFrame数据插入到数据库表(修复版)
Args:
table_name (str): 目标表名
df (pd.DataFrame): 要插入的数据
chunk_size (int): 分批插入大小
replace (bool): 是否替换现有数据
Returns:
int: 插入的总行数
Raises:
MySQLError: 如果插入失败
"""
if df.empty:
self.log.warning("Attempted to insert empty DataFrame", table=table_name)
return 0
self.log.debug("Preparing to insert DataFrame",
table=table_name,
rows=len(df),
chunk_size=chunk_size)
try:
method = 'replace' if replace else 'append'
total_rows = 0
# 创建临时SQLAlchemy引擎(不创建新连接池)
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
# 获取当前连接并包装
conn = self.get_connection()
# 修复连接对象缺少character_set_name的问题
if not hasattr(conn, 'character_set_name'):
conn.character_set_name = lambda: self.config.get('charset', 'utf8mb4')
engine = create_engine(
"mysql+pymysql://",
creator=lambda: conn,
poolclass=StaticPool, # 使用静态池避免创建新连接
connect_args={
'charset': self.config.get('charset', 'utf8mb4'),
'autocommit': True
}
)
try:
for i in range(0, len(df), chunk_size):
chunk = df.iloc[i:i + chunk_size]
# macOS需要特殊处理datetime
if platform.system() == 'Darwin':
for col in chunk.select_dtypes(include=['datetime64']):
chunk[col] = chunk[col].dt.strftime('%Y-%m-%d %H:%M:%S')
chunk.to_sql(
table_name,
engine,
if_exists=method,
index=False,
method='multi'
)
total_rows += len(chunk)
method = 'append' # 第一次之后都使用追加模式
self.log.trace(f"Inserted chunk {i // chunk_size + 1}",
rows=len(chunk),
total_inserted=total_rows)
self.log.info("Data inserted successfully",
table=table_name,
total_rows=total_rows)
return total_rows
finally:
# 确保连接正确关闭
engine.dispose()
conn.close()
except Exception as e:
self.log.error("Data insertion failed",
table=table_name,
error=str(e),
exc_info=True)
raise
def update_from_df(self, table_name: str, df: pd.DataFrame,
key_columns: Union[str, List[str]]) -> int:
"""
使用DataFrame数据更新数据库表
Args:
table_name (str): 目标表名
df (pd.DataFrame): 包含更新数据
key_columns (Union[str, List[str]]): 用于匹配记录的关键列
Returns:
int: 更新的总行数
Raises:
MySQLError: 如果更新失败
"""
if df.empty:
self.log.warning("Attempted to update with empty DataFrame", table=table_name)
return 0
self.log.debug("Preparing to update table from DataFrame",
table=table_name,
key_columns=key_columns,
rows=len(df))
try:
if isinstance(key_columns, str):
key_columns = [key_columns]
total_updated = 0
conn = self.begin_transaction()
try:
cursor = conn.cursor()
# 获取表结构信息
table_info = self._get_table_info(table_name)
columns = [col for col in df.columns if col in table_info]
# 构建UPDATE语句模板
set_clause = ', '.join([f"{col}=%s" for col in columns if col not in key_columns])
where_clause = ' AND '.join([f"{col}=%s" for col in key_columns])
update_sql = f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}"
self.log.trace("Generated update SQL", sql=update_sql)
# 准备数据
update_data = []
for _, row in df.iterrows():
# SET部分的值
set_values = [row[col] for col in columns if col not in key_columns]
# WHERE部分的值
key_values = [row[col] for col in key_columns]
update_data.append(tuple(set_values + key_values))
# 执行批量更新
cursor.executemany(update_sql, update_data)
total_updated = cursor.rowcount
self.commit_transaction(conn)
self.log.info("Data updated successfully",
table=table_name,
rows_updated=total_updated)
return total_updated
except Exception as e:
self.rollback_transaction(conn)
raise
except Exception as e:
self.log.error("Data update failed",
table=table_name,
error=str(e),
exc_info=True)
raise
def _get_table_info(self, table_name: str) -> Dict[str, str]:
"""
获取表结构信息
Args:
table_name (str): 表名
Returns:
Dict[str, str]: 列名到类型的映射
Raises:
MySQLError: 如果查询失败
"""
sql = f"""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
"""
params = (self.config['database'], table_name)
try:
with self.get_connection() as conn:
cursor = conn.cursor()
cursor.execute(sql, params)
result = cursor.fetchall()
return {row['column_name']: row['data_type'] for row in result}
except Exception as e:
self.log.error("Failed to get table info",
table=table_name,
error=str(e))
raise
def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]:
"""
推断DataFrame各列的SQL类型
Args:
df (pd.DataFrame): 输入数据框
Returns:
Dict[str, str]: 列名到SQL类型的映射
"""
type_mapping = {
'int64': 'BIGINT',
'float64': 'DOUBLE',
'datetime64[ns]': 'DATETIME',
'object': 'TEXT',
'bool': 'TINYINT(1)',
'category': 'VARCHAR(255)'
}
sql_types = {}
for col, dtype in df.dtypes.items():
dtype_str = str(dtype)
sql_types[col] = type_mapping.get(dtype_str, 'TEXT')
self.log.debug("Mapped DataFrame types to SQL types",
mappings=sql_types)
return sql_types
def create_table_from_df(self, table_name: str, df: pd.DataFrame,
primary_key: Union[str, List[str], None] = None) -> bool:
"""
根据DataFrame结构创建表
Args:
table_name (str): 表名
df (pd.DataFrame): 参考数据框
primary_key (Union[str, List[str], None]): 主键列
Returns:
bool: 是否创建成功
"""
if self.table_exists(table_name):
self.log.warning("Table already exists", table=table_name)
return False
self.log.debug("Creating new table from DataFrame schema",
table=table_name,
columns=list(df.columns))
try:
sql_types = self.df_to_sql_type(df)
columns_sql = []
for col, sql_type in sql_types.items():
col_def = f"{col} {sql_type}"
columns_sql.append(col_def)
if primary_key:
if isinstance(primary_key, str):
primary_key = [primary_key]
pk_columns = [col for col in primary_key if col in sql_types]
if pk_columns:
columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})")
self.log.trace("Set primary key",
table=table_name,
primary_key=pk_columns)
create_sql = f"CREATE TABLE {table_name} (\n {',\n '.join(columns_sql)}\n)"
self.execute_sql(create_sql)
self.log.info("Table created successfully", table=table_name)
return True
except Exception as e:
self.log.error("Failed to create table",
table=table_name,
error=str(e),
exc_info=True)
return False
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
"""
执行SQL语句
Args:
sql (str): SQL语句
params (Union[tuple, dict, None]): 参数
fetch (bool): 是否获取结果
Returns:
Union[int, List[Dict[str, Any]]]:
- 如果是INSERT/UPDATE/DELETE,返回影响的行数
- 如果是SELECT且fetch=True,返回结果列表
"""
conn = None
cursor = None
try:
conn = self.get_connection()
cursor = conn.cursor()
# Linux/macOS需要更长的执行时间
if platform.system() != 'Windows':
cursor.execute("SET SESSION max_execution_time=600000")
cursor.execute(sql, params)
if fetch:
result = cursor.fetchall()
self.log.debug("Query executed", rows=len(result))
return result
else:
affected_rows = cursor.rowcount
self.log.debug("Update executed", affected_rows=affected_rows)
return affected_rows
except Exception as e:
self.log.error("SQL execution failed",
sql=sql,
params=params,
error=str(e),
exc_info=True)
raise
finally:
if cursor:
cursor.close()
if conn:
conn.close()
def begin_transaction(self) -> pymysql.connections.Connection:
"""开始事务"""
try:
conn = self.get_connection()
conn.autocommit(False)
# macOS需要特殊处理事务隔离级别
if platform.system() == 'Darwin':
conn.cursor().execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED")
self.log.debug("Transaction started")
return conn
except Exception as e:
self.log.error("Begin transaction_failed", error=str(e))
raise
def commit_transaction(self, conn: pymysql.connections.Connection) -> None:
"""提交事务"""
try:
conn.commit()
self.log.debug("Transaction committed")
except Exception as e:
self.log.error("Commit failed", error=str(e))
raise
finally:
conn.close()
def rollback_transaction(self, conn: pymysql.connections.Connection) -> None:
"""回滚事务"""
try:
conn.rollback()
self.log.warning("Transaction rolled back")
except Exception as e:
self.log.error("Rollback failed", error=str(e))
finally:
conn.close()
def table_exists(self, table_name: str) -> bool:
"""检查表是否存在"""
sql = """
SELECT COUNT(*) as count
FROM `information_schema`.`tables`
WHERE `table_schema` = %s AND `table_name` = %s
"""
params = (self.config['database'], table_name)
try:
result = self.execute_sql(sql, params, fetch=True)
exists = result[0]['count'] > 0
self.log.debug("Checked table existence",
table=table_name,
exists=exists)
return exists
except Exception:
return False
def drop_table(self, table_name: str) -> bool:
"""删除表"""
if not self.table_exists(table_name):
self.log.warning("Table does not exist", table=table_name)
return False
try:
self.execute_sql(f"DROP TABLE {table_name}")
self.log.info("Table dropped successfully", table=table_name)
return True
except Exception as e:
self.log.error("Failed to drop table",
table=table_name,
error=str(e),
exc_info=True)
return False
def get_pool_status(self) -> Dict[str, int]:
"""获取连接池状态"""
return {
'max': self._pool._maxconnections,
'active': self._pool._connections,
'idle': len(self._pool._idle_cache),
'shared': len(self._pool._shared_cache)
}
def validate_connection(self) -> bool:
"""验证连接是否有效"""
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute("SELECT 1")
return cursor.fetchone()[0] == 1
except Exception:
return False
def __del__(self):
"""析构函数"""
if hasattr(self, '_pool'):
try:
self._pool.close()
self.log.info("Connection pool closed")
except Exception as e:
self.log.error("Failed to close pool", error=str(e))
# 平台特定的默认配置
def get_default_config():
"""获取各平台默认配置"""
current_platform = platform.system()
base_config = {
'host': 'localhost',
'port': 3306,
'user': 'root',
'password': '123123',
'database': 'intelligence',
'max_connections': 5
}
if current_platform == 'Windows':
return {
**base_config,
'connect_timeout': 10,
'read_timeout': 30,
'write_timeout': 30
}
elif current_platform == 'Darwin':
return {
**base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
}
else: # Linux和其他平台
return {
**base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60
}
if __name__ == "__main__":
# 使用示例
db = MySQLAgent(get_default_config())
# 测试连接
if db.validate_connection():
print("Database connection successful")
# 获取数据库版本
version = db.query_to_df("SELECT VERSION() as version")
print(f"Database version: {version['version'].iloc[0]}")
# 查看连接池状态
print("Connection pool status:", db.get_pool_status())
else:
print("Failed to connect to database")
-2
View File
@@ -1,2 +0,0 @@
# Makes system_management a package
-3
View File
@@ -1,3 +0,0 @@
# Makes system_management.scheduler a package
from .task_scheduler import TaskScheduler
@@ -1,190 +0,0 @@
import argparse
from datetime import datetime
from system_management.scheduler.task_scheduler import TaskScheduler
from system_management.scheduler.task_scheduler import TaskManager
from config import Config
from utils.logger import CrossPlatformLog
# 初始化日志
log = CrossPlatformLog.get_logger("TaskManagement")
def main():
# 初始化配置和组件
scheduler = TaskScheduler(Config.MYSQL_CONFIG)
manager = TaskManager(scheduler)
# 解析命令行参数
parser = argparse.ArgumentParser(description="任务管理工具")
subparsers = parser.add_subparsers(dest="command", help="可用命令")
# 列出任务
list_parser = subparsers.add_parser("list", help="列出所有任务")
list_parser.add_argument("--active-only", action="store_true", help="只显示活跃任务")
# 查看任务详情
show_parser = subparsers.add_parser("show", help="显示任务详情")
show_parser.add_argument("task_id", type=int, help="任务ID")
# 更新任务
update_parser = subparsers.add_parser("update", help="更新任务属性")
update_parser.add_argument("task_id", type=int, help="任务ID")
update_parser.add_argument("--name", help="任务名称")
update_parser.add_argument("--type", help="任务类型")
update_parser.add_argument("--module", help="模块路径")
update_parser.add_argument("--cron", help="Cron表达式")
update_parser.add_argument("--timezone", help="时区")
# 启用/禁用任务
toggle_parser = subparsers.add_parser("toggle", help="启用/禁用任务")
toggle_parser.add_argument("task_id", type=int, help="任务ID")
toggle_parser.add_argument("--activate", action="store_true", help="启用任务")
toggle_parser.add_argument("--deactivate", action="store_true", help="禁用任务")
# 删除任务
delete_parser = subparsers.add_parser("delete", help="删除任务")
delete_parser.add_argument("task_id", type=int, help="任务ID")
# 手动执行任务
run_parser = subparsers.add_parser("run", help="手动执行任务")
run_parser.add_argument("task_id", type=int, help="任务ID")
# 添加任务
add_parser = subparsers.add_parser("add", help="添加新任务")
add_parser.add_argument("--name", required=True, help="任务名称")
add_parser.add_argument("--type", required=True, help="任务类型")
add_parser.add_argument("--module", required=True, help="模块路径")
add_parser.add_argument("--cron", required=True, help="Cron表达式")
add_parser.add_argument("--timezone", default="Asia/Shanghai", help="时区")
args = parser.parse_args()
# 执行相应命令
if args.command == "list":
try:
tasks = manager.get_all_tasks(args.active_only)
manager.print_task_table(tasks)
log.info(f"列出任务完成,共{len(tasks)}个任务")
except Exception as e:
log.error(f"列出任务失败: {str(e)}", exc_info=True)
elif args.command == "show":
try:
task = manager.get_task_by_id(args.task_id)
if task:
print("\n===== 任务详情 =====")
for key, value in task.items():
print(f"{key}: {value}")
print("====================")
log.info(f"显示任务详情成功,任务ID: {args.task_id}")
else:
log.warning(f"未找到任务ID: {args.task_id}")
print(f"任务ID {args.task_id} 不存在")
except Exception as e:
log.error(f"显示任务详情失败,任务ID: {args.task_id}", exc_info=True)
elif args.command == "update":
try:
updates = {}
if args.name:
updates['task_name'] = args.name
if args.type:
updates['task_type'] = args.type
if args.module:
updates['module_path'] = args.module
if args.cron:
updates['cron_expression'] = args.cron
if args.timezone:
updates['time_zone'] = args.timezone
if not updates:
log.warning("未提供任何更新字段")
print("请至少指定一个更新字段")
return
if manager.update_task(args.task_id, updates):
log.info(f"任务ID {args.task_id} 更新成功")
print(f"任务ID {args.task_id} 更新成功")
else:
log.warning(f"任务ID {args.task_id} 更新失败")
print(f"任务ID {args.task_id} 更新失败")
except Exception as e:
log.error(f"更新任务失败,任务ID: {args.task_id}", exc_info=True)
elif args.command == "toggle":
try:
if args.activate and args.deactivate:
log.warning("不能同时指定 --activate 和 --deactivate")
print("不能同时指定 --activate 和 --deactivate")
return
if not args.activate and not args.deactivate:
log.warning("请指定 --activate 或 --deactivate")
print("请指定 --activate 或 --deactivate")
return
if args.activate:
success = manager.toggle_task_status(args.task_id, True)
action = "启用"
else:
success = manager.toggle_task_status(args.task_id, False)
action = "禁用"
if success:
log.info(f"任务ID {args.task_id} {action}成功")
print(f"任务ID {args.task_id} {action}成功")
else:
log.warning(f"任务ID {args.task_id} {action}失败")
print(f"任务ID {args.task_id} {action}失败")
except Exception as e:
log.error(f"切换任务状态失败,任务ID: {args.task_id}", exc_info=True)
elif args.command == "delete":
try:
confirm = input(f"确定要删除任务ID {args.task_id} 吗? (y/n) ")
if confirm.lower() == 'y':
if manager.delete_task(args.task_id):
log.info(f"任务ID {args.task_id} 删除成功")
print(f"任务ID {args.task_id} 删除成功")
else:
log.warning(f"任务ID {args.task_id} 删除失败")
print(f"任务ID {args.task_id} 删除失败")
else:
log.info(f"用户取消删除任务ID {args.task_id}")
print("操作已取消")
except Exception as e:
log.error(f"删除任务失败,任务ID: {args.task_id}", exc_info=True)
elif args.command == "run":
try:
log.info(f"开始手动执行任务ID {args.task_id}")
print(f"正在手动执行任务ID {args.task_id}...")
if manager.run_task_manually(args.task_id):
log.info(f"任务ID {args.task_id} 执行成功")
print(f"任务ID {args.task_id} 执行成功")
else:
log.warning(f"任务ID {args.task_id} 执行失败")
print(f"任务ID {args.task_id} 执行失败")
except Exception as e:
log.error(f"手动执行任务失败,任务ID: {args.task_id}", exc_info=True)
elif args.command == "add":
try:
task_id = scheduler.add_task(
task_name=args.name,
task_type=args.type,
module_path=args.module,
cron_expression=args.cron,
time_zone=args.timezone
)
log.info(f"新任务添加成功,ID: {task_id}")
print(f"新任务添加成功,ID: {task_id}")
except Exception as e:
log.error(f"添加任务失败: {str(e)}", exc_info=True)
print(f"添加任务失败: {str(e)}")
else:
parser.print_help()
if __name__ == "__main__":
main()
+215 -422
View File
@@ -1,484 +1,277 @@
# system_management/scheduler/task_scheduler.py
import importlib
import threading
import time
from datetime import datetime
from typing import Dict, List, Optional, Any
import croniter
import pytz
from concurrent.futures import ThreadPoolExecutor, as_completed
import pandas as pd
from sqlalchemy.exc import SQLAlchemyError
from utils.mysql_agent import MySQLAgent
from utils.logger import CrossPlatformLog
from datetime import datetime, timedelta
from typing import Dict, List, Optional
from storage.mysql_agent import MySQLAgent
from pathlib import Path
# 初始化调度器日志
# 使用您的日志系统
from utils.logger import CrossPlatformLog
log = CrossPlatformLog.get_logger("TaskScheduler")
class TaskScheduler:
def __init__(self, db_config: Optional[Dict] = None, max_workers: int = 5):
"""初始化任务调度器(基于Cron表达式)"""
self.db = MySQLAgent(db_config or {})
self.executor = ThreadPoolExecutor(max_workers=max_workers)
# 并发容量控制:限制同时运行的后台任务不超过 max_workers
self._running_semaphore = threading.Semaphore(max_workers)
# 任务统计
self.hourly_stats = {'成功': 0, '失败': 0, '总数': 0}
self.hourly_stats_lock = threading.Lock()
log.info(f"任务调度器已初始化,最大工作线程数: {max_workers}")
def _resolve_callable(self, module_path: str):
"""解析模块路径,支持模块、模块内类/函数,并返回可调用对象
兼容以下形式:
- package.module -> 期望模块内存在 main()
- package.module.ClassName -> 调用 ClassName.main() 或实例化后调用 main()
- package.module.func_name -> 直接调用该函数
- package.module.ClassName.method_name -> 调用指定方法
def __init__(self, db_config: Optional[Dict] = None):
"""
if not module_path or not isinstance(module_path, str):
raise ImportError("无效的模块路径")
初始化任务调度器
parts = module_path.split('.')
last_import_error = None
# 从最长前缀开始尝试导入模块,逐步回退
for i in range(len(parts), 0, -1):
module_name = '.'.join(parts[:i])
try:
module = importlib.import_module(module_name)
attr_chain = parts[i:]
# 从模块开始逐级解析属性
target = module
for attr in attr_chain:
if not hasattr(target, attr):
raise AttributeError(f"{target} 中未找到属性: {attr}")
target = getattr(target, attr)
# 若目标是类,优先尝试类方法/实例方法 main
if isinstance(target, type):
# 类方法 main
if hasattr(target, 'main') and callable(getattr(target, 'main')):
return getattr(target, 'main')
# 实例方法 main
try:
instance = target()
if hasattr(instance, 'main') and callable(getattr(instance, 'main')):
return getattr(instance, 'main')
except Exception:
pass
# 不把“类本身”当作任务入口(否则只会构造实例不执行 main)
raise AttributeError(f"{target.__name__} 缺少可调用的 main() 作为任务入口")
# 目标非类:若本身可调用,则直接作为入口返回
if callable(target):
return target
# 否则尝试对象上的 main()
if hasattr(target, 'main') and callable(getattr(target, 'main')):
return getattr(target, 'main')
raise AttributeError(f"路径 {module_path} 未解析到可调用入口(缺少 main 或不可调用)")
except Exception as e:
last_import_error = e
continue
# 如果所有尝试均失败,则抛出最后的错误
raise ImportError(f"模块 {module_path} 导入/解析失败: {str(last_import_error)}")
def check_and_run_tasks(self, print_empty_status: bool = False) -> Dict[str, int]:
"""检查并执行所有到期的任务,优化空任务处理和异常容错
Args:
print_empty_status: 是否打印空任务状态(默认False,避免频繁输出)
db_config (Optional[Dict]): 可选的数据库配置,默认使用MySQLAgent默认配置
"""
result = {'总任务数': 0, '成功': 0, '失败': 0}
self.db = MySQLAgent(db_config or {}) # 使用您提供的MySQLAgent
self._init_task_table()
log.info("TaskScheduler initialized")
def _init_task_table(self):
"""确保任务表存在并包含必要字段"""
if not self.db.table_exists("main_task"):
log.info("Creating main_task table")
create_sql = """
CREATE TABLE main_task (
task_id INT AUTO_INCREMENT PRIMARY KEY,
task_name VARCHAR(100) NOT NULL,
module_path VARCHAR(255) NOT NULL COMMENT '例如data_collection.spiders.weibo_spider',
frequency_type ENUM('minute','hourly','daily','weekly','monthly') NOT NULL,
frequency_value INT DEFAULT NULL COMMENT '间隔数值',
last_run_time DATETIME DEFAULT NULL,
next_run_time DATETIME DEFAULT NULL,
last_run_status VARCHAR(20) DEFAULT NULL,
is_active TINYINT(1) DEFAULT 1,
is_running TINYINT(1) DEFAULT 0,
run_count INT DEFAULT 0,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
INDEX idx_next_run (next_run_time),
INDEX idx_active (is_active)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
"""
self.db.execute_sql(create_sql)
log.success("main_task table created")
def run_pending_tasks(self) -> Dict[str, int]:
"""
执行所有到期的活跃任务
Returns:
Dict[str, int]: 包含执行结果的字典 {
'total': 总任务数,
'success': 成功数,
'failed': 失败数
}
"""
result = {'total': 0, 'success': 0, 'failed': 0}
try:
# 获取当前时间(带时区转换为本地时间)
tz = pytz.timezone('Asia/Shanghai')
now = datetime.now(tz).replace(tzinfo=None) # 移除时区信息,与数据库存储一致
log.debug(f"当前检查时间: {now.strftime('%Y-%m-%d %H:%M:%S')}")
# 使用您提供的query_to_df方法获取任务
tasks_df = self.db.query_to_df(
"SELECT * FROM main_task "
"WHERE is_active = 1 AND next_run_time <= %s "
"ORDER BY next_run_time",
params=(datetime.now(),)
)
# 查询所有到期的活跃任务(使用参数化查询防止注入)
tasks_df = self.db.query_to_df("""
SELECT *
FROM main_task
WHERE is_active = 1
AND next_run_time <= %s
AND is_running = 0
ORDER BY next_run_time
""", params=(now,),is_print=False)
result['total'] = len(tasks_df)
result['总任务数'] = len(tasks_df)
if tasks_df.empty:
# 空任务时根据参数决定是否输出
if print_empty_status:
print(f"当前没有到期的任务,等待新任务加入...{now.strftime('%Y-%m-%d %H:%M:%S')}")
log.debug("No pending tasks found")
return result
# 并发执行任务
futures = []
for _, task in tasks_df.iterrows():
# 传递任务字典的副本避免线程安全问题
task_copy = task.to_dict()
futures.append(self.executor.submit(self._process_single_task, task_copy))
task_id = task['task_id']
log.bind(task_id=task_id).info(
f"Starting task {task['task_name']}"
)
# 标记任务为执行中
self._update_task_status(
task_id,
{'is_running': 1, 'last_run_time': datetime.now()}
)
# 收集执行结果
for future in as_completed(futures):
try:
if future.result():
result['成功'] += 1
else:
result['失败'] += 1
self._execute_single_task(task)
self._update_task_status(
task_id,
{
'last_run_status': 'success',
'is_running': 0,
'run_count': task['run_count'] + 1,
'next_run_time': self._calculate_next_run(
task['frequency_type'],
task['frequency_value']
)
}
)
result['success'] += 1
log.bind(task_id=task_id).success("Task completed")
except Exception as e:
log.error(f"任务线程执行失败: {str(e)}", exc_info=True)
result['失败'] += 1
# 更新小时统计
with self.hourly_stats_lock:
self.hourly_stats['成功'] += result['成功']
self.hourly_stats['失败'] += result['失败']
self.hourly_stats['总数'] += result['总任务数']
log.bind(task_id=task_id).error(
f"Task failed: {str(e)}",
exc_info=True
)
self._update_task_status(
task_id,
{
'last_run_status': 'failed',
'is_running': 0,
'next_run_time': self._calculate_next_run(
task['frequency_type'],
task['frequency_value'],
retry=True
)
}
)
result['failed'] += 1
log.info(
"任务调度周期完成",
总任务数=result['总任务数'],
成功=result['成功'],
失败=result['失败']
"Scheduler cycle completed",
total_tasks=result['total'],
success=result['success'],
failed=result['failed']
)
return result
except SQLAlchemyError as e: # 数据库异常处理优化
log.error(f"数据库操作失败,将在下次轮询重试: {str(e)}", exc_info=True)
return result # 不中断,返回当前结果
except Exception as e:
log.error("调度器周期执行异常,将在下次轮询重试", exc_info=True)
return result # 不中断主循环,允许下次重试
log.critical(
"Scheduler main loop failed",
exc_info=True
)
raise
def _process_single_task(self, task: Dict[str, Any]) -> bool:
"""处理单个任务(线程安全)"""
task_id = task['task_id']
task_name = task['task_name']
task_log = log.bind(task_id=task_id, task_name=task_name)
task_log.info(f"开始执行任务: {task_name}")
def _execute_single_task(self, task: Dict) -> None:
"""执行单个任务模块"""
start_time = time.time()
task_log = log.bind(
task_id=task['task_id'],
module=task['module_path']
)
try:
# 阻塞等待可用的执行槽位,保证同时运行的任务不超过最大工作线程数
self._running_semaphore.acquire()
module = importlib.import_module(task['module_path'])
# 标记任务为运行中(使用当前时间的时区感知对象)
tz = pytz.timezone(task.get('time_zone', 'Asia/Shanghai'))
current_time = datetime.now(tz).replace(tzinfo=None)
if not hasattr(module, 'main'):
raise ImportError(f"Module has no main() function")
self._update_task_status(task_id, {
'is_running': 1,
'last_run_time': current_time
})
# 执行任务
task_log.debug("Task execution started")
module.main()
# 将任务主体放到后台线程执行,当前线程快速返回
self.executor.submit(self._run_task_async, task.copy())
task_log.debug("任务已提交至后台执行队列")
return True # 表示已成功提交
except Exception as e:
task_log.error(f"任务执行失败: {str(e)}", exc_info=True)
# 失败时计算下次重试时间(15分钟后)
next_retry_time = datetime.now() + pd.Timedelta(minutes=15)
# 即使任务执行失败,也要确保状态更新
try:
self._update_task_status(task_id, {
'last_run_status': 'failed',
'is_running': 0,
'next_run_time': next_retry_time
})
except Exception as update_err:
task_log.error(f"任务失败后状态更新失败: {str(update_err)}", exc_info=True)
# 若已占用并发槽位,释放之
try:
self._running_semaphore.release()
except Exception:
pass
return False
def _run_task_async(self, task: Dict[str, Any]) -> None:
"""在后台线程中执行任务主体,并在结束后更新状态"""
task_id = task['task_id']
task_name = task['task_name']
task_log = log.bind(task_id=task_id, task_name=task_name)
try:
# 如果 module_path 指向类,先实例化以触发初始化日志,然后执行 main
self._execute_task_logic(task)
# 成功后计算下次运行时间
next_run_time = self._calculate_next_run_time(
cron_expr=task['cron_expression'],
time_zone=task.get('time_zone', 'Asia/Shanghai')
elapsed = time.time() - start_time
task_log.info(
f"Task completed in {elapsed:.2f}s",
duration=elapsed
)
self._update_task_status(task_id, {
'last_run_status': 'success',
'is_running': 0,
'run_count': task['run_count'] + 1,
'next_run_time': next_run_time
})
task_log.info(f"任务执行成功: {task_name}")
except Exception:
task_log.error("任务后台执行失败", exc_info=True)
next_retry_time = datetime.now() + pd.Timedelta(minutes=15)
try:
self._update_task_status(task_id, {
'last_run_status': 'failed',
'is_running': 0,
'next_run_time': next_retry_time
})
except Exception:
task_log.error("任务失败后状态更新失败(后台)", exc_info=True)
finally:
# 释放并发槽位
try:
self._running_semaphore.release()
except Exception:
pass
def _execute_task_logic(self, task):
"""
执行任务逻辑的核心方法
支持类方法、静态方法和实例方法的调用
"""
module_path = task.get('module_path')
if not module_path:
raise ValueError("任务缺少 module_path 配置")
# 解析模块路径和类名
try:
path_parts = module_path.split('.')
if len(path_parts) < 2:
raise ValueError(f"无效的模块路径: {module_path}")
module_name = '.'.join(path_parts[:-1])
class_name = path_parts[-1]
method_name = 'main' # 默认方法名
except Exception as e:
raise ValueError(f"解析模块路径失败: {str(e)}")
# 动态导入模块
try:
import importlib
module = importlib.import_module(module_name)
except ImportError as e:
raise ImportError(f"无法导入模块 {module_name}: {str(e)}")
# 获取类和方法
if not hasattr(module, class_name):
raise AttributeError(f"模块 {module_name} 中未找到类 {class_name}")
cls = getattr(module, class_name)
# 检查是否存在指定方法
if not hasattr(cls, method_name):
raise AttributeError(f"{class_name} 中未找到方法 {method_name}")
method = getattr(cls, method_name)
# 根据方法类型决定如何调用
import inspect
callable_entry = None
# 判断是否为静态方法或类方法
if isinstance(method, staticmethod):
# 静态方法可以直接调用
callable_entry = method
elif isinstance(method, classmethod):
# 类方法需要传入类作为第一个参数
callable_entry = method
else:
# 实例方法或普通函数
try:
# 尝试检查方法签名
sig = inspect.signature(method)
params = list(sig.parameters.values())
# 如果第一个参数是self且没有默认值,则认为是实例方法
if params and params[0].name == 'self' and params[0].default == inspect.Parameter.empty:
# 创建实例并获取绑定方法
instance = cls()
callable_entry = getattr(instance, method_name)
else:
# 可能是普通函数或者是带有默认self参数的方法
callable_entry = method
except Exception:
# 如果检查签名失败,默认尝试创建实例
try:
instance = cls()
callable_entry = getattr(instance, method_name)
except Exception:
# 如果创建实例也失败,则直接调用方法(适用于不需要self的特殊情况)
callable_entry = method
# 执行任务
if not callable(callable_entry):
raise TypeError(f"{module_path}.{method_name} 不是可调用对象")
try:
# 执行任务逻辑
callable_entry()
except Exception as e:
self.logger.error(f"任务逻辑执行失败: {str(e)}")
task_log.error(
"Task execution failed",
exc_info=True
)
raise
def _calculate_next_run_time(self, cron_expr: str, time_zone: str = 'Asia/Shanghai') -> datetime:
"""基于Cron表达式计算下次运行时间"""
try:
tz = pytz.timezone(time_zone)
now = datetime.now(tz) # 使用任务指定时区的当前时间
cron = croniter.croniter(cron_expr, now)
next_run = cron.get_next(datetime)
return next_run.replace(tzinfo=None) # 移除时区信息,适应数据库存储
except Exception as e:
log.error(f"Cron表达式解析失败: {cron_expr}, 错误: {str(e)}")
raise ValueError(f"无效的Cron表达式: {cron_expr}")
def _update_task_status(self, task_id: int, updates: Dict) -> None:
"""更新任务状态"""
set_clause = ", ".join([f"{k}=%s" for k in updates.keys()])
sql = f"UPDATE main_task SET {set_clause} WHERE task_id=%s"
def _update_task_status(self, task_id: int, updates: Dict[str, Any]) -> None:
"""更新任务状态到数据库(适配SQLAlchemy的参数传递方式)"""
if not updates:
log.warning(f"任务ID {task_id} 未提供任何更新字段")
return
# 构建UPDATE语句(确保字段名安全)
valid_fields = {'is_running', 'last_run_time', 'last_run_status',
'run_count', 'next_run_time', 'updated_at'}
filtered_updates = {k: v for k, v in updates.items() if k in valid_fields}
if not filtered_updates:
log.warning(f"任务ID {task_id} 没有有效的更新字段")
return
set_clause = ", ".join([f"{k}=%s" for k in filtered_updates.keys()])
sql = f"UPDATE main_task SET {set_clause}, updated_at=NOW() WHERE task_id=%s"
params = list(filtered_updates.values()) + [task_id]
params = list(updates.values()) + [task_id]
try:
# 执行更新并获取受影响的行数
affected_rows = self.db.execute_sql(sql, params=params)
if affected_rows != 1:
affected = self.db.execute_sql(sql, params=params)
if affected != 1:
log.warning(
"任务状态更新异常",
"Unexpected row count in update",
task_id=task_id,
预期影响行数=1,
实际影响行数=affected_rows
expected=1,
affected=affected
)
except SQLAlchemyError as e:
log.error(f"任务状态更新失败(数据库错误),task_id: {task_id}", exc_info=True)
raise
except Exception as e:
log.error(f"任务状态更新失败,task_id: {task_id}", exc_info=True)
log.error(
"Failed to update task status",
task_id=task_id,
exc_info=True
)
raise
def add_task(self,
task_name: str,
task_type: str,
module_path: str,
cron_expression: str,
time_zone: str = 'Asia/Shanghai') -> int:
"""添加新的Cron任务"""
if not cron_expression:
raise ValueError("Cron表达式不能为空")
def _calculate_next_run(self, freq_type: str, freq_value: Optional[int] = None,
retry: bool = False) -> datetime:
"""
计算下次执行时间(带重试逻辑)
"""
base_time = datetime.now()
# 验证模块路径可解析(提前检查,避免添加无效任务)
try:
_ = self._resolve_callable(module_path)
except Exception as e:
raise ValueError(f"模块路径不可用: {module_path},错误: {str(e)}")
if retry:
# 失败后15分钟重试
log.debug("Calculating retry time")
return base_time + timedelta(minutes=15)
# 计算首次运行时间
first_run_time = self._calculate_next_run_time(cron_expression, time_zone)
if freq_type == 'minute':
delta = timedelta(minutes=freq_value or 1)
elif freq_type == 'hourly':
delta = timedelta(hours=freq_value or 1)
elif freq_type == 'daily':
delta = timedelta(days=freq_value or 1)
elif freq_type == 'weekly':
delta = timedelta(weeks=freq_value or 1)
elif freq_type == 'monthly':
# 处理月末日期特殊情况
next_month = (base_time.replace(day=1) + timedelta(days=32)).replace(day=1)
last_day = (next_month - timedelta(days=1)).day
day = min(base_time.day, last_day)
return base_time.replace(day=1, month=next_month.month, day=day)
else:
raise ValueError(f"Unknown frequency type: {freq_type}")
# 插入数据库
return base_time + delta
def add_task(self, task_name: str, module_path: str, frequency_type: str,
frequency_value: Optional[int] = None) -> int:
"""
添加新任务到调度系统
"""
sql = """
INSERT INTO main_task
(task_name, task_type, module_path, cron_expression, time_zone,
next_run_time, is_active, created_at, updated_at)
VALUES (%s, %s, %s, %s, %s, %s, 1, NOW(), NOW()) \
"""
params = (task_name, task_type, module_path, cron_expression, time_zone, first_run_time)
INSERT INTO main_task
(task_name, module_path, frequency_type, frequency_value, next_run_time)
VALUES (%s, %s, %s, %s, %s)
"""
next_run = self._calculate_next_run(frequency_type, frequency_value)
params = (task_name, module_path, frequency_type, frequency_value, next_run)
try:
self.db.execute_sql(sql, params=params)
# 获取插入的任务ID
result_df = self.db.query_to_df("SELECT LAST_INSERT_ID() AS id")
if result_df.empty or 'id' not in result_df.columns:
raise ValueError("无法获取新添加任务的ID")
task_id = result_df.iloc[0]['id']
task_id = self.db.query_to_df("SELECT LAST_INSERT_ID() AS id").iloc[0]['id']
log.info(
"新任务添加成功",
"New task added",
task_id=task_id,
task_name=task_name,
cron表达式=cron_expression,
首次运行时间=first_run_time.strftime('%Y-%m-%d %H:%M:%S')
next_run=next_run
)
return task_id
except SQLAlchemyError as e:
log.error(f"添加任务失败(数据库错误): {task_name}", exc_info=True)
raise
except Exception as e:
log.error(f"添加任务失败: {task_name}", exc_info=True)
log.error(
"Failed to add new task",
task_name=task_name,
exc_info=True
)
raise
def get_pending_tasks_count(self) -> int:
"""获取待执行任务数量(用于优雅关闭)"""
try:
tz = pytz.timezone('Asia/Shanghai')
now = datetime.now(tz).replace(tzinfo=None)
sql = """
SELECT COUNT(*) as cnt
FROM main_task
WHERE is_active = 1
AND next_run_time <= %s
AND is_running = 0
"""
df = self.db.query_to_df(sql, params=(now,))
return df['cnt'].iloc[0] if not df.empty else 0
except Exception as e:
log.error(f"查询待执行任务数量失败: {str(e)}", exc_info=True)
return 0 # 出错时返回0,避免影响关闭流程
def get_pending_tasks(self) -> List[Dict[str, Any]]:
"""查询所有待执行任务(兼容原有逻辑)"""
try:
tz = pytz.timezone('Asia/Shanghai')
now = datetime.now(tz).replace(tzinfo=None)
sql = """
SELECT *
FROM main_task
WHERE is_active = 1
AND next_run_time <= %s
AND is_running = 0
ORDER BY next_run_time
"""
tasks_df = self.db.query_to_df(sql, params=(now,))
if tasks_df.empty:
log.info("当前任务列表为空,等待新任务加入...")
return []
log.info(f"查询到{len(tasks_df)}个待执行任务")
return tasks_df.to_dict('records')
except Exception as e:
log.error(f"查询待执行任务失败,将重试: {str(e)}", exc_info=True)
return []
def get_and_reset_hourly_stats(self) -> Dict[str, int]:
"""获取并重置小时统计数据(用于每小时统计)"""
with self.hourly_stats_lock:
stats = self.hourly_stats.copy()
# 重置统计
self.hourly_stats = {'成功': 0, '失败': 0, '总数': 0}
return stats
def get_task_status(self, active_only: bool = True) -> pd.DataFrame:
"""
获取任务状态
"""
where = "WHERE is_active = 1" if active_only else ""
log.debug("Fetching task status", active_only=active_only)
return self.db.query_to_df(
f"""
SELECT
task_id, task_name, module_path,
frequency_type, frequency_value,
last_run_time, next_run_time,
last_run_status, run_count,
is_active, is_running
FROM main_task
{where}
ORDER BY next_run_time
"""
)
-1
View File
@@ -1 +0,0 @@
print("Hello, World!")
-171
View File
@@ -1,171 +0,0 @@
import unittest
import os
import tempfile
import hashlib
from datetime import datetime
from utils.minio_agent import MinIOAgent # 导入之前的MinIO操作类
class TestMinIOAgent(unittest.TestCase):
# 测试配置 - 本地MinIO社区版
MINIO_CONFIG = {
'endpoint': '127.0.0.1:9005',
'access_key': 'admin', # 默认账号
'secret_key': 'abc88888888', # 默认密码
'secure': False # 社区版默认不启用SSL
}
@classmethod
def setUpClass(cls):
"""初始化测试环境"""
# 创建唯一测试桶(避免冲突)
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
cls.test_bucket = f"test-bucket-{timestamp}"
cls.test_object = "test-data/sample.txt"
cls.test_content = b"this is MinIO test data: 1234567890"
# 初始化客户端
cls.minio_agent = MinIOAgent(cls.MINIO_CONFIG)
# 确保测试桶存在
cls.minio_agent.create_bucket(cls.test_bucket)
@classmethod
def tearDownClass(cls):
"""清理测试环境"""
try:
# 列出并删除桶内所有对象
objects = cls.minio_agent.list_objects(cls.test_bucket)
for obj in objects:
cls.minio_agent.delete_object(cls.test_bucket, obj['object_name'])
# 删除测试桶(MinIO要求桶为空才能删除)
cls.minio_agent._client.remove_bucket(cls.test_bucket)
print(f"\n测试清理完成,已删除桶: {cls.test_bucket}")
except Exception as e:
print(f"清理测试环境失败: {str(e)}")
def test_01_create_bucket(self):
"""测试创建存储桶"""
new_bucket = f"temp-bucket-{datetime.now().microsecond}"
result = self.minio_agent.create_bucket(new_bucket)
self.assertTrue(result, "存储桶创建失败")
# 验证桶是否存在
exists = self.minio_agent._client.bucket_exists(new_bucket)
self.assertTrue(exists, "存储桶创建后未检测到存在")
# 清理临时桶
self.minio_agent._client.remove_bucket(new_bucket)
def test_02_upload_download(self):
"""测试上传与下载功能"""
# 上传数据
upload_meta = self.minio_agent.upload_bytes(
bucket=self.test_bucket,
object_name=self.test_object,
data=self.test_content
)
# 验证上传结果
self.assertEqual(upload_meta['size'], len(self.test_content), "上传数据大小不匹配")
self.assertEqual(upload_meta['local_hash'], hashlib.md5(self.test_content).hexdigest(), "本地哈希校验失败")
# 下载数据到临时文件
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_path = temp_file.name
download_meta = self.minio_agent.download_file(
bucket=self.test_bucket,
object_name=self.test_object,
local_path=temp_path
)
# 验证下载内容
with open(temp_path, 'rb') as f:
downloaded_content = f.read()
self.assertEqual(downloaded_content, self.test_content, "下载数据与原始数据不匹配")
self.assertEqual(download_meta['size'], len(self.test_content), "下载文件大小不匹配")
# 清理临时文件
os.unlink(temp_path)
def test_03_presigned_url(self):
"""测试生成预签名URL"""
# 先上传测试文件
self.minio_agent.upload_bytes(
self.test_bucket,
self.test_object,
self.test_content
)
# 生成URL(有效期30秒)
url_info = self.minio_agent.get_presigned_url(
bucket=self.test_bucket,
object_name=self.test_object,
expires=30
)
# 验证URL格式
self.assertIn("http://127.0.0.1:9005", url_info['presigned_url'], "预签名URL格式不正确")
self.assertEqual(url_info['expires_in'], 30, "过期时间设置不正确")
def test_04_list_objects(self):
"""测试列出对象功能"""
# 上传多个测试对象
test_objects = [
"test-folder/file1.txt",
"test-folder/file2.csv",
"another-folder/image.jpg"
]
for obj in test_objects:
self.minio_agent.upload_bytes(
self.test_bucket,
obj,
b"tese_list_obj"
)
# 列出所有对象
all_objects = self.minio_agent.list_objects(self.test_bucket)
self.assertEqual(len(all_objects), len(test_objects) + 1, "列出对象数量不匹配") # +1是之前的test_object
# 按前缀筛选
filtered_objects = self.minio_agent.list_objects(
self.test_bucket,
prefix="test-folder/"
)
self.assertEqual(len(filtered_objects), 2, "按前缀筛选结果不正确")
def test_05_delete_object(self):
"""测试删除对象功能"""
# 创建测试对象
delete_obj = "to-delete/temp.txt"
self.minio_agent.upload_bytes(
self.test_bucket,
delete_obj,
b"will be delete"
)
# 执行删除
result = self.minio_agent.delete_object(self.test_bucket, delete_obj)
self.assertTrue(result, "删除对象失败")
# 验证删除
objects = self.minio_agent.list_objects(self.test_bucket, prefix="to-delete/")
self.assertEqual(len(objects), 0, "对象删除后仍存在")
def test_06_upload_empty_data(self):
"""测试上传空数据的异常处理"""
with self.assertRaises(ValueError, msg="未捕获空数据上传异常"):
self.minio_agent.upload_bytes(
self.test_bucket,
"empty.txt",
b""
)
if __name__ == "__main__":
# 执行测试并显示详细结果
unittest.main(verbosity=2)
+115 -104
View File
@@ -1,22 +1,21 @@
import unittest
import pandas as pd
from datetime import datetime
import tempfile
import time
import pymysql
from storage.mysql_agent import MySQLAgent
import platform
from concurrent.futures import ThreadPoolExecutor
from utils.mysql_agent import MySQLAgent
class TestMySQLAgent(unittest.TestCase):
@classmethod
def setUpClass(cls):
"""初始化测试环境和测试表"""
# 创建唯一的测试数据库和表名(避免冲突)
cls.test_db_name = f"test_db_{datetime.now().strftime('%Y%m%d%H%M%S')}"
cls.test_table = f"test_table_{datetime.now().strftime('%Y%m%d%H%M%S')}"
# 创建唯一的测试数据库
cls.test_db_name = "test_db_" + datetime.now().strftime("%Y%m%d%H%M%S")
cls.test_table = "test_table_" + datetime.now().strftime("%Y%m%d%H%M%S")
# 基础配置(根据实际环境修改)
# 基础配置
cls.base_config = {
'host': 'localhost',
'port': 3306,
@@ -34,19 +33,21 @@ class TestMySQLAgent(unittest.TestCase):
'database': cls.test_db_name
})
# 创建测试表并插入初始数据
# 创建测试表
test_data = pd.DataFrame({
'id': [1, 2, 3],
'name': ['Test1', 'Test2', 'Test3'],
'value': [10.5, 20.3, 30.8],
'created_at': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03'])
})
cls.db.create_table_from_df(cls.test_table, test_data, primary_key='id')
cls.db.insert_from_df(cls.test_table, test_data)
@classmethod
def _create_test_database(cls):
"""创建测试数据库"""
# 使用临时连接创建数据库
temp_conn = pymysql.connect(
host=cls.base_config['host'],
port=cls.base_config['port'],
@@ -54,6 +55,7 @@ class TestMySQLAgent(unittest.TestCase):
password=cls.base_config['password'],
charset='utf8mb4'
)
try:
with temp_conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}")
@@ -65,14 +67,21 @@ class TestMySQLAgent(unittest.TestCase):
@classmethod
def tearDownClass(cls):
"""清理测试环境"""
"""清理测试数据库"""
if hasattr(cls, 'db') and cls.db:
# 删除测试表
if cls.db.table_exists(cls.test_table):
cls.db.drop_table(cls.test_table)
# 删除测试数据库
temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4')
temp_conn = pymysql.connect(
host=cls.base_config['host'],
port=cls.base_config['port'],
user=cls.base_config['user'],
password=cls.base_config['password'],
charset='utf8mb4'
)
try:
with temp_conn.cursor() as cursor:
cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}")
@@ -80,24 +89,22 @@ class TestMySQLAgent(unittest.TestCase):
finally:
temp_conn.close()
def test_connection(self):
def test_01_connection(self):
"""测试数据库连接"""
version_df = self.db.query_to_df("SELECT VERSION() as version")
self.assertIsNotNone(version_df)
self.assertEqual(len(version_df), 1)
print(f"数据库版本: {version_df['version'].iloc[0]}")
version = self.db.query_to_df("SELECT VERSION() as version")
self.assertIsNotNone(version)
print(f"\nDatabase version: {version['version'].iloc[0]}")
print(f"Running on: {platform.system()} {platform.release()}")
def test_query_to_df(self):
def test_02_query_to_df(self):
"""测试查询返回DataFrame"""
df = self.db.query_to_df(
f"SELECT * FROM {self.test_table} WHERE id > %s",
params=(1,)
)
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id > %s", (1,))
self.assertEqual(len(df), 2)
self.assertIsInstance(df, pd.DataFrame)
self.assertEqual(len(df), 2) # id>1 的数据有2条
self.assertIn('name', df.columns)
print("\nQuery result sample:")
print(df.head())
def test_insert_from_df(self):
def test_03_insert_from_df(self):
"""测试DataFrame插入"""
new_data = pd.DataFrame({
'id': [4, 5],
@@ -106,55 +113,55 @@ class TestMySQLAgent(unittest.TestCase):
'created_at': pd.to_datetime(['2023-01-04', '2023-01-05'])
})
inserted_rows = self.db.insert_from_df(self.test_table, new_data)
self.assertEqual(inserted_rows, 2)
rows = self.db.insert_from_df(self.test_table, new_data)
self.assertEqual(rows, 2)
# 验证插入结果
result_df = self.db.query_to_df(
f"SELECT name FROM {self.test_table} WHERE id IN (4,5)"
)
self.assertEqual(result_df['name'].tolist(), ['Test4', 'Test5'])
# 验证数据
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id >= 4")
self.assertEqual(len(df), 2)
self.assertEqual(df['name'].tolist(), ['Test4', 'Test5'])
def test_update_from_df(self):
def test_04_update_from_df(self):
"""测试DataFrame更新"""
update_data = pd.DataFrame({
'id': [1, 2],
'name': ['Updated1', 'Updated2']
})
updated_rows = self.db.update_from_df(self.test_table, update_data, 'id')
self.assertGreaterEqual(updated_rows, 2)
rows = self.db.update_from_df(self.test_table, update_data, 'id')
self.assertGreaterEqual(rows, 2)
# 验证更新结果
result_df = self.db.query_to_df(
f"SELECT name FROM {self.test_table} WHERE id IN (1,2)"
)
self.assertIn('Updated1', result_df['name'].values)
self.assertIn('Updated2', result_df['name'].values)
# 验证更新
df = self.db.query_to_df(f"SELECT name FROM {self.test_table} WHERE id IN (1,2)")
self.assertIn('Updated1', df['name'].values)
self.assertIn('Updated2', df['name'].values)
def test_transaction(self):
def test_05_transaction(self):
"""测试事务处理"""
conn = self.db.begin_transaction()
try:
# 执行事务内操作
# 执行多个操作
cursor = conn.cursor()
cursor.execute(f"UPDATE {self.test_table} SET value = 99.9 WHERE id = 1")
cursor.execute(f"UPDATE {self.test_table} SET value = 88.8 WHERE id = 2")
# 验证事务内修改
cursor.execute(f"SELECT value FROM {self.test_table} WHERE id = 1")
self.assertEqual(cursor.fetchone()['value'], 99.9)
self.db.commit_transaction(conn)
except Exception:
self.db.rollback_transaction(conn)
raise
# 验证事务提交结果
result_df = self.db.query_to_df(
f"SELECT value FROM {self.test_table} WHERE id IN (1,2)"
)
self.assertIn(99.9, result_df['value'].values)
self.assertIn(88.8, result_df['value'].values)
# 验证提交后的修改
df = self.db.query_to_df(f"SELECT value FROM {self.test_table} WHERE id IN (1,2)")
self.assertIn(99.9, df['value'].values)
self.assertIn(88.8, df['value'].values)
def test_large_data_insert(self):
"""测试大数据量插入"""
# 生成1000行测试数据
def test_06_large_data(self):
"""测试大数据量操作"""
# 生成测试数据
large_data = pd.DataFrame({
'id': range(1000, 2000),
'name': [f"Item_{i}" for i in range(1000, 2000)],
@@ -162,55 +169,59 @@ class TestMySQLAgent(unittest.TestCase):
'created_at': pd.date_range('2023-01-01', periods=1000)
})
# 根据平台自动调整批次大小
# Windows平台使用更小的批次
chunk_size = 100 if platform.system() == 'Windows' else 500
start_time = time.time()
inserted_rows = self.db.insert_from_df(
self.test_table,
large_data,
chunk_size=chunk_size
)
rows = self.db.insert_from_df(self.test_table, large_data, chunk_size=chunk_size)
elapsed = time.time() - start_time
self.assertEqual(inserted_rows, 1000)
print(f"插入1000行数据耗时: {elapsed:.2f} (批次大小: {chunk_size})")
self.assertEqual(rows, 1000)
print(f"\nInserted 1000 rows in {elapsed:.2f}s (chunk_size={chunk_size})")
def test_concurrent_access(self):
# 验证数据
df = self.db.query_to_df(f"SELECT COUNT(*) as cnt FROM {self.test_table} WHERE id >= 1000")
self.assertEqual(df['cnt'].iloc[0], 1000)
def test_07_concurrent_access(self):
"""测试并发访问"""
from concurrent.futures import ThreadPoolExecutor
def query_worker(i):
"""并发查询工作函数"""
df = self.db.query_to_df(
f"SELECT * FROM {self.test_table} WHERE id = %s",
params=(i % 3 + 1,) # 查询id=1,2,3循环
)
def worker(i):
df = self.db.query_to_df(f"SELECT * FROM {self.test_table} WHERE id = %s", (i % 5 + 1,))
return len(df)
# 20个线程执行100次查询
start_time = time.time()
with ThreadPoolExecutor(max_workers=20) as executor:
results = list(executor.map(query_worker, range(100)))
elapsed = time.time() - start_time
results = list(executor.map(worker, range(100)))
self.assertEqual(sum(results), 100) # 每次查询应返回1行
print(f"100次并发查询耗时: {elapsed:.2f}")
elapsed = time.time() - start_time
self.assertEqual(sum(results), 100)
print(f"\nCompleted 100 concurrent queries in {elapsed:.2f}s")
class TestPlatformSpecific(unittest.TestCase):
"""平台特定功能测试"""
@classmethod
def setUpClass(cls):
cls.test_db_name = f"test_platform_db_{datetime.now().strftime('%Y%m%d%H%M%S')}"
"""创建临时测试数据库"""
cls.test_db_name = "test_db_platform_" + datetime.now().strftime("%Y%m%d%H%M%S")
cls.base_config = {
'host': 'localhost',
'port': 3306,
'user': 'root',
'password': '123123'
'password': '123123',
'max_connections': 10
}
# 创建测试数据库
temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4')
# 创建数据库
temp_conn = pymysql.connect(
host=cls.base_config['host'],
port=cls.base_config['port'],
user=cls.base_config['user'],
password=cls.base_config['password'],
charset='utf8mb4'
)
try:
with temp_conn.cursor() as cursor:
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {cls.test_db_name}")
@@ -220,8 +231,15 @@ class TestPlatformSpecific(unittest.TestCase):
@classmethod
def tearDownClass(cls):
"""清理测试数据库"""
temp_conn = pymysql.connect(**cls.base_config, charset='utf8mb4')
"""删除临时测试数据库"""
temp_conn = pymysql.connect(
host=cls.base_config['host'],
port=cls.base_config['port'],
user=cls.base_config['user'],
password=cls.base_config['password'],
charset='utf8mb4'
)
try:
with temp_conn.cursor() as cursor:
cursor.execute(f"DROP DATABASE IF EXISTS {cls.test_db_name}")
@@ -232,49 +250,42 @@ class TestPlatformSpecific(unittest.TestCase):
def test_windows_timeout(self):
"""测试Windows平台超时处理"""
if platform.system() != 'Windows':
self.skipTest("仅在Windows平台运行")
self.skipTest("Only runs on Windows")
config = {
**self.base_config,
'database': self.test_db_name,
'connect_timeout': 1,
'read_timeout': 1,
'write_timeout': 1
'read_timeout': 1
}
db = MySQLAgent(config)
# 执行会超时查询(SLEEP(2)超过1秒超时设置)
with self.assertRaises((pymysql.OperationalError, TimeoutError)) as ctx:
try:
db.query_to_df("SELECT SLEEP(2)")
except Exception as e:
# 提取底层异常信息(可能被包装)
while hasattr(e, 'args') and len(e.args) > 0 and isinstance(e.args[0], Exception):
e = e.args[0]
raise e
# 测试短超时查询
start_time = time.time()
try:
db.query_to_df("SELECT SLEEP(2)")
self.fail("Should have timed out")
except Exception as e:
self.assertIn("timed out", str(e))
print(f"\nWindows timeout test: {str(e)}")
error_msg = str(ctx.exception)
self.assertTrue(
"timed out" in error_msg or
"timeout" in error_msg or
"HY000" in error_msg, # MySQL超时错误码
f"未检测到超时异常,实际异常: {error_msg}"
)
def test_macos_ssl_connection(self):
"""测试macOS平台SSL连接"""
def test_macos_ssl(self):
"""测试macOS SSL连接"""
if platform.system() != 'Darwin':
self.skipTest("仅在macOS平台运行")
self.skipTest("Only runs on macOS")
config = {
**self.base_config,
'database': self.test_db_name,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
}
db = MySQLAgent(config)
version_df = db.query_to_df("SELECT VERSION() as version")
self.assertIsNotNone(version_df)
version = db.query_to_df("SELECT VERSION() as version")
self.assertIsNotNone(version)
print(f"\nmacOS SSL connection successful: {version['version'].iloc[0]}")
if __name__ == '__main__':
unittest.main(verbosity=2)
unittest.main()
-18
View File
@@ -1,18 +0,0 @@
CREATE TABLE IF NOT EXISTS main_task (
task_id INT AUTO_INCREMENT PRIMARY KEY COMMENT '任务唯一ID',
task_name VARCHAR(100) NOT NULL COMMENT '任务名称',
task_type VARCHAR(50) NOT NULL COMMENT '任务类型(如processor、collector等)',
module_path VARCHAR(255) NOT NULL COMMENT '任务模块路径(如processors.data_checker',
cron_expression VARCHAR(100) NOT NULL COMMENT 'Cron表达式(调度频率)',
time_zone VARCHAR(50) DEFAULT 'Asia/Shanghai' COMMENT '时区', -- 补充此字段
next_run_time DATETIME NOT NULL COMMENT '下次运行时间',
last_run_time DATETIME NULL COMMENT '上次运行时间',
last_run_status ENUM('success', 'failed', 'pending') DEFAULT 'pending' COMMENT '上次运行状态',
run_count INT DEFAULT 0 COMMENT '运行次数统计',
is_active TINYINT(1) DEFAULT 1 COMMENT '是否活跃(1=启用,0=禁用)',
is_running TINYINT(1) DEFAULT 0 COMMENT '是否正在运行(1=运行中,0=未运行)',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
INDEX idx_next_run (next_run_time) COMMENT '优化下次运行时间查询', -- 建议保留索引提升性能
INDEX idx_active (is_active) COMMENT '优化活跃任务查询' -- 建议保留索引提升性能
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='任务调度主表';
Binary file not shown.
-957
View File
@@ -1,957 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "197b1b81f5528a50",
"metadata": {},
"source": [
"## 1. 初始化(所有操作前必须运行)"
]
},
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2025-10-29T02:25:08.582541Z",
"start_time": "2025-10-29T02:25:08.473381Z"
}
},
"source": [
"# 使 Notebook 可从项目根导入\n",
"import sys\n",
"from pathlib import Path\n",
"\n",
"def add_project_root(marker_dirs=(\"utils\", \"system_management\")):\n",
" p = Path.cwd()\n",
" for _ in range(6):\n",
" if all((p / d).exists() for d in marker_dirs):\n",
" if str(p) not in sys.path:\n",
" sys.path.insert(0, str(p))\n",
" return p\n",
" p = p.parent\n",
" raise RuntimeError(\"未找到项目根目录,请从项目根启动 Notebook 或手动设置 sys.path\")\n",
"\n",
"PROJECT_ROOT = add_project_root()\n",
"print(f\"PROJECT_ROOT = {PROJECT_ROOT}\")\n",
"\n",
"# 依赖与日志\n",
"import pandas as pd\n",
"from IPython.display import display, HTML, Markdown\n",
"from utils.logger import CrossPlatformLog\n",
"log = CrossPlatformLog.get_logger(\"TaskNotebook\")\n",
"\n",
"# 配置与调度器\n",
"from config import Config # 若你使用 ConfigManager,请改为: from config.config import ConfigManager\n",
"from system_management.scheduler.task_scheduler import TaskScheduler\n",
"\n",
"# 初始化调度器(根据你的项目配置选一段)\n",
"scheduler = TaskScheduler(Config.MYSQL_CONFIG)\n",
"# 或使用 ConfigManager\n",
"# config = ConfigManager()\n",
"# scheduler = TaskScheduler(config.get(\"database\"))\n",
"\n",
"# 在 Notebook 中实现一个最小可用的 TaskManager\n",
"class TaskManager:\n",
" def __init__(self, scheduler: TaskScheduler):\n",
" self.scheduler = scheduler\n",
" self.db = scheduler.db # 复用调度器里的 MySQLAgent\n",
"\n",
" def get_all_tasks(self, active_only: bool = False):\n",
" sql = \"\"\"\n",
" SELECT *\n",
" FROM main_task\n",
" {where}\n",
" ORDER BY created_at DESC, task_id DESC\n",
" \"\"\"\n",
" where = \"WHERE is_active = 1\" if active_only else \"\"\n",
" df = self.db.query_to_df(sql.format(where=where))\n",
" return [] if df is None or df.empty else df.to_dict(\"records\")\n",
"\n",
" def get_task_by_id(self, task_id: int):\n",
" df = self.db.query_to_df(\n",
" \"SELECT * FROM main_task WHERE task_id = %s\",\n",
" params=(task_id,)\n",
" )\n",
" return None if df is None or df.empty else df.iloc[0].to_dict()\n",
"\n",
" def update_task(self, task_id: int, updates: dict) -> bool:\n",
" if not updates:\n",
" return False\n",
" # 允许更新的字段(与调度器一致)\n",
" allowed = {\n",
" \"task_name\", \"task_type\", \"module_path\",\n",
" \"cron_expression\", \"time_zone\"\n",
" }\n",
" filtered = {k: v for k, v in updates.items() if k in allowed}\n",
" if not filtered:\n",
" return False\n",
"\n",
" set_clause = \", \".join([f\"{k}=%s\" for k in filtered.keys()])\n",
" params = list(filtered.values()) + [task_id]\n",
" sql = f\"UPDATE main_task SET {set_clause}, updated_at=NOW() WHERE task_id=%s\"\n",
" affected = self.db.execute_sql(sql, params=params)\n",
" return affected == 1\n",
"\n",
" def toggle_task_status(self, task_id: int, activate: bool) -> bool:\n",
" sql = \"UPDATE main_task SET is_active=%s, updated_at=NOW() WHERE task_id=%s\"\n",
" affected = self.db.execute_sql(sql, params=(1 if activate else 0, task_id))\n",
" return affected == 1\n",
"\n",
" def delete_task(self, task_id: int) -> bool:\n",
" # 如果你更偏好软删除,可以改为: UPDATE main_task SET is_active=0, updated_at=NOW() WHERE task_id=%s\n",
" affected = self.db.execute_sql(\"DELETE FROM main_task WHERE task_id=%s\", params=(task_id,))\n",
" return affected == 1\n",
"\n",
" def run_task_manually(self, task_id: int) -> bool:\n",
" # 读取任务,直接复用调度器的单任务执行逻辑\n",
" task = self.get_task_by_id(task_id)\n",
" if not task:\n",
" return False\n",
" # _process_single_task 期望 dict\n",
" try:\n",
" return bool(self.scheduler._process_single_task(task)) # 注意:使用了调度器的内部方法\n",
" except Exception:\n",
" log.exception(\"手动执行任务失败\")\n",
" return False\n",
"\n",
" def run_task_synchronously(self, task_id: int) -> dict:\n",
" \"\"\"同步执行任务并返回详细结果(用于Notebook中查看执行过程)\"\"\"\n",
" import time\n",
" import sys\n",
" from io import StringIO\n",
" \n",
" task = self.get_task_by_id(task_id)\n",
" if not task:\n",
" return {\n",
" 'success': False,\n",
" 'error': f'未找到任务ID: {task_id}',\n",
" 'output': ''\n",
" }\n",
" \n",
" # 捕获标准输出\n",
" old_stdout = sys.stdout\n",
" sys.stdout = output_buffer = StringIO()\n",
" \n",
" start_time = time.time()\n",
" success = False\n",
" error_msg = None\n",
" \n",
" try:\n",
" # 直接同步执行任务逻辑\n",
" self.scheduler._execute_task_logic(task)\n",
" success = True\n",
" \n",
" # 更新任务状态\n",
" next_run_time = self.scheduler._calculate_next_run_time(\n",
" cron_expr=task['cron_expression'],\n",
" time_zone=task.get('time_zone', 'Asia/Shanghai')\n",
" )\n",
" \n",
" self.scheduler._update_task_status(task['task_id'], {\n",
" 'last_run_status': 'success',\n",
" 'is_running': 0,\n",
" 'run_count': task['run_count'] + 1,\n",
" 'next_run_time': next_run_time\n",
" })\n",
" \n",
" except Exception as e:\n",
" success = False\n",
" error_msg = str(e)\n",
" log.exception(f\"任务执行失败: {task['task_name']}\")\n",
" \n",
" # 更新失败状态\n",
" try:\n",
" next_retry_time = datetime.now() + pd.Timedelta(minutes=15)\n",
" self.scheduler._update_task_status(task['task_id'], {\n",
" 'last_run_status': 'failed',\n",
" 'is_running': 0,\n",
" 'next_run_time': next_retry_time\n",
" })\n",
" except Exception:\n",
" pass\n",
" \n",
" finally:\n",
" # 恢复标准输出\n",
" sys.stdout = old_stdout\n",
" output_text = output_buffer.getvalue()\n",
" \n",
" execution_time = time.time() - start_time\n",
" \n",
" return {\n",
" 'success': success,\n",
" 'task_name': task['task_name'],\n",
" 'task_id': task['task_id'],\n",
" 'execution_time': execution_time,\n",
" 'output': output_text,\n",
" 'error': error_msg\n",
" }\n",
"\n",
"# 在这里创建 manager(供后续单元使用)\n",
"manager = TaskManager(scheduler)\n",
"\n",
"# 常用辅助函数\n",
"def format_datetime(dt):\n",
" if dt is None:\n",
" return \"-\"\n",
" try:\n",
" if isinstance(dt, pd.Timestamp) and pd.isna(dt):\n",
" return \"-\"\n",
" return dt.strftime(\"%Y-%m-%d %H:%M:%S\")\n",
" except Exception:\n",
" try:\n",
" if pd.isna(dt):\n",
" return \"-\"\n",
" except Exception:\n",
" pass\n",
" return str(dt)"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PROJECT_ROOT = D:\\Idea Project\\intelligence_system\n",
"\u001B[32m2025-10-29 10:25:08\u001B[0m | \u001B[1mINFO \u001B[0m | \u001B[36mtask_scheduler\u001B[0m - \u001B[1m任务调度器已初始化,最大工作线程数: 5\u001B[0m\n"
]
}
],
"execution_count": 8
},
{
"cell_type": "markdown",
"id": "8271189cef3b5f17",
"metadata": {},
"source": [
"## 2. 列出任务(对应命令行 list"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7b020af55972643",
"metadata": {
"ExecuteTime": {
"end_time": "2025-10-17T05:43:18.499582Z",
"start_time": "2025-10-17T05:43:18.394863Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001B[32m2025-10-29 09:54:09\u001B[0m | \u001B[1mINFO \u001B[0m | \u001B[36mmysql_agent\u001B[0m - \u001B[1m查询执行成功\u001B[0m\n"
]
},
{
"data": {
"text/markdown": [
"### 任务列表"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th>任务ID</th>\n",
" <th>任务名称</th>\n",
" <th>任务类型</th>\n",
" <th>模块路径</th>\n",
" <th>Cron表达式</th>\n",
" <th>时区</th>\n",
" <th>下次运行时间</th>\n",
" <th>最后运行时间</th>\n",
" <th>运行状态</th>\n",
" <th>运行次数</th>\n",
" <th>是否活跃</th>\n",
" <th>is_running</th>\n",
" <th>created_at</th>\n",
" <th>updated_at</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>RSS基于规则数据处理</td>\n",
" <td>processor</td>\n",
" <td>processors.processor_rss_data</td>\n",
" <td>0 8,20 * * *</td>\n",
" <td>Asia/Shanghai</td>\n",
" <td>2025-10-28 20:00:00</td>\n",
" <td>2025-10-28 13:34:49</td>\n",
" <td>success</td>\n",
" <td>10</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>2025-10-22 16:06:42</td>\n",
" <td>2025-10-28 13:34:50</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>RSS新闻订阅</td>\n",
" <td>collector</td>\n",
" <td>collectors.rss_subscriptions.NewsAPIClient</td>\n",
" <td>*/5 * * * *</td>\n",
" <td>Asia/Shanghai</td>\n",
" <td>2025-10-28 13:40:00</td>\n",
" <td>2025-10-28 13:35:09</td>\n",
" <td>success</td>\n",
" <td>495</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>2025-10-16 15:47:34</td>\n",
" <td>2025-10-28 13:35:09</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>任务ID</th>\n",
" <th>任务名称</th>\n",
" <th>任务类型</th>\n",
" <th>模块路径</th>\n",
" <th>Cron表达式</th>\n",
" <th>时区</th>\n",
" <th>下次运行时间</th>\n",
" <th>最后运行时间</th>\n",
" <th>运行状态</th>\n",
" <th>运行次数</th>\n",
" <th>是否活跃</th>\n",
" <th>is_running</th>\n",
" <th>created_at</th>\n",
" <th>updated_at</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2</td>\n",
" <td>RSS基于规则数据处理</td>\n",
" <td>processor</td>\n",
" <td>processors.processor_rss_data</td>\n",
" <td>0 8,20 * * *</td>\n",
" <td>Asia/Shanghai</td>\n",
" <td>2025-10-28 20:00:00</td>\n",
" <td>2025-10-28 13:34:49</td>\n",
" <td>success</td>\n",
" <td>10</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>2025-10-22 16:06:42</td>\n",
" <td>2025-10-28 13:34:50</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>RSS新闻订阅</td>\n",
" <td>collector</td>\n",
" <td>collectors.rss_subscriptions.NewsAPIClient</td>\n",
" <td>*/5 * * * *</td>\n",
" <td>Asia/Shanghai</td>\n",
" <td>2025-10-28 13:40:00</td>\n",
" <td>2025-10-28 13:35:09</td>\n",
" <td>success</td>\n",
" <td>495</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>2025-10-16 15:47:34</td>\n",
" <td>2025-10-28 13:35:09</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" 任务ID 任务名称 任务类型 模块路径 \\\n",
"0 2 RSS基于规则数据处理 processor processors.processor_rss_data \n",
"1 1 RSS新闻订阅 collector collectors.rss_subscriptions.NewsAPIClient \n",
"\n",
" Cron表达式 时区 下次运行时间 最后运行时间 \\\n",
"0 0 8,20 * * * Asia/Shanghai 2025-10-28 20:00:00 2025-10-28 13:34:49 \n",
"1 */5 * * * * Asia/Shanghai 2025-10-28 13:40:00 2025-10-28 13:35:09 \n",
"\n",
" 运行状态 运行次数 是否活跃 is_running created_at updated_at \n",
"0 success 10 1 0 2025-10-22 16:06:42 2025-10-28 13:34:50 \n",
"1 success 495 1 0 2025-10-16 15:47:34 2025-10-28 13:35:09 "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 列出所有任务(包括已禁用的)\n",
"def list_tasks(active_only=True):\n",
" tasks = manager.get_all_tasks(active_only)\n",
" if not tasks:\n",
" display(Markdown(\"### 没有找到任务\"))\n",
" return None\n",
"\n",
" df = pd.DataFrame(tasks)\n",
"\n",
" # 格式化日期列\n",
" if 'last_run_time' in df.columns:\n",
" df['last_run_time'] = df['last_run_time'].apply(format_datetime)\n",
" if 'next_run_time' in df.columns:\n",
" df['next_run_time'] = df['next_run_time'].apply(format_datetime)\n",
"\n",
" # 重命名列名\n",
" df = df.rename(columns={\n",
" 'task_id': '任务ID',\n",
" 'task_name': '任务名称',\n",
" 'task_type': '任务类型',\n",
" 'module_path': '模块路径',\n",
" 'cron_expression': 'Cron表达式',\n",
" 'time_zone': '时区',\n",
" 'last_run_time': '最后运行时间',\n",
" 'next_run_time': '下次运行时间',\n",
" 'last_run_status': '运行状态',\n",
" 'is_active': '是否活跃',\n",
" 'run_count': '运行次数'\n",
" })\n",
"\n",
" display(Markdown(\"### 任务列表\"))\n",
" display(HTML(df.to_html(index=False)))\n",
" return df\n",
"\n",
"# 执行:列出所有任务(包括已禁用)\n",
"list_tasks(active_only=False)\n",
"\n",
"# 或者:只列出活跃任务\n",
"# list_tasks(active_only=True)"
]
},
{
"cell_type": "markdown",
"id": "7780dcef67a0534c",
"metadata": {},
"source": [
"## 3. 查看任务详情(对应命令行 show)"
]
},
{
"cell_type": "code",
"id": "eab90de72c35429e",
"metadata": {
"ExecuteTime": {
"end_time": "2025-10-29T02:26:12.873536Z",
"start_time": "2025-10-29T02:26:12.648420Z"
}
},
"source": [
"# 查看指定任务的详情\n",
"def show_task_details(task_id):\n",
" task = manager.get_task_by_id(task_id)\n",
" if not task:\n",
" display(Markdown(f\"### 未找到任务ID为 {task_id} 的任务\"))\n",
" return None\n",
"\n",
" details = [\"### 任务详情\"]\n",
" details.append(f\"**任务ID**: {task.get('task_id')}\")\n",
" details.append(f\"**任务名称**: {task.get('task_name')}\")\n",
" details.append(f\"**任务类型**: {task.get('task_type')}\")\n",
" details.append(f\"**模块路径**: {task.get('module_path')}\")\n",
" details.append(f\"**Cron表达式**: {task.get('cron_expression')}\")\n",
" details.append(f\"**时区**: {task.get('time_zone', 'Asia/Shanghai')}\")\n",
" details.append(f\"**最后运行时间**: {format_datetime(task.get('last_run_time'))}\")\n",
" details.append(f\"**下次运行时间**: {format_datetime(task.get('next_run_time'))}\")\n",
" details.append(f\"**运行状态**: {task.get('last_run_status', '未运行')}\")\n",
" details.append(f\"**是否活跃**: {'是' if task.get('is_active') else '否'}\")\n",
" details.append(f\"**运行次数**: {task.get('run_count', 0)}\")\n",
" details.append(f\"**创建时间**: {format_datetime(task.get('created_at'))}\")\n",
"\n",
" display(Markdown('\\n'.join(details)))\n",
" return task\n",
"\n",
"# 执行:查看任务ID为1的详情(替换为实际ID)\n",
"show_task_details(1)"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001B[32m2025-10-29 10:26:12\u001B[0m | \u001B[1mINFO \u001B[0m | \u001B[36mmysql_agent\u001B[0m - \u001B[1m查询执行成功\u001B[0m\n"
]
},
{
"data": {
"text/plain": [
"<IPython.core.display.Markdown object>"
],
"text/markdown": "### 任务详情\n**任务ID**: 1\n**任务名称**: RSS新闻订阅\n**任务类型**: collector\n**模块路径**: processors.processor_rss_data.RSSDataProcessor\n**Cron表达式**: */5 * * * *\n**时区**: Asia/Shanghai\n**最后运行时间**: 2025-10-28 13:35:09\n**下次运行时间**: 2025-10-29 10:25:00\n**运行状态**: success\n**是否活跃**: 是\n**运行次数**: 496\n**创建时间**: 2025-10-16 15:47:34"
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
},
{
"data": {
"text/plain": [
"{'task_id': 1,\n",
" 'task_name': 'RSS新闻订阅',\n",
" 'task_type': 'collector',\n",
" 'module_path': 'processors.processor_rss_data.RSSDataProcessor',\n",
" 'cron_expression': '*/5 * * * *',\n",
" 'time_zone': 'Asia/Shanghai',\n",
" 'next_run_time': Timestamp('2025-10-29 10:25:00'),\n",
" 'last_run_time': Timestamp('2025-10-28 13:35:09'),\n",
" 'last_run_status': 'success',\n",
" 'run_count': 496,\n",
" 'is_active': 1,\n",
" 'is_running': 0,\n",
" 'created_at': Timestamp('2025-10-16 15:47:34'),\n",
" 'updated_at': Timestamp('2025-10-29 10:24:49')}"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 10
},
{
"cell_type": "markdown",
"id": "a313f1524f5a54bc",
"metadata": {},
"source": [
"## 4. 添加新任务(对应命令行 add)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2b2d723bb8e2784f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001B[32m2025-10-29 09:56:52\u001B[0m | \u001B[1mINFO \u001B[0m | \u001B[36mmysql_agent\u001B[0m - \u001B[1m查询执行成功\u001B[0m\n",
"\u001B[32m2025-10-29 09:56:52\u001B[0m | \u001B[1mINFO \u001B[0m | \u001B[36mtask_scheduler\u001B[0m - \u001B[1m新任务添加成功\u001B[0m\n"
]
},
{
"data": {
"text/markdown": [
"### 任务添加成功!"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"新任务ID: 0,任务名称: AI处理RSS新闻"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"np.int64(0)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 添加新任务\n",
"def add_new_task(name, task_type, module_path, cron_expression, timezone=\"Asia/Shanghai\"):\n",
" try:\n",
" task_id = scheduler.add_task(\n",
" task_name=name,\n",
" task_type=task_type,\n",
" module_path=module_path,\n",
" cron_expression=cron_expression,\n",
" time_zone=timezone\n",
" )\n",
" display(Markdown(f\"### 任务添加成功!\"))\n",
" display(Markdown(f\"新任务ID: {task_id},任务名称: {name}\"))\n",
" return task_id\n",
" except Exception as e:\n",
" display(Markdown(f\"### 添加任务失败: {str(e)}\"))\n",
" return None\n",
"\n",
"# 执行:添加一个新闻采集任务\n",
"add_new_task(\n",
" name=\"AI处理RSS新闻\",\n",
" task_type=\"processor\",\n",
" module_path=\"processors.ai_processors.ai_processor_rss_data.RSSDataAIProcessor\",\n",
" cron_expression=\"5 0 * * *\", # 每5分钟执行1次\n",
" timezone=\"Asia/Shanghai\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "12373bcbb4a0b434",
"metadata": {},
"source": [
"## 5. 更新任务属性(对应命令行 update)"
]
},
{
"cell_type": "code",
"id": "c892fd8ad2f0dd9d",
"metadata": {
"ExecuteTime": {
"end_time": "2025-10-29T02:29:56.088085Z",
"start_time": "2025-10-29T02:29:55.754298Z"
}
},
"source": [
"# 更新任务属性\n",
"def update_task(task_id, **kwargs):\n",
" updates = {}\n",
" if 'name' in kwargs and kwargs['name']:\n",
" updates['task_name'] = kwargs['name']\n",
" if 'type' in kwargs and kwargs['type']:\n",
" updates['task_type'] = kwargs['type']\n",
" if 'module' in kwargs and kwargs['module']:\n",
" updates['module_path'] = kwargs['module']\n",
" if 'cron' in kwargs and kwargs['cron']:\n",
" updates['cron_expression'] = kwargs['cron']\n",
" if 'timezone' in kwargs and kwargs['timezone']:\n",
" updates['time_zone'] = kwargs['timezone']\n",
"\n",
" if not updates:\n",
" display(Markdown(\"### 没有提供任何更新内容\"))\n",
" return False\n",
"\n",
" success = manager.update_task(task_id, updates)\n",
" if success:\n",
" display(Markdown(f\"### 任务ID {task_id} 更新成功\"))\n",
" show_task_details(task_id) # 显示更新后的详情\n",
" else:\n",
" display(Markdown(f\"### 任务ID {task_id} 更新失败\"))\n",
" return success\n",
"\n",
"# 执行:更新任务(示例:修改任务1的Cron表达式为每天10点)\n",
"update_task(2, module = \"processors.processor_rss_data\")\n",
"\n",
"# 执行:同时更新多个属性(名称和Cron表达式)\n",
"# update_task(1, name=\"每日早间新闻采集\", cron=\"0 8 * * *\")"
],
"outputs": [
{
"data": {
"text/plain": [
"<IPython.core.display.Markdown object>"
],
"text/markdown": "### 任务ID 2 更新成功"
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001B[32m2025-10-29 10:29:56\u001B[0m | \u001B[1mINFO \u001B[0m | \u001B[36mmysql_agent\u001B[0m - \u001B[1m查询执行成功\u001B[0m\n"
]
},
{
"data": {
"text/plain": [
"<IPython.core.display.Markdown object>"
],
"text/markdown": "### 任务详情\n**任务ID**: 2\n**任务名称**: RSS基于规则数据处理\n**任务类型**: processor\n**模块路径**: processors.processor_rss_data\n**Cron表达式**: 0 8,20 * * *\n**时区**: Asia/Shanghai\n**最后运行时间**: 2025-10-28 13:34:49\n**下次运行时间**: 2025-10-28 20:00:00\n**运行状态**: success\n**是否活跃**: 是\n**运行次数**: 10\n**创建时间**: 2025-10-22 16:06:42"
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
},
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 21
},
{
"cell_type": "markdown",
"id": "37564011cf5aa501",
"metadata": {},
"source": [
"## 6. 启用 / 禁用任务(对应命令行 toggle"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "65388d10c5c8d407",
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"### 任务ID 1 启用成功"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 启用或禁用任务\n",
"def toggle_task_status(task_id, activate=True):\n",
" success = manager.toggle_task_status(task_id, activate)\n",
" action = \"启用\" if activate else \"禁用\"\n",
" if success:\n",
" display(Markdown(f\"### 任务ID {task_id} {action}成功\"))\n",
" else:\n",
" display(Markdown(f\"### 任务ID {task_id} {action}失败\"))\n",
" return success\n",
"\n",
"# 执行:启用任务ID为1的任务\n",
"toggle_task_status(1, activate=True)\n",
"\n",
"# 执行:禁用任务ID为1的任务\n",
"# toggle_task_status(1, activate=False)"
]
},
{
"cell_type": "markdown",
"id": "c554c748169d5ac8",
"metadata": {},
"source": [
"## 7. 手动执行任务(对应命令行 run\n",
"\n",
"自动识别main,即main的上一级"
]
},
{
"cell_type": "code",
"id": "94892f4134316f8e",
"metadata": {
"jupyter": {
"is_executing": true
},
"ExecuteTime": {
"start_time": "2025-10-29T02:30:10.298891Z"
}
},
"source": [
"# 手动执行任务(异步方式,快速返回)\n",
"def run_task_manually(task_id):\n",
" display(Markdown(f\"### 正在手动执行任务ID {task_id}...\"))\n",
" success = manager.run_task_manually(task_id)\n",
" if success:\n",
" display(Markdown(f\"### 任务ID {task_id} 执行成功\"))\n",
" else:\n",
" display(Markdown(f\"### 任务ID {task_id} 执行失败\"))\n",
" return success\n",
"\n",
"# 手动执行任务(同步方式,显示详细执行过程)\n",
"def run_task_with_details(task_id):\n",
" display(Markdown(f\"### 开始执行任务ID {task_id}\"))\n",
" display(Markdown(\"---\"))\n",
" \n",
" result = manager.run_task_synchronously(task_id)\n",
" \n",
" if not result['success'] and result.get('error') and 'task_id' not in result:\n",
" display(Markdown(f\"### ❌ 错误: {result['error']}\"))\n",
" return result\n",
" \n",
" # 显示任务基本信息\n",
" display(Markdown(f\"**任务名称**: {result['task_name']}\"))\n",
" display(Markdown(f\"**任务ID**: {result['task_id']}\"))\n",
" display(Markdown(f\"**执行时长**: {result['execution_time']:.2f} 秒\"))\n",
" display(Markdown(\"---\"))\n",
" \n",
" # 显示执行输出\n",
" if result['output']:\n",
" display(Markdown(\"### 📋 执行输出:\"))\n",
" print(result['output'])\n",
" display(Markdown(\"---\"))\n",
" \n",
" # 显示执行结果\n",
" if result['success']:\n",
" display(Markdown(\"### ✅ 任务执行成功\"))\n",
" else:\n",
" display(Markdown(f\"### ❌ 任务执行失败\"))\n",
" if result['error']:\n",
" display(Markdown(f\"**错误信息**: {result['error']}\"))\n",
" \n",
" return result\n",
"\n",
"# 执行:手动运行任务ID为2的任务(显示详细执行过程)\n",
"run_task_with_details(3)"
],
"outputs": [
{
"data": {
"text/plain": [
"<IPython.core.display.Markdown object>"
],
"text/markdown": "### 开始执行任务ID 3"
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
},
{
"data": {
"text/plain": [
"<IPython.core.display.Markdown object>"
],
"text/markdown": "---"
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001B[32m2025-10-29 10:30:10\u001B[0m | \u001B[1mINFO \u001B[0m | \u001B[36mmysql_agent\u001B[0m - \u001B[1m查询执行成功\u001B[0m\n",
"\u001B[32m2025-10-29 10:30:11\u001B[0m | \u001B[1mINFO \u001B[0m | \u001B[36mai_processor_rss_data\u001B[0m - \u001B[1mRSS数据AI处理器初始化完成\u001B[0m\n",
"\u001B[32m2025-10-29 10:30:11\u001B[0m | \u001B[1mINFO \u001B[0m | \u001B[36mai_processor_rss_data\u001B[0m - \u001B[1m开始批量处理数据,批次大小: 200, 延迟: 1.5秒\u001B[0m\n",
"\u001B[32m2025-10-29 10:30:11\u001B[0m | \u001B[1mINFO \u001B[0m | \u001B[36mai_processor_rss_data\u001B[0m - \u001B[1m成功加载 3 条未处理的数据\u001B[0m\n"
]
}
],
"execution_count": null
},
{
"cell_type": "markdown",
"id": "c3492a1af7dbf2b1",
"metadata": {},
"source": [
"## 8. 删除任务(对应命令行 delete"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6936dcc673933a8d",
"metadata": {},
"outputs": [],
"source": [
"# 删除任务\n",
"def delete_task(task_id, confirm=False):\n",
" if not confirm:\n",
" display(Markdown(f\"### 警告:删除任务是不可逆操作!\"))\n",
" display(Markdown(f\"请运行 `delete_task({task_id}, confirm=True)` 确认删除\"))\n",
" return False\n",
"\n",
" success = manager.delete_task(task_id)\n",
" if success:\n",
" display(Markdown(f\"### 任务ID {task_id} 删除成功\"))\n",
" else:\n",
" display(Markdown(f\"### 任务ID {task_id} 删除失败\"))\n",
" return success\n",
"\n",
"# 执行:第一步 - 确认删除(不会实际删除)\n",
"delete_task(1)\n",
"\n",
"# 执行:第二步 - 实际删除(谨慎操作!)\n",
"# delete_task(1, confirm=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "intelligence_system",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
-6
View File
@@ -1,6 +0,0 @@
use intelligence_system;
SELECT * FROM main_task
WHERE is_active = 1
AND next_run_time <= %s
AND is_running = 0
ORDER BY next_run_time;
-1
View File
@@ -1 +0,0 @@
from .logger import CrossPlatformLog
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+17 -35
View File
@@ -1,11 +1,10 @@
import os
import shutil
import zipfile
import pickle
import pandas as pd
from datetime import datetime
from pathlib import Path, PurePath
from typing import Union, Optional, List, Dict, Any, Callable
from typing import Union, Optional, List, Dict, Any
from utils.logger import log
class FileHandler:
@@ -72,17 +71,6 @@ class FileHandler:
df = pd.read_excel(file_path, **kwargs)
elif ext == 'json':
df = pd.read_json(file_path, encoding=encoding, **kwargs)
elif ext in ['pkl', 'pickle']:
# 统一将pickle内容转为DataFrame返回
obj = pd.read_pickle(file_path)
if isinstance(obj, pd.DataFrame):
df = obj
elif isinstance(obj, list):
df = pd.DataFrame(obj)
elif isinstance(obj, dict):
df = pd.DataFrame([obj])
else:
df = pd.DataFrame({'content': [obj]})
elif ext == 'parquet':
df = pd.read_parquet(file_path, **kwargs)
else:
@@ -114,31 +102,25 @@ class FileHandler:
if not parent_dir.exists():
self.create_dir(parent_dir)
# 统一数据格式
if isinstance(data, pd.DataFrame):
df = data
else:
df = pd.DataFrame(data if isinstance(data, list) else [data])
# 根据扩展名选择写入方式
ext = self.get_file_extension(file_path)
if ext in ['pkl', 'pickle']:
# 直接按原始对象进行pickle序列化
with open(file_path, 'wb') as f:
pickle.dump(data, f)
if ext in ['csv', 'txt']:
df.to_csv(file_path, encoding=encoding, index=False, **kwargs)
elif ext in ['xls', 'xlsx']:
df.to_excel(file_path, index=False, **kwargs)
elif ext == 'json':
df.to_json(file_path, force_ascii=False, **kwargs)
elif ext == 'parquet':
df.to_parquet(file_path, **kwargs)
else:
# 统一数据格式到DataFrame
if isinstance(data, pd.DataFrame):
df = data
else:
df = pd.DataFrame(data if isinstance(data, list) else [data])
if ext in ['csv', 'txt']:
df.to_csv(file_path, encoding=encoding, index=False, **kwargs)
elif ext in ['xls', 'xlsx']:
df.to_excel(file_path, index=False, **kwargs)
elif ext == 'json':
df.to_json(file_path, force_ascii=False, **kwargs)
elif ext == 'parquet':
df.to_parquet(file_path, **kwargs)
else:
with open(file_path, 'w', encoding=encoding) as f:
f.write(str(data))
with open(file_path, 'w', encoding=encoding) as f:
f.write(str(data))
# 返回成功结果
return self._format_result(
+4 -32
View File
@@ -35,7 +35,6 @@ class CrossPlatformLog:
"""配置跨平台日志处理器"""
logger.remove() # 清除默认配置
# 统一控制台输出格式
logger.add(
sys.stdout,
@@ -59,47 +58,20 @@ class CrossPlatformLog:
compression=self._compress_log,
encoding="utf-8",
level="DEBUG",
# 👇 增加 {extra} 输出,并美化结构
# format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {module}:{line} - {message}{extra_output}",
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {module}:{line} - {message}",
retention="30 days",
enqueue=True,
# 👇 动态处理 extra 字段为可读格式
format=self._format_with_extra, # 使用自定义格式函数
enqueue=True # 线程安全
)
def _format_with_extra(self, record):
# 构造 extra 的可读字符串
extra_str = ""
if record["extra"]:
extra_items = []
for key, value in record["extra"].items():
if key == "extra_output": # 跳过自己,避免递归
continue
value_repr = repr(value)
# 对于错误信息,增加截断长度限制,避免丢失重要信息
if key in ["error", "error_message", "sql", "params"]:
if len(value_repr) > 500:
value_repr = value_repr[:497] + "..."
elif len(value_repr) > 200:
value_repr = value_repr[:197] + "..."
extra_items.append(f"\n{key}: {value_repr}")
extra_str = "".join(extra_items)
# 👉 直接将 extra_str 写入 message 或附加字段
record["extra"]["extra_output"] = extra_str
# ✅ 关键:返回的 format 字符串不再引用 {extra_output},而是使用 {extra[extra_output]}
return "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {module}:{line} - {message}{extra[extra_output]}\n"
def _add_error_log(self):
"""错误日志专用配置"""
error_log = self.log_dir / "errors.log"
logger.add(
str(error_log),
level="ERROR",
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | ERROR | {module}:{line} - {message}{extra[extra_output]}\n{exception}",
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | ERROR | {module}:{line} - {message}\n{exception}",
rotation="10 MB",
retention="90 days",
enqueue=True
retention="90 days"
)
@staticmethod
-383
View File
@@ -1,383 +0,0 @@
import os
import sys
import platform
import threading
from typing import List, Dict, Optional, BinaryIO, Tuple, Any
from datetime import datetime, timedelta
import hashlib
from io import BytesIO
from minio import Minio
from minio.error import S3Error, MinioException
from utils.logger import log
class MinIOAgent:
"""
全平台兼容的MinIO对象存储操作类
支持Windows/macOS/Linux系统提供对象存储的上传下载查询等功能
专注于二进制数据处理返回元数据用于与MySQL关联
"""
_instance = None # 单例模式实例
_lock = threading.Lock() # 线程锁,保证单例线程安全
def __new__(cls, *args, **kwargs):
"""单例模式实现,确保全局只有一个实例"""
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, config: dict):
"""
初始化MinIO连接
参数:
config (dict): MinIO配置字典包含以下键:
- endpoint: 服务端点'localhost:9000'
- access_key: 访问密钥
- secret_key: 密钥
- [可选] secure: 是否使用SSL默认False
- [可选] region: 区域
- [可选] timeout: 超时时间(默认30)
"""
# 避免重复初始化
if hasattr(self, '_client') and self._client:
return
# 验证必要配置参数
required_keys = ['endpoint', 'access_key', 'secret_key']
if not all(key in config for key in required_keys):
raise ValueError(f"MinIO配置缺少必要参数,需要: {required_keys}")
# 整合配置,设置默认值
self.config = {
'endpoint': config['endpoint'],
'access_key': config['access_key'],
'secret_key': config['secret_key'],
'secure': config.get('secure', False),
'region': config.get('region'),
'timeout': config.get('timeout', 30)
}
# 初始化日志,绑定当前平台信息
current_platform = platform.system()
self.log = log.bind(module=f"MinIOAgent({current_platform})")
# 创建客户端实例
self._client = self._create_client()
# 验证连接是否有效
self._verify_connection()
def _create_client(self) -> Minio:
"""创建MinIO客户端实例"""
try:
client = Minio(
endpoint=self.config['endpoint'],
access_key=self.config['access_key'],
secret_key=self.config['secret_key'],
secure=self.config['secure'],
region=self.config['region']
)
self.log.info("MinIO客户端创建成功")
return client
except Exception as e:
self.log.critical("创建MinIO客户端失败", 错误=str(e), exc_info=True)
raise
def _verify_connection(self) -> None:
"""验证与MinIO服务的连接是否正常"""
try:
# 通过列出存储桶来验证连接
self._client.list_buckets()
self.log.info(f"成功连接到MinIO服务:{self.config['endpoint']}")
except Exception as e:
self.log.critical("连接验证失败", 错误=str(e), exc_info=True)
raise
def create_bucket(self, bucket_name: str) -> bool:
"""
创建存储桶如不存在
参数:
bucket_name: 存储桶名称
返回:
是否成功创建或已存在
"""
try:
if not self._client.bucket_exists(bucket_name):
self._client.make_bucket(bucket_name)
self.log.info(f"存储桶创建成功:{bucket_name}")
return True
self.log.debug(f"存储桶已存在:{bucket_name}")
return True
except MinioException as e:
self.log.error(f"创建存储桶 {bucket_name} 失败", 错误=str(e), exc_info=True)
return False
def upload_bytes(self, bucket: str, object_name: str, data: bytes) -> Dict[str, Any]:
"""
上传二进制数据至MinIO
参数:
bucket: 存储桶名称
object_name: 对象名称路径
data: 二进制数据
返回:
包含元数据的字典:
- bucket: 存储桶名称
- object_name: 对象路径
- size: 数据大小(字节)
- etag: 服务器生成的哈希值
- content_type: 内容类型
- upload_time: 上传时间(UTC)
- local_hash: 本地计算的MD5哈希
"""
if not data:
raise ValueError("上传数据不能为空")
# 确保存储桶存在
self.create_bucket(bucket)
try:
# 计算本地哈希(用于数据完整性校验)
local_hash = hashlib.md5(data).hexdigest()
# 上传数据
result = self._client.put_object(
bucket_name=bucket,
object_name=object_name,
data=BytesIO(data),
length=len(data),
content_type=self._guess_content_type(object_name)
)
# 构建元数据
metadata = {
'bucket': bucket,
'object_name': object_name,
'size': len(data),
'etag': result.etag,
'content_type': result.content_type,
'upload_time': datetime.utcfromtimestamp(result.last_modified.timestamp()),
'local_hash': local_hash
}
self.log.info(
"文件上传成功",
存储桶=bucket,
对象名称=object_name,
大小=len(data)
)
return metadata
except MinioException as e:
self.log.error(
"文件上传失败",
存储桶=bucket,
对象名称=object_name,
错误=str(e),
exc_info=True
)
raise
def download_file(self, bucket: str, object_name: str, local_path: str) -> Dict[str, Any]:
"""
从MinIO下载文件至本地
参数:
bucket: 存储桶名称
object_name: 对象名称路径
local_path: 本地保存路径
返回:
包含下载信息的字典:
- local_path: 本地路径
- size: 文件大小
- download_time: 下载时间
"""
try:
# 创建父目录(如果不存在)
os.makedirs(os.path.dirname(local_path), exist_ok=True)
# 下载文件
start_time = datetime.now()
self._client.fget_object(bucket, object_name, local_path)
download_time = datetime.now() - start_time
# 获取文件信息
stat = os.stat(local_path)
result = {
'local_path': local_path,
'size': stat.st_size,
'download_time': download_time.total_seconds(),
'downloaded_at': datetime.now()
}
self.log.info(
"文件下载成功",
存储桶=bucket,
对象名称=object_name,
本地路径=local_path,
大小=stat.st_size
)
return result
except MinioException as e:
self.log.error(
"文件下载失败",
存储桶=bucket,
对象名称=object_name,
错误=str(e),
exc_info=True
)
raise
except IOError as e:
self.log.error(
"本地文件操作失败",
本地路径=local_path,
错误=str(e),
exc_info=True
)
raise
def get_presigned_url(self, bucket: str, object_name: str, expires: int = 3600) -> Dict[str, str]:
"""
生成临时访问URL
参数:
bucket: 存储桶名称
object_name: 对象名称路径
expires: 过期时间()默认3600秒
返回:
包含URL和过期信息的字典
"""
try:
url = self._client.presigned_get_object(
bucket_name=bucket,
object_name=object_name,
expires=expires
)
result = {
'presigned_url': url,
'expires_in': expires,
'expires_at': datetime.now() + timedelta(seconds=expires),
'bucket': bucket,
'object_name': object_name
}
self.log.debug(
"预签名URL生成成功",
存储桶=bucket,
对象名称=object_name,
过期时间=expires
)
return result
except MinioException as e:
self.log.error(
"生成预签名URL失败",
存储桶=bucket,
对象名称=object_name,
错误=str(e),
exc_info=True
)
raise
def list_objects(self, bucket: str, prefix: str = "") -> List[Dict[str, Any]]:
"""
查询指定前缀的对象列表及元数据
参数:
bucket: 存储桶名称
prefix: 对象路径前缀
返回:
对象信息列表每个对象包含:
- bucket: 存储桶
- object_name: 对象名称
- size: 大小
- last_modified: 最后修改时间
- etag: 哈希值
- content_type: 内容类型
"""
try:
objects = self._client.list_objects(
bucket_name=bucket,
prefix=prefix,
recursive=True
)
result = []
for obj in objects:
# 获取详细元数据
stat = self._client.stat_object(bucket, obj.object_name)
result.append({
'bucket': bucket,
'object_name': obj.object_name,
'size': obj.size,
'last_modified': obj.last_modified,
'etag': stat.etag,
'content_type': stat.content_type
})
self.log.info(
"对象列表查询成功",
存储桶=bucket,
前缀=prefix,
数量=len(result)
)
return result
except MinioException as e:
self.log.error(
"查询对象列表失败",
存储桶=bucket,
前缀=prefix,
错误=str(e),
exc_info=True
)
raise
def delete_object(self, bucket: str, object_name: str) -> bool:
"""
删除指定对象
参数:
bucket: 存储桶名称
object_name: 对象名称路径
返回:
是否删除成功
"""
try:
self._client.remove_object(bucket, object_name)
self.log.info(
"对象删除成功",
存储桶=bucket,
对象名称=object_name
)
return True
except MinioException as e:
self.log.error(
"删除对象失败",
存储桶=bucket,
对象名称=object_name,
错误=str(e),
exc_info=True
)
return False
@staticmethod
def _guess_content_type(object_name: str) -> str:
"""根据文件名猜测内容类型"""
from mimetypes import guess_type
mime_type, _ = guess_type(object_name)
return mime_type or 'application/octet-stream' # 默认二进制流类型
-722
View File
@@ -1,722 +0,0 @@
import os
import sys
import platform
import pandas as pd
import pymysql
import json
import numpy as np
from pymysql import cursors
from pymysql.err import MySQLError
from typing import Union, List, Dict, Any, Optional, Tuple, Literal
import threading
from datetime import datetime
from pathlib import Path
# 导入日志系统
from utils.logger import log
class MySQLAgent:
"""
全平台兼容的MySQL数据库操作类
支持Windows/macOS/Linux系统
配置参数从外部传入不使用连接池和事务管理
"""
_instance = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, config: dict):
"""初始化MySQL数据库连接(原有逻辑完全保留)"""
if hasattr(self, 'config') and self.config:
return
# 基础配置校验
required_keys = ['host', 'port', 'user', 'password', 'database']
if not all(key in config for key in required_keys):
log.warning(f"数据库配置缺少必要参数,当前数据库链接信息为:{config}")
raise ValueError(f"数据库配置缺少必要参数,需要: {required_keys}")
self.config = {
'host': config['host'],
'port': config['port'],
'user': config['user'],
'password': config['password'],
'database': config['database'],
'charset': config.get('charset', 'utf8mb4'),
'autocommit': True,
'connect_timeout': config.get('connect_timeout', 10),
'read_timeout': config.get('read_timeout', 30),
'write_timeout': config.get('write_timeout', 30),
'ssl': config.get('ssl')
}
# 初始化日志
current_platform = platform.system()
self.log = log.bind(module=f"MySQLAgent({current_platform})")
def get_connection(self) -> pymysql.connections.Connection:
"""获取数据库连接(原有逻辑完全保留)"""
try:
conn = pymysql.connect(** self.config)
# 为连接添加 character_set_name 方法
if not hasattr(conn, 'character_set_name'):
def _character_set_name():
return self.config.get('charset', 'utf8mb4')
conn.character_set_name = _character_set_name
# macOS需要特殊处理SSL
if platform.system() == 'Darwin' and self.config.get('ssl'):
conn.ping(reconnect=True)
self.log.trace("获取数据库连接成功")
return conn
except Exception as e:
error_msg = str(e)
if platform.system() == 'Windows' and "timed out" in error_msg:
self.log.warning("Windows连接超时,正在重试...")
return self._retry_connection()
self.log.error("连接失败",
error=error_msg,
error_type=type(e).__name__,
host=self.config.get('host'),
port=self.config.get('port'),
database=self.config.get('database'),
exc_info=True)
raise
def _retry_connection(self, max_retries: int = 3) -> Any | None:
"""Windows平台连接重试机制(原有逻辑完全保留)"""
for attempt in range(max_retries):
try:
conn = pymysql.connect(**self.config)
self.log.info(f"经过 {attempt + 1} 次尝试后成功建立连接")
return conn
except Exception:
if attempt == max_retries - 1:
raise
import time
time.sleep(1)
def query_to_df(self, sql: str, params: Union[tuple, dict, None] = None,
parse_dates: Union[List[str], bool] = True,is_print = True) -> pd.DataFrame:
"""执行SQL查询并返回DataFrame(原有逻辑完全保留)"""
try:
self.log.debug("执行SQL查询", sql=sql)
# 获取连接并确保字符集方法存在
conn = self.get_connection()
# 创建SQLAlchemy引擎
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
engine = create_engine(
"mysql+pymysql://",
creator=lambda: conn,
poolclass=StaticPool,
connect_args={'charset': self.config.get('charset', 'utf8mb4')}
)
# 执行查询
df = pd.read_sql(sql, engine, params=params, parse_dates=parse_dates)
if is_print:
self.log.info("查询执行成功", 行数=len(df))
return df
except Exception as e:
self.log.error("SQL查询失败",
sql=sql,
params=params,
error=str(e),
error_type=type(e).__name__,
exc_info=True)
raise
finally:
if 'engine' in locals():
engine.dispose()
def insert_from_df(self, table_name: str, df: pd.DataFrame,
chunk_size: int = 1000, replace: bool = False,
ignore_duplicates: bool = None) -> int:
"""
兼容旧接口的通用插入方法保留replace参数同时支持新的ignore_duplicates
自动处理重复数据对所有数据源通用插入失败的数据会通过日志记录
"""
# 【兼容性处理】如果未指定ignore_duplicates,用replace参数推导
if ignore_duplicates is None:
ignore_duplicates = not replace # 旧逻辑中replace=True表示替换,即不忽略重复
if df.empty:
self.log.warning("尝试插入空的DataFrame", table=table_name)
return 0
conn = None
cursor = None
total_inserted = 0
total_duplicates = 0
total_failed = 0
failed_records = [] # 存储所有失败的记录
try:
# 1. 建立数据库连接
conn = self.get_connection()
cursor = conn.cursor()
self.log.debug(f"已建立连接,准备插入数据到 {table_name}")
# 2. 获取数据库表的实际列名
cursor.execute(f"SHOW COLUMNS FROM `{table_name}`")
columns_info = cursor.fetchall()
db_columns = [col[0] for col in columns_info]
self.log.debug(f"{table_name} 包含以下列:{db_columns}")
# 3. 数据预处理:统一处理空值
cleaned_df = df.replace(
[None, np.nan, pd.NA, 'nan', 'NaN', 'NAN', ''],
None
).copy()
# 4. 字段匹配:只保留与数据库匹配的列
df_columns = cleaned_df.columns.tolist()
matched_columns = [col for col in df_columns if col in db_columns]
unmatched_columns = [col for col in df_columns if col not in db_columns]
if unmatched_columns:
self.log.warning(
f"{table_name} 中存在不匹配的列,已自动丢弃",
unmatched_columns=unmatched_columns,
count=len(unmatched_columns)
)
if not matched_columns:
self.log.warning(f"{table_name} 没有匹配的列,终止插入操作")
return 0
filtered_df = cleaned_df[matched_columns].copy()
total_to_insert = len(filtered_df)
self.log.debug(
f"{table_name} 的过滤后DataFrame:共 {total_to_insert} 行待插入"
)
# 5. 处理复杂类型(dict/list转JSON
for col in filtered_df.columns:
has_complex_type = filtered_df[col].apply(
lambda x: isinstance(x, (dict, list)) if x is not None else False
).any()
if has_complex_type:
self.log.debug(f"{table_name} 中的 {col} 列包含复杂类型,正在转换为JSON")
filtered_df.loc[:, col] = filtered_df[col].apply(
lambda x: json.dumps(x, ensure_ascii=False) if x is not None else x
)
# 6. 构建通用插入SQL
columns_str = ', '.join([f"`{col}`" for col in filtered_df.columns])
placeholders = ', '.join(['%s'] * len(filtered_df.columns))
insert_sql = f"INSERT INTO `{table_name}` ({columns_str}) VALUES ({placeholders})"
self.log.trace(f"为表 {table_name} 生成的插入SQL{insert_sql}")
# 7. 逐条插入(确保能捕获单条重复错误)
records = filtered_df.to_dict('records')
indices = filtered_df.index.tolist()
for i, (record, idx) in enumerate(zip(records, indices)):
try:
data = tuple(record[col] for col in filtered_df.columns)
cursor.execute(insert_sql, data)
total_inserted += 1
if (i + 1) % 100 == 0:
self.log.trace(
f"已向表 {table_name} 插入 {i + 1}/{total_to_insert} 行数据"
)
except MySQLError as e:
# 8. 捕获重复错误(MySQL错误码1062)
if e.args[0] == 1062:
total_duplicates += 1
short_record = {
k: (str(v)[:100] + '...') if isinstance(v, (str, dict, list)) else v
for k, v in record.items()
}
self.log.warning(
f"{table_name} 中跳过重复记录",
index=idx,
error_message=e.args[1],
record=short_record
)
# 记录重复的记录
failed_records.append({
'index': idx,
'type': 'duplicate',
'error_code': e.args[0],
'error_message': e.args[1],
'record': record
})
if not ignore_duplicates:
raise
else:
# 其他数据库错误
total_failed += 1
# 记录失败的记录详情
failed_records.append({
'index': idx,
'type': 'error',
'error_code': e.args[0],
'error_message': e.args[1],
'record': record
})
self.log.error(
f"{table_name} 插入记录失败",
index=idx,
error_code=e.args[0],
error_message=e.args[1],
record=record # 完整记录写入日志
)
if not ignore_duplicates:
raise
# 提交事务
conn.commit()
# 9. 插入结果统计,包括失败记录汇总
self.log.info(
f"{table_name} 插入结果汇总",
total_to_insert=total_to_insert,
total_inserted=total_inserted,
total_duplicates=total_duplicates,
total_failed=total_failed,
failed_records_count=len(failed_records)
)
# 单独记录所有失败的数据详情
if failed_records:
self.log.error(
f"{table_name} 插入失败记录详情",
failed_records_summary=[
{
'index': r['index'],
'type': r['type'],
'error_code': r['error_code'],
'error_message': r['error_message']
} for r in failed_records
],
# 完整记录可以作为调试信息单独记录,避免日志过大
detailed_failed_records=failed_records
)
return total_inserted
except Exception as e:
if conn:
conn.rollback()
self.log.error(f"{table_name} 批量插入失败",
error=str(e),
error_type=type(e).__name__,
table_name=table_name,
total_records=len(df) if not df.empty else 0,
exc_info=True)
# 记录事务回滚时的失败记录
if failed_records:
self.log.error(
f"{table_name} 事务回滚,已失败的记录",
failed_records=failed_records,
failed_count=len(failed_records)
)
raise
finally:
if cursor:
cursor.close()
if conn:
conn.close()
def _get_primary_key(self, table_name: str, cursor) -> Optional[str]:
"""【新增辅助方法】获取表的主键(用于replace逻辑的去重)"""
try:
cursor.execute("""
SELECT COLUMN_NAME
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
WHERE TABLE_SCHEMA = %s
AND TABLE_NAME = %s
AND CONSTRAINT_NAME = 'PRIMARY'
""", (self.config['database'], table_name))
result = cursor.fetchone()
return result[0] if result else None
except Exception as e:
self.log.warning(f"获取表 {table_name} 的主键失败", error=str(e))
return None
def _get_table_detailed_info(self, table_name: str) -> Dict[str, Dict[str, Any]]:
"""获取表的详细结构信息(原有逻辑完全保留,供其他方法调用)"""
sql = """
SELECT column_name, data_type, character_maximum_length
FROM information_schema.columns
WHERE table_schema = %s \
AND table_name = %s \
"""
params = (self.config['database'], table_name)
try:
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(sql, params)
result = cursor.fetchall()
# 强制转换为列表,避免游标类型导致的解析问题
result_list = list(result)
if not result_list:
self.log.error("未在表中找到任何列", =table_name)
return {}
schema = {}
for row in result_list:
# 确保正确提取字段名(兼容元组格式)
col_name = str(row[0]).strip() # 强制转为字符串并去空格
data_type = str(row[1]).strip()
max_length = row[2] if row[2] else None
schema[col_name] = {
'type': data_type,
'max_length': max_length
}
self.log.debug("成功获取表结构信息",
=table_name,
=list(schema.keys()))
return schema
finally:
cursor.close()
conn.close()
except Exception as e:
self.log.error("获取表详细信息失败",
=table_name,
error=str(e))
raise
def _validate_and_clean_data(self, df: pd.DataFrame, table_name: str,
table_schema: Dict[str, Dict[str, Any]]) -> pd.DataFrame:
"""数据校验与清洗(原有逻辑完全保留,供其他方法调用)"""
# 1. 字段过滤:只保留表中存在的字段
df_columns = df.columns.tolist()
table_columns = list(table_schema.keys())
valid_columns = [col for col in df_columns if col in table_columns]
invalid_columns = [col for col in df_columns if col not in table_columns]
if invalid_columns:
self.log.warning("丢弃表中不存在的无效列",
=table_name,
无效列=invalid_columns,
数量=len(invalid_columns))
cleaned_df = df[valid_columns].copy()
if cleaned_df.empty:
return cleaned_df
# 2. 处理每个字段的数据
for col in valid_columns:
col_info = table_schema[col]
data_type = col_info['type']
max_length = col_info['max_length']
# 2.1 处理空值
if cleaned_df[col].isnull().any():
# 根据字段类型设置默认值
default_value = '' if data_type in ['varchar', 'char', 'text'] else None
cleaned_df[col].fillna(default_value, inplace=True)
self.log.debug("替换空值",
=table_name,
=col,
默认值=default_value,
数量=cleaned_df[col].isnull().sum())
# 2.2 处理字符串类型的超长字段
if data_type in ['varchar', 'char'] and max_length:
# 确保是字符串类型
cleaned_df[col] = cleaned_df[col].astype(str)
# 截断超长内容
too_long_mask = cleaned_df[col].str.len() > max_length
if too_long_mask.any():
cleaned_df.loc[too_long_mask, col] = cleaned_df.loc[too_long_mask, col].str.slice(0, max_length)
self.log.warning("截断超长值",
=table_name,
=col,
最大长度=max_length,
数量=too_long_mask.sum())
# 2.3 处理日期时间类型
if data_type in ['datetime', 'timestamp']:
try:
# 尝试转换为datetime类型
cleaned_df[col] = pd.to_datetime(cleaned_df[col])
except Exception as e:
self.log.warning("转换为datetime失败,使用当前时间替代",
=table_name,
=col,
错误=str(e))
# 转换失败的用当前时间替代
invalid_mask = pd.to_datetime(cleaned_df[col], errors='coerce').isna()
cleaned_df.loc[invalid_mask, col] = datetime.now()
return cleaned_df
def update_from_df(self, table_name: str, df: pd.DataFrame,
key_columns: Union[str, List[str]]) -> int:
"""使用DataFrame数据更新数据库表(原有逻辑完全保留)"""
if df.empty:
self.log.warning("尝试使用空的DataFrame进行更新", =table_name)
return 0
self.log.debug("准备从DataFrame更新表数据",
=table_name,
关键字列=key_columns,
行数=len(df))
try:
if isinstance(key_columns, str):
key_columns = [key_columns]
总更新数 = 0
with self.get_connection() as conn:
with conn.cursor() as cursor:
# 获取表结构信息
table_info = self._get_table_detailed_info(table_name)
columns = [col for col in df.columns if col in table_info]
# 构建UPDATE语句模板
set_clause = ', '.join([f"{col}=%s" for col in columns if col not in key_columns])
where_clause = ' AND '.join([f"{col}=%s" for col in key_columns])
if not set_clause:
self.log.warning("没有可更新的列", =table_name)
return 0
update_sql = f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}"
self.log.trace("生成的更新SQL", sql=update_sql)
# 准备数据
update_data = []
for _, row in df.iterrows():
set_values = [row[col] for col in columns if col not in key_columns]
key_values = [row[col] for col in key_columns]
update_data.append(tuple(set_values + key_values))
# 执行批量更新
cursor.executemany(update_sql, update_data)
总更新数 = cursor.rowcount
conn.commit()
self.log.info("数据更新成功",
=table_name,
更新行数=总更新数)
return 总更新数
except Exception as e:
self.log.error("数据更新失败",
=table_name,
error=str(e),
exc_info=True)
raise
def df_to_sql_type(self, df: pd.DataFrame) -> Dict[str, str]:
"""推断DataFrame各列的SQL类型(原有逻辑完全保留)"""
type_mapping = {
'int64': 'BIGINT',
'float64': 'DOUBLE',
'datetime64[ns]': 'DATETIME',
'object': 'TEXT',
'bool': 'TINYINT(1)',
'category': 'VARCHAR(255)'
}
sql_types = {}
for col, dtype in df.dtypes.items():
dtype_str = str(dtype)
sql_types[col] = type_mapping.get(dtype_str, 'TEXT')
self.log.debug("将DataFrame类型映射为SQL类型",
映射关系=sql_types)
return sql_types
def create_table_from_df(self, table_name: str, df: pd.DataFrame,
primary_key: Union[str, List[str], None] = None) -> bool:
"""根据DataFrame结构创建表(原有逻辑完全保留)"""
if self.table_exists(table_name):
self.log.warning("表已存在", =table_name)
return False
self.log.debug("根据DataFrame结构创建新表",
=table_name,
=list(df.columns))
try:
sql_types = self.df_to_sql_type(df)
columns_sql = []
for col, sql_type in sql_types.items():
col_def = f"{col} {sql_type}"
columns_sql.append(col_def)
if primary_key:
if isinstance(primary_key, str):
primary_key = [primary_key]
pk_columns = [col for col in primary_key if col in sql_types]
if pk_columns:
columns_sql.append(f"PRIMARY KEY ({', '.join(pk_columns)})")
self.log.trace("设置主键",
=table_name,
主键=pk_columns)
create_sql = f"CREATE TABLE {table_name} (\n {',\n '.join(columns_sql)}\n)"
self.execute_sql(create_sql)
self.log.info("表创建成功", =table_name)
return True
except Exception as e:
self.log.error("创建表失败",
=table_name,
error=str(e),
exc_info=True)
return False
def execute_sql(self, sql: str, params: Union[tuple, dict, None] = None,
fetch: bool = False) -> Union[int, List[Dict[str, Any]]]:
"""执行SQL语句(原有逻辑完全保留)"""
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
# Linux/macOS需要更长的执行时间
if platform.system() != 'Windows':
cursor.execute("SET SESSION max_execution_time=600000")
cursor.execute(sql, params)
if fetch:
result = cursor.fetchall()
self.log.debug("查询执行完成", 行数=len(result))
return result
else:
affected_rows = cursor.rowcount
conn.commit() # 立即提交
self.log.debug("更新执行完成", 受影响行数=affected_rows)
return affected_rows
except Exception as e:
self.log.error("SQL执行失败",
sql=sql,
params=params,
error=str(e),
error_type=type(e).__name__,
exc_info=True)
raise
def table_exists(self, table_name: str) -> bool:
"""检查表是否存在(原有逻辑完全保留)"""
sql = """
SELECT COUNT(*) as count
FROM `information_schema`.`tables`
WHERE `table_schema` = %s \
AND `table_name` = %s \
"""
params = (self.config['database'], table_name)
try:
result = self.execute_sql(sql, params, fetch=True)
exists = result[0][0] > 0 # 适配元组结果
self.log.debug("检查表是否存在",
=table_name,
存在=exists)
return exists
except Exception:
return False
def drop_table(self, table_name: str) -> bool:
"""删除表(原有逻辑完全保留)"""
if not self.table_exists(table_name):
self.log.warning("表不存在", =table_name)
return False
try:
self.execute_sql(f"DROP TABLE {table_name}")
self.log.info("表删除成功", =table_name)
return True
except Exception as e:
self.log.error("删除表失败",
=table_name,
error=str(e),
exc_info=True)
return False
def validate_connection(self) -> bool:
"""验证连接是否有效(原有逻辑完全保留)"""
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute("SELECT 1")
return cursor.fetchone()[0] == 1
except Exception:
return False
# 平台特定的默认配置(原有逻辑完全保留)
def get_default_config():
"""获取各平台默认配置"""
current_platform = platform.system()
base_config = {
'host': 'localhost',
'port': 3306,
'user': 'root',
'password': '123123',
'database': 'intelligence_system',
}
if current_platform == 'Windows':
return {**base_config,
'connect_timeout': 10,
'read_timeout': 30,
'write_timeout': 30
}
elif current_platform == 'Darwin':
return {
**base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60,
'ssl': {'ca': '/usr/local/etc/openssl/cert.pem'}
}
else: # Linux和其他平台
return {** base_config,
'connect_timeout': 15,
'read_timeout': 60,
'write_timeout': 60
}
if __name__ == "__main__":
# 使用示例(原有逻辑完全保留)
db = MySQLAgent(get_default_config())
# 测试连接
if db.validate_connection():
print("数据库连接成功")
# 获取数据库版本
version = db.query_to_df("SELECT VERSION() as version")
print(f"数据库版本: {version['version'].iloc[0]}")
else:
print("连接数据库失败")