🚀 Major Upgrade! Visual Workflow Orchestrator and AI-Powered Crawler Implemented. Added Model Arena Feature and Efficiency Optimizations (Two-Level Caching Architecture + End-to-End Performance Enhancements).

This commit is contained in:
戒酒的李白
2025-03-13 13:14:35 +08:00
parent ee5372941a
commit 0c6a40b869
12 changed files with 5688 additions and 78 deletions
+262 -77
View File
@@ -1,116 +1,301 @@
import json
import os
import time
import shutil
from datetime import datetime, timedelta
import threading
import queue
from collections import OrderedDict
import pickle
import hashlib
import logging
class PredictionCache:
logger = logging.getLogger('cache_manager')
logger.setLevel(logging.INFO)
class LRUCache:
"""实现LRU (Least Recently Used) 缓存策略"""
def __init__(self, capacity):
self.cache = OrderedDict()
self.capacity = capacity
def get(self, key):
if key not in self.cache:
return None
# 访问元素时,将其移至末尾,表示最近使用
self.cache.move_to_end(key)
return self.cache[key]
def put(self, key, value):
# 如果键已存在,更新值并将其移至末尾
if key in self.cache:
self.cache[key] = value
self.cache.move_to_end(key)
return
# 如果缓存已满,删除最久未使用的项(OrderedDict 的首项)
if len(self.cache) >= self.capacity:
self.cache.popitem(last=False)
# 添加新项至末尾
self.cache[key] = value
def remove(self, key):
if key in self.cache:
del self.cache[key]
def clear(self):
self.cache.clear()
def __len__(self):
return len(self.cache)
def get_all_keys(self):
return list(self.cache.keys())
class CacheManager:
"""两级缓存系统:内存LRU缓存 + 磁盘持久化缓存"""
_instance = None
_lock = threading.Lock()
def __new__(cls):
def __new__(cls, *args, **kwargs):
with cls._lock:
if cls._instance is None:
cls._instance = super(PredictionCache, cls).__new__(cls)
cls._instance = super(CacheManager, cls).__new__(cls)
return cls._instance
def __init__(self):
if not hasattr(self, 'initialized'):
self.cache_dir = 'cache/predictions'
self.cache_duration = timedelta(hours=24) # 缓存24小时
self.cache = {}
self.cache_queue = queue.Queue()
def __init__(self, name="default", memory_capacity=1000, cache_duration=24,
disk_cache_dir="cache", flush_interval=5):
if hasattr(self, 'initialized'):
return
self.name = name
self.memory_cache = LRUCache(memory_capacity)
self.disk_cache_dir = os.path.join(disk_cache_dir, name)
self.cache_duration = timedelta(hours=cache_duration)
self.flush_interval = flush_interval # 定时将内存缓存刷新到磁盘的间隔(分钟)
self.cache_stats = {"hits": 0, "misses": 0, "disk_hits": 0}
self.disk_queue = queue.Queue()
self.initialized = True
# 确保缓存目录存在
os.makedirs(self.cache_dir, exist_ok=True)
os.makedirs(self.disk_cache_dir, exist_ok=True)
# 启动缓存理线程
self.cleanup_thread = threading.Thread(target=self._cleanup_old_cache, daemon=True)
# 启动缓存理线程
self.cleanup_thread = threading.Thread(target=self._cleanup_and_flush_task, daemon=True)
self.cleanup_thread.start()
# 加载现有缓存
self._load_cache()
# 启动磁盘写入线程
self.disk_writer_thread = threading.Thread(target=self._disk_writer_task, daemon=True)
self.disk_writer_thread.start()
logger.info(f"初始化缓存管理器: {name},内存容量: {memory_capacity}项,缓存时间: {cache_duration}小时")
def _load_cache(self):
"""加载磁盘上的缓存文件"""
try:
for filename in os.listdir(self.cache_dir):
if filename.endswith('.json'):
filepath = os.path.join(self.cache_dir, filename)
with open(filepath, 'r', encoding='utf-8') as f:
cache_data = json.load(f)
# 检查缓存是否过期
if self._is_cache_valid(cache_data['timestamp']):
topic = filename[:-5] # 移除.json后缀
self.cache[topic] = cache_data
else:
# 删除过期缓存文件
os.remove(filepath)
except Exception as e:
print(f"加载缓存失败: {e}")
def _get_cache_key(self, key):
"""标准化缓存键"""
if isinstance(key, str):
return key
return hashlib.md5(str(key).encode()).hexdigest()
def _cleanup_old_cache(self):
"""定期清理过期缓存的后台线程"""
while True:
try:
# 检查并清理内存缓存
current_time = datetime.now()
expired_topics = []
for topic, cache_data in self.cache.items():
if not self._is_cache_valid(cache_data['timestamp']):
expired_topics.append(topic)
# 删除过期缓存
for topic in expired_topics:
del self.cache[topic]
cache_file = os.path.join(self.cache_dir, f"{topic}.json")
if os.path.exists(cache_file):
os.remove(cache_file)
# 休眠1小时后再次检查
time.sleep(3600)
except Exception as e:
print(f"清理缓存时出错: {e}")
time.sleep(3600) # 发生错误时也等待1小时
def _get_disk_path(self, key):
"""获取磁盘缓存路径"""
safe_key = self._get_cache_key(key)
return os.path.join(self.disk_cache_dir, f"{safe_key}.cache")
def _is_cache_valid(self, timestamp):
"""检查缓存是否有效"""
"""检查缓存是否过期"""
cache_time = datetime.fromtimestamp(timestamp)
return datetime.now() - cache_time < self.cache_duration
def get(self, topic):
"""获取话题的预测缓存"""
if topic in self.cache and self._is_cache_valid(self.cache[topic]['timestamp']):
return self.cache[topic]['prediction']
def get(self, key):
"""获取缓存数据,首先检查内存,然后检查磁盘"""
cache_key = self._get_cache_key(key)
# 1. 检查内存缓存
cache_data = self.memory_cache.get(cache_key)
if cache_data is not None:
if self._is_cache_valid(cache_data['timestamp']):
self.cache_stats["hits"] += 1
logger.debug(f"内存缓存命中: {key}")
return cache_data['data']
else:
# 过期缓存,从内存中删除
self.memory_cache.remove(cache_key)
# 2. 检查磁盘缓存
disk_path = self._get_disk_path(cache_key)
if os.path.exists(disk_path):
try:
with open(disk_path, 'rb') as f:
cache_data = pickle.load(f)
if self._is_cache_valid(cache_data['timestamp']):
# 从磁盘加载后,放入内存缓存
self.memory_cache.put(cache_key, cache_data)
self.cache_stats["disk_hits"] += 1
logger.debug(f"磁盘缓存命中: {key}")
return cache_data['data']
else:
# 过期缓存,删除磁盘文件
os.remove(disk_path)
except Exception as e:
logger.warning(f"读取磁盘缓存失败: {key}, 错误: {e}")
self.cache_stats["misses"] += 1
logger.debug(f"缓存未命中: {key}")
return None
def set(self, topic, prediction):
"""设置话题的预测缓存"""
def set(self, key, data, immediate_disk_write=False):
"""设置缓存数据,同时更新内存和安排磁盘写入"""
cache_key = self._get_cache_key(key)
cache_data = {
'prediction': prediction,
'data': data,
'timestamp': datetime.now().timestamp()
}
# 更新内存缓存
self.cache[topic] = cache_data
self.memory_cache.put(cache_key, cache_data)
# 异步保存到磁盘
self.cache_queue.put((topic, cache_data))
threading.Thread(target=self._save_cache_to_disk, daemon=True).start()
# 安排写入磁盘
if immediate_disk_write:
self._write_to_disk(cache_key, cache_data)
else:
self.disk_queue.put((cache_key, cache_data))
logger.debug(f"缓存已设置: {key}")
return True
def _save_cache_to_disk(self):
"""异步保存缓存到磁盘"""
def invalidate(self, key):
"""使指定键的缓存失效"""
cache_key = self._get_cache_key(key)
# 从内存中删除
self.memory_cache.remove(cache_key)
# 从磁盘中删除
disk_path = self._get_disk_path(cache_key)
if os.path.exists(disk_path):
try:
os.remove(disk_path)
logger.debug(f"缓存已失效: {key}")
except Exception as e:
logger.warning(f"删除磁盘缓存失败: {key}, 错误: {e}")
return True
def clear_all(self):
"""清除所有缓存"""
# 清除内存缓存
self.memory_cache.clear()
# 清除磁盘缓存
try:
while not self.cache_queue.empty():
topic, cache_data = self.cache_queue.get()
cache_file = os.path.join(self.cache_dir, f"{topic}.json")
with open(cache_file, 'w', encoding='utf-8') as f:
json.dump(cache_data, f, ensure_ascii=False, indent=2)
shutil.rmtree(self.disk_cache_dir)
os.makedirs(self.disk_cache_dir, exist_ok=True)
logger.info(f"所有缓存已清除: {self.name}")
except Exception as e:
print(f"保存缓存到磁盘失败: {e}")
logger.error(f"清除磁盘缓存失败: {e}")
# 重置统计信息
self.cache_stats = {"hits": 0, "misses": 0, "disk_hits": 0}
return True
def get_stats(self):
"""获取缓存统计信息"""
total_requests = self.cache_stats["hits"] + self.cache_stats["misses"]
hit_rate = (self.cache_stats["hits"] / total_requests * 100) if total_requests > 0 else 0
total_hits = self.cache_stats["hits"] + self.cache_stats["disk_hits"]
memory_size = len(self.memory_cache)
disk_size = len([f for f in os.listdir(self.disk_cache_dir) if f.endswith('.cache')])
return {
"name": self.name,
"memory_items": memory_size,
"disk_items": disk_size,
"memory_hits": self.cache_stats["hits"],
"disk_hits": self.cache_stats["disk_hits"],
"misses": self.cache_stats["misses"],
"total_requests": total_requests,
"hit_rate": hit_rate,
"two_level_hit_rate": (total_hits / total_requests * 100) if total_requests > 0 else 0
}
def _write_to_disk(self, cache_key, cache_data):
"""将缓存写入磁盘"""
disk_path = self._get_disk_path(cache_key)
try:
with open(disk_path, 'wb') as f:
pickle.dump(cache_data, f)
return True
except Exception as e:
logger.warning(f"写入磁盘缓存失败: {cache_key}, 错误: {e}")
return False
def _disk_writer_task(self):
"""后台线程,负责将缓存写入磁盘"""
while True:
try:
# 尝试从队列获取条目,超时后继续循环
try:
cache_key, cache_data = self.disk_queue.get(timeout=1)
self._write_to_disk(cache_key, cache_data)
self.disk_queue.task_done()
except queue.Empty:
time.sleep(0.1)
except Exception as e:
logger.error(f"磁盘写入线程出错: {e}")
time.sleep(5) # 发生错误时等待一段时间
def _cleanup_and_flush_task(self):
"""后台线程,负责清理过期缓存和定期刷新内存缓存到磁盘"""
while True:
try:
# 1. 清理过期的内存缓存
current_time = datetime.now()
for key in self.memory_cache.get_all_keys():
cache_data = self.memory_cache.get(key)
if not self._is_cache_valid(cache_data['timestamp']):
self.memory_cache.remove(key)
# 2. 清理过期的磁盘缓存
for filename in os.listdir(self.disk_cache_dir):
if filename.endswith('.cache'):
filepath = os.path.join(self.disk_cache_dir, filename)
try:
with open(filepath, 'rb') as f:
cache_data = pickle.load(f)
if not self._is_cache_valid(cache_data['timestamp']):
os.remove(filepath)
except Exception as e:
# 清理损坏的缓存文件
logger.warning(f"读取缓存文件失败,将删除: {filepath}, 错误: {e}")
os.remove(filepath)
# 3. 将内存缓存刷新到磁盘
# 注意:这会重写已经写入磁盘的缓存,但确保内存和磁盘保持同步
for key in self.memory_cache.get_all_keys():
cache_data = self.memory_cache.get(key)
self._write_to_disk(key, cache_data)
# 每小时执行一次清理
time.sleep(3600)
except Exception as e:
logger.error(f"缓存清理线程出错: {e}")
time.sleep(3600) # 发生错误时也等待一段时间
# 创建全局缓存实例
prediction_cache = PredictionCache()
# 创建不同领域的缓存实例
prediction_cache = CacheManager(name="predictions", memory_capacity=500, cache_duration=24)
sentiment_cache = CacheManager(name="sentiment", memory_capacity=1000, cache_duration=12)
topic_cache = CacheManager(name="topics", memory_capacity=200, cache_duration=6)
user_data_cache = CacheManager(name="user_data", memory_capacity=300, cache_duration=48)
# 向后兼容的别名
PredictionCache = CacheManager
# 为保持向后兼容,我们保留原来的prediction_cache
prediction_cache_old = prediction_cache
+558
View File
@@ -0,0 +1,558 @@
import os
import sys
import json
import getpass
import secrets
import logging
import platform
import socket
import hashlib
import base64
import re
import shutil
import subprocess
from pathlib import Path
from datetime import datetime
import pymysql
from dotenv import load_dotenv, set_key, find_dotenv
# 设置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger('init_wizard')
class InitWizard:
"""
初始化向导 - 简化系统的初始配置流程,并提供安全加固功能
"""
def __init__(self):
# 加载环境变量
load_dotenv()
# 配置项
self.config = {
# 数据库配置
'db': {
'host': os.getenv('DB_HOST', 'localhost'),
'port': int(os.getenv('DB_PORT', '3306')),
'user': os.getenv('DB_USER', 'root'),
'password': os.getenv('DB_PASSWORD', ''),
'database': os.getenv('DB_NAME', 'Weibo_PublicOpinion_AnalysisSystem'),
'ssl': bool(os.getenv('DB_SSL', 'false').lower() == 'true')
},
# Flask应用配置
'app': {
'host': os.getenv('FLASK_HOST', '127.0.0.1'),
'port': int(os.getenv('FLASK_PORT', '5000')),
'secret_key': os.getenv('FLASK_SECRET_KEY', ''),
'enable_https': bool(os.getenv('ENABLE_HTTPS', 'false').lower() == 'true'),
'debug': bool(os.getenv('FLASK_DEBUG', 'false').lower() == 'true')
},
# API密钥配置
'api_keys': {
'openai': os.getenv('OPENAI_API_KEY', ''),
'anthropic': os.getenv('ANTHROPIC_API_KEY', ''),
'deepseek': os.getenv('DEEPSEEK_API_KEY', '')
},
# 安全配置
'security': {
'enable_rate_limit': bool(os.getenv('ENABLE_RATE_LIMIT', 'true').lower() == 'true'),
'enable_ip_blocking': bool(os.getenv('ENABLE_IP_BLOCKING', 'true').lower() == 'true'),
'enable_sensitive_data_filter': bool(os.getenv('ENABLE_SENSITIVE_DATA_FILTER', 'true').lower() == 'true'),
'enable_mutual_auth': bool(os.getenv('ENABLE_MUTUAL_AUTH', 'false').lower() == 'true'),
'min_password_length': int(os.getenv('MIN_PASSWORD_LENGTH', '8')),
'session_timeout': int(os.getenv('SESSION_TIMEOUT', '120')), # 分钟
},
# 爬虫配置
'crawler': {
'interval': int(os.getenv('CRAWL_INTERVAL', '18000')), # 秒
'max_retries': int(os.getenv('CRAWL_MAX_RETRIES', '3')),
'timeout': int(os.getenv('CRAWL_TIMEOUT', '30')),
'max_concurrent': int(os.getenv('CRAWL_MAX_CONCURRENT', '2')),
'user_agent': os.getenv('CRAWL_USER_AGENT', 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36')
},
# 系统配置
'system': {
'initialized': bool(os.getenv('SYSTEM_INITIALIZED', 'false').lower() == 'true'),
'version': os.getenv('SYSTEM_VERSION', '2.0.0'),
'log_level': os.getenv('LOG_LEVEL', 'INFO'),
'data_dir': os.getenv('DATA_DIR', 'data'),
'temp_dir': os.getenv('TEMP_DIR', 'temp'),
'cache_dir': os.getenv('CACHE_DIR', 'cache'),
'max_model_memory': float(os.getenv('MAX_MODEL_MEMORY_USAGE', '4.0')), # GB
}
}
# 安全选项
self.security_options = {
'rate_limit': {
'name': '请求速率限制',
'description': '防止API被滥用,限制单个IP的请求频率',
'default': True
},
'ip_blocking': {
'name': 'IP黑名单',
'description': '阻止可疑IP访问系统',
'default': True
},
'sensitive_data_filter': {
'name': '敏感信息过滤',
'description': '自动识别并屏蔽输出内容中的敏感信息(如手机号、邮箱等)',
'default': True
},
'mutual_auth': {
'name': '双向认证',
'description': '要求API调用方提供有效证书,增强API安全性(需要HTTPS)',
'default': False
}
}
def start(self):
"""启动初始化向导"""
self._print_welcome()
if self.config['system']['initialized']:
print("\n系统已经初始化过。您想重新配置吗? [y/N]: ", end='')
choice = input().strip().lower()
if choice != 'y':
print("初始化向导已退出。如需重新配置,请设置环境变量 SYSTEM_INITIALIZED=false 或删除 .env 文件。")
return
# 主配置流程
try:
self._configure_database()
self._configure_app()
self._configure_api_keys()
self._configure_security()
self._configure_crawler()
self._configure_system()
# 保存配置
self._save_config()
# 应用安全措施
self._apply_security_measures()
print("\n✅ 初始化完成!系统已成功配置。")
print("您现在可以运行 python app.py 启动应用。")
except KeyboardInterrupt:
print("\n\n初始化向导已取消。配置未保存。")
except Exception as e:
logger.error(f"初始化过程中发生错误: {e}")
print(f"\n❌ 初始化失败: {e}")
print("请检查错误并重试。")
def _print_welcome(self):
"""打印欢迎信息"""
print("\n" + "="*80)
print(" "*20 + "微博舆情分析预测系统 - 初始化向导 v2.0")
print("="*80)
print("\n欢迎使用微博舆情分析预测系统!此向导将引导您完成系统的初始配置。")
print("按Ctrl+C可随时退出向导。")
print("\n系统信息:")
print(f" • 操作系统: {platform.system()} {platform.release()}")
print(f" • Python版本: {platform.python_version()}")
print(f" • 主机名: {socket.gethostname()}")
print(f" • 当前时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("\n让我们开始配置吧!每个选项都有默认值,直接按回车即可使用默认值。")
print("-"*80)
def _configure_database(self):
"""配置数据库连接"""
print("\n📦 数据库配置")
print("-"*50)
# 询问数据库连接信息
self.config['db']['host'] = self._prompt(
"数据库主机", self.config['db']['host'])
port_str = self._prompt(
"数据库端口", str(self.config['db']['port']))
try:
self.config['db']['port'] = int(port_str)
except ValueError:
print(f"端口号无效,使用默认值 {self.config['db']['port']}")
self.config['db']['user'] = self._prompt(
"数据库用户名", self.config['db']['user'])
# 密码使用getpass以避免明文显示
default_pass = '*' * len(self.config['db']['password']) if self.config['db']['password'] else ''
password = getpass.getpass(f"数据库密码 [{default_pass}]: ")
if password:
self.config['db']['password'] = password
self.config['db']['database'] = self._prompt(
"数据库名", self.config['db']['database'])
ssl_str = self._prompt(
"使用SSL连接 (true/false)", str(self.config['db']['ssl']).lower())
self.config['db']['ssl'] = ssl_str.lower() == 'true'
# 测试数据库连接
print("\n正在测试数据库连接...")
try:
self._test_db_connection()
print("✅ 数据库连接成功!")
except Exception as e:
print(f"❌ 数据库连接失败: {e}")
retry = input("是否重新配置数据库连接? [Y/n]: ").strip().lower()
if retry != 'n':
return self._configure_database()
else:
print("跳过数据库连接测试,但配置可能不正确。")
def _configure_app(self):
"""配置Flask应用"""
print("\n🚀 应用配置")
print("-"*50)
self.config['app']['host'] = self._prompt(
"监听地址 (0.0.0.0表示所有网络接口)", self.config['app']['host'])
port_str = self._prompt(
"监听端口", str(self.config['app']['port']))
try:
self.config['app']['port'] = int(port_str)
except ValueError:
print(f"端口号无效,使用默认值 {self.config['app']['port']}")
# 自动生成密钥
if not self.config['app']['secret_key']:
self.config['app']['secret_key'] = secrets.token_hex(32)
print(f"已自动生成应用密钥: {self.config['app']['secret_key'][:8]}...")
else:
regenerate = input("应用密钥已存在。是否重新生成? [y/N]: ").strip().lower()
if regenerate == 'y':
self.config['app']['secret_key'] = secrets.token_hex(32)
print(f"已重新生成应用密钥: {self.config['app']['secret_key'][:8]}...")
https_str = self._prompt(
"启用HTTPS (true/false)", str(self.config['app']['enable_https']).lower())
self.config['app']['enable_https'] = https_str.lower() == 'true'
debug_str = self._prompt(
"启用调试模式 (true/false, 生产环境建议false)", str(self.config['app']['debug']).lower())
self.config['app']['debug'] = debug_str.lower() == 'true'
def _configure_api_keys(self):
"""配置API密钥"""
print("\n🔑 API密钥配置")
print("-"*50)
print("系统支持多个大语言模型,至少需要配置一个API密钥。")
# 配置OpenAI API密钥
has_openai = self._prompt(
"是否配置OpenAI API密钥? (y/n)", "y" if self.config['api_keys']['openai'] else "n")
if has_openai.lower() == 'y':
self.config['api_keys']['openai'] = self._prompt(
"OpenAI API密钥", self.config['api_keys']['openai'])
# 配置Anthropic API密钥
has_anthropic = self._prompt(
"是否配置Anthropic (Claude) API密钥? (y/n)", "y" if self.config['api_keys']['anthropic'] else "n")
if has_anthropic.lower() == 'y':
self.config['api_keys']['anthropic'] = self._prompt(
"Anthropic API密钥", self.config['api_keys']['anthropic'])
# 配置DeepSeek API密钥
has_deepseek = self._prompt(
"是否配置DeepSeek API密钥? (y/n)", "y" if self.config['api_keys']['deepseek'] else "n")
if has_deepseek.lower() == 'y':
self.config['api_keys']['deepseek'] = self._prompt(
"DeepSeek API密钥", self.config['api_keys']['deepseek'])
# 检查是否至少配置了一个API密钥
if not (self.config['api_keys']['openai'] or self.config['api_keys']['anthropic'] or self.config['api_keys']['deepseek']):
print("⚠️ 警告: 您未配置任何API密钥,系统的AI分析功能将不可用。")
confirm = input("是否继续? [Y/n]: ").strip().lower()
if confirm == 'n':
return self._configure_api_keys()
def _configure_security(self):
"""配置安全设置"""
print("\n🔒 安全配置")
print("-"*50)
for key, option in self.security_options.items():
current_value = self.config['security'][f'enable_{key}']
print(f"\n{option['name']}: {option['description']}")
enable_str = self._prompt(
f"启用{option['name']} (true/false)", str(current_value).lower())
self.config['security'][f'enable_{key}'] = enable_str.lower() == 'true'
# 密码安全策略
min_len_str = self._prompt(
"最小密码长度 (推荐不低于8)", str(self.config['security']['min_password_length']))
try:
self.config['security']['min_password_length'] = int(min_len_str)
if self.config['security']['min_password_length'] < 6:
print("⚠️ 警告: 短密码容易被暴力破解,建议设置更长的密码。")
except ValueError:
print(f"无效输入,使用默认值 {self.config['security']['min_password_length']}")
# 会话超时设置
timeout_str = self._prompt(
"会话超时时间 (分钟)", str(self.config['security']['session_timeout']))
try:
self.config['security']['session_timeout'] = int(timeout_str)
except ValueError:
print(f"无效输入,使用默认值 {self.config['security']['session_timeout']}")
def _configure_crawler(self):
"""配置爬虫设置"""
print("\n🕷️ 爬虫配置")
print("-"*50)
interval_str = self._prompt(
"爬取间隔 (秒)", str(self.config['crawler']['interval']))
try:
self.config['crawler']['interval'] = int(interval_str)
except ValueError:
print(f"无效输入,使用默认值 {self.config['crawler']['interval']}")
retries_str = self._prompt(
"最大重试次数", str(self.config['crawler']['max_retries']))
try:
self.config['crawler']['max_retries'] = int(retries_str)
except ValueError:
print(f"无效输入,使用默认值 {self.config['crawler']['max_retries']}")
timeout_str = self._prompt(
"超时时间 (秒)", str(self.config['crawler']['timeout']))
try:
self.config['crawler']['timeout'] = int(timeout_str)
except ValueError:
print(f"无效输入,使用默认值 {self.config['crawler']['timeout']}")
concurrent_str = self._prompt(
"最大并发数", str(self.config['crawler']['max_concurrent']))
try:
self.config['crawler']['max_concurrent'] = int(concurrent_str)
except ValueError:
print(f"无效输入,使用默认值 {self.config['crawler']['max_concurrent']}")
self.config['crawler']['user_agent'] = self._prompt(
"User-Agent", self.config['crawler']['user_agent'])
def _configure_system(self):
"""配置系统设置"""
print("\n⚙️ 系统配置")
print("-"*50)
# 日志级别
log_levels = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']
current_level = self.config['system']['log_level']
print(f"可选日志级别: {', '.join(log_levels)}")
log_level = self._prompt("日志级别", current_level).upper()
if log_level in log_levels:
self.config['system']['log_level'] = log_level
else:
print(f"无效的日志级别,使用默认值 {current_level}")
# 数据目录
data_dir = self._prompt("数据目录", self.config['system']['data_dir'])
if data_dir:
self.config['system']['data_dir'] = data_dir
os.makedirs(data_dir, exist_ok=True)
print(f"已创建数据目录: {data_dir}")
# 缓存目录
cache_dir = self._prompt("缓存目录", self.config['system']['cache_dir'])
if cache_dir:
self.config['system']['cache_dir'] = cache_dir
os.makedirs(cache_dir, exist_ok=True)
print(f"已创建缓存目录: {cache_dir}")
# 临时目录
temp_dir = self._prompt("临时文件目录", self.config['system']['temp_dir'])
if temp_dir:
self.config['system']['temp_dir'] = temp_dir
os.makedirs(temp_dir, exist_ok=True)
print(f"已创建临时文件目录: {temp_dir}")
# 模型内存限制
memory_str = self._prompt(
"最大模型内存使用量 (GB)", str(self.config['system']['max_model_memory']))
try:
self.config['system']['max_model_memory'] = float(memory_str)
except ValueError:
print(f"无效输入,使用默认值 {self.config['system']['max_model_memory']}")
# 标记系统已初始化
self.config['system']['initialized'] = True
def _save_config(self):
"""保存配置到.env文件"""
print("\n正在保存配置...")
# 构建.env文件内容
env_content = [
"# 微博舆情分析预测系统配置文件",
f"# 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
"",
"# 数据库配置",
f"DB_HOST={self.config['db']['host']}",
f"DB_PORT={self.config['db']['port']}",
f"DB_USER={self.config['db']['user']}",
f"DB_PASSWORD={self.config['db']['password']}",
f"DB_NAME={self.config['db']['database']}",
f"DB_SSL={str(self.config['db']['ssl']).lower()}",
"",
"# 应用配置",
f"FLASK_HOST={self.config['app']['host']}",
f"FLASK_PORT={self.config['app']['port']}",
f"FLASK_SECRET_KEY={self.config['app']['secret_key']}",
f"ENABLE_HTTPS={str(self.config['app']['enable_https']).lower()}",
f"FLASK_DEBUG={str(self.config['app']['debug']).lower()}",
"",
"# API密钥",
f"OPENAI_API_KEY={self.config['api_keys']['openai']}",
f"ANTHROPIC_API_KEY={self.config['api_keys']['anthropic']}",
f"DEEPSEEK_API_KEY={self.config['api_keys']['deepseek']}",
"",
"# 安全配置",
f"ENABLE_RATE_LIMIT={str(self.config['security']['enable_rate_limit']).lower()}",
f"ENABLE_IP_BLOCKING={str(self.config['security']['enable_ip_blocking']).lower()}",
f"ENABLE_SENSITIVE_DATA_FILTER={str(self.config['security']['enable_sensitive_data_filter']).lower()}",
f"ENABLE_MUTUAL_AUTH={str(self.config['security']['enable_mutual_auth']).lower()}",
f"MIN_PASSWORD_LENGTH={self.config['security']['min_password_length']}",
f"SESSION_TIMEOUT={self.config['security']['session_timeout']}",
"",
"# 爬虫配置",
f"CRAWL_INTERVAL={self.config['crawler']['interval']}",
f"CRAWL_MAX_RETRIES={self.config['crawler']['max_retries']}",
f"CRAWL_TIMEOUT={self.config['crawler']['timeout']}",
f"CRAWL_MAX_CONCURRENT={self.config['crawler']['max_concurrent']}",
f"CRAWL_USER_AGENT={self.config['crawler']['user_agent']}",
"",
"# 系统配置",
f"SYSTEM_INITIALIZED={str(self.config['system']['initialized']).lower()}",
f"SYSTEM_VERSION={self.config['system']['version']}",
f"LOG_LEVEL={self.config['system']['log_level']}",
f"DATA_DIR={self.config['system']['data_dir']}",
f"TEMP_DIR={self.config['system']['temp_dir']}",
f"CACHE_DIR={self.config['system']['cache_dir']}",
f"MAX_MODEL_MEMORY_USAGE={self.config['system']['max_model_memory']}",
]
# 写入.env文件
with open('.env', 'w') as f:
f.write('\n'.join(env_content))
print("✅ 配置已保存到 .env 文件")
# 创建备份
backup_path = f".env.backup.{datetime.now().strftime('%Y%m%d%H%M%S')}"
shutil.copy2('.env', backup_path)
print(f"✅ 配置备份已保存到 {backup_path}")
def _test_db_connection(self):
"""测试数据库连接"""
connection = pymysql.connect(
host=self.config['db']['host'],
port=self.config['db']['port'],
user=self.config['db']['user'],
password=self.config['db']['password'],
database=self.config['db']['database'],
charset='utf8mb4',
ssl={'ssl': {'ca': None}} if self.config['db']['ssl'] else None
)
connection.close()
def _apply_security_measures(self):
"""应用安全措施"""
print("\n正在应用安全措施...")
# 创建相关目录
security_dir = os.path.join(self.config['system']['data_dir'], 'security')
os.makedirs(security_dir, exist_ok=True)
# 设置文件权限
try:
# 仅在类Unix系统上设置文件权限
if platform.system() != "Windows":
os.chmod('.env', 0o600) # 只有所有者可读写
print("✅ 已设置.env文件权限为600 (只有所有者可读写)")
except Exception as e:
logger.warning(f"设置文件权限失败: {e}")
# 生成密钥对(如果启用了双向认证)
if self.config['security']['enable_mutual_auth']:
cert_dir = os.path.join(security_dir, 'certs')
os.makedirs(cert_dir, exist_ok=True)
try:
# 检查是否有OpenSSL可用
subprocess.run(['openssl', 'version'], check=True, capture_output=True)
# 生成自签名证书
key_file = os.path.join(cert_dir, 'server.key')
cert_file = os.path.join(cert_dir, 'server.crt')
if not os.path.exists(key_file) or not os.path.exists(cert_file):
print("正在生成SSL证书...")
subprocess.run([
'openssl', 'req', '-x509', '-newkey', 'rsa:4096',
'-keyout', key_file, '-out', cert_file,
'-days', '365', '-nodes',
'-subj', '/CN=localhost'
], check=True)
print(f"✅ SSL证书已生成: {cert_file}")
except subprocess.CalledProcessError:
print("⚠️ OpenSSL不可用,无法生成SSL证书。如需使用HTTPS,请手动配置证书。")
except Exception as e:
logger.warning(f"生成SSL证书失败: {e}")
# 创建敏感信息过滤器配置
if self.config['security']['enable_sensitive_data_filter']:
filter_config = {
'enabled': True,
'patterns': {
'phone': r'\b1[3-9]\d{9}\b',
'email': r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
'id_card': r'\b[1-9]\d{5}(19|20)\d{2}(0[1-9]|1[0-2])(0[1-9]|[12]\d|3[01])\d{3}[\dXx]\b',
'credit_card': r'\b\d{4}[ -]?\d{4}[ -]?\d{4}[ -]?\d{4}\b',
'address': r'(北京|上海|广州|深圳|天津|重庆|南京|杭州|武汉|成都|西安)市.*?(路|街|道|巷).*?(号)'
},
'replacements': {
'phone': '***********',
'email': '******@*****',
'id_card': '******************',
'credit_card': '****************',
'address': '[地址已隐藏]'
}
}
filter_path = os.path.join(security_dir, 'sensitive_filter.json')
with open(filter_path, 'w', encoding='utf-8') as f:
json.dump(filter_config, f, ensure_ascii=False, indent=2)
print(f"✅ 敏感信息过滤器配置已保存到 {filter_path}")
# 创建IP黑名单文件
if self.config['security']['enable_ip_blocking']:
blacklist_path = os.path.join(security_dir, 'ip_blacklist.txt')
if not os.path.exists(blacklist_path):
with open(blacklist_path, 'w') as f:
f.write("# 每行一个IP地址\n")
print(f"✅ IP黑名单文件已创建: {blacklist_path}")
def _prompt(self, prompt, default=""):
"""提示用户输入,如果用户直接按回车则返回默认值"""
if default:
user_input = input(f"{prompt} [{default}]: ").strip()
else:
user_input = input(f"{prompt}: ").strip()
return user_input if user_input else default
if __name__ == "__main__":
wizard = InitWizard()
wizard.start()
+285
View File
@@ -0,0 +1,285 @@
import os
import sys
import pickle
import marshal
import types
import logging
import torch
import numpy as np
import json
from pathlib import Path
logger = logging.getLogger('model_loader')
logger.setLevel(logging.INFO)
def load_sentiment_model(model_path, device=None):
"""
加载情感分析模型
参数:
model_path: 模型文件路径
device: 设备(可忽略,marshal模型不依赖设备)
返回:
加载好的模型对象
"""
try:
logger.info(f"加载情感分析模型: {model_path}")
if model_path.endswith('.marshal') or model_path.endswith('.marshal.3'):
with open(model_path, 'rb') as f:
model_data = marshal.load(f)
# 将marshal数据转换为可调用的函数对象
sentiment_func = types.FunctionType(model_data, globals(), "sentiment_func")
logger.info("情感分析模型加载成功")
return sentiment_func
else:
raise ValueError(f"不支持的情感模型格式: {model_path}")
except Exception as e:
logger.error(f"加载情感分析模型失败: {e}")
raise
def load_bert_ctm_model(model_dir, device='cuda' if torch.cuda.is_available() else 'cpu'):
"""
加载BERT-CTM模型
参数:
model_dir: 模型目录
device: 计算设备
返回:
包含模型和分词器的字典
"""
try:
logger.info(f"加载BERT-CTM模型: {model_dir}")
sys.path.append('model_pro')
from BERT_CTM import BERT_CTM
from transformers import BertTokenizer
# 加载模型
model_path = os.path.join(model_dir, 'final_model.pt') if not model_dir.endswith('.pt') else model_dir
model = BERT_CTM()
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
# 加载分词器
tokenizer_path = os.path.join(os.path.dirname(model_dir), 'bert_model')
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
logger.info("BERT-CTM模型加载成功")
return {
'model': model,
'tokenizer': tokenizer,
'device': device
}
except Exception as e:
logger.error(f"加载BERT-CTM模型失败: {e}")
raise
def load_bcat_model(model_dir, device='cuda' if torch.cuda.is_available() else 'cpu'):
"""
加载BCAT模型
参数:
model_dir: 模型目录
device: 计算设备
返回:
包含模型和分词器的字典
"""
try:
logger.info(f"加载BCAT模型: {model_dir}")
sys.path.append('model_pro')
from BCAT import BCAT
from transformers import BertTokenizer
# 加载模型配置
config_path = os.path.join(model_dir, 'config.json')
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
# 初始化模型
model = BCAT(**config)
# 加载模型权重
model_path = os.path.join(model_dir, 'model.pt')
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
# 加载分词器
tokenizer_path = os.path.join(model_dir, 'tokenizer')
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
logger.info("BCAT模型加载成功")
return {
'model': model,
'tokenizer': tokenizer,
'device': device,
'config': config
}
except Exception as e:
logger.error(f"加载BCAT模型失败: {e}")
raise
def load_topic_classifier(model_dir, device='cuda' if torch.cuda.is_available() else 'cpu'):
"""
加载话题分类模型
参数:
model_dir: 模型目录
device: 计算设备
返回:
包含模型、分词器和标签映射的字典
"""
try:
logger.info(f"加载话题分类模型: {model_dir}")
# 尝试加载transformers模型
try:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# 加载模型
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
model.to(device)
model.eval()
# 加载分词器
tokenizer = AutoTokenizer.from_pretrained(model_dir)
# 加载标签映射
labels_path = os.path.join(model_dir, 'labels.json')
if os.path.exists(labels_path):
with open(labels_path, 'r', encoding='utf-8') as f:
labels_map = json.load(f)
else:
# 尝试从config中读取标签
if hasattr(model.config, 'id2label'):
labels_map = model.config.id2label
else:
labels_map = {}
logger.info("话题分类模型加载成功 (transformers)")
return {
'model': model,
'tokenizer': tokenizer,
'labels_map': labels_map,
'device': device
}
except Exception as e:
logger.warning(f"使用transformers加载失败,尝试其他方法: {e}")
# 尝试加载PyTorch模型
model_path = os.path.join(model_dir, 'model.pt')
if os.path.exists(model_path):
model = torch.load(model_path, map_location=device)
# 加载分词器
tokenizer_path = os.path.join(model_dir, 'tokenizer.pkl')
if os.path.exists(tokenizer_path):
with open(tokenizer_path, 'rb') as f:
tokenizer = pickle.load(f)
else:
tokenizer = None
# 加载标签映射
labels_path = os.path.join(model_dir, 'labels.json')
if os.path.exists(labels_path):
with open(labels_path, 'r', encoding='utf-8') as f:
labels_map = json.load(f)
else:
labels_map = {}
logger.info("话题分类模型加载成功 (PyTorch)")
return {
'model': model,
'tokenizer': tokenizer,
'labels_map': labels_map,
'device': device
}
raise ValueError(f"无法加载模型: {model_dir}")
except Exception as e:
logger.error(f"加载话题分类模型失败: {e}")
raise
def load_echarts_optimizer():
"""
加载ECharts优化器,用于提升大数据渲染性能
返回:
ECharts优化器对象
"""
try:
class EChartsOptimizer:
def __init__(self):
self.chunk_size = 1000 # 分块大小
logger.info("ECharts优化器初始化成功")
def optimize_option(self, option):
"""优化ECharts配置,提升大数据渲染性能"""
if not option:
return option
# 深拷贝以避免修改原始对象
import copy
option = copy.deepcopy(option)
# 添加渐进式渲染
if 'progressive' not in option:
option['progressive'] = 300 # 每帧渲染的数据点数量
if 'progressiveThreshold' not in option:
option['progressiveThreshold'] = 5000 # 启动渐进式渲染的阈值
if 'series' in option and isinstance(option['series'], list):
for series in option['series']:
# 对大数据系列应用优化
if 'data' in series and isinstance(series['data'], list) and len(series['data']) > 5000:
# 大数据采样
if series.get('type') in ['scatter', 'line']:
self._optimize_large_data_series(series)
return option
def _optimize_large_data_series(self, series):
"""优化大数据系列"""
# 添加大数据优化选项
series['large'] = True
series['largeThreshold'] = 2000
# 按需设置抽样
if len(series['data']) > 50000:
# 对非常大的数据集进行抽样
step = max(1, len(series['data']) // 50000)
series['data'] = series['data'][::step]
series['sampling'] = 'average'
return series
def chunk_process_data(self, data, process_func):
"""分块处理大数据"""
result = []
for i in range(0, len(data), self.chunk_size):
chunk = data[i:i + self.chunk_size]
result.extend(process_func(chunk))
return result
return EChartsOptimizer()
except Exception as e:
logger.error(f"加载ECharts优化器失败: {e}")
return None
# 导出所有加载函数
__all__ = [
'load_sentiment_model',
'load_bert_ctm_model',
'load_bcat_model',
'load_topic_classifier',
'load_echarts_optimizer'
]
+364
View File
@@ -0,0 +1,364 @@
import os
import time
import threading
import logging
import gc
import torch
import numpy as np
from collections import OrderedDict
from datetime import datetime, timedelta
logger = logging.getLogger('model_manager')
logger.setLevel(logging.INFO)
class ModelManager:
"""
模型管理器 - 实现模型预加载和按需卸载技术
功能:
1. 预加载经常使用的模型,减少加载等待时间
2. 使用LRU (Least Recently Used) 策略管理内存中加载的模型
3. 支持模型的异步加载和监控
4. 自动检测并释放长时间未使用的模型内存
5. 提供模型使用统计
"""
_instance = None
_lock = threading.Lock()
def __new__(cls):
with cls._lock:
if cls._instance is None:
cls._instance = super(ModelManager, cls).__new__(cls)
return cls._instance
def __init__(self):
if hasattr(self, 'initialized'):
return
# 已加载模型的缓存,使用OrderedDict实现LRU
self.loaded_models = OrderedDict()
# 模型使用统计
self.model_stats = {}
# 模型预热配置
self.preload_config = {}
# 最大内存占用(GB
self.max_memory_usage = float(os.getenv('MAX_MODEL_MEMORY_USAGE', '4.0'))
# 模型加载中的锁
self.loading_locks = {}
# 模型卸载超时(分钟)
self.unload_timeout = int(os.getenv('MODEL_UNLOAD_TIMEOUT', '30'))
# 启动模型监控线程
self.monitor_thread = threading.Thread(target=self._monitor_models, daemon=True)
self.monitor_thread.start()
self.initialized = True
logger.info(f"模型管理器初始化完成,最大内存占用: {self.max_memory_usage}GB")
def register_model(self, model_id, model_path, preload=False, model_size_gb=0.5,
load_function=None, device='cuda' if torch.cuda.is_available() else 'cpu'):
"""
注册模型,可选设置为预加载
参数:
model_id: 模型唯一标识符
model_path: 模型路径
preload: 是否预加载
model_size_gb: 模型估计大小(GB
load_function: 自定义加载函数,签名为 load_function(model_path, device) -> model
device: 加载模型的设备
"""
self.preload_config[model_id] = {
'model_path': model_path,
'preload': preload,
'model_size_gb': model_size_gb,
'load_function': load_function,
'device': device
}
self.model_stats[model_id] = {
'load_count': 0,
'use_count': 0,
'total_load_time': 0,
'last_used': None,
'avg_load_time': 0
}
if preload:
logger.info(f"模型 {model_id} 已注册并标记为预加载")
# 启动预加载线程
threading.Thread(target=self._preload_model, args=(model_id,), daemon=True).start()
else:
logger.info(f"模型 {model_id} 已注册")
return True
def get_model(self, model_id):
"""
获取模型,如果未加载则加载
参数:
model_id: 模型唯一标识符
返回:
加载好的模型对象
"""
if model_id not in self.preload_config:
raise ValueError(f"模型 {model_id} 未注册")
# 更新最后使用时间
self.model_stats[model_id]['last_used'] = datetime.now()
self.model_stats[model_id]['use_count'] += 1
# 检查模型是否已加载
if model_id in self.loaded_models:
# 将模型移至OrderedDict末尾,表示最近使用
model = self.loaded_models.pop(model_id)
self.loaded_models[model_id] = model
logger.debug(f"使用已加载的模型: {model_id}")
return model
# 获取模型加载锁,防止并发加载同一模型
if model_id not in self.loading_locks:
self.loading_locks[model_id] = threading.Lock()
# 加锁加载模型
with self.loading_locks[model_id]:
# 再次检查模型是否已被其他线程加载
if model_id in self.loaded_models:
return self.loaded_models[model_id]
# 检查是否有足够内存
self._ensure_memory_available(self.preload_config[model_id]['model_size_gb'])
# 加载模型
start_time = time.time()
model = self._load_model(model_id)
load_time = time.time() - start_time
# 更新统计
self.model_stats[model_id]['load_count'] += 1
self.model_stats[model_id]['total_load_time'] += load_time
self.model_stats[model_id]['avg_load_time'] = (
self.model_stats[model_id]['total_load_time'] /
self.model_stats[model_id]['load_count']
)
logger.info(f"模型 {model_id} 加载完成,耗时: {load_time:.2f}")
# 存储模型
self.loaded_models[model_id] = model
return model
def unload_model(self, model_id):
"""
手动卸载模型
参数:
model_id: 模型唯一标识符
"""
if model_id in self.loaded_models:
logger.info(f"手动卸载模型: {model_id}")
del self.loaded_models[model_id]
# 强制垃圾回收
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return True
return False
def get_model_stats(self):
"""获取所有模型的使用统计"""
result = {}
for model_id, stats in self.model_stats.items():
is_loaded = model_id in self.loaded_models
result[model_id] = {
**stats,
'is_loaded': is_loaded,
'preload': self.preload_config[model_id]['preload'],
'model_size_gb': self.preload_config[model_id]['model_size_gb'],
'device': self.preload_config[model_id]['device'],
}
return result
def preload_all(self):
"""预加载所有标记为预加载的模型"""
for model_id, config in self.preload_config.items():
if config['preload'] and model_id not in self.loaded_models:
threading.Thread(target=self._preload_model, args=(model_id,), daemon=True).start()
def _preload_model(self, model_id):
"""预加载单个模型的内部方法"""
try:
logger.info(f"开始预加载模型: {model_id}")
# 确保有足够内存
self._ensure_memory_available(self.preload_config[model_id]['model_size_gb'])
# 加载模型
start_time = time.time()
model = self._load_model(model_id)
load_time = time.time() - start_time
# 更新统计
self.model_stats[model_id]['load_count'] += 1
self.model_stats[model_id]['total_load_time'] += load_time
self.model_stats[model_id]['avg_load_time'] = (
self.model_stats[model_id]['total_load_time'] /
self.model_stats[model_id]['load_count']
)
# 存储模型
self.loaded_models[model_id] = model
logger.info(f"模型 {model_id} 预加载完成,耗时: {load_time:.2f}")
except Exception as e:
logger.error(f"预加载模型 {model_id} 失败: {e}")
def _load_model(self, model_id):
"""加载模型的内部方法"""
config = self.preload_config[model_id]
if config['load_function'] is not None:
# 使用自定义加载函数
return config['load_function'](config['model_path'], config['device'])
# 默认加载逻辑 - 根据文件扩展名确定加载方式
model_path = config['model_path']
device = config['device']
if model_path.endswith('.pt') or model_path.endswith('.pth'):
# PyTorch模型
return torch.load(model_path, map_location=device)
elif model_path.endswith('.pkl'):
# Pickle模型
import pickle
with open(model_path, 'rb') as f:
return pickle.load(f)
else:
# 尝试作为目录加载
if os.path.isdir(model_path):
# 如果是目录,尝试加载预训练模型
try:
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
return {'model': model.to(device), 'tokenizer': tokenizer}
except ImportError:
logger.error("transformers库未安装,无法加载预训练模型")
raise
except Exception as e:
logger.error(f"加载预训练模型失败: {e}")
raise
raise ValueError(f"无法确定如何加载模型: {model_path}")
def _ensure_memory_available(self, required_gb):
"""确保有足够的内存来加载新模型"""
# 如果当前没有加载的模型,直接返回
if not self.loaded_models:
return
# 计算当前已加载模型的总内存
current_usage = sum(
self.preload_config[model_id]['model_size_gb']
for model_id in self.loaded_models
)
# 如果添加新模型后超过限制,需要卸载一些模型
while current_usage + required_gb > self.max_memory_usage and self.loaded_models:
# 卸载最久未使用的模型(OrderedDict的首项)
oldest_model_id, _ = next(iter(self.loaded_models.items()))
# 检查是否是预加载且最近使用过的模型
if (self.preload_config[oldest_model_id]['preload'] and
self.model_stats[oldest_model_id]['last_used'] and
(datetime.now() - self.model_stats[oldest_model_id]['last_used']) <
timedelta(minutes=self.unload_timeout)):
# 跳过预加载且最近使用过的模型
# 将该模型移至OrderedDict末尾
model = self.loaded_models.pop(oldest_model_id)
self.loaded_models[oldest_model_id] = model
# 如果所有模型都是预加载的且最近使用过,允许超过限制
if len(self.loaded_models) <= 1:
break
continue
# 卸载模型并更新内存使用
model_size = self.preload_config[oldest_model_id]['model_size_gb']
del self.loaded_models[oldest_model_id]
current_usage -= model_size
logger.info(f"自动卸载模型以释放内存: {oldest_model_id} ({model_size}GB)")
# 强制垃圾回收
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _monitor_models(self):
"""监控并管理模型的内部线程方法"""
while True:
try:
# 检查长时间未使用的非预加载模型
current_time = datetime.now()
for model_id in list(self.loaded_models.keys()):
if (not self.preload_config[model_id]['preload'] and
self.model_stats[model_id]['last_used'] and
(current_time - self.model_stats[model_id]['last_used']) >
timedelta(minutes=self.unload_timeout)):
# 卸载长时间未使用的非预加载模型
logger.info(f"卸载长时间未使用的模型: {model_id}")
del self.loaded_models[model_id]
# 强制垃圾回收
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 每5分钟检查一次
time.sleep(300)
except Exception as e:
logger.error(f"模型监控线程出错: {e}")
time.sleep(300)
# 创建全局模型管理器实例
model_manager = ModelManager()
# 注册示例函数
def register_sentiment_model():
"""注册情感分析模型示例"""
from utils.model_loader import load_sentiment_model # 假设您有一个加载情感模型的函数
try:
model_path = os.path.join('model', 'sentiment.marshal.3')
model_manager.register_model(
model_id='sentiment_basic',
model_path=model_path,
preload=True,
model_size_gb=0.2,
load_function=load_sentiment_model
)
return True
except Exception as e:
logger.error(f"注册情感分析模型失败: {e}")
return False
def register_bert_model():
"""注册BERT模型示例"""
try:
model_path = os.path.join('model_pro', 'bert_model')
model_manager.register_model(
model_id='bert_classifier',
model_path=model_path,
preload=True,
model_size_gb=0.8
)
return True
except Exception as e:
logger.error(f"注册BERT模型失败: {e}")
return False
# 自动注册常用模型(在导入时执行)
try:
register_sentiment_model()
register_bert_model()
except Exception as e:
logger.error(f"自动注册模型失败: {e}")
+494
View File
@@ -0,0 +1,494 @@
import os
import json
import logging
import re
from collections import defaultdict
import random
import torch
import numpy as np
from datetime import datetime
from typing import Dict, List, Any, Tuple, Optional, Union, Callable
logger = logging.getLogger('model_router')
logger.setLevel(logging.INFO)
class ModelRouter:
"""
模型路由器 - 自动根据内容类型选择最优的AI模型
功能:
1. 根据内容类型和任务需求,自动选择最合适的AI模型
2. 支持多种模型供应商和模型类型
3. 考虑性能、成本和准确度等因素进行智能路由
4. 学习和适应用户偏好和使用模式
5. 提供标准化的API接口,支持私有模型集成
"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(ModelRouter, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
# 支持的模型定义
self.models = {
# OpenAI 模型
'gpt-4o-latest': {
'provider': 'openai',
'capabilities': {
'text_analysis': 0.95,
'sentiment_analysis': 0.92,
'keyword_extraction': 0.90,
'summarization': 0.93,
'classification': 0.89,
'chinese_text': 0.88
},
'cost_per_1k': 0.01,
'max_tokens': 128000,
'avg_latency': 2.5, # 秒
'requires_api_key': 'OPENAI_API_KEY'
},
'gpt-4o-mini': {
'provider': 'openai',
'capabilities': {
'text_analysis': 0.85,
'sentiment_analysis': 0.82,
'keyword_extraction': 0.80,
'summarization': 0.84,
'classification': 0.81,
'chinese_text': 0.79
},
'cost_per_1k': 0.00015,
'max_tokens': 4000,
'avg_latency': 1.2,
'requires_api_key': 'OPENAI_API_KEY'
},
'gpt-3.5-turbo': {
'provider': 'openai',
'capabilities': {
'text_analysis': 0.75,
'sentiment_analysis': 0.72,
'keyword_extraction': 0.70,
'summarization': 0.77,
'classification': 0.73,
'chinese_text': 0.65
},
'cost_per_1k': 0.0015,
'max_tokens': 16000,
'avg_latency': 0.8,
'requires_api_key': 'OPENAI_API_KEY'
},
# Claude 模型
'claude-3.5-sonnet': {
'provider': 'anthropic',
'capabilities': {
'text_analysis': 0.90,
'sentiment_analysis': 0.91,
'keyword_extraction': 0.85,
'summarization': 0.92,
'classification': 0.89,
'chinese_text': 0.80
},
'cost_per_1k': 0.015,
'max_tokens': 200000,
'avg_latency': 2.8,
'requires_api_key': 'ANTHROPIC_API_KEY'
},
'claude-3.5-haiku': {
'provider': 'anthropic',
'capabilities': {
'text_analysis': 0.84,
'sentiment_analysis': 0.83,
'keyword_extraction': 0.79,
'summarization': 0.85,
'classification': 0.80,
'chinese_text': 0.72
},
'cost_per_1k': 0.0025,
'max_tokens': 200000,
'avg_latency': 1.5,
'requires_api_key': 'ANTHROPIC_API_KEY'
},
# DeepSeek 模型
'deepseek-chat': {
'provider': 'deepseek',
'capabilities': {
'text_analysis': 0.82,
'sentiment_analysis': 0.79,
'keyword_extraction': 0.77,
'summarization': 0.80,
'classification': 0.77,
'chinese_text': 0.90 # 特别好中文
},
'cost_per_1k': 0.002,
'max_tokens': 4000,
'avg_latency': 1.0,
'requires_api_key': 'DEEPSEEK_API_KEY'
},
'deepseek-reasoner': {
'provider': 'deepseek',
'capabilities': {
'text_analysis': 0.87,
'sentiment_analysis': 0.75,
'keyword_extraction': 0.76,
'summarization': 0.78,
'classification': 0.85,
'chinese_text': 0.88
},
'cost_per_1k': 0.003,
'max_tokens': 4000,
'avg_latency': 1.8,
'requires_api_key': 'DEEPSEEK_API_KEY'
}
}
# 任务类型定义
self.task_types = {
'sentiment_analysis': {
'description': '情感分析',
'key_capabilities': ['sentiment_analysis', 'text_analysis'],
'example_prompt': '分析以下文本的情感倾向(积极、消极或中性)'
},
'topic_classification': {
'description': '话题分类',
'key_capabilities': ['classification', 'text_analysis'],
'example_prompt': '将以下文本分类到最合适的话题类别'
},
'keyword_extraction': {
'description': '关键词提取',
'key_capabilities': ['keyword_extraction', 'text_analysis'],
'example_prompt': '从以下文本中提取5个最重要的关键词'
},
'text_summarization': {
'description': '文本摘要',
'key_capabilities': ['summarization', 'text_analysis'],
'example_prompt': '为以下文本生成一个简短的摘要'
},
'comprehensive_analysis': {
'description': '综合分析',
'key_capabilities': ['text_analysis', 'sentiment_analysis', 'keyword_extraction', 'summarization'],
'example_prompt': '对以下文本进行全面分析,包括情感、关键词和主要观点'
}
}
# 用户偏好和使用历史
self.usage_history = defaultdict(list)
# 模型可用性缓存
self.available_models = {}
# 更新模型可用性
self._update_available_models()
self._initialized = True
logger.info("模型路由器初始化完成")
def _update_available_models(self):
"""更新模型可用性"""
self.available_models = {}
for model_id, model_info in self.models.items():
# 检查API密钥是否可用
api_key_env = model_info.get('requires_api_key')
if api_key_env and os.getenv(api_key_env):
self.available_models[model_id] = model_info
if not self.available_models:
logger.warning("未找到可用的模型,请检查API密钥配置")
else:
logger.info(f"找到 {len(self.available_models)} 个可用模型")
def detect_content_type(self, text: str) -> Dict[str, float]:
"""
检测内容类型和特征
参数:
text: 要分析的文本
返回:
内容类型特征字典,键为特征名称,值为权重
"""
features = {
'chinese_text': 0.0,
'length': 0.0,
'complexity': 0.0
}
if not text:
return features
# 检测中文比例
chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
total_chars = len(text)
chinese_ratio = chinese_chars / total_chars if total_chars > 0 else 0
# 文本长度评分 (归一化至0-1)
length_score = min(1.0, len(text) / 10000)
# 文本复杂度简单估计
# 基于句子长度、词汇多样性等
sentences = re.split(r'[.!?。!?]', text)
avg_sentence_len = sum(len(s) for s in sentences) / len(sentences) if sentences else 0
unique_words = len(set(re.findall(r'\w+', text.lower())))
total_words = len(re.findall(r'\w+', text.lower()))
lexical_diversity = unique_words / total_words if total_words > 0 else 0
complexity_score = (avg_sentence_len / 50 + lexical_diversity) / 2
complexity_score = min(1.0, complexity_score)
features['chinese_text'] = chinese_ratio
features['length'] = length_score
features['complexity'] = complexity_score
return features
def select_model(self, text: str, task_type: str,
optimize_for: str = 'balanced',
exclude_models: List[str] = None) -> str:
"""
为给定文本和任务选择最合适的模型
参数:
text: 要处理的文本
task_type: 任务类型,如 'sentiment_analysis'
optimize_for: 优化目标,可选值:'cost'(成本), 'performance'(性能), 'balanced'(平衡)
exclude_models: 要排除的模型列表
返回:
选择的模型ID
"""
if not self.available_models:
logger.error("没有可用的模型,请检查API密钥配置")
return None
if task_type not in self.task_types:
logger.warning(f"未知的任务类型: {task_type},使用默认任务类型: 'comprehensive_analysis'")
task_type = 'comprehensive_analysis'
# 获取内容特征
content_features = self.detect_content_type(text)
# 获取任务关键能力
task_capabilities = self.task_types[task_type]['key_capabilities']
# 计算每个模型的得分
model_scores = {}
exclude_models = exclude_models or []
for model_id, model_info in self.available_models.items():
if model_id in exclude_models:
continue
# 基于任务能力的得分
capability_score = 0
for capability in task_capabilities:
capability_score += model_info['capabilities'].get(capability, 0)
capability_score /= len(task_capabilities)
# 基于内容特征的得分调整
content_score = 1.0
# 如果有大量中文,增加中文能力的权重
if content_features['chinese_text'] > 0.5:
chinese_capability = model_info['capabilities'].get('chinese_text', 0)
content_score *= (1.0 + chinese_capability) / 2
# 如果文本很长,检查模型的最大token限制
if content_features['length'] > 0.7:
max_tokens = model_info.get('max_tokens', 4000)
if max_tokens < 10000:
content_score *= 0.7 # 长文本降低短上下文模型的分数
# 如果文本很复杂,可能需要更强大的模型
if content_features['complexity'] > 0.7:
# 假设能力得分更高的模型更能处理复杂文本
content_score *= (1.0 + capability_score) / 2
# 根据优化目标调整最终得分
final_score = capability_score * content_score
if optimize_for == 'cost':
# 成本越低,分数越高
cost_factor = 1 - min(1.0, model_info.get('cost_per_1k', 0) / 0.03)
final_score = final_score * 0.3 + cost_factor * 0.7
elif optimize_for == 'performance':
# 能力得分权重更高
final_score = capability_score * 0.8 + content_score * 0.2
# balanced 是默认值,不需要额外调整
model_scores[model_id] = final_score
if not model_scores:
logger.warning("没有符合条件的可用模型")
return list(self.available_models.keys())[0]
# 选择得分最高的模型
selected_model = max(model_scores, key=model_scores.get)
# 记录使用历史
self.usage_history[task_type].append({
'model': selected_model,
'timestamp': datetime.now().timestamp(),
'score': model_scores[selected_model],
'optimize_for': optimize_for
})
logger.info(f"为任务 '{task_type}' 选择了模型: {selected_model} (得分: {model_scores[selected_model]:.4f})")
return selected_model
def get_model_info(self, model_id: str) -> Dict:
"""获取模型信息"""
if model_id in self.models:
return self.models[model_id]
return None
def get_available_models(self, refresh: bool = False) -> Dict[str, Dict]:
"""获取所有可用的模型"""
if refresh:
self._update_available_models()
return self.available_models
def get_model_by_provider(self, provider: str, optimize_for: str = 'balanced') -> str:
"""根据提供商获取推荐模型"""
provider_models = {
model_id: info for model_id, info in self.available_models.items()
if info['provider'] == provider
}
if not provider_models:
logger.warning(f"未找到提供商 '{provider}' 的可用模型")
return None
if optimize_for == 'cost':
# 选择成本最低的模型
return min(provider_models.items(), key=lambda x: x[1].get('cost_per_1k', float('inf')))[0]
elif optimize_for == 'performance':
# 选择性能最好的模型,简单取所有能力的平均值
return max(provider_models.items(),
key=lambda x: sum(x[1]['capabilities'].values()) / len(x[1]['capabilities']))[0]
else:
# 平衡模式,综合考虑成本和性能
scores = {}
for model_id, info in provider_models.items():
perf_score = sum(info['capabilities'].values()) / len(info['capabilities'])
cost_score = 1 - min(1.0, info.get('cost_per_1k', 0) / 0.03)
scores[model_id] = perf_score * 0.5 + cost_score * 0.5
return max(scores, key=scores.get)
def get_task_types(self) -> Dict[str, Dict]:
"""获取支持的任务类型"""
return self.task_types
def register_custom_model(self, model_id: str, model_info: Dict[str, Any]) -> bool:
"""
注册自定义模型
参数:
model_id: 模型唯一标识符
model_info: 模型信息字典,包含以下字段:
- provider: 提供商名称
- capabilities: 能力评分字典
- cost_per_1k: 每千token的成本
- max_tokens: 最大token限制
- avg_latency: 平均延迟(秒)
- requires_api_key: API密钥环境变量名
返回:
是否注册成功
"""
# 验证必要字段
required_fields = ['provider', 'capabilities', 'cost_per_1k', 'max_tokens']
for field in required_fields:
if field not in model_info:
logger.error(f"注册模型失败: 缺少必要字段 '{field}'")
return False
# 验证能力评分
if not isinstance(model_info['capabilities'], dict):
logger.error("注册模型失败: 'capabilities' 必须是字典")
return False
# 添加模型
self.models[model_id] = model_info
# 更新可用模型列表
self._update_available_models()
logger.info(f"成功注册自定义模型: {model_id}")
return True
# 创建全局模型路由器实例
model_router = ModelRouter()
def select_model(text, task_type, optimize_for='balanced', exclude_models=None):
"""选择最合适的模型"""
return model_router.select_model(text, task_type, optimize_for, exclude_models)
def get_available_models(refresh=False):
"""获取所有可用的模型"""
return model_router.get_available_models(refresh)
def get_model_by_provider(provider, optimize_for='balanced'):
"""根据提供商获取推荐模型"""
return model_router.get_model_by_provider(provider, optimize_for)
def get_task_types():
"""获取支持的任务类型"""
return model_router.get_task_types()
def register_custom_model(model_id, model_info):
"""注册自定义模型"""
return model_router.register_custom_model(model_id, model_info)
# 示例用法
if __name__ == "__main__":
# 示例文本
chinese_text = """
近日,人工智能技术的发展引发广泛关注。
专家指出,大型语言模型在自然语言处理领域取得了显著进展,
但同时也带来了诸多伦理和安全问题。对此,业界呼吁加强监管,
确保人工智能的发展能够造福人类社会。
"""
english_text = """
Recent developments in artificial intelligence technology have drawn widespread attention.
Experts point out that large language models have made significant progress in the field of natural language processing,
but also bring many ethical and security issues. In response, the industry calls for strengthened regulation
to ensure that the development of artificial intelligence can benefit human society.
"""
# 测试模型选择
print("中文文本任务测试:")
model_for_chinese = select_model(chinese_text, 'sentiment_analysis')
print(f"选择的模型: {model_for_chinese}")
print("\n英文文本任务测试:")
model_for_english = select_model(english_text, 'sentiment_analysis')
print(f"选择的模型: {model_for_english}")
print("\n成本优化测试:")
model_for_cost = select_model(chinese_text, 'sentiment_analysis', optimize_for='cost')
print(f"选择的模型: {model_for_cost}")
print("\n性能优化测试:")
model_for_perf = select_model(chinese_text, 'sentiment_analysis', optimize_for='performance')
print(f"选择的模型: {model_for_perf}")
# 测试API提供商
print("\n根据提供商获取模型:")
for provider in ['openai', 'anthropic', 'deepseek']:
model = get_model_by_provider(provider)
if model:
print(f"{provider}: {model}")
else:
print(f"{provider}: 无可用模型")
+357
View File
@@ -0,0 +1,357 @@
import re
import json
import os
import logging
from pathlib import Path
logger = logging.getLogger('sensitive_filter')
logger.setLevel(logging.INFO)
class SensitiveDataFilter:
"""
敏感数据过滤器 - 用于检测和屏蔽输出内容中的敏感信息
功能:
1. 自动识别并过滤手机号、邮箱、身份证号、信用卡号等敏感信息
2. 支持自定义敏感信息模式和替换文本
3. 提供批量处理和实时过滤功能
"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(SensitiveDataFilter, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
# 默认配置
self.config = {
'enabled': os.getenv('ENABLE_SENSITIVE_DATA_FILTER', 'true').lower() == 'true',
'patterns': {
'phone': r'\b1[3-9]\d{9}\b',
'email': r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
'id_card': r'\b[1-9]\d{5}(19|20)\d{2}(0[1-9]|1[0-2])(0[1-9]|[12]\d|3[01])\d{3}[\dXx]\b',
'credit_card': r'\b\d{4}[ -]?\d{4}[ -]?\d{4}[ -]?\d{4}\b',
'address': r'(北京|上海|广州|深圳|天津|重庆|南京|杭州|武汉|成都|西安)市.*?(路|街|道|巷).*?(号)'
},
'replacements': {
'phone': '***********',
'email': '******@*****',
'id_card': '******************',
'credit_card': '****************',
'address': '[地址已隐藏]'
}
}
# 加载自定义配置
self._load_config()
# 编译正则表达式
self._compile_patterns()
self._initialized = True
logger.info("敏感数据过滤器初始化完成")
if self.config['enabled']:
logger.info(f"已启用以下类型的敏感数据过滤: {', '.join(self.config['patterns'].keys())}")
else:
logger.info("敏感数据过滤器已禁用")
def _load_config(self):
"""加载自定义配置"""
# 配置文件路径
data_dir = os.getenv('DATA_DIR', 'data')
config_path = os.path.join(data_dir, 'security', 'sensitive_filter.json')
if os.path.exists(config_path):
try:
with open(config_path, 'r', encoding='utf-8') as f:
custom_config = json.load(f)
# 更新配置
if 'enabled' in custom_config:
self.config['enabled'] = custom_config['enabled']
if 'patterns' in custom_config:
for key, pattern in custom_config['patterns'].items():
self.config['patterns'][key] = pattern
if 'replacements' in custom_config:
for key, replacement in custom_config['replacements'].items():
self.config['replacements'][key] = replacement
logger.info(f"已加载自定义敏感数据过滤配置: {config_path}")
except Exception as e:
logger.error(f"加载敏感数据过滤配置失败: {e}")
def _compile_patterns(self):
"""编译正则表达式"""
self.compiled_patterns = {}
for key, pattern in self.config['patterns'].items():
try:
self.compiled_patterns[key] = re.compile(pattern)
logger.debug(f"已编译敏感数据模式: {key} - {pattern}")
except re.error as e:
logger.error(f"编译敏感数据模式失败: {key} - {pattern}: {e}")
def filter_text(self, text):
"""
过滤文本中的敏感信息
参数:
text: 要过滤的文本
返回:
过滤后的文本
"""
if not self.config['enabled'] or not text:
return text
filtered_text = text
for key, pattern in self.compiled_patterns.items():
replacement = self.config['replacements'].get(key, '[FILTERED]')
filtered_text = pattern.sub(replacement, filtered_text)
return filtered_text
def filter_dict(self, data, *skip_keys):
"""
过滤字典中的敏感信息
参数:
data: 要过滤的字典
skip_keys: 要跳过的键(不进行过滤)
返回:
过滤后的字典
"""
if not self.config['enabled'] or not data:
return data
if not isinstance(data, dict):
if isinstance(data, str):
return self.filter_text(data)
return data
filtered_data = {}
for key, value in data.items():
if key in skip_keys:
filtered_data[key] = value
continue
if isinstance(value, dict):
filtered_data[key] = self.filter_dict(value, *skip_keys)
elif isinstance(value, list):
filtered_data[key] = [
self.filter_dict(item, *skip_keys) if isinstance(item, (dict, list)) else
self.filter_text(item) if isinstance(item, str) else item
for item in value
]
elif isinstance(value, str):
filtered_data[key] = self.filter_text(value)
else:
filtered_data[key] = value
return filtered_data
def filter_list(self, data, *skip_keys):
"""
过滤列表中的敏感信息
参数:
data: 要过滤的列表
skip_keys: 如果列表项是字典,要跳过的键
返回:
过滤后的列表
"""
if not self.config['enabled'] or not data:
return data
if not isinstance(data, list):
if isinstance(data, dict):
return self.filter_dict(data, *skip_keys)
if isinstance(data, str):
return self.filter_text(data)
return data
return [
self.filter_dict(item, *skip_keys) if isinstance(item, dict) else
self.filter_list(item, *skip_keys) if isinstance(item, list) else
self.filter_text(item) if isinstance(item, str) else item
for item in data
]
def is_sensitive_info(self, text, info_type=None):
"""
检查文本是否包含敏感信息
参数:
text: 要检查的文本
info_type: 指定要检查的敏感信息类型,如果为None则检查所有类型
返回:
包含敏感信息返回True,否则返回False
"""
if not self.config['enabled'] or not text:
return False
if info_type:
if info_type not in self.compiled_patterns:
logger.warning(f"未知的敏感信息类型: {info_type}")
return False
return bool(self.compiled_patterns[info_type].search(text))
for pattern in self.compiled_patterns.values():
if pattern.search(text):
return True
return False
def get_sensitive_info_types(self, text):
"""
获取文本中包含的敏感信息类型
参数:
text: 要检查的文本
返回:
包含的敏感信息类型列表
"""
if not self.config['enabled'] or not text:
return []
types = []
for key, pattern in self.compiled_patterns.items():
if pattern.search(text):
types.append(key)
return types
def enable(self):
"""启用敏感数据过滤器"""
self.config['enabled'] = True
logger.info("敏感数据过滤器已启用")
def disable(self):
"""禁用敏感数据过滤器"""
self.config['enabled'] = False
logger.info("敏感数据过滤器已禁用")
def is_enabled(self):
"""检查敏感数据过滤器是否启用"""
return self.config['enabled']
def add_pattern(self, key, pattern, replacement='[FILTERED]'):
"""
添加自定义敏感信息模式
参数:
key: 敏感信息类型标识
pattern: 正则表达式字符串
replacement: 替换文本
"""
try:
# 测试是否是有效的正则表达式
re.compile(pattern)
# 更新配置
self.config['patterns'][key] = pattern
self.config['replacements'][key] = replacement
# 重新编译正则表达式
self._compile_patterns()
logger.info(f"已添加敏感信息模式: {key}")
return True
except re.error as e:
logger.error(f"添加敏感信息模式失败: {key} - {pattern}: {e}")
return False
def remove_pattern(self, key):
"""
移除敏感信息模式
参数:
key: 敏感信息类型标识
"""
if key in self.config['patterns']:
del self.config['patterns'][key]
if key in self.config['replacements']:
del self.config['replacements'][key]
if key in self.compiled_patterns:
del self.compiled_patterns[key]
logger.info(f"已移除敏感信息模式: {key}")
return True
logger.warning(f"未找到敏感信息模式: {key}")
return False
def save_config(self):
"""保存当前配置到文件"""
data_dir = os.getenv('DATA_DIR', 'data')
security_dir = os.path.join(data_dir, 'security')
os.makedirs(security_dir, exist_ok=True)
config_path = os.path.join(security_dir, 'sensitive_filter.json')
try:
with open(config_path, 'w', encoding='utf-8') as f:
json.dump(self.config, f, ensure_ascii=False, indent=2)
logger.info(f"敏感数据过滤配置已保存到: {config_path}")
return True
except Exception as e:
logger.error(f"保存敏感数据过滤配置失败: {e}")
return False
# 创建全局敏感数据过滤器实例
sensitive_filter = SensitiveDataFilter()
# 提供便捷的过滤函数
def filter_text(text):
"""过滤文本中的敏感信息"""
return sensitive_filter.filter_text(text)
def filter_dict(data, *skip_keys):
"""过滤字典中的敏感信息"""
return sensitive_filter.filter_dict(data, *skip_keys)
def filter_list(data, *skip_keys):
"""过滤列表中的敏感信息"""
return sensitive_filter.filter_list(data, *skip_keys)
def is_sensitive_info(text, info_type=None):
"""检查文本是否包含敏感信息"""
return sensitive_filter.is_sensitive_info(text, info_type)
# 示例用法
if __name__ == "__main__":
# 测试文本
test_text = """
联系人: 张三
电话: 13812345678
邮箱: zhangsan@example.com
身份证: 110101199001011234
地址: 北京市海淀区中关村大街20号
信用卡: 6225 1234 5678 9012
"""
# 过滤敏感信息
filtered_text = filter_text(test_text)
print("原始文本:")
print(test_text)
print("\n过滤后:")
print(filtered_text)
# 检查敏感信息类型
types = sensitive_filter.get_sensitive_info_types(test_text)
print(f"\n包含的敏感信息类型: {types}")
+837
View File
@@ -0,0 +1,837 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import json
import time
import uuid
import logging
import traceback
from datetime import datetime
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
from utils.db_manager import DatabaseManager
from utils.cache_manager import CacheManager
from utils.model_router import ModelRouter
from utils.sensitive_filter import SensitiveDataFilter
from spider.weibo_crawler import WeiboCrawler
from utils.ai_analyzer import AIAnalyzer
# 配置日志
from utils.logger import setup_logger
logger = setup_logger('workflow_engine', 'logs/workflow_engine.log')
class WorkflowEngine:
"""工作流引擎 - 负责执行数据爬取和分析工作流"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super(WorkflowEngine, cls).__new__(cls)
return cls._instance
def __init__(self):
if self._initialized:
return
self.db = DatabaseManager()
self.cache = CacheManager(memory_capacity=50, cache_duration=3600)
self.model_router = ModelRouter()
self.sensitive_filter = SensitiveDataFilter()
self.executor = ThreadPoolExecutor(max_workers=5)
self.running_tasks = {}
# 创建必要的目录
self.data_dir = Path('data/workflow')
self.data_dir.mkdir(parents=True, exist_ok=True)
self._initialized = True
logger.info("工作流引擎初始化完成")
def execute_crawler_workflow(self, task_id, config):
"""
执行爬虫工作流
Args:
task_id: 任务ID
config: 爬虫配置
"""
logger.info(f"开始执行爬虫工作流: {task_id}")
try:
# 更新任务状态为运行中
self._update_task_status(task_id, 'running', 0)
# 创建爬虫实例
crawler = WeiboCrawler()
# 设置爬虫参数
source = config.get('source', 'hot_topics')
depth = config.get('crawl_depth', 1)
interval = config.get('interval', 5)
filters = config.get('filters', {})
# 执行爬取
result = crawler.crawl(
source=source,
depth=depth,
interval=interval,
filters=filters,
callback=lambda progress: self._update_task_progress(task_id, progress)
)
# 更新任务状态为已完成
self._update_task_status(task_id, 'completed', 100, result=result)
logger.info(f"爬虫工作流完成: {task_id}")
return result
except Exception as e:
logger.error(f"爬虫工作流出错: {str(e)}")
logger.error(traceback.format_exc())
self._update_task_status(task_id, 'failed', 0, error=str(e))
return None
def execute_analysis_workflow(self, task_id, workflow):
"""
执行分析工作流
Args:
task_id: 任务ID
workflow: 工作流配置
"""
logger.info(f"开始执行分析工作流: {task_id}")
try:
# 更新任务状态为运行中
self._update_task_status(task_id, 'running', 0)
components = workflow.get('components', [])
connections = workflow.get('connections', [])
# 验证工作流
if not components or not connections:
raise ValueError("工作流配置不完整,缺少组件或连接")
# 构建组件依赖图
component_map, dependency_graph = self._build_dependency_graph(components, connections)
# 进行拓扑排序
execution_order = self._topological_sort(dependency_graph)
# 执行组件
result_map = {}
total_components = len(execution_order)
for idx, component_id in enumerate(execution_order):
component = component_map.get(component_id)
if not component:
continue
# 计算总体进度
progress = int((idx / total_components) * 100)
self._update_task_progress(task_id, progress)
# 获取输入数据
input_data = self._get_component_input_data(component_id, connections, result_map)
# 执行组件
result = self._execute_component(component, input_data)
# 存储结果
result_map[component_id] = result
# 获取最终输出
final_outputs = self._get_final_outputs(dependency_graph, result_map)
# 应用敏感信息过滤
if final_outputs and self.sensitive_filter.is_enabled():
if isinstance(final_outputs, dict):
final_outputs = self.sensitive_filter.filter_dict(final_outputs)
elif isinstance(final_outputs, list):
final_outputs = self.sensitive_filter.filter_list(final_outputs)
# 更新任务状态为已完成
self._update_task_status(task_id, 'completed', 100, result=final_outputs)
logger.info(f"分析工作流完成: {task_id}")
return final_outputs
except Exception as e:
logger.error(f"分析工作流出错: {str(e)}")
logger.error(traceback.format_exc())
self._update_task_status(task_id, 'failed', 0, error=str(e))
return None
def start_workflow(self, workflow_type, config, template_id=None):
"""
异步启动工作流
Args:
workflow_type: 工作流类型 (crawler/analysis)
config: 工作流配置
template_id: 关联的模板ID
Returns:
task_id: 工作流任务ID
"""
# 生成任务ID
task_id = str(uuid.uuid4())
# 保存任务信息到数据库
conn = self.db.get_connection()
cursor = conn.cursor()
try:
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
cursor.execute(
"""
INSERT INTO workflow_tasks
(id, template_id, type, status, progress, config, created_at, updated_at)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
""",
(
task_id,
template_id,
workflow_type,
'pending',
0,
json.dumps(config, ensure_ascii=False),
now,
now
)
)
conn.commit()
# 异步执行工作流
if workflow_type == 'crawler':
self.running_tasks[task_id] = self.executor.submit(
self.execute_crawler_workflow, task_id, config
)
elif workflow_type == 'analysis':
self.running_tasks[task_id] = self.executor.submit(
self.execute_analysis_workflow, task_id, config
)
else:
logger.error(f"未知的工作流类型: {workflow_type}")
return None
return task_id
except Exception as e:
logger.error(f"启动工作流失败: {str(e)}")
conn.rollback()
return None
finally:
cursor.close()
def get_task_status(self, task_id):
"""
获取任务状态
Args:
task_id: 任务ID
Returns:
task: 任务信息
"""
# 先检查缓存
cache_key = f"task_status:{task_id}"
cached_task = self.cache.get(cache_key)
if cached_task:
return cached_task
# 从数据库获取
conn = self.db.get_connection()
cursor = conn.cursor()
try:
cursor.execute(
"SELECT * FROM workflow_tasks WHERE id = %s",
(task_id,)
)
task = cursor.fetchone()
if task:
# 将JSON字符串转为Python对象
if task.get('config'):
task['config'] = json.loads(task['config'])
if task.get('result'):
task['result'] = json.loads(task['result'])
# 缓存结果
self.cache.set(cache_key, task)
return task
except Exception as e:
logger.error(f"获取任务状态失败: {str(e)}")
return None
finally:
cursor.close()
def cancel_task(self, task_id):
"""
取消任务
Args:
task_id: 任务ID
Returns:
success: 是否成功
"""
# 检查任务是否存在并正在运行
if task_id in self.running_tasks:
# 尝试取消任务
future = self.running_tasks[task_id]
if not future.done():
future.cancel()
# 从运行列表中移除
del self.running_tasks[task_id]
# 更新数据库状态
conn = self.db.get_connection()
cursor = conn.cursor()
try:
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
cursor.execute(
"""
UPDATE workflow_tasks
SET status = %s, updated_at = %s
WHERE id = %s
""",
('cancelled', now, task_id)
)
conn.commit()
# 清理缓存
cache_key = f"task_status:{task_id}"
self.cache.delete(cache_key)
return True
except Exception as e:
logger.error(f"取消任务失败: {str(e)}")
conn.rollback()
return False
finally:
cursor.close()
def _update_task_status(self, task_id, status, progress, result=None, error=None):
"""更新任务状态"""
conn = self.db.get_connection()
cursor = conn.cursor()
try:
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
update_fields = ["status = %s", "progress = %s", "updated_at = %s"]
params = [status, progress, now]
# 添加开始时间
if status == 'running' and progress == 0:
update_fields.append("started_at = %s")
params.append(now)
# 添加完成时间
if status in ['completed', 'failed']:
update_fields.append("completed_at = %s")
params.append(now)
# 添加结果
if result is not None:
update_fields.append("result = %s")
params.append(json.dumps(result, ensure_ascii=False))
# 添加错误
if error is not None:
update_fields.append("error = %s")
params.append(error)
# 构建SQL
sql = f"""
UPDATE workflow_tasks
SET {', '.join(update_fields)}
WHERE id = %s
"""
params.append(task_id)
cursor.execute(sql, tuple(params))
conn.commit()
# 清理缓存
cache_key = f"task_status:{task_id}"
self.cache.delete(cache_key)
except Exception as e:
logger.error(f"更新任务状态失败: {str(e)}")
conn.rollback()
finally:
cursor.close()
def _update_task_progress(self, task_id, progress):
"""更新任务进度"""
self._update_task_status(task_id, 'running', progress)
def _build_dependency_graph(self, components, connections):
"""构建组件依赖图"""
component_map = {comp['id']: comp for comp in components}
dependency_graph = {comp['id']: [] for comp in components}
# 构建依赖关系
for conn in connections:
source = conn.get('source')
target = conn.get('target')
if source and target and source in component_map and target in component_map:
dependency_graph[target].append(source)
return component_map, dependency_graph
def _topological_sort(self, graph):
"""拓扑排序,确定组件执行顺序"""
visited = set()
temp = set()
order = []
def visit(node):
if node in temp:
raise ValueError(f"工作流存在循环依赖: {node}")
if node in visited:
return
temp.add(node)
for neighbor in graph.get(node, []):
visit(neighbor)
temp.remove(node)
visited.add(node)
order.append(node)
for node in graph:
if node not in visited:
visit(node)
return list(reversed(order))
def _get_component_input_data(self, component_id, connections, result_map):
"""获取组件的输入数据"""
input_data = {}
for conn in connections:
if conn.get('target') == component_id:
source_id = conn.get('source')
if source_id in result_map:
input_name = conn.get('targetInput', 'default')
input_data[input_name] = result_map[source_id]
return input_data
def _execute_component(self, component, input_data):
"""执行单个组件"""
component_type = component.get('type')
config = component.get('config', {})
if component_type == 'data_source':
return self._execute_data_source(config, input_data)
elif component_type == 'preprocessing':
return self._execute_preprocessing(config, input_data)
elif component_type == 'model':
return self._execute_model(config, input_data)
elif component_type == 'visualization':
return self._execute_visualization(config, input_data)
else:
logger.warning(f"未知的组件类型: {component_type}")
return None
def _execute_data_source(self, config, input_data):
"""执行数据源组件"""
source_type = config.get('source_type')
if source_type == 'database':
# 从数据库获取数据
table = config.get('table')
filters = config.get('filters', {})
limit = config.get('limit', 1000)
query_conditions = []
query_params = []
for key, value in filters.items():
if value:
query_conditions.append(f"{key} = %s")
query_params.append(value)
where_clause = f"WHERE {' AND '.join(query_conditions)}" if query_conditions else ""
sql = f"SELECT * FROM {table} {where_clause} LIMIT {limit}"
conn = self.db.get_connection()
cursor = conn.cursor()
try:
cursor.execute(sql, tuple(query_params))
return cursor.fetchall()
except Exception as e:
logger.error(f"数据库查询出错: {str(e)}")
return []
finally:
cursor.close()
elif source_type == 'file':
# 从文件加载数据
file_path = config.get('file_path')
if not file_path or not os.path.exists(file_path):
return []
try:
with open(file_path, 'r', encoding='utf-8') as f:
if file_path.endswith('.json'):
return json.load(f)
else:
return f.read()
except Exception as e:
logger.error(f"文件读取出错: {str(e)}")
return []
elif source_type == 'api':
# 这里需要实现API调用逻辑
# 由于涉及复杂的HTTP请求,暂不实现
logger.warning("API数据源暂未实现")
return []
else:
logger.warning(f"未知的数据源类型: {source_type}")
return []
def _execute_preprocessing(self, config, input_data):
"""执行数据预处理组件"""
preprocessing_type = config.get('preprocessing_type')
data = input_data.get('default', [])
if not data:
return []
if preprocessing_type == 'filter':
# 过滤数据
field = config.get('field')
value = config.get('value')
operator = config.get('operator', 'eq')
if not field:
return data
result = []
for item in data:
if operator == 'eq' and item.get(field) == value:
result.append(item)
elif operator == 'neq' and item.get(field) != value:
result.append(item)
elif operator == 'contains' and value in str(item.get(field, '')):
result.append(item)
elif operator == 'not_contains' and value not in str(item.get(field, '')):
result.append(item)
return result
elif preprocessing_type == 'sort':
# 排序数据
field = config.get('field')
order = config.get('order', 'asc')
if not field:
return data
return sorted(
data,
key=lambda x: x.get(field, ''),
reverse=(order == 'desc')
)
elif preprocessing_type == 'aggregate':
# 聚合数据
group_by = config.get('group_by')
aggregate_field = config.get('aggregate_field')
aggregate_type = config.get('aggregate_type', 'count')
if not group_by:
return data
result = {}
for item in data:
key = item.get(group_by)
if key not in result:
result[key] = {
'count': 0,
'sum': 0,
'values': []
}
result[key]['count'] += 1
if aggregate_field:
value = item.get(aggregate_field, 0)
if isinstance(value, (int, float)):
result[key]['sum'] += value
result[key]['values'].append(value)
# 计算最终结果
final_result = []
for key, values in result.items():
item = {group_by: key}
if aggregate_type == 'count':
item['value'] = values['count']
elif aggregate_type == 'sum':
item['value'] = values['sum']
elif aggregate_type == 'avg':
item['value'] = values['sum'] / values['count'] if values['count'] > 0 else 0
final_result.append(item)
return final_result
else:
logger.warning(f"未知的预处理类型: {preprocessing_type}")
return data
def _execute_model(self, config, input_data):
"""执行模型组件"""
model_type = config.get('model_type')
data = input_data.get('default', [])
if not data:
return []
analyzer = AIAnalyzer()
if model_type == 'sentiment':
# 情感分析
texts = []
if isinstance(data, list):
# 如果是列表,从指定字段获取文本
field = config.get('text_field', 'content')
texts = [item.get(field, '') for item in data if item.get(field)]
elif isinstance(data, str):
# 如果是字符串,直接使用
texts = [data]
# 获取合适的模型
model = self.model_router.select_model_for_text(texts[0] if texts else "", "sentiment")
# 执行分析
results = []
for text in texts:
result = analyzer.analyze_sentiment(text, model=model)
results.append(result)
# 如果输入是列表,将结果合并回原始数据
if isinstance(data, list):
field = config.get('text_field', 'content')
for i, item in enumerate(data):
if i < len(results) and item.get(field):
item['sentiment'] = results[i]
return data
else:
return results[0] if results else None
elif model_type == 'topic':
# 主题分类
texts = []
if isinstance(data, list):
field = config.get('text_field', 'content')
texts = [item.get(field, '') for item in data if item.get(field)]
elif isinstance(data, str):
texts = [data]
# 获取合适的模型
model = self.model_router.select_model_for_text(texts[0] if texts else "", "topic")
# 执行分析
results = []
for text in texts:
result = analyzer.analyze_topic(text, model=model)
results.append(result)
# 如果输入是列表,将结果合并回原始数据
if isinstance(data, list):
field = config.get('text_field', 'content')
for i, item in enumerate(data):
if i < len(results) and item.get(field):
item['topic'] = results[i]
return data
else:
return results[0] if results else None
elif model_type == 'keywords':
# 关键词提取
texts = []
if isinstance(data, list):
field = config.get('text_field', 'content')
texts = [item.get(field, '') for item in data if item.get(field)]
elif isinstance(data, str):
texts = [data]
# 获取合适的模型
model = self.model_router.select_model_for_text(texts[0] if texts else "", "keyword")
# 执行分析
results = []
for text in texts:
result = analyzer.extract_keywords(text, model=model)
results.append(result)
# 如果输入是列表,将结果合并回原始数据
if isinstance(data, list):
field = config.get('text_field', 'content')
for i, item in enumerate(data):
if i < len(results) and item.get(field):
item['keywords'] = results[i]
return data
else:
return results[0] if results else None
elif model_type == 'summarize':
# 文本摘要
texts = []
if isinstance(data, list):
field = config.get('text_field', 'content')
texts = [item.get(field, '') for item in data if item.get(field)]
elif isinstance(data, str):
texts = [data]
# 获取合适的模型
model = self.model_router.select_model_for_text(texts[0] if texts else "", "summarization")
# 执行分析
results = []
for text in texts:
result = analyzer.summarize_text(text, model=model)
results.append(result)
# 如果输入是列表,将结果合并回原始数据
if isinstance(data, list):
field = config.get('text_field', 'content')
for i, item in enumerate(data):
if i < len(results) and item.get(field):
item['summary'] = results[i]
return data
else:
return results[0] if results else None
else:
logger.warning(f"未知的模型类型: {model_type}")
return data
def _execute_visualization(self, config, input_data):
"""执行可视化组件"""
visualization_type = config.get('visualization_type')
data = input_data.get('default', [])
if not data:
return {}
if visualization_type == 'chart':
# 图表可视化
chart_type = config.get('chart_type', 'bar')
x_field = config.get('x_field')
y_field = config.get('y_field')
title = config.get('title', '数据可视化')
if not x_field or not y_field:
return {'error': '缺少x或y字段'}
# 提取数据
chart_data = {
'type': chart_type,
'title': title,
'xAxis': {'type': 'category', 'data': []},
'yAxis': {'type': 'value'},
'series': [{'data': []}]
}
for item in data:
x_value = item.get(x_field)
y_value = item.get(y_field)
if x_value is not None and y_value is not None:
chart_data['xAxis']['data'].append(x_value)
chart_data['series'][0]['data'].append(y_value)
return chart_data
elif visualization_type == 'table':
# 表格可视化
columns = config.get('columns', [])
title = config.get('title', '数据表格')
# 如果没有指定列,使用数据中的所有字段
if not columns and isinstance(data, list) and data:
columns = list(data[0].keys())
# 构建表格数据
table_data = {
'type': 'table',
'title': title,
'columns': columns,
'data': data
}
return table_data
elif visualization_type == 'wordcloud':
# 词云可视化
word_field = config.get('word_field')
value_field = config.get('value_field')
title = config.get('title', '词云图')
if not word_field:
return {'error': '缺少词字段'}
# 构建词云数据
wordcloud_data = {
'type': 'wordcloud',
'title': title,
'data': []
}
for item in data:
word = item.get(word_field)
value = item.get(value_field, 1)
if word:
wordcloud_data['data'].append({
'name': word,
'value': value
})
return wordcloud_data
else:
logger.warning(f"未知的可视化类型: {visualization_type}")
return {}
def _get_final_outputs(self, dependency_graph, result_map):
"""获取最终输出结果"""
# 找出没有后继节点的叶子节点
leaf_nodes = []
all_targets = set()
for node, deps in dependency_graph.items():
all_targets.update(deps)
for node in dependency_graph:
if node not in all_targets:
leaf_nodes.append(node)
# 收集所有叶子节点的结果
outputs = {}
for node in leaf_nodes:
if node in result_map:
outputs[node] = result_map[node]
return outputs