commit 71e9c7c5bc9144593e1cc5d97defd9446809f5c4 Author: Administrator <1415243231@qq.com> Date: Tue Aug 5 15:00:46 2025 +0800 ai初期模板 diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..35410ca --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml +# 基于编辑器的 HTTP 客户端请求 +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..af54cf8 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,85 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..cf3a8aa --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,9 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..97be862 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/applications/alert.py b/applications/alert.py new file mode 100644 index 0000000..e69de29 diff --git a/applications/reporter/daily.py b/applications/reporter/daily.py new file mode 100644 index 0000000..e69de29 diff --git a/applications/reporter/monthly.py b/applications/reporter/monthly.py new file mode 100644 index 0000000..e69de29 diff --git a/collectors/base.py b/collectors/base.py new file mode 100644 index 0000000..e69de29 diff --git a/collectors/complaint_spider.py b/collectors/complaint_spider.py new file mode 100644 index 0000000..e69de29 diff --git a/collectors/news_api.py b/collectors/news_api.py new file mode 100644 index 0000000..e69de29 diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000..6c574fc --- /dev/null +++ b/config/__init__.py @@ -0,0 +1,273 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +配置初始化模块 +功能: +1. 自动生成默认配置文件 +2. 多环境配置支持(dev/test/prod) +3. 敏感信息加密存储 +4. 配置完整性检查与修复 +""" + +import os +import json +import platform +from pathlib import Path +from typing import Dict, Any, Optional +import logging +from cryptography.fernet import Fernet +import hashlib + +# 初始化日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger('config_init') + +class ConfigInitializer: + """配置初始化工具类""" + + def __init__(self, app_name: str = "intelligence_system"): + self.system = platform.system().lower() + self.app_name = app_name + self.config_dir = self._get_config_dir() + self.config_file = self.config_dir / "config.json" + self.secret_key_file = self.config_dir / ".secret.key" + self._fernet = None + + # 确保配置目录存在 + self.config_dir.mkdir(parents=True, exist_ok=True) + + # 设置文件权限(非Windows) + if self.system != 'windows': + os.chmod(self.config_dir, 0o700) + + def _get_config_dir(self) -> Path: + """获取适合当前平台的配置目录路径""" + if self.system == 'windows': + return Path(os.environ['APPDATA']) / self.app_name + elif self.system == 'darwin': # macOS + return Path.home() / "Library" / "Application Support" / self.app_name + else: # Linux及其他Unix-like + xdg_config = os.getenv('XDG_CONFIG_HOME', '~/.config') + return Path(xdg_config).expanduser() / self.app_name + + def _init_encryption(self): + """初始化加密模块""" + if not self.secret_key_file.exists(): + self.secret_key_file.write_bytes(Fernet.generate_key()) + if self.system != 'windows': + self.secret_key_file.chmod(0o600) # 仅用户可读写 + + self._fernet = Fernet(self.secret_key_file.read_bytes()) + + def encrypt_value(self, plaintext: str) -> str: + """加密敏感信息""" + if not self._fernet: + self._init_encryption() + return self._fernet.encrypt(plaintext.encode()).decode() + + def decrypt_value(self, ciphertext: str) -> str: + """解密信息""" + if not self._fernet: + self._init_encryption() + return self._fernet.decrypt(ciphertext.encode()).decode() + + def _get_default_config(self) -> Dict[str, Any]: + """获取默认配置模板""" + return { + "system": { + "env": "dev", # dev/test/prod + "log_level": "INFO", + "max_threads": max(1, os.cpu_count() or 4), + "data_dir": str(self.config_dir / "data") + }, + "api": { + "newsapi": { + "endpoint": "https://newsapi.org/v2", + "key": "" # 需加密存储 + }, + "weibo": { + "version": "2", + "access_token": "" # 需加密存储 + } + }, + "database": { + "type": "sqlite", + "path": str(self.config_dir / "data.db") + }, + "network": { + "timeout": 30, + "retries": 3, + "proxy": "" # 示例: http://user:pass@proxy:port + } + } + + def _migrate_old_config(self, config: Dict[str, Any]) -> Dict[str, Any]: + """旧配置迁移(兼容性处理)""" + # 示例:将旧版api_key迁移到新版结构 + if 'api_key' in config: + config.setdefault('api', {})['newsapi'] = { + 'key': config.pop('api_key') + } + return config + + def _validate_config(self, config: Dict[str, Any]) -> bool: + """验证配置完整性""" + required_keys = { + "system": ["env", "log_level"], + "api/newsapi": ["endpoint"] + } + + for path, keys in required_keys.items(): + current = config + for part in path.split('/'): + current = current.get(part, {}) + if not isinstance(current, dict): + return False + + for key in keys: + if key not in current: + return False + return True + + def _repair_config(self, config: Dict[str, Any]) -> Dict[str, Any]: + """自动修复缺失的配置项""" + default_config = self._get_default_config() + + def _merge(current, default): + for key, value in default.items(): + if key not in current: + current[key] = value + elif isinstance(value, dict): + _merge(current[key], value) + return current + + return _merge(config, default_config) + + def init_config(self, force: bool = False) -> bool: + """ + 初始化配置文件 + 参数: + force: 是否强制重新生成配置 + 返回: + bool: 是否创建了新配置 + """ + config = None + + # 已有配置文件且不强制重置 + if self.config_file.exists() and not force: + try: + with open(self.config_file, 'r', encoding='utf-8') as f: + config = json.load(f) + + # 配置迁移和修复 + config = self._migrate_old_config(config) + if not self._validate_config(config): + config = self._repair_config(config) + logger.warning("自动修复不完整的配置文件") + + except Exception as e: + logger.error(f"加载现有配置失败: {str(e)}") + config = None + + # 需要创建新配置 + if config is None: + config = self._get_default_config() + logger.info("创建新的配置文件") + + # 加密敏感字段 + self._init_encryption() + for field in [ + "api/newsapi/key", + "api/weibo/access_token", + "network/proxy" + ]: + parts = field.split('/') + current = config + for part in parts[:-1]: + current = current.setdefault(part, {}) + + if parts[-1] in current and current[parts[-1]]: + current[parts[-1]] = self.encrypt_value(current[parts[-1]]) + + # 保存配置 + with open(self.config_file, 'w', encoding='utf-8') as f: + json.dump(config, f, indent=2, ensure_ascii=False) + + # 设置文件权限(非Windows) + if self.system != 'windows': + os.chmod(self.config_file, 0o600) + + return True + + def get_config_hash(self) -> str: + """获取配置文件哈希值(用于检测变更)""" + if not self.config_file.exists(): + return "" + + with open(self.config_file, 'rb') as f: + return hashlib.sha256(f.read()).hexdigest() + + def create_env_specific_config(self, env: str = None) -> bool: + """ + 创建环境特定配置 + 参数: + env: 环境类型(dev/test/prod) + """ + if not self.config_file.exists(): + self.init_config() + + with open(self.config_file, 'r', encoding='utf-8') as f: + base_config = json.load(f) + + env = env or base_config['system']['env'] + env_config = { + f"env_{env}": { + "api": { + "newsapi": {"endpoint": self._get_env_endpoint(env)} + }, + "database": { + "path": str(self.config_dir / f"data_{env}.db") + } + } + } + + env_file = self.config_dir / f"config.{env}.json" + with open(env_file, 'w', encoding='utf-8') as f: + json.dump(env_config, f, indent=2) + + return True + + def _get_env_endpoint(self, env: str) -> str: + """获取环境特定的API端点""" + endpoints = { + "dev": "http://dev-api.example.com", + "test": "https://test-api.example.com", + "prod": "https://api.example.com" + } + return endpoints.get(env, endpoints['dev']) + +# 快捷初始化函数 +def init_app_config(app_name: str = None, force: bool = False) -> bool: + """ + 快速初始化应用配置 + 参数: + app_name: 应用名称 + force: 是否强制重新初始化 + """ + return ConfigInitializer(app_name).init_config(force) + +# 测试代码 +if __name__ == "__main__": + # 初始化配置 + initializer = ConfigInitializer() + if initializer.init_config(): + print("配置文件已生成:", initializer.config_file) + + # 创建环境配置示例 + initializer.create_env_specific_config("prod") + print("生产环境配置已生成") + + # 加密演示 + encrypted = initializer.encrypt_value("my_secret_key") + print("加密示例:", encrypted) + print("解密测试:", initializer.decrypt_value(encrypted)) \ No newline at end of file diff --git a/config/logging.conf b/config/logging.conf new file mode 100644 index 0000000..5adc1ba --- /dev/null +++ b/config/logging.conf @@ -0,0 +1,56 @@ +[loggers] +keys=root,data_collector,api_client,alert + +[handlers] +keys=consoleHandler,fileHandler,errorFileHandler + +[formatters] +keys=standardFormatter,detailedFormatter + +[logger_root] +level=INFO +handlers=consoleHandler,fileHandler + +[logger_data_collector] +level=DEBUG +handlers=fileHandler +qualname=data_collector +propagate=0 + +[logger_api_client] +level=INFO +handlers=fileHandler,errorFileHandler +qualname=api_client +propagate=0 + +[logger_alert] +level=WARNING +handlers=consoleHandler,errorFileHandler +qualname=alert +propagate=0 + +[handler_consoleHandler] +class=StreamHandler +level=INFO +formatter=standardFormatter +args=(sys.stdout,) + +[handler_fileHandler] +class=logging.handlers.TimedRotatingFileHandler +level=DEBUG +formatter=detailedFormatter +args=('%(log_dir)s/application.log', 'midnight', 1, 30, 'utf-8') + +[handler_errorFileHandler] +class=logging.handlers.TimedRotatingFileHandler +level=WARNING +formatter=detailedFormatter +args=('%(log_dir)s/error.log', 'midnight', 1, 90, 'utf-8') + +[formatter_standardFormatter] +format=%(asctime)s [%(levelname)-5s] %(name)s - %(message)s +datefmt=%Y-%m-%d %H:%M:%S + +[formatter_detailedFormatter] +format=%(asctime)s [%(levelname)-5s] %(name)s (%(filename)s:%(lineno)d) - %(message)s +datefmt=%Y-%m-%d %H:%M:%S \ No newline at end of file diff --git a/config/settings.py b/config/settings.py new file mode 100644 index 0000000..b37aef8 --- /dev/null +++ b/config/settings.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +系统配置模块 +功能: +1. 支持多平台路径适配 +2. 环境变量与配置文件优先级管理 +3. 敏感信息加密存储 +4. 配置热更新检测 +""" + +import os +import json +import platform +from pathlib import Path +from typing import Dict, Any, Optional +import dotenv +from cryptography.fernet import Fernet + +class ConfigManager: + def __init__(self, app_name: str = "intelligence_system"): + """ + 初始化配置管理器 + + 参数: + app_name: 应用名称(用于生成配置目录) + """ + self.system = platform.system().lower() + self.app_name = app_name + self._config = {} + self._secret_key = None + + # 初始化配置路径 + self.config_dir = self._get_config_dir() + os.makedirs(self.config_dir, exist_ok=True) + + # 加载配置顺序 + self._load_defaults() + self._load_env_file() + self._load_user_config() + + # 初始化加密模块 + self._init_encryption() + + def _get_config_dir(self) -> str: + """获取适合当前平台的配置目录""" + if self.system == 'windows': + return os.path.join(os.environ['APPDATA'], self.app_name) + elif self.system == 'darwin': # macOS + return os.path.expanduser(f"~/Library/Application Support/{self.app_name}") + else: # Linux及其他Unix-like + return os.path.expanduser(f"~/.config/{self.app_name}") + + def _load_defaults(self): + """加载默认配置""" + self._config = { + "system": { + "log_level": "INFO", + "max_threads": os.cpu_count() or 4 + }, + "api": { + "newsapi": {"endpoint": "https://newsapi.org/v2"}, + "weibo": {"version": "2"} + }, + "paths": { + "data_dir": os.path.join(self.config_dir, "data"), + "cache_dir": os.path.join(self.config_dir, "cache") + } + } + + def _load_env_file(self): + """加载.env环境变量文件""" + env_path = Path(self.config_dir) / ".env" + if env_path.exists(): + dotenv.load_dotenv(env_path) + + # 环境变量覆盖配置 + if os.getenv("LOG_LEVEL"): + self._config["system"]["log_level"] = os.getenv("LOG_LEVEL") + + def _load_user_config(self): + """加载用户自定义配置""" + config_file = Path(self.config_dir) / "config.json" + try: + if config_file.exists(): + with open(config_file, 'r', encoding='utf-8') as f: + user_config = json.load(f) + self._deep_update(self._config, user_config) + except Exception as e: + print(f"加载用户配置失败: {str(e)}") + + def _init_encryption(self): + """初始化配置加密模块""" + key_file = Path(self.config_dir) / ".secret.key" + if key_file.exists(): + with open(key_file, 'rb') as f: + self._secret_key = f.read() + else: + self._secret_key = Fernet.generate_key() + with open(key_file, 'wb') as f: + f.write(self._secret_key) + key_file.chmod(0o600) # 设置密钥文件权限 + + def _deep_update(self, original: Dict, update: Dict) -> Dict: + """深度合并字典""" + for key, value in update.items(): + if isinstance(value, dict) and key in original: + original[key] = self._deep_update(original.get(key, {}), value) + else: + original[key] = value + return original + + def get(self, key: str, default: Any = None) -> Any: + """ + 获取配置项(支持点分路径) + 示例: get("api.newsapi.endpoint") + """ + keys = key.split('.') + value = self._config + try: + for k in keys: + value = value[k] + return value + except (KeyError, TypeError): + return default + + def set(self, key: str, value: Any, persist: bool = False): + """ + 设置配置项 + 参数: + persist: 是否保存到用户配置文件 + """ + keys = key.split('.') + config_ref = self._config + + for k in keys[:-1]: + if k not in config_ref: + config_ref[k] = {} + config_ref = config_ref[k] + + config_ref[keys[-1]] = value + + if persist: + self._save_user_config() + + def encrypt_value(self, plaintext: str) -> str: + """加密敏感信息""" + fernet = Fernet(self._secret_key) + return fernet.encrypt(plaintext.encode()).decode() + + def decrypt_value(self, ciphertext: str) -> str: + """解密敏感信息""" + fernet = Fernet(self._secret_key) + return fernet.decrypt(ciphertext.encode()).decode() + + def _save_user_config(self): + """保存用户配置到文件""" + config_file = Path(self.config_dir) / "config.json" + with open(config_file, 'w', encoding='utf-8') as f: + json.dump(self._config, f, indent=2, ensure_ascii=False) + + def reload(self): + """重新加载所有配置""" + self._config = {} + self._load_defaults() + self._load_env_file() + self._load_user_config() + +# 全局配置实例 +config = ConfigManager() + +# 快捷访问方法(兼容旧代码) +def get_config(key: str, default: Optional[Any] = None) -> Any: + return config.get(key, default) + +def set_config(key: str, value: Any, persist: bool = False): + config.set(key, value, persist) + +# 测试代码 +if __name__ == "__main__": + # 设置并保存API密钥(自动加密) + api_key = "your_newsapi_key_here" + encrypted_key = config.encrypt_value(api_key) + config.set("api.newsapi.key", encrypted_key, persist=True) + + # 获取配置示例 + print(f"日志级别: {config.get('system.log_level')}") + print(f"NewsAPI端点: {config.get('api.newsapi.endpoint')}") + + # 解密敏感信息 + stored_key = config.get("api.newsapi.key") + if stored_key: + print(f"解密后的API密钥: {config.decrypt_value(stored_key)}") \ No newline at end of file diff --git a/intelligence_system.iml b/intelligence_system.iml new file mode 100644 index 0000000..599fc7c --- /dev/null +++ b/intelligence_system.iml @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..4d10b3a --- /dev/null +++ b/main.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +情报收集系统主程序 +功能: +1. 调度数据采集、处理、存储流程 +2. 生成日报/月报 +3. 异常监控和报警 +""" + +import sys +import logging +from datetime import datetime, timedelta +from typing import List, Dict, Any + +# 自定义模块 +from config.settings import API_KEYS, DATA_SOURCES +from collectors.news_api import NewsAPICollector +from collectors.complaint_spider import ComplaintSpider +from processors.data_processor import DataProcessor +from storage.database import IntelligenceDB +from applications.reporter import ReportGenerator +from applications.alert import AlertService +from utils.logger import setup_logging +from utils.mail import send_email + +class IntelligenceSystem: + def __init__(self): + # 初始化核心组件 + setup_logging() + self.logger = logging.getLogger(__name__) + + self.db = IntelligenceDB() + self.processor = DataProcessor() + self.alert = AlertService() + + # 数据采集器注册 + self.collectors = { + "news": NewsAPICollector(API_KEYS['newsapi']), + "complaint": ComplaintSpider( + base_url=DATA_SOURCES['blackcat'], + rate_limit=30 # 30秒爬取间隔 + ) + } + + def run_daily_pipeline(self): + """每日数据采集处理流程""" + try: + # 阶段1:数据采集 + raw_data = self._collect_data() + + # 阶段2:数据处理 + processed_data = self._process_data(raw_data) + + # 阶段3:数据存储 + self._store_data(processed_data) + + # 阶段4:生成日报 + self._generate_reports() + + # 阶段5:异常检测 + self._check_alerts() + + except Exception as e: + self.logger.error(f"主流程执行失败: {str(e)}", exc_info=True) + self.alert.send_critical(f"系统异常: {str(e)}") + + def _collect_data(self) -> Dict[str, List[Dict]]: + """执行所有数据采集任务""" + collected = {} + for name, collector in self.collectors.items(): + try: + self.logger.info(f"开始采集 {name} 数据...") + data = collector.fetch_data({ + 'keywords': '汽车后市场', + 'max_results': 100 + }) + collected[name] = data + self.logger.info(f"{name} 采集完成,共 {len(data)} 条数据") + except Exception as e: + self.logger.error(f"{name} 采集器异常: {str(e)}") + continue + return collected + + def _process_data(self, raw_data: Dict) -> Dict: + """处理原始数据""" + processed = {} + for data_type, items in raw_data.items(): + processed[data_type] = [] + for item in items: + try: + # 文本数据标准处理 + if data_type in ['news', 'complaint']: + result = self.processor.process_text(item['content']) + processed_item = { + **item, + 'keywords': result['keywords'], + 'category': result['category'] + } + processed[data_type].append(processed_item) + + # 图像处理(预留接口) + elif data_type == 'images': + processed[data_type].append( + self.processor.image_to_text(item) + ) + except Exception as e: + self.logger.warning(f"数据处理失败: {item.get('id', '')} - {str(e)}") + continue + return processed + + def _store_data(self, processed_data: Dict): + """存储到数据库""" + for data_type, items in processed_data.items(): + success_count = 0 + for item in items: + try: + if self.db.insert_data(data_type, item): + success_count += 1 + except Exception as e: + self.logger.error(f"数据存储失败: {str(e)}") + + self.logger.info( + f"{data_type} 数据存储完成,成功 {success_count}/{len(items)} 条" + ) + + def _generate_reports(self): + """生成报告并发送""" + try: + # 日报生成 + report_html = ReportGenerator(self.db).generate_daily() + with open(f"reports/daily_{datetime.now().date()}.html", 'w') as f: + f.write(report_html) + + # 每月1号生成月报 + if datetime.now().day == 1: + monthly_report = ReportGenerator(self.db).generate_monthly() + send_email( + to="team@example.com", + subject=f"{datetime.now().strftime('%Y-%m')} 情报月报", + content=monthly_report + ) + + except Exception as e: + self.logger.error(f"报告生成失败: {str(e)}") + + def _check_alerts(self): + """检查预警信息""" + # 负面舆情监测 + negative_keywords = ['投诉', '造假', '违规'] + alerts = self.alert.check_negative(negative_keywords) + + if alerts: + self.alert.send_urgent( + "负面舆情警报", + "\n".join([f"[{a['source']}] {a['content']}" for a in alerts]) + ) + + def cleanup(self): + """资源清理""" + self.db.close() + self.logger.info("系统资源已释放") + +if __name__ == "__main__": + system = IntelligenceSystem() + + try: + # 执行每日任务 + if len(sys.argv) > 1 and sys.argv[1] == "--manual": + system.logger.info("手动执行模式启动") + system.run_daily_pipeline() + else: + # 定时任务模式(实际部署时改用crontab或APScheduler) + system.logger.info("定时任务模式启动") + while True: + now = datetime.now() + if now.hour == 9 and now.minute == 0: # 每天9点执行 + system.run_daily_pipeline() + time.sleep(60) # 避免重复执行 + time.sleep(30) + + except KeyboardInterrupt: + system.logger.info("用户中断执行") + finally: + system.cleanup() \ No newline at end of file diff --git a/processors/image_processor.py b/processors/image_processor.py new file mode 100644 index 0000000..e69de29 diff --git a/processors/text_processor.py b/processors/text_processor.py new file mode 100644 index 0000000..e69de29 diff --git a/storage/database.py b/storage/database.py new file mode 100644 index 0000000..db151ff --- /dev/null +++ b/storage/database.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +数据库存储模块 +功能: +1. 统一数据库接口(SQLite/MySQL/PostgreSQL) +2. 自动处理多平台路径问题 +3. 连接池管理 +4. 数据加密存储 +""" + +import os +import platform +import sqlite3 +import threading +from pathlib import Path +from typing import Optional, Union, Dict, Any, List +from threading import Lock +import logging +from cryptography.fernet import Fernet + +# 类型别名 +QueryParams = Union[tuple, Dict[str, Any]] + + +class DatabaseManager: + """数据库统一管理类""" + + def __init__(self, db_config: Dict[str, Any]): + """ + 初始化数据库连接 + + 参数: + db_config: 配置字典,包含: + - type: 'sqlite'|'mysql'|'postgresql' + - database: 数据库名/路径 + - [可选] host, port, user, password + """ + self.config = db_config + self._lock = Lock() + self._connection_pool = {} + self._setup_crypto() + + # 自动创建SQLite目录 + if db_config['type'] == 'sqlite': + self._ensure_sqlite_dir() + + def _ensure_sqlite_dir(self): + """确保SQLite数据库目录存在""" + db_path = Path(self.config['database']) + if not db_path.parent.exists(): + try: + db_path.parent.mkdir(parents=True, mode=0o755) + except Exception as e: + logging.error(f"创建数据库目录失败: {str(e)}") + + def _setup_crypto(self): + """初始化加密模块""" + key_file = Path(os.path.expanduser("~/.db_encryption.key")) + if key_file.exists(): + with open(key_file, 'rb') as f: + self._fernet = Fernet(f.read()) + else: + self._fernet = Fernet.generate_key() + with open(key_file, 'wb') as f: + f.write(self._fernet) + key_file.chmod(0o600) # 仅限当前用户读写 + + def get_connection(self, reuse=True): + """ + 获取数据库连接(线程安全) + + 参数: + reuse: 是否复用现有连接(默认True) + """ + thread_id = threading.get_ident() + + with self._lock: + if reuse and thread_id in self._connection_pool: + conn = self._connection_pool[thread_id] + try: + # 检查连接是否有效 + conn.execute("SELECT 1") + return conn + except: + del self._connection_pool[thread_id] + + # 创建新连接 + if self.config['type'] == 'sqlite': + conn = self._create_sqlite_connection() + elif self.config['type'] == 'mysql': + conn = self._create_mysql_connection() + elif self.config['type'] == 'postgresql': + conn = self._create_pg_connection() + else: + raise ValueError("不支持的数据库类型") + + self._connection_pool[thread_id] = conn + return conn + + def _create_sqlite_connection(self) -> sqlite3.Connection: + """创建SQLite连接(兼容多平台路径)""" + db_path = self.config['database'] + + # Windows路径处理 + if platform.system() == 'Windows' and not db_path.startswith(('\\\\', '/')): + db_path = os.path.abspath(db_path) + + conn = sqlite3.connect(db_path, timeout=15) + conn.execute("PRAGMA journal_mode=WAL") # 写前日志提升并发 + conn.execute("PRAGMA synchronous=NORMAL") + conn.row_factory = sqlite3.Row # 支持字典式访问 + return conn + + def _create_mysql_connection(self): + """创建MySQL连接(需安装PyMySQL)""" + import pymysql + return pymysql.connect( + host=self.config.get('host', 'localhost'), + port=self.config.get('port', 3306), + user=self.config.get('user', 'root'), + password=self.config.get('password', ''), + database=self.config['database'], + charset='utf8mb4', + cursorclass=pymysql.cursors.DictCursor + ) + + def _create_pg_connection(self): + """创建PostgreSQL连接(需安装psycopg2)""" + import psycopg2 + return psycopg2.connect( + host=self.config.get('host', 'localhost'), + port=self.config.get('port', 5432), + user=self.config.get('user', 'postgres'), + password=self.config.get('password', ''), + dbname=self.config['database'] + ) + + def execute( + self, + query: str, + params: Optional[QueryParams] = None, + return_lastrowid: bool = False + ) -> Union[int, None]: + conn = self.get_connection() + try: + with conn: + cursor = conn.cursor() + if params is not None: + cursor.execute(query, params) + else: + cursor.execute(query) # ✅ 无参数时不要传 params + return cursor.lastrowid if return_lastrowid else None + finally: + if not return_lastrowid: + self._release_connection() + + def query( + self, + query: str, + params: Optional[QueryParams] = None, + fetchall: bool = True + ) -> Union[List[Dict], Dict]: + conn = self.get_connection() + try: + cursor = conn.cursor() + if params is not None: + cursor.execute(query, params) + else: + cursor.execute(query) # ✅ 无参数时不要传 None + + if self.config['type'] == 'sqlite': + result = cursor.fetchall() + if not fetchall and result: + return dict(result[0]) + return [dict(row) for row in result] + else: + return cursor.fetchall() if fetchall else cursor.fetchone() + finally: + self._release_connection() + + def _release_connection(self): + """释放当前线程的连接(SQLite除外)""" + if self.config['type'] != 'sqlite': + thread_id = threading.get_ident() + with self._lock: + if thread_id in self._connection_pool: + self._connection_pool[thread_id].close() + del self._connection_pool[thread_id] + + def encrypt_data(self, plaintext: str) -> str: + """加密敏感数据""" + return self._fernet.encrypt(plaintext.encode()).decode() + + def decrypt_data(self, ciphertext: str) -> str: + """解密数据""" + return self._fernet.decrypt(ciphertext.encode()).decode() + + def close_all(self): + """关闭所有数据库连接""" + with self._lock: + for conn in self._connection_pool.values(): + try: + conn.close() + except: + pass + self._connection_pool.clear() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close_all() + + +# 全局SQLite实例(默认配置) +def get_default_db() -> DatabaseManager: + """获取默认SQLite数据库(跨平台路径处理)""" + system = platform.system().lower() + if system == 'windows': + db_path = os.path.join(os.getenv('APPDATA'), 'app_name/data.db') + elif system == 'darwin': + db_path = os.path.expanduser('~/Library/Application Support/app_name/data.db') + else: + db_path = '/var/lib/app_name/data.db' if os.access('/var/lib', os.W_OK) \ + else os.path.expanduser('~/.local/share/app_name/data.db') + + return DatabaseManager({ + 'type': 'sqlite', + 'database': db_path + }) + + +# 测试代码 +if __name__ == "__main__": + with get_default_db() as db: + # 创建测试表 + db.execute(""" + CREATE TABLE IF NOT EXISTS test_table ( + id INTEGER PRIMARY KEY, + name TEXT, + secret TEXT + ) + """) + + # 插入加密数据 + secret = db.encrypt_data("新敏感信息") + db.execute( + "INSERT INTO test_table (name, secret) VALUES (?, ?)", + ("测试记录", secret) + ) + + # 查询并解密 + row = db.query("SELECT * FROM test_table", fetchall=False) + print(f"解密数据: {db.decrypt_data(row['name'])}") diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..a89b623 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +跨平台日志工具模块 +功能: +1. 自动适配不同操作系统的日志路径 +2. 支持中文等非ASCII字符 +3. 日志文件自动按日期分割 +4. 控制台与文件双输出 +""" + +import os +import sys +import logging +from logging.handlers import TimedRotatingFileHandler +from datetime import datetime +import platform + +class CrossPlatformLogger: + def __init__(self, name="intelligence_system"): + """ + 初始化跨平台日志系统 + + 参数: + name: 日志名称(用于创建日志文件夹) + """ + self.system = platform.system().lower() + self.logger = logging.getLogger(name) + self.logger.setLevel(logging.INFO) + + # 确保日志目录存在 + self.log_dir = self._get_log_dir(name) + os.makedirs(self.log_dir, exist_ok=True) + + # 配置日志格式 + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + # 控制台处理器 + self._setup_console_handler(formatter) + + # 文件处理器(按天分割) + self._setup_file_handler(formatter) + + # 处理未捕获的异常 + sys.excepthook = self.handle_uncaught_exception + + def _get_log_dir(self, name: str) -> str: + """获取适合当前平台的日志目录路径""" + if self.system == 'windows': + base_dir = os.path.join(os.environ['APPDATA'], name) + elif self.system == 'darwin': # macOS + base_dir = os.path.expanduser(f"~/Library/Logs/{name}") + else: # Linux及其他Unix-like系统 + base_dir = f"/var/log/{name}" if os.access("/var/log", os.W_OK) \ + else os.path.expanduser(f"~/.local/share/{name}") + + return base_dir + + def _setup_console_handler(self, formatter: logging.Formatter): + """配置控制台输出(兼容不同终端的编码)""" + console = logging.StreamHandler() + + # Windows终端特殊处理 + if self.system == 'windows' and not sys.stdout.isatty(): + try: + import colorama + colorama.init() + except ImportError: + pass + + # 解决Windows控制台编码问题 + if sys.stdout.encoding != 'utf-8': + import io + sys.stdout = io.TextIOWrapper( + sys.stdout.buffer, + encoding='utf-8', + errors='replace' + ) + + console.setFormatter(formatter) + self.logger.addHandler(console) + + def _setup_file_handler(self, formatter: logging.Formatter): + """配置日志文件输出(UTF-8编码)""" + log_file = os.path.join( + self.log_dir, + f"{datetime.now().strftime('%Y%m%d')}.log" + ) + + # 使用TimedRotatingFileHandler实现日志分割 + file_handler = TimedRotatingFileHandler( + filename=log_file, + when='midnight', # 每天午夜分割 + encoding='utf-8', + backupCount=30 # 保留30天日志 + ) + file_handler.setFormatter(formatter) + self.logger.addHandler(file_handler) + + def handle_uncaught_exception(self, exc_type, exc_value, exc_traceback): + """全局异常捕获""" + self.logger.error( + "未捕获的异常:", + exc_info=(exc_type, exc_value, exc_traceback) + ) + + @staticmethod + def get_logger(name: str = None) -> logging.Logger: + """获取配置好的日志实例""" + return CrossPlatformLogger(name).logger + +def setup_logging(name: str = "intelligence_system"): + """快速配置日志(兼容旧代码)""" + return CrossPlatformLogger(name).logger + +# 测试代码 +if __name__ == "__main__": + logger = CrossPlatformLogger().logger + logger.info("这是一条info日志(包含中文测试)") + try: + 1 / 0 + except Exception as e: + logger.error("除零错误示例", exc_info=True) \ No newline at end of file diff --git a/utils/network.py b/utils/network.py new file mode 100644 index 0000000..66d4373 --- /dev/null +++ b/utils/network.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +网络工具模块 +功能: +1. 自动重试的HTTP请求 +2. 多平台代理配置支持 +3. DNS缓存优化 +4. 连接超时与SSL验证 +5. 用户代理轮换 +""" + +import socket +import time +import random +import platform +from typing import Optional, Dict, Any, Union +from urllib.parse import urlparse +from functools import lru_cache +import requests +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry +from urllib3.connection import HTTPConnection +import os + +# 类型别名 +TimeoutType = Union[float, tuple] + + +class NetworkUtils: + """跨平台网络操作工具类""" + + def __init__(self): + self.system = platform.system().lower() + self._dns_cache = {} + self._setup_platform_specific() + + def _setup_platform_specific(self): + """平台相关初始化""" + if self.system == 'windows': + # Windows默认关闭TCP_NODELAY + HTTPConnection.default_socket_options = ( + HTTPConnection.default_socket_options + [ + (socket.IPPROTO_TCP, socket.TCP_NODELAY, 0) + ] + ) + elif self.system == 'linux': + # Linux启用TCP快速打开 + HTTPConnection.default_socket_options = ( + HTTPConnection.default_socket_options + [ + (socket.IPPROTO_TCP, socket.TCP_FASTOPEN, 5) + ] + ) + + @staticmethod + @lru_cache(maxsize=512) + def _resolve_hostname(hostname: str) -> str: + """DNS缓存解析(跨线程安全)""" + try: + return socket.gethostbyname(hostname) + except socket.gaierror: + return hostname # 失败时返回原始域名 + + def get_session( + self, + retries: int = 3, + backoff_factor: float = 0.5, + timeout: TimeoutType = (3.05, 30), + proxy: Optional[str] = None, + verify_ssl: bool = True + ) -> requests.Session: + """ + 获取配置好的请求会话 + + 参数: + retries: 重试次数 + backoff_factor: 重试间隔系数 + timeout: (连接超时, 读取超时)秒数 + proxy: 代理地址(如 'http://user:pass@proxy:port') + verify_ssl: 是否验证SSL证书 + """ + session = requests.Session() + + # 重试策略 + retry = Retry( + total=retries, + backoff_factor=backoff_factor, + status_forcelist=[500, 502, 503, 504], + allowed_methods=frozenset(['GET', 'POST', 'PUT', 'DELETE']) + ) + + # 适配器配置 + adapter = HTTPAdapter( + max_retries=retry, + pool_connections=20, + pool_maxsize=100 + ) + + # 挂载适配器 + session.mount('http://', adapter) + session.mount('https://', adapter) + + # 代理配置 + if proxy: + session.proxies = { + 'http': proxy, + 'https': proxy + } + + # 请求默认配置 + session.request = self._wrap_request( + session.request, + timeout=timeout, + verify=verify_ssl + ) + + return session + + def _wrap_request(self, original_request, **defaults): + """包装请求方法添加默认参数""" + + def wrapped(method, url, **kwargs): + # 处理DNS缓存 + parsed = urlparse(url) + if parsed.hostname: + kwargs['hooks'] = kwargs.get('hooks', {}) + kwargs['hooks']['pre_request'] = lambda r: setattr( + r, '_orig_host', r.url + ) + url = url.replace( + parsed.hostname, + self._resolve_hostname(parsed.hostname), + 1 + ) + + # 合并默认参数 + for k, v in defaults.items(): + kwargs.setdefault(k, v) + + return original_request(method, url, **kwargs) + + return wrapped + + def get_user_agent(self) -> str: + """获取随机用户代理(兼容各平台)""" + agents = [ + # Windows + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", + # macOS + "Mozilla/5.0 (Macintosh; Intel Mac OS X 12_4) AppleWebKit/605.1.15", + # Linux + "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36", + # Mobile + "Mozilla/5.0 (iPhone; CPU iPhone OS 15_5 like Mac OS X) AppleWebKit/605.1.15" + ] + return random.choice(agents) + + def check_connection( + self, + url: str = "https://www.baidu.com", + timeout: float = 5.0 + ) -> bool: + """ + 检查网络连接状态 + + 参数: + url: 测试用的URL + timeout: 超时时间(秒) + """ + try: + session = self.get_session(retries=0, timeout=timeout) + session.head(url) + return True + except Exception: + return False + + def download_file( + self, + url: str, + save_path: str, + chunk_size: int = 8192, + progress_callback: Optional[callable] = None + ) -> bool: + """ + 下载大文件支持断点续传 + + 参数: + url: 文件URL + save_path: 本地保存路径 + chunk_size: 分块大小(字节) + progress_callback: 进度回调函数(bytes_downloaded, total_size) + """ + session = self.get_session() + headers = {} + + # 检查本地文件部分下载 + if os.path.exists(save_path): + downloaded = os.path.getsize(save_path) + headers['Range'] = f'bytes={downloaded}-' + + try: + with session.get(url, headers=headers, stream=True) as r: + r.raise_for_status() + + # 获取文件总大小 + total_size = int(r.headers.get('content-length', 0)) + downloaded + + # 追加模式写入 + mode = 'ab' if headers.get('Range') else 'wb' + with open(save_path, mode) as f: + for chunk in r.iter_content(chunk_size=chunk_size): + if chunk: # 过滤keep-alive chunks + f.write(chunk) + downloaded += len(chunk) + if progress_callback: + progress_callback(downloaded, total_size) + return True + except Exception as e: + print(f"下载失败: {str(e)}") + return False + + +# 全局实例(线程安全) +network_utils = NetworkUtils() + + +# 快捷方法(兼容旧代码) +def get_session(*args, **kwargs): + return network_utils.get_session(*args, **kwargs) + + +def check_connection(*args, **kwargs): + return network_utils.check_connection(*args, **kwargs)