🚀 Major Upgrade! Visual Workflow Orchestrator and AI-Powered Crawler Implemented. Added Model Arena Feature and Efficiency Optimizations (Two-Level Caching Architecture + End-to-End Performance Enhancements).
This commit is contained in:
+262
-77
@@ -1,116 +1,301 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import shutil
|
||||
from datetime import datetime, timedelta
|
||||
import threading
|
||||
import queue
|
||||
from collections import OrderedDict
|
||||
import pickle
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
class PredictionCache:
|
||||
logger = logging.getLogger('cache_manager')
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class LRUCache:
|
||||
"""实现LRU (Least Recently Used) 缓存策略"""
|
||||
|
||||
def __init__(self, capacity):
|
||||
self.cache = OrderedDict()
|
||||
self.capacity = capacity
|
||||
|
||||
def get(self, key):
|
||||
if key not in self.cache:
|
||||
return None
|
||||
# 访问元素时,将其移至末尾,表示最近使用
|
||||
self.cache.move_to_end(key)
|
||||
return self.cache[key]
|
||||
|
||||
def put(self, key, value):
|
||||
# 如果键已存在,更新值并将其移至末尾
|
||||
if key in self.cache:
|
||||
self.cache[key] = value
|
||||
self.cache.move_to_end(key)
|
||||
return
|
||||
|
||||
# 如果缓存已满,删除最久未使用的项(OrderedDict 的首项)
|
||||
if len(self.cache) >= self.capacity:
|
||||
self.cache.popitem(last=False)
|
||||
|
||||
# 添加新项至末尾
|
||||
self.cache[key] = value
|
||||
|
||||
def remove(self, key):
|
||||
if key in self.cache:
|
||||
del self.cache[key]
|
||||
|
||||
def clear(self):
|
||||
self.cache.clear()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.cache)
|
||||
|
||||
def get_all_keys(self):
|
||||
return list(self.cache.keys())
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""两级缓存系统:内存LRU缓存 + 磁盘持久化缓存"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
def __new__(cls, *args, **kwargs):
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super(PredictionCache, cls).__new__(cls)
|
||||
cls._instance = super(CacheManager, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, 'initialized'):
|
||||
self.cache_dir = 'cache/predictions'
|
||||
self.cache_duration = timedelta(hours=24) # 缓存24小时
|
||||
self.cache = {}
|
||||
self.cache_queue = queue.Queue()
|
||||
def __init__(self, name="default", memory_capacity=1000, cache_duration=24,
|
||||
disk_cache_dir="cache", flush_interval=5):
|
||||
if hasattr(self, 'initialized'):
|
||||
return
|
||||
|
||||
self.name = name
|
||||
self.memory_cache = LRUCache(memory_capacity)
|
||||
self.disk_cache_dir = os.path.join(disk_cache_dir, name)
|
||||
self.cache_duration = timedelta(hours=cache_duration)
|
||||
self.flush_interval = flush_interval # 定时将内存缓存刷新到磁盘的间隔(分钟)
|
||||
self.cache_stats = {"hits": 0, "misses": 0, "disk_hits": 0}
|
||||
self.disk_queue = queue.Queue()
|
||||
self.initialized = True
|
||||
|
||||
# 确保缓存目录存在
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
os.makedirs(self.disk_cache_dir, exist_ok=True)
|
||||
|
||||
# 启动缓存清理线程
|
||||
self.cleanup_thread = threading.Thread(target=self._cleanup_old_cache, daemon=True)
|
||||
# 启动缓存管理线程
|
||||
self.cleanup_thread = threading.Thread(target=self._cleanup_and_flush_task, daemon=True)
|
||||
self.cleanup_thread.start()
|
||||
|
||||
# 加载现有缓存
|
||||
self._load_cache()
|
||||
# 启动磁盘写入线程
|
||||
self.disk_writer_thread = threading.Thread(target=self._disk_writer_task, daemon=True)
|
||||
self.disk_writer_thread.start()
|
||||
|
||||
logger.info(f"初始化缓存管理器: {name},内存容量: {memory_capacity}项,缓存时间: {cache_duration}小时")
|
||||
|
||||
def _load_cache(self):
|
||||
"""加载磁盘上的缓存文件"""
|
||||
try:
|
||||
for filename in os.listdir(self.cache_dir):
|
||||
if filename.endswith('.json'):
|
||||
filepath = os.path.join(self.cache_dir, filename)
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
cache_data = json.load(f)
|
||||
# 检查缓存是否过期
|
||||
if self._is_cache_valid(cache_data['timestamp']):
|
||||
topic = filename[:-5] # 移除.json后缀
|
||||
self.cache[topic] = cache_data
|
||||
else:
|
||||
# 删除过期缓存文件
|
||||
os.remove(filepath)
|
||||
except Exception as e:
|
||||
print(f"加载缓存失败: {e}")
|
||||
def _get_cache_key(self, key):
|
||||
"""标准化缓存键"""
|
||||
if isinstance(key, str):
|
||||
return key
|
||||
return hashlib.md5(str(key).encode()).hexdigest()
|
||||
|
||||
def _cleanup_old_cache(self):
|
||||
"""定期清理过期缓存的后台线程"""
|
||||
while True:
|
||||
try:
|
||||
# 检查并清理内存缓存
|
||||
current_time = datetime.now()
|
||||
expired_topics = []
|
||||
|
||||
for topic, cache_data in self.cache.items():
|
||||
if not self._is_cache_valid(cache_data['timestamp']):
|
||||
expired_topics.append(topic)
|
||||
|
||||
# 删除过期缓存
|
||||
for topic in expired_topics:
|
||||
del self.cache[topic]
|
||||
cache_file = os.path.join(self.cache_dir, f"{topic}.json")
|
||||
if os.path.exists(cache_file):
|
||||
os.remove(cache_file)
|
||||
|
||||
# 休眠1小时后再次检查
|
||||
time.sleep(3600)
|
||||
except Exception as e:
|
||||
print(f"清理缓存时出错: {e}")
|
||||
time.sleep(3600) # 发生错误时也等待1小时
|
||||
def _get_disk_path(self, key):
|
||||
"""获取磁盘缓存路径"""
|
||||
safe_key = self._get_cache_key(key)
|
||||
return os.path.join(self.disk_cache_dir, f"{safe_key}.cache")
|
||||
|
||||
def _is_cache_valid(self, timestamp):
|
||||
"""检查缓存是否有效"""
|
||||
"""检查缓存是否过期"""
|
||||
cache_time = datetime.fromtimestamp(timestamp)
|
||||
return datetime.now() - cache_time < self.cache_duration
|
||||
|
||||
def get(self, topic):
|
||||
"""获取话题的预测缓存"""
|
||||
if topic in self.cache and self._is_cache_valid(self.cache[topic]['timestamp']):
|
||||
return self.cache[topic]['prediction']
|
||||
def get(self, key):
|
||||
"""获取缓存数据,首先检查内存,然后检查磁盘"""
|
||||
cache_key = self._get_cache_key(key)
|
||||
|
||||
# 1. 检查内存缓存
|
||||
cache_data = self.memory_cache.get(cache_key)
|
||||
if cache_data is not None:
|
||||
if self._is_cache_valid(cache_data['timestamp']):
|
||||
self.cache_stats["hits"] += 1
|
||||
logger.debug(f"内存缓存命中: {key}")
|
||||
return cache_data['data']
|
||||
else:
|
||||
# 过期缓存,从内存中删除
|
||||
self.memory_cache.remove(cache_key)
|
||||
|
||||
# 2. 检查磁盘缓存
|
||||
disk_path = self._get_disk_path(cache_key)
|
||||
if os.path.exists(disk_path):
|
||||
try:
|
||||
with open(disk_path, 'rb') as f:
|
||||
cache_data = pickle.load(f)
|
||||
|
||||
if self._is_cache_valid(cache_data['timestamp']):
|
||||
# 从磁盘加载后,放入内存缓存
|
||||
self.memory_cache.put(cache_key, cache_data)
|
||||
self.cache_stats["disk_hits"] += 1
|
||||
logger.debug(f"磁盘缓存命中: {key}")
|
||||
return cache_data['data']
|
||||
else:
|
||||
# 过期缓存,删除磁盘文件
|
||||
os.remove(disk_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"读取磁盘缓存失败: {key}, 错误: {e}")
|
||||
|
||||
self.cache_stats["misses"] += 1
|
||||
logger.debug(f"缓存未命中: {key}")
|
||||
return None
|
||||
|
||||
def set(self, topic, prediction):
|
||||
"""设置话题的预测缓存"""
|
||||
def set(self, key, data, immediate_disk_write=False):
|
||||
"""设置缓存数据,同时更新内存和安排磁盘写入"""
|
||||
cache_key = self._get_cache_key(key)
|
||||
cache_data = {
|
||||
'prediction': prediction,
|
||||
'data': data,
|
||||
'timestamp': datetime.now().timestamp()
|
||||
}
|
||||
|
||||
# 更新内存缓存
|
||||
self.cache[topic] = cache_data
|
||||
self.memory_cache.put(cache_key, cache_data)
|
||||
|
||||
# 异步保存到磁盘
|
||||
self.cache_queue.put((topic, cache_data))
|
||||
threading.Thread(target=self._save_cache_to_disk, daemon=True).start()
|
||||
# 安排写入磁盘
|
||||
if immediate_disk_write:
|
||||
self._write_to_disk(cache_key, cache_data)
|
||||
else:
|
||||
self.disk_queue.put((cache_key, cache_data))
|
||||
|
||||
logger.debug(f"缓存已设置: {key}")
|
||||
return True
|
||||
|
||||
def _save_cache_to_disk(self):
|
||||
"""异步保存缓存到磁盘"""
|
||||
def invalidate(self, key):
|
||||
"""使指定键的缓存失效"""
|
||||
cache_key = self._get_cache_key(key)
|
||||
|
||||
# 从内存中删除
|
||||
self.memory_cache.remove(cache_key)
|
||||
|
||||
# 从磁盘中删除
|
||||
disk_path = self._get_disk_path(cache_key)
|
||||
if os.path.exists(disk_path):
|
||||
try:
|
||||
os.remove(disk_path)
|
||||
logger.debug(f"缓存已失效: {key}")
|
||||
except Exception as e:
|
||||
logger.warning(f"删除磁盘缓存失败: {key}, 错误: {e}")
|
||||
|
||||
return True
|
||||
|
||||
def clear_all(self):
|
||||
"""清除所有缓存"""
|
||||
# 清除内存缓存
|
||||
self.memory_cache.clear()
|
||||
|
||||
# 清除磁盘缓存
|
||||
try:
|
||||
while not self.cache_queue.empty():
|
||||
topic, cache_data = self.cache_queue.get()
|
||||
cache_file = os.path.join(self.cache_dir, f"{topic}.json")
|
||||
with open(cache_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(cache_data, f, ensure_ascii=False, indent=2)
|
||||
shutil.rmtree(self.disk_cache_dir)
|
||||
os.makedirs(self.disk_cache_dir, exist_ok=True)
|
||||
logger.info(f"所有缓存已清除: {self.name}")
|
||||
except Exception as e:
|
||||
print(f"保存缓存到磁盘失败: {e}")
|
||||
logger.error(f"清除磁盘缓存失败: {e}")
|
||||
|
||||
# 重置统计信息
|
||||
self.cache_stats = {"hits": 0, "misses": 0, "disk_hits": 0}
|
||||
|
||||
return True
|
||||
|
||||
def get_stats(self):
|
||||
"""获取缓存统计信息"""
|
||||
total_requests = self.cache_stats["hits"] + self.cache_stats["misses"]
|
||||
hit_rate = (self.cache_stats["hits"] / total_requests * 100) if total_requests > 0 else 0
|
||||
total_hits = self.cache_stats["hits"] + self.cache_stats["disk_hits"]
|
||||
|
||||
memory_size = len(self.memory_cache)
|
||||
disk_size = len([f for f in os.listdir(self.disk_cache_dir) if f.endswith('.cache')])
|
||||
|
||||
return {
|
||||
"name": self.name,
|
||||
"memory_items": memory_size,
|
||||
"disk_items": disk_size,
|
||||
"memory_hits": self.cache_stats["hits"],
|
||||
"disk_hits": self.cache_stats["disk_hits"],
|
||||
"misses": self.cache_stats["misses"],
|
||||
"total_requests": total_requests,
|
||||
"hit_rate": hit_rate,
|
||||
"two_level_hit_rate": (total_hits / total_requests * 100) if total_requests > 0 else 0
|
||||
}
|
||||
|
||||
def _write_to_disk(self, cache_key, cache_data):
|
||||
"""将缓存写入磁盘"""
|
||||
disk_path = self._get_disk_path(cache_key)
|
||||
try:
|
||||
with open(disk_path, 'wb') as f:
|
||||
pickle.dump(cache_data, f)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"写入磁盘缓存失败: {cache_key}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def _disk_writer_task(self):
|
||||
"""后台线程,负责将缓存写入磁盘"""
|
||||
while True:
|
||||
try:
|
||||
# 尝试从队列获取条目,超时后继续循环
|
||||
try:
|
||||
cache_key, cache_data = self.disk_queue.get(timeout=1)
|
||||
self._write_to_disk(cache_key, cache_data)
|
||||
self.disk_queue.task_done()
|
||||
except queue.Empty:
|
||||
time.sleep(0.1)
|
||||
except Exception as e:
|
||||
logger.error(f"磁盘写入线程出错: {e}")
|
||||
time.sleep(5) # 发生错误时等待一段时间
|
||||
|
||||
def _cleanup_and_flush_task(self):
|
||||
"""后台线程,负责清理过期缓存和定期刷新内存缓存到磁盘"""
|
||||
while True:
|
||||
try:
|
||||
# 1. 清理过期的内存缓存
|
||||
current_time = datetime.now()
|
||||
for key in self.memory_cache.get_all_keys():
|
||||
cache_data = self.memory_cache.get(key)
|
||||
if not self._is_cache_valid(cache_data['timestamp']):
|
||||
self.memory_cache.remove(key)
|
||||
|
||||
# 2. 清理过期的磁盘缓存
|
||||
for filename in os.listdir(self.disk_cache_dir):
|
||||
if filename.endswith('.cache'):
|
||||
filepath = os.path.join(self.disk_cache_dir, filename)
|
||||
try:
|
||||
with open(filepath, 'rb') as f:
|
||||
cache_data = pickle.load(f)
|
||||
if not self._is_cache_valid(cache_data['timestamp']):
|
||||
os.remove(filepath)
|
||||
except Exception as e:
|
||||
# 清理损坏的缓存文件
|
||||
logger.warning(f"读取缓存文件失败,将删除: {filepath}, 错误: {e}")
|
||||
os.remove(filepath)
|
||||
|
||||
# 3. 将内存缓存刷新到磁盘
|
||||
# 注意:这会重写已经写入磁盘的缓存,但确保内存和磁盘保持同步
|
||||
for key in self.memory_cache.get_all_keys():
|
||||
cache_data = self.memory_cache.get(key)
|
||||
self._write_to_disk(key, cache_data)
|
||||
|
||||
# 每小时执行一次清理
|
||||
time.sleep(3600)
|
||||
except Exception as e:
|
||||
logger.error(f"缓存清理线程出错: {e}")
|
||||
time.sleep(3600) # 发生错误时也等待一段时间
|
||||
|
||||
# 创建全局缓存实例
|
||||
prediction_cache = PredictionCache()
|
||||
|
||||
# 创建不同领域的缓存实例
|
||||
prediction_cache = CacheManager(name="predictions", memory_capacity=500, cache_duration=24)
|
||||
sentiment_cache = CacheManager(name="sentiment", memory_capacity=1000, cache_duration=12)
|
||||
topic_cache = CacheManager(name="topics", memory_capacity=200, cache_duration=6)
|
||||
user_data_cache = CacheManager(name="user_data", memory_capacity=300, cache_duration=48)
|
||||
|
||||
# 向后兼容的别名
|
||||
PredictionCache = CacheManager
|
||||
# 为保持向后兼容,我们保留原来的prediction_cache
|
||||
prediction_cache_old = prediction_cache
|
||||
@@ -0,0 +1,558 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import getpass
|
||||
import secrets
|
||||
import logging
|
||||
import platform
|
||||
import socket
|
||||
import hashlib
|
||||
import base64
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
import pymysql
|
||||
from dotenv import load_dotenv, set_key, find_dotenv
|
||||
|
||||
# 设置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[logging.StreamHandler(sys.stdout)]
|
||||
)
|
||||
logger = logging.getLogger('init_wizard')
|
||||
|
||||
class InitWizard:
|
||||
"""
|
||||
初始化向导 - 简化系统的初始配置流程,并提供安全加固功能
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
# 配置项
|
||||
self.config = {
|
||||
# 数据库配置
|
||||
'db': {
|
||||
'host': os.getenv('DB_HOST', 'localhost'),
|
||||
'port': int(os.getenv('DB_PORT', '3306')),
|
||||
'user': os.getenv('DB_USER', 'root'),
|
||||
'password': os.getenv('DB_PASSWORD', ''),
|
||||
'database': os.getenv('DB_NAME', 'Weibo_PublicOpinion_AnalysisSystem'),
|
||||
'ssl': bool(os.getenv('DB_SSL', 'false').lower() == 'true')
|
||||
},
|
||||
# Flask应用配置
|
||||
'app': {
|
||||
'host': os.getenv('FLASK_HOST', '127.0.0.1'),
|
||||
'port': int(os.getenv('FLASK_PORT', '5000')),
|
||||
'secret_key': os.getenv('FLASK_SECRET_KEY', ''),
|
||||
'enable_https': bool(os.getenv('ENABLE_HTTPS', 'false').lower() == 'true'),
|
||||
'debug': bool(os.getenv('FLASK_DEBUG', 'false').lower() == 'true')
|
||||
},
|
||||
# API密钥配置
|
||||
'api_keys': {
|
||||
'openai': os.getenv('OPENAI_API_KEY', ''),
|
||||
'anthropic': os.getenv('ANTHROPIC_API_KEY', ''),
|
||||
'deepseek': os.getenv('DEEPSEEK_API_KEY', '')
|
||||
},
|
||||
# 安全配置
|
||||
'security': {
|
||||
'enable_rate_limit': bool(os.getenv('ENABLE_RATE_LIMIT', 'true').lower() == 'true'),
|
||||
'enable_ip_blocking': bool(os.getenv('ENABLE_IP_BLOCKING', 'true').lower() == 'true'),
|
||||
'enable_sensitive_data_filter': bool(os.getenv('ENABLE_SENSITIVE_DATA_FILTER', 'true').lower() == 'true'),
|
||||
'enable_mutual_auth': bool(os.getenv('ENABLE_MUTUAL_AUTH', 'false').lower() == 'true'),
|
||||
'min_password_length': int(os.getenv('MIN_PASSWORD_LENGTH', '8')),
|
||||
'session_timeout': int(os.getenv('SESSION_TIMEOUT', '120')), # 分钟
|
||||
},
|
||||
# 爬虫配置
|
||||
'crawler': {
|
||||
'interval': int(os.getenv('CRAWL_INTERVAL', '18000')), # 秒
|
||||
'max_retries': int(os.getenv('CRAWL_MAX_RETRIES', '3')),
|
||||
'timeout': int(os.getenv('CRAWL_TIMEOUT', '30')),
|
||||
'max_concurrent': int(os.getenv('CRAWL_MAX_CONCURRENT', '2')),
|
||||
'user_agent': os.getenv('CRAWL_USER_AGENT', 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36')
|
||||
},
|
||||
# 系统配置
|
||||
'system': {
|
||||
'initialized': bool(os.getenv('SYSTEM_INITIALIZED', 'false').lower() == 'true'),
|
||||
'version': os.getenv('SYSTEM_VERSION', '2.0.0'),
|
||||
'log_level': os.getenv('LOG_LEVEL', 'INFO'),
|
||||
'data_dir': os.getenv('DATA_DIR', 'data'),
|
||||
'temp_dir': os.getenv('TEMP_DIR', 'temp'),
|
||||
'cache_dir': os.getenv('CACHE_DIR', 'cache'),
|
||||
'max_model_memory': float(os.getenv('MAX_MODEL_MEMORY_USAGE', '4.0')), # GB
|
||||
}
|
||||
}
|
||||
|
||||
# 安全选项
|
||||
self.security_options = {
|
||||
'rate_limit': {
|
||||
'name': '请求速率限制',
|
||||
'description': '防止API被滥用,限制单个IP的请求频率',
|
||||
'default': True
|
||||
},
|
||||
'ip_blocking': {
|
||||
'name': 'IP黑名单',
|
||||
'description': '阻止可疑IP访问系统',
|
||||
'default': True
|
||||
},
|
||||
'sensitive_data_filter': {
|
||||
'name': '敏感信息过滤',
|
||||
'description': '自动识别并屏蔽输出内容中的敏感信息(如手机号、邮箱等)',
|
||||
'default': True
|
||||
},
|
||||
'mutual_auth': {
|
||||
'name': '双向认证',
|
||||
'description': '要求API调用方提供有效证书,增强API安全性(需要HTTPS)',
|
||||
'default': False
|
||||
}
|
||||
}
|
||||
|
||||
def start(self):
|
||||
"""启动初始化向导"""
|
||||
self._print_welcome()
|
||||
|
||||
if self.config['system']['initialized']:
|
||||
print("\n系统已经初始化过。您想重新配置吗? [y/N]: ", end='')
|
||||
choice = input().strip().lower()
|
||||
if choice != 'y':
|
||||
print("初始化向导已退出。如需重新配置,请设置环境变量 SYSTEM_INITIALIZED=false 或删除 .env 文件。")
|
||||
return
|
||||
|
||||
# 主配置流程
|
||||
try:
|
||||
self._configure_database()
|
||||
self._configure_app()
|
||||
self._configure_api_keys()
|
||||
self._configure_security()
|
||||
self._configure_crawler()
|
||||
self._configure_system()
|
||||
|
||||
# 保存配置
|
||||
self._save_config()
|
||||
|
||||
# 应用安全措施
|
||||
self._apply_security_measures()
|
||||
|
||||
print("\n✅ 初始化完成!系统已成功配置。")
|
||||
print("您现在可以运行 python app.py 启动应用。")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n初始化向导已取消。配置未保存。")
|
||||
except Exception as e:
|
||||
logger.error(f"初始化过程中发生错误: {e}")
|
||||
print(f"\n❌ 初始化失败: {e}")
|
||||
print("请检查错误并重试。")
|
||||
|
||||
def _print_welcome(self):
|
||||
"""打印欢迎信息"""
|
||||
print("\n" + "="*80)
|
||||
print(" "*20 + "微博舆情分析预测系统 - 初始化向导 v2.0")
|
||||
print("="*80)
|
||||
print("\n欢迎使用微博舆情分析预测系统!此向导将引导您完成系统的初始配置。")
|
||||
print("按Ctrl+C可随时退出向导。")
|
||||
print("\n系统信息:")
|
||||
print(f" • 操作系统: {platform.system()} {platform.release()}")
|
||||
print(f" • Python版本: {platform.python_version()}")
|
||||
print(f" • 主机名: {socket.gethostname()}")
|
||||
print(f" • 当前时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print("\n让我们开始配置吧!每个选项都有默认值,直接按回车即可使用默认值。")
|
||||
print("-"*80)
|
||||
|
||||
def _configure_database(self):
|
||||
"""配置数据库连接"""
|
||||
print("\n📦 数据库配置")
|
||||
print("-"*50)
|
||||
|
||||
# 询问数据库连接信息
|
||||
self.config['db']['host'] = self._prompt(
|
||||
"数据库主机", self.config['db']['host'])
|
||||
|
||||
port_str = self._prompt(
|
||||
"数据库端口", str(self.config['db']['port']))
|
||||
try:
|
||||
self.config['db']['port'] = int(port_str)
|
||||
except ValueError:
|
||||
print(f"端口号无效,使用默认值 {self.config['db']['port']}")
|
||||
|
||||
self.config['db']['user'] = self._prompt(
|
||||
"数据库用户名", self.config['db']['user'])
|
||||
|
||||
# 密码使用getpass以避免明文显示
|
||||
default_pass = '*' * len(self.config['db']['password']) if self.config['db']['password'] else ''
|
||||
password = getpass.getpass(f"数据库密码 [{default_pass}]: ")
|
||||
if password:
|
||||
self.config['db']['password'] = password
|
||||
|
||||
self.config['db']['database'] = self._prompt(
|
||||
"数据库名", self.config['db']['database'])
|
||||
|
||||
ssl_str = self._prompt(
|
||||
"使用SSL连接 (true/false)", str(self.config['db']['ssl']).lower())
|
||||
self.config['db']['ssl'] = ssl_str.lower() == 'true'
|
||||
|
||||
# 测试数据库连接
|
||||
print("\n正在测试数据库连接...")
|
||||
try:
|
||||
self._test_db_connection()
|
||||
print("✅ 数据库连接成功!")
|
||||
except Exception as e:
|
||||
print(f"❌ 数据库连接失败: {e}")
|
||||
retry = input("是否重新配置数据库连接? [Y/n]: ").strip().lower()
|
||||
if retry != 'n':
|
||||
return self._configure_database()
|
||||
else:
|
||||
print("跳过数据库连接测试,但配置可能不正确。")
|
||||
|
||||
def _configure_app(self):
|
||||
"""配置Flask应用"""
|
||||
print("\n🚀 应用配置")
|
||||
print("-"*50)
|
||||
|
||||
self.config['app']['host'] = self._prompt(
|
||||
"监听地址 (0.0.0.0表示所有网络接口)", self.config['app']['host'])
|
||||
|
||||
port_str = self._prompt(
|
||||
"监听端口", str(self.config['app']['port']))
|
||||
try:
|
||||
self.config['app']['port'] = int(port_str)
|
||||
except ValueError:
|
||||
print(f"端口号无效,使用默认值 {self.config['app']['port']}")
|
||||
|
||||
# 自动生成密钥
|
||||
if not self.config['app']['secret_key']:
|
||||
self.config['app']['secret_key'] = secrets.token_hex(32)
|
||||
print(f"已自动生成应用密钥: {self.config['app']['secret_key'][:8]}...")
|
||||
else:
|
||||
regenerate = input("应用密钥已存在。是否重新生成? [y/N]: ").strip().lower()
|
||||
if regenerate == 'y':
|
||||
self.config['app']['secret_key'] = secrets.token_hex(32)
|
||||
print(f"已重新生成应用密钥: {self.config['app']['secret_key'][:8]}...")
|
||||
|
||||
https_str = self._prompt(
|
||||
"启用HTTPS (true/false)", str(self.config['app']['enable_https']).lower())
|
||||
self.config['app']['enable_https'] = https_str.lower() == 'true'
|
||||
|
||||
debug_str = self._prompt(
|
||||
"启用调试模式 (true/false, 生产环境建议false)", str(self.config['app']['debug']).lower())
|
||||
self.config['app']['debug'] = debug_str.lower() == 'true'
|
||||
|
||||
def _configure_api_keys(self):
|
||||
"""配置API密钥"""
|
||||
print("\n🔑 API密钥配置")
|
||||
print("-"*50)
|
||||
print("系统支持多个大语言模型,至少需要配置一个API密钥。")
|
||||
|
||||
# 配置OpenAI API密钥
|
||||
has_openai = self._prompt(
|
||||
"是否配置OpenAI API密钥? (y/n)", "y" if self.config['api_keys']['openai'] else "n")
|
||||
if has_openai.lower() == 'y':
|
||||
self.config['api_keys']['openai'] = self._prompt(
|
||||
"OpenAI API密钥", self.config['api_keys']['openai'])
|
||||
|
||||
# 配置Anthropic API密钥
|
||||
has_anthropic = self._prompt(
|
||||
"是否配置Anthropic (Claude) API密钥? (y/n)", "y" if self.config['api_keys']['anthropic'] else "n")
|
||||
if has_anthropic.lower() == 'y':
|
||||
self.config['api_keys']['anthropic'] = self._prompt(
|
||||
"Anthropic API密钥", self.config['api_keys']['anthropic'])
|
||||
|
||||
# 配置DeepSeek API密钥
|
||||
has_deepseek = self._prompt(
|
||||
"是否配置DeepSeek API密钥? (y/n)", "y" if self.config['api_keys']['deepseek'] else "n")
|
||||
if has_deepseek.lower() == 'y':
|
||||
self.config['api_keys']['deepseek'] = self._prompt(
|
||||
"DeepSeek API密钥", self.config['api_keys']['deepseek'])
|
||||
|
||||
# 检查是否至少配置了一个API密钥
|
||||
if not (self.config['api_keys']['openai'] or self.config['api_keys']['anthropic'] or self.config['api_keys']['deepseek']):
|
||||
print("⚠️ 警告: 您未配置任何API密钥,系统的AI分析功能将不可用。")
|
||||
confirm = input("是否继续? [Y/n]: ").strip().lower()
|
||||
if confirm == 'n':
|
||||
return self._configure_api_keys()
|
||||
|
||||
def _configure_security(self):
|
||||
"""配置安全设置"""
|
||||
print("\n🔒 安全配置")
|
||||
print("-"*50)
|
||||
|
||||
for key, option in self.security_options.items():
|
||||
current_value = self.config['security'][f'enable_{key}']
|
||||
print(f"\n{option['name']}: {option['description']}")
|
||||
enable_str = self._prompt(
|
||||
f"启用{option['name']} (true/false)", str(current_value).lower())
|
||||
self.config['security'][f'enable_{key}'] = enable_str.lower() == 'true'
|
||||
|
||||
# 密码安全策略
|
||||
min_len_str = self._prompt(
|
||||
"最小密码长度 (推荐不低于8)", str(self.config['security']['min_password_length']))
|
||||
try:
|
||||
self.config['security']['min_password_length'] = int(min_len_str)
|
||||
if self.config['security']['min_password_length'] < 6:
|
||||
print("⚠️ 警告: 短密码容易被暴力破解,建议设置更长的密码。")
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认值 {self.config['security']['min_password_length']}")
|
||||
|
||||
# 会话超时设置
|
||||
timeout_str = self._prompt(
|
||||
"会话超时时间 (分钟)", str(self.config['security']['session_timeout']))
|
||||
try:
|
||||
self.config['security']['session_timeout'] = int(timeout_str)
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认值 {self.config['security']['session_timeout']}")
|
||||
|
||||
def _configure_crawler(self):
|
||||
"""配置爬虫设置"""
|
||||
print("\n🕷️ 爬虫配置")
|
||||
print("-"*50)
|
||||
|
||||
interval_str = self._prompt(
|
||||
"爬取间隔 (秒)", str(self.config['crawler']['interval']))
|
||||
try:
|
||||
self.config['crawler']['interval'] = int(interval_str)
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认值 {self.config['crawler']['interval']}")
|
||||
|
||||
retries_str = self._prompt(
|
||||
"最大重试次数", str(self.config['crawler']['max_retries']))
|
||||
try:
|
||||
self.config['crawler']['max_retries'] = int(retries_str)
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认值 {self.config['crawler']['max_retries']}")
|
||||
|
||||
timeout_str = self._prompt(
|
||||
"超时时间 (秒)", str(self.config['crawler']['timeout']))
|
||||
try:
|
||||
self.config['crawler']['timeout'] = int(timeout_str)
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认值 {self.config['crawler']['timeout']}")
|
||||
|
||||
concurrent_str = self._prompt(
|
||||
"最大并发数", str(self.config['crawler']['max_concurrent']))
|
||||
try:
|
||||
self.config['crawler']['max_concurrent'] = int(concurrent_str)
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认值 {self.config['crawler']['max_concurrent']}")
|
||||
|
||||
self.config['crawler']['user_agent'] = self._prompt(
|
||||
"User-Agent", self.config['crawler']['user_agent'])
|
||||
|
||||
def _configure_system(self):
|
||||
"""配置系统设置"""
|
||||
print("\n⚙️ 系统配置")
|
||||
print("-"*50)
|
||||
|
||||
# 日志级别
|
||||
log_levels = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']
|
||||
current_level = self.config['system']['log_level']
|
||||
print(f"可选日志级别: {', '.join(log_levels)}")
|
||||
log_level = self._prompt("日志级别", current_level).upper()
|
||||
if log_level in log_levels:
|
||||
self.config['system']['log_level'] = log_level
|
||||
else:
|
||||
print(f"无效的日志级别,使用默认值 {current_level}")
|
||||
|
||||
# 数据目录
|
||||
data_dir = self._prompt("数据目录", self.config['system']['data_dir'])
|
||||
if data_dir:
|
||||
self.config['system']['data_dir'] = data_dir
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
print(f"已创建数据目录: {data_dir}")
|
||||
|
||||
# 缓存目录
|
||||
cache_dir = self._prompt("缓存目录", self.config['system']['cache_dir'])
|
||||
if cache_dir:
|
||||
self.config['system']['cache_dir'] = cache_dir
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
print(f"已创建缓存目录: {cache_dir}")
|
||||
|
||||
# 临时目录
|
||||
temp_dir = self._prompt("临时文件目录", self.config['system']['temp_dir'])
|
||||
if temp_dir:
|
||||
self.config['system']['temp_dir'] = temp_dir
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
print(f"已创建临时文件目录: {temp_dir}")
|
||||
|
||||
# 模型内存限制
|
||||
memory_str = self._prompt(
|
||||
"最大模型内存使用量 (GB)", str(self.config['system']['max_model_memory']))
|
||||
try:
|
||||
self.config['system']['max_model_memory'] = float(memory_str)
|
||||
except ValueError:
|
||||
print(f"无效输入,使用默认值 {self.config['system']['max_model_memory']}")
|
||||
|
||||
# 标记系统已初始化
|
||||
self.config['system']['initialized'] = True
|
||||
|
||||
def _save_config(self):
|
||||
"""保存配置到.env文件"""
|
||||
print("\n正在保存配置...")
|
||||
|
||||
# 构建.env文件内容
|
||||
env_content = [
|
||||
"# 微博舆情分析预测系统配置文件",
|
||||
f"# 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
"",
|
||||
"# 数据库配置",
|
||||
f"DB_HOST={self.config['db']['host']}",
|
||||
f"DB_PORT={self.config['db']['port']}",
|
||||
f"DB_USER={self.config['db']['user']}",
|
||||
f"DB_PASSWORD={self.config['db']['password']}",
|
||||
f"DB_NAME={self.config['db']['database']}",
|
||||
f"DB_SSL={str(self.config['db']['ssl']).lower()}",
|
||||
"",
|
||||
"# 应用配置",
|
||||
f"FLASK_HOST={self.config['app']['host']}",
|
||||
f"FLASK_PORT={self.config['app']['port']}",
|
||||
f"FLASK_SECRET_KEY={self.config['app']['secret_key']}",
|
||||
f"ENABLE_HTTPS={str(self.config['app']['enable_https']).lower()}",
|
||||
f"FLASK_DEBUG={str(self.config['app']['debug']).lower()}",
|
||||
"",
|
||||
"# API密钥",
|
||||
f"OPENAI_API_KEY={self.config['api_keys']['openai']}",
|
||||
f"ANTHROPIC_API_KEY={self.config['api_keys']['anthropic']}",
|
||||
f"DEEPSEEK_API_KEY={self.config['api_keys']['deepseek']}",
|
||||
"",
|
||||
"# 安全配置",
|
||||
f"ENABLE_RATE_LIMIT={str(self.config['security']['enable_rate_limit']).lower()}",
|
||||
f"ENABLE_IP_BLOCKING={str(self.config['security']['enable_ip_blocking']).lower()}",
|
||||
f"ENABLE_SENSITIVE_DATA_FILTER={str(self.config['security']['enable_sensitive_data_filter']).lower()}",
|
||||
f"ENABLE_MUTUAL_AUTH={str(self.config['security']['enable_mutual_auth']).lower()}",
|
||||
f"MIN_PASSWORD_LENGTH={self.config['security']['min_password_length']}",
|
||||
f"SESSION_TIMEOUT={self.config['security']['session_timeout']}",
|
||||
"",
|
||||
"# 爬虫配置",
|
||||
f"CRAWL_INTERVAL={self.config['crawler']['interval']}",
|
||||
f"CRAWL_MAX_RETRIES={self.config['crawler']['max_retries']}",
|
||||
f"CRAWL_TIMEOUT={self.config['crawler']['timeout']}",
|
||||
f"CRAWL_MAX_CONCURRENT={self.config['crawler']['max_concurrent']}",
|
||||
f"CRAWL_USER_AGENT={self.config['crawler']['user_agent']}",
|
||||
"",
|
||||
"# 系统配置",
|
||||
f"SYSTEM_INITIALIZED={str(self.config['system']['initialized']).lower()}",
|
||||
f"SYSTEM_VERSION={self.config['system']['version']}",
|
||||
f"LOG_LEVEL={self.config['system']['log_level']}",
|
||||
f"DATA_DIR={self.config['system']['data_dir']}",
|
||||
f"TEMP_DIR={self.config['system']['temp_dir']}",
|
||||
f"CACHE_DIR={self.config['system']['cache_dir']}",
|
||||
f"MAX_MODEL_MEMORY_USAGE={self.config['system']['max_model_memory']}",
|
||||
]
|
||||
|
||||
# 写入.env文件
|
||||
with open('.env', 'w') as f:
|
||||
f.write('\n'.join(env_content))
|
||||
|
||||
print("✅ 配置已保存到 .env 文件")
|
||||
|
||||
# 创建备份
|
||||
backup_path = f".env.backup.{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
shutil.copy2('.env', backup_path)
|
||||
print(f"✅ 配置备份已保存到 {backup_path}")
|
||||
|
||||
def _test_db_connection(self):
|
||||
"""测试数据库连接"""
|
||||
connection = pymysql.connect(
|
||||
host=self.config['db']['host'],
|
||||
port=self.config['db']['port'],
|
||||
user=self.config['db']['user'],
|
||||
password=self.config['db']['password'],
|
||||
database=self.config['db']['database'],
|
||||
charset='utf8mb4',
|
||||
ssl={'ssl': {'ca': None}} if self.config['db']['ssl'] else None
|
||||
)
|
||||
connection.close()
|
||||
|
||||
def _apply_security_measures(self):
|
||||
"""应用安全措施"""
|
||||
print("\n正在应用安全措施...")
|
||||
|
||||
# 创建相关目录
|
||||
security_dir = os.path.join(self.config['system']['data_dir'], 'security')
|
||||
os.makedirs(security_dir, exist_ok=True)
|
||||
|
||||
# 设置文件权限
|
||||
try:
|
||||
# 仅在类Unix系统上设置文件权限
|
||||
if platform.system() != "Windows":
|
||||
os.chmod('.env', 0o600) # 只有所有者可读写
|
||||
print("✅ 已设置.env文件权限为600 (只有所有者可读写)")
|
||||
except Exception as e:
|
||||
logger.warning(f"设置文件权限失败: {e}")
|
||||
|
||||
# 生成密钥对(如果启用了双向认证)
|
||||
if self.config['security']['enable_mutual_auth']:
|
||||
cert_dir = os.path.join(security_dir, 'certs')
|
||||
os.makedirs(cert_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
# 检查是否有OpenSSL可用
|
||||
subprocess.run(['openssl', 'version'], check=True, capture_output=True)
|
||||
|
||||
# 生成自签名证书
|
||||
key_file = os.path.join(cert_dir, 'server.key')
|
||||
cert_file = os.path.join(cert_dir, 'server.crt')
|
||||
|
||||
if not os.path.exists(key_file) or not os.path.exists(cert_file):
|
||||
print("正在生成SSL证书...")
|
||||
subprocess.run([
|
||||
'openssl', 'req', '-x509', '-newkey', 'rsa:4096',
|
||||
'-keyout', key_file, '-out', cert_file,
|
||||
'-days', '365', '-nodes',
|
||||
'-subj', '/CN=localhost'
|
||||
], check=True)
|
||||
print(f"✅ SSL证书已生成: {cert_file}")
|
||||
except subprocess.CalledProcessError:
|
||||
print("⚠️ OpenSSL不可用,无法生成SSL证书。如需使用HTTPS,请手动配置证书。")
|
||||
except Exception as e:
|
||||
logger.warning(f"生成SSL证书失败: {e}")
|
||||
|
||||
# 创建敏感信息过滤器配置
|
||||
if self.config['security']['enable_sensitive_data_filter']:
|
||||
filter_config = {
|
||||
'enabled': True,
|
||||
'patterns': {
|
||||
'phone': r'\b1[3-9]\d{9}\b',
|
||||
'email': r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
|
||||
'id_card': r'\b[1-9]\d{5}(19|20)\d{2}(0[1-9]|1[0-2])(0[1-9]|[12]\d|3[01])\d{3}[\dXx]\b',
|
||||
'credit_card': r'\b\d{4}[ -]?\d{4}[ -]?\d{4}[ -]?\d{4}\b',
|
||||
'address': r'(北京|上海|广州|深圳|天津|重庆|南京|杭州|武汉|成都|西安)市.*?(路|街|道|巷).*?(号)'
|
||||
},
|
||||
'replacements': {
|
||||
'phone': '***********',
|
||||
'email': '******@*****',
|
||||
'id_card': '******************',
|
||||
'credit_card': '****************',
|
||||
'address': '[地址已隐藏]'
|
||||
}
|
||||
}
|
||||
|
||||
filter_path = os.path.join(security_dir, 'sensitive_filter.json')
|
||||
with open(filter_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(filter_config, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"✅ 敏感信息过滤器配置已保存到 {filter_path}")
|
||||
|
||||
# 创建IP黑名单文件
|
||||
if self.config['security']['enable_ip_blocking']:
|
||||
blacklist_path = os.path.join(security_dir, 'ip_blacklist.txt')
|
||||
if not os.path.exists(blacklist_path):
|
||||
with open(blacklist_path, 'w') as f:
|
||||
f.write("# 每行一个IP地址\n")
|
||||
print(f"✅ IP黑名单文件已创建: {blacklist_path}")
|
||||
|
||||
def _prompt(self, prompt, default=""):
|
||||
"""提示用户输入,如果用户直接按回车则返回默认值"""
|
||||
if default:
|
||||
user_input = input(f"{prompt} [{default}]: ").strip()
|
||||
else:
|
||||
user_input = input(f"{prompt}: ").strip()
|
||||
|
||||
return user_input if user_input else default
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
wizard = InitWizard()
|
||||
wizard.start()
|
||||
@@ -0,0 +1,285 @@
|
||||
import os
|
||||
import sys
|
||||
import pickle
|
||||
import marshal
|
||||
import types
|
||||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger('model_loader')
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
def load_sentiment_model(model_path, device=None):
|
||||
"""
|
||||
加载情感分析模型
|
||||
|
||||
参数:
|
||||
model_path: 模型文件路径
|
||||
device: 设备(可忽略,marshal模型不依赖设备)
|
||||
|
||||
返回:
|
||||
加载好的模型对象
|
||||
"""
|
||||
try:
|
||||
logger.info(f"加载情感分析模型: {model_path}")
|
||||
|
||||
if model_path.endswith('.marshal') or model_path.endswith('.marshal.3'):
|
||||
with open(model_path, 'rb') as f:
|
||||
model_data = marshal.load(f)
|
||||
|
||||
# 将marshal数据转换为可调用的函数对象
|
||||
sentiment_func = types.FunctionType(model_data, globals(), "sentiment_func")
|
||||
logger.info("情感分析模型加载成功")
|
||||
return sentiment_func
|
||||
else:
|
||||
raise ValueError(f"不支持的情感模型格式: {model_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载情感分析模型失败: {e}")
|
||||
raise
|
||||
|
||||
def load_bert_ctm_model(model_dir, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
||||
"""
|
||||
加载BERT-CTM模型
|
||||
|
||||
参数:
|
||||
model_dir: 模型目录
|
||||
device: 计算设备
|
||||
|
||||
返回:
|
||||
包含模型和分词器的字典
|
||||
"""
|
||||
try:
|
||||
logger.info(f"加载BERT-CTM模型: {model_dir}")
|
||||
|
||||
sys.path.append('model_pro')
|
||||
from BERT_CTM import BERT_CTM
|
||||
from transformers import BertTokenizer
|
||||
|
||||
# 加载模型
|
||||
model_path = os.path.join(model_dir, 'final_model.pt') if not model_dir.endswith('.pt') else model_dir
|
||||
model = BERT_CTM()
|
||||
model.load_state_dict(torch.load(model_path, map_location=device))
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# 加载分词器
|
||||
tokenizer_path = os.path.join(os.path.dirname(model_dir), 'bert_model')
|
||||
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
logger.info("BERT-CTM模型加载成功")
|
||||
return {
|
||||
'model': model,
|
||||
'tokenizer': tokenizer,
|
||||
'device': device
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"加载BERT-CTM模型失败: {e}")
|
||||
raise
|
||||
|
||||
def load_bcat_model(model_dir, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
||||
"""
|
||||
加载BCAT模型
|
||||
|
||||
参数:
|
||||
model_dir: 模型目录
|
||||
device: 计算设备
|
||||
|
||||
返回:
|
||||
包含模型和分词器的字典
|
||||
"""
|
||||
try:
|
||||
logger.info(f"加载BCAT模型: {model_dir}")
|
||||
|
||||
sys.path.append('model_pro')
|
||||
from BCAT import BCAT
|
||||
from transformers import BertTokenizer
|
||||
|
||||
# 加载模型配置
|
||||
config_path = os.path.join(model_dir, 'config.json')
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
# 初始化模型
|
||||
model = BCAT(**config)
|
||||
|
||||
# 加载模型权重
|
||||
model_path = os.path.join(model_dir, 'model.pt')
|
||||
model.load_state_dict(torch.load(model_path, map_location=device))
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# 加载分词器
|
||||
tokenizer_path = os.path.join(model_dir, 'tokenizer')
|
||||
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
logger.info("BCAT模型加载成功")
|
||||
return {
|
||||
'model': model,
|
||||
'tokenizer': tokenizer,
|
||||
'device': device,
|
||||
'config': config
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"加载BCAT模型失败: {e}")
|
||||
raise
|
||||
|
||||
def load_topic_classifier(model_dir, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
||||
"""
|
||||
加载话题分类模型
|
||||
|
||||
参数:
|
||||
model_dir: 模型目录
|
||||
device: 计算设备
|
||||
|
||||
返回:
|
||||
包含模型、分词器和标签映射的字典
|
||||
"""
|
||||
try:
|
||||
logger.info(f"加载话题分类模型: {model_dir}")
|
||||
|
||||
# 尝试加载transformers模型
|
||||
try:
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
# 加载模型
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# 加载分词器
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
|
||||
# 加载标签映射
|
||||
labels_path = os.path.join(model_dir, 'labels.json')
|
||||
if os.path.exists(labels_path):
|
||||
with open(labels_path, 'r', encoding='utf-8') as f:
|
||||
labels_map = json.load(f)
|
||||
else:
|
||||
# 尝试从config中读取标签
|
||||
if hasattr(model.config, 'id2label'):
|
||||
labels_map = model.config.id2label
|
||||
else:
|
||||
labels_map = {}
|
||||
|
||||
logger.info("话题分类模型加载成功 (transformers)")
|
||||
return {
|
||||
'model': model,
|
||||
'tokenizer': tokenizer,
|
||||
'labels_map': labels_map,
|
||||
'device': device
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"使用transformers加载失败,尝试其他方法: {e}")
|
||||
|
||||
# 尝试加载PyTorch模型
|
||||
model_path = os.path.join(model_dir, 'model.pt')
|
||||
if os.path.exists(model_path):
|
||||
model = torch.load(model_path, map_location=device)
|
||||
|
||||
# 加载分词器
|
||||
tokenizer_path = os.path.join(model_dir, 'tokenizer.pkl')
|
||||
if os.path.exists(tokenizer_path):
|
||||
with open(tokenizer_path, 'rb') as f:
|
||||
tokenizer = pickle.load(f)
|
||||
else:
|
||||
tokenizer = None
|
||||
|
||||
# 加载标签映射
|
||||
labels_path = os.path.join(model_dir, 'labels.json')
|
||||
if os.path.exists(labels_path):
|
||||
with open(labels_path, 'r', encoding='utf-8') as f:
|
||||
labels_map = json.load(f)
|
||||
else:
|
||||
labels_map = {}
|
||||
|
||||
logger.info("话题分类模型加载成功 (PyTorch)")
|
||||
return {
|
||||
'model': model,
|
||||
'tokenizer': tokenizer,
|
||||
'labels_map': labels_map,
|
||||
'device': device
|
||||
}
|
||||
|
||||
raise ValueError(f"无法加载模型: {model_dir}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载话题分类模型失败: {e}")
|
||||
raise
|
||||
|
||||
def load_echarts_optimizer():
|
||||
"""
|
||||
加载ECharts优化器,用于提升大数据渲染性能
|
||||
|
||||
返回:
|
||||
ECharts优化器对象
|
||||
"""
|
||||
try:
|
||||
class EChartsOptimizer:
|
||||
def __init__(self):
|
||||
self.chunk_size = 1000 # 分块大小
|
||||
logger.info("ECharts优化器初始化成功")
|
||||
|
||||
def optimize_option(self, option):
|
||||
"""优化ECharts配置,提升大数据渲染性能"""
|
||||
if not option:
|
||||
return option
|
||||
|
||||
# 深拷贝以避免修改原始对象
|
||||
import copy
|
||||
option = copy.deepcopy(option)
|
||||
|
||||
# 添加渐进式渲染
|
||||
if 'progressive' not in option:
|
||||
option['progressive'] = 300 # 每帧渲染的数据点数量
|
||||
|
||||
if 'progressiveThreshold' not in option:
|
||||
option['progressiveThreshold'] = 5000 # 启动渐进式渲染的阈值
|
||||
|
||||
if 'series' in option and isinstance(option['series'], list):
|
||||
for series in option['series']:
|
||||
# 对大数据系列应用优化
|
||||
if 'data' in series and isinstance(series['data'], list) and len(series['data']) > 5000:
|
||||
# 大数据采样
|
||||
if series.get('type') in ['scatter', 'line']:
|
||||
self._optimize_large_data_series(series)
|
||||
|
||||
return option
|
||||
|
||||
def _optimize_large_data_series(self, series):
|
||||
"""优化大数据系列"""
|
||||
# 添加大数据优化选项
|
||||
series['large'] = True
|
||||
series['largeThreshold'] = 2000
|
||||
|
||||
# 按需设置抽样
|
||||
if len(series['data']) > 50000:
|
||||
# 对非常大的数据集进行抽样
|
||||
step = max(1, len(series['data']) // 50000)
|
||||
series['data'] = series['data'][::step]
|
||||
series['sampling'] = 'average'
|
||||
|
||||
return series
|
||||
|
||||
def chunk_process_data(self, data, process_func):
|
||||
"""分块处理大数据"""
|
||||
result = []
|
||||
for i in range(0, len(data), self.chunk_size):
|
||||
chunk = data[i:i + self.chunk_size]
|
||||
result.extend(process_func(chunk))
|
||||
return result
|
||||
|
||||
return EChartsOptimizer()
|
||||
except Exception as e:
|
||||
logger.error(f"加载ECharts优化器失败: {e}")
|
||||
return None
|
||||
|
||||
# 导出所有加载函数
|
||||
__all__ = [
|
||||
'load_sentiment_model',
|
||||
'load_bert_ctm_model',
|
||||
'load_bcat_model',
|
||||
'load_topic_classifier',
|
||||
'load_echarts_optimizer'
|
||||
]
|
||||
@@ -0,0 +1,364 @@
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
import logging
|
||||
import gc
|
||||
import torch
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger('model_manager')
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class ModelManager:
|
||||
"""
|
||||
模型管理器 - 实现模型预加载和按需卸载技术
|
||||
|
||||
功能:
|
||||
1. 预加载经常使用的模型,减少加载等待时间
|
||||
2. 使用LRU (Least Recently Used) 策略管理内存中加载的模型
|
||||
3. 支持模型的异步加载和监控
|
||||
4. 自动检测并释放长时间未使用的模型内存
|
||||
5. 提供模型使用统计
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super(ModelManager, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if hasattr(self, 'initialized'):
|
||||
return
|
||||
|
||||
# 已加载模型的缓存,使用OrderedDict实现LRU
|
||||
self.loaded_models = OrderedDict()
|
||||
# 模型使用统计
|
||||
self.model_stats = {}
|
||||
# 模型预热配置
|
||||
self.preload_config = {}
|
||||
# 最大内存占用(GB)
|
||||
self.max_memory_usage = float(os.getenv('MAX_MODEL_MEMORY_USAGE', '4.0'))
|
||||
# 模型加载中的锁
|
||||
self.loading_locks = {}
|
||||
# 模型卸载超时(分钟)
|
||||
self.unload_timeout = int(os.getenv('MODEL_UNLOAD_TIMEOUT', '30'))
|
||||
|
||||
# 启动模型监控线程
|
||||
self.monitor_thread = threading.Thread(target=self._monitor_models, daemon=True)
|
||||
self.monitor_thread.start()
|
||||
|
||||
self.initialized = True
|
||||
logger.info(f"模型管理器初始化完成,最大内存占用: {self.max_memory_usage}GB")
|
||||
|
||||
def register_model(self, model_id, model_path, preload=False, model_size_gb=0.5,
|
||||
load_function=None, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
||||
"""
|
||||
注册模型,可选设置为预加载
|
||||
|
||||
参数:
|
||||
model_id: 模型唯一标识符
|
||||
model_path: 模型路径
|
||||
preload: 是否预加载
|
||||
model_size_gb: 模型估计大小(GB)
|
||||
load_function: 自定义加载函数,签名为 load_function(model_path, device) -> model
|
||||
device: 加载模型的设备
|
||||
"""
|
||||
self.preload_config[model_id] = {
|
||||
'model_path': model_path,
|
||||
'preload': preload,
|
||||
'model_size_gb': model_size_gb,
|
||||
'load_function': load_function,
|
||||
'device': device
|
||||
}
|
||||
|
||||
self.model_stats[model_id] = {
|
||||
'load_count': 0,
|
||||
'use_count': 0,
|
||||
'total_load_time': 0,
|
||||
'last_used': None,
|
||||
'avg_load_time': 0
|
||||
}
|
||||
|
||||
if preload:
|
||||
logger.info(f"模型 {model_id} 已注册并标记为预加载")
|
||||
# 启动预加载线程
|
||||
threading.Thread(target=self._preload_model, args=(model_id,), daemon=True).start()
|
||||
else:
|
||||
logger.info(f"模型 {model_id} 已注册")
|
||||
|
||||
return True
|
||||
|
||||
def get_model(self, model_id):
|
||||
"""
|
||||
获取模型,如果未加载则加载
|
||||
|
||||
参数:
|
||||
model_id: 模型唯一标识符
|
||||
|
||||
返回:
|
||||
加载好的模型对象
|
||||
"""
|
||||
if model_id not in self.preload_config:
|
||||
raise ValueError(f"模型 {model_id} 未注册")
|
||||
|
||||
# 更新最后使用时间
|
||||
self.model_stats[model_id]['last_used'] = datetime.now()
|
||||
self.model_stats[model_id]['use_count'] += 1
|
||||
|
||||
# 检查模型是否已加载
|
||||
if model_id in self.loaded_models:
|
||||
# 将模型移至OrderedDict末尾,表示最近使用
|
||||
model = self.loaded_models.pop(model_id)
|
||||
self.loaded_models[model_id] = model
|
||||
logger.debug(f"使用已加载的模型: {model_id}")
|
||||
return model
|
||||
|
||||
# 获取模型加载锁,防止并发加载同一模型
|
||||
if model_id not in self.loading_locks:
|
||||
self.loading_locks[model_id] = threading.Lock()
|
||||
|
||||
# 加锁加载模型
|
||||
with self.loading_locks[model_id]:
|
||||
# 再次检查模型是否已被其他线程加载
|
||||
if model_id in self.loaded_models:
|
||||
return self.loaded_models[model_id]
|
||||
|
||||
# 检查是否有足够内存
|
||||
self._ensure_memory_available(self.preload_config[model_id]['model_size_gb'])
|
||||
|
||||
# 加载模型
|
||||
start_time = time.time()
|
||||
model = self._load_model(model_id)
|
||||
load_time = time.time() - start_time
|
||||
|
||||
# 更新统计
|
||||
self.model_stats[model_id]['load_count'] += 1
|
||||
self.model_stats[model_id]['total_load_time'] += load_time
|
||||
self.model_stats[model_id]['avg_load_time'] = (
|
||||
self.model_stats[model_id]['total_load_time'] /
|
||||
self.model_stats[model_id]['load_count']
|
||||
)
|
||||
|
||||
logger.info(f"模型 {model_id} 加载完成,耗时: {load_time:.2f}秒")
|
||||
|
||||
# 存储模型
|
||||
self.loaded_models[model_id] = model
|
||||
return model
|
||||
|
||||
def unload_model(self, model_id):
|
||||
"""
|
||||
手动卸载模型
|
||||
|
||||
参数:
|
||||
model_id: 模型唯一标识符
|
||||
"""
|
||||
if model_id in self.loaded_models:
|
||||
logger.info(f"手动卸载模型: {model_id}")
|
||||
del self.loaded_models[model_id]
|
||||
# 强制垃圾回收
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_model_stats(self):
|
||||
"""获取所有模型的使用统计"""
|
||||
result = {}
|
||||
for model_id, stats in self.model_stats.items():
|
||||
is_loaded = model_id in self.loaded_models
|
||||
result[model_id] = {
|
||||
**stats,
|
||||
'is_loaded': is_loaded,
|
||||
'preload': self.preload_config[model_id]['preload'],
|
||||
'model_size_gb': self.preload_config[model_id]['model_size_gb'],
|
||||
'device': self.preload_config[model_id]['device'],
|
||||
}
|
||||
return result
|
||||
|
||||
def preload_all(self):
|
||||
"""预加载所有标记为预加载的模型"""
|
||||
for model_id, config in self.preload_config.items():
|
||||
if config['preload'] and model_id not in self.loaded_models:
|
||||
threading.Thread(target=self._preload_model, args=(model_id,), daemon=True).start()
|
||||
|
||||
def _preload_model(self, model_id):
|
||||
"""预加载单个模型的内部方法"""
|
||||
try:
|
||||
logger.info(f"开始预加载模型: {model_id}")
|
||||
# 确保有足够内存
|
||||
self._ensure_memory_available(self.preload_config[model_id]['model_size_gb'])
|
||||
|
||||
# 加载模型
|
||||
start_time = time.time()
|
||||
model = self._load_model(model_id)
|
||||
load_time = time.time() - start_time
|
||||
|
||||
# 更新统计
|
||||
self.model_stats[model_id]['load_count'] += 1
|
||||
self.model_stats[model_id]['total_load_time'] += load_time
|
||||
self.model_stats[model_id]['avg_load_time'] = (
|
||||
self.model_stats[model_id]['total_load_time'] /
|
||||
self.model_stats[model_id]['load_count']
|
||||
)
|
||||
|
||||
# 存储模型
|
||||
self.loaded_models[model_id] = model
|
||||
logger.info(f"模型 {model_id} 预加载完成,耗时: {load_time:.2f}秒")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"预加载模型 {model_id} 失败: {e}")
|
||||
|
||||
def _load_model(self, model_id):
|
||||
"""加载模型的内部方法"""
|
||||
config = self.preload_config[model_id]
|
||||
|
||||
if config['load_function'] is not None:
|
||||
# 使用自定义加载函数
|
||||
return config['load_function'](config['model_path'], config['device'])
|
||||
|
||||
# 默认加载逻辑 - 根据文件扩展名确定加载方式
|
||||
model_path = config['model_path']
|
||||
device = config['device']
|
||||
|
||||
if model_path.endswith('.pt') or model_path.endswith('.pth'):
|
||||
# PyTorch模型
|
||||
return torch.load(model_path, map_location=device)
|
||||
elif model_path.endswith('.pkl'):
|
||||
# Pickle模型
|
||||
import pickle
|
||||
with open(model_path, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
else:
|
||||
# 尝试作为目录加载
|
||||
if os.path.isdir(model_path):
|
||||
# 如果是目录,尝试加载预训练模型
|
||||
try:
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
model = AutoModel.from_pretrained(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
return {'model': model.to(device), 'tokenizer': tokenizer}
|
||||
except ImportError:
|
||||
logger.error("transformers库未安装,无法加载预训练模型")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"加载预训练模型失败: {e}")
|
||||
raise
|
||||
|
||||
raise ValueError(f"无法确定如何加载模型: {model_path}")
|
||||
|
||||
def _ensure_memory_available(self, required_gb):
|
||||
"""确保有足够的内存来加载新模型"""
|
||||
# 如果当前没有加载的模型,直接返回
|
||||
if not self.loaded_models:
|
||||
return
|
||||
|
||||
# 计算当前已加载模型的总内存
|
||||
current_usage = sum(
|
||||
self.preload_config[model_id]['model_size_gb']
|
||||
for model_id in self.loaded_models
|
||||
)
|
||||
|
||||
# 如果添加新模型后超过限制,需要卸载一些模型
|
||||
while current_usage + required_gb > self.max_memory_usage and self.loaded_models:
|
||||
# 卸载最久未使用的模型(OrderedDict的首项)
|
||||
oldest_model_id, _ = next(iter(self.loaded_models.items()))
|
||||
# 检查是否是预加载且最近使用过的模型
|
||||
if (self.preload_config[oldest_model_id]['preload'] and
|
||||
self.model_stats[oldest_model_id]['last_used'] and
|
||||
(datetime.now() - self.model_stats[oldest_model_id]['last_used']) <
|
||||
timedelta(minutes=self.unload_timeout)):
|
||||
# 跳过预加载且最近使用过的模型
|
||||
# 将该模型移至OrderedDict末尾
|
||||
model = self.loaded_models.pop(oldest_model_id)
|
||||
self.loaded_models[oldest_model_id] = model
|
||||
# 如果所有模型都是预加载的且最近使用过,允许超过限制
|
||||
if len(self.loaded_models) <= 1:
|
||||
break
|
||||
continue
|
||||
|
||||
# 卸载模型并更新内存使用
|
||||
model_size = self.preload_config[oldest_model_id]['model_size_gb']
|
||||
del self.loaded_models[oldest_model_id]
|
||||
current_usage -= model_size
|
||||
logger.info(f"自动卸载模型以释放内存: {oldest_model_id} ({model_size}GB)")
|
||||
|
||||
# 强制垃圾回收
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _monitor_models(self):
|
||||
"""监控并管理模型的内部线程方法"""
|
||||
while True:
|
||||
try:
|
||||
# 检查长时间未使用的非预加载模型
|
||||
current_time = datetime.now()
|
||||
for model_id in list(self.loaded_models.keys()):
|
||||
if (not self.preload_config[model_id]['preload'] and
|
||||
self.model_stats[model_id]['last_used'] and
|
||||
(current_time - self.model_stats[model_id]['last_used']) >
|
||||
timedelta(minutes=self.unload_timeout)):
|
||||
# 卸载长时间未使用的非预加载模型
|
||||
logger.info(f"卸载长时间未使用的模型: {model_id}")
|
||||
del self.loaded_models[model_id]
|
||||
# 强制垃圾回收
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# 每5分钟检查一次
|
||||
time.sleep(300)
|
||||
except Exception as e:
|
||||
logger.error(f"模型监控线程出错: {e}")
|
||||
time.sleep(300)
|
||||
|
||||
# 创建全局模型管理器实例
|
||||
model_manager = ModelManager()
|
||||
|
||||
# 注册示例函数
|
||||
def register_sentiment_model():
|
||||
"""注册情感分析模型示例"""
|
||||
from utils.model_loader import load_sentiment_model # 假设您有一个加载情感模型的函数
|
||||
|
||||
try:
|
||||
model_path = os.path.join('model', 'sentiment.marshal.3')
|
||||
model_manager.register_model(
|
||||
model_id='sentiment_basic',
|
||||
model_path=model_path,
|
||||
preload=True,
|
||||
model_size_gb=0.2,
|
||||
load_function=load_sentiment_model
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"注册情感分析模型失败: {e}")
|
||||
return False
|
||||
|
||||
def register_bert_model():
|
||||
"""注册BERT模型示例"""
|
||||
try:
|
||||
model_path = os.path.join('model_pro', 'bert_model')
|
||||
model_manager.register_model(
|
||||
model_id='bert_classifier',
|
||||
model_path=model_path,
|
||||
preload=True,
|
||||
model_size_gb=0.8
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"注册BERT模型失败: {e}")
|
||||
return False
|
||||
|
||||
# 自动注册常用模型(在导入时执行)
|
||||
try:
|
||||
register_sentiment_model()
|
||||
register_bert_model()
|
||||
except Exception as e:
|
||||
logger.error(f"自动注册模型失败: {e}")
|
||||
@@ -0,0 +1,494 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
import random
|
||||
import torch
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any, Tuple, Optional, Union, Callable
|
||||
|
||||
logger = logging.getLogger('model_router')
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class ModelRouter:
|
||||
"""
|
||||
模型路由器 - 自动根据内容类型选择最优的AI模型
|
||||
|
||||
功能:
|
||||
1. 根据内容类型和任务需求,自动选择最合适的AI模型
|
||||
2. 支持多种模型供应商和模型类型
|
||||
3. 考虑性能、成本和准确度等因素进行智能路由
|
||||
4. 学习和适应用户偏好和使用模式
|
||||
5. 提供标准化的API接口,支持私有模型集成
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(ModelRouter, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# 支持的模型定义
|
||||
self.models = {
|
||||
# OpenAI 模型
|
||||
'gpt-4o-latest': {
|
||||
'provider': 'openai',
|
||||
'capabilities': {
|
||||
'text_analysis': 0.95,
|
||||
'sentiment_analysis': 0.92,
|
||||
'keyword_extraction': 0.90,
|
||||
'summarization': 0.93,
|
||||
'classification': 0.89,
|
||||
'chinese_text': 0.88
|
||||
},
|
||||
'cost_per_1k': 0.01,
|
||||
'max_tokens': 128000,
|
||||
'avg_latency': 2.5, # 秒
|
||||
'requires_api_key': 'OPENAI_API_KEY'
|
||||
},
|
||||
'gpt-4o-mini': {
|
||||
'provider': 'openai',
|
||||
'capabilities': {
|
||||
'text_analysis': 0.85,
|
||||
'sentiment_analysis': 0.82,
|
||||
'keyword_extraction': 0.80,
|
||||
'summarization': 0.84,
|
||||
'classification': 0.81,
|
||||
'chinese_text': 0.79
|
||||
},
|
||||
'cost_per_1k': 0.00015,
|
||||
'max_tokens': 4000,
|
||||
'avg_latency': 1.2,
|
||||
'requires_api_key': 'OPENAI_API_KEY'
|
||||
},
|
||||
'gpt-3.5-turbo': {
|
||||
'provider': 'openai',
|
||||
'capabilities': {
|
||||
'text_analysis': 0.75,
|
||||
'sentiment_analysis': 0.72,
|
||||
'keyword_extraction': 0.70,
|
||||
'summarization': 0.77,
|
||||
'classification': 0.73,
|
||||
'chinese_text': 0.65
|
||||
},
|
||||
'cost_per_1k': 0.0015,
|
||||
'max_tokens': 16000,
|
||||
'avg_latency': 0.8,
|
||||
'requires_api_key': 'OPENAI_API_KEY'
|
||||
},
|
||||
|
||||
# Claude 模型
|
||||
'claude-3.5-sonnet': {
|
||||
'provider': 'anthropic',
|
||||
'capabilities': {
|
||||
'text_analysis': 0.90,
|
||||
'sentiment_analysis': 0.91,
|
||||
'keyword_extraction': 0.85,
|
||||
'summarization': 0.92,
|
||||
'classification': 0.89,
|
||||
'chinese_text': 0.80
|
||||
},
|
||||
'cost_per_1k': 0.015,
|
||||
'max_tokens': 200000,
|
||||
'avg_latency': 2.8,
|
||||
'requires_api_key': 'ANTHROPIC_API_KEY'
|
||||
},
|
||||
'claude-3.5-haiku': {
|
||||
'provider': 'anthropic',
|
||||
'capabilities': {
|
||||
'text_analysis': 0.84,
|
||||
'sentiment_analysis': 0.83,
|
||||
'keyword_extraction': 0.79,
|
||||
'summarization': 0.85,
|
||||
'classification': 0.80,
|
||||
'chinese_text': 0.72
|
||||
},
|
||||
'cost_per_1k': 0.0025,
|
||||
'max_tokens': 200000,
|
||||
'avg_latency': 1.5,
|
||||
'requires_api_key': 'ANTHROPIC_API_KEY'
|
||||
},
|
||||
|
||||
# DeepSeek 模型
|
||||
'deepseek-chat': {
|
||||
'provider': 'deepseek',
|
||||
'capabilities': {
|
||||
'text_analysis': 0.82,
|
||||
'sentiment_analysis': 0.79,
|
||||
'keyword_extraction': 0.77,
|
||||
'summarization': 0.80,
|
||||
'classification': 0.77,
|
||||
'chinese_text': 0.90 # 特别好中文
|
||||
},
|
||||
'cost_per_1k': 0.002,
|
||||
'max_tokens': 4000,
|
||||
'avg_latency': 1.0,
|
||||
'requires_api_key': 'DEEPSEEK_API_KEY'
|
||||
},
|
||||
'deepseek-reasoner': {
|
||||
'provider': 'deepseek',
|
||||
'capabilities': {
|
||||
'text_analysis': 0.87,
|
||||
'sentiment_analysis': 0.75,
|
||||
'keyword_extraction': 0.76,
|
||||
'summarization': 0.78,
|
||||
'classification': 0.85,
|
||||
'chinese_text': 0.88
|
||||
},
|
||||
'cost_per_1k': 0.003,
|
||||
'max_tokens': 4000,
|
||||
'avg_latency': 1.8,
|
||||
'requires_api_key': 'DEEPSEEK_API_KEY'
|
||||
}
|
||||
}
|
||||
|
||||
# 任务类型定义
|
||||
self.task_types = {
|
||||
'sentiment_analysis': {
|
||||
'description': '情感分析',
|
||||
'key_capabilities': ['sentiment_analysis', 'text_analysis'],
|
||||
'example_prompt': '分析以下文本的情感倾向(积极、消极或中性)'
|
||||
},
|
||||
'topic_classification': {
|
||||
'description': '话题分类',
|
||||
'key_capabilities': ['classification', 'text_analysis'],
|
||||
'example_prompt': '将以下文本分类到最合适的话题类别'
|
||||
},
|
||||
'keyword_extraction': {
|
||||
'description': '关键词提取',
|
||||
'key_capabilities': ['keyword_extraction', 'text_analysis'],
|
||||
'example_prompt': '从以下文本中提取5个最重要的关键词'
|
||||
},
|
||||
'text_summarization': {
|
||||
'description': '文本摘要',
|
||||
'key_capabilities': ['summarization', 'text_analysis'],
|
||||
'example_prompt': '为以下文本生成一个简短的摘要'
|
||||
},
|
||||
'comprehensive_analysis': {
|
||||
'description': '综合分析',
|
||||
'key_capabilities': ['text_analysis', 'sentiment_analysis', 'keyword_extraction', 'summarization'],
|
||||
'example_prompt': '对以下文本进行全面分析,包括情感、关键词和主要观点'
|
||||
}
|
||||
}
|
||||
|
||||
# 用户偏好和使用历史
|
||||
self.usage_history = defaultdict(list)
|
||||
|
||||
# 模型可用性缓存
|
||||
self.available_models = {}
|
||||
|
||||
# 更新模型可用性
|
||||
self._update_available_models()
|
||||
|
||||
self._initialized = True
|
||||
logger.info("模型路由器初始化完成")
|
||||
|
||||
def _update_available_models(self):
|
||||
"""更新模型可用性"""
|
||||
self.available_models = {}
|
||||
|
||||
for model_id, model_info in self.models.items():
|
||||
# 检查API密钥是否可用
|
||||
api_key_env = model_info.get('requires_api_key')
|
||||
if api_key_env and os.getenv(api_key_env):
|
||||
self.available_models[model_id] = model_info
|
||||
|
||||
if not self.available_models:
|
||||
logger.warning("未找到可用的模型,请检查API密钥配置")
|
||||
else:
|
||||
logger.info(f"找到 {len(self.available_models)} 个可用模型")
|
||||
|
||||
def detect_content_type(self, text: str) -> Dict[str, float]:
|
||||
"""
|
||||
检测内容类型和特征
|
||||
|
||||
参数:
|
||||
text: 要分析的文本
|
||||
|
||||
返回:
|
||||
内容类型特征字典,键为特征名称,值为权重
|
||||
"""
|
||||
features = {
|
||||
'chinese_text': 0.0,
|
||||
'length': 0.0,
|
||||
'complexity': 0.0
|
||||
}
|
||||
|
||||
if not text:
|
||||
return features
|
||||
|
||||
# 检测中文比例
|
||||
chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
|
||||
total_chars = len(text)
|
||||
chinese_ratio = chinese_chars / total_chars if total_chars > 0 else 0
|
||||
|
||||
# 文本长度评分 (归一化至0-1)
|
||||
length_score = min(1.0, len(text) / 10000)
|
||||
|
||||
# 文本复杂度简单估计
|
||||
# 基于句子长度、词汇多样性等
|
||||
sentences = re.split(r'[.!?。!?]', text)
|
||||
avg_sentence_len = sum(len(s) for s in sentences) / len(sentences) if sentences else 0
|
||||
unique_words = len(set(re.findall(r'\w+', text.lower())))
|
||||
total_words = len(re.findall(r'\w+', text.lower()))
|
||||
|
||||
lexical_diversity = unique_words / total_words if total_words > 0 else 0
|
||||
complexity_score = (avg_sentence_len / 50 + lexical_diversity) / 2
|
||||
complexity_score = min(1.0, complexity_score)
|
||||
|
||||
features['chinese_text'] = chinese_ratio
|
||||
features['length'] = length_score
|
||||
features['complexity'] = complexity_score
|
||||
|
||||
return features
|
||||
|
||||
def select_model(self, text: str, task_type: str,
|
||||
optimize_for: str = 'balanced',
|
||||
exclude_models: List[str] = None) -> str:
|
||||
"""
|
||||
为给定文本和任务选择最合适的模型
|
||||
|
||||
参数:
|
||||
text: 要处理的文本
|
||||
task_type: 任务类型,如 'sentiment_analysis'
|
||||
optimize_for: 优化目标,可选值:'cost'(成本), 'performance'(性能), 'balanced'(平衡)
|
||||
exclude_models: 要排除的模型列表
|
||||
|
||||
返回:
|
||||
选择的模型ID
|
||||
"""
|
||||
if not self.available_models:
|
||||
logger.error("没有可用的模型,请检查API密钥配置")
|
||||
return None
|
||||
|
||||
if task_type not in self.task_types:
|
||||
logger.warning(f"未知的任务类型: {task_type},使用默认任务类型: 'comprehensive_analysis'")
|
||||
task_type = 'comprehensive_analysis'
|
||||
|
||||
# 获取内容特征
|
||||
content_features = self.detect_content_type(text)
|
||||
|
||||
# 获取任务关键能力
|
||||
task_capabilities = self.task_types[task_type]['key_capabilities']
|
||||
|
||||
# 计算每个模型的得分
|
||||
model_scores = {}
|
||||
exclude_models = exclude_models or []
|
||||
|
||||
for model_id, model_info in self.available_models.items():
|
||||
if model_id in exclude_models:
|
||||
continue
|
||||
|
||||
# 基于任务能力的得分
|
||||
capability_score = 0
|
||||
for capability in task_capabilities:
|
||||
capability_score += model_info['capabilities'].get(capability, 0)
|
||||
|
||||
capability_score /= len(task_capabilities)
|
||||
|
||||
# 基于内容特征的得分调整
|
||||
content_score = 1.0
|
||||
|
||||
# 如果有大量中文,增加中文能力的权重
|
||||
if content_features['chinese_text'] > 0.5:
|
||||
chinese_capability = model_info['capabilities'].get('chinese_text', 0)
|
||||
content_score *= (1.0 + chinese_capability) / 2
|
||||
|
||||
# 如果文本很长,检查模型的最大token限制
|
||||
if content_features['length'] > 0.7:
|
||||
max_tokens = model_info.get('max_tokens', 4000)
|
||||
if max_tokens < 10000:
|
||||
content_score *= 0.7 # 长文本降低短上下文模型的分数
|
||||
|
||||
# 如果文本很复杂,可能需要更强大的模型
|
||||
if content_features['complexity'] > 0.7:
|
||||
# 假设能力得分更高的模型更能处理复杂文本
|
||||
content_score *= (1.0 + capability_score) / 2
|
||||
|
||||
# 根据优化目标调整最终得分
|
||||
final_score = capability_score * content_score
|
||||
|
||||
if optimize_for == 'cost':
|
||||
# 成本越低,分数越高
|
||||
cost_factor = 1 - min(1.0, model_info.get('cost_per_1k', 0) / 0.03)
|
||||
final_score = final_score * 0.3 + cost_factor * 0.7
|
||||
elif optimize_for == 'performance':
|
||||
# 能力得分权重更高
|
||||
final_score = capability_score * 0.8 + content_score * 0.2
|
||||
# balanced 是默认值,不需要额外调整
|
||||
|
||||
model_scores[model_id] = final_score
|
||||
|
||||
if not model_scores:
|
||||
logger.warning("没有符合条件的可用模型")
|
||||
return list(self.available_models.keys())[0]
|
||||
|
||||
# 选择得分最高的模型
|
||||
selected_model = max(model_scores, key=model_scores.get)
|
||||
|
||||
# 记录使用历史
|
||||
self.usage_history[task_type].append({
|
||||
'model': selected_model,
|
||||
'timestamp': datetime.now().timestamp(),
|
||||
'score': model_scores[selected_model],
|
||||
'optimize_for': optimize_for
|
||||
})
|
||||
|
||||
logger.info(f"为任务 '{task_type}' 选择了模型: {selected_model} (得分: {model_scores[selected_model]:.4f})")
|
||||
return selected_model
|
||||
|
||||
def get_model_info(self, model_id: str) -> Dict:
|
||||
"""获取模型信息"""
|
||||
if model_id in self.models:
|
||||
return self.models[model_id]
|
||||
return None
|
||||
|
||||
def get_available_models(self, refresh: bool = False) -> Dict[str, Dict]:
|
||||
"""获取所有可用的模型"""
|
||||
if refresh:
|
||||
self._update_available_models()
|
||||
return self.available_models
|
||||
|
||||
def get_model_by_provider(self, provider: str, optimize_for: str = 'balanced') -> str:
|
||||
"""根据提供商获取推荐模型"""
|
||||
provider_models = {
|
||||
model_id: info for model_id, info in self.available_models.items()
|
||||
if info['provider'] == provider
|
||||
}
|
||||
|
||||
if not provider_models:
|
||||
logger.warning(f"未找到提供商 '{provider}' 的可用模型")
|
||||
return None
|
||||
|
||||
if optimize_for == 'cost':
|
||||
# 选择成本最低的模型
|
||||
return min(provider_models.items(), key=lambda x: x[1].get('cost_per_1k', float('inf')))[0]
|
||||
elif optimize_for == 'performance':
|
||||
# 选择性能最好的模型,简单取所有能力的平均值
|
||||
return max(provider_models.items(),
|
||||
key=lambda x: sum(x[1]['capabilities'].values()) / len(x[1]['capabilities']))[0]
|
||||
else:
|
||||
# 平衡模式,综合考虑成本和性能
|
||||
scores = {}
|
||||
for model_id, info in provider_models.items():
|
||||
perf_score = sum(info['capabilities'].values()) / len(info['capabilities'])
|
||||
cost_score = 1 - min(1.0, info.get('cost_per_1k', 0) / 0.03)
|
||||
scores[model_id] = perf_score * 0.5 + cost_score * 0.5
|
||||
|
||||
return max(scores, key=scores.get)
|
||||
|
||||
def get_task_types(self) -> Dict[str, Dict]:
|
||||
"""获取支持的任务类型"""
|
||||
return self.task_types
|
||||
|
||||
def register_custom_model(self, model_id: str, model_info: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
注册自定义模型
|
||||
|
||||
参数:
|
||||
model_id: 模型唯一标识符
|
||||
model_info: 模型信息字典,包含以下字段:
|
||||
- provider: 提供商名称
|
||||
- capabilities: 能力评分字典
|
||||
- cost_per_1k: 每千token的成本
|
||||
- max_tokens: 最大token限制
|
||||
- avg_latency: 平均延迟(秒)
|
||||
- requires_api_key: API密钥环境变量名
|
||||
|
||||
返回:
|
||||
是否注册成功
|
||||
"""
|
||||
# 验证必要字段
|
||||
required_fields = ['provider', 'capabilities', 'cost_per_1k', 'max_tokens']
|
||||
for field in required_fields:
|
||||
if field not in model_info:
|
||||
logger.error(f"注册模型失败: 缺少必要字段 '{field}'")
|
||||
return False
|
||||
|
||||
# 验证能力评分
|
||||
if not isinstance(model_info['capabilities'], dict):
|
||||
logger.error("注册模型失败: 'capabilities' 必须是字典")
|
||||
return False
|
||||
|
||||
# 添加模型
|
||||
self.models[model_id] = model_info
|
||||
|
||||
# 更新可用模型列表
|
||||
self._update_available_models()
|
||||
|
||||
logger.info(f"成功注册自定义模型: {model_id}")
|
||||
return True
|
||||
|
||||
# 创建全局模型路由器实例
|
||||
model_router = ModelRouter()
|
||||
|
||||
def select_model(text, task_type, optimize_for='balanced', exclude_models=None):
|
||||
"""选择最合适的模型"""
|
||||
return model_router.select_model(text, task_type, optimize_for, exclude_models)
|
||||
|
||||
def get_available_models(refresh=False):
|
||||
"""获取所有可用的模型"""
|
||||
return model_router.get_available_models(refresh)
|
||||
|
||||
def get_model_by_provider(provider, optimize_for='balanced'):
|
||||
"""根据提供商获取推荐模型"""
|
||||
return model_router.get_model_by_provider(provider, optimize_for)
|
||||
|
||||
def get_task_types():
|
||||
"""获取支持的任务类型"""
|
||||
return model_router.get_task_types()
|
||||
|
||||
def register_custom_model(model_id, model_info):
|
||||
"""注册自定义模型"""
|
||||
return model_router.register_custom_model(model_id, model_info)
|
||||
|
||||
# 示例用法
|
||||
if __name__ == "__main__":
|
||||
# 示例文本
|
||||
chinese_text = """
|
||||
近日,人工智能技术的发展引发广泛关注。
|
||||
专家指出,大型语言模型在自然语言处理领域取得了显著进展,
|
||||
但同时也带来了诸多伦理和安全问题。对此,业界呼吁加强监管,
|
||||
确保人工智能的发展能够造福人类社会。
|
||||
"""
|
||||
|
||||
english_text = """
|
||||
Recent developments in artificial intelligence technology have drawn widespread attention.
|
||||
Experts point out that large language models have made significant progress in the field of natural language processing,
|
||||
but also bring many ethical and security issues. In response, the industry calls for strengthened regulation
|
||||
to ensure that the development of artificial intelligence can benefit human society.
|
||||
"""
|
||||
|
||||
# 测试模型选择
|
||||
print("中文文本任务测试:")
|
||||
model_for_chinese = select_model(chinese_text, 'sentiment_analysis')
|
||||
print(f"选择的模型: {model_for_chinese}")
|
||||
|
||||
print("\n英文文本任务测试:")
|
||||
model_for_english = select_model(english_text, 'sentiment_analysis')
|
||||
print(f"选择的模型: {model_for_english}")
|
||||
|
||||
print("\n成本优化测试:")
|
||||
model_for_cost = select_model(chinese_text, 'sentiment_analysis', optimize_for='cost')
|
||||
print(f"选择的模型: {model_for_cost}")
|
||||
|
||||
print("\n性能优化测试:")
|
||||
model_for_perf = select_model(chinese_text, 'sentiment_analysis', optimize_for='performance')
|
||||
print(f"选择的模型: {model_for_perf}")
|
||||
|
||||
# 测试API提供商
|
||||
print("\n根据提供商获取模型:")
|
||||
for provider in ['openai', 'anthropic', 'deepseek']:
|
||||
model = get_model_by_provider(provider)
|
||||
if model:
|
||||
print(f"{provider}: {model}")
|
||||
else:
|
||||
print(f"{provider}: 无可用模型")
|
||||
@@ -0,0 +1,357 @@
|
||||
import re
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger('sensitive_filter')
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class SensitiveDataFilter:
|
||||
"""
|
||||
敏感数据过滤器 - 用于检测和屏蔽输出内容中的敏感信息
|
||||
|
||||
功能:
|
||||
1. 自动识别并过滤手机号、邮箱、身份证号、信用卡号等敏感信息
|
||||
2. 支持自定义敏感信息模式和替换文本
|
||||
3. 提供批量处理和实时过滤功能
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(SensitiveDataFilter, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# 默认配置
|
||||
self.config = {
|
||||
'enabled': os.getenv('ENABLE_SENSITIVE_DATA_FILTER', 'true').lower() == 'true',
|
||||
'patterns': {
|
||||
'phone': r'\b1[3-9]\d{9}\b',
|
||||
'email': r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
|
||||
'id_card': r'\b[1-9]\d{5}(19|20)\d{2}(0[1-9]|1[0-2])(0[1-9]|[12]\d|3[01])\d{3}[\dXx]\b',
|
||||
'credit_card': r'\b\d{4}[ -]?\d{4}[ -]?\d{4}[ -]?\d{4}\b',
|
||||
'address': r'(北京|上海|广州|深圳|天津|重庆|南京|杭州|武汉|成都|西安)市.*?(路|街|道|巷).*?(号)'
|
||||
},
|
||||
'replacements': {
|
||||
'phone': '***********',
|
||||
'email': '******@*****',
|
||||
'id_card': '******************',
|
||||
'credit_card': '****************',
|
||||
'address': '[地址已隐藏]'
|
||||
}
|
||||
}
|
||||
|
||||
# 加载自定义配置
|
||||
self._load_config()
|
||||
|
||||
# 编译正则表达式
|
||||
self._compile_patterns()
|
||||
|
||||
self._initialized = True
|
||||
|
||||
logger.info("敏感数据过滤器初始化完成")
|
||||
if self.config['enabled']:
|
||||
logger.info(f"已启用以下类型的敏感数据过滤: {', '.join(self.config['patterns'].keys())}")
|
||||
else:
|
||||
logger.info("敏感数据过滤器已禁用")
|
||||
|
||||
def _load_config(self):
|
||||
"""加载自定义配置"""
|
||||
# 配置文件路径
|
||||
data_dir = os.getenv('DATA_DIR', 'data')
|
||||
config_path = os.path.join(data_dir, 'security', 'sensitive_filter.json')
|
||||
|
||||
if os.path.exists(config_path):
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
custom_config = json.load(f)
|
||||
|
||||
# 更新配置
|
||||
if 'enabled' in custom_config:
|
||||
self.config['enabled'] = custom_config['enabled']
|
||||
|
||||
if 'patterns' in custom_config:
|
||||
for key, pattern in custom_config['patterns'].items():
|
||||
self.config['patterns'][key] = pattern
|
||||
|
||||
if 'replacements' in custom_config:
|
||||
for key, replacement in custom_config['replacements'].items():
|
||||
self.config['replacements'][key] = replacement
|
||||
|
||||
logger.info(f"已加载自定义敏感数据过滤配置: {config_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载敏感数据过滤配置失败: {e}")
|
||||
|
||||
def _compile_patterns(self):
|
||||
"""编译正则表达式"""
|
||||
self.compiled_patterns = {}
|
||||
for key, pattern in self.config['patterns'].items():
|
||||
try:
|
||||
self.compiled_patterns[key] = re.compile(pattern)
|
||||
logger.debug(f"已编译敏感数据模式: {key} - {pattern}")
|
||||
except re.error as e:
|
||||
logger.error(f"编译敏感数据模式失败: {key} - {pattern}: {e}")
|
||||
|
||||
def filter_text(self, text):
|
||||
"""
|
||||
过滤文本中的敏感信息
|
||||
|
||||
参数:
|
||||
text: 要过滤的文本
|
||||
|
||||
返回:
|
||||
过滤后的文本
|
||||
"""
|
||||
if not self.config['enabled'] or not text:
|
||||
return text
|
||||
|
||||
filtered_text = text
|
||||
for key, pattern in self.compiled_patterns.items():
|
||||
replacement = self.config['replacements'].get(key, '[FILTERED]')
|
||||
filtered_text = pattern.sub(replacement, filtered_text)
|
||||
|
||||
return filtered_text
|
||||
|
||||
def filter_dict(self, data, *skip_keys):
|
||||
"""
|
||||
过滤字典中的敏感信息
|
||||
|
||||
参数:
|
||||
data: 要过滤的字典
|
||||
skip_keys: 要跳过的键(不进行过滤)
|
||||
|
||||
返回:
|
||||
过滤后的字典
|
||||
"""
|
||||
if not self.config['enabled'] or not data:
|
||||
return data
|
||||
|
||||
if not isinstance(data, dict):
|
||||
if isinstance(data, str):
|
||||
return self.filter_text(data)
|
||||
return data
|
||||
|
||||
filtered_data = {}
|
||||
for key, value in data.items():
|
||||
if key in skip_keys:
|
||||
filtered_data[key] = value
|
||||
continue
|
||||
|
||||
if isinstance(value, dict):
|
||||
filtered_data[key] = self.filter_dict(value, *skip_keys)
|
||||
elif isinstance(value, list):
|
||||
filtered_data[key] = [
|
||||
self.filter_dict(item, *skip_keys) if isinstance(item, (dict, list)) else
|
||||
self.filter_text(item) if isinstance(item, str) else item
|
||||
for item in value
|
||||
]
|
||||
elif isinstance(value, str):
|
||||
filtered_data[key] = self.filter_text(value)
|
||||
else:
|
||||
filtered_data[key] = value
|
||||
|
||||
return filtered_data
|
||||
|
||||
def filter_list(self, data, *skip_keys):
|
||||
"""
|
||||
过滤列表中的敏感信息
|
||||
|
||||
参数:
|
||||
data: 要过滤的列表
|
||||
skip_keys: 如果列表项是字典,要跳过的键
|
||||
|
||||
返回:
|
||||
过滤后的列表
|
||||
"""
|
||||
if not self.config['enabled'] or not data:
|
||||
return data
|
||||
|
||||
if not isinstance(data, list):
|
||||
if isinstance(data, dict):
|
||||
return self.filter_dict(data, *skip_keys)
|
||||
if isinstance(data, str):
|
||||
return self.filter_text(data)
|
||||
return data
|
||||
|
||||
return [
|
||||
self.filter_dict(item, *skip_keys) if isinstance(item, dict) else
|
||||
self.filter_list(item, *skip_keys) if isinstance(item, list) else
|
||||
self.filter_text(item) if isinstance(item, str) else item
|
||||
for item in data
|
||||
]
|
||||
|
||||
def is_sensitive_info(self, text, info_type=None):
|
||||
"""
|
||||
检查文本是否包含敏感信息
|
||||
|
||||
参数:
|
||||
text: 要检查的文本
|
||||
info_type: 指定要检查的敏感信息类型,如果为None则检查所有类型
|
||||
|
||||
返回:
|
||||
包含敏感信息返回True,否则返回False
|
||||
"""
|
||||
if not self.config['enabled'] or not text:
|
||||
return False
|
||||
|
||||
if info_type:
|
||||
if info_type not in self.compiled_patterns:
|
||||
logger.warning(f"未知的敏感信息类型: {info_type}")
|
||||
return False
|
||||
return bool(self.compiled_patterns[info_type].search(text))
|
||||
|
||||
for pattern in self.compiled_patterns.values():
|
||||
if pattern.search(text):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_sensitive_info_types(self, text):
|
||||
"""
|
||||
获取文本中包含的敏感信息类型
|
||||
|
||||
参数:
|
||||
text: 要检查的文本
|
||||
|
||||
返回:
|
||||
包含的敏感信息类型列表
|
||||
"""
|
||||
if not self.config['enabled'] or not text:
|
||||
return []
|
||||
|
||||
types = []
|
||||
for key, pattern in self.compiled_patterns.items():
|
||||
if pattern.search(text):
|
||||
types.append(key)
|
||||
|
||||
return types
|
||||
|
||||
def enable(self):
|
||||
"""启用敏感数据过滤器"""
|
||||
self.config['enabled'] = True
|
||||
logger.info("敏感数据过滤器已启用")
|
||||
|
||||
def disable(self):
|
||||
"""禁用敏感数据过滤器"""
|
||||
self.config['enabled'] = False
|
||||
logger.info("敏感数据过滤器已禁用")
|
||||
|
||||
def is_enabled(self):
|
||||
"""检查敏感数据过滤器是否启用"""
|
||||
return self.config['enabled']
|
||||
|
||||
def add_pattern(self, key, pattern, replacement='[FILTERED]'):
|
||||
"""
|
||||
添加自定义敏感信息模式
|
||||
|
||||
参数:
|
||||
key: 敏感信息类型标识
|
||||
pattern: 正则表达式字符串
|
||||
replacement: 替换文本
|
||||
"""
|
||||
try:
|
||||
# 测试是否是有效的正则表达式
|
||||
re.compile(pattern)
|
||||
|
||||
# 更新配置
|
||||
self.config['patterns'][key] = pattern
|
||||
self.config['replacements'][key] = replacement
|
||||
|
||||
# 重新编译正则表达式
|
||||
self._compile_patterns()
|
||||
|
||||
logger.info(f"已添加敏感信息模式: {key}")
|
||||
return True
|
||||
except re.error as e:
|
||||
logger.error(f"添加敏感信息模式失败: {key} - {pattern}: {e}")
|
||||
return False
|
||||
|
||||
def remove_pattern(self, key):
|
||||
"""
|
||||
移除敏感信息模式
|
||||
|
||||
参数:
|
||||
key: 敏感信息类型标识
|
||||
"""
|
||||
if key in self.config['patterns']:
|
||||
del self.config['patterns'][key]
|
||||
|
||||
if key in self.config['replacements']:
|
||||
del self.config['replacements'][key]
|
||||
|
||||
if key in self.compiled_patterns:
|
||||
del self.compiled_patterns[key]
|
||||
|
||||
logger.info(f"已移除敏感信息模式: {key}")
|
||||
return True
|
||||
|
||||
logger.warning(f"未找到敏感信息模式: {key}")
|
||||
return False
|
||||
|
||||
def save_config(self):
|
||||
"""保存当前配置到文件"""
|
||||
data_dir = os.getenv('DATA_DIR', 'data')
|
||||
security_dir = os.path.join(data_dir, 'security')
|
||||
os.makedirs(security_dir, exist_ok=True)
|
||||
|
||||
config_path = os.path.join(security_dir, 'sensitive_filter.json')
|
||||
|
||||
try:
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.config, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"敏感数据过滤配置已保存到: {config_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"保存敏感数据过滤配置失败: {e}")
|
||||
return False
|
||||
|
||||
# 创建全局敏感数据过滤器实例
|
||||
sensitive_filter = SensitiveDataFilter()
|
||||
|
||||
# 提供便捷的过滤函数
|
||||
def filter_text(text):
|
||||
"""过滤文本中的敏感信息"""
|
||||
return sensitive_filter.filter_text(text)
|
||||
|
||||
def filter_dict(data, *skip_keys):
|
||||
"""过滤字典中的敏感信息"""
|
||||
return sensitive_filter.filter_dict(data, *skip_keys)
|
||||
|
||||
def filter_list(data, *skip_keys):
|
||||
"""过滤列表中的敏感信息"""
|
||||
return sensitive_filter.filter_list(data, *skip_keys)
|
||||
|
||||
def is_sensitive_info(text, info_type=None):
|
||||
"""检查文本是否包含敏感信息"""
|
||||
return sensitive_filter.is_sensitive_info(text, info_type)
|
||||
|
||||
# 示例用法
|
||||
if __name__ == "__main__":
|
||||
# 测试文本
|
||||
test_text = """
|
||||
联系人: 张三
|
||||
电话: 13812345678
|
||||
邮箱: zhangsan@example.com
|
||||
身份证: 110101199001011234
|
||||
地址: 北京市海淀区中关村大街20号
|
||||
信用卡: 6225 1234 5678 9012
|
||||
"""
|
||||
|
||||
# 过滤敏感信息
|
||||
filtered_text = filter_text(test_text)
|
||||
print("原始文本:")
|
||||
print(test_text)
|
||||
print("\n过滤后:")
|
||||
print(filtered_text)
|
||||
|
||||
# 检查敏感信息类型
|
||||
types = sensitive_filter.get_sensitive_info_types(test_text)
|
||||
print(f"\n包含的敏感信息类型: {types}")
|
||||
@@ -0,0 +1,837 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from utils.db_manager import DatabaseManager
|
||||
from utils.cache_manager import CacheManager
|
||||
from utils.model_router import ModelRouter
|
||||
from utils.sensitive_filter import SensitiveDataFilter
|
||||
from spider.weibo_crawler import WeiboCrawler
|
||||
from utils.ai_analyzer import AIAnalyzer
|
||||
|
||||
# 配置日志
|
||||
from utils.logger import setup_logger
|
||||
logger = setup_logger('workflow_engine', 'logs/workflow_engine.log')
|
||||
|
||||
class WorkflowEngine:
|
||||
"""工作流引擎 - 负责执行数据爬取和分析工作流"""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(WorkflowEngine, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.db = DatabaseManager()
|
||||
self.cache = CacheManager(memory_capacity=50, cache_duration=3600)
|
||||
self.model_router = ModelRouter()
|
||||
self.sensitive_filter = SensitiveDataFilter()
|
||||
self.executor = ThreadPoolExecutor(max_workers=5)
|
||||
self.running_tasks = {}
|
||||
|
||||
# 创建必要的目录
|
||||
self.data_dir = Path('data/workflow')
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._initialized = True
|
||||
logger.info("工作流引擎初始化完成")
|
||||
|
||||
def execute_crawler_workflow(self, task_id, config):
|
||||
"""
|
||||
执行爬虫工作流
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
config: 爬虫配置
|
||||
"""
|
||||
logger.info(f"开始执行爬虫工作流: {task_id}")
|
||||
|
||||
try:
|
||||
# 更新任务状态为运行中
|
||||
self._update_task_status(task_id, 'running', 0)
|
||||
|
||||
# 创建爬虫实例
|
||||
crawler = WeiboCrawler()
|
||||
|
||||
# 设置爬虫参数
|
||||
source = config.get('source', 'hot_topics')
|
||||
depth = config.get('crawl_depth', 1)
|
||||
interval = config.get('interval', 5)
|
||||
filters = config.get('filters', {})
|
||||
|
||||
# 执行爬取
|
||||
result = crawler.crawl(
|
||||
source=source,
|
||||
depth=depth,
|
||||
interval=interval,
|
||||
filters=filters,
|
||||
callback=lambda progress: self._update_task_progress(task_id, progress)
|
||||
)
|
||||
|
||||
# 更新任务状态为已完成
|
||||
self._update_task_status(task_id, 'completed', 100, result=result)
|
||||
logger.info(f"爬虫工作流完成: {task_id}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"爬虫工作流出错: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
self._update_task_status(task_id, 'failed', 0, error=str(e))
|
||||
return None
|
||||
|
||||
def execute_analysis_workflow(self, task_id, workflow):
|
||||
"""
|
||||
执行分析工作流
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
workflow: 工作流配置
|
||||
"""
|
||||
logger.info(f"开始执行分析工作流: {task_id}")
|
||||
|
||||
try:
|
||||
# 更新任务状态为运行中
|
||||
self._update_task_status(task_id, 'running', 0)
|
||||
|
||||
components = workflow.get('components', [])
|
||||
connections = workflow.get('connections', [])
|
||||
|
||||
# 验证工作流
|
||||
if not components or not connections:
|
||||
raise ValueError("工作流配置不完整,缺少组件或连接")
|
||||
|
||||
# 构建组件依赖图
|
||||
component_map, dependency_graph = self._build_dependency_graph(components, connections)
|
||||
|
||||
# 进行拓扑排序
|
||||
execution_order = self._topological_sort(dependency_graph)
|
||||
|
||||
# 执行组件
|
||||
result_map = {}
|
||||
total_components = len(execution_order)
|
||||
|
||||
for idx, component_id in enumerate(execution_order):
|
||||
component = component_map.get(component_id)
|
||||
if not component:
|
||||
continue
|
||||
|
||||
# 计算总体进度
|
||||
progress = int((idx / total_components) * 100)
|
||||
self._update_task_progress(task_id, progress)
|
||||
|
||||
# 获取输入数据
|
||||
input_data = self._get_component_input_data(component_id, connections, result_map)
|
||||
|
||||
# 执行组件
|
||||
result = self._execute_component(component, input_data)
|
||||
|
||||
# 存储结果
|
||||
result_map[component_id] = result
|
||||
|
||||
# 获取最终输出
|
||||
final_outputs = self._get_final_outputs(dependency_graph, result_map)
|
||||
|
||||
# 应用敏感信息过滤
|
||||
if final_outputs and self.sensitive_filter.is_enabled():
|
||||
if isinstance(final_outputs, dict):
|
||||
final_outputs = self.sensitive_filter.filter_dict(final_outputs)
|
||||
elif isinstance(final_outputs, list):
|
||||
final_outputs = self.sensitive_filter.filter_list(final_outputs)
|
||||
|
||||
# 更新任务状态为已完成
|
||||
self._update_task_status(task_id, 'completed', 100, result=final_outputs)
|
||||
logger.info(f"分析工作流完成: {task_id}")
|
||||
|
||||
return final_outputs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"分析工作流出错: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
self._update_task_status(task_id, 'failed', 0, error=str(e))
|
||||
return None
|
||||
|
||||
def start_workflow(self, workflow_type, config, template_id=None):
|
||||
"""
|
||||
异步启动工作流
|
||||
|
||||
Args:
|
||||
workflow_type: 工作流类型 (crawler/analysis)
|
||||
config: 工作流配置
|
||||
template_id: 关联的模板ID
|
||||
|
||||
Returns:
|
||||
task_id: 工作流任务ID
|
||||
"""
|
||||
# 生成任务ID
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# 保存任务信息到数据库
|
||||
conn = self.db.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO workflow_tasks
|
||||
(id, template_id, type, status, progress, config, created_at, updated_at)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
||||
""",
|
||||
(
|
||||
task_id,
|
||||
template_id,
|
||||
workflow_type,
|
||||
'pending',
|
||||
0,
|
||||
json.dumps(config, ensure_ascii=False),
|
||||
now,
|
||||
now
|
||||
)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# 异步执行工作流
|
||||
if workflow_type == 'crawler':
|
||||
self.running_tasks[task_id] = self.executor.submit(
|
||||
self.execute_crawler_workflow, task_id, config
|
||||
)
|
||||
elif workflow_type == 'analysis':
|
||||
self.running_tasks[task_id] = self.executor.submit(
|
||||
self.execute_analysis_workflow, task_id, config
|
||||
)
|
||||
else:
|
||||
logger.error(f"未知的工作流类型: {workflow_type}")
|
||||
return None
|
||||
|
||||
return task_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"启动工作流失败: {str(e)}")
|
||||
conn.rollback()
|
||||
return None
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def get_task_status(self, task_id):
|
||||
"""
|
||||
获取任务状态
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
task: 任务信息
|
||||
"""
|
||||
# 先检查缓存
|
||||
cache_key = f"task_status:{task_id}"
|
||||
cached_task = self.cache.get(cache_key)
|
||||
if cached_task:
|
||||
return cached_task
|
||||
|
||||
# 从数据库获取
|
||||
conn = self.db.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute(
|
||||
"SELECT * FROM workflow_tasks WHERE id = %s",
|
||||
(task_id,)
|
||||
)
|
||||
task = cursor.fetchone()
|
||||
|
||||
if task:
|
||||
# 将JSON字符串转为Python对象
|
||||
if task.get('config'):
|
||||
task['config'] = json.loads(task['config'])
|
||||
if task.get('result'):
|
||||
task['result'] = json.loads(task['result'])
|
||||
|
||||
# 缓存结果
|
||||
self.cache.set(cache_key, task)
|
||||
|
||||
return task
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取任务状态失败: {str(e)}")
|
||||
return None
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def cancel_task(self, task_id):
|
||||
"""
|
||||
取消任务
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
success: 是否成功
|
||||
"""
|
||||
# 检查任务是否存在并正在运行
|
||||
if task_id in self.running_tasks:
|
||||
# 尝试取消任务
|
||||
future = self.running_tasks[task_id]
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
|
||||
# 从运行列表中移除
|
||||
del self.running_tasks[task_id]
|
||||
|
||||
# 更新数据库状态
|
||||
conn = self.db.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
cursor.execute(
|
||||
"""
|
||||
UPDATE workflow_tasks
|
||||
SET status = %s, updated_at = %s
|
||||
WHERE id = %s
|
||||
""",
|
||||
('cancelled', now, task_id)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# 清理缓存
|
||||
cache_key = f"task_status:{task_id}"
|
||||
self.cache.delete(cache_key)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"取消任务失败: {str(e)}")
|
||||
conn.rollback()
|
||||
return False
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def _update_task_status(self, task_id, status, progress, result=None, error=None):
|
||||
"""更新任务状态"""
|
||||
conn = self.db.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
update_fields = ["status = %s", "progress = %s", "updated_at = %s"]
|
||||
params = [status, progress, now]
|
||||
|
||||
# 添加开始时间
|
||||
if status == 'running' and progress == 0:
|
||||
update_fields.append("started_at = %s")
|
||||
params.append(now)
|
||||
|
||||
# 添加完成时间
|
||||
if status in ['completed', 'failed']:
|
||||
update_fields.append("completed_at = %s")
|
||||
params.append(now)
|
||||
|
||||
# 添加结果
|
||||
if result is not None:
|
||||
update_fields.append("result = %s")
|
||||
params.append(json.dumps(result, ensure_ascii=False))
|
||||
|
||||
# 添加错误
|
||||
if error is not None:
|
||||
update_fields.append("error = %s")
|
||||
params.append(error)
|
||||
|
||||
# 构建SQL
|
||||
sql = f"""
|
||||
UPDATE workflow_tasks
|
||||
SET {', '.join(update_fields)}
|
||||
WHERE id = %s
|
||||
"""
|
||||
params.append(task_id)
|
||||
|
||||
cursor.execute(sql, tuple(params))
|
||||
conn.commit()
|
||||
|
||||
# 清理缓存
|
||||
cache_key = f"task_status:{task_id}"
|
||||
self.cache.delete(cache_key)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新任务状态失败: {str(e)}")
|
||||
conn.rollback()
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def _update_task_progress(self, task_id, progress):
|
||||
"""更新任务进度"""
|
||||
self._update_task_status(task_id, 'running', progress)
|
||||
|
||||
def _build_dependency_graph(self, components, connections):
|
||||
"""构建组件依赖图"""
|
||||
component_map = {comp['id']: comp for comp in components}
|
||||
dependency_graph = {comp['id']: [] for comp in components}
|
||||
|
||||
# 构建依赖关系
|
||||
for conn in connections:
|
||||
source = conn.get('source')
|
||||
target = conn.get('target')
|
||||
|
||||
if source and target and source in component_map and target in component_map:
|
||||
dependency_graph[target].append(source)
|
||||
|
||||
return component_map, dependency_graph
|
||||
|
||||
def _topological_sort(self, graph):
|
||||
"""拓扑排序,确定组件执行顺序"""
|
||||
visited = set()
|
||||
temp = set()
|
||||
order = []
|
||||
|
||||
def visit(node):
|
||||
if node in temp:
|
||||
raise ValueError(f"工作流存在循环依赖: {node}")
|
||||
if node in visited:
|
||||
return
|
||||
|
||||
temp.add(node)
|
||||
for neighbor in graph.get(node, []):
|
||||
visit(neighbor)
|
||||
|
||||
temp.remove(node)
|
||||
visited.add(node)
|
||||
order.append(node)
|
||||
|
||||
for node in graph:
|
||||
if node not in visited:
|
||||
visit(node)
|
||||
|
||||
return list(reversed(order))
|
||||
|
||||
def _get_component_input_data(self, component_id, connections, result_map):
|
||||
"""获取组件的输入数据"""
|
||||
input_data = {}
|
||||
|
||||
for conn in connections:
|
||||
if conn.get('target') == component_id:
|
||||
source_id = conn.get('source')
|
||||
if source_id in result_map:
|
||||
input_name = conn.get('targetInput', 'default')
|
||||
input_data[input_name] = result_map[source_id]
|
||||
|
||||
return input_data
|
||||
|
||||
def _execute_component(self, component, input_data):
|
||||
"""执行单个组件"""
|
||||
component_type = component.get('type')
|
||||
config = component.get('config', {})
|
||||
|
||||
if component_type == 'data_source':
|
||||
return self._execute_data_source(config, input_data)
|
||||
elif component_type == 'preprocessing':
|
||||
return self._execute_preprocessing(config, input_data)
|
||||
elif component_type == 'model':
|
||||
return self._execute_model(config, input_data)
|
||||
elif component_type == 'visualization':
|
||||
return self._execute_visualization(config, input_data)
|
||||
else:
|
||||
logger.warning(f"未知的组件类型: {component_type}")
|
||||
return None
|
||||
|
||||
def _execute_data_source(self, config, input_data):
|
||||
"""执行数据源组件"""
|
||||
source_type = config.get('source_type')
|
||||
|
||||
if source_type == 'database':
|
||||
# 从数据库获取数据
|
||||
table = config.get('table')
|
||||
filters = config.get('filters', {})
|
||||
limit = config.get('limit', 1000)
|
||||
|
||||
query_conditions = []
|
||||
query_params = []
|
||||
|
||||
for key, value in filters.items():
|
||||
if value:
|
||||
query_conditions.append(f"{key} = %s")
|
||||
query_params.append(value)
|
||||
|
||||
where_clause = f"WHERE {' AND '.join(query_conditions)}" if query_conditions else ""
|
||||
|
||||
sql = f"SELECT * FROM {table} {where_clause} LIMIT {limit}"
|
||||
|
||||
conn = self.db.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
cursor.execute(sql, tuple(query_params))
|
||||
return cursor.fetchall()
|
||||
except Exception as e:
|
||||
logger.error(f"数据库查询出错: {str(e)}")
|
||||
return []
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
elif source_type == 'file':
|
||||
# 从文件加载数据
|
||||
file_path = config.get('file_path')
|
||||
if not file_path or not os.path.exists(file_path):
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
if file_path.endswith('.json'):
|
||||
return json.load(f)
|
||||
else:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
logger.error(f"文件读取出错: {str(e)}")
|
||||
return []
|
||||
|
||||
elif source_type == 'api':
|
||||
# 这里需要实现API调用逻辑
|
||||
# 由于涉及复杂的HTTP请求,暂不实现
|
||||
logger.warning("API数据源暂未实现")
|
||||
return []
|
||||
|
||||
else:
|
||||
logger.warning(f"未知的数据源类型: {source_type}")
|
||||
return []
|
||||
|
||||
def _execute_preprocessing(self, config, input_data):
|
||||
"""执行数据预处理组件"""
|
||||
preprocessing_type = config.get('preprocessing_type')
|
||||
data = input_data.get('default', [])
|
||||
|
||||
if not data:
|
||||
return []
|
||||
|
||||
if preprocessing_type == 'filter':
|
||||
# 过滤数据
|
||||
field = config.get('field')
|
||||
value = config.get('value')
|
||||
operator = config.get('operator', 'eq')
|
||||
|
||||
if not field:
|
||||
return data
|
||||
|
||||
result = []
|
||||
for item in data:
|
||||
if operator == 'eq' and item.get(field) == value:
|
||||
result.append(item)
|
||||
elif operator == 'neq' and item.get(field) != value:
|
||||
result.append(item)
|
||||
elif operator == 'contains' and value in str(item.get(field, '')):
|
||||
result.append(item)
|
||||
elif operator == 'not_contains' and value not in str(item.get(field, '')):
|
||||
result.append(item)
|
||||
|
||||
return result
|
||||
|
||||
elif preprocessing_type == 'sort':
|
||||
# 排序数据
|
||||
field = config.get('field')
|
||||
order = config.get('order', 'asc')
|
||||
|
||||
if not field:
|
||||
return data
|
||||
|
||||
return sorted(
|
||||
data,
|
||||
key=lambda x: x.get(field, ''),
|
||||
reverse=(order == 'desc')
|
||||
)
|
||||
|
||||
elif preprocessing_type == 'aggregate':
|
||||
# 聚合数据
|
||||
group_by = config.get('group_by')
|
||||
aggregate_field = config.get('aggregate_field')
|
||||
aggregate_type = config.get('aggregate_type', 'count')
|
||||
|
||||
if not group_by:
|
||||
return data
|
||||
|
||||
result = {}
|
||||
for item in data:
|
||||
key = item.get(group_by)
|
||||
if key not in result:
|
||||
result[key] = {
|
||||
'count': 0,
|
||||
'sum': 0,
|
||||
'values': []
|
||||
}
|
||||
|
||||
result[key]['count'] += 1
|
||||
|
||||
if aggregate_field:
|
||||
value = item.get(aggregate_field, 0)
|
||||
if isinstance(value, (int, float)):
|
||||
result[key]['sum'] += value
|
||||
result[key]['values'].append(value)
|
||||
|
||||
# 计算最终结果
|
||||
final_result = []
|
||||
for key, values in result.items():
|
||||
item = {group_by: key}
|
||||
|
||||
if aggregate_type == 'count':
|
||||
item['value'] = values['count']
|
||||
elif aggregate_type == 'sum':
|
||||
item['value'] = values['sum']
|
||||
elif aggregate_type == 'avg':
|
||||
item['value'] = values['sum'] / values['count'] if values['count'] > 0 else 0
|
||||
|
||||
final_result.append(item)
|
||||
|
||||
return final_result
|
||||
|
||||
else:
|
||||
logger.warning(f"未知的预处理类型: {preprocessing_type}")
|
||||
return data
|
||||
|
||||
def _execute_model(self, config, input_data):
|
||||
"""执行模型组件"""
|
||||
model_type = config.get('model_type')
|
||||
data = input_data.get('default', [])
|
||||
|
||||
if not data:
|
||||
return []
|
||||
|
||||
analyzer = AIAnalyzer()
|
||||
|
||||
if model_type == 'sentiment':
|
||||
# 情感分析
|
||||
texts = []
|
||||
if isinstance(data, list):
|
||||
# 如果是列表,从指定字段获取文本
|
||||
field = config.get('text_field', 'content')
|
||||
texts = [item.get(field, '') for item in data if item.get(field)]
|
||||
elif isinstance(data, str):
|
||||
# 如果是字符串,直接使用
|
||||
texts = [data]
|
||||
|
||||
# 获取合适的模型
|
||||
model = self.model_router.select_model_for_text(texts[0] if texts else "", "sentiment")
|
||||
|
||||
# 执行分析
|
||||
results = []
|
||||
for text in texts:
|
||||
result = analyzer.analyze_sentiment(text, model=model)
|
||||
results.append(result)
|
||||
|
||||
# 如果输入是列表,将结果合并回原始数据
|
||||
if isinstance(data, list):
|
||||
field = config.get('text_field', 'content')
|
||||
for i, item in enumerate(data):
|
||||
if i < len(results) and item.get(field):
|
||||
item['sentiment'] = results[i]
|
||||
return data
|
||||
else:
|
||||
return results[0] if results else None
|
||||
|
||||
elif model_type == 'topic':
|
||||
# 主题分类
|
||||
texts = []
|
||||
if isinstance(data, list):
|
||||
field = config.get('text_field', 'content')
|
||||
texts = [item.get(field, '') for item in data if item.get(field)]
|
||||
elif isinstance(data, str):
|
||||
texts = [data]
|
||||
|
||||
# 获取合适的模型
|
||||
model = self.model_router.select_model_for_text(texts[0] if texts else "", "topic")
|
||||
|
||||
# 执行分析
|
||||
results = []
|
||||
for text in texts:
|
||||
result = analyzer.analyze_topic(text, model=model)
|
||||
results.append(result)
|
||||
|
||||
# 如果输入是列表,将结果合并回原始数据
|
||||
if isinstance(data, list):
|
||||
field = config.get('text_field', 'content')
|
||||
for i, item in enumerate(data):
|
||||
if i < len(results) and item.get(field):
|
||||
item['topic'] = results[i]
|
||||
return data
|
||||
else:
|
||||
return results[0] if results else None
|
||||
|
||||
elif model_type == 'keywords':
|
||||
# 关键词提取
|
||||
texts = []
|
||||
if isinstance(data, list):
|
||||
field = config.get('text_field', 'content')
|
||||
texts = [item.get(field, '') for item in data if item.get(field)]
|
||||
elif isinstance(data, str):
|
||||
texts = [data]
|
||||
|
||||
# 获取合适的模型
|
||||
model = self.model_router.select_model_for_text(texts[0] if texts else "", "keyword")
|
||||
|
||||
# 执行分析
|
||||
results = []
|
||||
for text in texts:
|
||||
result = analyzer.extract_keywords(text, model=model)
|
||||
results.append(result)
|
||||
|
||||
# 如果输入是列表,将结果合并回原始数据
|
||||
if isinstance(data, list):
|
||||
field = config.get('text_field', 'content')
|
||||
for i, item in enumerate(data):
|
||||
if i < len(results) and item.get(field):
|
||||
item['keywords'] = results[i]
|
||||
return data
|
||||
else:
|
||||
return results[0] if results else None
|
||||
|
||||
elif model_type == 'summarize':
|
||||
# 文本摘要
|
||||
texts = []
|
||||
if isinstance(data, list):
|
||||
field = config.get('text_field', 'content')
|
||||
texts = [item.get(field, '') for item in data if item.get(field)]
|
||||
elif isinstance(data, str):
|
||||
texts = [data]
|
||||
|
||||
# 获取合适的模型
|
||||
model = self.model_router.select_model_for_text(texts[0] if texts else "", "summarization")
|
||||
|
||||
# 执行分析
|
||||
results = []
|
||||
for text in texts:
|
||||
result = analyzer.summarize_text(text, model=model)
|
||||
results.append(result)
|
||||
|
||||
# 如果输入是列表,将结果合并回原始数据
|
||||
if isinstance(data, list):
|
||||
field = config.get('text_field', 'content')
|
||||
for i, item in enumerate(data):
|
||||
if i < len(results) and item.get(field):
|
||||
item['summary'] = results[i]
|
||||
return data
|
||||
else:
|
||||
return results[0] if results else None
|
||||
|
||||
else:
|
||||
logger.warning(f"未知的模型类型: {model_type}")
|
||||
return data
|
||||
|
||||
def _execute_visualization(self, config, input_data):
|
||||
"""执行可视化组件"""
|
||||
visualization_type = config.get('visualization_type')
|
||||
data = input_data.get('default', [])
|
||||
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
if visualization_type == 'chart':
|
||||
# 图表可视化
|
||||
chart_type = config.get('chart_type', 'bar')
|
||||
x_field = config.get('x_field')
|
||||
y_field = config.get('y_field')
|
||||
title = config.get('title', '数据可视化')
|
||||
|
||||
if not x_field or not y_field:
|
||||
return {'error': '缺少x或y字段'}
|
||||
|
||||
# 提取数据
|
||||
chart_data = {
|
||||
'type': chart_type,
|
||||
'title': title,
|
||||
'xAxis': {'type': 'category', 'data': []},
|
||||
'yAxis': {'type': 'value'},
|
||||
'series': [{'data': []}]
|
||||
}
|
||||
|
||||
for item in data:
|
||||
x_value = item.get(x_field)
|
||||
y_value = item.get(y_field)
|
||||
|
||||
if x_value is not None and y_value is not None:
|
||||
chart_data['xAxis']['data'].append(x_value)
|
||||
chart_data['series'][0]['data'].append(y_value)
|
||||
|
||||
return chart_data
|
||||
|
||||
elif visualization_type == 'table':
|
||||
# 表格可视化
|
||||
columns = config.get('columns', [])
|
||||
title = config.get('title', '数据表格')
|
||||
|
||||
# 如果没有指定列,使用数据中的所有字段
|
||||
if not columns and isinstance(data, list) and data:
|
||||
columns = list(data[0].keys())
|
||||
|
||||
# 构建表格数据
|
||||
table_data = {
|
||||
'type': 'table',
|
||||
'title': title,
|
||||
'columns': columns,
|
||||
'data': data
|
||||
}
|
||||
|
||||
return table_data
|
||||
|
||||
elif visualization_type == 'wordcloud':
|
||||
# 词云可视化
|
||||
word_field = config.get('word_field')
|
||||
value_field = config.get('value_field')
|
||||
title = config.get('title', '词云图')
|
||||
|
||||
if not word_field:
|
||||
return {'error': '缺少词字段'}
|
||||
|
||||
# 构建词云数据
|
||||
wordcloud_data = {
|
||||
'type': 'wordcloud',
|
||||
'title': title,
|
||||
'data': []
|
||||
}
|
||||
|
||||
for item in data:
|
||||
word = item.get(word_field)
|
||||
value = item.get(value_field, 1)
|
||||
|
||||
if word:
|
||||
wordcloud_data['data'].append({
|
||||
'name': word,
|
||||
'value': value
|
||||
})
|
||||
|
||||
return wordcloud_data
|
||||
|
||||
else:
|
||||
logger.warning(f"未知的可视化类型: {visualization_type}")
|
||||
return {}
|
||||
|
||||
def _get_final_outputs(self, dependency_graph, result_map):
|
||||
"""获取最终输出结果"""
|
||||
# 找出没有后继节点的叶子节点
|
||||
leaf_nodes = []
|
||||
all_targets = set()
|
||||
|
||||
for node, deps in dependency_graph.items():
|
||||
all_targets.update(deps)
|
||||
|
||||
for node in dependency_graph:
|
||||
if node not in all_targets:
|
||||
leaf_nodes.append(node)
|
||||
|
||||
# 收集所有叶子节点的结果
|
||||
outputs = {}
|
||||
for node in leaf_nodes:
|
||||
if node in result_map:
|
||||
outputs[node] = result_map[node]
|
||||
|
||||
return outputs
|
||||
Reference in New Issue
Block a user